def construct(self, input_ids, input_mask, input_position=None, attention_mask=None, layer_past=None): """PanGu Alpha model""" if not self.use_past: layer_past = self.past input_embedding, embedding_table = self.word_embedding(input_ids) if not self.eod_reset: batch_size, seq_length = F.shape(input_ids) input_position = F.tuple_to_array(F.make_range(seq_length)) input_position = P.Tile()(input_position, (batch_size, 1)) attention_mask = self.get_attention_mask(input_mask) position_embedding = self.position_embedding(input_position) hidden_states = self.add(input_embedding, position_embedding) hidden_states = self.dropout(hidden_states) hidden_states = P.Cast()(hidden_states, mstype.float16) attention_mask = self.expand_dims(attention_mask, 1) present_layer = () for i in range(self.num_layers): hidden_states, present = self.blocks[i](hidden_states, attention_mask, layer_past) present_layer = present_layer + (present,) output_state = self.layernorm(hidden_states) output_state = F.cast(output_state, self.dtype) top_query_hidden_states = self.top_query_embedding(input_position) output_state, present = self.top_query_layer(output_state, top_query_hidden_states, attention_mask, layer_past) present_layer = present_layer + (present,) return output_state, present_layer, embedding_table
def construct(self, input_ids, input_mask, layer_past=None): """GPT model""" if not self.use_past: layer_past = self.past input_embedding, embedding_table = self.word_embedding(input_ids) batch_size, seq_length = F.shape(input_ids) input_position = F.tuple_to_array(F.make_range(seq_length)) input_position = P.Tile()(input_position, (batch_size, 1)) position_embedding = self.position_embedding(input_position) hidden_states = input_embedding + position_embedding hidden_states = P.Cast()(hidden_states, mstype.float16) attention_mask = self.get_attention_mask(input_mask) attention_mask = P.ExpandDims()(attention_mask, 1) present_layer = () for i in range(self.num_layers): hidden_states, present = self.blocks[i](hidden_states, attention_mask, layer_past) present_layer = present_layer + (present, ) output_state = self.layernorm(hidden_states) return output_state, present_layer, embedding_table
def transpose(a, axes=None): """ Reverse or permute the axes of an array; returns the modified array. Args: a (Tensor): a tensor to be transposed axes Union[None, tuple, list]: the axes order, if axes is None, transpose the entire tensor. Default is None. Returns: Tensor, the transposed tensor array. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: >>> import mindspore.numpy as np >>> x = np.ones((1,2,3)) >>> x = np.transpose(x) >>> print(x,shape) (3,2,1) """ if axes is None: shape = F.shape(a) length = F.tuple_len(shape) perm = F.make_range(0, length) new_order = F.tuple_reversed(perm) return P.Transpose()(a, new_order) axes = _check_shape_compile(axes) return P.Transpose()(a, axes)
def swapaxes(x, axis1, axis2): """ Interchange two axes of a tensor. Args: x (Tensor): A Tensor to be transposed. axis1 (int): First axis. axis2 (int): Second axis. Returns: Transposed Tensor. Has the same data type as the original tensor x. Raises: TypeError: If axis1 or axis2 is not integer. ValueError: If axis1 or axis2 is not in the range from -ndim to ndim-1. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Examples: >>> import mindspore >>> import mindspore.numpy as mnp >>> from mindspore import Tensor >>> import numpy as onp >>> input_x = Tensor(onp.ones((2,3,4)), mindspore.float32) >>> output = mnp.swapaxes(x, 0, 2) >>> print(output.shape) (4,3,2) """ _check_is_int(axis1) _check_is_int(axis2) shape = F.shape(x) ndim = F.tuple_len(shape) axes = _check_axes_range((axis1, axis2), ndim) axis1, axis2 = axes[0], axes[1] if axis1 == axis2: return x if axis1 > axis2: axis1, axis2 = axis2, axis1 perm = F.make_range(0, ndim) new_perm = None if axis2 + 1 < ndim: new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ perm[axis1+1:axis2] + perm[axis1:axis1+1] + perm[axis2+1:] else: new_perm = perm[0:axis1] + perm[axis2:axis2+1] + \ perm[axis1+1:axis2] + perm[axis1:axis1+1] return P.Transpose()(x, new_perm)
def construct(self): """position matrix generator""" range_vec_row_out = self.cast(F.tuple_to_array(F.make_range(self._length)), mstype.int32) range_vec_col_out = self.range_mat(range_vec_row_out, (self._length, -1)) tile_row_out = self.tile(range_vec_row_out, (self._length,)) tile_col_out = self.tile(range_vec_col_out, (1, self._length)) range_mat_out = self.range_mat(tile_row_out, (self._length, self._length)) transpose_out = self.range_mat(tile_col_out, (self._length, self._length)) distance_mat = self.sub(range_mat_out, transpose_out) distance_mat_clipped = C.clip_by_value(distance_mat, self._min_relative_position, self._max_relative_position) # Shift values to be >=0. Each integer still uniquely identifies a # relative position difference. final_mat = distance_mat_clipped + self._max_relative_position return final_mat
def construct(self, input_tensor, output_weights, positions): """Get output log_probs""" rng = F.tuple_to_array(F.make_range(P.Shape()(input_tensor)[0])) flat_offsets = self.reshape(rng * self.seq_length_tensor, self.shape_flat_offsets) flat_position = self.reshape(positions + flat_offsets, self.last_idx) flat_sequence_tensor = self.reshape(input_tensor, self.shape_flat_sequence_tensor) input_tensor = self.gather(flat_sequence_tensor, flat_position, 0) input_tensor = self.cast(input_tensor, self.compute_type) output_weights = self.cast(output_weights, self.compute_type) input_tensor = self.dense(input_tensor) input_tensor = self.layernorm(input_tensor) logits = self.matmul(input_tensor, output_weights) logits = self.cast(logits, self.dtype) logits = logits + self.output_bias log_probs = self.log_softmax(logits) return log_probs
def _check_axis_valid(axes, ndim): """ Checks axes are valid given ndim, and returns axes that can be passed to the built-in operator (non-negative, int or tuple) """ if isinstance(axes, int): _ = _check_axis_in_range(axes, ndim) return (axes % ndim, ) if isinstance(axes, tuple): for axis in axes: _ = _check_axis_in_range(axis, ndim) axes = tuple(map(lambda x: x % ndim, axes)) if all(axes.count(el) <= 1 for el in axes): return axes if axes is None: axes = F.make_range(ndim) return axes raise ValueError('duplicate value in \'axis\'')
def get_axis(self, x): shape = F.shape(x) length = F.tuple_len(shape) perm = F.make_range(0, length) return perm
def get_axis(x): shape_op = P.Shape() shape = shape_op(x) length = F.tuple_len(shape) perm = F.make_range(0, length) return perm
def construct(self, x): batch_size, seq_length = x.shape[0], x.shape[1] input_position = F.tuple_to_array(F.make_range(seq_length)) # input_position = P.Tile()(input_position, (batch_size, 1)) return self.emb(input_position)
def rollaxis(x, axis, start=0): """ Roll the specified axis backwards, until it lies in the given position. The positions of the other axes do not change relative to one another. Args: x (Tensor): A Tensor to be transposed. axis (int): The axis to be rolled. start (int): - When start >= 0: - When start <= axis: the axis is rolled back until it lies in this position (start). - When start > axis: the axis is rolled until it lies before this position (start). - When start < 0: the start will be normalized as follows: start ........... Normalized start -(x.ndim+1) raise ValueError -x.ndim 0 ... ... -1 x.ndim-1 0 0 ... ... x.ndim x.ndim x.ndim+1 raise ValueError Returns: Transposed Tensor. Has the same data type as the original tensor x. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` Raises: TypeError: If axis or start is not integer. ValueError: If axis is not in the range from -ndim to ndim-1 or start is not in the range from -ndim to ndim. Examples: >>> import mindspore >>> import mindspore.numpy as mnp >>> from mindspore import Tensor >>> import numpy as onp >>> input_x = Tensor(onp.ones((2,3,4)), mindspore.float32) >>> output = mnp.rollaxis(x, 0, 2) >>> print(output.shape) (3,2,4) """ _check_is_int(axis) _check_is_int(start) shape = F.shape(x) ndim = F.tuple_len(shape) axis = _check_axes_range(axis, ndim) start = _check_start_normalize(start, ndim) if start - axis >= 0 and start - axis <= 1: return x perm = F.make_range(0, ndim) new_perm = None if start < axis: if axis + 1 < ndim: new_perm = perm[0:start] + perm[axis:axis+1] + \ perm[start:axis] + perm[axis+1:] else: new_perm = perm[0:start] + perm[axis:axis + 1] + perm[start:axis] if start > axis: if start < ndim: new_perm = perm[0:axis] + perm[axis+1:start] + \ perm[axis:axis+1] + perm[start:] else: new_perm = perm[0:axis] + perm[axis+1:start] + \ perm[axis:axis+1] return P.Transpose()(x, new_perm)