def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='MVNCaffe'): node_name = node.soft_get('name', node.id) start_axis = 2 if node['across_channels'] == 1: start_axis = 1 rank = Rank(graph, {'name': node_name + '/Rank'}).create_node() # create range of axes based on `start_axis` and rank of input rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(start_axis), 2: int64_array(1) }, { 'name': node_name + '/Range', 'output_type': np.int64 }) rng.in_port(1).connect(rank.out_port(0)) new_mvn = MVN( graph, { 'eps': node.soft_get('eps', 1e-9), 'eps_mode': 'inside_sqrt', 'normalize_variance': node.soft_get( 'normalize_variance', 1) }).create_node([node.in_port(0).get_source().node, rng]) new_mvn.in_port(0).get_connection().add_destination( rank.in_port(0)) node.out_port(0).get_connection().set_source(new_mvn.out_port(0)) rename_nodes([(node, node_name + '/tbd'), (new_mvn, node_name)]) graph.remove_node(node.id)
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): node = match['reduce'] connected_in_ports = [ port for port in node.in_ports().values() if not port.disconnected() ] if len(connected_in_ports) == 1: node_name = node.soft_get('name', node.id) # if the 'axis' is None then we still add a second input to the layer with a 1D array with 1 element equal # to None. The infer function handles this case because the input shape is known at this stage only if node.has_valid('axis'): const = Const(graph, { 'name': node_name + '/axis', 'value': node.axis }).create_node() node.add_input_port(1, skip_if_exist=True) const.out_port(0).connect(node.in_port(1)) del graph.node[node.id]['axis'] else: # The default (if there is no 'axis') is to reduce over all the dimensions of the input tensor. axes = create_op_with_const_inputs( graph, Range, { 0: int64_array(0), 2: int64_array(1) }, dict(name=node_name + '/axes')) end_of_range = Rank(graph, dict(name=node_name + '/range_end')).create_node() node.in_port(0).get_connection().get_source().connect( end_of_range.in_port(0)) end_of_range.out_port(0).connect(axes.in_port(1)) node.add_input_port(1, skip_if_exist=True) axes.out_port(0).connect(node.in_port(1))
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='ATen', operator='embedding_bag'): assert node.soft_get('mode') == 0, 'ATen::embedding_bag has unsupported mode, only "sum" ' \ 'mode is supported for node {}.'.format(node.id) node_name = node.soft_get('name', node.id) rename_node(node, node_name + '/TBR') is_packed = False if len(node.in_ports()) < 3 or node.in_port(2).disconnected(): is_packed = True embedding_bag = EmbeddingBagPackedSum(graph, {'name': node_name}).create_node() else: embedding_bag = EmbeddingBagOffsetsSum(graph, {'name': node_name}).create_node() node.in_port(2).get_connection().set_destination(embedding_bag.in_port(2)) rename_node(embedding_bag, node_name) node.in_port(0).get_connection().set_destination(embedding_bag.in_port(0)) node.in_port(1).get_connection().set_destination(embedding_bag.in_port(1)) node.out_port(0).get_connection().set_source(embedding_bag.out_port(0)) if len(node.in_ports()) == 4 and not node.in_port(3).disconnected(): if is_packed: node.in_port(3).get_connection().set_destination(embedding_bag.in_port(2)) else: # connect per_sample_weights node.in_port(3).get_connection().set_destination(embedding_bag.in_port(4)) weights_shape_node = Shape(graph, {'name': node_name + '/WeightsShape'}).create_node() weights_rank_node = Rank(graph, {'name': node_name + '/WeightsRank'}).create_node() last_dim_node = get_canonical_axis_index_node(weights_rank_node, -1) weights_last_dim = get_shape_values_by_indices_node(weights_shape_node, last_dim_node) weights_first_dim = node_to_get_shape_value_of_indices(weights_shape_node, [0]) zero_col_node = create_op_with_const_inputs(graph, Broadcast, {0: int64_array([0])}, {'name': node_name + '/Broadcast'}) zero_col_node.in_port(1).connect(weights_last_dim.out_port(0)) default_embeddings_node = create_op_with_const_inputs(graph, Unsqueeze, {1: int64_array(0)}, {'name': node_name + '/Unsqueeze'}) default_embeddings_node.in_port(0).connect(zero_col_node.out_port(0)) # expand embedding table with zeros weights_concat = Concat(graph, {'axis': 0, 'in_ports_count': 2, 'name': node_name + '/Concat'}).create_node() embedding_bag.in_port(0).get_connection().set_destination(weights_concat.in_port(0)) weights_concat.in_port(0).get_connection().add_destination(weights_shape_node.in_port(0)) weights_concat.in_port(0).get_connection().add_destination(weights_rank_node.in_port(0)) weights_concat.in_port(1).connect(default_embeddings_node.out_port(0)) weights_concat.out_port(0).connect(embedding_bag.in_port(0)) # point default index to expanded part of embedding table weights_first_dim.out_port(0).connect(embedding_bag.in_port(3))
def convert_ifft_to_dft(self, graph: Graph, mx_fft: Node): mx_fft_name = mx_fft.soft_get('name', mx_fft.id) rank_node = Rank(graph, {'name': mx_fft_name + '/rank'}).create_node() sub_node = create_op_with_const_inputs(graph, Sub, {1: int64_array(1)}, {'name': mx_fft_name + '/Sub'}) rank_node.out_port(0).connect(sub_node.in_port(0)) broadcast_node0 = create_op_with_const_inputs( graph, Broadcast, {0: int64_array(0)}, {'name': mx_fft_name + '/broadcast'}) sub_node.out_port(0).connect(broadcast_node0.in_port(1)) concat_node = create_op_with_const_inputs( graph, Concat, {1: int64_array([-1, 2])}, { 'name': mx_fft_name + '/new_shape', 'in_ports_count': 2, 'axis': 0 }, broadcast_node0) reshape_node = Reshape(graph, { 'name': mx_fft_name + '/reshape' }).create_node() concat_node.out_port(0).connect(reshape_node.in_port(1)) mx_fft_connection = mx_fft.in_port(0).get_connection() mx_fft_connection.set_destination(reshape_node.in_port(0)) mx_fft_connection.get_source().connect(rank_node.in_port(0)) dft_node = create_op_with_const_inputs( graph, IDFT, {1: int64_array([-1])}, { 'name': mx_fft_name + '/idft', 'in_ports_count': 2 }, reshape_node) split_node = create_op_with_const_inputs( graph, Split, {1: int64_array(-1)}, { 'name': mx_fft_name + '/split', 'num_splits': 2 }, dft_node) squeeze_node = create_op_with_const_inputs(graph, Squeeze, {1: int64_array([-1])}, {}, split_node) mx_fft.out_port(0).get_connection().set_source( squeeze_node.out_port(0)) rename_nodes([(mx_fft, mx_fft_name + '/to_be_removed'), (squeeze_node, mx_fft_name)])
def find_and_replace_pattern(self, graph: Graph): global_poolings = graph.get_op_nodes(type='Pooling', global_pool=True) if len(global_poolings) == 0: return layout = graph.graph['layout'] assert layout != 'NHWC', 'Global pooling transformation depends on layout (NHWC not enabled)' for pooling in global_poolings: name = pooling.soft_get('name', pooling.id) assert pooling.has_valid( 'pool_method' ), 'Global Pooling {} has no `pool_method` attribute'.format(name) method = pooling['pool_method'] assert method in self.pool_method_to_reduce_type, \ 'Unexpected Global Pooling method `{}` for node `{}`'.format(method, name) reduce_op_class = self.pool_method_to_reduce_type[method] reduce = reduce_op_class(graph, { 'name': name + '/reduce', 'keep_dims': True }).create_node() pooling.out_port(0).get_connection().set_source(reduce.out_port(0)) src = pooling.in_port(0).get_connection().get_source() reduce.in_port(0).get_connection().set_source(src) start = Const(graph, {'value': int64_array(2)}).create_node() end = Rank(graph, {'name': name + '/input_rank'}).create_node() delta = Const(graph, {'value': int64_array(1)}).create_node() axis = Range(graph, { 'name': name + '/global_pooling_reduce_axis' }).create_node() axis.in_port(0).connect(start.out_port(0)) src.connect(end.in_port(0)) axis.in_port(1).connect(end.out_port(0)) axis.in_port(2).connect(delta.out_port(0)) axis.out_port(0).connect(reduce.in_port(1)) log.debug('Global {} pooling was converted to reduce: `{}`'.format( method, name))
def convert_fft_to_dft(self, graph: Graph, mx_fft: Node): mx_fft_name = mx_fft.soft_get('name', mx_fft.id) unsqueeze_node = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array([-1])}, {'name': mx_fft_name + '/Unsqueeze'}) rank_node = Rank(graph, {'name': mx_fft_name + '/Rank'}).create_node() mx_fft_connection = mx_fft.in_port(0).get_connection() mx_fft_connection.set_destination(unsqueeze_node.in_port(0)) mx_fft_connection.get_source().connect(rank_node.in_port(0)) add_node = create_op_with_const_inputs(graph, Add, {1: int64_array(1)}, {'name': mx_fft_name + '/Add'}, rank_node) broadcast_node1 = create_op_with_const_inputs( graph, Broadcast, {0: int64_array(0)}, {'name': mx_fft_name + '/Pad_broadcast'}) add_node.out_port(0).connect(broadcast_node1.in_port(1)) scatter_node = create_op_with_const_inputs( graph, ScatterUpdate, { 2: int64_array(1), 3: int64_array(0) }, {'name': mx_fft_name + '/ScatterUpdate'}) broadcast_node1.out_port(0).connect(scatter_node.in_port(0)) rank_node.out_port(0).connect(scatter_node.in_port(1)) pad_node = Pad(graph, { 'name': mx_fft_name + '/Pad', 'mode': 'constant' }).create_node([unsqueeze_node, broadcast_node1, scatter_node]) dft_node = create_op_with_const_inputs( graph, DFT, {1: int64_array([-1])}, { 'name': mx_fft_name + '/DFT', 'in_ports_count': 2 }, pad_node) sub_node = create_op_with_const_inputs(graph, Sub, {1: int64_array(1)}, {'name': mx_fft_name + '/Sub'}) rank_node.out_port(0).connect(sub_node.in_port(0)) broadcast_node2 = create_op_with_const_inputs( graph, Broadcast, {0: int64_array(0)}, {'name': mx_fft_name + '/Reshape_broadcast'}) sub_node.out_port(0).connect(broadcast_node2.in_port(1)) concat_node = create_op_with_const_inputs( graph, Concat, {1: int64_array([-1, 2])}, { 'name': mx_fft_name + '/New_shape', 'in_ports_count': 2, 'axis': 0 }, broadcast_node2) reshape_node = Reshape(graph, {}).create_node([dft_node, concat_node]) mx_fft.out_port(0).get_connection().set_source( reshape_node.out_port(0)) rename_nodes([(mx_fft, mx_fft_name + '/to_be_removed'), (reshape_node, mx_fft_name)])
def replace_sub_graph(self, graph: Graph, match: dict): node = match['flatten'] name = node.soft_get('name', node.id) assert node.has_valid('axis'), 'Flatten {} has no mandatory `axis` attribute'.format(name) assert node.has_valid('end_axis'), 'Flatten {} has no mandatory `end_axis` attribute'.format(name) axis = node.axis end_axis = node.end_axis if end_axis == -1 and axis >= 0: begin_dims = Const(graph, {'value': int64_array([0] * axis)}).create_node() middle_dim = Const(graph, {'value': int64_array([-1])}).create_node() end_dims = Const(graph, {'value': int64_array([])}).create_node() else: rank = Rank(graph, {'name': name + '/input_rank'}).create_node() node.in_port(0).get_source().connect(rank.in_port(0)) shape = Shape(graph, {'name': name + '/input_shape'}).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) begin_dims = get_shape_values_by_range_idxs( shape=shape, rank=rank, begin=0, end=axis) middle_dims = get_shape_values_by_range_idxs( shape=shape, rank=rank, begin=axis, end=end_axis, include_end=True) end_dims = get_shape_values_by_range_idxs( shape=shape, rank=rank, begin=end_axis, end=-1, include_begin=False, include_end=True) middle_dim = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]), {'keep_dims': True}) middle_dims.out_port(0).connect(middle_dim.in_port(0)) dim = new_shape_node_from_shape_nodes([begin_dims, middle_dim, end_dims]) original_name = node.soft_get('name') abandoned_name = original_name + '/ShouldBeDeleted' reshape_node = Reshape(graph, {}).create_node() # Keep node with the same name to avoid confuse with renaming rename_nodes([(node, abandoned_name), (reshape_node, original_name)]) reshape_node.in_port(1).connect(dim.out_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0)) node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
def replace_op(self, graph: Graph, node: Node): name = node.soft_get('name', node.id) # create range of axes for MVN based on `start_axis` and rank of input rank = Rank(graph, {'name': name + '/Rank'}).create_node() rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(2), 2: int64_array(1) }, { 'name': name + '/Range', 'output_type': np.int64 }) mvn = MVN( graph, { 'eps': node.epsilon, 'eps_mode': 'inside_sqrt', 'normalize_variance': 1, 'name': name + '/Ins_Norm/MVN_', }).create_node() node.in_port(0).get_connection().set_destination(mvn.in_port(0)) rng.out_port(0).connect(mvn.in_port(1)) mul = Mul(graph, { 'axis': 1, 'name': name + '/Ins_Norm/mul_' }).create_node() mvn.out_port(0).connect(mul.in_port(0)) node.in_port(1).get_connection().set_destination(mul.in_port(1)) add = Add(graph, { 'axis': 1, 'name': name + '/Ins_Norm/add_' }).create_node() mul.out_port(0).connect(add.in_port(0)) node.in_port(2).get_connection().set_destination(add.in_port(1)) mvn.in_port(0).get_connection().add_destination(rank.in_port(0)) rng.in_port(1).connect(rank.out_port(0)) rename_nodes([(node, name + '/TBD'), (add, name)]) return [add.id]
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='SpaceToBatch') + graph.get_op_nodes( op='BatchToSpace'): node.add_input_port(3, skip_if_exist=True) # convert TF representation of the pads/crops as [N, 2] to IE representation: [N] and [N] transposed_pads = create_op_with_const_inputs( graph, Transpose, {1: int64_array([1, 0])}) node.in_port(2).get_connection().set_destination( transposed_pads.in_port(0)) split_pads = create_op_with_const_inputs(graph, Split, {1: int64_array(0)}, {'num_splits': 2}) transposed_pads.out_port(0).connect(split_pads.in_port(0)) for port_ind in range(2): node.in_port(port_ind + 2).connect( split_pads.out_port(port_ind)) node.in_port(port_ind + 2).get_connection().insert_node( create_op_with_const_inputs(graph, Squeeze, {1: int64_array([0])})) # add zeros/ones to related inputs to align it with data input in0_rank = Rank(graph, { 'name': node.name + '/rank_0' }).create_node() in1_shape = Shape(graph, { 'name': node.name + '/rank_1' }).create_node() diff_size = Sub(graph, { 'name': node.name + '/sub_0' }).create_node() diff = Sub(graph, {'name': node.name + '/sub_1'}).create_node() const_begin = Const(graph, { 'value': int64_array([1]) }).create_node() const_pad_val = Const(graph, { 'value': int64_array(1) }).create_node() block_shape = Pad(graph, { 'name': node.name + '/aligned_block_shape', 'mode': 'constant' }).create_node() # in case of SpaceToBatch begin = pads_begin, end = pads_end # in case of BatchToSpace begin = crops_begin, end = crops_end new_begin_name = '/aligned_pads_begin' new_end_name = '/aligned_pads_end' if node.type == 'BatchToSpace': new_begin_name = '/aligned_crops_begin' new_end_name = '/aligned_crops_end' begin = Pad(graph, { 'name': node.name + new_begin_name, 'mode': 'constant' }).create_node() end = Pad(graph, { 'name': node.name + new_end_name, 'mode': 'constant' }).create_node() in0_rank_1d = create_op_node_with_second_input( graph, Unsqueeze, int64_array([0]), {'name': node.name + '/1d_rank_of_0'}, in0_rank) node.in_port(0).get_source().connect(in0_rank.in_port(0)) node.in_port(1).get_source().connect(in1_shape.in_port(0)) in0_rank_1d.out_port(0).connect(diff_size.in_port(0)) in1_shape.out_port(0).connect(diff_size.in_port(1)) diff_size.out_port(0).connect(diff.in_port(0)) const_begin.out_port(0).connect(diff.in_port(1)) const_pad_val.out_port(0).connect(block_shape.in_port(3)) inputs_array = [block_shape, begin, end] for idx, input_to_node in enumerate(inputs_array): name_of_input_to_node = input_to_node.name node.in_port(idx + 1).get_connection().set_destination( input_to_node.in_port(0)) const_begin.out_port(0).connect(input_to_node.in_port(1)) diff.out_port(0).connect(input_to_node.in_port(2)) input_to_node.out_port(0).connect(node.in_port(idx + 1)) convert = Cast(graph, { 'name': name_of_input_to_node + '/i64', 'dst_type': np.int64 }).create_node() input_to_node.in_port(0).get_connection().insert_node(convert)
def replace_resize(graph: Graph, resize: Node): log.debug("Converting of ONNX Resize-10 to Interpolate-4 " "is triggered for node {}.".format( resize.soft_get('name', resize.id))) resize_name = resize.soft_get('name', resize.id) rank_node = Rank(graph, {'name': resize_name + '/max_axes'}).create_node() range_node = create_op_with_const_inputs(graph, Range, { 0: int64_array(2), 2: int64_array(1) }, {'name': resize_name + '/axes'}) sizes_ss = create_op_with_const_inputs(graph, StridedSlice, { 1: int64_array([2]), 2: int64_array([0]), 3: int64_array([1]) }, { 'name': resize_name + '/sizes_ss', 'begin_mask': int64_array([1]), 'end_mask': int64_array([0]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) scales_ss = create_op_with_const_inputs( graph, StridedSlice, { 1: int64_array([2]), 2: int64_array([0]), 3: int64_array([1]) }, { 'name': resize_name + '/scales_ss', 'begin_mask': int64_array([1]), 'end_mask': int64_array([0]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]) }) rank_node.out_port(0).connect(range_node.in_port(1)) interpolate_node = Interpolate( graph, { 'version': 'opset4', 'mode': 'linear_onnx' if resize.mode == 'linear' else 'nearest', 'coordinate_transformation_mode': 'asymmetric', 'cube_coeff': -0.75, 'nearest_mode': 'simple', 'pads_begin': int64_array([0]), 'pads_end': int64_array([0]), 'antialias': 0, 'shape_calculation_mode': 'scales', 'in_ports_count': 4 }).create_node() range_node.out_port(0).connect(interpolate_node.in_port(3)) shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node() # When we calculate 'sizes' input as floor(input_shape * scales), we can get incorrect 'sizes' if, e.g., # scales = [1.0, 1.0, 1.33333, 2.0], input_shape = [1, 3, 30, 200], because # input_shape * scales = [1, 3, 39.9999, 400], and floor(input_shape * scales)[2] == 39, not 40. # Maybe we need to calculate 'sizes' input as floor(input_shape * scales + eps), where eps is some small # floating point number, e.g. 1.0e-5. But, in this case, if scales = [1.0, 1.0, 1.333333, 2.0], # input_shape = [1, 3, 30, 200], floor(input_shape * scales + eps) = 39, not 40, because # input_shape[2] * scales[2] + 1.0e-5 = 39.99991. # Hence, we need to calculate 'sizes' as floor(input_shape * (scales + eps)). add_node = create_op_with_const_inputs(graph, Add, {1: float_array([1.0e-5])}, {'name': resize_name + '/Add'}) dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node() shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) mul_node = Mul(graph, { 'name': resize_name + '/Mul' }).create_node([cast_shape_to_float, add_node]) floor_node = Floor(graph, { 'name': resize_name + '/Floor' }).create_node([mul_node]) cast_mul_result_to_int = Cast(graph, { 'dst_type': np.int64 }).create_node([floor_node]) cast_mul_result_to_int.out_port(0).connect(sizes_ss.in_port(0)) sizes_ss.out_port(0).connect(interpolate_node.in_port(1)) scales_ss.out_port(0).connect(interpolate_node.in_port(2)) connection_of_resize_input = resize.in_port(0).get_connection() connection_of_resize_input.set_destination(interpolate_node.in_port(0)) connection_of_scales = resize.in_port(1).get_connection() connection_of_scales.set_destination(scales_ss.in_port(0)) connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_resize_input.get_source().connect(rank_node.in_port(0)) connection_of_scales.get_source().connect(add_node.in_port(0)) rename_nodes([(resize, resize_name + '/delete'), (interpolate_node, resize_name)]) resize.out_port(0).get_connection().set_source( interpolate_node.out_port(0))
def mxrepeat_decomposition(node: Node): graph = node.graph name = node.soft_get('name', node.id) rename_node(node, name + '/to_be_removed') # Unqueeze input_rank = Rank(graph, {'name': name + '/Rank'}).create_node() node.in_port(0).get_source().connect(input_rank.in_port(0)) axis = get_canonical_axis_index_node(input_rank, node.axis) unsqueeze_axis = create_op_node_with_second_input( graph, Add, int64_array([1]), {'name': name + '/Unsqueeze/Axis'}, input_node=axis) unsqueeze = Unsqueeze(graph, { 'name': name + '/Unsqueeze' }).create_node() unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0)) # Tile (1, 1, ..., repeats, ..., 1) # we generate tile array according to the following table: # parts: | first | repeats | second | # i: | 0, 1, ..., axis,| axis + 1,| ..., rank+1 | # tile_array: | 1, 1, ..., 1 ,| repeats ,| ..., 1 | one = Const(graph, { 'name': name + '/Broadcast/One', 'value': int64_array([1]) }).create_node() first_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_first_part' }).create_node() first_ones.in_port(0).connect(one.out_port(0)) first_ones.in_port(1).connect(unsqueeze_axis.out_port(0)) repeats = Const(graph, { 'name': name + '/repeats', 'value': int64_array([node.repeats]) }).create_node() second_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_second_part' }).create_node() second_part_broadcast_shape = Sub( graph, { 'name': name + '/Broadcast/Shape/second_part' }).create_node() second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0)) second_part_broadcast_shape.in_port(1).connect( unsqueeze_axis.out_port(0)) second_ones.in_port(0).connect(one.out_port(0)) second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0)) tile_repeats = new_shape_node_from_shape_nodes( [first_ones, repeats, second_ones]) tile = Tile(graph, {'name': name + '/Tile'}).create_node() tile.in_port(1).connect(tile_repeats.out_port(0)) # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:]) # we generate reshape dim array according to the following table: # parts: | first | rep | second | # i: | 0, 1, ... ,| axis, | ..., rank | # dim_array: | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] | input_shape = Shape(graph, {'name': name + '/Shape'}).create_node() node.in_port(0).get_source().connect(input_shape.in_port(0)) first_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=0, end=node.axis, include_begin=True, include_end=False) original_axis_dim = create_op_with_const_inputs( graph, Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'}, input_node=input_shape) original_axis_dim.in_port(1).connect(axis.out_port(0)) repeated_dimention = Mul(graph, { 'name': name + '/RepeatedDim' }).create_node() repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0)) repeated_dimention.in_port(1).connect(repeats.out_port(0)) second_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=node.axis, end=-1, include_begin=False, include_end=True) output_shape = new_shape_node_from_shape_nodes([ first_input_shape_part, repeated_dimention, second_input_shape_part ]) reshape = Reshape(graph, {'name': name}).create_node() rename_node(reshape, name) reshape.in_port(1).connect(output_shape.out_port(0)) # Final connections node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0)) tile.in_port(0).connect(unsqueeze.out_port(0)) reshape.in_port(0).connect(tile.out_port(0)) node.out_port(0).get_connection().set_source(reshape.out_port(0))
def extract(cls, node: Node): Rank.update_node_stat(node, {'output_type': np.int32}) return cls.enabled