Ejemplo n.º 1
0
    def __init__(self, opt):
        super(AttEncodeDecodeARNet, self).__init__()

        self.token_cnt = opt.token_cnt
        self.word_cnt = opt.word_cnt
        self.lstm_size = opt.lstm_size
        self.drop_prob = opt.drop_prob
        self.input_encoding_size = opt.input_encoding_size
        self.encode_time_step = opt.code_truncate
        self.decode_time_step = opt.comment_truncate
        self.ss_prob = opt.ss_prob

        self.encoding_feat_size = opt.lstm_size
        self.encoding_att_size = opt.encoding_att_size
        self.att_hidden_size = opt.att_hidden_size

        self.encode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size,
                                    self.drop_prob)
        self.decode_lstm = LSTMSoftAttentionCore(
            self.input_encoding_size, self.lstm_size, self.encoding_feat_size,
            self.encoding_att_size, self.att_hidden_size, self.drop_prob_lm)

        self.embed = nn.Embedding(self.token_cnt + 1, self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.word_cnt)
        self.init_weights()

        # ARNet
        self.rcst_weight = opt.rcst_weight
        self.rcst_lstm = LSTMCore(self.lstm_size, self.lstm_size,
                                  self.drop_prob_lm)
        self.h_2_pre_h = nn.Linear(self.lstm_size, self.lstm_size)
        self.rcst_init_weights()
Ejemplo n.º 2
0
    def __init__(self, opt):
        super(EncoderDecoder, self).__init__()

        self.vocab_size = opt.vocab_size
        self.input_encoding_size = opt.input_encoding_size
        self.lstm_size = opt.lstm_size
        self.drop_prob_lm = opt.drop_prob_lm
        self.seq_length = opt.seq_length
        self.fc_feat_size = opt.fc_feat_size

        self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
        self.LSTMCore = LSTMCore(self.input_encoding_size, self.lstm_size, self.drop_prob_lm)
        
        self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.vocab_size)

        self.init_weights()
Ejemplo n.º 3
0
    def __init__(self, opt):
        super(ReviewNet, self).__init__()

        self.token_cnt = opt.token_cnt
        self.word_cnt = opt.word_cnt
        self.lstm_size = opt.lstm_size
        self.drop_prob = opt.drop_prob

        self.input_encoding_size = opt.input_encoding_size
        self.encode_time_step = opt.code_truncate
        self.decode_time_step = opt.comment_truncate
        self.ss_prob = opt.ss_prob

        self.encoding_feat_size = opt.lstm_size
        self.encoding_att_size = opt.encoding_att_size
        self.att_hidden_size = opt.att_hidden_size
        self.num_review_steps = opt.num_review_steps
        self.drop_prob_reason = opt.drop_prob_reason

        # encoder
        self.encode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size, self.drop_prob)

        # reviewer
        self.review_steps = nn.ModuleList([LSTMSoftAttentionNoInputCore(self.input_encoding_size,
                                                                        self.lstm_size,
                                                                        self.encoding_feat_size,
                                                                        self.encoding_att_size,
                                                                        self.att_hidden_size,
                                                                        self.drop_prob_reason)
                                           for _ in range(self.num_review_steps)])

        # decoder
        self.decode_lstm = LSTMSoftAttentionCore(self.input_encoding_size,
                                                 self.lstm_size,
                                                 self.encoding_feat_size,
                                                 self.num_review_steps,
                                                 self.att_hidden_size,
                                                 self.drop_prob)

        self.embed = nn.Embedding(self.token_cnt + 1, self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.word_cnt)
        self.init_weights()
Ejemplo n.º 4
0
    def __init__(self, opt):
        super(EncodeDecode, self).__init__()

        self.token_cnt = opt.token_cnt
        self.word_cnt = opt.word_cnt
        self.lstm_size = opt.lstm_size
        self.drop_prob = opt.drop_prob
        self.input_encoding_size = opt.input_encoding_size
        self.encode_time_step = opt.code_truncate
        self.decode_time_step = opt.comment_truncate
        self.ss_prob = opt.ss_prob

        self.encode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size,
                                    self.drop_prob)
        self.decode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size,
                                    self.drop_prob)

        self.embed = nn.Embedding(self.token_cnt + 1, self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.word_cnt)

        self.init_weights()
Ejemplo n.º 5
0
    def __init__(self, opt):
        super(EncodeDecodeARNet, self).__init__()

        self.token_cnt = opt.token_cnt
        self.word_cnt = opt.word_cnt
        self.lstm_size = opt.lstm_size
        self.drop_prob = opt.drop_prob
        self.input_encoding_size = opt.input_encoding_size
        self.encode_time_step = opt.code_truncate
        self.decode_time_step = opt.comment_truncate

        self.encode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size,
                                    self.drop_prob)
        self.decode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size,
                                    self.drop_prob)

        self.embed = nn.Embedding(self.token_cnt + 1, self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.word_cnt)
        self.init_weights()

        # params of ARNet
        self.rcst_weight = opt.reconstruct_weight
        self.rcst_lstm = LSTMCore(self.lstm_size, self.lstm_size,
                                  self.drop_prob)
        self.h_2_pre_h = nn.Linear(self.lstm_size, self.lstm_size)
        self.rcst_init_weights()
Ejemplo n.º 6
0
    def __init__(self, opt):
        super(EncoderDecoder, self).__init__()

        self.vocab_size = opt.vocab_size
        self.input_encoding_size = opt.input_encoding_size
        self.lstm_size = opt.lstm_size
        self.ss_prob = opt.ss_prob
        self.seq_length = opt.seq_length
        self.fc_feat_size = opt.fc_feat_size

        self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
        self.LSTMCore = LSTMCore(self.input_encoding_size, self.lstm_size, self.ss_prob)

        self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.vocab_size)
        self.init_weights()

        # parameters of ARNet
        self.rcst_weight = opt.rcst_weight
        self.rcst_lstm = LSTMCore(self.input_encoding_size, self.lstm_size, self.ss_prob)
        self.h_2_pre_h = nn.Linear(self.lstm_size, self.lstm_size)
        self.rcst_init_weights()
Ejemplo n.º 7
0
    def __init__(self, opt):
        super(ShowAttendTellModel, self).__init__()

        self.vocab_size = opt.vocab_size
        self.input_encoding_size = opt.input_encoding_size
        self.lstm_size = opt.lstm_size
        # self.drop_prob_lm = opt.drop_prob_lm
        self.drop_prob_lm = 0.1
        self.seq_length = opt.seq_length

        self.fc_feat_size = opt.fc_feat_size
        self.conv_feat_size = opt.conv_feat_size
        self.conv_att_size = opt.conv_att_size
        self.att_hidden_size = opt.att_hidden_size
        self.ss_prob = opt.ss_prob  # Schedule sampling probability

        self.fc2h = nn.Linear(self.fc_feat_size, self.lstm_size)

        self.core = LSTMSoftAttentionCore(self.input_encoding_size,
                                          self.lstm_size, self.conv_feat_size,
                                          self.conv_att_size,
                                          self.att_hidden_size,
                                          self.drop_prob_lm)

        self.embed = nn.Embedding(self.vocab_size + 1,
                                  self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.vocab_size)

        # add the following parameters for ARNet
        self.rcst_time = opt.rcst_time
        self.rcst_scale = opt.rcst_weight  # lambda in ARNet
        self.rcstLSTM = LSTMCore(
            self.lstm_size, self.lstm_size,
            self.drop_prob_lm)  # ARNet is realized by LSTM network
        self.h_2_pre_h = nn.Linear(
            self.lstm_size, self.lstm_size)  # fully connected layer in ARNet

        self.init_weights()
Ejemplo n.º 8
0
class AttEncodeDecodeARNet(nn.Module):
    def __init__(self, opt):
        super(AttEncodeDecodeARNet, self).__init__()

        self.token_cnt = opt.token_cnt
        self.word_cnt = opt.word_cnt
        self.lstm_size = opt.lstm_size
        self.drop_prob = opt.drop_prob
        self.input_encoding_size = opt.input_encoding_size
        self.encode_time_step = opt.code_truncate
        self.decode_time_step = opt.comment_truncate
        self.ss_prob = opt.ss_prob

        self.encoding_feat_size = opt.lstm_size
        self.encoding_att_size = opt.encoding_att_size
        self.att_hidden_size = opt.att_hidden_size

        self.encode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size,
                                    self.drop_prob)
        self.decode_lstm = LSTMSoftAttentionCore(
            self.input_encoding_size, self.lstm_size, self.encoding_feat_size,
            self.encoding_att_size, self.att_hidden_size, self.drop_prob_lm)

        self.embed = nn.Embedding(self.token_cnt + 1, self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.word_cnt)
        self.init_weights()

        # ARNet
        self.rcst_weight = opt.rcst_weight
        self.rcst_lstm = LSTMCore(self.lstm_size, self.lstm_size,
                                  self.drop_prob_lm)
        self.h_2_pre_h = nn.Linear(self.lstm_size, self.lstm_size)
        self.rcst_init_weights()

    def init_weights(self):
        self.embed.weight.data.uniform_(-0.1, 0.1)
        self.logit.weight.data.uniform_(-0.1, 0.1)
        self.logit.bias.data.fill_(0)

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        init_h = Variable(weight.new(1, batch_size, self.lstm_size).zero_())
        init_c = Variable(weight.new(1, batch_size, self.lstm_size).zero_())
        init_state = (init_h, init_c)
        return init_state

    # init params of ARNet
    def rcst_init_weights(self):
        self.h_2_pre_h.weight.data.uniform_(-0.1, 0.1)
        self.h_2_pre_h.bias.data.fill_(0)

    def forward(self, code_matrix, comment_matrix, comment_mask):
        batch_size = code_matrix.size(0)
        encode_state = self.init_hidden(batch_size)
        decode_logit_seq = []
        outputs = []

        # encoder
        encode_hidden_states = []
        for i in range(self.encode_time_step):
            encode_words = code_matrix[:, i].clone()

            if code_matrix[:, i].data.sum() == 0:
                break

            encode_xt = self.embed(encode_words)
            encode_output, encode_state = self.encode_lstm.forward(
                encode_xt, encode_state)
            encode_hidden_states.append(encode_output)
        encode_hidden_states = torch.cat(
            [_.unsqueeze(1) for _ in encode_hidden_states],
            1)  # batch x 300 x 512

        # decoder
        decode_state = (encode_state[0].clone(), encode_state[1].clone())
        rcst_state = (encode_state[0].clone(), encode_state[1].clone())
        pre_h = encode_state[0].clone()
        rcst_loss = 0.0

        for i in range(self.decode_time_step):
            if i >= 1 and self.ss_prob > 0.0:
                sample_prob = comment_mask.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = comment_matrix[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = comment_matrix[:, i].data.clone()
                    prob_prev = torch.exp(
                        outputs[-1].data
                    )  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
                    it = Variable(it, requires_grad=False)
            else:
                it = comment_matrix[:, i].clone()

            if i >= 1 and comment_matrix[:, i].data.sum() == 0:
                break

            decode_xt = self.embed(it)
            decode_output, decode_state = self.decode_lstm.forward(
                decode_xt, encode_hidden_states, decode_state)
            decode_logit_words = F.log_softmax(self.logit(decode_output))
            decode_logit_seq.append(decode_logit_words)
            outputs.append(decode_logit_words)

            # ARNet part
            rcst_output, rcst_state = self.rcst_lstm.forward(
                decode_output, rcst_state)
            rcst_h = self.h_2_pre_h(rcst_output)
            rcst_diff = rcst_h - pre_h
            rcst_mask = comment_mask[:, i].contiguous().view(
                -1, batch_size).repeat(1, self.lstm_size)
            cur_rcst_loss = torch.sum(
                torch.sum(torch.mul(rcst_diff, rcst_diff) * rcst_mask, dim=1))
            rcst_loss += cur_rcst_loss * self.rcst_weight / torch.sum(
                comment_mask[:, i])

            # update previous hidden state
            pre_h = decode_state[0].clone()

        # aggregate
        decode_logit_seq = torch.cat(
            [_.unsqueeze(1) for _ in decode_logit_seq], 1).contiguous()

        return decode_logit_seq, rcst_loss

    def sample(self, code_matrix, init_index, eos_index):
        batch_size = code_matrix.size(0)
        encode_state = self.init_hidden(batch_size)

        seq = []
        seqLogprobs = []
        logprobs_all = []

        # encoder
        encode_hidden_states = []
        for i in range(self.encode_time_step):
            encode_words = code_matrix[:, i].clone()

            if code_matrix[:, i].data.sum() == 0:
                break

            encode_xt = self.embed(encode_words)
            encode_output, encode_state = self.encode_lstm.forward(
                encode_xt, encode_state)
            encode_hidden_states.append(encode_output)
        encode_hidden_states = torch.cat(
            [_.unsqueeze(1) for _ in encode_hidden_states], 1)

        # decoder
        decode_state = (encode_state[0].clone(), encode_state[1].clone())
        for i in range(self.decode_time_step):
            if i == 0:
                it = code_matrix.data.new(batch_size).long().fill_(init_index)
                decode_xt = self.embed(
                    Variable(it, requires_grad=False).cuda())
                decode_output, decode_state = self.decode_lstm.forward(
                    decode_xt, encode_hidden_states, decode_state)
            else:
                max_logprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()

                if it.sum() == eos_index:
                    break

                decode_xt = self.embed(
                    Variable(it, requires_grad=False).cuda())
                decode_output, decode_state = self.decode_lstm.forward(
                    decode_xt, encode_hidden_states, decode_state)

                seq.append(it)
                seqLogprobs.append(max_logprobs.view(-1))

            logprobs = F.log_softmax(self.logit(decode_output))
            logprobs_all.append(logprobs)

        greedy_seq = torch.cat([_.unsqueeze(1) for _ in seq], 1).contiguous()
        greedy_seq_probs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs],
                                     1).contiguous()
        greedy_logprobs_all = torch.cat([_.unsqueeze(1) for _ in logprobs_all],
                                        1).contiguous()

        return greedy_seq, greedy_seq_probs, greedy_logprobs_all
Ejemplo n.º 9
0
class ReviewNet(nn.Module):
    def __init__(self, opt):
        super(ReviewNet, self).__init__()

        self.token_cnt = opt.token_cnt
        self.word_cnt = opt.word_cnt
        self.lstm_size = opt.lstm_size
        self.drop_prob = opt.drop_prob

        self.input_encoding_size = opt.input_encoding_size
        self.encode_time_step = opt.code_truncate
        self.decode_time_step = opt.comment_truncate
        self.ss_prob = opt.ss_prob

        self.encoding_feat_size = opt.lstm_size
        self.encoding_att_size = opt.encoding_att_size
        self.att_hidden_size = opt.att_hidden_size
        self.num_review_steps = opt.num_review_steps
        self.drop_prob_reason = opt.drop_prob_reason

        # encoder
        self.encode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size, self.drop_prob)

        # reviewer
        self.review_steps = nn.ModuleList([LSTMSoftAttentionNoInputCore(self.input_encoding_size,
                                                                        self.lstm_size,
                                                                        self.encoding_feat_size,
                                                                        self.encoding_att_size,
                                                                        self.att_hidden_size,
                                                                        self.drop_prob_reason)
                                           for _ in range(self.num_review_steps)])

        # decoder
        self.decode_lstm = LSTMSoftAttentionCore(self.input_encoding_size,
                                                 self.lstm_size,
                                                 self.encoding_feat_size,
                                                 self.num_review_steps,
                                                 self.att_hidden_size,
                                                 self.drop_prob)

        self.embed = nn.Embedding(self.token_cnt + 1, self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.word_cnt)
        self.init_weights()

    def init_weights(self):
        self.embed.weight.data.uniform_(-0.1, 0.1)
        self.logit.weight.data.uniform_(-0.1, 0.1)
        self.logit.bias.data.fill_(0)

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        init_h = Variable(weight.new(1, batch_size, self.lstm_size).zero_())
        init_c = Variable(weight.new(1, batch_size, self.lstm_size).zero_())
        init_state = (init_h, init_c)
        return init_state

    def forward(self, code_matrix, comment_matrix, current_comment_mask_cuda):
        batch_size = code_matrix.size(0)
        encode_state = self.init_hidden(batch_size)
        decode_logit_seq = []
        outputs = []

        # encoder
        encode_hidden_states = []
        for i in range(self.encode_time_step):
            encode_words = code_matrix[:, i].clone()
            encode_xt = self.embed(encode_words)
            encode_output, encode_state = self.encode_lstm.forward(encode_xt, encode_state)
            encode_hidden_states.append(encode_output)
        encode_hidden_states = torch.cat([_.unsqueeze(1) for _ in encode_hidden_states], 1)  # batch x 300 x 512

        # reviewer
        review_state = (encode_state[0].clone(), encode_state[1].clone())
        thought = []
        for i in range(self.num_review_steps):
            review_output, review_state = self.review_steps[i].forward(encode_hidden_states, review_state)
            thought.append(review_output)
        thought_vectors = torch.stack(thought).transpose(0, 1).cuda().contiguous()  # thoughts vectors

        # decoder
        decode_state = (encode_state[0].clone(), encode_state[1].clone())
        for i in range(self.decode_time_step):
            if i >= 1 and self.ss_prob > 0.0:
                sample_prob = current_comment_mask_cuda.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = comment_matrix[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = comment_matrix[:, i].data.clone()
                    prob_prev = torch.exp(outputs[-1].data)  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
                    it = Variable(it, requires_grad=False)
            else:
                it = comment_matrix[:, i].clone()

            if i >= 1 and comment_matrix[:, i].data.sum() == 0:
                break

            decode_xt = self.embed(it)
            decode_output, decode_state = self.decode_lstm.forward(decode_xt, thought_vectors, decode_state)

            decode_logit_words = F.log_softmax(self.logit(decode_output))
            decode_logit_seq.append(decode_logit_words)
            outputs.append(decode_logit_words)

        # aggregate
        decode_logit_seq = torch.cat([_.unsqueeze(1) for _ in decode_logit_seq], 1).contiguous()

        return decode_logit_seq

    def sample(self, code_matrix, init_index, eos_index):
        batch_size = code_matrix.size(0)
        encode_state = self.init_hidden(batch_size)

        seq = []
        seqLogprobs = []
        logprobs_all = []

        # encoder
        encode_hidden_states = []
        for i in range(self.encode_time_step):
            encode_words = code_matrix[:, i].clone()

            if code_matrix[:, i].data.sum() == 0:
                break
            encode_xt = self.embed(encode_words)
            encode_output, encode_state = self.encode_lstm.forward(encode_xt, encode_state)
            encode_hidden_states.append(encode_output)
        encode_hidden_states = torch.cat([_.unsqueeze(1) for _ in encode_hidden_states], 1)

        # reviewer
        review_state = (encode_state[0].clone(), encode_state[1].clone())
        thought = []
        for i in range(self.num_review_steps):
            review_output, review_state = self.review_steps[i].forward(encode_hidden_states, review_state)
            thought.append(review_output)
        thought_vectors = torch.stack(thought).transpose(0, 1).cuda().contiguous()  # thoughts vectors

        # decoder
        decode_state = (encode_state[0].clone(), encode_state[1].clone())
        for i in range(self.decode_time_step):
            if i == 0:
                it = code_matrix.data.new(batch_size).long().fill_(init_index)
                decode_xt = self.embed(Variable(it, requires_grad=False).cuda())
                decode_output, decode_state = self.decode_lstm.forward(decode_xt, thought_vectors, decode_state)
            else:
                max_logprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()

                if it.sum() == eos_index:
                    break

                decode_xt = self.embed(Variable(it, requires_grad=False).cuda())
                decode_output, decode_state = self.decode_lstm.forward(decode_xt, thought_vectors, decode_state)

                seq.append(it)
                seqLogprobs.append(max_logprobs.view(-1))

            logprobs = F.log_softmax(self.logit(decode_output))
            logprobs_all.append(logprobs)

        # aggregate
        greedy_seq = torch.cat([_.unsqueeze(1) for _ in seq], 1).contiguous()
        greedy_seq_probs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1).contiguous()
        greedy_logprobs_all = torch.cat([_.unsqueeze(1) for _ in logprobs_all], 1).contiguous()

        return greedy_seq, greedy_seq_probs, greedy_logprobs_all
Ejemplo n.º 10
0
class EncodeDecodeARNet(nn.Module):
    def __init__(self, opt):
        super(EncodeDecodeARNet, self).__init__()

        self.token_cnt = opt.token_cnt
        self.word_cnt = opt.word_cnt
        self.lstm_size = opt.lstm_size
        self.drop_prob = opt.drop_prob
        self.input_encoding_size = opt.input_encoding_size
        self.encode_time_step = opt.code_truncate
        self.decode_time_step = opt.comment_truncate

        self.encode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size,
                                    self.drop_prob)
        self.decode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size,
                                    self.drop_prob)

        self.embed = nn.Embedding(self.token_cnt + 1, self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.word_cnt)
        self.init_weights()

        # params of ARNet
        self.rcst_weight = opt.reconstruct_weight
        self.rcst_lstm = LSTMCore(self.lstm_size, self.lstm_size,
                                  self.drop_prob)
        self.h_2_pre_h = nn.Linear(self.lstm_size, self.lstm_size)
        self.rcst_init_weights()

    def init_weights(self):
        self.embed.weight.data.uniform_(-0.1, 0.1)
        self.logit.weight.data.uniform_(-0.1, 0.1)
        self.logit.bias.data.fill_(0)

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        init_h = Variable(weight.new(1, batch_size, self.lstm_size).zero_())
        init_c = Variable(weight.new(1, batch_size, self.lstm_size).zero_())
        init_state = (init_h, init_c)
        return init_state

    # init
    def rcst_init_weights(self):
        self.h_2_pre_h.weight.data.uniform_(-0.1, 0.1)
        self.h_2_pre_h.bias.data.fill_(0)

    # copy weights from pre-trained model with cross entropy
    def copy_weights(self, model_path):
        src_weights = torch.load(model_path)
        own_dict = self.state_dict()
        for key, var in src_weights.items():
            print("copy weights: {}  size: {}".format(key, var.size()))
            own_dict[key].copy_(var)

    def forward(self, code_matrix, comment_matrix, comment_mask):
        batch_size = code_matrix.size(0)
        encode_state = self.init_hidden(batch_size)
        decode_logit_seq = []

        # encoder
        for i in range(self.encode_time_step):
            encode_words = code_matrix[:, i].clone()

            if code_matrix[:, i].data.sum() == 0:
                break

            encode_xt = self.embed(encode_words)
            encode_output, encode_state = self.encode_lstm.forward(
                encode_xt, encode_state)

        # decoder
        decode_state = (encode_state[0].clone(), encode_state[1].clone())
        rcst_state = (encode_state[0].clone(), encode_state[1].clone())
        pre_h = encode_state[0].clone()
        rcst_loss = 0.0

        for i in range(self.decode_time_step):
            decode_words = comment_matrix[:, i].clone()

            if comment_matrix[:, i].data.sum() == 0:
                break

            decode_xt = self.embed(decode_words)
            decode_output, decode_state = self.decode_lstm.forward(
                decode_xt, decode_state)

            decode_logit_words = F.log_softmax(self.logit(decode_output))
            decode_logit_seq.append(decode_logit_words)

            # ARNet
            rcst_state, rcst_state = self.rcst_lstm.forward(
                decode_output, rcst_state)
            rcst_h = self.h_2_pre_h(rcst_state)

            rcst_diff = rcst_h - pre_h
            rcst_mask = comment_mask[:, i].contiguous().view(
                -1, batch_size).repeat(1, self.lstm_size)

            cur_rcst_loss = torch.sum(
                torch.sum(torch.mul(rcst_diff, rcst_diff) * rcst_mask, dim=1))
            rcst_loss += cur_rcst_loss * self.rcst_weight / torch.sum(
                comment_mask[:, i])

            # update previous hidden state
            pre_h = decode_state[0].clone()

        # aggregate
        decode_logit_seq = torch.cat(
            [_.unsqueeze(1) for _ in decode_logit_seq], 1).contiguous()

        return decode_logit_seq, rcst_loss

    def sample(self, code_matrix, init_index, eos_index):
        batch_size = code_matrix.size(0)
        encode_state = self.init_hidden(batch_size)

        seq = []
        seqLogprobs = []
        logprobs_all = []

        # encoder
        for i in range(self.encode_time_step):
            encode_words = code_matrix[:, i].clone()

            if code_matrix[:, i].data.sum() == 0:
                break

            encode_xt = self.embed(encode_words)
            encode_output, encode_state = self.encode_lstm.forward(
                encode_xt, encode_state)

        # decoder
        decode_state = (encode_state[0].clone(), encode_state[1].clone())
        for i in range(self.decode_time_step):
            if i == 0:
                it = code_matrix.data.new(batch_size).long().fill_(init_index)
                decode_xt = self.embed(
                    Variable(it, requires_grad=False).cuda())
                decode_output, decode_state = self.decode_lstm.forward(
                    decode_xt, decode_state)
            else:
                max_logprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()

                if it.sum() == eos_index:
                    break

                decode_xt = self.embed(
                    Variable(it, requires_grad=False).cuda())
                decode_output, decode_state = self.decode_lstm.forward(
                    decode_xt, decode_state)

                seq.append(it)
                seqLogprobs.append(max_logprobs.view(-1))

            logprobs = F.log_softmax(self.logit(decode_output))
            logprobs_all.append(logprobs)

        # aggregate
        greedy_seq = torch.cat([_.unsqueeze(1) for _ in seq], 1).contiguous()
        greedy_seq_probs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs],
                                     1).contiguous()
        greedy_logprobs_all = torch.cat([_.unsqueeze(1) for _ in logprobs_all],
                                        1).contiguous()

        return greedy_seq, greedy_seq_probs, greedy_logprobs_all
Ejemplo n.º 11
0
class ShowAttendTellModel(nn.Module):
    def __init__(self, opt):
        super(ShowAttendTellModel, self).__init__()

        self.vocab_size = opt.vocab_size
        self.input_encoding_size = opt.input_encoding_size
        self.lstm_size = opt.lstm_size
        # self.drop_prob_lm = opt.drop_prob_lm
        self.drop_prob_lm = 0.1
        self.seq_length = opt.seq_length

        self.fc_feat_size = opt.fc_feat_size
        self.conv_feat_size = opt.conv_feat_size
        self.conv_att_size = opt.conv_att_size
        self.att_hidden_size = opt.att_hidden_size
        self.ss_prob = opt.ss_prob  # Schedule sampling probability

        self.fc2h = nn.Linear(self.fc_feat_size, self.lstm_size)

        self.core = LSTMSoftAttentionCore(self.input_encoding_size,
                                          self.lstm_size, self.conv_feat_size,
                                          self.conv_att_size,
                                          self.att_hidden_size,
                                          self.drop_prob_lm)

        self.embed = nn.Embedding(self.vocab_size + 1,
                                  self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.vocab_size)

        # add the following parameters for ARNet
        self.rcst_time = opt.rcst_time
        self.rcst_scale = opt.rcst_weight  # lambda in ARNet
        self.rcstLSTM = LSTMCore(
            self.lstm_size, self.lstm_size,
            self.drop_prob_lm)  # ARNet is realized by LSTM network
        self.h_2_pre_h = nn.Linear(
            self.lstm_size, self.lstm_size)  # fully connected layer in ARNet

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embed.weight.data.uniform_(-initrange, initrange)
        self.fc2h.weight.data.uniform_(-initrange, initrange)
        self.logit.weight.data.uniform_(-initrange, initrange)
        self.logit.bias.data.fill_(0)

        # initialize weights of parameters in ARNet
        self.h_2_pre_h.weight.data.uniform_(-initrange, initrange)
        self.h_2_pre_h.bias.data.fill_(0)

    def copy_weights(self, model_path):
        """
        Initialize the weights of parameters from the model 
        which is pre-trained by Cross Entropy (MLE)
        """
        src_weights = torch.load(model_path)
        own_dict = self.state_dict()
        for key, var in src_weights.items():
            print("copy weights: {}  size: {}".format(key, var.size()))
            own_dict[key].copy_(var)

    def forward(self, fc_feats, att_feats, seq):
        batch_size = fc_feats.size(0)

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)

        outputs = []
        for i in range(seq.size(1)):
            if i >= 1 and self.ss_prob > 0.0:
                sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    prob_prev = torch.exp(
                        outputs[-1].data
                    )  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
                    it = Variable(it, requires_grad=False)
            else:
                it = seq[:, i].clone()

            # break if all the sequences end
            if i >= 1 and seq[:, i].data.sum() == 0:
                break

            xt = self.embed(it)

            output, state = self.core.forward(xt, att_feats, state)
            output = F.log_softmax(self.logit(output.squeeze(0)), dim=1)
            outputs.append(output)

        return torch.cat([_.unsqueeze(1) for _ in outputs],
                         1).contiguous()  # batch * 19 * vocab_size

    # reconstruct 部分
    def rcst_forward(self, fc_feats, att_feats, seq, mask):
        batch_size = fc_feats.size(0)

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)

        rcst_init_h = init_h.clone()
        rcst_init_c = init_c.clone()
        rcst_state = (rcst_init_h, rcst_init_c)

        pre_h = []
        output_logits = []
        rcst_loss = 0.0

        for i in range(seq.size(1) - 1):
            if i >= 1 and self.ss_prob > 0.0:  # otherwise no need to sample
                sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    prob_prev = torch.exp(
                        output_logits[-1].data
                    )  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
                    it = Variable(it, requires_grad=False)
            else:
                it = seq[:, i].clone()

            # break if all the sequences end
            if i >= 1 and seq[:, i].data.sum() == 0:
                break

            xt = self.embed(it)
            output, state = self.core.forward(xt, att_feats, state)
            logit_words = F.log_softmax(self.logit(output.squeeze(0)))
            output_logits.append(logit_words)

            if i >= 1:
                rcst_output, rcst_state = self.rcstLSTM.forward(
                    output, rcst_state)
                rcst_h = F.leaky_relu(self.h_2_pre_h(rcst_output))
                rcst_t = pre_h[i - 1].squeeze(dim=0)

                # -1 means not changing the size of that dimension,
                # http://pytorch.org/docs/master/tensors.html
                rcst_mask = mask[:,
                                 i].contiguous().view(batch_size, -1).expand(
                                     batch_size, self.lstm_size)
                rcst_diff = rcst_h - rcst_t
                rcst_loss += torch.sum(
                    torch.sum(torch.mul(rcst_diff, rcst_diff) * rcst_mask,
                              dim=1)) / batch_size * self.rcst_scale

            # 更新 previous hidden state
            pre_h.append(state[0].clone())

        output_logits = torch.cat([_.unsqueeze(1) for _ in output_logits],
                                  1).contiguous()

        return output_logits, rcst_loss

    def sample_beam(self, fc_feats, att_feats, init_index, opt={}):
        beam_size = opt.get('beam_size',
                            10)  # 如果不能取到 beam_size 这个变量, 则令 beam_size 为 10
        batch_size = fc_feats.size(0)
        fc_feat_size = fc_feats.size(1)

        seq = torch.LongTensor(self.seq_length, batch_size).zero_()
        seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)

        top_seq = []
        top_prob = [[] for _ in range(batch_size)]

        self.done_beams = [[] for _ in range(batch_size)]

        for k in range(batch_size):
            init_h = self.fc2h(fc_feats[k].unsqueeze(0).expand(
                beam_size, fc_feat_size))
            init_h = init_h.unsqueeze(0)
            init_c = init_h.clone()
            state = (init_h, init_c)

            att_feats_current = att_feats[k].unsqueeze(0).expand(
                beam_size, att_feats.size(1), att_feats.size(2))
            att_feats_current = att_feats_current.contiguous()

            beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_()
            beam_seq_logprobs = torch.FloatTensor(self.seq_length,
                                                  beam_size).zero_()
            beam_logprobs_sum = torch.zeros(
                beam_size)  # running sum of logprobs for each beam
            for t in range(self.seq_length + 1):
                if t == 0:  # input <bos>
                    it = fc_feats.data.new(beam_size).long().fill_(init_index)
                    xt = self.embed(Variable(it, requires_grad=False))
                    # xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)
                else:
                    """perform a beam merge. that is,
                    for every previous beam we now many new possibilities to branch out
                    we need to resort our beams to maintain the loop invariant of keeping
                    the top beam_size most likely sequences."""
                    logprobsf = logprobs.float(
                    )  # lets go to CPU for more efficiency in indexing operations
                    # ys: beam_size * (Vab_size + 1)
                    ys, ix = torch.sort(
                        logprobsf, 1, True
                    )  # sorted array of logprobs along each previous beam (last true = descending)
                    candidates = []
                    cols = min(beam_size, ys.size(1))
                    rows = beam_size
                    if t == 1:  # at first time step only the first beam is active
                        rows = 1
                    for c in range(cols):
                        for q in range(rows):
                            # compute logprob of expanding beam q with word in (sorted) position c
                            local_logprob = ys[q, c]
                            candidate_logprob = beam_logprobs_sum[
                                q] + local_logprob
                            if t > 1 and beam_seq[t - 2, q] == 0:
                                continue
                            candidates.append({
                                'c': ix.data[q, c],
                                'q': q,
                                'p': candidate_logprob.data[0],
                                'r': local_logprob.data[0]
                            })

                    if len(candidates) == 0:
                        break
                    candidates = sorted(candidates, key=lambda x: -x['p'])

                    # construct new beams
                    new_state = [_.clone() for _ in state]
                    if t > 1:
                        # well need these as reference when we fork beams around
                        beam_seq_prev = beam_seq[:t - 1].clone()
                        beam_seq_logprobs_prev = beam_seq_logprobs[:t -
                                                                   1].clone()

                    for vix in range(min(beam_size, len(candidates))):
                        v = candidates[vix]
                        # fork beam index q into index vix
                        if t > 1:
                            beam_seq[:t - 1, vix] = beam_seq_prev[:, v['q']]
                            beam_seq_logprobs[:t - 1,
                                              vix] = beam_seq_logprobs_prev[:, v[
                                                  'q']]

                        # rearrange recurrent states
                        for state_ix in range(len(new_state)):
                            # copy over state in previous beam q to new beam at vix
                            new_state[state_ix][0, vix] = state[state_ix][
                                0, v['q']]  # dimension one is time step

                        # append new end terminal at the end of this beam
                        beam_seq[t - 1,
                                 vix] = v['c']  # c'th word is the continuation
                        beam_seq_logprobs[t - 1,
                                          vix] = v['r']  # the raw logprob here
                        beam_logprobs_sum[vix] = v[
                            'p']  # the new (sum) logprob along this beam

                        if v['c'] == 0 or t == self.seq_length:
                            # END token special case here, or we reached the end.
                            # add the beam to a set of done beams
                            self.done_beams[k].append({
                                'seq':
                                beam_seq[:, vix].clone(),
                                'logps':
                                beam_seq_logprobs[:, vix].clone(),
                                'p':
                                beam_logprobs_sum[vix]
                            })

                    # encode as vectors
                    it = beam_seq[t - 1]
                    xt = self.embed(Variable(it.cuda()))

                if t >= 1:
                    state = new_state

                output, state = self.core.forward(xt, att_feats_current, state)
                logprobs = F.log_softmax(self.logit(output))

            self.done_beams[k] = sorted(self.done_beams[k],
                                        key=lambda x: -x['p'])
            seq[:, k] = self.done_beams[k][0][
                'seq']  # the first beam has highest cumulative score
            seqLogprobs[:, k] = self.done_beams[k][0]['logps']

            # save result
            l = len(self.done_beams[k])
            top_seq_cur = torch.LongTensor(l, self.seq_length).zero_()

            for temp_index in range(l):
                top_seq_cur[temp_index] = self.done_beams[k][temp_index][
                    'seq'].clone()
                top_prob[k].append(self.done_beams[k][temp_index]['p'])

            top_seq.append(top_seq_cur)
        # return the samples and their log likelihoods
        return seq.transpose(0, 1), seqLogprobs.transpose(0,
                                                          1), top_seq, top_prob

    def sample(self, fc_feats, att_feats, init_index, opt={}):
        sample_max = opt.get('sample_max', 1)
        beam_size = opt.get('beam_size', 1)
        temperature = opt.get('temperature', 1.0)

        if beam_size > 1:
            return self.sample_beam(fc_feats, att_feats, init_index, opt)

        batch_size = fc_feats.size(0)
        seq = []
        seqLogprobs = []
        logprobs_all = []

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)

        for t in range(self.seq_length):
            if t == 0:  # input BOS, 304
                it = fc_feats.data.new(batch_size).long().fill_(init_index)
            elif sample_max:
                sampleLogprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()
            else:
                if temperature == 1.0:
                    prob_prev = torch.exp(logprobs.data).cpu(
                    )  # fetch prev distribution: shape Nx(M+1)
                else:
                    # scale logprobs by temperature
                    prob_prev = torch.exp(torch.div(logprobs.data,
                                                    temperature)).cpu()

                it = torch.multinomial(prob_prev, 1).cuda()

                sampleLogprobs = logprobs.gather(
                    1,
                    Variable(it, requires_grad=False).cuda(
                    ))  # gather the logprobs at sampled positions
                it = it.view(
                    -1).long()  # and flatten indices for downstream processing

            xt = self.embed(Variable(it, requires_grad=False).cuda())

            if t >= 1:
                # stop when all finished
                if t == 1:
                    unfinished = it > 0
                else:
                    unfinished = unfinished * (it > 0)
                if unfinished.sum() == 0:
                    break
                it = it * unfinished.type_as(it)
                seq.append(it)
                seqLogprobs.append(sampleLogprobs.view(-1))

            output, state = self.core.forward(xt, att_feats, state)

            logprobs = F.log_softmax(self.logit(output))
            logprobs_all.append(logprobs)

        greedy_seq = torch.cat([_.unsqueeze(1) for _ in seq], 1)
        greedy_seqLogprobs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs],
                                       1)
        greedy_logprobs_all = torch.cat([_.unsqueeze(1) for _ in logprobs_all],
                                        1).contiguous()

        return greedy_seq, greedy_seqLogprobs, greedy_logprobs_all

    def teacher_forcing_get_hidden_states(self, fc_feats, att_feats, seq):
        batch_size = fc_feats.size(0)

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)
        outputs = []

        for i in range(seq.size(1)):
            if i >= 1 and self.ss_prob > 0.0:  # otherwise no need to sample
                sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    prob_prev = torch.exp(
                        outputs[-1].data
                    )  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
                    it = Variable(it, requires_grad=False)
            else:
                it = seq[:, i].clone()

            # break if all the sequences end
            if i >= 1 and seq[:, i].data.sum() == 0:
                break

            xt = self.embed(it)

            output, state = self.core.forward(xt, att_feats, state)
            if batch_size == 1:
                output = F.log_softmax(self.logit(output), dim=1)
            else:
                output = F.log_softmax(self.logit(output.squeeze(0)), dim=1)
            outputs.append(output)

        # 返回 hidden states
        return state[0], outputs

    def free_running_get_hidden_states(self, fc_feats, att_feats, init_index,
                                       end_index):
        batch_size = fc_feats.size(0)
        logprobs_all = []

        init_h = self.fc2h(fc_feats)
        init_h = init_h.unsqueeze(0)
        init_c = init_h.clone()
        state = (init_h, init_c)

        for t in range(self.seq_length):
            if t == 0:  # input BOS
                it = fc_feats.data.new(batch_size).long().fill_(init_index)

            xt = self.embed(Variable(it, requires_grad=False))

            output, state = self.core.forward(xt, att_feats, state)

            if batch_size == 1:
                logprobs = F.log_softmax(self.logit(output), dim=1)
            else:
                logprobs = F.log_softmax(self.logit(output.squeeze(0)), dim=1)
            logprobs_all.append(logprobs)

            _, it = torch.max(logprobs.data, 1)
            it = it.view(-1).long()
            if it.cpu().numpy()[0] == end_index:
                break

        return state[0], logprobs_all
Ejemplo n.º 12
0
class AttEncodeDecode(nn.Module):
    def __init__(self, opt):
        super(AttEncodeDecode, self).__init__()

        self.token_cnt = opt.token_cnt
        self.word_cnt = opt.word_cnt
        self.lstm_size = opt.lstm_size
        self.drop_prob = opt.drop_prob
        self.input_encoding_size = opt.input_encoding_size
        self.encode_time_step = opt.code_truncate
        self.decode_time_step = opt.comment_truncate
        self.ss_prob = opt.ss_prob

        self.encoding_feat_size = opt.lstm_size
        self.encoding_att_size = opt.encoding_att_size
        self.att_hidden_size = opt.att_hidden_size

        self.encode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size,
                                    self.drop_prob)
        self.decode_lstm = LSTMSoftAttentionCore(
            self.input_encoding_size, self.lstm_size, self.encoding_feat_size,
            self.encoding_att_size, self.att_hidden_size, self.drop_prob_lm)

        self.embed = nn.Embedding(self.token_cnt + 1, self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.word_cnt)
        self.init_weights()

    def init_weights(self):
        self.embed.weight.data.uniform_(-0.1, 0.1)
        self.logit.weight.data.uniform_(-0.1, 0.1)
        self.logit.bias.data.fill_(0)

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        init_h = Variable(weight.new(1, batch_size, self.lstm_size).zero_())
        init_c = Variable(weight.new(1, batch_size, self.lstm_size).zero_())
        init_state = (init_h, init_c)
        return init_state

    def copy_weights(self, model_path):
        src_weights = torch.load(model_path)
        own_dict = self.state_dict()
        for key, var in own_dict.items():
            print("copy weights: {}  size: {}".format(key, var.size()))
            own_dict[key].copy_(src_weights[key])

    def forward(self, code_matrix, comment_matrix, comment_mask):
        batch_size = code_matrix.size(0)
        encode_state = self.init_hidden(batch_size)
        decode_logit_seq = []
        outputs = []

        # encoder
        encode_hidden_states = []
        for i in range(self.encode_time_step):
            encode_words = code_matrix[:, i].clone()

            if code_matrix[:, i].data.sum() == 0:
                break

            encode_xt = self.embed(encode_words)
            encode_output, encode_state = self.encode_lstm.forward(
                encode_xt, encode_state)
            encode_hidden_states.append(encode_output)
        encode_hidden_states = torch.cat(
            [_.unsqueeze(1) for _ in encode_hidden_states], 1)

        # decoder
        decode_state = (encode_state[0].clone(), encode_state[1].clone())
        for i in range(self.decode_time_step):
            if i >= 1 and self.ss_prob > 0.0:
                sample_prob = comment_mask.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = comment_matrix[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = comment_matrix[:, i].data.clone()
                    prob_prev = torch.exp(
                        outputs[-1].data
                    )  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
                    it = Variable(it, requires_grad=False)
            else:
                it = comment_matrix[:, i].clone()

            if i >= 1 and comment_matrix[:, i].data.sum() == 0:
                break

            decode_xt = self.embed(it)
            decode_output, decode_state = self.decode_lstm.forward(
                decode_xt, encode_hidden_states, decode_state)

            decode_logit_words = F.log_softmax(self.logit(decode_output))
            decode_logit_seq.append(decode_logit_words)
            outputs.append(decode_logit_words)

        decode_logit_seq = torch.cat(
            [_.unsqueeze(1) for _ in decode_logit_seq], 1).contiguous()

        return decode_logit_seq

    def teacher_forcing_get_hidden_states(self, code_matrix, comment_matrix,
                                          comment_mask, eos_index):
        batch_size = code_matrix.size(0)
        encode_state = self.init_hidden(batch_size)
        outputs = []

        # encoder
        encode_hidden_states = []
        for i in range(self.encode_time_step):
            encode_words = code_matrix[:, i].clone()
            if code_matrix[:, i].data.sum() == 0:
                break
            encode_xt = self.embed(encode_words)
            encode_output, encode_state = self.encode_lstm.forward(
                encode_xt, encode_state)
            encode_hidden_states.append(encode_output)
        encode_hidden_states = torch.cat(
            [_.unsqueeze(1) for _ in encode_hidden_states], 1)

        # decoder
        decode_state = (encode_state[0].clone(), encode_state[1].clone())
        for i in range(self.decode_time_step):
            if i >= 1 and self.ss_prob > 0.0:
                sample_prob = comment_mask.data.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = comment_matrix[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = comment_matrix[:, i].data.clone()
                    prob_prev = torch.exp(
                        outputs[-1].data
                    )  # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(
                        0, sample_ind,
                        torch.multinomial(prob_prev, 1).view(-1).index_select(
                            0, sample_ind))
                    it = Variable(it, requires_grad=False)
            else:
                it = comment_matrix[:, i].clone()
            if it.cpu().data[0] == eos_index:
                break
            decode_xt = self.embed(it)
            decode_output, decode_state = self.decode_lstm.forward(
                decode_xt, encode_hidden_states, decode_state)

        return decode_state[0]

    def free_running_get_hidden_states(self, code_matrix, init_index,
                                       eos_index):
        batch_size = code_matrix.size(0)
        encode_state = self.init_hidden(batch_size)
        seq = []
        seqLogprobs = []
        logprobs_all = []

        # encoder
        encode_hidden_states = []
        for i in range(self.encode_time_step):
            encode_words = code_matrix[:, i].clone()

            if code_matrix[:, i].data.sum() == 0:
                break
            encode_xt = self.embed(encode_words)
            encode_output, encode_state = self.encode_lstm.forward(
                encode_xt, encode_state)
            encode_hidden_states.append(encode_output)
        encode_hidden_states = torch.cat(
            [_.unsqueeze(1) for _ in encode_hidden_states], 1)

        # decoder
        decode_state = (encode_state[0].clone(), encode_state[1].clone())
        for i in range(self.decode_time_step):
            if i == 0:
                it = code_matrix.data.new(batch_size).long().fill_(init_index)
                decode_xt = self.embed(
                    Variable(it, requires_grad=False).cuda())
                decode_output, decode_state = self.decode_lstm.forward(
                    decode_xt, encode_hidden_states, decode_state)
            else:
                max_logprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()
                if it.cpu()[0] == eos_index:
                    break
                decode_xt = self.embed(
                    Variable(it, requires_grad=False).cuda())
                decode_output, decode_state = self.decode_lstm.forward(
                    decode_xt, encode_hidden_states, decode_state)
                seq.append(it)
                seqLogprobs.append(max_logprobs.view(-1))
            logprobs = F.log_softmax(self.logit(decode_output))
            logprobs_all.append(logprobs)

        return decode_state[0]

    def sample(self, code_matrix, init_index, eos_index):
        batch_size = code_matrix.size(0)
        encode_state = self.init_hidden(batch_size)

        seq = []
        seqLogprobs = []
        logprobs_all = []

        # encoder
        encode_hidden_states = []
        for i in range(self.encode_time_step):
            encode_words = code_matrix[:, i].clone()

            if code_matrix[:, i].data.sum() == 0:
                break
            encode_xt = self.embed(encode_words)
            encode_output, encode_state = self.encode_lstm.forward(
                encode_xt, encode_state)
            encode_hidden_states.append(encode_output)

        encode_hidden_states = torch.cat(
            [_.unsqueeze(1) for _ in encode_hidden_states], 1)

        # decoder
        decode_state = (encode_state[0].clone(), encode_state[1].clone())
        for i in range(self.decode_time_step):
            if i == 0:
                it = code_matrix.data.new(batch_size).long().fill_(init_index)
                decode_xt = self.embed(
                    Variable(it, requires_grad=False).cuda())
                decode_output, decode_state = self.decode_lstm.forward(
                    decode_xt, encode_hidden_states, decode_state)
            else:
                max_logprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()

                if it.sum() == eos_index:
                    break

                decode_xt = self.embed(
                    Variable(it, requires_grad=False).cuda())
                decode_output, decode_state = self.decode_lstm.forward(
                    decode_xt, encode_hidden_states, decode_state)

                seq.append(it)
                seqLogprobs.append(max_logprobs.view(-1))

            logprobs = F.log_softmax(self.logit(decode_output))
            logprobs_all.append(logprobs)

        greedy_seq = torch.cat([_.unsqueeze(1) for _ in seq], 1).contiguous()
        greedy_seq_probs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs],
                                     1).contiguous()
        greedy_logprobs_all = torch.cat([_.unsqueeze(1) for _ in logprobs_all],
                                        1).contiguous()

        return greedy_seq, greedy_seq_probs, greedy_logprobs_all
Ejemplo n.º 13
0
class EncoderDecoder(nn.Module):
    def __init__(self, opt):
        super(EncoderDecoder, self).__init__()

        self.vocab_size = opt.vocab_size
        self.input_encoding_size = opt.input_encoding_size
        self.lstm_size = opt.lstm_size
        self.drop_prob_lm = opt.drop_prob_lm
        self.seq_length = opt.seq_length
        self.fc_feat_size = opt.fc_feat_size

        self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size)
        self.LSTMCore = LSTMCore(self.input_encoding_size, self.lstm_size, self.drop_prob_lm)
        
        self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
        self.logit = nn.Linear(self.lstm_size, self.vocab_size)

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.img_embed.weight.data.uniform_(-initrange, initrange)
        self.img_embed.bias.data.fill_(0)
        self.embed.weight.data.uniform_(-initrange, initrange)
        self.logit.weight.data.uniform_(-initrange, initrange)
        self.logit.bias.data.fill_(0)

    def init_hidden(self, batch_size):
        weight = next(self.parameters()).data
        return (Variable(weight.new(1, batch_size, self.lstm_size).zero_()),
                Variable(weight.new(1, batch_size, self.lstm_size).zero_()))

    def forward(self, fc_feats, seq):
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)
        outputs = []

        for i in range(seq.size(1)):
            if i == 0:
                xt = self.img_embed(fc_feats)
            else:
                it = seq[:, i-1].clone()
                if seq[:, i-1].data.sum() == 0:
                    break
                xt = self.embed(it)
            output, state = self.LSTMCore.forward(xt, state)

            if i > 0:
                output = F.log_softmax(self.logit(output.squeeze(0)))
                outputs.append(output)

        return torch.cat([_.unsqueeze(1) for _ in outputs], 1).contiguous()

    def sample_beam(self, fc_feats, init_index, opt={}):
        beam_size = opt.get('beam_size', 3)  # 如果不能取到 beam_size 这个变量, 则令 beam_size 为 3
        batch_size = fc_feats.size(0)

        seq = torch.LongTensor(self.seq_length, batch_size).zero_()
        seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)

        top_seq = []
        top_prob = [[] for _ in range(batch_size)]
        done_beams = [[] for _ in range(batch_size)]

        for k in range(batch_size):
            state = self.init_hidden(beam_size)

            beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_()
            beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_()
            beam_logprobs_sum = torch.zeros(beam_size)  # running sum of logprobs for each beam

            for t in range(self.seq_length + 1):
                if t == 0:
                    xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size)

                elif t == 1:
                    it = fc_feats.data.new(beam_size).long().fill_(init_index)
                    xt = self.embed(Variable(it, requires_grad=False))

                else:
                    logprobsf = logprobs.float()
                    ys, ix = torch.sort(logprobsf, 1, True)
                    candidates = []
                    cols = min(beam_size, ys.size(1))
                    rows = beam_size

                    if t == 2:  # at first time step only the first beam is active
                        rows = 1

                    for c in range(cols):
                        for q in range(rows):
                            # compute logprob of expanding beam q with word in (sorted) position c
                            local_logprob = ys[q, c]
                            candidate_logprob = beam_logprobs_sum[q] + local_logprob
                            candidates.append({'c': ix.data[q, c],
                                               'q': q,
                                               'p': candidate_logprob.data[0],
                                               'r': local_logprob.data[0]})

                    candidates = sorted(candidates, key=lambda x: -x['p'])

                    # construct new beams
                    new_state = [_.clone() for _ in state]
                    if t > 2:
                        # well need these as reference when we fork beams around
                        beam_seq_prev = beam_seq[:t-2].clone()
                        beam_seq_logprobs_prev = beam_seq_logprobs[:t-2].clone()

                    for vix in range(beam_size):
                        v = candidates[vix]
                        # fork beam index q into index vix
                        if t > 2:
                            beam_seq[:t - 2, vix] = beam_seq_prev[:, v['q']]
                            beam_seq_logprobs[:t - 2, vix] = beam_seq_logprobs_prev[:, v['q']]

                        # rearrange recurrent states
                        for state_ix in range(len(new_state)):
                            # copy over state in previous beam q to new beam at vix
                            new_state[state_ix][0, vix] = state[state_ix][0, v['q']]  # dimension one is time step

                        # append new end terminal at the end of this beam
                        beam_seq[t - 2, vix] = v['c']  # c'th word is the continuation
                        beam_seq_logprobs[t - 2, vix] = v['r']  # the raw logprob here
                        beam_logprobs_sum[vix] = v['p']  # the new (sum) logprob along this beam

                        if v['c'] == 0 or t == self.seq_length:
                            # END token special case here, or we reached the end.
                            # add the beam to a set of done beams
                            done_beams[k].append({'seq': beam_seq[:, vix].clone(),
                                                  'logps': beam_seq_logprobs[:, vix].clone(),
                                                  'p': beam_logprobs_sum[vix]})

                    # encode as vectors
                    it = beam_seq[t - 2]
                    xt = self.embed(Variable(it.cuda()))

                if t >= 2:
                    state = new_state

                output, state = self.LSTMCore.forward(xt, state)

                logprobs = F.log_softmax(self.logit(output))

            done_beams[k] = sorted(done_beams[k], key=lambda x: -x['p'])
            seq[:, k] = done_beams[k][0]['seq']  # the first beam has highest cumulative score
            seqLogprobs[:, k] = done_beams[k][0]['logps']

            # save result
            l = len(done_beams[k])
            top_seq_cur = torch.LongTensor(l, self.seq_length).zero_()

            for temp_index in range(l):
                top_seq_cur[temp_index] = done_beams[k][temp_index]['seq'].clone()
                top_prob[k].append(done_beams[k][temp_index]['p'])

            top_seq.append(top_seq_cur)

        # return the samples and their log likelihoods
        return seq.transpose(0, 1), seqLogprobs.transpose(0, 1), top_seq, top_prob

    def sample(self, fc_feats, init_index, opt={}):
        beam_size = opt.get('beam_size', 1)

        if beam_size > 1:
            return self.sample_beam(fc_feats, init_index, opt)

        batch_size = fc_feats.size(0)
        seq = []
        seqLogprobs = []
        logprobs_all = []
        state = self.init_hidden(batch_size)

        for t in range(self.seq_length):
            if t == 0:
                xt = self.img_embed(fc_feats)
            else:
                if t == 1:
                    it = fc_feats.data.new(batch_size).long().fill_(init_index)
                else:
                    sampleLogprobs, it = torch.max(logprobs.data, 1)
                    it = it.view(-1).long()
                xt = self.embed(Variable(it, requires_grad=False).cuda())
            if t >= 2:
                if t == 2:
                    unfinished = it > 0
                else:
                    unfinished *= (it > 0)
                if unfinished.sum() == 0:
                    break
                it = it * unfinished.type_as(it)
                seq.append(it)
                seqLogprobs.append(sampleLogprobs.view(-1))

            output, state = self.LSTMCore.forward(xt, state)

            logprobs = F.log_softmax(self.logit(output))
            logprobs_all.append(logprobs)

        return torch.cat([_.unsqueeze(1) for _ in seq], 1), \
               torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1), \
               torch.cat([_.unsqueeze(1) for _ in logprobs_all], 1).contiguous()

    def teacher_forcing_get_hidden_states(self, fc_feats, seq):
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)
        outputs = []

        for i in range(seq.size(1)):
            if i == 0:
                xt = self.img_embed(fc_feats)
            else:
                it = seq[:, i-1].clone()
                if seq[:, i-1].data.sum() == 0:
                    break
                xt = self.embed(it)
            output, state = self.LSTMCore.forward(xt, state)
            if i > 0:
                if batch_size == 1:
                    output = F.log_softmax(self.logit(output))
                else:
                    output = F.log_softmax(self.logit(output.squeeze(0)))
                outputs.append(output)

        return state[0], outputs

    def free_running_get_hidden_states(self, fc_feats, init_index, end_index):
        batch_size = fc_feats.size(0)
        seq = []
        seqLogprobs = []
        logprobs_all = []
        state = self.init_hidden(batch_size)

        for t in range(self.seq_length):
            if t == 0:
                xt = self.img_embed(fc_feats)
            if t == 1:
                it = fc_feats.data.new(batch_size).long().fill_(init_index)
                xt = self.embed(Variable(it, requires_grad=False).cuda())
            if t >= 2:
                sampleLogprobs, it = torch.max(logprobs.data, 1)
                it = it.view(-1).long()
                if it.cpu().numpy()[0] == end_index:
                    break
                xt = self.embed(Variable(it, requires_grad=False).cuda())
                seq.append(it)
                seqLogprobs.append(sampleLogprobs.view(-1))

            output, state = self.LSTMCore.forward(xt, state)
            logprobs = F.log_softmax(self.logit(output))
            logprobs_all.append(logprobs)

        return state[0], logprobs_all