def replace_pattern(graph: Graph, match: [str, Node]):
        pow = match['inv']
        mul = match['mul']
        const = match['const']

        name = mul.soft_get('name', mul.id)

        devidend_port = mul.in_port(0).get_source() if mul.in_port(
            1).get_source().node.id == pow.id else mul.in_port(1).get_source()
        divider_port = pow.in_port(0).get_source() if pow.in_port(
            1).get_source().node.id == const.id else pow.in_port(
                1).get_source()

        div = Div(graph, {'name': name + '/div'}).create_node()

        mul.out_port(0).get_connection().set_source(div.out_port(0))
        devidend_port.connect(div.in_port(0))
        divider_port.connect(div.in_port(1))
    def floor_div_replacement(floor_div: Node):
        graph = floor_div.graph
        name = floor_div.soft_get('name', floor_div.id)

        div = Div(graph, {'name': name + '/Div'}).create_node()
        floor = Floor(graph, {'name': name}).create_node()
        div.out_port(0).connect(floor.in_port(0))

        div.in_port(0).connect(floor_div.in_port(0).get_source())
        div.in_port(1).connect(floor_div.in_port(1).get_source())
        floor_div.out_port(0).get_connection().set_source(floor.out_port(0))

        graph.remove_node(floor_div.id)
        rename_node(floor, name)
Пример #3
0
    def replace_op(self, graph: Graph, node: Node):
        """
        Replace Softsign according to formula feature/(abs(feature)+1)
        """
        abs_node = Abs(graph, {'name': "abs_" + node.id}).create_node()
        abs_node.in_port(0).connect(node.in_port(0).get_source())

        add_node = create_op_node_with_second_input(graph, Add, np.ones(
            [1]), {"name": node.id + "_plus_1"})
        add_node.in_port(0).connect(abs_node.out_port(0))
        div_node = Div(graph, {"name": "div_" + node.id}).create_node()
        div_node.in_port(0).connect(node.in_port(0).get_source())
        div_node.in_port(1).connect(add_node.out_port(0))
        return [div_node.id]
Пример #4
0
    def test_value_propagation(self, a_shape, a_value, b_shape, b_value,
                               elem_type):
        graph = build_graph(nodes_attrs=graph_nodes_attrs,
                            edges=graph_edges,
                            update_attributes={
                                'A': {
                                    'shape': int64_array(a_shape),
                                    'value': a_value.astype(elem_type)
                                },
                                'A_data': {
                                    'shape': int64_array(a_shape),
                                    'value': a_value.astype(elem_type)
                                },
                                'B': {
                                    'shape': int64_array(b_shape),
                                    'value': b_value.astype(elem_type)
                                },
                                'B_data': {
                                    'shape': int64_array(b_shape),
                                    'value': b_value.astype(elem_type)
                                },
                            })
        node = Node(graph, 'div')
        node['infer'] = Div(graph, node.attrs()).create_node().infer
        node.infer(node)
        node_data = node.out_port(0).get_destination().data.get_value()

        def func_for_ref():
            if np.issubdtype(elem_type, np.integer):
                return lambda a, b: a // b
            else:
                return lambda a, b: a / b

        ref_data = func_for_ref()(a_value, b_value)
        node_data_shape = node_data.shape
        ref_data_shape = ref_data.shape
        msg = "Value propagation for 'div' node is not correct."
        self.assertTrue(
            node_data_shape == ref_data_shape
            and np.all(node_data == ref_data), msg)
Пример #5
0
 def extract(cls, node):
     Div.update_node_stat(
         node, {'data_type': tf_dtype_extractor(node.pb.attr["T"].type)})
     return cls.enabled
Пример #6
0
def replace_resize(graph: Graph, resize: Node):
    log.debug("Converting of ONNX Resize-11 to Interpolate-4 "
              "is triggered for node {}.".format(
                  resize.soft_get('name', resize.id)))

    input_shape = resize.in_port(0).data.get_shape()
    input_rank = len(input_shape)
    resize_name = resize.soft_get('name', resize.id)
    if input_rank not in {4, 5}:
        log.warning(
            'The input shape is not 4D or 5D for op with name {}'.format(
                resize_name))
        return

    num_of_inputs = len([
        port for port in resize.in_ports().values() if not port.disconnected()
    ])
    assert num_of_inputs in {3, 4}, \
        "Number of inputs of ONNXResize (with name {}) should be equal to 3 or 4".format(resize_name)

    assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \
        'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op, resize_name)

    layout = graph.graph['layout']

    if input_rank == 4:
        begin_dim = get_height_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1
    else:
        begin_dim = get_depth_dim(layout, input_rank)
        end_dim = get_width_dim(layout, input_rank) + 1

    sizes_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_sizes',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    scales_ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([begin_dim]),
            2: int64_array([end_dim]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice_scales',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })
    axes_node = Const(
        graph, {
            'name': resize_name + '/axis',
            'value': int64_array(np.arange(begin_dim, end_dim))
        }).create_node()

    shape_calculation_mode = 'scales' if num_of_inputs == 3 else 'sizes'

    interpolate_node = Interpolate(
        graph, {
            'version': 'opset4',
            'mode': convert_mode(resize.mode),
            'coordinate_transformation_mode':
            resize.coordinate_transformation_mode,
            'cube_coeff': resize.cube_coeff,
            'nearest_mode': resize.nearest_mode,
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'antialias': 0,
            'shape_calculation_mode': shape_calculation_mode,
            'in_ports_count': 4
        }).create_node()

    axes_node.out_port(0).connect(interpolate_node.in_port(3))
    shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()

    add_node = create_op_with_const_inputs(graph, Add,
                                           {1: float_array([1.0e-5])},
                                           {'name': resize_name + '/Add'})

    input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)

    if num_of_inputs == 3:
        cast_shape_to_float = Cast(graph, {
            'dst_type': input_data_type
        }).create_node()
        mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
        cast_add_result_to_int = Cast(graph, {
            'dst_type': np.int64
        }).create_node()
        floor_node = Floor(graph, {
            'name': resize_name + '/Floor'
        }).create_node()
        mul_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(floor_node.in_port(0))
        floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0))
        cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_scales = resize.in_port(2).get_connection()
        connection_of_scales.set_destination(scales_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_scales.get_source().connect(mul_node.in_port(1))
    else:
        cast_shape_to_float = Cast(graph, {
            'dst_type': input_data_type
        }).create_node()
        cast_sizes_to_float = Cast(graph, {
            'dst_type': input_data_type
        }).create_node()
        div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
        cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
        cast_shape_to_float.out_port(0).connect(div_node.in_port(1))
        shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
        div_node.out_port(0).connect(add_node.in_port(0))
        add_node.out_port(0).connect(scales_ss.in_port(0))
        scales_ss.out_port(0).connect(interpolate_node.in_port(2))
        sizes_ss.out_port(0).connect(interpolate_node.in_port(1))

        connection_of_resize_input = resize.in_port(0).get_connection()
        connection_of_resize_input.set_destination(interpolate_node.in_port(0))

        connection_of_sizes = resize.in_port(3).get_connection()
        connection_of_sizes.set_destination(sizes_ss.in_port(0))

        connection_of_resize_input.get_source().connect(shape_of.in_port(0))
        connection_of_sizes.get_source().connect(
            cast_sizes_to_float.in_port(0))

    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate_node, resize_name)])
    resize.out_port(0).get_connection().set_source(
        interpolate_node.out_port(0))
    def replace_sub_graph(self, graph: Graph, match: Dict[str, Node]):
        node = match['op']
        name = node.name

        min_port_tuple = (node.in_port(1).get_source().node,
                          node.in_port(1).get_source().idx)
        max_port_tuple = (node.in_port(2).get_source().node,
                          node.in_port(2).get_source().idx)

        node.in_port(1).disconnect()
        node.in_port(2).disconnect()

        # make sure min < max
        min_less_max = Less(graph, {
            'name': name + '/if_min_less_max'
        }).create_node([min_port_tuple, max_port_tuple])
        minimum = Select(graph, {
            'name': name + '/minimum'
        }).create_node([min_less_max, min_port_tuple, max_port_tuple])
        maximum = Select(graph, {
            'name': name + '/maximum'
        }).create_node([min_less_max, max_port_tuple, min_port_tuple])

        # to create zero of limits data type, we multiply it by integer zero
        zero = create_op_node_with_second_input(graph,
                                                Mul,
                                                int64_array(0),
                                                {'name': name + '/zero'},
                                                input_node=minimum)

        # if 0 < min < max: min_adj = 0 and max_adj = max - min
        min_greater_zero = Greater(graph, {
            'name': name + '/if_minimum_greater_zero'
        }).create_node([minimum, zero])
        max_minus_min = Sub(graph, {
            'name': name + '/max_minus_min'
        }).create_node([maximum, minimum])
        minimum = Select(graph, {
            'name': name + '/first_adj_min'
        }).create_node([min_greater_zero, zero, minimum])
        maximum = Select(graph, {
            'name': name + '/first_adj_max'
        }).create_node([min_greater_zero, max_minus_min, maximum])

        # if min < max < 0: min_adj = min - max and max_adj = 0
        max_less_zero = Less(graph, {
            'name': name + '/if_max_less_zero'
        }).create_node([maximum, zero])
        min_minus_max = Sub(graph, {
            'name': name + '/min_minus_max'
        }).create_node([minimum, maximum])
        minimum = Select(graph, {
            'name': name + '/second_adj_min'
        }).create_node([max_less_zero, min_minus_max, minimum])
        maximum = Select(graph, {
            'name': name + '/second_adj_max'
        }).create_node([max_less_zero, zero, maximum])

        # scale = (max - min) / (2 ^ num_bits - 1),
        float_range = Sub(graph, {
            'name': name + '/float_range'
        }).create_node([maximum, minimum])
        quant_min_value, quant_max_value = int(
            node.narrow_range), 2**node.num_bits - 1
        int_range = Const(
            graph,
            dict(name=name + '/int_range',
                 value=quant_max_value - quant_min_value)).create_node()
        scale = Div(graph, {
            'name': name + '/scale'
        }).create_node([float_range, int_range])
        # min_adj = scale * round(min / scale)
        descaled_min = Div(graph, {
            'name': name + '/descaled_min'
        }).create_node([minimum, scale])
        rounded_descaled_min = Round(graph, {
            'name': name + '/rounded_descaled_min'
        }).create_node([descaled_min])
        min_adj = Mul(graph, {
            'name': name + '/min_adj'
        }).create_node([scale, rounded_descaled_min])
        # max_adj = max + min_adj - min.
        adjustment = Sub(graph, {
            'name': name + '/limits_adjustment'
        }).create_node([min_adj, minimum])
        max_adj = Add(graph, {
            'name': name + '/max_adj'
        }).create_node([maximum, adjustment])

        # FakeQuantize operation has 5 inputs instead of 3 inputs in TensorFlow
        node.add_input_port(3, skip_if_exist=True)
        node.add_input_port(4, skip_if_exist=True)

        node.in_port(1).connect(min_adj.out_port(0))
        node.in_port(2).connect(max_adj.out_port(0))
        node.in_port(3).connect(min_adj.out_port(0))
        node.in_port(4).connect(max_adj.out_port(0))

        FakeQuantize.update_node_stat(node, {'levels': node['levels']})
Пример #8
0
    def find_and_replace_pattern(self, graph: Graph):
        for embedding_segments_mean in graph.get_op_nodes(
                op='EmbeddingSegmentsMean'):
            embedding_segments_mean_name = embedding_segments_mean.soft_get(
                'name', embedding_segments_mean.id)
            embedding_table_input = embedding_segments_mean.in_port(0)
            segment_ids_input = embedding_segments_mean.in_port(2)
            num_segments_input = embedding_segments_mean.in_port(3)

            # TODO: support EmbeddingSegmentsMean with specified weights vector.
            # now this case has not appeared in models so far so EmbeddingSegmentsOperation fusion
            # transformations do not handle it either
            if embedding_segments_mean.is_in_port_connected(5):
                return

            # 1. compute indices membership matrix, i.e. which indices belong to some object
            # the shape of this matrix is [num_segments, num_indices]
            non_norm_range_1_to_num_segments = create_op_with_const_inputs(
                graph, Range, {
                    0: int64_array(0),
                    2: int64_array(1)
                }, {
                    'name':
                    embedding_segments_mean_name + '/Range1ToNumSegments',
                    'output_type': np.int64
                })
            num_segments_input.get_connection().add_destination(
                non_norm_range_1_to_num_segments.in_port(1))

            range_1_to_num_segments = ConvertLike(graph, {
                'name':
                embedding_segments_mean_name + '/Range1ToNumSegmentsNorm'
            }).create_node()
            range_1_to_num_segments.in_port(0).connect(
                non_norm_range_1_to_num_segments.out_port(0))
            num_segments_input.get_connection().add_destination(
                range_1_to_num_segments.in_port(1))

            unsqueeze_range_1_to_num_segments = create_op_with_const_inputs(
                graph, Unsqueeze, {1: int64_array(1)}, {
                    'name':
                    embedding_segments_mean_name +
                    '/Range1ToNumSegmentsUnsqueeze'
                })
            unsqueeze_range_1_to_num_segments.in_port(0).connect(
                range_1_to_num_segments.out_port(0))
            unsqueeze_segment_ids = create_op_with_const_inputs(
                graph, Unsqueeze, {1: int64_array(0)}, {
                    'name':
                    embedding_segments_mean_name + '/SegmentIdsUnsqueeze'
                })
            segment_ids_input.get_connection().add_destination(
                unsqueeze_segment_ids.in_port(0))
            boolean_membership_matrix = Equal(graph, {
                'name':
                embedding_segments_mean_name + '/BooleanMembershipMatrix'
            }).create_node()
            boolean_membership_matrix.in_port(0).connect(
                unsqueeze_range_1_to_num_segments.out_port(0))
            boolean_membership_matrix.in_port(1).connect(
                unsqueeze_segment_ids.out_port(0))
            shape_of_membership_matrix = Shape(graph, {
                'name':
                embedding_segments_mean_name + '/ShapeOfMembershipMatrix'
            }).create_node([boolean_membership_matrix])
            one_scalar_constant = Const(
                graph, {
                    'name': embedding_segments_mean_name + '/OneScalar',
                    'value': int64_array([1])
                }).create_node()
            one_constant = Broadcast(graph, {
                'name':
                embedding_segments_mean_name + '/One'
            }).create_node([one_scalar_constant, shape_of_membership_matrix])
            zero_constant = Const(
                graph, {
                    'name': embedding_segments_mean_name + '/Zero',
                    'value': int64_array(0)
                }).create_node()
            membership_matrix = Select(
                graph, {
                    'name': embedding_segments_mean_name + '/MembershipMatrix',
                    'auto_broadcast': 'numpy'
                }).create_node(
                    [boolean_membership_matrix, one_constant, zero_constant])

            # 2. compute a number of indices belong to each object from the batch
            # it computes the normalization coefficients
            num_indices_per_object = create_op_with_const_inputs(
                graph, ReduceSum, {1: int64_array(1)}, {
                    'name':
                    embedding_segments_mean_name + '/NumIndicesPerObject'
                })
            num_indices_per_object.in_port(0).connect(
                membership_matrix.out_port(0))

            # 3. replace zero coefficient (zero number of indices belong to an object) with one
            # because for such object the single default embedding vector is used
            where_zero_number = Equal(graph, {
                'name':
                embedding_segments_mean_name + '/WhereZeroIndicesNumber'
            }).create_node([num_indices_per_object, zero_constant])
            normalized_num_indices_per_object = Select(
                graph, {
                    'name':
                    embedding_segments_mean_name + '/NormNumIndicesPerObject',
                    'auto_broadcast': 'numpy'
                }).create_node([
                    where_zero_number, one_scalar_constant,
                    num_indices_per_object
                ])

            # 4. cast normalized_num_indices_per_object to the same type as embedding vector table
            norm_coefficients = ConvertLike(
                graph, {
                    'name': embedding_segments_mean_name + '/NormCoefficients'
                }).create_node()
            norm_coefficients.in_port(0).connect(
                normalized_num_indices_per_object.out_port(0))
            embedding_table_input.get_connection().add_destination(
                norm_coefficients.in_port(1))

            # 5. replace EmbeddingSegmentMean with EmbeddingSegmentSum
            embedding_segments_sum = EmbeddingSegmentsSum(
                graph, {
                    'name':
                    embedding_segments_mean_name + '/EmbeddingSegmentsSum'
                }).create_node()
            for in_port in embedding_segments_mean.in_ports():
                if embedding_segments_mean.is_in_port_connected(in_port):
                    embedding_segments_mean.in_port(
                        in_port).get_connection().set_destination(
                            embedding_segments_sum.in_port(in_port))

            # 6. normalize EmbeddingSegmentSum results by computed coefficients
            result_node = Div(graph, {
                'name': embedding_segments_mean_name + '/Div'
            }).create_node([embedding_segments_sum, norm_coefficients])
            embedding_segments_mean.out_port(0).get_connection().set_source(
                result_node.out_port(0))

            rename_nodes([(embedding_segments_mean,
                           embedding_segments_mean_name + '/AbandonedName'),
                          (result_node, embedding_segments_mean_name)])
            graph.remove_nodes_from([embedding_segments_mean.id])
Пример #9
0
def replace_tf_resize(graph: Graph, resize: Node, interpolation_mode: str):
    resize_name = resize.soft_get('name', resize.id)
    log.debug(
        "Converting of {} to Interpolate-4 is triggered for node {}.".format(
            resize.op, resize_name))

    num_of_inputs = len([
        port for port in resize.in_ports().values() if not port.disconnected()
    ])
    assert num_of_inputs == 2, \
        "Number of inputs of {} (with name {}) should be equal to 2".format(resize.op, resize_name)

    attrs_msg = "If half_pixel_centers attribute of the node {} with op {} is True, " \
                "the attribute align_corners must be False"
    assert not resize.half_pixel_centers or (resize.half_pixel_centers and not resize.align_corners), \
        attrs_msg.format(resize_name, resize.op)

    shape = Shape(graph, {'name': resize_name + '/shapeof'}).create_node()

    layout = graph.graph['layout']
    height_dim = get_height_dim(layout, 4)
    width_dim = get_width_dim(layout, 4)

    ss = create_op_with_const_inputs(
        graph, StridedSlice, {
            1: int64_array([height_dim]),
            2: int64_array([width_dim + 1]),
            3: int64_array([1])
        }, {
            'name': resize_name + '/StridedSlice',
            'begin_mask': int64_array([1]),
            'end_mask': int64_array([1]),
            'new_axis_mask': int64_array([0]),
            'shrink_axis_mask': int64_array([0]),
            'ellipsis_mask': int64_array([0])
        })

    div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()

    shape_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
    size_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()

    size_to_float.out_port(0).connect(div_node.in_port(0))
    shape_to_float.out_port(0).connect(div_node.in_port(1))
    ss.out_port(0).connect(shape_to_float.in_port(0))
    shape.out_port(0).connect(ss.in_port(0))

    align_corners = resize.align_corners
    half_pixel_centers = resize.half_pixel_centers

    nearest_mode = 'floor' if interpolation_mode == 'nearest' else 'round_prefer_floor'
    if align_corners:
        coordinate_transformation_mode = 'align_corners'
        if interpolation_mode == 'nearest':
            nearest_mode = 'round_prefer_ceil'
    elif half_pixel_centers:
        coordinate_transformation_mode = 'tf_half_pixel_for_nn' if interpolation_mode == 'nearest' else 'half_pixel'
    else:
        coordinate_transformation_mode = 'asymmetric'

    interpolate4 = create_op_with_const_inputs(
        graph, Interpolate, {3: int64_array([height_dim, width_dim])}, {
            'name': resize_name + '/interpolate_4',
            'mode': interpolation_mode,
            'antialias': False,
            'coordinate_transformation_mode': coordinate_transformation_mode,
            'pads_begin': int64_array([0]),
            'pads_end': int64_array([0]),
            'nearest_mode': nearest_mode,
            'cube_coeff': -0.75,
            'shape_calculation_mode': 'sizes',
            'version': 'opset4',
            'in_ports_count': 4,
        })

    resize_input_connection = resize.in_port(0).get_connection()
    resize_input_connection.set_destination(interpolate4.in_port(0))
    resize_input_connection.get_source().connect(shape.in_port(0))

    div_node.out_port(0).connect(interpolate4.in_port(2))

    sizes_connection = resize.in_port(1).get_connection()
    sizes_connection.set_destination(interpolate4.in_port(1))
    sizes_connection.get_source().connect(size_to_float.in_port(0))

    resize.out_port(0).get_connection().set_source(interpolate4.out_port(0))
    rename_nodes([(resize, resize_name + '/delete'),
                  (interpolate4, resize_name)])
Пример #10
0
 def extract(cls, node):
     Div.update_node_stat(node)
     return cls.enabled
 def extract(cls, node: Node):
     axis = onnx_attr(node, 'axis', 'i', default=None)
     Div.update_node_stat(node, {'axis': axis})
     return cls.enabled
Пример #12
0
    def replace_sub_graph(self, graph: Graph, match: Dict[str, Node]):
        node = match['op']
        name = node.name

        # Zero Point Nudging : Scale counting
        f_min = node.in_port(1).get_source()
        node.in_port(1).disconnect()
        f_max = node.in_port(2).get_source()
        node.in_port(2).disconnect()

        f_diff = Sub(graph, {'name': name + '/float_range'}).create_node()
        f_max.connect(f_diff.in_port(0))
        f_min.connect(f_diff.in_port(1))

        quant_min_value = int(node.narrow_range)
        quant_max_value = 2 ** node.num_bits - 1
        i_diff = Const(graph, dict(name=name + '/int_range', value=quant_max_value - quant_min_value)).create_node()

        scale = Div(graph, {'name': name + '/scale'}).create_node()
        f_diff.out_port(0).connect(scale.in_port(0))
        i_diff.out_port(0).connect(scale.in_port(1))

        # Zero Point Nudging : ZP from min counting
        descaled_min = Div(graph, {'name': name + '/descaled_min'}).create_node()
        f_min.connect(descaled_min.in_port(0))
        scale.out_port(0).connect(descaled_min.in_port(1))

        zero_point_from_min = Sub(graph, {'name': name + '/zero_point_from_min'}).create_node()
        quant_min = Const(graph, {'value': quant_min_value, 'name': name + '/quant_min'}).create_node()
        quant_min.out_port(0).connect(zero_point_from_min.in_port(0))
        descaled_min.out_port(0).connect(zero_point_from_min.in_port(1))

        # Zero Point Nudging : Nudged Zero Point counting
        zp_lesser_q_mi = Less(graph, {'name': name + '/zero_point_from_min_less_quant_min'}).create_node()
        zero_point_from_min.out_port(0).connect(zp_lesser_q_mi.in_port(0))
        quant_min.out_port(0).connect(zp_lesser_q_mi.in_port(1))

        zp_greater_q_ma = Greater(graph, {'name': name + '/zero_point_from_min_greater_quant_max'}).create_node()
        zero_point_from_min.out_port(0).connect(zp_greater_q_ma.in_port(0))
        quant_max = Const(graph, {'value': quant_max_value, 'name': name + '/quant_max'}).create_node()
        quant_max.out_port(0).connect(zp_greater_q_ma.in_port(1))

        rounded_zero_point_from_min = Round(graph, {'name': name + '/zero_point_from_min_rounding'}).create_node()
        zero_point_from_min.out_port(0).connect(rounded_zero_point_from_min.in_port(0))

        nudged_zero_point = Select(graph, {'name': name + '/nudging_zp_1_select_less_condition'}).create_node()
        greater_condition = Select(graph, {'name': name + '/nudging_zp_2_select_greater_condition'}).create_node()

        greater_condition.in_port(0).connect(zp_greater_q_ma.out_port(0))
        greater_condition.in_port(1).connect(quant_max.out_port(0))
        greater_condition.in_port(2).connect(rounded_zero_point_from_min.out_port(0))

        nudged_zero_point.in_port(0).connect(zp_lesser_q_mi.out_port(0))
        nudged_zero_point.in_port(1).connect(quant_max.out_port(0))
        nudged_zero_point.in_port(2).connect(greater_condition.out_port(0))

        nudged_i_min = Sub(graph, {'name': name + '/nudged_i_min'}).create_node()
        quant_min.out_port(0).connect(nudged_i_min.in_port(0))
        nudged_zero_point.out_port(0).connect(nudged_i_min.in_port(1))

        nudged_i_max = Sub(graph, {'name': name + '/nudged_i_max'}).create_node()
        quant_max.out_port(0).connect(nudged_i_max.in_port(0))
        nudged_zero_point.out_port(0).connect(nudged_i_max.in_port(1))

        nudged_min = Mul(graph, {'name': name + '/nudged_min'}).create_node()
        nudged_i_min.out_port(0).connect(nudged_min.in_port(0))
        scale.out_port(0).connect(nudged_min.in_port(1))

        nudged_max = Mul(graph, {'name': name + '/nudged_max'}).create_node()
        nudged_i_max.out_port(0).connect(nudged_max.in_port(0))
        scale.out_port(0).connect(nudged_max.in_port(1))

        nudged_min.out_port(0).connect(node.in_port(1))
        nudged_max.out_port(0).connect(node.in_port(2))

        # FakeQuantize operation has 5 inputs instead of 3 inputs in TensorFlow
        node.add_input_port(3, skip_if_exist=True)
        node.add_input_port(4, skip_if_exist=True)

        node.in_port(3).connect(nudged_min.out_port(0))
        node.in_port(4).connect(nudged_max.out_port(0))

        FakeQuantize.update_node_stat(node, {'levels': node['levels']})
Пример #13
0
    def dequantize_data(fake_quantize: Node, dst_type: type, quantized_type: type) -> Node:
        graph = fake_quantize.graph
        quantized_data = fake_quantize.in_port(0).get_source().node
        name = fake_quantize.soft_get('name', fake_quantize.id)

        assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \
            'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id))

        dequantizing_cast = Cast(graph, dict(
            name=quantized_data.name + "/to_{}".format(np_data_type_to_destination_type(dst_type)),
            dst_type=dst_type, stop_value_propagation=True)).create_node()
        fake_quantize.in_port(0).get_connection().set_destination(dequantizing_cast.in_port(0))

        # limits of dequantize
        in_low = fake_quantize.in_port(1).get_source()
        in_high = fake_quantize.in_port(2).get_source()
        out_low = fake_quantize.in_port(3).get_source()
        out_high = fake_quantize.in_port(4).get_source()

        # scale calculation
        output_range = Sub(graph, {'name': name + '/output_range'}).create_node()
        output_range.in_port(0).connect(out_high)
        output_range.in_port(1).connect(out_low)

        input_range = Sub(graph, {'name': name + '/input_range'}).create_node()
        input_range.in_port(0).connect(in_high)
        input_range.in_port(1).connect(in_low)

        scale = Div(graph, {'name': name + '/scale'}).create_node()
        scale.in_port(0).connect(output_range.out_port(0))
        scale.in_port(1).connect(input_range.out_port(0))

        # shift calculation
        descaled_output_low = Div(graph, {'name': name + '/descaled_output_low'}).create_node()
        descaled_output_low.in_port(0).connect(out_low)
        descaled_output_low.in_port(1).connect(scale.out_port(0))

        shift = Sub(graph, {'name': name + '/zero_point'}).create_node()
        shift.in_port(0).connect(in_low)
        shift.in_port(1).connect(descaled_output_low.out_port(0))

        # DeQuantize(x) == Mul(Sub(x, zero_point), scale)
        sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node()
        sub_zp.in_port(0).connect(dequantizing_cast.out_port(0))
        sub_zp.in_port(1).connect(shift.out_port(0))

        mul_scale = Mul(graph, {'name': name + '/mulpiply_by_scale'}).create_node()
        mul_scale.in_port(0).connect(sub_zp.out_port(0))
        mul_scale.in_port(1).connect(scale.out_port(0))

        fake_quantize.out_port(0).get_connection().set_source(mul_scale.out_port(0))

        graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])
Пример #14
0
 def extract(node):
     Div.update_node_stat(node)
     return __class__.enabled