def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, layer_past=None):
        """PanGu Alpha model"""
        if not self.use_past:
            layer_past = self.past

        input_embedding, embedding_table = self.word_embedding(input_ids)
        if not self.eod_reset:
            batch_size, seq_length = F.shape(input_ids)
            input_position = F.tuple_to_array(F.make_range(seq_length))
            input_position = P.Tile()(input_position, (batch_size, 1))
            attention_mask = self.get_attention_mask(input_mask)
            
        position_embedding = self.position_embedding(input_position)
        hidden_states = self.add(input_embedding, position_embedding)
        hidden_states = self.dropout(hidden_states)
        hidden_states = P.Cast()(hidden_states, mstype.float16)
        attention_mask = self.expand_dims(attention_mask, 1)

        present_layer = ()
        for i in range(self.num_layers):
            hidden_states, present = self.blocks[i](hidden_states,
                                                    attention_mask, layer_past)
            present_layer = present_layer + (present,)

        output_state = self.layernorm(hidden_states)
        output_state = F.cast(output_state, self.dtype)

        top_query_hidden_states = self.top_query_embedding(input_position)
        output_state, present = self.top_query_layer(output_state, top_query_hidden_states,
                                                     attention_mask, layer_past)
        present_layer = present_layer + (present,)

        return output_state, present_layer, embedding_table
示例#2
0
    def construct(self, input_ids, input_mask, layer_past=None):
        """GPT model"""
        if not self.use_past:
            layer_past = self.past

        input_embedding, embedding_table = self.word_embedding(input_ids)

        batch_size, seq_length = F.shape(input_ids)
        input_position = F.tuple_to_array(F.make_range(seq_length))
        input_position = P.Tile()(input_position, (batch_size, 1))

        position_embedding = self.position_embedding(input_position)
        hidden_states = input_embedding + position_embedding

        hidden_states = P.Cast()(hidden_states, mstype.float16)
        attention_mask = self.get_attention_mask(input_mask)
        attention_mask = P.ExpandDims()(attention_mask, 1)

        present_layer = ()
        for i in range(self.num_layers):
            hidden_states, present = self.blocks[i](hidden_states,
                                                    attention_mask, layer_past)
            present_layer = present_layer + (present, )

        output_state = self.layernorm(hidden_states)
        return output_state, present_layer, embedding_table
示例#3
0
def transpose(a, axes=None):
    """
    Reverse or permute the axes of an array; returns the modified array.

    Args:
        a (Tensor): a tensor to be transposed
        axes Union[None, tuple, list]: the axes order, if axes is None, transpose
        the entire tensor. Default is None.

    Returns:
        Tensor, the transposed tensor array.

    Supported Platforms:
        ``Ascend`` ``GPU`` ``CPU``

    Examples:
        >>> import mindspore.numpy as np
        >>> x = np.ones((1,2,3))
        >>> x = np.transpose(x)
        >>> print(x,shape)
        (3,2,1)
    """
    if axes is None:
        shape = F.shape(a)
        length = F.tuple_len(shape)
        perm = F.make_range(0, length)
        new_order = F.tuple_reversed(perm)
        return P.Transpose()(a, new_order)

    axes = _check_shape_compile(axes)
    return P.Transpose()(a, axes)
示例#4
0
def swapaxes(x, axis1, axis2):
    """
    Interchange two axes of a tensor.

    Args:
        x (Tensor): A Tensor to be transposed.
        axis1 (int): First axis.
        axis2 (int): Second axis.

    Returns:
        Transposed Tensor. Has the same data type as the original tensor x.

    Raises:
        TypeError: If axis1 or axis2 is not integer.
        ValueError: If axis1 or axis2 is not in the range from -ndim to ndim-1.

    Supported Platforms:
        ``Ascend`` ``GPU`` ``CPU``

    Examples:
        >>> import mindspore
        >>> import mindspore.numpy as mnp
        >>> from mindspore import Tensor
        >>> import numpy as onp
        >>> input_x = Tensor(onp.ones((2,3,4)), mindspore.float32)
        >>> output = mnp.swapaxes(x, 0, 2)
        >>> print(output.shape)
        (4,3,2)
    """
    _check_is_int(axis1)
    _check_is_int(axis2)

    shape = F.shape(x)
    ndim = F.tuple_len(shape)

    axes = _check_axes_range((axis1, axis2), ndim)
    axis1, axis2 = axes[0], axes[1]

    if axis1 == axis2:
        return x
    if axis1 > axis2:
        axis1, axis2 = axis2, axis1

    perm = F.make_range(0, ndim)
    new_perm = None
    if axis2 + 1 < ndim:
        new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
            perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:]
    else:
        new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \
            perm[axis1+1:axis2] + perm[axis1:axis1+1]

    return P.Transpose()(x, new_perm)
示例#5
0
 def construct(self):
     """position matrix generator"""
     range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32)
     range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1))
     tile_row_out = self.tile(range_vec_row_out, (self._length,))
     tile_col_out = self.tile(range_vec_col_out, (1, self._length))
     range_mat_out = self.range_mat(tile_row_out, (self._length, self._length))
     transpose_out = self.range_mat(tile_col_out, (self._length, self._length))
     distance_mat = self.sub(range_mat_out, transpose_out)
     distance_mat_clipped = C.clip_by_value(distance_mat,
                                            self._min_relative_position,
                                            self._max_relative_position)
     # Shift values to be >=0. Each integer still uniquely identifies a
     # relative position difference.
     final_mat = distance_mat_clipped + self._max_relative_position
     return final_mat
 def construct(self, input_tensor, output_weights, positions):
     """Get output log_probs"""
     rng = F.tuple_to_array(F.make_range(P.Shape()(input_tensor)[0]))
     flat_offsets = self.reshape(rng * self.seq_length_tensor,
                                 self.shape_flat_offsets)
     flat_position = self.reshape(positions + flat_offsets, self.last_idx)
     flat_sequence_tensor = self.reshape(input_tensor,
                                         self.shape_flat_sequence_tensor)
     input_tensor = self.gather(flat_sequence_tensor, flat_position, 0)
     input_tensor = self.cast(input_tensor, self.compute_type)
     output_weights = self.cast(output_weights, self.compute_type)
     input_tensor = self.dense(input_tensor)
     input_tensor = self.layernorm(input_tensor)
     logits = self.matmul(input_tensor, output_weights)
     logits = self.cast(logits, self.dtype)
     logits = logits + self.output_bias
     log_probs = self.log_softmax(logits)
     return log_probs
示例#7
0
def _check_axis_valid(axes, ndim):
    """
    Checks axes are valid given ndim, and returns axes that can be passed
    to the built-in operator (non-negative, int or tuple)
    """
    if isinstance(axes, int):
        _ = _check_axis_in_range(axes, ndim)
        return (axes % ndim, )
    if isinstance(axes, tuple):
        for axis in axes:
            _ = _check_axis_in_range(axis, ndim)
        axes = tuple(map(lambda x: x % ndim, axes))
        if all(axes.count(el) <= 1 for el in axes):
            return axes
    if axes is None:
        axes = F.make_range(ndim)
        return axes
    raise ValueError('duplicate value in \'axis\'')
示例#8
0
 def get_axis(self, x):
     shape = F.shape(x)
     length = F.tuple_len(shape)
     perm = F.make_range(0, length)
     return perm
示例#9
0
def get_axis(x):
    shape_op = P.Shape()
    shape = shape_op(x)
    length = F.tuple_len(shape)
    perm = F.make_range(0, length)
    return perm
示例#10
0
 def construct(self, x):
     batch_size, seq_length = x.shape[0], x.shape[1]
     input_position = F.tuple_to_array(F.make_range(seq_length))
     # input_position = P.Tile()(input_position, (batch_size, 1))
     return self.emb(input_position)
示例#11
0
def rollaxis(x, axis, start=0):
    """
    Roll the specified axis backwards, until it lies in the given position.
    The positions of the other axes do not change relative to one another.

    Args:
        x (Tensor): A Tensor to be transposed.
        axis (int): The axis to be rolled.
        start (int):
            - When start >= 0:
                - When start <= axis: the axis is rolled back until it lies in this position (start).
                - When start > axis: the axis is rolled until it lies before this position (start).
            - When start < 0: the start will be normalized as follows:
                start ........... Normalized start
                -(x.ndim+1)       raise ValueError
                -x.ndim           0
                ...               ...
                -1                x.ndim-1
                0                 0
                ...               ...
                x.ndim            x.ndim
                x.ndim+1          raise ValueError

    Returns:
        Transposed Tensor. Has the same data type as the original tensor x.

    Supported Platforms:
        ``Ascend`` ``GPU`` ``CPU``

    Raises:
        TypeError: If axis or start is not integer.
        ValueError: If axis is not in the range from -ndim to ndim-1 or
            start is not in the range from -ndim to ndim.

    Examples:
        >>> import mindspore
        >>> import mindspore.numpy as mnp
        >>> from mindspore import Tensor
        >>> import numpy as onp
        >>> input_x = Tensor(onp.ones((2,3,4)), mindspore.float32)
        >>> output = mnp.rollaxis(x, 0, 2)
        >>> print(output.shape)
        (3,2,4)
    """
    _check_is_int(axis)
    _check_is_int(start)

    shape = F.shape(x)
    ndim = F.tuple_len(shape)

    axis = _check_axes_range(axis, ndim)
    start = _check_start_normalize(start, ndim)
    if start - axis >= 0 and start - axis <= 1:
        return x
    perm = F.make_range(0, ndim)
    new_perm = None
    if start < axis:
        if axis + 1 < ndim:
            new_perm = perm[0:start] + perm[axis:axis+1] + \
                perm[start:axis] + perm[axis+1:]
        else:
            new_perm = perm[0:start] + perm[axis:axis + 1] + perm[start:axis]
    if start > axis:
        if start < ndim:
            new_perm = perm[0:axis] + perm[axis+1:start] + \
                perm[axis:axis+1] + perm[start:]
        else:
            new_perm = perm[0:axis] + perm[axis+1:start] + \
                perm[axis:axis+1]

    return P.Transpose()(x, new_perm)