Пример #1
0
def predict(args):
    with open(os.path.join(args.save_dir, 'config.pkl')) as f:
        saved_args = cPickle.load(f)
    with open(os.path.join(args.save_dir, 'combined_vocab.pkl')) as f:
        _, vocab = cPickle.load(f)
    model = Discriminator(saved_args, is_training = False)
    with tf.Session() as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        ckpt  = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            return model.predict(sess, args.text, vocab)
Пример #2
0
    def __init__(self, cfg):
        self.cfg = cfg
        self.OldLabel_generator = U_Net(in_ch=cfg.DATASET.N_CLASS,
                                        out_ch=cfg.DATASET.N_CLASS,
                                        side='out')
        self.Image_generator = U_Net(in_ch=3,
                                     out_ch=cfg.DATASET.N_CLASS,
                                     side='in')
        self.discriminator = Discriminator(cfg.DATASET.N_CLASS + 3,
                                           cfg.DATASET.IMGSIZE,
                                           patch=True)

        self.criterion_G = GeneratorLoss(cfg.LOSS.LOSS_WEIGHT[0],
                                         cfg.LOSS.LOSS_WEIGHT[1],
                                         cfg.LOSS.LOSS_WEIGHT[2],
                                         ignore_index=cfg.LOSS.IGNORE_INDEX)
        self.criterion_D = DiscriminatorLoss()

        train_dataset = BaseDataset(cfg, split='train')
        valid_dataset = BaseDataset(cfg, split='val')
        self.train_dataloader = data.DataLoader(
            train_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)
        self.valid_dataloader = data.DataLoader(
            valid_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)

        self.ckpt_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints')
        if not os.path.isdir(self.ckpt_outdir):
            os.mkdir(self.ckpt_outdir)
        self.val_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'val')
        if not os.path.isdir(self.val_outdir):
            os.mkdir(self.val_outdir)
        self.start_epoch = cfg.TRAIN.RESUME
        self.n_epoch = cfg.TRAIN.N_EPOCH

        self.optimizer_G = torch.optim.Adam(
            [{
                'params': self.OldLabel_generator.parameters()
            }, {
                'params': self.Image_generator.parameters()
            }],
            lr=cfg.OPTIMIZER.G_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        self.optimizer_D = torch.optim.Adam(
            [{
                'params': self.discriminator.parameters(),
                'initial_lr': cfg.OPTIMIZER.D_LR
            }],
            lr=cfg.OPTIMIZER.D_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        iter_per_epoch = len(train_dataset) // cfg.DATASET.BATCHSIZE
        lambda_poly = lambda iters: pow(
            (1.0 - iters / (cfg.TRAIN.N_EPOCH * iter_per_epoch)), 0.9)
        self.scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)
        self.scheduler_D = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)

        self.logger = logger(cfg.TRAIN.OUTDIR, name='train')
        self.running_metrics = runningScore(n_classes=cfg.DATASET.N_CLASS)

        if self.start_epoch >= 0:
            self.OldLabel_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_N'])
            self.Image_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_I'])
            self.discriminator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_D'])
            self.optimizer_G.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_G'])
            self.optimizer_D.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_D'])

            log = "Using the {}th checkpoint".format(self.start_epoch)
            self.logger.info(log)
        self.Image_generator = self.Image_generator.cuda()
        self.OldLabel_generator = self.OldLabel_generator.cuda()
        self.discriminator = self.discriminator.cuda()
        self.criterion_G = self.criterion_G.cuda()
        self.criterion_D = self.criterion_D.cuda()
Пример #3
0
def main():
    # Define parameters
    img_size = 32
    epochs = 50
    batch_size = 32
    checkpoint_dir = 'models/checkpoints'

    # Create instances of the Generator, the Discriminator, the Critic and the optimizers
    generator = Generator(img_size=img_size)
    discriminator = Discriminator(img_size=img_size)
    critic = Critic(img_size=img_size)

    gpu_devices = tf.config.experimental.list_physical_devices('GPU')
    for device in gpu_devices:
        tf.config.experimental.set_memory_growth(device, True)

    # Load the dataset and normalize it
    print('Loading dataset...')
    images = np.load('dataset/LLD_icon_numpy/dataset1.npy', allow_pickle=True)
    print('Finished')

    # Slice the dataset into batches of size batch_size
    print('Splitting the dataset into batches...')
    images = tf.data.Dataset.from_tensor_slices(images).batch(batch_size)
    print('Finished')

    print('Enter which architecture you want to use: DCGAN, LSGAN or WGAN? ')
    architecture = str(input())

    if architecture == 'DCGAN':
        # Define optimizers for the DCGAN
        generator_optimizer, discriminator_optimizer = define_dcgan_optimizers(
        )

        # Restore checkpoint
        print('Restoring checkpoint...')
        old_checkpoint = create_checkpoint_dcgan_lsgan(
            generator, discriminator, generator_optimizer,
            discriminator_optimizer)
        restore_checkpoint(old_checkpoint, checkpoint_dir)
        print('Checkpoint restored')

        # Train the DCGAN on the dataset
        print('Starting training...')
        train_dcgan(generator, discriminator, generator_optimizer,
                    discriminator_optimizer, images, epochs, batch_size)
        print('Training finished')

        # Create a checkpoint
        print('Creating checkpoint...')
        new_checkpoint = create_checkpoint_dcgan_lsgan(
            generator, discriminator, generator_optimizer,
            discriminator_optimizer)
        save_checkpoint(new_checkpoint, checkpoint_dir)
        print('Checkpoint created')

    elif architecture == 'LSGAN':
        # Define optimizers for the LSGAN
        generator_optimizer, discriminator_optimizer = define_lsgan_optimizers(
        )

        # Restore checkpoint
        print('Restoring checkpoint...')
        old_checkpoint = create_checkpoint_dcgan_lsgan(
            generator, discriminator, generator_optimizer,
            discriminator_optimizer)
        restore_checkpoint(old_checkpoint, checkpoint_dir)
        print('Checkpoint restored')

        # Train the GAN on the dataset
        print('Starting training...')
        train_lsgan(generator, discriminator, generator_optimizer,
                    discriminator_optimizer, images, epochs, batch_size)
        print('Training finished')

        # Create a checkpoint
        print('Creating checkpoint...')
        new_checkpoint = create_checkpoint_dcgan_lsgan(
            generator, discriminator, generator_optimizer,
            discriminator_optimizer)
        save_checkpoint(new_checkpoint, checkpoint_dir)
        print('Checkpoint created')

    elif architecture == 'WGAN':
        # Define optimizers for the WGAN
        generator_optimizer, critic_optimizer = define_wgan_optimizers()

        # Restore checkpoint
        print('Restoring checkpoint...')
        old_checkpoint = create_checkpoint_wgan(generator, critic,
                                                generator_optimizer,
                                                critic_optimizer)
        restore_checkpoint(old_checkpoint, checkpoint_dir)
        print('Checkpoint restored')

        # Train the WGAN on the dataset
        print('Starting training...')
        train_wgan(generator, critic, generator_optimizer, critic_optimizer,
                   images, epochs, batch_size)
        print('Training finished')

        # Create a checkpoint
        print('Creating checkpoint...')
        new_checkpoint = create_checkpoint_wgan(generator, critic,
                                                generator_optimizer,
                                                critic_optimizer)
        save_checkpoint(new_checkpoint, checkpoint_dir)
        print('Checkpoint created')

    else:
        print('Invalid architecture selected!')

    print('Execution finished')
Пример #4
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    # data loaders declaration
    # loaders for generator, discriminator, and additional validation data loader
    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    dis_data_loader = Dis_dataloader(BATCH_SIZE)
    eval_data_loader = Gen_Data_loader(BATCH_SIZE)

    # define generator and discriminator
    # general structures are same with the original model
    # learning rates for generator needs heavy tuning for general use
    # l2 reg for D & G also affects performance
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN, GENERATOR_LR, REWARD_GAMMA)
    discriminator = Discriminator(sequence_length=SEQ_LENGTH,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    # VRAM limitation for efficient deployment
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    sess = tf.Session(config=tf_config)
    sess.run(tf.global_variables_initializer())
    # define saver
    saver = tf.train.Saver(tf.trainable_variables(), max_to_keep=1)
    # generate real data from the true dataset
    gen_data_loader.create_batches(positive_file)
    # generate real validation data from true validation dataset
    eval_data_loader.create_batches(valid_file)

    time = str(datetime.datetime.now())[:-7]
    log = open('save/experiment-log' + str(time) + '.txt', 'w')
    log.write(str(config) + '\n')
    log.write('D loss: original\n')
    log.flush()

    #summary_writer = tf.summary.FileWriter('save/tensorboard/', graph=tf.get_default_graph())

    if config['pretrain'] == True:
        #  pre-train generator
        print 'Start pre-training...'
        log.write('pre-training...\n')
        for epoch in xrange(PRE_GEN_EPOCH):
            # calculate the loss by running an epoch
            loss = pre_train_epoch(sess, generator, gen_data_loader)

            # measure bleu score with the validation set
            bleu_score = calculate_bleu(sess, generator, eval_data_loader)

            # since the real data is the true data distribution, only evaluate the pretraining loss
            # note the absence of the oracle model which is meaningless for general use
            buffer = 'pre-train epoch: ' + str(
                epoch) + ' pretrain_loss: ' + str(loss) + ' bleu: ' + str(
                    bleu_score)
            print(buffer)
            log.write(buffer + '\n')
            log.flush()

            # generate 5 test samples per epoch
            # it automatically samples from the generator and postprocess to midi file
            # midi files are saved to the pre-defined folder
            if epoch == 0:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 negative_file)
                POST.main(negative_file, 5, str(-1) + '_vanilla_', 'midi')
            elif epoch == PRE_GEN_EPOCH - 1:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 negative_file)
                POST.main(negative_file, 5,
                          str(-PRE_GEN_EPOCH) + '_vanilla_', 'midi')

        print 'Start pre-training discriminator...'
        # Train 3 epoch on the generated data and do this for 50 times
        # this trick is also in spirit of the original work, but the epoch strategy needs tuning
        for epochs in range(PRE_DIS_EPOCH):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            D_loss = 0
            for _ in range(3):
                dis_data_loader.load_train_data(positive_file, negative_file)
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
                    D_loss += discriminator.loss.eval(feed, session=sess)
            buffer = 'epoch: ' + str(epochs + 1) + '  D loss: ' + str(
                D_loss / dis_data_loader.num_batch / 3)
            print(buffer)
            log.write(buffer + '\n')
            log.flush()

        # save the pre-trained checkpoint for future use
        # if one wants adv. training only, comment out the pre-training section after the save
        save_checkpoint(sess, saver, PRE_GEN_EPOCH, PRE_DIS_EPOCH)

    # define rollout target object
    # the second parameter specifies target update rate
    # the higher rate makes rollout "conservative", with less update from the learned generator
    # we found that higher update rate stabilized learning, constraining divergence of the generator
    rollout = ROLLOUT(generator, ROLLOUT_UPDATE_RATE)

    print '#########################################################################'
    print 'Start Adversarial Training...'
    log.write('adversarial training...\n')
    if config['pretrain'] == False:
        # load checkpoint of pre-trained model
        load_checkpoint(sess, saver)

    # 0.001 to 0.01
    if config['x10adv_g'] == True:
        generator.learning_rate *= 10

    for total_batch in range(TOTAL_BATCH):
        G_loss = 0
        # Train the generator for one step
        for it in range(epochs_generator):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, config['rollout_num'],
                                         discriminator)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)
            G_loss += generator.g_loss.eval(feed, session=sess)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        D_loss = 0
        for _ in range(epochs_discriminator):
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            for _ in range(config['epochs_discriminator_multiplier']):
                dis_data_loader.load_train_data(positive_file, negative_file)
                dis_data_loader.reset_pointer()

                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)
                    D_loss += discriminator.loss.eval(feed, session=sess)

        # measure stability and performance evaluation with bleu score
        bleu_score = calculate_bleu(sess, generator, eval_data_loader)
        buffer = 'epoch: ' + str(total_batch+1) + \
                 ',  G_adv_loss: %.12f' % (G_loss/epochs_generator) + \
                 ',  D loss: %.12f' % (D_loss/epochs_discriminator/config['epochs_discriminator_multiplier']) + \
                 ',  bleu score: %.12f' % bleu_score
        print(buffer)
        log.write(buffer + '\n')
        log.flush()

        if config['infinite_loop'] is True:
            if bleu_score < config['loop_threshold']:
                buffer = 'Mode collapse detected, restarting from pretrained model...'
                print(buffer)
                log.write(buffer + '\n')
                log.flush()
                load_checkpoint(sess, saver)

        # generate random test samples and postprocess the sequence to midi file
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)
        POST.main(negative_file, 5, str(total_batch) + '_vanilla_', 'midi')
    log.close()
Пример #5
0
class Model(object):
    def __init__(self, class_num, z_dim, batch_size):

        self.input_size = 32
        self.class_num = class_num
        self.z_dim = z_dim
        self.batch_size = batch_size

        self.Lambda = 10

        self.lr = 0.001

        # generator config
        gen_layer = [512, 256, 128, 1]
        gen_in_dim = int(self.input_size / 2**(len(gen_layer) - 1))

        #discriminato config
        disc_layer = [1, 64, 128, 256]

        # -- generator -----
        self.gen = Generator([u'gen_reshape', u'gen_deconv'], gen_in_dim,
                             gen_layer)

        # -- discriminator --
        self.disc = Discriminator([u'disc_conv', u'disc_fc'], disc_layer)

        # -- q ---------------
        self.Q_value = Discriminator([u'Q_val_conv', u'Q_val_fc'], disc_layer,
                                     class_num)

    def set_model(self):

        # -- define place holder -------
        self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim])
        self.c = tf.placeholder(tf.float32, [self.batch_size, self.class_num])
        self.figs = tf.placeholder(
            tf.float32, [self.batch_size, self.input_size, self.input_size, 1])
        #figs_ = flatten(self.figs)

        # -- generator -----------------
        gen_figs = self.gen.set_model(self.c, self.z, self.batch_size, True,
                                      False)
        g_logits = self.disc.set_model(gen_figs, True, False)

        self.g_obj = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=g_logits, labels=tf.ones_like(g_logits)))

        self.train_gen = tf.train.AdamOptimizer(self.lr, beta1=0.5).minimize(
            self.g_obj, var_list=self.gen.get_variables())

        # -- q loss ------------------
        q_logits = self.Q_value.set_model(gen_figs, True, False)
        self.q_obj = self.Lambda * tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=q_logits,
                                                    labels=self.c))
        train_var = self.gen.get_variables() + self.Q_value.get_variables()
        self.train_q = tf.train.AdamOptimizer(self.lr, beta1=0.5).minimize(
            self.g_obj, var_list=train_var)

        # -- discriminator --------
        d_logits = self.disc.set_model(self.figs, True, True)

        d_obj_true = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=d_logits, labels=tf.ones_like(d_logits)))
        d_obj_false = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                logits=g_logits, labels=tf.zeros_like(g_logits)))
        self.d_obj = d_obj_true + d_obj_false
        self.train_disc = tf.train.AdamOptimizer(self.lr, beta1=0.5).minimize(
            self.d_obj, var_list=self.disc.get_variables())

        # -- for figure generation -------
        self.gen_figs = self.gen.set_model(self.c, self.z, self.batch_size,
                                           False, True)

    def training_gen(self, sess, c_list, z_list):
        _, g_obj = sess.run([self.train_gen, self.g_obj],
                            feed_dict={
                                self.c: c_list,
                                self.z: z_list
                            })
        return g_obj

    def training_disc(self, sess, c_list, z_list, figs):
        _, d_obj = sess.run([self.train_disc, self.d_obj],
                            feed_dict={
                                self.c: c_list,
                                self.z: z_list,
                                self.figs: figs
                            })
        return d_obj

    def training_q(self, sess, c_list, z_list):
        _, d_obj = sess.run([self.train_q, self.q_obj],
                            feed_dict={
                                self.c: c_list,
                                self.z: z_list,
                            })
        return d_obj

    def gen_fig(self, sess, c, z):
        ret_ = sess.run(self.gen_figs, feed_dict={self.c: c, self.z: z})
        ret = []
        for fig in ret_:
            ret.append(np.reshape(fig, [32, 32, 1]))
        return ret
Пример #6
0
def main():
    if sys.argv < 2:
        print "INPUT THE NUMBER OF GPU TO RUN"
        sys.exit(0)

    os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]

    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing
    vocab_size = 5000
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                          SEQ_LENGTH, START_TOKEN)
    target_params = cPickle.load(open('save/target_params.pkl'))

    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM,
                              SEQ_LENGTH, START_TOKEN,
                              target_params)  # The oracle model

    discriminator = Discriminator(sequence_length=20,
                                  num_classes=2,
                                  vocab_size=vocab_size,
                                  embedding_size=dis_embedding_dim,
                                  filter_sizes=dis_filter_sizes,
                                  num_filters=dis_num_filters,
                                  l2_reg_lambda=dis_l2_reg_lambda)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    #generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')
    #  pre-train generator
    print 'Start pre-training...'
    log.write('pre-training...\n')
    for epoch in xrange(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            print 'pre-train epoch ', epoch, 'test_loss ', test_loss, 'pretrain loss', loss
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                test_loss) + '\n'
            log.write(buffer)
    print "MLE TRAIN END"

    print 'Start pre-training discriminator...'
    # Train 3 epoch on the generated data and do this for 50 times
    for _ in range(50):
        continue
        generate_samples(sess, generator, BATCH_SIZE, generated_num,
                         negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in xrange(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {
                    discriminator.input_x: x_batch,
                    discriminator.input_y: y_batch,
                    discriminator.dropout_keep_prob: dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op, feed)

    rollout = ROLLOUT(generator, 0.8)

    print '#########################################################################'
    print 'Start Adversarial Training...'
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        for it in range(1):
            samples = generator.generate(sess)
            #rewards = rollout.get_reward(sess, samples, 16, discriminator)
            rewards = count_reward(samples)
            feed = {generator.x: samples, generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        # this loss has no use
        if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(
                test_loss) + '\n'
            print 'total_batch: ', total_batch, 'test_loss: ', test_loss
            log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        for _ in range(5):
            continue
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in xrange(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op, feed)

    log.close()
Пример #7
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    print("Discriminator loaded successfully!")

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/sample_relativity'):
        os.makedirs('checkpoints/sample_relativity')
    checkpoints_path = 'checkpoints/sample_relativity/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')
    d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode
        generator.train()
        discriminator.train()
        update_learning_rate(num_update, 8e4, args.g_learning_rate,
                             args.lr_shrink, g_optimizer)

        for i, sample in enumerate(trainloader):

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when random.random() > 50%
            if random.random() >= 0.5:

                print("Policy Gradient Training")

                sys_out_batch, p = generator('PG', epoch_i,
                                             sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 * 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64*50 = 3200
                prediction = torch.reshape(
                    prediction,
                    sample['net_input']['src_tokens'].shape)  # 64 X 50

                with torch.no_grad():
                    reward = discriminator(sample['net_input']['src_tokens'],
                                           prediction)  # 64 X 1

                train_trg_batch = sample['target']  # 64 x 50

                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']  # 64
                logging_loss = pg_loss / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss.item(),
                                                      sample_size)
                logging.debug(
                    f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            else:
                # MLE training
                print("MLE Training")

                sys_out_batch, p = generator("MLE", epoch_i, sample)

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                train_trg_batch = sample['target'].view(-1)  # 64*50 = 3200
                loss = g_criterion(out_batch, train_trg_batch)

                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    # print(p.size())
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            num_update += 1

            if random.random() >= 0.7:

                # part II: train the discriminator
                bsz = sample['target'].size(0)  # batch_size = 64

                src_sentence = sample['net_input'][
                    'src_tokens']  # 64 x max-len i.e 64 X 50

                # now train with machine translation output i.e generator output
                true_sentence = sample['target'].view(-1)  # 64*50 = 3200

                true_labels = torch.ones(
                    sample['target'].size(0)).float()  # 64 length vector

                with torch.no_grad():
                    sys_out_batch, p = generator('MLE', epoch_i,
                                                 sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = torch.zeros(
                    sample['target'].size(0)).float()  # 64 length vector

                fake_sentence = torch.reshape(prediction,
                                              src_sentence.shape)  # 64 X 50

                if use_cuda:
                    fake_labels = fake_labels.cuda()

                disc_out = discriminator(src_sentence, fake_sentence)  # 64 X 1

                d_loss = d_criterion(disc_out.squeeze(1), fake_labels)

                acc = torch.sum(
                    torch.round(disc_out).squeeze(1) ==
                    fake_labels).float() / len(fake_labels)

                d_logging_meters['train_acc'].update(acc)
                d_logging_meters['train_loss'].update(d_loss)
                logging.debug(
                    f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}"
                )
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()

        # validation
        # set validation mode
        generator.eval()
        discriminator.eval()
        # Initialize dataloader
        max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
        valloader = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=True,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(valloader):

            with torch.no_grad():
                if use_cuda:
                    # wrap input tensors in cuda tensors
                    sample = utils.make_variable(sample, cuda=cuda)

                # generator validation
                sys_out_batch, p = generator('test', epoch_i, sample)
                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632
                dev_trg_batch = sample['target'].view(-1)  # 64*50 = 3200

                loss = g_criterion(out_batch, dev_trg_batch)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                loss = loss / sample_size / math.log(2)
                g_logging_meters['valid_loss'].update(loss, sample_size)
                logging.debug(
                    f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}"
                )

                # discriminator validation
                bsz = sample['target'].size(0)
                src_sentence = sample['net_input']['src_tokens']
                # train with half human-translation and half machine translation

                true_sentence = sample['target']
                true_labels = torch.ones(sample['target'].size(0)).float()

                with torch.no_grad():
                    sys_out_batch, p = generator('test', epoch_i, sample)

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = torch.zeros(sample['target'].size(0)).float()

                fake_sentence = torch.reshape(prediction,
                                              src_sentence.shape)  # 64 X 50

                if use_cuda:
                    fake_labels = fake_labels.cuda()

                disc_out = discriminator(src_sentence, fake_sentence)
                d_loss = d_criterion(disc_out.squeeze(1), fake_labels)
                acc = torch.sum(
                    torch.round(disc_out).squeeze(1) ==
                    fake_labels).float() / len(fake_labels)
                d_logging_meters['valid_acc'].update(acc)
                d_logging_meters['valid_loss'].update(d_loss)
                logging.debug(
                    f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}"
                )

        torch.save(
            generator,
            open(
                checkpoints_path +
                f"sampling_{g_logging_meters['valid_loss'].avg:.3f}.epoch_{epoch_i}.pt",
                'wb'),
            pickle_module=dill)

        if g_logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = g_logging_meters['valid_loss'].avg
            torch.save(generator,
                       open(checkpoints_path + "best_gmodel.pt", 'wb'),
                       pickle_module=dill)
Пример #8
0
    def train_gan(self):
        self.logger.info("Setting up summary writer to record progress on TensorBoard...")
        summary_writer = tf.summary.create_file_writer(self.log_dir)
        self.logger.info(
            f"Starting adversarial training with {self.epochs} epochs, "
            f"batch size: {self.batch_size}..."
        )
        self.logger.info(f"Building `{self.dataset_name}` "
                         "datasets for source/target/smooth domains...")
        ds_source, steps_per_epoch = self.get_dataset(dataset_name=self.dataset_name,
                                                      domain=self.source_domain,
                                                      _type="train",
                                                      batch_size=self.batch_size)
        ds_target, _ = self.get_dataset(dataset_name=self.dataset_name,
                                        domain=self.target_domain,
                                        _type="train",
                                        batch_size=self.batch_size)
        ds_smooth, _ = self.get_dataset(dataset_name=self.dataset_name,
                                        domain=f"{self.target_domain}_smooth",
                                        _type="train",
                                        batch_size=self.batch_size)
        self.logger.info("Setting up optimizer to update generator and discriminator...")
        g_optimizer = tf.keras.optimizers.Adam(learning_rate=self.generator_lr, beta_1=.5)
        d_optimizer = tf.keras.optimizers.Adam(learning_rate=self.discriminator_lr, beta_1=.5)
        if self.multi_scale:
            self.logger.info(f"Initializing generator with "
                             f"batch_size: {self.batch_size}, input_size: multi-scale...")
        else:
            self.logger.info(f"Initializing generator with "
                             f"batch_size: {self.batch_size}, input_size: {self.input_size}...")
        g = Generator(base_filters=2 if self.debug else 64, light=self.light)
        g(tf.keras.Input(
            shape=(self.input_size, self.input_size, 3),
            batch_size=self.batch_size))

        self.logger.info(f"Searching existing checkpoints: `{self.generator_checkpoint_prefix}`...")
        try:
            g_checkpoint = tf.train.Checkpoint(generator=g)
            g_checkpoint.restore(
                tf.train.latest_checkpoint(
                    self.generator_checkpoint_dir)).assert_existing_objects_matched()
            self.logger.info(f"Previous checkpoints has been restored.")
            trained_epochs = g_checkpoint.save_counter.numpy()
            epochs = self.epochs - trained_epochs
            if epochs <= 0:
                self.logger.info(f"Already trained {trained_epochs} epochs. "
                                 "Set a larger `epochs`...")
                return
            else:
                self.logger.info(f"Already trained {trained_epochs} epochs, "
                                 f"{epochs} epochs left to be trained...")
        except AssertionError as e:
            self.logger.warning(e)
            self.logger.warning(
                "Previous checkpoints are not found, trying to load checkpoints from pretraining..."
            )

            try:
                g_checkpoint = tf.train.Checkpoint(generator=g)
                g_checkpoint.restore(tf.train.latest_checkpoint(
                    os.path.join(
                        self.checkpoint_dir, "pretrain"))).assert_existing_objects_matched()
                self.logger.info("Successfully loaded "
                                 f"`{self.pretrain_checkpoint_prefix}`...")
            except AssertionError:
                self.logger.warning("specified pretrained checkpoint is not found, "
                                    "training from scratch...")

            trained_epochs = 0
            epochs = self.epochs

        if self.multi_scale:
            self.logger.info(f"Initializing discriminator with "
                             f"batch_size: {self.batch_size}, input_size: multi-scale...")
        else:
            self.logger.info(f"Initializing discriminator with "
                             f"batch_size: {self.batch_size}, input_size: {self.input_size}...")
        if self.debug:
            d_base_filters = 2
        elif self.light:
            d_base_filters = 24
        else:
            d_base_filters = 32
        d = Discriminator(base_filters=d_base_filters)
        d(tf.keras.Input(
            shape=(self.input_size, self.input_size, 3),
            batch_size=self.batch_size))

        self.logger.info("Searching existing checkpoints: "
                         f"`{self.discriminator_checkpoint_prefix}`...")
        try:
            d_checkpoint = tf.train.Checkpoint(d=d)
            d_checkpoint.restore(
                tf.train.latest_checkpoint(
                    self.discriminator_checkpoint_dir)).assert_existing_objects_matched()
            self.logger.info(f"Previous checkpoints has been restored.")
        except AssertionError:
            self.logger.info("specified checkpoint is not found, training from scratch...")

        if not self.disable_sampling:
            val_files = glob(os.path.join(
                self.data_dir, self.dataset_name, f"test{self.source_domain}", "*"))
            val_real_batch = tf.map_fn(
                lambda fname: self.image_processing(fname, False),
                tf.constant(val_files), tf.float32, back_prop=False)
            real_batch = next(ds_source)
            while real_batch.shape[0] < self.sample_size:
                real_batch = tf.concat((real_batch, next(ds_source)), 0)
            real_batch = real_batch[:self.sample_size]
            with summary_writer.as_default():
                img = np.expand_dims(self._save_generated_images(
                    tf.cast((real_batch + 1) * 127.5, tf.uint8),
                    image_name="gan_sample_images.png"), 0,)
                tf.summary.image("gan_sample_images", img, step=0)
                img = np.expand_dims(self._save_generated_images(
                    tf.cast((val_real_batch + 1) * 127.5, tf.uint8),
                    image_name="gan_val_sample_images.png"), 0,)
                tf.summary.image("gan_val_sample_images", img, step=0)
            gc.collect()
        else:
            self.logger.info("Proceeding training without sample images...")

        self.logger.info("Starting training loop...")

        self.logger.info(f"Number of trained epochs: {trained_epochs}, "
                         f"epochs to be trained: {epochs}, "
                         f"batch size: {self.batch_size}")
        for epoch in range(epochs):
            epoch_idx = trained_epochs + epoch + 1

            for step in tqdm(
                    range(1, steps_per_epoch + 1),
                    desc=f'Train {epoch + 1}/{epochs}',
                    total=steps_per_epoch):
                source_images, target_images, smooth_images = (
                    ds_source.next(), ds_target.next(), ds_smooth.next())
                self.train_step(source_images, target_images, smooth_images,
                                g, d, g_optimizer, d_optimizer)

                if step % self.reporting_steps == 0:

                    global_step = (epoch_idx - 1) * steps_per_epoch + step
                    with summary_writer.as_default():
                        for metric, name in self.metric_and_names:
                            tf.summary.scalar(name, metric.result(), step=global_step)
                            metric.reset_states()
                        if not self.disable_sampling:
                            fake_batch = tf.cast(
                                (g(real_batch, training=False) + 1) * 127.5, tf.uint8)
                            img = np.expand_dims(self._save_generated_images(
                                    fake_batch,
                                    image_name=("gan_generated_images_at_epoch_"
                                                f"{epoch_idx}_step_{step}.png")),
                                    0,
                            )
                            tf.summary.image('gan_generated_images', img, step=global_step)

                    self.logger.debug(f"Epoch {epoch_idx}, Step {step} finished, "
                                      f"{global_step * self.batch_size} images processed.")

            with summary_writer.as_default():
                if not self.disable_sampling:
                    val_fake_batch = tf.cast(
                        (g(val_real_batch, training=False) + 1) * 127.5, tf.uint8)
                    img = np.expand_dims(self._save_generated_images(
                            val_fake_batch,
                            image_name=("gan_val_generated_images_at_epoch_"
                                        f"{epoch_idx}_step_{step}.png")),
                            0,
                    )
                    tf.summary.image('gan_val_generated_images', img, step=epoch)
            self.logger.info(f"Saving checkpoints after epoch {epoch_idx} ended...")
            g_checkpoint.save(file_prefix=self.generator_checkpoint_prefix)
            d_checkpoint.save(file_prefix=self.discriminator_checkpoint_prefix)

            g.save_weights(os.path.join(self.model_dir, "generator"))
            gc.collect()
        del ds_source, ds_target, ds_smooth
        gc.collect()
Пример #9
0
def train(args):
    # Context
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    aug_list = args.aug_list

    # Model
    scope_gen = "Generator"
    scope_dis = "Discriminator"
    # generator loss
    z = nn.Variable([args.batch_size, args.latent, 1, 1])
    x_fake = Generator(z, scope_name=scope_gen, img_size=args.image_size)
    p_fake = Discriminator([augment(xf, aug_list)
                            for xf in x_fake], label="fake", scope_name=scope_dis)
    lossG = loss_gen(p_fake)
    # discriminator loss
    x_real = nn.Variable(
        [args.batch_size, 3, args.image_size, args.image_size])
    x_real_aug = augment(x_real, aug_list)
    p_real, rec_imgs, part = Discriminator(
        x_real_aug, label="real", scope_name=scope_dis)
    lossD_fake = loss_dis_fake(p_fake)
    lossD_real = loss_dis_real(p_real, rec_imgs, part, x_real_aug)
    lossD = lossD_fake + lossD_real
    # generator with fixed latent values for test
    # Use train=True even in an inference phase
    z_test = nn.Variable.from_numpy_array(
        np.random.randn(args.batch_size, args.latent, 1, 1))
    x_test = Generator(z_test, scope_name=scope_gen,
                       train=True, img_size=args.image_size)[0]

    # Exponential Moving Average (EMA) model
    # Use train=True even in an inference phase
    scope_gen_ema = "Generator_EMA"
    x_test_ema = Generator(z_test, scope_name=scope_gen_ema,
                           train=True, img_size=args.image_size)[0]
    copy_params(scope_gen, scope_gen_ema)
    update_ema_var = make_ema_updater(scope_gen_ema, scope_gen, 0.999)

    # Solver
    solver_gen = S.Adam(args.lr, beta1=0.5)
    solver_dis = S.Adam(args.lr, beta1=0.5)
    with nn.parameter_scope(scope_gen):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope(scope_dis):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries(
        "Generator Loss", monitor, interval=10)
    monitor_loss_dis_real = MonitorSeries(
        "Discriminator Loss Real", monitor, interval=10)
    monitor_loss_dis_fake = MonitorSeries(
        "Discriminator Loss Fake", monitor, interval=10)
    monitor_time = MonitorTimeElapsed(
        "Training Time", monitor, interval=10)
    monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor,
                                                num_images=args.batch_size,
                                                interval=1,
                                                normalize_method=lambda x: (x + 1.) / 2.)
    monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor,
                                               num_images=args.batch_size,
                                               interval=1,
                                               normalize_method=lambda x: (x + 1.) / 2.)
    monitor_image_tile_test_ema = MonitorImageTile("Image Tile Test EMA", monitor,
                                                   num_images=args.batch_size,
                                                   interval=1,
                                                   normalize_method=lambda x: (x + 1.) / 2.)

    # Data Iterator
    rng = np.random.RandomState(141)
    di = data_iterator(args.img_path, args.batch_size,
                       imsize=(args.image_size, args.image_size),
                       num_samples=args.train_samples, rng=rng)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake[0].need_grad = False  # no need backward to generator
        x_fake[1].need_grad = False  # no need backward to generator
        solver_dis.zero_grad()
        x_real.d = di.next()[0]
        z.d = np.random.randn(args.batch_size, args.latent, 1, 1)
        lossD.forward()
        lossD.backward()
        solver_dis.update()

        # Train generator
        x_fake[0].need_grad = True  # need backward to generator
        x_fake[1].need_grad = True  # need backward to generator
        solver_gen.zero_grad()
        lossG.forward()
        lossG.backward()
        solver_gen.update()

        # Update EMA model
        update_ema_var.forward()

        # Monitor
        monitor_loss_gen.add(i, lossG.d)
        monitor_loss_dis_real.add(i, lossD_real.d)
        monitor_loss_dis_fake.add(i, lossD_fake.d)
        monitor_time.add(i)

        # Save
        if (i+1) % args.save_interval == 0:
            with nn.parameter_scope(scope_gen):
                nn.save_parameters(os.path.join(
                    args.monitor_path, "Gen_iter{}.h5".format(i+1)))
            with nn.parameter_scope(scope_gen_ema):
                nn.save_parameters(os.path.join(
                    args.monitor_path, "GenEMA_iter{}.h5".format(i+1)))
            with nn.parameter_scope(scope_dis):
                nn.save_parameters(os.path.join(
                    args.monitor_path, "Dis_iter{}.h5".format(i+1)))
        if (i+1) % args.test_interval == 0:
            x_test.forward(clear_buffer=True)
            x_test_ema.forward(clear_buffer=True)
            monitor_image_tile_train.add(i+1, x_fake[0])
            monitor_image_tile_test.add(i+1, x_test)
            monitor_image_tile_test_ema.add(i+1, x_test_ema)

    # Last
    x_test.forward(clear_buffer=True)
    x_test_ema.forward(clear_buffer=True)
    monitor_image_tile_train.add(args.max_iter, x_fake[0])
    monitor_image_tile_test.add(args.max_iter, x_test)
    monitor_image_tile_test_ema.add(args.max_iter, x_test_ema)
    with nn.parameter_scope(scope_gen):
        nn.save_parameters(os.path.join(args.monitor_path,
                                        "Gen_iter{}.h5".format(args.max_iter)))
    with nn.parameter_scope(scope_gen_ema):
        nn.save_parameters(os.path.join(args.monitor_path,
                                        "GenEMA_iter{}.h5".format(args.max_iter)))
    with nn.parameter_scope(scope_dis):
        nn.save_parameters(os.path.join(args.monitor_path,
                                        "Dis_iter{}.h5".format(args.max_iter)))
Пример #10
0
def main(pretrain_dataset, rl_dataset, args):
    ##############################################################################
    # Setup
    ##############################################################################
    # set random seeds
    random.seed(const.SEED)
    np.random.seed(const.SEED)

    # load datasets
    pt_train_loader, pt_valid_loader = SplitDataLoader(
        pretrain_dataset, batch_size=const.BATCH_SIZE, drop_last=True).split()

    # Define Networks
    generator = Generator(const.VOCAB_SIZE, const.GEN_EMBED_DIM,
                          const.GEN_HIDDEN_DIM, device, args.cuda)
    discriminator = Discriminator(const.VOCAB_SIZE, const.DSCR_EMBED_DIM,
                                  const.DSCR_FILTER_LENGTHS,
                                  const.DSCR_NUM_FILTERS,
                                  const.DSCR_NUM_CLASSES, const.DSCR_DROPOUT)

    # if torch.cuda.device_count() > 1:
    # print("Using", torch.cuda.device_count(), "GPUs.")
    # generator = nn.DataParallel(generator)
    # discriminator = nn.DataParallel(discriminator)
    generator.to(device)
    discriminator.to(device)

    # set CUDA
    if args.cuda and torch.cuda.is_available():
        generator = generator.cuda()
        discriminator = discriminator.cuda()
    ##############################################################################

    ##############################################################################
    # Pre-Training
    ##############################################################################
    # Pretrain and save Generator using MLE, Load the Pretrained generator and display training stats
    # if it already exists.
    print('#' * 80)
    print('Generator Pretraining')
    print('#' * 80)
    if (not (args.force_pretrain
             or args.force_pretrain_gen)) and op.exists(GEN_MODEL_CACHE):
        print('Loading Pretrained Generator ...')
        checkpoint = torch.load(GEN_MODEL_CACHE)
        generator.load_state_dict(checkpoint['state_dict'])
        print('::INFO:: DateTime - %s.' % checkpoint['datetime'])
        print('::INFO:: Model was trained for %d epochs.' %
              checkpoint['epochs'])
        print('::INFO:: Final Training Loss - %.5f' % checkpoint['train_loss'])
        print('::INFO:: Final Validation Loss - %.5f' %
              checkpoint['valid_loss'])
    else:
        try:
            print('Pretraining Generator with MLE ...')
            GeneratorPretrainer(generator, pt_train_loader, pt_valid_loader,
                                PT_CACHE_DIR, device, args).train()
        except KeyboardInterrupt:
            print('Stopped Generator Pretraining Early.')

    # Pretrain Discriminator on real data and data from the pretrained generator. If a pretrained Discriminator
    # already exists, load it and display its stats
    print('#' * 80)
    print('Discriminator Pretraining')
    print('#' * 80)
    if (not (args.force_pretrain
             or args.force_pretrain_dscr)) and op.exists(DSCR_MODEL_CACHE):
        print("Loading Pretrained Discriminator ...")
        checkpoint = torch.load(DSCR_MODEL_CACHE)
        discriminator.load_state_dict(checkpoint['state_dict'])
        print('::INFO:: DateTime - %s.' % checkpoint['datetime'])
        print('::INFO:: Model was trained on %d data generations.' %
              checkpoint['data_gens'])
        print('::INFO:: Model was trained for %d epochs per data generation.' %
              checkpoint['epochs_per_gen'])
        print('::INFO:: Final Loss - %.5f' % checkpoint['loss'])
    else:
        print('Pretraining Discriminator ...')
        try:
            DiscriminatorPretrainer(discriminator, rl_dataset, PT_CACHE_DIR,
                                    TEMP_DATA_DIR, device,
                                    args).train(generator)
        except KeyboardInterrupt:
            print('Stopped Discriminator Pretraining Early.')
    ##############################################################################

    ##############################################################################
    # Adversarial Training
    ##############################################################################
    print('#' * 80)
    print('Adversarial Training')
    print('#' * 80)
    AdversarialRLTrainer(generator, discriminator, rl_dataset, TEMP_DATA_DIR,
                         pt_valid_loader, device, args).train()
Пример #11
0
    def build(self):
        """Build the model."""
        self.global_step = tf.Variable(0, trainable=False, name='global_step')

        self.nets = OrderedDict()
        datasets = self.config["datasets"]

        #we supose that we have only 2 datasets, but the solution for more than two is to make this for each pair

        self.G_vars = []
        self.D_vars = []
        self.output[()] = []
        self.losses[()] = []
        self.optimizers[()] = []

        self.global_step = tf.train.get_or_create_global_step()
        """lr = tf.train.noisy_linear_cosine_decay(self.config['init_lr'],
                            self.global_step,
                            self.config['decay_steps_lr'],
                            num_periods=self.config['cycles_lr'],
                            initial_variance=0.001,
                            alpha=0.0,
                            beta=0.1,
                                name="learning_rate")"""
        lr = tf.constant(self.config['init_lr'])
        """lr = tf.multiply(
                tf.multiply(
                    1e-10,
                    tf.pow(
                        10e0,
                        tf.multiply(
                            tf.cast(
                                tf.div(
                                    self.global_step,
                                    10
                                ),
                                tf.float32
                            ),
                            1e-1
                        )
                    )
                ),
                tf.cast(
                    tf.mod(
                        tf.add(
                            self.global_step,
                            1
                        ),
                        2
                    ),
                    tf.float32
                )
        )"""

        self.nets["learning_rate"] = lr
        self.lr = lr
        self.losses[()] = [("learning_rate", self.lr)]

        for d in datasets:
            self.nets[datasets[d]['id']] = {}

            self.output[(datasets[d]["id"], )] = []
            self.losses[(datasets[d]["id"], )] = []
            self.optimizers[(datasets[d]["id"], )] = []

            data_shape = (datasets[d]['sizes']['batch_size'], None,
                          datasets[d]['sizes']['num_timesteps'],
                          datasets[d]['sizes']['num_pitch'],
                          datasets[d]['sizes']['num_track'])
            t_input = tf.placeholder(tf.float32,
                                     data_shape,
                                     name='in_ph_' +
                                     datasets[d]['id'])  #input bool
            t_seqlen = tf.placeholder(tf.float32,
                                      [datasets[d]['sizes']['batch_size']],
                                      name='len_ph_' + datasets[d]['id'])

            self.input[datasets[d]['id']] = (t_input, t_seqlen)

        m = 1
        for pair in permutations(datasets, 2):
            self.output[(datasets[pair[0]]["id"],
                         datasets[pair[1]]["id"])] = []
            self.losses[(datasets[pair[0]]["id"],
                         datasets[pair[1]]["id"])] = []
            self.optimizers[(datasets[pair[0]]["id"],
                             datasets[pair[1]]["id"])] = []

            x_ = self.input[datasets[pair[0]]['id']][0]

            with tf.device('/gpu:' + str(1 + m)):
                config = deepcopy(self.config)
                config["gen_from_ds"] = datasets[pair[0]]["sizes"]
                config["gen_to_ds"] = datasets[pair[1]]["sizes"]
                self.nets[pair[0] + "_to_" + pair[1]] = Generator(
                    self.input[datasets[pair[0]]['id']][0],
                    self.input[datasets[pair[0]]['id']][1],
                    config,
                    name=pair[0] + "_to_" + pair[1],
                    reuse=tf.AUTO_REUSE)

            #RNN
            with tf.device('/gpu:' + str(1 + int(not m))):
                config["gen_from_ds"], config["gen_to_ds"] = config[
                    "gen_to_ds"], config["gen_from_ds"]
                self.nets[pair[1] + "_backto_" + pair[0]] = Generator(
                    self.nets[pair[0] + "_to_" + pair[1]].tensor_out,
                    self.nets[pair[0] + "_to_" + pair[1]].tensor_len,
                    config,
                    name=pair[1] + "_to_" + pair[0],
                    reuse=tf.AUTO_REUSE)

            self.output[(datasets[pair[0]]["id"], )] += [
                (datasets[pair[0]]["id"], "input",
                 self.input[datasets[pair[0]]['id']][0]),
                (datasets[pair[1]]["id"], "sample_mapped",
                 self.nets[pair[0] + "_to_" + pair[1]].tensor_out),
                (datasets[pair[0]]["id"], "sample_recon",
                 self.nets[pair[1] + "_backto_" + pair[0]].tensor_out),
            ]

            with tf.device('/gpu:' + str(0)):  #1+int(not m))):
                config = deepcopy(self.config)
                config["dis_ds"] = datasets[pair[0]]["sizes"]
                self.nets[pair[0] + "_to_" + pair[1] +
                          "_real"] = Discriminator(
                              self.input[datasets[pair[0]]['id']][0],
                              self.input[datasets[pair[0]]['id']][1],
                              config,
                              name=pair[0] + "_D",
                              reuse=tf.AUTO_REUSE)
                if self.nets.get(pair[0] + "_dis") is None:
                    self.nets[pair[0] + "_dis"] = self.nets[pair[0] + "_to_" +
                                                            pair[1] + "_real"]

            with tf.device('/gpu:' + str(0)):  #1+m)):
                config["dis_ds"] = datasets[pair[1]]["sizes"]
                self.nets[pair[0] + "_to_" + pair[1] +
                          "_fake"] = Discriminator(
                              self.nets[pair[0] + "_to_" + pair[1]].tensor_out,
                              self.nets[pair[0] + "_to_" + pair[1]].tensor_len,
                              config,
                              name=pair[1] + '_D',
                              reuse=tf.AUTO_REUSE)
                if self.nets.get(pair[1] + "_dis") is None:
                    self.nets[pair[1] + "_dis"] = self.nets[pair[0] + "_to_" +
                                                            pair[1] + "_fake"]

            m = int(not m)

        self.nets["G_loss"] = 0
        self.nets["D_loss"] = 0

        ds = list(datasets)
        if len(ds) != 2:
            raise ValueError('It must be size 2')
        for i in range(len(ds)):
            self.D_real = self.nets[ds[i] + "_to_" + ds[not i] + "_real"]
            self.D_fake = self.nets[ds[not i] + "_to_" + ds[i] + "_fake"]
            self.G = self.nets[ds[not i] + "_to_" + ds[i]]
            self.G_inv = self.nets[ds[i] + "_backto_" + ds[not i]]
            self.G_vars += self.G.vars
            self.D_vars += self.D_fake.vars
            # Losses
            self.config["dis_ds"] = datasets[ds[i]]["sizes"]
            self.x_ = self.input[datasets[ds[i]]['id']][0]
            self.g_loss, self.d_loss = self.get_adversarial_loss(
                Discriminator, name=ds[i] + "_D")  #Loss_(GAN_i) Loss_(D_i)
            self.cons_loss = self.get_reconstruction_loss()  #Loss_(CONS_j)
            self.nets[
                "G_loss"] = self.nets["G_loss"] + self.g_loss + self.cons_loss
            self.nets["D_loss"] = self.nets["D_loss"] + self.d_loss

            self.losses[(datasets[ds[i]]["id"],
                         datasets[ds[not i]]["id"])] += [(ds[i] + "_" + "G",
                                                          self.g_loss)]
            self.losses[(datasets[ds[i]]["id"],
                         datasets[ds[not i]]["id"])] += [(ds[i] + "_" + "D",
                                                          self.d_loss)]
            self.losses[(datasets[ds[not i]]["id"], )] += [
                (ds[not i] + "_CONS", self.cons_loss)
            ]
            self.losses[(datasets[ds[not i]]["id"], )] += [
                (ds[i] + "_mapped_densitity",
                 tf.log(tf.reduce_sum(self.G.tensor_out) + 1))
            ]
            self.losses[(datasets[ds[not i]]["id"], )] += [
                (ds[not i] + "_recon_densitity",
                 tf.log(tf.reduce_sum(self.G_inv.tensor_out) + 1))
            ]
            self.losses[(datasets[ds[i]]["id"],
                         datasets[ds[not i]]["id"])] += [
                             (datasets[ds[i]]["id"] + "_diff",
                              tf.reduce_mean(self.x_) -
                              tf.reduce_mean(self.G.tensor_out))
                         ]
        self.losses[(datasets[ds[i]]["id"], datasets[ds[not i]]["id"])] += [
            ("G", self.nets["G_loss"])
        ]
        self.losses[(datasets[ds[i]]["id"], datasets[ds[not i]]["id"])] += [
            ("D", self.nets["D_loss"])
        ]

        self.g_loss = self.nets["G_loss"]
        self.d_loss = self.nets["D_loss"]

        # Optimizers
        with tf.variable_scope('Optimizer'):
            self.g_optimizer = self.get_optimizer()
            self.g_step = self.g_optimizer.minimize(self.g_loss,
                                                    self.global_step,
                                                    self.G_vars)

            self.d_optimizer = self.get_optimizer()
            self.d_step = self.d_optimizer.minimize(self.d_loss,
                                                    self.global_step,
                                                    self.D_vars)

            # Apply weight clipping
            if self.config['gan']['type'] == 'wgan':
                with tf.control_dependencies([self.d_step]):
                    self.d_step = tf.group(*(tf.assign(
                        var,
                        tf.clip_by_value(var,
                                         -self.config['gan']['clip_value'],
                                         self.config['gan']['clip_value']))
                                             for var in self.D_vars))

            self.optimizers[((datasets[ds[i]]["id"],
                              datasets[ds[not i]]["id"]))] += [(0, self.d_step)
                                                               ]
            self.optimizers[((datasets[ds[i]]["id"],
                              datasets[ds[not i]]["id"]))] += [(1, self.g_step)
                                                               ]

        # Saver
        self.saver = tf.train.Saver()
        self.savers["global"] = self.saver
        self.reset_weights = tf.variables_initializer(
            var_list=tf.trainable_variables())
Пример #12
0
def run_tensorflow():
    """
    [summary] This is needed for tensorflow to free up my gpu ram...
    """

    gpus = tf.config.experimental.list_physical_devices("GPU")
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices("GPU")
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

    mixed_precision = tf.keras.mixed_precision.experimental

    policy = mixed_precision.Policy("mixed_float16")
    mixed_precision.set_policy(policy)

    AnimeCleanData = getAnimeCleanData(BATCH_SIZE=batch_size)
    CelebaData = getCelebaData(BATCH_SIZE=batch_size)

    logdir = "./logs/Startrain_data/" + datetime.now().strftime("%Y%m%d-%H%M%S")
    file_writer = tf.summary.create_file_writer(logdir)

    generator_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic"
    )

    discriminator_optimizer = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic"
    )

    up_G_optim = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic"
    )
    up_D_optim = mixed_precision.LossScaleOptimizer(
        tf.keras.optimizers.Adam(2e-4, beta_1=0.5), loss_scale="dynamic"
    )
    up_G = UpsampleGenerator()
    up_D = Discriminator()

    generator = GeneratorV2()
    # input: Batch, 256,256,3
    discriminator = StarDiscriminator()

    checkpoint_path = "./checkpoints/StarTrain"

    ckpt = tf.train.Checkpoint(
        generator = generator,
        discriminator = discriminator,
        generator_optimizer = generator_optimizer,
        discriminator_optimizer = discriminator_optimizer,

    )

    ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

    # if a checkpoint exists, restore the latest checkpoint.
    if ckpt_manager.latest_checkpoint:
        ckpt.restore(ckpt_manager.latest_checkpoint)
        print("Latest checkpoint restored!!")

    # out: Batch, 16, 16, 1
    # x is human, y is anime
    @tf.function
    def trainstep(real_human, real_anime, big_anime):
        with tf.GradientTape(persistent=True) as tape:
            ones = tf.ones_like(real_human)
            neg_ones = tf.ones_like(real_human) * -1

            def get_domain_anime(img):
                return tf.concat([img, ones], 3)

            def get_domain_human(img):
                return tf.concat([img, neg_ones], 3)

            fake_anime = generator(get_domain_anime(real_human), training=True)
            cycled_human = generator(get_domain_human(fake_anime), training=True)

            fake_human = generator(get_domain_human(real_anime), training=True)
            cycled_anime = generator(get_domain_anime(fake_human), training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator(get_domain_human(real_human), training=True)
            same_anime = generator(get_domain_anime(real_anime), training=True)

            disc_real_human, label_real_human = discriminator(real_human, training=True)
            disc_real_anime, label_real_anime = discriminator(real_anime, training=True)

            disc_fake_human, label_fake_human = discriminator(fake_human, training=True)
            disc_fake_anime, label_fake_anime = discriminator(fake_anime, training=True)

            _, label_cycled_human = discriminator(cycled_human, training=True)
            _, label_cycled_anime = discriminator(cycled_anime, training=True)

            _, label_same_human = discriminator(same_human, training=True)
            _, label_same_anime = discriminator(same_anime, training=True)

            # calculate the loss
            gen_anime_loss = generator_loss(disc_fake_anime)
            gen_human_loss = generator_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human, cycled_human) + cycle_loss(
                real_anime, cycled_anime
            )

            gen_class_loss = (
                discriminator_loss(label_fake_human, label_fake_anime)
                + discriminator_loss(label_cycled_human, label_cycled_anime)
                + discriminator_loss(label_same_human, label_same_anime)
            )

            # Total generator loss = adversarial loss + cycle loss
            total_gen_loss = (
                gen_anime_loss
                + gen_human_loss 
                + gen_class_loss
                + total_cycle_loss * 0.1
                + identity_loss(real_anime, same_anime)
                + identity_loss(real_human, same_human)
            )

            tf.print("gen_anime_loss",gen_anime_loss)
            tf.print("gen_human_loss",gen_human_loss)
            tf.print("gen_class_loss",gen_class_loss)
            tf.print("total_cycle_loss",total_cycle_loss)
            tf.print("identity_loss(real_anime, same_anime)",identity_loss(real_anime, same_anime))
            tf.print("identity_loss(real_human, same_human)",identity_loss(real_human, same_human))

            scaled_total_gen_anime_loss = generator_optimizer.get_scaled_loss(
                total_gen_loss
            )

            disc_human_loss = discriminator_loss(disc_real_human, disc_fake_human)
            disc_anime_loss = discriminator_loss(disc_real_anime, disc_fake_anime)

            # disc_gp_anime = gradient_penalty_star(partial(discriminator, training=True), real_anime,fake_anime )
            # disc_gp_human = gradient_penalty_star(partial(discriminator, training=True), real_human,fake_human )

            disc_loss = disc_human_loss + disc_anime_loss + discriminator_loss(label_real_human,label_real_anime)
            # +disc_gp_anime+disc_gp_human

            scaled_disc_loss = discriminator_optimizer.get_scaled_loss(
                disc_loss
            )

        # Calculate the gradients for generator and discriminator
        generator_gradients =generator_optimizer.get_unscaled_gradients( tape.gradient(
            scaled_total_gen_anime_loss, generator.trainable_variables
        ))
        discriminator_gradients = discriminator_optimizer.get_unscaled_gradients( tape.gradient(
            scaled_disc_loss, discriminator.trainable_variables
        ))

        generator_optimizer.apply_gradients(
            zip(generator_gradients, generator.trainable_variables)
        )

        discriminator_optimizer.apply_gradients(
            zip(discriminator_gradients, discriminator.trainable_variables)
        )

        with tf.GradientTape(persistent=True) as tape:
            real_anime_up = up_G(real_anime)
            fake_anime_up = up_G(fake_anime)

            dis_fake_anime_up = up_D(fake_anime_up)
            dis_real_anime_up = up_D(real_anime_up)
            dis_ori_anime = up_D(big_anime)
            gen_up_loss =  generator_loss(fake_anime_up) + generator_loss(dis_real_anime_up)*0.1
            dis_up_loss = discriminator_loss(dis_ori_anime,dis_fake_anime_up)+discriminator_loss(dis_ori_anime,dis_real_anime_up)*0.1
            scaled_gen_up_loss = up_G_optim.get_scaled_loss(gen_up_loss)
            scaled_disc_loss = up_D_optim.get_scaled_loss(dis_up_loss)

        up_G_gradients =up_G_optim.get_unscaled_gradients( tape.gradient(
            scaled_gen_up_loss, up_G.trainable_variables
        ))
        up_D_gradients = up_D_optim.get_unscaled_gradients( tape.gradient(
            scaled_disc_loss, up_D.trainable_variables
        ))

        up_G_optim.apply_gradients(
            zip(up_G_gradients, up_G.trainable_variables)
        )

        up_D_optim.apply_gradients(
            zip(up_D_gradients, up_D.trainable_variables)
        )
            

        return (
            real_human,
            real_anime,
            fake_anime,
            cycled_human,
            fake_human,
            cycled_anime,
            same_human,
            same_anime,
            fake_anime_up,
            real_anime_up,
            gen_anime_loss,
            gen_human_loss,
            disc_human_loss,
            disc_anime_loss,
            gen_up_loss,
            dis_up_loss
        )

    def process_data_for_display(input_image):
        return input_image * 0.5 + 0.5


    print_string = [
            "real_human",
            "real_anime",
            "fake_anime",
            "cycled_human",
            "fake_human",
            "cycled_anime",
            "same_human",
            "same_anime",
            "fake_anime_up",
            "real_anime_up",
            "gen_anime_loss",
            "gen_human_loss",
            "disc_human_loss",
            "disc_anime_loss",
            "gen_up_loss",
            "dis_up_loss"
    ]

    counter = 0
    i = -1
    while True:
        i = i + 1
        counter = counter + 1
        AnimeBatchImage, BigAnimeBatchImage = next(iter(AnimeCleanData))
        CelebaBatchImage = next(iter(CelebaData))
        print(counter)

        if not (i % 5):
            result = trainstep(CelebaBatchImage, AnimeBatchImage,BigAnimeBatchImage)

            with file_writer.as_default():
                for j in range(len(result)):
                    if j<10:
                        tf.summary.image(
                        print_string[j],
                        process_data_for_display(result[j]),
                        step=counter,
                        )
                    else:
                        tf.summary.scalar(
                        print_string[j],
                        result[j],
                        step=counter,
                        )
                
            ckpt_manager.save()
        else:
            trainstep(CelebaBatchImage, AnimeBatchImage,BigAnimeBatchImage)
Пример #13
0
    shutil.rmtree(args.modelName)
    os.makedirs(args.modelName)

# for handling training over GPU
cpu_device = torch.device('cpu')
fast_device = torch.device('cpu')
if (args.use_gpu):
    fast_device = torch.device('cuda:' + str(args.gpu_device))

# config file storing hyperparameters
config = importlib.import_module(args.config).config

# Initializing the models
#generator_model = Generator(nstack = 8 , inp_dim = 256 , oup_dim = 6)

discriminator_model_pose = Discriminator(6 +3, config['discriminator']['num_channels'], config['dataset']['num_joints'], config['discriminator']['num_residuals'])

discriminator_model_pose = Discriminator2(nstack = 1 , inp_dim = 256 , oup_dim = 6)

#discriminator_model_conf = Discriminator3(nstack = 1 , inp_dim = 256 , oup_dim = 6)


modelpath_g = torch.load('train-model-19/supervised-medical-660-lr-0002/experiment04-batch-size-1/model_20_520.pt')
generator_model = modelpath_g['generator_model']
#print(generator_model)
#print(generator_model)
#print(discriminator_model)

####modelpath_d = torch.load('train-model-16-medical-pre-trainied_pose_conf/pretrained_conf_pose/pretrained_conf_pose_10_500.pt')

####discriminator_model_pose = modelpath_d['discriminator_model_pose']
Пример #14
0
    def __init__(self, generator: Generator, discriminator: Discriminator):
        super().__init__()

        self.generator = generator.apply(weights_init)
        self.discriminator = discriminator.apply(weights_init)
Пример #15
0
class DCGAN:
    # initialise network with learning rate, layer shape etc
    def __init__(self,
                 img_shape,
                 epochs=50000,
                 lr_gen=0.0001,
                 lr_disc=0.0001,
                 z_shape=100,
                 batch_size=64,
                 beta1=0.5,
                 epochs_for_sample=500):

        # initalise architecture vars
        self.rows, self.cols, self.channels = img_shape
        self.batch_size = batch_size
        self.epochs = epochs
        self.z_shape = z_shape
        self.epochs_for_sample = epochs_for_sample
        # intialise underlying networks
        self.generator = Generator(img_shape, self.batch_size)
        self.discriminator = Discriminator(img_shape)

        mnist = tf.keras.datasets.mnist
        (x_train, _), (x_test, _) = mnist.load_data()

        X = np.concatenate([x_train, x_test])
        # As and after training for the generator,sampling will occur. Uses tanh for generator output for
        # best results <--- need to rescale MNIST [0,1] -> [-1,1]
        self.X = X / 127.5 - 1  # Scale between -1 and 1
        self.phX = tf.placeholder(tf.float32, [None, self.rows, self.cols])
        self.phZ = tf.placeholder(tf.float32, [None, self.z_shape])

        self.gen_out = self.generator.forward(self.phZ)

        disc_logits_fake = self.discriminator.forward(self.gen_out)
        disc_logits_real = self.discriminator.forward(self.phX)

        # compute cost functions - sigmoid cross entropy (sigmoid as real or fake)
        disc_fake_loss = cost(tf.zeros_like(disc_logits_fake),
                              disc_logits_fake)
        disc_real_loss = cost(tf.ones_like(disc_logits_real), disc_logits_real)

        self.disc_loss = tf.add(disc_fake_loss, disc_real_loss)
        self.gen_loss = cost(tf.ones_like(disc_logits_fake), disc_logits_fake)

        train_vars = tf.trainable_variables()

        disc_vars = [var for var in train_vars if 'd' in var.name]
        gen_vars = [var for var in train_vars if 'g' in var.name]

        self.disc_train = tf.train.AdamOptimizer(
            lr_disc, beta1=beta1).minimize(self.disc_loss, var_list=disc_vars)
        self.gen_train = tf.train.AdamOptimizer(lr_gen, beta1=beta1).minimize(
            self.gen_loss, var_list=gen_vars)

    def train(self):
        init = tf.global_variables_initializer()
        self.sess = tf.Session()
        self.sess.run(init)

        for i in range(self.epochs):
            idx = np.random.randint(0, len(self.X), self.batch_size)
            batch_X = self.X[idx]
            batch_Z = np.random.uniform(-1, 1, (self.batch_size, self.z_shape))

            _, d_loss = self.sess.run([self.disc_train, self.disc_loss],
                                      feed_dict={
                                          self.phX: batch_X,
                                          self.phZ: batch_Z
                                      })
            batch_Z = np.random.uniform(-1, 1, (self.batch_size, self.z_shape))
            _, g_loss = self.sess.run([self.gen_train, self.gen_loss],
                                      feed_dict={self.phZ: batch_Z})
            if i % self.epochs_for_sample == 0:
                self.generate_sample(i)
                print(
                    f"Epoch: {i}. Discriminator loss: {d_loss}. Generator loss: {g_loss}"
                )

    def generate_sample(self, epoch):
        c = 7
        r = 7
        z = np.random.uniform(-1, 1, (self.batch_size, self.z_shape))
        imgs = self.sess.run(self.gen_out, feed_dict={self.phZ: z})
        imgs = imgs * 0.5 + 0.5
        # scale between 0, 1
        fig, axs = plt.subplots(c, r)
        cnt = 0
        for i in range(c):
            for j in range(r):
                axs[i, j].imshow(imgs[cnt, :, :, 0], cmap="gray")
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig("samples/%d.png" % epoch)
        plt.close()
Пример #16
0
class SVM_Classifier:
    def __init__(self, batch_size, image_size=64):
        self.image_size = image_size
        self.device = torch.device("cuda:0" if (
            torch.cuda.is_available()) else "cpu")
        self.save_filename = f'model_{datetime.datetime.now().strftime("%a_%H_%M")}.sav'

        transform = transforms.Compose([
            # transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.trainset = torchvision.datasets.CIFAR10(root='./data',
                                                     train=True,
                                                     download=True,
                                                     transform=transform)
        self.trainloader = data.DataLoader(self.trainset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=2)

        self.testset = torchvision.datasets.CIFAR10(root='./data',
                                                    train=False,
                                                    download=True,
                                                    transform=transform)
        self.testloader = data.DataLoader(self.testset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=2)

        saved_state = torch.load(
            "C:\\Users\\ankit\\Workspaces\\CS7150\\FinalProject\\models\\imagenet\\trained_model_Tue_17_06.pth"
        )
        self.discriminator = Discriminator(ngpu=1,
                                           num_channels=3,
                                           num_features=64,
                                           data_generation_mode=1,
                                           input_size=image_size)
        self.discriminator.load_state_dict(saved_state['discriminator'])
        self.discriminator.eval()  # change the mode of the network.

    def plot_training_data(self):
        # Plot some training images
        real_batch = next(iter(self.trainloader))
        real_batch = real_batch[0][0:8]
        plt.figure(figsize=(8, 8))
        plt.axis("off")
        plt.title("Training Images")
        plt.imshow(
            np.transpose(
                vutils.make_grid(real_batch[0].to(self.device)[:64],
                                 padding=2,
                                 normalize=True).cpu(), (1, 2, 0)))
        plt.show()

    def train(self):
        train_data, train_labels = next(iter(self.trainloader))
        modified_train_data = self.discriminator(train_data)
        l2_svm = svm.LinearSVC(verbose=2, max_iter=2000)

        modified_train_data_ndarray = modified_train_data.detach().numpy()
        train_labels_ndarray = train_labels.detach().numpy()
        self.l2_svm = l2_svm.fit(modified_train_data_ndarray,
                                 train_labels_ndarray)

        # save model
        with open(self.save_filename, 'wb') as file:
            pickle.dump(self.l2_svm, file)

    def train_test_SGD_Classifier(self):
        est = make_pipeline(StandardScaler(), SGDClassifier(max_iter=200))
        training_data = self.discriminator(next(iter(self.trainloader))[0])
        training_data = training_data.detach().numpy()
        est.steps[0][1].fit(training_data)

        self.est = est

        for i, data in enumerate(self.trainloader):
            train_data, train_labels = data
            modified_train_data = self.discriminator(train_data)

            modified_train_data_ndarray = modified_train_data.detach().numpy()
            train_labels_ndarray = train_labels.detach().numpy()
            modified_train_data_ndarray = est.steps[0][1].transform(
                modified_train_data_ndarray)

            est.steps[1][1].partial_fit(
                modified_train_data_ndarray,
                train_labels_ndarray,
                classes=np.unique(train_labels_ndarray))
            print(f'Batch: {i}')

        with open(self.save_filename, 'wb') as file:
            pickle.dump(est.steps[1][1], file)

    def test(self):
        l2_svm = self.est.steps[1][1]
        accuracy = []

        for i, data in enumerate(self.testloader):
            test_data, test_labels = data
            modified_test_data = self.discriminator(test_data)

            modified_test_data_ndarray = modified_test_data.detach().numpy()
            test_labels_ndarray = test_labels.detach().numpy()
            modified_test_data_ndarray = self.est.steps[0][1].transform(
                modified_test_data_ndarray)

            predictions = l2_svm.predict(modified_test_data_ndarray)

            accuracy.append(
                metrics.accuracy_score(test_labels_ndarray, predictions))

        print(f'Accuracy: {np.mean(accuracy)}')
Пример #17
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    assert START_TOKEN == 0

    #
    # Declare data loader
    # ----------------------------------------------------------------------------
    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE) # For testing
    vocab_size = 5000
    dis_data_loader = Dis_dataloader(BATCH_SIZE)
    # ----------------------------------------------------------------------------


    #
    # Declare Generator & Discriminator
    # ----------------------------------------------------------------------------
    # declare: generator
    generator = Generator(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN)
    target_params = cPickle.load(open('save/target_params.pkl'))
    target_lstm = TARGET_LSTM(vocab_size, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, START_TOKEN, target_params) # The oracle model

    # declare: discriminator
    discriminator = Discriminator(sequence_length=20, num_classes=2,
                                   vocab_size=vocab_size, embedding_size=dis_embedding_dim,
                                   filter_sizes=dis_filter_sizes, num_filters=dis_num_filters,
                                   l2_reg_lambda=dis_l2_reg_lambda)
    # ----------------------------------------------------------------------------

    #
    # Set the session <sess>
    # ----------------------------------------------------------------------------
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    # ----------------------------------------------------------------------------

    # First, use the oracle model to provide the positive examples, which are sampled from the oracle data distribution
    # generate samples by using <target_lstm> and write the samples to file <positive_file>
    #generate_samples(sess, target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(positive_file)

    log = open('save/experiment-log.txt', 'w')


    #
    # Pre-train <generator> by using <gen_data_loader>,
    # and then compute the <test_loss> of <target_lstm> and <likelihood_data_loader>
    # ----------------------------------------------------------------------------
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        loss = pre_train_epoch(sess, generator, gen_data_loader)
        if epoch % 5 == 0:
            # generate samples by using <generator> and write the samples to file <eval_file>
            generate_samples(sess, generator, BATCH_SIZE, generated_num, eval_file)

            # load samples from file <eval_file>
            likelihood_data_loader.create_batches(eval_file)

            # compute <test_loss> of <target_lstm>, with input <likelihood_data_loader>
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)

            print('pre-train epoch ', epoch, 'test_loss ', test_loss)
            buffer = 'epoch:\t'+ str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)
    # ----------------------------------------------------------------------------


    #
    # Pre-train <discriminator> by using <generator>
    # ----------------------------------------------------------------------------
    print('Start pre-training discriminator...')
    # Generate data and train 3 epoch on the generated data, which will be done for 50 times
    for _ in range(50):
        # generate samples by using <generator> and write the samples to file <negative_file>
        generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)

        # load samples from file <negative_file>
        dis_data_loader.load_train_data(positive_file, negative_file)

        for _ in range(3):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                feed = {discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob}
                _ = sess.run(discriminator.train_op, feed_dict=feed)
    # ----------------------------------------------------------------------------

    rollout = ROLLOUT(generator, 0.8)

    #
    # Start seqGAN, train <discriminator> and <generator>
    # ----------------------------------------------------------------------------
    print('#########################################################################')
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    for total_batch in range(TOTAL_BATCH):

        # ----- Train the generator for one step -----------------
        for it in range(G_STEPS):
            samples = generator.generate(sess)
            rewards = rollout.get_reward(sess, samples, ROLLOUT_NUM, discriminator, SEQ_LENGTH)
            feed = {generator.x: samples,
                    generator.rewards: rewards}
            _ = sess.run(generator.g_updates, feed_dict=feed)
        # --------------------------------------------------------

        # Update roll-out parameters
        rollout.update_params()

        # ----- Train the discriminator -------------------------
        for _ in range(D_STEPS):
            generate_samples(sess, generator, BATCH_SIZE, generated_num, negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)

            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {discriminator.input_x: x_batch,
                            discriminator.input_y: y_batch,
                            discriminator.dropout_keep_prob: dis_dropout_keep_prob}
                    _ = sess.run(discriminator.train_op, feed_dict=feed)
        # --------------------------------------------------------
    # ----------------------------------------------------------------------------

    log.close()
Пример #18
0
class WGANGP(Model):
    def __init__(self,
                 scope_name,
                 channel_min,
                 img_size,
                 generator_size=100,
                 channel_rate=2):
        super(WGANGP, self).__init__(scope_name)
        self.scope_name = scope_name
        self.channel_min = channel_min
        self.channel_rate = channel_rate
        self.img_size = img_size
        self.input_img = tf.placeholder(tf.float32,
                                        shape=[None] + img_size,
                                        name="input_img")
        self.input_z = tf.placeholder(tf.float32,
                                      shape=[None, generator_size],
                                      name="input_z")

    def buind(self, train_fn=tf_tools.adam_fn, real_lr=1e-5, fake_lr=1e-5):
        with tf.variable_scope(self.scope_name) as scope:
            self.fake_img, self.real_output, self.fake_output = self.buind_network(
            )
            self.var_list = tf.trainable_variables(scope=self.scope_name)
            self.real_train_fn = train_fn(real_lr, 0, 0.9)
            self.fake_train_fn = train_fn(fake_lr, 0, 0.9)
            self.build_optimization()

    def buind_network(self, fake_normal=True):
        self.real_network = Discriminator(self.channel_min,
                                          1,
                                          name="discriminator")
        self.fake_network = Generator(self.channel_min,
                                      self.img_size,
                                      name="generator")

        fake_img = self.fake_network.build(self.input_z, times=3)
        if fake_normal:
            fake_img = tf.nn.sigmoid(fake_img)

        real_output = self.real_network.build(self.input_img,
                                              times=3,
                                              normal=False)
        fake_output = self.real_network.build(fake_img,
                                              times=3,
                                              reuse=tf.AUTO_REUSE,
                                              normal=False)

        return fake_img, real_output, fake_output

    def build_optimization(self):
        epsilon = tf.random_uniform([], 0.0, 1.0)
        x_hat = self.input_img * epsilon + (1 - epsilon) * self.fake_img
        d_hat = self.real_network.build(x_hat,
                                        times=3,
                                        reuse=tf.AUTO_REUSE,
                                        normal=False)
        gradients = tf.gradients(d_hat, x_hat)[0]
        slopes = tf.sqrt(
            tf.reduce_sum(tf.square(gradients), reduction_indices=[1]))
        gradient_penalty = 10 * tf.reduce_mean((slopes - 1.0)**2)

        self.real_loss_op = tf.reduce_mean(self.fake_output) - tf.reduce_mean(
            self.real_output) + gradient_penalty
        self.fake_loss_op = -tf.reduce_mean(self.fake_output)

        self.real_index = tf.Variable(0)
        self.real_train_op = self.real_train_fn.minimize(
            self.real_loss_op,
            var_list=self.real_network.var_list,
            global_step=self.real_index)

        self.fake_index = tf.Variable(0)
        self.fake_train_op = self.fake_train_fn.minimize(
            self.fake_loss_op,
            var_list=self.fake_network.var_list,
            global_step=self.fake_index)

    def predict(self, z):
        session = tf.get_default_session()
        feed_dict = {self.input_z: z}
        output = session.run(self.fake_img, feed_dict=feed_dict)
        return output

    def train(self, img, z, mode='D'):
        session = tf.get_default_session()
        feed_dict = {self.input_img: img, self.input_z: z}
        if mode == 'D':
            session.run(self.real_train_op, feed_dict=feed_dict)
        else:
            session.run(self.fake_train_op, feed_dict=feed_dict)
        return self.loss(img, z)

    def loss(self, img, z):
        session = tf.get_default_session()
        feed_dict = {self.input_img: img, self.input_z: z}
        loss = session.run([self.real_loss_op, self.fake_loss_op],
                           feed_dict=feed_dict)
        return loss
Пример #19
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    seq_len = args.seq_len

    # prepare data
    gen_data_loader = Gen_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE,
                                             SEQ_LENGTH)  # For testing
    dis_data_loader = Dis_Data_loader(BATCH_SIZE, SEQ_LENGTH)
    generator = None
    discriminator = None
    target_lstm = None

    while seq_len <= SEQ_LENGTH:
        log = open(args.save + '/experiment-log' + str(seq_len) + '.txt', 'w')
        print("Current sequence length is " + str(seq_len))
        print('Args:', args)
        log.write(str(args))
        if generator is None:
            log.write("Init generator")
            print("Init generator")
        else:
            log.write("Used same generator")
            print("Used same generator")
        generator = Generator(
            num_emb=vocab_size,
            batch_size=BATCH_SIZE,
            emb_dim=EMB_DIM,
            num_units=HIDDEN_DIM,
            sequence_length=SEQ_LENGTH,
            start_token=START_TOKEN,
            true_seq_len=seq_len,
            save_model_path=args.save) if generator is None else generator
        generator.true_seq_len = seq_len

        # target_params's size: [15 * 5000 * 32]
        target_params = pickle.load(open('./save/target_params_py3.pkl', 'rb'))
        # The oracle model
        target_lstm = TARGET_LSTM(
            5000, BATCH_SIZE, EMB_DIM, HIDDEN_DIM, SEQ_LENGTH, 0,
            target_params, seq_len) if target_lstm is None else target_lstm
        target_lstm.true_seq_len = seq_len
        discriminator = Discriminator(
            sequence_length=SEQ_LENGTH,
            num_classes=2,
            vocab_size=vocab_size,
            embedding_size=dis_embedding_dim,
            filter_sizes=dis_filter_sizes,
            num_filters=dis_num_filters,
            l2_reg_lambda=dis_l2_reg_lambda,
            save_model_path=args.save
        ) if discriminator is None else discriminator

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        sess = tf.Session(config=config)
        sess.run(tf.global_variables_initializer())

        generate_samples_from_target(sess, target_lstm, BATCH_SIZE,
                                     generated_num, positive_file)
        gen_data_loader.create_batches(positive_file, seq_len)

        # print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
        #
        # likelihood_data_loader.create_batches(positive_file)
        # for i in range(100):
        #     test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
        #     print('my step ', i, 'test_loss ', test_loss)
        #     input("next:")
        # input("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

        #log = open('save_19_20/experiment-log' + str(seq_len) + '.txt', 'w')
        #  pre-train generator
        print('Start pre-training...')
        log.write('pre-training...\n')
        ans_file = open(args.save + '/learning_cure' + str(seq_len) + '.txt',
                        'w')
        epochs = args.gen_pre_epoch
        #ans_file.write("-------- %s \n" % seq_len)
        for epoch in range(epochs):  # 120
            loss = pre_train_epoch(sess, generator, gen_data_loader)
            if epoch % 1 == 0:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 eval_file)
                likelihood_data_loader.create_batches(eval_file, seq_len)
                test_loss = target_loss(sess, target_lstm,
                                        likelihood_data_loader)
                print('pre-train epoch ', epoch, 'test_loss ', test_loss)
                buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(
                    test_loss) + '\n'
                log.write(buffer)
                ans_file.write("%s\n" % str(test_loss))

        buffer = 'Start pre-training discriminator...'
        print(buffer)
        log.write(buffer)
        for _ in range(args.disc_pre_epoch):  # 10
            generate_samples(sess, generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file,
                                            seq_len)
            for _ in range(3):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x: x_batch,
                        discriminator.input_y: y_batch,
                        discriminator.dropout_keep_prob: dis_dropout_keep_prob,
                    }
                    d_loss, d_acc, _ = sess.run([
                        discriminator.loss, discriminator.accuracy,
                        discriminator.train_op
                    ], feed)
            buffer = "discriminator loss %f acc %f\n" % (d_loss, d_acc)
            print(buffer)

            log.write(buffer)
        ans_file.write("==========\n")
        print("Start Adversarial Training...")
        log.write('adversarial training...')
        TOTAL_BATCH = args.adversarial_epoch
        for total_batch in range(TOTAL_BATCH):
            # Train the generator
            for it in range(1):
                samples = generator.generate(sess)
                rewards = generator.get_reward(sess, samples, 16,
                                               discriminator, START_TOKEN)
                a = str(samples[0])
                b = str(rewards[0])
                buffer = "%s\n%s\n\n" % (a, b)
                # print(buffer)
                log.write(buffer)
                rewards_loss = generator.update_with_rewards(
                    sess, samples, rewards, START_TOKEN)

                # good rewards
                # good_samples = gen_data_loader.next_batch()
                # rewards = np.array([[1.0] * SEQ_LENGTH] * BATCH_SIZE)
                # a = str(good_samples[0])
                # b = str(rewards[0])
                # buffer = "%s\n%s\n\n" % (a, b)
                # print(buffer)
                # log.write(buffer)
                # rewards_loss = generator.update_with_rewards(sess, good_samples, rewards, START_TOKEN)

                # little1 good reward
                # litter1_samples = gen_data_loader.next_batch()
                # rewards = generator.get_reward(sess, litter1_samples, 16, discriminator, START_TOKEN)
                # a = str(little1 good reward[0])
                # b = str(rewards[0])
                # buffer = "%s\n%s\n\n" % (a, b)
                # print(buffer)
                # log.write(buffer)
                # rewards_loss = generator.update_with_rewards(sess, litter1_samples, rewards, START_TOKEN)

            # Test
            if total_batch % 1 == 0 or total_batch == TOTAL_BATCH - 1:
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 eval_file)
                likelihood_data_loader.create_batches(eval_file, seq_len)
                test_loss = target_loss(sess, target_lstm,
                                        likelihood_data_loader)
                buffer = 'reward-train epoch %s train loss %s test_loss %s\n' % (
                    str(total_batch), str(rewards_loss), str(test_loss))
                print(buffer)
                log.write(buffer)
                ans_file.write("%s\n" % str(test_loss))

            if total_batch % 20 == 0 or total_batch == TOTAL_BATCH - 1:
                generator.save_model(sess, seq_len)

            # Train the discriminator
            for _ in range(1):
                generate_samples(sess, generator, BATCH_SIZE, generated_num,
                                 negative_file)
                dis_data_loader.load_train_data(positive_file, negative_file,
                                                seq_len)
                for _ in range(3):
                    dis_data_loader.reset_pointer()
                    for it in range(dis_data_loader.num_batch):
                        x_batch, y_batch = dis_data_loader.next_batch()
                        feed = {
                            discriminator.input_x: x_batch,
                            discriminator.input_y: y_batch,
                            discriminator.dropout_keep_prob:
                            dis_dropout_keep_prob,
                        }
                        d_loss, d_acc, _ = sess.run([
                            discriminator.loss, discriminator.accuracy,
                            discriminator.train_op
                        ], feed)
                if total_batch % 5 == 0 or total_batch == TOTAL_BATCH - 1:
                    buffer = "discriminator loss %f acc %f\n" % (d_loss, d_acc)
                    print(buffer)
                    log.write(buffer)
                if total_batch % 20 == 0 or total_batch == TOTAL_BATCH - 1:
                    discriminator.save_model(sess, seq_len)
        seq_len += 1
Пример #20
0
    def __init__(
            self,
            X_train_file='',
            Y_train_file='',
            batch_size=1,
            #image_size=256,
            image_height=240,
            image_width=320,
            use_lsgan=True,
            norm='instance',
            lambda1=10.0,
            lambda2=10.0,
            learning_rate=2e-4,
            beta1=0.5,
            ngf=64):
        """
    Args:
      X_train_file: string, X tfrecords file for training
      Y_train_file: string Y tfrecords file for training
      batch_size: integer, batch size
      image_size: integer, image size
      lambda1: integer, weight for forward cycle loss (X->Y->X)
      lambda2: integer, weight for backward cycle loss (Y->X->Y)
      use_lsgan: boolean
      norm: 'instance' or 'batch'
      learning_rate: float, initial learning rate for Adam
      beta1: float, momentum term of Adam
      ngf: number of gen filters in first conv layer
    """
        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.use_lsgan = use_lsgan
        use_sigmoid = not use_lsgan
        self.batch_size = batch_size
        #self.image_size = image_size
        self.image_height = image_height
        self.image_width = image_width
        self.learning_rate = learning_rate
        self.beta1 = beta1
        self.X_train_file = X_train_file
        self.Y_train_file = Y_train_file

        self.is_training = tf.placeholder_with_default(True,
                                                       shape=[],
                                                       name='is_training')

        self.G = Generator('G',
                           self.is_training,
                           ngf=ngf,
                           norm=norm,
                           image_height=image_height,
                           image_width=image_width)
        self.D_Y = Discriminator('D_Y',
                                 self.is_training,
                                 norm=norm,
                                 use_sigmoid=use_sigmoid)
        self.F = Generator('F',
                           self.is_training,
                           norm=norm,
                           image_height=image_height,
                           image_width=image_width)
        self.D_X = Discriminator('D_X',
                                 self.is_training,
                                 norm=norm,
                                 use_sigmoid=use_sigmoid)

        self.fake_x = tf.placeholder(
            tf.float32, shape=[batch_size, image_height, image_width, 3])
        self.fake_y = tf.placeholder(
            tf.float32, shape=[batch_size, image_height, image_width, 3])
Пример #21
0
class DCGAN:
    def __init__(self, img_shape, epochs=50000, lr_gen=0.0001, lr_disc=0.0001, z_shape=100, num_classes = 256, batch_size=100, beta1=0.5, epochs_for_sample=500):
        
        self.rows, self.cols, self.channels = img_shape
        self.batch_size = batch_size
        self.epochs = epochs
        self.z_shape = z_shape
        self.num_classes = num_classes
        self.epochs_for_sample = epochs_for_sample
        self.generator = Generator(self.z_shape,self.num_classes, img_shape, self.batch_size)
        self.discriminator = Discriminator(self.channels, self.num_classes, img_shape)
        self.samples = []
        self.losses = []

        self.SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))

        # Default paths.
        self.DEFAULT_LABEL_FILE = os.path.join(self.SCRIPT_PATH, './labels/256-common-hangul.txt')
        self.DEFAULT_TFRECORDS_DIR = os.path.join(self.SCRIPT_PATH, 'tfrecords-output')


        """Perform graph definition and model training.

        Here we will first create our input pipeline for reading in TFRecords
        files and producing random batches of images and labels.
        """

        labels = io.open(self.DEFAULT_LABEL_FILE, 'r', encoding='utf-8').read().splitlines()
        num_classes = len(labels)

        print('Processing data...')

        tf_record_pattern = os.path.join(self.DEFAULT_TFRECORDS_DIR, '%s-*' % 'train')
        self.train_data_files = tf.gfile.Glob(tf_record_pattern)

        """
        label, image = get_image(self.train_data_files, num_classes)

        # Associate objects with a randomly selected batch of labels and images.
        self.image_batch, self.label_batch = tf.train.shuffle_batch(
            [image, label], batch_size=self.batch_size,
            capacity=2000,
            min_after_dequeue=1000)
        """

        # Make tf.data.Dataset
        # If you want to use one more parameter for decode, use 'lambda' for data.map
        dataset = tf.data.TFRecordDataset(self.train_data_files)
        dataset = dataset.map(lambda x: get_image(x, self.num_classes))
        dataset = dataset.repeat(self.train_epoch)  # set epoch
        dataset = dataset.shuffle(buffer_size=3 * self.batch_size)  # for getting data in each buffer size data part
        dataset = dataset.batch(self.batch_size)  # set batch size
        dataset = dataset.prefetch(buffer_size=1)  # reduce GPU starvation

        # Make iterator for dataset
        self.iterator = dataset.make_initializable_iterator()
        self.next_element = self.iterator.get_next()

        self.phX = tf.placeholder(tf.float32, [None, self.rows, self.cols, self.channels])
        self.phZ = tf.placeholder(tf.float32, [None, self.z_shape])
        self.phY_g = tf.placeholder(tf.float32, [None, self.num_classes])
        self.phY_d = tf.placeholder(tf.float32, shape=(None,  self.rows, self.cols, self.num_classes))
    
        self.gen_out = self.generator.forward(self.phZ, self.phY_g) #output shape of this z is (?, 28, 28, 1)

        disc_logits_fake = self.discriminator.forward(self.gen_out, self.phY_d ) #out put shape of this logit is (?, 1)
        disc_logits_real = self.discriminator.forward(self.phX, self.phY_d ) # out put shape of this logit is (?, 1)
        
        disc_fake_loss = cost(tf.zeros_like(disc_logits_fake), disc_logits_fake)
        disc_real_loss = cost(tf.ones_like(disc_logits_real), disc_logits_real)

        self.disc_loss = tf.add(disc_fake_loss, disc_real_loss)
        self.gen_loss = cost(tf.ones_like(disc_logits_fake), disc_logits_fake)

        train_vars = tf.trainable_variables()

        self.disc_vars = [var for var in train_vars if 'd' in var.name]
        self.gen_vars = [var for var in train_vars if 'g' in var.name]

        self.disc_train = tf.train.AdamOptimizer(lr_disc,beta1=beta1).minimize(self.disc_loss, var_list=self.disc_vars)
        self.gen_train = tf.train.AdamOptimizer(lr_gen, beta1=beta1).minimize(self.gen_loss, var_list=self.gen_vars)
        


    def train(self):
        init = [tf.global_variables_initializer(), self.iterator.initializer]
        config = tf.ConfigProto()
        config.gpu_options.allow_growth=True
        self.sess = tf.Session(config=config)
        self.sess.run(init)

        # Initialize the queue threads.
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=self.sess, coord=coord)

        epoch_start_time = time.time()
        for i in range(self.epochs):
            # Get a random batch of images and labels.
            train_labels, train_images = self.sess.run(self.next_element)

            # Real image input for Real Discriminator,
            # Get images, reshape and rescale to pass to D
            batch_X = train_images.reshape((self.batch_size, self.rows, self.cols, self.channels))
            batch_X = batch_X * 2 - 1

            # Z noise for Generator
            batch_Z = np.random.uniform(-1, 1, (self.batch_size, self.z_shape)) # Shape is [?, 100]

            # Label input for Generator
            batch_Y_g = train_labels
            batch_Y_g = batch_Y_g.reshape([self.batch_size, self.num_classes])

            # Label input for Discriminator
            batch_Y_d = train_labels    
            batch_Y_d = batch_Y_d.reshape([self.batch_size,1,1,self.num_classes])
            batch_Y_d = batch_Y_d * np.ones([self.batch_size, self.rows, self.cols, self.num_classes])

            _, d_loss = self.sess.run([self.disc_train, self.disc_loss], feed_dict={self.phX:batch_X, self.phZ:batch_Z, self.phY_g:batch_Y_g, self.phY_d:batch_Y_d})
            batch_Z = np.random.uniform(-1, 1, (self.batch_size, self.z_shape))
            _, g_loss = self.sess.run([self.gen_train, self.gen_loss], feed_dict={self.phX:batch_X, self.phZ:batch_Z, self.phY_g:batch_Y_g, self.phY_d:batch_Y_d})
            
            if i % self.epochs_for_sample == 0:
                epoch_end_time = time.time()
                per_epoch_ptime = epoch_end_time - epoch_start_time

                print(f"Epoch: {i}. Discriminator loss: {d_loss}. Generator loss: {g_loss}")
                # Save losses to view after training
                self.losses.append((d_loss, g_loss))

        # Save training generator samples
        with open('train_samples.pkl', 'wb') as f:
            pkl.dump(self.samples, f)

        # Generate random sample after training
        self.generate_random_sample()
        
        # Stop queue threads and close session.
        coord.request_stop()
        coord.join(threads)
        self.sess.close() 



    def generate_random_sample(self):
        init = tf.global_variables_initializer()
        config = tf.ConfigProto()
        config.gpu_options.allow_growth=True
        self.sess = tf.Session(config=config)
        self.sess.run(init)
        # Only save generator variables
        saver = tf.train.Saver(var_list=self.gen_vars)
        c = 7
        r = 7
        # data_len = Get_dataset_length(self.train_data_files)
        # data_len_y = np.ndarray(data_len, dtype=np.uint8)

        # z = np.random.uniform(-1, 1, (self.batch_size, self.z_shape))
        # idx = np.random.randint(0, data_len, self.batch_size)
        # print('length of images are ', data_len)
        # print('Batch size is ', self.batch_size)
        # print('idx shape is is ', idx.shape)
        # print('Y shape is ', data_len_y.shape)
        
        # # Label input for Generator
        # batch_Y_g = np.eye(self.num_classes)[data_len_y]
        # batch_Y_g = batch_Y_g[idx]
        # batch_Y_g = batch_Y_g.reshape([self.batch_size, self.num_classes])
        n_sample = 100
        z = np.random.uniform(-1, 1, (self.batch_size, self.z_shape))

        # Create conditional one-hot vector, with index 5 = 1
        batch_Y_g = np.zeros(shape=[n_sample, 256])
        batch_Y_g[:, 0] = 4
        saver.restore(self.sess, tf.train.latest_checkpoint('checkpoints'))
        samples = self.sess.run(self.gen_out, feed_dict={self.phZ:z, self.phY_g:batch_Y_g})

        # scale between 0, 1
        fig, axs = plt.subplots(c, r)
        cnt = 0
        for i in range(c):
            for j in range(r):
                axs[i, j].imshow(samples[cnt, :, :, 0], cmap="gray")
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig("generated/generated_test_1.png")
        plt.close()
Пример #22
0
    def __init__(
        self,
        batch_size=1,
        image_size=256,
        use_lsgan=True,
        norm='instance',
        lambda1=10,
        lambda2=10,
        learning_rate=2e-4,
        beta1=0.5,
        ngf=32,
        use_gpu=0,
        discrim_inp_ch=28,
    ):
        """
    Args:
      X_train_file: string, X tfrecords file for training
      Y_train_file: string Y tfrecords file for training
      batch_size: integer, batch size
      image_size: integer, image size
      lambda1: integer, weight for forward cycle loss (X->Y->X)
      lambda2: integer, weight for backward cycle loss (Y->X->Y)
      use_lsgan: boolean
      norm: 'instance' or 'batch'
      learning_rate: float, initial learning rate for Adam
      beta1: float, momentum term of Adam
      ngf: number of gen filters in first conv layer
    """
        if type(use_gpu) is int:
            use_gpu = [use_gpu]
        self.use_gpu = use_gpu

        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.use_lsgan = use_lsgan
        use_sigmoid = not use_lsgan
        self.batch_size = batch_size
        self.image_size = image_size
        self.learning_rate = learning_rate
        self.beta1 = beta1

        self.opt = self.make_optimizer()

        self.is_training = tf.placeholder_with_default(True,
                                                       shape=[],
                                                       name='is_training')

        self.G = Generator('G',
                           self.is_training,
                           output_channel=3,
                           ngf=ngf,
                           norm=norm,
                           image_size=image_size)
        self.D_Y = Discriminator('D_Y',
                                 self.is_training,
                                 norm=norm,
                                 use_sigmoid=use_sigmoid)
        self.F = Generator('F',
                           self.is_training,
                           output_channel=3,
                           norm=norm,
                           image_size=image_size)

        self.D_X = Discriminator('D_X',
                                 self.is_training,
                                 norm=norm,
                                 use_sigmoid=use_sigmoid)

        self.fake_x = tf.placeholder(
            tf.float32,
            shape=[batch_size, image_size, image_size, discrim_inp_ch])
        self.fake_y = tf.placeholder(
            tf.float32,
            shape=[batch_size, image_size, image_size, discrim_inp_ch])
Пример #23
0
    def __init__(self, img_shape, epochs=50000, lr_gen=0.0001, lr_disc=0.0001, z_shape=100, num_classes = 256, batch_size=100, beta1=0.5, epochs_for_sample=500):
        
        self.rows, self.cols, self.channels = img_shape
        self.batch_size = batch_size
        self.epochs = epochs
        self.z_shape = z_shape
        self.num_classes = num_classes
        self.epochs_for_sample = epochs_for_sample
        self.generator = Generator(self.z_shape,self.num_classes, img_shape, self.batch_size)
        self.discriminator = Discriminator(self.channels, self.num_classes, img_shape)
        self.samples = []
        self.losses = []

        self.SCRIPT_PATH = os.path.dirname(os.path.abspath(__file__))

        # Default paths.
        self.DEFAULT_LABEL_FILE = os.path.join(self.SCRIPT_PATH, './labels/256-common-hangul.txt')
        self.DEFAULT_TFRECORDS_DIR = os.path.join(self.SCRIPT_PATH, 'tfrecords-output')


        """Perform graph definition and model training.

        Here we will first create our input pipeline for reading in TFRecords
        files and producing random batches of images and labels.
        """

        labels = io.open(self.DEFAULT_LABEL_FILE, 'r', encoding='utf-8').read().splitlines()
        num_classes = len(labels)

        print('Processing data...')

        tf_record_pattern = os.path.join(self.DEFAULT_TFRECORDS_DIR, '%s-*' % 'train')
        self.train_data_files = tf.gfile.Glob(tf_record_pattern)

        """
        label, image = get_image(self.train_data_files, num_classes)

        # Associate objects with a randomly selected batch of labels and images.
        self.image_batch, self.label_batch = tf.train.shuffle_batch(
            [image, label], batch_size=self.batch_size,
            capacity=2000,
            min_after_dequeue=1000)
        """

        # Make tf.data.Dataset
        # If you want to use one more parameter for decode, use 'lambda' for data.map
        dataset = tf.data.TFRecordDataset(self.train_data_files)
        dataset = dataset.map(lambda x: get_image(x, self.num_classes))
        dataset = dataset.repeat(self.train_epoch)  # set epoch
        dataset = dataset.shuffle(buffer_size=3 * self.batch_size)  # for getting data in each buffer size data part
        dataset = dataset.batch(self.batch_size)  # set batch size
        dataset = dataset.prefetch(buffer_size=1)  # reduce GPU starvation

        # Make iterator for dataset
        self.iterator = dataset.make_initializable_iterator()
        self.next_element = self.iterator.get_next()

        self.phX = tf.placeholder(tf.float32, [None, self.rows, self.cols, self.channels])
        self.phZ = tf.placeholder(tf.float32, [None, self.z_shape])
        self.phY_g = tf.placeholder(tf.float32, [None, self.num_classes])
        self.phY_d = tf.placeholder(tf.float32, shape=(None,  self.rows, self.cols, self.num_classes))
    
        self.gen_out = self.generator.forward(self.phZ, self.phY_g) #output shape of this z is (?, 28, 28, 1)

        disc_logits_fake = self.discriminator.forward(self.gen_out, self.phY_d ) #out put shape of this logit is (?, 1)
        disc_logits_real = self.discriminator.forward(self.phX, self.phY_d ) # out put shape of this logit is (?, 1)
        
        disc_fake_loss = cost(tf.zeros_like(disc_logits_fake), disc_logits_fake)
        disc_real_loss = cost(tf.ones_like(disc_logits_real), disc_logits_real)

        self.disc_loss = tf.add(disc_fake_loss, disc_real_loss)
        self.gen_loss = cost(tf.ones_like(disc_logits_fake), disc_logits_fake)

        train_vars = tf.trainable_variables()

        self.disc_vars = [var for var in train_vars if 'd' in var.name]
        self.gen_vars = [var for var in train_vars if 'g' in var.name]

        self.disc_train = tf.train.AdamOptimizer(lr_disc,beta1=beta1).minimize(self.disc_loss, var_list=self.disc_vars)
        self.gen_train = tf.train.AdamOptimizer(lr_gen, beta1=beta1).minimize(self.gen_loss, var_list=self.gen_vars)
Пример #24
0
# 3. Build the models, losses and optimizers

# 3-1. Build the G model
#  Modified U-Net
generator = Generator().generate()

display_on = False
if display_on:
    inp, re = il.load(PATH + 'train/100.jpg')
    gen_output = generator(inp[tf.newaxis, ...], training=False)
    plt.imshow(gen_output[0, ...])
    ret = input()

# 3-2. Build teh D model
discriminator = Discriminator().generate()

if display_on:
    disc_out = discriminator([inp[tf.newaxis, ...], gen_output],
                             training=False)
    plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')
    plt.colorbar()
    ret = input()

# prepare the losses
LAMBDA = 100
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def generator_loss(disc_generated_output, gen_output, target):
    gen_loss = loss_object(tf.ones_like(disc_generated_output),
Пример #25
0
                    help='Save states of models every nsave iterations.')
parser.add_argument(
    '--retrain',
    action='store_true',
    help='Whether or not to start training from a previous state.')
args = parser.parse_args()

print("Initializing generator model and optimizer.")
g_net = Generator().cuda()
g_opt = optim.RMSprop(g_net.parameters(),
                      args.learning_rate_d,
                      weight_decay=args.rmsprop_decay)
g_losses = np.empty(0)

print("Initializing discriminator model and optimizer.")
d_net = Discriminator().cuda()
d_opt = optim.RMSprop(d_net.parameters(),
                      args.learning_rate_d,
                      weight_decay=args.rmsprop_decay)
d_losses = np.empty(0)

if args.retrain:
    g_net.load_state_dict(torch.load('../data/generator_state'))
    d_net.load_state_dict(torch.load('../data/discriminator_state'))

print("Beginning training..")
loader = ETL(args.batch_size, args.image_size, args.path)

for iteration in range(args.iterations):

    # Train discriminator
Пример #26
0
def run_tensorflow():
    """
    [summary] This is needed for tensorflow to free up my gpu ram...
    """

    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus),
                  "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

    # policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
    # tf.keras.mixed_precision.experimental.set_policy(policy)
    # print('Compute dtype: %s' % policy.compute_dtype)
    # print('Variable dtype: %s' % policy.variable_dtype)

    AnimeCleanData = getAnimeCleanData(BATCH_SIZE=32)
    CelebaData = getCelebaData()

    logdir = "../logs/train_data/" + datetime.now().strftime("%Y%m%d-%H%M%S")
    file_writer = tf.summary.create_file_writer(logdir)

    # AnimeBatchImage = next(iter(AnimeCleanData))
    # CelebaBatchImage = next(iter(CelebaData))
    # print(image.dtype)

    # # checkpoint_path = "./checkpoints/train"

    # # ckpt = tf.train.Checkpoint(generator_to_anime=generator_to_anime,
    # #                            generator_to_human=generator_to_human,
    # #                            discriminator_x=discriminator_x,
    # #                            discriminator_y=discriminator_y,
    # #                            generator_to_anime_optimizer=generator_to_anime_optimizer,
    # #                            generator_to_human_optimizer=generator_to_human_optimizer,
    # #                            discriminator_x_optimizer=discriminator_x_optimizer,
    # #                            discriminator_y_optimizer=discriminator_y_optimizer)

    # # ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

    # # # if a checkpoint exists, restore the latest checkpoint.
    # # if ckpt_manager.latest_checkpoint:
    # #   ckpt.restore(ckpt_manager.latest_checkpoint)
    # #   print ('Latest checkpoint restored!!')

    generator_to_anime_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    generator_to_human_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    # input: Batch, 256,256,3
    discriminator_x = Discriminator()
    discriminator_y = Discriminator()
    # out: Batch, 16, 16, 1

    generator_to_anime = Generator()
    generator_to_human = Generator()

    # x is human, y is anime
    @tf.function
    def trainstep(real_human, real_anime):
        with tf.GradientTape(persistent=True) as tape:

            fake_anime = generator_to_anime(real_human, training=True)
            cycled_human = generator_to_human(fake_anime, training=True)

            fake_human = generator_to_human(real_anime, training=True)
            cycled_anime = generator_to_anime(fake_human, training=True)

            # same_human and same_anime are used for identity loss.
            same_human = generator_to_human(real_human, training=True)
            same_anime = generator_to_anime(real_anime, training=True)

            disc_real_human = discriminator_x(real_human, training=True)
            disc_real_anime = discriminator_y(real_anime, training=True)

            disc_fake_human = discriminator_x(fake_human, training=True)
            disc_fake_anime = discriminator_y(fake_anime, training=True)

            # calculate the loss
            gen_anime_loss = generator_loss(disc_fake_anime)
            gen_human_loss = generator_loss(disc_fake_human)

            total_cycle_loss = cycle_loss(real_human,
                                          cycled_human) + cycle_loss(
                                              real_anime, cycled_anime)

            # Total generator loss = adversarial loss + cycle loss
            total_gen_anime_loss = gen_anime_loss + total_cycle_loss + identity_loss(
                real_anime, same_anime)
            total_gen_human_loss = gen_human_loss + total_cycle_loss + identity_loss(
                real_human, same_human)

            disc_x_loss = discriminator_loss(disc_real_human, disc_fake_human)
            disc_y_loss = discriminator_loss(disc_real_anime, disc_fake_anime)

        # Calculate the gradients for generator and discriminator
        generator_to_anime_gradients = tape.gradient(
            total_gen_anime_loss, generator_to_anime.trainable_variables)
        generator_to_human_gradients = tape.gradient(
            total_gen_human_loss, generator_to_human.trainable_variables)

        discriminator_x_gradients = tape.gradient(
            disc_x_loss, discriminator_x.trainable_variables)
        discriminator_y_gradients = tape.gradient(
            disc_y_loss, discriminator_y.trainable_variables)

        # Apply the gradients to the optimizer
        generator_to_anime_optimizer.apply_gradients(
            zip(generator_to_anime_gradients,
                generator_to_anime.trainable_variables))

        generator_to_human_optimizer.apply_gradients(
            zip(generator_to_human_gradients,
                generator_to_human.trainable_variables))

        discriminator_x_optimizer.apply_gradients(
            zip(discriminator_x_gradients,
                discriminator_x.trainable_variables))

        discriminator_y_optimizer.apply_gradients(
            zip(discriminator_y_gradients,
                discriminator_y.trainable_variables))

        return fake_anime, cycled_human, fake_human, cycled_anime , same_human , same_anime, \
            gen_anime_loss, gen_human_loss, disc_x_loss, disc_y_loss, total_gen_anime_loss, total_gen_human_loss

    counter = 0
    i = -1
    while True:
        i = i + 1
        counter = counter + 1
        AnimeBatchImage = next(iter(AnimeCleanData))
        CelebaBatchImage = next(iter(CelebaData))

        if not (i % 5):
            fake_anime, cycled_human, fake_human, cycled_anime , same_human , same_anime, \
                gen_anime_loss, gen_human_loss, disc_x_loss, disc_y_loss, total_gen_anime_loss, total_gen_human_loss = trainstep(CelebaBatchImage, AnimeBatchImage)

            with file_writer.as_default():

                tf.summary.image("fake_anime", fake_anime, step=counter)
                tf.summary.image("cycled_human", cycled_human, step=counter)
                tf.summary.image("fake_human", fake_human, step=counter)
                tf.summary.image("cycled_anime", cycled_anime, step=counter)
                tf.summary.image("same_human", same_human, step=counter)
                tf.summary.image("same_anime", same_anime, step=counter)
                tf.summary.scalar("gen_anime_loss",
                                  gen_anime_loss,
                                  step=counter)
                tf.summary.scalar("gen_human_loss",
                                  gen_human_loss,
                                  step=counter)
                tf.summary.scalar("disc_x_loss", disc_x_loss, step=counter)
                tf.summary.scalar("disc_y_loss", disc_y_loss, step=counter)
                tf.summary.scalar("total_gen_anime_loss",
                                  total_gen_anime_loss,
                                  step=counter)
                tf.summary.scalar("total_gen_human_loss",
                                  total_gen_human_loss,
                                  step=counter)

                # tf.summary.image("CelebaBatchImage", CelebaBatchImage, step=counter)
        else:
            trainstep(CelebaBatchImage, AnimeBatchImage)
 def __init__(self, g_hidden_size, d_hidden_size, char_list):
     self.char_list = char_list
     self.generator = Generator(g_hidden_size, char_list)
     self.discriminator = Discriminator(len(char_list), d_hidden_size)
Пример #28
0
class trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.OldLabel_generator = U_Net(in_ch=cfg.DATASET.N_CLASS,
                                        out_ch=cfg.DATASET.N_CLASS,
                                        side='out')
        self.Image_generator = U_Net(in_ch=3,
                                     out_ch=cfg.DATASET.N_CLASS,
                                     side='in')
        self.discriminator = Discriminator(cfg.DATASET.N_CLASS + 3,
                                           cfg.DATASET.IMGSIZE,
                                           patch=True)

        self.criterion_G = GeneratorLoss(cfg.LOSS.LOSS_WEIGHT[0],
                                         cfg.LOSS.LOSS_WEIGHT[1],
                                         cfg.LOSS.LOSS_WEIGHT[2],
                                         ignore_index=cfg.LOSS.IGNORE_INDEX)
        self.criterion_D = DiscriminatorLoss()

        train_dataset = BaseDataset(cfg, split='train')
        valid_dataset = BaseDataset(cfg, split='val')
        self.train_dataloader = data.DataLoader(
            train_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)
        self.valid_dataloader = data.DataLoader(
            valid_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)

        self.ckpt_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints')
        if not os.path.isdir(self.ckpt_outdir):
            os.mkdir(self.ckpt_outdir)
        self.val_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'val')
        if not os.path.isdir(self.val_outdir):
            os.mkdir(self.val_outdir)
        self.start_epoch = cfg.TRAIN.RESUME
        self.n_epoch = cfg.TRAIN.N_EPOCH

        self.optimizer_G = torch.optim.Adam(
            [{
                'params': self.OldLabel_generator.parameters()
            }, {
                'params': self.Image_generator.parameters()
            }],
            lr=cfg.OPTIMIZER.G_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        self.optimizer_D = torch.optim.Adam(
            [{
                'params': self.discriminator.parameters(),
                'initial_lr': cfg.OPTIMIZER.D_LR
            }],
            lr=cfg.OPTIMIZER.D_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        iter_per_epoch = len(train_dataset) // cfg.DATASET.BATCHSIZE
        lambda_poly = lambda iters: pow(
            (1.0 - iters / (cfg.TRAIN.N_EPOCH * iter_per_epoch)), 0.9)
        self.scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)
        self.scheduler_D = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)

        self.logger = logger(cfg.TRAIN.OUTDIR, name='train')
        self.running_metrics = runningScore(n_classes=cfg.DATASET.N_CLASS)

        if self.start_epoch >= 0:
            self.OldLabel_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_N'])
            self.Image_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_I'])
            self.discriminator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_D'])
            self.optimizer_G.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_G'])
            self.optimizer_D.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_D'])

            log = "Using the {}th checkpoint".format(self.start_epoch)
            self.logger.info(log)
        self.Image_generator = self.Image_generator.cuda()
        self.OldLabel_generator = self.OldLabel_generator.cuda()
        self.discriminator = self.discriminator.cuda()
        self.criterion_G = self.criterion_G.cuda()
        self.criterion_D = self.criterion_D.cuda()

    def train(self):
        all_train_iter_total_loss = []
        all_train_iter_corr_loss = []
        all_train_iter_recover_loss = []
        all_train_iter_change_loss = []
        all_train_iter_gan_loss_gen = []
        all_train_iter_gan_loss_dis = []
        all_val_epo_iou = []
        all_val_epo_acc = []
        iter_num = [0]
        epoch_num = []
        num_batches = len(self.train_dataloader)

        for epoch_i in range(self.start_epoch + 1, self.n_epoch):
            iter_total_loss = AverageTracker()
            iter_corr_loss = AverageTracker()
            iter_recover_loss = AverageTracker()
            iter_change_loss = AverageTracker()
            iter_gan_loss_gen = AverageTracker()
            iter_gan_loss_dis = AverageTracker()
            batch_time = AverageTracker()
            tic = time.time()

            # train
            self.OldLabel_generator.train()
            self.Image_generator.train()
            self.discriminator.train()
            for i, meta in enumerate(self.train_dataloader):

                image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                ), meta[2].cuda()
                recover_pred, feats = self.OldLabel_generator(
                    label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                corr_pred = self.Image_generator(image, feats)

                # -------------------
                # Train Discriminator
                # -------------------
                self.discriminator.set_requires_grad(True)
                self.optimizer_D.zero_grad()

                fake_sample = torch.cat((image, corr_pred), 1).detach()
                real_sample = torch.cat(
                    (image, label2onehot(new_label, cfg.DATASET.N_CLASS)), 1)

                score_fake_d = self.discriminator(fake_sample)
                score_real = self.discriminator(real_sample)

                gan_loss_dis = self.criterion_D(pred_score=score_fake_d,
                                                real_score=score_real)
                gan_loss_dis.backward()
                self.optimizer_D.step()
                self.scheduler_D.step()

                # ---------------
                # Train Generator
                # ---------------
                self.discriminator.set_requires_grad(False)
                self.optimizer_G.zero_grad()

                score_fake = self.discriminator(
                    torch.cat((image, corr_pred), 1))

                total_loss, corr_loss, recover_loss, change_loss, gan_loss_gen = self.criterion_G(
                    corr_pred, recover_pred, score_fake, old_label, new_label)

                total_loss.backward()
                self.optimizer_G.step()
                self.scheduler_G.step()

                iter_total_loss.update(total_loss.item())
                iter_corr_loss.update(corr_loss.item())
                iter_recover_loss.update(recover_loss.item())
                iter_change_loss.update(change_loss.item())
                iter_gan_loss_gen.update(gan_loss_gen.item())
                iter_gan_loss_dis.update(gan_loss_dis.item())
                batch_time.update(time.time() - tic)
                tic = time.time()

                log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, ' \
                      'Total Loss: {:.6f}, Corr Loss: {:.6f}, Recover Loss: {:.6f}, Change Loss: {:.6f}, GAN_G Loss: {:.6f}, GAN_D Loss: {:.6f}'.format(
                    datetime.now(), epoch_i, i, num_batches, batch_time.avg,
                    total_loss.item(), corr_loss.item(), recover_loss.item(), change_loss.item(), gan_loss_gen.item(), gan_loss_dis.item())
                print(log)

                if (i + 1) % 10 == 0:
                    all_train_iter_total_loss.append(iter_total_loss.avg)
                    all_train_iter_corr_loss.append(iter_corr_loss.avg)
                    all_train_iter_recover_loss.append(iter_recover_loss.avg)
                    all_train_iter_change_loss.append(iter_change_loss.avg)
                    all_train_iter_gan_loss_gen.append(iter_gan_loss_gen.avg)
                    all_train_iter_gan_loss_dis.append(iter_gan_loss_dis.avg)
                    iter_total_loss.reset()
                    iter_corr_loss.reset()
                    iter_recover_loss.reset()
                    iter_change_loss.reset()
                    iter_gan_loss_gen.reset()
                    iter_gan_loss_dis.reset()

                    vis.line(X=np.column_stack(
                        np.repeat(np.expand_dims(iter_num, 0), 6, axis=0)),
                             Y=np.column_stack((all_train_iter_total_loss,
                                                all_train_iter_corr_loss,
                                                all_train_iter_recover_loss,
                                                all_train_iter_change_loss,
                                                all_train_iter_gan_loss_gen,
                                                all_train_iter_gan_loss_dis)),
                             opts={
                                 'legend': [
                                     'total_loss', 'corr_loss', 'recover_loss',
                                     'change_loss', 'gan_loss_gen',
                                     'gan_loss_dis'
                                 ],
                                 'linecolor':
                                 np.array([[255, 0, 0], [0, 255, 0],
                                           [0, 0, 255], [255, 255, 0],
                                           [0, 255, 255], [255, 0, 255]]),
                                 'title':
                                 'Train loss of generator and discriminator'
                             },
                             win='Train loss of generator and discriminator')
                    iter_num.append(iter_num[-1] + 1)

            # eval
            self.OldLabel_generator.eval()
            self.Image_generator.eval()
            self.discriminator.eval()
            with torch.no_grad():
                for j, meta in enumerate(self.valid_dataloader):
                    image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                    ), meta[2].cuda()
                    recover_pred, feats = self.OldLabel_generator(
                        label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                    corr_pred = self.Image_generator(image, feats)
                    preds = np.argmax(corr_pred.cpu().detach().numpy().copy(),
                                      axis=1)
                    target = new_label.cpu().detach().numpy().copy()
                    self.running_metrics.update(target, preds)

                    if j == 0:
                        color_map1 = gen_color_map(preds[0, :]).astype(
                            np.uint8)
                        color_map2 = gen_color_map(preds[1, :]).astype(
                            np.uint8)
                        color_map = cv2.hconcat([color_map1, color_map2])
                        cv2.imwrite(
                            os.path.join(
                                self.val_outdir, '{}epoch*{}*{}.png'.format(
                                    epoch_i, meta[3][0], meta[3][1])),
                            color_map)

            score = self.running_metrics.get_scores()
            oa = score['Overall Acc: \t']
            precision = score['Precision: \t'][1]
            recall = score['Recall: \t'][1]
            iou = score['Class IoU: \t'][1]
            miou = score['Mean IoU: \t']
            self.running_metrics.reset()

            epoch_num.append(epoch_i)
            all_val_epo_acc.append(oa)
            all_val_epo_iou.append(miou)
            vis.line(X=np.column_stack(
                np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)),
                     Y=np.column_stack((all_val_epo_acc, all_val_epo_iou)),
                     opts={
                         'legend':
                         ['val epoch Overall Acc', 'val epoch Mean IoU'],
                         'linecolor': np.array([[255, 0, 0], [0, 255, 0]]),
                         'title': 'Validate Accuracy and IoU'
                     },
                     win='validate Accuracy and IoU')

            log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \
                .format(datetime.now(), epoch_i, oa, recall, miou)
            self.logger.info(log)

            state = {
                'epoch': epoch_i,
                "acc": oa,
                "recall": recall,
                "iou": miou,
                'model_G_N': self.OldLabel_generator.state_dict(),
                'model_G_I': self.Image_generator.state_dict(),
                'model_D': self.discriminator.state_dict(),
                'optimizer_G': self.optimizer_G.state_dict(),
                'optimizer_D': self.optimizer_D.state_dict()
            }
            save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints',
                                     '{}epoch.pth'.format(epoch_i))
            torch.save(state, save_path)
class GAN:

    # g_hidden_size: size of hidden layer in generator
    # d_hidden_size: size of hidden layer in discriminator
    # char_list: list of characters the generator can generate
    def __init__(self, g_hidden_size, d_hidden_size, char_list):
        self.char_list = char_list
        self.generator = Generator(g_hidden_size, char_list)
        self.discriminator = Discriminator(len(char_list), d_hidden_size)

    # X_actual: input data from dataset (not generated)
    # n_epochs: total epochs to train entire network
    # g_epochs: how long to train generator each epoch
    # d_epochs: how long to train disciminator each epoch
    # g_initial_lr, g_multiplier: generator RMSprop parameters
    # d_initial_lr, d_multiplier: discriminator RMSprop parameters
    # g_batch_size, d_batch_size: batch sizes for generator and discriminator
    # num_displayed: if print progress is True, this is how many example words
    # to display - make this None to display all examples
    def train(self,
              X_actual,
              seq_len,
              n_epochs,
              g_epochs,
              d_epochs,
              g_initial_lr,
              d_initial_lr,
              g_multiplier,
              d_multiplier,
              g_batch_size,
              d_batch_size,
              print_progress=False,
              num_displayed=None):

        num_examples = X_actual.shape[0]
        # TODO: make genr_input change every epoch
        genr_input = np.random.randn(num_examples, self.generator.input_size)
        for i in range(n_epochs):

            # generate text
            genr_output = self.generator.generate_tensor(
                seq_len, num_examples, genr_input)

            # train discriminator
            self.discriminator.train_RMS(X_actual, genr_output, d_epochs,
                                         d_initial_lr, d_multiplier,
                                         d_batch_size)

            # evaluate dicriminator
            if print_progress:
                genr_output = self.generator.generate_tensor(
                    seq_len, num_examples, genr_input)
                accuracy = self.discriminator.accuracy(X_actual, genr_output)
                print("accuracy before generator training: ", accuracy)

            # train generator
            self.generator.train_RMS(genr_input, seq_len, self.discriminator,
                                     g_epochs, 1, g_initial_lr, g_multiplier,
                                     g_batch_size)
            #print(sum(l.magnitude_theta() for l in self.generator.lstm.layers))

            # evaluate discriminator
            if print_progress:
                genr_output = self.generator.generate_tensor(
                    seq_len, num_examples, genr_input)
                accuracy = self.discriminator.accuracy(X_actual, genr_output)
                print("accuracy after generator training: ", accuracy)

            # display generator's output
            if print_progress:
                gen_text = self.generator.generate(seq_len, num_examples,
                                                   genr_input)
                if num_displayed is not None:
                    gen_text = gen_text[:num_displayed]
                for line in gen_text:
                    print(line)
Пример #30
0
from bottle import route, run, static_file, request, response
from os import getenv, path
import json

from discriminator import Discriminator


def relative_path(target_path):
    return path.normpath(path.join(path.dirname(__file__), target_path))


@route('/')
def index():
    response.content_type = 'text/html; charset=utf-8'
    return static_file('index.html', root=relative_path('./'))


@route('/api/upload', method='POST')
def upload():
    upload = request.files.get('upload')
    result = discriminator.predict(upload.file)
    return json.dumps(result)


discriminator = Discriminator()

run(host='0.0.0.0', port=getenv('PORT', 8080), debug=True)
Пример #31
0
        os.mkdir(args.save_dir)

    if not os.path.exists(args.samples_dir):
        os.mkdir(args.samples_dir)

    INPUT_SIZE = 784
    SAMPLE_SIZE = 80
    NUM_LABELS = 10
    train_dataset = datasets.MNIST(root='data',
        train=True,
        download=True,
        transform=transforms.ToTensor())
    train_loader = DataLoader(train_dataset, shuffle=True,
        batch_size=args.batch_size)

    model_d = Discriminator()
    model_g = Generator(args.nz)
    criterion = nn.BCELoss()
    input = torch.FloatTensor(args.batch_size, INPUT_SIZE)
    noise = torch.FloatTensor(args.batch_size, (args.nz))
    
    fixed_noise = torch.FloatTensor(SAMPLE_SIZE, args.nz).normal_(0,1)
    fixed_labels = torch.zeros(SAMPLE_SIZE, NUM_LABELS)
    for i in range(NUM_LABELS):
        for j in range(SAMPLE_SIZE // NUM_LABELS):
            fixed_labels[i*(SAMPLE_SIZE // NUM_LABELS) + j, i] = 1.0
    
    label = torch.FloatTensor(args.batch_size)
    one_hot_labels = torch.FloatTensor(args.batch_size, 10)
    if args.cuda:
        model_d.cuda()
logger = Logger(TENSORBOARD_DIRECTORY)

###################################################################################

if is_gpu_mode:
    ones_label = Variable(torch.ones(BATCH_SIZE).cuda())
    zeros_label = Variable(torch.zeros(BATCH_SIZE).cuda())
else:
    ones_label = Variable(torch.ones(BATCH_SIZE))
    zeros_label = Variable(torch.zeros(BATCH_SIZE))

if __name__ == "__main__":
    print 'main'

    gen_model = Tiramisu()
    disc_model = Discriminator()

    if is_gpu_mode:
        gen_model.cuda()
        disc_model.cuda()
        # gen_model = torch.nn.DataParallel(gen_model).cuda()
        # disc_model = torch.nn.DataParallel(disc_model).cuda()

    optimizer_gen = torch.optim.Adam(gen_model.parameters(), lr=LEARNING_RATE_GENERATOR)
    optimizer_disc = torch.optim.Adam(disc_model.parameters(), lr=LEARNING_RATE_DISCRIMINATOR)

    # read imgs
    image_buff_read_index = 0

    # pytorch style
    input_img = np.empty(shape=(BATCH_SIZE, 3, data_loader.INPUT_IMAGE_WIDTH, data_loader.INPUT_IMAGE_HEIGHT))
Пример #33
0
def main(unused_argv):
    config_train = training_config()
    config_gen = generator_config()
    config_dis = discriminator_config()

    np.random.seed(config_train.seed)

    assert config_train.start_token == 0
    gen_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    likelihood_data_loader = Gen_Data_loader(config_gen.gen_batch_size)
    dis_data_loader = Dis_dataloader(config_dis.dis_batch_size)

    generator = Generator(config=config_gen)
    generator.build()

    rollout_gen = rollout(config=config_gen)

    #Build target LSTM
    target_params = pickle.load(open('save/target_params.pkl','rb'),encoding='iso-8859-1')
    target_lstm = TARGET_LSTM(config=config_gen, params=target_params) # The oracle model


    # Build discriminator
    discriminator = Discriminator(config=config_dis)
    discriminator.build_discriminator()


    # Build optimizer op for pretraining
    pretrained_optimizer = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    var_pretrained = [v for v in tf.trainable_variables() if 'teller' in v.name]
    gradients, variables = zip(
        *pretrained_optimizer.compute_gradients(generator.pretrained_loss, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    gen_pre_update = pretrained_optimizer.apply_gradients(zip(gradients, variables))

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    generate_samples(sess,target_lstm,config_train.batch_size,config_train.generated_num,config_train.positive_file)
    gen_data_loader.create_batches(config_train.positive_file)

    log = open('save/experiment-log.txt','w')
    print('Start pre-training generator....')

    log.write('pre-training...\n')

    for epoch in range(config_train.pretrained_epoch_num):
        gen_data_loader.reset_pointer()
        for it in range(gen_data_loader.num_batch):
            batch = gen_data_loader.next_batch()
            _,g_loss = sess.run([gen_pre_update,generator.pretrained_loss],feed_dict={generator.input_seqs_pre:batch,
                                                                                      generator.input_seqs_mask:np.ones_like(batch)})

        if epoch % config_train.test_per_epoch == 0:
            #进行测试,通过Generator产生一批序列,
            generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.eval_file)
            # 创建这批序列的data-loader
            likelihood_data_loader.create_batches(config_train.eval_file)
            # 使用oracle 计算 交叉熵损失nll
            test_loss = target_loss(sess,target_lstm,likelihood_data_loader)
            # 打印并写入日志
            print('pre-train ',epoch, ' test_loss ',test_loss)
            buffer = 'epoch:\t' + str(epoch) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)


    print('Start pre-training discriminator...')
    for t in range(config_train.dis_update_time_pre):
        print("Times: " + str(t))
        generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.negative_file)
        dis_data_loader.load_train_data(config_train.positive_file,config_train.negative_file)
        for _ in range(config_train.dis_update_time_pre):
            dis_data_loader.reset_pointer()
            for it in range(dis_data_loader.num_batch):
                x_batch,y_batch = dis_data_loader.next_batch()
                feed_dict = {
                    discriminator.input_x : x_batch,
                    discriminator.input_y : y_batch,
                    discriminator.dropout_keep_prob : config_dis.dis_dropout_keep_prob
                }
                _ = sess.run(discriminator.train_op,feed_dict)



    # Build optimizer op for adversarial training
    train_adv_opt = tf.train.AdamOptimizer(config_train.gen_learning_rate)
    gradients, variables = zip(*train_adv_opt.compute_gradients(generator.gen_loss_adv, var_list=var_pretrained))
    gradients, _ = tf.clip_by_global_norm(gradients, config_train.grad_clip)
    train_adv_update = train_adv_opt.apply_gradients(zip(gradients, variables))

    # Initialize global variables of optimizer for adversarial training
    uninitialized_var = [e for e in tf.global_variables() if e not in tf.trainable_variables()]
    init_vars_uninit_op = tf.variables_initializer(uninitialized_var)
    sess.run(init_vars_uninit_op)

    # Start adversarial training
    for total_batch in range(config_train.total_batch):
        for iter_gen in range(config_train.gen_update_time):
            samples = sess.run(generator.sample_word_list_reshpae)

            feed = {'pred_seq_rollout:0':samples}
            reward_rollout = []
            for iter_roll in range(config_train.rollout_num):
                rollout_list = sess.run(rollout_gen.sample_rollout_step,feed_dict=feed)
                # np.vstack 它是垂直(按照行顺序)的把数组给堆叠起来。
                rollout_list_stack = np.vstack(rollout_list)
                reward_rollout_seq = sess.run(discriminator.ypred_for_auc,feed_dict={
                    discriminator.input_x:rollout_list_stack,discriminator.dropout_keep_prob:1.0
                })
                reward_last_tok = sess.run(discriminator.ypred_for_auc,feed_dict={
                    discriminator.input_x:samples,discriminator.dropout_keep_prob:1.0
                })
                reward_allseq = np.concatenate((reward_rollout_seq,reward_last_tok),axis=0)[:,1]
                reward_tmp = []
                for r in range(config_gen.gen_batch_size):
                    reward_tmp.append(reward_allseq[range(r,config_gen.gen_batch_size * config_gen.sequence_length,config_gen.gen_batch_size)])

                reward_rollout.append(np.array(reward_tmp))
                rewards = np.sum(reward_rollout,axis = 0) / config_train.rollout_num
                _,gen_loss = sess.run([train_adv_update,generator.gen_loss_adv],feed_dict={generator.input_seqs_adv:samples,
                                                                                           generator.rewards:rewards})


        if total_batch % config_train.test_per_epoch == 0 or total_batch == config_train.total_batch - 1:
            generate_samples(sess, generator, config_train.batch_size, config_train.generated_num, config_train.eval_file)
            likelihood_data_loader.create_batches(config_train.eval_file)
            test_loss = target_loss(sess, target_lstm, likelihood_data_loader)
            buffer = 'epoch:\t' + str(total_batch) + '\tnll:\t' + str(test_loss) + '\n'
            print ('total_batch: ', total_batch, 'test_loss: ', test_loss)
            log.write(buffer)


        for _ in range(config_train.dis_update_time_adv):
            generate_samples(sess,generator,config_train.batch_size,config_train.generated_num,config_train.negative_file)
            dis_data_loader.load_train_data(config_train.positive_file,config_train.negative_file)

            for _ in range(config_train.dis_update_time_adv):
                dis_data_loader.reset_pointer()
                for it in range(dis_data_loader.num_batch):
                    x_batch,y_batch = dis_data_loader.next_batch()
                    feed = {
                        discriminator.input_x:x_batch,
                        discriminator.input_y:y_batch,
                        discriminator.dropout_keep_prob:config_dis.dis_dropout_keep_prob
                    }
                    _ = sess.run(discriminator.train_op,feed)

    log.close()