def train():
    FLAGS = get_args()
    train_data = MNISTData('train',
                           data_dir=DATA_PATH,
                           shuffle=True,
                           pf=preprocess_im,
                           batch_dict_name=['im', 'label'])
    train_data.setup(epoch_val=0, batch_size=FLAGS.bsize)
    valid_data = MNISTData('test',
                           data_dir=DATA_PATH,
                           shuffle=True,
                           pf=preprocess_im,
                           batch_dict_name=['im', 'label'])
    valid_data.setup(epoch_val=0, batch_size=FLAGS.bsize)

    with tf.variable_scope('VAE') as scope:
        model = VAE(n_code=FLAGS.ncode, wd=0)
        model.create_train_model()

    with tf.variable_scope('VAE') as scope:
        scope.reuse_variables()
        valid_model = VAE(n_code=FLAGS.ncode, wd=0)
        valid_model.create_generate_model(b_size=400)

    trainer = Trainer(model,
                      valid_model,
                      train_data,
                      init_lr=FLAGS.lr,
                      save_path=SAVE_PATH)
    if FLAGS.ncode == 2:
        z = distribution.interpolate(plot_size=20)
        z = np.reshape(z, (400, 2))
        visualizer = Visualizer(model, save_path=SAVE_PATH)
    else:
        z = None
    generator = Generator(generate_model=valid_model, save_path=SAVE_PATH)

    sessconfig = tf.ConfigProto()
    sessconfig.gpu_options.allow_growth = True
    with tf.Session(config=sessconfig) as sess:
        writer = tf.summary.FileWriter(SAVE_PATH)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        writer.add_graph(sess.graph)

        for epoch_id in range(FLAGS.maxepoch):
            trainer.train_epoch(sess, summary_writer=writer)
            trainer.valid_epoch(sess, summary_writer=writer)
            if epoch_id % 10 == 0:
                saver.save(sess, '{}vae-epoch-{}'.format(SAVE_PATH, epoch_id))
                if FLAGS.ncode == 2:
                    generator.generate_samples(sess,
                                               plot_size=20,
                                               z=z,
                                               file_id=epoch_id)
                    visualizer.viz_2Dlatent_variable(sess,
                                                     valid_data,
                                                     file_id=epoch_id)
def semisupervised_train():
    """ Function for semisupervised training (Fig 8 in the paper)

    Validation will be processed after each epoch of training 
    Loss of each modules will be averaged and saved in summaries
    every 100 steps.
    """

    FLAGS = get_args()
    # load dataset
    train_data_unlabel = read_train_data(FLAGS.bsize)
    train_data_label = read_train_data(FLAGS.bsize, n_use_sample=1280)
    train_data = {'unlabeled': train_data_unlabel, 'labeled': train_data_label}
    valid_data = read_valid_data(FLAGS.bsize)

    # create an AAE model for semisupervised training
    train_model = AAE(n_code=FLAGS.ncode,
                      wd=0,
                      n_class=10,
                      add_noise=FLAGS.noise,
                      enc_weight=FLAGS.encw,
                      gen_weight=FLAGS.genw,
                      dis_weight=FLAGS.disw,
                      cat_dis_weight=FLAGS.ydisw,
                      cat_gen_weight=FLAGS.ygenw,
                      cls_weight=FLAGS.clsw)
    train_model.create_semisupervised_train_model()

    # create an separated AAE model for semisupervised validation
    # shared weights with training model
    cls_valid_model = AAE(n_code=FLAGS.ncode, n_class=10)
    cls_valid_model.create_semisupervised_test_model()

    # initialize a trainer for training
    trainer = Trainer(train_model,
                      cls_valid_model=cls_valid_model,
                      generate_model=None,
                      train_data=train_data,
                      init_lr=FLAGS.lr,
                      save_path=SAVE_PATH)

    sessconfig = tf.ConfigProto()
    sessconfig.gpu_options.allow_growth = True
    with tf.Session(config=sessconfig) as sess:
        writer = tf.summary.FileWriter(SAVE_PATH)
        sess.run(tf.global_variables_initializer())
        writer.add_graph(sess.graph)
        for epoch_id in range(FLAGS.maxepoch):
            trainer.train_semisupervised_epoch(sess,
                                               ae_dropout=FLAGS.dropout,
                                               summary_writer=writer)
            trainer.valid_semisupervised_epoch(sess,
                                               valid_data,
                                               summary_writer=writer)
Esempio n. 3
0
def supervised_train():
    """ Function for supervised training (Fig 6 in the paper)

    Validation will be processed after each epoch of training.
    Loss of each modules will be averaged and saved in summaries
    every 100 steps. Every 10 epochs, 10 different style for 10 digits
    will be saved.
    """

    FLAGS = get_args()
    # load dataset
    train_data = read_train_data(FLAGS.bsize)
    valid_data = read_valid_data(FLAGS.bsize)

    # create an AAE model for supervised training
    model = AAE(n_code=FLAGS.ncode, wd=0, n_class=10, 
                use_supervise=True, add_noise=FLAGS.noise,
                enc_weight=FLAGS.encw, gen_weight=FLAGS.genw,
                dis_weight=FLAGS.disw)
    model.create_train_model()

    # Create an separated AAE model for supervised validation
    # shared weights with training model. This model is used to
    # generate 10 different style for 10 digits for every 10 epochs.
    valid_model = AAE(n_code=FLAGS.ncode, use_supervise=True, n_class=10)
    valid_model.create_generate_style_model(n_sample=10)

    # initialize a trainer for training
    trainer = Trainer(model, valid_model, train_data,
                      init_lr=FLAGS.lr, save_path=SAVE_PATH)
    # initialize a generator for generating style images
    generator = Generator(
        generate_model=valid_model, save_path=SAVE_PATH, n_labels=10)

    sessconfig = tf.ConfigProto()
    sessconfig.gpu_options.allow_growth = True
    with tf.Session(config=sessconfig) as sess:
        writer = tf.summary.FileWriter(SAVE_PATH)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        writer.add_graph(sess.graph)

        for epoch_id in range(FLAGS.maxepoch):
            trainer.train_z_gan_epoch(
                sess, ae_dropout=FLAGS.dropout, summary_writer=writer)
            trainer.valid_epoch(sess, dataflow=valid_data, summary_writer=writer)
            
            if epoch_id % 10 == 0:
                saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id))
                generator.sample_style(sess, valid_data, plot_size=10,
                                       file_id=epoch_id, n_sample=10)
        saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id))
Esempio n. 4
0
def train():
    FLAGS = get_args()
    # Create Dataflow object for training and testing set
    train_data, valid_data = loader.load_cifar(cifar_path=DATA_PATH,
                                               batch_size=FLAGS.bsize,
                                               subtract_mean=True)

    pre_trained_path = None
    if FLAGS.finetune:
        # Load the pre-trained model (on ImageNet)
        # for convolutional layers if fine tuning
        pre_trained_path = PRETRINED_PATH

    # Create a training model
    train_model = GoogLeNet_cifar(n_channel=3,
                                  n_class=10,
                                  pre_trained_path=pre_trained_path,
                                  bn=True,
                                  wd=0,
                                  sub_imagenet_mean=False,
                                  conv_trainable=True,
                                  fc_trainable=True)
    train_model.create_train_model()
    # Create a validation model
    valid_model = GoogLeNet_cifar(n_channel=3,
                                  n_class=10,
                                  bn=True,
                                  sub_imagenet_mean=False)
    valid_model.create_test_model()

    # create a Trainer object for training control
    trainer = Trainer(train_model, valid_model, train_data, init_lr=FLAGS.lr)

    with tf.Session() as sess:
        writer = tf.summary.FileWriter(SAVE_PATH)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        writer.add_graph(sess.graph)
        for epoch_id in range(FLAGS.maxepoch):
            # train one epoch
            trainer.train_epoch(sess,
                                keep_prob=FLAGS.keep_prob,
                                summary_writer=writer)
            # test the model on validation set after each epoch
            trainer.valid_epoch(sess,
                                dataflow=valid_data,
                                summary_writer=writer)
            saver.save(
                sess, '{}inception-cifar-epoch-{}'.format(SAVE_PATH, epoch_id))
        saver.save(sess,
                   '{}inception-cifar-epoch-{}'.format(SAVE_PATH, epoch_id))
        writer.close()
Esempio n. 5
0
def train():
    FLAGS = get_args()
    train_data, valid_data = loader.load_cifar(cifar_path=DATA_PATH,
                                               batch_size=FLAGS.bsize,
                                               substract_mean=True)

    train_model = VGG_CIFAR10(n_channel=3,
                              n_class=10,
                              pre_trained_path=None,
                              bn=True,
                              wd=5e-3,
                              trainable=True,
                              sub_vgg_mean=False)
    train_model.create_train_model()

    valid_model = VGG_CIFAR10(n_channel=3,
                              n_class=10,
                              bn=True,
                              sub_vgg_mean=False)
    valid_model.create_test_model()

    trainer = Trainer(train_model, valid_model, train_data, init_lr=FLAGS.lr)

    with tf.Session() as sess:
        writer = tf.summary.FileWriter(SAVE_PATH)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        writer.add_graph(sess.graph)
        for epoch_id in range(FLAGS.maxepoch):
            trainer.train_epoch(sess,
                                keep_prob=FLAGS.keep_prob,
                                summary_writer=writer)
            trainer.valid_epoch(sess,
                                dataflow=valid_data,
                                summary_writer=writer)
            saver.save(sess,
                       '{}vgg-cifar-epoch-{}'.format(SAVE_PATH, epoch_id))
        saver.save(sess, '{}vgg-cifar-epoch-{}'.format(SAVE_PATH, epoch_id))
def train():
    """ Function for unsupervised training and incorporate
        label info in adversarial regularization 
        (Fig 1 and 3 in the paper)

    Validation will be processed after each epoch of training.
    Loss of each modules will be averaged and saved in summaries
    every 100 steps. Random samples and learned latent space will
    be saved for every 10 epochs.
    """

    FLAGS = get_args()
    # image size for visualization. plot_size * plot_size digits will be visualized.
    plot_size = 20

    # Use 10000 labels info to train latent space
    n_use_label = 10000
    n_use_sample = 50000
    # load data
    train_data = read_train_data(FLAGS.bsize,
                                 n_use_label=n_use_label,
                                 n_use_sample=n_use_sample)
    valid_data = read_valid_data(FLAGS.bsize)

    # create an AAE model for training
    model = AAE(n_code=FLAGS.ncode,
                wd=0,
                n_class=10,
                use_label=FLAGS.label,
                add_noise=FLAGS.noise,
                enc_weight=FLAGS.encw,
                gen_weight=FLAGS.genw,
                dis_weight=FLAGS.disw)
    model.create_train_model()

    # Create an separated AAE model for validation shared weights
    # with training model. This model is used to
    # randomly sample model data every 10 epoches.
    valid_model = AAE(n_code=FLAGS.ncode, n_class=10)
    valid_model.create_generate_model(b_size=400)

    # initialize a trainer for training
    trainer = Trainer(model,
                      valid_model,
                      train_data,
                      distr_type=FLAGS.dist_type,
                      use_label=FLAGS.label,
                      init_lr=FLAGS.lr,
                      save_path=SAVE_PATH)
    # Initialize a visualizer and a generator to monitor learned
    # latent space and data generation.
    # Latent space visualization only for code dim = 2
    if FLAGS.ncode == 2:
        visualizer = Visualizer(model, save_path=SAVE_PATH)
    generator = Generator(generate_model=valid_model,
                          save_path=SAVE_PATH,
                          distr_type=FLAGS.dist_type,
                          n_labels=10,
                          use_label=FLAGS.label)

    sessconfig = tf.ConfigProto()
    sessconfig.gpu_options.allow_growth = True
    with tf.Session(config=sessconfig) as sess:
        writer = tf.summary.FileWriter(SAVE_PATH)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        writer.add_graph(sess.graph)

        for epoch_id in range(FLAGS.maxepoch):
            trainer.train_z_gan_epoch(sess,
                                      ae_dropout=FLAGS.dropout,
                                      summary_writer=writer)
            trainer.valid_epoch(sess,
                                dataflow=valid_data,
                                summary_writer=writer)

            if epoch_id % 10 == 0:
                saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id))
                generator.generate_samples(sess,
                                           plot_size=plot_size,
                                           file_id=epoch_id)
                if FLAGS.ncode == 2:
                    visualizer.viz_2Dlatent_variable(sess,
                                                     valid_data,
                                                     file_id=epoch_id)
        saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id))
Esempio n. 7
0
def train():
    FLAGS = get_args()
    train_data, valid_data = loader.load_cifar(cifar_path=FLAGS.data_path,
                                               batch_size=FLAGS.bsize,
                                               substract_mean=True)

    train_model = VGG_CIFAR10(n_channel=3,
                              n_class=10,
                              pre_trained_path=None,
                              bn=True,
                              wd=5e-3,
                              trainable=True,
                              sub_vgg_mean=False)
    train_model.create_train_model()

    valid_model = VGG_CIFAR10(n_channel=3,
                              n_class=10,
                              bn=True,
                              sub_vgg_mean=False)
    valid_model.create_test_model()

    trainer = Trainer(train_model, valid_model, train_data, init_lr=FLAGS.lr)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        writer = tf.compat.v1.summary.FileWriter(FLAGS.save_path)
        saver = tf.compat.v1.train.Saver()
        if FLAGS.saved_model != '':
            saver.restore(sess, FLAGS.saved_model)

        sess.run(tf.global_variables_initializer())
        writer.add_graph(sess.graph)
        for epoch_id in range(FLAGS.maxepoch):
            trainer.train_epoch(sess,
                                keep_prob=FLAGS.keep_prob,
                                summary_writer=writer)
            trainer.valid_epoch(sess,
                                dataflow=valid_data,
                                summary_writer=writer)

            # connection part
            msg = {}
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            server_address = ('0.0.0.0', 5555)
            sock.connect(server_address)

            if scheduler_op == 'g' or 's':
                saved_model = '{}vgg-cifar-epoch-{}'.format(
                    FLAGS.save_path, epoch_id)
                saver.save(sess, saved_model)
                # generate command
                if scheduler_op == 'g':
                    sch_gpus = FLAGS.gpus.split(',') + target_gpus
                else:
                    sch_gpus = FLAGS.gpus.split(',') - target_gpus

                current_gpus = ','.join(sch_gpus)

                config_command = '--train --port ' + str(FLAGS.port) + ' --data_path ' + str(FLAGS.data_path) +\
                                 ' --saved_model ' + saved_model + ' --save_path ' + FLAGS.save_path + ' --lr ' +\
                                 str(FLAGS.lr) + ' --bsize ' + str(FLAGS.bsize) + ' --keep_prob ' +\
                                 str(FLAGS.keep_prob) + ' --maxepoch ' + str(FLAGS.maxepoch-epoch_id-1) +\
                                 ' --gpus ' + current_gpus
                # send command back to the scheduler
                msg['config'] = config_command
                msg['id'] = FLAGS.id
                msg['ep'] = epoch_id + 1
                msg['gpus'] = sch_gpus
                sock.sendall(dict_to_binary(msg))
                # leave the scheduler to restart
                exit()
            else:
                msg['id'] = FLAGS.id
                msg['ep'] = epoch_id + 1
                sock.sendall(dict_to_binary(msg))

            # saver.save(sess, '{}vgg-cifar-epoch-{}'.format(SAVE_PATH, epoch_id))
        saver.save(sess, '{}vgg-cifar-epoch-{}'.format(FLAGS.save_path,
                                                       epoch_id))
Esempio n. 8
0
def train_type_1():
    FLAGS = get_args()
    if FLAGS.gan_type == 'lsgan':
        gan_model = LSGAN
        print('**** LSGAN ****')
    elif FLAGS.gan_type == 'dcgan':
        gan_model = DCGAN
        print('**** DCGAN ****')
    else:
        raise ValueError('Wrong GAN type!')

    save_path = os.path.join(SAVE_PATH, FLAGS.gan_type)
    save_path += '/'

    # load dataset
    if FLAGS.dataset == 'celeba':
        train_data = loader.load_celeba(FLAGS.bsize, data_path=CELEBA_PATH)
        im_size = 64
        n_channels = 3
    else:
        train_data = loader.load_mnist(FLAGS.bsize, data_path=MNIST_PATH)
        im_size = 28
        n_channels = 1

    # init training model
    train_model = gan_model(input_len=FLAGS.zlen,
                            im_size=im_size,
                            n_channels=n_channels)
    train_model.create_train_model()

    # init generate model
    generate_model = gan_model(input_len=FLAGS.zlen,
                               im_size=im_size,
                               n_channels=n_channels)
    generate_model.create_generate_model()

    # create trainer
    trainer = Trainer(train_model,
                      train_data,
                      moniter_gradient=False,
                      init_lr=FLAGS.lr,
                      save_path=save_path)
    # create generator for sampling
    generator = Generator(generate_model,
                          keep_prob=FLAGS.keep_prob,
                          save_path=save_path)

    sessconfig = tf.ConfigProto()
    sessconfig.gpu_options.allow_growth = True
    with tf.Session(config=sessconfig) as sess:
        writer = tf.summary.FileWriter(save_path)
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        writer.add_graph(sess.graph)
        for epoch_id in range(FLAGS.maxepoch):
            trainer.train_epoch(sess,
                                keep_prob=FLAGS.keep_prob,
                                n_g_train=FLAGS.ng,
                                n_d_train=FLAGS.nd,
                                summary_writer=writer)
            generator.random_sampling(sess, plot_size=10, file_id=epoch_id)
            generator.viz_interpolate(sess, file_id=epoch_id)
            if FLAGS.zlen == 2:
                generator.viz_2D_manifold(sess, plot_size=20, file_id=epoch_id)

            saver.save(
                sess, '{}gan-{}-epoch-{}'.format(save_path, FLAGS.gan_type,
                                                 epoch_id))
        saver.save(
            sess, '{}gan-{}-epoch-{}'.format(save_path, FLAGS.gan_type,
                                             epoch_id))