def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.info('print this') params = FLAGS.flag_values_dict() tf.set_random_seed(params['seed']) plt.rcParams['savefig.format'] = params['mpl_format'] params['results_dir'] = utils.make_subdir(params['results_dir'], params['expname']) params['figdir'] = utils.make_subdir(params['results_dir'], 'figs') params['sampledir'] = utils.make_subdir(params['figdir'], 'samples') params['ckptdir'] = utils.make_subdir(params['results_dir'], 'ckpts') params['logdir'] = utils.make_subdir(params['results_dir'], 'logs') params['tensordir'] = utils.make_subdir(params['results_dir'], 'tensors') itr_train, itr_valid, itr_test = dataset_utils.load_dset_unsupervised() conv_dims = [int(x) for x in params['conv_dims'].split(',')] conv_sizes = [int(x) for x in params['conv_sizes'].split(',')] vae = VAE(conv_dims, conv_sizes) train_vae(vae, itr_train, itr_valid, params) test_vae(vae, itr_test, params)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') logging.info('print this') params = FLAGS.flag_values_dict() plt.rcParams['savefig.format'] = params['mpl_format'] params['results_dir'] = utils.make_subdir(params['results_dir'], params['expname']) params['figdir'] = utils.make_subdir(params['results_dir'], 'figs') params['sampledir'] = utils.make_subdir(params['figdir'], 'samples') params['ckptdir'] = utils.make_subdir(params['results_dir'], 'ckpts') params['logdir'] = utils.make_subdir(params['results_dir'], 'logs') params['tensordir'] = utils.make_subdir(params['results_dir'], 'tensors') ood_classes = [int(x) for x in params['ood_classes'].split(',')] # assume we train on all non-OOD classes n_classes = 10 all_classes = range(n_classes) ind_classes = [x for x in all_classes if x not in ood_classes] (itr_train, itr_valid, itr_test, itr_test_ood) = dataset_utils.load_dset_ood_unsupervised( ind_classes, ood_classes) conv_dims = [int(x) for x in params['conv_dims'].split(',')] conv_sizes = [int(x) for x in params['conv_sizes'].split(',')] vae = VAE(conv_dims, conv_sizes) run_vae_mnist.train_vae(vae, itr_train, itr_valid, params) run_vae_mnist.test_vae(vae, itr_test, params) params['tensordir'] = utils.make_subdir(params['results_dir'], 'ood_tensors') run_vae_mnist.test_vae(vae, itr_test_ood, params)