def replace_pattern(graph: Graph, match: [str, Node]): pow = match['inv'] mul = match['mul'] const = match['const'] name = mul.soft_get('name', mul.id) devidend_port = mul.in_port(0).get_source() if mul.in_port( 1).get_source().node.id == pow.id else mul.in_port(1).get_source() divider_port = pow.in_port(0).get_source() if pow.in_port( 1).get_source().node.id == const.id else pow.in_port( 1).get_source() div = Div(graph, {'name': name + '/div'}).create_node() mul.out_port(0).get_connection().set_source(div.out_port(0)) devidend_port.connect(div.in_port(0)) divider_port.connect(div.in_port(1))
def floor_div_replacement(floor_div: Node): graph = floor_div.graph name = floor_div.soft_get('name', floor_div.id) div = Div(graph, {'name': name + '/Div'}).create_node() floor = Floor(graph, {'name': name}).create_node() div.out_port(0).connect(floor.in_port(0)) div.in_port(0).connect(floor_div.in_port(0).get_source()) div.in_port(1).connect(floor_div.in_port(1).get_source()) floor_div.out_port(0).get_connection().set_source(floor.out_port(0)) graph.remove_node(floor_div.id) rename_node(floor, name)
def replace_op(self, graph: Graph, node: Node): """ Replace Softsign according to formula feature/(abs(feature)+1) """ abs_node = Abs(graph, {'name': "abs_" + node.id}).create_node() abs_node.in_port(0).connect(node.in_port(0).get_source()) add_node = create_op_node_with_second_input(graph, Add, np.ones( [1]), {"name": node.id + "_plus_1"}) add_node.in_port(0).connect(abs_node.out_port(0)) div_node = Div(graph, {"name": "div_" + node.id}).create_node() div_node.in_port(0).connect(node.in_port(0).get_source()) div_node.in_port(1).connect(add_node.out_port(0)) return [div_node.id]
def test_value_propagation(self, a_shape, a_value, b_shape, b_value, elem_type): graph = build_graph(nodes_attrs=graph_nodes_attrs, edges=graph_edges, update_attributes={ 'A': { 'shape': int64_array(a_shape), 'value': a_value.astype(elem_type) }, 'A_data': { 'shape': int64_array(a_shape), 'value': a_value.astype(elem_type) }, 'B': { 'shape': int64_array(b_shape), 'value': b_value.astype(elem_type) }, 'B_data': { 'shape': int64_array(b_shape), 'value': b_value.astype(elem_type) }, }) node = Node(graph, 'div') node['infer'] = Div(graph, node.attrs()).create_node().infer node.infer(node) node_data = node.out_port(0).get_destination().data.get_value() def func_for_ref(): if np.issubdtype(elem_type, np.integer): return lambda a, b: a // b else: return lambda a, b: a / b ref_data = func_for_ref()(a_value, b_value) node_data_shape = node_data.shape ref_data_shape = ref_data.shape msg = "Value propagation for 'div' node is not correct." self.assertTrue( node_data_shape == ref_data_shape and np.all(node_data == ref_data), msg)
def extract(cls, node): Div.update_node_stat( node, {'data_type': tf_dtype_extractor(node.pb.attr["T"].type)}) return cls.enabled
def replace_resize(graph: Graph, resize: Node): log.debug("Converting of ONNX Resize-11 to Interpolate-4 " "is triggered for node {}.".format( resize.soft_get('name', resize.id))) input_shape = resize.in_port(0).data.get_shape() input_rank = len(input_shape) resize_name = resize.soft_get('name', resize.id) if input_rank not in {4, 5}: log.warning( 'The input shape is not 4D or 5D for op with name {}'.format( resize_name)) return num_of_inputs = len([ port for port in resize.in_ports().values() if not port.disconnected() ]) assert num_of_inputs in {3, 4}, \ "Number of inputs of ONNXResize (with name {}) should be equal to 3 or 4".format(resize_name) assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \ 'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op, resize_name) layout = graph.graph['layout'] if input_rank == 4: begin_dim = get_height_dim(layout, input_rank) end_dim = get_width_dim(layout, input_rank) + 1 else: begin_dim = get_depth_dim(layout, input_rank) end_dim = get_width_dim(layout, input_rank) + 1 sizes_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([begin_dim]), 2: int64_array([end_dim]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice_sizes', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) scales_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([begin_dim]), 2: int64_array([end_dim]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice_scales', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) axes_node = Const( graph, { 'name': resize_name + '/axis', 'value': int64_array(np.arange(begin_dim, end_dim)) }).create_node() shape_calculation_mode = 'scales' if num_of_inputs == 3 else 'sizes' interpolate_node = Interpolate( graph, { 'version': 'opset4', 'mode': convert_mode(resize.mode), 'coordinate_transformation_mode': resize.coordinate_transformation_mode, 'cube_coeff': resize.cube_coeff, 'nearest_mode': resize.nearest_mode, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'antialias': 0, 'shape_calculation_mode': shape_calculation_mode, 'in_ports_count': 4 }).create_node() axes_node.out_port(0).connect(interpolate_node.in_port(3)) shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node() add_node = create_op_with_const_inputs(graph, Add, {1: float_array([1.0e-5])}, {'name': resize_name + '/Add'}) input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) if num_of_inputs == 3: cast_shape_to_float = Cast(graph, { 'dst_type': input_data_type }).create_node() mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node() shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) cast_add_result_to_int = Cast(graph, { 'dst_type': np.int64 }).create_node() floor_node = Floor(graph, { 'name': resize_name + '/Floor' }).create_node() mul_node.out_port(0).connect(add_node.in_port(0)) add_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0)) cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_scales = resize.in_port(2).get_connection() connection_of_scales.set_destination(scales_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_scales.get_source().connect(mul_node.in_port(1)) else: cast_shape_to_float = Cast(graph, { 'dst_type': input_data_type }).create_node() cast_sizes_to_float = Cast(graph, { 'dst_type': input_data_type }).create_node() div_node = Div(graph, {'name': resize_name + '/Div'}).create_node() cast_sizes_to_float.out_port(0).connect(div_node.in_port(0)) cast_shape_to_float.out_port(0).connect(div_node.in_port(1)) shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) div_node.out_port(0).connect(add_node.in_port(0)) add_node.out_port(0).connect(scales_ss.in_port(0)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_sizes = resize.in_port(3).get_connection() connection_of_sizes.set_destination(sizes_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_sizes.get_source().connect( cast_sizes_to_float.in_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate_node, resize_name)]) resize.out_port(0).get_connection().set_source( interpolate_node.out_port(0))
def replace_sub_graph(self, graph: Graph, match: Dict[str, Node]): node = match['op'] name = node.name min_port_tuple = (node.in_port(1).get_source().node, node.in_port(1).get_source().idx) max_port_tuple = (node.in_port(2).get_source().node, node.in_port(2).get_source().idx) node.in_port(1).disconnect() node.in_port(2).disconnect() # make sure min < max min_less_max = Less(graph, { 'name': name + '/if_min_less_max' }).create_node([min_port_tuple, max_port_tuple]) minimum = Select(graph, { 'name': name + '/minimum' }).create_node([min_less_max, min_port_tuple, max_port_tuple]) maximum = Select(graph, { 'name': name + '/maximum' }).create_node([min_less_max, max_port_tuple, min_port_tuple]) # to create zero of limits data type, we multiply it by integer zero zero = create_op_node_with_second_input(graph, Mul, int64_array(0), {'name': name + '/zero'}, input_node=minimum) # if 0 < min < max: min_adj = 0 and max_adj = max - min min_greater_zero = Greater(graph, { 'name': name + '/if_minimum_greater_zero' }).create_node([minimum, zero]) max_minus_min = Sub(graph, { 'name': name + '/max_minus_min' }).create_node([maximum, minimum]) minimum = Select(graph, { 'name': name + '/first_adj_min' }).create_node([min_greater_zero, zero, minimum]) maximum = Select(graph, { 'name': name + '/first_adj_max' }).create_node([min_greater_zero, max_minus_min, maximum]) # if min < max < 0: min_adj = min - max and max_adj = 0 max_less_zero = Less(graph, { 'name': name + '/if_max_less_zero' }).create_node([maximum, zero]) min_minus_max = Sub(graph, { 'name': name + '/min_minus_max' }).create_node([minimum, maximum]) minimum = Select(graph, { 'name': name + '/second_adj_min' }).create_node([max_less_zero, min_minus_max, minimum]) maximum = Select(graph, { 'name': name + '/second_adj_max' }).create_node([max_less_zero, zero, maximum]) # scale = (max - min) / (2 ^ num_bits - 1), float_range = Sub(graph, { 'name': name + '/float_range' }).create_node([maximum, minimum]) quant_min_value, quant_max_value = int( node.narrow_range), 2**node.num_bits - 1 int_range = Const( graph, dict(name=name + '/int_range', value=quant_max_value - quant_min_value)).create_node() scale = Div(graph, { 'name': name + '/scale' }).create_node([float_range, int_range]) # min_adj = scale * round(min / scale) descaled_min = Div(graph, { 'name': name + '/descaled_min' }).create_node([minimum, scale]) rounded_descaled_min = Round(graph, { 'name': name + '/rounded_descaled_min' }).create_node([descaled_min]) min_adj = Mul(graph, { 'name': name + '/min_adj' }).create_node([scale, rounded_descaled_min]) # max_adj = max + min_adj - min. adjustment = Sub(graph, { 'name': name + '/limits_adjustment' }).create_node([min_adj, minimum]) max_adj = Add(graph, { 'name': name + '/max_adj' }).create_node([maximum, adjustment]) # FakeQuantize operation has 5 inputs instead of 3 inputs in TensorFlow node.add_input_port(3, skip_if_exist=True) node.add_input_port(4, skip_if_exist=True) node.in_port(1).connect(min_adj.out_port(0)) node.in_port(2).connect(max_adj.out_port(0)) node.in_port(3).connect(min_adj.out_port(0)) node.in_port(4).connect(max_adj.out_port(0)) FakeQuantize.update_node_stat(node, {'levels': node['levels']})
def find_and_replace_pattern(self, graph: Graph): for embedding_segments_mean in graph.get_op_nodes( op='EmbeddingSegmentsMean'): embedding_segments_mean_name = embedding_segments_mean.soft_get( 'name', embedding_segments_mean.id) embedding_table_input = embedding_segments_mean.in_port(0) segment_ids_input = embedding_segments_mean.in_port(2) num_segments_input = embedding_segments_mean.in_port(3) # TODO: support EmbeddingSegmentsMean with specified weights vector. # now this case has not appeared in models so far so EmbeddingSegmentsOperation fusion # transformations do not handle it either if embedding_segments_mean.is_in_port_connected(5): return # 1. compute indices membership matrix, i.e. which indices belong to some object # the shape of this matrix is [num_segments, num_indices] non_norm_range_1_to_num_segments = create_op_with_const_inputs( graph, Range, { 0: int64_array(0), 2: int64_array(1) }, { 'name': embedding_segments_mean_name + '/Range1ToNumSegments', 'output_type': np.int64 }) num_segments_input.get_connection().add_destination( non_norm_range_1_to_num_segments.in_port(1)) range_1_to_num_segments = ConvertLike(graph, { 'name': embedding_segments_mean_name + '/Range1ToNumSegmentsNorm' }).create_node() range_1_to_num_segments.in_port(0).connect( non_norm_range_1_to_num_segments.out_port(0)) num_segments_input.get_connection().add_destination( range_1_to_num_segments.in_port(1)) unsqueeze_range_1_to_num_segments = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array(1)}, { 'name': embedding_segments_mean_name + '/Range1ToNumSegmentsUnsqueeze' }) unsqueeze_range_1_to_num_segments.in_port(0).connect( range_1_to_num_segments.out_port(0)) unsqueeze_segment_ids = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array(0)}, { 'name': embedding_segments_mean_name + '/SegmentIdsUnsqueeze' }) segment_ids_input.get_connection().add_destination( unsqueeze_segment_ids.in_port(0)) boolean_membership_matrix = Equal(graph, { 'name': embedding_segments_mean_name + '/BooleanMembershipMatrix' }).create_node() boolean_membership_matrix.in_port(0).connect( unsqueeze_range_1_to_num_segments.out_port(0)) boolean_membership_matrix.in_port(1).connect( unsqueeze_segment_ids.out_port(0)) shape_of_membership_matrix = Shape(graph, { 'name': embedding_segments_mean_name + '/ShapeOfMembershipMatrix' }).create_node([boolean_membership_matrix]) one_scalar_constant = Const( graph, { 'name': embedding_segments_mean_name + '/OneScalar', 'value': int64_array([1]) }).create_node() one_constant = Broadcast(graph, { 'name': embedding_segments_mean_name + '/One' }).create_node([one_scalar_constant, shape_of_membership_matrix]) zero_constant = Const( graph, { 'name': embedding_segments_mean_name + '/Zero', 'value': int64_array(0) }).create_node() membership_matrix = Select( graph, { 'name': embedding_segments_mean_name + '/MembershipMatrix', 'auto_broadcast': 'numpy' }).create_node( [boolean_membership_matrix, one_constant, zero_constant]) # 2. compute a number of indices belong to each object from the batch # it computes the normalization coefficients num_indices_per_object = create_op_with_const_inputs( graph, ReduceSum, {1: int64_array(1)}, { 'name': embedding_segments_mean_name + '/NumIndicesPerObject' }) num_indices_per_object.in_port(0).connect( membership_matrix.out_port(0)) # 3. replace zero coefficient (zero number of indices belong to an object) with one # because for such object the single default embedding vector is used where_zero_number = Equal(graph, { 'name': embedding_segments_mean_name + '/WhereZeroIndicesNumber' }).create_node([num_indices_per_object, zero_constant]) normalized_num_indices_per_object = Select( graph, { 'name': embedding_segments_mean_name + '/NormNumIndicesPerObject', 'auto_broadcast': 'numpy' }).create_node([ where_zero_number, one_scalar_constant, num_indices_per_object ]) # 4. cast normalized_num_indices_per_object to the same type as embedding vector table norm_coefficients = ConvertLike( graph, { 'name': embedding_segments_mean_name + '/NormCoefficients' }).create_node() norm_coefficients.in_port(0).connect( normalized_num_indices_per_object.out_port(0)) embedding_table_input.get_connection().add_destination( norm_coefficients.in_port(1)) # 5. replace EmbeddingSegmentMean with EmbeddingSegmentSum embedding_segments_sum = EmbeddingSegmentsSum( graph, { 'name': embedding_segments_mean_name + '/EmbeddingSegmentsSum' }).create_node() for in_port in embedding_segments_mean.in_ports(): if embedding_segments_mean.is_in_port_connected(in_port): embedding_segments_mean.in_port( in_port).get_connection().set_destination( embedding_segments_sum.in_port(in_port)) # 6. normalize EmbeddingSegmentSum results by computed coefficients result_node = Div(graph, { 'name': embedding_segments_mean_name + '/Div' }).create_node([embedding_segments_sum, norm_coefficients]) embedding_segments_mean.out_port(0).get_connection().set_source( result_node.out_port(0)) rename_nodes([(embedding_segments_mean, embedding_segments_mean_name + '/AbandonedName'), (result_node, embedding_segments_mean_name)]) graph.remove_nodes_from([embedding_segments_mean.id])
def replace_tf_resize(graph: Graph, resize: Node, interpolation_mode: str): resize_name = resize.soft_get('name', resize.id) log.debug( "Converting of {} to Interpolate-4 is triggered for node {}.".format( resize.op, resize_name)) num_of_inputs = len([ port for port in resize.in_ports().values() if not port.disconnected() ]) assert num_of_inputs == 2, \ "Number of inputs of {} (with name {}) should be equal to 2".format(resize.op, resize_name) attrs_msg = "If half_pixel_centers attribute of the node {} with op {} is True, " \ "the attribute align_corners must be False" assert not resize.half_pixel_centers or (resize.half_pixel_centers and not resize.align_corners), \ attrs_msg.format(resize_name, resize.op) shape = Shape(graph, {'name': resize_name + '/shapeof'}).create_node() layout = graph.graph['layout'] height_dim = get_height_dim(layout, 4) width_dim = get_width_dim(layout, 4) ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([height_dim]), 2: int64_array([width_dim + 1]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) div_node = Div(graph, {'name': resize_name + '/Div'}).create_node() shape_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() size_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() size_to_float.out_port(0).connect(div_node.in_port(0)) shape_to_float.out_port(0).connect(div_node.in_port(1)) ss.out_port(0).connect(shape_to_float.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) align_corners = resize.align_corners half_pixel_centers = resize.half_pixel_centers nearest_mode = 'floor' if interpolation_mode == 'nearest' else 'round_prefer_floor' if align_corners: coordinate_transformation_mode = 'align_corners' if interpolation_mode == 'nearest': nearest_mode = 'round_prefer_ceil' elif half_pixel_centers: coordinate_transformation_mode = 'tf_half_pixel_for_nn' if interpolation_mode == 'nearest' else 'half_pixel' else: coordinate_transformation_mode = 'asymmetric' interpolate4 = create_op_with_const_inputs( graph, Interpolate, {3: int64_array([height_dim, width_dim])}, { 'name': resize_name + '/interpolate_4', 'mode': interpolation_mode, 'antialias': False, 'coordinate_transformation_mode': coordinate_transformation_mode, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'nearest_mode': nearest_mode, 'cube_coeff': -0.75, 'shape_calculation_mode': 'sizes', 'version': 'opset4', 'in_ports_count': 4, }) resize_input_connection = resize.in_port(0).get_connection() resize_input_connection.set_destination(interpolate4.in_port(0)) resize_input_connection.get_source().connect(shape.in_port(0)) div_node.out_port(0).connect(interpolate4.in_port(2)) sizes_connection = resize.in_port(1).get_connection() sizes_connection.set_destination(interpolate4.in_port(1)) sizes_connection.get_source().connect(size_to_float.in_port(0)) resize.out_port(0).get_connection().set_source(interpolate4.out_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate4, resize_name)])
def extract(cls, node): Div.update_node_stat(node) return cls.enabled
def extract(cls, node: Node): axis = onnx_attr(node, 'axis', 'i', default=None) Div.update_node_stat(node, {'axis': axis}) return cls.enabled
def replace_sub_graph(self, graph: Graph, match: Dict[str, Node]): node = match['op'] name = node.name # Zero Point Nudging : Scale counting f_min = node.in_port(1).get_source() node.in_port(1).disconnect() f_max = node.in_port(2).get_source() node.in_port(2).disconnect() f_diff = Sub(graph, {'name': name + '/float_range'}).create_node() f_max.connect(f_diff.in_port(0)) f_min.connect(f_diff.in_port(1)) quant_min_value = int(node.narrow_range) quant_max_value = 2 ** node.num_bits - 1 i_diff = Const(graph, dict(name=name + '/int_range', value=quant_max_value - quant_min_value)).create_node() scale = Div(graph, {'name': name + '/scale'}).create_node() f_diff.out_port(0).connect(scale.in_port(0)) i_diff.out_port(0).connect(scale.in_port(1)) # Zero Point Nudging : ZP from min counting descaled_min = Div(graph, {'name': name + '/descaled_min'}).create_node() f_min.connect(descaled_min.in_port(0)) scale.out_port(0).connect(descaled_min.in_port(1)) zero_point_from_min = Sub(graph, {'name': name + '/zero_point_from_min'}).create_node() quant_min = Const(graph, {'value': quant_min_value, 'name': name + '/quant_min'}).create_node() quant_min.out_port(0).connect(zero_point_from_min.in_port(0)) descaled_min.out_port(0).connect(zero_point_from_min.in_port(1)) # Zero Point Nudging : Nudged Zero Point counting zp_lesser_q_mi = Less(graph, {'name': name + '/zero_point_from_min_less_quant_min'}).create_node() zero_point_from_min.out_port(0).connect(zp_lesser_q_mi.in_port(0)) quant_min.out_port(0).connect(zp_lesser_q_mi.in_port(1)) zp_greater_q_ma = Greater(graph, {'name': name + '/zero_point_from_min_greater_quant_max'}).create_node() zero_point_from_min.out_port(0).connect(zp_greater_q_ma.in_port(0)) quant_max = Const(graph, {'value': quant_max_value, 'name': name + '/quant_max'}).create_node() quant_max.out_port(0).connect(zp_greater_q_ma.in_port(1)) rounded_zero_point_from_min = Round(graph, {'name': name + '/zero_point_from_min_rounding'}).create_node() zero_point_from_min.out_port(0).connect(rounded_zero_point_from_min.in_port(0)) nudged_zero_point = Select(graph, {'name': name + '/nudging_zp_1_select_less_condition'}).create_node() greater_condition = Select(graph, {'name': name + '/nudging_zp_2_select_greater_condition'}).create_node() greater_condition.in_port(0).connect(zp_greater_q_ma.out_port(0)) greater_condition.in_port(1).connect(quant_max.out_port(0)) greater_condition.in_port(2).connect(rounded_zero_point_from_min.out_port(0)) nudged_zero_point.in_port(0).connect(zp_lesser_q_mi.out_port(0)) nudged_zero_point.in_port(1).connect(quant_max.out_port(0)) nudged_zero_point.in_port(2).connect(greater_condition.out_port(0)) nudged_i_min = Sub(graph, {'name': name + '/nudged_i_min'}).create_node() quant_min.out_port(0).connect(nudged_i_min.in_port(0)) nudged_zero_point.out_port(0).connect(nudged_i_min.in_port(1)) nudged_i_max = Sub(graph, {'name': name + '/nudged_i_max'}).create_node() quant_max.out_port(0).connect(nudged_i_max.in_port(0)) nudged_zero_point.out_port(0).connect(nudged_i_max.in_port(1)) nudged_min = Mul(graph, {'name': name + '/nudged_min'}).create_node() nudged_i_min.out_port(0).connect(nudged_min.in_port(0)) scale.out_port(0).connect(nudged_min.in_port(1)) nudged_max = Mul(graph, {'name': name + '/nudged_max'}).create_node() nudged_i_max.out_port(0).connect(nudged_max.in_port(0)) scale.out_port(0).connect(nudged_max.in_port(1)) nudged_min.out_port(0).connect(node.in_port(1)) nudged_max.out_port(0).connect(node.in_port(2)) # FakeQuantize operation has 5 inputs instead of 3 inputs in TensorFlow node.add_input_port(3, skip_if_exist=True) node.add_input_port(4, skip_if_exist=True) node.in_port(3).connect(nudged_min.out_port(0)) node.in_port(4).connect(nudged_max.out_port(0)) FakeQuantize.update_node_stat(node, {'levels': node['levels']})
def dequantize_data(fake_quantize: Node, dst_type: type, quantized_type: type) -> Node: graph = fake_quantize.graph quantized_data = fake_quantize.in_port(0).get_source().node name = fake_quantize.soft_get('name', fake_quantize.id) assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \ 'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id)) dequantizing_cast = Cast(graph, dict( name=quantized_data.name + "/to_{}".format(np_data_type_to_destination_type(dst_type)), dst_type=dst_type, stop_value_propagation=True)).create_node() fake_quantize.in_port(0).get_connection().set_destination(dequantizing_cast.in_port(0)) # limits of dequantize in_low = fake_quantize.in_port(1).get_source() in_high = fake_quantize.in_port(2).get_source() out_low = fake_quantize.in_port(3).get_source() out_high = fake_quantize.in_port(4).get_source() # scale calculation output_range = Sub(graph, {'name': name + '/output_range'}).create_node() output_range.in_port(0).connect(out_high) output_range.in_port(1).connect(out_low) input_range = Sub(graph, {'name': name + '/input_range'}).create_node() input_range.in_port(0).connect(in_high) input_range.in_port(1).connect(in_low) scale = Div(graph, {'name': name + '/scale'}).create_node() scale.in_port(0).connect(output_range.out_port(0)) scale.in_port(1).connect(input_range.out_port(0)) # shift calculation descaled_output_low = Div(graph, {'name': name + '/descaled_output_low'}).create_node() descaled_output_low.in_port(0).connect(out_low) descaled_output_low.in_port(1).connect(scale.out_port(0)) shift = Sub(graph, {'name': name + '/zero_point'}).create_node() shift.in_port(0).connect(in_low) shift.in_port(1).connect(descaled_output_low.out_port(0)) # DeQuantize(x) == Mul(Sub(x, zero_point), scale) sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node() sub_zp.in_port(0).connect(dequantizing_cast.out_port(0)) sub_zp.in_port(1).connect(shift.out_port(0)) mul_scale = Mul(graph, {'name': name + '/mulpiply_by_scale'}).create_node() mul_scale.in_port(0).connect(sub_zp.out_port(0)) mul_scale.in_port(1).connect(scale.out_port(0)) fake_quantize.out_port(0).get_connection().set_source(mul_scale.out_port(0)) graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])
def extract(node): Div.update_node_stat(node) return __class__.enabled