def find_and_replace_pattern(self, graph: Graph): for dequantize_node in graph.get_op_nodes(op='DequantizeLinear'): node_name = dequantize_node.soft_get('name', dequantize_node.id) axis = dequantize_node.soft_get('axis', None) scale_y_shape = dequantize_node.in_port(1).data.get_shape() model_data_type = data_type_str_to_np( graph.graph['cmd_params'].data_type) cast = Cast(graph, { 'dst_type': model_data_type, 'name': node_name + '/Cast' }).create_node() dequantize_node.in_port(0).get_connection().set_destination( cast.in_port(0)) mul = Mul(graph, {'can_be_fused': False}).create_node() is_second_port_connected = dequantize_node.is_in_port_connected(2) if is_second_port_connected: # its is necessary not to replace subrtract for pattern in offline transformations # See ConvertQuantizeDequantize transformation in ngraph sub = Sub(graph, { 'name': node_name + '/Sub', 'zero_point_sub': True }).create_node() cast.out_port(0).connect(sub.in_port(0)) dequantize_node.in_port(2).get_connection().set_destination( sub.in_port(1)) sub.out_port(0).connect(mul.in_port(0)) else: cast.out_port(0).connect(mul.in_port(0)) dequantize_node.in_port(1).get_connection().set_destination( mul.in_port(1)) dequantize_node.out_port(0).get_connection().set_source( mul.out_port(0)) rename_nodes([(dequantize_node, node_name + '/TBD'), (mul, node_name)]) assert scale_y_shape is not None if axis is not None and len( scale_y_shape) > 0 and scale_y_shape[0] > 1: input_shape = cast.in_port(0).data.get_shape() target_shape = np.ones(len(input_shape), np.int64) target_shape[axis] = input_shape[axis] mul_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array(target_shape)}, {'name': node_name + '/Reshape/Mul'}) mul.in_port(1).get_connection().set_destination( mul_reshape.in_port(0)) mul_reshape.out_port(0).connect(mul.in_port(1)) if is_second_port_connected: sub_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array(target_shape)}, {'name': node_name + '/Reshape/Sub'}) sub.in_port(1).get_connection().set_destination( sub_reshape.in_port(0)) sub_reshape.out_port(0).connect(sub.in_port(1))
def create_ss_interval_border(graph: Graph, slice_border_port: Port, shape: np.ndarray, axes: np.ndarray, node_name: str): """ This function creates "begin"/"end" parameters for the StridedSlice based on Slice's "starts"/"ends" :param graph: graph to operate on. :param slice_border_port: node output port that provides "starts"/"ends" values for the Slice. :param shape: input shape of the Slice :param axes: axes that "starts" and "ends" apply to :param node_name: Slice node name :return: Concat node that forms "begin"/"end" values for the StridedSlice """ # the value for 'starts' or 'ends' might be maximum/minimum possible value of int64. This # value must be converted to maximum/minimum of int32 because such big values do not fit into the int32 which is # supported by the StridedSlice layer clamp = create_op_with_const_inputs(graph, Clamp, port_value_dict={ 1: np.iinfo(np.int32).min, 2: np.iinfo(np.int32).max }, op_attrs=dict(name=node_name + '/Clamp')) clamp.in_port(0).connect(slice_border_port) # we have to convert "starts"/"ends" values from the network to one data type with constant values that are created # here to prevent type errors in Concat node cast = Cast(graph, dict(name=node_name + '/CastToI64', dst_type=np.int64)).create_node() cast.in_port(0).connect(clamp.out_port(0)) concat = Concat(graph, dict(name=node_name + '/Concat', axis=0)).create_node() for value_idx, port_idx in enumerate(axes): concat.add_input_port(port_idx) # "axes" may not be sorted, so we need to split "starts"/"ends" values and connect each value to the correct # Concat input port value = create_op_with_const_inputs( graph, Gather, port_value_dict={ 1: int64_array([value_idx]), 2: int64_array(0) }, op_attrs={'name': node_name + '/Gather'}) cast.out_port(0).connect(value.in_port(0)) value.out_port(0).connect(concat.in_port(port_idx)) for port_idx in range(len(shape)): if not concat.is_in_port_connected(port_idx): concat.add_input_port(port_idx) # This border value would be ignored in StridedSlice because of the begin_mask\end_mask const = Const( graph, dict(name=node_name + '/Const', value=int64_array([0]))).create_node() const.out_port(0).connect(concat.in_port(port_idx)) return concat
def replace_interpolate_pattern(graph: Graph, match: dict): split = match['split'] scale = float32_array([get_split_scale(split)]) axis = int(split.in_port(1).get_connection().get_source().node.value) split_node_name = split.name axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node() shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node() scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node() mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node() scales_node.out_port(0).connect(mul_node.in_port(1)) strided_slice_node = create_op_with_const_inputs(graph, StridedSlice, {1: int64_array([axis]), 2: int64_array([axis + 1])}, { 'name': split_node_name + '/StridedSlice', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) shape_node.out_port(0).connect(strided_slice_node.in_port(0)) cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node() strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate', mode='nearest', antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]), coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor', cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales', in_ports_count=4, maybe_part_of_sequence=True)).create_node() floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node() cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node() mul_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0)) cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1)) scales_node.out_port(0).connect(interp_node.in_port(2)) axis_node.out_port(0).connect(interp_node.in_port(3)) match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0)) split_connection = split.in_port(0).get_connection() split_connection.set_destination(interp_node.in_port(0)) split_connection.get_source().connect(shape_node.in_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): div_sqrt = match['op'] div_sqrt_name = div_sqrt.soft_get('name', div_sqrt.id) shape_node = Shape(graph, dict(name=div_sqrt_name + '/Shape')).create_node() data_out_port = div_sqrt.in_port(0).get_source() shape_node.in_port(0).connect(data_out_port) shape_values_node = node_to_get_shape_value_of_indices( shape_node=shape_node, indices=[-1]) pow_node = AttributedPower( graph, dict(name=div_sqrt_name + '/Sqrt', power=mo_array(0.5))).create_node() # Due to specification, Power must have inputs with the same data type. convert_pow_input = Cast( graph, dict(dst_type=np.float32, name=shape_values_node.name + '/ConvertToFP32')).create_node() div_node = Div(graph, dict(name="Div")).create_node() shape_values_node.out_port(0).connect(convert_pow_input.in_port(0)) convert_pow_input.out_port(0).connect(pow_node.in_port(0)) div_sqrt.in_port(0).get_connection().set_destination( div_node.in_port(0)) div_node.in_port(1).connect(pow_node.out_port(0)) div_sqrt.out_port(0).get_connection().set_source(div_node.out_port(0)) rename_nodes([(div_sqrt, div_sqrt_name + '/ShouldBeDeleted'), (div_node, div_sqrt_name)])
def replace_op(self, graph: Graph, node: Node): name = node.soft_get('name', node.id) axis = node.soft_get('axis', 0) rename_node(node=node, name=name + '/to_be_removed') cumsum_node = create_op_node_with_second_input(graph, CumSum, int64_array(axis), { 'name': name, 'reverse': False, 'exclusive': False }) rename_node(cumsum_node, name) node.in_port(0).get_connection().set_destination( cumsum_node.in_port(0)) if node.has_valid('mx_out_type') and node['mx_out_type'] is not None: rename_node(node=cumsum_node, name=name + '/CumSum') convert = Cast(graph, { 'name': name, 'dst_type': node['mx_out_type'] }).create_node() rename_node(convert, name) cumsum_node.out_port(0).connect(convert.in_port(0)) return [convert.id] else: return [cumsum_node.id]
def replace_op(self, graph: Graph, node: Node): if node.has_and_set('inputs_preprocessed'): log.debug('Node "{}" has already been preprocessed'.format( node.soft_get('name'))) return [] # reshape tensor with batch indices to 2d unsqueeze_node = create_op_node_with_second_input( graph, Unsqueeze, int64_array([1]), {'name': node.name + '/Unsqueeze'}, node.in_node(2)) convert_node = Cast( graph, { 'name': unsqueeze_node.name + '/ToFloat', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() convert_node.in_port(0).connect(unsqueeze_node.out_port(0)) concat_op = Concat( graph, { 'axis': 1, 'name': node.name + '/concat_batch_indices_and_boxes', 'in_ports_count': 2 }) concat_node = concat_op.create_node([convert_node, node.in_node(1)]) # do not remove edge with crop_size because it is needed in the partial infer graph.remove_edge(node.in_node(1).id, node.id) # input to the CropAndResize contains boxes coordinates in YXYX layout. But IE layer ROIPooling expects # coordinates in the XYXY layout, so convolution is added here to swap coordinates swapped_box_coordinates_node = add_convolution_to_swap_xy_coordinates( graph, concat_node, 5) # reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift reshape_2d_node = create_op_node_with_second_input( graph, Reshape, int64_array([-1, 5]), dict(name=swapped_box_coordinates_node.id + '/reshape_2d_'), swapped_box_coordinates_node) graph.create_edge(reshape_2d_node, node, 0, 1) # do not replace any output edge return []
def quantize_data(fake_quantize: Node, dst_type: type, quantized_type: type, mode: str): graph = fake_quantize.graph name = fake_quantize.soft_get('name', fake_quantize.id) levels = fake_quantize.levels quantize = fake_quantize.copy_node( dict(name=name + '/Copy', stop_value_propagation=False), graph) fake_quantize.in_port(0).get_connection().set_destination( quantize.in_port(0)) # inherit input limits fake_quantize.in_port(1).get_connection().set_destination( quantize.in_port(1)) fake_quantize.in_port(2).get_connection().set_destination( quantize.in_port(2)) # calculate output limits for quantized weights assert mode in ["signed", "unsigned"] i_min_value = -(levels // 2) if mode == "signed" else 0 i_min = mo_array(i_min_value, dtype=dst_type) if not quantize.in_node( 0).shape.size else mo_array([i_min_value], dtype=dst_type) i_max = mo_array(levels + i_min - 1, dtype=dst_type) assert i_max - i_min == levels - 1 out_low = Const(graph, dict(name=name + '/Copy/out_low', value=i_min)).create_node() out_high = Const(graph, dict(name=name + '/Copy/out_high', value=i_max)).create_node() out_low.out_port(0).connect(quantize.in_port(3)) out_high.out_port(0).connect(quantize.in_port(4)) out_low.out_port(0).connect(fake_quantize.in_port(1)) out_high.out_port(0).connect(fake_quantize.in_port(2)) original_const = quantize.in_port(0).get_source().node quantized_data_name = original_const.soft_get( 'name', original_const.id) + '/quantized' cast = Cast( graph, dict(name=quantized_data_name, dst_type=quantized_type, stop_value_propagation=False)).create_node() quantize.out_port(0).connect(cast.in_port(0)) cast.out_port(0).connect(fake_quantize.in_port(0))
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='ThresholdedRelu'): name = node.soft_get('name', node.id) greater = create_op_with_const_inputs( graph, Greater, {1: float_array([node.alpha])}) greater.in_port(0).connect(node.in_port(0).get_source()) float_greater = Cast( graph, { 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() greater.out_port(0).connect(float_greater.in_port(0)) mul = Mul(graph, {}).create_node() node.out_port(0).get_connection().set_source(mul.out_port(0)) mul.in_port(0).connect(node.in_port(0).get_source()) mul.in_port(1).connect(float_greater.out_port(0)) rename_nodes([(node, name + '/TBR'), (mul, name)]) graph.remove_node(node.id)
def replace_resize(graph: Graph, resize: Node): log.debug("Converting of ONNX Resize-10 to Interpolate-4 " "is triggered for node {}.".format( resize.soft_get('name', resize.id))) resize_name = resize.soft_get('name', resize.id) rank_node = Rank(graph, {'name': resize_name + '/max_axes'}).create_node() range_node = create_op_with_const_inputs(graph, Range, { 0: int64_array(2), 2: int64_array(1) }, {'name': resize_name + '/axes'}) sizes_ss = create_op_with_const_inputs(graph, StridedSlice, { 1: int64_array([2]), 2: int64_array([0]), 3: int64_array([1]) }, { 'name': resize_name + '/sizes_ss', 'begin_mask': int64_array([1]), 'end_mask': int64_array([0]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) scales_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([2]), 2: int64_array([0]), 3: int64_array([1]) }, { 'name': resize_name + '/scales_ss', 'begin_mask': int64_array([1]), 'end_mask': int64_array([0]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) rank_node.out_port(0).connect(range_node.in_port(1)) interpolate_node = Interpolate( graph, { 'version': 'opset4', 'mode': 'linear_onnx' if resize.mode == 'linear' else 'nearest', 'coordinate_transformation_mode': 'asymmetric', 'cube_coeff': -0.75, 'nearest_mode': 'simple', 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'antialias': 0, 'shape_calculation_mode': 'scales', 'in_ports_count': 4 }).create_node() range_node.out_port(0).connect(interpolate_node.in_port(3)) shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node() # When we calculate 'sizes' input as floor(input_shape * scales), we can get incorrect 'sizes' if, e.g., # scales = [1.0, 1.0, 1.33333, 2.0], input_shape = [1, 3, 30, 200], because # input_shape * scales = [1, 3, 39.9999, 400], and floor(input_shape * scales)[2] == 39, not 40. # Maybe we need to calculate 'sizes' input as floor(input_shape * scales + eps), where eps is some small # floating point number, e.g. 1.0e-5. But, in this case, if scales = [1.0, 1.0, 1.333333, 2.0], # input_shape = [1, 3, 30, 200], floor(input_shape * scales + eps) = 39, not 40, because # input_shape[2] * scales[2] + 1.0e-5 = 39.99991. # Hence, we need to calculate 'sizes' as floor(input_shape * (scales + eps)). add_node = create_op_with_const_inputs(graph, Add, {1: float_array([1.0e-5])}, {'name': resize_name + '/Add'}) dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node() shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) mul_node = Mul(graph, { 'name': resize_name + '/Mul' }).create_node([cast_shape_to_float, add_node]) floor_node = Floor(graph, { 'name': resize_name + '/Floor' }).create_node([mul_node]) cast_mul_result_to_int = Cast(graph, { 'dst_type': np.int64 }).create_node([floor_node]) cast_mul_result_to_int.out_port(0).connect(sizes_ss.in_port(0)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_scales = resize.in_port(1).get_connection() connection_of_scales.set_destination(scales_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_resize_input.get_source().connect(rank_node.in_port(0)) connection_of_scales.get_source().connect(add_node.in_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate_node, resize_name)]) resize.out_port(0).get_connection().set_source( interpolate_node.out_port(0))
def replace_tf_resize(graph: Graph, resize: Node, interpolation_mode: str): resize_name = resize.soft_get('name', resize.id) log.debug( "Converting of {} to Interpolate-4 is triggered for node {}.".format( resize.op, resize_name)) num_of_inputs = len([ port for port in resize.in_ports().values() if not port.disconnected() ]) assert num_of_inputs == 2, \ "Number of inputs of {} (with name {}) should be equal to 2".format(resize.op, resize_name) attrs_msg = "If half_pixel_centers attribute of the node {} with op {} is True, " \ "the attribute align_corners must be False" assert not resize.half_pixel_centers or (resize.half_pixel_centers and not resize.align_corners), \ attrs_msg.format(resize_name, resize.op) shape = Shape(graph, {'name': resize_name + '/shapeof'}).create_node() ss = create_op_with_const_inputs(graph, StridedSlice, { 1: int64_array([1]), 2: int64_array([3]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) div_node = Div(graph, {'name': resize_name + '/Div'}).create_node() shape_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() size_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() size_to_float.out_port(0).connect(div_node.in_port(0)) shape_to_float.out_port(0).connect(div_node.in_port(1)) ss.out_port(0).connect(shape_to_float.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) align_corners = resize.align_corners half_pixel_centers = resize.half_pixel_centers nearest_mode = 'floor' if interpolation_mode == 'nearest' else 'round_prefer_floor' if align_corners: coordinate_transformation_mode = 'align_corners' if interpolation_mode == 'nearest': nearest_mode = 'round_prefer_ceil' elif half_pixel_centers: coordinate_transformation_mode = 'tf_half_pixel_for_nn' if interpolation_mode == 'nearest' else 'half_pixel' else: coordinate_transformation_mode = 'asymmetric' interpolate4 = create_op_with_const_inputs( graph, Interpolate, {3: int64_array([1, 2])}, { 'name': resize_name + '/interpolate_4', 'mode': interpolation_mode, 'antialias': False, 'coordinate_transformation_mode': coordinate_transformation_mode, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'nearest_mode': nearest_mode, 'cube_coeff': -0.75, 'shape_calculation_mode': 'sizes', 'version': 'opset4', 'in_ports_count': 4, }) resize_input_connection = resize.in_port(0).get_connection() resize_input_connection.set_destination(interpolate4.in_port(0)) resize_input_connection.get_source().connect(shape.in_port(0)) div_node.out_port(0).connect(interpolate4.in_port(2)) sizes_connection = resize.in_port(1).get_connection() sizes_connection.set_destination(interpolate4.in_port(1)) sizes_connection.get_source().connect(size_to_float.in_port(0)) resize.out_port(0).get_connection().set_source(interpolate4.out_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate4, resize_name)])
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] name = node.soft_get('name', node.id) axis = node.axis input_shape_node = Shape(graph, { 'name': name + '/ShapeOf' }).create_node() range_node = create_op_with_const_inputs(graph, Range, { 0: mo_array(node.start), 2: mo_array(node.step) }, {'name': name + '/Range'}) node.in_port(0).get_connection().set_destination( input_shape_node.in_port(0)) if axis is not None: ''' Replace arange_like op to subgraph: Shape - Gather - Range ''' gather_node = create_op_with_const_inputs(graph, Gather, { 1: int64_array([axis]), 2: int64_array(0) }, {'name': name + '/Gather'}) input_shape_node.out_port(0).connect(gather_node.in_port(0)) gather_node.out_port(0).connect(range_node.in_port(1)) node.out_port(0).get_connection().set_source( range_node.out_port(0)) rename_nodes([(node, name + '/ShouldBeDeleted'), (range_node, name)]) else: r''' Replace arange_like op to subgraph: | ShapeOf ----------- | | | ReduceProd | | | Range | | | Reshape ----------- | | ''' flattened_shape_node = create_op_with_const_inputs( graph, ReduceProd, {1: int64_array([0])}, { 'name': input_shape_node.name + '/ReduceProd', 'keep_dims': True }) reshape_backward_node = Reshape(graph, { 'name': name + '/Reshape_backward' }).create_node() input_shape_node.out_port(0).connect( flattened_shape_node.in_port(0)) flattened_shape_node.out_port(0).connect(range_node.in_port(1)) range_node.out_port(0).connect(reshape_backward_node.in_port(0)) input_shape_node.out_port(0).connect( reshape_backward_node.in_port(1)) node.out_port(0).get_connection().set_source( reshape_backward_node.out_port(0)) rename_nodes([(node, name + '/ShouldBeDeleted'), (reshape_backward_node, name)]) if node.repeat != 1: r""" First, we generate the correct stop value for Range like new_stop_value = stop_value // repeat + 1. Then repeats each value of the interval using Tile. After that we can get a longer interval so we reduce it with Slice. Sub-graph after Range node will be look like Range - Reshape([-1, 1]) - Tile([1, repeat]) - Reshape(-1) - Slice """ if node.repeat < 1: raise Error( "Unexpected value {} of the attribute 'repeat' for the node {}" .format(node.repeat, name)) div_node = create_op_with_const_inputs( graph, Div, {1: int64_array([node.repeat])}, {'name': name + '/Divide'}) add_node = create_op_with_const_inputs( graph, Add, {1: int64_array([1])}, {'name': div_node.name + '/Add'}) cast_node = Cast(graph, { 'name': name + '/ConvertToI64', 'dst_type': np.int64 }).create_node() cast_node.out_port(0).connect(div_node.in_port(0)) div_node.out_port(0).connect(add_node.in_port(0)) range_node.in_port(1).get_connection().set_destination( cast_node.in_port(0)) add_node.out_port(0).connect(range_node.in_port(1)) tile_forward_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array([-1, 1])}, {'name': range_node.name + '/ForwardReshape'}) tile = create_op_with_const_inputs( graph, Tile, {1: int64_array([1, node.repeat])}, {'name': tile_forward_reshape.name + '/Tile'}) tile_backward_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array([-1])}, {'name': tile.name + '/BackwardReshape'}) slice_node = create_op_with_const_inputs( graph, Slice, { 1: int64_array([0]), 3: int64_array([0]), 4: int64_array([1]) }, {'name': tile_backward_reshape.name + '/Slice'}) tile_forward_reshape.out_port(0).connect(tile.in_port(0)) tile.out_port(0).connect(tile_backward_reshape.in_port(0)) tile_backward_reshape.out_port(0).connect(slice_node.in_port(0)) slice_node.in_port(2).connect(div_node.in_port(0).get_source()) range_node.out_port(0).get_connection().set_source( slice_node.out_port(0)) range_node.out_port(0).connect(tile_forward_reshape.in_port(0)) if axis is not None: rename_nodes([(range_node, name + '/Range'), (slice_node, name)]) # MXNet arange_like op has no stop attribute and the result tensor always matches the input shape, so # we have to correct the stop value for the Range node if step != 1 or start != 0 if node.step != 1: # If step attribute is not integer, we will generate an interval with a larger size and then reduce it # using Slice true_elements_count_port = range_node.in_port(1).get_source() mul_value = np.ceil(node.step) if node.step > 0 else np.floor( node.step) stop_value = create_op_with_const_inputs( graph, Mul, port_value_dict={1: mo_array(np.ceil(mul_value))}, op_attrs={'name': range_node.name + '/Stop'}) range_node.in_port(1).get_connection().insert_node(stop_value) slice_range_values = create_op_with_const_inputs( graph, Slice, { 1: int64_array([0]), 3: int64_array([0]), 4: int64_array([1]) }, {'name': range_node.name + '/Slice'}) slice_range_values.in_port(2).connect(true_elements_count_port) range_node.out_port(0).get_connection().insert_node( slice_range_values) if axis is not None and node.repeat == 1: rename_nodes([(range_node, name + '/Range'), (slice_range_values, name)]) if node.start != 0: correct_stop_value = create_op_with_const_inputs( graph, Add, port_value_dict={1: mo_array(node.start)}, op_attrs={'name': range_node.name + '/Correct_Stop'}) range_node.in_port(1).get_connection().insert_node( correct_stop_value) # Range node supports only scalar inputs squeeze_node = create_op_with_const_inputs( graph, Squeeze, port_value_dict={1: int64_array(0)}, op_attrs={"name": range_node.name + '/Stop/Squeeze'}) range_node.in_port(1).get_connection().insert_node(squeeze_node)
def replace_sub_graph(self, graph: Graph, match: dict): identity_spw = match['identity_spw'] gather0_1 = match['gather0_1'] gather0_2 = match['gather0_2'] greaterequal0 = match['greaterequal0'] sparse_fill_empty_rows = match['sparse_fill_empty_rows'] gather = match['gather'] select = match['select'] where0 = match['where0'] sparse_segment_op = match['sparse_segment_op'] output_node_name = select.soft_get('name', select.id) log.debug('Found EmbeddingSparseSegmentsSingleFeature pattern after {} with name {}'.format( sparse_fill_empty_rows.op, sparse_fill_empty_rows.name)) split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)}, {'num_splits': 2}) squeeze_for_indices = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([1])}) split_for_dense_shape = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2}) squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])}) # TODO: remove Cast nodes once we start to support EmbeddingSegmentSum (new version) with segment_ids, # indices, and num_segments of different integer type. # Because the real cases show that it is possible to have it in TensorFlow cast_indices = Cast(graph, {'name': output_node_name + '/CastIndices', 'dst_type': np.int32}).create_node() cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds', 'dst_type': np.int32}).create_node() cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node() cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber', 'dst_type': np.int32}).create_node() if sparse_segment_op.op == 'SparseSegmentSum': embedding_segments_op = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node() else: embedding_segments_op = EmbeddingSegmentsMean(graph, {'name': output_node_name}).create_node() rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_op, output_node_name)]) # connect parameters table gather.in_port(0).get_connection().set_destination(embedding_segments_op.in_port(0)) # connect indices values greaterequal0.in_port(0).get_connection().set_destination(cast_indices.in_port(0)) embedding_segments_op.in_port(1).connect(cast_indices.out_port(0)) # split and connect segment ids gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0)) squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0)) cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0)) embedding_segments_op.in_port(2).connect(cast_segment_ids.out_port(0)) # split and connect number of segments identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0)) squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0)) cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0)) embedding_segments_op.in_port(3).connect(cast_num_segments.out_port(0)) # connect default value sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0)) embedding_segments_op.in_port(4).connect(cast_default_value.out_port(0)) # no input port for per_sample_weight identity_spw.in_port(0).disconnect() gather0_1.in_port(0).disconnect() gather0_2.in_port(0).disconnect() greaterequal0.in_port(0).disconnect() sparse_fill_empty_rows.in_port(2).disconnect() gather.in_port(0).disconnect() select.out_port(0).get_connection().set_source(embedding_segments_op.out_port(0)) graph.remove_nodes_from( [gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): group_norm_node = match['op'] group_norm_num_input_dims = len(group_norm_node.in_port(0).data.get_shape()) # node computing initial GroupNorm input shape initial_shape_op_node = Shape(graph, {'name': group_norm_node.name + '/Shape'}).create_node() initial_shape_op_node.in_port(0).connect(group_norm_node.in_port(0).get_source()) initial_shape_op_node_float = Cast( graph, {'name': initial_shape_op_node.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node() initial_shape_op_node.out_port(0).connect(initial_shape_op_node_float.in_port(0)) initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node_float) initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node_float) initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(initial_shape_op_node) initial_spatial_dims_node = Cast( graph, {'name': initial_spatial_dims_node_int.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node() initial_spatial_dims_node_int.out_port(0).connect(initial_spatial_dims_node.in_port(0)) group_size_node = Const(graph, {'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize'}).create_node() # calculate "features // group_size" value reciprocal_group_size_node = Const(graph, {'value': np.array([1.0 / group_norm_node.num_groups]), 'name': group_norm_node.name + '/ReciprocalGroupSize'}).create_node() c_div_g_node = Mul(graph, {}).create_node() c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0)) c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0)) batch_mul_group_size_node = Mul(graph, {}).create_node() batch_mul_group_size_node.in_port(0).connect(initial_batch_dim_node.out_port(0)) batch_mul_group_size_node.in_port(1).connect(group_size_node.out_port(0)) # create new node which concatenates several dims to one new_shape_node_float = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node]) new_shape_node = Cast(graph, {'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64}).create_node() new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0)) reshape_for_mvn_node = Reshape(graph, {}).create_node() group_norm_node.in_port(0).get_connection().set_destination(reshape_for_mvn_node.in_port(0)) reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0)) # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64) gamma_beta_shape[1] = -1 gamma_value = group_norm_node.in_port(1).get_source().data.get_value() beta_value = group_norm_node.in_port(2).get_source().data.get_value() assert gamma_value is not None, 'The gamma should be constant' assert beta_value is not None, 'The beta should be constant' gamma_value = np.reshape(gamma_value, gamma_beta_shape) group_norm_node.in_port(1).get_source().data.set_value(gamma_value) beta_value = np.reshape(beta_value, gamma_beta_shape) group_norm_node.in_port(2).get_source().data.set_value(beta_value) # MVN mvn_node = MVN(graph, {'name': group_norm_node.name + '/MVN', 'normalize_variance': 1, 'eps': group_norm_node.eps, 'eps_mode': 'inside_sqrt'}).create_node() mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0)) # MVN axes _, rank = get_shape_and_rank_nodes_by_port(mvn_node.in_port(0).get_connection().get_source(), return_as_a_scalar=True) rng = create_op_with_const_inputs(graph, Range, {0: int64_array(1), 2: int64_array(1)}, {'name': group_norm_node.name + '/Range', 'output_type': np.int64}) mvn_node.in_port(1).connect(rng.out_port(0)) rng.in_port(1).connect(rank.out_port(0)) # reshape to the initial shape before multiplying with gamma and adding beta reshape_to_initial_shape_node = Reshape(graph, {}).create_node() reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0)) reshape_to_initial_shape_node.in_port(1).connect(initial_shape_op_node.out_port(0)) mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node() mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0)) group_norm_node.in_port(1).get_connection().set_destination(mul_node.in_port(1)) add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node() add_node.in_port(0).connect(mul_node.out_port(0)) group_norm_node.in_port(2).get_connection().set_destination(add_node.in_port(1)) group_norm_node.out_port(0).get_connection().set_source(add_node.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['pool'] node_name = node.soft_get('name', node.id) if node.pool_step is None: node.stride = int64_array([1, 1, node.window[-1], node.window[-1]]) # create Reshape before convolution # shape = [in_shape[0], pool_stride, 1, in_shape[1]/pool_stride] i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node() dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values shape = Cast(graph, { 'name': node_name + '/to_float', 'dst_type': dst_dtype }).create_node() i_shape.in_port(0).connect(node.in_port(0).get_source()) shape.in_port(0).connect(i_shape.out_port(0)) N, H = node_to_get_shape_value_of_indices( shape, [0]), node_to_get_shape_value_of_indices(shape, [1]) div = create_op_with_const_inputs( graph, Div, {1: float32_array([node.pool_stride])}, {'name': node_name + '/div_stride_h'}) div.in_port(0).connect(H.out_port(0)) concat = create_op_with_const_inputs( graph, Concat, { 1: float32_array([node.pool_stride]), 2: float32_array([1]) }, { 'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0 }) concat.in_port(0).connect(N.out_port(0)) concat.in_port(3).connect(div.out_port(0)) reshape_pattern = Cast(graph, { 'name': node_name + '/to_int', 'dst_type': np.int64 }).create_node() concat.out_port(0).connect(reshape_pattern.in_port(0)) reshape_in = Reshape(graph, { 'name': node_name + '/reshape_in' }).create_node() reshape_in.in_port(1).connect(reshape_pattern.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), {'name': node_name + '/reshape_out'}) # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape_in.out_port(0)) reshape_in.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(reshape_out.out_port(0)) node.out_port(0).connect(reshape_out.in_port(0))
def replace_resize(graph: Graph, resize: Node): log.debug("Converting of ONNX Resize-11 to Interpolate-4 " "is triggered for node {}.".format( resize.soft_get('name', resize.id))) input_shape = resize.in_port(0).data.get_shape() input_rank = len(input_shape) resize_name = resize.soft_get('name', resize.id) if input_rank not in {4, 5}: log.warning( 'The input shape is not 4D or 5D for op with name {}'.format( resize_name)) return assert (resize.is_in_port_connected(0) and (resize.is_in_port_connected(2) or resize.is_in_port_connected(3))), \ "Scales or sizes inputs must be connected to Node {} with op {}.".format(resize.soft_get("name", resize.id), resize.op) assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \ 'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op, resize.soft_get("name", resize.id)) layout = graph.graph['layout'] if input_rank == 4: begin_dim = get_height_dim(layout, input_rank) end_dim = get_width_dim(layout, input_rank) + 1 else: begin_dim = get_depth_dim(layout, input_rank) end_dim = get_width_dim(layout, input_rank) + 1 sizes_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([begin_dim]), 2: int64_array([end_dim]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice_sizes', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) scales_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([begin_dim]), 2: int64_array([end_dim]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice_scales', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) axes_node = Const( graph, { 'name': resize_name + '/axis', 'value': int64_array(np.arange(begin_dim, end_dim)) }).create_node() shape_calculation_mode = 'sizes' if resize.is_in_port_connected( 3) else 'scales' interpolate_node = Interpolate( graph, { 'version': 'opset4', 'mode': convert_mode(resize.mode), 'coordinate_transformation_mode': resize.coordinate_transformation_mode, 'cube_coeff': resize.cube_coeff, 'nearest_mode': resize.nearest_mode, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'antialias': 0, 'shape_calculation_mode': shape_calculation_mode, 'in_ports_count': 4 }).create_node() axes_node.out_port(0).connect(interpolate_node.in_port(3)) shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node() add_node = create_op_with_const_inputs(graph, Add, {1: float_array([1.0e-5])}, {'name': resize_name + '/Add'}) dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values if not resize.is_in_port_connected(3): cast_shape_to_float = Cast(graph, { 'dst_type': dst_dtype }).create_node() mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node() shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) cast_add_result_to_int = Cast(graph, { 'dst_type': np.int64 }).create_node() floor_node = Floor(graph, { 'name': resize_name + '/Floor' }).create_node() mul_node.out_port(0).connect(add_node.in_port(0)) add_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0)) cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_scales = resize.in_port(2).get_connection() connection_of_scales.set_destination(scales_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_scales.get_source().connect(mul_node.in_port(1)) else: cast_shape_to_float = Cast(graph, { 'dst_type': dst_dtype }).create_node() cast_sizes_to_float = Cast(graph, { 'dst_type': dst_dtype }).create_node() div_node = Div(graph, {'name': resize_name + '/Div'}).create_node() cast_sizes_to_float.out_port(0).connect(div_node.in_port(0)) cast_shape_to_float.out_port(0).connect(div_node.in_port(1)) shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) div_node.out_port(0).connect(add_node.in_port(0)) add_node.out_port(0).connect(scales_ss.in_port(0)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_sizes = resize.in_port(3).get_connection() connection_of_sizes.set_destination(sizes_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_sizes.get_source().connect( cast_sizes_to_float.in_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate_node, resize_name)]) resize.out_port(0).get_connection().set_source( interpolate_node.out_port(0))
def dequantize_data(fake_quantize: Node, dst_type: type, quantized_type: type) -> Node: graph = fake_quantize.graph quantized_data = fake_quantize.in_port(0).get_source().node name = fake_quantize.soft_get('name', fake_quantize.id) assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \ 'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id)) dequantizing_cast = Cast( graph, dict(name=quantized_data.name + "/to_{}".format(np_data_type_to_destination_type(dst_type)), dst_type=dst_type, stop_value_propagation=True)).create_node() fake_quantize.in_port(0).get_connection().set_destination( dequantizing_cast.in_port(0)) # limits of dequantize in_low = fake_quantize.in_port(1).get_source() in_high = fake_quantize.in_port(2).get_source() out_low = fake_quantize.in_port(3).get_source() out_high = fake_quantize.in_port(4).get_source() # scale calculation output_range = Sub(graph, { 'name': name + '/output_range' }).create_node() output_range.in_port(0).connect(out_high) output_range.in_port(1).connect(out_low) input_range = Sub(graph, {'name': name + '/input_range'}).create_node() input_range.in_port(0).connect(in_high) input_range.in_port(1).connect(in_low) scale = Div(graph, {'name': name + '/scale'}).create_node() scale.in_port(0).connect(output_range.out_port(0)) scale.in_port(1).connect(input_range.out_port(0)) # shift calculation descaled_output_low = Div(graph, { 'name': name + '/descaled_output_low' }).create_node() descaled_output_low.in_port(0).connect(out_low) descaled_output_low.in_port(1).connect(scale.out_port(0)) shift = Sub(graph, {'name': name + '/shift'}).create_node() shift.in_port(0).connect(in_low) shift.in_port(1).connect(descaled_output_low.out_port(0)) zero = Const(graph, { 'name': name + '/zero', 'value': mo_array(0, dtype=dst_type) }).create_node() scale_eq_zero = Equal(graph, { 'name': name + '/scale_eq_zero' }).create_node() scale_eq_zero.in_port(0).connect(scale.out_port(0)) scale_eq_zero.in_port(1).connect(zero.out_port(0)) zero_point = Select(graph, { 'name': name + '/zero_point' }).create_node() zero_point.in_port(0).connect(scale_eq_zero.out_port(0)) zero_point.in_port(1).connect(zero.out_port(0)) zero_point.in_port(2).connect(shift.out_port(0)) # DeQuantize(x) == Mul(Sub(x, zero_point), scale) sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node() sub_zp.in_port(0).connect(dequantizing_cast.out_port(0)) sub_zp.in_port(1).connect(zero_point.out_port(0)) mul_scale = Mul(graph, { 'name': name + '/mulpiply_by_scale' }).create_node() mul_scale.in_port(0).connect(sub_zp.out_port(0)) mul_scale.in_port(1).connect(scale.out_port(0)) fake_quantize.out_port(0).get_connection().set_source( mul_scale.out_port(0)) graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])
def find_and_replace_pattern(self, graph: Graph): for quantize_node in graph.get_op_nodes(op='QuantizeLinear'): node_name = quantize_node.soft_get('name', quantize_node.id) axis = quantize_node.soft_get('axis', None) scale_y_shape = quantize_node.in_port(1).data.get_shape() if quantize_node.is_in_port_connected(2): zerop = quantize_node.in_port(2).get_source().node else: zerop = Const( graph, { 'value': mo_array(0, dtype=np.uint8), 'name': node_name + '/ZeroPoint' }).create_node() assert zerop.soft_get( 'type' ) == 'Const', 'only constant for zero_point is supported for QuantizeLinear' zero_point_type = zerop.value.dtype # data type affects range of output values: [-128..127] or [0..255] if zero_point_type == np.int8: output_low_value = -128.0 output_high_value = 127.0 elif zero_point_type == np.uint8: output_low_value = 0.0 output_high_value = 255.0 else: raise Error( 'Not expected type {} for zero point value in node {}'. format(zero_point_type, zerop.soft_get('name'))) fake_quantize = create_op_with_const_inputs( graph, FakeQuantize, { 3: float_array(output_low_value), 4: float_array(output_high_value) }, { 'levels': 256, 'name': node_name + '/FakeQuantize' }) quantize_node.in_port(0).get_connection().set_destination( fake_quantize.in_port(0)) # Calculate input_low value mul_low = create_op_with_const_inputs( graph, Mul, {1: float_array(output_low_value - zerop.value)}, {'name': node_name + '/Mul/Low'}) quantize_node.in_port(1).get_connection().set_destination( mul_low.in_port(0)) mul_low.out_port(0).connect(fake_quantize.in_port(1)) # Calculate input_high value mul_high = create_op_with_const_inputs( graph, Mul, {1: float_array(output_high_value - zerop.value)}, {'name': node_name + '/Mul/High'}) mul_low.in_port(0).get_connection().add_destination( mul_high.in_port(0)) mul_high.out_port(0).connect(fake_quantize.in_port(2)) cast = Cast(graph, { 'dst_type': zero_point_type, 'name': node_name + '/Cast' }).create_node() fake_quantize.out_port(0).connect(cast.in_port(0)) quantize_node.out_port(0).get_connection().set_source( cast.out_port(0)) rename_nodes([(quantize_node, node_name + '/TBD'), (cast, node_name)]) assert scale_y_shape is not None if axis is not None and len( scale_y_shape) > 0 and scale_y_shape[0] > 1: input_shape = fake_quantize.in_port(0).data.get_shape() target_shape = np.ones(len(input_shape), np.int) target_shape[axis] = input_shape[axis] mul_low_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array(target_shape)}, {'name': node_name + '/Reshape/Mul/Low'}) mul_high_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array(target_shape)}, {'name': node_name + '/Reshape/Mul/high'}) fake_quantize.in_port(1).get_connection().set_destination( mul_low_reshape.in_port(0)) fake_quantize.in_port(2).get_connection().set_destination( mul_high_reshape.in_port(0)) mul_low_reshape.out_port(0).connect(fake_quantize.in_port(1)) mul_high_reshape.out_port(0).connect(fake_quantize.in_port(2))