Exemplo n.º 1
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        fbn = match['fbn']
        input = fbn.in_node(0)
        log.debug('Found potential MVN pattern after {} with name {}'.format(input.op, input.name))
        if input.id != match['mean'].in_node(0).id or input.id != match['sqdiff'].in_node(0).id:
            return

        log.debug('Confirmed MVN pattern after {} with name {}'.format(input.op, input.name))

        mvn = MVN(graph, dict(
            name=fbn.name + '/MVN_',
            eps=fbn.eps,
            eps_mode='outside_sqrt',
            normalize_variance=1
        ))
        mvn.attrs['old_infer'] = mvn.attrs['infer']
        mvn.attrs['infer'] = __class__.infer

        mul = Mul(graph, dict(operation='mul', name=fbn.name + '/Mul_'))
        add = Add(graph, dict(operation='sum', name=fbn.name + '/Add_'))

        input_gamma = fbn.in_node(1)
        input_beta = fbn.in_node(2)

        mean_reduction = match['mean'].in_node(1)
        variance_reduction = match['variance'].in_node(1)

        new_subgraph = add.create_node([
            mul.create_node([
                mvn.create_node([input, mean_reduction, variance_reduction]),
                input_gamma
            ]),
            input_beta
        ])
        fbn.replace_node(new_subgraph)
    def replace_op(self, graph: Graph, node: Node):

        # Add new nodes
        mvn = MVN(graph, {
            'eps': node.epsilon,
            'name': node.name + '/Ins_Norm/MVN_',
        }).create_node()
        mul = Mul(graph, {
            'axis': 1,
            'name': node.name + '/Ins_Norm/mul_'
        }).create_node()
        add = Add(graph, {
            'axis': 1,
            'name': node.name + '/Ins_Norm/add_'
        }).create_node()

        # Connect nodes
        node.in_port(0).get_connection().set_destination(mvn.in_port(0))
        node.in_port(1).get_connection().set_destination(mul.in_port(1))
        node.in_port(2).get_connection().set_destination(add.in_port(1))

        mvn.out_port(0).connect(mul.in_port(0))
        mul.out_port(0).connect(add.in_port(0))

        return [add.id]
Exemplo n.º 3
0
    def extract(cls, node):
        name = node.soft_get('name', node.id)
        axes = onnx_attr(node,
                         'axes',
                         'ints',
                         default=np.array([0, 2, 3], dtype=np.int64),
                         dst_type=lambda x: np.array(x, dtype=np.int64))

        if axes is not None:
            if 0 in axes:
                raise Error('Reduction over the batch dimension in node "{}" '
                            'is not supported by the backend.'.format(name))
            # Dimension 4 (if it's present in the input tensor)
            # should also be in the list of axes for reduction.
            # This case will be handled at the MVN Op side,
            # 'cause input shape is not available at that stage.
            for i in (2, 3):
                if i not in axes:
                    raise Error(
                        'Reduction over spatial dimensions in node "{}" '
                        'is obligatory for the backend.'.format(name))

        attrs = {
            'eps': 1e-9,
            'across_channels': 1 if 1 in axes else 0,
            'normalize_variance': 1,
            'axes': axes
        }
        MVN.update_node_stat(node, attrs)
        return cls.enabled
Exemplo n.º 4
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in graph.get_op_nodes(op='MVNCaffe'):
            node_name = node.soft_get('name', node.id)

            start_axis = 2
            if node['across_channels'] == 1:
                start_axis = 1

            rank = Rank(graph, {'name': node_name + '/Rank'}).create_node()

            # create range of axes based on `start_axis` and rank of input
            rng = create_op_with_const_inputs(graph, Range, {
                0: int64_array(start_axis),
                2: int64_array(1)
            }, {
                'name': node_name + '/Range',
                'output_type': np.int64
            })
            rng.in_port(1).connect(rank.out_port(0))

            new_mvn = MVN(
                graph, {
                    'eps': node.soft_get('eps', 1e-9),
                    'eps_mode': 'inside_sqrt',
                    'normalize_variance': node.soft_get(
                        'normalize_variance', 1)
                }).create_node([node.in_port(0).get_source().node, rng])
            new_mvn.in_port(0).get_connection().add_destination(
                rank.in_port(0))
            node.out_port(0).get_connection().set_source(new_mvn.out_port(0))
            rename_nodes([(node, node_name + '/tbd'), (new_mvn, node_name)])

            graph.remove_node(node.id)
Exemplo n.º 5
0
    def extract(cls, node):
        proto_layer = node.pb
        param = proto_layer.mvn_param

        attrs = collect_attributes(param)

        # update the attributes of the node
        MVN.update_node_stat(node, attrs)
        return cls.enabled
Exemplo n.º 6
0
    def replace_op(self, graph: Graph, node: Node):
        prefix = node.name + '/InstanceNormalization'
        mvn = MVN(graph, dict(name=prefix + '/MVN', eps=node.epsilon))
        mul = Mul(graph, dict(name=prefix + '/Mul', axis=1))
        add = Add(graph, dict(name=prefix + '/Add', axis=1))

        new_subgraph = add.create_node([
            mul.create_node(
                [mvn.create_node([node.in_node(0)]),
                 node.in_node(1)]),
            node.in_node(2)
        ])

        return [new_subgraph.id]
    def replace_sub_graph(graph: Graph, match: dict):
        mvn = MVN(graph, dict(
            name=match['truediv'].name + '/MVN_',
            eps_mode='outside_sqrt',
            normalize_variance=1
        ))
        mvn.attrs['old_infer'] = mvn.attrs['infer']
        mvn.attrs['infer'] = __class__.infer

        mean_reduction = match['mean'].in_node(1)
        variance_reduction = match['variance'].in_node(1)
        pow2 = match['pow'].in_node(1)
        eps = match['add'].in_node(0 if match['add'].in_node(0).id != match['variance'].id else 1)

        new_subgraph = mvn.create_node([match['mean'].in_node(0), mean_reduction, variance_reduction, pow2, eps])

        match['truediv'].replace_node(new_subgraph)
Exemplo n.º 8
0
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['op']
        node.is_training = False

        shape = node.in_port(1).data.get_shape()
        assert shape is not None, 'The shape of scale input of the BatchNorm node {} is not defined'.format(node.name)

        bn_mean = Const(graph, {'name': node.name + '/mean', 'value': np.zeros(shape, dtype=np.float32),
                                'override_output_shape': True}).create_node()
        bn_std = Const(graph, {'name': node.name + '/std', 'value': np.ones(shape, dtype=np.float32),
                               'override_output_shape': True}).create_node()
        node.in_port(3).get_connection().set_source(bn_mean.out_port(0))
        node.in_port(4).get_connection().set_source(bn_std.out_port(0))

        # save the original shape
        original_shape = Shape(graph, {'name': node.in_port(0).get_source().node.soft_get('name')}).create_node()
        original_shape.in_port(0).connect(node.in_port(0).get_source())

        mvn = MVN(graph, {'name': node.name + '/mvn_', 'eps': node.soft_get('eps', 1e-6),
                          'override_output_shape': True}).create_node()
        node.in_port(0).get_connection().insert_node(mvn)

        reshape_4d = create_op_node_with_second_input(graph, Reshape, int64_array([1, -1, 0, 0]),
                                                      {'override_output_shape': True,
                                                       'name': node.soft_get('name') + '/fused_batch_and_channels'})
        mvn.in_port(0).get_connection().insert_node(reshape_4d)

        # restore original shape
        reshape_back = Reshape(graph, {'name': mvn.soft_get('name') + '/restore_shape',
                                       'override_output_shape': True}).create_node()
        reshape_back.in_port(1).connect(original_shape.out_port(0))
        mvn.out_port(0).get_connection().insert_node(reshape_back)
    def extract(cls, node):
        name = node.soft_get('name', node.id)
        axes = onnx_attr(node, 'axes', 'ints',
                         default=np.array([0, 2, 3], dtype=np.int64),
                         dst_type=lambda x: np.array(x, dtype=np.int64))

        axes = Const(node.graph, {'value': axes, 'name': name + '/Axes'}).create_node()
        node.add_input_port(1, skip_if_exist=True)
        node.in_port(1).connect(axes.out_port(0))

        attrs = {
            'eps': 1e-9,
            'normalize_variance': 1,
            'eps_mode': 'outside_sqrt'
        }

        MVN.update_node_stat(node, attrs)
        return cls.enabled
Exemplo n.º 10
0
    def replace_sub_graph(graph: Graph, match: dict):
        mvn = MVN(
            graph,
            dict(name=match['truediv'].name + '/MVN_',
                 required_reduction_indices=[1, 2]
                 if graph.graph['layout'] == 'NHWC' else [2, 3]))
        mvn.attrs['old_infer'] = mvn.attrs['infer']
        mvn.attrs['infer'] = __class__.infer

        mean_reduction = match['mean'].in_node(1)
        variance_reduction = match['variance'].in_node(1)
        pow2 = match['pow'].in_node(1)
        eps = match['add'].in_node(
            0 if match['add'].in_node(0).id != match['variance'].id else 1)

        new_subgraph = mvn.create_node([
            match['mean'].in_node(0), mean_reduction, variance_reduction, pow2,
            eps
        ])

        match['truediv'].replace_node(new_subgraph)
Exemplo n.º 11
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        inp = match['pool0']
        inp_port = inp.in_port(0).get_source()

        # take/check the values of the add, pow and axes for ReduceMean
        pow_param = match['pow_param']
        add_param = match['add_param']
        if add_param.value.size == 1 and pow_param.value.size == 1 and add_param.value.item() <= 1e-05 \
                and pow_param.value.item() == 0.5 and match['pool0_param'].value == match['pool1_param'].value:
            log.debug('Found LayerNorm pattern after {} with name {}'.format(
                inp_port.node.op, inp_port.node.name))
            mvn = MVN(
                graph, {
                    'eps': add_param.value.item(),
                    'axes': match['pool1_param'].value,
                    'normalize_variance': 1
                }).create_node()
            div_name = match['div'].soft_get('name', match['div'].id)
            rename_nodes([(match['div'], div_name + '/to_be_removed'),
                          (mvn, div_name)])

            inp_port.connect(mvn.in_port(0))
            match['div'].out_port(0).get_connection().set_source(
                mvn.out_port(0))
Exemplo n.º 12
0
    def replace_op(self, graph: Graph, node: Node):
        axis = Const(graph, {'value': int64_array([-1])}).create_node()
        mvn = MVN(
            graph,
            dict(name=node.name + '/mvn',
                 eps=node.module.eps,
                 normalize_variance=True,
                 eps_mode='inside_sqrt')).create_node([node.in_node(0), axis])

        weight = node.module.weight.detach().numpy()
        bias = node.module.bias.detach().numpy()

        w = Const(graph, {'value': weight}).create_node()
        b = Const(graph, {'value': bias}).create_node()
        mul = Mul(graph, dict(name=node.name + '/mul')).create_node([mvn, w])
        add = Add(graph, dict(name=node.name + '/add')).create_node([mul, b])
        return [add.id]
    def replace_op(self, graph: Graph, node: Node):
        name = node.soft_get('name', node.id)

        # create range of axes for MVN based on `start_axis` and rank of input
        rank = Rank(graph, {'name': name + '/Rank'}).create_node()
        rng = create_op_with_const_inputs(graph, Range, {
            0: int64_array(2),
            2: int64_array(1)
        }, {
            'name': name + '/Range',
            'output_type': np.int64
        })
        mvn = MVN(
            graph, {
                'eps': node.epsilon,
                'eps_mode': 'inside_sqrt',
                'normalize_variance': 1,
                'name': name + '/Ins_Norm/MVN_',
            }).create_node()
        node.in_port(0).get_connection().set_destination(mvn.in_port(0))
        rng.out_port(0).connect(mvn.in_port(1))
        mul = Mul(graph, {
            'axis': 1,
            'name': name + '/Ins_Norm/mul_'
        }).create_node()
        mvn.out_port(0).connect(mul.in_port(0))
        node.in_port(1).get_connection().set_destination(mul.in_port(1))
        add = Add(graph, {
            'axis': 1,
            'name': name + '/Ins_Norm/add_'
        }).create_node()
        mul.out_port(0).connect(add.in_port(0))
        node.in_port(2).get_connection().set_destination(add.in_port(1))

        mvn.in_port(0).get_connection().add_destination(rank.in_port(0))
        rng.in_port(1).connect(rank.out_port(0))

        rename_nodes([(node, name + '/TBD'), (add, name)])

        return [add.id]
Exemplo n.º 14
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        if not check_applicability(match):
            return

        reshape = match['reshape']
        div_name = match['division'].name

        input_shape = Shape(graph, dict(name=div_name +
                                        '/shape/MVN_T_')).create_node()
        shape_of_reshape = reshape.in_port(
            1).get_connection().get_source().node.value
        c1, c2 = shape_of_reshape[1], shape_of_reshape[2]
        c = c1 * c2

        new_reshape = create_op_node_with_second_input(
            graph, Reshape, int64_array([0, 0, 0, c1, c2]),
            dict(name=div_name + '/first_reshape/MVN_T_'))
        permute_order = int64_array([0, 1, 2, 4, 3])
        first_permute = create_op_node_with_second_input(
            graph, Transpose, permute_order,
            dict(name=div_name + '/first_permute/MVN_T_'), new_reshape)

        add = match['add']
        variance = match['variance']
        eps_port_num = 0 if add.in_port(
            0).get_connection().get_source().node.id != variance.id else 1
        eps = add.in_port(eps_port_num).get_connection().get_source().node
        mvn_node = MVN(
            graph,
            dict(name=div_name + '/MVN/MVN_T_',
                 required_reduction_indices=[1, 2, 3],
                 eps=eps.value)).create_node()
        first_permute.out_port(0).connect(mvn_node.in_port(0))

        second_permute = create_op_node_with_second_input(
            graph, Transpose, permute_order,
            dict(name=div_name + '/second_permute/MVN_T_'), mvn_node)
        new_reshape2 = Reshape(graph,
                               dict(name=div_name +
                                    '/second_reshape/MVN_T_')).create_node()
        second_permute.out_port(0).connect(new_reshape2.in_port(0))
        gamma_val = np.reshape(
            match['gamma_identity'].in_port(
                0).get_connection().get_source().node.value,
            int64_array([1, 1, 1, c]))
        new_mul = create_op_node_with_second_input(
            graph, Mul, gamma_val, dict(name=match['mul'].name + '/MVN_T_'),
            new_reshape2)
        beta_val = np.reshape(
            match['beta_identity'].in_port(
                0).get_connection().get_source().node.value,
            int64_array([1, 1, 1, c]))
        new_add2 = create_op_node_with_second_input(
            graph, Add, beta_val, dict(name=match['add2'].name + '/MVN_T_'),
            new_mul)

        transpose_connection = match['transpose'].in_port(0).get_connection()
        before_transpose = transpose_connection.get_source().node
        transpose_connection.set_destination(new_reshape.in_port(0))
        input_shape.out_port(0).connect(new_reshape2.in_port(1))
        before_transpose.out_port(0).connect(input_shape.in_port(0))
        match['transpose2'].out_port(0).get_connection().set_source(
            new_add2.out_port(0))
Exemplo n.º 15
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(
            group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {
            'name': group_norm_node.name + '/Shape'
        }).create_node()
        initial_shape_op_node.in_port(0).connect(
            group_norm_node.in_port(0).get_source())

        initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node)
        initial_features_dim_node = node_to_get_features_dimension_value(
            initial_shape_op_node)
        initial_spatial_dims_node = node_to_get_spatial_dimensions_value(
            initial_shape_op_node)
        group_size_node = Const(
            graph, {
                'value': int64_array([group_norm_node.num_groups]),
                'name': group_norm_node.name + '/GroupSize'
            }).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(
            graph, {
                'value': np.array([1.0 / group_norm_node.num_groups]),
                'name': group_norm_node.name + '/ReciprocalGroupSize'
            }).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(
            initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(
            group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node = new_shape_node_from_shape_nodes([
            batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node
        ])

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(
            reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(
            graph, {
                'name': group_norm_node.name + '/MVN',
                'across_channels': 1,
                'normalize_variance': 1,
                'eps': group_norm_node.eps
            }).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(
            initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(
            mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(
            add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(
            add_node.out_port(0))
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(
            group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {
            'name': group_norm_node.name + '/Shape'
        }).create_node()
        initial_shape_op_node.in_port(0).connect(
            group_norm_node.in_port(0).get_source())

        initial_shape_op_node_float = Cast(
            graph, {
                'name':
                initial_shape_op_node.name + '/to_float',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()
        initial_shape_op_node.out_port(0).connect(
            initial_shape_op_node_float.in_port(0))

        initial_batch_dim_node = node_to_get_batch_value(
            initial_shape_op_node_float)
        initial_features_dim_node = node_to_get_features_dimension_value(
            initial_shape_op_node_float)
        initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(
            initial_shape_op_node)
        initial_spatial_dims_node = Cast(
            graph, {
                'name':
                initial_spatial_dims_node_int.name + '/to_float',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()
        initial_spatial_dims_node_int.out_port(0).connect(
            initial_spatial_dims_node.in_port(0))

        group_size_node = Const(
            graph, {
                'value': int64_array([group_norm_node.num_groups]),
                'name': group_norm_node.name + '/GroupSize'
            }).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(
            graph, {
                'value': np.array([1.0 / group_norm_node.num_groups]),
                'name': group_norm_node.name + '/ReciprocalGroupSize'
            }).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(
            initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(
            group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node_float = new_shape_node_from_shape_nodes([
            batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node
        ])
        new_shape_node = Cast(graph, {
            'name': new_shape_node_float.name + '/to_int64',
            'dst_type': np.int64
        }).create_node()
        new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0))

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(
            reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(
            graph, {
                'name': group_norm_node.name + '/MVN',
                'normalize_variance': 1,
                'eps': group_norm_node.eps,
                'eps_mode': 'inside_sqrt'
            }).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # MVN axes
        _, rank = get_shape_and_rank_nodes_by_port(
            mvn_node.in_port(0).get_connection().get_source(),
            return_as_a_scalar=True)
        rng = create_op_with_const_inputs(graph, Range, {
            0: int64_array(1),
            2: int64_array(1)
        }, {
            'name': group_norm_node.name + '/Range',
            'output_type': np.int64
        })
        mvn_node.in_port(1).connect(rng.out_port(0))
        rng.in_port(1).connect(rank.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(
            initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(
            mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(
            add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(
            add_node.out_port(0))