def __init__(self, vocab_size, embedding_size, embedding_shape, use_one_hot_embeddings=False, initializer_range=0.02): super(EmbeddingLookup, self).__init__() self.vocab_size = vocab_size self.use_one_hot_embeddings = use_one_hot_embeddings self.embedding_table = Parameter(initializer (TruncatedNormal(initializer_range), [vocab_size, embedding_size])) self.expand = P.ExpandDims() self.shape_flat = (-1,) self.gather = P.Gather() self.one_hot = P.OneHot() self.on_value = Tensor(1.0, mstype.float32) self.off_value = Tensor(0.0, mstype.float32) self.array_mul = P.MatMul() self.reshape = P.Reshape() self.shape = tuple(embedding_shape)
def __init__(self, length, depth, max_relative_position, initializer_range, use_one_hot_embeddings=False): super(RelaPosEmbeddingsGenerator, self).__init__() self.depth = depth self.vocab_size = max_relative_position * 2 + 1 self.use_one_hot_embeddings = use_one_hot_embeddings self.embeddings_table = Parameter( initializer(TruncatedNormal(initializer_range), [self.vocab_size, self.depth])) self.relative_positions_matrix = RelaPosMatrixGenerator(length=length, max_relative_position=max_relative_position) self.reshape = P.Reshape() self.one_hot = nn.OneHot(depth=self.vocab_size) self.shape = P.Shape() self.gather = P.Gather() # index_select self.matmul = P.BatchMatMul()
def __init__(self, embedding_size, embedding_shape, use_relative_positions=False, use_token_type=False, token_type_vocab_size=16, use_one_hot_embeddings=False, initializer_range=0.02, max_position_embeddings=512, dropout_prob=0.1): super(EmbeddingPostprocessor, self).__init__() self.use_token_type = use_token_type self.token_type_vocab_size = token_type_vocab_size self.use_one_hot_embeddings = use_one_hot_embeddings self.max_position_embeddings = max_position_embeddings self.token_type_embedding = nn.Embedding( vocab_size=token_type_vocab_size, embedding_size=embedding_size, use_one_hot=use_one_hot_embeddings) self.shape_flat = (-1,) self.one_hot = P.OneHot() self.on_value = Tensor(1.0, mstype.float32) self.off_value = Tensor(0.1, mstype.float32) self.array_mul = P.MatMul() self.reshape = P.Reshape() self.shape = tuple(embedding_shape) self.dropout = nn.Dropout(1 - dropout_prob) self.gather = P.Gather() self.use_relative_positions = use_relative_positions self.slice = P.StridedSlice() _, seq, _ = self.shape self.full_position_embedding = nn.Embedding( vocab_size=max_position_embeddings, embedding_size=embedding_size, use_one_hot=False) self.layernorm = nn.LayerNorm((embedding_size,), epsilon=1e-5) self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32)) self.add = P.Add()
def construct(self, input_x): shape = input_x.shape dim_size = shape[self.dim] reversed_indexes = mnp.arange(dim_size - 1, -1, -1) output = ops.Gather()(input_x, reversed_indexes, self.dim) return output