Ejemplo n.º 1
0
 def test_activation_mish_infer(self):
     graph = build_graph(
         self.nodes_attributes, [('node_1', 'activation_node'),
                                 ('activation_node', 'node_3')],
         {
             'node_1': {
                 'value': np.array([-1.0, 0.0, 1.0, 20.0])
             },
             'activation_node': {
                 'op': 'Mish',
                 'operation': Mish.operation,
             },
             'node_3': {
                 'value': None
             }
         })
     graph.graph['layout'] = 'NCHW'
     activation_node = Node(graph, 'activation_node')
     Mish.infer(activation_node)
     exp_shape = np.array([4])
     res_shape = graph.node['node_3']['shape']
     res_value = graph.node['node_3']['value']
     exp_value = np.array([-0.30340146, 0.0, 0.8650984, 20.0])
     for i, value in enumerate(exp_shape):
         self.assertEqual(res_shape[i], value)
     for i, value in enumerate(exp_value):
         self.assertAlmostEqual(res_value[i], value)
Ejemplo n.º 2
0
    def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
        mul = match['mul']
        mul_name = mul.soft_get('name', mul.id)
        softplus = match['softplus']

        # determine the input port of Mul which gets the 'input' node output
        input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Tanh')

        # check that the same tensor provided as input to Mul and SoftPlus
        if mul.in_port(input_port_idx).get_source() != softplus.in_port(0).get_source():
            return

        mish = Mish(graph, {}).create_node()
        mish.in_port(0).connect(mul.in_port(input_port_idx).get_source())
        mul.out_port(0).get_connection().set_source(mish.out_port(0))

        rename_nodes([(mul, mul_name + '/TBR'), (mish, mul_name)])
Ejemplo n.º 3
0
 def extract(cls, node):
     Mish.update_node_stat(node)
     return cls.enabled