def sample(): embeder.eval() encoder.eval() decoder.eval() data_iter_val = iter(dataloader_val) hidden = encoder.init_hidden(opt.batchSize) i = 0 while i < len(dataloader_val): data = data_iter_val.next() history, question, answer, answerT, questionL, opt_answer, opt_answerT, answer_ids = data batch_size = question.size(0) for rnd in range(10): ques, tans = question[:,rnd,:].t(), answerT[:,rnd,:].t() #his = history[:,:rnd+1,:].clone().view(-1, his_length).t() ans = opt_answer[:,rnd,:,:].clone().view(-1, ans_length).t() gt_id = answer_ids[:,rnd] ques_input.data.resize_(ques.size()).copy_(ques) ans_input.data.resize_(ans.size()).copy_(ans) ans_target.data.resize_(tans.size()).copy_(tans) ques_emb = embeder(ques_input) hidden = repackage_hidden(hidden, batch_size) _, hidden = encoder(ques_emb, hidden) #output, hidden = decoder(ans_emb, hidden) #loss = crit(output, ans_target.view(-1,1)) #average_loss += loss.data[0] #count += 1 ans_sample_result = torch.Tensor(ans_length, batch_size) ans_sample.data.resize_((1, batch_size)).fill_(n_words) # sample the result. noise_input.data.resize_(ans_length, batch_size, n_words+1) noise_input.data.uniform_(0,1) for t in range(ans_length): ans_sample_embed = embeder(ans_sample) output, hidden = decoder(ans_sample_embed, hidden) prob = - output #_, idx = torch.max(prob, 1) one_hot, idx = sampler(prob, noise_input[t], 0.5) ans_sample.data.copy_(idx.data) ans_sample_result[t].copy_(ans_sample.data) ans_sample_txt = decode_txt(itow, ans_sample_result) ans_txt = decode_txt(itow, tans) ques_txt = decode_txt(itow, questionL[:,rnd,:].t()) for j in range(len(ans_txt)): print('Q: %s --A: %s --Sampled: %s' %(ques_txt[j], ans_txt[j], ans_sample_txt[j])) pdb.set_trace() i += 1
def sample(): embeder.eval() encoder.eval() decoder.eval() data_iter_val = iter(dataloader_val) hidden = encoder.init_hidden(opt.batchSize) i = 0 while i < len(dataloader_val): data = data_iter_val.next() history, question, answer, answerT, questionL, opt_answer, opt_answerT, answer_ids = data batch_size = question.size(0) for rnd in range(10): ques, tans = question[:, rnd, :].t(), answerT[:, rnd, :].t() #his = history[:,:rnd+1,:].clone().view(-1, his_length).t() ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t() gt_id = answer_ids[:, rnd] ques_input.data.resize_(ques.size()).copy_(ques) ans_input.data.resize_(ans.size()).copy_(ans) ans_target.data.resize_(tans.size()).copy_(tans) ques_emb = embeder(ques_input) hidden = repackage_hidden(hidden, batch_size) _, hidden = encoder(ques_emb, hidden) #output, hidden = decoder(ans_emb, hidden) #loss = crit(output, ans_target.view(-1,1)) #average_loss += loss.data[0] #count += 1 ans_sample_result = torch.Tensor(ans_length, batch_size) ans_sample.data.resize_((1, batch_size)).fill_(n_words) # sample the result. noise_input.data.resize_(ans_length, batch_size, n_words + 1) noise_input.data.uniform_(0, 1) for t in range(ans_length): ans_sample_embed = embeder(ans_sample) output, hidden = decoder(ans_sample_embed, hidden) prob = -output #_, idx = torch.max(prob, 1) one_hot, idx = sampler(prob, noise_input[t], 0.5) ans_sample.data.copy_(idx.data) ans_sample_result[t].copy_(ans_sample.data) ans_sample_txt = decode_txt(itow, ans_sample_result) ans_txt = decode_txt(itow, tans) ques_txt = decode_txt(itow, questionL[:, rnd, :].t()) for j in range(len(ans_txt)): print('Q: %s --A: %s --Sampled: %s' % (ques_txt[j], ans_txt[j], ans_sample_txt[j])) pdb.set_trace() i += 1
def val(): netE.eval() netW.eval() netG.eval() data_iter_val = iter(dataloader_val) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) i = 0 average_loss = 0 rank_all_tmp = [] while i < len(dataloader_val): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, img_feat_size) img_input.data.resize_(image.size()).copy_(image) for rnd in range(10): # get the corresponding round QA and history. ques, tans = question[:, rnd, :].t(), opt_answerT[:, rnd, :].clone( ).view(-1, ans_length).t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t() gt_id = answer_ids[:, rnd] his_input.data.resize_(his.size()).copy_(his) ques_input.data.resize_(ques.size()).copy_(ques) ans_input.data.resize_(ans.size()).copy_(ans) ans_target.data.resize_(tans.size()).copy_(tans) gt_index.data.resize_(gt_id.size()).copy_(gt_id) ques_emb = netW(ques_input, format='index') his_emb = netW(his_input, format='index') ques_hidden = repackage_hidden(ques_hidden, batch_size) hist_hidden = repackage_hidden(hist_hidden, his_input.size(1)) encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp), ques_hidden) hidden_replicated = [] for hid in ques_hidden: hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \ opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid)) hidden_replicated = tuple(hidden_replicated) ans_emb = netW(ans_input, format='index') output, _ = netG(ans_emb, hidden_replicated) logprob = -output logprob_select = torch.gather(logprob, 1, ans_target.view(-1, 1)) mask = ans_target.data.eq(0) # generate the mask if isinstance(logprob, Variable): mask = Variable(mask, volatile=logprob.volatile) logprob_select.masked_fill_(mask.view_as(logprob_select), 0) prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1, 100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b * 100 gt_score = prob.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(prob, 1) count = sort_score.lt(gt_score.view(-1, 1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_all_tmp += list(rank.view(-1).data.cpu().numpy()) i += 1 return rank_all_tmp, average_loss
def train(epoch): netW.train() netE.train() netG.train() lr = adjust_learning_rate(optimizer, epoch, opt.lr) data_iter = iter(dataloader) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) average_loss = 0 count = 0 i = 0 while i < len(dataloader): data = data_iter.next() image, history, question, answer, answerT, answerLen, answerIdx, \ questionL, negAnswer, negAnswerLen, negAnswerIdx = data batch_size = question.size(0) image = image.view(-1, img_feat_size) img_input.data.resize_(image.size()).copy_(image) for rnd in range(10): ques = question[:, rnd, :].t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() ans, tans = answer[:, rnd, :].t(), answerT[:, rnd, :].t() his_input.data.resize_(his.size()).copy_(his) ques_input.data.resize_(ques.size()).copy_(ques) ans_input.data.resize_(ans.size()).copy_(ans) ans_target.data.resize_(tans.size()).copy_(tans) ques_emb = netW(ques_input, format='index') his_emb = netW(his_input, format='index') ques_hidden = repackage_hidden(ques_hidden, batch_size) hist_hidden = repackage_hidden(hist_hidden, his_input.size(1)) encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp), ques_hidden) ans_emb = netW(ans_input) logprob, ques_hidden = netG(ans_emb, ques_hidden) loss = critG(logprob, ans_target.view(-1, 1)) loss = loss / torch.sum(ans_target.data.gt(0)) average_loss += loss.data[0] # do backward. netW.zero_grad() netE.zero_grad() netG.zero_grad() loss.backward() optimizer.step() count += 1 i += 1 if i % opt.log_interval == 0: average_loss /= count print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}"\ .format(i, len(dataloader), epoch, average_loss, lr)) average_loss = 0 count = 0 return average_loss
def val(): netE.eval() netW.eval() netD.eval() n_neg = 100 data_iter_val = iter(dataloader_val) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) opt_hidden = netD.init_hidden(opt.batchSize) i = 0 average_loss = 0 rank_all_tmp = [] while i < len(dataloader_val): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, img_feat_size) #image = l2_norm(image) img_input.data.resize_(image.size()).copy_(image) for rnd in range(10): # get the corresponding round QA and history. ques = question[:,rnd,:].t() his = history[:,:rnd+1,:].clone().view(-1, his_length).t() opt_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t() gt_id = answer_ids[:,rnd] ques_input.data.resize_(ques.size()).copy_(ques) his_input.data.resize_(his.size()).copy_(his) opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans) gt_index.data.resize_(gt_id.size()).copy_(gt_id) opt_len = opt_answerLen[:,rnd,:].clone().view(-1) ques_emb = netW(ques_input, format = 'index') his_emb = netW(his_input, format = 'index') ques_hidden = repackage_hidden(ques_hidden, batch_size) hist_hidden = repackage_hidden(hist_hidden, his_input.size(1)) featD, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) opt_ans_emb = netW(opt_ans_input, format = 'index') opt_hidden = repackage_hidden(opt_hidden, opt_ans_input.size(1)) opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, vocab_size) opt_feat = opt_feat.view(batch_size, -1, opt.ninp) #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid) featD = featD.view(-1, opt.ninp, 1) score = torch.bmm(opt_feat, featD) score = score.view(-1, 100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b*100 gt_score = score.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(score, 1, descending=True) count = sort_score.gt(gt_score.view(-1,1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_all_tmp += list(rank.view(-1).data.cpu().numpy()) i += 1 sys.stdout.write('Evaluating: {:d}/{:d} \r' \ .format(i, len(dataloader_val))) sys.stdout.flush() return rank_all_tmp
def train(epoch): netW.train() netE.train() netD.train() lr = adjust_learning_rate(optimizer, epoch, opt.lr) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) real_hidden = netD.init_hidden(opt.batchSize) wrong_hidden = netD.init_hidden(opt.batchSize) data_iter = iter(dataloader) average_loss = 0 count = 0 i = 0 while i < len(dataloader): t1 = time.time() data = data_iter.next() image, history, question, answer, answerT, answerLen, answerIdx, questionL, \ opt_answerT, opt_answerLen, opt_answerIdx = data batch_size = question.size(0) image = image.view(-1, img_feat_size) img_input.data.resize_(image.size()).copy_(image) for rnd in range(10): netW.zero_grad() netE.zero_grad() netD.zero_grad() # get the corresponding round QA and history. ques = question[:,rnd,:].t() his = history[:,:rnd+1,:].clone().view(-1, his_length).t() ans = answer[:,rnd,:].t() tans = answerT[:,rnd,:].t() wrong_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t() real_len = answerLen[:,rnd] wrong_len = opt_answerLen[:,rnd,:].clone().view(-1) ques_input.data.resize_(ques.size()).copy_(ques) his_input.data.resize_(his.size()).copy_(his) ans_input.data.resize_(ans.size()).copy_(ans) ans_target.data.resize_(tans.size()).copy_(tans) wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans) # sample in-batch negative index batch_sample_idx.data.resize_(batch_size, opt.neg_batch_sample).zero_() sample_batch_neg(answerIdx[:,rnd], opt_answerIdx[:,rnd,:], batch_sample_idx, opt.neg_batch_sample) ques_emb = netW(ques_input, format = 'index') his_emb = netW(his_input, format = 'index') ques_hidden = repackage_hidden(ques_hidden, batch_size) hist_hidden = repackage_hidden(hist_hidden, his_input.size(1)) featD, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) ans_real_emb = netW(ans_target, format='index') ans_wrong_emb = netW(wrong_ans_input, format='index') real_hidden = repackage_hidden(real_hidden, batch_size) wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1)) real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size) wrong_feat = netD(ans_wrong_emb, wrong_ans_input, wrong_hidden, vocab_size) batch_wrong_feat = wrong_feat.index_select(0, batch_sample_idx.view(-1)) wrong_feat = wrong_feat.view(batch_size, -1, opt.ninp) batch_wrong_feat = batch_wrong_feat.view(batch_size, -1, opt.ninp) nPairLoss = critD(featD, real_feat, wrong_feat, batch_wrong_feat) average_loss += nPairLoss.data[0] nPairLoss.backward() optimizer.step() count += 1 i += 1 if i % opt.log_interval == 0: average_loss /= count print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}"\ .format(i, len(dataloader), epoch, average_loss, lr)) average_loss = 0 count = 0 return average_loss
def train(epoch): netW_d.train(), netE_d.train(), netE_g.train() netD.train(), netG.train(), netW_g.train() fake_len = torch.LongTensor(opt.batchSize) fake_len = fake_len.cuda() n_neg = opt.negative_sample ques_hidden1 = netE_d.init_hidden(opt.batchSize) ques_hidden2 = netE_g.init_hidden(opt.batchSize) hist_hidden1 = netE_d.init_hidden(opt.batchSize) hist_hidden2 = netE_g.init_hidden(opt.batchSize) real_hidden = netD.init_hidden(opt.batchSize) wrong_hidden = netD.init_hidden(opt.batchSize) fake_hidden = netD.init_hidden(opt.batchSize) data_iter = iter(dataloader) err_d = 0 err_g = 0 err_lm = 0 average_loss = 0 count = 0 i = 0 loss_store = [] while i < len(dataloader): t1 = time.time() data = data_iter.next() image, history, question, answer, answerT, answerLen, answerIdx, questionL, \ opt_answerT, opt_answerLen, opt_answerIdx = data batch_size = question.size(0) image = image.view(-1, 512) img_input.data.resize_(image.size()).copy_(image) err_d_tmp = 0 err_g_tmp = 0 err_lm_tmp = 0 err_d_fake_tmp = 0 err_g_fake_tmp = 0 for rnd in range(10): # get the corresponding round QA and history. ques = question[:,rnd,:].t() his = history[:,:rnd+1,:].clone().view(-1, his_length).t() ans = answer[:,rnd,:].t() tans = answerT[:,rnd,:].t() wrong_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t() real_len = answerLen[:,rnd].long() wrong_len = opt_answerLen[:,rnd,:].clone().view(-1) ques_input.data.resize_(ques.size()).copy_(ques) his_input.data.resize_(his.size()).copy_(his) ans_input.data.resize_(ans.size()).copy_(ans) ans_target.data.resize_(tans.size()).copy_(tans) wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans) batch_sample_idx.data.resize_(batch_size, opt.neg_batch_sample).zero_() sample_batch_neg(answerIdx[:,rnd], opt_answerIdx[:,rnd,:], batch_sample_idx, opt.neg_batch_sample) # ----------------------------------------- # update the Generator using MLE loss. # ----------------------------------------- if opt.update_LM: ques_emb_g = netW_g(ques_input, format = 'index') his_emb_g = netW_g(his_input, format = 'index') ques_hidden1 = repackage_hidden(ques_hidden1, batch_size) hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1)) featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \ ques_hidden1, hist_hidden1, rnd+1) _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1) # MLE loss for generator ans_emb = netW_g(ans_input) logprob, _ = netG(ans_emb, ques_hidden1) lm_loss = critLM(logprob, ans_target.view(-1, 1)) lm_loss = lm_loss / torch.sum(ans_target.data.gt(0)) netW_g.zero_grad() netG.zero_grad() netE_g.zero_grad() lm_loss.backward() optimizerLM.step() err_lm += lm_loss.data[0] err_lm_tmp += lm_loss.data[0] # sample the answer using gumble softmax sampler. ques_emb_g = netW_g(ques_input, format = 'index') his_emb_g = netW_g(his_input, format = 'index') ques_hidden1 = repackage_hidden(ques_hidden1, batch_size) hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1)) featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \ ques_hidden1, hist_hidden1, rnd+1) _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1) # Gumble softmax to sample the output. fake_onehot = [] fake_idx = [] noise_input.data.resize_(ans_length, batch_size, vocab_size+1) noise_input.data.uniform_(0,1) ans_sample = ans_input[0] for di in range(ans_length): ans_emb = netW_g(ans_sample, format = 'index') logprob, ques_hidden1 = netG(ans_emb.view(1,-1,opt.ninp), ques_hidden1) one_hot, idx = sampler(logprob, noise_input[di], opt.gumble_weight) fake_onehot.append(one_hot.view(1, -1, vocab_size+1)) fake_idx.append(idx) if di+1 < ans_length: ans_sample = idx # convert the list into the tensor variable. fake_onehot = torch.cat(fake_onehot, 0) fake_idx = torch.cat(fake_idx,0) fake_len.resize_(batch_size).fill_(ans_length-1) for di in range(ans_length-1, 0, -1): fake_len.masked_fill_(fake_idx.data[di].eq(vocab_size), di) # generate fake mask. fake_mask.data.resize_(fake_idx.size()).fill_(1) # get the real, wrong and fake index. for b in range(batch_size): fake_mask.data[:fake_len[b]+1, b] = 0 # apply the mask on the fake_idx. fake_idx.masked_fill_(fake_mask, 0) # get the fake diff mask. #fake_diff_mask = torch.sum(fake_idx == ans_target, 0) != 0 fake_onehot = fake_onehot.view(-1, vocab_size+1) ###################################### # Discriminative trained generative model. ###################################### # forward the discriminator again. ques_emb_d = netW_d(ques_input, format = 'index') his_emb_d = netW_d(his_input, format = 'index') ques_hidden2 = repackage_hidden(ques_hidden2, batch_size) hist_hidden2 = repackage_hidden(hist_hidden2, his_emb_d.size(1)) featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \ ques_hidden2, hist_hidden2, rnd+1) ans_real_emb = netW_d(ans_target, format='index') #ans_wrong_emb = netW_d(wrong_ans_input, format='index') ans_fake_emb = netW_d(fake_onehot, format='onehot') ans_fake_emb = ans_fake_emb.view(ans_length, -1, opt.ninp) real_hidden = repackage_hidden(real_hidden, batch_size) #wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1)) fake_hidden = repackage_hidden(fake_hidden, batch_size) fake_feat = netD(ans_fake_emb, fake_idx, fake_hidden, vocab_size) real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size) d_g_loss, g_fake = critG(featD, real_feat, fake_feat)#, fake_diff_mask.detach()) netW_g.zero_grad() netG.zero_grad() netE_g.zero_grad() d_g_loss.backward() optimizerG.step() err_g += d_g_loss.data[0] err_g_tmp += d_g_loss.data[0] err_g_fake_tmp += g_fake count += 1 i += 1 loss_store.append({'iter':i, 'err_lm':err_lm_tmp/10, 'err_d':err_d_tmp/10, 'err_g':err_g_tmp/10, \ 'd_fake': err_d_fake_tmp/10, 'g_fake':err_g_fake_tmp/10}) if i % 20 == 0: print ('Epoch:%d %d/%d, err_lm %4f, err_d %4f, err_g %4f, d_fake %4f, g_fake %4f' \ % (epoch, i, len(dataloader), err_lm_tmp/10, err_d_tmp/10, err_g_tmp/10, err_d_fake_tmp/10, \ err_g_fake_tmp/10)) #average_loss = average_loss / count err_g = err_g / count err_d = err_d / count err_lm = err_lm / count return err_lm, err_d, err_g, loss_store
def eval(): netE.eval() netW.eval() netG.eval() data_iter_val = iter(dataloader_val) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) i = 0 display_count = 0 average_loss = 0 rank_all_tmp = [] while i < len(dataloader_val): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, 512) img_input.data.resize_(image.size()).copy_(image) for rnd in range(10): # get the corresponding round QA and history. ques, tans = question[:, rnd, :].t(), opt_answerT[:, rnd, :].clone( ).view(-1, ans_length).t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t() gt_id = answer_ids[:, rnd] his_input.data.resize_(his.size()).copy_(his) ques_input.data.resize_(ques.size()).copy_(ques) ans_input.data.resize_(ans.size()).copy_(ans) ans_target.data.resize_(tans.size()).copy_(tans) gt_index.data.resize_(gt_id.size()).copy_(gt_id) ques_emb = netW(ques_input, format='index') his_emb = netW(his_input, format='index') ques_hidden = repackage_hidden(ques_hidden, batch_size) hist_hidden = repackage_hidden(hist_hidden, his_input.size(1)) encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp), ques_hidden) #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid) ans_score = torch.FloatTensor(batch_size, 100).zero_() # extend the hidden hidden_replicated = [] for hid in ques_hidden: hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \ opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid)) hidden_replicated = tuple(hidden_replicated) ans_emb = netW(ans_input, format='index') output, _ = netG(ans_emb, hidden_replicated) logprob = -output logprob_select = torch.gather(logprob, 1, ans_target.view(-1, 1)) mask = ans_target.data.eq(0) # generate the mask if isinstance(logprob, Variable): mask = Variable(mask, volatile=logprob.volatile) logprob_select.masked_fill_(mask.view_as(logprob_select), 0) prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1, 100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b * 100 gt_score = prob.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(prob, 1) count = sort_score.lt(gt_score.view(-1, 1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_all_tmp += list(rank.view(-1).data.cpu().numpy()) i += 1 sys.stdout.write('Evaluating: {:d}/{:d} \r' \ .format(i, len(dataloader_val))) if i % 50 == 0: R1 = np.sum(np.array(rank_all_tmp) == 1) / float(len(rank_all_tmp)) R5 = np.sum(np.array(rank_all_tmp) <= 5) / float(len(rank_all_tmp)) R10 = np.sum(np.array(rank_all_tmp) <= 10) / float( len(rank_all_tmp)) ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp)) mrr = np.sum(1 / (np.array(rank_all_tmp, dtype='float'))) / float( len(rank_all_tmp)) print('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' % (1, len(dataloader_val), mrr, R1, R5, R10, ave)) return rank_all_tmp
def eval_val(): netW.eval() netE.eval() netD.eval() data_iter_test = iter(dataloader_test) ques_hidden = netE.init_bi_hidden(opt.batchSize) hist_hidden = netE.init_bi_hidden(opt.batchSize) opt_hidden = netD.init_bi_hidden(opt.batchSize) i = 0 display_count = 0 average_loss = 0 img_atten = torch.FloatTensor(100 * 30, 10, 7, 7) result = [] while i < len(dataloader_test): #len(1000): data = data_iter_test.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, 2048) img_input.data.resize_(image.size()).copy_(image) # save_tmp = [[] for j in range(batch_size)] for rnd in range(10): result_tmp = {'image_id': 1, 'round_id': 1, 'ranks': []} # get the corresponding round QA and history. ques = question[:, rnd, :].t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() opt_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t() # gt_id = answer_ids[:,rnd] ques_input.data.resize_(ques.size()).copy_(ques) his_input.data.resize_(his.size()).copy_(his) # gt_index.data.resize_(gt_id.size()).copy_(gt_id) opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans) # opt_len = opt_answerLen[:,rnd,:].clone().view(-1) ques_emb = netW(ques_input, format='index') his_emb = netW(his_input, format='index') ques_hidden = repackage_hidden(ques_hidden, batch_size) hist_hidden = repackage_hidden(hist_hidden, his_input.size(1)) featD, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) #img_atten[i*batch_size:(i+1)*batch_size, rnd, :] = img_atten_weight.data.view(batch_size, 7, 7) opt_ans_emb = netW(opt_ans_input, format='index') opt_hidden = repackage_hidden(opt_hidden, opt_ans_input.size(1)) opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, vocab_size) opt_feat = opt_feat.view(batch_size, -1, opt.ninp) featD = featD.view(-1, opt.ninp, 1) score = torch.bmm(opt_feat, featD) score = score.view(-1, 100) sort_score, sort_idx = torch.sort(score, 1, descending=True) sort_idx = sort_idx + 1 sort_idx = sort_idx.squeeze().cpu().data.type(torch.IntTensor) sort_result = list(sort_idx) sort_final = [1] * 100 for rank, w in enumerate(sort_result): sort_final[w - 1] = rank + 1 result_tmp['image_id'] = int(img_id) result_tmp['round_id'] = int(rnd + 1) result_tmp['ranks'] = sort_final # sort_result = [rank +'\n' for rank in sort_result] result.append(result_tmp) i += 1 if i % 100 == 0: print(i) return result
def val(): netE_g.eval() netW_g.eval() netG.eval() n_neg = 100 ques_hidden1 = netE_g.init_hidden(opt.batchSize) hist_hidden1 = netE_g.init_hidden(opt.batchSize) bar = progressbar.ProgressBar(max_value=300) #len(dataloader)) data_iter_val = iter(dataloader_val) count = 0 i = 0 result_all = [] while i < 300: #len(dataloader_val): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, 512) img_input.data.resize_(image.size()).copy_(image) save_tmp = [[] for j in range(batch_size)] for rnd in range(10): # get the corresponding round QA and history. ques, tans = question[:, rnd, :].t(), answerT[:, rnd, :].t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() opt_ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t() opt_tans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t() gt_id = answer_ids[:, rnd] opt_len = opt_answerLen[:, rnd, :].clone().view(-1) ques_input.data.resize_(ques.size()).copy_(ques) his_input.data.resize_(his.size()).copy_(his) opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans) opt_ans_target.data.resize_(opt_tans.size()).copy_(opt_tans) gt_index.data.resize_(gt_id.size()).copy_(gt_id) ques_emb_g = netW_g(ques_input, format='index') his_emb_g = netW_g(his_input, format='index') ques_hidden1 = repackage_hidden(ques_hidden1, batch_size) hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1)) featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \ ques_hidden1, hist_hidden1, rnd+1) #featD = l2_norm(featD) # Evaluate the Generator: _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1) #_, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden) # extend the hidden sample_ans_input.data.resize_((1, batch_size)).fill_(vocab_size) sample_opt = {'beam_size': 1} seq, seqLogprobs = netG.sample(netW_g, sample_ans_input, ques_hidden1, sample_opt) ans_sample_txt = decode_txt(itow, seq.t()) ans_txt = decode_txt(itow, tans) ques_txt = decode_txt(itow, questionL[:, rnd, :].t()) ''' for j in range(len(ans_txt)): print('Q: %s --A: %s --Sampled: %s' %(ques_txt[j], ans_txt[j], ans_sample_txt[j])) ans_sample_z = [[] for z in range(batch_size)] for m in range(5): ans_sample_result = torch.Tensor(ans_length, batch_size) # sample the result. noise_input.data.resize_(ans_length, batch_size, vocab_size+1) noise_input.data.uniform_(0,1) for t in range(ans_length): ans_emb = netW_g(sample_ans_input, format = 'index') if t == 0: logprob, ques_hidden2 = netG(ans_emb.view(1,-1,opt.ninp), ques_hidden1) else: logprob, ques_hidden2 = netG(ans_emb.view(1,-1,opt.ninp), ques_hidden2) one_hot, idx = sampler(logprob, noise_input[t], opt.gumble_weight) sample_ans_input.data.copy_(idx.data) ans_sample_result[t].copy_(idx.data) ans_sample_txt = decode_txt(itow, ans_sample_result) for ii in range(batch_size): ans_sample_z[ii].append(ans_sample_txt[ii]) ''' ans_txt = decode_txt(itow, tans) ques_txt = decode_txt(itow, questionL[:, rnd, :].t()) #for j in range(len(ans_txt)): # print('Q: %s --A: %s --Sampled: %s' %(ques_txt[j], ans_txt[j], ans_sample_txt[j])) for j in range(batch_size): save_tmp[j].append({'ques':ques_txt[j], 'gt_ans':ans_txt[j], \ 'sample_ans':ans_sample_txt[j], 'rnd':rnd, 'img_id':img_id[j]}) i += 1 bar.update(i) result_all += save_tmp return result_all
def eval(): netE.eval() netW.eval() netG.eval() data_iter_val = iter(dataloader_val) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) i = 0 display_count = 0 average_loss = 0 rank_all_tmp = [] while i < len(dataloader_val): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, 512) img_input.data.resize_(image.size()).copy_(image) for rnd in range(10): # get the corresponding round QA and history. ques, tans = question[:,rnd,:].t(), opt_answerT[:,rnd,:].clone().view(-1, ans_length).t() his = history[:,:rnd+1,:].clone().view(-1, his_length).t() ans = opt_answer[:,rnd,:,:].clone().view(-1, ans_length).t() gt_id = answer_ids[:,rnd] his_input.data.resize_(his.size()).copy_(his) ques_input.data.resize_(ques.size()).copy_(ques) ans_input.data.resize_(ans.size()).copy_(ans) ans_target.data.resize_(tans.size()).copy_(tans) gt_index.data.resize_(gt_id.size()).copy_(gt_id) ques_emb = netW(ques_input, format = 'index') his_emb = netW(his_input, format = 'index') ques_hidden = repackage_hidden(ques_hidden, batch_size) hist_hidden = repackage_hidden(hist_hidden, his_input.size(1)) encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) _, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden) #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid) ans_score = torch.FloatTensor(batch_size, 100).zero_() # extend the hidden hidden_replicated = [] for hid in ques_hidden: hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \ opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid)) hidden_replicated = tuple(hidden_replicated) ans_emb = netW(ans_input, format = 'index') output, _ = netG(ans_emb, hidden_replicated) logprob = - output logprob_select = torch.gather(logprob, 1, ans_target.view(-1,1)) mask = ans_target.data.eq(0) # generate the mask if isinstance(logprob, Variable): mask = Variable(mask, volatile=logprob.volatile) logprob_select.masked_fill_(mask.view_as(logprob_select), 0) prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1,100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b*100 gt_score = prob.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(prob, 1) count = sort_score.lt(gt_score.view(-1,1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_all_tmp += list(rank.view(-1).data.cpu().numpy()) i += 1 sys.stdout.write('Evaluating: {:d}/{:d} \r' \ .format(i, len(dataloader_val))) if i % 50 == 0: R1 = np.sum(np.array(rank_all_tmp)==1) / float(len(rank_all_tmp)) R5 = np.sum(np.array(rank_all_tmp)<=5) / float(len(rank_all_tmp)) R10 = np.sum(np.array(rank_all_tmp)<=10) / float(len(rank_all_tmp)) ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp)) mrr = np.sum(1/(np.array(rank_all_tmp, dtype='float'))) / float(len(rank_all_tmp)) print ('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' %(1, len(dataloader_val), mrr, R1, R5, R10, ave)) return rank_all_tmp
def eval(): netW.eval() netE.eval() netD.eval() data_iter_val = iter(dataloader_val) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) opt_hidden = netD.init_hidden(opt.batchSize) i = 0 display_count = 0 average_loss = 0 rank_all_tmp = [] result_all = [] img_atten = torch.FloatTensor(100 * 30, 10, 7, 7) while i < len(dataloader_val):#len(1000): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, 512) img_input.data.resize_(image.size()).copy_(image) save_tmp = [[] for j in range(batch_size)] for rnd in range(10): rnd=5 # get the corresponding round QA and history. ques = question[:,rnd,:].t() his = history[:,:rnd+1,:].clone().view(-1, his_length).t() opt_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t() gt_id = answer_ids[:,rnd] ques_input.data.resize_(ques.size()).copy_(ques) his_input.data.resize_(his.size()).copy_(his) gt_index.data.resize_(gt_id.size()).copy_(gt_id) opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans) opt_len = opt_answerLen[:,rnd,:].clone().view(-1) ques_emb = netW(ques_input, format = 'index') his_emb = netW(his_input, format = 'index') ques_hidden = repackage_hidden(ques_hidden, batch_size) hist_hidden = repackage_hidden(hist_hidden, his_input.size(1)) featD, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) #img_atten[i*batch_size:(i+1)*batch_size, rnd, :] = img_atten_weight.data.view(batch_size, 7, 7) opt_ans_emb = netW(opt_ans_input, format = 'index') opt_hidden = repackage_hidden(opt_hidden, opt_ans_input.size(1)) opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, n_words) opt_feat = opt_feat.view(batch_size, -1, opt.ninp) featD = featD.view(-1, opt.ninp, 1) score = torch.bmm(opt_feat, featD) score = score.view(-1, 100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b*100 gt_score = score.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(score, 1, descending=True) count = sort_score.gt(gt_score.view(-1,1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_all_tmp += list(rank.view(-1).data.cpu().numpy()) i += 1 result_all += save_tmp if i % 50 == 0: R1 = np.sum(np.array(rank_all_tmp)==1) / float(len(rank_all_tmp)) R5 = np.sum(np.array(rank_all_tmp)<=5) / float(len(rank_all_tmp)) R10 = np.sum(np.array(rank_all_tmp)<=10) / float(len(rank_all_tmp)) ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp)) mrr = np.sum(1/(np.array(rank_all_tmp, dtype='float'))) / float(len(rank_all_tmp)) print ('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' %(1, len(dataloader_val), mrr, R1, R5, R10, ave)) return img_atten
def val(): netE_g.eval() netE_d.eval() netW_g.eval() netW_d.eval() netG.eval() netD.eval() n_neg = 100 ques_hidden1 = netE_g.init_hidden(opt.batchSize) ques_hidden2 = netE_d.init_hidden(opt.batchSize) hist_hidden1 = netE_d.init_hidden(opt.batchSize) hist_hidden2 = netE_g.init_hidden(opt.batchSize) opt_hidden = netD.init_hidden(opt.batchSize) data_iter_val = iter(dataloader_val) count = 0 i = 0 rank_G = [] rank_D = [] while i < len(dataloader_val): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, 512) img_input.data.resize_(image.size()).copy_(image) for rnd in range(10): # get the corresponding round QA and history. ques = question[:,rnd,:].t() his = history[:,:rnd+1,:].clone().view(-1, his_length).t() opt_ans = opt_answer[:,rnd,:,:].clone().view(-1, ans_length).t() opt_tans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t() gt_id = answer_ids[:,rnd] opt_len = opt_answerLen[:,rnd,:].clone().view(-1) ques_input.data.resize_(ques.size()).copy_(ques) his_input.data.resize_(his.size()).copy_(his) opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans) opt_ans_target.data.resize_(opt_tans.size()).copy_(opt_tans) gt_index.data.resize_(gt_id.size()).copy_(gt_id) ques_emb_g = netW_g(ques_input, format = 'index') his_emb_g = netW_g(his_input, format = 'index') ques_emb_d = netW_d(ques_input, format = 'index') his_emb_d = netW_d(his_input, format = 'index') ques_hidden1 = repackage_hidden(ques_hidden1, batch_size) ques_hidden2 = repackage_hidden(ques_hidden2, batch_size) hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1)) hist_hidden2 = repackage_hidden(hist_hidden2, his_emb_d.size(1)) featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \ ques_hidden1, hist_hidden1, rnd+1) featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \ ques_hidden2, hist_hidden2, rnd+1) #featD = l2_norm(featD) # Evaluate the Generator: _, ques_hidden1 = netG(featG.view(1,-1,opt.ninp), ques_hidden1) #_, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden) # extend the hidden hidden_replicated = [] for hid in ques_hidden1: hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \ opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid)) hidden_replicated = tuple(hidden_replicated) ans_emb = netW_g(opt_ans_input, format = 'index') output, _ = netG(ans_emb, hidden_replicated) logprob = - output logprob_select = torch.gather(logprob, 1, opt_ans_target.view(-1,1)) mask = opt_ans_target.data.eq(0) # generate the mask if isinstance(logprob, Variable): mask = Variable(mask, volatile=logprob.volatile) logprob_select.masked_fill_(mask.view_as(logprob_select), 0) prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1,100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b*100 gt_score = prob.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(prob, 1) count = sort_score.lt(gt_score.view(-1,1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_G += list(rank.view(-1).data.cpu().numpy()) opt_ans_emb = netW_d(opt_ans_target, format = 'index') opt_hidden = repackage_hidden(opt_hidden, opt_ans_target.size(1)) opt_feat = netD(opt_ans_emb, opt_ans_target, opt_hidden, vocab_size) opt_feat = opt_feat.view(batch_size, -1, opt.ninp) #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid) featD = featD.view(-1, opt.ninp, 1) score = torch.bmm(opt_feat, featD) score = score.view(-1, 100) gt_score = score.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(score, 1, descending=True) count = sort_score.gt(gt_score.view(-1,1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_D += list(rank.view(-1).data.cpu().numpy()) i += 1 return rank_G, rank_D
def val(): netE_g.eval() netW_g.eval() netG.eval() n_neg = 100 ques_hidden1 = netE_g.init_hidden(opt.batchSize) hist_hidden2 = netE_g.init_hidden(opt.batchSize) data_iter_val = iter(dataloader_val) count = 0 i = 0 rank_G = [] while i < len(dataloader_val): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, 512) img_input.data.resize_(image.size()).copy_(image) for rnd in range(10): # get the corresponding round QA and history. ques = question[:, rnd, :].t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() opt_ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t() opt_tans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t() gt_id = answer_ids[:, rnd] opt_len = opt_answerLen[:, rnd, :].clone().view(-1) ques_input.data.resize_(ques.size()).copy_(ques) his_input.data.resize_(his.size()).copy_(his) opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans) opt_ans_target.data.resize_(opt_tans.size()).copy_(opt_tans) gt_index.data.resize_(gt_id.size()).copy_(gt_id) ques_emb_g = netW_g(ques_input, format='index') his_emb_g = netW_g(his_input, format='index') ques_emb_d = netW_d(ques_input, format='index') his_emb_d = netW_d(his_input, format='index') ques_hidden1 = repackage_hidden(ques_hidden1, batch_size) ques_hidden2 = repackage_hidden(ques_hidden2, batch_size) hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1)) hist_hidden2 = repackage_hidden(hist_hidden2, his_emb_d.size(1)) featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \ ques_hidden1, hist_hidden1, rnd+1) featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \ ques_hidden2, hist_hidden2, rnd+1) #featD = l2_norm(featD) _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1) hidden_replicated = [] for hid in ques_hidden1: hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \ opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid)) hidden_replicated = tuple(hidden_replicated) ans_emb = netW_g(opt_ans_input, format='index') output, _ = netG(ans_emb, hidden_replicated) logprob = -output logprob_select = torch.gather(logprob, 1, opt_ans_target.view(-1, 1)) mask = opt_ans_target.data.eq(0) # generate the mask if isinstance(logprob, Variable): mask = Variable(mask, volatile=logprob.volatile) logprob_select.masked_fill_(mask.view_as(logprob_select), 0) prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1, 100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b * 100 gt_score = prob.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(prob, 1) count = sort_score.lt(gt_score.view(-1, 1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_G += list(rank.view(-1).data.cpu().numpy()) opt_ans_emb = netW_d(opt_ans_target, format='index') opt_hidden = repackage_hidden(opt_hidden, opt_ans_target.size(1)) opt_feat = netD(opt_ans_emb, opt_ans_target, opt_hidden, vocab_size) opt_feat = opt_feat.view(batch_size, -1, opt.ninp) #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid) featD = featD.view(-1, opt.ninp, 1) score = torch.bmm(opt_feat, featD) score = score.view(-1, 100) gt_score = score.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(score, 1, descending=True) count = sort_score.gt(gt_score.view(-1, 1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_D += list(rank.view(-1).data.cpu().numpy()) i += 1 if i % 50 == 0: R1 = np.sum(np.array(rank_G) == 1) / float(len(rank_G)) R5 = np.sum(np.array(rank_G) <= 5) / float(len(rank_G)) R10 = np.sum(np.array(rank_G) <= 10) / float(len(rank_G)) ave = np.sum(np.array(rank_G)) / float(len(rank_G)) mrr = np.sum(1 / (np.array(rank_G, dtype='float'))) / float( len(rank_G)) print('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' % (1, len(dataloader_val), mrr, R1, R5, R10, ave)) return rank_G, rank_D
def train(epoch): netW.train() netE.train() netD.train() lr = adjust_learning_rate(optimizer, epoch, opt.lr) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) real_hidden = netD.init_hidden(opt.batchSize) wrong_hidden = netD.init_hidden(opt.batchSize) data_iter = iter(dataloader) bar = progressbar.ProgressBar(maxval=len(dataloader)) average_loss = 0 count = 0 i = 0 while i < len(dataloader): t1 = time.time() data = data_iter.next() image, history, question, answer, answerT, answerLen, answerIdx, questionL, \ opt_answerT, opt_answerLen, opt_answerIdx = data batch_size = question.size(0) image = image.view(-1, img_feat_size) img_input.data.resize_(image.size()).copy_(image) for rnd in range(10): netW.zero_grad() netE.zero_grad() netD.zero_grad() # get the corresponding round QA and history. ques = question[:,rnd,:].t() his = history[:,:rnd+1,:].clone().view(-1, his_length).t() ans = answer[:,rnd,:].t() tans = answerT[:,rnd,:].t() wrong_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t() real_len = answerLen[:,rnd] wrong_len = opt_answerLen[:,rnd,:].clone().view(-1) ques_input.data.resize_(ques.size()).copy_(ques) his_input.data.resize_(his.size()).copy_(his) ans_input.data.resize_(ans.size()).copy_(ans) ans_target.data.resize_(tans.size()).copy_(tans) wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans) # sample in-batch negative index batch_sample_idx.data.resize_(batch_size, opt.neg_batch_sample).zero_() sample_batch_neg(answerIdx[:,rnd], opt_answerIdx[:,rnd,:], batch_sample_idx, opt.neg_batch_sample) ques_emb = netW(ques_input, format = 'index') his_emb = netW(his_input, format = 'index') ques_hidden = repackage_hidden(ques_hidden, batch_size) hist_hidden = repackage_hidden(hist_hidden, his_input.size(1)) featD, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) ans_real_emb = netW(ans_target, format='index') ans_wrong_emb = netW(wrong_ans_input, format='index') real_hidden = repackage_hidden(real_hidden, batch_size) wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1)) real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size) wrong_feat = netD(ans_wrong_emb, wrong_ans_input, wrong_hidden, vocab_size) batch_wrong_feat = wrong_feat.index_select(0, batch_sample_idx.view(-1)) wrong_feat = wrong_feat.view(batch_size, -1, opt.ninp) batch_wrong_feat = batch_wrong_feat.view(batch_size, -1, opt.ninp) nPairLoss = critD(featD, real_feat, wrong_feat, batch_wrong_feat) average_loss += nPairLoss.data[0] nPairLoss.backward() optimizer.step() count += 1 bar.update(i) i += 1 average_loss = average_loss / count return average_loss, lr
def train(epoch): netW_d.train(), netE_d.train(), netE_g.train() netD.train(), netG.train(), netW_g.train() fake_len = torch.LongTensor(opt.batchSize) fake_len = fake_len.cpu() n_neg = opt.negative_sample ques_hidden1 = netE_d.init_hidden(opt.batchSize) ques_hidden2 = netE_g.init_hidden(opt.batchSize) hist_hidden1 = netE_d.init_hidden(opt.batchSize) hist_hidden2 = netE_g.init_hidden(opt.batchSize) real_hidden = netD.init_hidden(opt.batchSize) wrong_hidden = netD.init_hidden(opt.batchSize) fake_hidden = netD.init_hidden(opt.batchSize) data_iter = iter(dataloader) err_d = 0 err_g = 0 err_lm = 0 average_loss = 0 count = 0 i = 0 loss_store = [] while i < len(dataloader): t1 = time.time() data = data_iter.next() image, history, question, answer, answerT, answerLen, answerIdx, questionL, \ opt_answerT, opt_answerLen, opt_answerIdx = data batch_size = question.size(0) image = image.view(-1, 512) img_input.data.resize_(image.size()).copy_(image) err_d_tmp = 0 err_g_tmp = 0 err_lm_tmp = 0 err_d_fake_tmp = 0 err_g_fake_tmp = 0 for rnd in range(10): # get the corresponding round QA and history. ques = question[:, rnd, :].t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() ans = answer[:, rnd, :].t() tans = answerT[:, rnd, :].t() wrong_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t() real_len = answerLen[:, rnd].long() wrong_len = opt_answerLen[:, rnd, :].clone().view(-1) ques_input.data.resize_(ques.size()).copy_(ques) his_input.data.resize_(his.size()).copy_(his) ans_input.data.resize_(ans.size()).copy_(ans) ans_target.data.resize_(tans.size()).copy_(tans) wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans) batch_sample_idx.data.resize_(batch_size, opt.neg_batch_sample).zero_() sample_batch_neg(answerIdx[:, rnd], opt_answerIdx[:, rnd, :], batch_sample_idx, opt.neg_batch_sample) # ----------------------------------------- # update the Generator using MLE loss. # ----------------------------------------- if opt.update_LM: ques_emb_g = netW_g(ques_input, format='index') his_emb_g = netW_g(his_input, format='index') ques_hidden1 = repackage_hidden(ques_hidden1, batch_size) hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1)) featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \ ques_hidden1, hist_hidden1, rnd+1) _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1) # MLE loss for generator ans_emb = netW_g(ans_input) logprob, _ = netG(ans_emb, ques_hidden1) lm_loss = critLM(logprob, ans_target.view(-1, 1)) lm_loss = lm_loss / torch.sum(ans_target.data.gt(0)) netW_g.zero_grad() netG.zero_grad() netE_g.zero_grad() lm_loss.backward() optimizerLM.step() err_lm += lm_loss.data[0] err_lm_tmp += lm_loss.data[0] # sample the answer using gumble softmax sampler. ques_emb_g = netW_g(ques_input, format='index') his_emb_g = netW_g(his_input, format='index') ques_hidden1 = repackage_hidden(ques_hidden1, batch_size) hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1)) featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \ ques_hidden1, hist_hidden1, rnd+1) _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1) # Gumble softmax to sample the output. fake_onehot = [] fake_idx = [] noise_input.data.resize_(ans_length, batch_size, vocab_size + 1) noise_input.data.uniform_(0, 1) ans_sample = ans_input[0] for di in range(ans_length): ans_emb = netW_g(ans_sample, format='index') logprob, ques_hidden1 = netG(ans_emb.view(1, -1, opt.ninp), ques_hidden1) one_hot, idx = sampler(logprob, noise_input[di], opt.gumble_weight) fake_onehot.append(one_hot.view(1, -1, vocab_size + 1)) fake_idx.append(idx) if di + 1 < ans_length: ans_sample = idx # convert the list into the tensor variable. fake_onehot = torch.cat(fake_onehot, 0) fake_idx = torch.cat(fake_idx, 0) fake_len.resize_(batch_size).fill_(ans_length - 1) for di in range(ans_length - 1, 0, -1): fake_len.masked_fill_(fake_idx.data[di].eq(vocab_size), di) # generate fake mask. fake_mask.data.resize_(fake_idx.size()).fill_(1) # get the real, wrong and fake index. for b in range(batch_size): fake_mask.data[:fake_len[b] + 1, b] = 0 # apply the mask on the fake_idx. fake_idx.masked_fill_(fake_mask, 0) # get the fake diff mask. #fake_diff_mask = torch.sum(fake_idx == ans_target, 0) != 0 fake_onehot = fake_onehot.view(-1, vocab_size + 1) ###################################### # Discriminative trained generative model. ###################################### # forward the discriminator again. ques_emb_d = netW_d(ques_input, format='index') his_emb_d = netW_d(his_input, format='index') ques_hidden2 = repackage_hidden(ques_hidden2, batch_size) hist_hidden2 = repackage_hidden(hist_hidden2, his_emb_d.size(1)) featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \ ques_hidden2, hist_hidden2, rnd+1) ans_real_emb = netW_d(ans_target, format='index') #ans_wrong_emb = netW_d(wrong_ans_input, format='index') ans_fake_emb = netW_d(fake_onehot, format='onehot') ans_fake_emb = ans_fake_emb.view(ans_length, -1, opt.ninp) real_hidden = repackage_hidden(real_hidden, batch_size) #wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1)) fake_hidden = repackage_hidden(fake_hidden, batch_size) fake_feat = netD(ans_fake_emb, fake_idx, fake_hidden, vocab_size) real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size) d_g_loss, g_fake = critG(featD, real_feat, fake_feat) #, fake_diff_mask.detach()) netW_g.zero_grad() netG.zero_grad() netE_g.zero_grad() d_g_loss.backward() optimizerG.step() err_g += d_g_loss.data[0] err_g_tmp += d_g_loss.data[0] err_g_fake_tmp += g_fake count += 1 i += 1 loss_store.append({'iter':i, 'err_lm':err_lm_tmp/10, 'err_d':err_d_tmp/10, 'err_g':err_g_tmp/10, \ 'd_fake': err_d_fake_tmp/10, 'g_fake':err_g_fake_tmp/10}) if i % 20 == 0: print ('Epoch:%d %d/%d, err_lm %4f, err_d %4f, err_g %4f, d_fake %4f, g_fake %4f' \ % (epoch, i, len(dataloader), err_lm_tmp/10, err_d_tmp/10, err_g_tmp/10, err_d_fake_tmp/10, \ err_g_fake_tmp/10)) #average_loss = average_loss / count err_g = err_g / count err_d = err_d / count err_lm = err_lm / count return err_lm, err_d, err_g, loss_store
def eval(): netW.eval() netE.eval() netD.eval() data_iter_val = iter(dataloader_val) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) opt_hidden = netD.init_hidden(opt.batchSize) i = 0 display_count = 0 average_loss = 0 rank_all_tmp = [] result_all = [] img_atten = torch.FloatTensor(100 * 30, 10, 7, 7) while i < len(dataloader_val):#len(1000): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, 512) img_input.data.resize_(image.size()).copy_(image) save_tmp = [[] for j in range(batch_size)] for rnd in range(10): # get the corresponding round QA and history. ques = question[:,rnd,:].t() his = history[:,:rnd+1,:].clone().view(-1, his_length).t() opt_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t() gt_id = answer_ids[:,rnd] ques_input.data.resize_(ques.size()).copy_(ques) his_input.data.resize_(his.size()).copy_(his) gt_index.data.resize_(gt_id.size()).copy_(gt_id) opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans) opt_len = opt_answerLen[:,rnd,:].clone().view(-1) ques_emb = netW(ques_input, format = 'index') his_emb = netW(his_input, format = 'index') ques_hidden = repackage_hidden(ques_hidden, batch_size) hist_hidden = repackage_hidden(hist_hidden, his_input.size(1)) featD, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) #img_atten[i*batch_size:(i+1)*batch_size, rnd, :] = img_atten_weight.data.view(batch_size, 7, 7) opt_ans_emb = netW(opt_ans_input, format = 'index') opt_hidden = repackage_hidden(opt_hidden, opt_ans_input.size(1)) opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, n_words) opt_feat = opt_feat.view(batch_size, -1, opt.ninp) featD = featD.view(-1, opt.ninp, 1) score = torch.bmm(opt_feat, featD) score = score.view(-1, 100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b*100 gt_score = score.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(score, 1, descending=True) count = sort_score.gt(gt_score.view(-1,1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_all_tmp += list(rank.view(-1).data.cpu().numpy()) i += 1 result_all += save_tmp if i % 50 == 0: R1 = np.sum(np.array(rank_all_tmp)==1) / float(len(rank_all_tmp)) R5 = np.sum(np.array(rank_all_tmp)<=5) / float(len(rank_all_tmp)) R10 = np.sum(np.array(rank_all_tmp)<=10) / float(len(rank_all_tmp)) ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp)) mrr = np.sum(1/(np.array(rank_all_tmp, dtype='float'))) / float(len(rank_all_tmp)) print ('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' %(1, len(dataloader_val), mrr, R1, R5, R10, ave)) return img_atten
def train(epoch): netW.train() netE.train() netG.train() lr = adjust_learning_rate(optimizer, epoch, opt.lr) data_iter = iter(dataloader) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) average_loss = 0 count = 0 i = 0 while i < len(dataloader): data = data_iter.next() image, history, question, answer, answerT, answerLen, answerIdx, \ questionL, negAnswer, negAnswerLen, negAnswerIdx = data batch_size = question.size(0) image = image.view(-1, img_feat_size) img_input.data.resize_(image.size()).copy_(image) for rnd in range(10): ques = question[:,rnd,:].t() his = history[:,:rnd+1,:].clone().view(-1, his_length).t() ans, tans = answer[:,rnd,:].t(), answerT[:,rnd,:].t() his_input.data.resize_(his.size()).copy_(his) ques_input.data.resize_(ques.size()).copy_(ques) ans_input.data.resize_(ans.size()).copy_(ans) ans_target.data.resize_(tans.size()).copy_(tans) ques_emb = netW(ques_input, format = 'index') his_emb = netW(his_input, format = 'index') ques_hidden = repackage_hidden(ques_hidden, batch_size) hist_hidden = repackage_hidden(hist_hidden, his_input.size(1)) encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) _, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden) ans_emb = netW(ans_input) logprob, ques_hidden = netG(ans_emb, ques_hidden) loss = critG(logprob, ans_target.view(-1, 1)) loss = loss / torch.sum(ans_target.data.gt(0)) average_loss += loss.data[0] # do backward. netW.zero_grad() netE.zero_grad() netG.zero_grad() loss.backward() optimizer.step() count += 1 i += 1 if i % opt.log_interval == 0: average_loss /= count print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}"\ .format(i, len(dataloader), epoch, average_loss, lr)) average_loss = 0 count = 0 return average_loss
def val(): netE.eval() netW.eval() netG.eval() data_iter_val = iter(dataloader_val) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) i = 0 average_loss = 0 rank_all_tmp = [] while i < len(dataloader_val): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data batch_size = question.size(0) image = image.view(-1, img_feat_size) img_input.data.resize_(image.size()).copy_(image) for rnd in range(10): # get the corresponding round QA and history. ques, tans = question[:,rnd,:].t(), opt_answerT[:,rnd,:].clone().view(-1, ans_length).t() his = history[:,:rnd+1,:].clone().view(-1, his_length).t() ans = opt_answer[:,rnd,:,:].clone().view(-1, ans_length).t() gt_id = answer_ids[:,rnd] his_input.data.resize_(his.size()).copy_(his) ques_input.data.resize_(ques.size()).copy_(ques) ans_input.data.resize_(ans.size()).copy_(ans) ans_target.data.resize_(tans.size()).copy_(tans) gt_index.data.resize_(gt_id.size()).copy_(gt_id) ques_emb = netW(ques_input, format = 'index') his_emb = netW(his_input, format = 'index') ques_hidden = repackage_hidden(ques_hidden, batch_size) hist_hidden = repackage_hidden(hist_hidden, his_input.size(1)) encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) _, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden) hidden_replicated = [] for hid in ques_hidden: hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \ opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid)) hidden_replicated = tuple(hidden_replicated) ans_emb = netW(ans_input, format = 'index') output, _ = netG(ans_emb, hidden_replicated) logprob = - output logprob_select = torch.gather(logprob, 1, ans_target.view(-1,1)) mask = ans_target.data.eq(0) # generate the mask if isinstance(logprob, Variable): mask = Variable(mask, volatile=logprob.volatile) logprob_select.masked_fill_(mask.view_as(logprob_select), 0) prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1,100) for b in range(batch_size): gt_index.data[b] = gt_index.data[b] + b*100 gt_score = prob.view(-1).index_select(0, gt_index) sort_score, sort_idx = torch.sort(prob, 1) count = sort_score.lt(gt_score.view(-1,1).expand_as(sort_score)) rank = count.sum(1) + 1 rank_all_tmp += list(rank.view(-1).data.cpu().numpy()) i += 1 return rank_all_tmp, average_loss
def eval(): netW.eval() netE.eval() netD.eval() data_iter_val = iter(dataloader_val) ques_hidden = netE.init_hidden(opt.batchSize) hist_hidden = netE.init_hidden(opt.batchSize) opt_hidden = netD.init_hidden(opt.batchSize) i = 0 ranks_json = [] while i < len(dataloader_val): #len(1000): data = data_iter_val.next() image, history, question, answer, answerT, questionL, opt_answer, \ opt_answerT, answerLen, opt_answerLen, img_id, rounds = data batch_size = question.size(0) image = image.view(-1, 36, 2048) # image : batchx36x2048 img_input.data.resize_(image.size()).copy_(image) #image2 = image2.view(-1, img_feat_size) # image : 6272(128x7x7) x 512 #img_input2.data.resize_(image2.size()).copy_(image2) #6272x512 save_tmp = [[] for j in range(batch_size)] rnd = int(rounds - 1) ques = question[:, rnd, :].t() his = history[:, :rnd + 1, :].clone().view(-1, his_length).t() opt_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t() ques_input.data.resize_(ques.size()).copy_(ques) his_input.data.resize_(his.size()).copy_(his) opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans) opt_len = opt_answerLen[:, rnd, :].clone().view(-1) ques_emb = netW(ques_input, format='index') his_emb = netW(his_input, format='index') ques_hidden = repackage_hidden(ques_hidden, batch_size) hist_hidden = repackage_hidden(hist_hidden, his_input.size(1)) featD, ques_hidden = netE(ques_emb, his_emb, img_input, \ ques_hidden, hist_hidden, rnd+1) opt_ans_emb = netW(opt_ans_input, format='index') opt_hidden = repackage_hidden(opt_hidden, opt_ans_input.size(1)) opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, n_words) opt_feat = opt_feat.view(batch_size, -1, opt.ninp) featD = featD.view(-1, opt.ninp, 1) score = torch.bmm(opt_feat, featD) score = score.view(-1, 100) ranks = scores_to_ranks(score) ranks_json.append({ "image_id": int(img_id), "round_id": int(rounds), "ranks": ranks.view(-1).tolist() }) i += 1 sys.stdout.write('Evaluating: {:d}/{:d} \r' \ .format(i, len(dataloader_val))) sys.stdout.flush() #pdb.set_trace() return ranks_json