コード例 #1
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['flatten']
        name = node.soft_get('name', node.id)

        assert node.has_valid('axis'), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format(name)
        axis = node.axis

        if axis == 0:
            dim = Const(graph, {'value': int64_array([1, -1])}).create_node()
        elif axis == 1:
            dim = Const(graph, {'value': int64_array([0, -1])}).create_node()
        else:
            shape = Shape(graph, {'name': name + '/input_shape'}).create_node()

            idxs = list(range(axis)) if axis > 0 else list(range(axis, 0))

            axis_shape_portion = node_to_get_shape_value_of_indices(shape, idxs)
            first_dims = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]),
                                                          {'keep_dims': True})
            second_dims = Const(graph, {'value': int64_array([-1])}).create_node()

            node.in_port(0).get_source().connect(shape.in_port(0))
            axis_shape_portion.out_port(0).connect(first_dims.in_port(0))

            order_of_dims = [first_dims, second_dims] if axis > 0 else [second_dims, first_dims]

            dim = new_shape_node_from_shape_nodes(order_of_dims)

        reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node()
        reshape_node.in_port(1).connect(dim.out_port(0))

        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
        node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
コード例 #2
0
ファイル: MatMulNormalizer.py プロジェクト: SDxKeeper/dldt
    def replace_pattern(self, graph: Graph, match: dict):
        matmul = match['matmul']
        reshape = match['reshape']
        other_input_port_idx = 0 if match['matmul'].in_port(0).get_source().node.id == match['other_input'].id else 1
        shape_source = match['matmul'].in_port(other_input_port_idx).get_source()
        initial_reshape_pattern = reshape.in_port(1).data.get_value()
        if len(initial_reshape_pattern) != 2:
            return

        reshape_is_A_input = matmul.in_port(0).get_source().node.id == reshape.id
        if reshape_is_A_input:
            idx = -1 if matmul.transpose_b else -2
        else:
            idx = -2 if matmul.transpose_a else -1
        idx = get_canonical_axis_index(initial_reshape_pattern, idx)

        shape_name = shape_source.node.soft_get('name', shape_source.node.id)
        shape = Shape(graph, {'name': shape_name + '/Shape'}).create_node()
        shape.in_port(0).connect(shape_source)
        C = node_to_get_shape_value_of_indices(shape, [idx])
        N = Const(graph, {'name': shape_name + '/MinusOne', 'value': int64_array([-1])}).create_node()

        if len(initial_reshape_pattern) == 2:
            if reshape_is_A_input:
                reshape_pattern = [C, N] if matmul.transpose_a else [N, C]
            else:
                reshape_pattern = [N, C] if matmul.transpose_b else [C, N]
            new_reshape_pattern = new_shape_node_from_shape_nodes(reshape_pattern)
            reshape.in_port(1).get_connection().set_source(new_reshape_pattern.out_port(0))
        else:
            return
コード例 #3
0
    def decompose_shuffle_channel(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        rename_node(node, name + '/to_be_removed')

        shape = Shape(graph, dict(name=name + '/InputShape')).create_node()
        shape.in_port(0).connect(node.in_port(0).get_source())

        # Reshape [input_batch, group, input_channels/group, -1]
        batch = node_to_get_batch_value(shape)
        group = Const(
            graph, dict(name=name + '/Rows',
                        value=int64_array([node.group]))).create_node()
        const = Const(graph, dict(name=name + '/Const',
                                  value=int64_array([-1]))).create_node()

        input_channels = node_to_get_features_dimension_value(shape)
        output_channels = create_op_node_with_second_input(
            graph,
            Div,
            np.int64(node.group), {'name': name + '/Cols'},
            input_node=input_channels)
        i_output_channels = Cast(graph, {
            'name': output_channels.name + '/Convert',
            'dst_type': np.int64
        }).create_node()
        output_channels.out_port(0).connect(i_output_channels.in_port(0))

        reshape_split_dim = new_shape_node_from_shape_nodes(
            [batch, group, i_output_channels, const])
        reshape_split_node = Reshape(
            graph, dict(name=name + '/Reshape_split_')).create_node()
        reshape_split_dim.out_port(0).connect(reshape_split_node.in_port(1))

        # Transpose(0, 2, 1, 3)
        transpose_node = create_op_node_with_second_input(
            graph,
            Transpose,
            int64_array([0, 2, 1, 3]), {'name': name + '/Transpose_'},
            input_node=reshape_split_node)

        # Reshape back to input shape
        reshape_concat = Reshape(graph, dict(name=name)).create_node()
        rename_node(reshape_concat, name)

        shape.out_port(0).connect(reshape_concat.in_port(1))
        transpose_node.out_port(0).connect(reshape_concat.in_port(0))

        # Final connections
        node.in_port(0).get_connection().set_destination(
            reshape_split_node.in_port(0))
        node.out_port(0).get_connection().set_source(
            reshape_concat.out_port(0))
コード例 #4
0
    def replace_pattern(graph: Graph, match: dict):
        """
        Workarounds not supported type of Tile in Inference Engine (Tiles are supported for 2-D or 4-D tensors):
        Searches for Tiles with 3D shapes and covers it with Reshapes.

        Example: Tile (axis=1, tiles=16):
            in_shape: [1,1,101]
            out_shape: [1,16,101]

        Old behaviour:
            Tile -> [1,16,101]
        New behaviour:
            Reshape [1,1,101,1] -> Tile -> [1,16,101,1] -> Reshape [1,16,101]
        """
        node = match['tile']
        name = node.soft_get('name', node.id)

        out_shape = node.out_port(0).data.get_shape()
        assert out_shape is not None, 'Output shape is undefined for {} in back phase'.format(name)
        if out_shape.size != 3:
            return

        inp_shape = node.in_port(0).data.get_shape()
        assert inp_shape is not None, 'Input shape is undefined for {} in back phase'.format(name)

        unsqueeze_dim = Const(graph, {'name': name + '/3D_Tile_Unsqueeze_dim', 'value': int64_array([3])}).create_node()
        unsqueeze = Unsqueeze(graph, {'name': name + '/3D_Tile_Unsqueeze', 'override_output_shape': True}).create_node()
        unsqueeze_dim.out_port(0).connect(unsqueeze.in_port(1))

        const = Const(graph, {'name': name + '/additional_axis', 'value': int64_array([1])}).create_node()
        new_tiles = new_shape_node_from_shape_nodes([node.in_port(1).get_source().node, const])

        node.in_port(1).get_connection().set_source(new_tiles.out_port(0))

        squeeze_dim = Const(graph, {'name': name + '/3D_Tile_Squeeze_dim', 'value': int64_array([3])}).create_node()
        squeeze = Squeeze(graph, {'name': name + '/3D_Tile_Squeeze', 'override_output_shape': True}).create_node()
        squeeze_dim.out_port(0).connect(squeeze.in_port(1))

        source = node.in_port(0).get_source()
        node.in_port(0).get_connection().set_source(unsqueeze.out_port(0))
        unsqueeze.in_port(0).connect(source)

        node.out_port(0).get_connection().set_source(squeeze.out_port(0))
        node.out_port(0).connect(squeeze.in_port(0))

        node['override_output_shape'] = True
        new_tiles['override_output_shape'] = True
        node['need_shape_inference'] = True
コード例 #5
0
    def replace_pattern(graph: Graph, match: Dict[str, Node]):
        node = match['op']
        name = node.soft_get('name', node.id)
        input_shape = node.in_port(0).data.get_shape()
        second_input_shape = node.in_port(1).data.get_shape()

        begin_mask = np.zeros(len(input_shape), dtype=np.int64)
        end_mask = np.zeros(len(input_shape), dtype=np.int64)

        for i in node.axes:
            end_mask[i] = np.int64(1)

        new_axis_mask = np.zeros(len(input_shape), dtype=np.int64)
        shrink_axis_mask = np.zeros(len(input_shape), dtype=np.int64)
        ellipsis_mask = np.zeros(len(input_shape), dtype=np.int64)

        ss = create_op_with_const_inputs(graph, StridedSlice,
                                         port_value_dict={1: np.zeros(len(input_shape), dtype=np.int64)},
                                         op_attrs={'name': 'StridedSlice', 'begin_mask': begin_mask,
                                                   'end_mask': end_mask, 'new_axis_mask': new_axis_mask,
                                                   'shrink_axis_mask': shrink_axis_mask,
                                                   'ellipsis_mask': ellipsis_mask})
        if input_shape.size == second_input_shape.size:
            end = Shape(graph, dict(name=name + '/End')).create_node()
            end.in_port(0).connect(node.in_port(1).get_source())
            ss.in_port(2).connect(end.out_port(0))
        else:
            shape_like, rank_like = get_shape_and_rank_nodes_by_port(node.in_port(1).get_source())
            end_first_part = get_shape_values_by_range_idxs(shape_like, rank_like, 0, node.axes[-1], include_end=True)
            if input_shape.size - 1 == node.axes[-1]:
                ss.in_port(2).connect(end_first_part.out_port(0))
            else:
                shape, rank = get_shape_and_rank_nodes_by_port(node.in_port(0).get_source())
                end_second_part = get_shape_values_by_range_idxs(shape, rank, node.axes[-1], -1, include_begin=False,
                                                                 include_end=True)
                end = new_shape_node_from_shape_nodes([end_first_part, end_second_part])
                ss.in_port(2).connect(end.out_port(0))

        node.in_port(0).get_connection().set_destination(ss.in_port(0))
        node.in_port(1).disconnect()
        node.out_port(0).get_connection().set_source(ss.out_port(0))

        rename_nodes([(node, name + '/ShouldBeDeleted'), (ss, name)])
コード例 #6
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        node = match['flatten']
        name = node.soft_get('name', node.id)

        assert node.has_valid(
            'axis'), 'Flatten {} has no mandatory `axis` attribute'.format(
                name)
        assert node.has_valid(
            'end_axis'
        ), 'Flatten {} has no mandatory `end_axis` attribute'.format(name)

        axis = node.axis
        end_axis = node.end_axis

        if end_axis == -1 and axis >= 0:
            begin_dims = Const(graph, {
                'value': int64_array([0] * axis)
            }).create_node()
            middle_dim = Const(graph, {
                'value': int64_array([-1])
            }).create_node()
            end_dims = Const(graph, {'value': int64_array([])}).create_node()
        else:
            rank = Rank(graph, {'name': name + '/input_rank'}).create_node()
            node.in_port(0).get_source().connect(rank.in_port(0))

            shape = Shape(graph, {'name': name + '/input_shape'}).create_node()
            node.in_port(0).get_source().connect(shape.in_port(0))

            begin_dims = get_shape_values_by_range_idxs(shape=shape,
                                                        rank=rank,
                                                        begin=0,
                                                        end=axis)
            middle_dims = get_shape_values_by_range_idxs(shape=shape,
                                                         rank=rank,
                                                         begin=axis,
                                                         end=end_axis,
                                                         include_end=True)
            end_dims = get_shape_values_by_range_idxs(shape=shape,
                                                      rank=rank,
                                                      begin=end_axis,
                                                      end=-1,
                                                      include_begin=False,
                                                      include_end=True)

            middle_dim = create_op_node_with_second_input(
                graph, ReduceProd, int64_array([0]), {'keep_dims': True})
            middle_dims.out_port(0).connect(middle_dim.in_port(0))

        dim = new_shape_node_from_shape_nodes(
            [begin_dims, middle_dim, end_dims])

        original_name = node.soft_get('name')
        abandoned_name = original_name + '/ShouldBeDeleted'
        reshape_node = Reshape(graph, {}).create_node()
        # Keep node with the same name to avoid confuse with renaming
        rename_nodes([(node, abandoned_name), (reshape_node, original_name)])
        reshape_node.in_port(1).connect(dim.out_port(0))

        node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
        node.in_port(0).get_connection().set_destination(
            reshape_node.in_port(0))
コード例 #7
0
    def mxrepeat_decomposition(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        rename_node(node, name + '/to_be_removed')

        # Unqueeze
        input_rank = Rank(graph, {'name': name + '/Rank'}).create_node()
        node.in_port(0).get_source().connect(input_rank.in_port(0))

        axis = get_canonical_axis_index_node(input_rank, node.axis)
        unsqueeze_axis = create_op_node_with_second_input(
            graph,
            Add,
            int64_array([1]), {'name': name + '/Unsqueeze/Axis'},
            input_node=axis)

        unsqueeze = Unsqueeze(graph, {
            'name': name + '/Unsqueeze'
        }).create_node()
        unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0))

        # Tile (1, 1, ..., repeats, ..., 1)
        # we generate tile array according to the following table:

        # parts:       |      first      |  repeats |  second     |
        # i:           | 0, 1, ..., axis,| axis + 1,| ..., rank+1 |
        # tile_array:  | 1, 1, ...,  1  ,| repeats ,| ...,   1    |

        one = Const(graph, {
            'name': name + '/Broadcast/One',
            'value': int64_array([1])
        }).create_node()
        first_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_first_part'
        }).create_node()
        first_ones.in_port(0).connect(one.out_port(0))
        first_ones.in_port(1).connect(unsqueeze_axis.out_port(0))

        repeats = Const(graph, {
            'name': name + '/repeats',
            'value': int64_array([node.repeats])
        }).create_node()

        second_ones = Broadcast(graph, {
            'name': name + '/Broadcast/Ones_second_part'
        }).create_node()
        second_part_broadcast_shape = Sub(
            graph, {
                'name': name + '/Broadcast/Shape/second_part'
            }).create_node()
        second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0))
        second_part_broadcast_shape.in_port(1).connect(
            unsqueeze_axis.out_port(0))
        second_ones.in_port(0).connect(one.out_port(0))
        second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0))

        tile_repeats = new_shape_node_from_shape_nodes(
            [first_ones, repeats, second_ones])
        tile = Tile(graph, {'name': name + '/Tile'}).create_node()
        tile.in_port(1).connect(tile_repeats.out_port(0))

        # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:])
        # we generate reshape dim array according to the following table:

        # parts:       |    first   |                rep           |  second   |
        # i:           | 0, 1, ... ,|               axis,          | ..., rank |
        # dim_array:   | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] |

        input_shape = Shape(graph, {'name': name + '/Shape'}).create_node()
        node.in_port(0).get_source().connect(input_shape.in_port(0))

        first_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=0,
            end=node.axis,
            include_begin=True,
            include_end=False)

        original_axis_dim = create_op_with_const_inputs(
            graph,
            Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'},
            input_node=input_shape)
        original_axis_dim.in_port(1).connect(axis.out_port(0))

        repeated_dimention = Mul(graph, {
            'name': name + '/RepeatedDim'
        }).create_node()
        repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0))
        repeated_dimention.in_port(1).connect(repeats.out_port(0))

        second_input_shape_part = get_shape_values_by_range_idxs(
            input_shape,
            input_rank,
            begin=node.axis,
            end=-1,
            include_begin=False,
            include_end=True)

        output_shape = new_shape_node_from_shape_nodes([
            first_input_shape_part, repeated_dimention, second_input_shape_part
        ])

        reshape = Reshape(graph, {'name': name}).create_node()
        rename_node(reshape, name)
        reshape.in_port(1).connect(output_shape.out_port(0))

        # Final connections
        node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0))
        tile.in_port(0).connect(unsqueeze.out_port(0))
        reshape.in_port(0).connect(tile.out_port(0))
        node.out_port(0).get_connection().set_source(reshape.out_port(0))
コード例 #8
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(
            group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {
            'name': group_norm_node.name + '/Shape'
        }).create_node()
        initial_shape_op_node.in_port(0).connect(
            group_norm_node.in_port(0).get_source())

        initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node)
        initial_features_dim_node = node_to_get_features_dimension_value(
            initial_shape_op_node)
        initial_spatial_dims_node = node_to_get_spatial_dimensions_value(
            initial_shape_op_node)
        group_size_node = Const(
            graph, {
                'value': int64_array([group_norm_node.num_groups]),
                'name': group_norm_node.name + '/GroupSize'
            }).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(
            graph, {
                'value': np.array([1.0 / group_norm_node.num_groups]),
                'name': group_norm_node.name + '/ReciprocalGroupSize'
            }).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(
            initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(
            group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node = new_shape_node_from_shape_nodes([
            batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node
        ])

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(
            reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(
            graph, {
                'name': group_norm_node.name + '/MVN',
                'across_channels': 1,
                'normalize_variance': 1,
                'eps': group_norm_node.eps
            }).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(
            initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(
            mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(
            add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(
            add_node.out_port(0))
コード例 #9
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['interpolate']

        assert 1 in node.in_ports() and not node.in_port(1).disconnected() and \
               node.in_port(1).data.get_value() is not None, 'Interpolate node {} is corrupted: no 1-port input found'

        # common
        mode = node.mode
        assert mode in ['linear', 'nearest', 'cubic', 'area']
        in_shape = node.in_port(0).data.get_shape()
        assert in_shape is not None and len(in_shape) in [4, 5]
        out_shape = node.out_port(0).data.get_shape()
        assert out_shape is not None and len(out_shape) in [4, 5]
        in_height, in_width = in_shape[2], in_shape[3]
        out_height, out_width = out_shape[2], out_shape[3]
        factor = factor_update(
            None if not node.has_valid('factor') else node.factor,
            [float(out_height) / in_height,
             float(out_width) / in_width], [in_height, in_width],
            [out_height, out_width], node.soft_get('name'))
        update_attrs = {
            'width': out_width,
            'height': out_height,
            'factor': factor,
        }

        if (node.has_valid('shrink_factor')
                and node.has_valid('zoom_factor')) or factor is None:
            del update_attrs['factor']
            if node.has('factor'):
                del node['factor']

        if ((node.has_valid('shrink_factor') and node.shrink_factor != 1) or
            (node.has_valid('zoom_factor') and node.zoom_factor != 1) or 'factor' in update_attrs) \
                and ((not node.has_valid('width') or node.width == 0) and
                     (not node.has_valid('height') or node.height == 0)):
            update_attrs['width'] = 0
            update_attrs['height'] = 0

        # specific
        if mode in ['nearest', 'cubic', 'area'
                    ] or node.has_and_set('convert_to_resample'):
            assert not node.align_corners
            assert node.pads_begin == 0 and node.pads_end == 0
            update_attrs[
                'resample_type'] = InterpolateToInterpOrResample.type_map[mode]
            ResampleOp.update_node_stat(node, update_attrs)

            if not graph.graph[
                    'cmd_params'].keep_shape_ops or graph.graph['fw'] != 'tf':
                node.in_port(1).disconnect()
            else:
                # we avoid making resample non-reshapable for tf version
                shape = Shape(graph, {}).create_node()
                node.in_port(0).get_source().connect(shape.in_port(0))

                batch = node_to_get_batch_value(shape)
                features = node_to_get_features_dimension_value(shape)
                full_shape = new_shape_node_from_shape_nodes(
                    [batch, features,
                     node.in_port(1).get_source().node])
                node.in_port(1).get_connection().set_source(
                    full_shape.out_port(0))
                full_shape['override_output_shape'] = True

        elif mode == 'linear':
            assert len(in_shape) == 4, 'Interp does not support 5D input'
            update_attrs.update({
                'pad_beg': node.pads_begin,
                'pad_end': node.pads_end,
                'align_corners': node.align_corners,
            })
            InterpOp.update_node_stat(node, update_attrs)
            node.in_port(1).disconnect()
コード例 #10
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(
            group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {
            'name': group_norm_node.name + '/Shape'
        }).create_node()
        initial_shape_op_node.in_port(0).connect(
            group_norm_node.in_port(0).get_source())

        initial_shape_op_node_float = Cast(
            graph, {
                'name':
                initial_shape_op_node.name + '/to_float',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()
        initial_shape_op_node.out_port(0).connect(
            initial_shape_op_node_float.in_port(0))

        initial_batch_dim_node = node_to_get_batch_value(
            initial_shape_op_node_float)
        initial_features_dim_node = node_to_get_features_dimension_value(
            initial_shape_op_node_float)
        initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(
            initial_shape_op_node)
        initial_spatial_dims_node = Cast(
            graph, {
                'name':
                initial_spatial_dims_node_int.name + '/to_float',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()
        initial_spatial_dims_node_int.out_port(0).connect(
            initial_spatial_dims_node.in_port(0))

        group_size_node = Const(
            graph, {
                'value': int64_array([group_norm_node.num_groups]),
                'name': group_norm_node.name + '/GroupSize'
            }).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(
            graph, {
                'value': np.array([1.0 / group_norm_node.num_groups]),
                'name': group_norm_node.name + '/ReciprocalGroupSize'
            }).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(
            initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(
            group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node_float = new_shape_node_from_shape_nodes([
            batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node
        ])
        new_shape_node = Cast(graph, {
            'name': new_shape_node_float.name + '/to_int64',
            'dst_type': np.int64
        }).create_node()
        new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0))

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(
            reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(
            graph, {
                'name': group_norm_node.name + '/MVN',
                'normalize_variance': 1,
                'eps': group_norm_node.eps,
                'eps_mode': 'inside_sqrt'
            }).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # MVN axes
        _, rank = get_shape_and_rank_nodes_by_port(
            mvn_node.in_port(0).get_connection().get_source(),
            return_as_a_scalar=True)
        rng = create_op_with_const_inputs(graph, Range, {
            0: int64_array(1),
            2: int64_array(1)
        }, {
            'name': group_norm_node.name + '/Range',
            'output_type': np.int64
        })
        mvn_node.in_port(1).connect(rng.out_port(0))
        rng.in_port(1).connect(rank.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(
            initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(
            mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(
            add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(
            add_node.out_port(0))