def test_activation_softplus_infer(self): graph = build_graph( self.nodes_attributes, [('node_1', 'activation_node'), ('activation_node', 'node_3')], { 'node_1': { 'value': np.array([-1.0, 0.0, 1.0, 20.0]) }, 'activation_node': { 'op': 'SoftPlus', 'operation': SoftPlus.operation, }, 'node_3': { 'value': None } }) graph.graph['layout'] = 'NCHW' activation_node = Node(graph, 'activation_node') SoftPlus.infer(activation_node) exp_shape = np.array([4]) res_shape = graph.node['node_3']['shape'] res_value = graph.node['node_3']['value'] exp_value = np.array([0.3132617, 0.6931472, 1.3132617, 20.0]) for i, value in enumerate(exp_shape): self.assertEqual(res_shape[i], value) for i, value in enumerate(exp_value): self.assertAlmostEqual(res_value[i], value)
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): ln = match['ln'] exp = match['exp'] ln_name = ln.soft_get('name', ln.id) softplus = SoftPlus(graph, {}).create_node() softplus.in_port(0).connect(exp.in_port(0).get_source()) ln.out_port(0).get_connection().set_source(softplus.out_port(0)) rename_nodes([(ln, ln_name + '/TBR'), (softplus, ln_name)])
def extract(cls, node): SoftPlus.update_node_stat(node, {}) return cls.enabled