예제 #1
0
 def test_activation_softplus_infer(self):
     graph = build_graph(
         self.nodes_attributes, [('node_1', 'activation_node'),
                                 ('activation_node', 'node_3')],
         {
             'node_1': {
                 'value': np.array([-1.0, 0.0, 1.0, 20.0])
             },
             'activation_node': {
                 'op': 'SoftPlus',
                 'operation': SoftPlus.operation,
             },
             'node_3': {
                 'value': None
             }
         })
     graph.graph['layout'] = 'NCHW'
     activation_node = Node(graph, 'activation_node')
     SoftPlus.infer(activation_node)
     exp_shape = np.array([4])
     res_shape = graph.node['node_3']['shape']
     res_value = graph.node['node_3']['value']
     exp_value = np.array([0.3132617, 0.6931472, 1.3132617, 20.0])
     for i, value in enumerate(exp_shape):
         self.assertEqual(res_shape[i], value)
     for i, value in enumerate(exp_value):
         self.assertAlmostEqual(res_value[i], value)
예제 #2
0
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        ln = match['ln']
        exp = match['exp']

        ln_name = ln.soft_get('name', ln.id)

        softplus = SoftPlus(graph, {}).create_node()
        softplus.in_port(0).connect(exp.in_port(0).get_source())
        ln.out_port(0).get_connection().set_source(softplus.out_port(0))

        rename_nodes([(ln, ln_name + '/TBR'), (softplus, ln_name)])
예제 #3
0
 def extract(cls, node):
     SoftPlus.update_node_stat(node, {})
     return cls.enabled