Beispiel #1
0
    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
        reciprocal = Power(
            graph, dict(scale=1, power=-1, shift=0,
                        name=node.name + '/power_'))
        out_node = reciprocal.create_node([node.in_node(0)])

        return [out_node.id]
    def replace_pattern(graph: Graph, match: [str, Node]):
        consumers = [
            n for n in match if n not in ['mul_add', 'pow_d']
            and not check_node_usages_out_of_match(match, n)
        ]
        if consumers:
            log.warning(
                'Power(mul_add,pow) pattern was detected. Non pattern consumers of nodes: "{}" were found.'
                ' Won\'t replace'.format(', '.join(
                    [match[n].id for n in consumers])))
            return
        mul_add = match['mul_add']
        pow = match['pow']
        new_power = Power(
            graph, {
                'name': mul_add.name + '/fused_power',
                'shift': mul_add.shift,
                'scale': mul_add.scale,
                'power': pow.power
            }).create_node()

        source = mul_add.in_port(0).get_connection().get_source()
        mul_add.in_port(0).disconnect()
        new_power.in_port(0).connect(source)
        pow.out_port(0).get_connection().set_source(new_power.out_port(0))

        log.debug(
            'Power nodes {} and {} were fused to single Power node {}'.format(
                mul_add.name, pow.name, new_power.name))
Beispiel #3
0
    def replace_op(self, graph: Graph, node: Node):
        power = Power(graph, dict(scale=0, name=node.name + '/Power/')).create_node()

        # Reconnecting inputs to this new node
        node.in_port(0).get_connection().set_destination(power.in_port(0))
        node.out_port(0).get_connection().set_source(power.out_port(0))
        return [power.id]
Beispiel #4
0
 def _create_sub(graph: Graph, input_1: Node, port_1: int, input_2: Node,
                 port_2: int):
     negate = Power(graph, dict(scale=-1, name=input_2.name + '/negate_'))
     add = Eltwise(graph, dict(operation='sum',
                               name=input_1.name + '/add_'))
     out_node = add.create_node([(input_1, port_1),
                                 negate.create_node([(input_2, port_2)])])
     return out_node
Beispiel #5
0
    def test_power_two_input_infer3(self):
        graph = self.create_graph(single_input=False)
        power_node = Node(graph, 'power')
        input2 = Node(graph, 'input2')
        input2.value = None

        Power.infer(power_node)

        self.assertIsNone(power_node.out_node().shape)
Beispiel #6
0
 def replace_op(self, graph: nx.MultiDiGraph, node: Node):
     negate = Power(graph, dict(scale=-1, name=node.name + '/negate_'))
     add = Eltwise(graph, dict(operation='sum', name=node.name + '/add_'))
     out_node = add.create_node([
         (node.in_node(0), node.in_edge(0)['out']),
         negate.create_node([(node.in_node(1), node.in_edge(1)['out'])])
     ])
     # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0.
     # The "explicit" version of the return value is: [(out_node.id, 0)])
     return [out_node.id]
Beispiel #7
0
    def test_power_two_input_infer1(self):
        graph = self.create_graph(single_input=False)
        graph.graph['layout'] = 'NCHW'
        power_node = Node(graph, 'power')

        Power.infer(power_node)

        self.assertTrue(
            np.array_equal(power_node.out_node().shape,
                           power_node.in_node(0).shape))
    def replace_pattern(graph: Graph, match: dict):
        node = match['op']
        shape = node.in_port(0).data.get_shape().copy()

        assert shape[1] % node.group == 0

        power_node = Power(graph, attrs={'name': node.id + '_power',
                                         'power': node.p}).create_node()

        reshape_node = create_op_node_with_second_input(graph, Reshape,
                                                        int64_array([shape[0], shape[1] / node.group, node.group]),
                                                        {'name': node.id + '_reshape'})
        reshape_node.in_port(0).connect(power_node.out_port(0))

        reducesum_node = create_op_node_with_second_input(graph, ReduceSum,
                                                          int64_array([2]),
                                                          {'name': node.id + '_sum', 'keep_dims': False})
        reducesum_node.in_port(0).connect(reshape_node.out_port(0))

        invpower_node = Power(graph, attrs={'name': node.id + '_invpower',
                                            'power': 1.0 / node.p}).create_node()
        invpower_node.in_port(0).connect(reducesum_node.out_port(0))

        node.in_port(0).get_connection().set_destination(power_node.in_port(0))
        node.out_port(0).get_connection().set_source(invpower_node.out_port(0))
Beispiel #9
0
    def replace_op(self, graph: Graph, node: Node):
        reciprocal = Power(graph, {'scale': 1, 'power': np.float64(-1), 'shift': 0,
                                   'name': node.name + '/reciprocal_'}).create_node()
        mul = Eltwise(graph, {'operation': 'mul', 'name': node.name + '/mul_'}).create_node()

        # Connect nodes
        node.in_port(1).get_connection().set_destination(reciprocal.in_port(0))
        node.in_port(0).get_connection().set_destination(mul.in_port(1))
        reciprocal.out_port(0).connect(mul.in_port(0))

        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [mul.id]
 def extract(node: Node):
     pb = node.pb
     assert pb, 'Protobuf layer can not be empty'
     param = pb.power_param
     attrs = {
         'output_spatial_shape': None,
         'power': param.power,
         'scale': param.scale,
         'shift': param.shift,
     }
     Power.update_node_stat(node, attrs)
     return __class__.enabled
Beispiel #11
0
    def replace_op(self, graph: nx.MultiDiGraph, node: Node):
        reciprocal = Power(
            graph,
            dict(scale=1,
                 power=np.float64(-1),
                 shift=0,
                 name=node.name + '/reciprocal_'))
        mul = Eltwise(graph, dict(operation='mul', name=node.name + '/mul_'))

        out_node = mul.create_node([
            (node.in_node(0), node.in_edge(0)['out']),
            reciprocal.create_node([(node.in_node(1), node.in_edge(1)['out'])])
        ])
        # Replace edge from out port 0 of the matched node with a edge from node out_node.id with port 0.
        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [out_node.id]
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['minimum']
        # Constant propagation case
        if node.in_node(0).value is not None and node.in_node(1).value is not None:
            return

        negate_1 = Power(graph, dict(scale=-1, name=node.name + '/negate1_'))
        negate_2 = Power(graph, dict(scale=-1, name=node.name + '/negate2_'))
        maximum = Eltwise(graph, dict(operation='max', name=node.name + '/Max_'))
        negate_output = Power(graph, dict(scale=-1, name=node.name + '/negate_out_'))

        negate_output.create_node_with_data(
            inputs=[maximum.create_node_with_data([negate_1.create_node_with_data([node.in_node(0)]),
                                                   negate_2.create_node_with_data([node.in_node(1)])])],
            data_nodes=node.out_node())
        # Delete minimum vertex
        node.graph.remove_node(node.id)
    def replace_pattern(graph: Graph, match: [str, Node]):
        op = match['op']
        op_type = op.type

        const_port, tensor_port = get_value_in_port(op), get_tensor_in_port(op)
        if const_port is None or tensor_port is None:
            return
        value = const_port.data.get_value()
        assert value is not None
        if value.size != 1:
            return
        value = value.item(0)

        assert op_type in EltwisesWithScalarInputToPower.eltw_types
        if op_type == 'Add':
            delete_node = value == 0
            Power.update_node_stat(op, {'shift': value})
        elif op_type == 'Multiply':
            delete_node = value == 1
            Power.update_node_stat(op, {'scale': value})
        elif op_type == 'Pow':
            delete_node = value == 1
            Power.update_node_stat(op, {'power': value})

        const_port.disconnect()
        if tensor_port.idx != 0:
            tensor_port.get_connection().set_destination(op.in_port(0))
Beispiel #14
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['op']
        if (node.data_format != b'NHWC' or len(node.in_nodes()) != 5
                or node.in_node(0).value is not None or  # input
                node.in_node(1).value is None or  # scale
                node.in_node(2).value is None or  # offset
                node.in_node(3).value is not None or  # mean
                node.in_node(4).value is not None or  # variance
                node.in_node(1).value.ndim != 1 or
                node.in_node(2).value.ndim != 1):
            return

        scale_mul = Eltwise(
            graph, dict(operation='mul', name=node.name + '/scale_mul_'))
        shift_add = Eltwise(
            graph, dict(operation='sum', name=node.name + '/shift_add_'))
        mean_add = Eltwise(
            graph, dict(operation='sum', name=node.name + '/mean_add_'))
        variance_mul = Eltwise(
            graph, dict(operation='mul', name=node.name + '/variance_mul_'))

        mean_negate = Power(graph,
                            dict(scale=-1, name=node.name + '/mean_negate_'))
        mean_arg = mean_add.create_node_with_data([
            node.in_node(0),
            mean_negate.create_node_with_data([node.in_node(3)])
        ])

        variance_square = Power(
            graph, dict(power=2, name=node.name + '/variance_square_'))
        variance_denom = Power(
            graph,
            dict(shift=node.eps,
                 power=-0.5,
                 name=node.name + '/variance_denom_'))
        variance_arg = variance_mul.create_node_with_data([
            mean_arg,
            variance_denom.create_node_with_data([node.in_node(4)])
        ])

        shift_add.create_node_with_data([
            scale_mul.create_node_with_data([variance_arg,
                                             node.in_node(1)]),
            node.in_node(2)
        ],
                                        data_nodes=node.out_node())

        node.graph.remove_node(node.id)
Beispiel #15
0
 def extract(cls, node):
     Power.update_node_stat(node, {'power': -0.5})
     return cls.enabled
Beispiel #16
0
 def extract(node):
     # update the attributes of the node
     Power.update_node_stat(node, {'power': 1 / 2, 'op': SqrtExtractor.op})
     return __class__.enabled
Beispiel #17
0
 def extract(node):
     Power.update_node_stat(node, {'power': -0.5})
     return __class__.enabled
Beispiel #18
0
 def extract(node):
     # update the attributes of the node
     Power.update_node_stat(node, {'power': 2})
     return __class__.enabled
Beispiel #19
0
 def extract(node):
     Power.update_node_stat(node, {'scale': 0})
     return __class__.enabled
Beispiel #20
0
 def extract(node: Node):
     Power.update_node_stat(node)
     return __class__.enabled
Beispiel #21
0
    def replace_pattern(self, graph: Graph, match: dict):

        const = 0.99
        merge = match['merge']
        digits = significant_digits()
        pnorm = Power(
            graph, {
                'name': merge.name + '/reciprocal_',
                'type': 'PNORM',
                'significant': digits[0],
                'to_significant': digits[1],
                'scale': 1,
                'shift': 0,
                'power': get_power_attr()
            }).create_node()
        merge.in_port(0).get_connection().set_destination(pnorm.in_port(0))

        in_shape = pnorm.in_port(0).data.get_shape()
        in_shape = list(in_shape)
        in_shape.insert(0, 1)

        reshape1 = Reshape(graph, {
            'name': merge.name + '/Reshape_Node1'
        }).create_node()
        reshape_dim1 = Const(graph, {
            'value': np.array(in_shape),
            'name': merge.id + '/Reshape_Dim1'
        }).create_node()
        pnorm.out_port(0).connect(reshape1.in_port(0))
        reshape1.in_port(1).connect(reshape_dim1.out_port(0))

        concat_node = Concat(
            graph, {
                'axis': 0,
                'name': merge.name + '/Concat_',
                'override_output_shape': True
            }).create_node()
        const3 = Const(graph, {
            'name': merge.name + '/const_reduce',
            'value': 0
        }).create_node()

        for ii, idx in enumerate(
                range(pnorm.significant, pnorm.to_significant + 1, 1)):
            const_node = Const(
                graph, {
                    'value': float_array(math.pow(const, idx)),
                    'name': merge.name + '/Const_' + ii.__str__()
                }).create_node()

            mul_node = Mul(graph, {
                'name': merge.name + '/Mul_' + ii.__str__()
            }).create_node()
            const_node.out_port(0).connect(mul_node.in_port(0))

            reshape1.out_port(0).connect(
                mul_node.in_port(1))  # connect to the graph node
            mul_node2 = Mul(graph, {
                'name': merge.name + '/Mul_Div_' + ii.__str__()
            }).create_node()

            const_node2 = Const(
                graph, {
                    'value': float_array(math.pow(const, -1 * idx)),
                    'name': merge.name + '/Const_Pow_' + ii.__str__()
                }).create_node()
            cast_node = ExpOp(graph, {
                'name': merge.name + '/Exp_' + idx.__str__()
            }).create_node()

            mul_node.out_port(0).connect(cast_node.in_port(0))
            const_node2.out_port(0).connect(mul_node2.in_port(1))
            cast_node.out_port(0).connect(mul_node2.in_port(0))
            concat_node.add_input_port(ii, skip_if_exist=True)
            concat_node.in_port(ii).get_connection().set_source(
                mul_node2.out_port(0))

        in_shape = pnorm.in_port(0).data.get_shape()
        in_shape = list(in_shape)

        reducesum_node = ReduceMean(
            graph, {
                'name': merge.id + '/_pnorm_reduced_sum',
                'keep_dims': True,
                'in_ports_count': 2,
                'shape': in_shape,
                'axis': 0,
                'need_shape_inference': None,
                'infer': reduce_infer
            }).create_node()

        const3.out_port(0).connect(reducesum_node.in_port(1))
        reducesum_node.in_port(0).get_connection().set_source(
            concat_node.out_port(0))

        reshape = Reshape(graph, {
            'name': merge.name + '/Reshape_Node'
        }).create_node()
        reshape_dim = Const(graph, {
            'value': np.array(in_shape),
            'name': merge.id + '/Reshape_Dim'
        }).create_node()
        reducesum_node.out_port(0).connect(reshape.in_port(0))
        reshape.in_port(1).connect(reshape_dim.out_port(0))
        merge.out_port(0).get_connection().set_source(reshape.out_port(0))
Beispiel #22
0
 def extract(cls, node):
     Power.update_node_stat(node, {'scale': 0})
     return cls.enabled
 def extract(node: Node):
     scale = onnx_attr(node, 'scale', 'f', default=np.array(1.0), dst_type=lambda x: np.array(x))
     Power.update_node_stat(node, {'scale': scale})
     return __class__.enabled
Beispiel #24
0
 def extract(cls, node):
     # update the attributes of the node
     Power.update_node_stat(node, {'power': 2})
     return cls.enabled
Beispiel #25
0
    def replace_pattern(self, graph: Graph, match: dict):
        assert match['operator'].has('multiplication_transparent_ports')

        port = match['operator'].input_ports_with(match['quantized'])
        assert len(port) >= 1
        if len(port) > 1:
            log.debug(
                'BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more'
                ' than once'.format(match['quantized'].name))
            return

        assert len(port) == 1
        port = port[0]
        applicable = [
            pair for pair in match['operator'].multiplication_transparent_ports
            if pair[0] == port
        ]
        if len(applicable) == 0:
            return

        # Look at 3-rd and 4-th inputs of Quantize -- they have constants that should be passed through.
        # Assume that the constant that should be passed through is a scalar.
        quantize = match['quantize']
        output_low = quantize.in_node(3)
        output_high = quantize.in_node(4)

        if not output_low.has_valid('value') and not output_high.has_valid(
                'value'):
            return

        output_low = output_low.value
        output_high = output_high.value

        # This pass is applicable for binarization only. Other intX variants are not relevant.
        if quantize.levels != 2:
            return

        # Recognize two cases: 0/+1 and -1/+1.
        zp1 = np.all(output_low == 0) or np.all(output_high == 0)
        m1p1 = np.all(-output_low == output_high)
        if (not zp1 and not m1p1) or (zp1 and m1p1):
            log.debug(
                'BinarizeWeightsM1P1 cannot apply transformation for data {} because it does\'t has one of'
                ' 0/+1 or -1/+1 forms.'.format(match['quantized'].name))
            return

        # Recognize scalar
        if len(np.unique(output_low)) != 1 or len(np.unique(output_high)) != 1:
            log.debug(
                'BinarizeWeightsM1P1 cannot apply transformation for data {} because output_low or output_high '
                'cannot be interpreted as scalars.'.format(
                    match['quantized'].name))
            return

        # TODO: Extract real scalar from 3rd and 4th inputs; reusing original tensors is dangerous because
        #       it may have incompatible shape.

        mult_term = quantize.in_node(3) if np.all(
            output_high == 0) else quantize.in_node(4)

        # Patch inflow path (by diving by mult_term)
        # Put a new Power/Mul combination here:
        #       ---->---- (here)---> data ---> [3rd/4th ports]quantize ---> quantized ---> operator

        if len(match['quantized'].out_nodes()) > 1:
            log.debug(
                'BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1'
            )
            return
        div_op = Power(graph, {
            'name': quantize.name + '/DivNormalize',
            'power': -1.0
        })
        div_output = div_op.create_node_with_data([mult_term])

        for i in [3, 4]:
            match['quantize'].insert_node_with_data_before(
                match['quantize'].in_node(i),
                Mul,
                dict(name=quantize.name + '/MulNormalize'),
                additional_inputs=[div_output],
            )

        match[
            'quantized'].value = None  # reset value because it will be recomputed
        match['quantize'].infer(match['quantize'])

        # Put a complimentary new Mul node here:   operator -->---(here)-----> operator.out_node()

        match['operator'].insert_node_with_data_after(
            match['operator'].out_node(),
            Mul,
            dict(name=match['operator'].name + '/MulNormalize'),
            [mult_term],
        )

        # Disable 'operator' fusion with linear ops, otherwise it will annihilate changes that we just made
        match['operator']['can_be_fused'] = False