def replace_sub_graph(self, graph: Graph, match: dict): node = match['flatten'] name = node.soft_get('name', node.id) assert node.has_valid('axis'), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format(name) axis = node.axis if axis == 0: dim = Const(graph, {'value': int64_array([1, -1])}).create_node() elif axis == 1: dim = Const(graph, {'value': int64_array([0, -1])}).create_node() else: shape = Shape(graph, {'name': name + '/input_shape'}).create_node() idxs = list(range(axis)) if axis > 0 else list(range(axis, 0)) axis_shape_portion = node_to_get_shape_value_of_indices(shape, idxs) first_dims = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]), {'keep_dims': True}) second_dims = Const(graph, {'value': int64_array([-1])}).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) axis_shape_portion.out_port(0).connect(first_dims.in_port(0)) order_of_dims = [first_dims, second_dims] if axis > 0 else [second_dims, first_dims] dim = new_shape_node_from_shape_nodes(order_of_dims) reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node() reshape_node.in_port(1).connect(dim.out_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0)) node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
def replace_pattern(self, graph: Graph, match: dict): matmul = match['matmul'] reshape = match['reshape'] other_input_port_idx = 0 if match['matmul'].in_port(0).get_source().node.id == match['other_input'].id else 1 shape_source = match['matmul'].in_port(other_input_port_idx).get_source() initial_reshape_pattern = reshape.in_port(1).data.get_value() if len(initial_reshape_pattern) != 2: return reshape_is_A_input = matmul.in_port(0).get_source().node.id == reshape.id if reshape_is_A_input: idx = -1 if matmul.transpose_b else -2 else: idx = -2 if matmul.transpose_a else -1 idx = get_canonical_axis_index(initial_reshape_pattern, idx) shape_name = shape_source.node.soft_get('name', shape_source.node.id) shape = Shape(graph, {'name': shape_name + '/Shape'}).create_node() shape.in_port(0).connect(shape_source) C = node_to_get_shape_value_of_indices(shape, [idx]) N = Const(graph, {'name': shape_name + '/MinusOne', 'value': int64_array([-1])}).create_node() if len(initial_reshape_pattern) == 2: if reshape_is_A_input: reshape_pattern = [C, N] if matmul.transpose_a else [N, C] else: reshape_pattern = [N, C] if matmul.transpose_b else [C, N] new_reshape_pattern = new_shape_node_from_shape_nodes(reshape_pattern) reshape.in_port(1).get_connection().set_source(new_reshape_pattern.out_port(0)) else: return
def decompose_shuffle_channel(node: Node): graph = node.graph name = node.soft_get('name', node.id) rename_node(node, name + '/to_be_removed') shape = Shape(graph, dict(name=name + '/InputShape')).create_node() shape.in_port(0).connect(node.in_port(0).get_source()) # Reshape [input_batch, group, input_channels/group, -1] batch = node_to_get_batch_value(shape) group = Const( graph, dict(name=name + '/Rows', value=int64_array([node.group]))).create_node() const = Const(graph, dict(name=name + '/Const', value=int64_array([-1]))).create_node() input_channels = node_to_get_features_dimension_value(shape) output_channels = create_op_node_with_second_input( graph, Div, np.int64(node.group), {'name': name + '/Cols'}, input_node=input_channels) i_output_channels = Cast(graph, { 'name': output_channels.name + '/Convert', 'dst_type': np.int64 }).create_node() output_channels.out_port(0).connect(i_output_channels.in_port(0)) reshape_split_dim = new_shape_node_from_shape_nodes( [batch, group, i_output_channels, const]) reshape_split_node = Reshape( graph, dict(name=name + '/Reshape_split_')).create_node() reshape_split_dim.out_port(0).connect(reshape_split_node.in_port(1)) # Transpose(0, 2, 1, 3) transpose_node = create_op_node_with_second_input( graph, Transpose, int64_array([0, 2, 1, 3]), {'name': name + '/Transpose_'}, input_node=reshape_split_node) # Reshape back to input shape reshape_concat = Reshape(graph, dict(name=name)).create_node() rename_node(reshape_concat, name) shape.out_port(0).connect(reshape_concat.in_port(1)) transpose_node.out_port(0).connect(reshape_concat.in_port(0)) # Final connections node.in_port(0).get_connection().set_destination( reshape_split_node.in_port(0)) node.out_port(0).get_connection().set_source( reshape_concat.out_port(0))
def replace_pattern(graph: Graph, match: dict): """ Workarounds not supported type of Tile in Inference Engine (Tiles are supported for 2-D or 4-D tensors): Searches for Tiles with 3D shapes and covers it with Reshapes. Example: Tile (axis=1, tiles=16): in_shape: [1,1,101] out_shape: [1,16,101] Old behaviour: Tile -> [1,16,101] New behaviour: Reshape [1,1,101,1] -> Tile -> [1,16,101,1] -> Reshape [1,16,101] """ node = match['tile'] name = node.soft_get('name', node.id) out_shape = node.out_port(0).data.get_shape() assert out_shape is not None, 'Output shape is undefined for {} in back phase'.format(name) if out_shape.size != 3: return inp_shape = node.in_port(0).data.get_shape() assert inp_shape is not None, 'Input shape is undefined for {} in back phase'.format(name) unsqueeze_dim = Const(graph, {'name': name + '/3D_Tile_Unsqueeze_dim', 'value': int64_array([3])}).create_node() unsqueeze = Unsqueeze(graph, {'name': name + '/3D_Tile_Unsqueeze', 'override_output_shape': True}).create_node() unsqueeze_dim.out_port(0).connect(unsqueeze.in_port(1)) const = Const(graph, {'name': name + '/additional_axis', 'value': int64_array([1])}).create_node() new_tiles = new_shape_node_from_shape_nodes([node.in_port(1).get_source().node, const]) node.in_port(1).get_connection().set_source(new_tiles.out_port(0)) squeeze_dim = Const(graph, {'name': name + '/3D_Tile_Squeeze_dim', 'value': int64_array([3])}).create_node() squeeze = Squeeze(graph, {'name': name + '/3D_Tile_Squeeze', 'override_output_shape': True}).create_node() squeeze_dim.out_port(0).connect(squeeze.in_port(1)) source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(unsqueeze.out_port(0)) unsqueeze.in_port(0).connect(source) node.out_port(0).get_connection().set_source(squeeze.out_port(0)) node.out_port(0).connect(squeeze.in_port(0)) node['override_output_shape'] = True new_tiles['override_output_shape'] = True node['need_shape_inference'] = True
def replace_pattern(graph: Graph, match: Dict[str, Node]): node = match['op'] name = node.soft_get('name', node.id) input_shape = node.in_port(0).data.get_shape() second_input_shape = node.in_port(1).data.get_shape() begin_mask = np.zeros(len(input_shape), dtype=np.int64) end_mask = np.zeros(len(input_shape), dtype=np.int64) for i in node.axes: end_mask[i] = np.int64(1) new_axis_mask = np.zeros(len(input_shape), dtype=np.int64) shrink_axis_mask = np.zeros(len(input_shape), dtype=np.int64) ellipsis_mask = np.zeros(len(input_shape), dtype=np.int64) ss = create_op_with_const_inputs(graph, StridedSlice, port_value_dict={1: np.zeros(len(input_shape), dtype=np.int64)}, op_attrs={'name': 'StridedSlice', 'begin_mask': begin_mask, 'end_mask': end_mask, 'new_axis_mask': new_axis_mask, 'shrink_axis_mask': shrink_axis_mask, 'ellipsis_mask': ellipsis_mask}) if input_shape.size == second_input_shape.size: end = Shape(graph, dict(name=name + '/End')).create_node() end.in_port(0).connect(node.in_port(1).get_source()) ss.in_port(2).connect(end.out_port(0)) else: shape_like, rank_like = get_shape_and_rank_nodes_by_port(node.in_port(1).get_source()) end_first_part = get_shape_values_by_range_idxs(shape_like, rank_like, 0, node.axes[-1], include_end=True) if input_shape.size - 1 == node.axes[-1]: ss.in_port(2).connect(end_first_part.out_port(0)) else: shape, rank = get_shape_and_rank_nodes_by_port(node.in_port(0).get_source()) end_second_part = get_shape_values_by_range_idxs(shape, rank, node.axes[-1], -1, include_begin=False, include_end=True) end = new_shape_node_from_shape_nodes([end_first_part, end_second_part]) ss.in_port(2).connect(end.out_port(0)) node.in_port(0).get_connection().set_destination(ss.in_port(0)) node.in_port(1).disconnect() node.out_port(0).get_connection().set_source(ss.out_port(0)) rename_nodes([(node, name + '/ShouldBeDeleted'), (ss, name)])
def replace_sub_graph(self, graph: Graph, match: dict): node = match['flatten'] name = node.soft_get('name', node.id) assert node.has_valid( 'axis'), 'Flatten {} has no mandatory `axis` attribute'.format( name) assert node.has_valid( 'end_axis' ), 'Flatten {} has no mandatory `end_axis` attribute'.format(name) axis = node.axis end_axis = node.end_axis if end_axis == -1 and axis >= 0: begin_dims = Const(graph, { 'value': int64_array([0] * axis) }).create_node() middle_dim = Const(graph, { 'value': int64_array([-1]) }).create_node() end_dims = Const(graph, {'value': int64_array([])}).create_node() else: rank = Rank(graph, {'name': name + '/input_rank'}).create_node() node.in_port(0).get_source().connect(rank.in_port(0)) shape = Shape(graph, {'name': name + '/input_shape'}).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) begin_dims = get_shape_values_by_range_idxs(shape=shape, rank=rank, begin=0, end=axis) middle_dims = get_shape_values_by_range_idxs(shape=shape, rank=rank, begin=axis, end=end_axis, include_end=True) end_dims = get_shape_values_by_range_idxs(shape=shape, rank=rank, begin=end_axis, end=-1, include_begin=False, include_end=True) middle_dim = create_op_node_with_second_input( graph, ReduceProd, int64_array([0]), {'keep_dims': True}) middle_dims.out_port(0).connect(middle_dim.in_port(0)) dim = new_shape_node_from_shape_nodes( [begin_dims, middle_dim, end_dims]) original_name = node.soft_get('name') abandoned_name = original_name + '/ShouldBeDeleted' reshape_node = Reshape(graph, {}).create_node() # Keep node with the same name to avoid confuse with renaming rename_nodes([(node, abandoned_name), (reshape_node, original_name)]) reshape_node.in_port(1).connect(dim.out_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0)) node.in_port(0).get_connection().set_destination( reshape_node.in_port(0))
def mxrepeat_decomposition(node: Node): graph = node.graph name = node.soft_get('name', node.id) rename_node(node, name + '/to_be_removed') # Unqueeze input_rank = Rank(graph, {'name': name + '/Rank'}).create_node() node.in_port(0).get_source().connect(input_rank.in_port(0)) axis = get_canonical_axis_index_node(input_rank, node.axis) unsqueeze_axis = create_op_node_with_second_input( graph, Add, int64_array([1]), {'name': name + '/Unsqueeze/Axis'}, input_node=axis) unsqueeze = Unsqueeze(graph, { 'name': name + '/Unsqueeze' }).create_node() unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0)) # Tile (1, 1, ..., repeats, ..., 1) # we generate tile array according to the following table: # parts: | first | repeats | second | # i: | 0, 1, ..., axis,| axis + 1,| ..., rank+1 | # tile_array: | 1, 1, ..., 1 ,| repeats ,| ..., 1 | one = Const(graph, { 'name': name + '/Broadcast/One', 'value': int64_array([1]) }).create_node() first_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_first_part' }).create_node() first_ones.in_port(0).connect(one.out_port(0)) first_ones.in_port(1).connect(unsqueeze_axis.out_port(0)) repeats = Const(graph, { 'name': name + '/repeats', 'value': int64_array([node.repeats]) }).create_node() second_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_second_part' }).create_node() second_part_broadcast_shape = Sub( graph, { 'name': name + '/Broadcast/Shape/second_part' }).create_node() second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0)) second_part_broadcast_shape.in_port(1).connect( unsqueeze_axis.out_port(0)) second_ones.in_port(0).connect(one.out_port(0)) second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0)) tile_repeats = new_shape_node_from_shape_nodes( [first_ones, repeats, second_ones]) tile = Tile(graph, {'name': name + '/Tile'}).create_node() tile.in_port(1).connect(tile_repeats.out_port(0)) # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:]) # we generate reshape dim array according to the following table: # parts: | first | rep | second | # i: | 0, 1, ... ,| axis, | ..., rank | # dim_array: | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] | input_shape = Shape(graph, {'name': name + '/Shape'}).create_node() node.in_port(0).get_source().connect(input_shape.in_port(0)) first_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=0, end=node.axis, include_begin=True, include_end=False) original_axis_dim = create_op_with_const_inputs( graph, Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'}, input_node=input_shape) original_axis_dim.in_port(1).connect(axis.out_port(0)) repeated_dimention = Mul(graph, { 'name': name + '/RepeatedDim' }).create_node() repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0)) repeated_dimention.in_port(1).connect(repeats.out_port(0)) second_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=node.axis, end=-1, include_begin=False, include_end=True) output_shape = new_shape_node_from_shape_nodes([ first_input_shape_part, repeated_dimention, second_input_shape_part ]) reshape = Reshape(graph, {'name': name}).create_node() rename_node(reshape, name) reshape.in_port(1).connect(output_shape.out_port(0)) # Final connections node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0)) tile.in_port(0).connect(unsqueeze.out_port(0)) reshape.in_port(0).connect(tile.out_port(0)) node.out_port(0).get_connection().set_source(reshape.out_port(0))
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): group_norm_node = match['op'] group_norm_num_input_dims = len( group_norm_node.in_port(0).data.get_shape()) # node computing initial GroupNorm input shape initial_shape_op_node = Shape(graph, { 'name': group_norm_node.name + '/Shape' }).create_node() initial_shape_op_node.in_port(0).connect( group_norm_node.in_port(0).get_source()) initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node) initial_features_dim_node = node_to_get_features_dimension_value( initial_shape_op_node) initial_spatial_dims_node = node_to_get_spatial_dimensions_value( initial_shape_op_node) group_size_node = Const( graph, { 'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize' }).create_node() # calculate "features // group_size" value reciprocal_group_size_node = Const( graph, { 'value': np.array([1.0 / group_norm_node.num_groups]), 'name': group_norm_node.name + '/ReciprocalGroupSize' }).create_node() c_div_g_node = Mul(graph, {}).create_node() c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0)) c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0)) batch_mul_group_size_node = Mul(graph, {}).create_node() batch_mul_group_size_node.in_port(0).connect( initial_batch_dim_node.out_port(0)) batch_mul_group_size_node.in_port(1).connect( group_size_node.out_port(0)) # create new node which concatenates several dims to one new_shape_node = new_shape_node_from_shape_nodes([ batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node ]) reshape_for_mvn_node = Reshape(graph, {}).create_node() group_norm_node.in_port(0).get_connection().set_destination( reshape_for_mvn_node.in_port(0)) reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0)) # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64) gamma_beta_shape[1] = -1 gamma_value = group_norm_node.in_port(1).get_source().data.get_value() beta_value = group_norm_node.in_port(2).get_source().data.get_value() assert gamma_value is not None, 'The gamma should be constant' assert beta_value is not None, 'The beta should be constant' gamma_value = np.reshape(gamma_value, gamma_beta_shape) group_norm_node.in_port(1).get_source().data.set_value(gamma_value) beta_value = np.reshape(beta_value, gamma_beta_shape) group_norm_node.in_port(2).get_source().data.set_value(beta_value) # MVN mvn_node = MVN( graph, { 'name': group_norm_node.name + '/MVN', 'across_channels': 1, 'normalize_variance': 1, 'eps': group_norm_node.eps }).create_node() mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0)) # reshape to the initial shape before multiplying with gamma and adding beta reshape_to_initial_shape_node = Reshape(graph, {}).create_node() reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0)) reshape_to_initial_shape_node.in_port(1).connect( initial_shape_op_node.out_port(0)) mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node() mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0)) group_norm_node.in_port(1).get_connection().set_destination( mul_node.in_port(1)) add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node() add_node.in_port(0).connect(mul_node.out_port(0)) group_norm_node.in_port(2).get_connection().set_destination( add_node.in_port(1)) group_norm_node.out_port(0).get_connection().set_source( add_node.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['interpolate'] assert 1 in node.in_ports() and not node.in_port(1).disconnected() and \ node.in_port(1).data.get_value() is not None, 'Interpolate node {} is corrupted: no 1-port input found' # common mode = node.mode assert mode in ['linear', 'nearest', 'cubic', 'area'] in_shape = node.in_port(0).data.get_shape() assert in_shape is not None and len(in_shape) in [4, 5] out_shape = node.out_port(0).data.get_shape() assert out_shape is not None and len(out_shape) in [4, 5] in_height, in_width = in_shape[2], in_shape[3] out_height, out_width = out_shape[2], out_shape[3] factor = factor_update( None if not node.has_valid('factor') else node.factor, [float(out_height) / in_height, float(out_width) / in_width], [in_height, in_width], [out_height, out_width], node.soft_get('name')) update_attrs = { 'width': out_width, 'height': out_height, 'factor': factor, } if (node.has_valid('shrink_factor') and node.has_valid('zoom_factor')) or factor is None: del update_attrs['factor'] if node.has('factor'): del node['factor'] if ((node.has_valid('shrink_factor') and node.shrink_factor != 1) or (node.has_valid('zoom_factor') and node.zoom_factor != 1) or 'factor' in update_attrs) \ and ((not node.has_valid('width') or node.width == 0) and (not node.has_valid('height') or node.height == 0)): update_attrs['width'] = 0 update_attrs['height'] = 0 # specific if mode in ['nearest', 'cubic', 'area' ] or node.has_and_set('convert_to_resample'): assert not node.align_corners assert node.pads_begin == 0 and node.pads_end == 0 update_attrs[ 'resample_type'] = InterpolateToInterpOrResample.type_map[mode] ResampleOp.update_node_stat(node, update_attrs) if not graph.graph[ 'cmd_params'].keep_shape_ops or graph.graph['fw'] != 'tf': node.in_port(1).disconnect() else: # we avoid making resample non-reshapable for tf version shape = Shape(graph, {}).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) batch = node_to_get_batch_value(shape) features = node_to_get_features_dimension_value(shape) full_shape = new_shape_node_from_shape_nodes( [batch, features, node.in_port(1).get_source().node]) node.in_port(1).get_connection().set_source( full_shape.out_port(0)) full_shape['override_output_shape'] = True elif mode == 'linear': assert len(in_shape) == 4, 'Interp does not support 5D input' update_attrs.update({ 'pad_beg': node.pads_begin, 'pad_end': node.pads_end, 'align_corners': node.align_corners, }) InterpOp.update_node_stat(node, update_attrs) node.in_port(1).disconnect()
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): group_norm_node = match['op'] group_norm_num_input_dims = len( group_norm_node.in_port(0).data.get_shape()) # node computing initial GroupNorm input shape initial_shape_op_node = Shape(graph, { 'name': group_norm_node.name + '/Shape' }).create_node() initial_shape_op_node.in_port(0).connect( group_norm_node.in_port(0).get_source()) initial_shape_op_node_float = Cast( graph, { 'name': initial_shape_op_node.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() initial_shape_op_node.out_port(0).connect( initial_shape_op_node_float.in_port(0)) initial_batch_dim_node = node_to_get_batch_value( initial_shape_op_node_float) initial_features_dim_node = node_to_get_features_dimension_value( initial_shape_op_node_float) initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value( initial_shape_op_node) initial_spatial_dims_node = Cast( graph, { 'name': initial_spatial_dims_node_int.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() initial_spatial_dims_node_int.out_port(0).connect( initial_spatial_dims_node.in_port(0)) group_size_node = Const( graph, { 'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize' }).create_node() # calculate "features // group_size" value reciprocal_group_size_node = Const( graph, { 'value': np.array([1.0 / group_norm_node.num_groups]), 'name': group_norm_node.name + '/ReciprocalGroupSize' }).create_node() c_div_g_node = Mul(graph, {}).create_node() c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0)) c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0)) batch_mul_group_size_node = Mul(graph, {}).create_node() batch_mul_group_size_node.in_port(0).connect( initial_batch_dim_node.out_port(0)) batch_mul_group_size_node.in_port(1).connect( group_size_node.out_port(0)) # create new node which concatenates several dims to one new_shape_node_float = new_shape_node_from_shape_nodes([ batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node ]) new_shape_node = Cast(graph, { 'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64 }).create_node() new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0)) reshape_for_mvn_node = Reshape(graph, {}).create_node() group_norm_node.in_port(0).get_connection().set_destination( reshape_for_mvn_node.in_port(0)) reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0)) # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64) gamma_beta_shape[1] = -1 gamma_value = group_norm_node.in_port(1).get_source().data.get_value() beta_value = group_norm_node.in_port(2).get_source().data.get_value() assert gamma_value is not None, 'The gamma should be constant' assert beta_value is not None, 'The beta should be constant' gamma_value = np.reshape(gamma_value, gamma_beta_shape) group_norm_node.in_port(1).get_source().data.set_value(gamma_value) beta_value = np.reshape(beta_value, gamma_beta_shape) group_norm_node.in_port(2).get_source().data.set_value(beta_value) # MVN mvn_node = MVN( graph, { 'name': group_norm_node.name + '/MVN', 'normalize_variance': 1, 'eps': group_norm_node.eps, 'eps_mode': 'inside_sqrt' }).create_node() mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0)) # MVN axes _, rank = get_shape_and_rank_nodes_by_port( mvn_node.in_port(0).get_connection().get_source(), return_as_a_scalar=True) rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(1), 2: int64_array(1) }, { 'name': group_norm_node.name + '/Range', 'output_type': np.int64 }) mvn_node.in_port(1).connect(rng.out_port(0)) rng.in_port(1).connect(rank.out_port(0)) # reshape to the initial shape before multiplying with gamma and adding beta reshape_to_initial_shape_node = Reshape(graph, {}).create_node() reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0)) reshape_to_initial_shape_node.in_port(1).connect( initial_shape_op_node.out_port(0)) mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node() mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0)) group_norm_node.in_port(1).get_connection().set_destination( mul_node.in_port(1)) add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node() add_node.in_port(0).connect(mul_node.out_port(0)) group_norm_node.in_port(2).get_connection().set_destination( add_node.in_port(1)) group_norm_node.out_port(0).get_connection().set_source( add_node.out_port(0))