class AttEncodeDecodeARNet(nn.Module): def __init__(self, opt): super(AttEncodeDecodeARNet, self).__init__() self.token_cnt = opt.token_cnt self.word_cnt = opt.word_cnt self.lstm_size = opt.lstm_size self.drop_prob = opt.drop_prob self.input_encoding_size = opt.input_encoding_size self.encode_time_step = opt.code_truncate self.decode_time_step = opt.comment_truncate self.ss_prob = opt.ss_prob self.encoding_feat_size = opt.lstm_size self.encoding_att_size = opt.encoding_att_size self.att_hidden_size = opt.att_hidden_size self.encode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size, self.drop_prob) self.decode_lstm = LSTMSoftAttentionCore( self.input_encoding_size, self.lstm_size, self.encoding_feat_size, self.encoding_att_size, self.att_hidden_size, self.drop_prob_lm) self.embed = nn.Embedding(self.token_cnt + 1, self.input_encoding_size) self.logit = nn.Linear(self.lstm_size, self.word_cnt) self.init_weights() # ARNet self.rcst_weight = opt.rcst_weight self.rcst_lstm = LSTMCore(self.lstm_size, self.lstm_size, self.drop_prob_lm) self.h_2_pre_h = nn.Linear(self.lstm_size, self.lstm_size) self.rcst_init_weights() def init_weights(self): self.embed.weight.data.uniform_(-0.1, 0.1) self.logit.weight.data.uniform_(-0.1, 0.1) self.logit.bias.data.fill_(0) def init_hidden(self, batch_size): weight = next(self.parameters()).data init_h = Variable(weight.new(1, batch_size, self.lstm_size).zero_()) init_c = Variable(weight.new(1, batch_size, self.lstm_size).zero_()) init_state = (init_h, init_c) return init_state # init params of ARNet def rcst_init_weights(self): self.h_2_pre_h.weight.data.uniform_(-0.1, 0.1) self.h_2_pre_h.bias.data.fill_(0) def forward(self, code_matrix, comment_matrix, comment_mask): batch_size = code_matrix.size(0) encode_state = self.init_hidden(batch_size) decode_logit_seq = [] outputs = [] # encoder encode_hidden_states = [] for i in range(self.encode_time_step): encode_words = code_matrix[:, i].clone() if code_matrix[:, i].data.sum() == 0: break encode_xt = self.embed(encode_words) encode_output, encode_state = self.encode_lstm.forward( encode_xt, encode_state) encode_hidden_states.append(encode_output) encode_hidden_states = torch.cat( [_.unsqueeze(1) for _ in encode_hidden_states], 1) # batch x 300 x 512 # decoder decode_state = (encode_state[0].clone(), encode_state[1].clone()) rcst_state = (encode_state[0].clone(), encode_state[1].clone()) pre_h = encode_state[0].clone() rcst_loss = 0.0 for i in range(self.decode_time_step): if i >= 1 and self.ss_prob > 0.0: sample_prob = comment_mask.data.new(batch_size).uniform_(0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = comment_matrix[:, i].clone() else: sample_ind = sample_mask.nonzero().view(-1) it = comment_matrix[:, i].data.clone() prob_prev = torch.exp( outputs[-1].data ) # fetch prev distribution: shape Nx(M+1) it.index_copy_( 0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select( 0, sample_ind)) it = Variable(it, requires_grad=False) else: it = comment_matrix[:, i].clone() if i >= 1 and comment_matrix[:, i].data.sum() == 0: break decode_xt = self.embed(it) decode_output, decode_state = self.decode_lstm.forward( decode_xt, encode_hidden_states, decode_state) decode_logit_words = F.log_softmax(self.logit(decode_output)) decode_logit_seq.append(decode_logit_words) outputs.append(decode_logit_words) # ARNet part rcst_output, rcst_state = self.rcst_lstm.forward( decode_output, rcst_state) rcst_h = self.h_2_pre_h(rcst_output) rcst_diff = rcst_h - pre_h rcst_mask = comment_mask[:, i].contiguous().view( -1, batch_size).repeat(1, self.lstm_size) cur_rcst_loss = torch.sum( torch.sum(torch.mul(rcst_diff, rcst_diff) * rcst_mask, dim=1)) rcst_loss += cur_rcst_loss * self.rcst_weight / torch.sum( comment_mask[:, i]) # update previous hidden state pre_h = decode_state[0].clone() # aggregate decode_logit_seq = torch.cat( [_.unsqueeze(1) for _ in decode_logit_seq], 1).contiguous() return decode_logit_seq, rcst_loss def sample(self, code_matrix, init_index, eos_index): batch_size = code_matrix.size(0) encode_state = self.init_hidden(batch_size) seq = [] seqLogprobs = [] logprobs_all = [] # encoder encode_hidden_states = [] for i in range(self.encode_time_step): encode_words = code_matrix[:, i].clone() if code_matrix[:, i].data.sum() == 0: break encode_xt = self.embed(encode_words) encode_output, encode_state = self.encode_lstm.forward( encode_xt, encode_state) encode_hidden_states.append(encode_output) encode_hidden_states = torch.cat( [_.unsqueeze(1) for _ in encode_hidden_states], 1) # decoder decode_state = (encode_state[0].clone(), encode_state[1].clone()) for i in range(self.decode_time_step): if i == 0: it = code_matrix.data.new(batch_size).long().fill_(init_index) decode_xt = self.embed( Variable(it, requires_grad=False).cuda()) decode_output, decode_state = self.decode_lstm.forward( decode_xt, encode_hidden_states, decode_state) else: max_logprobs, it = torch.max(logprobs.data, 1) it = it.view(-1).long() if it.sum() == eos_index: break decode_xt = self.embed( Variable(it, requires_grad=False).cuda()) decode_output, decode_state = self.decode_lstm.forward( decode_xt, encode_hidden_states, decode_state) seq.append(it) seqLogprobs.append(max_logprobs.view(-1)) logprobs = F.log_softmax(self.logit(decode_output)) logprobs_all.append(logprobs) greedy_seq = torch.cat([_.unsqueeze(1) for _ in seq], 1).contiguous() greedy_seq_probs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1).contiguous() greedy_logprobs_all = torch.cat([_.unsqueeze(1) for _ in logprobs_all], 1).contiguous() return greedy_seq, greedy_seq_probs, greedy_logprobs_all
class ReviewNet(nn.Module): def __init__(self, opt): super(ReviewNet, self).__init__() self.token_cnt = opt.token_cnt self.word_cnt = opt.word_cnt self.lstm_size = opt.lstm_size self.drop_prob = opt.drop_prob self.input_encoding_size = opt.input_encoding_size self.encode_time_step = opt.code_truncate self.decode_time_step = opt.comment_truncate self.ss_prob = opt.ss_prob self.encoding_feat_size = opt.lstm_size self.encoding_att_size = opt.encoding_att_size self.att_hidden_size = opt.att_hidden_size self.num_review_steps = opt.num_review_steps self.drop_prob_reason = opt.drop_prob_reason # encoder self.encode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size, self.drop_prob) # reviewer self.review_steps = nn.ModuleList([LSTMSoftAttentionNoInputCore(self.input_encoding_size, self.lstm_size, self.encoding_feat_size, self.encoding_att_size, self.att_hidden_size, self.drop_prob_reason) for _ in range(self.num_review_steps)]) # decoder self.decode_lstm = LSTMSoftAttentionCore(self.input_encoding_size, self.lstm_size, self.encoding_feat_size, self.num_review_steps, self.att_hidden_size, self.drop_prob) self.embed = nn.Embedding(self.token_cnt + 1, self.input_encoding_size) self.logit = nn.Linear(self.lstm_size, self.word_cnt) self.init_weights() def init_weights(self): self.embed.weight.data.uniform_(-0.1, 0.1) self.logit.weight.data.uniform_(-0.1, 0.1) self.logit.bias.data.fill_(0) def init_hidden(self, batch_size): weight = next(self.parameters()).data init_h = Variable(weight.new(1, batch_size, self.lstm_size).zero_()) init_c = Variable(weight.new(1, batch_size, self.lstm_size).zero_()) init_state = (init_h, init_c) return init_state def forward(self, code_matrix, comment_matrix, current_comment_mask_cuda): batch_size = code_matrix.size(0) encode_state = self.init_hidden(batch_size) decode_logit_seq = [] outputs = [] # encoder encode_hidden_states = [] for i in range(self.encode_time_step): encode_words = code_matrix[:, i].clone() encode_xt = self.embed(encode_words) encode_output, encode_state = self.encode_lstm.forward(encode_xt, encode_state) encode_hidden_states.append(encode_output) encode_hidden_states = torch.cat([_.unsqueeze(1) for _ in encode_hidden_states], 1) # batch x 300 x 512 # reviewer review_state = (encode_state[0].clone(), encode_state[1].clone()) thought = [] for i in range(self.num_review_steps): review_output, review_state = self.review_steps[i].forward(encode_hidden_states, review_state) thought.append(review_output) thought_vectors = torch.stack(thought).transpose(0, 1).cuda().contiguous() # thoughts vectors # decoder decode_state = (encode_state[0].clone(), encode_state[1].clone()) for i in range(self.decode_time_step): if i >= 1 and self.ss_prob > 0.0: sample_prob = current_comment_mask_cuda.data.new(batch_size).uniform_(0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = comment_matrix[:, i].clone() else: sample_ind = sample_mask.nonzero().view(-1) it = comment_matrix[:, i].data.clone() prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) it = Variable(it, requires_grad=False) else: it = comment_matrix[:, i].clone() if i >= 1 and comment_matrix[:, i].data.sum() == 0: break decode_xt = self.embed(it) decode_output, decode_state = self.decode_lstm.forward(decode_xt, thought_vectors, decode_state) decode_logit_words = F.log_softmax(self.logit(decode_output)) decode_logit_seq.append(decode_logit_words) outputs.append(decode_logit_words) # aggregate decode_logit_seq = torch.cat([_.unsqueeze(1) for _ in decode_logit_seq], 1).contiguous() return decode_logit_seq def sample(self, code_matrix, init_index, eos_index): batch_size = code_matrix.size(0) encode_state = self.init_hidden(batch_size) seq = [] seqLogprobs = [] logprobs_all = [] # encoder encode_hidden_states = [] for i in range(self.encode_time_step): encode_words = code_matrix[:, i].clone() if code_matrix[:, i].data.sum() == 0: break encode_xt = self.embed(encode_words) encode_output, encode_state = self.encode_lstm.forward(encode_xt, encode_state) encode_hidden_states.append(encode_output) encode_hidden_states = torch.cat([_.unsqueeze(1) for _ in encode_hidden_states], 1) # reviewer review_state = (encode_state[0].clone(), encode_state[1].clone()) thought = [] for i in range(self.num_review_steps): review_output, review_state = self.review_steps[i].forward(encode_hidden_states, review_state) thought.append(review_output) thought_vectors = torch.stack(thought).transpose(0, 1).cuda().contiguous() # thoughts vectors # decoder decode_state = (encode_state[0].clone(), encode_state[1].clone()) for i in range(self.decode_time_step): if i == 0: it = code_matrix.data.new(batch_size).long().fill_(init_index) decode_xt = self.embed(Variable(it, requires_grad=False).cuda()) decode_output, decode_state = self.decode_lstm.forward(decode_xt, thought_vectors, decode_state) else: max_logprobs, it = torch.max(logprobs.data, 1) it = it.view(-1).long() if it.sum() == eos_index: break decode_xt = self.embed(Variable(it, requires_grad=False).cuda()) decode_output, decode_state = self.decode_lstm.forward(decode_xt, thought_vectors, decode_state) seq.append(it) seqLogprobs.append(max_logprobs.view(-1)) logprobs = F.log_softmax(self.logit(decode_output)) logprobs_all.append(logprobs) # aggregate greedy_seq = torch.cat([_.unsqueeze(1) for _ in seq], 1).contiguous() greedy_seq_probs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1).contiguous() greedy_logprobs_all = torch.cat([_.unsqueeze(1) for _ in logprobs_all], 1).contiguous() return greedy_seq, greedy_seq_probs, greedy_logprobs_all
class EncodeDecodeARNet(nn.Module): def __init__(self, opt): super(EncodeDecodeARNet, self).__init__() self.token_cnt = opt.token_cnt self.word_cnt = opt.word_cnt self.lstm_size = opt.lstm_size self.drop_prob = opt.drop_prob self.input_encoding_size = opt.input_encoding_size self.encode_time_step = opt.code_truncate self.decode_time_step = opt.comment_truncate self.encode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size, self.drop_prob) self.decode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size, self.drop_prob) self.embed = nn.Embedding(self.token_cnt + 1, self.input_encoding_size) self.logit = nn.Linear(self.lstm_size, self.word_cnt) self.init_weights() # params of ARNet self.rcst_weight = opt.reconstruct_weight self.rcst_lstm = LSTMCore(self.lstm_size, self.lstm_size, self.drop_prob) self.h_2_pre_h = nn.Linear(self.lstm_size, self.lstm_size) self.rcst_init_weights() def init_weights(self): self.embed.weight.data.uniform_(-0.1, 0.1) self.logit.weight.data.uniform_(-0.1, 0.1) self.logit.bias.data.fill_(0) def init_hidden(self, batch_size): weight = next(self.parameters()).data init_h = Variable(weight.new(1, batch_size, self.lstm_size).zero_()) init_c = Variable(weight.new(1, batch_size, self.lstm_size).zero_()) init_state = (init_h, init_c) return init_state # init def rcst_init_weights(self): self.h_2_pre_h.weight.data.uniform_(-0.1, 0.1) self.h_2_pre_h.bias.data.fill_(0) # copy weights from pre-trained model with cross entropy def copy_weights(self, model_path): src_weights = torch.load(model_path) own_dict = self.state_dict() for key, var in src_weights.items(): print("copy weights: {} size: {}".format(key, var.size())) own_dict[key].copy_(var) def forward(self, code_matrix, comment_matrix, comment_mask): batch_size = code_matrix.size(0) encode_state = self.init_hidden(batch_size) decode_logit_seq = [] # encoder for i in range(self.encode_time_step): encode_words = code_matrix[:, i].clone() if code_matrix[:, i].data.sum() == 0: break encode_xt = self.embed(encode_words) encode_output, encode_state = self.encode_lstm.forward( encode_xt, encode_state) # decoder decode_state = (encode_state[0].clone(), encode_state[1].clone()) rcst_state = (encode_state[0].clone(), encode_state[1].clone()) pre_h = encode_state[0].clone() rcst_loss = 0.0 for i in range(self.decode_time_step): decode_words = comment_matrix[:, i].clone() if comment_matrix[:, i].data.sum() == 0: break decode_xt = self.embed(decode_words) decode_output, decode_state = self.decode_lstm.forward( decode_xt, decode_state) decode_logit_words = F.log_softmax(self.logit(decode_output)) decode_logit_seq.append(decode_logit_words) # ARNet rcst_state, rcst_state = self.rcst_lstm.forward( decode_output, rcst_state) rcst_h = self.h_2_pre_h(rcst_state) rcst_diff = rcst_h - pre_h rcst_mask = comment_mask[:, i].contiguous().view( -1, batch_size).repeat(1, self.lstm_size) cur_rcst_loss = torch.sum( torch.sum(torch.mul(rcst_diff, rcst_diff) * rcst_mask, dim=1)) rcst_loss += cur_rcst_loss * self.rcst_weight / torch.sum( comment_mask[:, i]) # update previous hidden state pre_h = decode_state[0].clone() # aggregate decode_logit_seq = torch.cat( [_.unsqueeze(1) for _ in decode_logit_seq], 1).contiguous() return decode_logit_seq, rcst_loss def sample(self, code_matrix, init_index, eos_index): batch_size = code_matrix.size(0) encode_state = self.init_hidden(batch_size) seq = [] seqLogprobs = [] logprobs_all = [] # encoder for i in range(self.encode_time_step): encode_words = code_matrix[:, i].clone() if code_matrix[:, i].data.sum() == 0: break encode_xt = self.embed(encode_words) encode_output, encode_state = self.encode_lstm.forward( encode_xt, encode_state) # decoder decode_state = (encode_state[0].clone(), encode_state[1].clone()) for i in range(self.decode_time_step): if i == 0: it = code_matrix.data.new(batch_size).long().fill_(init_index) decode_xt = self.embed( Variable(it, requires_grad=False).cuda()) decode_output, decode_state = self.decode_lstm.forward( decode_xt, decode_state) else: max_logprobs, it = torch.max(logprobs.data, 1) it = it.view(-1).long() if it.sum() == eos_index: break decode_xt = self.embed( Variable(it, requires_grad=False).cuda()) decode_output, decode_state = self.decode_lstm.forward( decode_xt, decode_state) seq.append(it) seqLogprobs.append(max_logprobs.view(-1)) logprobs = F.log_softmax(self.logit(decode_output)) logprobs_all.append(logprobs) # aggregate greedy_seq = torch.cat([_.unsqueeze(1) for _ in seq], 1).contiguous() greedy_seq_probs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1).contiguous() greedy_logprobs_all = torch.cat([_.unsqueeze(1) for _ in logprobs_all], 1).contiguous() return greedy_seq, greedy_seq_probs, greedy_logprobs_all
class ShowAttendTellModel(nn.Module): def __init__(self, opt): super(ShowAttendTellModel, self).__init__() self.vocab_size = opt.vocab_size self.input_encoding_size = opt.input_encoding_size self.lstm_size = opt.lstm_size # self.drop_prob_lm = opt.drop_prob_lm self.drop_prob_lm = 0.1 self.seq_length = opt.seq_length self.fc_feat_size = opt.fc_feat_size self.conv_feat_size = opt.conv_feat_size self.conv_att_size = opt.conv_att_size self.att_hidden_size = opt.att_hidden_size self.ss_prob = opt.ss_prob # Schedule sampling probability self.fc2h = nn.Linear(self.fc_feat_size, self.lstm_size) self.core = LSTMSoftAttentionCore(self.input_encoding_size, self.lstm_size, self.conv_feat_size, self.conv_att_size, self.att_hidden_size, self.drop_prob_lm) self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) self.logit = nn.Linear(self.lstm_size, self.vocab_size) # add the following parameters for ARNet self.rcst_time = opt.rcst_time self.rcst_scale = opt.rcst_weight # lambda in ARNet self.rcstLSTM = LSTMCore( self.lstm_size, self.lstm_size, self.drop_prob_lm) # ARNet is realized by LSTM network self.h_2_pre_h = nn.Linear( self.lstm_size, self.lstm_size) # fully connected layer in ARNet self.init_weights() def init_weights(self): initrange = 0.1 self.embed.weight.data.uniform_(-initrange, initrange) self.fc2h.weight.data.uniform_(-initrange, initrange) self.logit.weight.data.uniform_(-initrange, initrange) self.logit.bias.data.fill_(0) # initialize weights of parameters in ARNet self.h_2_pre_h.weight.data.uniform_(-initrange, initrange) self.h_2_pre_h.bias.data.fill_(0) def copy_weights(self, model_path): """ Initialize the weights of parameters from the model which is pre-trained by Cross Entropy (MLE) """ src_weights = torch.load(model_path) own_dict = self.state_dict() for key, var in src_weights.items(): print("copy weights: {} size: {}".format(key, var.size())) own_dict[key].copy_(var) def forward(self, fc_feats, att_feats, seq): batch_size = fc_feats.size(0) init_h = self.fc2h(fc_feats) init_h = init_h.unsqueeze(0) init_c = init_h.clone() state = (init_h, init_c) outputs = [] for i in range(seq.size(1)): if i >= 1 and self.ss_prob > 0.0: sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = seq[:, i].clone() else: sample_ind = sample_mask.nonzero().view(-1) it = seq[:, i].data.clone() prob_prev = torch.exp( outputs[-1].data ) # fetch prev distribution: shape Nx(M+1) it.index_copy_( 0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select( 0, sample_ind)) it = Variable(it, requires_grad=False) else: it = seq[:, i].clone() # break if all the sequences end if i >= 1 and seq[:, i].data.sum() == 0: break xt = self.embed(it) output, state = self.core.forward(xt, att_feats, state) output = F.log_softmax(self.logit(output.squeeze(0)), dim=1) outputs.append(output) return torch.cat([_.unsqueeze(1) for _ in outputs], 1).contiguous() # batch * 19 * vocab_size # reconstruct 部分 def rcst_forward(self, fc_feats, att_feats, seq, mask): batch_size = fc_feats.size(0) init_h = self.fc2h(fc_feats) init_h = init_h.unsqueeze(0) init_c = init_h.clone() state = (init_h, init_c) rcst_init_h = init_h.clone() rcst_init_c = init_c.clone() rcst_state = (rcst_init_h, rcst_init_c) pre_h = [] output_logits = [] rcst_loss = 0.0 for i in range(seq.size(1) - 1): if i >= 1 and self.ss_prob > 0.0: # otherwise no need to sample sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = seq[:, i].clone() else: sample_ind = sample_mask.nonzero().view(-1) it = seq[:, i].data.clone() prob_prev = torch.exp( output_logits[-1].data ) # fetch prev distribution: shape Nx(M+1) it.index_copy_( 0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select( 0, sample_ind)) it = Variable(it, requires_grad=False) else: it = seq[:, i].clone() # break if all the sequences end if i >= 1 and seq[:, i].data.sum() == 0: break xt = self.embed(it) output, state = self.core.forward(xt, att_feats, state) logit_words = F.log_softmax(self.logit(output.squeeze(0))) output_logits.append(logit_words) if i >= 1: rcst_output, rcst_state = self.rcstLSTM.forward( output, rcst_state) rcst_h = F.leaky_relu(self.h_2_pre_h(rcst_output)) rcst_t = pre_h[i - 1].squeeze(dim=0) # -1 means not changing the size of that dimension, # http://pytorch.org/docs/master/tensors.html rcst_mask = mask[:, i].contiguous().view(batch_size, -1).expand( batch_size, self.lstm_size) rcst_diff = rcst_h - rcst_t rcst_loss += torch.sum( torch.sum(torch.mul(rcst_diff, rcst_diff) * rcst_mask, dim=1)) / batch_size * self.rcst_scale # 更新 previous hidden state pre_h.append(state[0].clone()) output_logits = torch.cat([_.unsqueeze(1) for _ in output_logits], 1).contiguous() return output_logits, rcst_loss def sample_beam(self, fc_feats, att_feats, init_index, opt={}): beam_size = opt.get('beam_size', 10) # 如果不能取到 beam_size 这个变量, 则令 beam_size 为 10 batch_size = fc_feats.size(0) fc_feat_size = fc_feats.size(1) seq = torch.LongTensor(self.seq_length, batch_size).zero_() seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) top_seq = [] top_prob = [[] for _ in range(batch_size)] self.done_beams = [[] for _ in range(batch_size)] for k in range(batch_size): init_h = self.fc2h(fc_feats[k].unsqueeze(0).expand( beam_size, fc_feat_size)) init_h = init_h.unsqueeze(0) init_c = init_h.clone() state = (init_h, init_c) att_feats_current = att_feats[k].unsqueeze(0).expand( beam_size, att_feats.size(1), att_feats.size(2)) att_feats_current = att_feats_current.contiguous() beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() beam_logprobs_sum = torch.zeros( beam_size) # running sum of logprobs for each beam for t in range(self.seq_length + 1): if t == 0: # input <bos> it = fc_feats.data.new(beam_size).long().fill_(init_index) xt = self.embed(Variable(it, requires_grad=False)) # xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) else: """perform a beam merge. that is, for every previous beam we now many new possibilities to branch out we need to resort our beams to maintain the loop invariant of keeping the top beam_size most likely sequences.""" logprobsf = logprobs.float( ) # lets go to CPU for more efficiency in indexing operations # ys: beam_size * (Vab_size + 1) ys, ix = torch.sort( logprobsf, 1, True ) # sorted array of logprobs along each previous beam (last true = descending) candidates = [] cols = min(beam_size, ys.size(1)) rows = beam_size if t == 1: # at first time step only the first beam is active rows = 1 for c in range(cols): for q in range(rows): # compute logprob of expanding beam q with word in (sorted) position c local_logprob = ys[q, c] candidate_logprob = beam_logprobs_sum[ q] + local_logprob if t > 1 and beam_seq[t - 2, q] == 0: continue candidates.append({ 'c': ix.data[q, c], 'q': q, 'p': candidate_logprob.data[0], 'r': local_logprob.data[0] }) if len(candidates) == 0: break candidates = sorted(candidates, key=lambda x: -x['p']) # construct new beams new_state = [_.clone() for _ in state] if t > 1: # well need these as reference when we fork beams around beam_seq_prev = beam_seq[:t - 1].clone() beam_seq_logprobs_prev = beam_seq_logprobs[:t - 1].clone() for vix in range(min(beam_size, len(candidates))): v = candidates[vix] # fork beam index q into index vix if t > 1: beam_seq[:t - 1, vix] = beam_seq_prev[:, v['q']] beam_seq_logprobs[:t - 1, vix] = beam_seq_logprobs_prev[:, v[ 'q']] # rearrange recurrent states for state_ix in range(len(new_state)): # copy over state in previous beam q to new beam at vix new_state[state_ix][0, vix] = state[state_ix][ 0, v['q']] # dimension one is time step # append new end terminal at the end of this beam beam_seq[t - 1, vix] = v['c'] # c'th word is the continuation beam_seq_logprobs[t - 1, vix] = v['r'] # the raw logprob here beam_logprobs_sum[vix] = v[ 'p'] # the new (sum) logprob along this beam if v['c'] == 0 or t == self.seq_length: # END token special case here, or we reached the end. # add the beam to a set of done beams self.done_beams[k].append({ 'seq': beam_seq[:, vix].clone(), 'logps': beam_seq_logprobs[:, vix].clone(), 'p': beam_logprobs_sum[vix] }) # encode as vectors it = beam_seq[t - 1] xt = self.embed(Variable(it.cuda())) if t >= 1: state = new_state output, state = self.core.forward(xt, att_feats_current, state) logprobs = F.log_softmax(self.logit(output)) self.done_beams[k] = sorted(self.done_beams[k], key=lambda x: -x['p']) seq[:, k] = self.done_beams[k][0][ 'seq'] # the first beam has highest cumulative score seqLogprobs[:, k] = self.done_beams[k][0]['logps'] # save result l = len(self.done_beams[k]) top_seq_cur = torch.LongTensor(l, self.seq_length).zero_() for temp_index in range(l): top_seq_cur[temp_index] = self.done_beams[k][temp_index][ 'seq'].clone() top_prob[k].append(self.done_beams[k][temp_index]['p']) top_seq.append(top_seq_cur) # return the samples and their log likelihoods return seq.transpose(0, 1), seqLogprobs.transpose(0, 1), top_seq, top_prob def sample(self, fc_feats, att_feats, init_index, opt={}): sample_max = opt.get('sample_max', 1) beam_size = opt.get('beam_size', 1) temperature = opt.get('temperature', 1.0) if beam_size > 1: return self.sample_beam(fc_feats, att_feats, init_index, opt) batch_size = fc_feats.size(0) seq = [] seqLogprobs = [] logprobs_all = [] init_h = self.fc2h(fc_feats) init_h = init_h.unsqueeze(0) init_c = init_h.clone() state = (init_h, init_c) for t in range(self.seq_length): if t == 0: # input BOS, 304 it = fc_feats.data.new(batch_size).long().fill_(init_index) elif sample_max: sampleLogprobs, it = torch.max(logprobs.data, 1) it = it.view(-1).long() else: if temperature == 1.0: prob_prev = torch.exp(logprobs.data).cpu( ) # fetch prev distribution: shape Nx(M+1) else: # scale logprobs by temperature prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu() it = torch.multinomial(prob_prev, 1).cuda() sampleLogprobs = logprobs.gather( 1, Variable(it, requires_grad=False).cuda( )) # gather the logprobs at sampled positions it = it.view( -1).long() # and flatten indices for downstream processing xt = self.embed(Variable(it, requires_grad=False).cuda()) if t >= 1: # stop when all finished if t == 1: unfinished = it > 0 else: unfinished = unfinished * (it > 0) if unfinished.sum() == 0: break it = it * unfinished.type_as(it) seq.append(it) seqLogprobs.append(sampleLogprobs.view(-1)) output, state = self.core.forward(xt, att_feats, state) logprobs = F.log_softmax(self.logit(output)) logprobs_all.append(logprobs) greedy_seq = torch.cat([_.unsqueeze(1) for _ in seq], 1) greedy_seqLogprobs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1) greedy_logprobs_all = torch.cat([_.unsqueeze(1) for _ in logprobs_all], 1).contiguous() return greedy_seq, greedy_seqLogprobs, greedy_logprobs_all def teacher_forcing_get_hidden_states(self, fc_feats, att_feats, seq): batch_size = fc_feats.size(0) init_h = self.fc2h(fc_feats) init_h = init_h.unsqueeze(0) init_c = init_h.clone() state = (init_h, init_c) outputs = [] for i in range(seq.size(1)): if i >= 1 and self.ss_prob > 0.0: # otherwise no need to sample sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = seq[:, i].clone() else: sample_ind = sample_mask.nonzero().view(-1) it = seq[:, i].data.clone() prob_prev = torch.exp( outputs[-1].data ) # fetch prev distribution: shape Nx(M+1) it.index_copy_( 0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select( 0, sample_ind)) it = Variable(it, requires_grad=False) else: it = seq[:, i].clone() # break if all the sequences end if i >= 1 and seq[:, i].data.sum() == 0: break xt = self.embed(it) output, state = self.core.forward(xt, att_feats, state) if batch_size == 1: output = F.log_softmax(self.logit(output), dim=1) else: output = F.log_softmax(self.logit(output.squeeze(0)), dim=1) outputs.append(output) # 返回 hidden states return state[0], outputs def free_running_get_hidden_states(self, fc_feats, att_feats, init_index, end_index): batch_size = fc_feats.size(0) logprobs_all = [] init_h = self.fc2h(fc_feats) init_h = init_h.unsqueeze(0) init_c = init_h.clone() state = (init_h, init_c) for t in range(self.seq_length): if t == 0: # input BOS it = fc_feats.data.new(batch_size).long().fill_(init_index) xt = self.embed(Variable(it, requires_grad=False)) output, state = self.core.forward(xt, att_feats, state) if batch_size == 1: logprobs = F.log_softmax(self.logit(output), dim=1) else: logprobs = F.log_softmax(self.logit(output.squeeze(0)), dim=1) logprobs_all.append(logprobs) _, it = torch.max(logprobs.data, 1) it = it.view(-1).long() if it.cpu().numpy()[0] == end_index: break return state[0], logprobs_all
class AttEncodeDecode(nn.Module): def __init__(self, opt): super(AttEncodeDecode, self).__init__() self.token_cnt = opt.token_cnt self.word_cnt = opt.word_cnt self.lstm_size = opt.lstm_size self.drop_prob = opt.drop_prob self.input_encoding_size = opt.input_encoding_size self.encode_time_step = opt.code_truncate self.decode_time_step = opt.comment_truncate self.ss_prob = opt.ss_prob self.encoding_feat_size = opt.lstm_size self.encoding_att_size = opt.encoding_att_size self.att_hidden_size = opt.att_hidden_size self.encode_lstm = LSTMCore(self.input_encoding_size, self.lstm_size, self.drop_prob) self.decode_lstm = LSTMSoftAttentionCore( self.input_encoding_size, self.lstm_size, self.encoding_feat_size, self.encoding_att_size, self.att_hidden_size, self.drop_prob_lm) self.embed = nn.Embedding(self.token_cnt + 1, self.input_encoding_size) self.logit = nn.Linear(self.lstm_size, self.word_cnt) self.init_weights() def init_weights(self): self.embed.weight.data.uniform_(-0.1, 0.1) self.logit.weight.data.uniform_(-0.1, 0.1) self.logit.bias.data.fill_(0) def init_hidden(self, batch_size): weight = next(self.parameters()).data init_h = Variable(weight.new(1, batch_size, self.lstm_size).zero_()) init_c = Variable(weight.new(1, batch_size, self.lstm_size).zero_()) init_state = (init_h, init_c) return init_state def copy_weights(self, model_path): src_weights = torch.load(model_path) own_dict = self.state_dict() for key, var in own_dict.items(): print("copy weights: {} size: {}".format(key, var.size())) own_dict[key].copy_(src_weights[key]) def forward(self, code_matrix, comment_matrix, comment_mask): batch_size = code_matrix.size(0) encode_state = self.init_hidden(batch_size) decode_logit_seq = [] outputs = [] # encoder encode_hidden_states = [] for i in range(self.encode_time_step): encode_words = code_matrix[:, i].clone() if code_matrix[:, i].data.sum() == 0: break encode_xt = self.embed(encode_words) encode_output, encode_state = self.encode_lstm.forward( encode_xt, encode_state) encode_hidden_states.append(encode_output) encode_hidden_states = torch.cat( [_.unsqueeze(1) for _ in encode_hidden_states], 1) # decoder decode_state = (encode_state[0].clone(), encode_state[1].clone()) for i in range(self.decode_time_step): if i >= 1 and self.ss_prob > 0.0: sample_prob = comment_mask.data.new(batch_size).uniform_(0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = comment_matrix[:, i].clone() else: sample_ind = sample_mask.nonzero().view(-1) it = comment_matrix[:, i].data.clone() prob_prev = torch.exp( outputs[-1].data ) # fetch prev distribution: shape Nx(M+1) it.index_copy_( 0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select( 0, sample_ind)) it = Variable(it, requires_grad=False) else: it = comment_matrix[:, i].clone() if i >= 1 and comment_matrix[:, i].data.sum() == 0: break decode_xt = self.embed(it) decode_output, decode_state = self.decode_lstm.forward( decode_xt, encode_hidden_states, decode_state) decode_logit_words = F.log_softmax(self.logit(decode_output)) decode_logit_seq.append(decode_logit_words) outputs.append(decode_logit_words) decode_logit_seq = torch.cat( [_.unsqueeze(1) for _ in decode_logit_seq], 1).contiguous() return decode_logit_seq def teacher_forcing_get_hidden_states(self, code_matrix, comment_matrix, comment_mask, eos_index): batch_size = code_matrix.size(0) encode_state = self.init_hidden(batch_size) outputs = [] # encoder encode_hidden_states = [] for i in range(self.encode_time_step): encode_words = code_matrix[:, i].clone() if code_matrix[:, i].data.sum() == 0: break encode_xt = self.embed(encode_words) encode_output, encode_state = self.encode_lstm.forward( encode_xt, encode_state) encode_hidden_states.append(encode_output) encode_hidden_states = torch.cat( [_.unsqueeze(1) for _ in encode_hidden_states], 1) # decoder decode_state = (encode_state[0].clone(), encode_state[1].clone()) for i in range(self.decode_time_step): if i >= 1 and self.ss_prob > 0.0: sample_prob = comment_mask.data.new(batch_size).uniform_(0, 1) sample_mask = sample_prob < self.ss_prob if sample_mask.sum() == 0: it = comment_matrix[:, i].clone() else: sample_ind = sample_mask.nonzero().view(-1) it = comment_matrix[:, i].data.clone() prob_prev = torch.exp( outputs[-1].data ) # fetch prev distribution: shape Nx(M+1) it.index_copy_( 0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select( 0, sample_ind)) it = Variable(it, requires_grad=False) else: it = comment_matrix[:, i].clone() if it.cpu().data[0] == eos_index: break decode_xt = self.embed(it) decode_output, decode_state = self.decode_lstm.forward( decode_xt, encode_hidden_states, decode_state) return decode_state[0] def free_running_get_hidden_states(self, code_matrix, init_index, eos_index): batch_size = code_matrix.size(0) encode_state = self.init_hidden(batch_size) seq = [] seqLogprobs = [] logprobs_all = [] # encoder encode_hidden_states = [] for i in range(self.encode_time_step): encode_words = code_matrix[:, i].clone() if code_matrix[:, i].data.sum() == 0: break encode_xt = self.embed(encode_words) encode_output, encode_state = self.encode_lstm.forward( encode_xt, encode_state) encode_hidden_states.append(encode_output) encode_hidden_states = torch.cat( [_.unsqueeze(1) for _ in encode_hidden_states], 1) # decoder decode_state = (encode_state[0].clone(), encode_state[1].clone()) for i in range(self.decode_time_step): if i == 0: it = code_matrix.data.new(batch_size).long().fill_(init_index) decode_xt = self.embed( Variable(it, requires_grad=False).cuda()) decode_output, decode_state = self.decode_lstm.forward( decode_xt, encode_hidden_states, decode_state) else: max_logprobs, it = torch.max(logprobs.data, 1) it = it.view(-1).long() if it.cpu()[0] == eos_index: break decode_xt = self.embed( Variable(it, requires_grad=False).cuda()) decode_output, decode_state = self.decode_lstm.forward( decode_xt, encode_hidden_states, decode_state) seq.append(it) seqLogprobs.append(max_logprobs.view(-1)) logprobs = F.log_softmax(self.logit(decode_output)) logprobs_all.append(logprobs) return decode_state[0] def sample(self, code_matrix, init_index, eos_index): batch_size = code_matrix.size(0) encode_state = self.init_hidden(batch_size) seq = [] seqLogprobs = [] logprobs_all = [] # encoder encode_hidden_states = [] for i in range(self.encode_time_step): encode_words = code_matrix[:, i].clone() if code_matrix[:, i].data.sum() == 0: break encode_xt = self.embed(encode_words) encode_output, encode_state = self.encode_lstm.forward( encode_xt, encode_state) encode_hidden_states.append(encode_output) encode_hidden_states = torch.cat( [_.unsqueeze(1) for _ in encode_hidden_states], 1) # decoder decode_state = (encode_state[0].clone(), encode_state[1].clone()) for i in range(self.decode_time_step): if i == 0: it = code_matrix.data.new(batch_size).long().fill_(init_index) decode_xt = self.embed( Variable(it, requires_grad=False).cuda()) decode_output, decode_state = self.decode_lstm.forward( decode_xt, encode_hidden_states, decode_state) else: max_logprobs, it = torch.max(logprobs.data, 1) it = it.view(-1).long() if it.sum() == eos_index: break decode_xt = self.embed( Variable(it, requires_grad=False).cuda()) decode_output, decode_state = self.decode_lstm.forward( decode_xt, encode_hidden_states, decode_state) seq.append(it) seqLogprobs.append(max_logprobs.view(-1)) logprobs = F.log_softmax(self.logit(decode_output)) logprobs_all.append(logprobs) greedy_seq = torch.cat([_.unsqueeze(1) for _ in seq], 1).contiguous() greedy_seq_probs = torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1).contiguous() greedy_logprobs_all = torch.cat([_.unsqueeze(1) for _ in logprobs_all], 1).contiguous() return greedy_seq, greedy_seq_probs, greedy_logprobs_all
class EncoderDecoder(nn.Module): def __init__(self, opt): super(EncoderDecoder, self).__init__() self.vocab_size = opt.vocab_size self.input_encoding_size = opt.input_encoding_size self.lstm_size = opt.lstm_size self.drop_prob_lm = opt.drop_prob_lm self.seq_length = opt.seq_length self.fc_feat_size = opt.fc_feat_size self.img_embed = nn.Linear(self.fc_feat_size, self.input_encoding_size) self.LSTMCore = LSTMCore(self.input_encoding_size, self.lstm_size, self.drop_prob_lm) self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size) self.logit = nn.Linear(self.lstm_size, self.vocab_size) self.init_weights() def init_weights(self): initrange = 0.1 self.img_embed.weight.data.uniform_(-initrange, initrange) self.img_embed.bias.data.fill_(0) self.embed.weight.data.uniform_(-initrange, initrange) self.logit.weight.data.uniform_(-initrange, initrange) self.logit.bias.data.fill_(0) def init_hidden(self, batch_size): weight = next(self.parameters()).data return (Variable(weight.new(1, batch_size, self.lstm_size).zero_()), Variable(weight.new(1, batch_size, self.lstm_size).zero_())) def forward(self, fc_feats, seq): batch_size = fc_feats.size(0) state = self.init_hidden(batch_size) outputs = [] for i in range(seq.size(1)): if i == 0: xt = self.img_embed(fc_feats) else: it = seq[:, i-1].clone() if seq[:, i-1].data.sum() == 0: break xt = self.embed(it) output, state = self.LSTMCore.forward(xt, state) if i > 0: output = F.log_softmax(self.logit(output.squeeze(0))) outputs.append(output) return torch.cat([_.unsqueeze(1) for _ in outputs], 1).contiguous() def sample_beam(self, fc_feats, init_index, opt={}): beam_size = opt.get('beam_size', 3) # 如果不能取到 beam_size 这个变量, 则令 beam_size 为 3 batch_size = fc_feats.size(0) seq = torch.LongTensor(self.seq_length, batch_size).zero_() seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) top_seq = [] top_prob = [[] for _ in range(batch_size)] done_beams = [[] for _ in range(batch_size)] for k in range(batch_size): state = self.init_hidden(beam_size) beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_() beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_() beam_logprobs_sum = torch.zeros(beam_size) # running sum of logprobs for each beam for t in range(self.seq_length + 1): if t == 0: xt = self.img_embed(fc_feats[k:k+1]).expand(beam_size, self.input_encoding_size) elif t == 1: it = fc_feats.data.new(beam_size).long().fill_(init_index) xt = self.embed(Variable(it, requires_grad=False)) else: logprobsf = logprobs.float() ys, ix = torch.sort(logprobsf, 1, True) candidates = [] cols = min(beam_size, ys.size(1)) rows = beam_size if t == 2: # at first time step only the first beam is active rows = 1 for c in range(cols): for q in range(rows): # compute logprob of expanding beam q with word in (sorted) position c local_logprob = ys[q, c] candidate_logprob = beam_logprobs_sum[q] + local_logprob candidates.append({'c': ix.data[q, c], 'q': q, 'p': candidate_logprob.data[0], 'r': local_logprob.data[0]}) candidates = sorted(candidates, key=lambda x: -x['p']) # construct new beams new_state = [_.clone() for _ in state] if t > 2: # well need these as reference when we fork beams around beam_seq_prev = beam_seq[:t-2].clone() beam_seq_logprobs_prev = beam_seq_logprobs[:t-2].clone() for vix in range(beam_size): v = candidates[vix] # fork beam index q into index vix if t > 2: beam_seq[:t - 2, vix] = beam_seq_prev[:, v['q']] beam_seq_logprobs[:t - 2, vix] = beam_seq_logprobs_prev[:, v['q']] # rearrange recurrent states for state_ix in range(len(new_state)): # copy over state in previous beam q to new beam at vix new_state[state_ix][0, vix] = state[state_ix][0, v['q']] # dimension one is time step # append new end terminal at the end of this beam beam_seq[t - 2, vix] = v['c'] # c'th word is the continuation beam_seq_logprobs[t - 2, vix] = v['r'] # the raw logprob here beam_logprobs_sum[vix] = v['p'] # the new (sum) logprob along this beam if v['c'] == 0 or t == self.seq_length: # END token special case here, or we reached the end. # add the beam to a set of done beams done_beams[k].append({'seq': beam_seq[:, vix].clone(), 'logps': beam_seq_logprobs[:, vix].clone(), 'p': beam_logprobs_sum[vix]}) # encode as vectors it = beam_seq[t - 2] xt = self.embed(Variable(it.cuda())) if t >= 2: state = new_state output, state = self.LSTMCore.forward(xt, state) logprobs = F.log_softmax(self.logit(output)) done_beams[k] = sorted(done_beams[k], key=lambda x: -x['p']) seq[:, k] = done_beams[k][0]['seq'] # the first beam has highest cumulative score seqLogprobs[:, k] = done_beams[k][0]['logps'] # save result l = len(done_beams[k]) top_seq_cur = torch.LongTensor(l, self.seq_length).zero_() for temp_index in range(l): top_seq_cur[temp_index] = done_beams[k][temp_index]['seq'].clone() top_prob[k].append(done_beams[k][temp_index]['p']) top_seq.append(top_seq_cur) # return the samples and their log likelihoods return seq.transpose(0, 1), seqLogprobs.transpose(0, 1), top_seq, top_prob def sample(self, fc_feats, init_index, opt={}): beam_size = opt.get('beam_size', 1) if beam_size > 1: return self.sample_beam(fc_feats, init_index, opt) batch_size = fc_feats.size(0) seq = [] seqLogprobs = [] logprobs_all = [] state = self.init_hidden(batch_size) for t in range(self.seq_length): if t == 0: xt = self.img_embed(fc_feats) else: if t == 1: it = fc_feats.data.new(batch_size).long().fill_(init_index) else: sampleLogprobs, it = torch.max(logprobs.data, 1) it = it.view(-1).long() xt = self.embed(Variable(it, requires_grad=False).cuda()) if t >= 2: if t == 2: unfinished = it > 0 else: unfinished *= (it > 0) if unfinished.sum() == 0: break it = it * unfinished.type_as(it) seq.append(it) seqLogprobs.append(sampleLogprobs.view(-1)) output, state = self.LSTMCore.forward(xt, state) logprobs = F.log_softmax(self.logit(output)) logprobs_all.append(logprobs) return torch.cat([_.unsqueeze(1) for _ in seq], 1), \ torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1), \ torch.cat([_.unsqueeze(1) for _ in logprobs_all], 1).contiguous() def teacher_forcing_get_hidden_states(self, fc_feats, seq): batch_size = fc_feats.size(0) state = self.init_hidden(batch_size) outputs = [] for i in range(seq.size(1)): if i == 0: xt = self.img_embed(fc_feats) else: it = seq[:, i-1].clone() if seq[:, i-1].data.sum() == 0: break xt = self.embed(it) output, state = self.LSTMCore.forward(xt, state) if i > 0: if batch_size == 1: output = F.log_softmax(self.logit(output)) else: output = F.log_softmax(self.logit(output.squeeze(0))) outputs.append(output) return state[0], outputs def free_running_get_hidden_states(self, fc_feats, init_index, end_index): batch_size = fc_feats.size(0) seq = [] seqLogprobs = [] logprobs_all = [] state = self.init_hidden(batch_size) for t in range(self.seq_length): if t == 0: xt = self.img_embed(fc_feats) if t == 1: it = fc_feats.data.new(batch_size).long().fill_(init_index) xt = self.embed(Variable(it, requires_grad=False).cuda()) if t >= 2: sampleLogprobs, it = torch.max(logprobs.data, 1) it = it.view(-1).long() if it.cpu().numpy()[0] == end_index: break xt = self.embed(Variable(it, requires_grad=False).cuda()) seq.append(it) seqLogprobs.append(sampleLogprobs.view(-1)) output, state = self.LSTMCore.forward(xt, state) logprobs = F.log_softmax(self.logit(output)) logprobs_all.append(logprobs) return state[0], logprobs_all