def train(epoch, lr):
    print("in train()")
    # exit()
    src_netW.train()
    tgt_netW.train()
    src_netE_att.train()
    tgt_netE_att.train()

    lr = adjust_learning_rate(optimizer, epoch, lr)
    # print ("in train() 2")

    data_iter = iter(dataloader)
    average_loss = 0
    count = 0
    # i = 0
    print("train_D train()", len(dataloader))
    data_size = len(dataloader)
    t = time.time()
    for i in range(data_size):
        # print ("i", i)
        t1 = time.time()
        data = data_iter.next()
        _, question, answer, answerT, answerLen, answerIdx, questionL, _, opt_answerT, opt_answerLen, opt_answerIdx, answer_ids = data
        # print("answerLen", answerLen)
        # print("answerT", answerT)
        # print("opt_answerT", opt_answerT)
        # print("question", question.size())
        batch_size = question.size(0)
        # print ("batch_size", batch_size)

        # rnd = 0
        src_netW.zero_grad()
        tgt_netW.zero_grad()
        src_netE_att.zero_grad()
        tgt_netE_att.zero_grad()
        ques = question[:, :].t()

        # ans = answer[:,rnd,:].t()
        tans = answerT[:, :].t()

        # print("train_att_D.py opt_answerT", opt_answerT.size())
        wrong_ans = opt_answerT[:, :].clone().view(-1, tgt_length).t()
        # print("train_att_D.py wrong_ans", wrong_ans.size())

        # real_len = answerLen[:]
        wrong_len = opt_answerLen[:, :].clone().view(-1)

        ques_input.data.resize_(ques.size()).copy_(ques)

        # ans_input.data.resize_(ans.size()).copy_(ans)
        ans_target.data.resize_(tans.size()).copy_(tans)
        # print("ans_target", ans_target.size(), ans_target)
        # print("ans_target", ans_target)
        # exit()
        # print("train_att_D.py wrong_ans_input", wrong_ans_input.size())

        wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans)
        # print("train_att_D.py wrong_ans_input", wrong_ans_input.size())

        # sample in-batch negative index
        batch_sample_idx.data.resize_(batch_size, opt.neg_batch_sample).zero_()
        sample_batch_neg(answerIdx[:], opt_answerIdx[:, :], batch_sample_idx,
                         opt.neg_batch_sample)

        ques_emb = src_netW(ques_input, format='index')
        # print("train_att_D.py ques_emb", ques_emb.size(), ques_emb.device)
        src_featD, _ = src_netE_att(ques_emb, ques_input)
        # print("train_att_D.py src_featD", src_featD.size(), src_featD.device)

        ans_real_emb = tgt_netW(ans_target, format='index')
        ans_wrong_emb = tgt_netW(wrong_ans_input, format='index')

        tgt_real_feat, _weight_ = tgt_netE_att(ans_real_emb, ans_target)
        tgt_wrong_feat, _weight_ = tgt_netE_att(ans_wrong_emb, wrong_ans_input)
        # print("train_att_D.py train() tgt_wrong_feat _weight_", _weight_, _weight_.size())

        # tgt_wrong_feat, _ = tgt_netE_att(src_featD, ans_wrong_emb, wrong_ans_input, wrong_hidden, tgt_vocab_size)
        # exit()
        # print("tgt_wrong_feat", tgt_wrong_feat.size())
        # print("batch_sample_idx", batch_sample_idx.size())
        # print("batch_sample_idx.view(-1)", batch_sample_idx.view(-1).size())
        batch_wrong_feat = tgt_wrong_feat.index_select(
            0, batch_sample_idx.view(-1))
        tgt_wrong_feat = tgt_wrong_feat.view(batch_size, -1, opt.rnn_size)
        batch_wrong_feat = batch_wrong_feat.view(batch_size, -1, opt.rnn_size)
        # print("src_featD", src_featD.size())
        # print("tgt_real_feat", tgt_real_feat.size())
        # print("tgt_wrong_feat", tgt_wrong_feat.size())
        # print("batch_wrong_feat", batch_wrong_feat.size())
        # exit()
        nPairLoss = critD(src_featD, tgt_real_feat, tgt_wrong_feat,
                          batch_wrong_feat)

        average_loss += nPairLoss.data.item()
        nPairLoss.backward()
        optimizer.step()
        count += 1

        # i += 1
        if i % opt.log_interval == 0:
            average_loss /= count
            print(
                "step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f} Time: {:.3f}"
                .format(i, data_size, epoch, average_loss, lr,
                        time.time() - t))
            average_loss = 0
            count = 0
            t = time.time()

        # if i > 10: break

    print("finish train()")
    return average_loss, lr
Пример #2
0
def train(epoch):
    netW.train()
    netE.train()
    netD.train()

    lr = adjust_learning_rate(optimizer, epoch, opt.lr)

    ques_hidden = netE.init_bi_hidden(opt.batchSize)
    hist_hidden = netE.init_bi_hidden(opt.batchSize)

    real_hidden = netD.init_hidden(opt.batchSize)
    wrong_hidden = netD.init_hidden(opt.batchSize)

    data_iter = iter(dataloader)
    num = len(dataloader)

    average_loss = 0
    total_loss = 0
    count = 0
    i = 0
    while i < num:

        data = data_iter.next()
        image, history, question, answer, answerLen, answerIdx, questionL, \
        opt_answerT, opt_answerLen, opt_answerIdx = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            netW.zero_grad()
            netE.zero_grad()
            netD.zero_grad()

            # get the corresponding round QA and history.
            ques = question[:, rnd, :]
            his = history[:, :rnd + 1, :].clone().view(-1, his_length)

            tans = answer[:, rnd, :]
            # tans = answerT[:, rnd, :]
            wrong_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length)

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            # ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)
            wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans)

            # sample in-batch negative index
            batch_sample_idx.data.resize_(batch_size,
                                          opt.neg_batch_sample).zero_()
            sample_batch_neg(answerIdx[:, rnd], opt_answerIdx[:, rnd, :],
                             batch_sample_idx, opt.neg_batch_sample)

            ques_emb = netW(ques_input, format='index')
            his_emb = netW(his_input, format='index')

            featD, _ = netE(ques_emb, his_emb, img_input, ques_input,
                            his_input, rnd + 1)

            ans_real_emb = netW(ans_target, format='index')
            ans_wrong_emb = netW(wrong_ans_input, format='index')

            # real_hidden = repackage_hidden(real_hidden, batch_size)
            # wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1))

            real_feat = netD(ans_real_emb, ans_target, vocab_size)
            wrong_feat = netD(ans_wrong_emb, wrong_ans_input, vocab_size)

            batch_wrong_feat = wrong_feat.index_select(
                0, batch_sample_idx.view(-1))
            wrong_feat = wrong_feat.view(batch_size, -1, opt.ninp)
            batch_wrong_feat = batch_wrong_feat.view(batch_size, -1, opt.ninp)

            nPairLoss = critD(featD, real_feat, wrong_feat, batch_wrong_feat)

            average_loss += nPairLoss.data[0]
            nPairLoss.backward()

            optimizer.step()
            count += 1

        i += 1
        if i % opt.log_interval == 0:
            average_loss /= count
            print("step {} / {} (epoch {}), loss {:.3f}, lr = {:.6f}" \
                  .format(i, len(dataloader), epoch, average_loss, lr))
            total_loss = total_loss + average_loss
            average_loss = 0
            count = 0

    return total_loss, lr
Пример #3
0
def train(epoch):
    netW.train()
    netE.train()
    netD.train()

    lr = adjust_learning_rate(optimizer, epoch, opt.lr)

    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    real_hidden = netD.init_hidden(opt.batchSize)
    wrong_hidden = netD.init_hidden(opt.batchSize)

    data_iter = iter(dataloader)
    bar = progressbar.ProgressBar(maxval=len(dataloader))

    average_loss = 0
    count = 0
    i = 0

    while i < len(dataloader):

        t1 = time.time()
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
                                    opt_answerT, opt_answerLen, opt_answerIdx = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            netW.zero_grad()
            netE.zero_grad()
            netD.zero_grad()
            # get the corresponding round QA and history.
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()

            ans = answer[:,rnd,:].t()
            tans = answerT[:,rnd,:].t()
            wrong_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()

            real_len = answerLen[:,rnd]
            wrong_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)
            wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans)

            # sample in-batch negative index
            batch_sample_idx.data.resize_(batch_size, opt.neg_batch_sample).zero_()
            sample_batch_neg(answerIdx[:,rnd], opt_answerIdx[:,rnd,:], batch_sample_idx, opt.neg_batch_sample)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            ans_real_emb = netW(ans_target, format='index')
            ans_wrong_emb = netW(wrong_ans_input, format='index')

            real_hidden = repackage_hidden(real_hidden, batch_size)
            wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1))

            real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size)
            wrong_feat = netD(ans_wrong_emb, wrong_ans_input, wrong_hidden, vocab_size)

            batch_wrong_feat = wrong_feat.index_select(0, batch_sample_idx.view(-1))
            wrong_feat = wrong_feat.view(batch_size, -1, opt.ninp)
            batch_wrong_feat = batch_wrong_feat.view(batch_size, -1, opt.ninp)

            nPairLoss = critD(featD, real_feat, wrong_feat, batch_wrong_feat)

            average_loss += nPairLoss.data[0]
            nPairLoss.backward()
            optimizer.step()
            count += 1

        bar.update(i)
        i += 1

    average_loss = average_loss / count

    return average_loss, lr
Пример #4
0
def train(epoch):
    netW.train()
    netE.train()
    netD.train()

    lr = adjust_learning_rate(optimizer, epoch, opt.lr)

    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    real_hidden = netD.init_hidden(opt.batchSize)
    wrong_hidden = netD.init_hidden(opt.batchSize)

    data_iter = iter(dataloader)

    average_loss = 0
    count = 0
    i = 0

    while i < len(dataloader):

        t1 = time.time()
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
        opt_answerT, opt_answerLen, opt_answerIdx = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)

        with torch.no_grad():
            img_input.resize_(image.size()).copy_(image)

        for rnd in range(10):
            netW.zero_grad()
            netE.zero_grad()
            netD.zero_grad()
            # get the corresponding round QA and history.
            ques = question[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()

            ans = answer[:, rnd, :].t()
            tans = answerT[:, rnd, :].t()
            wrong_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()

            real_len = answerLen[:, rnd]
            wrong_len = opt_answerLen[:, rnd, :].clone().view(-1)

            ques_input = torch.LongTensor(ques.size())
            ques_input.copy_(ques)

            his_input = torch.LongTensor(his.size())
            his_input.copy_(his)

            ans_input = torch.LongTensor(ans.size())
            ans_input.copy_(ans)

            ans_target = torch.LongTensor(tans.size())
            ans_target.copy_(tans)

            wrong_ans_input = torch.LongTensor(wrong_ans.size())
            wrong_ans_input.copy_(wrong_ans)

            # sample in-batch negative index
            batch_sample_idx = torch.zeros(batch_size,
                                           opt.neg_batch_sample,
                                           dtype=torch.long)
            sample_batch_neg(answerIdx[:, rnd], opt_answerIdx[:, rnd, :],
                             batch_sample_idx, opt.neg_batch_sample)

            ques_emb = netW(ques_input, format='index')
            his_emb = netW(his_input, format='index')

            ques_hidden = repackage_hidden_new(ques_hidden, batch_size)
            hist_hidden = repackage_hidden_new(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                      ques_hidden, hist_hidden, rnd + 1)

            ans_real_emb = netW(ans_target, format='index')
            ans_wrong_emb = netW(wrong_ans_input, format='index')

            real_hidden = repackage_hidden_new(real_hidden, batch_size)
            wrong_hidden = repackage_hidden_new(wrong_hidden,
                                                ans_wrong_emb.size(1))

            real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size)
            wrong_feat = netD(ans_wrong_emb, wrong_ans_input, wrong_hidden,
                              vocab_size)

            batch_wrong_feat = wrong_feat.index_select(
                0, batch_sample_idx.view(-1))
            wrong_feat = wrong_feat.view(
                batch_size, -1,
                opt.ninp)  # (batch_size, negative_sample, ninp)
            batch_wrong_feat = batch_wrong_feat.view(
                batch_size, -1,
                opt.ninp)  # (batch_size, crossover_negative_sample, ninp)

            # All the correct answers are persent at the begining of the combined_scores
            combined_scores, l2_norm = feat2score(
                featD, real_feat, wrong_feat,
                batch_wrong_feat)  # (batch_size, 1 + n_s + n_b_s)
            lambs, nPairLoss = critD(combined_scores)

            average_loss += nPairLoss.data.item()
            l2_norm.backward(retain_graph=True)
            combined_scores.backward(lambs)
            optimizer.step()
            count += 1

        i += 1
        if i % opt.log_interval == 0:
            average_loss /= count
            print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}" \
                  .format(i, len(dataloader), epoch, average_loss, lr))
            average_loss = 0
            count = 0

    return average_loss, lr
Пример #5
0
def train(epoch):
    netW.train()
    netE.train()
    netD.train()

    lr = adjust_learning_rate(optimizer, epoch, opt.lr)

    ques_hidden = netE.init_hidden(opt.batchSize)
    hist_hidden = netE.init_hidden(opt.batchSize)

    real_hidden = netD.init_hidden(opt.batchSize)
    wrong_hidden = netD.init_hidden(opt.batchSize)

    data_iter = iter(dataloader)

    average_loss = 0
    count = 0
    i = 0

    while i < len(dataloader):

        t1 = time.time()
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
                                    opt_answerT, opt_answerLen, opt_answerIdx = data

        batch_size = question.size(0)
        image = image.view(-1, img_feat_size)
        img_input.data.resize_(image.size()).copy_(image)

        for rnd in range(10):
            netW.zero_grad()
            netE.zero_grad()
            netD.zero_grad()
            # get the corresponding round QA and history.
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()

            ans = answer[:,rnd,:].t()
            tans = answerT[:,rnd,:].t()
            wrong_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()

            real_len = answerLen[:,rnd]
            wrong_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)
            wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans)

            # sample in-batch negative index
            batch_sample_idx.data.resize_(batch_size, opt.neg_batch_sample).zero_()
            sample_batch_neg(answerIdx[:,rnd], opt_answerIdx[:,rnd,:], batch_sample_idx, opt.neg_batch_sample)

            ques_emb = netW(ques_input, format = 'index')
            his_emb = netW(his_input, format = 'index')

            ques_hidden = repackage_hidden(ques_hidden, batch_size)
            hist_hidden = repackage_hidden(hist_hidden, his_input.size(1))

            featD, ques_hidden = netE(ques_emb, his_emb, img_input, \
                                                ques_hidden, hist_hidden, rnd+1)

            ans_real_emb = netW(ans_target, format='index')
            ans_wrong_emb = netW(wrong_ans_input, format='index')

            real_hidden = repackage_hidden(real_hidden, batch_size)
            wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1))

            real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size)
            wrong_feat = netD(ans_wrong_emb, wrong_ans_input, wrong_hidden, vocab_size)

            batch_wrong_feat = wrong_feat.index_select(0, batch_sample_idx.view(-1))
            wrong_feat = wrong_feat.view(batch_size, -1, opt.ninp)
            batch_wrong_feat = batch_wrong_feat.view(batch_size, -1, opt.ninp)

            nPairLoss = critD(featD, real_feat, wrong_feat, batch_wrong_feat)

            average_loss += nPairLoss.data[0]
            nPairLoss.backward()
            optimizer.step()
            count += 1

        i += 1
        if i % opt.log_interval == 0:
            average_loss /= count
            print("step {} / {} (epoch {}), g_loss {:.3f}, lr = {:.6f}"\
                .format(i, len(dataloader), epoch, average_loss, lr))
            average_loss = 0
            count = 0

    return average_loss
Пример #6
0
def train(epoch):
    netW_d.train(), netE_d.train(), netE_g.train()
    netD.train(), netG.train(), netW_g.train()

    fake_len = torch.LongTensor(opt.batchSize)
    fake_len = fake_len.cpu()

    n_neg = opt.negative_sample
    ques_hidden1 = netE_d.init_hidden(opt.batchSize)
    ques_hidden2 = netE_g.init_hidden(opt.batchSize)

    hist_hidden1 = netE_d.init_hidden(opt.batchSize)
    hist_hidden2 = netE_g.init_hidden(opt.batchSize)

    real_hidden = netD.init_hidden(opt.batchSize)
    wrong_hidden = netD.init_hidden(opt.batchSize)
    fake_hidden = netD.init_hidden(opt.batchSize)

    data_iter = iter(dataloader)

    err_d = 0
    err_g = 0
    err_lm = 0
    average_loss = 0
    count = 0
    i = 0
    loss_store = []
    while i < len(dataloader):
        t1 = time.time()
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
                                    opt_answerT, opt_answerLen, opt_answerIdx = data
        batch_size = question.size(0)
        image = image.view(-1, 512)
        img_input.data.resize_(image.size()).copy_(image)

        err_d_tmp = 0
        err_g_tmp = 0
        err_lm_tmp = 0
        err_d_fake_tmp = 0
        err_g_fake_tmp = 0
        for rnd in range(10):
            # get the corresponding round QA and history.
            ques = question[:, rnd, :].t()
            his = history[:, :rnd + 1, :].clone().view(-1, his_length).t()

            ans = answer[:, rnd, :].t()
            tans = answerT[:, rnd, :].t()
            wrong_ans = opt_answerT[:, rnd, :].clone().view(-1, ans_length).t()

            real_len = answerLen[:, rnd].long()
            wrong_len = opt_answerLen[:, rnd, :].clone().view(-1)

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)
            wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans)

            batch_sample_idx.data.resize_(batch_size,
                                          opt.neg_batch_sample).zero_()
            sample_batch_neg(answerIdx[:, rnd], opt_answerIdx[:, rnd, :],
                             batch_sample_idx, opt.neg_batch_sample)

            # -----------------------------------------
            # update the Generator using MLE loss.
            # -----------------------------------------
            if opt.update_LM:
                ques_emb_g = netW_g(ques_input, format='index')
                his_emb_g = netW_g(his_input, format='index')

                ques_hidden1 = repackage_hidden(ques_hidden1, batch_size)
                hist_hidden1 = repackage_hidden(hist_hidden1,
                                                his_emb_g.size(1))

                featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                    ques_hidden1, hist_hidden1, rnd+1)

                _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp),
                                       ques_hidden1)
                # MLE loss for generator
                ans_emb = netW_g(ans_input)
                logprob, _ = netG(ans_emb, ques_hidden1)
                lm_loss = critLM(logprob, ans_target.view(-1, 1))
                lm_loss = lm_loss / torch.sum(ans_target.data.gt(0))

                netW_g.zero_grad()
                netG.zero_grad()
                netE_g.zero_grad()

                lm_loss.backward()
                optimizerLM.step()
                err_lm += lm_loss.data[0]
                err_lm_tmp += lm_loss.data[0]

            # sample the answer using gumble softmax sampler.
            ques_emb_g = netW_g(ques_input, format='index')
            his_emb_g = netW_g(his_input, format='index')

            ques_hidden1 = repackage_hidden(ques_hidden1, batch_size)
            hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1))

            featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                ques_hidden1, hist_hidden1, rnd+1)

            _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1)

            # Gumble softmax to sample the output.
            fake_onehot = []
            fake_idx = []
            noise_input.data.resize_(ans_length, batch_size, vocab_size + 1)
            noise_input.data.uniform_(0, 1)

            ans_sample = ans_input[0]
            for di in range(ans_length):
                ans_emb = netW_g(ans_sample, format='index')
                logprob, ques_hidden1 = netG(ans_emb.view(1, -1, opt.ninp),
                                             ques_hidden1)
                one_hot, idx = sampler(logprob, noise_input[di],
                                       opt.gumble_weight)
                fake_onehot.append(one_hot.view(1, -1, vocab_size + 1))
                fake_idx.append(idx)
                if di + 1 < ans_length:
                    ans_sample = idx

            # convert the list into the tensor variable.
            fake_onehot = torch.cat(fake_onehot, 0)
            fake_idx = torch.cat(fake_idx, 0)

            fake_len.resize_(batch_size).fill_(ans_length - 1)
            for di in range(ans_length - 1, 0, -1):
                fake_len.masked_fill_(fake_idx.data[di].eq(vocab_size), di)

            # generate fake mask.
            fake_mask.data.resize_(fake_idx.size()).fill_(1)
            # get the real, wrong and fake index.
            for b in range(batch_size):
                fake_mask.data[:fake_len[b] + 1, b] = 0

            # apply the mask on the fake_idx.
            fake_idx.masked_fill_(fake_mask, 0)

            # get the fake diff mask.
            #fake_diff_mask = torch.sum(fake_idx == ans_target, 0) != 0
            fake_onehot = fake_onehot.view(-1, vocab_size + 1)

            ######################################
            # Discriminative trained generative model.
            ######################################
            # forward the discriminator again.
            ques_emb_d = netW_d(ques_input, format='index')
            his_emb_d = netW_d(his_input, format='index')

            ques_hidden2 = repackage_hidden(ques_hidden2, batch_size)
            hist_hidden2 = repackage_hidden(hist_hidden2, his_emb_d.size(1))

            featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \
                                        ques_hidden2, hist_hidden2, rnd+1)

            ans_real_emb = netW_d(ans_target, format='index')
            #ans_wrong_emb = netW_d(wrong_ans_input, format='index')
            ans_fake_emb = netW_d(fake_onehot, format='onehot')
            ans_fake_emb = ans_fake_emb.view(ans_length, -1, opt.ninp)

            real_hidden = repackage_hidden(real_hidden, batch_size)
            #wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1))
            fake_hidden = repackage_hidden(fake_hidden, batch_size)

            fake_feat = netD(ans_fake_emb, fake_idx, fake_hidden, vocab_size)
            real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size)

            d_g_loss, g_fake = critG(featD, real_feat,
                                     fake_feat)  #, fake_diff_mask.detach())

            netW_g.zero_grad()
            netG.zero_grad()
            netE_g.zero_grad()

            d_g_loss.backward()
            optimizerG.step()

            err_g += d_g_loss.data[0]
            err_g_tmp += d_g_loss.data[0]
            err_g_fake_tmp += g_fake

            count += 1

        i += 1
        loss_store.append({'iter':i, 'err_lm':err_lm_tmp/10, 'err_d':err_d_tmp/10, 'err_g':err_g_tmp/10, \
                            'd_fake': err_d_fake_tmp/10, 'g_fake':err_g_fake_tmp/10})

        if i % 20 == 0:
            print ('Epoch:%d %d/%d, err_lm %4f, err_d %4f, err_g %4f, d_fake %4f, g_fake %4f' \
                % (epoch, i, len(dataloader), err_lm_tmp/10, err_d_tmp/10, err_g_tmp/10, err_d_fake_tmp/10, \
                    err_g_fake_tmp/10))

    #average_loss = average_loss / count
    err_g = err_g / count
    err_d = err_d / count
    err_lm = err_lm / count

    return err_lm, err_d, err_g, loss_store
Пример #7
0
def train(epoch):
    netW_d.train(), netE_d.train(), netE_g.train()
    netD.train(), netG.train(), netW_g.train()

    fake_len = torch.LongTensor(opt.batchSize)
    fake_len = fake_len.cuda()

    n_neg = opt.negative_sample
    ques_hidden1 = netE_d.init_hidden(opt.batchSize)
    ques_hidden2 = netE_g.init_hidden(opt.batchSize)

    hist_hidden1 = netE_d.init_hidden(opt.batchSize)
    hist_hidden2 = netE_g.init_hidden(opt.batchSize)

    real_hidden = netD.init_hidden(opt.batchSize)
    wrong_hidden = netD.init_hidden(opt.batchSize)
    fake_hidden = netD.init_hidden(opt.batchSize)

    data_iter = iter(dataloader)

    err_d = 0
    err_g = 0
    err_lm = 0
    average_loss = 0
    count = 0
    i = 0
    loss_store = []
    while i < len(dataloader):
        t1 = time.time()
        data = data_iter.next()
        image, history, question, answer, answerT, answerLen, answerIdx, questionL, \
                                    opt_answerT, opt_answerLen, opt_answerIdx = data
        batch_size = question.size(0)
        image = image.view(-1, 512)
        img_input.data.resize_(image.size()).copy_(image)

        err_d_tmp = 0
        err_g_tmp = 0
        err_lm_tmp = 0
        err_d_fake_tmp = 0
        err_g_fake_tmp = 0
        for rnd in range(10):
            # get the corresponding round QA and history.
            ques = question[:,rnd,:].t()
            his = history[:,:rnd+1,:].clone().view(-1, his_length).t()

            ans = answer[:,rnd,:].t()
            tans = answerT[:,rnd,:].t()
            wrong_ans = opt_answerT[:,rnd,:].clone().view(-1, ans_length).t()

            real_len = answerLen[:,rnd].long()
            wrong_len = opt_answerLen[:,rnd,:].clone().view(-1)

            ques_input.data.resize_(ques.size()).copy_(ques)
            his_input.data.resize_(his.size()).copy_(his)

            ans_input.data.resize_(ans.size()).copy_(ans)
            ans_target.data.resize_(tans.size()).copy_(tans)
            wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans)

            batch_sample_idx.data.resize_(batch_size, opt.neg_batch_sample).zero_()
            sample_batch_neg(answerIdx[:,rnd], opt_answerIdx[:,rnd,:], batch_sample_idx, opt.neg_batch_sample)

            # -----------------------------------------
            # update the Generator using MLE loss.
            # -----------------------------------------
            if opt.update_LM:
                ques_emb_g = netW_g(ques_input, format = 'index')
                his_emb_g = netW_g(his_input, format = 'index')

                ques_hidden1 = repackage_hidden(ques_hidden1, batch_size)
                hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1))

                featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                    ques_hidden1, hist_hidden1, rnd+1)

                _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1)
                # MLE loss for generator
                ans_emb = netW_g(ans_input)
                logprob, _ = netG(ans_emb, ques_hidden1)
                lm_loss = critLM(logprob, ans_target.view(-1, 1))
                lm_loss = lm_loss / torch.sum(ans_target.data.gt(0))

                netW_g.zero_grad()
                netG.zero_grad()
                netE_g.zero_grad()

                lm_loss.backward()
                optimizerLM.step()
                err_lm += lm_loss.data[0]
                err_lm_tmp += lm_loss.data[0]

            # sample the answer using gumble softmax sampler.
            ques_emb_g = netW_g(ques_input, format = 'index')
            his_emb_g = netW_g(his_input, format = 'index')

            ques_hidden1 = repackage_hidden(ques_hidden1, batch_size)
            hist_hidden1 = repackage_hidden(hist_hidden1, his_emb_g.size(1))

            featG, ques_hidden1 = netE_g(ques_emb_g, his_emb_g, img_input, \
                                                ques_hidden1, hist_hidden1, rnd+1)

            _, ques_hidden1 = netG(featG.view(1, -1, opt.ninp), ques_hidden1)

            # Gumble softmax to sample the output.
            fake_onehot = []
            fake_idx = []
            noise_input.data.resize_(ans_length, batch_size, vocab_size+1)
            noise_input.data.uniform_(0,1)

            ans_sample = ans_input[0]
            for di in range(ans_length):
                ans_emb = netW_g(ans_sample, format = 'index')
                logprob, ques_hidden1 = netG(ans_emb.view(1,-1,opt.ninp), ques_hidden1)
                one_hot, idx = sampler(logprob, noise_input[di], opt.gumble_weight)
                fake_onehot.append(one_hot.view(1, -1, vocab_size+1))
                fake_idx.append(idx)
                if di+1 < ans_length:
                    ans_sample = idx

            # convert the list into the tensor variable.
            fake_onehot = torch.cat(fake_onehot, 0)
            fake_idx = torch.cat(fake_idx,0)

            fake_len.resize_(batch_size).fill_(ans_length-1)
            for di in range(ans_length-1, 0, -1):
                fake_len.masked_fill_(fake_idx.data[di].eq(vocab_size), di)

            # generate fake mask.
            fake_mask.data.resize_(fake_idx.size()).fill_(1)
            # get the real, wrong and fake index.
            for b in range(batch_size):
                fake_mask.data[:fake_len[b]+1, b] = 0

            # apply the mask on the fake_idx.
            fake_idx.masked_fill_(fake_mask, 0)

            # get the fake diff mask.
            #fake_diff_mask = torch.sum(fake_idx == ans_target, 0) != 0
            fake_onehot = fake_onehot.view(-1, vocab_size+1)
            
            ######################################
            # Discriminative trained generative model.
            ######################################
            # forward the discriminator again.
            ques_emb_d = netW_d(ques_input, format = 'index')
            his_emb_d = netW_d(his_input, format = 'index')

            ques_hidden2 = repackage_hidden(ques_hidden2, batch_size)
            hist_hidden2 = repackage_hidden(hist_hidden2, his_emb_d.size(1))

            featD, _ = netE_d(ques_emb_d, his_emb_d, img_input, \
                                        ques_hidden2, hist_hidden2, rnd+1)

            ans_real_emb = netW_d(ans_target, format='index')
            #ans_wrong_emb = netW_d(wrong_ans_input, format='index')
            ans_fake_emb = netW_d(fake_onehot, format='onehot')
            ans_fake_emb = ans_fake_emb.view(ans_length, -1, opt.ninp)

            real_hidden = repackage_hidden(real_hidden, batch_size)
            #wrong_hidden = repackage_hidden(wrong_hidden, ans_wrong_emb.size(1))
            fake_hidden = repackage_hidden(fake_hidden, batch_size)

            fake_feat = netD(ans_fake_emb, fake_idx, fake_hidden, vocab_size)
            real_feat = netD(ans_real_emb, ans_target, real_hidden, vocab_size)

            d_g_loss, g_fake = critG(featD, real_feat, fake_feat)#, fake_diff_mask.detach())

            netW_g.zero_grad()
            netG.zero_grad()
            netE_g.zero_grad()

            d_g_loss.backward()
            optimizerG.step()

            err_g += d_g_loss.data[0]
            err_g_tmp += d_g_loss.data[0]
            err_g_fake_tmp += g_fake

            count += 1

        i += 1
        loss_store.append({'iter':i, 'err_lm':err_lm_tmp/10, 'err_d':err_d_tmp/10, 'err_g':err_g_tmp/10, \
                            'd_fake': err_d_fake_tmp/10, 'g_fake':err_g_fake_tmp/10})

        if i % 20 == 0:
            print ('Epoch:%d %d/%d, err_lm %4f, err_d %4f, err_g %4f, d_fake %4f, g_fake %4f' \
                % (epoch, i, len(dataloader), err_lm_tmp/10, err_d_tmp/10, err_g_tmp/10, err_d_fake_tmp/10, \
                    err_g_fake_tmp/10))


    #average_loss = average_loss / count
    err_g = err_g / count
    err_d = err_d / count
    err_lm = err_lm / count

    return err_lm, err_d, err_g, loss_store
def train(epoch):
    tgt_netW.train(), src_netW.train(), src_netE_att.train(), tgt_netE_att.train(), netG.train()
    save_path = opt.save_model
    # alpha.train()

    if opt.cuda:
        tgt_netW.cuda()
        src_netW.cuda()
        src_netE_att.cuda()
        tgt_netE_att.cuda()
        netG.cuda()

    fake_len = torch.LongTensor(opt.batch_size)

    fake_len = fake_len.cuda()

    n_neg = opt.negative_sample

    data_iter = iter(dataloader)

    err_d = 0
    err_g = 0
    err_lm = 0
    average_loss = 0
    count = 0
    i = 0
    loss_store = []
    t = time.time()

    while i < len(dataloader)-1:
        t1 = time.time()
        data = data_iter.next()
        # image, history, question, answer, answerT, answerLen, answerIdx, questionL, opt_answerT, opt_answerLen, opt_answerIdx = data
        questionLen, question, answer, answerT, answerLen, answerIdx, questionL, _, opt_answerT, opt_answerLen, opt_answerIdx, answer_ids = data

        # print("question", question,question.size())
        # print("questionL", questionL,questionL.size())
        # print("questionLen", questionLen, questionLen.size())
        # exit()
        batch_size = question.size(0)

        # image = image.view(-1, 512)
        # img_input.data.resize_(image.size()).copy_(image)

        err_d_tmp = 0.
        err_g_tmp = 0.
        err_lm_tmp = 0.
        err_g_fake_tmp = 0.
        err_d_fake_tmp = 0.

        ques = question[:,:].t()
        ques_g = questionL[:,:].t()

        ans = answer[:,:].t()
        if opt.debug == True:
            ques_words = [src_itos[int(w)] for w in ques_g]
            answer_words = [tgt_itos[int(w)] for w in answer[0]]
            print("train_all.py train() ques_g.size()", ques_g.size())
            print("train_all.py train() answer.size()", answer.size())
            print("train_all.py train() ques_g", ques_g, len(ques_g))
            print("train_all.py train() answer", answer, len(answer))
            print("train_all.py train() ques_words", ques_words, len(ques_words))
            print("train_all.py train() answer_words", answer_words, len(answer_words))
            # print("train_all.py ans", ans.size())
        tans = answerT[:,:].t()
        wrong_ans = opt_answerT[:,:].clone().view(-1, tgt_length).t()

        real_len = answerLen[:].long()
        wrong_len = opt_answerLen[:,:].clone().view(-1)

        ques_input.data.resize_(ques.size()).copy_(ques)
        src_input_g.data.resize_(ques.size()).copy_(ques_g)
        # print("train_all.py ques_input",ques_input, ques_input.size(), ques_input.device)

        ans_input.data.resize_(ans.size()).copy_(ans)
        ans_target.data.resize_(tans.size()).copy_(tans)
        wrong_ans_input.data.resize_(wrong_ans.size()).copy_(wrong_ans)


        batch_sample_idx.data.resize_(batch_size, opt.neg_batch_sample).zero_()
        sample_batch_neg(answerIdx[:], opt_answerIdx[:,:], batch_sample_idx, opt.neg_batch_sample)


        src = src_input_g.view(-1, batch_size, 1).clone().to(additional_device0)
        tgt = ans_input.view(-1, batch_size, 1).clone().to(additional_device0)
        src_length = questionLen.view(batch_size)

        # if opt.update_G:
        fake_onehot = []
        fake_idx = []
        noise_input.data.resize_(tgt_length, batch_size, tgt_vocab_size)
        noise_input.data.uniform_(0,1)
        # print("train_all.py train() noise_input", noise_input.size())

        ans_sample = ans_input[0]
        enc_final, memory_bank = netG.encoder(src, src_length)

        # dec_states = netG.decoder.init_decoder_state(src, memory_bank, enc_states)
        dec_state = None
        dec_state =  netG.decoder.init_decoder_state(src, memory_bank, enc_final)
        memory_bank = tile(memory_bank, 1, dim=1)
        memory_lengths = tile(src_length, 1)
        
        alive_seq = torch.full(
            [opt.batch_size, 1],
            start_token,
            dtype=torch.long,
            device=additional_device0)

        for step in range(tgt_length):

            decoder_input = alive_seq[:, -1].view(1, -1, 1)

            dec_out, dec_state, _ = netG.decoder(decoder_input,
                memory_bank,
                dec_state,
                memory_lengths=memory_lengths,
                step=step)

            logprob = netG.generator.forward(dec_out.squeeze(0))

            one_hot, idx = sampler(logprob, noise_input[step].to(additional_device0), opt.gumble_weight)
            fake_onehot.append(one_hot.view(1, -1, tgt_vocab_size))
            fake_idx.append(idx)
            if step+1 < tgt_length:
                ans_sample = idx

            alive_seq = torch.cat(
                (alive_seq, idx.view(-1,1)), -1)

        fake_onehot = torch.cat(fake_onehot, 0)

        fake_idx = torch.cat(fake_idx,0)
        fake_len = fake_len.resize_(batch_size).fill_(tgt_length-1).clone()
        for di in range(tgt_length-1, 0, -1):
            fake_len.masked_fill_(fake_idx.data[di].eq(tgt_vocab_size), di)

        fake_mask.data.resize_(fake_idx.size()).fill_(1)
        for b in range(batch_size):
            fake_mask.data[:fake_len[b]+1, b] = 0

        fake_idx.masked_fill_(fake_mask.clone(), 0)

        fake_onehot = fake_onehot.view(-1, tgt_vocab_size)
        ques_emb_d = src_netW(ques_input, format = 'index')

        featD, _ = src_netE_att(ques_emb_d, ques_input)
        ans_real_emb = tgt_netW(ans_target, format='index')
        ans_fake_emb = tgt_netW(fake_onehot.to('cuda'), format='onehot')
        ans_fake_emb = ans_fake_emb.view(tgt_length, -1, opt.rnn_size)

        fake_feat, _ = tgt_netE_att(ans_fake_emb.cuda(), fake_idx.cuda())
        real_feat, _ = tgt_netE_att(ans_real_emb.cuda(), ans_target)


        ##############################   update_D   ###########################################
        src_netW.zero_grad()
        tgt_netW.zero_grad()
        src_netE_att.zero_grad()
        tgt_netE_att.zero_grad()
        ans_wrong_emb = tgt_netW(wrong_ans_input, format='index')

        tgt_wrong_feat, _weight_ = tgt_netE_att(ans_wrong_emb, wrong_ans_input)

        batch_wrong_feat = tgt_wrong_feat.index_select(0, batch_sample_idx.view(-1))
        tgt_wrong_feat = tgt_wrong_feat.view(batch_size, -1, opt.rnn_size)
        batch_wrong_feat = batch_wrong_feat.view(batch_size, -1, opt.rnn_size)
        # print("featD",featD.size())
        # print("fake_feat",fake_feat.size())
        # batch_wrong_feat = fake_feat.index_select(0, batch_sample_idx.view(-1))
        # fake_feat = fake_feat.view(batch_size, -1, opt.rnn_size)
        # batch_wrong_feat = batch_wrong_feat.view(batch_size, -1, opt.rnn_size)
        nPairLoss, d_fake = critD(featD, real_feat, tgt_wrong_feat, batch_wrong_feat, fake_feat)
        # print("train_gan.py d_fake", d_fake, float(d_fake))
        nPairLoss.backward(retain_graph=True)
        optimizerD.step()
        err_d += nPairLoss.data.item()
        err_d_tmp += nPairLoss.data.item()
        err_d_fake_tmp += d_fake.data.item()
        ##############################   update_D   ###########################################



        ##############################   update_G  #############################################
        d_g_loss, g_fake = critG(featD, real_feat, fake_feat)
        bleu_score = BLEU_score(ans_target, fake_idx)
        netG.zero_grad()
        d_g_loss.backward()
        optimizerG.step()
        err_g += d_g_loss.item()
        err_g_tmp += d_g_loss.item()
        err_g_fake_tmp += g_fake
        ##############################   update_G  #############################################


        ##############################   update_LM   ###########################################
        outputs, _, _ = netG(src, tgt, src_length, None)  
        logprob = netG.generator(outputs.view(-1, outputs.size(2)))
        lm_loss = critLM(logprob, tgt[1:, :, :].view(-1, 1))
        lm_loss = lm_loss / float(torch.sum(1-ans_target[1:, :].data.eq(1)))
        lm_loss = lm_loss*bleu_score
        netG.zero_grad()
        lm_loss.backward()
        optimizerLM.step()
        err_lm += lm_loss.item()
        err_lm_tmp += lm_loss.item()
        ##############################   update_LM   ###########################################

        count += 1

        i += 1
        loss_store.append({'iter':i, 'err_lm':err_lm_tmp/10, 'err_g':err_g_tmp/10, 'g_fake':err_g_fake_tmp/10, 'err_d':err_d_tmp})

        if i % opt.report_every == 0:
            print ('Epoch:%d %d/%d, err_lm %4f, err_g %4f, g_fake %4f, err_d %4f, d_fake %4f, Time: %.3f' % (epoch, i, len(dataloader), err_lm_tmp/10, err_g_tmp/10, err_g_fake_tmp/10, err_d_tmp/10, err_d_fake_tmp/10, time.time()-t))
            t = time.time()

        if i % opt.save_checkpoint_steps == 0:
            print("Saving ... ")
            torch.save({'epoch': epoch,
                    'opt': opt,
                    'src_netW': src_netW.state_dict(),
                    'tgt_netW': tgt_netW.state_dict(),
                    'tgt_netE_att': tgt_netE_att.state_dict(),
                    'src_netE_att': src_netE_att.state_dict(),
                    'netG': netG.state_dict(),
                    'optimizerD': optimizerD,
                    'optimizerG': optimizerG,
                    'optimizerLM': optimizerLM,
                    },
                    '%s/epoch_%d_%d.pth' % (save_path, epoch, i))

        # if i >= 20:break
    err_g = err_g / count
    err_lm = err_lm / count
    err_d = err_d / count
    return err_lm, err_g, err_d, loss_store