示例#1
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        softplus = match['op']

        name = softplus.soft_get('name', softplus.id)
        exp_node = Exp(graph, {'name': name + '/Exp'}).create_node()
        add_node = create_op_node_with_second_input(graph, Add,
                                                    float_array([1.0]),
                                                    {'name': name + '/Add'})
        log_node = Log(graph, {'name': name + '/Log'}).create_node()
        rename_nodes([(softplus, name + '/Log'), (log_node, name)])

        softplus.in_port(0).get_connection().set_destination(
            exp_node.in_port(0))
        add_node.in_port(0).connect(exp_node.out_port(0))
        log_node.in_port(0).connect(add_node.out_port(0))
        softplus.out_port(0).get_connection().set_source(log_node.out_port(0))
    def replace_op(self, graph: Graph, node: Node):
        node_name = node.soft_get('name', node.id)
        const_dtype = np.float32
        if node.has_valid('data_type'):
            const_dtype = node.data_type
        const = Const(graph, {
            'value': np.array([1], dtype=const_dtype)
        }).create_node()
        add = Add(graph, {'name': node.name + '/Add_'}).create_node()
        log = Log(graph, {'name': node.name + '/Log_'}).create_node()

        # Connect nodes: input -> Add -> Log
        const.out_port(0).connect(add.in_port(0))
        node.in_port(0).get_connection().set_destination(add.in_port(1))
        add.out_port(0).connect(log.in_port(0))
        rename_nodes([(node, node_name + '/delete'), (log, node_name)])

        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [log.id]
示例#3
0
 def extract(cls, node):
     Log.update_node_stat(node)
     return cls.enabled