Exemplo n.º 1
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        op = match['op']
        out_port = op.in_port(0).get_source()

        if op.soft_get('scale', 1) != 1:
            const = Const(graph, {'value': np.array(op.scale)}).create_node()
            mul = Mul(graph, {'name': op.name + '/mul_'}).create_node()
            const.out_port(0).connect(mul.in_port(1))
            out_port.connect(mul.in_port(0))
            out_port = mul.out_port(0)

        if op.soft_get('shift', 0) != 0:
            const = Const(graph, {'value': np.array(op.shift)}).create_node()
            add = Add(graph, {'name': op.name + '/add_'}).create_node()
            const.out_port(0).connect(add.in_port(1))
            out_port.connect(add.in_port(0))
            out_port = add.out_port(0)

        if op.soft_get('power', 1) != 1:
            const = Const(graph, {'value': np.array(op.power)}).create_node()
            pow = Pow(graph, {'name': op.name + '/pow_'}).create_node()
            const.out_port(0).connect(pow.in_port(1))
            out_port.connect(pow.in_port(0))
            out_port = pow.out_port(0)

        op.out_port(0).get_connection().set_source(out_port)
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['op']
        if (node.data_format != b'NHWC' or len(node.in_nodes()) != 5
                or node.in_node(0).value is not None or  # input
                node.in_node(1).value is None or  # scale
                node.in_node(2).value is None or  # offset
                node.in_node(3).value is not None or  # mean
                node.in_node(4).value is not None or  # variance
                node.in_node(1).value.ndim != 1 or
                node.in_node(2).value.ndim != 1):
            return

        scale_mul = Mul(graph, dict(name=node.name + '/scale_mul_'))
        shift_add = Add(graph, dict(name=node.name + '/shift_add_'))
        mean_add = Add(graph, dict(name=node.name + '/mean_add_'))
        variance_mul = Mul(graph, dict(name=node.name + '/variance_mul_'))

        neg_const = Const(
            graph, dict(value=np.array(-1), name=node.name + '/mean_negate_'))
        mean_negate = Mul(graph, dict(name=node.name + '/mean_negate_'))
        mean_arg = mean_add.create_node_with_data([
            node.in_node(0),
            mean_negate.create_node_with_data(
                [node.in_node(3),
                 neg_const.create_node_with_data()])
        ])

        shift_const = Const(
            graph,
            dict(value=node.eps,
                 name=node.name + '/variance_denom_shift_const_'))
        power_const = Const(
            graph,
            dict(value=-0.5, name=node.name + '/variance_denom_power_const_'))
        variance_denom_shift = Add(
            graph, dict(name=node.name + '/variance_denom_shift_'))
        variance_denom_power = Pow(
            graph, dict(name=node.name + '/variance_denom_power_'))
        variance_arg = variance_mul.create_node_with_data([
            mean_arg,
            variance_denom_power.create_node_with_data([
                variance_denom_shift.create_node_with_data(
                    [node.in_node(4),
                     shift_const.create_node_with_data()]),
                power_const.create_node_with_data()
            ])
        ])

        shift_add.create_node_with_data([
            scale_mul.create_node_with_data([variance_arg,
                                             node.in_node(1)]),
            node.in_node(2)
        ],
                                        data_nodes=node.out_node())

        node.graph.remove_node(node.id)
 def replace_op(self, graph: Graph, node: Node):
     const = Const(
         graph,
         dict(value=np.array(-1.),
              name=node.name + '/reciprocal_pow_const_')).create_node()
     reciprocal = Pow(graph, {
         'name': node.name + '/reciprocal_pow_'
     }).create_node()
     node.in_port(0).get_connection().set_destination(reciprocal.in_port(0))
     const.out_port(0).connect(reciprocal.in_port(1))
     return [reciprocal.id]
Exemplo n.º 4
0
    def replace_op(self, graph: Graph, node: Node):
        pow_2 = Const(graph, {'value': np.float32(2.0)}).create_node()
        reduce_axis = Const(graph, {'value': np.int32(-1)}).create_node()
        pow_0_5 = Const(graph, {'value': np.float32(0.5)}).create_node()

        sq = Pow(graph, dict(name=node.in_node(0).name + '/sq',
                             power=2.0)).create_node([node.in_node(0), pow_2])
        sum = ReduceSum(graph, dict(name=sq.name + '/sum')).create_node(
            [sq, reduce_axis])
        sqrt = Pow(graph, dict(name=sum.name + '/sqrt',
                               power=0.5)).create_node([sum, pow_0_5])
        return [sqrt.id]
    def placeholder_scales(self, placeholder: Node):
        """
        Helper function to get scales for prior boxes out of input image size:
                [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height]
        """
        graph = placeholder.graph
        name = placeholder.soft_get('name', placeholder.id)

        shape_value = placeholder.soft_get('shape', None)
        assert shape_value is not None, \
            "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name)
        assert isinstance(shape_value, np.ndarray), \
            "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name)
        assert shape_value.size == 4, \
            "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value)

        shape = Shape(graph, {'name': 'input_image_shape'}).create_node()
        shape.in_port(0).connect(placeholder.out_port(0))

        begin = Const(graph, {'value': int64_array([1])}).create_node()
        end = Const(graph, {'value': int64_array([3])}).create_node()
        stride = Const(graph, {'value': int64_array([1])}).create_node()
        spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]),
                                       'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
                                       'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()

        spatial.in_port(0).connect(shape.out_port(0))
        spatial.in_port(1).connect(begin.out_port(0))
        spatial.in_port(2).connect(end.out_port(0))
        spatial.in_port(3).connect(stride.out_port(0))

        power = Const(graph, {'value': float32_array([-1.])}).create_node()
        spatial_scale = Pow(graph, {}).create_node()

        spatial_scale.in_port(0).connect(spatial.out_port(0))
        spatial_scale.in_port(1).connect(power.out_port(0))

        # Power `type_infer` requires inputs to have equal data type
        convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node()
        spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32)

        order = Const(graph, {'value': int64_array([1, 0])}).create_node()
        axis_const = Const(graph, {'value': int64_array(0)}).create_node()
        reverse = Gather(graph, {}).create_node()

        reverse.in_port(0).connect(spatial_scale.out_port(0))
        reverse.in_port(1).connect(order.out_port(0))
        axis_const.out_port(0).connect(reverse.in_port(2))

        priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node()
        priors_scale_node.add_input_port(0, skip_if_exist=True)
        priors_scale_node.add_input_port(1, skip_if_exist=True)

        priors_scale_node.in_port(0).connect(reverse.out_port(0))
        priors_scale_node.in_port(1).connect(reverse.out_port(0))
        return priors_scale_node
    def replace_op(self, graph: Graph, node: Node):
        # Create nodes
        const_neg = Const(
            graph, dict(value=np.array(-1),
                        name=node.name + '/negate_const_')).create_node()
        negate = Mul(graph, {'name': node.name + '/negate_'}).create_node()
        add = Add(graph, {'name': node.name + '/add_'}).create_node()

        const = Const(graph, {'value': np.array(2)}).create_node()
        squared = Pow(graph, {'name': node.name + '/squared_'}).create_node()

        # Connect nodes
        node.in_port(0).get_connection().set_destination(add.in_port(0))
        node.in_port(1).get_connection().set_destination(negate.in_port(0))
        const_neg.out_port(0).connect(negate.in_port(1))
        negate.out_port(0).connect(add.in_port(1))
        add.out_port(0).connect(squared.in_port(0))
        const.out_port(0).connect(squared.in_port(1))

        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [squared.id]
    def replace_op(self, graph: Graph, node: Node):
        power_of_exponent = Const(graph, {
            'value': np.float64(-1)
        }).create_node()
        reciprocal = Pow(graph, {
            'name': node.name + '/reciprocal_'
        }).create_node()
        mul = Mul(graph, {'name': node.name + '/mul_'}).create_node()

        # Connect nodes
        node.in_port(1).get_connection().set_destination(reciprocal.in_port(0))
        power_of_exponent.out_port(0).connect(reciprocal.in_port(1))
        node.in_port(0).get_connection().set_destination(mul.in_port(1))
        reciprocal.out_port(0).connect(mul.in_port(0))

        # The "explicit" version of the return value is: [(out_node.id, 0)])
        return [mul.id]
    def replace_pattern(graph: Graph, match: [str, Node]):
        node = match['div']
        power_of_exponent = Const(graph, {
            'value': np.float64(-1)
        }).create_node()
        reciprocal = Pow(graph, {
            'name': node.name + '/reciprocal_'
        }).create_node()
        mul = Mul(graph, {'name': node.name + '/mul_'}).create_node()

        # Connect nodes
        node.in_port(1).get_connection().set_destination(reciprocal.in_port(0))
        power_of_exponent.out_port(0).connect(reciprocal.in_port(1))
        node.in_port(0).get_connection().set_destination(mul.in_port(1))
        reciprocal.out_port(0).connect(mul.in_port(0))

        node.out_port(0).get_connection().set_source(mul.out_port(0))
Exemplo n.º 9
0
 def extract(cls, node):
     Pow.update_node_stat(
         node, {'data_type': tf_dtype_extractor(node.pb.attr["T"].type)})
     return cls.enabled
Exemplo n.º 10
0
    def replace_pattern(self, graph: Graph, match: dict):

        merge = match['merge']
        power = Pow(graph, {
            'name': merge.name + '/reciprocal_',
            'type': 'PNORM'
        }).create_node()
        const1 = Const(graph, {
            'value': -1.0,
            'name': merge.name + '/negate_const'
        }).create_node()
        merge.in_port(0).get_connection().set_destination(power.in_port(0))
        const1.out_port(0).connect(power.in_port(1))

        concat_node = Concat(
            graph, {
                'axis': 0,
                'name': merge.name + '/Concat_',
                'override_output_shape': True
            }).create_node()
        const3 = Const(graph, {
            'name': merge.name + '/const_reduce',
            'value': 0
        }).create_node()

        for ii, idx in enumerate(
                range(merge.significant, merge.to_significant + 1, 1)):
            const_node = Const(
                graph, {
                    'value': float_array(math.pow(10.0, idx)),
                    'name': merge.name + '/Const_' + ii.__str__()
                }).create_node()

            mul_node = Mul(graph, {
                'name': merge.name + '/Mul_' + ii.__str__()
            }).create_node()
            const_node.out_port(0).connect(mul_node.in_port(0))

            power.out_port(0).connect(
                mul_node.in_port(1))  # connect to the graph node
            mul_node2 = Mul(graph, {
                'name': merge.name + '/Mul_Div_' + ii.__str__()
            }).create_node()

            const_node2 = Const(
                graph, {
                    'value': float_array(math.pow(10.0, -1 * idx)),
                    'name': merge.name + '/Const_Pow_' + ii.__str__()
                }).create_node()
            cast_node = Cast(
                graph, {
                    'name': merge.name + '/Cast_' + idx.__str__(),
                    'dst_type': np.float32
                }).create_node()

            mul_node.out_port(0).connect(cast_node.in_port(0))
            const_node2.out_port(0).connect(mul_node2.in_port(1))
            cast_node.out_port(0).connect(mul_node2.in_port(0))
            concat_node.add_input_port(ii, skip_if_exist=True)
            concat_node.in_port(ii).get_connection().set_source(
                mul_node2.out_port(0))

        reducesum_node = ReduceMean(
            graph, {
                'name': merge.id + '/_pnorm_reduced_sum',
                'keep_dims': False,
                'in_ports_count': 2,
                'need_shape_inference': None,
                'infer': reduce_infer
            }).create_node()

        const3.out_port(0).connect(reducesum_node.in_port(1))
        reducesum_node.in_port(0).get_connection().set_source(
            concat_node.out_port(0))

        reshape = Reshape(graph, {
            'name': merge.name + '/Reshape_Node'
        }).create_node()
        reshape_dim = Const(graph, {
            'value': np.array([1, 5]),
            'name': merge.id + '/Reshape_Dim'
        }).create_node()
        reducesum_node.out_port(0).connect(reshape.in_port(0))
        reshape.in_port(1).connect(reshape_dim.out_port(0))
        merge.out_port(0).get_connection().set_source(reshape.out_port(0))
Exemplo n.º 11
0
    def replace_pattern(self, graph: Graph, match: dict):
        assert match['operator'].has('multiplication_transparent_ports')

        port = match['operator'].input_ports_with(match['quantized'])
        assert len(port) >= 1
        if len(port) > 1:
            log.debug(
                'BinarizeWeightsM1P1 cannot apply transformation for data {} because it consumed more'
                ' than once'.format(match['quantized'].name))
            return

        assert len(port) == 1
        port = port[0]
        applicable = [
            pair for pair in match['operator'].multiplication_transparent_ports
            if pair[0] == port
        ]
        if len(applicable) == 0:
            return

        # Look at 3-rd and 4-th inputs of FakeQuantize -- they have constants that should be passed through.
        # Assume that the constant that should be passed through is a scalar.
        quantize = match['quantize']
        output_low = quantize.in_node(3)
        output_high = quantize.in_node(4)

        quantize_name = quantize.soft_get('name', quantize.id)

        if not output_low.has_valid('value') and not output_high.has_valid(
                'value'):
            return

        output_low = output_low.value
        output_high = output_high.value

        # This pass is applicable for binarization only. Other intX variants are not relevant.
        if quantize.levels != 2:
            return

        # Recognize two cases: 0/+1 and -1/+1.
        zp1 = np.all(output_low == 0) or np.all(output_high == 0)
        m1p1 = np.all(-output_low == output_high)
        if (not zp1 and not m1p1) or (zp1 and m1p1):
            log.debug(
                'BinarizeWeightsM1P1 cannot apply transformation for data {} because it does\'t has one of'
                ' 0/+1 or -1/+1 forms.'.format(match['quantized'].name))
            return

        # TODO: Extract real scalar from 3rd and 4th inputs; reusing original tensors is dangerous because
        #       it may have incompatible shape.

        mult_term = quantize.in_node(3) if np.all(
            output_high == 0) else quantize.in_node(4)

        new_shape = Const(
            graph, {
                'name': quantize_name + '/Reshape/Shape',
                'value': int64_array([-1, 1, 1])
            }).create_node_with_data()
        reshape = Reshape(graph, {
            'name': quantize_name + '/Reshape'
        }).create_node_with_data([mult_term, new_shape])

        # Patch inflow path (by diving by mult_term)
        # Put a new Pow/Mul combination here:
        #       ---->---- (here)---> data ---> [3rd/4th ports]quantize ---> quantized ---> operator

        if len(match['quantized'].out_nodes()) > 1:
            log.debug(
                'BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1'
            )
            return
        power_of_exponent = Const(graph, {
            'name': quantize_name + '/DivNormalize/Power',
            'value': np.array(-1.0)
        }).create_node_with_data()
        div_op = Pow(graph, {'name': quantize_name + '/DivNormalize'})
        div_output = div_op.create_node_with_data(
            [mult_term, power_of_exponent])

        for i in [3, 4]:
            match['quantize'].insert_node_with_data_before(
                match['quantize'].in_node(i),
                Mul,
                dict(name=quantize_name + '/MulNormalize'),
                additional_inputs=[div_output],
            )

        match[
            'quantized'].value = None  # reset value because it will be recomputed
        match['quantize'].infer(match['quantize'])

        # Put a complimentary new Mul node here:   operator -->---(here)-----> operator.out_node()

        match['operator'].insert_node_with_data_after(
            match['operator'].out_node(),
            Mul,
            dict(name=match['operator'].name + '/MulNormalize'),
            [reshape],
        )

        # Disable 'operator' fusion with linear ops, otherwise it will annihilate changes that we just made
        match['operator']['can_be_fused'] = False
Exemplo n.º 12
0
 def extract(cls, node):
     Pow.update_node_stat(node)
     return cls.enabled
Exemplo n.º 13
0
 def extract(node: Node):
     Pow.update_node_stat(node)
     return __class__.enabled