Пример #1
0
def train_discriminator(discriminator, dis_opt, real_data_samples, generator,
                        oracle, d_steps, epochs):
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """
    #生成一小部分验证数据和我们的真实数据对比
    # generating a small validation set before training (using oracle and generator)
    pos_val = oracle.sample(100)
    neg_val = generator.sample(100)
    val_inp, val_target = helpers.prepare_discriminator_data(pos_val,
                                                             neg_val,
                                                             gpu=CUDA)

    #训练
    for d_step in range(d_steps):
        s = helpers.batchwise_sample(generator, POS_NEG_SAMPLES, BATCH_SIZE)
        dis_inp, dis_target = helpers.prepare_discriminator_data(
            real_data_samples, s, gpu=CUDA)
        for epoch in range(epochs):
            print('d-step %d epoch %d : ' % (d_step + 1, epoch + 1), end='')
            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i in range(0, 2 * POS_NEG_SAMPLES, BATCH_SIZE):
                inp, target = dis_inp[i:i + BATCH_SIZE], dis_target[i:i +
                                                                    BATCH_SIZE]
                dis_opt.zero_grad()
                out = discriminator.batchClassify(inp)
                loss_fn = nn.BCELoss()
                loss = loss_fn(out, target)
                loss.backward()
                dis_opt.step()

                total_loss += loss.data.item()
                total_acc += torch.sum(
                    (out > 0.5) == (target > 0.5)).data.item()

                if (i / BATCH_SIZE) % ceil(
                        ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE)) /
                        10.) == 0:  # roughly every 10% of an epoch
                    print('.', end='')
                    sys.stdout.flush()

            total_loss /= ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE))
            total_acc /= float(2 * POS_NEG_SAMPLES)

            val_pred = discriminator.batchClassify(val_inp)
            print(' average_loss = %.4f, train_acc = %.4f, val_acc = %.4f' %
                  (total_loss, total_acc,
                   torch.sum(
                       (val_pred > 0.5) == (val_target > 0.5)).data.item() /
                   200.))
Пример #2
0
def train_discriminator(discriminator,dis_opt,real_data_samples,generator,oracle,d_steps,epochs):
    #通过鉴别器对真实数据和生成器生成的数据进行训练
    #样本通过d步得到,鉴别器通过epochs次的训练

    #生成一小部分验证集
    pos_val = oracle.sample(100)
    neg_val = generator.sample(100)
    val_inp,val_target = helpers.prepare_discriminator_data(pos_val,neg_val,gpu=CUDA)

    for d_step in range(d_steps):
        s = helpers.batchwise_sample(generator,POS_NEG_SAMPLES,BATCH_SIZE)
        dis_inp,dis_target = helpers.prepare_discriminator_data(real_data_samples,s,gpu=CUDA)
        for epoch in range(epochs):
            print('d_step %d epoch %d:' %(d_step+1,epoch+1),end='')
            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i in range(0,2*POS_NEG_SAMPLES,BATCH_SIZE):
                inp,target = dis_inp[i:i+BATCH_SIZE],dis_target[i:i+BATCH_SIZE]
                dis_opt.zero_grad()
                out = discriminator.batchClassify(inp)
                loss_fn = nn.BCELoss()
                loss = loss_fn(out,target)
                loss.backward()
                dis_opt.step()

                total_loss += loss.data.item()
                total_acc += torch.sum((out>0.5)==(target>0.5)).data.item()

                if(i/BATCH_SIZE) % ceil(ceil(2*POS_NEG_SAMPLES/float(BATCH_SIZE))/10) == 0:
                    print('.',end='')
                    sys.stdout.flush()

            total_acc /= ceil(2*POS_NEG_SAMPLES/float(BATCH_SIZE))
            total_acc /= float(2*POS_NEG_SAMPLES)

            val_pred = discriminator.batchClassify(val_inp)
            print(' average_loss = %.4f, train_acc = %.4f, val_acc = %.4f' % (
                total_loss, total_acc, torch.sum((val_pred > 0.5) == (val_target > 0.5)).data.item() / 200.))
Пример #3
0
def train_discriminator(dis, dis_opt, real_data_samples, gen, d_steps, epochs):
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """

    for d_step in range(d_steps):
        s = helpers.batchwise_sample(gen, POS_NEG_SAMPLES, BATCH_SIZE)
        dis_inp, dis_target = helpers.prepare_discriminator_data(
            real_data_samples, s, gpu=CUDA)
        for epoch in range(epochs):
            print('d-step %d epoch %d : ' % (d_step + 1, epoch + 1), end='')
            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i in range(0, 2 * POS_NEG_SAMPLES, BATCH_SIZE):
                inp, target = dis_inp[i:i + BATCH_SIZE], dis_target[i:i +
                                                                    BATCH_SIZE]
                dis_opt.zero_grad()
                out = dis.batchClassify(inp)
                loss_fn = nn.BCELoss()
                loss = loss_fn(out, target)
                loss.backward()
                dis_opt.step()

                total_loss += loss.data[0]
                total_acc += torch.sum((out > 0.5) == (target > 0.5)).data[0]

                if (i / BATCH_SIZE) % ceil(
                        ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE)) /
                        10.) == 0:  # roughly every 10% of an epoch
                    print('.', end='')
                    sys.stdout.flush()

            total_loss /= ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE))
            total_acc /= float(2 * POS_NEG_SAMPLES)
Пример #4
0
def train_discriminator(discriminator, dis_opt, real_data_samples, generator,
                        oracle, d_steps, epochs, args):
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """

    # generating a small validation set before training (using oracle and generator)
    pos_val = oracle.sample(100)
    neg_val = generator.sample(100)
    val_inp, val_target = helpers.prepare_discriminator_data(pos_val,
                                                             neg_val,
                                                             gpu=args.cuda)
    val_buffer = torch.zeros(200 * args.max_seq_len, args.vocab_size)
    if args.cuda:
        val_buffer = val_buffer.cuda()
    val_inp_oh = helpers.get_oh(val_inp, val_buffer)

    inp_buf = torch.zeros(args.d_bsz * args.max_seq_len, args.vocab_size)
    if args.cuda:
        inp_buf = inp_buf.cuda()
    num_data = len(real_data_samples)
    for d_step in range(d_steps):
        s = helpers.batchwise_sample(generator, args.num_data)
        dis_inp, dis_target = helpers.prepare_discriminator_data(
            real_data_samples, s, gpu=args.cuda)
        for epoch in range(epochs):
            print('d-step %d epoch %d : ' % (d_step + 1, epoch + 1), end='')
            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i in range(0, 2 * num_data, args.d_bsz):
                if i + args.d_bsz > 2 * num_data:
                    break
                inp, target = dis_inp[i:i + args.d_bsz], dis_target[i:i +
                                                                    args.d_bsz]
                inp_oh = helpers.get_oh(inp, inp_buf)
                dis_opt.zero_grad()
                out = discriminator.batchClassify(Variable(inp_oh))
                loss_fn = nn.BCELoss()
                loss = loss_fn(out, Variable(target))
                loss.backward()
                dis_opt.step()

                total_loss += loss.data[0]
                total_acc += torch.sum(
                    (out > 0.5) == (Variable(target) > 0.5)).data[0]

                if (i / args.d_bsz) % ceil(
                        ceil(2 * num_data / float(args.d_bsz)) /
                        10.) == 0:  # roughly every 10% of an epoch
                    print('.', end='')
                    sys.stdout.flush()

            total_loss /= ceil(2 * num_data / float(args.d_bsz))
            total_acc /= float(2 * num_data)

            val_pred = discriminator.batchClassify(Variable(val_inp_oh))
            print(' average_loss = %.4f, train_acc = %.4f, val_acc = %.4f' %
                  (total_loss, total_acc,
                   torch.sum((val_pred > 0.5) ==
                             (Variable(val_target) > 0.5)).data[0] / 200.))
Пример #5
0
    args.max_seq_len,
    gpu=args.cuda,
)

dis = discriminator.Discriminator(args.d_emb_dim,
                                  args.d_hid_dim,
                                  args.vocab_size,
                                  args.max_seq_len,
                                  gpu=args.cuda)

if args.cuda:
    oracle = oracle.cuda()
    gen = gen.cuda()
    dis = dis.cuda()

oracle_samples = helpers.batchwise_sample(oracle, args.num_data)
if args.oracle_save is not None:
    torch.save(oracle.state_dict(), args.oracle_save)

logger = logger.Logger(args.log_dir)
# GENERATOR MLE TRAINING
gen_optimizer = optim.Adam(gen.parameters(), lr=1e-2)
if args.pre_g_load is not None:
    print("Load pretrained MLE gen")
    gen.load_state_dict(torch.load(args.pre_g_load))
else:
    print('Starting Generator MLE Training...')
    for epoch in range(args.mle_epochs):
        print('epoch %d : ' % (epoch + 1), end='')
        sys.stdout.flush()