Exemple #1
0
 def __init__(self, net_config):
     super(CenterNetMultiPoseLossCell, self).__init__()
     self.network = GatherMultiPoseFeatureCell(net_config)
     self.reduce_sum = ops.ReduceSum()
     self.crit = FocalLoss()
     self.crit_hm_hp = nn.MSELoss() if net_config.mse_loss else self.crit
     self.crit_kp = RegWeightedL1Loss(
     ) if not net_config.dense_hp else nn.L1Loss(reduction='sum')
     self.crit_reg = RegLoss(net_config.reg_loss)
     self.hm_weight = net_config.hm_weight
     self.hm_hp_weight = net_config.hm_hp_weight
     self.hp_weight = net_config.hp_weight
     self.wh_weight = net_config.wh_weight
     self.off_weight = net_config.off_weight
     self.hm_hp = net_config.hm_hp
     self.dense_hp = net_config.dense_hp
     self.reg_offset = net_config.reg_offset
     self.reg_hp_offset = net_config.reg_hp_offset
     self.hm_hp_ind = 3 if self.hm_hp else 2
     self.reg_ind = self.hm_hp_ind + 1 if self.reg_offset else self.hm_hp_ind
     self.reg_hp_ind = self.reg_ind + 1 if self.reg_hp_offset else self.reg_ind
     # just used for check
     self.print = ops.Print()
     self.concat = ops.Concat(axis=1)
     self.reshape = ops.Reshape()
Exemple #2
0
 def __init__(self):
     super(GatherFlipFeature, self).__init__()
     self.gather_nd = ops.GatherNd()
     self.transpose = ops.Transpose()
     self.perm_list = (1, 0, 2, 3)
     self.shape = ops.Shape()
     self.reshape = ops.Reshape()
Exemple #3
0
 def __init__(self):
     super(TransposeGatherFeature, self).__init__()
     self.shape = ops.Shape()
     self.reshape = ops.Reshape()
     self.transpose = ops.Transpose()
     self.perm_list = (0, 2, 3, 1)
     self.gather_feat = GatherFeature()
Exemple #4
0
    def __init__(self,
                 length,
                 depth,
                 max_relative_position,
                 initializer_range,
                 use_one_hot_embeddings=False):
        super(RelaPosEmbeddingsGenerator, self).__init__()
        self.depth = depth
        self.vocab_size = max_relative_position * 2 + 1
        self.use_one_hot_embeddings = use_one_hot_embeddings

        self.embeddings_table = Parameter(
            initializer(TruncatedNormal(initializer_range),
                        [self.vocab_size, self.depth]),
            name='embeddings_for_position')

        self.relative_positions_matrix = RelaPosMatrixGenerator(length=length,
                                                                max_relative_position=max_relative_position)
        self.reshape = ops.Reshape()
        self.one_hot = ops.OneHot()
        self.on_value = Tensor(1.0, mstype.float32)
        self.off_value = Tensor(0.0, mstype.float32)
        self.shape = ops.Shape()
        self.gather = ops.GatherV2()  # index_select
        self.matmul = ops.BatchMatMul()
Exemple #5
0
def pixel_shuffle(tensor, scale_factor):
    """
    Implementation of pixel shuffle using numpy

    Parameters:
    -----------
    tensor: input tensor, shape is [N, C, H, W]
    scale_factor: scale factor to up-sample tensor

    Returns:
    --------
    tensor: tensor after pixel shuffle, shape is [N, C/(r*r), r*H, r*W],
        where r refers to scale factor
    """
    num, ch, height, width = tensor.shape
    # assert ch % (scale_factor * scale_factor) == 0

    new_ch = ch // (scale_factor * scale_factor)
    new_height = height * scale_factor
    new_width = width * scale_factor

    reshape = ops.Reshape()
    tensor = reshape(tensor,
                     (num, new_ch, scale_factor, scale_factor, height, width))
    # new axis: [num, new_ch, height, scale_factor, width, scale_factor]
    transpose = ops.Transpose()
    tensor = transpose(tensor, (0, 1, 4, 2, 5, 3))
    tensor = reshape(tensor, (num, new_ch, new_height, new_width))
    return tensor
Exemple #6
0
    def __init__(self, bins=10, momentum=0.0, mu=0.02):
        super(GHMRLoss, self).__init__()
        self.bins = bins
        self.momentum = momentum
        self.mu = mu
        edges_left = np.array([float(x) / bins for x in range(bins)], dtype=np.float32)
        self.edges_left = Tensor(edges_left.reshape((bins, 1, 1, 1, 1)))
        edges_right = np.array([float(x) / bins for x in range(1, bins + 1)], dtype=np.float32)
        edges_right[-1] += 1e-4
        self.edges_right = Tensor(edges_right.reshape((bins, 1, 1, 1, 1)))

        if momentum >= 0:
            self.acc_sum = Parameter(initializer(0, [bins], mstype.float32))

        self.abs = ops.Abs()
        self.sqrt = ops.Sqrt()
        self.cast = ops.Cast()
        self.select = ops.Select()
        self.reshape = ops.Reshape()
        self.reduce_sum = ops.ReduceSum()
        self.max = ops.Maximum()
        self.less = ops.Less()
        self.equal = ops.Equal()
        self.greater = ops.Greater()
        self.logical_and = ops.LogicalAnd()
        self.greater_equal = ops.GreaterEqual()
        self.zeros_like = ops.ZerosLike()
        self.expand_dims = ops.ExpandDims()
Exemple #7
0
    def __init__(self, config):
        super(CreateAttentionMaskFromInputMask, self).__init__()
        self.input_mask = None

        self.cast = P.Cast()
        self.reshape = P.Reshape()
        self.shape = (-1, 1, config.seq_length)
Exemple #8
0
 def __init__(self, ks):
     super(RegenerateFeatureMap, self).__init__()
     self.ks = ks
     self.shape = ops.Shape()
     self.reshape = ops.Reshape()
     self.split = ops.Split(axis=-1, output_num=ks)
     self.concat = ops.Concat(axis=2)
    def __init__(self, config):
        super(GetMaskedLMOutput, self).__init__()
        self.width = config.hidden_size
        self.reshape = ops.Reshape()
        self.gather = ops.GatherV2()

        weight_init = TruncatedNormal(config.initializer_range)
        self.dense = nn.Dense(self.width,
                              config.hidden_size,
                              weight_init=weight_init,
                              activation=config.hidden_act).to_float(config.compute_type)
        self.layernorm = nn.LayerNorm((config.hidden_size,)).to_float(config.compute_type)
        self.output_bias = Parameter(
            initializer(
                'zero',
                config.vocab_size),
            name='output_bias')
        self.matmul = ops.MatMul(transpose_b=True)
        self.log_softmax = nn.LogSoftmax(axis=-1)
        self.shape_flat_offsets = (-1, 1)
        self.rng = Tensor(np.array(range(0, config.batch_size)).astype(np.int32))
        self.last_idx = (-1,)
        self.shape_flat_sequence_tensor = (config.batch_size * config.seq_length, self.width)
        self.seq_length_tensor = Tensor(np.array((config.seq_length,)).astype(np.int32))
        self.cast = ops.Cast()
        self.compute_type = config.compute_type
        self.dtype = config.dtype
Exemple #10
0
 def __init__(self):
     super(GatherTopKChannel, self).__init__()
     self.shape = ops.Shape()
     self.reshape = ops.Reshape()
     self.topk = ops.TopK(sorted=True)
     self.cast = ops.Cast()
     self.dtype = ops.DType()
     self.mod = ops.Mod()
     self.div = ops.Div()
Exemple #11
0
    def __init__(self, length, max_relative_position):
        super(RelaPosMatrixGenerator, self).__init__()
        self._length = length
        self._max_relative_position = max_relative_position
        self._min_relative_position = -max_relative_position
        self.range_length = -length + 1

        self.tile = P.Tile()
        self.range_mat = P.Reshape()
        self.sub = P.Sub()
        self.expanddims = P.ExpandDims()
        self.cast = P.Cast()
Exemple #12
0
 def __init__(self):
     super(GetSurroundFeature, self).__init__()
     self.shape = ops.Shape()
     self.concat = ops.Concat(axis=1)
     self.reshape = ops.Reshape()
     self.half = ops.Split(axis=-1, output_num=2)
     self.tile = ops.Tile()
     self.gather_nd = ops.GatherNd()
     self.transpose = ops.Transpose()
     self.perm_list = (0, 2, 3, 1)
     self.order_list = (0, 3, 1, 2)
     self.expand_dims = ops.ExpandDims()
Exemple #13
0
    def __init__(self, length, max_relative_position):
        super(RelaPosMatrixGenerator, self).__init__()
        self._length = length
        self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
        self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
        self.range_length = -length + 1

        self.tile = ops.Tile()
        self.range_mat = ops.Reshape()
        self.sub = ops.Sub()
        self.expanddims = ops.ExpandDims()
        self.cast = ops.Cast()
 def __init__(self, config):
     super(BertPretrainingLoss, self).__init__()
     self.vocab_size = config.vocab_size
     self.onehot = ops.OneHot()
     self.on_value = Tensor(1.0, mstype.float32)
     self.off_value = Tensor(0.0, mstype.float32)
     self.reduce_sum = ops.ReduceSum()
     self.reduce_mean = ops.ReduceMean()
     self.reshape = ops.Reshape()
     self.last_idx = (-1,)
     self.neg = ops.Neg()
     self.cast = ops.Cast()
Exemple #15
0
 def __init__(self, enable_cpu_gather=True):
     super(GatherFeature, self).__init__()
     self.tile = ops.Tile()
     self.shape = ops.Shape()
     self.concat = ops.Concat(axis=1)
     self.reshape = ops.Reshape()
     self.enable_cpu_gather = enable_cpu_gather
     if self.enable_cpu_gather:
         self.gather_nd = ops.GatherD()
         self.expand_dims = ops.ExpandDims()
     else:
         self.gather_nd = ops.GatherND()
Exemple #16
0
 def __init__(self, num_class=10):
     super(LeNet5, self).__init__()
     self.num_class = num_class
     self.batch_size = 32
     self.conv1 = conv(1, 6, 5)
     self.conv2 = conv(6, 16, 5)
     self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
     self.fc2 = fc_with_initialize(120, 84)
     self.fc3 = fc_with_initialize(84, self.num_class)
     self.relu = nn.ReLU()
     self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
     self.reshape = ops.Reshape()
Exemple #17
0
 def __init__(self, mode='l1'):
     super(RegLoss, self).__init__()
     self.reduce_sum = ops.ReduceSum()
     self.cast = ops.Cast()
     self.expand_dims = ops.ExpandDims()
     self.reshape = ops.Reshape()
     self.gather_feature = TransposeGatherFeature()
     if mode == 'l1':
         self.loss = nn.L1Loss(reduction='sum')
     elif mode == 'sl1':
         self.loss = nn.SmoothL1Loss()
     else:
         self.loss = None
Exemple #18
0
 def __init__(self, begin, stride):
     super(GetOffsetPosition, self).__init__()
     self.begin = begin
     self.stride = stride
     self.meshgrid = ops.Meshgrid()
     self.shape = ops.Shape()
     self.reshape = ops.Reshape()
     self.cat_a0 = ops.Concat(axis=0)
     self.cat_a1 = ops.Concat(axis=1)
     self.tile = ops.Tile()
     self.dtype = ops.DType()
     self.range = nn.Range(-self.begin, self.begin + 1)
     self.cast = ops.Cast()
Exemple #19
0
 def __init__(self, net_config, K=100, enable_nms_fp16=True):
     super(DetectionDecode, self).__init__()
     self.K = K
     self.nms = NMS(enable_nms_fp16=enable_nms_fp16)
     self.shape = ops.Shape()
     self.gather_topk = GatherTopK()
     self.half = ops.Split(axis=-1, output_num=2)
     self.add = ops.TensorAdd()
     self.concat_a2 = ops.Concat(axis=2)
     self.trans_gather_feature = TransposeGatherFeature()
     self.expand_dims = ops.ExpandDims()
     self.reshape = ops.Reshape()
     self.reg_offset = net_config.reg_offset
     self.Sigmoid = nn.Sigmoid()
Exemple #20
0
    def construct(self, input_ids, token_type_id, input_mask):
        """construct BertPoetryModel"""
        input_shape = ops.Shape()(input_mask)
        shape_right = (input_shape[0], 1, input_shape[1])
        shape_left = input_shape + (1,)
        input_mask = self.cast(input_mask, mstype.float32)
        mask_left = self.reshape(input_mask, shape_left)
        mask_right = self.reshape(input_mask, shape_right)
        attention_mask = self.batch_matmul(mask_left, mask_right)
        attention_mask = self.multiply(attention_mask, self.lower_triangle_mask)


        sequence_output, _, embedding_tables = self.bert(input_ids, token_type_id, attention_mask)
        bert_output = ops.Reshape()(sequence_output, (-1, self.hidden_size))
        MLM_output = self.MLM_Dense(bert_output)
        MLM_output = self.layer_norm(MLM_output)
        embedding_tables = ops.Cast()(embedding_tables, mstype.float16)
        output = self.matmul(MLM_output, embedding_tables)
        output = ops.Cast()(output, mstype.float32)
        output = output + self.biasadd
        output = ops.Reshape()(output, (-1, self.seq_length, self.num_tokens))

        logits = self.softmax(output)
        return logits
Exemple #21
0
 def __init__(self, model, config, is_training, dropout_prob=0.0, use_one_hot_embeddings=False):
     super(BertPoetry, self).__init__(auto_prefix=False)
     self.num_tokens = 3191
     self.poetry = model
     self.onehot = ops.OneHot()
     self.on_value = Tensor(1.0, mstype.float32)
     self.off_value = Tensor(0.0, mstype.float32)
     self.reduce_sum = ops.ReduceSum()
     self.reduce_mean = ops.ReduceMean()
     self.reshape = ops.Reshape()
     self.neg = ops.Neg()
     self.cast = ops.Cast()
     self.last_idx = (-1,)
     self.log = ops.Log()
     self.max = ops.ArgMaxWithValue(axis=-1)
Exemple #22
0
    def __init__(self, config):
        super(CreateAttentionMaskFromInputMask, self).__init__()
        self.input_mask_from_dataset = config.input_mask_from_dataset
        self.input_mask = None

        if not self.input_mask_from_dataset:
            self.input_mask = initializer(
                "ones", [config.batch_size, config.seq_length], mstype.int32).to_tensor()

        self.cast = ops.Cast()
        self.reshape = ops.Reshape()
        self.shape = (config.batch_size, 1, config.seq_length)
        self.broadcast_ones = initializer(
            "ones", [config.batch_size, config.seq_length, 1], mstype.float32).to_tensor()
        self.batch_matmul = ops.BatchMatMul()
Exemple #23
0
 def __init__(self, network, optimizer, scale_sense):
     super(TrainAccuStepsWithLossScaleCell,
           self).__init__(network, optimizer, scale_sense)
     self.accumulation = False
     self.accumulation_steps = context.get_auto_parallel_context(
         "grad_accumulation_step")
     self.one = Tensor(np.array([1]).astype(np.int32))
     self.zero = Tensor(np.array([0]).astype(np.int32))
     self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
     self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
     self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
     self.cast = ops.Cast()
     self.logical_or = ops.LogicalOr()
     self.not_equal = ops.NotEqual()
     self.select = ops.Select()
     self.reshape = ops.Reshape()
Exemple #24
0
    def __init__(self,
                 batch_size,
                 seq_length,
                 hidden_size,
                 num_attention_heads=12,
                 attention_probs_dropout_prob=0.1,
                 use_one_hot_embeddings=False,
                 initializer_range=0.02,
                 hidden_dropout_prob=0.1,
                 use_relative_positions=False,
                 compute_type=mstype.float32,
                 enable_fused_layernorm=False):
        super(BertSelfAttention, self).__init__()
        if hidden_size % num_attention_heads != 0:
            raise ValueError("The hidden size (%d) is not a multiple of the number "
                             "of attention heads (%d)" % (hidden_size, num_attention_heads))

        self.size_per_head = int(hidden_size / num_attention_heads)

        self.attention = BertAttention(
            batch_size=batch_size,
            from_tensor_width=hidden_size,
            to_tensor_width=hidden_size,
            from_seq_length=seq_length,
            to_seq_length=seq_length,
            num_attention_heads=num_attention_heads,
            size_per_head=self.size_per_head,
            attention_probs_dropout_prob=attention_probs_dropout_prob,
            use_one_hot_embeddings=use_one_hot_embeddings,
            initializer_range=initializer_range,
            use_relative_positions=use_relative_positions,
            has_attention_mask=True,
            do_return_2d_tensor=True,
            compute_type=compute_type)

        self.output = BertOutput(in_channels=hidden_size,
                                 out_channels=hidden_size,
                                 initializer_range=initializer_range,
                                 dropout_prob=hidden_dropout_prob,
                                 compute_type=compute_type,
                                 enable_fused_layernorm=enable_fused_layernorm)
        self.reshape = ops.Reshape()
        self.shape = (-1, hidden_size)
Exemple #25
0
 def __init__(self, num_classes=256, log_scale_min=-7.0, reduce=True):
     super(discretized_mix_logistic_loss, self).__init__()
     self.num_classes = num_classes
     self.log_scale_min = log_scale_min
     self.reduce = reduce
     self.transpose_op = P.Transpose()
     self.exp = P.Exp()
     self.sigmoid = P.Sigmoid()
     self.softplus = Stable_softplus()
     self.log = P.Log()
     self.cast = P.Cast()
     self.logsoftmax = P.LogSoftmax(-1)
     self.expand_dims = P.ExpandDims()
     self.tile = P.Tile()
     self.maximum = P.Maximum()
     self.sums = P.ReduceSum()
     self.lse = log_sum_exp()
     self.reshape = P.Reshape()
     self.factor = self.log(Tensor((self.num_classes - 1) / 2, ms.float32))
Exemple #26
0
    def __init__(self,
                 batch_size,
                 hidden_size,
                 seq_length,
                 num_hidden_layers,
                 num_attention_heads=12,
                 intermediate_size=3072,
                 attention_probs_dropout_prob=0.1,
                 use_one_hot_embeddings=False,
                 initializer_range=0.02,
                 hidden_dropout_prob=0.1,
                 use_relative_positions=False,
                 hidden_act="gelu",
                 compute_type=mstype.float32,
                 return_all_encoders=False,
                 enable_fused_layernorm=False):
        super(BertTransformer, self).__init__()
        self.return_all_encoders = return_all_encoders

        layers = []
        for _ in range(num_hidden_layers):
            layer = BertEncoderCell(batch_size=batch_size,
                                    hidden_size=hidden_size,
                                    seq_length=seq_length,
                                    num_attention_heads=num_attention_heads,
                                    intermediate_size=intermediate_size,
                                    attention_probs_dropout_prob=attention_probs_dropout_prob,
                                    use_one_hot_embeddings=use_one_hot_embeddings,
                                    initializer_range=initializer_range,
                                    hidden_dropout_prob=hidden_dropout_prob,
                                    use_relative_positions=use_relative_positions,
                                    hidden_act=hidden_act,
                                    compute_type=compute_type,
                                    enable_fused_layernorm=enable_fused_layernorm)
            layers.append(layer)

        self.layers = nn.CellList(layers)

        self.reshape = ops.Reshape()
        self.shape = (-1, hidden_size)
        self.out_shape = (batch_size, seq_length, hidden_size)
Exemple #27
0
 def __init__(self,
              vocab_size,
              embedding_size,
              embedding_shape,
              use_one_hot_embeddings=False,
              initializer_range=0.02):
     super(EmbeddingLookup, self).__init__()
     self.vocab_size = vocab_size
     self.use_one_hot_embeddings = use_one_hot_embeddings
     self.embedding_table = Parameter(initializer
                                      (TruncatedNormal(initializer_range),
                                       [vocab_size, embedding_size]))
     self.expand = P.ExpandDims()
     self.shape_flat = (-1,)
     self.gather = P.Gather()
     self.one_hot = P.OneHot()
     self.on_value = Tensor(1.0, mstype.float32)
     self.off_value = Tensor(0.0, mstype.float32)
     self.array_mul = P.MatMul()
     self.reshape = P.Reshape()
     self.shape = tuple(embedding_shape)
Exemple #28
0
    def __init__(self, network, optimizer, scale_update_cell=None):

        super(BertPoetryCell, self).__init__(network, optimizer, scale_update_cell)
        self.network = network
        self.weights = ParameterTuple(network.trainable_params())
        self.optimizer = optimizer
        self.grad = ops.GradOperation(
            get_by_list=True,
            sens_param=True)
        self.reducer_flag = False
        self.allreduce = ops.AllReduce()
        self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
        if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
            self.reducer_flag = True
        self.grad_reducer = None
        if self.reducer_flag:
            mean = context.get_auto_parallel_context("mirror_mean")
            degree = get_group_size()
            self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
        self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
        self.cast = ops.Cast()
        self.gpu_target = False
        if context.get_context("device_target") == "GPU":
            self.gpu_target = True
            self.float_status = ops.FloatStatus()
            self.addn = ops.AddN()
            self.reshape = ops.Reshape()
        else:
            self.alloc_status = ops.NPUAllocFloatStatus()
            self.get_status = ops.NPUGetFloatStatus()
            self.clear_before_grad = ops.NPUClearFloatStatus()
        self.reduce_sum = ops.ReduceSum(keep_dims=False)
        self.base = Tensor(1, mstype.float32)
        self.less_equal = ops.LessEqual()
        self.hyper_map = ops.HyperMap()
        self.loss_scale = None
        self.loss_scaling_manager = scale_update_cell
        if scale_update_cell:
            self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32),
                                        name="loss_scale")
Exemple #29
0
 def __init__(self, net_config, K=100, enable_nms_fp16=True):
     super(MultiPoseDecode, self).__init__()
     self.K = K
     self.nms = NMS(enable_nms_fp16=enable_nms_fp16)
     self.shape = ops.Shape()
     self.gather_topk = GatherTopK()
     self.gather_topk_channel = GatherTopKChannel()
     self.gather_by_ind = GatherFeatureByInd()
     self.half = ops.Split(axis=-1, output_num=2)
     self.half_first = ops.Split(axis=0, output_num=2)
     self.split = ops.Split(axis=-1, output_num=4)
     self.flip_lr = FlipLR()
     self.flip_lr_off = FlipLROff()
     self.flip_tensor = FlipTensor()
     self.concat = ops.Concat(axis=1)
     self.concat_a2 = ops.Concat(axis=2)
     self.concat_a3 = ops.Concat(axis=3)
     self.trans_gather_feature = TransposeGatherFeature()
     self.expand_dims = ops.ExpandDims()
     self.reshape = ops.Reshape()
     self.add = ops.TensorAdd()
     self.dtype = ops.DType()
     self.cast = ops.Cast()
     self.thresh = 0.1
     self.transpose = ops.Transpose()
     self.perm_list = (0, 2, 1, 3)
     self.tile = ops.Tile()
     self.greater = ops.Greater()
     self.square = ops.Square()
     self.sqrt = ops.Sqrt()
     self.reduce_sum = ops.ReduceSum()
     self.min = ops.ArgMinWithValue(axis=3)
     self.max = ops.Maximum()
     self.hm_hp = net_config.hm_hp
     self.dense_hp = net_config.dense_hp
     self.reg_offset = net_config.reg_offset
     self.reg_hp_offset = net_config.reg_hp_offset
     self.hm_hp_ind = 3 if self.hm_hp else 2
     self.reg_ind = self.hm_hp_ind + 1 if self.reg_offset else self.hm_hp_ind
     self.reg_hp_ind = self.reg_ind + 1 if self.reg_hp_offset else self.reg_ind
Exemple #30
0
 def __init__(self, config, is_training, num_tokens, dropout_prob=0.0, use_one_hot_embeddings=False):
     super(BertPoetryModel, self).__init__()
     self.bert = BertModel(config, is_training, use_one_hot_embeddings)
     self.num_tokens = num_tokens
     idx = np.arange(config.seq_length)
     mask = idx[None, :] <= idx[:, None]
     self.mask = Tensor([mask], mstype.float32)
     self.MLM_Dense = nn.Dense(config.hidden_size, config.hidden_size,\
                             has_bias=True, weight_init=TruncatedNormal(0.02),\
                             activation='gelu').to_float(mstype.float16)
     self.layer_norm = nn.LayerNorm((config.hidden_size,))
     self.matmul = ops.MatMul(transpose_b=True)
     self.biasadd = Parameter(initializer('zero', self.num_tokens), name='MLM_output_biasadd')
     self.softmax = ops.Softmax(axis=-1)
     self.seq_length = config.seq_length
     self.hidden_size = config.hidden_size
     self.cast = ops.Cast()
     self.reshape = ops.Reshape()
     self.batch_matmul = ops.BatchMatMul()
     ones = np.ones(shape=(config.batch_size, config.seq_length, config.seq_length))
     self.lower_triangle_mask = Tensor(np.tril(ones), dtype=mstype.float32)
     self.multiply = ops.Mul()