コード例 #1
0
ファイル: main.py プロジェクト: pwstt/seqGan_music
def train_discriminator(discriminator, dis_opt, real_data_samples, generator,
                        real_val, d_steps, epochs):
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """

    pos_val = helpers.positive_sample(real_val, 100)
    neg_val = generator.sample(100, MAX_SEQ_LEN)
    val_inp, val_target = helpers.prepare_discriminator_data(pos_val,
                                                             neg_val,
                                                             gpu=CUDA)

    for d_step in range(d_steps):

        s = helpers.batchwise_sample(generator, POS_NEG_SAMPLES, BATCH_SIZE,
                                     MAX_SEQ_LEN)

        dis_inp, dis_target = helpers.prepare_discriminator_data(
            real_data_samples, s, gpu=CUDA)

        for epoch in range(epochs):
            print('d-step %d epoch %d : ' % (d_step + 1, epoch + 1), end='')
            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i in range(0, 2 * POS_NEG_SAMPLES, BATCH_SIZE):

                inp, target = dis_inp[i:i + BATCH_SIZE], dis_target[i:i +
                                                                    BATCH_SIZE]

                dis_opt.zero_grad()
                out = discriminator.batchClassify(inp)
                loss_fn = nn.BCELoss()
                loss = loss_fn(out, target)
                loss.backward()
                dis_opt.step()

                total_loss += loss.data.item()
                total_acc += torch.sum(
                    (out > 0.5) == (target > 0.5)).data.item()

                if (i / BATCH_SIZE) % ceil(
                        ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE)) /
                        10.) == 0:  # roughly every 10% of an epoch
                    print('.', end='')
                    sys.stdout.flush()

            total_loss /= ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE))
            total_acc /= float(2 * POS_NEG_SAMPLES)

            val_pred = discriminator.batchClassify(val_inp)
            print(' average_loss = %.4f, train_acc = %.4f, val_acc = %.4f' %
                  (total_loss, total_acc,
                   torch.sum(
                       (val_pred > 0.5) == (val_target > 0.5)).data.item() /
                   200.))
コード例 #2
0
def train_discriminator(context, real_reply, discriminator, dis_opt, generator,
                        corpus):
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """
    # Batchsize is 32
    # context is 32 x max_context_size

    fake_reply, _, _ = gen.sample(context.permute(1, 0), MAX_SEQ_LEN)

    # UNCOMMENT FOR PRINTING SAMPLES AND CONTEXT

    # print(corpus.ids_to_tokens([int(i) for i in context[0]]))
    # print("Fake generated reply")
    # print(corpus.ids_to_tokens([int(i) for i in fake_reply[0]]))
    # print("Real  reply")
    # print(corpus.ids_to_tokens([int(i) for i in real_reply[0]]))
    # print(30 * "-")
    if DISCRIMINATOR_LM:
        fake_rewards = -torch.mean(dis.get_rewards(fake_reply), dim=1)
        real_rewards = -torch.mean(dis.get_rewards(real_reply), dim=1)
        print(real_rewards)
        loss = -torch.mean((real_rewards - fake_rewards))
    else:
        fake_targets = torch.zeros(BATCH_SIZE)
        real_targets = torch.ones(BATCH_SIZE)

        dis_opt.zero_grad()
        out_fake = discriminator.batchClassify(context, fake_reply.long())
        out_real = discriminator.batchClassify(context, real_reply.long())

        loss_fn = nn.BCELoss()
        loss_fake = loss_fn(out_fake, fake_targets)

        loss_real = loss_fn(out_real, real_targets)

        loss = loss_real + loss_fake
        total_loss = loss.data.item()
        out = torch.cat((out_fake, out_real), 0)
        targets = torch.cat((real_targets, fake_targets), 0)
        correct_real = torch.sum(out_real > 0.5) / BATCH_SIZE
        correct_fake = torch.sum(out_fake < 0.5) / BATCH_SIZE
        total_acc = (correct_real + correct_fake) / 2
        print(' average_loss = %.4f, train_acc = %.4f' %
              (total_loss, total_acc))
    loss.backward()
    dis_opt.step()
コード例 #3
0
    def eval(val_iter, discriminator, generator):
        # validation
        discriminator.eval()
        print('validation :', end=' ')
        total_acc = 0
        num_samples = 0
        total_loss = 0
        for i, data in enumerate(val_iter):
            tgt_data = data.target[0].permute(1, 0)  # batch_size X length
            src_data_wrap = data.source
            ans = data.answer[0]

            if CUDA:
                scr_data = data.source[0].to(device)
                scr_lengths = data.source[1].to(device)
                ans = ans.to(device)
                src_data_wrap = (scr_data, scr_lengths, ans)

            real_samples = tgt_data
            real_lengths = data.target[1]
            passage = src_data_wrap[0].permute(1, 0)

            with torch.no_grad():
                fake_samples, fake_lengths = generator.sample(src_data_wrap)
            # prepare prepare_discriminator_data input
            fake_samples = fake_samples.cpu()
            fake_lengths = fake_lengths.cpu()
            ans = ans.permute(1, 0).cpu()

            # shuffle data
            dis_inp, dis_target, dis_len, dis_pa, dis_an = helpers.prepare_discriminator_data(
                real_samples, real_lengths, fake_samples, fake_lengths,
                passage, ans, tgt_special)
            inp, target = dis_inp, dis_target
            lengths, pa = dis_len, dis_pa
            an = dis_an

            if CUDA:
                inp = inp.to(device)
                target = target.to(device).type(torch.float)
                lengths = lengths.to(device)
                pa = pa.to(device)
                an = an.to(device)
                pa = (pa, an)

            # inp = (inp, lengths)
            out = discriminator.batchClassify(inp, pa)
            loss_fn = nn.BCELoss()  # todo: should .cuda??
            loss = loss_fn(out, target)
            total_loss += loss.item()
            num_samples += tgt_data.size(0) * 2
            total_acc += torch.sum((out > 0.5) == (target > 0.5)).item()

        total_acc = total_acc * 1.0 / float(num_samples)
        print('loss = %.4f' % (total_loss / (num_samples)), end=' ')
        print('val_acc = %.4f\n' % (total_acc))
        discriminator.train()
        return total_acc
コード例 #4
0
ファイル: main.py プロジェクト: Hongzhi-Chen/GAN_Learn
def train_discriminator(discriminator,dis_opt,real_data_samples,generator,oracle,d_steps,epochs):
    #通过鉴别器对真实数据和生成器生成的数据进行训练
    #样本通过d步得到,鉴别器通过epochs次的训练

    #生成一小部分验证集
    pos_val = oracle.sample(100)
    neg_val = generator.sample(100)
    val_inp,val_target = helpers.prepare_discriminator_data(pos_val,neg_val,gpu=CUDA)

    for d_step in range(d_steps):
        s = helpers.batchwise_sample(generator,POS_NEG_SAMPLES,BATCH_SIZE)
        dis_inp,dis_target = helpers.prepare_discriminator_data(real_data_samples,s,gpu=CUDA)
        for epoch in range(epochs):
            print('d_step %d epoch %d:' %(d_step+1,epoch+1),end='')
            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i in range(0,2*POS_NEG_SAMPLES,BATCH_SIZE):
                inp,target = dis_inp[i:i+BATCH_SIZE],dis_target[i:i+BATCH_SIZE]
                dis_opt.zero_grad()
                out = discriminator.batchClassify(inp)
                loss_fn = nn.BCELoss()
                loss = loss_fn(out,target)
                loss.backward()
                dis_opt.step()

                total_loss += loss.data.item()
                total_acc += torch.sum((out>0.5)==(target>0.5)).data.item()

                if(i/BATCH_SIZE) % ceil(ceil(2*POS_NEG_SAMPLES/float(BATCH_SIZE))/10) == 0:
                    print('.',end='')
                    sys.stdout.flush()

            total_acc /= ceil(2*POS_NEG_SAMPLES/float(BATCH_SIZE))
            total_acc /= float(2*POS_NEG_SAMPLES)

            val_pred = discriminator.batchClassify(val_inp)
            print(' average_loss = %.4f, train_acc = %.4f, val_acc = %.4f' % (
                total_loss, total_acc, torch.sum((val_pred > 0.5) == (val_target > 0.5)).data.item() / 200.))
コード例 #5
0
ファイル: SeqGAN.py プロジェクト: I-am-Bot/nlp-robustness
def train_discriminator(discriminator, dis_opt, train_tar_batch, train_asr_batch, bsz = 10, d_steps = 5, epochs = 5):

    # generating a small validation set before training (using oracle and generator)
    device = discriminator.device
    POS_NEG_SAMPLES = len(train_tar_batch)
    BATCH_SIZE = bsz
    for d_step in range(d_steps):
        #s = helpers.batchwise_sample(generator, POS_NEG_SAMPLES, BATCH_SIZE)
        #dis_inp, dis_target = helpers.prepare_discriminator_data(real_data_samples, s, gpu=CUDA)
        random_batch = [np.random.randint(POS_NEG_SAMPLES) for i in range(1000)] 
        for epoch in range(epochs):
            print('d-step %d epoch %d : ' % (d_step + 1, epoch + 1), end='')
            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i in random_batch:
                inp_gen, inp_asr = train_tar_batch[i].view(-1, bsz).to(device), train_asr_batch[i].transpose(0,1)[1:].to(device)  
                inp = torch.cat((inp_gen, inp_asr),1)
                target = np.concatenate((np.zeros(bsz), np.ones(bsz)))
                target = torch.from_numpy(target).float().to(device)

                dis_opt.zero_grad()
                out = discriminator.batchClassify(inp, 2 * bsz)
                
                loss_fn = nn.BCELoss()
                loss = loss_fn(out, target)
                loss.backward()
                dis_opt.step()

                total_loss += loss.data.item()
                total_acc += torch.sum((out>0.5)==(target>0.5)).data.item()

                if (i / BATCH_SIZE) % np.ceil(np.ceil(POS_NEG_SAMPLES / float(
                        BATCH_SIZE)) / 10.) == 0:  # roughly every 10% of an epoch
                    print('.', end='')
                    sys.stdout.flush()

            total_loss /= np.ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE))
            total_acc /= 1000 * 2 * bsz

            #val_pred = discriminator.batchClassify(val_inp)
            print(' average_loss = %.4f, train_acc = %.4f, val_acc = %.4f' % (
                total_loss, total_acc, torch.sum((out>0.5)==(target>0.5)).data.item()/(2*bsz)))
コード例 #6
0
def train_discriminator(discriminator, dis_opt, real_data_samples, generator,
                        oracle, d_steps, epochs, args):
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """

    # generating a small validation set before training (using oracle and generator)
    pos_val = oracle.sample(100)
    neg_val = generator.sample(100)
    val_inp, val_target = helpers.prepare_discriminator_data(pos_val,
                                                             neg_val,
                                                             gpu=args.cuda)
    val_buffer = torch.zeros(200 * args.max_seq_len, args.vocab_size)
    if args.cuda:
        val_buffer = val_buffer.cuda()
    val_inp_oh = helpers.get_oh(val_inp, val_buffer)

    inp_buf = torch.zeros(args.d_bsz * args.max_seq_len, args.vocab_size)
    if args.cuda:
        inp_buf = inp_buf.cuda()
    num_data = len(real_data_samples)
    for d_step in range(d_steps):
        s = helpers.batchwise_sample(generator, args.num_data)
        dis_inp, dis_target = helpers.prepare_discriminator_data(
            real_data_samples, s, gpu=args.cuda)
        for epoch in range(epochs):
            print('d-step %d epoch %d : ' % (d_step + 1, epoch + 1), end='')
            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i in range(0, 2 * num_data, args.d_bsz):
                if i + args.d_bsz > 2 * num_data:
                    break
                inp, target = dis_inp[i:i + args.d_bsz], dis_target[i:i +
                                                                    args.d_bsz]
                inp_oh = helpers.get_oh(inp, inp_buf)
                dis_opt.zero_grad()
                out = discriminator.batchClassify(Variable(inp_oh))
                loss_fn = nn.BCELoss()
                loss = loss_fn(out, Variable(target))
                loss.backward()
                dis_opt.step()

                total_loss += loss.data[0]
                total_acc += torch.sum(
                    (out > 0.5) == (Variable(target) > 0.5)).data[0]

                if (i / args.d_bsz) % ceil(
                        ceil(2 * num_data / float(args.d_bsz)) /
                        10.) == 0:  # roughly every 10% of an epoch
                    print('.', end='')
                    sys.stdout.flush()

            total_loss /= ceil(2 * num_data / float(args.d_bsz))
            total_acc /= float(2 * num_data)

            val_pred = discriminator.batchClassify(Variable(val_inp_oh))
            print(' average_loss = %.4f, train_acc = %.4f, val_acc = %.4f' %
                  (total_loss, total_acc,
                   torch.sum((val_pred > 0.5) ==
                             (Variable(val_target) > 0.5)).data[0] / 200.))
コード例 #7
0
def train_discriminator(discriminator, dis_opt, train_iter, generator, out_acc, epochs, ADV_batches = None):
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """
    def eval(val_iter, discriminator, generator):
        # validation
        discriminator.eval()
        print('validation :', end=' ')
        total_acc = 0
        num_samples = 0
        total_loss = 0
        for i, data in enumerate(val_iter):
            tgt_data = data.target[0].permute(1, 0)  # batch_size X length
            src_data_wrap = data.source
            ans = data.answer[0]

            if CUDA:
                scr_data = data.source[0].to(device)
                scr_lengths = data.source[1].to(device)
                ans = ans.to(device)
                src_data_wrap = (scr_data, scr_lengths, ans)

            real_samples = tgt_data
            real_lengths = data.target[1]
            passage = src_data_wrap[0].permute(1, 0)

            with torch.no_grad():
                fake_samples, fake_lengths = generator.sample(src_data_wrap)
            # prepare prepare_discriminator_data input
            fake_samples = fake_samples.cpu()
            fake_lengths = fake_lengths.cpu()
            ans = ans.permute(1, 0).cpu()

            # shuffle data
            dis_inp, dis_target, dis_len, dis_pa, dis_an = helpers.prepare_discriminator_data(real_samples, real_lengths,
                                                                                     fake_samples, fake_lengths, passage, ans, tgt_special)
            inp, target = dis_inp, dis_target
            lengths, pa = dis_len, dis_pa
            an = dis_an

            if CUDA:
                inp = inp.to(device)
                target = target.to(device)
                lengths = lengths.to(device)
                pa = pa.to(device)
                an = an.to(device)
                pa = (pa, an)

            # inp = (inp, lengths)
            out = discriminator.batchClassify(inp, pa)
            loss_fn = nn.BCELoss()   # todo: should .cuda??
            loss = loss_fn(out, target)
            total_loss += loss.item()
            num_samples += tgt_data.size(0) * 2
            total_acc += torch.sum((out > 0.5) == (target > 0.5)).item()

        total_acc = total_acc * 1.0 / float(num_samples)
        print('loss = %.4f' % (total_loss / (num_samples)), end=' ')
        print('val_acc = %.4f\n' % (total_acc))
        discriminator.train()
        return total_acc

    d_step = 0
    while(1):
        d_step += 1
        passages = []
        anses = []
        real_samples = []
        fake_samples = []
        real_lengths = []
        fake_lengths = []

        for i, data in enumerate(train_iter):
            if ADV_batches is not None:
                if i+1 == ADV_batches:
                    break

            tgt_data = data.target[0].permute(1, 0)  # batch_size X length
            src_data_wrap = data.source
            ans = data.answer[0]

            if CUDA:
                scr_data = data.source[0].to(device)
                scr_lengths = data.source[1].to(device)
                ans = ans.to(device)
                src_data_wrap = (scr_data, scr_lengths, ans)

            real_sample = tgt_data
            real_length = data.target[1]
            with torch.no_grad():
                fake_sample, fake_length = generator.sample(src_data_wrap)
            fake_sample = fake_sample.cpu()
            fake_length = fake_length.cpu()
            ans = ans.permute(1, 0).cpu()

            # keep lengths as the same in order to pack
            passage = src_data_wrap[0].permute(1, 0)
            pad_len = max_sent_len - passage.size(1)
            m = nn.ConstantPad1d((0, pad_len), src_pad)
            passage = m(passage)
            ans = m(ans)

            # keep lengths as the same in order to pack
            pad_len = max_sent_len - real_sample.size(1)
            m = nn.ConstantPad1d((0, pad_len), tgt_pad)
            real_sample = m(real_sample)

            real_samples.append(real_sample)
            real_lengths.append(real_length)
            fake_samples.append(fake_sample)
            fake_lengths.append(fake_length)
            passages.append(passage)
            anses.append(ans)

        real_samples = torch.cat(real_samples, 0).type(torch.LongTensor)
        real_lengths = torch.cat(real_lengths, 0).type(torch.LongTensor)
        fake_samples = torch.cat(fake_samples, 0).type(torch.LongTensor)
        fake_lengths = torch.cat(fake_lengths, 0).type(torch.LongTensor)
        passages = torch.cat(passages, 0).type(torch.LongTensor)
        anses = torch.cat(anses, 0).type(torch.LongTensor)
        dis_inp, dis_target, dis_len, dis_pa, dis_an = helpers.prepare_discriminator_data(real_samples, real_lengths,
                                                                                   fake_samples, fake_lengths, passages, anses, tgt_special)

        # iterator
        # for i, dis_data in enumerate(dis_iter):
        #     dis_inp = dis_data.question[0]
        #     dis_target = dis_data.target
        #     dis_pa = dis_data.passage[0]
        #     dis_an = dis_data.answer[0]

        # collect discriminator data
        # disc_writer = open("disc.json", "w")
        # question0 = rev.reverse(dis_inp.permute(1,0))
        # answer0 = ans_rev.reverse(dis_an.permute(1, 0))
        # passage0 = src_rev.reverse(dis_pa.permute(1, 0))
        # for i in range(len(dis_inp)):
        #     disc_writer.write("{\"question\": \"" + question0[i][6:] + "\", ")
        #     disc_writer.write("\"answer\": \"" + answer0[i] + "\", ")
        #     disc_writer.write("\"passage\": \"" + passage0[i] + "\", ")
        #     disc_writer.write("\"target\": \"" + str(int(dis_target[i].item())) + "\"}" + "\n")

        # # showcases
        # print(' sample showcase:')
        # show = rev.reverse(dis_inp[:Show_num].permute(1, 0))
        # for i in range(Show_num):
        #     print(show[i])

        for epoch in range(epochs):
            discriminator.train()
            print('\n d-step %d epoch %d : ' % (d_step, epoch + 1), end='')
            total_loss = 0
            total_acc = 0
            true_acc = 0
            num_samples = dis_inp.size(0)

            for i in range(0, num_samples, batch_size):
                inp, target = dis_inp[i: i + batch_size], dis_target[i: i + batch_size]
                # lengths = dis_len[i: i + batch_size]
                pa = dis_pa[i: i + batch_size]
                an = dis_an[i: i + batch_size]
                if CUDA:
                    inp = inp.to(device)
                    target = target.to(device)
                    # lengths = lengths.to(device)
                    an = an.to(device)
                    pa = pa.to(device)
                    pa = (pa, an)

                # inp = (inp, lengths)
                dis_opt.zero_grad()
                out = discriminator.batchClassify(inp, pa) # hidden = none over here
                loss_fn = nn.BCELoss()
                loss = loss_fn(out, target)
                loss.backward()
                dis_opt.step()

                total_loss += loss.item()
                total_acc += torch.sum((out>0.5)==(target>0.5)).item()
                true = (target > 0.5).type(torch.FloatTensor)
                out = out.cpu()
                out_true = out * true
                true_acc += torch.sum(out_true > 0.5).item()

            total_acc = total_acc * 1.0 / float(num_samples)
            true_acc = true_acc * 1.0 / float(num_samples/2)
            print('loss = %.4f, train_acc = %.4f' % (total_loss/(num_samples), total_acc), end=' ')
            print('true_acc = %.4f' % true_acc)
            val_acc = eval(val_iter, discriminator, generator)
            # dis_opt.updateLearningRate(val_acc)


            # todo: when to stop the discriminator MLE training(below is my randomly settings)
            flag = 0
            if ADV_batches is None:
                if val_acc > out_acc:
                    flag = 1
                    break

                elif d_step+1 == 8 and epoch+1 == 5:
                    flag = 1
                    break

            else:
                if d_step+1 == 4 and epoch+1 == 5:
                    flag = 1
                    break

        if flag == 1:
            break
コード例 #8
0
ファイル: main.py プロジェクト: shiyinw/AcapellaTune
def train_discriminator(discriminator, dis_opt, pos_samples, neg_samples,
                        generator, oracle, d_steps, epochs, dataloader):
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """

    # generating a small validation set before training (using oracle and generator)

    pos_val, neg_val = dataloader.sample_valid(100)
    val_inp, val_target = helpers.prepare_discriminator_data(pos_val,
                                                             neg_val,
                                                             gpu=CUDA)

    for d_step in range(d_steps):
        s = helpers.batchwise_sample(generator, POS_NEG_SAMPLES, BATCH_SIZE)
        pos_samples = dataloader.sample_pos(POS_NEG_SAMPLES)

        hidden = gen.init_hidden(MAX_SEQ_LEN)
        neg_samples, _ = gen(pos_samples, hidden)
        pos_samples = dataloader.sample_pos(POS_NEG_SAMPLES)
        dis_inp, dis_target = helpers.prepare_discriminator_data(pos_samples,
                                                                 neg_samples,
                                                                 gpu=CUDA)

        for epoch in range(epochs):
            msg = 'd-step %d/%d epoch %d/%d : ' % (d_step + 1, d_steps,
                                                   epoch + 1, epochs)
            log.append(msg)
            print(msg)

            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i in range(0, 2 * POS_NEG_SAMPLES, BATCH_SIZE):
                inp, target = dis_inp[i:i + BATCH_SIZE], dis_target[i:i +
                                                                    BATCH_SIZE]

                dis_opt.zero_grad()
                out = discriminator.batchClassify(inp)
                loss_fn = nn.BCELoss()
                loss = loss_fn(out, target)
                loss.backward()
                dis_opt.step()

                total_loss += loss.data.item()
                total_acc += torch.sum(
                    (out > 0.5) == (target > 0.5)).data.item()

                if (i / BATCH_SIZE) % ceil(
                        ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE)) /
                        10.) == 0:  # roughly every 10% of an epoch
                    print('.', end='')
                    sys.stdout.flush()

            total_loss /= ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE))
            total_acc /= float(2 * POS_NEG_SAMPLES)

            val_pred = discriminator.batchClassify(val_inp)
            msg = ' average_loss = %.4f, train_acc = %.4f, val_acc = %.4f' % (
                total_loss, total_acc,
                torch.sum(
                    (val_pred > 0.5) == (val_target > 0.5)).data.item() / 200.)

            log.append(msg)
            print(msg)
コード例 #9
0
def train_discriminator(discriminator,
                        dis_opt,
                        train_iter,
                        generator,
                        out_acc,
                        epochs,
                        ADV_batches=None):
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """
    def eval(val_iter, discriminator, generator):
        # validation
        discriminator.eval()
        print('validation :', end=' ')
        total_acc = 0
        num_samples = 0
        total_loss = 0
        for i, data in enumerate(val_iter):
            tgt_data = data.target[0].permute(1, 0)  # batch_size X length
            src_data_wrap = data.source
            ans = data.answer[0]

            if CUDA:
                scr_data = data.source[0].to(device)
                scr_lengths = data.source[1].to(device)
                ans = ans.to(device)
                src_data_wrap = (scr_data, scr_lengths, ans)

            real_samples = tgt_data
            real_lengths = data.target[1]
            passage = src_data_wrap[0].permute(1, 0)

            with torch.no_grad():
                fake_samples, fake_lengths = generator.sample(src_data_wrap)
            # prepare prepare_discriminator_data input
            fake_samples = fake_samples.cpu()
            fake_lengths = fake_lengths.cpu()
            ans = ans.permute(1, 0).cpu()

            # shuffle data
            dis_inp, dis_target, dis_len, dis_pa, dis_an = helpers.prepare_discriminator_data(
                real_samples, real_lengths, fake_samples, fake_lengths,
                passage, ans, tgt_special)
            inp, target = dis_inp, dis_target
            lengths, pa = dis_len, dis_pa
            an = dis_an

            if CUDA:
                inp = inp.to(device)
                target = target.to(device).type(torch.float)
                lengths = lengths.to(device)
                pa = pa.to(device)
                an = an.to(device)
                pa = (pa, an)

            # inp = (inp, lengths)
            out = discriminator.batchClassify(inp, pa)
            loss_fn = nn.BCELoss()  # todo: should .cuda??
            loss = loss_fn(out, target)
            total_loss += loss.item()
            num_samples += tgt_data.size(0) * 2
            total_acc += torch.sum((out > 0.5) == (target > 0.5)).item()

        total_acc = total_acc * 1.0 / float(num_samples)
        print('loss = %.4f' % (total_loss / (num_samples)), end=' ')
        print('val_acc = %.4f\n' % (total_acc))
        discriminator.train()
        return total_acc

    for epoch in range(epochs):
        discriminator.train()
        print('\n epoch %d : ' % (epoch + 1), end='')
        total_loss = 0
        total_acc = 0
        true_acc = 0
        num_samples = 0

        for i, dis_data in enumerate(disc_train_iter):
            inp, inp_length = dis_data.question
            target = dis_data.target
            pa, pa_length = dis_data.passage
            ans, ans_length = dis_data.answer
            num_samples += inp.size(1)

            if CUDA:
                pa = pa.transpose(0, 1)
                inp = inp.transpose(0, 1)
                ans = ans.transpose(0, 1)

                inp = inp.to(device)
                target = target.to(device).type(torch.float)
                # lengths = lengths.to(device)
                ans = ans.to(device)
                pa = pa.to(device)
                pa = (pa, ans)

            # inp = (inp, lengths)
            dis_opt.zero_grad()
            out = discriminator.batchClassify(inp,
                                              pa)  # hidden = none over here
            loss_fn = nn.BCELoss()
            loss = loss_fn(out, target)
            loss.backward()
            dis_opt.step()

            total_loss += loss.item()
            total_acc += torch.sum((out > 0.5) == (target > 0.5)).item()
            true = (target > 0.5).type(torch.FloatTensor)
            out = out.cpu()
            out_true = out * true
            true_acc += torch.sum(out_true > 0.5).item()

        total_acc = total_acc * 1.0 / float(num_samples)
        true_acc = true_acc * 1.0 / float(num_samples / 2)
        print('loss = %.4f, train_acc = %.4f' % (total_loss /
                                                 (num_samples), total_acc),
              end=' ')
        print('true_acc = %.4f' % true_acc)
        val_acc = eval(val_iter, discriminator, generator)
        # dis_opt.updateLearningRate(val_acc)

        # todo: when to stop the discriminator MLE training(below is my randomly settings)
        flag = 0
        if ADV_batches is None:
            if val_acc > out_acc:
                flag = 1
                break