示例#1
0
    def replace_op(self, graph: Graph, node: Node):
        """
        Replace Softsign according to formula feature/(abs(feature)+1)
        """
        abs_node = Abs(graph, {'name': "abs_" + node.id}).create_node()
        abs_node.in_port(0).connect(node.in_port(0).get_source())

        add_node = create_op_node_with_second_input(graph, Add, np.ones(
            [1]), {"name": node.id + "_plus_1"})
        add_node.in_port(0).connect(abs_node.out_port(0))
        div_node = Div(graph, {"name": "div_" + node.id}).create_node()
        div_node.in_port(0).connect(node.in_port(0).get_source())
        div_node.in_port(1).connect(add_node.out_port(0))
        return [div_node.id]
示例#2
0
    def insert_abs_max(self, model_graph, node, type_stat, node_name,
                       **kwargs):
        axis_const = self.find_axis(node, kwargs.get('granularity'))
        if isinstance(axis_const, str):
            return (True, node.name)
        abs_node = Abs(node.graph, {
            "name": f'abs_{node_name}'
        }).create_node_with_data([node.out_node(0)]).in_node(0)
        max_op = create_op_node_with_second_input(
            node.graph, ReduceMax, int64_array(axis_const),
            dict(name=f'{type_stat}_{node_name}'))

        if node.graph != model_graph:
            Op.create_data_node(max_op.graph, max_op, {'shape': [1]})
        max_op['fullname'] = reset_node_fullname(node.fullname, max_op.name)
        abs_node.out_port(0).connect(max_op.in_port(0))
        return self.insert_result(model_graph, node, max_op, type_stat)
示例#3
0
 def insert_abs_max(self, node, type_stat, node_name, **kwargs):
     axis_const = self.find_axis(node, kwargs.get('granularity'))
     if isinstance(axis_const, str):
         return node.name
     abs_node = Abs(node.graph, {
         "name": type_stat + node_name
     }).create_node()
     abs_node.in_port(0).connect(node.out_port(0))
     max_op = create_op_node_with_second_input(
         node.graph, ReduceMax, int64_array(axis_const),
         dict(name='abs_max_' + node_name))
     abs_node.out_port(0).connect(max_op.in_port(0))
     self.insert_result(node, max_op, type_stat)
     return None
示例#4
0
 def extract(cls, node):
     Abs.update_node_stat(node)
     return cls.enabled