def val(): netE_g.eval() netW_g.eval() netG.eval() n_neg = 100 ques_hidden1 = netE_g.init_hidden(opt.batchSize) hist_hidden1 = netE_g.init_hidden(opt.batchSize) bar = progressbar.ProgressBar(max_value=len(dataloader_val)) data_iter_val = iter(dataloader_val) count = 0 i = 0 result_all = [] # print('length of dataloader: ', len(dataloader_val)) while i < len(dataloader_val): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, 512) with torch.no_grad(): img_input.resize_(image.size()).copy_(image) save_tmp = [[] for j in range(batch_size)] for rnd in range(10): # get the corresponding round QA and history. ques, tans = question[:, rnd, :].t(), answerT[:, rnd, :].t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() opt_ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t() opt_tans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t() gt_id = answer_ids[:, rnd] opt_len = opt_answerLen[:, rnd, :].clone().view(-1) ques_input = torch.LongTensor(ques.size()).cuda() ques_input.copy_(ques) his_input = torch.LongTensor(his.size()).cuda() his_input.copy_(his) opt_ans_input = torch.LongTensor(opt_ans.size()).cuda() opt_ans_input.copy_(opt_ans) opt_ans_target = torch.LongTensor(opt_tans.size()).cuda() opt_ans_target.copy_(opt_tans) gt_index = torch.LongTensor(gt_id.size()) gt_index.copy_(gt_id) ques_emb_g = netW_g(ques_input, format='index') his_emb_g = netW_g(his_input, format='index') ques_hidden1 = repackage_hidden_new(ques_hidden1, batch_size) hist_hidden1 = repackage_hidden_new(hist_hidden1, his_emb_g.size(1)) featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \ ques_hidden1, hist_hidden1, rnd+1) #featD = l2_norm(featD) # Evaluate the Generator: _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1) #_, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden) # extend the hidden sample_ans_input = torch.LongTensor(1, opt.batchSize).cuda() sample_ans_input.resize_((1, batch_size)).fill_(vocab_size) sample_opt = {'beam_size': 1} seq, seqLogprobs = netG.sample(netW_g, sample_ans_input, ques_hidden1, sample_opt) ans_sample_txt = decode_txt(itow, seq.t()) ans_txt = decode_txt(itow, tans) ques_txt = decode_txt(itow, questionL[:, rnd, :].t()) ''' for j in range(len(ans_txt)): print('Q: %s --A: %s --Sampled: %s' %(ques_txt[j], ans_txt[j], ans_sample_txt[j])) ans_sample_z = [[] for z in range(batch_size)] for m in range(5): ans_sample_result = torch.Tensor(ans_length, batch_size) # sample the result. noise_input.data.resize_(ans_length, batch_size, vocab_size+1) noise_input.data.uniform_(0,1) for t in range(ans_length): ans_emb = netW_g(sample_ans_input, format = 'index') if t == 0: logprob, ques_hidden2 = netG(ans_emb.view(1,-1,opt.ninp), ques_hidden1) else: logprob, ques_hidden2 = netG(ans_emb.view(1,-1,opt.ninp), ques_hidden2) one_hot, idx = sampler(logprob, noise_input[t], opt.gumble_weight) sample_ans_input.data.copy_(idx.data) ans_sample_result[t].copy_(idx.data) ans_sample_txt = decode_txt(itow, ans_sample_result) for ii in range(batch_size): ans_sample_z[ii].append(ans_sample_txt[ii]) ''' ans_txt = decode_txt(itow, tans) ques_txt = decode_txt(itow, questionL[:, rnd, :].t()) #for j in range(len(ans_txt)): # print('Q: %s --A: %s --Sampled: %s' %(ques_txt[j], ans_txt[j], ans_sample_txt[j])) for j in range(batch_size): save_tmp[j].append({'ques':ques_txt[j], 'gt_ans':ans_txt[j], \ 'sample_ans':ans_sample_txt[j], 'rnd':rnd, 'img_id':img_id[j].item()}) i += 1 bar.update(i) result_all += save_tmp return result_all
def train(epoch): netW.train() netE.train() netD.train() lr = adjust_learning_rate(optimizer, epoch, opt.lr) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) real_hidden = netD.init_hidden(opt.batchSize) wrong_hidden = netD.init_hidden(opt.batchSize) data_iter = iter(dataloader) average_loss = 0 count = 0 i = 0 while i < len(dataloader): t1 = time.time() data = data_iter.next() image, history, question, answer, answerT, answerLen, answerIdx, questionL, \ opt_answerT, opt_answerLen, opt_answerIdx = data batch_size = question.size(0) image = image.view(-1, img_feat_size) with torch.no_grad(): img_input.resize_(image.size()).copy_(image) for rnd in range(10): netW.zero_grad() netE.zero_grad() netD.zero_grad() # get the corresponding round QA and history. ques = question[:, rnd, :].t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() ans = answer[:, rnd, :].t() tans = answerT[:, rnd, :].t() wrong_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t() real_len = answerLen[:, rnd] wrong_len = opt_answerLen[:, rnd, :].clone().view(-1) ques_input = torch.LongTensor(ques.size()) ques_input.copy_(ques) his_input = torch.LongTensor(his.size()) his_input.copy_(his) ans_input = torch.LongTensor(ans.size()) ans_input.copy_(ans) ans_target = torch.LongTensor(tans.size()) ans_target.copy_(tans) wrong_ans_input = torch.LongTensor(wrong_ans.size()) wrong_ans_input.copy_(wrong_ans) # sample in-batch negative index batch_sample_idx = torch.zeros(batch_size, opt.neg_batch_sample, dtype=torch.long) sample_batch_neg(answerIdx[:, rnd], opt_answerIdx[:, rnd, :], batch_sample_idx, opt.neg_batch_sample) ques_emb = netW(ques_input, format='index') his_emb = netW(his_input, format='index') ques_hidden = repackage_hidden_new(ques_hidden, batch_size) hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1)) featD, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd + 1) ans_real_emb = netW(ans_target, format='index') ans_wrong_emb = netW(wrong_ans_input, format='index') real_hidden = repackage_hidden_new(real_hidden, batch_size) wrong_hidden = repackage_hidden_new(wrong_hidden, ans_wrong_emb.size(1)) real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size) wrong_feat = netD(ans_wrong_emb, wrong_ans_input, wrong_hidden, vocab_size) batch_wrong_feat = wrong_feat.index_select( 0, batch_sample_idx.view(-1)) wrong_feat = wrong_feat.view( batch_size, -1, opt.ninp) # (batch_size, negative_sample, ninp) batch_wrong_feat = batch_wrong_feat.view( batch_size, -1, opt.ninp) # (batch_size, crossover_negative_sample, ninp) # All the correct answers are persent at the begining of the combined_scores combined_scores, l2_norm = feat2score( featD, real_feat, wrong_feat, batch_wrong_feat) # (batch_size, 1 + n_s + n_b_s) lambs, nPairLoss = critD(combined_scores) average_loss += nPairLoss.data.item() l2_norm.backward(retain_graph=True) combined_scores.backward(lambs) optimizer.step() count += 1 i += 1 if i % opt.log_interval == 0: average_loss /= count print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}" \ .format(i, len(dataloader), epoch, average_loss, lr)) average_loss = 0 count = 0 return average_loss, lr
def val(): # ques_hidden = netE.init_hidden(opt.batchSize) netE.eval() netW.eval() netD.eval() n_neg = 100 data_iter_val = iter(dataloader_val) ques_hidden = netE.init_hidden(opt.batchSize) print('DSuccess') hist_hidden = netE.init_hidden(opt.batchSize) opt_hidden = netD.init_hidden(opt.batchSize) i = 0 average_loss = 0 rank_all_tmp = [] while i < len(dataloader_val): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, img_feat_size) # image = l2_norm(image) with torch.no_grad(): img_input.resize_(image.size()).copy_(image) for rnd in range(10): # get the corresponding round QA and history. ques = question[:, rnd, :].t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() opt_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t() gt_id = answer_ids[:, rnd] ques_input = torch.LongTensor(ques.size()) ques_input.copy_(ques) his_input = torch.LongTensor(his.size()) his_input.copy_(his) opt_ans_input = torch.LongTensor(opt_ans.size()) opt_ans_input.copy_(opt_ans) gt_index = torch.LongTensor(gt_id.size()) gt_index.copy_(gt_id) opt_len = opt_answerLen[:, rnd, :].clone().view(-1) ques_emb = netW(ques_input, format='index') his_emb = netW(his_input, format='index') ques_hidden = repackage_hidden_new(ques_hidden, batch_size) hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1)) featD, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd + 1) opt_ans_emb = netW(opt_ans_input, format='index') opt_hidden = repackage_hidden_new(opt_hidden, opt_ans_input.size(1)) opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, vocab_size) opt_feat = opt_feat.view(batch_size, -1, opt.ninp) # ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid) featD = featD.view(-1, opt.ninp, 1) score = torch.bmm(opt_feat, featD) score = score.view(-1, 100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b * 100 gt_score = score.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(score, 1, descending=True) count = sort_score.gt(gt_score.view(-1, 1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_all_tmp += list(rank.view(-1).data.cpu().numpy()) i += 1 sys.stdout.write('Evaluating: {:d}/{:d} \r' \ .format(i, len(dataloader_val))) sys.stdout.flush() return rank_all_tmp
def eval(): netE.eval() netW.eval() netG.eval() data_iter_val = iter(dataloader_val) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) i = 0 display_count = 0 average_loss = 0 rank_all_tmp = [] result_all = [] early_stop = int(opt.early_stop / opt.batchSize) dataloader_size = min(len(dataloader_val), early_stop) print('early_stop: {}'.format(early_stop)) print('dataloader_size: {}'.format(dataloader_size)) while i < dataloader_size: data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, 512) with torch.no_grad(): img_input.resize_(image.size()).copy_(image) # img_input.data.resize_(image.size()).copy_(image) save_tmp = [[] for j in range(batch_size)] for rnd in range(10): # get the corresponding round QA and history. ques, tans = question[:,rnd,:].t(), opt_answerT[:,rnd,:].clone().view(-1, ans_length).t() his = history[:,:rnd+1,:].clone().view(-1, his_length).t() ans = opt_answer[:,rnd,:,:].clone().view(-1, ans_length).t() gt_id = answer_ids[:,rnd] # his_input.data.resize_(his.size()).copy_(his) # ques_input.data.resize_(ques.size()).copy_(ques) # ans_input.data.resize_(ans.size()).copy_(ans) # ans_target.data.resize_(tans.size()).copy_(tans) # # gt_index.data.resize_(gt_id.size()).copy_(gt_id) his_input = torch.LongTensor(his.size()).cuda() his_input.copy_(his) ques_input = torch.LongTensor(ques.size()).cuda() ques_input.copy_(ques) ans_input = torch.LongTensor(ans.size()).cuda() ans_input.copy_(ans) ans_target = torch.LongTensor(tans.size()).cuda() ans_target.copy_(tans) gt_index = torch.LongTensor(gt_id.size()).cuda() gt_index.copy_(gt_id) ques_emb = netW(ques_input, format = 'index') his_emb = netW(his_input, format = 'index') ques_hidden = repackage_hidden_new(ques_hidden, batch_size) hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1)) encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) _, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden) #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid) ans_score = torch.FloatTensor(batch_size, 100).zero_() # extend the hidden hidden_replicated = [] for hid in ques_hidden: hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \ opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid)) hidden_replicated = tuple(hidden_replicated) ans_emb = netW(ans_input, format = 'index') output, _ = netG(ans_emb, hidden_replicated) logprob = - output logprob_select = torch.gather(logprob, 1, ans_target.view(-1,1)) mask = ans_target.data.eq(0) # generate the mask if isinstance(logprob, Variable): mask = Variable(mask, volatile=logprob.volatile) logprob_select.masked_fill_(mask.view_as(logprob_select), 0) prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1,100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b*100 gt_score = prob.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(prob, 1) count = sort_score.lt(gt_score.view(-1,1).expand_as(sort_score)) rank = count.sum(1) + 1 gt_rank_cpu = rank.view(-1).data.cpu().numpy() # --------------------- get the top 10 answers ------------------------- answer_list = tans # 9 x bs*100 new_sorted_idx = torch.LongTensor(batch_size * 10) for b in range(batch_size): new_sorted_idx[b * 10:b * 10 + 10] = sort_idx[b, :10] + b * 100 ans_array = tans.index_select(1, new_sorted_idx) ans_list = decode_txt(itow, ans_array) ques_txt = decode_txt(itow, questionL[:, rnd, :].t()) ans_txt = decode_txt(itow, tans) for b in range(batch_size): data_dict = {} data_dict['ques'] = ques_txt[b] data_dict['gt_ans'] = ans_txt[gt_index[b]] data_dict['top10_ans'] = ans_list[b * 10:(b + 1) * 10] data_dict['rnd'] = rnd data_dict['image_id'] = img_id[b].item() data_dict['gt_ans_rank'] = str(gt_rank_cpu[b]) save_tmp[b].append(data_dict) #------------------------------------------------------------------------ rank_all_tmp += list(rank.view(-1).data.cpu().numpy()) result_all += save_tmp i += 1 sys.stdout.write('Evaluating: {:d}/{:d} \r' \ .format(i, len(dataloader_val))) if i % 50 == 0: R1 = np.sum(np.array(rank_all_tmp)==1) / float(len(rank_all_tmp)) R5 = np.sum(np.array(rank_all_tmp)<=5) / float(len(rank_all_tmp)) R10 = np.sum(np.array(rank_all_tmp)<=10) / float(len(rank_all_tmp)) ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp)) mrr = np.sum(1/(np.array(rank_all_tmp, dtype='float'))) / float(len(rank_all_tmp)) logger.warning('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' %(i, len(dataloader_val), mrr, R1, R5, R10, ave)) return (rank_all_tmp, result_all)
def eval(): netW.eval() netE.eval() netD.eval() data_iter_val = iter(dataloader_val) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) opt_hidden = netD.init_hidden(opt.batchSize) i = 0 display_count = 0 average_loss = 0 rank_all_tmp = [] result_all = [] img_atten = torch.FloatTensor(100 * 30, 10, 7, 7) early_stop = int(opt.early_stop / opt.batchSize) dataloader_size = min(len(dataloader_val), early_stop) print('early_stop: {}'.format(early_stop)) print('dataloader_size: {}'.format(dataloader_size)) while i < dataloader_size: #len(1000): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, 512) with torch.no_grad(): img_input.resize_(image.size()).copy_(image) save_tmp = [[] for j in range(batch_size)] for rnd in range(10): #todo: remove this hard coded rnd = 5 after verifying! # rnd=5 # get the corresponding round QA and history. ques, tans = question[:, rnd, :].t(), answerT[:, rnd, :].t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() opt_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t() gt_id = answer_ids[:, rnd] ques_input = torch.LongTensor(ques.size()).cuda() ques_input.copy_(ques) his_input = torch.LongTensor(his.size()).cuda() his_input.copy_(his) gt_index = torch.LongTensor(gt_id.size()).cuda() gt_index.copy_(gt_id) opt_ans_input = torch.LongTensor(opt_ans.size()).cuda() opt_ans_input.copy_(opt_ans) opt_len = opt_answerLen[:, rnd, :].clone().view(-1) ques_emb = netW(ques_input, format='index') his_emb = netW(his_input, format='index') ques_hidden = repackage_hidden_new(ques_hidden, batch_size) hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1)) featD, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) #img_atten[i*batch_size:(i+1)*batch_size, rnd, :] = img_atten_weight.data.view(batch_size, 7, 7) opt_ans_emb = netW(opt_ans_input, format='index') opt_hidden = repackage_hidden_new(opt_hidden, opt_ans_input.size(1)) opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, n_words) opt_feat = opt_feat.view(batch_size, -1, opt.ninp) featD = featD.view(-1, opt.ninp, 1) score = torch.bmm(opt_feat, featD) score = score.view(-1, 100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b * 100 gt_score = score.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(score, 1, descending=True) opt_answer_cur_ques = opt_answerT.detach().numpy( )[:, rnd, :, :] #5, 100, 9 top_sort_idx = sort_idx.cpu().detach().numpy()[:, 0:opt.topn] first_dim_indices = np.broadcast_to( np.arange(batch_size).reshape(batch_size, 1), (batch_size, opt.topn)).reshape(batch_size * opt.topn) top_ans_word_indices = opt_answer_cur_ques[ first_dim_indices, top_sort_idx.reshape(batch_size * opt.topn), :].reshape( batch_size, opt.topn, 9) top_ans_txt_rank_wise = [] #10, 5 as strings ques_txt = decode_txt(itow, questionL[:, rnd, :].t()) ans_txt = decode_txt(itow, tans) for pos in range(opt.topn): top_temp = decode_txt( itow, torch.tensor(top_ans_word_indices[:, pos, :]).t()) top_ans_txt_rank_wise.append(top_temp) top_ans_txt = np.array(top_ans_txt_rank_wise, dtype=str).T count = sort_score.gt(gt_score.view(-1, 1).expand_as(sort_score)) rank = count.sum(1) + 1 gt_rank_cpu = rank.view(-1).data.cpu().numpy() rank_all_tmp += list(rank.view(-1).data.cpu().numpy()) sort_score_cpu = sort_score.data.cpu().numpy() for b in range(batch_size): save_tmp[b].append({ "ques": ques_txt[b], "gt_ans": ans_txt[b], "top_10_disc_ans": top_ans_txt.tolist()[b], "gt_ans_rank": str(gt_rank_cpu[b]), "rnd": rnd, "img_id": img_id[b].item(), "top_10_scores": sort_score_cpu[b][:opt.topn].tolist() }) i += 1 result_all += save_tmp if i % opt.log_iter == 0: R1 = np.sum(np.array(rank_all_tmp) == 1) / float(len(rank_all_tmp)) R5 = np.sum(np.array(rank_all_tmp) <= 5) / float(len(rank_all_tmp)) R10 = np.sum(np.array(rank_all_tmp) <= 10) / float( len(rank_all_tmp)) ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp)) mrr = np.sum(1 / (np.array(rank_all_tmp, dtype='float'))) / float( len(rank_all_tmp)) logger.warning('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' % (i, len(dataloader_val), mrr, R1, R5, R10, ave)) return (rank_all_tmp, result_all)
def train(epoch): netW_d.train(), netE_d.train(), netE_g.train() netD.train(), netG.train(), netW_g.train() fake_len = torch.LongTensor(opt.batchSize) fake_len = fake_len.cuda() n_neg = opt.negative_sample ques_hidden1 = netE_d.init_hidden(opt.batchSize) ques_hidden2 = netE_g.init_hidden(opt.batchSize) hist_hidden1 = netE_d.init_hidden(opt.batchSize) hist_hidden2 = netE_g.init_hidden(opt.batchSize) real_hidden = netD.init_hidden(opt.batchSize) wrong_hidden = netD.init_hidden(opt.batchSize) fake_hidden = netD.init_hidden(opt.batchSize) data_iter = iter(dataloader) err_d = 0 err_g = 0 err_lm = 0 average_loss = 0 count = 0 i = 0 loss_store = [] while i < len(dataloader): t1 = time.time() data = data_iter.next() image, history, question, answer, answerT, answerLen, answerIdx, questionL, \ opt_answerT, opt_answerLen, opt_answerIdx, _ = data batch_size = question.size(0) image = image.view(-1, 512) with torch.no_grad(): img_input.resize_(image.size()).copy_(image) err_d_tmp = 0 err_g_tmp = 0 err_lm_tmp = 0 err_d_fake_tmp = 0 err_g_fake_tmp = 0 for rnd in range(10): # get the corresponding round QA and history. ques = question[:, rnd, :].t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() ans = answer[:, rnd, :].t() tans = answerT[:, rnd, :].t() wrong_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t() real_len = answerLen[:, rnd].long() wrong_len = opt_answerLen[:, rnd, :].clone().view(-1) ques_input = torch.LongTensor(ques.size()).cuda() ques_input.copy_(ques) his_input = torch.LongTensor(his.size()).cuda() his_input.copy_(his) ans_input = torch.LongTensor(ans.size()).cuda() ans_input.copy_(ans) ans_target = torch.LongTensor(tans.size()).cuda() ans_target.copy_(tans) wrong_ans_input = torch.LongTensor(wrong_ans.size()).cuda() wrong_ans_input.copy_(wrong_ans) batch_sample_idx = torch.zeros(batch_size, opt.neg_batch_sample, dtype=torch.long).cuda() sample_batch_neg(answerIdx[:, rnd], opt_answerIdx[:, rnd, :], batch_sample_idx, opt.neg_batch_sample) # ----------------------------------------- # update the Generator using MLE loss. # ----------------------------------------- if opt.update_LM: ques_emb_g = netW_g(ques_input, format='index') his_emb_g = netW_g(his_input, format='index') ques_hidden1 = repackage_hidden_new(ques_hidden1, batch_size) hist_hidden1 = repackage_hidden_new(hist_hidden1, his_emb_g.size(1)) featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \ ques_hidden1, hist_hidden1, rnd+1) _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1) # MLE loss for generator ans_emb = netW_g(ans_input) logprob, _ = netG(ans_emb, ques_hidden1) lm_loss = critLM(logprob, ans_target.view(-1, 1)) lm_loss = lm_loss / torch.sum(ans_target.data.gt(0)) # total loss = discriminator_loss + alpha*lm_loss lm_loss = opt.alpha * lm_loss netW_g.zero_grad() netG.zero_grad() netE_g.zero_grad() lm_loss.backward() optimizerLM.step() err_lm += lm_loss.data.item() err_lm_tmp += lm_loss.data.item() # sample the answer using gumble softmax sampler. ques_emb_g = netW_g(ques_input, format='index') his_emb_g = netW_g(his_input, format='index') ques_hidden1 = repackage_hidden_new(ques_hidden1, batch_size) hist_hidden1 = repackage_hidden_new(hist_hidden1, his_emb_g.size(1)) featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \ ques_hidden1, hist_hidden1, rnd+1) _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1) # Gumble softmax to sample the output. fake_onehot = [] fake_idx = [] noise_input = torch.FloatTensor(ans_length, batch_size, vocab_size + 1).cuda() noise_input.data.uniform_(0, 1) ans_sample = ans_input[0] for di in range(ans_length): ans_emb = netW_g(ans_sample, format='index') logprob, ques_hidden1 = netG(ans_emb.view(1, -1, opt.ninp), ques_hidden1) one_hot, idx = sampler(logprob, noise_input[di], opt.gumble_weight) fake_onehot.append(one_hot.view(1, -1, vocab_size + 1)) fake_idx.append(idx) if di + 1 < ans_length: ans_sample = idx # convert the list into the tensor variable. fake_onehot = torch.cat(fake_onehot, 0) fake_idx = torch.cat(fake_idx, 0) fake_len.resize_(batch_size).fill_(ans_length - 1) for di in range(ans_length - 1, 0, -1): fake_len.masked_fill_(fake_idx.data[di].eq(vocab_size), di) # generate fake mask. #---------------------------------------------------------------------------- fake_mask = torch.ByteTensor(fake_idx.size()).cuda() fake_mask.resize_(fake_idx.size()).fill_(1) #---------------------------------------------------------------------------- # get the real, wrong and fake index. for b in range(batch_size): fake_mask.data[:fake_len[b] + 1, b] = 0 # apply the mask on the fake_idx. fake_idx.masked_fill_(fake_mask, 0) # get the fake diff mask. #fake_diff_mask = torch.sum(fake_idx == ans_target, 0) != 0 fake_onehot = fake_onehot.view(-1, vocab_size + 1) ###################################### # Discriminative trained generative model. ###################################### # forward the discriminator again. ques_emb_d = netW_d(ques_input, format='index') his_emb_d = netW_d(his_input, format='index') ques_hidden2 = repackage_hidden_new(ques_hidden2, batch_size) hist_hidden2 = repackage_hidden_new(hist_hidden2, his_emb_d.size(1)) featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \ ques_hidden2, hist_hidden2, rnd+1) ans_real_emb = netW_d(ans_target, format='index') #ans_wrong_emb = netW_d(wrong_ans_input, format='index') ans_fake_emb = netW_d(fake_onehot, format='onehot') ans_fake_emb = ans_fake_emb.view(ans_length, -1, opt.ninp) real_hidden = repackage_hidden_new(real_hidden, batch_size) #wrong_hidden = repackage_hidden_new(wrong_hidden, ans_wrong_emb.size(1)) fake_hidden = repackage_hidden_new(fake_hidden, batch_size) fake_feat = netD(ans_fake_emb, fake_idx, fake_hidden, vocab_size) real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size) d_g_loss, g_fake = critG(featD, real_feat, fake_feat) #, fake_diff_mask.detach()) netW_g.zero_grad() netG.zero_grad() netE_g.zero_grad() d_g_loss.backward() optimizerG.step() err_g += d_g_loss.data.item() err_g_tmp += d_g_loss.data.item() err_g_fake_tmp += g_fake count += 1 i += 1 loss_store.append({'iter':i, 'err_lm':err_lm_tmp/10, 'err_d':err_d_tmp/10, 'err_g':err_g_tmp/10, \ 'd_fake': err_d_fake_tmp/10, 'g_fake':err_g_fake_tmp/10}) if i % opt.log_interval == 0: print ('Epoch:%d %d/%d, err_lm %4f, err_d %4f, err_g %4f, d_fake %4f, g_fake %4f' \ % (epoch, i, len(dataloader), err_lm_tmp/10, err_d_tmp/10, err_g_tmp/10, err_d_fake_tmp/10, \ err_g_fake_tmp/10)) #average_loss = average_loss / count err_g = err_g / count err_d = err_d / count err_lm = err_lm / count return err_lm, err_d, err_g, loss_store
def val(): netE_g.eval() netE_d.eval() netW_g.eval() netW_d.eval() netG.eval() netD.eval() n_neg = 100 ques_hidden1 = netE_g.init_hidden(opt.batchSize) ques_hidden2 = netE_d.init_hidden(opt.batchSize) hist_hidden1 = netE_d.init_hidden(opt.batchSize) hist_hidden2 = netE_g.init_hidden(opt.batchSize) opt_hidden = netD.init_hidden(opt.batchSize) data_iter_val = iter(dataloader_val) count = 0 i = 0 rank_G = [] rank_D = [] while i < len(dataloader_val): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, 512) with torch.no_grad(): img_input.resize_(image.size()).copy_(image) for rnd in range(10): # get the corresponding round QA and history. ques = question[:, rnd, :].t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() opt_ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t() opt_tans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t() gt_id = answer_ids[:, rnd] opt_len = opt_answerLen[:, rnd, :].clone().view(-1) #----------------------------------------------------------------- his_input = torch.LongTensor(his.size()).cuda() his_input.copy_(his) ques_input = torch.LongTensor(ques.size()).cuda() ques_input.copy_(ques) opt_ans_input = torch.LongTensor(opt_ans.size()).cuda() opt_ans_input.copy_(opt_ans) opt_ans_target = torch.LongTensor(opt_tans.size()).cuda() opt_ans_target.copy_(opt_tans) gt_index = torch.LongTensor(gt_id.size()).cuda() gt_index.copy_(gt_id) #----------------------------------------------------------------------- ques_emb_g = netW_g(ques_input, format='index') his_emb_g = netW_g(his_input, format='index') ques_emb_d = netW_d(ques_input, format='index') his_emb_d = netW_d(his_input, format='index') ques_hidden1 = repackage_hidden_new(ques_hidden1, batch_size) ques_hidden2 = repackage_hidden_new(ques_hidden2, batch_size) hist_hidden1 = repackage_hidden_new(hist_hidden1, his_emb_g.size(1)) hist_hidden2 = repackage_hidden_new(hist_hidden2, his_emb_d.size(1)) featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \ ques_hidden1, hist_hidden1, rnd+1) featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \ ques_hidden2, hist_hidden2, rnd+1) #featD = l2_norm(featD) # Evaluate the Generator: _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1) #_, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden) # extend the hidden hidden_replicated = [] for hid in ques_hidden1: hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \ opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid)) hidden_replicated = tuple(hidden_replicated) ans_emb = netW_g(opt_ans_input, format='index') output, _ = netG(ans_emb, hidden_replicated) logprob = -output logprob_select = torch.gather(logprob, 1, opt_ans_target.view(-1, 1)) mask = opt_ans_target.data.eq(0) # generate the mask if isinstance(logprob, Variable): mask = Variable(mask, volatile=logprob.volatile) logprob_select.masked_fill_(mask.view_as(logprob_select), 0) prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1, 100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b * 100 gt_score = prob.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(prob, 1) count = sort_score.lt(gt_score.view(-1, 1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_G += list(rank.view(-1).data.cpu().numpy()) opt_ans_emb = netW_d(opt_ans_target, format='index') opt_hidden = repackage_hidden_new(opt_hidden, opt_ans_target.size(1)) opt_feat = netD(opt_ans_emb, opt_ans_target, opt_hidden, vocab_size) opt_feat = opt_feat.view(batch_size, -1, opt.ninp) #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid) featD = featD.view(-1, opt.ninp, 1) score = torch.bmm(opt_feat, featD) score = score.view(-1, 100) gt_score = score.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(score, 1, descending=True) count = sort_score.gt(gt_score.view(-1, 1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_D += list(rank.view(-1).data.cpu().numpy()) i += 1 return rank_G, rank_D
def val(): netE.eval() netW.eval() netG.eval() data_iter_val = iter(dataloader_val) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) i = 0 average_loss = 0 rank_all_tmp = [] while i < len(dataloader_val): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, img_feat_size) with torch.no_grad(): img_input.resize_(image.size()).copy_(image) for rnd in range(10): # get the corresponding round QA and history. ques, tans = question[:, rnd, :].t(), opt_answerT[:, rnd, :].clone( ).view(-1, ans_length).t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t() gt_id = answer_ids[:, rnd] # his_input.data.resize_(his.size()).copy_(his) # ques_input.data.resize_(ques.size()).copy_(ques) # ans_input.data.resize_(ans.size()).copy_(ans) # ans_target.data.resize_(tans.size()).copy_(tans) his_input = torch.LongTensor(his.size()) his_input.copy_(his) ques_input = torch.LongTensor(ques.size()) ques_input.copy_(ques) ans_input = torch.LongTensor(ans.size()) ans_input.copy_(ans) ans_target = torch.LongTensor(tans.size()) ans_target.copy_(tans) gt_index = torch.LongTensor(gt_id.size()) gt_index.copy_(gt_id) ques_emb = netW(ques_input, format='index') his_emb = netW(his_input, format='index') ques_hidden = repackage_hidden_new(ques_hidden, batch_size) hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1)) encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp), ques_hidden) hidden_replicated = [] for hid in ques_hidden: hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \ opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid)) hidden_replicated = tuple(hidden_replicated) ans_emb = netW(ans_input, format='index') output, _ = netG(ans_emb, hidden_replicated) logprob = -output logprob_select = torch.gather(logprob, 1, ans_target.view(-1, 1)) mask = ans_target.data.eq(0) # generate the mask if isinstance(logprob, Variable): mask = Variable(mask, volatile=logprob.volatile) logprob_select.masked_fill_(mask.view_as(logprob_select), 0) prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1, 100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b * 100 gt_score = prob.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(prob, 1) count = sort_score.lt(gt_score.view(-1, 1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_all_tmp += list(rank.view(-1).data.cpu().numpy()) i += 1 return rank_all_tmp, average_loss
def train(epoch): netW.train() netE.train() netG.train() lr = adjust_learning_rate(optimizer, epoch, opt.lr) data_iter = iter(dataloader) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) average_loss = 0 count = 0 i = 0 while i < len(dataloader): data = data_iter.next() image, history, question, answer, answerT, answerLen, answerIdx, \ questionL, negAnswer, negAnswerLen, negAnswerIdx = data batch_size = question.size(0) image = image.view(-1, img_feat_size) with torch.no_grad(): img_input.resize_(image.size()).copy_(image) for rnd in range(10): ques = question[:, rnd, :].t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() ans, tans = answer[:, rnd, :].t(), answerT[:, rnd, :].t() his_input = torch.LongTensor(his.size()) his_input.copy_(his) ques_input = torch.LongTensor(ques.size()) ques_input.copy_(ques) ans_input = torch.LongTensor(ans.size()) ans_input.copy_(ans) ans_target = torch.LongTensor(tans.size()) ans_target.copy_(tans) ques_emb = netW(ques_input, format='index') his_emb = netW(his_input, format='index') ques_hidden = repackage_hidden_new(ques_hidden, batch_size) hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1)) encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp), ques_hidden) ans_emb = netW(ans_input) logprob, ques_hidden = netG(ans_emb, ques_hidden) loss = critG(logprob, ans_target.view(-1, 1)) loss = loss / torch.sum(ans_target.data.gt(0)) average_loss += loss.data.item() # do backward. netW.zero_grad() netE.zero_grad() netG.zero_grad() loss.backward() optimizer.step() count += 1 i += 1 if i % opt.log_interval == 0: average_loss /= count print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}"\ .format(i, len(dataloader), epoch, average_loss, lr)) average_loss = 0 count = 0 return average_loss, lr
def eval(): netE.eval() netW.eval() netG.eval() data_iter_val = iter(dataloader_val) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) i = 0 display_count = 0 average_loss = 0 rank_all_tmp = [] while i < len(dataloader_val): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, 512) with torch.no_grad(): img_input.resize_(image.size()).copy_(image) # img_input.data.resize_(image.size()).copy_(image) for rnd in range(10): # get the corresponding round QA and history. ques, tans = question[:, rnd, :].t(), opt_answerT[:, rnd, :].clone().view(-1, ans_length).t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t() gt_id = answer_ids[:, rnd] # his_input.data.resize_(his.size()).copy_(his) # ques_input.data.resize_(ques.size()).copy_(ques) # ans_input.data.resize_(ans.size()).copy_(ans) # ans_target.data.resize_(tans.size()).copy_(tans) # # gt_index.data.resize_(gt_id.size()).copy_(gt_id) his_input = torch.LongTensor(his.size()).cpu() his_input.copy_(his) ques_input = torch.LongTensor(ques.size()).cpu() ques_input.copy_(ques) ans_input = torch.LongTensor(ans.size()).cpu() ans_input.copy_(ans) ans_target = torch.LongTensor(tans.size()).cpu() ans_target.copy_(tans) gt_index = torch.LongTensor(gt_id.size()).cpu() gt_index.copy_(gt_id) ques_emb = netW(ques_input, format='index') his_emb = netW(his_input, format='index') ques_hidden = repackage_hidden_new(ques_hidden, batch_size) hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1)) encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd + 1) _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp), ques_hidden) # ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid) ans_score = torch.FloatTensor(batch_size, 100).zero_() # extend the hidden hidden_replicated = [] for hid in ques_hidden: hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \ opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view( opt.nlayers, -1, opt.nhid)) hidden_replicated = tuple(hidden_replicated) ans_emb = netW(ans_input, format='index') output, _ = netG(ans_emb, hidden_replicated) logprob = - output logprob_select = torch.gather(logprob, 1, ans_target.view(-1, 1)) mask = ans_target.data.eq(0) # generate the mask if isinstance(logprob, Variable): mask = Variable(mask, volatile=logprob.volatile) logprob_select.masked_fill_(mask.view_as(logprob_select), 0) prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1, 100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b * 100 gt_score = prob.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(prob, 1) count = sort_score.lt(gt_score.view(-1, 1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_all_tmp += list(rank.view(-1).data.cpu().numpy()) i += 1 sys.stdout.write('Evaluating: {:d}/{:d} \r' \ .format(i, len(dataloader_val))) if i % 50 == 0: R1 = np.sum(np.array(rank_all_tmp) == 1) / float(len(rank_all_tmp)) R5 = np.sum(np.array(rank_all_tmp) <= 5) / float(len(rank_all_tmp)) R10 = np.sum(np.array(rank_all_tmp) <= 10) / float(len(rank_all_tmp)) ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp)) mrr = np.sum(1 / (np.array(rank_all_tmp, dtype='float'))) / float(len(rank_all_tmp)) print('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' % (1, len(dataloader_val), mrr, R1, R5, R10, ave)) return rank_all_tmp
ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) ind = his.size(1) his_input = torch.LongTensor(his.size()) his_input.copy_(his) ques_input = torch.LongTensor(ques.size()) ques_input.copy_(ques) ques_emb = netW(ques_input, format='index') his_emb = netW(his_input, format='index') ques_hidden = repackage_hidden_new(ques_hidden, 1) hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1)) encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, ind) _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp), ques_hidden) # generate ans based on ques_hidden # using netG(x , ques_hidden) # Gumble softmax to sample the output. ans_length = 16 fake_onehot = [] fake_idx = [] noise_input = torch.FloatTensor(
def train(epoch): netW.train() netE.train() netD.train() lr = adjust_learning_rate(optimizer, epoch, opt.lr) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) real_hidden = netD.init_hidden(opt.batchSize) wrong_hidden = netD.init_hidden(opt.batchSize) data_iter = iter(dataloader) average_loss = 0 avg_dist_summary = np.zeros(3, dtype=float) smooth_avg_dist_summary = np.zeros(3, dtype=float) count = 0 i = 0 # size of data to work on early_stop = int(opt.early_stop / opt.batchSize) dataloader_size = min(len(dataloader), early_stop) while i < dataloader_size: #len(dataloader): t1 = time.time() data = data_iter.next() image, history, question, answer, answerT, answerLen, answerIdx, questionL, \ opt_answerT, opt_answerLen, opt_answerIdx, opt_selected_probs = data batch_size = question.size(0) image = image.view(-1, img_feat_size) with torch.no_grad(): img_input.resize_(image.size()).copy_(image) for rnd in range(10): netW.zero_grad() netE.zero_grad() netD.zero_grad() # get the corresponding round QA and history. ques = question[:,rnd,:].t() his = history[:,:rnd+1,:].clone().view(-1, his_length).t() ans = answer[:,rnd,:].t() tans = answerT[:,rnd,:].t() wrong_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t() opt_selected_probs_for_rnd = opt_selected_probs[:, rnd, :, :] real_len = answerLen[:,rnd] wrong_len = opt_answerLen[:,rnd,:].clone().view(-1) ques_input = torch.LongTensor(ques.size()).cuda() ques_input.copy_(ques) his_input = torch.LongTensor(his.size()).cuda() his_input.copy_(his) ans_input = torch.LongTensor(ans.size()).cuda() ans_input.copy_(ans) ans_target = torch.LongTensor(tans.size()).cuda() ans_target.copy_(tans) wrong_ans_input = torch.LongTensor(wrong_ans.size()).cuda() wrong_ans_input.copy_(wrong_ans) opt_selected_probs_for_rnd_input = torch.FloatTensor(opt_selected_probs_for_rnd.size()).cuda() opt_selected_probs_for_rnd_input.copy_(opt_selected_probs_for_rnd) # # sample in-batch negative index # batch_sample_idx = torch.zeros(batch_size, opt.neg_batch_sample, dtype=torch.long).cuda() # sample_batch_neg(answerIdx[:,rnd], opt_answerIdx[:,rnd,:], batch_sample_idx, opt.neg_batch_sample) ques_emb = netW(ques_input, format = 'index') his_emb = netW(his_input, format = 'index') ques_hidden = repackage_hidden_new(ques_hidden, batch_size) hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1)) featD, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) ans_real_emb = netW(ans_target, format='index') ans_wrong_emb = netW(wrong_ans_input, format='index') real_hidden = repackage_hidden_new(real_hidden, batch_size) wrong_hidden = repackage_hidden_new(wrong_hidden, ans_wrong_emb.size(1)) real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size) wrong_feat = netD(ans_wrong_emb, wrong_ans_input, wrong_hidden, vocab_size) # batch_wrong_feat = wrong_feat.index_select(0, batch_sample_idx.view(-1)) wrong_feat = wrong_feat.view(batch_size, -1, opt.ninp) # batch_wrong_feat = batch_wrong_feat.view(batch_size, -1, opt.ninp) nPairLoss, dist_summary, smooth_dist_summary = \ critD(featD, real_feat, wrong_feat, opt_selected_probs_for_rnd_input) average_loss += nPairLoss.data.item() avg_dist_summary += dist_summary.cpu().detach().numpy() smooth_avg_dist_summary += smooth_dist_summary.cpu().detach().numpy() nPairLoss.backward() optimizer.step() count += 1 i += 1 if i % opt.log_interval == 0: average_loss /= count avg_dist_summary = avg_dist_summary/np.sum(avg_dist_summary) smooth_avg_dist_summary = smooth_avg_dist_summary/np.sum(smooth_avg_dist_summary) print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}, CEN dist: {}, CEN smooth: {}"\ .format(i, len(dataloader), epoch, average_loss, lr, avg_dist_summary, smooth_avg_dist_summary)) average_loss = 0 avg_dist_summary = np.zeros(3, dtype=float) smooth_avg_dist_summary = np.zeros(3, dtype=float) count = 0 return average_loss, lr