def train(): FLAGS = get_args() train_data = MNISTData('train', data_dir=DATA_PATH, shuffle=True, pf=preprocess_im, batch_dict_name=['im', 'label']) train_data.setup(epoch_val=0, batch_size=FLAGS.bsize) valid_data = MNISTData('test', data_dir=DATA_PATH, shuffle=True, pf=preprocess_im, batch_dict_name=['im', 'label']) valid_data.setup(epoch_val=0, batch_size=FLAGS.bsize) with tf.variable_scope('VAE') as scope: model = VAE(n_code=FLAGS.ncode, wd=0) model.create_train_model() with tf.variable_scope('VAE') as scope: scope.reuse_variables() valid_model = VAE(n_code=FLAGS.ncode, wd=0) valid_model.create_generate_model(b_size=400) trainer = Trainer(model, valid_model, train_data, init_lr=FLAGS.lr, save_path=SAVE_PATH) if FLAGS.ncode == 2: z = distribution.interpolate(plot_size=20) z = np.reshape(z, (400, 2)) visualizer = Visualizer(model, save_path=SAVE_PATH) else: z = None generator = Generator(generate_model=valid_model, save_path=SAVE_PATH) sessconfig = tf.ConfigProto() sessconfig.gpu_options.allow_growth = True with tf.Session(config=sessconfig) as sess: writer = tf.summary.FileWriter(SAVE_PATH) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) writer.add_graph(sess.graph) for epoch_id in range(FLAGS.maxepoch): trainer.train_epoch(sess, summary_writer=writer) trainer.valid_epoch(sess, summary_writer=writer) if epoch_id % 10 == 0: saver.save(sess, '{}vae-epoch-{}'.format(SAVE_PATH, epoch_id)) if FLAGS.ncode == 2: generator.generate_samples(sess, plot_size=20, z=z, file_id=epoch_id) visualizer.viz_2Dlatent_variable(sess, valid_data, file_id=epoch_id)
def semisupervised_train(): """ Function for semisupervised training (Fig 8 in the paper) Validation will be processed after each epoch of training Loss of each modules will be averaged and saved in summaries every 100 steps. """ FLAGS = get_args() # load dataset train_data_unlabel = read_train_data(FLAGS.bsize) train_data_label = read_train_data(FLAGS.bsize, n_use_sample=1280) train_data = {'unlabeled': train_data_unlabel, 'labeled': train_data_label} valid_data = read_valid_data(FLAGS.bsize) # create an AAE model for semisupervised training train_model = AAE(n_code=FLAGS.ncode, wd=0, n_class=10, add_noise=FLAGS.noise, enc_weight=FLAGS.encw, gen_weight=FLAGS.genw, dis_weight=FLAGS.disw, cat_dis_weight=FLAGS.ydisw, cat_gen_weight=FLAGS.ygenw, cls_weight=FLAGS.clsw) train_model.create_semisupervised_train_model() # create an separated AAE model for semisupervised validation # shared weights with training model cls_valid_model = AAE(n_code=FLAGS.ncode, n_class=10) cls_valid_model.create_semisupervised_test_model() # initialize a trainer for training trainer = Trainer(train_model, cls_valid_model=cls_valid_model, generate_model=None, train_data=train_data, init_lr=FLAGS.lr, save_path=SAVE_PATH) sessconfig = tf.ConfigProto() sessconfig.gpu_options.allow_growth = True with tf.Session(config=sessconfig) as sess: writer = tf.summary.FileWriter(SAVE_PATH) sess.run(tf.global_variables_initializer()) writer.add_graph(sess.graph) for epoch_id in range(FLAGS.maxepoch): trainer.train_semisupervised_epoch(sess, ae_dropout=FLAGS.dropout, summary_writer=writer) trainer.valid_semisupervised_epoch(sess, valid_data, summary_writer=writer)
def supervised_train(): """ Function for supervised training (Fig 6 in the paper) Validation will be processed after each epoch of training. Loss of each modules will be averaged and saved in summaries every 100 steps. Every 10 epochs, 10 different style for 10 digits will be saved. """ FLAGS = get_args() # load dataset train_data = read_train_data(FLAGS.bsize) valid_data = read_valid_data(FLAGS.bsize) # create an AAE model for supervised training model = AAE(n_code=FLAGS.ncode, wd=0, n_class=10, use_supervise=True, add_noise=FLAGS.noise, enc_weight=FLAGS.encw, gen_weight=FLAGS.genw, dis_weight=FLAGS.disw) model.create_train_model() # Create an separated AAE model for supervised validation # shared weights with training model. This model is used to # generate 10 different style for 10 digits for every 10 epochs. valid_model = AAE(n_code=FLAGS.ncode, use_supervise=True, n_class=10) valid_model.create_generate_style_model(n_sample=10) # initialize a trainer for training trainer = Trainer(model, valid_model, train_data, init_lr=FLAGS.lr, save_path=SAVE_PATH) # initialize a generator for generating style images generator = Generator( generate_model=valid_model, save_path=SAVE_PATH, n_labels=10) sessconfig = tf.ConfigProto() sessconfig.gpu_options.allow_growth = True with tf.Session(config=sessconfig) as sess: writer = tf.summary.FileWriter(SAVE_PATH) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) writer.add_graph(sess.graph) for epoch_id in range(FLAGS.maxepoch): trainer.train_z_gan_epoch( sess, ae_dropout=FLAGS.dropout, summary_writer=writer) trainer.valid_epoch(sess, dataflow=valid_data, summary_writer=writer) if epoch_id % 10 == 0: saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id)) generator.sample_style(sess, valid_data, plot_size=10, file_id=epoch_id, n_sample=10) saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id))
def train(): FLAGS = get_args() # Create Dataflow object for training and testing set train_data, valid_data = loader.load_cifar(cifar_path=DATA_PATH, batch_size=FLAGS.bsize, subtract_mean=True) pre_trained_path = None if FLAGS.finetune: # Load the pre-trained model (on ImageNet) # for convolutional layers if fine tuning pre_trained_path = PRETRINED_PATH # Create a training model train_model = GoogLeNet_cifar(n_channel=3, n_class=10, pre_trained_path=pre_trained_path, bn=True, wd=0, sub_imagenet_mean=False, conv_trainable=True, fc_trainable=True) train_model.create_train_model() # Create a validation model valid_model = GoogLeNet_cifar(n_channel=3, n_class=10, bn=True, sub_imagenet_mean=False) valid_model.create_test_model() # create a Trainer object for training control trainer = Trainer(train_model, valid_model, train_data, init_lr=FLAGS.lr) with tf.Session() as sess: writer = tf.summary.FileWriter(SAVE_PATH) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) writer.add_graph(sess.graph) for epoch_id in range(FLAGS.maxepoch): # train one epoch trainer.train_epoch(sess, keep_prob=FLAGS.keep_prob, summary_writer=writer) # test the model on validation set after each epoch trainer.valid_epoch(sess, dataflow=valid_data, summary_writer=writer) saver.save( sess, '{}inception-cifar-epoch-{}'.format(SAVE_PATH, epoch_id)) saver.save(sess, '{}inception-cifar-epoch-{}'.format(SAVE_PATH, epoch_id)) writer.close()
def train(): FLAGS = get_args() train_data, valid_data = loader.load_cifar(cifar_path=DATA_PATH, batch_size=FLAGS.bsize, substract_mean=True) train_model = VGG_CIFAR10(n_channel=3, n_class=10, pre_trained_path=None, bn=True, wd=5e-3, trainable=True, sub_vgg_mean=False) train_model.create_train_model() valid_model = VGG_CIFAR10(n_channel=3, n_class=10, bn=True, sub_vgg_mean=False) valid_model.create_test_model() trainer = Trainer(train_model, valid_model, train_data, init_lr=FLAGS.lr) with tf.Session() as sess: writer = tf.summary.FileWriter(SAVE_PATH) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) writer.add_graph(sess.graph) for epoch_id in range(FLAGS.maxepoch): trainer.train_epoch(sess, keep_prob=FLAGS.keep_prob, summary_writer=writer) trainer.valid_epoch(sess, dataflow=valid_data, summary_writer=writer) saver.save(sess, '{}vgg-cifar-epoch-{}'.format(SAVE_PATH, epoch_id)) saver.save(sess, '{}vgg-cifar-epoch-{}'.format(SAVE_PATH, epoch_id))
def train(): """ Function for unsupervised training and incorporate label info in adversarial regularization (Fig 1 and 3 in the paper) Validation will be processed after each epoch of training. Loss of each modules will be averaged and saved in summaries every 100 steps. Random samples and learned latent space will be saved for every 10 epochs. """ FLAGS = get_args() # image size for visualization. plot_size * plot_size digits will be visualized. plot_size = 20 # Use 10000 labels info to train latent space n_use_label = 10000 n_use_sample = 50000 # load data train_data = read_train_data(FLAGS.bsize, n_use_label=n_use_label, n_use_sample=n_use_sample) valid_data = read_valid_data(FLAGS.bsize) # create an AAE model for training model = AAE(n_code=FLAGS.ncode, wd=0, n_class=10, use_label=FLAGS.label, add_noise=FLAGS.noise, enc_weight=FLAGS.encw, gen_weight=FLAGS.genw, dis_weight=FLAGS.disw) model.create_train_model() # Create an separated AAE model for validation shared weights # with training model. This model is used to # randomly sample model data every 10 epoches. valid_model = AAE(n_code=FLAGS.ncode, n_class=10) valid_model.create_generate_model(b_size=400) # initialize a trainer for training trainer = Trainer(model, valid_model, train_data, distr_type=FLAGS.dist_type, use_label=FLAGS.label, init_lr=FLAGS.lr, save_path=SAVE_PATH) # Initialize a visualizer and a generator to monitor learned # latent space and data generation. # Latent space visualization only for code dim = 2 if FLAGS.ncode == 2: visualizer = Visualizer(model, save_path=SAVE_PATH) generator = Generator(generate_model=valid_model, save_path=SAVE_PATH, distr_type=FLAGS.dist_type, n_labels=10, use_label=FLAGS.label) sessconfig = tf.ConfigProto() sessconfig.gpu_options.allow_growth = True with tf.Session(config=sessconfig) as sess: writer = tf.summary.FileWriter(SAVE_PATH) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) writer.add_graph(sess.graph) for epoch_id in range(FLAGS.maxepoch): trainer.train_z_gan_epoch(sess, ae_dropout=FLAGS.dropout, summary_writer=writer) trainer.valid_epoch(sess, dataflow=valid_data, summary_writer=writer) if epoch_id % 10 == 0: saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id)) generator.generate_samples(sess, plot_size=plot_size, file_id=epoch_id) if FLAGS.ncode == 2: visualizer.viz_2Dlatent_variable(sess, valid_data, file_id=epoch_id) saver.save(sess, '{}aae-epoch-{}'.format(SAVE_PATH, epoch_id))
def train(): FLAGS = get_args() train_data, valid_data = loader.load_cifar(cifar_path=FLAGS.data_path, batch_size=FLAGS.bsize, substract_mean=True) train_model = VGG_CIFAR10(n_channel=3, n_class=10, pre_trained_path=None, bn=True, wd=5e-3, trainable=True, sub_vgg_mean=False) train_model.create_train_model() valid_model = VGG_CIFAR10(n_channel=3, n_class=10, bn=True, sub_vgg_mean=False) valid_model.create_test_model() trainer = Trainer(train_model, valid_model, train_data, init_lr=FLAGS.lr) config = tf.ConfigProto() config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: writer = tf.compat.v1.summary.FileWriter(FLAGS.save_path) saver = tf.compat.v1.train.Saver() if FLAGS.saved_model != '': saver.restore(sess, FLAGS.saved_model) sess.run(tf.global_variables_initializer()) writer.add_graph(sess.graph) for epoch_id in range(FLAGS.maxepoch): trainer.train_epoch(sess, keep_prob=FLAGS.keep_prob, summary_writer=writer) trainer.valid_epoch(sess, dataflow=valid_data, summary_writer=writer) # connection part msg = {} sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server_address = ('0.0.0.0', 5555) sock.connect(server_address) if scheduler_op == 'g' or 's': saved_model = '{}vgg-cifar-epoch-{}'.format( FLAGS.save_path, epoch_id) saver.save(sess, saved_model) # generate command if scheduler_op == 'g': sch_gpus = FLAGS.gpus.split(',') + target_gpus else: sch_gpus = FLAGS.gpus.split(',') - target_gpus current_gpus = ','.join(sch_gpus) config_command = '--train --port ' + str(FLAGS.port) + ' --data_path ' + str(FLAGS.data_path) +\ ' --saved_model ' + saved_model + ' --save_path ' + FLAGS.save_path + ' --lr ' +\ str(FLAGS.lr) + ' --bsize ' + str(FLAGS.bsize) + ' --keep_prob ' +\ str(FLAGS.keep_prob) + ' --maxepoch ' + str(FLAGS.maxepoch-epoch_id-1) +\ ' --gpus ' + current_gpus # send command back to the scheduler msg['config'] = config_command msg['id'] = FLAGS.id msg['ep'] = epoch_id + 1 msg['gpus'] = sch_gpus sock.sendall(dict_to_binary(msg)) # leave the scheduler to restart exit() else: msg['id'] = FLAGS.id msg['ep'] = epoch_id + 1 sock.sendall(dict_to_binary(msg)) # saver.save(sess, '{}vgg-cifar-epoch-{}'.format(SAVE_PATH, epoch_id)) saver.save(sess, '{}vgg-cifar-epoch-{}'.format(FLAGS.save_path, epoch_id))
def train_type_1(): FLAGS = get_args() if FLAGS.gan_type == 'lsgan': gan_model = LSGAN print('**** LSGAN ****') elif FLAGS.gan_type == 'dcgan': gan_model = DCGAN print('**** DCGAN ****') else: raise ValueError('Wrong GAN type!') save_path = os.path.join(SAVE_PATH, FLAGS.gan_type) save_path += '/' # load dataset if FLAGS.dataset == 'celeba': train_data = loader.load_celeba(FLAGS.bsize, data_path=CELEBA_PATH) im_size = 64 n_channels = 3 else: train_data = loader.load_mnist(FLAGS.bsize, data_path=MNIST_PATH) im_size = 28 n_channels = 1 # init training model train_model = gan_model(input_len=FLAGS.zlen, im_size=im_size, n_channels=n_channels) train_model.create_train_model() # init generate model generate_model = gan_model(input_len=FLAGS.zlen, im_size=im_size, n_channels=n_channels) generate_model.create_generate_model() # create trainer trainer = Trainer(train_model, train_data, moniter_gradient=False, init_lr=FLAGS.lr, save_path=save_path) # create generator for sampling generator = Generator(generate_model, keep_prob=FLAGS.keep_prob, save_path=save_path) sessconfig = tf.ConfigProto() sessconfig.gpu_options.allow_growth = True with tf.Session(config=sessconfig) as sess: writer = tf.summary.FileWriter(save_path) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) writer.add_graph(sess.graph) for epoch_id in range(FLAGS.maxepoch): trainer.train_epoch(sess, keep_prob=FLAGS.keep_prob, n_g_train=FLAGS.ng, n_d_train=FLAGS.nd, summary_writer=writer) generator.random_sampling(sess, plot_size=10, file_id=epoch_id) generator.viz_interpolate(sess, file_id=epoch_id) if FLAGS.zlen == 2: generator.viz_2D_manifold(sess, plot_size=20, file_id=epoch_id) saver.save( sess, '{}gan-{}-epoch-{}'.format(save_path, FLAGS.gan_type, epoch_id)) saver.save( sess, '{}gan-{}-epoch-{}'.format(save_path, FLAGS.gan_type, epoch_id))