Beispiel #1
0
def sample():
    embeder.eval()
    encoder.eval()
    decoder.eval()

    data_iter_val = iter(dataloader_val)
    hidden = encoder.init_hidden(opt.batchSize)

    i = 0
    while i < len(dataloader_val):
        data = data_iter_val.next()
        history, question, answer, answerT, questionL, opt_answer, opt_answerT, answer_ids  = data
        batch_size = question.size(0)

        for rnd in range(10):
            ques, tans = question[:,rnd,:].t(), answerT[:,rnd,:].t()
            #his = history[:,:rnd+1,:].clone().view(-1, his_length).t()
            ans = opt_answer[:,rnd,:,:].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:,rnd]
            ques_input.data.resize_(ques.size()).copy_(ques)
            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)

            ques_emb = embeder(ques_input)
            hidden = repackage_hidden(hidden, batch_size)
            _, hidden = encoder(ques_emb, hidden)

            #output, hidden = decoder(ans_emb, hidden)
            #loss = crit(output, ans_target.view(-1,1))
            #average_loss += loss.data[0]
            #count += 1
            ans_sample_result = torch.Tensor(ans_length, batch_size)
            ans_sample.data.resize_((1, batch_size)).fill_(n_words)
            # sample the result.
            
            
            noise_input.data.resize_(ans_length, batch_size, n_words+1)
            noise_input.data.uniform_(0,1)
            for t in range(ans_length):
                ans_sample_embed = embeder(ans_sample)
                output, hidden = decoder(ans_sample_embed, hidden)

                prob = - output
                #_, idx = torch.max(prob, 1)
                one_hot, idx = sampler(prob, noise_input[t], 0.5)

                ans_sample.data.copy_(idx.data)
                ans_sample_result[t].copy_(ans_sample.data)

            ans_sample_txt = decode_txt(itow, ans_sample_result)
            ans_txt = decode_txt(itow, tans)
            ques_txt = decode_txt(itow, questionL[:,rnd,:].t())
            for j in range(len(ans_txt)):
                print('Q: %s --A: %s --Sampled: %s' %(ques_txt[j], ans_txt[j], ans_sample_txt[j]))

            pdb.set_trace()
        i += 1
Beispiel #2
0
def sample():
    embeder.eval()
    encoder.eval()
    decoder.eval()

    data_iter_val = iter(dataloader_val)
    hidden = encoder.init_hidden(opt.batchSize)

    i = 0
    while i < len(dataloader_val):
        data = data_iter_val.next()
        history, question, answer, answerT, questionL, opt_answer, opt_answerT, answer_ids = data
        batch_size = question.size(0)

        for rnd in range(10):
            ques, tans = question[:, rnd, :].t(), answerT[:, rnd, :].t()
            #his = history[:,:rnd+1,:].clone().view(-1, his_length).t()
            ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:, rnd]
            ques_input.data.resize_(ques.size()).copy_(ques)
            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)

            ques_emb = embeder(ques_input)
            hidden = repackage_hidden(hidden, batch_size)
            _, hidden = encoder(ques_emb, hidden)

            #output, hidden = decoder(ans_emb, hidden)
            #loss = crit(output, ans_target.view(-1,1))
            #average_loss += loss.data[0]
            #count += 1
            ans_sample_result = torch.Tensor(ans_length, batch_size)
            ans_sample.data.resize_((1, batch_size)).fill_(n_words)
            # sample the result.

            noise_input.data.resize_(ans_length, batch_size, n_words + 1)
            noise_input.data.uniform_(0, 1)
            for t in range(ans_length):
                ans_sample_embed = embeder(ans_sample)
                output, hidden = decoder(ans_sample_embed, hidden)

                prob = -output
                #_, idx = torch.max(prob, 1)
                one_hot, idx = sampler(prob, noise_input[t], 0.5)

                ans_sample.data.copy_(idx.data)
                ans_sample_result[t].copy_(ans_sample.data)

            ans_sample_txt = decode_txt(itow, ans_sample_result)
            ans_txt = decode_txt(itow, tans)
            ques_txt = decode_txt(itow, questionL[:, rnd, :].t())
            for j in range(len(ans_txt)):
                print('Q: %s --A: %s --Sampled: %s' %
                      (ques_txt[j], ans_txt[j], ans_sample_txt[j]))

            pdb.set_trace()
        i += 1
Beispiel #3
0
def val():
    netE.eval()
    netW.eval()
    netG.eval()

    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    i = 0
    average_loss = 0
    rank_all_tmp = []

    while i < len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                    opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            # get the corresponding round QA and history.
            ques, tans = question[:, rnd, :].t(), opt_answerT[:, rnd, :].clone(
            ).view(-1, ans_length).t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()
            ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:, rnd]

            his_input.data.resize_(his.size()).copy_(his)
            ques_input.data.resize_(ques.size()).copy_(ques)
            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)

            gt_index.data.resize_(gt_id.size()).copy_(gt_id)

            ques_emb = netW(ques_input, format='index')
            his_emb = netW(his_input, format='index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp),
                                  ques_hidden)

            hidden_replicated = []
            for hid in ques_hidden:
                hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \
                    opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid))
            hidden_replicated = tuple(hidden_replicated)

            ans_emb = netW(ans_input, format='index')

            output, _ = netG(ans_emb, hidden_replicated)
            logprob = -output
            logprob_select = torch.gather(logprob, 1, ans_target.view(-1, 1))

            mask = ans_target.data.eq(0)  # generate the mask
            if isinstance(logprob, Variable):
                mask = Variable(mask, volatile=logprob.volatile)
            logprob_select.masked_fill_(mask.view_as(logprob_select), 0)

            prob = logprob_select.view(ans_length, -1,
                                       100).sum(0).view(-1, 100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b * 100

            gt_score = prob.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(prob, 1)

            count = sort_score.lt(gt_score.view(-1, 1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_all_tmp += list(rank.view(-1).data.cpu().numpy())

        i += 1

    return rank_all_tmp, average_loss
Beispiel #4
0
def train(epoch):
    netW.train()
    netE.train()
    netG.train()

    lr = adjust_learning_rate(optimizer, epoch, opt.lr)
    data_iter = iter(dataloader)

    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)
    average_loss = 0
    count = 0
    i = 0
    while i < len(dataloader):
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, \
        questionL, negAnswer, negAnswerLen, negAnswerIdx = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            ques = question[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()
            ans, tans = answer[:, rnd, :].t(), answerT[:, rnd, :].t()

            his_input.data.resize_(his.size()).copy_(his)
            ques_input.data.resize_(ques.size()).copy_(ques)
            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)

            ques_emb = netW(ques_input, format='index')
            his_emb = netW(his_input, format='index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp),
                                  ques_hidden)

            ans_emb = netW(ans_input)
            logprob, ques_hidden = netG(ans_emb, ques_hidden)
            loss = critG(logprob, ans_target.view(-1, 1))

            loss = loss / torch.sum(ans_target.data.gt(0))
            average_loss += loss.data[0]
            # do backward.
            netW.zero_grad()
            netE.zero_grad()
            netG.zero_grad()
            loss.backward()

            optimizer.step()
            count += 1
        i += 1
        if i % opt.log_interval == 0:
            average_loss /= count
            print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}"\
                .format(i, len(dataloader), epoch, average_loss, lr))
            average_loss = 0
            count = 0

    return average_loss
Beispiel #5
0
def val():
    netE.eval()
    netW.eval()
    netD.eval()

    n_neg = 100
    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    opt_hidden = netD.init_hidden(opt.batchSize)
    i = 0

    average_loss = 0
    rank_all_tmp = []

    while i < len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                opt_answerT, answer_ids, answerLen, opt_answerLen, img_id  = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        #image = l2_norm(image)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            # get the corresponding round QA and history.
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()

            opt_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:,rnd]

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans)
            gt_index.data.resize_(gt_id.size()).copy_(gt_id)
            opt_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            opt_ans_emb = netW(opt_ans_input, format = 'index')
            opt_hidden = repackage_hidden(opt_hidden, opt_ans_input.size(1))
            opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, vocab_size)
            opt_feat = opt_feat.view(batch_size, -1, opt.ninp)

            #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid)
            featD = featD.view(-1, opt.ninp, 1)
            score = torch.bmm(opt_feat, featD)
            score = score.view(-1, 100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b*100

            gt_score = score.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(score, 1, descending=True)

            count = sort_score.gt(gt_score.view(-1,1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_all_tmp += list(rank.view(-1).data.cpu().numpy())
            
        i += 1
        sys.stdout.write('Evaluating: {:d}/{:d}  \r' \
          .format(i, len(dataloader_val)))
        sys.stdout.flush()

    return rank_all_tmp
Beispiel #6
0
def train(epoch):
    netW.train()
    netE.train()
    netD.train()

    lr = adjust_learning_rate(optimizer, epoch, opt.lr)

    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    real_hidden = netD.init_hidden(opt.batchSize)
    wrong_hidden = netD.init_hidden(opt.batchSize)

    data_iter = iter(dataloader)

    average_loss = 0
    count = 0
    i = 0

    while i < len(dataloader):

        t1 = time.time()
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
                                    opt_answerT, opt_answerLen, opt_answerIdx = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            netW.zero_grad()
            netE.zero_grad()
            netD.zero_grad()
            # get the corresponding round QA and history.
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()

            ans = answer[:,rnd,:].t()
            tans = answerT[:,rnd,:].t()
            wrong_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()

            real_len = answerLen[:,rnd]
            wrong_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)
            wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans)

            # sample in-batch negative index
            batch_sample_idx.data.resize_(batch_size, opt.neg_batch_sample).zero_()
            sample_batch_neg(answerIdx[:,rnd], opt_answerIdx[:,rnd,:], batch_sample_idx, opt.neg_batch_sample)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            ans_real_emb = netW(ans_target, format='index')
            ans_wrong_emb = netW(wrong_ans_input, format='index')

            real_hidden = repackage_hidden(real_hidden, batch_size)
            wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1))

            real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size)
            wrong_feat = netD(ans_wrong_emb, wrong_ans_input, wrong_hidden, vocab_size)

            batch_wrong_feat = wrong_feat.index_select(0, batch_sample_idx.view(-1))
            wrong_feat = wrong_feat.view(batch_size, -1, opt.ninp)
            batch_wrong_feat = batch_wrong_feat.view(batch_size, -1, opt.ninp)

            nPairLoss = critD(featD, real_feat, wrong_feat, batch_wrong_feat)

            average_loss += nPairLoss.data[0]
            nPairLoss.backward()
            optimizer.step()
            count += 1

        i += 1
        if i % opt.log_interval == 0:
            average_loss /= count
            print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}"\
                .format(i, len(dataloader), epoch, average_loss, lr))
            average_loss = 0
            count = 0

    return average_loss
Beispiel #7
0
def train(epoch):
    netW_d.train(), netE_d.train(), netE_g.train()
    netD.train(), netG.train(), netW_g.train()

    fake_len = torch.LongTensor(opt.batchSize)
    fake_len = fake_len.cuda()

    n_neg = opt.negative_sample
    ques_hidden1 = netE_d.init_hidden(opt.batchSize)
    ques_hidden2 = netE_g.init_hidden(opt.batchSize)

    hist_hidden1 = netE_d.init_hidden(opt.batchSize)
    hist_hidden2 = netE_g.init_hidden(opt.batchSize)

    real_hidden = netD.init_hidden(opt.batchSize)
    wrong_hidden = netD.init_hidden(opt.batchSize)
    fake_hidden = netD.init_hidden(opt.batchSize)

    data_iter = iter(dataloader)

    err_d = 0
    err_g = 0
    err_lm = 0
    average_loss = 0
    count = 0
    i = 0
    loss_store = []
    while i < len(dataloader):
        t1 = time.time()
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
                                    opt_answerT, opt_answerLen, opt_answerIdx = data
        batch_size = question.size(0)
        image = image.view(-1, 512)
        img_input.data.resize_(image.size()).copy_(image)

        err_d_tmp = 0
        err_g_tmp = 0
        err_lm_tmp = 0
        err_d_fake_tmp = 0
        err_g_fake_tmp = 0
        for rnd in range(10):
            # get the corresponding round QA and history.
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()

            ans = answer[:,rnd,:].t()
            tans = answerT[:,rnd,:].t()
            wrong_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()

            real_len = answerLen[:,rnd].long()
            wrong_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)
            wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans)

            batch_sample_idx.data.resize_(batch_size, opt.neg_batch_sample).zero_()
            sample_batch_neg(answerIdx[:,rnd], opt_answerIdx[:,rnd,:], batch_sample_idx, opt.neg_batch_sample)

            # -----------------------------------------
            # update the Generator using MLE loss.
            # -----------------------------------------
            if opt.update_LM:
                ques_emb_g = netW_g(ques_input, format = 'index')
                his_emb_g = netW_g(his_input, format = 'index')

                ques_hidden1 = repackage_hidden(ques_hidden1, batch_size)
                hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1))

                featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                    ques_hidden1, hist_hidden1, rnd+1)

                _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1)
                # MLE loss for generator
                ans_emb = netW_g(ans_input)
                logprob, _ = netG(ans_emb, ques_hidden1)
                lm_loss = critLM(logprob, ans_target.view(-1, 1))
                lm_loss = lm_loss / torch.sum(ans_target.data.gt(0))

                netW_g.zero_grad()
                netG.zero_grad()
                netE_g.zero_grad()

                lm_loss.backward()
                optimizerLM.step()
                err_lm += lm_loss.data[0]
                err_lm_tmp += lm_loss.data[0]

            # sample the answer using gumble softmax sampler.
            ques_emb_g = netW_g(ques_input, format = 'index')
            his_emb_g = netW_g(his_input, format = 'index')

            ques_hidden1 = repackage_hidden(ques_hidden1, batch_size)
            hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1))

            featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                ques_hidden1, hist_hidden1, rnd+1)

            _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1)

            # Gumble softmax to sample the output.
            fake_onehot = []
            fake_idx = []
            noise_input.data.resize_(ans_length, batch_size, vocab_size+1)
            noise_input.data.uniform_(0,1)

            ans_sample = ans_input[0]
            for di in range(ans_length):
                ans_emb = netW_g(ans_sample, format = 'index')
                logprob, ques_hidden1 = netG(ans_emb.view(1,-1,opt.ninp), ques_hidden1)
                one_hot, idx = sampler(logprob, noise_input[di], opt.gumble_weight)
                fake_onehot.append(one_hot.view(1, -1, vocab_size+1))
                fake_idx.append(idx)
                if di+1 < ans_length:
                    ans_sample = idx

            # convert the list into the tensor variable.
            fake_onehot = torch.cat(fake_onehot, 0)
            fake_idx = torch.cat(fake_idx,0)

            fake_len.resize_(batch_size).fill_(ans_length-1)
            for di in range(ans_length-1, 0, -1):
                fake_len.masked_fill_(fake_idx.data[di].eq(vocab_size), di)

            # generate fake mask.
            fake_mask.data.resize_(fake_idx.size()).fill_(1)
            # get the real, wrong and fake index.
            for b in range(batch_size):
                fake_mask.data[:fake_len[b]+1, b] = 0

            # apply the mask on the fake_idx.
            fake_idx.masked_fill_(fake_mask, 0)

            # get the fake diff mask.
            #fake_diff_mask = torch.sum(fake_idx == ans_target, 0) != 0
            fake_onehot = fake_onehot.view(-1, vocab_size+1)
            
            ######################################
            # Discriminative trained generative model.
            ######################################
            # forward the discriminator again.
            ques_emb_d = netW_d(ques_input, format = 'index')
            his_emb_d = netW_d(his_input, format = 'index')

            ques_hidden2 = repackage_hidden(ques_hidden2, batch_size)
            hist_hidden2 = repackage_hidden(hist_hidden2, his_emb_d.size(1))

            featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \
                                        ques_hidden2, hist_hidden2, rnd+1)

            ans_real_emb = netW_d(ans_target, format='index')
            #ans_wrong_emb = netW_d(wrong_ans_input, format='index')
            ans_fake_emb = netW_d(fake_onehot, format='onehot')
            ans_fake_emb = ans_fake_emb.view(ans_length, -1, opt.ninp)

            real_hidden = repackage_hidden(real_hidden, batch_size)
            #wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1))
            fake_hidden = repackage_hidden(fake_hidden, batch_size)

            fake_feat = netD(ans_fake_emb, fake_idx, fake_hidden, vocab_size)
            real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size)

            d_g_loss, g_fake = critG(featD, real_feat, fake_feat)#, fake_diff_mask.detach())

            netW_g.zero_grad()
            netG.zero_grad()
            netE_g.zero_grad()

            d_g_loss.backward()
            optimizerG.step()

            err_g += d_g_loss.data[0]
            err_g_tmp += d_g_loss.data[0]
            err_g_fake_tmp += g_fake

            count += 1

        i += 1
        loss_store.append({'iter':i, 'err_lm':err_lm_tmp/10, 'err_d':err_d_tmp/10, 'err_g':err_g_tmp/10, \
                            'd_fake': err_d_fake_tmp/10, 'g_fake':err_g_fake_tmp/10})

        if i % 20 == 0:
            print ('Epoch:%d %d/%d, err_lm %4f, err_d %4f, err_g %4f, d_fake %4f, g_fake %4f' \
                % (epoch, i, len(dataloader), err_lm_tmp/10, err_d_tmp/10, err_g_tmp/10, err_d_fake_tmp/10, \
                    err_g_fake_tmp/10))


    #average_loss = average_loss / count
    err_g = err_g / count
    err_d = err_d / count
    err_lm = err_lm / count

    return err_lm, err_d, err_g, loss_store
Beispiel #8
0
def eval():

    netE.eval()
    netW.eval()
    netG.eval()

    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    i = 0
    display_count = 0
    average_loss = 0
    rank_all_tmp = []
    while i < len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                opt_answerT, answer_ids, answerLen, opt_answerLen, img_id  = data

        batch_size = question.size(0)
        image = image.view(-1, 512)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            # get the corresponding round QA and history.
            ques, tans = question[:, rnd, :].t(), opt_answerT[:, rnd, :].clone(
            ).view(-1, ans_length).t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()
            ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:, rnd]

            his_input.data.resize_(his.size()).copy_(his)
            ques_input.data.resize_(ques.size()).copy_(ques)
            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)

            gt_index.data.resize_(gt_id.size()).copy_(gt_id)

            ques_emb = netW(ques_input, format='index')
            his_emb = netW(his_input, format='index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            _, ques_hidden = netG(encoder_feat.view(1, -1, opt.ninp),
                                  ques_hidden)

            #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid)
            ans_score = torch.FloatTensor(batch_size, 100).zero_()
            # extend the hidden
            hidden_replicated = []
            for hid in ques_hidden:
                hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \
                    opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid))
            hidden_replicated = tuple(hidden_replicated)

            ans_emb = netW(ans_input, format='index')

            output, _ = netG(ans_emb, hidden_replicated)
            logprob = -output
            logprob_select = torch.gather(logprob, 1, ans_target.view(-1, 1))

            mask = ans_target.data.eq(0)  # generate the mask
            if isinstance(logprob, Variable):
                mask = Variable(mask, volatile=logprob.volatile)
            logprob_select.masked_fill_(mask.view_as(logprob_select), 0)

            prob = logprob_select.view(ans_length, -1,
                                       100).sum(0).view(-1, 100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b * 100

            gt_score = prob.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(prob, 1)

            count = sort_score.lt(gt_score.view(-1, 1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_all_tmp += list(rank.view(-1).data.cpu().numpy())

        i += 1
        sys.stdout.write('Evaluating: {:d}/{:d}  \r' \
          .format(i, len(dataloader_val)))

        if i % 50 == 0:
            R1 = np.sum(np.array(rank_all_tmp) == 1) / float(len(rank_all_tmp))
            R5 = np.sum(np.array(rank_all_tmp) <= 5) / float(len(rank_all_tmp))
            R10 = np.sum(np.array(rank_all_tmp) <= 10) / float(
                len(rank_all_tmp))
            ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp))
            mrr = np.sum(1 / (np.array(rank_all_tmp, dtype='float'))) / float(
                len(rank_all_tmp))
            print('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' %
                  (1, len(dataloader_val), mrr, R1, R5, R10, ave))

    return rank_all_tmp
Beispiel #9
0
def val():
    netE.eval()
    netW.eval()
    netD.eval()

    n_neg = 100
    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    opt_hidden = netD.init_hidden(opt.batchSize)
    i = 0

    average_loss = 0
    rank_all_tmp = []

    while i < len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                opt_answerT, answer_ids, answerLen, opt_answerLen, img_id  = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        #image = l2_norm(image)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            # get the corresponding round QA and history.
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()

            opt_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:,rnd]

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans)
            gt_index.data.resize_(gt_id.size()).copy_(gt_id)
            opt_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            opt_ans_emb = netW(opt_ans_input, format = 'index')
            opt_hidden = repackage_hidden(opt_hidden, opt_ans_input.size(1))
            opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, vocab_size)
            opt_feat = opt_feat.view(batch_size, -1, opt.ninp)

            #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid)
            featD = featD.view(-1, opt.ninp, 1)
            score = torch.bmm(opt_feat, featD)
            score = score.view(-1, 100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b*100

            gt_score = score.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(score, 1, descending=True)

            count = sort_score.gt(gt_score.view(-1,1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_all_tmp += list(rank.view(-1).data.cpu().numpy())
            
        i += 1
        sys.stdout.write('Evaluating: {:d}/{:d}  \r' \
          .format(i, len(dataloader_val)))
        sys.stdout.flush()

    return rank_all_tmp
Beispiel #10
0
def eval_val():
    netW.eval()
    netE.eval()
    netD.eval()

    data_iter_test = iter(dataloader_test)
    ques_hidden = netE.init_bi_hidden(opt.batchSize)
    hist_hidden = netE.init_bi_hidden(opt.batchSize)

    opt_hidden = netD.init_bi_hidden(opt.batchSize)
    i = 0
    display_count = 0
    average_loss = 0

    img_atten = torch.FloatTensor(100 * 30, 10, 7, 7)
    result = []
    while i < len(dataloader_test):  #len(1000):
        data = data_iter_test.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                opt_answerT, answer_ids, answerLen, opt_answerLen, img_id  = data

        batch_size = question.size(0)
        image = image.view(-1, 2048)
        img_input.data.resize_(image.size()).copy_(image)
        # save_tmp = [[] for j in range(batch_size)]

        for rnd in range(10):
            result_tmp = {'image_id': 1, 'round_id': 1, 'ranks': []}
            # get the corresponding round QA and history.
            ques = question[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()

            opt_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()
            # gt_id = answer_ids[:,rnd]

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            # gt_index.data.resize_(gt_id.size()).copy_(gt_id)
            opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans)

            # opt_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_emb = netW(ques_input, format='index')
            his_emb = netW(his_input, format='index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                                    ques_hidden, hist_hidden, rnd+1)

            #img_atten[i*batch_size:(i+1)*batch_size, rnd, :] = img_atten_weight.data.view(batch_size, 7, 7)

            opt_ans_emb = netW(opt_ans_input, format='index')
            opt_hidden = repackage_hidden(opt_hidden, opt_ans_input.size(1))
            opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, vocab_size)
            opt_feat = opt_feat.view(batch_size, -1, opt.ninp)

            featD = featD.view(-1, opt.ninp, 1)
            score = torch.bmm(opt_feat, featD)
            score = score.view(-1, 100)

            sort_score, sort_idx = torch.sort(score, 1, descending=True)
            sort_idx = sort_idx + 1
            sort_idx = sort_idx.squeeze().cpu().data.type(torch.IntTensor)
            sort_result = list(sort_idx)
            sort_final = [1] * 100
            for rank, w in enumerate(sort_result):
                sort_final[w - 1] = rank + 1

            result_tmp['image_id'] = int(img_id)
            result_tmp['round_id'] = int(rnd + 1)
            result_tmp['ranks'] = sort_final
            # sort_result = [rank +'\n' for rank in sort_result]
            result.append(result_tmp)

        i += 1
        if i % 100 == 0:
            print(i)

    return result
Beispiel #11
0
def val():
    netE_g.eval()
    netW_g.eval()
    netG.eval()

    n_neg = 100
    ques_hidden1 = netE_g.init_hidden(opt.batchSize)

    hist_hidden1 = netE_g.init_hidden(opt.batchSize)

    bar = progressbar.ProgressBar(max_value=300)  #len(dataloader))
    data_iter_val = iter(dataloader_val)

    count = 0
    i = 0

    result_all = []

    while i < 300:  #len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                    opt_answerT, answer_ids, answerLen, opt_answerLen, img_id  = data

        batch_size = question.size(0)
        image = image.view(-1, 512)
        img_input.data.resize_(image.size()).copy_(image)

        save_tmp = [[] for j in range(batch_size)]
        for rnd in range(10):

            # get the corresponding round QA and history.
            ques, tans = question[:, rnd, :].t(), answerT[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()

            opt_ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t()
            opt_tans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:, rnd]
            opt_len = opt_answerLen[:, rnd, :].clone().view(-1)

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)
            opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans)
            opt_ans_target.data.resize_(opt_tans.size()).copy_(opt_tans)
            gt_index.data.resize_(gt_id.size()).copy_(gt_id)

            ques_emb_g = netW_g(ques_input, format='index')
            his_emb_g = netW_g(his_input, format='index')

            ques_hidden1 = repackage_hidden(ques_hidden1, batch_size)

            hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1))

            featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                ques_hidden1, hist_hidden1, rnd+1)

            #featD = l2_norm(featD)
            # Evaluate the Generator:
            _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1)
            #_, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden)
            # extend the hidden
            sample_ans_input.data.resize_((1, batch_size)).fill_(vocab_size)

            sample_opt = {'beam_size': 1}

            seq, seqLogprobs = netG.sample(netW_g, sample_ans_input,
                                           ques_hidden1, sample_opt)
            ans_sample_txt = decode_txt(itow, seq.t())
            ans_txt = decode_txt(itow, tans)
            ques_txt = decode_txt(itow, questionL[:, rnd, :].t())
            '''
            for j in range(len(ans_txt)):
                print('Q: %s --A: %s --Sampled: %s' %(ques_txt[j], ans_txt[j], ans_sample_txt[j]))
            
            ans_sample_z = [[] for z in range(batch_size)]
            for m in range(5):
                ans_sample_result = torch.Tensor(ans_length, batch_size)
                # sample the result.
                noise_input.data.resize_(ans_length, batch_size, vocab_size+1)
                noise_input.data.uniform_(0,1)
                for t in range(ans_length):
                    ans_emb = netW_g(sample_ans_input, format = 'index')
                    if t == 0:
                        logprob, ques_hidden2 = netG(ans_emb.view(1,-1,opt.ninp), ques_hidden1)
                    else:
                        logprob, ques_hidden2 = netG(ans_emb.view(1,-1,opt.ninp), ques_hidden2)

                    one_hot, idx = sampler(logprob, noise_input[t], opt.gumble_weight)

                    sample_ans_input.data.copy_(idx.data)
                    ans_sample_result[t].copy_(idx.data)

                ans_sample_txt = decode_txt(itow, ans_sample_result)
                for ii in range(batch_size):
                    ans_sample_z[ii].append(ans_sample_txt[ii])
            '''
            ans_txt = decode_txt(itow, tans)
            ques_txt = decode_txt(itow, questionL[:, rnd, :].t())
            #for j in range(len(ans_txt)):
            #    print('Q: %s --A: %s --Sampled: %s' %(ques_txt[j], ans_txt[j], ans_sample_txt[j]))

            for j in range(batch_size):
                save_tmp[j].append({'ques':ques_txt[j], 'gt_ans':ans_txt[j], \
                            'sample_ans':ans_sample_txt[j], 'rnd':rnd, 'img_id':img_id[j]})
        i += 1
        bar.update(i)

        result_all += save_tmp

    return result_all
Beispiel #12
0
def eval():

    netE.eval()
    netW.eval()
    netG.eval()

    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    i = 0
    display_count = 0
    average_loss = 0
    rank_all_tmp = []
    while i < len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                opt_answerT, answer_ids, answerLen, opt_answerLen, img_id  = data

        batch_size = question.size(0)
        image = image.view(-1, 512)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            # get the corresponding round QA and history.
            ques, tans = question[:,rnd,:].t(), opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()
            ans = opt_answer[:,rnd,:,:].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:,rnd]

            his_input.data.resize_(his.size()).copy_(his)
            ques_input.data.resize_(ques.size()).copy_(ques)
            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)

            gt_index.data.resize_(gt_id.size()).copy_(gt_id)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')


            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            _, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden)

            #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid)
            ans_score = torch.FloatTensor(batch_size, 100).zero_()
            # extend the hidden
            hidden_replicated = []
            for hid in ques_hidden:
                hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \
                    opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid))
            hidden_replicated = tuple(hidden_replicated)

            ans_emb = netW(ans_input, format = 'index')

            output, _ = netG(ans_emb, hidden_replicated)
            logprob = - output
            logprob_select = torch.gather(logprob, 1, ans_target.view(-1,1))

            mask = ans_target.data.eq(0)  # generate the mask
            if isinstance(logprob, Variable):
                mask = Variable(mask, volatile=logprob.volatile)
            logprob_select.masked_fill_(mask.view_as(logprob_select), 0)

            prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1,100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b*100

            gt_score = prob.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(prob, 1)

            count = sort_score.lt(gt_score.view(-1,1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_all_tmp += list(rank.view(-1).data.cpu().numpy())

        i += 1
        sys.stdout.write('Evaluating: {:d}/{:d}  \r' \
          .format(i, len(dataloader_val)))

        if i % 50 == 0:
            R1 = np.sum(np.array(rank_all_tmp)==1) / float(len(rank_all_tmp))
            R5 =  np.sum(np.array(rank_all_tmp)<=5) / float(len(rank_all_tmp))
            R10 = np.sum(np.array(rank_all_tmp)<=10) / float(len(rank_all_tmp))
            ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp))
            mrr = np.sum(1/(np.array(rank_all_tmp, dtype='float'))) / float(len(rank_all_tmp))
            print ('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' %(1, len(dataloader_val), mrr, R1, R5, R10, ave))

    return rank_all_tmp
Beispiel #13
0
def eval():
    netW.eval()
    netE.eval()
    netD.eval()

    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    opt_hidden = netD.init_hidden(opt.batchSize)
    i = 0
    display_count = 0
    average_loss = 0
    rank_all_tmp = []
    result_all = []
    img_atten = torch.FloatTensor(100 * 30, 10, 7, 7)
    while i < len(dataloader_val):#len(1000):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                opt_answerT, answer_ids, answerLen, opt_answerLen, img_id  = data

        batch_size = question.size(0)
        image = image.view(-1, 512)
        img_input.data.resize_(image.size()).copy_(image)
        save_tmp = [[] for j in range(batch_size)]

        for rnd in range(10):
            rnd=5
            # get the corresponding round QA and history.
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()

            opt_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:,rnd]

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            gt_index.data.resize_(gt_id.size()).copy_(gt_id)
            opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans)

            opt_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                                    ques_hidden, hist_hidden, rnd+1)

            #img_atten[i*batch_size:(i+1)*batch_size, rnd, :] = img_atten_weight.data.view(batch_size, 7, 7)

            opt_ans_emb = netW(opt_ans_input, format = 'index')
            opt_hidden = repackage_hidden(opt_hidden, opt_ans_input.size(1))
            opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, n_words)
            opt_feat = opt_feat.view(batch_size, -1, opt.ninp)

            featD = featD.view(-1, opt.ninp, 1)
            score = torch.bmm(opt_feat, featD)
            score = score.view(-1, 100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b*100

            gt_score = score.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(score, 1, descending=True)

            count = sort_score.gt(gt_score.view(-1,1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_all_tmp += list(rank.view(-1).data.cpu().numpy())

        i += 1

        result_all += save_tmp

        if i % 50 == 0:
            R1 = np.sum(np.array(rank_all_tmp)==1) / float(len(rank_all_tmp))
            R5 =  np.sum(np.array(rank_all_tmp)<=5) / float(len(rank_all_tmp))
            R10 = np.sum(np.array(rank_all_tmp)<=10) / float(len(rank_all_tmp))
            ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp))
            mrr = np.sum(1/(np.array(rank_all_tmp, dtype='float'))) / float(len(rank_all_tmp))
            print ('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' %(1, len(dataloader_val), mrr, R1, R5, R10, ave))

    return img_atten
Beispiel #14
0
def val():
    netE_g.eval()
    netE_d.eval()
    netW_g.eval()
    netW_d.eval()

    netG.eval()
    netD.eval()

    n_neg = 100
    ques_hidden1 = netE_g.init_hidden(opt.batchSize)
    ques_hidden2 = netE_d.init_hidden(opt.batchSize)

    hist_hidden1 = netE_d.init_hidden(opt.batchSize)
    hist_hidden2 = netE_g.init_hidden(opt.batchSize)

    opt_hidden = netD.init_hidden(opt.batchSize)
    data_iter_val = iter(dataloader_val)

    count = 0
    i = 0
    rank_G = []
    rank_D = []

    while i < len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                    opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data

        batch_size = question.size(0)
        image = image.view(-1, 512)

        img_input.data.resize_(image.size()).copy_(image)
        for rnd in range(10):

            # get the corresponding round QA and history.
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()

            opt_ans = opt_answer[:,rnd,:,:].clone().view(-1, ans_length).t()
            opt_tans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:,rnd]
            opt_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)
            opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans)
            opt_ans_target.data.resize_(opt_tans.size()).copy_(opt_tans)
            gt_index.data.resize_(gt_id.size()).copy_(gt_id)

            ques_emb_g = netW_g(ques_input, format = 'index')
            his_emb_g = netW_g(his_input, format = 'index')

            ques_emb_d = netW_d(ques_input, format = 'index')
            his_emb_d = netW_d(his_input, format = 'index')

            ques_hidden1 = repackage_hidden(ques_hidden1, batch_size)
            ques_hidden2 = repackage_hidden(ques_hidden2, batch_size)

            hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1))
            hist_hidden2 = repackage_hidden(hist_hidden2, his_emb_d.size(1))

            featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                ques_hidden1, hist_hidden1, rnd+1)

            featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \
                                                ques_hidden2, hist_hidden2, rnd+1)
            #featD = l2_norm(featD)
            # Evaluate the Generator:
            _, ques_hidden1 = netG(featG.view(1,-1,opt.ninp), ques_hidden1)
            #_, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden)
            # extend the hidden
            hidden_replicated = []
            for hid in ques_hidden1:
                hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \
                    opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid))
            hidden_replicated = tuple(hidden_replicated)

            ans_emb = netW_g(opt_ans_input, format = 'index')

            output, _ = netG(ans_emb, hidden_replicated)
            logprob = - output
            logprob_select = torch.gather(logprob, 1, opt_ans_target.view(-1,1))

            mask = opt_ans_target.data.eq(0)  # generate the mask
            if isinstance(logprob, Variable):
                mask = Variable(mask, volatile=logprob.volatile)
            logprob_select.masked_fill_(mask.view_as(logprob_select), 0)

            prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1,100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b*100

            gt_score = prob.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(prob, 1)

            count = sort_score.lt(gt_score.view(-1,1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_G += list(rank.view(-1).data.cpu().numpy())

            opt_ans_emb = netW_d(opt_ans_target, format = 'index')
            opt_hidden = repackage_hidden(opt_hidden, opt_ans_target.size(1))
            opt_feat = netD(opt_ans_emb, opt_ans_target, opt_hidden, vocab_size)
            opt_feat = opt_feat.view(batch_size, -1, opt.ninp)

            #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid)
            featD = featD.view(-1, opt.ninp, 1)
            score = torch.bmm(opt_feat, featD)
            score = score.view(-1, 100)

            gt_score = score.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(score, 1, descending=True)
            count = sort_score.gt(gt_score.view(-1,1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_D += list(rank.view(-1).data.cpu().numpy())

        i += 1






    return rank_G, rank_D
def val():
    netE_g.eval()
    netW_g.eval()
    netG.eval()

    n_neg = 100
    ques_hidden1 = netE_g.init_hidden(opt.batchSize)
    hist_hidden2 = netE_g.init_hidden(opt.batchSize)
    data_iter_val = iter(dataloader_val)

    count = 0
    i = 0
    rank_G = []

    while i < len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                    opt_answerT, answer_ids, answerLen, opt_answerLen, img_id  = data

        batch_size = question.size(0)
        image = image.view(-1, 512)

        img_input.data.resize_(image.size()).copy_(image)
        for rnd in range(10):

            # get the corresponding round QA and history.
            ques = question[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()

            opt_ans = opt_answer[:, rnd, :, :].clone().view(-1, ans_length).t()
            opt_tans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:, rnd]
            opt_len = opt_answerLen[:, rnd, :].clone().view(-1)

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)
            opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans)
            opt_ans_target.data.resize_(opt_tans.size()).copy_(opt_tans)
            gt_index.data.resize_(gt_id.size()).copy_(gt_id)

            ques_emb_g = netW_g(ques_input, format='index')
            his_emb_g = netW_g(his_input, format='index')

            ques_emb_d = netW_d(ques_input, format='index')
            his_emb_d = netW_d(his_input, format='index')

            ques_hidden1 = repackage_hidden(ques_hidden1, batch_size)
            ques_hidden2 = repackage_hidden(ques_hidden2, batch_size)

            hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1))
            hist_hidden2 = repackage_hidden(hist_hidden2, his_emb_d.size(1))

            featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                ques_hidden1, hist_hidden1, rnd+1)

            featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \
                                                ques_hidden2, hist_hidden2, rnd+1)
            #featD = l2_norm(featD)
            _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1)

            hidden_replicated = []
            for hid in ques_hidden1:
                hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \
                    opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid))
            hidden_replicated = tuple(hidden_replicated)

            ans_emb = netW_g(opt_ans_input, format='index')

            output, _ = netG(ans_emb, hidden_replicated)
            logprob = -output
            logprob_select = torch.gather(logprob, 1,
                                          opt_ans_target.view(-1, 1))

            mask = opt_ans_target.data.eq(0)  # generate the mask
            if isinstance(logprob, Variable):
                mask = Variable(mask, volatile=logprob.volatile)

            logprob_select.masked_fill_(mask.view_as(logprob_select), 0)

            prob = logprob_select.view(ans_length, -1,
                                       100).sum(0).view(-1, 100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b * 100

            gt_score = prob.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(prob, 1)

            count = sort_score.lt(gt_score.view(-1, 1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_G += list(rank.view(-1).data.cpu().numpy())

            opt_ans_emb = netW_d(opt_ans_target, format='index')
            opt_hidden = repackage_hidden(opt_hidden, opt_ans_target.size(1))
            opt_feat = netD(opt_ans_emb, opt_ans_target, opt_hidden,
                            vocab_size)
            opt_feat = opt_feat.view(batch_size, -1, opt.ninp)

            #ans_emb = ans_emb.view(ans_length, -1, 100, opt.nhid)
            featD = featD.view(-1, opt.ninp, 1)
            score = torch.bmm(opt_feat, featD)
            score = score.view(-1, 100)

            gt_score = score.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(score, 1, descending=True)
            count = sort_score.gt(gt_score.view(-1, 1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_D += list(rank.view(-1).data.cpu().numpy())
        i += 1

        if i % 50 == 0:
            R1 = np.sum(np.array(rank_G) == 1) / float(len(rank_G))
            R5 = np.sum(np.array(rank_G) <= 5) / float(len(rank_G))
            R10 = np.sum(np.array(rank_G) <= 10) / float(len(rank_G))
            ave = np.sum(np.array(rank_G)) / float(len(rank_G))
            mrr = np.sum(1 / (np.array(rank_G, dtype='float'))) / float(
                len(rank_G))
            print('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' %
                  (1, len(dataloader_val), mrr, R1, R5, R10, ave))

    return rank_G, rank_D
Beispiel #16
0
def train(epoch):
    netW.train()
    netE.train()
    netD.train()

    lr = adjust_learning_rate(optimizer, epoch, opt.lr)

    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    real_hidden = netD.init_hidden(opt.batchSize)
    wrong_hidden = netD.init_hidden(opt.batchSize)

    data_iter = iter(dataloader)
    bar = progressbar.ProgressBar(maxval=len(dataloader))

    average_loss = 0
    count = 0
    i = 0

    while i < len(dataloader):

        t1 = time.time()
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
                                    opt_answerT, opt_answerLen, opt_answerIdx = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            netW.zero_grad()
            netE.zero_grad()
            netD.zero_grad()
            # get the corresponding round QA and history.
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()

            ans = answer[:,rnd,:].t()
            tans = answerT[:,rnd,:].t()
            wrong_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()

            real_len = answerLen[:,rnd]
            wrong_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)
            wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans)

            # sample in-batch negative index
            batch_sample_idx.data.resize_(batch_size, opt.neg_batch_sample).zero_()
            sample_batch_neg(answerIdx[:,rnd], opt_answerIdx[:,rnd,:], batch_sample_idx, opt.neg_batch_sample)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            ans_real_emb = netW(ans_target, format='index')
            ans_wrong_emb = netW(wrong_ans_input, format='index')

            real_hidden = repackage_hidden(real_hidden, batch_size)
            wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1))

            real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size)
            wrong_feat = netD(ans_wrong_emb, wrong_ans_input, wrong_hidden, vocab_size)

            batch_wrong_feat = wrong_feat.index_select(0, batch_sample_idx.view(-1))
            wrong_feat = wrong_feat.view(batch_size, -1, opt.ninp)
            batch_wrong_feat = batch_wrong_feat.view(batch_size, -1, opt.ninp)

            nPairLoss = critD(featD, real_feat, wrong_feat, batch_wrong_feat)

            average_loss += nPairLoss.data[0]
            nPairLoss.backward()
            optimizer.step()
            count += 1

        bar.update(i)
        i += 1

    average_loss = average_loss / count

    return average_loss, lr
Beispiel #17
0
def train(epoch):
    netW_d.train(), netE_d.train(), netE_g.train()
    netD.train(), netG.train(), netW_g.train()

    fake_len = torch.LongTensor(opt.batchSize)
    fake_len = fake_len.cpu()

    n_neg = opt.negative_sample
    ques_hidden1 = netE_d.init_hidden(opt.batchSize)
    ques_hidden2 = netE_g.init_hidden(opt.batchSize)

    hist_hidden1 = netE_d.init_hidden(opt.batchSize)
    hist_hidden2 = netE_g.init_hidden(opt.batchSize)

    real_hidden = netD.init_hidden(opt.batchSize)
    wrong_hidden = netD.init_hidden(opt.batchSize)
    fake_hidden = netD.init_hidden(opt.batchSize)

    data_iter = iter(dataloader)

    err_d = 0
    err_g = 0
    err_lm = 0
    average_loss = 0
    count = 0
    i = 0
    loss_store = []
    while i < len(dataloader):
        t1 = time.time()
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
                                    opt_answerT, opt_answerLen, opt_answerIdx = data
        batch_size = question.size(0)
        image = image.view(-1, 512)
        img_input.data.resize_(image.size()).copy_(image)

        err_d_tmp = 0
        err_g_tmp = 0
        err_lm_tmp = 0
        err_d_fake_tmp = 0
        err_g_fake_tmp = 0
        for rnd in range(10):
            # get the corresponding round QA and history.
            ques = question[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()

            ans = answer[:, rnd, :].t()
            tans = answerT[:, rnd, :].t()
            wrong_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()

            real_len = answerLen[:, rnd].long()
            wrong_len = opt_answerLen[:, rnd, :].clone().view(-1)

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)
            wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans)

            batch_sample_idx.data.resize_(batch_size,
                                          opt.neg_batch_sample).zero_()
            sample_batch_neg(answerIdx[:, rnd], opt_answerIdx[:, rnd, :],
                             batch_sample_idx, opt.neg_batch_sample)

            # -----------------------------------------
            # update the Generator using MLE loss.
            # -----------------------------------------
            if opt.update_LM:
                ques_emb_g = netW_g(ques_input, format='index')
                his_emb_g = netW_g(his_input, format='index')

                ques_hidden1 = repackage_hidden(ques_hidden1, batch_size)
                hist_hidden1 = repackage_hidden(hist_hidden1,
                                                his_emb_g.size(1))

                featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                    ques_hidden1, hist_hidden1, rnd+1)

                _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp),
                                       ques_hidden1)
                # MLE loss for generator
                ans_emb = netW_g(ans_input)
                logprob, _ = netG(ans_emb, ques_hidden1)
                lm_loss = critLM(logprob, ans_target.view(-1, 1))
                lm_loss = lm_loss / torch.sum(ans_target.data.gt(0))

                netW_g.zero_grad()
                netG.zero_grad()
                netE_g.zero_grad()

                lm_loss.backward()
                optimizerLM.step()
                err_lm += lm_loss.data[0]
                err_lm_tmp += lm_loss.data[0]

            # sample the answer using gumble softmax sampler.
            ques_emb_g = netW_g(ques_input, format='index')
            his_emb_g = netW_g(his_input, format='index')

            ques_hidden1 = repackage_hidden(ques_hidden1, batch_size)
            hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1))

            featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                ques_hidden1, hist_hidden1, rnd+1)

            _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1)

            # Gumble softmax to sample the output.
            fake_onehot = []
            fake_idx = []
            noise_input.data.resize_(ans_length, batch_size, vocab_size + 1)
            noise_input.data.uniform_(0, 1)

            ans_sample = ans_input[0]
            for di in range(ans_length):
                ans_emb = netW_g(ans_sample, format='index')
                logprob, ques_hidden1 = netG(ans_emb.view(1, -1, opt.ninp),
                                             ques_hidden1)
                one_hot, idx = sampler(logprob, noise_input[di],
                                       opt.gumble_weight)
                fake_onehot.append(one_hot.view(1, -1, vocab_size + 1))
                fake_idx.append(idx)
                if di + 1 < ans_length:
                    ans_sample = idx

            # convert the list into the tensor variable.
            fake_onehot = torch.cat(fake_onehot, 0)
            fake_idx = torch.cat(fake_idx, 0)

            fake_len.resize_(batch_size).fill_(ans_length - 1)
            for di in range(ans_length - 1, 0, -1):
                fake_len.masked_fill_(fake_idx.data[di].eq(vocab_size), di)

            # generate fake mask.
            fake_mask.data.resize_(fake_idx.size()).fill_(1)
            # get the real, wrong and fake index.
            for b in range(batch_size):
                fake_mask.data[:fake_len[b] + 1, b] = 0

            # apply the mask on the fake_idx.
            fake_idx.masked_fill_(fake_mask, 0)

            # get the fake diff mask.
            #fake_diff_mask = torch.sum(fake_idx == ans_target, 0) != 0
            fake_onehot = fake_onehot.view(-1, vocab_size + 1)

            ######################################
            # Discriminative trained generative model.
            ######################################
            # forward the discriminator again.
            ques_emb_d = netW_d(ques_input, format='index')
            his_emb_d = netW_d(his_input, format='index')

            ques_hidden2 = repackage_hidden(ques_hidden2, batch_size)
            hist_hidden2 = repackage_hidden(hist_hidden2, his_emb_d.size(1))

            featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \
                                        ques_hidden2, hist_hidden2, rnd+1)

            ans_real_emb = netW_d(ans_target, format='index')
            #ans_wrong_emb = netW_d(wrong_ans_input, format='index')
            ans_fake_emb = netW_d(fake_onehot, format='onehot')
            ans_fake_emb = ans_fake_emb.view(ans_length, -1, opt.ninp)

            real_hidden = repackage_hidden(real_hidden, batch_size)
            #wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1))
            fake_hidden = repackage_hidden(fake_hidden, batch_size)

            fake_feat = netD(ans_fake_emb, fake_idx, fake_hidden, vocab_size)
            real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size)

            d_g_loss, g_fake = critG(featD, real_feat,
                                     fake_feat)  #, fake_diff_mask.detach())

            netW_g.zero_grad()
            netG.zero_grad()
            netE_g.zero_grad()

            d_g_loss.backward()
            optimizerG.step()

            err_g += d_g_loss.data[0]
            err_g_tmp += d_g_loss.data[0]
            err_g_fake_tmp += g_fake

            count += 1

        i += 1
        loss_store.append({'iter':i, 'err_lm':err_lm_tmp/10, 'err_d':err_d_tmp/10, 'err_g':err_g_tmp/10, \
                            'd_fake': err_d_fake_tmp/10, 'g_fake':err_g_fake_tmp/10})

        if i % 20 == 0:
            print ('Epoch:%d %d/%d, err_lm %4f, err_d %4f, err_g %4f, d_fake %4f, g_fake %4f' \
                % (epoch, i, len(dataloader), err_lm_tmp/10, err_d_tmp/10, err_g_tmp/10, err_d_fake_tmp/10, \
                    err_g_fake_tmp/10))

    #average_loss = average_loss / count
    err_g = err_g / count
    err_d = err_d / count
    err_lm = err_lm / count

    return err_lm, err_d, err_g, loss_store
def eval():
    netW.eval()
    netE.eval()
    netD.eval()

    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    opt_hidden = netD.init_hidden(opt.batchSize)
    i = 0
    display_count = 0
    average_loss = 0
    rank_all_tmp = []
    result_all = []
    img_atten = torch.FloatTensor(100 * 30, 10, 7, 7)
    while i < len(dataloader_val):#len(1000):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                opt_answerT, answer_ids, answerLen, opt_answerLen, img_id  = data

        batch_size = question.size(0)
        image = image.view(-1, 512)
        img_input.data.resize_(image.size()).copy_(image)
        save_tmp = [[] for j in range(batch_size)]

        for rnd in range(10):
            # get the corresponding round QA and history.
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()

            opt_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:,rnd]

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            gt_index.data.resize_(gt_id.size()).copy_(gt_id)
            opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans)

            opt_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                                    ques_hidden, hist_hidden, rnd+1)

            #img_atten[i*batch_size:(i+1)*batch_size, rnd, :] = img_atten_weight.data.view(batch_size, 7, 7)

            opt_ans_emb = netW(opt_ans_input, format = 'index')
            opt_hidden = repackage_hidden(opt_hidden, opt_ans_input.size(1))
            opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, n_words)
            opt_feat = opt_feat.view(batch_size, -1, opt.ninp)

            featD = featD.view(-1, opt.ninp, 1)
            score = torch.bmm(opt_feat, featD)
            score = score.view(-1, 100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b*100

            gt_score = score.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(score, 1, descending=True)

            count = sort_score.gt(gt_score.view(-1,1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_all_tmp += list(rank.view(-1).data.cpu().numpy())

        i += 1

        result_all += save_tmp

        if i % 50 == 0:
            R1 = np.sum(np.array(rank_all_tmp)==1) / float(len(rank_all_tmp))
            R5 =  np.sum(np.array(rank_all_tmp)<=5) / float(len(rank_all_tmp))
            R10 = np.sum(np.array(rank_all_tmp)<=10) / float(len(rank_all_tmp))
            ave = np.sum(np.array(rank_all_tmp)) / float(len(rank_all_tmp))
            mrr = np.sum(1/(np.array(rank_all_tmp, dtype='float'))) / float(len(rank_all_tmp))
            print ('%d/%d: mrr: %f R1: %f R5 %f R10 %f Mean %f' %(1, len(dataloader_val), mrr, R1, R5, R10, ave))

    return img_atten
Beispiel #19
0
def train(epoch):
    netW.train()
    netE.train()
    netG.train()

    lr = adjust_learning_rate(optimizer, epoch, opt.lr)
    data_iter = iter(dataloader)

    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)
    average_loss = 0
    count = 0
    i = 0
    while i < len(dataloader):
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, \
        questionL, negAnswer, negAnswerLen, negAnswerIdx = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()
            ans, tans = answer[:,rnd,:].t(), answerT[:,rnd,:].t()

            his_input.data.resize_(his.size()).copy_(his)
            ques_input.data.resize_(ques.size()).copy_(ques)
            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            _, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden)

            ans_emb = netW(ans_input)
            logprob, ques_hidden = netG(ans_emb, ques_hidden)
            loss = critG(logprob, ans_target.view(-1, 1))

            loss = loss / torch.sum(ans_target.data.gt(0))
            average_loss += loss.data[0]
            # do backward.
            netW.zero_grad()
            netE.zero_grad()
            netG.zero_grad()
            loss.backward()

            optimizer.step()
            count += 1
        i += 1
        if i % opt.log_interval == 0:
            average_loss /= count
            print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}"\
                .format(i, len(dataloader), epoch, average_loss, lr))
            average_loss = 0
            count = 0

    return average_loss
Beispiel #20
0
def val():
    netE.eval()
    netW.eval()
    netG.eval()

    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    i = 0
    average_loss = 0
    rank_all_tmp = []

    while i < len(dataloader_val):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                    opt_answerT, answer_ids, answerLen, opt_answerLen, img_id = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            # get the corresponding round QA and history.
            ques, tans = question[:,rnd,:].t(), opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()
            ans = opt_answer[:,rnd,:,:].clone().view(-1, ans_length).t()
            gt_id = answer_ids[:,rnd]

            his_input.data.resize_(his.size()).copy_(his)
            ques_input.data.resize_(ques.size()).copy_(ques)
            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)

            gt_index.data.resize_(gt_id.size()).copy_(gt_id)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            encoder_feat, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            _, ques_hidden = netG(encoder_feat.view(1,-1,opt.ninp), ques_hidden)

            hidden_replicated = []
            for hid in ques_hidden:
                hidden_replicated.append(hid.view(opt.nlayers, batch_size, 1, \
                    opt.nhid).expand(opt.nlayers, batch_size, 100, opt.nhid).clone().view(opt.nlayers, -1, opt.nhid))
            hidden_replicated = tuple(hidden_replicated)

            ans_emb = netW(ans_input, format = 'index')

            output, _ = netG(ans_emb, hidden_replicated)
            logprob = - output
            logprob_select = torch.gather(logprob, 1, ans_target.view(-1,1))

            mask = ans_target.data.eq(0)  # generate the mask
            if isinstance(logprob, Variable):
                mask = Variable(mask, volatile=logprob.volatile)
            logprob_select.masked_fill_(mask.view_as(logprob_select), 0)

            prob = logprob_select.view(ans_length, -1, 100).sum(0).view(-1,100)

            for b in range(batch_size):
                gt_index.data[b] = gt_index.data[b] + b*100

            gt_score = prob.view(-1).index_select(0, gt_index)
            sort_score, sort_idx = torch.sort(prob, 1)

            count = sort_score.lt(gt_score.view(-1,1).expand_as(sort_score))
            rank = count.sum(1) + 1
            rank_all_tmp += list(rank.view(-1).data.cpu().numpy())

        i += 1

    return rank_all_tmp, average_loss
Beispiel #21
0
def eval():
    netW.eval()
    netE.eval()
    netD.eval()

    data_iter_val = iter(dataloader_val)
    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)
    opt_hidden = netD.init_hidden(opt.batchSize)

    i = 0
    ranks_json = []

    while i < len(dataloader_val):  #len(1000):
        data = data_iter_val.next()
        image, history, question, answer, answerT, questionL, opt_answer, \
                opt_answerT, answerLen, opt_answerLen, img_id, rounds  = data
        batch_size = question.size(0)
        image = image.view(-1, 36, 2048)  #   image : batchx36x2048
        img_input.data.resize_(image.size()).copy_(image)

        #image2 = image2.view(-1, img_feat_size) #   image : 6272(128x7x7) x 512
        #img_input2.data.resize_(image2.size()).copy_(image2) #6272x512
        save_tmp = [[] for j in range(batch_size)]

        rnd = int(rounds - 1)
        ques = question[:, rnd, :].t()
        his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()
        opt_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()

        ques_input.data.resize_(ques.size()).copy_(ques)
        his_input.data.resize_(his.size()).copy_(his)

        opt_ans_input.data.resize_(opt_ans.size()).copy_(opt_ans)
        opt_len = opt_answerLen[:, rnd, :].clone().view(-1)

        ques_emb = netW(ques_input, format='index')
        his_emb = netW(his_input, format='index')

        ques_hidden = repackage_hidden(ques_hidden, batch_size)
        hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

        featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                                    ques_hidden, hist_hidden, rnd+1)

        opt_ans_emb = netW(opt_ans_input, format='index')
        opt_hidden = repackage_hidden(opt_hidden, opt_ans_input.size(1))
        opt_feat = netD(opt_ans_emb, opt_ans_input, opt_hidden, n_words)
        opt_feat = opt_feat.view(batch_size, -1, opt.ninp)

        featD = featD.view(-1, opt.ninp, 1)
        score = torch.bmm(opt_feat, featD)
        score = score.view(-1, 100)

        ranks = scores_to_ranks(score)

        ranks_json.append({
            "image_id": int(img_id),
            "round_id": int(rounds),
            "ranks": ranks.view(-1).tolist()
        })

        i += 1
        sys.stdout.write('Evaluating: {:d}/{:d}  \r' \
          .format(i, len(dataloader_val)))
        sys.stdout.flush()
        #pdb.set_trace()
    return ranks_json