def replace_sub_graph(self, graph: Graph, match: dict): softplus = match['op'] name = softplus.soft_get('name', softplus.id) exp_node = Exp(graph, {'name': name + '/Exp'}).create_node() add_node = create_op_node_with_second_input(graph, Add, float_array([1.0]), {'name': name + '/Add'}) log_node = Log(graph, {'name': name + '/Log'}).create_node() rename_nodes([(softplus, name + '/Log'), (log_node, name)]) softplus.in_port(0).get_connection().set_destination( exp_node.in_port(0)) add_node.in_port(0).connect(exp_node.out_port(0)) log_node.in_port(0).connect(add_node.out_port(0)) softplus.out_port(0).get_connection().set_source(log_node.out_port(0))
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) const_dtype = np.float32 if node.has_valid('data_type'): const_dtype = node.data_type const = Const(graph, { 'value': np.array([1], dtype=const_dtype) }).create_node() add = Add(graph, {'name': node.name + '/Add_'}).create_node() log = Log(graph, {'name': node.name + '/Log_'}).create_node() # Connect nodes: input -> Add -> Log const.out_port(0).connect(add.in_port(0)) node.in_port(0).get_connection().set_destination(add.in_port(1)) add.out_port(0).connect(log.in_port(0)) rename_nodes([(node, node_name + '/delete'), (log, node_name)]) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [log.id]
def extract(cls, node): Log.update_node_stat(node) return cls.enabled