コード例 #1
0
    def construct(self, x, seq_lengths):
        """Defines the ReverseSequence operator computation performed."""
        batch_size = x.shape[self.batch_dim]
        max_seq_len = x.shape[self.seq_dim]
        seq_lens_type = seq_lengths.dtype

        back = ops.Sub()(seq_lengths, ops.OnesLike()(seq_lengths))

        batch_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type,
                                    0)
        forward_idx = self.make_shape((batch_size, max_seq_len), seq_lens_type,
                                      1)

        back = back.view(-1, 1)
        reverse_idx = ops.Sub()(back, forward_idx)

        condition = ops.Less()(reverse_idx, ops.ZerosLike()(reverse_idx))
        reverse_idx = ops.Select()(condition, forward_idx, reverse_idx)

        reverse_idx = ops.ExpandDims()(reverse_idx, 2)
        batch_idx = ops.ExpandDims()(batch_idx, 2)

        if self.batch_dim > self.seq_dim:
            batch_idx = ops.Transpose()(batch_idx, (1, 0, 2))
            reverse_idx = ops.Transpose()(reverse_idx, (1, 0, 2))
            x = ops.Transpose()(x, (1, 0, 2))
        start_indices = ops.Concat(2)((batch_idx, reverse_idx))

        output = ops.GatherNd()(x, start_indices)

        return output
コード例 #2
0
    def construct(self, s_t_hat, encoder_outputs, encoder_feature,
                  enc_padding_mask, coverage):
        b, t_k, n = encoder_outputs.shape

        dec_fea = self.decode_proj(s_t_hat)  # (B, 2 * hidden_dim)
        dec_fea_expand = P.ExpandDims()(dec_fea, 1)
        dec_fea_expand = P.BroadcastTo()(dec_fea_expand, (b, t_k, n))

        att_features = encoder_feature + dec_fea_expand
        if self.is_coverage:
            coverage_input = coverage.view(-1, 1)  # (B * t_k, 1)
            coverage_feature = self.W_c(
                coverage_input)  # (B * t_k, 2 * hidden_dim)
            att_features = att_features + coverage_feature

        e = P.Tanh()(att_features)  # (B * t_k, 2 * hidden_dim)
        scores = self.v(e)  # (B * t_k, 1)
        scores = scores.view(-1, t_k)  # (B, t_k)

        attn_dist_ = P.Softmax(1)(scores) * enc_padding_mask  # (B, t_k)
        normalization_factor = P.ReduceSum(True)(attn_dist_, 1)
        attn_dist = attn_dist_ / normalization_factor

        attn_dist = P.ExpandDims()(attn_dist, 1)  # (B, 1, t_k)
        c_t = P.BatchMatMul(attn_dist, encoder_outputs)  # (B, 1, n)
        c_t = c_t.view(-1, self.hidden_dim * 2)  # (B, 2 * hidden_dim)

        attn_dist = attn_dist.view(-1, t_k)

        if self.is_coverage:
            coverage = coverage.view(-1, t_k)
            coverage = coverage + attn_dist

        return c_t, attn_dist, coverage
コード例 #3
0
    def construct(self, hidden):
        h, c = hidden  # h, c dim = 2 x b x hidden_dim
        h_in = P.Transpose()(h, (1, 0, 2)).view(-1, self.hidden_dim * 2)
        hidden_reduced_h = P.ReLU()(self.reduce_h(h))
        hidden_reduced_h = P.ExpandDims()(hidden_reduced_h, 0)
        c_in = P.Transpose()(c, (1, 0, 2)).view(-1, self.hidden_dim * 2)
        hidden_reduced_c = P.ReLU()(self.reduce_c(c_in))
        hidden_reduced_c = P.ExpandDims()(hidden_reduced_c, 0)

        return (hidden_reduced_h, hidden_reduced_c)
コード例 #4
0
ファイル: utils.py プロジェクト: xiaoxiugege/mindspore
    def __init__(self, bins=10, momentum=0.0, mu=0.02):
        super(GHMRLoss, self).__init__()
        self.bins = bins
        self.momentum = momentum
        self.mu = mu
        edges_left = np.array([float(x) / bins for x in range(bins)], dtype=np.float32)
        self.edges_left = Tensor(edges_left.reshape((bins, 1, 1, 1, 1)))
        edges_right = np.array([float(x) / bins for x in range(1, bins + 1)], dtype=np.float32)
        edges_right[-1] += 1e-4
        self.edges_right = Tensor(edges_right.reshape((bins, 1, 1, 1, 1)))

        if momentum >= 0:
            self.acc_sum = Parameter(initializer(0, [bins], mstype.float32))

        self.abs = ops.Abs()
        self.sqrt = ops.Sqrt()
        self.cast = ops.Cast()
        self.select = ops.Select()
        self.reshape = ops.Reshape()
        self.reduce_sum = ops.ReduceSum()
        self.max = ops.Maximum()
        self.less = ops.Less()
        self.equal = ops.Equal()
        self.greater = ops.Greater()
        self.logical_and = ops.LogicalAnd()
        self.greater_equal = ops.GreaterEqual()
        self.zeros_like = ops.ZerosLike()
        self.expand_dims = ops.ExpandDims()
コード例 #5
0
    def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, has_bias=False, modulation=True):
        super(DeformConv2d, self).__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.zero_padding = nn.Pad(((0, 0), (0, 0), (padding, padding), (padding, padding)))
        self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, pad_mode='valid', padding=0,
                              stride=kernel_size, has_bias=has_bias)

        self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=self.kernel_size,
                                pad_mode='pad', padding=self.padding, stride=self.stride)

        self.modulation = modulation
        if modulation:
            self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=self.kernel_size,
                                    pad_mode='valid', padding=0, stride=self.stride)
        if kernel_size % 2 == 0:
            raise ValueError("Only odd number is supported, but current kernel sizeis {}".format(kernel_size))
        self.N = kernel_size * kernel_size
        self.begin = kernel_size // 2
        self.sigmoid = ops.Sigmoid()
        self.dtype = ops.DType()
        self.perm_list = (0, 2, 3, 1)
        self.transpose = ops.Transpose()
        self.floor = ops.Floor()
        self.half = ops.Split(axis=-1, output_num=2)
        self.clip_value = ClipByValue()
        self.expand_dims = ops.ExpandDims()
        self.shape = ops.Shape()
        self.cast = ops.Cast()
        self._get_offset = GetOffsetPosition(self.begin, self.stride)
        self._get_surround = GetSurroundFeature()
        self._generate_fm = RegenerateFeatureMap(self.kernel_size)
コード例 #6
0
    def construct(self, teacher, student, neg):
        expand_dims = ops.ExpandDims()  # unsqueeze算子
        teacher_vgg, student_vgg, neg_vgg = self.vgg(teacher), self.vgg(
            student), self.vgg(neg)

        loss = 0
        for i in range(len(teacher_vgg)):
            neg_i = expand_dims(neg_vgg[i], 0)  # [8, n_feats, w, h]
            # neg_i = neg_i.repeat(student_vgg[i].shape[0], axis=0)  #TODO:1.3版本才会支持Tensor.repeat
            neg_i = np.repeat(neg_i, student_vgg[i].shape[0],
                              axis=0)  # [16, 8, n_feats, w, h]
            neg_i = neg_i.transpose((1, 0, 2, 3, 4))  # [8, 16, n_feats, w, h]

            d_ts = self.l1(stop_gradient(teacher_vgg[i]), student_vgg[i])
            # d_sn = (stop_gradient(neg_i) - student_vgg[i]).abs().sum(axis=0).mean() #TODO:1.3版本才支持Tensor.sum
            d_sn = (stop_gradient(neg_i) -
                    student_vgg[i]).abs()  # [8, 16, n_feats, w, h]
            # print(d_sn.shape)
            reduceSum = ops.ReduceSum()
            d_sn = reduceSum(d_sn, 0).mean()
            # print(d_sn)

            contrastive = d_ts / (d_sn + 1e-7)
            loss += self.weights[i] * contrastive

        return self.get_loss(loss)
コード例 #7
0
 def __init__(self):
     super(log_softmax, self).__init__()
     self.maxi = P.ReduceMax()
     self.log = P.Log()
     self.sums = P.ReduceSum()
     self.exp = P.Exp()
     self.axis = -1
     self.concat = P.Concat(-1)
     self.expanddims = P.ExpandDims()
コード例 #8
0
 def __init__(self, mixture_size: int, do_layer_norm: bool = False) -> None:
     super(Scalar_mix, self).__init__()
     self.mixture_size = mixture_size
     self.do_layer_norm = do_layer_norm
     self.scalar_parameters = ParameterTuple([Parameter(Tensor(np.array([0.0]), mindspore.float32)) \
                                              for _ in range(mixture_size)])
     self.gamma = Parameter(Tensor(np.array([0.0]), mindspore.float32))
     self.sum = P.ReduceSum()
     self.sqrt = P.Sqrt()
     self.cat = P.Concat()
     self.unsqueeze = P.ExpandDims(0)
コード例 #9
0
    def __init__(self, length, max_relative_position):
        super(RelaPosMatrixGenerator, self).__init__()
        self._length = length
        self._max_relative_position = max_relative_position
        self._min_relative_position = -max_relative_position
        self.range_length = -length + 1

        self.tile = P.Tile()
        self.range_mat = P.Reshape()
        self.sub = P.Sub()
        self.expanddims = P.ExpandDims()
        self.cast = P.Cast()
コード例 #10
0
 def __init__(self):
     super(GetSurroundFeature, self).__init__()
     self.shape = ops.Shape()
     self.concat = ops.Concat(axis=1)
     self.reshape = ops.Reshape()
     self.half = ops.Split(axis=-1, output_num=2)
     self.tile = ops.Tile()
     self.gather_nd = ops.GatherNd()
     self.transpose = ops.Transpose()
     self.perm_list = (0, 2, 3, 1)
     self.order_list = (0, 3, 1, 2)
     self.expand_dims = ops.ExpandDims()
コード例 #11
0
ファイル: utils.py プロジェクト: xiaoxiugege/mindspore
 def __init__(self, enable_cpu_gather=True):
     super(GatherFeature, self).__init__()
     self.tile = ops.Tile()
     self.shape = ops.Shape()
     self.concat = ops.Concat(axis=1)
     self.reshape = ops.Reshape()
     self.enable_cpu_gather = enable_cpu_gather
     if self.enable_cpu_gather:
         self.gather_nd = ops.GatherD()
         self.expand_dims = ops.ExpandDims()
     else:
         self.gather_nd = ops.GatherND()
コード例 #12
0
    def __init__(self, length, max_relative_position):
        super(RelaPosMatrixGenerator, self).__init__()
        self._length = length
        self._max_relative_position = Tensor(max_relative_position, dtype=mstype.int32)
        self._min_relative_position = Tensor(-max_relative_position, dtype=mstype.int32)
        self.range_length = -length + 1

        self.tile = ops.Tile()
        self.range_mat = ops.Reshape()
        self.sub = ops.Sub()
        self.expanddims = ops.ExpandDims()
        self.cast = ops.Cast()
コード例 #13
0
ファイル: utils_ms.py プロジェクト: xcnick/deeplearning
    def get_extended_attention_mask(self, attention_mask):
        extended_attention_mask = None
        expand_dims = ops.ExpandDims()
        if attention_mask.ndim == 3:
            extended_attention_mask = expand_dims(attention_mask, 1)
        elif attention_mask.ndim == 2:
            attention_mask = expand_dims(attention_mask, 1)
            extended_attention_mask = expand_dims(attention_mask, 1)

        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        return extended_attention_mask
コード例 #14
0
ファイル: utils.py プロジェクト: xiaoxiugege/mindspore
 def __init__(self, mode='l1'):
     super(RegLoss, self).__init__()
     self.reduce_sum = ops.ReduceSum()
     self.cast = ops.Cast()
     self.expand_dims = ops.ExpandDims()
     self.reshape = ops.Reshape()
     self.gather_feature = TransposeGatherFeature()
     if mode == 'l1':
         self.loss = nn.L1Loss(reduction='sum')
     elif mode == 'sl1':
         self.loss = nn.SmoothL1Loss()
     else:
         self.loss = None
コード例 #15
0
ファイル: decode.py プロジェクト: yrpang/mindspore
 def __init__(self, net_config, K=100, enable_nms_fp16=True):
     super(DetectionDecode, self).__init__()
     self.K = K
     self.nms = NMS(enable_nms_fp16=enable_nms_fp16)
     self.shape = ops.Shape()
     self.gather_topk = GatherTopK()
     self.half = ops.Split(axis=-1, output_num=2)
     self.add = ops.TensorAdd()
     self.concat_a2 = ops.Concat(axis=2)
     self.trans_gather_feature = TransposeGatherFeature()
     self.expand_dims = ops.ExpandDims()
     self.reshape = ops.Reshape()
     self.reg_offset = net_config.reg_offset
     self.Sigmoid = nn.Sigmoid()
コード例 #16
0
    def __init__(self, log_scale_min=-7.0, reduce=True):
        super(mix_gaussian_loss, self).__init__()
        self.log_scale_min = log_scale_min
        self.reduce = reduce
        self.transpose_op = P.Transpose()
        self.maximum = P.Maximum()
        self.tile = P.Tile()
        self.exp = P.Exp()
        self.logsoftmax = P.LogSoftmax(-1)
        self.expand_dims = P.ExpandDims()
        self.sums = P.ReduceSum()
        self.lse = log_sum_exp()

        self.sq = P.Square()
        self.sqrt = P.Sqrt()
        self.const = P.ScalarToArray()
        self.log = P.Log()
コード例 #17
0
 def __init__(self, num_classes=256, log_scale_min=-7.0, reduce=True):
     super(discretized_mix_logistic_loss, self).__init__()
     self.num_classes = num_classes
     self.log_scale_min = log_scale_min
     self.reduce = reduce
     self.transpose_op = P.Transpose()
     self.exp = P.Exp()
     self.sigmoid = P.Sigmoid()
     self.softplus = Stable_softplus()
     self.log = P.Log()
     self.cast = P.Cast()
     self.logsoftmax = P.LogSoftmax(-1)
     self.expand_dims = P.ExpandDims()
     self.tile = P.Tile()
     self.maximum = P.Maximum()
     self.sums = P.ReduceSum()
     self.lse = log_sum_exp()
     self.reshape = P.Reshape()
     self.factor = self.log(Tensor((self.num_classes - 1) / 2, ms.float32))
コード例 #18
0
ファイル: resnet.py プロジェクト: yrpang/mindspore
    def __init__(self,
                 inplanes,
                 planes,
                 stride=1,
                 downsample=None,
                 num=0,
                 thres=None):
        super(MaskedBasicblock, self).__init__()

        self.conv_a = _conv3x3(inplanes, planes, stride=stride)
        self.bn_a = _bn(planes)

        self.conv_b = _conv3x3(planes, planes, stride=1)
        self.bn_b = _bn(planes)

        self.downsample = downsample

        self.mb1 = MaskBlock(inplanes, planes, num * 2, thres)
        self.mb2 = MaskBlock(planes, planes, num * 2 + 1, thres)
        self.relu = P.ReLU()
        self.expand_dims = ops.ExpandDims()
コード例 #19
0
 def __init__(self,
              vocab_size,
              embedding_size,
              embedding_shape,
              use_one_hot_embeddings=False,
              initializer_range=0.02):
     super(EmbeddingLookup, self).__init__()
     self.vocab_size = vocab_size
     self.use_one_hot_embeddings = use_one_hot_embeddings
     self.embedding_table = Parameter(initializer
                                      (TruncatedNormal(initializer_range),
                                       [vocab_size, embedding_size]))
     self.expand = P.ExpandDims()
     self.shape_flat = (-1,)
     self.gather = P.Gather()
     self.one_hot = P.OneHot()
     self.on_value = Tensor(1.0, mstype.float32)
     self.off_value = Tensor(0.0, mstype.float32)
     self.array_mul = P.MatMul()
     self.reshape = P.Reshape()
     self.shape = tuple(embedding_shape)
コード例 #20
0
    def __init__(self, log_scale_min=-7.0, reduce=True):
        super(mix_gaussian_loss, self).__init__()
        self.log_scale_min = log_scale_min
        self.reduce = reduce
        self.transpose_op = P.Transpose()
        self.maximum = P.Maximum()
        self.tile = P.Tile()
        self.exp = P.Exp()
        self.expand_dims = P.ExpandDims()
        self.sums = P.ReduceSum()
        self.lse = log_sum_exp()
        self.sq = P.Square()
        self.sqrt = P.Sqrt()
        self.const = P.ScalarToArray()
        self.log = P.Log()
        self.tensor_one = Tensor(1., ms.float32)

        if context.get_context("device_target") == "CPU":
            self.logsoftmax = log_softmax()
        else:
            self.logsoftmax = P.LogSoftmax(-1)
コード例 #21
0
ファイル: decode.py プロジェクト: louis100/mindspore
 def __init__(self, net_config, K=100, enable_nms_fp16=True):
     super(MultiPoseDecode, self).__init__()
     self.K = K
     self.nms = NMS(enable_nms_fp16=enable_nms_fp16)
     self.shape = ops.Shape()
     self.gather_topk = GatherTopK()
     self.gather_topk_channel = GatherTopKChannel()
     self.gather_by_ind = GatherFeatureByInd()
     self.half = ops.Split(axis=-1, output_num=2)
     self.half_first = ops.Split(axis=0, output_num=2)
     self.split = ops.Split(axis=-1, output_num=4)
     self.flip_lr = FlipLR()
     self.flip_lr_off = FlipLROff()
     self.flip_tensor = FlipTensor()
     self.concat = ops.Concat(axis=1)
     self.concat_a2 = ops.Concat(axis=2)
     self.concat_a3 = ops.Concat(axis=3)
     self.trans_gather_feature = TransposeGatherFeature()
     self.expand_dims = ops.ExpandDims()
     self.reshape = ops.Reshape()
     self.add = ops.TensorAdd()
     self.dtype = ops.DType()
     self.cast = ops.Cast()
     self.thresh = 0.1
     self.transpose = ops.Transpose()
     self.perm_list = (0, 2, 1, 3)
     self.tile = ops.Tile()
     self.greater = ops.Greater()
     self.square = ops.Square()
     self.sqrt = ops.Sqrt()
     self.reduce_sum = ops.ReduceSum()
     self.min = ops.ArgMinWithValue(axis=3)
     self.max = ops.Maximum()
     self.hm_hp = net_config.hm_hp
     self.dense_hp = net_config.dense_hp
     self.reg_offset = net_config.reg_offset
     self.reg_hp_offset = net_config.reg_hp_offset
     self.hm_hp_ind = 3 if self.hm_hp else 2
     self.reg_ind = self.hm_hp_ind + 1 if self.reg_offset else self.hm_hp_ind
     self.reg_hp_ind = self.reg_ind + 1 if self.reg_hp_offset else self.reg_ind
コード例 #22
0
ファイル: clip_grads.py プロジェクト: lvyufeng/elmo_mindspore
def average_gradients(tower_grads):
    average_grads = []
    for grad_and_vars in zip(*tower_grads):
        g0, v0 = grad_and_vars[0]
        if g0 is None:
            average_grads.append((g0, v0))
            continue
        # the gradient is type IndexedSlices
        # to do
        # a normal tensor can just do a simple  average
        grads = []
        for g, v in grad_and_vars:
            expand_g = P.ExpandDims()(g, 0)
            grads.append(expand_g)

        # Average over the 'tower' dimension
        grad = P.Concat(0)(grads)
        grad = P.ReduceMean(grad, 0)

        v = grad_and_vars[0][1]
        grad_and_vars = (grad, v)
        average_grads.append(grad_and_vars)
    assert len(average_grads) == len(list(zip(*tower_grads)))
    return average_grads
コード例 #23
0
ファイル: bert_ms.py プロジェクト: xcnick/deeplearning
    def construct(
        self,
        input_ids: ms.Tensor,
        token_type_ids: Optional[ms.Tensor] = None,
        position_ids: Optional[ms.Tensor] = None,
    ) -> ms.Tensor:
        input_shape = input_ids.shape
        seq_length = input_shape[1]

        if token_type_ids is None:
            token_type_ids = ops.zeros_like(input_ids)
        if position_ids is None:
            position_ids = ops.ExpandDims()(nn.Range(0, seq_length)(), 0)

        input_embeddings = self.token_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        position_embeddings = self.position_embeddings(position_ids)

        embeddings = input_embeddings + position_embeddings + \
            token_type_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings
コード例 #24
0
    def __init__(self,
                 batch_size,
                 from_tensor_width,
                 to_tensor_width,
                 from_seq_length,
                 to_seq_length,
                 num_attention_heads=1,
                 size_per_head=512,
                 query_act=None,
                 key_act=None,
                 value_act=None,
                 has_attention_mask=False,
                 attention_probs_dropout_prob=0.0,
                 use_one_hot_embeddings=False,
                 initializer_range=0.02,
                 do_return_2d_tensor=False,
                 use_relative_positions=False,
                 compute_type=mstype.float32):

        super(BertAttention, self).__init__()
        self.batch_size = batch_size
        self.from_seq_length = from_seq_length
        self.to_seq_length = to_seq_length
        self.num_attention_heads = num_attention_heads
        self.size_per_head = size_per_head
        self.has_attention_mask = has_attention_mask
        self.use_relative_positions = use_relative_positions

        self.scores_mul = Tensor([1.0 / math.sqrt(float(self.size_per_head))], dtype=compute_type)
        self.reshape = ops.Reshape()
        self.shape_from_2d = (-1, from_tensor_width)
        self.shape_to_2d = (-1, to_tensor_width)
        weight = TruncatedNormal(initializer_range)
        units = num_attention_heads * size_per_head
        self.query_layer = nn.Dense(from_tensor_width,
                                    units,
                                    activation=query_act,
                                    weight_init=weight).to_float(compute_type)
        self.key_layer = nn.Dense(to_tensor_width,
                                  units,
                                  activation=key_act,
                                  weight_init=weight).to_float(compute_type)
        self.value_layer = nn.Dense(to_tensor_width,
                                    units,
                                    activation=value_act,
                                    weight_init=weight).to_float(compute_type)

        self.shape_from = (batch_size, from_seq_length, num_attention_heads, size_per_head)
        self.shape_to = (
            batch_size, to_seq_length, num_attention_heads, size_per_head)

        self.matmul_trans_b = ops.BatchMatMul(transpose_b=True)
        self.multiply = ops.Mul()
        self.transpose = ops.Transpose()
        self.trans_shape = (0, 2, 1, 3)
        self.trans_shape_relative = (2, 0, 1, 3)
        self.trans_shape_position = (1, 2, 0, 3)
        #self.multiply_data = Tensor([-10000.0,], dtype=compute_type)
        self.multiply_data = Tensor([-10000.0,], dtype=mstype.float32)
        self.batch_num = batch_size * num_attention_heads
        self.matmul = ops.BatchMatMul()

        self.softmax = nn.Softmax()
        self.dropout = nn.Dropout(1 - attention_probs_dropout_prob)

        if self.has_attention_mask:
            self.expand_dims = ops.ExpandDims()
            self.sub = ops.Sub()
            self.add = ops.TensorAdd()
            self.cast = ops.Cast()
            self.get_dtype = ops.DType()
        if do_return_2d_tensor:
            self.shape_return = (batch_size * from_seq_length, num_attention_heads * size_per_head)
        else:
            self.shape_return = (batch_size, from_seq_length, num_attention_heads * size_per_head)

        self.cast_compute_type = SaturateCast(dst_type=compute_type)
        if self.use_relative_positions:
            self._generate_relative_positions_embeddings = \
                RelaPosEmbeddingsGenerator(length=to_seq_length,
                                           depth=size_per_head,
                                           max_relative_position=16,
                                           initializer_range=initializer_range,
                                           use_one_hot_embeddings=use_one_hot_embeddings)
コード例 #25
0
    def construct(self, y_t_1, s_t_1, encoder_outputs, encoder_feature,
                  enc_padding_mask, c_t_1, extra_zeros, enc_batch_extend_vocab,
                  coverage, step):
        if not self.training and step == 0:
            h_decoder, c_decoder = s_t_1
            h_decoder = h_decoder.view(-1, self.hidden_dim)
            c_decoder = c_decoder.view(-1, self.hidden_dim)
            s_t_hat = P.Concat(1)(
                (h_decoder, c_decoder))  # (B, 2 * hidden_dim)
            c_t, _, coverage_next = self.attention_network(
                s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask,
                coverage)
            coverage = coverage_next

        y_t_1_embed = self.embedding(y_t_1)
        x = self.x_content(P.Concat(1)((c_t_1, y_t_1_embed)))
        lstm_out, s_t = self.lstm(P.ExpandDims()(x, 1), s_t_1)

        h_decoder, c_decoder = s_t
        h_decoder = h_decoder.view(-1, self.hidden_dim)
        c_decoder = c_decoder.view(-1, self.hidden_dim)
        s_t_hat = P.Concat(1)((h_decoder, c_decoder))

        c_t, attn_dist, coverage_next = self.attention_network(
            s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask,
            coverage)

        if self.training or step > 0:
            coverage = coverage_next

        p_gen = None
        if self.pointer_gen:
            p_gen_input = P.Concat(1)(
                (c_t, s_t_hat, x))  # (B, 2 * 2 * hidden_dim + embed_dim)
            p_gen = self.p_gen_linear(p_gen_input)
            p_gen = P.Sigmoid()(p_gen)

        output = P.Concat(1)(
            (lstm_out.view(-1, self.hidden_dim), c_t))  # (B, hidden_dim * 3)
        output = self.out1(output)  # (B, hidden_dim)

        output = self.out2(output)  # (B, vocab_size)
        vocab_dist = P.SoftMax(1)(output)

        if self.pointer_gen:
            vocab_dist_ = p_gen * vocab_dist
            attn_dist_ = (1 - p_gen) * attn_dist

            if extra_zeros is not None:
                vocab_dist_ = P.Concat(1)((vocab_dist_, extra_zeros))

            # like pytorch scatter_add
            batch_size, attn_len = enc_batch_extend_vocab.shape
            batch_num = range_tensor(0, batch_size)
            batch_num = P.ExpandDims()(batch_num, 1)
            batch_num = P.Tile()(batch_num, (1, attn_len))
            indices = P.Pack(2)((batch_num, enc_batch_extend_vocab))
            shape = (batch_size, vocab_dist_.shape[1])
            attn_dist_ = P.ScatterNd()(indices, attn_dist_, shape)
            final_dist = vocab_dist_ + attn_dist_
        else:
            final_dist = vocab_dist

        return final_dist, s_t, c_t, attn_dist, p_gen, coverage