def val():
    netE_g.eval()
    netW_g.eval()
    netG.eval()

    n_neg = 100
    ques_hidden1 = netE_g.init_hidden(opt.batchSize)

    hist_hidden1 = netE_g.init_hidden(opt.batchSize)

    bar = progressbar.ProgressBar(max_value=len(dataloader_val))
    data_iter_val = iter(dataloader_val)

    count = 0
    i = 0

    result_all = []
    # print('length of dataloader: ', len(dataloader_val))
    while i < len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                    opt_answerT, answer_ids, answerLen, opt_answerLen, img_id  = data

        batch_size = question.size(0)
        image = image.view(-1, 512)
        with torch.no_grad():
            img_input.resize_(image.size()).copy_(image)

        save_tmp = [[] for j in range(batch_size)]
        for rnd in range(10):

            # get the corresponding round QA and history.
            ques, tans = question[:, rnd, :].t(), answerT[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()

            opt_ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t()
            opt_tans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:, rnd]
            opt_len = opt_answerLen[:, rnd, :].clone().view(-1)

            ques_input = torch.LongTensor(ques.size()).cuda()
            ques_input.copy_(ques)

            his_input = torch.LongTensor(his.size()).cuda()
            his_input.copy_(his)

            opt_ans_input = torch.LongTensor(opt_ans.size()).cuda()
            opt_ans_input.copy_(opt_ans)

            opt_ans_target = torch.LongTensor(opt_tans.size()).cuda()
            opt_ans_target.copy_(opt_tans)

            gt_index = torch.LongTensor(gt_id.size())
            gt_index.copy_(gt_id)

            ques_emb_g = netW_g(ques_input, format='index')
            his_emb_g = netW_g(his_input, format='index')

            ques_hidden1 = repackage_hidden_new(ques_hidden1, batch_size)

            hist_hidden1 = repackage_hidden_new(hist_hidden1,
                                                his_emb_g.size(1))

            featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                ques_hidden1, hist_hidden1, rnd+1)

            #featD = l2_norm(featD)
            # Evaluate the Generator:
            _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1)
            #_, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden)
            # extend the hidden
            sample_ans_input = torch.LongTensor(1, opt.batchSize).cuda()
            sample_ans_input.resize_((1, batch_size)).fill_(vocab_size)

            sample_opt = {'beam_size': 1}

            seq, seqLogprobs = netG.sample(netW_g, sample_ans_input,
                                           ques_hidden1, sample_opt)
            ans_sample_txt = decode_txt(itow, seq.t())
            ans_txt = decode_txt(itow, tans)
            ques_txt = decode_txt(itow, questionL[:, rnd, :].t())
            '''
            for j in range(len(ans_txt)):
                print('Q: %s --A: %s --Sampled: %s' %(ques_txt[j], ans_txt[j], ans_sample_txt[j]))
            
            ans_sample_z = [[] for z in range(batch_size)]
            for m in range(5):
                ans_sample_result = torch.Tensor(ans_length, batch_size)
                # sample the result.
                noise_input.data.resize_(ans_length, batch_size, vocab_size+1)
                noise_input.data.uniform_(0,1)
                for t in range(ans_length):
                    ans_emb = netW_g(sample_ans_input, format = 'index')
                    if t == 0:
                        logprob, ques_hidden2 = netG(ans_emb.view(1,-1,opt.ninp), ques_hidden1)
                    else:
                        logprob, ques_hidden2 = netG(ans_emb.view(1,-1,opt.ninp), ques_hidden2)

                    one_hot, idx = sampler(logprob, noise_input[t], opt.gumble_weight)

                    sample_ans_input.data.copy_(idx.data)
                    ans_sample_result[t].copy_(idx.data)

                ans_sample_txt = decode_txt(itow, ans_sample_result)
                for ii in range(batch_size):
                    ans_sample_z[ii].append(ans_sample_txt[ii])
            '''
            ans_txt = decode_txt(itow, tans)
            ques_txt = decode_txt(itow, questionL[:, rnd, :].t())
            #for j in range(len(ans_txt)):
            #    print('Q: %s --A: %s --Sampled: %s' %(ques_txt[j], ans_txt[j], ans_sample_txt[j]))

            for j in range(batch_size):
                save_tmp[j].append({'ques':ques_txt[j], 'gt_ans':ans_txt[j], \
                            'sample_ans':ans_sample_txt[j], 'rnd':rnd, 'img_id':img_id[j].item()})
        i += 1
        bar.update(i)

        result_all += save_tmp

    return result_all
def train(epoch):
    netW.train()
    netE.train()
    netD.train()

    lr = adjust_learning_rate(optimizer, epoch, opt.lr)

    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    real_hidden = netD.init_hidden(opt.batchSize)
    wrong_hidden = netD.init_hidden(opt.batchSize)

    data_iter = iter(dataloader)

    average_loss = 0
    count = 0
    i = 0

    while i < len(dataloader):

        t1 = time.time()
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
        opt_answerT, opt_answerLen, opt_answerIdx = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)

        with torch.no_grad():
            img_input.resize_(image.size()).copy_(image)

        for rnd in range(10):
            netW.zero_grad()
            netE.zero_grad()
            netD.zero_grad()
            # get the corresponding round QA and history.
            ques = question[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()

            ans = answer[:, rnd, :].t()
            tans = answerT[:, rnd, :].t()
            wrong_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()

            real_len = answerLen[:, rnd]
            wrong_len = opt_answerLen[:, rnd, :].clone().view(-1)

            ques_input = torch.LongTensor(ques.size())
            ques_input.copy_(ques)

            his_input = torch.LongTensor(his.size())
            his_input.copy_(his)

            ans_input = torch.LongTensor(ans.size())
            ans_input.copy_(ans)

            ans_target = torch.LongTensor(tans.size())
            ans_target.copy_(tans)

            wrong_ans_input = torch.LongTensor(wrong_ans.size())
            wrong_ans_input.copy_(wrong_ans)

            # sample in-batch negative index
            batch_sample_idx = torch.zeros(batch_size,
                                           opt.neg_batch_sample,
                                           dtype=torch.long)
            sample_batch_neg(answerIdx[:, rnd], opt_answerIdx[:, rnd, :],
                             batch_sample_idx, opt.neg_batch_sample)

            ques_emb = netW(ques_input, format='index')
            his_emb = netW(his_input, format='index')

            ques_hidden = repackage_hidden_new(ques_hidden, batch_size)
            hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                      ques_hidden, hist_hidden, rnd + 1)

            ans_real_emb = netW(ans_target, format='index')
            ans_wrong_emb = netW(wrong_ans_input, format='index')

            real_hidden = repackage_hidden_new(real_hidden, batch_size)
            wrong_hidden = repackage_hidden_new(wrong_hidden,
                                                ans_wrong_emb.size(1))

            real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size)
            wrong_feat = netD(ans_wrong_emb, wrong_ans_input, wrong_hidden,
                              vocab_size)

            batch_wrong_feat = wrong_feat.index_select(
                0, batch_sample_idx.view(-1))
            wrong_feat = wrong_feat.view(
                batch_size, -1,
                opt.ninp)  # (batch_size, negative_sample, ninp)
            batch_wrong_feat = batch_wrong_feat.view(
                batch_size, -1,
                opt.ninp)  # (batch_size, crossover_negative_sample, ninp)

            # All the correct answers are persent at the begining of the combined_scores
            combined_scores, l2_norm = feat2score(
                featD, real_feat, wrong_feat,
                batch_wrong_feat)  # (batch_size, 1 + n_s + n_b_s)
            lambs, nPairLoss = critD(combined_scores)

            average_loss += nPairLoss.data.item()
            l2_norm.backward(retain_graph=True)
            combined_scores.backward(lambs)
            optimizer.step()
            count += 1

        i += 1
        if i % opt.log_interval == 0:
            average_loss /= count
            print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}" \
                  .format(i, len(dataloader), epoch, average_loss, lr))
            average_loss = 0
            count = 0

    return average_loss, lr
def val():
    # ques_hidden = netE.init_hidden(opt.batchSize)

    netE.eval()
    netW.eval()
    netD.eval()

    n_neg = 100
    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    print('DSuccess')
    hist_hidden = netE.init_hidden(opt.batchSize)

    opt_hidden = netD.init_hidden(opt.batchSize)
    i = 0

    average_loss = 0
    rank_all_tmp = []

    while i < len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
        opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        # image = l2_norm(image)
        with torch.no_grad():
            img_input.resize_(image.size()).copy_(image)

        for rnd in range(10):
            # get the corresponding round QA and history.
            ques = question[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()

            opt_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:, rnd]

            ques_input = torch.LongTensor(ques.size())
            ques_input.copy_(ques)

            his_input = torch.LongTensor(his.size())
            his_input.copy_(his)

            opt_ans_input = torch.LongTensor(opt_ans.size())
            opt_ans_input.copy_(opt_ans)

            gt_index = torch.LongTensor(gt_id.size())
            gt_index.copy_(gt_id)

            opt_len = opt_answerLen[:, rnd, :].clone().view(-1)

            ques_emb = netW(ques_input, format='index')
            his_emb = netW(his_input, format='index')

            ques_hidden = repackage_hidden_new(ques_hidden, batch_size)
            hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                      ques_hidden, hist_hidden, rnd + 1)

            opt_ans_emb = netW(opt_ans_input, format='index')
            opt_hidden = repackage_hidden_new(opt_hidden,
                                              opt_ans_input.size(1))
            opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, vocab_size)
            opt_feat = opt_feat.view(batch_size, -1, opt.ninp)

            # ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid)
            featD = featD.view(-1, opt.ninp, 1)
            score = torch.bmm(opt_feat, featD)
            score = score.view(-1, 100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b * 100

            gt_score = score.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(score, 1, descending=True)

            count = sort_score.gt(gt_score.view(-1, 1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_all_tmp += list(rank.view(-1).data.cpu().numpy())

        i += 1
        sys.stdout.write('Evaluating: {:d}/{:d}  \r' \
                         .format(i, len(dataloader_val)))
        sys.stdout.flush()

    return rank_all_tmp
Exemple #4
0
def eval():

    netE.eval()
    netW.eval()
    netG.eval()

    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    i = 0
    display_count = 0
    average_loss = 0
    rank_all_tmp = []
    result_all = []

    early_stop = int(opt.early_stop / opt.batchSize)
    dataloader_size = min(len(dataloader_val), early_stop)

    print('early_stop: {}'.format(early_stop))
    print('dataloader_size: {}'.format(dataloader_size))

    while i < dataloader_size:
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                opt_answerT, answer_ids, answerLen, opt_answerLen, img_id  = data

        batch_size = question.size(0)
        image = image.view(-1, 512)

        with torch.no_grad():
            img_input.resize_(image.size()).copy_(image)
        # img_input.data.resize_(image.size()).copy_(image)

        save_tmp = [[] for j in range(batch_size)]

        for rnd in range(10):
            # get the corresponding round QA and history.
            ques, tans = question[:,rnd,:].t(), opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()
            ans = opt_answer[:,rnd,:,:].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:,rnd]

            # his_input.data.resize_(his.size()).copy_(his)
            # ques_input.data.resize_(ques.size()).copy_(ques)
            # ans_input.data.resize_(ans.size()).copy_(ans)
            # ans_target.data.resize_(tans.size()).copy_(tans)
            #
            # gt_index.data.resize_(gt_id.size()).copy_(gt_id)

            his_input = torch.LongTensor(his.size()).cuda()
            his_input.copy_(his)

            ques_input = torch.LongTensor(ques.size()).cuda()
            ques_input.copy_(ques)

            ans_input = torch.LongTensor(ans.size()).cuda()
            ans_input.copy_(ans)

            ans_target = torch.LongTensor(tans.size()).cuda()
            ans_target.copy_(tans)

            gt_index = torch.LongTensor(gt_id.size()).cuda()
            gt_index.copy_(gt_id)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')


            ques_hidden = repackage_hidden_new(ques_hidden, batch_size)
            hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1))

            encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            _, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden)

            #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid)
            ans_score = torch.FloatTensor(batch_size, 100).zero_()
            # extend the hidden
            hidden_replicated = []
            for hid in ques_hidden:
                hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \
                    opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid))
            hidden_replicated = tuple(hidden_replicated)

            ans_emb = netW(ans_input, format = 'index')

            output, _ = netG(ans_emb, hidden_replicated)
            logprob = - output
            logprob_select = torch.gather(logprob, 1, ans_target.view(-1,1))

            mask = ans_target.data.eq(0)  # generate the mask
            if isinstance(logprob, Variable):
                mask = Variable(mask, volatile=logprob.volatile)
            logprob_select.masked_fill_(mask.view_as(logprob_select), 0)

            prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1,100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b*100

            gt_score = prob.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(prob, 1)

            count = sort_score.lt(gt_score.view(-1,1).expand_as(sort_score))
            rank = count.sum(1) + 1
            gt_rank_cpu = rank.view(-1).data.cpu().numpy()

            # --------------------- get the top 10 answers -------------------------

            answer_list = tans  # 9 x bs*100
            new_sorted_idx = torch.LongTensor(batch_size * 10)
            for b in range(batch_size):
                new_sorted_idx[b * 10:b * 10 + 10] = sort_idx[b, :10] + b * 100

            ans_array = tans.index_select(1, new_sorted_idx)
            ans_list = decode_txt(itow, ans_array)

            ques_txt = decode_txt(itow, questionL[:, rnd, :].t())
            ans_txt = decode_txt(itow, tans)

            for b in range(batch_size):
                data_dict = {}
                data_dict['ques'] = ques_txt[b]
                data_dict['gt_ans'] = ans_txt[gt_index[b]]
                data_dict['top10_ans'] = ans_list[b * 10:(b + 1) * 10]
                data_dict['rnd'] = rnd
                data_dict['image_id'] = img_id[b].item()
                data_dict['gt_ans_rank'] = str(gt_rank_cpu[b])
                save_tmp[b].append(data_dict)
            #------------------------------------------------------------------------

            rank_all_tmp += list(rank.view(-1).data.cpu().numpy())

        result_all += save_tmp
        i += 1
        sys.stdout.write('Evaluating: {:d}/{:d}  \r' \
          .format(i, len(dataloader_val)))

        if i % 50 == 0:
            R1 = np.sum(np.array(rank_all_tmp)==1) / float(len(rank_all_tmp))
            R5 =  np.sum(np.array(rank_all_tmp)<=5) / float(len(rank_all_tmp))
            R10 = np.sum(np.array(rank_all_tmp)<=10) / float(len(rank_all_tmp))
            ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp))
            mrr = np.sum(1/(np.array(rank_all_tmp, dtype='float'))) / float(len(rank_all_tmp))
            logger.warning('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' %(i, len(dataloader_val), mrr, R1, R5, R10, ave))

    return (rank_all_tmp, result_all)
Exemple #5
0
def eval():
    netW.eval()
    netE.eval()
    netD.eval()

    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    opt_hidden = netD.init_hidden(opt.batchSize)
    i = 0
    display_count = 0
    average_loss = 0
    rank_all_tmp = []
    result_all = []
    img_atten = torch.FloatTensor(100 * 30, 10, 7, 7)

    early_stop = int(opt.early_stop / opt.batchSize)
    dataloader_size = min(len(dataloader_val), early_stop)

    print('early_stop: {}'.format(early_stop))
    print('dataloader_size: {}'.format(dataloader_size))

    while i < dataloader_size:  #len(1000):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                opt_answerT, answer_ids, answerLen, opt_answerLen, img_id  = data

        batch_size = question.size(0)
        image = image.view(-1, 512)
        with torch.no_grad():
            img_input.resize_(image.size()).copy_(image)

        save_tmp = [[] for j in range(batch_size)]

        for rnd in range(10):
            #todo: remove this hard coded rnd = 5 after verifying!
            # rnd=5
            # get the corresponding round QA and history.
            ques, tans = question[:, rnd, :].t(), answerT[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()

            opt_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:, rnd]

            ques_input = torch.LongTensor(ques.size()).cuda()
            ques_input.copy_(ques)

            his_input = torch.LongTensor(his.size()).cuda()
            his_input.copy_(his)

            gt_index = torch.LongTensor(gt_id.size()).cuda()
            gt_index.copy_(gt_id)

            opt_ans_input = torch.LongTensor(opt_ans.size()).cuda()
            opt_ans_input.copy_(opt_ans)

            opt_len = opt_answerLen[:, rnd, :].clone().view(-1)

            ques_emb = netW(ques_input, format='index')
            his_emb = netW(his_input, format='index')

            ques_hidden = repackage_hidden_new(ques_hidden, batch_size)
            hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                                    ques_hidden, hist_hidden, rnd+1)

            #img_atten[i*batch_size:(i+1)*batch_size, rnd, :] = img_atten_weight.data.view(batch_size, 7, 7)

            opt_ans_emb = netW(opt_ans_input, format='index')
            opt_hidden = repackage_hidden_new(opt_hidden,
                                              opt_ans_input.size(1))
            opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, n_words)
            opt_feat = opt_feat.view(batch_size, -1, opt.ninp)

            featD = featD.view(-1, opt.ninp, 1)
            score = torch.bmm(opt_feat, featD)
            score = score.view(-1, 100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b * 100

            gt_score = score.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(score, 1, descending=True)

            opt_answer_cur_ques = opt_answerT.detach().numpy(
            )[:, rnd, :, :]  #5, 100, 9
            top_sort_idx = sort_idx.cpu().detach().numpy()[:, 0:opt.topn]
            first_dim_indices = np.broadcast_to(
                np.arange(batch_size).reshape(batch_size, 1),
                (batch_size, opt.topn)).reshape(batch_size * opt.topn)
            top_ans_word_indices = opt_answer_cur_ques[
                first_dim_indices,
                top_sort_idx.reshape(batch_size * opt.topn), :].reshape(
                    batch_size, opt.topn, 9)
            top_ans_txt_rank_wise = []  #10, 5 as strings

            ques_txt = decode_txt(itow, questionL[:, rnd, :].t())
            ans_txt = decode_txt(itow, tans)

            for pos in range(opt.topn):
                top_temp = decode_txt(
                    itow,
                    torch.tensor(top_ans_word_indices[:, pos, :]).t())
                top_ans_txt_rank_wise.append(top_temp)

            top_ans_txt = np.array(top_ans_txt_rank_wise, dtype=str).T

            count = sort_score.gt(gt_score.view(-1, 1).expand_as(sort_score))
            rank = count.sum(1) + 1
            gt_rank_cpu = rank.view(-1).data.cpu().numpy()
            rank_all_tmp += list(rank.view(-1).data.cpu().numpy())

            sort_score_cpu = sort_score.data.cpu().numpy()

            for b in range(batch_size):
                save_tmp[b].append({
                    "ques":
                    ques_txt[b],
                    "gt_ans":
                    ans_txt[b],
                    "top_10_disc_ans":
                    top_ans_txt.tolist()[b],
                    "gt_ans_rank":
                    str(gt_rank_cpu[b]),
                    "rnd":
                    rnd,
                    "img_id":
                    img_id[b].item(),
                    "top_10_scores":
                    sort_score_cpu[b][:opt.topn].tolist()
                })

        i += 1

        result_all += save_tmp

        if i % opt.log_iter == 0:
            R1 = np.sum(np.array(rank_all_tmp) == 1) / float(len(rank_all_tmp))
            R5 = np.sum(np.array(rank_all_tmp) <= 5) / float(len(rank_all_tmp))
            R10 = np.sum(np.array(rank_all_tmp) <= 10) / float(
                len(rank_all_tmp))
            ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp))
            mrr = np.sum(1 / (np.array(rank_all_tmp, dtype='float'))) / float(
                len(rank_all_tmp))
            logger.warning('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' %
                           (i, len(dataloader_val), mrr, R1, R5, R10, ave))

    return (rank_all_tmp, result_all)
Exemple #6
0
def train(epoch):
    netW_d.train(), netE_d.train(), netE_g.train()
    netD.train(), netG.train(), netW_g.train()

    fake_len = torch.LongTensor(opt.batchSize)
    fake_len = fake_len.cuda()

    n_neg = opt.negative_sample
    ques_hidden1 = netE_d.init_hidden(opt.batchSize)
    ques_hidden2 = netE_g.init_hidden(opt.batchSize)

    hist_hidden1 = netE_d.init_hidden(opt.batchSize)
    hist_hidden2 = netE_g.init_hidden(opt.batchSize)

    real_hidden = netD.init_hidden(opt.batchSize)
    wrong_hidden = netD.init_hidden(opt.batchSize)
    fake_hidden = netD.init_hidden(opt.batchSize)

    data_iter = iter(dataloader)

    err_d = 0
    err_g = 0
    err_lm = 0
    average_loss = 0
    count = 0
    i = 0
    loss_store = []
    while i < len(dataloader):
        t1 = time.time()
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
                                    opt_answerT, opt_answerLen, opt_answerIdx, _ = data
        batch_size = question.size(0)
        image = image.view(-1, 512)
        with torch.no_grad():
            img_input.resize_(image.size()).copy_(image)

        err_d_tmp = 0
        err_g_tmp = 0
        err_lm_tmp = 0
        err_d_fake_tmp = 0
        err_g_fake_tmp = 0
        for rnd in range(10):
            # get the corresponding round QA and history.
            ques = question[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()

            ans = answer[:, rnd, :].t()
            tans = answerT[:, rnd, :].t()
            wrong_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()

            real_len = answerLen[:, rnd].long()
            wrong_len = opt_answerLen[:, rnd, :].clone().view(-1)

            ques_input = torch.LongTensor(ques.size()).cuda()
            ques_input.copy_(ques)

            his_input = torch.LongTensor(his.size()).cuda()
            his_input.copy_(his)

            ans_input = torch.LongTensor(ans.size()).cuda()
            ans_input.copy_(ans)

            ans_target = torch.LongTensor(tans.size()).cuda()
            ans_target.copy_(tans)

            wrong_ans_input = torch.LongTensor(wrong_ans.size()).cuda()
            wrong_ans_input.copy_(wrong_ans)

            batch_sample_idx = torch.zeros(batch_size,
                                           opt.neg_batch_sample,
                                           dtype=torch.long).cuda()
            sample_batch_neg(answerIdx[:, rnd], opt_answerIdx[:, rnd, :],
                             batch_sample_idx, opt.neg_batch_sample)

            # -----------------------------------------
            # update the Generator using MLE loss.
            # -----------------------------------------
            if opt.update_LM:
                ques_emb_g = netW_g(ques_input, format='index')
                his_emb_g = netW_g(his_input, format='index')

                ques_hidden1 = repackage_hidden_new(ques_hidden1, batch_size)
                hist_hidden1 = repackage_hidden_new(hist_hidden1,
                                                    his_emb_g.size(1))

                featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                    ques_hidden1, hist_hidden1, rnd+1)

                _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp),
                                       ques_hidden1)
                # MLE loss for generator
                ans_emb = netW_g(ans_input)
                logprob, _ = netG(ans_emb, ques_hidden1)
                lm_loss = critLM(logprob, ans_target.view(-1, 1))
                lm_loss = lm_loss / torch.sum(ans_target.data.gt(0))
                # total loss = discriminator_loss + alpha*lm_loss
                lm_loss = opt.alpha * lm_loss

                netW_g.zero_grad()
                netG.zero_grad()
                netE_g.zero_grad()

                lm_loss.backward()
                optimizerLM.step()
                err_lm += lm_loss.data.item()
                err_lm_tmp += lm_loss.data.item()

            # sample the answer using gumble softmax sampler.
            ques_emb_g = netW_g(ques_input, format='index')
            his_emb_g = netW_g(his_input, format='index')

            ques_hidden1 = repackage_hidden_new(ques_hidden1, batch_size)
            hist_hidden1 = repackage_hidden_new(hist_hidden1,
                                                his_emb_g.size(1))

            featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                ques_hidden1, hist_hidden1, rnd+1)

            _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1)

            # Gumble softmax to sample the output.
            fake_onehot = []
            fake_idx = []

            noise_input = torch.FloatTensor(ans_length, batch_size,
                                            vocab_size + 1).cuda()
            noise_input.data.uniform_(0, 1)

            ans_sample = ans_input[0]
            for di in range(ans_length):
                ans_emb = netW_g(ans_sample, format='index')
                logprob, ques_hidden1 = netG(ans_emb.view(1, -1, opt.ninp),
                                             ques_hidden1)
                one_hot, idx = sampler(logprob, noise_input[di],
                                       opt.gumble_weight)
                fake_onehot.append(one_hot.view(1, -1, vocab_size + 1))
                fake_idx.append(idx)
                if di + 1 < ans_length:
                    ans_sample = idx

            # convert the list into the tensor variable.
            fake_onehot = torch.cat(fake_onehot, 0)
            fake_idx = torch.cat(fake_idx, 0)

            fake_len.resize_(batch_size).fill_(ans_length - 1)
            for di in range(ans_length - 1, 0, -1):
                fake_len.masked_fill_(fake_idx.data[di].eq(vocab_size), di)

            # generate fake mask.
            #----------------------------------------------------------------------------
            fake_mask = torch.ByteTensor(fake_idx.size()).cuda()
            fake_mask.resize_(fake_idx.size()).fill_(1)
            #----------------------------------------------------------------------------
            # get the real, wrong and fake index.
            for b in range(batch_size):
                fake_mask.data[:fake_len[b] + 1, b] = 0

            # apply the mask on the fake_idx.
            fake_idx.masked_fill_(fake_mask, 0)

            # get the fake diff mask.
            #fake_diff_mask = torch.sum(fake_idx == ans_target, 0) != 0
            fake_onehot = fake_onehot.view(-1, vocab_size + 1)

            ######################################
            # Discriminative trained generative model.
            ######################################
            # forward the discriminator again.
            ques_emb_d = netW_d(ques_input, format='index')
            his_emb_d = netW_d(his_input, format='index')

            ques_hidden2 = repackage_hidden_new(ques_hidden2, batch_size)
            hist_hidden2 = repackage_hidden_new(hist_hidden2,
                                                his_emb_d.size(1))

            featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \
                                        ques_hidden2, hist_hidden2, rnd+1)

            ans_real_emb = netW_d(ans_target, format='index')
            #ans_wrong_emb = netW_d(wrong_ans_input, format='index')
            ans_fake_emb = netW_d(fake_onehot, format='onehot')
            ans_fake_emb = ans_fake_emb.view(ans_length, -1, opt.ninp)

            real_hidden = repackage_hidden_new(real_hidden, batch_size)
            #wrong_hidden = repackage_hidden_new(wrong_hidden, ans_wrong_emb.size(1))
            fake_hidden = repackage_hidden_new(fake_hidden, batch_size)

            fake_feat = netD(ans_fake_emb, fake_idx, fake_hidden, vocab_size)
            real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size)

            d_g_loss, g_fake = critG(featD, real_feat,
                                     fake_feat)  #, fake_diff_mask.detach())

            netW_g.zero_grad()
            netG.zero_grad()
            netE_g.zero_grad()

            d_g_loss.backward()
            optimizerG.step()

            err_g += d_g_loss.data.item()
            err_g_tmp += d_g_loss.data.item()
            err_g_fake_tmp += g_fake

            count += 1

        i += 1
        loss_store.append({'iter':i, 'err_lm':err_lm_tmp/10, 'err_d':err_d_tmp/10, 'err_g':err_g_tmp/10, \
                            'd_fake': err_d_fake_tmp/10, 'g_fake':err_g_fake_tmp/10})

        if i % opt.log_interval == 0:
            print ('Epoch:%d %d/%d, err_lm %4f, err_d %4f, err_g %4f, d_fake %4f, g_fake %4f' \
                % (epoch, i, len(dataloader), err_lm_tmp/10, err_d_tmp/10, err_g_tmp/10, err_d_fake_tmp/10, \
                    err_g_fake_tmp/10))

    #average_loss = average_loss / count
    err_g = err_g / count
    err_d = err_d / count
    err_lm = err_lm / count

    return err_lm, err_d, err_g, loss_store
Exemple #7
0
def val():
    netE_g.eval()
    netE_d.eval()
    netW_g.eval()
    netW_d.eval()

    netG.eval()
    netD.eval()

    n_neg = 100
    ques_hidden1 = netE_g.init_hidden(opt.batchSize)
    ques_hidden2 = netE_d.init_hidden(opt.batchSize)

    hist_hidden1 = netE_d.init_hidden(opt.batchSize)
    hist_hidden2 = netE_g.init_hidden(opt.batchSize)

    opt_hidden = netD.init_hidden(opt.batchSize)
    data_iter_val = iter(dataloader_val)

    count = 0
    i = 0
    rank_G = []
    rank_D = []

    while i < len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                    opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data

        batch_size = question.size(0)
        image = image.view(-1, 512)

        with torch.no_grad():
            img_input.resize_(image.size()).copy_(image)

        for rnd in range(10):

            # get the corresponding round QA and history.
            ques = question[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()

            opt_ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t()
            opt_tans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:, rnd]
            opt_len = opt_answerLen[:, rnd, :].clone().view(-1)

            #-----------------------------------------------------------------
            his_input = torch.LongTensor(his.size()).cuda()
            his_input.copy_(his)

            ques_input = torch.LongTensor(ques.size()).cuda()
            ques_input.copy_(ques)

            opt_ans_input = torch.LongTensor(opt_ans.size()).cuda()
            opt_ans_input.copy_(opt_ans)

            opt_ans_target = torch.LongTensor(opt_tans.size()).cuda()
            opt_ans_target.copy_(opt_tans)

            gt_index = torch.LongTensor(gt_id.size()).cuda()
            gt_index.copy_(gt_id)

            #-----------------------------------------------------------------------

            ques_emb_g = netW_g(ques_input, format='index')
            his_emb_g = netW_g(his_input, format='index')

            ques_emb_d = netW_d(ques_input, format='index')
            his_emb_d = netW_d(his_input, format='index')

            ques_hidden1 = repackage_hidden_new(ques_hidden1, batch_size)
            ques_hidden2 = repackage_hidden_new(ques_hidden2, batch_size)

            hist_hidden1 = repackage_hidden_new(hist_hidden1,
                                                his_emb_g.size(1))
            hist_hidden2 = repackage_hidden_new(hist_hidden2,
                                                his_emb_d.size(1))

            featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                ques_hidden1, hist_hidden1, rnd+1)

            featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \
                                                ques_hidden2, hist_hidden2, rnd+1)
            #featD = l2_norm(featD)
            # Evaluate the Generator:
            _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1)
            #_, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden)
            # extend the hidden
            hidden_replicated = []
            for hid in ques_hidden1:
                hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \
                    opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid))
            hidden_replicated = tuple(hidden_replicated)

            ans_emb = netW_g(opt_ans_input, format='index')

            output, _ = netG(ans_emb, hidden_replicated)
            logprob = -output
            logprob_select = torch.gather(logprob, 1,
                                          opt_ans_target.view(-1, 1))

            mask = opt_ans_target.data.eq(0)  # generate the mask
            if isinstance(logprob, Variable):
                mask = Variable(mask, volatile=logprob.volatile)
            logprob_select.masked_fill_(mask.view_as(logprob_select), 0)

            prob = logprob_select.view(ans_length, -1,
                                       100).sum(0).view(-1, 100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b * 100

            gt_score = prob.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(prob, 1)

            count = sort_score.lt(gt_score.view(-1, 1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_G += list(rank.view(-1).data.cpu().numpy())

            opt_ans_emb = netW_d(opt_ans_target, format='index')
            opt_hidden = repackage_hidden_new(opt_hidden,
                                              opt_ans_target.size(1))
            opt_feat = netD(opt_ans_emb, opt_ans_target, opt_hidden,
                            vocab_size)
            opt_feat = opt_feat.view(batch_size, -1, opt.ninp)

            #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid)
            featD = featD.view(-1, opt.ninp, 1)
            score = torch.bmm(opt_feat, featD)
            score = score.view(-1, 100)

            gt_score = score.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(score, 1, descending=True)
            count = sort_score.gt(gt_score.view(-1, 1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_D += list(rank.view(-1).data.cpu().numpy())

        i += 1

    return rank_G, rank_D
Exemple #8
0
def val():
    netE.eval()
    netW.eval()
    netG.eval()

    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    i = 0
    average_loss = 0
    rank_all_tmp = []

    while i < len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                    opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)

        with torch.no_grad():
            img_input.resize_(image.size()).copy_(image)

        for rnd in range(10):
            # get the corresponding round QA and history.
            ques, tans = question[:, rnd, :].t(), opt_answerT[:, rnd, :].clone(
            ).view(-1, ans_length).t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()
            ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:, rnd]

            # his_input.data.resize_(his.size()).copy_(his)
            # ques_input.data.resize_(ques.size()).copy_(ques)
            # ans_input.data.resize_(ans.size()).copy_(ans)
            # ans_target.data.resize_(tans.size()).copy_(tans)

            his_input = torch.LongTensor(his.size())
            his_input.copy_(his)

            ques_input = torch.LongTensor(ques.size())
            ques_input.copy_(ques)

            ans_input = torch.LongTensor(ans.size())
            ans_input.copy_(ans)

            ans_target = torch.LongTensor(tans.size())
            ans_target.copy_(tans)

            gt_index = torch.LongTensor(gt_id.size())
            gt_index.copy_(gt_id)

            ques_emb = netW(ques_input, format='index')
            his_emb = netW(his_input, format='index')

            ques_hidden = repackage_hidden_new(ques_hidden, batch_size)
            hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1))

            encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp),
                                  ques_hidden)

            hidden_replicated = []
            for hid in ques_hidden:
                hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \
                    opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid))
            hidden_replicated = tuple(hidden_replicated)

            ans_emb = netW(ans_input, format='index')

            output, _ = netG(ans_emb, hidden_replicated)
            logprob = -output
            logprob_select = torch.gather(logprob, 1, ans_target.view(-1, 1))

            mask = ans_target.data.eq(0)  # generate the mask
            if isinstance(logprob, Variable):
                mask = Variable(mask, volatile=logprob.volatile)
            logprob_select.masked_fill_(mask.view_as(logprob_select), 0)

            prob = logprob_select.view(ans_length, -1,
                                       100).sum(0).view(-1, 100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b * 100

            gt_score = prob.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(prob, 1)

            count = sort_score.lt(gt_score.view(-1, 1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_all_tmp += list(rank.view(-1).data.cpu().numpy())

        i += 1

    return rank_all_tmp, average_loss
Exemple #9
0
def train(epoch):
    netW.train()
    netE.train()
    netG.train()

    lr = adjust_learning_rate(optimizer, epoch, opt.lr)
    data_iter = iter(dataloader)

    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)
    average_loss = 0
    count = 0
    i = 0
    while i < len(dataloader):
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, \
        questionL, negAnswer, negAnswerLen, negAnswerIdx = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        with torch.no_grad():
            img_input.resize_(image.size()).copy_(image)

        for rnd in range(10):
            ques = question[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()
            ans, tans = answer[:, rnd, :].t(), answerT[:, rnd, :].t()

            his_input = torch.LongTensor(his.size())
            his_input.copy_(his)

            ques_input = torch.LongTensor(ques.size())
            ques_input.copy_(ques)

            ans_input = torch.LongTensor(ans.size())
            ans_input.copy_(ans)

            ans_target = torch.LongTensor(tans.size())
            ans_target.copy_(tans)

            ques_emb = netW(ques_input, format='index')
            his_emb = netW(his_input, format='index')

            ques_hidden = repackage_hidden_new(ques_hidden, batch_size)
            hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1))

            encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp),
                                  ques_hidden)

            ans_emb = netW(ans_input)
            logprob, ques_hidden = netG(ans_emb, ques_hidden)
            loss = critG(logprob, ans_target.view(-1, 1))

            loss = loss / torch.sum(ans_target.data.gt(0))
            average_loss += loss.data.item()
            # do backward.
            netW.zero_grad()
            netE.zero_grad()
            netG.zero_grad()
            loss.backward()

            optimizer.step()
            count += 1
        i += 1
        if i % opt.log_interval == 0:
            average_loss /= count
            print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}"\
                .format(i, len(dataloader), epoch, average_loss, lr))
            average_loss = 0
            count = 0

    return average_loss, lr
def eval():
    netE.eval()
    netW.eval()
    netG.eval()

    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    i = 0
    display_count = 0
    average_loss = 0
    rank_all_tmp = []
    while i < len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
        opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data

        batch_size = question.size(0)
        image = image.view(-1, 512)

        with torch.no_grad():
            img_input.resize_(image.size()).copy_(image)
        # img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            # get the corresponding round QA and history.
            ques, tans = question[:, rnd, :].t(), opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()
            ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:, rnd]

            # his_input.data.resize_(his.size()).copy_(his)
            # ques_input.data.resize_(ques.size()).copy_(ques)
            # ans_input.data.resize_(ans.size()).copy_(ans)
            # ans_target.data.resize_(tans.size()).copy_(tans)
            #
            # gt_index.data.resize_(gt_id.size()).copy_(gt_id)

            his_input = torch.LongTensor(his.size()).cpu()
            his_input.copy_(his)

            ques_input = torch.LongTensor(ques.size()).cpu()
            ques_input.copy_(ques)

            ans_input = torch.LongTensor(ans.size()).cpu()
            ans_input.copy_(ans)

            ans_target = torch.LongTensor(tans.size()).cpu()
            ans_target.copy_(tans)

            gt_index = torch.LongTensor(gt_id.size()).cpu()
            gt_index.copy_(gt_id)

            ques_emb = netW(ques_input, format='index')
            his_emb = netW(his_input, format='index')

            ques_hidden = repackage_hidden_new(ques_hidden, batch_size)
            hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1))

            encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                             ques_hidden, hist_hidden, rnd + 1)

            _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp), ques_hidden)

            # ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid)
            ans_score = torch.FloatTensor(batch_size, 100).zero_()
            # extend the hidden
            hidden_replicated = []
            for hid in ques_hidden:
                hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \
                                                  opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(
                    opt.nlayers, -1, opt.nhid))
            hidden_replicated = tuple(hidden_replicated)

            ans_emb = netW(ans_input, format='index')

            output, _ = netG(ans_emb, hidden_replicated)
            logprob = - output
            logprob_select = torch.gather(logprob, 1, ans_target.view(-1, 1))

            mask = ans_target.data.eq(0)  # generate the mask
            if isinstance(logprob, Variable):
                mask = Variable(mask, volatile=logprob.volatile)
            logprob_select.masked_fill_(mask.view_as(logprob_select), 0)

            prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1, 100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b * 100

            gt_score = prob.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(prob, 1)

            count = sort_score.lt(gt_score.view(-1, 1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_all_tmp += list(rank.view(-1).data.cpu().numpy())

        i += 1
        sys.stdout.write('Evaluating: {:d}/{:d}  \r' \
                         .format(i, len(dataloader_val)))

        if i % 50 == 0:
            R1 = np.sum(np.array(rank_all_tmp) == 1) / float(len(rank_all_tmp))
            R5 = np.sum(np.array(rank_all_tmp) <= 5) / float(len(rank_all_tmp))
            R10 = np.sum(np.array(rank_all_tmp) <= 10) / float(len(rank_all_tmp))
            ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp))
            mrr = np.sum(1 / (np.array(rank_all_tmp, dtype='float'))) / float(len(rank_all_tmp))
            print('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' % (1, len(dataloader_val), mrr, R1, R5, R10, ave))

    return rank_all_tmp
Exemple #11
0
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    ind = his.size(1)

    his_input = torch.LongTensor(his.size())
    his_input.copy_(his)

    ques_input = torch.LongTensor(ques.size())
    ques_input.copy_(ques)

    ques_emb = netW(ques_input, format='index')
    his_emb = netW(his_input, format='index')

    ques_hidden = repackage_hidden_new(ques_hidden, 1)
    hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1))

    encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                        ques_hidden, hist_hidden, ind)

    _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp), ques_hidden)

    # generate ans based on ques_hidden
    # using netG(x , ques_hidden)

    # Gumble softmax to sample the output.
    ans_length = 16
    fake_onehot = []
    fake_idx = []
    noise_input = torch.FloatTensor(
Exemple #12
0
def train(epoch):
    netW.train()
    netE.train()
    netD.train()

    lr = adjust_learning_rate(optimizer, epoch, opt.lr)

    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    real_hidden = netD.init_hidden(opt.batchSize)
    wrong_hidden = netD.init_hidden(opt.batchSize)

    data_iter = iter(dataloader)

    average_loss = 0
    avg_dist_summary = np.zeros(3, dtype=float)
    smooth_avg_dist_summary = np.zeros(3, dtype=float)
    count = 0
    i = 0

    # size of data to work on
    early_stop = int(opt.early_stop / opt.batchSize)
    dataloader_size = min(len(dataloader), early_stop)
    while i < dataloader_size: #len(dataloader):

        t1 = time.time()
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
                                    opt_answerT, opt_answerLen, opt_answerIdx, opt_selected_probs = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)

        with torch.no_grad():
            img_input.resize_(image.size()).copy_(image)

        for rnd in range(10):
            netW.zero_grad()
            netE.zero_grad()
            netD.zero_grad()
            # get the corresponding round QA and history.
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()

            ans = answer[:,rnd,:].t()
            tans = answerT[:,rnd,:].t()
            wrong_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()
            opt_selected_probs_for_rnd = opt_selected_probs[:, rnd, :, :]

            real_len = answerLen[:,rnd]
            wrong_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_input = torch.LongTensor(ques.size()).cuda()
            ques_input.copy_(ques)

            his_input = torch.LongTensor(his.size()).cuda()
            his_input.copy_(his)

            ans_input = torch.LongTensor(ans.size()).cuda()
            ans_input.copy_(ans)

            ans_target = torch.LongTensor(tans.size()).cuda()
            ans_target.copy_(tans)

            wrong_ans_input = torch.LongTensor(wrong_ans.size()).cuda()
            wrong_ans_input.copy_(wrong_ans)

            opt_selected_probs_for_rnd_input = torch.FloatTensor(opt_selected_probs_for_rnd.size()).cuda()
            opt_selected_probs_for_rnd_input.copy_(opt_selected_probs_for_rnd)

            # # sample in-batch negative index
            # batch_sample_idx = torch.zeros(batch_size, opt.neg_batch_sample, dtype=torch.long).cuda()
            # sample_batch_neg(answerIdx[:,rnd], opt_answerIdx[:,rnd,:], batch_sample_idx, opt.neg_batch_sample)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')

            ques_hidden = repackage_hidden_new(ques_hidden, batch_size)
            hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            ans_real_emb = netW(ans_target, format='index')
            ans_wrong_emb = netW(wrong_ans_input, format='index')

            real_hidden = repackage_hidden_new(real_hidden, batch_size)
            wrong_hidden = repackage_hidden_new(wrong_hidden, ans_wrong_emb.size(1))

            real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size)
            wrong_feat = netD(ans_wrong_emb, wrong_ans_input, wrong_hidden, vocab_size)

            # batch_wrong_feat = wrong_feat.index_select(0, batch_sample_idx.view(-1))
            wrong_feat = wrong_feat.view(batch_size, -1, opt.ninp)
            # batch_wrong_feat = batch_wrong_feat.view(batch_size, -1, opt.ninp)

            nPairLoss, dist_summary, smooth_dist_summary = \
                critD(featD, real_feat, wrong_feat, opt_selected_probs_for_rnd_input)

            average_loss += nPairLoss.data.item()
            avg_dist_summary += dist_summary.cpu().detach().numpy()
            smooth_avg_dist_summary += smooth_dist_summary.cpu().detach().numpy()
            nPairLoss.backward()
            optimizer.step()
            count += 1

        i += 1
        if i % opt.log_interval == 0:
            average_loss /= count
            avg_dist_summary = avg_dist_summary/np.sum(avg_dist_summary)
            smooth_avg_dist_summary = smooth_avg_dist_summary/np.sum(smooth_avg_dist_summary)
            print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}, CEN dist: {}, CEN smooth: {}"\
                .format(i, len(dataloader), epoch, average_loss, lr, avg_dist_summary, smooth_avg_dist_summary))
            average_loss = 0
            avg_dist_summary = np.zeros(3, dtype=float)
            smooth_avg_dist_summary = np.zeros(3, dtype=float)
            count = 0

    return average_loss, lr