Example #1
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['op']
        if (node.data_format != b'NHWC' or len(node.in_nodes()) != 5
                or node.in_node(0).value is not None or  # input
                node.in_node(1).value is None or  # scale
                node.in_node(2).value is None or  # offset
                node.in_node(3).value is not None or  # mean
                node.in_node(4).value is not None or  # variance
                node.in_node(1).value.ndim != 1 or
                node.in_node(2).value.ndim != 1):
            return

        scale_mul = Eltwise(
            graph, dict(operation='mul', name=node.name + '/scale_mul_'))
        shift_add = Eltwise(
            graph, dict(operation='sum', name=node.name + '/shift_add_'))
        mean_add = Eltwise(
            graph, dict(operation='sum', name=node.name + '/mean_add_'))
        variance_mul = Eltwise(
            graph, dict(operation='mul', name=node.name + '/variance_mul_'))

        mean_negate = Power(graph,
                            dict(scale=-1, name=node.name + '/mean_negate_'))
        mean_arg = mean_add.create_node_with_data([
            node.in_node(0),
            mean_negate.create_node_with_data([node.in_node(3)])
        ])

        variance_square = Power(
            graph, dict(power=2, name=node.name + '/variance_square_'))
        variance_denom = Power(
            graph,
            dict(shift=node.eps,
                 power=-0.5,
                 name=node.name + '/variance_denom_'))
        variance_arg = variance_mul.create_node_with_data([
            mean_arg,
            variance_denom.create_node_with_data([node.in_node(4)])
        ])

        shift_add.create_node_with_data([
            scale_mul.create_node_with_data([variance_arg,
                                             node.in_node(1)]),
            node.in_node(2)
        ],
                                        data_nodes=node.out_node())

        node.graph.remove_node(node.id)
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['minimum']
        # Constant propagation case
        if node.in_node(0).value is not None and node.in_node(1).value is not None:
            return

        negate_1 = Power(graph, dict(scale=-1, name=node.name + '/negate1_'))
        negate_2 = Power(graph, dict(scale=-1, name=node.name + '/negate2_'))
        maximum = Eltwise(graph, dict(operation='max', name=node.name + '/Max_'))
        negate_output = Power(graph, dict(scale=-1, name=node.name + '/negate_out_'))

        negate_output.create_node_with_data(
            inputs=[maximum.create_node_with_data([negate_1.create_node_with_data([node.in_node(0)]),
                                                   negate_2.create_node_with_data([node.in_node(1)])])],
            data_nodes=node.out_node())
        # Delete minimum vertex
        node.graph.remove_node(node.id)
Example #3
0
    def replace_pattern(self, graph: Graph, match: dict):
        assert match['operator'].has('multiplication_transparent_ports')

        port = match['operator'].input_ports_with(match['quantized'])
        assert len(port) >= 1
        if len(port) > 1:
            log.debug(
                'BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more'
                ' than once'.format(match['quantized'].name))
            return

        assert len(port) == 1
        port = port[0]
        applicable = [
            pair for pair in match['operator'].multiplication_transparent_ports
            if pair[0] == port
        ]
        if len(applicable) == 0:
            return

        # Look at 3-rd and 4-th inputs of Quantize -- they have constants that should be passed through.
        # Assume that the constant that should be passed through is a scalar.
        quantize = match['quantize']
        output_low = quantize.in_node(3)
        output_high = quantize.in_node(4)

        if not output_low.has_valid('value') and not output_high.has_valid(
                'value'):
            return

        output_low = output_low.value
        output_high = output_high.value

        # This pass is applicable for binarization only. Other intX variants are not relevant.
        if quantize.levels != 2:
            return

        # Recognize two cases: 0/+1 and -1/+1.
        zp1 = np.all(output_low == 0) or np.all(output_high == 0)
        m1p1 = np.all(-output_low == output_high)
        if (not zp1 and not m1p1) or (zp1 and m1p1):
            log.debug(
                'BinarizeWeightsM1P1 cannot apply transformation for data {} because it does\'t has one of'
                ' 0/+1 or -1/+1 forms.'.format(match['quantized'].name))
            return

        # Recognize scalar
        if len(np.unique(output_low)) != 1 or len(np.unique(output_high)) != 1:
            log.debug(
                'BinarizeWeightsM1P1 cannot apply transformation for data {} because output_low or output_high '
                'cannot be interpreted as scalars.'.format(
                    match['quantized'].name))
            return

        # TODO: Extract real scalar from 3rd and 4th inputs; reusing original tensors is dangerous because
        #       it may have incompatible shape.

        mult_term = quantize.in_node(3) if np.all(
            output_high == 0) else quantize.in_node(4)

        # Patch inflow path (by diving by mult_term)
        # Put a new Power/Mul combination here:
        #       ---->---- (here)---> data ---> [3rd/4th ports]quantize ---> quantized ---> operator

        if len(match['quantized'].out_nodes()) > 1:
            log.debug(
                'BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1'
            )
            return
        div_op = Power(graph, {
            'name': quantize.name + '/DivNormalize',
            'power': -1.0
        })
        div_output = div_op.create_node_with_data([mult_term])

        for i in [3, 4]:
            match['quantize'].insert_node_with_data_before(
                match['quantize'].in_node(i),
                Mul,
                dict(name=quantize.name + '/MulNormalize'),
                additional_inputs=[div_output],
            )

        match[
            'quantized'].value = None  # reset value because it will be recomputed
        match['quantize'].infer(match['quantize'])

        # Put a complimentary new Mul node here:   operator -->---(here)-----> operator.out_node()

        match['operator'].insert_node_with_data_after(
            match['operator'].out_node(),
            Mul,
            dict(name=match['operator'].name + '/MulNormalize'),
            [mult_term],
        )

        # Disable 'operator' fusion with linear ops, otherwise it will annihilate changes that we just made
        match['operator']['can_be_fused'] = False