示例#1
0
    def replace_sub_graph(self, graph: Graph, match: dict):
        source_connection = match['split'].in_port(0).get_connection()
        source_node = source_connection.get_source().node
        cast_node = match['cast']

        range_node = Range(graph, {
            'name': source_node.id + '/Range'
        }).create_node()
        start_node = Const(graph, {
            'name': range_node.id + '/Start',
            'value': int64_array(0)
        }).create_node()

        step_node = Const(graph, {
            'name': range_node.id + '/Step',
            'value': int64_array(1)
        }).create_node()
        input_shape_node = Shape(graph, {
            'name': start_node.id + '/Shape'
        }).create_node()
        input_shape_node.in_port(0).connect(source_node.out_port(0))

        limit_node_1D = node_to_get_batch_value(input_shape_node)
        limit_node = create_op_node_with_second_input(
            graph, Squeeze, int64_array([0]),
            {'name': source_node.id + '/batch_0D_value'}, limit_node_1D)

        range_node.in_port(0).connect(start_node.out_port(0))
        range_node.in_port(1).connect(limit_node.out_port(0))
        range_node.in_port(2).connect(step_node.out_port(0))
        cast_node.out_port(0).get_connection().set_source(
            range_node.out_port(0))

        graph.remove_nodes_from([node.id for node in match.values()])
示例#2
0
    def decompose_shuffle_channel(node: Node):
        graph = node.graph
        name = node.soft_get('name', node.id)

        rename_node(node, name + '/to_be_removed')

        shape = Shape(graph, dict(name=name + '/InputShape')).create_node()
        shape.in_port(0).connect(node.in_port(0).get_source())

        # Reshape [input_batch, group, input_channels/group, -1]
        batch = node_to_get_batch_value(shape)
        group = Const(
            graph, dict(name=name + '/Rows',
                        value=int64_array([node.group]))).create_node()
        const = Const(graph, dict(name=name + '/Const',
                                  value=int64_array([-1]))).create_node()

        input_channels = node_to_get_features_dimension_value(shape)
        output_channels = create_op_node_with_second_input(
            graph,
            Div,
            np.int64(node.group), {'name': name + '/Cols'},
            input_node=input_channels)
        i_output_channels = Cast(graph, {
            'name': output_channels.name + '/Convert',
            'dst_type': np.int64
        }).create_node()
        output_channels.out_port(0).connect(i_output_channels.in_port(0))

        reshape_split_dim = new_shape_node_from_shape_nodes(
            [batch, group, i_output_channels, const])
        reshape_split_node = Reshape(
            graph, dict(name=name + '/Reshape_split_')).create_node()
        reshape_split_dim.out_port(0).connect(reshape_split_node.in_port(1))

        # Transpose(0, 2, 1, 3)
        transpose_node = create_op_node_with_second_input(
            graph,
            Transpose,
            int64_array([0, 2, 1, 3]), {'name': name + '/Transpose_'},
            input_node=reshape_split_node)

        # Reshape back to input shape
        reshape_concat = Reshape(graph, dict(name=name)).create_node()
        rename_node(reshape_concat, name)

        shape.out_port(0).connect(reshape_concat.in_port(1))
        transpose_node.out_port(0).connect(reshape_concat.in_port(0))

        # Final connections
        node.in_port(0).get_connection().set_destination(
            reshape_split_node.in_port(0))
        node.out_port(0).get_connection().set_source(
            reshape_concat.out_port(0))
示例#3
0
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(
            group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {
            'name': group_norm_node.name + '/Shape'
        }).create_node()
        initial_shape_op_node.in_port(0).connect(
            group_norm_node.in_port(0).get_source())

        initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node)
        initial_features_dim_node = node_to_get_features_dimension_value(
            initial_shape_op_node)
        initial_spatial_dims_node = node_to_get_spatial_dimensions_value(
            initial_shape_op_node)
        group_size_node = Const(
            graph, {
                'value': int64_array([group_norm_node.num_groups]),
                'name': group_norm_node.name + '/GroupSize'
            }).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(
            graph, {
                'value': np.array([1.0 / group_norm_node.num_groups]),
                'name': group_norm_node.name + '/ReciprocalGroupSize'
            }).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(
            initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(
            group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node = new_shape_node_from_shape_nodes([
            batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node
        ])

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(
            reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(
            graph, {
                'name': group_norm_node.name + '/MVN',
                'across_channels': 1,
                'normalize_variance': 1,
                'eps': group_norm_node.eps
            }).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(
            initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(
            mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(
            add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(
            add_node.out_port(0))
    def replace_pattern(graph: Graph, match: dict):
        node = match['interpolate']

        assert 1 in node.in_ports() and not node.in_port(1).disconnected() and \
               node.in_port(1).data.get_value() is not None, 'Interpolate node {} is corrupted: no 1-port input found'

        # common
        mode = node.mode
        assert mode in ['linear', 'nearest', 'cubic', 'area']
        in_shape = node.in_port(0).data.get_shape()
        assert in_shape is not None and len(in_shape) in [4, 5]
        out_shape = node.out_port(0).data.get_shape()
        assert out_shape is not None and len(out_shape) in [4, 5]
        in_height, in_width = in_shape[2], in_shape[3]
        out_height, out_width = out_shape[2], out_shape[3]
        factor = factor_update(
            None if not node.has_valid('factor') else node.factor,
            [float(out_height) / in_height,
             float(out_width) / in_width], [in_height, in_width],
            [out_height, out_width], node.soft_get('name'))
        update_attrs = {
            'width': out_width,
            'height': out_height,
            'factor': factor,
        }

        if (node.has_valid('shrink_factor')
                and node.has_valid('zoom_factor')) or factor is None:
            del update_attrs['factor']
            if node.has('factor'):
                del node['factor']

        if ((node.has_valid('shrink_factor') and node.shrink_factor != 1) or
            (node.has_valid('zoom_factor') and node.zoom_factor != 1) or 'factor' in update_attrs) \
                and ((not node.has_valid('width') or node.width == 0) and
                     (not node.has_valid('height') or node.height == 0)):
            update_attrs['width'] = 0
            update_attrs['height'] = 0

        # specific
        if mode in ['nearest', 'cubic', 'area'
                    ] or node.has_and_set('convert_to_resample'):
            assert not node.align_corners
            assert node.pads_begin == 0 and node.pads_end == 0
            update_attrs[
                'resample_type'] = InterpolateToInterpOrResample.type_map[mode]
            ResampleOp.update_node_stat(node, update_attrs)

            if not graph.graph[
                    'cmd_params'].keep_shape_ops or graph.graph['fw'] != 'tf':
                node.in_port(1).disconnect()
            else:
                # we avoid making resample non-reshapable for tf version
                shape = Shape(graph, {}).create_node()
                node.in_port(0).get_source().connect(shape.in_port(0))

                batch = node_to_get_batch_value(shape)
                features = node_to_get_features_dimension_value(shape)
                full_shape = new_shape_node_from_shape_nodes(
                    [batch, features,
                     node.in_port(1).get_source().node])
                node.in_port(1).get_connection().set_source(
                    full_shape.out_port(0))
                full_shape['override_output_shape'] = True

        elif mode == 'linear':
            assert len(in_shape) == 4, 'Interp does not support 5D input'
            update_attrs.update({
                'pad_beg': node.pads_begin,
                'pad_end': node.pads_end,
                'align_corners': node.align_corners,
            })
            InterpOp.update_node_stat(node, update_attrs)
            node.in_port(1).disconnect()
    def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
        group_norm_node = match['op']
        group_norm_num_input_dims = len(
            group_norm_node.in_port(0).data.get_shape())

        # node computing initial GroupNorm input shape
        initial_shape_op_node = Shape(graph, {
            'name': group_norm_node.name + '/Shape'
        }).create_node()
        initial_shape_op_node.in_port(0).connect(
            group_norm_node.in_port(0).get_source())

        initial_shape_op_node_float = Cast(
            graph, {
                'name':
                initial_shape_op_node.name + '/to_float',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()
        initial_shape_op_node.out_port(0).connect(
            initial_shape_op_node_float.in_port(0))

        initial_batch_dim_node = node_to_get_batch_value(
            initial_shape_op_node_float)
        initial_features_dim_node = node_to_get_features_dimension_value(
            initial_shape_op_node_float)
        initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(
            initial_shape_op_node)
        initial_spatial_dims_node = Cast(
            graph, {
                'name':
                initial_spatial_dims_node_int.name + '/to_float',
                'dst_type':
                data_type_str_to_np(graph.graph['cmd_params'].data_type)
            }).create_node()
        initial_spatial_dims_node_int.out_port(0).connect(
            initial_spatial_dims_node.in_port(0))

        group_size_node = Const(
            graph, {
                'value': int64_array([group_norm_node.num_groups]),
                'name': group_norm_node.name + '/GroupSize'
            }).create_node()

        # calculate "features // group_size" value
        reciprocal_group_size_node = Const(
            graph, {
                'value': np.array([1.0 / group_norm_node.num_groups]),
                'name': group_norm_node.name + '/ReciprocalGroupSize'
            }).create_node()

        c_div_g_node = Mul(graph, {}).create_node()
        c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0))
        c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0))

        batch_mul_group_size_node = Mul(graph, {}).create_node()
        batch_mul_group_size_node.in_port(0).connect(
            initial_batch_dim_node.out_port(0))
        batch_mul_group_size_node.in_port(1).connect(
            group_size_node.out_port(0))

        # create new node which concatenates several dims to one
        new_shape_node_float = new_shape_node_from_shape_nodes([
            batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node
        ])
        new_shape_node = Cast(graph, {
            'name': new_shape_node_float.name + '/to_int64',
            'dst_type': np.int64
        }).create_node()
        new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0))

        reshape_for_mvn_node = Reshape(graph, {}).create_node()

        group_norm_node.in_port(0).get_connection().set_destination(
            reshape_for_mvn_node.in_port(0))
        reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0))

        # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc
        gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64)
        gamma_beta_shape[1] = -1

        gamma_value = group_norm_node.in_port(1).get_source().data.get_value()
        beta_value = group_norm_node.in_port(2).get_source().data.get_value()
        assert gamma_value is not None, 'The gamma should be constant'
        assert beta_value is not None, 'The beta should be constant'
        gamma_value = np.reshape(gamma_value, gamma_beta_shape)
        group_norm_node.in_port(1).get_source().data.set_value(gamma_value)
        beta_value = np.reshape(beta_value, gamma_beta_shape)
        group_norm_node.in_port(2).get_source().data.set_value(beta_value)

        # MVN
        mvn_node = MVN(
            graph, {
                'name': group_norm_node.name + '/MVN',
                'normalize_variance': 1,
                'eps': group_norm_node.eps,
                'eps_mode': 'inside_sqrt'
            }).create_node()
        mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0))

        # MVN axes
        _, rank = get_shape_and_rank_nodes_by_port(
            mvn_node.in_port(0).get_connection().get_source(),
            return_as_a_scalar=True)
        rng = create_op_with_const_inputs(graph, Range, {
            0: int64_array(1),
            2: int64_array(1)
        }, {
            'name': group_norm_node.name + '/Range',
            'output_type': np.int64
        })
        mvn_node.in_port(1).connect(rng.out_port(0))
        rng.in_port(1).connect(rank.out_port(0))

        # reshape to the initial shape before multiplying with gamma and adding beta
        reshape_to_initial_shape_node = Reshape(graph, {}).create_node()
        reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0))
        reshape_to_initial_shape_node.in_port(1).connect(
            initial_shape_op_node.out_port(0))

        mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node()
        mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0))
        group_norm_node.in_port(1).get_connection().set_destination(
            mul_node.in_port(1))

        add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node()
        add_node.in_port(0).connect(mul_node.out_port(0))
        group_norm_node.in_port(2).get_connection().set_destination(
            add_node.in_port(1))

        group_norm_node.out_port(0).get_connection().set_source(
            add_node.out_port(0))