예제 #1
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['minimum']
        # Constant propagation case
        if node.in_node(0).value is not None and node.in_node(
                1).value is not None:
            return

        neg_1_const = Const(
            graph, dict(value=np.array(-1), name=node.name + '/negate1_const'))
        neg_2_const = Const(
            graph, dict(value=np.array(-1), name=node.name + '/negate2_const'))
        negate_1 = Mul(graph, dict(name=node.name + '/negate1_'))
        negate_2 = Mul(graph, dict(name=node.name + '/negate2_'))
        maximum = Maximum(graph, dict(name=node.name + '/Max_'))
        negate_output_const = Const(
            graph,
            dict(value=np.array(-1), name=node.name + '/negate_out_const_'))
        negate_output = Mul(graph,
                            dict(scale=-1, name=node.name + '/negate_out_'))

        negate_output.create_node_with_data(inputs=[
            maximum.create_node_with_data([
                negate_1.create_node_with_data(
                    [node.in_node(0),
                     neg_1_const.create_node_with_data()]),
                negate_2.create_node_with_data(
                    [node.in_node(1),
                     neg_2_const.create_node_with_data()])
            ]),
            negate_output_const.create_node_with_data()
        ],
                                            data_nodes=node.out_node())
        # Delete minimum vertex
        node.graph.remove_node(node.id)
예제 #2
0
    def replace_pattern(self, graph: Graph, match: dict):
        clamp = match['clamp']
        name = clamp.soft_get('name', clamp.id)

        min_value = max_value = None
        port_1_exist = clamp.has_port(
            'in', 1) and not clamp.in_port(1).disconnected()
        port_2_exist = clamp.has_port(
            'in', 2) and not clamp.in_port(2).disconnected()
        if port_1_exist and clamp.in_port(1).get_source().node.soft_get(
                'type') == 'Const':
            min_value = clamp.in_port(1).data.get_value()
        if port_2_exist and clamp.in_port(2).get_source().node.soft_get(
                'type') == 'Const':
            max_value = clamp.in_port(2).data.get_value()

        rename_node(clamp, name + '/TBR')
        if min_value is None or max_value is None:
            max_node = min_node = None
            if port_1_exist:
                max_node = Maximum(graph, {}).create_node()
                clamp.in_port(0).get_connection().set_destination(
                    max_node.in_port(0))
                clamp.in_port(1).get_connection().set_destination(
                    max_node.in_port(1))
                clamp.out_port(0).get_connection().set_source(
                    max_node.out_port(0))
            if port_2_exist:
                min_node = Minimum(graph, {}).create_node()
                if max_node is not None:
                    max_node.out_port(0).get_connection().set_source(
                        min_node.out_port(0))
                    max_node.out_port(0).connect(min_node.in_port(0))
                else:
                    clamp.in_port(0).get_connection().set_destination(
                        min_node.in_port(0))
                    clamp.out_port(0).get_connection().set_source(
                        min_node.out_port(0))
                clamp.in_port(2).get_connection().set_destination(
                    min_node.in_port(1))
            assert min_node is not None or max_node is not None, 'Clamp node should have either min or max input used'
            rename_node(min_node if min_node is not None else max_node, name)
        else:
            a_clamp = AttributedClamp(graph, {
                'name': name,
                'min': min_value,
                'max': max_value
            }).create_node()
            rename_node(a_clamp, name)
            clamp.in_port(0).get_connection().set_destination(
                a_clamp.in_port(0))
            clamp.out_port(0).get_connection().set_source(a_clamp.out_port(0))
    def find_and_replace_pattern(self, graph: Graph):
        for cbv in graph.get_op_nodes(op='ClipByValueTF'):
            cbv_name = cbv.soft_get('name', cbv.id)
            minimum = Minimum(graph, {
                'name': cbv_name + '/CLipMinimum'
            }).create_node()
            maximum = Maximum(graph, {
                'name': cbv_name + '/CLipMaximum'
            }).create_node()
            minimum.in_port(0).connect(cbv.in_port(0).get_source())
            minimum.in_port(1).connect(cbv.in_port(2).get_source())
            maximum.in_port(0).connect(minimum.out_port(0))
            maximum.in_port(1).connect(cbv.in_port(1).get_source())
            cbv.out_port(0).get_connection().set_source(maximum.out_port(0))

            rename_nodes([(cbv, cbv_name + '/TBR'), (maximum, cbv_name)])
            graph.remove_node(cbv.id)
예제 #4
0
 def extract(cls, node):
     Maximum.update_node_stat(
         node, {'data_type': tf_dtype_extractor(node.pb.attr["T"].type)})
     return cls.enabled
예제 #5
0
 def extract(cls, node):
     Maximum.update_node_stat(node)
     return cls.enabled
예제 #6
0
 def extract(node):
     Maximum.update_node_stat(node)
     return __class__.enabled