def main(): # dataset config if FLAGS.dataset == 'mnist': opts = configs.config_mnist elif FLAGS.dataset == 'svhn': opts = configs.config_svhn else: assert False, 'Unknown dataset' if FLAGS.data_dir: opts['data_dir'] = FLAGS.data_dir else: raise Exception('You must provide a data_dir') # Model set up opts['model'] = FLAGS.model opts['cost'] = FLAGS.cost opts['beta'] = FLAGS.beta opts['decoder'] = FLAGS.decoder opts['net_archi'] = FLAGS.net_archi # Create directories results_dir = 'results' if not tf.io.gfile.isdir(results_dir): utils.create_dir(results_dir) opts['out_dir'] = os.path.join(results_dir,FLAGS.out_dir) if not tf.io.gfile.isdir(opts['out_dir']): utils.create_dir(opts['out_dir']) out_subdir = os.path.join(opts['out_dir'], opts['model']) if not tf.io.gfile.isdir(out_subdir): utils.create_dir(out_subdir) exp_dir = os.path.join(out_subdir, '{}_{}_{:%Y_%m_%d_%H_%M}'.format( FLAGS.res_dir, opts['beta'], datetime.now()), ) opts['exp_dir'] = exp_dir if not tf.io.gfile.isdir(exp_dir): utils.create_dir(exp_dir) utils.create_dir(os.path.join(exp_dir, 'checkpoints')) # Verbose logging.basicConfig(filename=os.path.join(opts['exp_dir'],'outputs.log'), level=logging.INFO, format='%(asctime)s - %(message)s') # Experiemnts set up opts['lr'] = FLAGS.lr opts['it_num'] = FLAGS.num_it opts['print_every'] = int(opts['it_num'] / 2.) opts['evaluate_every'] = int(opts['it_num'] / 4.) opts['save_every'] = 10000000000 opts['save_final'] = FLAGS.save_model opts['save_train_data'] = FLAGS.save_data #Reset tf graph tf.reset_default_graph() # Loading the dataset data = DataHandler(opts) assert data.train_size >= opts['batch_size'], 'Training set too small' # inti method run = Run(opts, data) # Training/testing/vizu if FLAGS.mode=="train": # Dumping all the configs to the text file with utils.o_gfile((exp_dir, 'params.txt'), 'w') as text: text.write('Parameters:\n') for key in opts: text.write('%s : %s\n' % (key, opts[key])) run.train() else: assert False, 'Unknown mode %s' % FLAGS.mode
def main(): # Select dataset to use if FLAGS.dataset == 'mnist': opts = configs.config_mnist elif FLAGS.dataset == 'smallNORB': opts = configs.config_smallNORB elif FLAGS.dataset == 'celebA': opts = configs.config_celebA else: assert False, 'Unknown dataset' # model opts['model'] = FLAGS.model opts['encoder'] = [ FLAGS.encoder, ] * opts['nlatents'] # opts['use_sigmoid'] = FLAGS.sigmoid opts['archi'] = [ FLAGS.net_archi, ] * opts['nlatents'] opts['obs_cost'] = FLAGS.cost opts['lambda_schedule'] = FLAGS.lmba_schedule opts['enc_sigma_pen'] = FLAGS.enc_sigma_pen opts['dec_sigma_pen'] = FLAGS.dec_sigma_pen # lamba lambda_rec = [0.01, 0.1] lamdba_match = [0.0001, 0.001] schedule = [ 'constant', ] sigmoid = [ False, ] lmba = list(itertools.product(schedule, sigmoid, lambda_rec, lamdba_match)) id = (FLAGS.id - 1) % len(lmba) sche, sig, lrec, lmatch = lmba[id][0], lmba[id][1], lmba[id][2], lmba[id][ 3] opts['lambda_schedule'] = sche opts['use_sigmoid'] = sig opts['lambda_init'] = [ lrec * log(n + 1.0001) / opts['zdim'][n] for n in range(0, opts['nlatents'] - 1) ] + [ lmatch / 100, ] opts['lambda'] = [ lrec**(n + 1) / opts['zdim'][n] for n in range(0, opts['nlatents'] - 1) ] + [ lmatch, ] # Create directories results_dir = 'results' if not tf.io.gfile.isdir(results_dir): utils.create_dir(results_dir) opts['out_dir'] = os.path.join(results_dir, FLAGS.out_dir) if not tf.io.gfile.isdir(opts['out_dir']): utils.create_dir(opts['out_dir']) out_subdir = os.path.join(opts['out_dir'], opts['model']) if not tf.io.gfile.isdir(out_subdir): utils.create_dir(out_subdir) out_subdir = os.path.join(out_subdir, 'l' + sche + '_sig' + str(sig)) if not tf.io.gfile.isdir(out_subdir): utils.create_dir(out_subdir) opts['exp_dir'] = FLAGS.res_dir if opts['model'] == 'stackedwae': exp_dir = os.path.join( out_subdir, '{}_{}layers_lrec{}_lmatch{}_{:%Y_%m_%d_%H_%M}'.format( opts['exp_dir'], opts['nlatents'], lrec, lmatch, datetime.now())) else: exp_dir = os.path.join( out_subdir, '{}_lmatch{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'], lmatch, datetime.now())) opts['exp_dir'] = exp_dir if not tf.io.gfile.isdir(exp_dir): utils.create_dir(exp_dir) utils.create_dir(os.path.join(exp_dir, 'checkpoints')) # getting weights path if FLAGS.weights_file is not None: WEIGHTS_PATH = os.path.join(opts['exp_dir'], 'checkpoints', FLAGS.weights_file) else: WEIGHTS_PATH = None # Verbose logging.basicConfig(filename=os.path.join(exp_dir, 'outputs.log'), level=logging.INFO, format='%(asctime)s - %(message)s') # run set up opts['vizu_splitloss'] = FLAGS.losses opts['vizu_fullrec'] = FLAGS.reconstructions opts['vizu_embedded'] = FLAGS.embedded opts['vizu_latent'] = FLAGS.latents opts['vizu_pz_grid'] = FLAGS.grid opts['vizu_stochasticity'] = FLAGS.stoch opts['fid'] = FLAGS.fid opts['it_num'] = FLAGS.num_it opts['print_every'] = int(opts['it_num'] / 4) opts['evaluate_every'] = int(opts['it_num'] / 50) if FLAGS.batch_size is not None: opts['batch_size'] = FLAGS.batch_size opts['lr'] = FLAGS.lr opts['use_trained'] = FLAGS.use_trained opts['save_every'] = 10000000000 opts['save_final'] = FLAGS.save_model opts['save_train_data'] = FLAGS.save_data #Reset tf graph tf.compat.v1.reset_default_graph() # Loading the dataset opts['data_dir'] = FLAGS.data_dir data = DataHandler(opts) assert data.train_size >= opts['batch_size'], 'Training set too small' # build model run = Run(opts, data) # Training/testing/vizu if FLAGS.mode == "train": # Dumping all the configs to the text file with utils.o_gfile((opts['exp_dir'], 'params.txt'), 'w') as text: text.write('Parameters:\n') for key in opts: text.write('%s : %s\n' % (key, opts[key])) run.train(WEIGHTS_PATH) elif FLAGS.mode == "vizu": opts['rec_loss_nsamples'] = 1 opts['sample_recons'] = False run.latent_interpolation(opts['exp_dir'], WEIGHTS_PATH) elif FLAGS.mode == "fid": run.fid_score(WEIGHTS_PATH) elif FLAGS.mode == "test": run.test_losses(WEIGHTS_PATH) elif FLAGS.mode == "vlae_exp": run.vlae_experiment(WEIGHTS_PATH) else: assert False, 'Unknown mode %s' % FLAGS.mode
def main(): # Select dataset to use if FLAGS.dataset == 'dsprites': opts = configs.config_dsprites elif FLAGS.dataset == 'noisydsprites': opts = configs.config_noisydsprites elif FLAGS.dataset == 'screamdsprites': opts = configs.config_screamdsprites elif FLAGS.dataset == 'smallNORB': opts = configs.config_smallNORB elif FLAGS.dataset == '3dshapes': opts = configs.config_3dshapes elif FLAGS.dataset == '3Dchairs': opts = configs.config_3Dchairs elif FLAGS.dataset == 'celebA': opts = configs.config_celebA elif FLAGS.dataset == 'mnist': opts = configs.config_mnist else: assert False, 'Unknown dataset' # Set method param opts['fid'] = FLAGS.fid opts['cost'] = FLAGS.cost #l2, l2sq, l2sq_norm, l1, xentropy if FLAGS.net_archi: opts['network'] = net_configs[FLAGS.net_archi] else: if FLAGS.dataset == 'celebA': opts['network'] = net_configs['conv_rae'] else: opts['network'] = net_configs['conv_locatello'] # Model set up opts['model'] = FLAGS.model if FLAGS.dataset == 'celebA': opts['zdim'] = 32 elif FLAGS.dataset == '3Dchairs': opts['zdim'] = 16 else: opts['zdim'] = 10 opts['lr'] = 0.0001 # Objective Function Coefficients if FLAGS.dataset == 'celebA': if opts['model'] == 'BetaTCVAE': beta = [1, 2, 4, 6, 8, 10] coef_id = (FLAGS.id - 1) % len(beta) opts['obj_fn_coeffs'] = beta[coef_id] elif opts['model'] == 'FactorVAE': beta = [1, 2, 4, 6, 8, 10] # beta = [1, 5, 10, 25, 50, 100] coef_id = (FLAGS.id - 1) % len(beta) opts['obj_fn_coeffs'] = beta[coef_id] elif opts['model'] == 'TCWAE_MWS': # beta = [0.1, 0.25, 0.5, 1., 2., 5.] # gamma = [0.1, 0.25, 0.5, 1., 2., 5.] beta = [1., 2., 5.] gamma = [0.1, 0.25, 0.5] lmba = list(itertools.product(beta, gamma)) coef_id = (FLAGS.id - 1) % len(lmba) opts['obj_fn_coeffs'] = list(lmba[coef_id]) elif opts['model'] == 'TCWAE_GAN': # beta = [0.1, 0.25, 0.5, 1., 2., 5.] # gamma = [0.1, 0.25, 0.5, 1., 2., 5.] beta = [1., 2., 5., 10., 15.] gamma = [1., 2., 5.] # lmba = list(itertools.product(beta, gamma)) coef_id = (FLAGS.id - 1) % len(lmba) opts['obj_fn_coeffs'] = list(lmba[coef_id]) else: raise Exception('Unknown {} model for celebA'.format( opts['model'])) elif FLAGS.dataset == '3Dchairs': if opts['model'] == 'BetaTCVAE': beta = [1, 2, 4, 6, 8, 10] coef_id = (FLAGS.id - 1) % len(beta) opts['obj_fn_coeffs'] = beta[coef_id] elif opts['model'] == 'FactorVAE': beta = [1., 2., 4.] # beta = [1., 5., 10., 25., 50., 100.] coef_id = (FLAGS.id - 1) % len(beta) opts['obj_fn_coeffs'] = beta[coef_id] elif opts['model'] == 'TCWAE_MWS': # beta = [0.1, 0.25, 0.5, 1., 2., 5.] # gamma = [0.1, 0.25, 0.5, 1., 2., 5.] beta = [1., 2., 5.] gamma = [0.1, 0.25, 0.5] lmba = list(itertools.product(beta, gamma)) coef_id = (FLAGS.id - 1) % len(lmba) opts['obj_fn_coeffs'] = list(lmba[coef_id]) elif opts['model'] == 'TCWAE_GAN': # beta = [0.1, 0.5, 1., 2., 5., 10.] # gamma = [0.1, 0.5, 1., 2., 5., 10.] beta = [0.1, 0.5] gamma = [0.1, 0.5, 1.] lmba = list(itertools.product(beta, gamma)) coef_id = (FLAGS.id - 1) % len(lmba) opts['obj_fn_coeffs'] = list(lmba[coef_id]) else: raise Exception('Unknown {} model for celebA'.format( opts['model'])) else: if opts['model'] == 'BetaTCVAE': beta = [1, 2, 4, 6, 8, 10] coef_id = (FLAGS.id - 1) % len(beta) opts['obj_fn_coeffs'] = beta[coef_id] elif opts['model'] == 'FactorVAE': beta = [1, 10, 25, 50, 75, 100] coef_id = (FLAGS.id - 1) % len(beta) opts['obj_fn_coeffs'] = beta[coef_id] elif opts['model'] == 'WAE': beta = [1, 5, 10, 25, 50, 100] coef_id = (FLAGS.id - 1) % len(beta) opts['obj_fn_coeffs'] = beta[coef_id] elif opts['model'] == 'TCWAE_MWS': if opts['cost'] == 'xent': beta = [1, 2, 4, 6, 8, 10] gamma = [1, 2, 4, 6, 8, 10] else: beta = [2., 5.] gamma = [0.1, 0.25, 0.5, 0.75, 1., 2., 5.] lmba1 = list(itertools.product(beta, gamma)) lmba2 = list(itertools.product(gamma, beta)) lmba = lmba1 + lmba2 # beta = [0.1, 0.25, 0.5, 0.75, 1, 2] # gamma = [0.1, 0.25, 0.5, 0.75, 1, 2] # lmba = list(itertools.product(beta,gamma)) coef_id = (FLAGS.id - 1) % len(lmba) opts['obj_fn_coeffs'] = list(lmba[coef_id]) elif opts['model'] == 'TCWAE_GAN': if opts['cost'] == 'xent': beta = [1, 10, 25, 50, 75, 100] gamma = [1, 10, 25, 50, 75, 100] else: # beta = [0.5,20.] # gamma = [0.1, 0.5, 1., 2.5, 5., 7.5, 10.0, 20.] # lmba1 = list(itertools.product(beta,gamma)) # lmba2 = list(itertools.product(gamma,beta)) # lmba = lmba1+lmba2 beta = [0.5, 1, 2.5, 5, 7.5, 10] gamma = [0.5, 1.0, 2.5, 5.0, 7.5, 10.0] lmba = list(itertools.product(beta, gamma)) coef_id = (FLAGS.id - 1) % len(lmba) opts['obj_fn_coeffs'] = list(lmba[coef_id]) elif opts['model'] == 'TCWAE_MWS_MI': beta = [0.1, 0.25, 0.5, 0.75, 1, 10] coef_id = (FLAGS.id - 1) % len(beta) opts['obj_fn_coeffs'] = beta[coef_id] elif opts['model'] == 'TCWAE_GAN_MI': beta = [1, 2, 4, 6, 8, 10] coef_id = (FLAGS.id - 1) % len(beta) opts['obj_fn_coeffs'] = beta[coef_id] else: raise NotImplementedError('Model type not recognised') # Create directories results_dir = 'results' if not tf.io.gfile.isdir(results_dir): utils.create_dir(results_dir) opts['out_dir'] = os.path.join(results_dir, FLAGS.out_dir) if not tf.io.gfile.isdir(opts['out_dir']): utils.create_dir(opts['out_dir']) out_subdir = os.path.join(opts['out_dir'], opts['model']) if not tf.io.gfile.isdir(out_subdir): utils.create_dir(out_subdir) opts['exp_dir'] = FLAGS.res_dir if opts['model'] == 'disWAE' or opts['model'] == 'TCWAE_MWS' or opts[ 'model'] == 'TCWAE_GAN': exp_dir = os.path.join( out_subdir, '{}_{}_{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'], opts['obj_fn_coeffs'][0], opts['obj_fn_coeffs'][1], datetime.now()), ) else: exp_dir = os.path.join( out_subdir, '{}_{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'], opts['obj_fn_coeffs'], datetime.now()), ) opts['exp_dir'] = exp_dir if not tf.io.gfile.isdir(exp_dir): utils.create_dir(exp_dir) utils.create_dir(os.path.join(exp_dir, 'checkpoints')) # Verbose logging.basicConfig(filename=os.path.join(exp_dir, 'outputs.log'), level=logging.INFO, format='%(asctime)s - %(message)s') # Experiemnts set up opts['it_num'] = FLAGS.num_it opts['print_every'] = int(opts['it_num'] / 2.) opts['evaluate_every'] = int(opts['it_num'] / 4.) opts['save_every'] = 10000000000 opts['save_final'] = FLAGS.save_model opts['save_train_data'] = FLAGS.save_data opts['vizu_encSigma'] = False #Reset tf graph tf.reset_default_graph() # Loading the dataset opts['data_dir'] = FLAGS.data_dir opts['stage_to_scratch'] = FLAGS.stage_to_scratch opts['scratch_dir'] = FLAGS.scratch_dir data = DataHandler(opts) assert data.train_size >= opts['batch_size'], 'Training set too small' # inti method run = Run(opts, data) # Training/testing/vizu if FLAGS.mode == "train": # Dumping all the configs to the text file with utils.o_gfile((exp_dir, 'params.txt'), 'w') as text: text.write('Parameters:\n') for key in opts: text.write('%s : %s\n' % (key, opts[key])) run.train() else: assert False, 'Unknown mode %s' % FLAGS.mode
def main(): # Select dataset to use if FLAGS.dataset == 'dsprites': opts = configs.config_dsprites elif FLAGS.dataset == 'noisydsprites': opts = configs.config_noisydsprites elif FLAGS.dataset == 'screamdsprites': opts = configs.config_screamdsprites elif FLAGS.dataset == 'smallNORB': opts = configs.config_smallNORB elif FLAGS.dataset == '3dshapes': opts = configs.config_3dshapes elif FLAGS.dataset == '3Dchairs': opts = configs.config_3Dchairs elif FLAGS.dataset == 'celebA': opts = configs.config_celebA elif FLAGS.dataset == 'mnist': opts = configs.config_mnist else: assert False, 'Unknown dataset' # Set method param opts['fid'] = FLAGS.fid opts['cost'] = FLAGS.cost #l2, l2sq, l2sq_norm, l1, xentropy if FLAGS.net_archi: opts['network'] = net_configs[FLAGS.net_archi] else: if FLAGS.dataset == 'celebA': opts['network'] = net_configs['conv_rae'] else: opts['network'] = net_configs['conv_locatello'] # Model set up opts['model'] = FLAGS.model if FLAGS.dataset == 'celebA': opts['zdim'] = 32 elif FLAGS.dataset == '3Dchairs': opts['zdim'] = 16 else: opts['zdim'] = 10 opts['lr'] = 0.0001 # Objective Function Coefficients if opts['model'] in ['BetaTCVAE', 'FactorVAE']: opts['obj_fn_coeffs'] = FLAGS.beta elif opts['model'] in ['TCWAE_MWS', 'TCWAE_GAN']: opts['obj_fn_coeffs'] = [FLAGS.beta, FLAGS.gamma] else: raise Exception('Unknown {} model for {}'.format( opts['model'], FLAGS.dataset)) # Create directories results_dir = 'results' if not tf.io.gfile.isdir(results_dir): utils.create_dir(results_dir) opts['out_dir'] = os.path.join(results_dir, FLAGS.out_dir) if not tf.io.gfile.isdir(opts['out_dir']): utils.create_dir(opts['out_dir']) out_subdir = os.path.join(opts['out_dir'], opts['model']) if not tf.io.gfile.isdir(out_subdir): utils.create_dir(out_subdir) opts['exp_dir'] = FLAGS.res_dir if opts['model'] == 'disWAE' or opts['model'] == 'TCWAE_MWS' or opts[ 'model'] == 'TCWAE_GAN': exp_dir = os.path.join( out_subdir, '{}_{}_{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'], opts['obj_fn_coeffs'][0], opts['obj_fn_coeffs'][1], datetime.now()), ) else: exp_dir = os.path.join( out_subdir, '{}_{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'], opts['obj_fn_coeffs'], datetime.now()), ) opts['exp_dir'] = exp_dir if not tf.io.gfile.isdir(exp_dir): utils.create_dir(exp_dir) utils.create_dir(os.path.join(exp_dir, 'checkpoints')) # Verbose logging.basicConfig(filename=os.path.join(exp_dir, 'outputs.log'), level=logging.INFO, format='%(asctime)s - %(message)s') # Experiemnts set up opts['it_num'] = FLAGS.num_it opts['print_every'] = int(opts['it_num'] / 2.) opts['evaluate_every'] = int(opts['it_num'] / 4.) opts['save_every'] = 10000000000 opts['save_final'] = FLAGS.save_model opts['save_train_data'] = FLAGS.save_data opts['vizu_encSigma'] = False #Reset tf graph tf.reset_default_graph() # Loading the dataset opts['data_dir'] = FLAGS.data_dir opts['stage_to_scratch'] = FLAGS.stage_to_scratch opts['scratch_dir'] = FLAGS.scratch_dir data = DataHandler(opts) assert data.train_size >= opts['batch_size'], 'Training set too small' # inti method run = Run(opts, data) # Training/testing/vizu if FLAGS.mode == "train": # Dumping all the configs to the text file with utils.o_gfile((exp_dir, 'params.txt'), 'w') as text: text.write('Parameters:\n') for key in opts: text.write('%s : %s\n' % (key, opts[key])) run.train() else: assert False, 'Unknown mode %s' % FLAGS.mode
def main(): # Select dataset to use if FLAGS.dataset == 'mnist': opts = configs.config_mnist elif FLAGS.dataset == 'smallNORB': opts = configs.config_smallNORB elif FLAGS.dataset == 'celebA': opts = configs.config_celebA else: assert False, 'Unknown dataset' # model opts['model'] = FLAGS.model opts['encoder'] = [ FLAGS.encoder, ] * opts['nlatents'] opts['use_sigmoid'] = FLAGS.sigmoid opts['archi'] = [ FLAGS.net_archi, ] * opts['nlatents'] opts['obs_cost'] = FLAGS.cost opts['lambda_schedule'] = FLAGS.lmba_schedule opts['enc_sigma_pen'] = FLAGS.enc_sigma_pen opts['dec_sigma_pen'] = FLAGS.dec_sigma_pen # opts['nlatents'] = 1 # zdims = [2,4,8,16] # id = (FLAGS.id-1) % len(zdims) # opts['zdim'] = [zdims[id],] # opts['lambda_init'] = [1,] # opts['lambda'] = [1.,] # beta = opts['lambda'] # opts['lambda_sigma'] = [1.,] # lamba beta = [0.0001, 1.] id = (FLAGS.id - 1) % len(beta) opts['lambda_init'] = [beta[id] for n in range(opts['nlatents'])] opts['lambda'] = [1. for n in range(opts['nlatents'])] # Create directories results_dir = 'results' if not tf.io.gfile.isdir(results_dir): utils.create_dir(results_dir) opts['out_dir'] = os.path.join(results_dir, FLAGS.out_dir) if not tf.io.gfile.isdir(opts['out_dir']): utils.create_dir(opts['out_dir']) out_subdir = os.path.join(opts['out_dir'], opts['model']) if not tf.io.gfile.isdir(out_subdir): utils.create_dir(out_subdir) # out_subdir = os.path.join(out_subdir, 'dz'+str(zdims[id])) # if not tf.io.gfile.isdir(out_subdir): # utils.create_dir(out_subdir) opts['exp_dir'] = FLAGS.res_dir exp_dir = os.path.join( out_subdir, '{}_{}layers_lreg{}_{:%Y_%m_%d_%H_%M}'.format(opts['exp_dir'], opts['nlatents'], beta[id], datetime.now())) opts['exp_dir'] = exp_dir if not tf.io.gfile.isdir(exp_dir): utils.create_dir(exp_dir) utils.create_dir(os.path.join(exp_dir, 'checkpoints')) # getting weights path if FLAGS.weights_file is not None: WEIGHTS_PATH = os.path.join(opts['exp_dir'], 'checkpoints', FLAGS.weights_file) else: WEIGHTS_PATH = None # Verbose logging.basicConfig(filename=os.path.join(exp_dir, 'outputs.log'), level=logging.INFO, format='%(asctime)s - %(message)s') # run set up opts['vizu_splitloss'] = FLAGS.losses opts['vizu_fullrec'] = FLAGS.reconstructions opts['vizu_embedded'] = FLAGS.embedded opts['vizu_latent'] = FLAGS.latents opts['fid'] = FLAGS.fid opts['it_num'] = FLAGS.num_it opts['print_every'] = int(opts['it_num'] / 4) opts['evaluate_every'] = int(opts['it_num'] / 50) if FLAGS.batch_size is not None: opts['batch_size'] = FLAGS.batch_size opts['lr'] = FLAGS.lr opts['use_trained'] = FLAGS.use_trained opts['save_every'] = 10000000000 opts['save_final'] = FLAGS.save_model opts['save_train_data'] = FLAGS.save_data #Reset tf graph tf.compat.v1.reset_default_graph() # Loading the dataset opts['data_dir'] = FLAGS.data_dir data = DataHandler(opts) assert data.train_size >= opts['batch_size'], 'Training set too small' # build model run = Run(opts, data) # Training/testing/vizu if FLAGS.mode == "train": # Dumping all the configs to the text file with utils.o_gfile((opts['exp_dir'], 'params.txt'), 'w') as text: text.write('Parameters:\n') for key in opts: text.write('%s : %s\n' % (key, opts[key])) run.train(WEIGHTS_PATH) elif FLAGS.mode == "vizu": opts['rec_loss_nsamples'] = 1 opts['sample_recons'] = False run.latent_interpolation(opts['exp_dir'], WEIGHTS_PATH) elif FLAGS.mode == "fid": run.fid_score(WEIGHTS_PATH) elif FLAGS.mode == "test": run.test_losses(WEIGHTS_PATH) elif FLAGS.mode == "vlae_exp": run.vlae_experiment(WEIGHTS_PATH) else: assert False, 'Unknown mode %s' % FLAGS.mode