def test_activation_mish_infer(self): graph = build_graph( self.nodes_attributes, [('node_1', 'activation_node'), ('activation_node', 'node_3')], { 'node_1': { 'value': np.array([-1.0, 0.0, 1.0, 20.0]) }, 'activation_node': { 'op': 'Mish', 'operation': Mish.operation, }, 'node_3': { 'value': None } }) graph.graph['layout'] = 'NCHW' activation_node = Node(graph, 'activation_node') Mish.infer(activation_node) exp_shape = np.array([4]) res_shape = graph.node['node_3']['shape'] res_value = graph.node['node_3']['value'] exp_value = np.array([-0.30340146, 0.0, 0.8650984, 20.0]) for i, value in enumerate(exp_shape): self.assertEqual(res_shape[i], value) for i, value in enumerate(exp_value): self.assertAlmostEqual(res_value[i], value)
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): mul = match['mul'] mul_name = mul.soft_get('name', mul.id) softplus = match['softplus'] # determine the input port of Mul which gets the 'input' node output input_port_idx = int(mul.in_port(0).get_connection().get_source().node.soft_get('op') == 'Tanh') # check that the same tensor provided as input to Mul and SoftPlus if mul.in_port(input_port_idx).get_source() != softplus.in_port(0).get_source(): return mish = Mish(graph, {}).create_node() mish.in_port(0).connect(mul.in_port(input_port_idx).get_source()) mul.out_port(0).get_connection().set_source(mish.out_port(0)) rename_nodes([(mul, mul_name + '/TBR'), (mish, mul_name)])
def extract(cls, node): Mish.update_node_stat(node) return cls.enabled