class ShowAttendTellModel(nn.Module):
    def __init__(self, opt):
        super(ShowAttendTellModel, self).__init__()

        self.vocab_size = opt.vocab_size
        self.input_encoding_size = opt.input_encoding_size
        self.lstm_size = opt.lstm_size
        # self.drop_prob_lm = opt.drop_prob_lm
        self.drop_prob_lm = 0.1
        self.seq_length = opt.seq_length

        self.fc_feat_size = opt.fc_feat_size
        self.conv_feat_size = opt.conv_feat_size
        self.conv_att_size = opt.conv_att_size
        self.att_hidden_size = opt.att_hidden_size
        self.ss_prob = opt.ss_prob  # Schedule sampling probability

        self.fc2h = nn.Linear(self.fc_feat_size, self.lstm_size)

        self.core = LSTMSoftAttentionCore(self.input_encoding_size,
                                          self.lstm_size, self.conv_feat_size,
                                          self.conv_att_size,
                                          self.att_hidden_size,
                                          self.drop_prob_lm)

        self.embed = nn.Embedding(self.vocab_size + 1,
                                  self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.vocab_size)

        # add the following parameters for ARNet
        self.rcst_time = opt.rcst_time
        self.rcst_scale = opt.rcst_weight  # lambda in ARNet
        self.rcstLSTM = LSTMCore(
            self.lstm_size, self.lstm_size,
            self.drop_prob_lm)  # ARNet is realized by LSTM network
        self.h_2_pre_h = nn.Linear(
            self.lstm_size, self.lstm_size)  # fully connected layer in ARNet

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embed.weight.data.uniform_(-initrange, initrange)
        self.fc2h.weight.data.uniform_(-initrange, initrange)
        self.logit.weight.data.uniform_(-initrange, initrange)
        self.logit.bias.data.fill_(0)

        # initialize weights of parameters in ARNet
        self.h_2_pre_h.weight.data.uniform_(-initrange, initrange)
        self.h_2_pre_h.bias.data.fill_(0)

    def copy_weights(self, model_path):
        """
        Initialize the weights of parameters from the model 
        which is pre-trained by Cross Entropy (MLE)
        """
        src_weights = torch.load(model_path)
        own_dict = self.state_dict()
        for key, var in src_weights.items():
            print("copy weights: {}  size: {}".format(key, var.size()))
            own_dict[key].copy_(var)

    def forward(self, fc_feats, att_feats, seq):
        batch_size = fc_feats.size(0)

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)

        outputs = []
        for i in range(seq.size(1)):
            if i >= 1 and self.ss_prob > 0.0:
                sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    prob_prev = torch.exp(
                        outputs[-1].data
                    )  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
                    it = Variable(it, requires_grad=False)
            else:
                it = seq[:, i].clone()

            # break if all the sequences end
            if i >= 1 and seq[:, i].data.sum() == 0:
                break

            xt = self.embed(it)

            output, state = self.core.forward(xt, att_feats, state)
            output = F.log_softmax(self.logit(output.squeeze(0)), dim=1)
            outputs.append(output)

        return torch.cat([_.unsqueeze(1) for _ in outputs],
                         1).contiguous()  # batch * 19 * vocab_size

    # reconstruct 部分
    def rcst_forward(self, fc_feats, att_feats, seq, mask):
        batch_size = fc_feats.size(0)

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)

        rcst_init_h = init_h.clone()
        rcst_init_c = init_c.clone()
        rcst_state = (rcst_init_h, rcst_init_c)

        pre_h = []
        output_logits = []
        rcst_loss = 0.0

        for i in range(seq.size(1) - 1):
            if i >= 1 and self.ss_prob > 0.0:  # otherwise no need to sample
                sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    prob_prev = torch.exp(
                        output_logits[-1].data
                    )  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
                    it = Variable(it, requires_grad=False)
            else:
                it = seq[:, i].clone()

            # break if all the sequences end
            if i >= 1 and seq[:, i].data.sum() == 0:
                break

            xt = self.embed(it)
            output, state = self.core.forward(xt, att_feats, state)
            logit_words = F.log_softmax(self.logit(output.squeeze(0)))
            output_logits.append(logit_words)

            if i >= 1:
                rcst_output, rcst_state = self.rcstLSTM.forward(
                    output, rcst_state)
                rcst_h = F.leaky_relu(self.h_2_pre_h(rcst_output))
                rcst_t = pre_h[i - 1].squeeze(dim=0)

                # -1 means not changing the size of that dimension,
                # http://pytorch.org/docs/master/tensors.html
                rcst_mask = mask[:,
                                 i].contiguous().view(batch_size, -1).expand(
                                     batch_size, self.lstm_size)
                rcst_diff = rcst_h - rcst_t
                rcst_loss += torch.sum(
                    torch.sum(torch.mul(rcst_diff, rcst_diff) * rcst_mask,
                              dim=1)) / batch_size * self.rcst_scale

            # 更新 previous hidden state
            pre_h.append(state[0].clone())

        output_logits = torch.cat([_.unsqueeze(1) for _ in output_logits],
                                  1).contiguous()

        return output_logits, rcst_loss

    def sample_beam(self, fc_feats, att_feats, init_index, opt={}):
        beam_size = opt.get('beam_size',
                            10)  # 如果不能取到 beam_size 这个变量, 则令 beam_size 为 10
        batch_size = fc_feats.size(0)
        fc_feat_size = fc_feats.size(1)

        seq = torch.LongTensor(self.seq_length, batch_size).zero_()
        seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)

        top_seq = []
        top_prob = [[] for _ in range(batch_size)]

        self.done_beams = [[] for _ in range(batch_size)]

        for k in range(batch_size):
            init_h = self.fc2h(fc_feats[k].unsqueeze(0).expand(
                beam_size, fc_feat_size))
            init_h = init_h.unsqueeze(0)
            init_c = init_h.clone()
            state = (init_h, init_c)

            att_feats_current = att_feats[k].unsqueeze(0).expand(
                beam_size, att_feats.size(1), att_feats.size(2))
            att_feats_current = att_feats_current.contiguous()

            beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_()
            beam_seq_logprobs = torch.FloatTensor(self.seq_length,
                                                  beam_size).zero_()
            beam_logprobs_sum = torch.zeros(
                beam_size)  # running sum of logprobs for each beam
            for t in range(self.seq_length + 1):
                if t == 0:  # input <bos>
                    it = fc_feats.data.new(beam_size).long().fill_(init_index)
                    xt = self.embed(Variable(it, requires_grad=False))
                    # xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
                else:
                    """perform a beam merge. that is,
                    for every previous beam we now many new possibilities to branch out
                    we need to resort our beams to maintain the loop invariant of keeping
                    the top beam_size most likely sequences."""
                    logprobsf = logprobs.float(
                    )  # lets go to CPU for more efficiency in indexing operations
                    # ys: beam_size * (Vab_size + 1)
                    ys, ix = torch.sort(
                        logprobsf, 1, True
                    )  # sorted array of logprobs along each previous beam (last true = descending)
                    candidates = []
                    cols = min(beam_size, ys.size(1))
                    rows = beam_size
                    if t == 1:  # at first time step only the first beam is active
                        rows = 1
                    for c in range(cols):
                        for q in range(rows):
                            # compute logprob of expanding beam q with word in (sorted) position c
                            local_logprob = ys[q, c]
                            candidate_logprob = beam_logprobs_sum[
                                q] + local_logprob
                            if t > 1 and beam_seq[t - 2, q] == 0:
                                continue
                            candidates.append({
                                'c': ix.data[q, c],
                                'q': q,
                                'p': candidate_logprob.data[0],
                                'r': local_logprob.data[0]
                            })

                    if len(candidates) == 0:
                        break
                    candidates = sorted(candidates, key=lambda x: -x['p'])

                    # construct new beams
                    new_state = [_.clone() for _ in state]
                    if t > 1:
                        # well need these as reference when we fork beams around
                        beam_seq_prev = beam_seq[:t - 1].clone()
                        beam_seq_logprobs_prev = beam_seq_logprobs[:t -
                                                                   1].clone()

                    for vix in range(min(beam_size, len(candidates))):
                        v = candidates[vix]
                        # fork beam index q into index vix
                        if t > 1:
                            beam_seq[:t - 1, vix] = beam_seq_prev[:, v['q']]
                            beam_seq_logprobs[:t - 1,
                                              vix] = beam_seq_logprobs_prev[:, v[
                                                  'q']]

                        # rearrange recurrent states
                        for state_ix in range(len(new_state)):
                            # copy over state in previous beam q to new beam at vix
                            new_state[state_ix][0, vix] = state[state_ix][
                                0, v['q']]  # dimension one is time step

                        # append new end terminal at the end of this beam
                        beam_seq[t - 1,
                                 vix] = v['c']  # c'th word is the continuation
                        beam_seq_logprobs[t - 1,
                                          vix] = v['r']  # the raw logprob here
                        beam_logprobs_sum[vix] = v[
                            'p']  # the new (sum) logprob along this beam

                        if v['c'] == 0 or t == self.seq_length:
                            # END token special case here, or we reached the end.
                            # add the beam to a set of done beams
                            self.done_beams[k].append({
                                'seq':
                                beam_seq[:, vix].clone(),
                                'logps':
                                beam_seq_logprobs[:, vix].clone(),
                                'p':
                                beam_logprobs_sum[vix]
                            })

                    # encode as vectors
                    it = beam_seq[t - 1]
                    xt = self.embed(Variable(it.cuda()))

                if t >= 1:
                    state = new_state

                output, state = self.core.forward(xt, att_feats_current, state)
                logprobs = F.log_softmax(self.logit(output))

            self.done_beams[k] = sorted(self.done_beams[k],
                                        key=lambda x: -x['p'])
            seq[:, k] = self.done_beams[k][0][
                'seq']  # the first beam has highest cumulative score
            seqLogprobs[:, k] = self.done_beams[k][0]['logps']

            # save result
            l = len(self.done_beams[k])
            top_seq_cur = torch.LongTensor(l, self.seq_length).zero_()

            for temp_index in range(l):
                top_seq_cur[temp_index] = self.done_beams[k][temp_index][
                    'seq'].clone()
                top_prob[k].append(self.done_beams[k][temp_index]['p'])

            top_seq.append(top_seq_cur)
        # return the samples and their log likelihoods
        return seq.transpose(0, 1), seqLogprobs.transpose(0,
                                                          1), top_seq, top_prob

    def sample(self, fc_feats, att_feats, init_index, opt={}):
        sample_max = opt.get('sample_max', 1)
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)

        if beam_size > 1:
            return self.sample_beam(fc_feats, att_feats, init_index, opt)

        batch_size = fc_feats.size(0)
        seq = []
        seqLogprobs = []
        logprobs_all = []

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)

        for t in range(self.seq_length):
            if t == 0:  # input BOS, 304
                it = fc_feats.data.new(batch_size).long().fill_(init_index)
            elif sample_max:
                sampleLogprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()
            else:
                if temperature == 1.0:
                    prob_prev = torch.exp(logprobs.data).cpu(
                    )  # fetch prev distribution: shape Nx(M+1)
                else:
                    # scale logprobs by temperature
                    prob_prev = torch.exp(torch.div(logprobs.data,
                                                    temperature)).cpu()

                it = torch.multinomial(prob_prev, 1).cuda()

                sampleLogprobs = logprobs.gather(
                    1,
                    Variable(it, requires_grad=False).cuda(
                    ))  # gather the logprobs at sampled positions
                it = it.view(
                    -1).long()  # and flatten indices for downstream processing

            xt = self.embed(Variable(it, requires_grad=False).cuda())

            if t >= 1:
                # stop when all finished
                if t == 1:
                    unfinished = it > 0
                else:
                    unfinished = unfinished * (it > 0)
                if unfinished.sum() == 0:
                    break
                it = it * unfinished.type_as(it)
                seq.append(it)
                seqLogprobs.append(sampleLogprobs.view(-1))

            output, state = self.core.forward(xt, att_feats, state)

            logprobs = F.log_softmax(self.logit(output))
            logprobs_all.append(logprobs)

        greedy_seq = torch.cat([_.unsqueeze(1) for _ in seq], 1)
        greedy_seqLogprobs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs],
                                       1)
        greedy_logprobs_all = torch.cat([_.unsqueeze(1) for _ in logprobs_all],
                                        1).contiguous()

        return greedy_seq, greedy_seqLogprobs, greedy_logprobs_all

    def teacher_forcing_get_hidden_states(self, fc_feats, att_feats, seq):
        batch_size = fc_feats.size(0)

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)
        outputs = []

        for i in range(seq.size(1)):
            if i >= 1 and self.ss_prob > 0.0:  # otherwise no need to sample
                sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    prob_prev = torch.exp(
                        outputs[-1].data
                    )  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
                    it = Variable(it, requires_grad=False)
            else:
                it = seq[:, i].clone()

            # break if all the sequences end
            if i >= 1 and seq[:, i].data.sum() == 0:
                break

            xt = self.embed(it)

            output, state = self.core.forward(xt, att_feats, state)
            if batch_size == 1:
                output = F.log_softmax(self.logit(output), dim=1)
            else:
                output = F.log_softmax(self.logit(output.squeeze(0)), dim=1)
            outputs.append(output)

        # 返回 hidden states
        return state[0], outputs

    def free_running_get_hidden_states(self, fc_feats, att_feats, init_index,
                                       end_index):
        batch_size = fc_feats.size(0)
        logprobs_all = []

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)

        for t in range(self.seq_length):
            if t == 0:  # input BOS
                it = fc_feats.data.new(batch_size).long().fill_(init_index)

            xt = self.embed(Variable(it, requires_grad=False))

            output, state = self.core.forward(xt, att_feats, state)

            if batch_size == 1:
                logprobs = F.log_softmax(self.logit(output), dim=1)
            else:
                logprobs = F.log_softmax(self.logit(output.squeeze(0)), dim=1)
            logprobs_all.append(logprobs)

            _, it = torch.max(logprobs.data, 1)
            it = it.view(-1).long()
            if it.cpu().numpy()[0] == end_index:
                break

        return state[0], logprobs_all
Esempio n. 2
0
class ShowAttendTellModel(nn.Module):
    def __init__(self, opt):
        super(ShowAttendTellModel, self).__init__()

        self.vocab_size = opt.vocab_size
        self.input_encoding_size = opt.input_encoding_size
        self.lstm_size = opt.lstm_size
        # self.drop_prob_lm = opt.drop_prob_lm
        self.drop_prob_lm = 0.1
        self.seq_length = opt.seq_length

        self.fc_feat_size = opt.fc_feat_size
        self.conv_feat_size = opt.conv_feat_size
        self.conv_att_size = opt.conv_att_size
        self.att_hidden_size = opt.att_hidden_size
        self.ss_prob = opt.ss_prob  # Schedule sampling probability

        self.fc2h = nn.Linear(self.fc_feat_size, self.lstm_size)

        self.core = LSTMSoftAttentionCore(self.input_encoding_size,
                                          self.lstm_size, self.conv_feat_size,
                                          self.conv_att_size,
                                          self.att_hidden_size,
                                          self.drop_prob_lm)

        self.embed = nn.Embedding(self.vocab_size + 1,
                                  self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.vocab_size)

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embed.weight.data.uniform_(-initrange, initrange)
        self.fc2h.weight.data.uniform_(-initrange, initrange)
        self.logit.weight.data.uniform_(-initrange, initrange)
        self.logit.bias.data.fill_(0)

    def forward(self, fc_feats, att_feats, seq):
        batch_size = fc_feats.size(0)

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)
        outputs = []

        for i in range(seq.size(1) - 1):
            if i >= 1 and self.ss_prob > 0.0:
                sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    prob_prev = torch.exp(
                        outputs[-1].data
                    )  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
                    it = Variable(it, requires_grad=False)
            else:
                it = seq[:, i].clone()

            # break if all the sequences end
            if seq[:, i].data.sum() == 0:
                break

            xt = self.embed(it)

            output, state = self.core.forward(xt, att_feats, state)
            output = F.log_softmax(self.logit(output.squeeze(0)), dim=1)
            outputs.append(output)

        vocab_log_probs = torch.cat([_.unsqueeze(1) for _ in outputs],
                                    1).contiguous()

        return vocab_log_probs  # e.g. batch * 19 * vocab_size

    def sample_beam(self, fc_feats, att_feats, init_index, opt={}):
        beam_size = opt.get('beam_size', 3)
        batch_size = fc_feats.size(0)
        fc_feat_size = fc_feats.size(1)

        seq = torch.LongTensor(self.seq_length, batch_size).zero_()
        seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)

        top_seq = []
        top_prob = [[] for _ in range(batch_size)]

        self.done_beams = [[] for _ in range(batch_size)]

        for k in range(batch_size):
            init_h = self.fc2h(fc_feats[k].unsqueeze(0).expand(
                beam_size, fc_feat_size))
            init_h = init_h.unsqueeze(0)
            init_c = init_h.clone()
            state = (init_h, init_c)

            att_feats_current = att_feats[k].unsqueeze(0).expand(
                beam_size, att_feats.size(1), att_feats.size(2))
            att_feats_current = att_feats_current.contiguous()

            beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_()
            beam_seq_logprobs = torch.FloatTensor(self.seq_length,
                                                  beam_size).zero_()
            beam_logprobs_sum = torch.zeros(
                beam_size)  # running sum of logprobs for each beam

            for t in range(self.seq_length + 1):
                if t == 0:
                    it = fc_feats.data.new(beam_size).long().fill_(init_index)
                    xt = self.embed(Variable(it, requires_grad=False))
                else:
                    # lets go to CPU for more efficiency in indexing operations
                    logprobsf = logprobs.float()

                    # sorted array of logprobs along each previous beam (last true = descending)
                    # ys: beam_size * (Vab_size + 1)
                    ys, ix = torch.sort(logprobsf, 1, True)

                    candidates = []

                    cols = min(beam_size, ys.size(1))
                    rows = beam_size
                    if t == 1:  # at first time step only the first beam is active
                        rows = 1

                    for c in range(cols):
                        for q in range(rows):
                            # compute logprob of expanding beam q with word in (sorted) position c
                            local_logprob = ys[q, c]
                            candidate_logprob = beam_logprobs_sum[
                                q] + local_logprob
                            if t > 1 and beam_seq[t - 2, q] == 0:
                                continue
                            candidates.append({
                                'c': ix.data[q, c],
                                'q': q,
                                'p': candidate_logprob.data[0],
                                'r': local_logprob.data[0]
                            })

                    if len(candidates) == 0:
                        break
                    candidates = sorted(candidates, key=lambda x: -x['p'])

                    # construct new beams
                    new_state = [_.clone() for _ in state]
                    if t > 1:
                        # well need these as reference when we fork beams around
                        beam_seq_prev = beam_seq[:t - 1].clone()
                        beam_seq_logprobs_prev = beam_seq_logprobs[:t -
                                                                   1].clone()

                    for vix in range(min(beam_size, len(candidates))):
                        v = candidates[vix]
                        # fork beam index q into index vix
                        if t > 1:
                            beam_seq[:t - 1, vix] = beam_seq_prev[:, v['q']]
                            beam_seq_logprobs[:t - 1,
                                              vix] = beam_seq_logprobs_prev[:, v[
                                                  'q']]

                        # rearrange recurrent states
                        for state_ix in range(len(new_state)):
                            # copy over state in previous beam q to new beam at vix
                            new_state[state_ix][0, vix] = state[state_ix][
                                0, v['q']]  # dimension one is time step

                        # append new end terminal at the end of this beam
                        beam_seq[t - 1,
                                 vix] = v['c']  # c'th word is the continuation
                        beam_seq_logprobs[t - 1,
                                          vix] = v['r']  # the raw logprob here
                        beam_logprobs_sum[vix] = v[
                            'p']  # the new (sum) logprob along this beam

                        if v['c'] == 0 or t == self.seq_length:
                            # END token special case here, or we reached the end.
                            # add the beam to a set of done beams
                            self.done_beams[k].append({
                                'seq':
                                beam_seq[:, vix].clone(),
                                'logps':
                                beam_seq_logprobs[:, vix].clone(),
                                'p':
                                beam_logprobs_sum[vix]
                            })

                    # encode as vectors
                    it = beam_seq[t - 1]
                    xt = self.embed(Variable(it.cuda(), requires_grad=False))
                    state = new_state

                output, state = self.core.forward(xt, att_feats_current, state)
                logprobs = F.log_softmax(self.logit(output), dim=1)

            self.done_beams[k] = sorted(self.done_beams[k],
                                        key=lambda x: -x['p'])
            seq[:, k] = self.done_beams[k][0][
                'seq']  # the first beam has highest cumulative score
            seqLogprobs[:, k] = self.done_beams[k][0]['logps']

            # save result
            l = len(self.done_beams[k])
            top_seq_cur = torch.LongTensor(l, self.seq_length).zero_()

            for temp_index in range(l):
                top_seq_cur[temp_index] = self.done_beams[k][temp_index][
                    'seq'].clone()
                top_prob[k].append(self.done_beams[k][temp_index]['p'])

            top_seq.append(top_seq_cur)

        # return the samples and their log likelihoods
        return seq.transpose(0, 1), seqLogprobs.transpose(0,
                                                          1), top_seq, top_prob

    def sample(self, fc_feats, att_feats, init_index, opt={}):
        sample_max = opt.get('sample_max', 1)
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)

        if beam_size > 1:
            return self.sample_beam(fc_feats, att_feats, init_index, opt)

        batch_size = fc_feats.size(0)
        seq = []
        seqLogprobs = []
        logprobs_all = []

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)

        for t in range(self.seq_length):
            if t == 0:
                it = fc_feats.data.new(batch_size).long().fill_(init_index)
            elif sample_max:
                sampleLogprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()
            else:
                if temperature == 1.0:
                    prob_prev = torch.exp(logprobs.data).cpu(
                    )  # fetch prev distribution: shape Nx(M+1)
                else:
                    # scale logprobs by temperature
                    prob_prev = torch.exp(torch.div(logprobs.data,
                                                    temperature)).cpu()

                it = torch.multinomial(prob_prev, 1).cuda()
                sampleLogprobs = logprobs.gather(
                    1,
                    Variable(it, requires_grad=False).cuda(
                    ))  # gather the logprobs at sampled positions
                it = it.view(
                    -1).long()  # and flatten indices for downstream processing

            xt = self.embed(Variable(it, requires_grad=False).cuda())

            if t >= 1:
                if t == 1:
                    unfinished = it > 0
                else:
                    unfinished *= (it > 0)
                if unfinished.sum() == 0:
                    break
                it = it * unfinished.type_as(it)
                seq.append(it)
                seqLogprobs.append(sampleLogprobs.view(-1))

            output, state = self.core.forward(xt, att_feats, state)

            logprobs = F.log_softmax(self.logit(output), dim=1)
            logprobs_all.append(logprobs)

        greedy_seq = torch.cat([_.unsqueeze(1) for _ in seq], 1)
        greedy_seqLogprobs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs],
                                       1)
        greedy_logprobs_all = torch.cat([_.unsqueeze(1) for _ in logprobs_all],
                                        1).contiguous()

        return greedy_seq, greedy_seqLogprobs, greedy_logprobs_all

    def teacher_forcing_get_hidden_states(self, fc_feats, att_feats, seq):
        batch_size = fc_feats.size(0)

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)
        outputs = []

        for i in range(seq.size(1)):
            if i >= 1 and self.ss_prob > 0.0:  # otherwiste no need to sample
                sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    prob_prev = torch.exp(
                        outputs[-1].data
                    )  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
                    it = Variable(it, requires_grad=False)
            else:
                it = seq[:, i].clone()

            # break if all the sequences end
            if i >= 1 and seq[:, i].data.sum() == 0:
                break

            xt = self.embed(it)

            output, state = self.core.forward(xt, att_feats, state)
            if batch_size == 1:
                output = F.log_softmax(self.logit(output), dim=1)
            else:
                output = F.log_softmax(self.logit(output.squeeze(0)), dim=1)
            outputs.append(output)

        # return hidden states
        return state[0], outputs

    def free_running_get_hidden_states(self, fc_feats, att_feats, init_index,
                                       end_index):
        batch_size = fc_feats.size(0)
        logprobs_all = []

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)

        for t in range(self.seq_length):
            if t == 0:  # input BOS
                it = fc_feats.data.new(batch_size).long().fill_(init_index)

            xt = self.embed(Variable(it, requires_grad=False))

            output, state = self.core.forward(xt, att_feats, state)

            if batch_size == 1:
                logprobs = F.log_softmax(self.logit(output), dim=1)
            else:
                logprobs = F.log_softmax(self.logit(output.squeeze(0)), dim=1)
            logprobs_all.append(logprobs)

            _, it = torch.max(logprobs.data, 1)
            it = it.view(-1).long()
            if it.cpu().numpy()[0] == end_index:
                break

        return state[0], logprobs_all