def replace_pattern(graph: Graph, match: dict): node = match['conv'] node_name = node.soft_get('name', node.id) # create Reshape before convolution # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride] i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node() shape = Cast( graph, { 'name': node_name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() i_shape.in_port(0).connect(node.in_port(0).get_source()) shape.in_port(0).connect(i_shape.out_port(0)) N, H = node_to_get_shape_value_of_indices( shape, [0]), node_to_get_shape_value_of_indices(shape, [1]) div = create_op_with_const_inputs( graph, Div, {1: float_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'}) div.in_port(0).connect(H.out_port(0)) concat = create_op_with_const_inputs( graph, Concat, { 2: float_array([1]), 3: float_array([node.patch_stride]) }, { 'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0 }) concat.in_port(0).connect(N.out_port(0)) concat.in_port(1).connect(div.out_port(0)) reshape_pattern = Cast(graph, { 'name': node_name + '/to_int', 'dst_type': np.int64 }).create_node() concat.out_port(0).connect(reshape_pattern.in_port(0)) reshape_in = Reshape(graph, { 'name': node_name + '/reshape_in' }).create_node() reshape_in.in_port(1).connect(reshape_pattern.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), {'name': node_name + '/reshape_out'}) # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape_in.out_port(0)) reshape_in.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(reshape_out.out_port(0)) node.out_port(0).connect(reshape_out.in_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): identity_spw = match['identity_spw'] gather0_1 = match['gather0_1'] gather0_2 = match['gather0_2'] greaterequal0 = match['greaterequal0'] sparse_fill_empty_rows = match['sparse_fill_empty_rows'] gather = match['gather'] select = match['select'] where0 = match['where0'] output_node_name = select.soft_get('name', select.id) log.debug('Found EmbeddingSegmentsSum2 pattern after {} with name {}'.format(sparse_fill_empty_rows.op, sparse_fill_empty_rows.name)) split_for_indices = create_op_with_const_inputs(graph, Split, {1: int64_array(1)}, {'num_splits': 2, 'name': output_node_name + '/SplitForIndices'}) squeeze_for_indices = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([1])}) split_for_dense_shape = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2, 'name': output_node_name + '/SplitForDenseShape'}) squeeze_to_scalar = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])}) cast_segment_ids = Cast(graph, {'name': output_node_name + '/CastSegmentIds', 'dst_type': np.int32}).create_node() cast_default_value = Cast(graph, {'name': output_node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node() cast_num_segments = Cast(graph, {'name': output_node_name + '/CastSegmentsNumber', 'dst_type': np.int32}).create_node() embedding_segments_sum = EmbeddingSegmentsSum(graph, {'name': output_node_name}).create_node() rename_nodes([(select, output_node_name + '/AbandonedName'), (embedding_segments_sum, output_node_name)]) # connect parameters table gather.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(0)) # connect indices values greaterequal0.in_port(0).get_connection().set_destination(embedding_segments_sum.in_port(1)) # split and connect segment ids gather0_1.in_port(0).get_connection().set_destination(split_for_indices.in_port(0)) squeeze_for_indices.in_port(0).connect(split_for_indices.out_port(0)) # TODO: remove casting once we start to support I64 model input cast_segment_ids.in_port(0).connect(squeeze_for_indices.out_port(0)) embedding_segments_sum.in_port(2).connect(cast_segment_ids.out_port(0)) # split and connect number of segments identity_spw.in_port(0).get_connection().set_destination(split_for_dense_shape.in_port(0)) squeeze_to_scalar.in_port(0).connect(split_for_dense_shape.out_port(0)) # TODO: remove casting once we start to support I64 model input cast_num_segments.in_port(0).connect(squeeze_to_scalar.out_port(0)) embedding_segments_sum.in_port(3).connect(cast_num_segments.out_port(0)) # connect default value # TODO: remove casting once we start to support I64 model input sparse_fill_empty_rows.in_port(3).get_connection().set_destination(cast_default_value.in_port(0)) embedding_segments_sum.in_port(4).connect(cast_default_value.out_port(0)) # no input port for per_sample_weight identity_spw.in_port(0).disconnect() gather0_1.in_port(0).disconnect() gather0_2.in_port(0).disconnect() greaterequal0.in_port(0).disconnect() sparse_fill_empty_rows.in_port(2).disconnect() gather.in_port(0).disconnect() select.out_port(0).get_connection().set_source(embedding_segments_sum.out_port(0)) graph.remove_nodes_from([gather0_1.id, gather0_2.id, greaterequal0.id, sparse_fill_empty_rows.id, select.id, where0.id])
def replace_interpolate_pattern(graph: Graph, match: dict): split = match['split'] scale = np.array([get_split_scale(split)], dtype=np.float32) axis = int(split.in_port(1).get_connection().get_source().node.value) split_node_name = split.name axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node() shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node() scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node() mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node() scales_node.out_port(0).connect(mul_node.in_port(1)) strided_slice_node = create_op_with_const_inputs(graph, StridedSlice, {1: int64_array([axis]), 2: int64_array([axis + 1])}, { 'name': split_node_name + '/StridedSlice', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) shape_node.out_port(0).connect(strided_slice_node.in_port(0)) cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node() strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate', mode='nearest', antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]), coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor', cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales', in_ports_count=4, maybe_part_of_sequence=True)).create_node() floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node() cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node() mul_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0)) cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1)) scales_node.out_port(0).connect(interp_node.in_port(2)) axis_node.out_port(0).connect(interp_node.in_port(3)) match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0)) split_connection = split.in_port(0).get_connection() split_connection.set_destination(interp_node.in_port(0)) split_connection.get_source().connect(shape_node.in_port(0))
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) # broadcast default value to required shape broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node() node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1)) if not node.in_port(3).disconnected(): # TODO: remove casting once we start to support I64 model input # cast default value to I32 due limitation about I64 input support # so that input parameter and default value will be of the same I32 type as required ScatterNDUpdate cast_default_value = Cast(graph, {'name': node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node() node.in_port(3).get_connection().set_destination(cast_default_value.in_port(0)) broadcast_node.in_port(0).connect(cast_default_value.out_port(0)) else: broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_', 'value': np.float32(0)} ).create_node().out_port(0)) # update broadcasted tensor with required values at required locations scatternd_node = ScatterNDUpdate(graph, {'name': node_name + '/ScatterNDUpdate_'}).create_node() scatternd_node.in_port(0).connect(broadcast_node.out_port(0)) node.in_port(0).get_connection().set_destination(scatternd_node.in_port(1)) node.in_port(2).get_connection().set_destination(scatternd_node.in_port(2)) rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)]) return [scatternd_node.id]
def replace_op(self, graph: Graph, node: Node): name = node.soft_get('name', node.id) axis = node.soft_get('axis', 0) rename_node(node=node, name=name + '/to_be_removed') cumsum_node = create_op_node_with_second_input(graph, CumSum, int64_array(axis), { 'name': name, 'reverse': False, 'exclusive': False }) rename_node(cumsum_node, name) node.in_port(0).get_connection().set_destination( cumsum_node.in_port(0)) if node.has_valid('mx_out_type') and node['mx_out_type'] is not None: rename_node(node=cumsum_node, name=name + '/CumSum') convert = Cast(graph, { 'name': name, 'dst_type': node['mx_out_type'] }).create_node() rename_node(convert, name) cumsum_node.out_port(0).connect(convert.in_port(0)) return [convert.id] else: return [cumsum_node.id]
def replace_pattern(graph: Graph, match: dict): node = match['node'] for in_port, precision in node.force_precision_in_ports.items(): if in_port in node.in_ports().keys() and not node.in_port(in_port).disconnected(): cast = Cast(graph, {'name': node.name + '/Cast_' + str(in_port), 'dst_type': data_type_str_to_np(precision)}).create_node() node.in_port(in_port).get_connection().insert_node(cast)
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): initial_fake_quantize = match['quantize'] new_fake_quantize = initial_fake_quantize.copy_node(dict(name=initial_fake_quantize.name + '/Copy', stop_value_propagation=False), graph) initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1)) initial_fake_quantize.in_port(2).get_connection().set_destination(new_fake_quantize.in_port(2)) dst_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) i_min = np.array([0.], dtype=dst_type) i_max = np.array([initial_fake_quantize.levels - 1.], dtype=dst_type) new_out_low_node = Const(graph, dict(name=initial_fake_quantize.name + '/Copy/out_low', value=i_min)).create_node() new_out_high_node = Const(graph, dict(name=initial_fake_quantize.name + '/Copy/out_high', value=i_max)).create_node() new_out_low_node.out_port(0).connect(new_fake_quantize.in_port(3)) new_out_high_node.out_port(0).connect(new_fake_quantize.in_port(4)) new_out_low_node.out_port(0).connect(initial_fake_quantize.in_port(1)) new_out_high_node.out_port(0).connect(initial_fake_quantize.in_port(2)) cast_node = Cast(graph, dict(name=initial_fake_quantize.name + "/Convert_to_float", dst_type=dst_type, stop_value_propagation=True)).create_node() new_fake_quantize.out_port(0).connect(cast_node.in_port(0)) initial_fake_quantize.in_port(0).get_connection().set_destination(new_fake_quantize.in_port(0)) cast_node.out_port(0).connect(initial_fake_quantize.in_port(0)) cast_node['force_precision_in_ports'] = {0: 'uint8'}
def insert_do(graph: Graph, replacement_descriptions: dict): do_outputs = replacement_descriptions['do_outputs'] prior_boxes_node = Node(graph, 'ROIFeatureExtractor_2') num_classes = 81 box_regressions_input_node = Node( graph, replacement_descriptions['box_regressions_input_node']) box_regressions_node = create_op_node_with_second_input( graph, Reshape, int64_array([-1, 4 * num_classes]), dict(name='box_regressions'), box_regressions_input_node) class_predicitons_node = Node( graph, replacement_descriptions['class_predicitons_node']) im_info_node = Parameter(graph, { "name": 'im_info', 'shape': int64_array([1, 3]) }).create_node() do_node = ExperimentalDetectronDetectionOutput( graph, { 'name': 'DetectionOutput', 'class_agnostic_box_regression': 0, 'deltas_weights': np.array([10.0, 10.0, 5.0, 5.0]), 'max_delta_log_wh': replacement_descriptions['max_delta_log_wh'], 'nms_threshold': replacement_descriptions['nms_threshold'], 'score_threshold': replacement_descriptions['score_threshold'], 'num_classes': num_classes, 'max_detections_per_image': replacement_descriptions['max_detections_per_image'], 'post_nms_count': replacement_descriptions['post_nms_count'] }).create_node() prior_boxes_node.out_port(1).connect(do_node.in_port(0)) box_regressions_node.out_port(0).connect(do_node.in_port(1)) class_predicitons_node.out_port(0).connect(do_node.in_port(2)) im_info_node.out_port(0).connect(do_node.in_port(3)) do_output_ports = [ do_node.out_port(0), do_node.out_port(1), do_node.out_port(2) ] old_do_output_nodes = [Node(graph, node_id) for node_id in do_outputs] for old_node, new_port in zip(old_do_output_nodes, do_output_ports): old_node.out_port(0).get_connection().set_source(new_port) # the consumer of the second output port of the ExperimentalDetectronDetectionOutput is the Mul node which second # input is of type int64 so it is necessary to insert Cast to have data types match do_node.out_port(1).get_connection().insert_node( Cast(graph, { 'dst_type': np.int64 }).create_node())
def create_ss_interval_border(graph: Graph, slice_border_port: Port, shape: np.ndarray, axes: np.ndarray, node_name: str): """ This function creates "begin"/"end" parameters for the StridedSlice based on Slice's "starts"/"ends" :param graph: graph to operate on. :param slice_border_port: node output port that provides "starts"/"ends" values for the Slice. :param shape: input shape of the Slice :param axes: axes that "starts" and "ends" apply to :param node_name: Slice node name :return: Concat node that forms "begin"/"end" values for the StridedSlice """ # the value for 'starts' or 'ends' might be maximum/minimum possible value of int64. This # value must be converted to maximum/minimum of int32 because such big values do not fit into the int32 which is # supported by the StridedSlice layer clamp = create_op_with_const_inputs(graph, Clamp, port_value_dict={ 1: np.iinfo(np.int32).min, 2: np.iinfo(np.int32).max }, op_attrs=dict(name=node_name + '/Clamp')) clamp.in_port(0).connect(slice_border_port) # we have to convert "starts"/"ends" values from the network to one data type with constant values that are created # here to prevent type errors in Concat node cast = Cast(graph, dict(name=node_name + '/CastToI64', dst_type=np.int64)).create_node() cast.in_port(0).connect(clamp.out_port(0)) concat = Concat(graph, dict(name=node_name + '/Concat', axis=0)).create_node() for value_idx, port_idx in enumerate(axes): concat.add_input_port(port_idx) # "axes" may not be sorted, so we need to split "starts"/"ends" values and connect each value to the correct # Concat input port value = create_op_with_const_inputs( graph, Gather, port_value_dict={ 1: int64_array([value_idx]), 2: int64_array(0) }, op_attrs={'name': node_name + '/Gather'}) cast.out_port(0).connect(value.in_port(0)) value.out_port(0).connect(concat.in_port(port_idx)) for port_idx in range(len(shape)): if not concat.is_in_port_connected(port_idx): concat.add_input_port(port_idx) # This border value would be ignored in StridedSlice because of the begin_mask\end_mask const = Const( graph, dict(name=node_name + '/Const', value=int64_array([0]))).create_node() const.out_port(0).connect(concat.in_port(port_idx)) return concat
def replace_sub_graph(self, graph: Graph, match: dict): # TODO: Once Inference Engine's CTCGreedyDecoder starts to support sequence length format like in TensorFlow, # CTCGreedyDecoderReplacement2 needs to be removed and CTCGreedyDecoderReplacement, a more generic # transformation, needs to be adopted for all cases ctc_greedy_decoder = match['decoder'] cast = match['cast'] sparse_to_dense = match['sparse_to_dense'] sparse_to_dense_name = sparse_to_dense.soft_get( 'name', sparse_to_dense.id) # disconnect SparseToDense and Cast nodes sparse_to_dense.in_port(0).disconnect() cast.in_port(0).disconnect() # transform CTCGreedyDecoder output to TensorFlow's one: # 1. squeeze the output to [N, T] shape # 2. cast it to integer squeeze_dec_seq = create_op_with_const_inputs( graph, Squeeze, {1: int64_array([2, 3])}, {'name': sparse_to_dense_name}) squeeze_dec_seq.in_port(0).connect(ctc_greedy_decoder.out_port(0)) cast_to_int = Cast(graph, { 'name': sparse_to_dense_name + '/CastToInt', 'dst_type': np.int32 }).create_node() cast_to_int.in_port(0).connect(squeeze_dec_seq.out_port(0)) # preserve output name from original graph rename_nodes([(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'), (cast_to_int, sparse_to_dense_name)]) # set output of the new sub-graph as a source for SparseToDense consumer sparse_to_dense.out_port(0).get_connection().set_source( cast_to_int.out_port(0)) # remove no longer needed nodes graph.remove_nodes_from([sparse_to_dense.id, cast.id]) # mark CTCGreedyDecoder node as a node that requires transformation of sequence length to a mask format # in the middle phase ctc_greedy_decoder['use_mask_format'] = True # unless the second input of CTCGreedyDecoder is a parameter, it enforces MO to use --static-shape # to try getting the second input with a value sequence_length_node = ctc_greedy_decoder.in_node(1) if sequence_length_node.soft_get( 'op' ) != 'Parameter' and not graph.graph['cmd_params'].static_shape: log.error( "Model can not be translated in a reshape-able way.\n" "Model Optimizer key static_shape was turned on to prevent related errors.\n" "There will be no success changing input shapes of the model with the help of " "InferenceEngine reshape method", extra={'is_warning': True}) graph.graph['cmd_params'].static_shape = True
def placeholder_scales(self, placeholder: Node): """ Helper function to get scales for prior boxes out of input image size: [1 / im_width, 1 / im_height, 1 / im_width, 1 / im_height] """ graph = placeholder.graph name = placeholder.soft_get('name', placeholder.id) shape_value = placeholder.soft_get('shape', None) assert shape_value is not None, \ "[ {} replacer ] Placeholder `{}` should have shape attribute".format(self.replacement_id, name) assert isinstance(shape_value, np.ndarray), \ "[ {} replacer ] Placeholder `{}` shape attribute should be np.ndarray".format(self.replacement_id, name) assert shape_value.size == 4, \ "[ {} replacer ] Placeholder `{}` should be 4D. Shape: {}".format(self.replacement_id, name, shape_value) shape = Shape(graph, {'name': 'input_image_shape'}).create_node() shape.in_port(0).connect(placeholder.out_port(0)) begin = Const(graph, {'value': int64_array([1])}).create_node() end = Const(graph, {'value': int64_array([3])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]), 'end_mask': np.array([1]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() spatial.in_port(0).connect(shape.out_port(0)) spatial.in_port(1).connect(begin.out_port(0)) spatial.in_port(2).connect(end.out_port(0)) spatial.in_port(3).connect(stride.out_port(0)) power = Const(graph, {'value': float32_array([-1.])}).create_node() spatial_scale = Pow(graph, {}).create_node() spatial_scale.in_port(0).connect(spatial.out_port(0)) spatial_scale.in_port(1).connect(power.out_port(0)) # Power `type_infer` requires inputs to have equal data type convert_to_fp32 = Cast(graph, {'dst_type': np.float32}).create_node() spatial_scale.in_port(0).get_connection().insert_node(convert_to_fp32) order = Const(graph, {'value': int64_array([1, 0])}).create_node() axis_const = Const(graph, {'value': int64_array(0)}).create_node() reverse = Gather(graph, {}).create_node() reverse.in_port(0).connect(spatial_scale.out_port(0)) reverse.in_port(1).connect(order.out_port(0)) axis_const.out_port(0).connect(reverse.in_port(2)) priors_scale_node = Concat(graph, {'axis': 0, 'in_ports_count': 2}).create_node() priors_scale_node.add_input_port(0, skip_if_exist=True) priors_scale_node.add_input_port(1, skip_if_exist=True) priors_scale_node.in_port(0).connect(reverse.out_port(0)) priors_scale_node.in_port(1).connect(reverse.out_port(0)) return priors_scale_node
def decompose_shuffle_channel(node: Node): graph = node.graph name = node.soft_get('name', node.id) rename_node(node, name + '/to_be_removed') shape = Shape(graph, dict(name=name + '/InputShape')).create_node() shape.in_port(0).connect(node.in_port(0).get_source()) # Reshape [input_batch, group, input_channels/group, -1] batch = node_to_get_batch_value(shape) group = Const( graph, dict(name=name + '/Rows', value=int64_array([node.group]))).create_node() const = Const(graph, dict(name=name + '/Const', value=int64_array([-1]))).create_node() input_channels = node_to_get_features_dimension_value(shape) output_channels = create_op_node_with_second_input( graph, Div, np.int64(node.group), {'name': name + '/Cols'}, input_node=input_channels) i_output_channels = Cast(graph, { 'name': output_channels.name + '/Convert', 'dst_type': np.int64 }).create_node() output_channels.out_port(0).connect(i_output_channels.in_port(0)) reshape_split_dim = new_shape_node_from_shape_nodes( [batch, group, i_output_channels, const]) reshape_split_node = Reshape( graph, dict(name=name + '/Reshape_split_')).create_node() reshape_split_dim.out_port(0).connect(reshape_split_node.in_port(1)) # Transpose(0, 2, 1, 3) transpose_node = create_op_node_with_second_input( graph, Transpose, int64_array([0, 2, 1, 3]), {'name': name + '/Transpose_'}, input_node=reshape_split_node) # Reshape back to input shape reshape_concat = Reshape(graph, dict(name=name)).create_node() rename_node(reshape_concat, name) shape.out_port(0).connect(reshape_concat.in_port(1)) transpose_node.out_port(0).connect(reshape_concat.in_port(0)) # Final connections node.in_port(0).get_connection().set_destination( reshape_split_node.in_port(0)) node.out_port(0).get_connection().set_source( reshape_concat.out_port(0))
def dequantize_data(fake_quantize: Node, dst_type: type, quantized_type: type) -> Node: graph = fake_quantize.graph quantized_data = fake_quantize.in_port(0).get_source().node name = fake_quantize.soft_get('name', fake_quantize.id) assert quantized_data.soft_get('type') == 'Convert' and quantized_data.dst_type == quantized_type, \ 'Weights aren`t compressed as expected for node {}'.format(fake_quantize.soft_get('name', fake_quantize.id)) dequantizing_cast = Cast(graph, dict( name=quantized_data.name + "/to_{}".format(np_data_type_to_destination_type(dst_type)), dst_type=dst_type, stop_value_propagation=True)).create_node() fake_quantize.in_port(0).get_connection().set_destination(dequantizing_cast.in_port(0)) # limits of dequantize in_low = fake_quantize.in_port(1).get_source() in_high = fake_quantize.in_port(2).get_source() out_low = fake_quantize.in_port(3).get_source() out_high = fake_quantize.in_port(4).get_source() # scale calculation output_range = Sub(graph, {'name': name + '/output_range'}).create_node() output_range.in_port(0).connect(out_high) output_range.in_port(1).connect(out_low) input_range = Sub(graph, {'name': name + '/input_range'}).create_node() input_range.in_port(0).connect(in_high) input_range.in_port(1).connect(in_low) scale = Div(graph, {'name': name + '/scale'}).create_node() scale.in_port(0).connect(output_range.out_port(0)) scale.in_port(1).connect(input_range.out_port(0)) # shift calculation descaled_output_low = Div(graph, {'name': name + '/descaled_output_low'}).create_node() descaled_output_low.in_port(0).connect(out_low) descaled_output_low.in_port(1).connect(scale.out_port(0)) shift = Sub(graph, {'name': name + '/zero_point'}).create_node() shift.in_port(0).connect(in_low) shift.in_port(1).connect(descaled_output_low.out_port(0)) # DeQuantize(x) == Mul(Sub(x, zero_point), scale) sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node() sub_zp.in_port(0).connect(dequantizing_cast.out_port(0)) sub_zp.in_port(1).connect(shift.out_port(0)) mul_scale = Mul(graph, {'name': name + '/mulpiply_by_scale'}).create_node() mul_scale.in_port(0).connect(sub_zp.out_port(0)) mul_scale.in_port(1).connect(scale.out_port(0)) fake_quantize.out_port(0).get_connection().set_source(mul_scale.out_port(0)) graph.remove_nodes_from([fake_quantize.id, fake_quantize.out_node(0)])
def find_and_replace_pattern(self, graph: Graph): for dequantize_node in graph.get_op_nodes(op='DequantizeLinear'): node_name = dequantize_node.soft_get('name', dequantize_node.id) axis = dequantize_node.soft_get('axis', None) scale_y_shape = dequantize_node.in_port(1).data.get_shape() model_data_type = data_type_str_to_np( graph.graph['cmd_params'].data_type) cast = Cast(graph, { 'dst_type': model_data_type, 'name': node_name + '/Cast' }).create_node() dequantize_node.in_port(0).get_connection().set_destination( cast.in_port(0)) mul = Mul(graph, {}).create_node() is_second_port_connected = dequantize_node.is_in_port_connected(2) if is_second_port_connected: sub = Sub(graph, {'name': node_name + '/Sub'}).create_node() cast.out_port(0).connect(sub.in_port(0)) dequantize_node.in_port(2).get_connection().set_destination( sub.in_port(1)) sub.out_port(0).connect(mul.in_port(0)) else: cast.out_port(0).connect(mul.in_port(0)) dequantize_node.in_port(1).get_connection().set_destination( mul.in_port(1)) dequantize_node.out_port(0).get_connection().set_source( mul.out_port(0)) rename_nodes([(dequantize_node, node_name + '/TBD'), (mul, node_name)]) assert scale_y_shape is not None if axis is not None and len( scale_y_shape) > 0 and scale_y_shape[0] > 1: input_shape = cast.in_port(0).data.get_shape() target_shape = np.ones(len(input_shape), np.int64) target_shape[axis] = input_shape[axis] mul_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array(target_shape)}, {'name': node_name + '/Reshape/Mul'}) mul.in_port(1).get_connection().set_destination( mul_reshape.in_port(0)) mul_reshape.out_port(0).connect(mul.in_port(1)) if is_second_port_connected: sub_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array(target_shape)}, {'name': node_name + '/Reshape/Sub'}) sub.in_port(1).get_connection().set_destination( sub_reshape.in_port(0)) sub_reshape.out_port(0).connect(sub.in_port(1))
def create_ss_interval_border(graph: Graph, shape, axes, port_to_connect: Port, node_name): shape_mask = np.zeros(len(shape), dtype=np.int64) first_part = shape_mask[:axes[0]] last_part = shape_mask[axes[-1] + 1:] cast = Cast(graph, dict(name=node_name + '/CastToI64', dst_type=np.int64)).create_node() port_to_connect.get_connection().set_destination(cast.in_port(0)) concat = create_op_with_const_inputs(graph, Concat, port_value_dict={0: first_part, 2: last_part}, op_attrs={'name': node_name + '/Concat', 'axis': 0, 'in_ports_count': 3}) cast.out_port(0).connect(concat.in_port(1)) return concat
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='ThresholdedRelu'): name = node.soft_get('name', node.id) greater = create_op_with_const_inputs(graph, Greater, {1: float_array([node.alpha])}) greater.in_port(0).connect(node.in_port(0).get_source()) float_greater = Cast(graph, {'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node() greater.out_port(0).connect(float_greater.in_port(0)) mul = Mul(graph, {}).create_node() node.out_port(0).get_connection().set_source(mul.out_port(0)) mul.in_port(0).connect(node.in_port(0).get_source()) mul.in_port(1).connect(float_greater.out_port(0)) rename_nodes([(node, name + '/TBR'), (mul, name)])
def quantize_data(fake_quantize: Node, dst_type: type, quantized_type: type, mode: str): graph = fake_quantize.graph name = fake_quantize.soft_get('name', fake_quantize.id) levels = fake_quantize.levels quantize = fake_quantize.copy_node( dict(name=name + '/Copy', stop_value_propagation=False), graph) fake_quantize.in_port(0).get_connection().set_destination( quantize.in_port(0)) # inherit input limits fake_quantize.in_port(1).get_connection().set_destination( quantize.in_port(1)) fake_quantize.in_port(2).get_connection().set_destination( quantize.in_port(2)) # calculate output limits for quantized weights assert mode in ["signed", "unsigned"] i_min_value = -(levels // 2) if mode == "signed" else 0 i_min = np.array([i_min_value], dtype=dst_type) i_max = np.array(levels + i_min - 1, dtype=dst_type) assert i_max - i_min == levels - 1 out_low = Const(graph, dict(name=name + '/Copy/out_low', value=i_min)).create_node() out_high = Const(graph, dict(name=name + '/Copy/out_high', value=i_max)).create_node() out_low.out_port(0).connect(quantize.in_port(3)) out_high.out_port(0).connect(quantize.in_port(4)) out_low.out_port(0).connect(fake_quantize.in_port(1)) out_high.out_port(0).connect(fake_quantize.in_port(2)) original_const = quantize.in_port(0).get_source().node quantized_data_name = original_const.soft_get( 'name', original_const.id) + '/quantized' cast = Cast( graph, dict(name=quantized_data_name, dst_type=quantized_type, stop_value_propagation=False)).create_node() quantize.out_port(0).connect(cast.in_port(0)) cast.out_port(0).connect(fake_quantize.in_port(0))
def find_and_replace_pattern(self, graph: Graph): ir_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) for node in graph.get_op_nodes(op='RandomUniform'): assert node.has_valid('output_type') if node.has_and_set('returns_shape_value'): continue if node.output_type != ir_data_type and np.issubdtype( node.output_type, np.floating): node_name = node.soft_get('name', node.id) convert_node = Cast(graph, { 'name': node_name + "/cast", 'dst_type': ir_data_type }).create_node() node.out_port(0).get_connection().insert_node(convert_node)
def replace_op(self, graph: Graph, node: Node): if node.has_and_set('inputs_preprocessed'): log.debug('Node "{}" has already been preprocessed'.format( node.soft_get('name'))) return [] # reshape tensor with batch indices to 2d unsqueeze_node = create_op_node_with_second_input( graph, Unsqueeze, int64_array([1]), {'name': node.name + '/Unsqueeze'}, node.in_node(2)) convert_node = Cast( graph, { 'name': unsqueeze_node.name + '/ToFloat', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() convert_node.in_port(0).connect(unsqueeze_node.out_port(0)) concat_op = Concat( graph, { 'axis': 1, 'name': node.name + '/concat_batch_indices_and_boxes', 'in_ports_count': 2 }) concat_node = concat_op.create_node([convert_node, node.in_node(1)]) # do not remove edge with crop_size because it is needed in the partial infer graph.remove_edge(node.in_node(1).id, node.id) # input to the CropAndResize contains boxes coordinates in YXYX layout. But IE layer ROIPooling expects # coordinates in the XYXY layout, so convolution is added here to swap coordinates swapped_box_coordinates_node = add_convolution_to_swap_xy_coordinates( graph, concat_node, 5) # reshape locations tensor to 2D so it could be passed to Eltwise which will be converted to ScaleShift reshape_2d_node = create_op_node_with_second_input( graph, Reshape, int64_array([-1, 5]), dict(name=swapped_box_coordinates_node.id + '/reshape_2d_'), swapped_box_coordinates_node) graph.create_edge(reshape_2d_node, node, 0, 1) # do not replace any output edge return []
def replace_sub_graph(self, graph: Graph, match: dict): tf_slice_node = match['op'] slice_name = tf_slice_node.soft_get('name', tf_slice_node.id) slice_node = Slice(graph).create_node() rename_nodes([(tf_slice_node, slice_name + '/to_be_removed'), (slice_node, slice_name)]) ends_node = Add(graph, {'name': slice_name + '/ends'}).create_node() # reconnect input, begin, and size from TFSlice to the subgraph with Slice tf_slice_node.in_port(0).get_connection().set_destination( slice_node.in_port(0)) tf_slice_node.in_port(1).get_connection().set_destination( slice_node.in_port(1)) tf_slice_node.in_port(2).get_connection().set_destination( ends_node.in_port(0)) slice_node.in_port(1).get_connection().add_destination( ends_node.in_port(1)) max_ends = Shape(graph, { 'name': slice_name + '/ShapeOf' }).create_node() slice_node.in_port(0).get_connection().add_destination( max_ends.in_port(0)) # check if size[i] == -1, will be applied elementwisely: len(size) = len(begin) = input_rank where_max_ends_is_needed = create_op_with_const_inputs( graph, Equal, {0: int64_array(-1)}, {'name': slice_name + '/where_max_ends_is_needed'}) ends_node.in_port(0).get_connection().add_destination( where_max_ends_is_needed.in_port(1)) # select requires equal dtypes, need to convert ends to I64 ends_casted_to_i64 = Cast(graph, { 'name': slice_name + '/CastToI64', 'dst_type': np.int64 }).create_node([ends_node]) # if size[i] == 1 then take max_ends values correct_ends = Select(graph, { 'name': slice_name + '/chosen_ends' }).create_node( [where_max_ends_is_needed, max_ends, ends_casted_to_i64]) correct_ends.out_port(0).connect(slice_node.in_port(2)) tf_slice_node.out_port(0).get_connection().set_source( slice_node.out_port(0))
def convert_inputs_of_specific_ops(graph: Graph): type_port = {'Broadcast': {1: 'int64', 2: 'int64'}, 'ConvolutionBackpropData': {2: 'int64'}, 'Deconvolution': {2: 'int64'}, 'Gather': {2: 'int64'}, 'GroupConvolutionBackpropData': {2: 'int64'}, 'Interpolate': {1: 'int64'}, 'LRN': {1: 'int64'}, 'NonMaxSuppression': {2: 'int64'}, 'NormalizeL2': {1: 'int64'}, 'OneHot': {1: 'int64'}, 'Pad': {1: 'int64', 2: 'int64'}, 'PriorBox': {0: 'int64', 1: 'int64'}, 'PriorBoxClustered': {0: 'int64', 1: 'int64'}, 'ReduceLogicalAnd': {1: 'int64'}, 'ReduceLogicalOr': {1: 'int64'}, 'ReduceMax': {1: 'int64'}, 'ReduceMean': {1: 'int64'}, 'ReduceMin': {1: 'int64'}, 'ReduceProd': {1: 'int64'}, 'ReduceSum': {1: 'int64'}, 'Reshape': {1: 'int64'}, 'Squeeze': {1: 'int64'}, 'StridedSlice': {1: 'int64', 2: 'int64', 3: 'int64'}, 'Split': {1: 'int64'}, 'Tile': {1: 'int64'}, 'Transpose': {1: 'int64'}, 'Unsqueeze': {1: 'int64'}, 'VariadicSplit': {1: 'int64', 2: 'int64'}, } for node in graph.get_op_nodes(): if node.soft_get('type') in type_port: ports_to_update = type_port[node.soft_get('type')] for port_id, precision in ports_to_update.items(): if port_id in node.in_ports() and not node.in_port(port_id).disconnected(): log.debug('Converting value for the input port "{}" of op "{}" to "{}".' ''.format(port_id, node.soft_get('name', node.id), precision)) in_port = node.in_port(port_id) np_type = data_type_str_to_np(precision) if in_port.get_source().node.type == 'Const': convert_const_node_value_type(node.in_port(port_id).get_source().node, np_type) else: in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node())
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) model_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) cast = Cast(graph, {'dst_type': model_data_type, 'name': node_name + '/Cast'}).create_node() node.in_port(0).get_connection().set_destination(cast.in_port(0)) mul = Mul(graph, {}).create_node() if node.is_in_port_connected(2): sub = Sub(graph, {'name': node_name + '/Sub'}).create_node() cast.out_port(0).connect(sub.in_port(0)) node.in_port(2).get_connection().set_destination(sub.in_port(1)) sub.out_port(0).connect(mul.in_port(0)) else: cast.out_port(0).connect(mul.in_port(0)) node.in_port(1).get_connection().set_destination(mul.in_port(1)) rename_nodes([(node, node_name + '/TBD'), (mul, node_name)]) return [mul.id]
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] slice_name = node.soft_get('name', node.id) slice_node = Slice(graph).create_node() rename_nodes([(node, slice_name + '/to_be_removed'), (slice_node, slice_name)]) eq_node = Equal(graph, {'name': slice_name + '/equal'}).create_node() minus_one_node = Const(graph, {'name': slice_name + '/minus_one', 'value': np.array(-1)}).create_node() int32_max_node = Const(graph, {'name': slice_name + '/int32_max', 'value': np.iinfo(np.int32).max}).create_node() select_node = Select(graph, {'name': slice_name + '/select'}).create_node() # node to convert sizes to ends sum_node = Add(graph, {'name': slice_name + '/end_const'}).create_node() # reconnect input from tfslice to slice node.in_port(0).get_source().connect(slice_node.in_port(0)) node.in_port(0).disconnect() # reconnect begin of tfslice to start of slice node.in_port(1).get_source().connect(slice_node.in_port(1)) node.in_port(1).disconnect() # (size -> ends) reconnect begins and sizes to sum to evaluate ends for Slice # connects begins to slice slice_node.in_port(1).get_source().connect(sum_node.in_port(0)) node.in_port(2).get_source().connect(sum_node.in_port(1)) node.in_port(2).disconnect() # if size[i] == -1 when take int32_max as end[i] sum_node.in_port(1).get_source().connect(eq_node.in_port(0)) minus_one_node.out_port(0).connect(eq_node.in_port(1)) # from equal to 0 port of select eq_node.out_port(0).connect(select_node.in_port(0)) # from int32_max to 1 of select int32_max_node.out_port(0).connect(select_node.in_port(1)) # from sum to 2nd of select sum_node.out_port(0).connect(select_node.in_port(2)) # out of select to end (2nd of slice) select_node.out_port(0).connect(slice_node.in_port(2)) cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node() select_node.in_port(2).get_connection().insert_node(cast) node.out_port(0).get_connection().set_source(slice_node.out_port(0))
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) if node.is_in_port_connected(2): zerop = node.in_port(2).get_source().node else: zerop = Const(graph, {'value': np.array(0, dtype=np.uint8), 'name': node_name + '/ZeroPoint'}).create_node() assert zerop.soft_get('type') == 'Const', 'only constant for zero_point is supported for QuantizeLinear' zero_point_type = zerop.value.dtype # data type affects range of output values: [-128..127] or [0..255] if zero_point_type == np.int8: output_low_value = -128.0 output_high_value = 127.0 elif zero_point_type == np.uint8: output_low_value = 0.0 output_high_value = 255.0 else: raise Error('Not expected type {} for zero point value in node {}'.format( zero_point_type, zerop.soft_get('name'))) fake_quantize = create_op_with_const_inputs(graph, FakeQuantize, {3: float_array(output_low_value), 4: float_array(output_high_value)}, {'levels': 256, 'name': node_name + '/FakeQuantize'}) node.in_port(0).get_connection().set_destination(fake_quantize.in_port(0)) # Calculate input_low value mul_low = create_op_with_const_inputs(graph, Mul, {1: float_array(output_low_value - zerop.value)}, {'name': node_name + '/Mul/Low'}) node.in_port(1).get_connection().set_destination(mul_low.in_port(0)) mul_low.out_port(0).connect(fake_quantize.in_port(1)) # Calculate input_high value mul_high = create_op_with_const_inputs(graph, Mul, {1: float_array(output_high_value - zerop.value)}, {'name': node_name + '/Mul/High'}) mul_low.in_port(0).get_connection().add_destination(mul_high.in_port(0)) mul_high.out_port(0).connect(fake_quantize.in_port(2)) cast = Cast(graph, {'dst_type': zero_point_type, 'name': node_name + '/Cast'}).create_node() rename_nodes([(node, node_name + '/TBD'), (cast, node_name)]) fake_quantize.out_port(0).connect(cast.in_port(0)) return [cast.id]
def add_removed_converts(graph: Graph): for data_node_name in graph.get_nodes_with_attributes(Insert_Convert_operation_after=True): data_node = Node(graph, data_node_name) # Get access to Const node connected to data node const_op = data_node.in_node(0) assert const_op.data_type == np.float32, "Error when try to insert Convert operation after Const: {}".\ format(const_op.soft_get('name')) convert_op = Cast(graph, {'dst_type': np.float32, 'name': const_op.name + '/restored_convert', 'stop_value_propagation': True}).create_node() # Insert Convert operation after Const operation consumer_port = const_op.out_port(0).get_connection().get_destination() const_op.out_port(0).get_connection().set_destination(convert_op.in_port(0)) convert_op.out_port(0).connect(consumer_port) # Convert Const value to FP32 to make types in graph consistent const_op.value, _, _ = convert_blob(const_op.value, np.float16) const_op.infer(const_op)
def convert_outputs_of_specific_ops(graph: Graph): type_port = { 'ShapeOf': { 0: 'int32' }, 'NonMaxSuppression': { 0: 'int32' }, } for node in graph.get_op_nodes(): if node.soft_get('type') in type_port: ports_to_update = type_port[node.soft_get('type')] for port_id, precision in ports_to_update.items(): if port_id in node.out_ports(): log.debug( 'Insert Convert after op "{}" to type "{}"'.format( node.soft_get('name', node.id), precision)) node.out_port(port_id).get_connection().insert_node( Cast(graph, { 'dst_type': data_type_str_to_np(precision) }).create_node())
def replace_pattern(self, graph: Graph, match: dict): merge = match['merge'] power = Pow(graph, { 'name': merge.name + '/reciprocal_', 'type': 'PNORM' }).create_node() const1 = Const(graph, { 'value': -1.0, 'name': merge.name + '/negate_const' }).create_node() merge.in_port(0).get_connection().set_destination(power.in_port(0)) const1.out_port(0).connect(power.in_port(1)) concat_node = Concat( graph, { 'axis': 0, 'name': merge.name + '/Concat_', 'override_output_shape': True }).create_node() const3 = Const(graph, { 'name': merge.name + '/const_reduce', 'value': 0 }).create_node() for ii, idx in enumerate( range(merge.significant, merge.to_significant + 1, 1)): const_node = Const( graph, { 'value': float_array(math.pow(10.0, idx)), 'name': merge.name + '/Const_' + ii.__str__() }).create_node() mul_node = Mul(graph, { 'name': merge.name + '/Mul_' + ii.__str__() }).create_node() const_node.out_port(0).connect(mul_node.in_port(0)) power.out_port(0).connect( mul_node.in_port(1)) # connect to the graph node mul_node2 = Mul(graph, { 'name': merge.name + '/Mul_Div_' + ii.__str__() }).create_node() const_node2 = Const( graph, { 'value': float_array(math.pow(10.0, -1 * idx)), 'name': merge.name + '/Const_Pow_' + ii.__str__() }).create_node() cast_node = Cast( graph, { 'name': merge.name + '/Cast_' + idx.__str__(), 'dst_type': np.float32 }).create_node() mul_node.out_port(0).connect(cast_node.in_port(0)) const_node2.out_port(0).connect(mul_node2.in_port(1)) cast_node.out_port(0).connect(mul_node2.in_port(0)) concat_node.add_input_port(ii, skip_if_exist=True) concat_node.in_port(ii).get_connection().set_source( mul_node2.out_port(0)) reducesum_node = ReduceMean( graph, { 'name': merge.id + '/_pnorm_reduced_sum', 'keep_dims': False, 'in_ports_count': 2, 'need_shape_inference': None, 'infer': reduce_infer }).create_node() const3.out_port(0).connect(reducesum_node.in_port(1)) reducesum_node.in_port(0).get_connection().set_source( concat_node.out_port(0)) reshape = Reshape(graph, { 'name': merge.name + '/Reshape_Node' }).create_node() reshape_dim = Const(graph, { 'value': np.array([1, 5]), 'name': merge.id + '/Reshape_Dim' }).create_node() reducesum_node.out_port(0).connect(reshape.in_port(0)) reshape.in_port(1).connect(reshape_dim.out_port(0)) merge.out_port(0).get_connection().set_source(reshape.out_port(0))
def replace_resize(graph: Graph, resize: Node): log.debug("Converting of ONNX Resize-11 to Interpolate-4 " "is triggered for node {}.".format( resize.soft_get('name', resize.id))) input_shape = resize.in_port(0).data.get_shape() input_rank = len(input_shape) resize_name = resize.soft_get('name', resize.id) if input_rank not in {4, 5}: log.warning( 'The input shape is not 4D or 5D for op with name {}'.format( resize_name)) return num_of_inputs = len([ port for port in resize.in_ports().values() if not port.disconnected() ]) assert num_of_inputs in {3, 4}, \ "Number of inputs of ONNXResize (with name {}) should be equal to 3 or 4".format(resize_name) assert resize.soft_get('coordinate_transformation_mode') != 'tf_crop_and_resize', \ 'Mode tf_crop_and_resize is not supported for op {} with name {}'.format(resize.op, resize_name) layout = graph.graph['layout'] if input_rank == 4: begin_dim = get_height_dim(layout, input_rank) end_dim = get_width_dim(layout, input_rank) + 1 else: begin_dim = get_depth_dim(layout, input_rank) end_dim = get_width_dim(layout, input_rank) + 1 sizes_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([begin_dim]), 2: int64_array([end_dim]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice_sizes', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) scales_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([begin_dim]), 2: int64_array([end_dim]), 3: int64_array([1]) }, { 'name': resize_name + '/StridedSlice_scales', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) axes_node = Const( graph, { 'name': resize_name + '/axis', 'value': int64_array(np.arange(begin_dim, end_dim)) }).create_node() shape_calculation_mode = 'scales' if num_of_inputs == 3 else 'sizes' interpolate_node = Interpolate( graph, { 'version': 'opset4', 'mode': convert_mode(resize.mode), 'coordinate_transformation_mode': resize.coordinate_transformation_mode, 'cube_coeff': resize.cube_coeff, 'nearest_mode': resize.nearest_mode, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'antialias': 0, 'shape_calculation_mode': shape_calculation_mode, 'in_ports_count': 4 }).create_node() axes_node.out_port(0).connect(interpolate_node.in_port(3)) shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node() add_node = create_op_with_const_inputs(graph, Add, {1: float_array([1.0e-5])}, {'name': resize_name + '/Add'}) input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) if num_of_inputs == 3: cast_shape_to_float = Cast(graph, { 'dst_type': input_data_type }).create_node() mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node() shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) cast_add_result_to_int = Cast(graph, { 'dst_type': np.int64 }).create_node() floor_node = Floor(graph, { 'name': resize_name + '/Floor' }).create_node() mul_node.out_port(0).connect(add_node.in_port(0)) add_node.out_port(0).connect(floor_node.in_port(0)) floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0)) cast_add_result_to_int.out_port(0).connect(sizes_ss.in_port(0)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_scales = resize.in_port(2).get_connection() connection_of_scales.set_destination(scales_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_scales.get_source().connect(mul_node.in_port(1)) else: cast_shape_to_float = Cast(graph, { 'dst_type': input_data_type }).create_node() cast_sizes_to_float = Cast(graph, { 'dst_type': input_data_type }).create_node() div_node = Div(graph, {'name': resize_name + '/Div'}).create_node() cast_sizes_to_float.out_port(0).connect(div_node.in_port(0)) cast_shape_to_float.out_port(0).connect(div_node.in_port(1)) shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) div_node.out_port(0).connect(add_node.in_port(0)) add_node.out_port(0).connect(scales_ss.in_port(0)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_sizes = resize.in_port(3).get_connection() connection_of_sizes.set_destination(sizes_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_sizes.get_source().connect( cast_sizes_to_float.in_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate_node, resize_name)]) resize.out_port(0).get_connection().set_source( interpolate_node.out_port(0))
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='SpaceToBatch') + graph.get_op_nodes( op='BatchToSpace'): node.add_input_port(3, skip_if_exist=True) # convert TF representation of the pads/crops as [N, 2] to IE representation: [N] and [N] transposed_pads = create_op_with_const_inputs( graph, Transpose, {1: int64_array([1, 0])}) node.in_port(2).get_connection().set_destination( transposed_pads.in_port(0)) split_pads = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2}) transposed_pads.out_port(0).connect(split_pads.in_port(0)) for port_ind in range(2): node.in_port(port_ind + 2).connect( split_pads.out_port(port_ind)) node.in_port(port_ind + 2).get_connection().insert_node( create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})) # add zeros/ones to related inputs to align it with data input in0_rank = Rank(graph, { 'name': node.name + '/rank_0' }).create_node() in1_shape = Shape(graph, { 'name': node.name + '/rank_1' }).create_node() diff_size = Sub(graph, { 'name': node.name + '/sub_0' }).create_node() diff = Sub(graph, {'name': node.name + '/sub_1'}).create_node() const_begin = Const(graph, { 'value': int64_array([1]) }).create_node() const_pad_val = Const(graph, { 'value': int64_array(1) }).create_node() block_shape = Pad(graph, { 'name': node.name + '/aligned_block_shape', 'mode': 'constant' }).create_node() # in case of SpaceToBatch begin = pads_begin, end = pads_end # in case of BatchToSpace begin = crops_begin, end = crops_end new_begin_name = '/aligned_pads_begin' new_end_name = '/aligned_pads_end' if node.type == 'BatchToSpace': new_begin_name = '/aligned_crops_begin' new_end_name = '/aligned_crops_end' begin = Pad(graph, { 'name': node.name + new_begin_name, 'mode': 'constant' }).create_node() end = Pad(graph, { 'name': node.name + new_end_name, 'mode': 'constant' }).create_node() in0_rank_1d = create_op_node_with_second_input( graph, Unsqueeze, int64_array([0]), {'name': node.name + '/1d_rank_of_0'}, in0_rank) node.in_port(0).get_source().connect(in0_rank.in_port(0)) node.in_port(1).get_source().connect(in1_shape.in_port(0)) in0_rank_1d.out_port(0).connect(diff_size.in_port(0)) in1_shape.out_port(0).connect(diff_size.in_port(1)) diff_size.out_port(0).connect(diff.in_port(0)) const_begin.out_port(0).connect(diff.in_port(1)) const_pad_val.out_port(0).connect(block_shape.in_port(3)) inputs_array = [block_shape, begin, end] for idx, input_to_node in enumerate(inputs_array): name_of_input_to_node = input_to_node.name node.in_port(idx + 1).get_connection().set_destination( input_to_node.in_port(0)) const_begin.out_port(0).connect(input_to_node.in_port(1)) diff.out_port(0).connect(input_to_node.in_port(2)) input_to_node.out_port(0).connect(node.in_port(idx + 1)) convert = Cast(graph, { 'name': name_of_input_to_node + '/i64', 'dst_type': np.int64 }).create_node() input_to_node.in_port(0).get_connection().insert_node(convert)
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): log.debug('UpsampleToResample is triggered') upsample = match['upsample'] upsample_name = upsample.soft_get('name', upsample.id) input_shape = upsample.in_port(0).data.get_shape() input_shape_rank = len(input_shape) if input_shape_rank not in [4, 5]: log.warning('The input shape is not 4D or 5D for op {}'.format( upsample.soft_get('name'))) return depth_scale = None layout = graph.graph['layout'] if len(upsample.in_nodes()) == 2: if upsample.in_node(1).value is None: return scales = upsample.in_node(1).value assert len(scales) in ( 4, 5 ), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format( len(scales), upsample_name) if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)): return height_scale = scales[get_height_dim(layout, input_shape_rank)] width_scale = scales[get_width_dim(layout, input_shape_rank)] if len(scales) == 5: depth_scale = scales[get_depth_dim(layout, input_shape_rank)] else: height_scale = upsample['height_scale'] width_scale = upsample['width_scale'] if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected(): upsample.in_port(1).disconnect() upsample_name = upsample.soft_get('name', upsample.id) shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node() layout = graph.graph['layout'] if input_shape_rank == 4: begin_value = int64_array( [get_height_dim(layout, input_shape_rank)]) factor_value = float32_array([height_scale, width_scale]) else: begin_value = int64_array( [get_depth_dim(layout, input_shape_rank)]) factor_value = float32_array( [depth_scale, height_scale, width_scale]) ss = create_op_with_const_inputs( graph, StridedSlice, { 1: begin_value, 2: int64_array([get_width_dim(layout, input_shape_rank) + 1]), 3: int64_array([1]) }, { 'name': upsample_name + '/ss_0_port', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) mul = create_op_node_with_second_input( graph, Mul, factor_value, {'name': upsample_name + '/factor_mul'}) source = upsample.in_port(0).get_connection().get_source() source.connect(shape.in_port(0)) shape.out_port(0).connect(ss.in_port(0)) ss.out_port(0).connect(mul.in_port(0)) # Create Interpolate operation if input_shape_rank == 4: axes = int64_array([ get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) else: axes = int64_array([ get_depth_dim(layout, input_shape_rank), get_height_dim(layout, input_shape_rank), get_width_dim(layout, input_shape_rank) ]) axes_node = Const(graph, { 'name': upsample_name + '/axis', 'value': axes }).create_node() interpolate = Interpolate( graph, { 'mode': upsample.attrs()['mode'], 'antialias': 0, 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'coordinate_transformation_mode': 'half_pixel', 'nearest_mode': 'round_prefer_floor', 'cube_coeff': -0.75, 'shape_calculation_mode': 'scales', 'version': 'opset4', 'in_ports_count': 4 }).create_node() upsample.add_input_port(1, skip_if_exist=True) assert upsample.in_port(1).disconnected() mul.out_port(0).connect(interpolate.in_port(1)) axes_node.out_port(0).connect(interpolate.in_port(3)) scales_node = Const(graph, { 'name': upsample_name + '/scales', 'value': factor_value }).create_node() scales_node.out_port(0).connect(interpolate.in_port(2)) upsample.in_port(0).get_connection().set_destination( interpolate.in_port(0)) upsample.out_port(0).get_connection().set_source( interpolate.out_port(0)) rename_nodes([(upsample, upsample_name + '/delete'), (interpolate, upsample_name)]) convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node() convert_to_int = Cast(graph, dict(dst_type=np.int64)).create_node() mul.in_port(0).get_connection().insert_node(convert_to_float) mul.out_port(0).get_connection().insert_node(convert_to_int)