def main(): parser = argparse.ArgumentParser() parser.add_argument('--cuda', default=False, action='store_true', help='Enable CUDA') args = parser.parse_args() use_cuda = True if args.cuda and torch.cuda.is_available() else False random.seed(SEED) np.random.seed(SEED) netG = Generator(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, G_LR, use_cuda) netD = Discriminator(VOCAB_SIZE, D_EMB_SIZE, D_NUM_CLASSES, D_FILTER_SIZES, D_NUM_FILTERS, DROPOUT, D_LR, D_L2_REG, use_cuda) oracle = Oracle(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, use_cuda) # generating synthetic data print('Generating data...') generate_samples(oracle, BATCH_SIZE, GENERATED_NUM, REAL_FILE) # pretrain generator gen_set = GeneratorDataset(REAL_FILE) genloader = DataLoader(dataset=gen_set, batch_size=BATCH_SIZE, shuffle=True) print('\nPretraining generator...\n') for epoch in range(PRE_G_EPOCHS): loss = netG.pretrain(genloader) print('Epoch {} pretrain generator training loss: {}'.format( epoch, loss)) generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) val_set = GeneratorDataset(EVAL_FILE) valloader = DataLoader(dataset=val_set, batch_size=BATCH_SIZE, shuffle=True) loss = oracle.val(valloader) print('Epoch {} pretrain generator val loss: {}'.format( epoch + 1, loss)) # pretrain discriminator print('\nPretraining discriminator...\n') for epoch in range(D_STEPS): generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE) dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE) disloader = DataLoader(dataset=dis_set, batch_size=BATCH_SIZE, shuffle=True) for _ in range(K_STEPS): loss = netD.dtrain(disloader) print('Epoch {} pretrain discriminator training loss: {}'.format( epoch + 1, loss)) # adversarial training rollout = Rollout(netG, update_rate=ROLLOUT_UPDATE_RATE, rollout_num=ROLLOUT_NUM) print('\n#####################################################') print('Adversarial training...\n') for epoch in range(TOTAL_EPOCHS): for _ in range(G_STEPS): netG.pgtrain(BATCH_SIZE, SEQUENCE_LEN, rollout, netD) for d_step in range(D_STEPS): # train discriminator generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE) dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE) disloader = DataLoader(dataset=dis_set, batch_size=BATCH_SIZE, shuffle=True) for k_step in range(K_STEPS): loss = netD.dtrain(disloader) print( 'D_step {}, K-step {} adversarial discriminator training loss: {}' .format(d_step + 1, k_step + 1, loss)) rollout.update_params() generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) val_set = GeneratorDataset(EVAL_FILE) valloader = DataLoader(dataset=val_set, batch_size=BATCH_SIZE, shuffle=True) loss = oracle.val(valloader) print('Epoch {} adversarial generator val loss: {}'.format( epoch + 1, loss))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--cuda', default=False, action='store_true', help='Enable CUDA') args = parser.parse_args() use_cuda = True if args.cuda and torch.cuda.is_available() else False netG = Generator(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, use_cuda) netD = Discriminator(VOCAB_SIZE, D_EMB_SIZE, D_NUM_CLASSES, D_FILTER_SIZES, D_NUM_FILTERS, DROPOUT, use_cuda) oracle = Oracle(VOCAB_SIZE, G_EMB_SIZE, G_HIDDEN_SIZE, use_cuda) if use_cuda: netG, netD, oracle = netG.cuda(), netD.cuda(), oracle.cuda() netG.create_optim(G_LR) netD.create_optim(D_LR, D_L2_REG) # generating synthetic data print('Generating data...') generate_samples(oracle, BATCH_SIZE, GENERATED_NUM, REAL_FILE) # pretrain generator gen_set = GeneratorDataset(REAL_FILE) genloader = DataLoader(dataset=gen_set, batch_size=BATCH_SIZE, shuffle=True) print('\nPretraining generator...\n') for epoch in range(PRE_G_EPOCHS): loss = netG.pretrain(genloader) print('Epoch {} pretrain generator training loss: {}'.format( epoch + 1, loss)) generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) val_set = GeneratorDataset(EVAL_FILE) valloader = DataLoader(dataset=val_set, batch_size=BATCH_SIZE, shuffle=True) loss = oracle.val(valloader) print('Epoch {} pretrain generator val loss: {}'.format( epoch + 1, loss)) # pretrain discriminator print('\nPretraining discriminator...\n') for epoch in range(PRE_D_EPOCHS): generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE) dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE) disloader = DataLoader(dataset=dis_set, batch_size=BATCH_SIZE, shuffle=True) for k_step in range(K_STEPS): loss = netD.dtrain(disloader) print( 'Epoch {} K-step {} pretrain discriminator training loss: {}'. format(epoch + 1, k_step + 1, loss)) print('\nStarting adversarial training...') for epoch in range(TOTAL_EPOCHS): nets = [copy.deepcopy(netG) for _ in range(POPULATION_SIZE)] population = [(net, evaluate(net, netD)) for net in nets] for g_step in range(G_STEPS): t_start = time.time() population.sort(key=lambda p: p[1], reverse=True) rewards = [p[1] for p in population[:PARENTS_COUNT]] reward_mean = np.mean(rewards) reward_max = np.max(rewards) reward_std = np.std(rewards) print( "Epoch %d step %d: reward_mean=%.2f, reward_max=%.2f, reward_std=%.2f, time=%.2f s" % (epoch, g_step, reward_mean, reward_max, reward_std, time.time() - t_start)) elite = population[0] # generate next population prev_population = population population = [elite] for _ in range(POPULATION_SIZE - 1): parent_idx = np.random.randint(0, PARENTS_COUNT) parent = prev_population[parent_idx][0] net = mutate_net(parent, use_cuda) fitness = evaluate(parent, netD) population.append((net, fitness)) netG = elite[0] for d_step in range(D_STEPS): # train discriminator generate_samples(netG, BATCH_SIZE, GENERATED_NUM, FAKE_FILE) dis_set = DiscriminatorDataset(REAL_FILE, FAKE_FILE) disloader = DataLoader(dataset=dis_set, batch_size=BATCH_SIZE, shuffle=True) for k_step in range(K_STEPS): loss = netD.dtrain(disloader) print( 'D_step {}, K-step {} adversarial discriminator training loss: {}' .format(d_step + 1, k_step + 1, loss)) generate_samples(netG, BATCH_SIZE, GENERATED_NUM, EVAL_FILE) val_set = GeneratorDataset(EVAL_FILE) valloader = DataLoader(dataset=val_set, batch_size=BATCH_SIZE, shuffle=True) loss = oracle.val(valloader) print('Epoch {} adversarial generator val loss: {}'.format( epoch + 1, loss))