def replace_pattern(self, graph: Graph, match: dict): gather = match['GatherND'] gather_name = gather.soft_get('name', gather.id) input_shape = gather.in_node(0).shape indices = gather.in_node(1).value if indices is None: # We can't do such special pass without indices value return # 0. All needed checks that we can replace GatherND by Gather gather_idx = self.indices_check(indices, input_shape) if gather_idx is None: log.warning( 'Node {} with op=GatherND can\'t be normalized to op=Gather.'. format(gather_name)) return # 1. Add Reshape and connect new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:])) reshape = create_op_node_with_second_input( graph, Reshape, new_shape, {'name': gather_name + '/Reshape_for_GatherND/'}) gather.in_port(0).get_connection().set_destination(reshape.in_port(0)) # 2. Change indices from Nd to 1d: new_indices = np.reshape( np.take(indices, indices=[gather_idx], axis=-1), [-1]) rename_node(gather, gather_name + '/to_delete') # 3. Create new Gather operation and reconnect all inputs/outputs new_gather = create_op_with_const_inputs(graph, Gather, { 1: new_indices, 2: int64_array(0) }, {'name': gather_name}) rename_node(new_gather, gather_name) reshape.out_port(0).connect(new_gather.in_port(0)) gather.out_port(0).get_connection().set_source(new_gather.out_port(0)) # 4. Remove old Gather node graph.remove_node(gather.id)
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(squeeze_axis=True): name = node.soft_get('name', node.id) for out_port in node.out_ports().values(): if node.has_valid('axis'): squeeze_node = create_op_with_const_inputs( graph, Squeeze, {1: np.array(node.axis)}, {'name': name + '/Squeeze_'}) out_port.get_connection().insert_node(squeeze_node) elif node.is_in_port_connected(1): squeeze_node = Squeeze(graph, { 'name': name + '/Squeeze_' }).create_node() out_port.get_connection().insert_node(squeeze_node) node.in_port(1).get_connection().add_destination( squeeze_node.in_port(1)) else: raise Error( 'Unknown axis to squeeze for node {}'.format(name))
def replace_pattern(graph: Graph, match: Dict[str, Node]): node = match['op'] name = node.soft_get('name', node.id) input_shape = node.in_port(0).data.get_shape() second_input_shape = node.in_port(1).data.get_shape() begin_mask = np.zeros(len(input_shape), dtype=np.int64) end_mask = np.zeros(len(input_shape), dtype=np.int64) for i in node.axes: end_mask[i] = np.int64(1) new_axis_mask = np.zeros(len(input_shape), dtype=np.int64) shrink_axis_mask = np.zeros(len(input_shape), dtype=np.int64) ellipsis_mask = np.zeros(len(input_shape), dtype=np.int64) ss = create_op_with_const_inputs(graph, StridedSlice, port_value_dict={1: np.zeros(len(input_shape), dtype=np.int64)}, op_attrs={'name': 'StridedSlice', 'begin_mask': begin_mask, 'end_mask': end_mask, 'new_axis_mask': new_axis_mask, 'shrink_axis_mask': shrink_axis_mask, 'ellipsis_mask': ellipsis_mask}) if input_shape.size == second_input_shape.size: end = Shape(graph, dict(name=name + '/End')).create_node() end.in_port(0).connect(node.in_port(1).get_source()) ss.in_port(2).connect(end.out_port(0)) else: shape_like, rank_like = get_shape_and_rank_nodes_by_port(node.in_port(1).get_source()) end_first_part = get_shape_values_by_range_idxs(shape_like, rank_like, 0, node.axes[-1], include_end=True) if input_shape.size - 1 == node.axes[-1]: ss.in_port(2).connect(end_first_part.out_port(0)) else: shape, rank = get_shape_and_rank_nodes_by_port(node.in_port(0).get_source()) end_second_part = get_shape_values_by_range_idxs(shape, rank, node.axes[-1], -1, include_begin=False, include_end=True) end = new_shape_node_from_shape_nodes([end_first_part, end_second_part]) ss.in_port(2).connect(end.out_port(0)) node.in_port(0).get_connection().set_destination(ss.in_port(0)) node.in_port(1).disconnect() node.out_port(0).get_connection().set_source(ss.out_port(0)) rename_nodes([(node, name + '/ShouldBeDeleted'), (ss, name)])
def transpose_nchw_to_nhwc(op_node: Node, port_info: str, input_port: int): graph = op_node.graph permutation_data_node = get_node_with_permutation(op_node, port_info) rank = len(permutation_data_node.shape) assert rank >= 4, 'Rank must be 4D or higher for HCHW to HHWC permutation on node {}.'.format( op_node.id) perm = list(range(rank)) perm.insert(1, perm.pop()) perm = int64_array(perm) transpose_name = op_node.soft_get('name', op_node.id) + '/Transpose' from mo.front.tf.graph_utils import create_op_with_const_inputs # avoiding recursive imports transpose = create_op_with_const_inputs(graph, Transpose, {1: perm}, { 'name': transpose_name, 'override_output_shape': True }) op_node.in_port(input_port).get_connection().insert_node(transpose) transpose.infer(transpose)
def replace_pattern(graph: Graph, match: dict): node = match['proposal'] assert len(node.in_ports() ) == 3, "Proposal op must have exactly 3 input ports" im_info_shape = node.in_port(2).data.get_shape() assert im_info_shape is not None if np.array_equal(im_info_shape, [1, 6]): log.error( 'The model contains Proposal layer "{}" with input of shape [1, 6]. Inference Engine ' 'implementation of the Proposal layer uses only 4 first values (indices 0, 1, 2 and 3). ' 'Elements with indices 4 and 5 will be ignored.'.format( node.soft_get('name', node.id)), extra={'is_warning': True}) cropped_im_info = create_op_with_const_inputs( graph, StridedSlice, { 1: np.array([0, 0], dtype=np.int32), 2: np.array([1, 3], dtype=np.int32), 3: np.array([1, 1], dtype=np.int32) }, { 'name': 'cropped_im_info', 'begin_mask': int64_array([1, 1]), 'end_mask': int64_array([1, 1]), 'new_axis_mask': int64_array([0, 0]), 'shrink_axis_mask': int64_array([0, 0]), 'ellipsis_mask': int64_array([0, 0]), 'override_output_shape': True, }) node.in_port(2).get_connection().insert_node(cropped_im_info) # update the im_info_shape so the next 'if' statement become true im_info_shape = int64_array([1, 3]) if np.array_equal(im_info_shape, [1, 3]) or np.array_equal( im_info_shape, [1, 4]): reshape = create_op_node_with_second_input( graph, Reshape, [im_info_shape[1]], {'name': 'im_info/Reshape'}) node.in_port(2).get_connection().set_destination( reshape.in_port(0)) reshape.out_port(0).connect(node.in_port(2))
def transpose(op_node: Node, port_info: str, input_port: int): graph = op_node.graph permutation_data_node = get_node_with_permutation(op_node, port_info) assert permutation_data_node.has_and_set('permutation'), \ 'Data node "{}" does not have permutation for node {}, port_info "{}".'.format( permutation_data_node.id, op_node.id, port_info) permutation = permutation_data_node.permutation if len(permutation.perm) == 0: return transpose_name = op_node.soft_get('name', op_node.id) + '/Transpose' from mo.front.tf.graph_utils import create_op_with_const_inputs # avoiding recursive imports transpose = create_op_with_const_inputs(graph, Transpose, {1: permutation.perm}, { 'name': transpose_name, 'override_output_shape': True }) op_node.in_port(input_port).get_connection().insert_node(transpose) transpose.infer(transpose)
def find_and_replace_pattern(self, graph: Graph): for attr_clamp in graph.get_op_nodes(op='AttributedClamp'): original_name = attr_clamp.soft_get('name', attr_clamp.id) rename_node(attr_clamp, original_name + '/TBR') min_value = attr_clamp.soft_get('min', np.finfo(np.float32).min) max_value = attr_clamp.soft_get('max', np.finfo(np.float32).max) new_clamp = create_op_with_const_inputs( graph, Clamp, { 1: np.array(min_value, dtype=np.float32), 2: np.array(max_value, dtype=np.float32) }, {'name': original_name}) rename_node(new_clamp, original_name) attr_clamp.in_port(0).get_connection().set_destination( new_clamp.in_port(0)) attr_clamp.out_port(0).get_connection().set_source( new_clamp.out_port(0)) graph.remove_node(attr_clamp.id)
def replace_interpolate_pattern(graph: Graph, match: dict): split = match['split'] scale = int64_array([get_split_scale(split)]) axis = int(split.in_port(1).get_connection().get_source().node.value) split_node_name = split.name shape_node = Shape(graph, dict(name=split_node_name + '/Shape_')).create_node() scales_node = Const(graph, dict(name=split_node_name + '/scales_', value=scale)).create_node() mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node() scales_node.out_port(0).connect(mul_node.in_port(1)) strided_slice_node = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([axis]), 2: int64_array([axis + 1]) }, { 'name': split_node_name + '/StridedSlice_', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) shape_node.out_port(0).connect(strided_slice_node.in_port(0)) strided_slice_node.out_port(0).connect(mul_node.in_port(0)) interp_node = Interpolate( graph, dict(name=split_node_name + '/Interpolate_', axes=int64_array([axis]), mode='nearest')).create_node() mul_node.out_port(0).connect(interp_node.in_port(1)) match['concat'].out_port(0).get_connection().set_source( interp_node.out_port(0)) split_connection = split.in_port(0).get_connection() split_connection.set_destination(interp_node.in_port(0)) split_connection.get_source().connect(shape_node.in_port(0))
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='ThresholdedRelu'): name = node.soft_get('name', node.id) greater = create_op_with_const_inputs( graph, Greater, {1: float_array([node.alpha])}) greater.in_port(0).connect(node.in_port(0).get_source()) float_greater = Cast( graph, { 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() greater.out_port(0).connect(float_greater.in_port(0)) mul = Mul(graph, {}).create_node() node.out_port(0).get_connection().set_source(mul.out_port(0)) mul.in_port(0).connect(node.in_port(0).get_source()) mul.in_port(1).connect(float_greater.out_port(0)) rename_nodes([(node, name + '/TBR'), (mul, name)])
def replace_pattern(graph: Graph, match: dict): node = match['normalize'] # rename normalize node since it will be no longer output node after the transformation output_name = node.soft_get('name', node.id) normalizel2_name = output_name + '/normalizel2' rename_node(node, normalizel2_name) assert node.in_port(0).data.get_shape().size in [2, 3, 4] assert node.has_valid('across_spatial') assert node.has_valid('channel_shared') assert node.has_valid('eps') if 'bin' in node.in_edge(1): del node.in_edge(1)['bin'] weights = node.in_port(1).data.get_value() assert weights is not None # in the code below we intentionally use get_source() to get the out port. Because updating the out port will # update the Const node 'value' and 'shape' attributes if node.channel_shared or all(weights == weights[0]): node.in_port(1).get_source().data.set_value(np.array([weights[0]])) else: new_shape = np.ones((len(node.in_port(0).data.get_shape())), dtype=np.int64) new_shape[1] = -1 node.in_port(1).get_source().data.set_value(np.array(weights).reshape(new_shape)) mul = Mul(graph, {'name': output_name}).create_node() rename_node(mul, output_name) if not node.across_spatial: axes = int64_array([1]) else: axes = int64_array(np.arange(start=1, stop=node.in_port(0).data.get_shape().size)) normalizel2 = create_op_with_const_inputs(graph, NormalizeL2Op, {1: axes}, {'eps_mode': 'add', 'eps': node.eps}) node.out_port(0).get_connection().set_source(mul.out_port(0)) node.in_port(1).get_connection().get_source().connect(mul.in_port(1)) normalizel2.out_port(0).connect(mul.in_port(0)) node.in_port(0).get_connection().set_destination(normalizel2.in_port(0))
def insert_pre_processing(graph: Graph, input_node: Node, node_mean_scale_values: np.array, preprocessing_name: str): assert preprocessing_name in ['scale', 'mean'] if node_mean_scale_values.get(preprocessing_name) is None: return user_value = node_mean_scale_values[preprocessing_name] value = 1 / user_value if preprocessing_name == 'scale' else user_value * ( -1) optimize_value = int(preprocessing_name == 'scale') op = Mul if preprocessing_name == 'scale' else Add if all([x == optimize_value for x in value]): return assert input_node.has_valid('shape') features_dim_idx = get_features_dim(graph.graph['layout'], len(input_node.shape)) assert compatible_dims( value.size, input_node.shape[features_dim_idx]) or value.size == 1 shape = np.ones(len(input_node.shape), dtype=np.int64) shape[features_dim_idx] = value.size value = value.reshape(shape) name = input_node.soft_get('name', input_node.id) + '/' + preprocessing_name preprocessing = create_op_with_const_inputs(graph, op=op, port_value_dict={1: value}, op_attrs={'name': name}) for dst in input_node.out_port(0).get_destinations(): if dst.node.soft_get('type') != 'ShapeOf': # After the insertion of additional operations model optimizer # should keep the link to the input layer. Parameter node in framework # should map to parameter node in IR. # For this reason 'fw_tensor_debug_info' should be kept in data node. dst.get_connection().set_source(preprocessing.out_port(0), "source") input_node.out_port(0).connect(preprocessing.in_port(0))
def replace_pattern(self, graph: Graph, match: dict): node = match['cell'] cell_name = node.soft_get('name', node.id) cell_type = node.soft_get('type') WR_input_id = node.soft_get('wr_input_id') hidden_size_coef = node.soft_get('gates_count') hidden_size = node.get_attrs()["hidden_size"] # default values for RNNCell/GRUCell additional_port_id = 4 if cell_type == "LSTMCell": additional_port_id = 5 WR_shape = node.in_port(WR_input_id).data.get_shape() assert WR_shape is not None, "Undefined 'WR' input shape for Cell node '{}'".format(cell_name) num_elements_in_WR = np.prod(WR_shape) input_size = (num_elements_in_WR / (hidden_size_coef * hidden_size)) - hidden_size # Reshape reshape = create_op_node_with_second_input(graph, Reshape, int64_array([hidden_size_coef * hidden_size, hidden_size + input_size]), {'name': cell_name + '/Dims'}) # VariadicSplit split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(1), 2: int64_array([input_size, hidden_size])}, {'out_ports_count': 2, 'name': cell_name + '/Split'}, reshape) # Cell node.in_port(WR_input_id).get_connection().set_destination(reshape.in_port(0)) node.add_input_port(additional_port_id, skip_if_exist=True) assert node.in_port(additional_port_id).disconnected() # (x, y, WR, B) -> (x, y, W, R, B(additional_port)) node.in_port(additional_port_id - 1).get_connection().set_destination(node.in_port(additional_port_id)) split.out_port(0).connect(node.in_port(additional_port_id - 2)) split.out_port(1).connect(node.in_port(additional_port_id - 1))
def make_interpolate_reshapeable(interpolate): assert interpolate.soft_get('type') == 'Interpolate' axes = Interpolate.get_axes(interpolate) input_shape = interpolate.in_port(0).data.get_shape() output_shape = interpolate.out_port(0).data.get_shape() if not np.all(np.remainder(output_shape, input_shape) == 0) and \ not np.all(np.remainder(input_shape, output_shape) == 0): return graph = interpolate.graph name = interpolate.soft_get('name', interpolate.id) shape = Shape(graph, {'name': name + '/ShapeOf'}).create_node() shape.in_port(0).connect(interpolate.in_port(0).get_source()) gather = create_op_with_const_inputs(graph, Gather, { 1: np.array(axes, dtype=np.int32), 2: int64_array(0) }, {'name': shape.name + '/Gathered'}, shape) multipliers = output_shape[axes] / input_shape[axes] mul = create_op_node_with_second_input( graph, Mul, multipliers, {'name': gather.name + '/Multiplied'}, gather) interpolate.in_port(1).get_connection().set_source(mul.out_port(0))
def insert_transpose(graph: Graph, input_port: Port, before_input=True): input_rank = len(input_port.data.get_shape()) if input_rank > 3: if before_input: axis_order = np.concatenate((int64_array([0]), int64_array(list(range(2, input_rank))), int64_array([1]))) source_node = input_port.get_source().node transpose_name = source_node.soft_get('name', source_node.id) + '/TransposeToNHWC' else: axis_order = np.concatenate( (int64_array([0]), int64_array([input_rank - 1]), int64_array(list(range(1, input_rank - 1))))) transpose_name = input_port.node.soft_get('name', input_port.node.id) + '/TransposeToNCHW' input_port.node['need_shape_inference'] = True input_port.node['override_output_shape'] = True transpose = create_op_with_const_inputs(graph, Transpose, {1: axis_order}, {'name': transpose_name}) input_port.get_connection().insert_node(transpose) transpose['need_shape_inference'] = True transpose['override_output_shape'] = True
def replace_op(self, graph: Graph, node: Node): name = node.soft_get('name', node.id) # create range of axes for MVN based on `start_axis` and rank of input rank = Rank(graph, {'name': name + '/Rank'}).create_node() rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(2), 2: int64_array(1) }, { 'name': name + '/Range', 'output_type': np.int64 }) mvn = MVN( graph, { 'eps': node.epsilon, 'eps_mode': 'inside_sqrt', 'normalize_variance': 1, 'name': name + '/Ins_Norm/MVN_', }).create_node() node.in_port(0).get_connection().set_destination(mvn.in_port(0)) rng.out_port(0).connect(mvn.in_port(1)) mul = Mul(graph, { 'axis': 1, 'name': name + '/Ins_Norm/mul_' }).create_node() mvn.out_port(0).connect(mul.in_port(0)) node.in_port(1).get_connection().set_destination(mul.in_port(1)) add = Add(graph, { 'axis': 1, 'name': name + '/Ins_Norm/add_' }).create_node() mul.out_port(0).connect(add.in_port(0)) node.in_port(2).get_connection().set_destination(add.in_port(1)) mvn.in_port(0).get_connection().add_destination(rank.in_port(0)) rng.in_port(1).connect(rank.out_port(0)) rename_nodes([(node, name + '/TBD'), (add, name)]) return [add.id]
def replace_sub_graph(self, graph: Graph, match: dict): ctc_greedy_decoder_tf = match['decoder'] cast = match['cast'] sparse_to_dense = match['sparse_to_dense'] sparse_to_dense_name = sparse_to_dense.soft_get( 'name', sparse_to_dense.id) ctc_greedy_decoder_tf_name = ctc_greedy_decoder_tf.soft_get( 'name', ctc_greedy_decoder_tf.id) # for normalizing input chanel need to transpose input data from [T, N, C] to [N, T, C] # which supported CTCGreedyDecoderSeqLen op. ctc_data_permute = create_op_with_const_inputs( graph, Transpose, {1: int64_array([1, 0, 2])}, {'name': ctc_greedy_decoder_tf_name + '/ctc_data_permute'}) assert ctc_greedy_decoder_tf.has_valid('merge_repeated'), \ 'The CTCGreedyDecoderSeqLen node "{}" misses "merge_repeated" attribute'.format(ctc_greedy_decoder_tf_name) ctc_greedy_decoder_tf.in_port(0).get_source().connect( ctc_data_permute.in_port(0)) merge_repeated_tf = ctc_greedy_decoder_tf.merge_repeated ctc_greedy_decoder = CTCGreedyDecoderSeqLenOp( graph, { 'name': sparse_to_dense_name, 'merge_repeated': merge_repeated_tf }).create_node() rename_nodes([(sparse_to_dense, sparse_to_dense_name + '/AbandonedName'), (ctc_greedy_decoder, sparse_to_dense_name)]) ctc_greedy_decoder.in_port(0).connect(ctc_data_permute.out_port(0)) ctc_greedy_decoder_tf.in_port(1).get_source().connect( ctc_greedy_decoder.in_port(1)) # set output of the new sub-graph as a source for SparseToDense consumer sparse_to_dense.out_port(0).get_connection().set_source( ctc_greedy_decoder.out_port(0)) # remove no longer needed nodes graph.remove_nodes_from( [sparse_to_dense.id, cast.id, ctc_greedy_decoder_tf.id])
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='AttributedSplit'): name = node.soft_get('name', node.id) axis = node.soft_get('axis', None) assert axis is not None, \ 'AttributedSplit should have `axis` parameter set, but it`s not for node {}'.format(name) num_splits = node.soft_get('num_splits', None) assert num_splits is not None, \ 'AttributedSplit should have `num_splits` parameter set, but it`s not for node {}'.format(name) split = create_op_with_const_inputs(graph, Split, {1: np.int64(axis)}, { 'name': name + '/Split', 'num_splits': num_splits }) for idx, port in node.out_ports().items(): port.get_connection().set_source(split.out_port(idx)) node.in_port(0).get_connection().set_destination(split.in_port(0)) graph.remove_node(node.id)
def find_and_replace_pattern(self, graph: Graph): for fake_output in graph.get_op_nodes(op='FakeOutput'): name = fake_output.soft_get('name', fake_output.id) producer = fake_output.in_port(0).get_source().node producer_outputs = 0 for port in producer.out_ports().values(): if not port.disconnected(): producer_outputs += 1 if producer_outputs != 1: # At this stage we don't know the type of output, so we rely on MO transformation which updates the # Const type for elementwise operations in case of input data types mismatch add = create_op_with_const_inputs(graph, Add, {1: int64_array(0)}, {'can_be_fused': False}) rename_nodes([(fake_output, name + '/TBD'), (add, name)]) fake_output.in_port(0).get_connection().set_destination(add.in_port(0)) fake_output.out_port(0).get_connection().set_source(add.out_port(0)) else: result_in_port = fake_output.out_port(0).get_destination() result_in_port.disconnect() fake_output.in_port(0).get_connection().set_destination(result_in_port) rename_nodes([(fake_output, name + '/TBD'), (producer, name)])
def replace_op(self, graph: Graph, node: Node): node_name = node.soft_get('name', node.id) rename_node(node, node_name + '/TBR') sqr_node = Mul(graph, {}).create_node() reduce_sum_node = ReduceSum( graph, { 'keep_dims': node.soft_get('keep_dims', 0), 'axis': node.soft_get('axis', None) }).create_node() sqrt_node = create_op_with_const_inputs(graph, Pow, {1: float_array(0.5)}) rename_node(sqrt_node, node_name) # Connect nodes node.in_port(0).get_connection().set_destination(sqr_node.in_port(0)) sqr_node.in_port(0).get_connection().add_destination( sqr_node.in_port(1)) sqr_node.out_port(0).connect(reduce_sum_node.in_port(0)) reduce_sum_node.out_port(0).connect(sqrt_node.in_port(0)) return [sqrt_node.id]
def replace_sub_graph(graph: Graph, match: dict, **kwargs): random_uniform_node = match['random_uniform'] random_uniform_node_name = random_uniform_node.soft_get( 'name', random_uniform_node.id) log.error("Possible dropout block with RandomUniform is detected. " "Replace {} with a Broadcast with constant value of 0.5 " "assuming that it is executed in inference mode.".format( random_uniform_node_name), extra={'is_warning': True}) data_type = match['add_const'].data_type broadcast_node = create_op_with_const_inputs( graph, Broadcast, {0: np.array([0.5], dtype=data_type)}, { 'mode': 'numpy', 'name': random_uniform_node_name + '/Broadcast' }) rename_nodes([(random_uniform_node, random_uniform_node_name + '/ToBeRemoved'), (broadcast_node, random_uniform_node_name)]) random_uniform_node.in_port(0).get_connection().set_destination( broadcast_node.in_port(1)) random_uniform_node.out_port(0).get_connection().set_source( broadcast_node.out_port(0))
def convert_ifft_to_dft(self, graph: Graph, mx_fft: Node): mx_fft_name = mx_fft.soft_get('name', mx_fft.id) rank_node = Rank(graph, {'name': mx_fft_name + '/rank'}).create_node() sub_node = create_op_with_const_inputs(graph, Sub, {1: int64_array(1)}, {'name': mx_fft_name + '/Sub'}) rank_node.out_port(0).connect(sub_node.in_port(0)) broadcast_node0 = create_op_with_const_inputs( graph, Broadcast, {0: int64_array(0)}, {'name': mx_fft_name + '/broadcast'}) sub_node.out_port(0).connect(broadcast_node0.in_port(1)) concat_node = create_op_with_const_inputs( graph, Concat, {1: int64_array([-1, 2])}, { 'name': mx_fft_name + '/new_shape', 'in_ports_count': 2, 'axis': 0 }, broadcast_node0) reshape_node = Reshape(graph, { 'name': mx_fft_name + '/reshape' }).create_node() concat_node.out_port(0).connect(reshape_node.in_port(1)) mx_fft_connection = mx_fft.in_port(0).get_connection() mx_fft_connection.set_destination(reshape_node.in_port(0)) mx_fft_connection.get_source().connect(rank_node.in_port(0)) dft_node = create_op_with_const_inputs( graph, IDFT, {1: int64_array([-1])}, { 'name': mx_fft_name + '/idft', 'in_ports_count': 2 }, reshape_node) split_node = create_op_with_const_inputs( graph, Split, {1: int64_array(-1)}, { 'name': mx_fft_name + '/split', 'num_splits': 2 }, dft_node) squeeze_node = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([-1])}, {}, split_node) mx_fft.out_port(0).get_connection().set_source( squeeze_node.out_port(0)) rename_nodes([(mx_fft, mx_fft_name + '/to_be_removed'), (squeeze_node, mx_fft_name)])
def transform_keras_rnn_input_slicing(external_match: dict, internal_match: dict): """ Transforms TensorFlow 2 input slicing into use of axis attribute for input port of Loop node :param external_match: a match used for handling a part of the main graph responsible for input slicing :param internal_match: a match used for handling a part of the body graph responsible for input slicing """ loop_node = external_match['while'] unstack_node = external_match['unstack'] body_graph = loop_node['body'] tensor_list_get_item_node = internal_match['slicing'] unstack_placeholder = internal_match['tensor_list'] tensor_list_get_item_node_name = tensor_list_get_item_node.soft_get( 'name', tensor_list_get_item_node.id) # 1. process the body graph to avoid unsupported operations: TensorListGetItem and TensorListSetItem # replace TensorListGetItem with Squeeze node and iterate through slices using axis for input port squeeze_list_element = create_op_with_const_inputs( body_graph, Squeeze, {1: int64_array(0)}, {'name': 'TensorListGetItemSqueeze'}) tensor_list_get_item_node.in_port(0).get_connection().set_destination( squeeze_list_element.in_port(0)) tensor_list_get_item_node.out_port(0).get_connection().set_source( squeeze_list_element.out_port(0)) rename_nodes([(tensor_list_get_item_node, tensor_list_get_item_node_name + '/AbandonedName'), (squeeze_list_element, tensor_list_get_item_node_name)]) unstack_placeholder_layer_id = unstack_placeholder.internal_layer_id Loop.update_port_map_value_ext(loop_node.input_port_map, 'internal_layer_id', unstack_placeholder_layer_id, 'axis', 0) # 2. process locality of Loop node in the main graph to avoid unsupported operations: # TensorListFromTensor, TensorListReserve, and TensorListStack # remove TensorListFromTensor and pass a tensor to Loop as is unstack_node.out_port(0).get_connection().set_source( unstack_node.in_port(0).get_connection().get_source())
def replace_pattern(self, graph: Graph, match: dict): node = match['op'] node_name = node.soft_get('name', node.id) node.is_training = False shape = node.in_port(1).data.get_shape() assert shape is not None, 'The shape of scale input of the BatchNorm node {} is not defined'.format(node.name) bn_mean = Const(graph, {'name': node_name + '/mean', 'value': np.zeros(shape, dtype=np.float32), 'override_output_shape': True}).create_node() bn_std = Const(graph, {'name': node_name + '/std', 'value': np.ones(shape, dtype=np.float32), 'override_output_shape': True}).create_node() node.in_port(3).get_connection().set_source(bn_mean.out_port(0)) node.in_port(4).get_connection().set_source(bn_std.out_port(0)) # save the original shape original_shape = Shape(graph, {'name': node.in_port(0).get_source().node.soft_get('name')}).create_node() original_shape.in_port(0).connect(node.in_port(0).get_source()) input_rank = len(node.in_port(0).data.get_shape()) rng = create_op_with_const_inputs(graph, Range, {0: int64_array(2), 1: int64_array(input_rank), 2: int64_array(1)}, {'name': node_name + '/Range', 'output_type': np.int64}) mvn = MVN(graph, {'name': node_name + '/mvn_', 'eps': node.soft_get('eps', 1e-6), 'eps_mode': 'outside_sqrt', 'normalize_variance': 1, 'override_output_shape': True}).create_node() node.in_port(0).get_connection().insert_node(mvn) mvn.in_port(1).connect(rng.out_port(0)) reshape_4d = create_op_node_with_second_input(graph, Reshape, int64_array([1, -1, 0, 0]), {'override_output_shape': True, 'name': node_name + '/fused_batch_and_channels'}) mvn.in_port(0).get_connection().insert_node(reshape_4d) # restore original shape reshape_back = Reshape(graph, {'name': node_name + '/restore_shape', 'override_output_shape': True}).create_node() reshape_back.in_port(1).connect(original_shape.out_port(0)) mvn.out_port(0).get_connection().insert_node(reshape_back)
def make_interpolate_reshapeable(interpolate, concat): assert interpolate.soft_get('type') == 'Interpolate' assert concat.soft_get('type') == 'Concat' output_shape = interpolate.out_port(0).data.get_shape() interp_axes = [ get_canonical_axis_index(output_shape, axis) for axis in Interpolate.get_axes(interpolate) ] concat_axis = get_canonical_axis_index(output_shape, concat.axis) if concat_axis in interp_axes: return concat_srcs = [ port.get_source() for port in concat.in_ports().values() if not port.disconnected() ] non_interp_concat_srcs = [ src for src in concat_srcs if src.node.soft_get('type') != 'Interpolate' ] if len(non_interp_concat_srcs) == 0: return graph = interpolate.graph src = non_interp_concat_srcs[0] shape = Shape(graph, { 'name': src.node.soft_get('name', src.node.id) + '/Shape' }).create_node() shape.in_port(0).connect(src) gather = create_op_with_const_inputs( graph, Gather, { 1: np.array(interp_axes, dtype=np.int32), 2: int64_array(0) }, {'name': shape.name + '/Gathered'}, shape) interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
def replace_op(self, graph: Graph, node: Node): ss_node = create_op_with_const_inputs( graph, Split, {1: int64_array(1)}, { 'name': 'Split_eltwise_' + node.name, 'num_splits': node['num_inputs'] }) inp = node.get_inputs() in_node = inp[0][0] edge_attrs = inp[0][1] graph.add_edge(in_node, ss_node.id, **edge_attrs) if ss_node.num_splits == 2: if node['operation'] == 'mul': eltwise_node = Mul(graph, attrs={ 'name': 'Eltwise_' + node.name }).create_node() elif node['operation'] == 'sum': eltwise_node = Add(graph, attrs={ 'name': 'Eltwise_' + node.name }).create_node() else: raise Error('Error on replacing Kaldi eltwise: unknown type ' + node['operation']) elif ss_node.num_splits > 2: eltwise_node = EltwiseN(graph, attrs={ 'name': 'Eltwise_' + node.name, 'operation': node['operation'] }).create_node() else: raise Error('Error on replacing Kaldi eltwise') for i in range(ss_node.num_splits): ss_node.out_port(i).get_connection().set_destination( eltwise_node.in_port(i)) return [eltwise_node.id]
def replace_op(self, graph: Graph, node: Node): # save the original node name to use it in the new Pad op instance original_name = node.soft_get('name', node.id) rename_node(node, original_name + '/TBR') new_pad = Pad(graph, {'mode': node.soft_get('mode', None)}).create_node() rename_node(new_pad, original_name) node.in_port(0).get_connection().set_destination(new_pad.in_port(0)) if node.soft_get('mode') == 'constant': # the input with fill value is an optional third input in ONNX if not node.in_port(2).disconnected(): node.in_port(2).get_connection().set_destination(new_pad.in_port(3)) else: new_pad.in_port(3).connect(Const(graph, {'value': 0.0}).create_node().out_port(0)) # convert ONNX representation of the pads as [2 * N] to MO representation: [N] and [N] split_pads = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2}) node.in_port(1).get_connection().set_destination(split_pads.in_port(0)) split_pads.out_port(0).connect(new_pad.in_port(1)) split_pads.out_port(1).connect(new_pad.in_port(2)) return [new_pad.id]
def sub_to_add_replacement(sub: Node): # we execute this transformation for V10 IR later on middle phase despite graph_condition # so we prevent Sub replacement on shape-calculating sub-graphs if sub.in_port(0).data.get_value() is not None and sub.in_port(1).data.get_value() is not None: return graph = sub.graph name = sub.soft_get('name', sub.id) # keep Add name the same as Sub -- because of mathematical equality of output tensors rename_node(node=sub, name=name + '/to_be_removed') # reconnect Sub in(out)puts to Add add = Add(graph, {'name': name}).create_node() rename_node(add, name) sub.in_port(0).get_connection().set_destination(add.in_port(0)) sub.in_port(1).get_connection().set_destination(add.in_port(1)) sub.out_port(0).get_connection().set_source(add.out_port(0)) # restore mathematical equivalence to Sub operation: Sub(A, B) = Add(A, Mul(B, -1)) const_dtype = sub.soft_get('data_type', np.float32) negate = create_op_with_const_inputs(graph, Mul, {1: np.array(-1, dtype=const_dtype)}, {'name': name + '/neg_'}) add.in_port(1).get_connection().insert_node(negate)
def replace_sub_graph(self, graph: Graph, match: dict): inp = match['pool0'] inp_port = inp.in_port(0).get_source() # take/check the values of the add, pow and axes for ReduceMean pow_param = match['pow_param'] add_param = match['add_param'] if add_param.value.size == 1 and pow_param.value.size == 1 and add_param.value.item() <= 1e-05 \ and pow_param.value.item() == 0.5 and match['pool0_param'].value == match['pool1_param'].value: log.debug('Found LayerNorm pattern after {} with name {}'.format( inp_port.node.op, inp_port.node.name)) mvn = create_op_with_const_inputs( graph, MVN, {1: match['pool1_param'].value}, { 'eps': add_param.value.item(), 'normalize_variance': 1, 'eps_mode': 'inside_sqrt' }) div_name = match['div'].soft_get('name', match['div'].id) rename_nodes([(match['div'], div_name + '/to_be_removed'), (mvn, div_name)]) inp_port.connect(mvn.in_port(0)) match['div'].out_port(0).get_connection().set_source( mvn.out_port(0))
def make_interpolate_reshape_able(self, interpolate: Node, concat: Node): assert interpolate.soft_get('type') == 'Interpolate' assert concat.soft_get('type') == 'Concat' interp_axes = interpolate.soft_get('axes', None) interp_axes = interp_axes if interp_axes is None else int64_array( interp_axes) concat_axis = self.get_concat_axis(concat) if concat_axis is None or interp_axes is None \ or np.any(interp_axes < 0) or concat_axis < 0 \ or concat_axis in interp_axes: # checks that interpolate axes and concat axis are valid and do not intersect return non_interp_concat_srcs = self.get_non_interpolate_concat_sources( concat) if not len(non_interp_concat_srcs): # there is no Concat input to take input from return graph = interpolate.graph src = non_interp_concat_srcs[0] shape = Shape(graph, { 'name': src.node.soft_get('name', src.node.id) + '/Shape' }).create_node() shape.in_port(0).connect(src) gather = create_op_with_const_inputs( graph, Gather, { 1: np.array(interpolate.axes, dtype=np.int32), 2: int64_array(0) }, {'name': shape.name + '/Gathered'}, input_node=shape) interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
def replace_pattern(self, graph: Graph, match: dict): node = match['node'] node_name = node.soft_get('name', node.id) if 2 in node.in_ports() and not node.in_port(2).disconnected(): # Third input represents output shape. Cutting its value according to scheme: # [N, C, spatial_dim_0, ..., spatial_dim_n] -> [spatial_dim_0, ..., spatial_dim_n] in_rank = node.in_port(0).data.get_shape().size shape_src = node.in_port(2).get_source() node.in_port(2).disconnect() ss_0 = create_op_with_const_inputs( graph, StridedSlice, { 1: np.array([2], dtype=np.int32), 2: np.array([in_rank], dtype=np.int32), 3: np.array([1], dtype=np.int32) }, { 'name': node_name + '/ss_0_port', 'begin_mask': np.array([1], dtype=np.int32), 'end_mask': np.array([0], dtype=np.int32), 'new_axis_mask': np.array([0], dtype=np.int32), 'shrink_axis_mask': np.array([0], dtype=np.int32), 'ellipsis_mask': np.array([0], dtype=np.int32) }) shape_src.connect(ss_0.in_port(0)) ss_0.out_port(0).connect(node.in_port(2)) # Specification: *padding amount* is deduced from relation of input and output spatial shapes del node['pad'] elif node.has_valid('original_output_spatial_shape'): # node had fixed output spatial shape set in original framework, so we restore it here const = Const( graph, { 'value': int64_array(node.original_output_spatial_shape), 'name': node_name + '/original_spatial_shape' }).create_node() node.add_input_port(2, skip_if_exist=True) const.out_port(0).connect(node.in_port(2)) # Specification: *padding amount* is deduced from relation of input and output spatial shapes del node['pad'] group = node.soft_get('group', 1) if group != 1: assert group > 1 weights_shape = node.in_port(1).data.get_shape() assert weights_shape is not None I = node.in_port(0).data.get_shape()[1] assert I % group == 0 assert node.output % group == 0 new_shape = int64_array( [group, I / group, node.output / group, *weights_shape[2:]]) assert np.prod(weights_shape) == np.prod(new_shape), \ 'Initial weights shape {}, grouped weights shape {}'.format(weights_shape, new_shape) reshape = create_op_node_with_second_input( graph, Reshape, int64_array(new_shape), {'override_output_shape': True}, node.in_port(1).get_source().node) node.in_port(1).get_connection().set_source(reshape.out_port(0)) node['type'] = 'GroupConvolutionBackpropData' else: node['type'] = 'ConvolutionBackpropData'