def main(): if FLAGS.exp == 'dir64': opts = configs.config_dir64 else: assert False, 'Unknown experiment configuration' if FLAGS.zdim is not None: opts['zdim'] = FLAGS.zdim if opts['verbose']: logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') wae = WAE(opts) wae.restore_checkpoint(FLAGS.checkpoint) batch_noise = wae.sample_pz(10) sample_gen = wae.sess.run(wae.decoded, feed_dict={ wae.sample_noise: batch_noise, wae.is_training: False }) img = np.hstack(sample_gen) img = (img + 1.0) / 2 plt.imshow(img) plt.savefig('img.png')
def main(): if FLAGS.exp == 'celebA': opts = configs.config_celebA elif FLAGS.exp == 'celebA_small': opts = configs.config_celebA_small elif FLAGS.exp == 'mnist': opts = configs.config_mnist elif FLAGS.exp == 'mnist_small': opts = configs.config_mnist_small elif FLAGS.exp == 'dsprites': opts = configs.config_dsprites elif FLAGS.exp == 'grassli': opts = configs.config_grassli elif FLAGS.exp == 'grassli_small': opts = configs.config_grassli_small elif FLAGS.exp == 'dir64': opts = configs.config_dir64 else: assert False, 'Unknown experiment configuration' if FLAGS.zdim is not None: opts['zdim'] = FLAGS.zdim if FLAGS.lr is not None: opts['lr'] = FLAGS.lr if FLAGS.z_test is not None: opts['z_test'] = FLAGS.z_test if FLAGS.lambda_schedule is not None: opts['lambda_schedule'] = FLAGS.lambda_schedule if FLAGS.work_dir is not None: opts['work_dir'] = FLAGS.work_dir if FLAGS.wae_lambda is not None: opts['lambda'] = FLAGS.wae_lambda if FLAGS.enc_noise is not None: opts['e_noise'] = FLAGS.enc_noise if FLAGS.epoch_num is not None: opts['epoch_num'] = FLAGS.epoch_num if opts['verbose']: logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') utils.create_dir(opts['work_dir']) utils.create_dir(os.path.join(opts['work_dir'], 'checkpoints')) # Dumping all the configs to the text file with utils.o_gfile((opts['work_dir'], 'params.txt'), 'w') as text: text.write('Parameters:\n') for key in opts: text.write('%s : %s\n' % (key, opts[key])) # Loading the dataset data = DataHandler(opts) assert data.num_points >= opts['batch_size'], 'Training set too small' # Training WAE wae = WAE(opts) wae.train(data)
def main(): opts = configs.config_mnist opts['mode'] = 'train' if opts['verbose']: logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') utils.create_dir(opts['work_dir']) utils.create_dir(os.path.join(opts['work_dir'], 'checkpoints')) if opts['e_noise'] == 'gaussian' and opts['pz'] != 'normal': assert False, 'Gaussian encoders compatible only with Gaussian prior' return # Dumping all the configs to the text file with utils.o_gfile((opts['work_dir'], 'params.txt'), 'w') as text: text.write('Parameters:\n') for key in opts: text.write('%s : %s\n' % (key, opts[key])) # Loading the dataset data = DataHandler(opts) assert data.num_points >= opts['batch_size'], 'Training set too small' if opts['mode'] == 'train': # Creating WAE model wae = WAE(opts, data.num_points) # Training WAE wae.train(data) elif opts['mode'] == 'test': # Do something else improved_wae.improved_sampling(opts)
def main(): if FLAGS.exp == 'dir64': opts = configs.config_dir64 else: assert False, 'Unknown experiment configuration' if FLAGS.zdim is not None: opts['zdim'] = FLAGS.zdim if opts['verbose']: logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') data = DataHandler(opts) wae = WAE(opts) wae.restore_checkpoint(FLAGS.checkpoint) batch_img = data.data[0:2] enc_vec = wae.sess.run(wae.encoded, feed_dict={ wae.sample_points: batch_img, wae.is_training: False }) vdiff = enc_vec[1] - enc_vec[0] vdiff = vdiff / 10 gen_vec = np.zeros((10, vdiff.shape[0]), dtype=np.float32) for i in range(10): gen_vec[i, :] = enc_vec[0] + vdiff * i sample_gen = wae.sess.run(wae.decoded, feed_dict={ wae.sample_noise: gen_vec, wae.is_training: False }) img = np.hstack(sample_gen) img = (img + 1.0) / 2 plt.imshow(img) plt.savefig('analogy.png')
def main(): if FLAGS.exp == 'celebA': opts = configs.config_celebA elif FLAGS.exp == 'celebA_small': opts = configs.config_celebA_small elif FLAGS.exp == 'mnist': opts = configs.config_mnist elif FLAGS.exp == 'mnist_small': opts = configs.config_mnist_small elif FLAGS.exp == 'dsprites': opts = configs.config_dsprites elif FLAGS.exp == 'grassli': opts = configs.config_grassli elif FLAGS.exp == 'grassli_small': opts = configs.config_grassli_small else: assert False, 'Unknown experiment configuration' opts['mode'] = FLAGS.mode if opts['mode'] == 'test': assert FLAGS.checkpoint is not None, 'Checkpoint must be provided' opts['checkpoint'] = FLAGS.checkpoint if FLAGS.zdim is not None: opts['zdim'] = FLAGS.zdim if FLAGS.pz is not None: opts['pz'] = FLAGS.pz if FLAGS.lr is not None: opts['lr'] = FLAGS.lr if FLAGS.w_aef is not None: opts['w_aef'] = FLAGS.w_aef if FLAGS.z_test is not None: opts['z_test'] = FLAGS.z_test if FLAGS.lambda_schedule is not None: opts['lambda_schedule'] = FLAGS.lambda_schedule if FLAGS.work_dir is not None: opts['work_dir'] = FLAGS.work_dir if FLAGS.wae_lambda is not None: opts['lambda'] = FLAGS.wae_lambda if FLAGS.enc_noise is not None: opts['e_noise'] = FLAGS.enc_noise if opts['verbose']: logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') utils.create_dir(opts['work_dir']) utils.create_dir(os.path.join(opts['work_dir'], 'checkpoints')) if opts['e_noise'] == 'gaussian' and opts['pz'] != 'normal': assert False, 'Gaussian encoders compatible only with Gaussian prior' return # Dumping all the configs to the text file with utils.o_gfile((opts['work_dir'], 'params.txt'), 'w') as text: text.write('Parameters:\n') for key in opts: text.write('%s : %s\n' % (key, opts[key])) # Loading the dataset data = DataHandler(opts) assert data.num_points >= opts['batch_size'], 'Training set too small' if opts['mode'] == 'train': # Creating WAE model wae = WAE(opts) # Training WAE wae.train(data) elif opts['mode'] == 'test': # Do something else improved_wae.improved_sampling(opts)
def main(): if FLAGS.exp == 'celebA': opts = configs.config_celebA elif FLAGS.exp == 'celebA_small': opts = configs.config_celebA_small elif FLAGS.exp == 'mnist': opts = configs.config_mnist elif FLAGS.exp == 'mnist_ord': opts = configs.config_mnist_ord elif FLAGS.exp == 'mnist_small': opts = configs.config_mnist_small elif FLAGS.exp == 'dsprites': opts = configs.config_dsprites elif FLAGS.exp == 'grassli': opts = configs.config_grassli elif FLAGS.exp == 'grassli_small': opts = configs.config_grassli_small elif FLAGS.exp == 'syn_constant_uniform': opts = configs.config_syn_constant_uniform elif FLAGS.exp == 'syn_2_constant_uniform': opts = configs.config_syn_2_constant_uniform elif FLAGS.exp == 'checkers': opts = configs.config_checkers elif FLAGS.exp == 'noise': opts = configs.config_noise elif FLAGS.exp == 'noise_unif': opts = configs.config_noise_unif else: assert False, 'Unknown experiment configuration' opts['exp'] = FLAGS.exp opts['seed'] = FLAGS.seed opts['mode'] = FLAGS.mode if opts['mode'] == 'test': assert FLAGS.checkpoint is not None, 'Checkpoint must be provided' opts['checkpoint'] = FLAGS.checkpoint if FLAGS.batch_size is not None: opts['batch_size'] = FLAGS.batch_size if FLAGS.recalculate_size is not None: opts['recalculate_size'] = FLAGS.recalculate_size assert opts['recalculate_size'] >= opts[ 'batch_size'], "recalculate_size should be at least as large as batch_size" else: opts['recalculate_size'] = opts['batch_size'] if FLAGS.zdim is not None: opts['zdim'] = FLAGS.zdim if FLAGS.pz is not None: opts['pz'] = FLAGS.pz if FLAGS.lr is not None: opts['lr'] = FLAGS.lr if FLAGS.lr_schedule is not None: opts['lr_schedule'] = FLAGS.lr_schedule if FLAGS.w_aef is not None: opts['w_aef'] = FLAGS.w_aef if FLAGS.z_test is not None: opts['z_test'] = FLAGS.z_test if FLAGS.lambda_schedule is not None: opts['lambda_schedule'] = FLAGS.lambda_schedule if FLAGS.work_dir is not None: opts['work_dir'] = FLAGS.work_dir if FLAGS.wae_lambda is not None: opts['lambda'] = FLAGS.wae_lambda if FLAGS.enc_noise is not None: opts['e_noise'] = FLAGS.enc_noise if FLAGS.z_test_scope is not None: opts['z_test_scope'] = FLAGS.z_test_scope if FLAGS.length_lambda is not None: opts['length_lambda'] = FLAGS.length_lambda if FLAGS.grad_clip is not None: opts['grad_clip'] = FLAGS.grad_clip else: opts['grad_clip'] = None if FLAGS.rec_lambda is not None: opts['rec_lambda'] = FLAGS.rec_lambda if FLAGS.zxz_lambda is not None: opts['zxz_lambda'] = FLAGS.zxz_lambda if FLAGS.train_size is not None: opts['train_size'] = FLAGS.train_size if FLAGS.nat_size is not None: opts['nat_size'] = FLAGS.nat_size else: opts['nat_size'] = FLAGS.train_size opts['nat_resampling'] = FLAGS.nat_resampling opts['sinkhorn_sparse'] = FLAGS.sinkhorn_sparse opts['sinkhorn_sparsifier'] = FLAGS.sinkhorn_sparsifier opts['sparsifier_freq'] = FLAGS.sparsifier_freq opts['sinkhorn_unbiased'] = FLAGS.sinkhorn_unbiased opts['feed_by_score_from_epoch'] = FLAGS.feed_by_score_from_epoch opts['recalculate_size'] = FLAGS.recalculate_size opts['stay_lambda'] = FLAGS.stay_lambda opts['mover_ratio'] = FLAGS.mover_ratio assert opts['mover_ratio'] >= 0 and opts[ 'mover_ratio'] <= 1, "mover_ratio must be in [0,1]" if FLAGS.sinkhorn_iters is not None: opts['sinkhorn_iters'] = FLAGS.sinkhorn_iters if FLAGS.sinkhorn_epsilon is not None: opts['sinkhorn_epsilon'] = FLAGS.sinkhorn_epsilon if FLAGS.name is not None: opts['name'] = FLAGS.name if FLAGS.tags is not None: opts['tags'] = FLAGS.tags if FLAGS.epoch_num is not None: opts['epoch_num'] = FLAGS.epoch_num if FLAGS.e_pretrain is not None: opts['e_pretrain'] = FLAGS.e_pretrain if FLAGS.shuffle is not None: opts['shuffle'] = FLAGS.shuffle if opts['verbose']: pass #logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') utils.create_dir(opts['work_dir']) utils.create_dir(os.path.join(opts['work_dir'], 'checkpoints')) if opts['e_noise'] == 'gaussian' and opts['pz'] != 'normal': assert False, 'Gaussian encoders compatible only with Gaussian prior' return # Dumping all the configs to the text file with utils.o_gfile((opts['work_dir'], 'params.txt'), 'w') as text: text.write('Parameters:\n') for key in opts: text.write('%s : %s\n' % (key, opts[key])) # Loading the dataset data = DataHandler(opts) assert data.num_points >= opts['batch_size'], 'Training set too small' if 'train_size' in opts and opts['train_size'] is not None: train_size = opts['train_size'] else: train_size = data.num_points print("Train size:", train_size) if opts['nat_size'] == -1: opts['nat_size'] = train_size use_neptune = "NEPTUNE_API_TOKEN" in os.environ if opts['mode'] == 'train': if use_neptune: neptune.init(project_qualified_name="csadrian/global-sinkhorn") exp = neptune.create_experiment( params=opts, name=opts['name'], upload_source_files=['wae.py', 'run.py', 'models.py']) for tag in opts['tags'].split(','): neptune.append_tag(tag) # Creating WAE model wae = WAE(opts, train_size) data.num_points = train_size # Training WAE wae.train(data) if use_neptune: exp.stop() elif opts['mode'] == 'test': # Do something else improved_wae.improved_sampling(opts) elif opts['mode'] == 'generate': fideval.generate(opts) elif opts['mode'] == 'draw': picture_plot.createimgs(opts)
def main(): if FLAGS.exp == 'celebA': opts = configs.config_celebA elif FLAGS.exp == 'celebA_small': opts = configs.config_celebA_small elif FLAGS.exp == 'celebA_ae_patch_var': opts = configs.config_celebA_ae_patch_var elif FLAGS.exp == 'celebA_sylvain_adv': opts = configs.config_celebA_sylvain_adv elif FLAGS.exp == 'celebA_adv': opts = configs.config_celebA_adv elif FLAGS.exp == 'mnist': opts = configs.config_mnist elif FLAGS.exp == 'mnist_small': opts = configs.config_mnist_small elif FLAGS.exp == 'dsprites': opts = configs.config_dsprites elif FLAGS.exp == 'grassli': opts = configs.config_grassli elif FLAGS.exp == 'grassli_small': opts = configs.config_grassli_small elif FLAGS.exp == 'cifar': opts = configs.config_cifar else: assert False, 'Unknown experiment configuration' opts['mode'] = FLAGS.mode if opts['mode'] == 'test': assert FLAGS.checkpoint is not None, 'Checkpoint must be provided' opts['checkpoint'] = FLAGS.checkpoint if FLAGS.zdim is not None: opts['zdim'] = FLAGS.zdim if FLAGS.pz is not None: opts['pz'] = FLAGS.pz if FLAGS.lr is not None: opts['lr'] = FLAGS.lr if FLAGS.w_aef is not None: opts['w_aef'] = FLAGS.w_aef if FLAGS.z_test is not None: opts['z_test'] = FLAGS.z_test if FLAGS.lambda_schedule is not None: opts['lambda_schedule'] = FLAGS.lambda_schedule if FLAGS.work_dir is not None: opts['work_dir'] = FLAGS.work_dir if FLAGS.wae_lambda is not None: opts['lambda'] = FLAGS.wae_lambda if FLAGS.celebA_crop is not None: opts['celebA_crop'] = FLAGS.celebA_crop if FLAGS.enc_noise is not None: opts['e_noise'] = FLAGS.enc_noise if FLAGS.e_num_filters is not None: opts['e_num_filters'] = FLAGS.e_num_filters if FLAGS.g_num_filters is not None: opts['g_num_filters'] = FLAGS.g_num_filters if FLAGS.smart_cost is True: opts['cost'] = [] if FLAGS.patch_var_w is not None: opts['cost'].append(('patch_variances', FLAGS.patch_var_w)) if FLAGS.l2sq_w is not None: opts['cost'].append(('l2sq', FLAGS.l2sq_w)) if FLAGS.sylvain_adv_c_w is not None and FLAGS.sylvain_emb_c_w is not None: adv_c_w = FLAGS.sylvain_adv_c_w emb_c_w = FLAGS.sylvain_emb_c_w opts['cost'].append( ('_sylvain_recon_loss_using_disc_conv', [adv_c_w, emb_c_w])) opts['cross_p_w'] = 0 opts['diag_p_w'] = 0 if FLAGS.adv_c_num_units is not None: opts['adv_c_num_units'] = FLAGS.adv_c_num_units if FLAGS.adv_c_patches_size is not None: opts['adv_c_patches_size'] = FLAGS.adv_c_patches_size if FLAGS.adv_use_sq is not None: opts['adv_use_sq'] = FLAGS.adv_use_sq if opts['verbose']: logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') utils.create_dir(opts['work_dir']) utils.create_dir(os.path.join(opts['work_dir'], 'checkpoints')) if opts['e_noise'] == 'gaussian' and opts['pz'] != 'normal': assert False, 'Gaussian encoders compatible only with Gaussian prior' return # Dumping all the configs to the text file with utils.o_gfile((opts['work_dir'], 'params.txt'), 'w') as text: text.write('Parameters:\n') for key in opts: text.write('%s : %s\n' % (key, opts[key])) # Loading the dataset data = DataHandler(opts) assert data.num_points >= opts['batch_size'], 'Training set too small' # Creating WAE model wae = WAE(opts) if opts['mode'] == 'train': # Training WAE wae.train(data) elif opts['mode'] == 'test': # Do something else wae.test()
def main(): # dataset = 'shuffle_mnist' dataset = 'disjoint_mnist' # dataset = 'cifar_10' # Load configs opts = configs.config_mnist # opts = configs.config_cifar task_num = opts['task_num'] logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') # Load data data = setdata.set_data(dataset, task_num) logging.error(opts) # print (data[0].data.shape) # print (data[0].labels.shape) # print (data[0].labels[:5]) # print (data[1].labels[:5]) # print (data[0].test_labels[:5]) # print (data[1].test_labels[:5]) # print (data[0].test_data.shape) # print (data[0].test_labels.shape) # print (data[0].data_shape) # images = [] # for i in range(3): # img = np.reshape(data[i].data[:10], (10, 28, 28)) # images.append(np.concatenate(img, axis = 1)) # img = np.concatenate([images[0], images[1], images[2]], axis = 0) # plt.imshow(img, cmap='Greys_r',interpolation='none', vmin=0., vmax=1) # plt.savefig("dataset_1.jpg") # tf.set_random_seed(1233) with tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=True))) as sess: # Initialize model wae = WAE(opts, sess) # Initialize variables sess.run(wae.init) logging.error(tf.trainable_variables()) logging.error(len(tf.trainable_variables())) # logging.error(wae.t_network_vars) # Initialize/Set former parameters as zeros random_z, pseudo_G_z, pseudo_T_z, w_old, b_old, f_g_z, pred = wae.former_init( ) for i in range(len(data)): if i == 0: # Consider no trans loss and regularization loss in Task 0 lambda_list = { 'wae_lambda': 1, 'rec_lambda': 0.1, 'trans_lambda': 0.0, 'reg_lambda': 0.0, 'f_lambda': 0.0, 'main_lambda': 1.0 } logging.error("task " + str(i) + ":") logging.error(lambda_list) wae.train(data, i, random_z, pseudo_G_z, pseudo_T_z, w_old, b_old, f_g_z, pred, lambda_list) wae.test(data, True) # Update/Save former parameters random_z, pseudo_G_z, pseudo_T_z, w_old, b_old, f_g_z, pred = wae.former( ) print(type(random_z)) print(type(pseudo_G_z)) print(type(pseudo_T_z[-1])) print(pseudo_G_z.shape) # Print images generated by random z # idx = np.random.choice(400, 20, replace=False) # img_gen = np.reshape(pseudo_G_z[idx], (20, 28, 28)) # img_gen = np.concatenate(img_gen, axis = 1) # plt.imshow(img_gen, cmap='Greys_r',interpolation='none', vmin=0., vmax=1) # plt.savefig("img_gen_after_task_%d_0.jpg" % i) # break else: lambda_list = { 'wae_lambda': 1.0, 'rec_lambda': 0.5 * opts['task_num'], 'trans_lambda': 0.01 * opts['task_num'], 'reg_lambda': 0.01, 'f_lambda': 1.0, 'main_lambda': 1.0 } logging.error("task " + str(i) + ":") logging.error(lambda_list) wae.train(data, i, random_z, pseudo_G_z, pseudo_T_z, w_old, b_old, f_g_z, pred, lambda_list) wae.test(data, True) random_z, pseudo_G_z, pseudo_T_z, w_old, b_old, f_g_z, pred = wae.former( )
def main(): # Select dataset to use if FLAGS.exp == 'celebA': opts = configs.config_celebA elif FLAGS.exp == 'celebA_small': opts = configs.config_celebA_small elif FLAGS.exp == 'mnist': opts = configs.config_mnist elif FLAGS.exp == 'mnist_small': opts = configs.config_mnist_small elif FLAGS.exp == 'dsprites': opts = configs.config_dsprites elif FLAGS.exp == 'grassli': opts = configs.config_grassli elif FLAGS.exp == 'grassli_small': opts = configs.config_grassli_small else: assert False, 'Unknown experiment configuration' # Select training method if FLAGS.method: opts['method'] = FLAGS.method # Working directory if FLAGS.work_dir: opts['work_dir'] = FLAGS.work_dir # Verbose if opts['verbose']: logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s') # Create directories utils.create_dir(opts['method']) work_dir = os.path.join(opts['method'], opts['work_dir']) utils.create_dir(work_dir) utils.create_dir(os.path.join(work_dir, 'checkpoints')) # Dumping all the configs to the text file with utils.o_gfile((work_dir, 'params.txt'), 'w') as text: text.write('Parameters:\n') for key in opts: text.write('%s : %s\n' % (key, opts[key])) # Loading the dataset data = DataHandler(opts) assert data.num_points >= opts['batch_size'], 'Training set too small' #Reset tf graph tf.reset_default_graph() # build WAE wae = WAE(opts) # Training/testing/vizu if FLAGS.mode == "train": wae.train(data, opts['work_dir'], FLAGS.weights_file) elif FLAGS.mode == "test": wae.test(data, opts['work_dir'], FLAGS.weights_file) elif FLAGS.mode == "reg": wae.reg(data, opts['work_dir'], FLAGS.weights_file) elif FLAGS.mode == "vizu": wae.vizu(data, opts['work_dir'], FLAGS.weights_file)
if DATASET == 'grid': data_dist = Grid() elif DATASET == 'low_dim_embed': data_dist = LowDimEmbed() elif DATASET == 'color_mnist': data_dist = CMNIST(os.path.join('data', 'mnist')) elif DATASET == 'cifar_100': data_dist = CIFAR100(os.path.join('data', 'cifar-100')) NOISE_PERTURB = 0.005 if METHOD == 'pae': from pae import PAE model = PAE(data_dist, noise_dist, flags, args) elif METHOD == 'wae': from wae import WAE model = WAE(data_dist, noise_dist, flags, args) elif METHOD == 'cougan': noise_dist = NormalNoise(flags.NOISE_DIM) from cougan import CoulombGAN model = CoulombGAN(data_dist, noise_dist, flags, args) elif METHOD == 'bigan': from bigan import BiGAN model = BiGAN(data_dist, noise_dist, flags, args) elif METHOD == 'veegan': if DATASET == 'grid': flags.NOISE_DIM = 254 # refer to author's implementation flags.GEN_ARCH[0] = flags.NOISE_DIM flags.ENC_ARCH[-1] = flags.NOISE_DIM flags.DISC_ARCH[0] = flags.NOISE_DIM + flags.DATA_DIM #flags.DISC_ARCH = [flags.NOISE_DIM+flags.DATA_DIM, 128, 1] elif DATASET == 'low_dim_embed':