def main(argv): del argv # Unused # Save all results in subdirectories of following path base_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), FLAGS.base_dir) # Overwrite output or not (for rerunning script) overwrite = True # Results directory of BetaTCVAE path_adagvae = os.path.join(base_path,FLAGS.output_dir) gin_bindings = [ "model.random_seed = {}".format(FLAGS.seed), "subset.name = '{}'".format(FLAGS.subset), "encoder.num_latent = {}".format(FLAGS.dim) ] # Train model. Training is configured with a gin config train.train_with_gin(os.path.join(path_adagvae, 'model'), overwrite, ['baselines/adagvae/adagvae_train.gin'], gin_bindings) # Extract mean representation of latent space representation_path = os.path.join(path_adagvae, "representation") model_path = os.path.join(path_adagvae, "model") postprocess_gin = ['baselines/adagvae/adagvae_postprocess.gin'] # This contains the settings. postprocess.postprocess_with_gin(model_path, representation_path, overwrite, postprocess_gin) # Compute DCI metric result_path = os.path.join(path_adagvae, "metrics", "dci") representation_path = os.path.join(path_adagvae, "representation") evaluate.evaluate_with_gin(representation_path, result_path, overwrite, ['baselines/adagvae/adagvae_dci.gin'])
def test_evaluate(self, gin_config): # We clear the gin config before running. Otherwise, if a prior test fails, # the gin config is locked and the current test fails. gin.clear_config() evaluate.evaluate_with_gin(self.output_dir, self.create_tempdir().full_path, True, [gin_config])
def eval(study, output_directory, model_dir): # We fix the random seed for the postprocessing and evaluation steps (each # config gets a different but reproducible seed derived from a master seed of # 0). The model seed was set via the gin bindings and configs of the study. random_state = np.random.RandomState(0) # We extract the different representations and save them to disk. postprocess_config_files = sorted(study.get_postprocess_config_files()) for config in postprocess_config_files: post_name = os.path.basename(config).replace(".gin", "") logging.info("Extracting representation %s...", post_name) post_dir = os.path.join(output_directory, "postprocessed", post_name) postprocess_bindings = [ "postprocess.random_seed = {}".format(random_state.randint(2**16)), "postprocess.name = '{}'".format(post_name) ] postprocess.postprocess_with_gin(model_dir, post_dir, FLAGS.overwrite, [config], postprocess_bindings) # Iterate through the disentanglement metrics. eval_configs = sorted(study.get_eval_config_files()) for config in postprocess_config_files: post_name = os.path.basename(config).replace(".gin", "") post_dir = os.path.join(output_directory, "postprocessed", post_name) # Now, we compute all the specified scores. for gin_eval_config in eval_configs: metric_name = os.path.basename(gin_eval_config).replace(".gin", "") logging.info("Computing metric '%s' on '%s'...", metric_name, post_name) metric_dir = os.path.join(output_directory, "metrics", post_name, metric_name) eval_bindings = [ "evaluation.random_seed = {}".format( random_state.randint(2**16)), "evaluation.name = '{}'".format(metric_name) ] evaluate.evaluate_with_gin(post_dir, metric_dir, FLAGS.overwrite, [gin_eval_config], eval_bindings)
def main(unused_argv): base_path = "3dshapes_models" done = False while not done: try: print("\n\n*- Preprocessing '%s' \n\n" % (FLAGS.dataset)) preproces_gin_bindings = [ "dataset.name = '%s'" % (FLAGS.dataset), "preprocess.preprocess_fn = @split_train_and_validation_per_model", "split_train_and_validation_per_model.random_seed = %d" % (FLAGS.rng) ] preprocess.preprocess_with_gin(FLAGS.dataset, FLAGS.model, overwrite=FLAGS.overwrite, gin_config_files=None, gin_bindings=preproces_gin_bindings) print("\n\n*- Preprocessing DONE \n\n") done = True except: time.sleep(30) if FLAGS.model == "vae": gin_file = "3d_shape_vae.gin" if FLAGS.model == "bvae": gin_file = "3d_shape_bvae.gin" if FLAGS.model == "b8vae": gin_file = "3d_shape_b8vae.gin" if FLAGS.model == "fvae": gin_file = "3d_shape_fvae.gin" if FLAGS.model == "btcvae": gin_file = "3d_shape_btcvae.gin" if FLAGS.model == "annvae": gin_file = "3d_shape_annvae.gin" if FLAGS.model == "randomvae": gin_file = "3d_shape_randomvae.gin" print("\n\n*- Training '%s' \n\n" % (FLAGS.model)) vae_gin_bindings = [ "model.random_seed = %d" % (FLAGS.rng), "dataset.name = '%s'" % (FLAGS.dataset + '_' + FLAGS.model + '_' + str(FLAGS.rng)) ] vae_path = os.path.join(base_path, FLAGS.model + FLAGS.dataset + '_' + str(FLAGS.rng)) train_vae_path = os.path.join(vae_path, 'model') unsupervised_train_partial.train_with_gin(train_vae_path, FLAGS.overwrite, [gin_file], vae_gin_bindings) visualize_model.visualize(train_vae_path, vae_path + "/vis", FLAGS.overwrite) preprocess.destroy_train_and_validation_splits(FLAGS.dataset + '_' + FLAGS.model + '_' + str(FLAGS.rng)) print("\n\n*- Training DONE \n\n") print("\n\n*- Postprocessing '%s' \n\n" % (FLAGS.model)) postprocess_gin_bindings = [ "postprocess.postprocess_fn = @mean_representation", "dataset.name='dummy_data'", "postprocess.random_seed = %d" % (FLAGS.rng) ] representation_path = os.path.join(vae_path, "representation") model_path = os.path.join(vae_path, "model") postprocess.postprocess_with_gin(model_path, representation_path, FLAGS.overwrite, gin_config_files=None, gin_bindings=postprocess_gin_bindings) print("\n\n*- Postprocessing DONE \n\n") # --- Evaluate disentanglement metrics print("\n\n*- Evaluating MIG.") gin_bindings = [ "evaluation.evaluation_fn = @mig", "dataset.name='3dshapes'", "evaluation.random_seed = 0", "mig.num_train = 10000", "discretizer.discretizer_fn = @histogram_discretizer", "discretizer.num_bins = 20" ] result_path = os.path.join(vae_path, "metrics", "mig") evaluate.evaluate_with_gin(representation_path, result_path, FLAGS.overwrite, gin_bindings=gin_bindings) print("\n\n*- Evaluating BetaVEA.") gin_bindings = [ "evaluation.evaluation_fn = @beta_vae_sklearn", "dataset.name='3dshapes'", "evaluation.random_seed = 0", "beta_vae_sklearn.batch_size = 16", "beta_vae_sklearn.num_train = 10000", "beta_vae_sklearn.num_eval = 5000", "discretizer.discretizer_fn = @histogram_discretizer", "discretizer.num_bins = 20" ] result_path = os.path.join(vae_path, "metrics", "bvae") evaluate.evaluate_with_gin(representation_path, result_path, FLAGS.overwrite, gin_bindings=gin_bindings) print("\n\n*- Evaluating FactorVAE.") gin_bindings = [ "evaluation.evaluation_fn = @factor_vae_score", "dataset.name='3dshapes'", "evaluation.random_seed = 0", "factor_vae_score.batch_size = 16", "factor_vae_score.num_train = 10000", "factor_vae_score.num_eval = 5000", "factor_vae_score.num_variance_estimate = 10000", "discretizer.discretizer_fn = @histogram_discretizer", "discretizer.num_bins = 20" ] result_path = os.path.join(vae_path, "metrics", "fvae") evaluate.evaluate_with_gin(representation_path, result_path, FLAGS.overwrite, gin_bindings=gin_bindings) print("\n\n*- Evaluating DCI.") gin_bindings = [ "evaluation.evaluation_fn = @dci", "dataset.name='3dshapes'", "evaluation.random_seed = 0", "dci.batch_size = 16", "dci.num_train = 10000", "dci.num_test = 5000", "discretizer.discretizer_fn = @histogram_discretizer", "discretizer.num_bins = 20" ] result_path = os.path.join(vae_path, "metrics", "dci") evaluate.evaluate_with_gin(representation_path, result_path, FLAGS.overwrite, gin_bindings=gin_bindings) print("\n\n*- Evaluation COMPLETED \n\n") # --- Downstream tasks print("\n\n*- Training downstream factor regression '%s' \n\n" % (FLAGS.model)) downstream_regression_train_gin_bindings = [ "evaluation.evaluation_fn = @downstream_regression_on_representations", "dataset.name = '3dshapes_task'", "evaluation.random_seed = 0", "downstream_regression_on_representations.holdout_dataset_name = '3dshapes_holdout'", "downstream_regression_on_representations.num_train = [127500]", "downstream_regression_on_representations.num_test = 22500", "downstream_regression_on_representations.num_holdout = 80000", "predictor.predictor_fn = @mlp_regressor", "mlp_regressor.hidden_layer_sizes = [16, 8]", "mlp_regressor.activation = 'logistic'", "mlp_regressor.max_iter = 50", "mlp_regressor.random_state = 0" ] result_path = os.path.join(vae_path, "metrics", "factor_regression") evaluate.evaluate_with_gin( representation_path, result_path, FLAGS.overwrite, gin_config_files=None, gin_bindings=downstream_regression_train_gin_bindings) print("\n\n*- Training downstream factor regression DONE \n\n") print("\n\n*- Training downstream reconstruction '%s' \n\n" % (FLAGS.model)) downstream_reconstruction_train_gin_bindings = [ "supervised_model.model = @downstream_decoder()", "supervised_model.batch_size = 64", "supervised_model.training_steps = 30000", "supervised_model.eval_steps = 1000", "supervised_model.random_seed = 0", "supervised_model.holdout_dataset_name = '3dshapes_holdout'", "dataset.name='3dshapes_task'", "decoder_optimizer.optimizer_fn = @AdamOptimizer", "AdamOptimizer.beta1 = 0.9", "AdamOptimizer.beta2 = 0.999", "AdamOptimizer.epsilon = 1e-08", "AdamOptimizer.learning_rate = 0.0001", "AdamOptimizer.name = 'Adam'", "AdamOptimizer.use_locking = False", "decoder.decoder_fn = @deconv_decoder", "reconstruction_loss.loss_fn = @l2_loss" ] result_path = os.path.join(vae_path, "metrics", "reconstruction") supervised_train_partial.train_with_gin( result_path, representation_path, FLAGS.overwrite, gin_bindings=downstream_reconstruction_train_gin_bindings) visualize_model.visualize_supervised(result_path, representation_path, result_path + "/vis", FLAGS.overwrite) print("\n\n*- Training downstream reconstruction DONE \n\n") print("\n\n*- Training & evaluation COMPLETED \n\n")
# We use the Mutual Information Gap (with a low number of samples to make it # faster). To learn more, have a look at the different scores in # disentanglement_lib.evaluation.evaluate.metrics and the predefined .gin # configuration files in # disentanglement_lib/config/unsupervised_study_v1/metrics_configs/(...). gin_bindings = [ "evaluation.evaluation_fn = @mig", "dataset.name='auto'", "evaluation.random_seed = 0", "mig.num_train=1000", "discretizer.discretizer_fn = @histogram_discretizer", "discretizer.num_bins = 20" ] for path in [path_vae, path_custom_vae]: result_path = os.path.join(path, "metrics", "mig") representation_path = os.path.join(path, "representation") evaluate.evaluate_with_gin(representation_path, result_path, overwrite, gin_bindings=gin_bindings) # 5. Compute a custom disentanglement metric for both models. # ------------------------------------------------------------------------------ # The following function implements a dummy metric. Note that all metrics get # ground_truth_data, representation_function, random_state arguments by the # evaluation protocol, while all other arguments have to be configured via gin. @gin.configurable( "custom_metric", blacklist=["ground_truth_data", "representation_function", "random_state"]) def compute_custom_metric(ground_truth_data, representation_function, random_state, num_train=gin.REQUIRED,
def main(unused_argv): base_path = "3dshapes_models" print("\n\n*- Evaluating '%s' \n\n" % (FLAGS.model)) vae_path = os.path.join(base_path, FLAGS.model + FLAGS.dataset + '_' + str(FLAGS.rng)) representation_path = os.path.join(vae_path, "representation") print(vae_path, representation_path) print("\n\n*- Evaluating MIG.") gin_bindings = [ "evaluation.evaluation_fn = @mig", "dataset.name='3dshapes'", "evaluation.random_seed = 0", "mig.num_train = 100000", "discretizer.discretizer_fn = @histogram_discretizer", "discretizer.num_bins = 20" ] result_path = os.path.join(vae_path, "metrics", "mig_10000") evaluate.evaluate_with_gin(representation_path, result_path, FLAGS.overwrite, gin_bindings=gin_bindings) print("\n\n*- Evaluating BetaVEA.") gin_bindings = [ "evaluation.evaluation_fn = @beta_vae_sklearn", "dataset.name='3dshapes'", "evaluation.random_seed = 0", "beta_vae_sklearn.batch_size = 16", "beta_vae_sklearn.num_train = 100000", "beta_vae_sklearn.num_eval = 5000", "discretizer.discretizer_fn = @histogram_discretizer", "discretizer.num_bins = 20" ] result_path = os.path.join(vae_path, "metrics", "bvae_10000") evaluate.evaluate_with_gin(representation_path, result_path, FLAGS.overwrite, gin_bindings=gin_bindings) print("\n\n*- Evaluating FactorVAE.") gin_bindings = [ "evaluation.evaluation_fn = @factor_vae_score", "dataset.name='3dshapes'", "evaluation.random_seed = 0", "factor_vae_score.batch_size = 16", "factor_vae_score.num_train = 100000", "factor_vae_score.num_eval = 5000", "factor_vae_score.num_variance_estimate = 100000", "discretizer.discretizer_fn = @histogram_discretizer", "discretizer.num_bins = 20" ] result_path = os.path.join(vae_path, "metrics", "fvae_10000") evaluate.evaluate_with_gin(representation_path, result_path, FLAGS.overwrite, gin_bindings=gin_bindings) print("\n\n*- Evaluating DCI.") gin_bindings = [ "evaluation.evaluation_fn = @dci", "dataset.name='3dshapes'", "evaluation.random_seed = 0", "dci.batch_size = 16", "dci.num_train = 100000", "dci.num_test = 5000", "discretizer.discretizer_fn = @histogram_discretizer", "discretizer.num_bins = 20" ] result_path = os.path.join(vae_path, "metrics", "dci_10000") evaluate.evaluate_with_gin(representation_path, result_path, FLAGS.overwrite, gin_bindings=gin_bindings) print("\n\n*- Evaluation COMPLETED \n\n")
def main(): parser = argparse.ArgumentParser(description='Project description.') parser.add_argument('--result_dir', help='Results directory.', type=str, default='/mnt/hdd/repo_results/Ramiel/sweep') parser.add_argument('--study', help='Name of the study.', type=str, default='unsupervised_study_v1') parser.add_argument('--model_gin', help='Name of the gin config.', type=str, default='test_model.gin') parser.add_argument('--model_name', help='Name of the model.', type=str, default='GroupVAE') parser.add_argument('--vae_beta', help='Beta-VAE beta.', type=str, default='1') parser.add_argument('--hyps', help='Hyperparameters of rec_mat_oth_spl_seed.', type=str, default='1_1_1_1_1_0') parser.add_argument('--overwrite', help='Whether to overwrite output directory.', type=_str_to_bool, default=False) parser.add_argument('--dataset', help='Dataset.', type=str, default='dsprites_full') parser.add_argument('--recons_type', help='Reconstruction loss type.', type=str, default='bernoulli_loss') args = parser.parse_args() # 1. Settings study = reproduce.STUDIES[args.study] args.hyps = args.hyps.split('_') print() study.print_postprocess_config() print() study.print_eval_config() gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: # Currently, memory growth needs to be the same across GPUs for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) logical_gpus = tf.config.experimental.list_logical_devices('GPU') print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") except RuntimeError as e: # Memory growth must be set before GPUs have been initialized print(e) # Call training module to train the custom model. if args.model_name == "GroupVAE": dir_name = "GroupVAE-" + "-".join(args.hyps) elif args.model_name == "vae": dir_name = "LieVAE-" + args.vae_beta + "-" + args.hyps[5] output_directory = os.path.join(args.result_dir, dir_name) model_dir = os.path.join(output_directory, "model") gin_bindings = [ "model.model = @" + args.model_name + "()", "vae.beta = " + args.vae_beta, "GroupVAE.hy_rec = " + args.hyps[0], "GroupVAE.hy_mat = " + args.hyps[1], "GroupVAE.hy_oth = " + args.hyps[2], "GroupVAE.hy_spl = " + args.hyps[3], "GroupVAE.hy_ncut = " + args.hyps[4], "model.random_seed = " + args.hyps[5], "dataset.name = '" + args.dataset + "'", "reconstruction_loss.loss_fn = @" + args.recons_type ] train.train_with_gin(model_dir, args.overwrite, [args.model_gin], gin_bindings) # We fix the random seed for the postprocessing and evaluation steps (each # config gets a different but reproducible seed derived from a master seed of # 0). The model seed was set via the gin bindings and configs of the study. random_state = np.random.RandomState(0) # We extract the different representations and save them to disk. postprocess_config_files = sorted(study.get_postprocess_config_files()) for config in postprocess_config_files: post_name = os.path.basename(config).replace(".gin", "") print("Extracting representation " + post_name + "...") post_dir = os.path.join(output_directory, "postprocessed", post_name) postprocess_bindings = [ "postprocess.random_seed = {}".format(random_state.randint(2**32)), "postprocess.name = '{}'".format(post_name) ] postprocess.postprocess_with_gin(model_dir, post_dir, args.overwrite, [config], postprocess_bindings) # Iterate through the disentanglement metrics. eval_configs = sorted(study.get_eval_config_files()) blacklist = ['downstream_task_logistic_regression.gin'] # blacklist = [ # 'downstream_task_logistic_regression.gin', 'beta_vae_sklearn.gin', # 'dci.gin', 'downstream_task_boosted_trees.gin', 'mig.gin', # 'modularity_explicitness.gin', 'sap_score.gin', 'unsupervised.gin' # ] for config in postprocess_config_files: post_name = os.path.basename(config).replace(".gin", "") post_dir = os.path.join(output_directory, "postprocessed", post_name) # Now, we compute all the specified scores. for gin_eval_config in eval_configs: if os.path.basename(gin_eval_config) not in blacklist: metric_name = os.path.basename(gin_eval_config).replace( ".gin", "") print("Computing metric " + metric_name + " on " + post_name + "...") metric_dir = os.path.join(output_directory, "metrics", post_name, metric_name) eval_bindings = [ "evaluation.random_seed = {}".format( random_state.randint(2**32)), "evaluation.name = '{}'".format(metric_name) ] evaluate.evaluate_with_gin(post_dir, metric_dir, args.overwrite, [gin_eval_config], eval_bindings) # We visualize reconstructions, samples and latent space traversals. visualize_dir = os.path.join(output_directory, "visualizations") visualize_model.visualize(model_dir, visualize_dir, args.overwrite)
def main(): parser = argparse.ArgumentParser(description='Project description.') parser.add_argument('--study', help='Name of the study.', type=str, default='unsupervised_study_v1') parser.add_argument('--output_directory', help='Output directory of experiments.', type=str, default=None) parser.add_argument('--model_dir', help='Directory to take trained model from.', type=str, default=None) parser.add_argument('--model_num', help='Integer with model number to train.', type=int, default=None) parser.add_argument('--only_print', help='Whether to only print the hyperparameter settings.', type=_str_to_bool, default=False) parser.add_argument('--overwrite', help='Whether to overwrite output directory.', type=_str_to_bool, default=False) args = parser.parse_args() # logging.set_verbosity('error') # logging.set_stderrthreshold('error') gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: try: # Currently, memory growth needs to be the same across GPUs for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) logical_gpus = tf.config.experimental.list_logical_devices('GPU') print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") except RuntimeError as e: # Memory growth must be set before GPUs have been initialized print(e) # Obtain the study to reproduce. study = reproduce.STUDIES[args.study] # Print the hyperparameter settings. if args.model_dir is None: study.print_model_config(args.model_num) else: print("Model directory (skipped training):") print("--") print(args.model_dir) print() study.print_postprocess_config() print() study.print_eval_config() if args.only_print: return # Set correct output directory. if args.output_directory is None: if args.model_dir is None: output_directory = os.path.join("output", "{study}", "{model_num}") else: output_directory = "output" else: output_directory = args.output_directory # Insert model number and study name into path if necessary. output_directory = output_directory.format(model_num=str(args.model_num), study=str(args.study)) # Model training (if model directory is not provided). if args.model_dir is None: model_bindings, model_config_file = study.get_model_config(args.model_num) print("Training model...") model_dir = os.path.join(output_directory, "model") model_bindings = [ "model.name = '{}'".format(os.path.basename(model_config_file)).replace( ".gin", ""), "model.model_num = {}".format(args.model_num), ] + model_bindings train.train_with_gin(model_dir, args.overwrite, [model_config_file], model_bindings) else: print("Skipped training...") model_dir = args.model_dir # We visualize reconstructions, samples and latent space traversals. visualize_dir = os.path.join(output_directory, "visualizations") visualize_model.visualize(model_dir, visualize_dir, args.overwrite) # We fix the random seed for the postprocessing and evaluation steps (each # config gets a different but reproducible seed derived from a master seed of # 0). The model seed was set via the gin bindings and configs of the study. random_state = np.random.RandomState(0) # We extract the different representations and save them to disk. postprocess_config_files = sorted(study.get_postprocess_config_files()) for config in postprocess_config_files: post_name = os.path.basename(config).replace(".gin", "") print("Extracting representation %s..." % post_name) post_dir = os.path.join(output_directory, "postprocessed", post_name) postprocess_bindings = [ "postprocess.random_seed = {}".format(random_state.randint(2**32)), "postprocess.name = '{}'".format(post_name) ] postprocess.postprocess_with_gin(model_dir, post_dir, args.overwrite, [config], postprocess_bindings) # Iterate through the disentanglement metrics. eval_configs = sorted(study.get_eval_config_files()) blacklist = ['downstream_task_logistic_regression.gin'] for config in postprocess_config_files: post_name = os.path.basename(config).replace(".gin", "") post_dir = os.path.join(output_directory, "postprocessed", post_name) # Now, we compute all the specified scores. for gin_eval_config in eval_configs: if os.path.basename(gin_eval_config) not in blacklist: metric_name = os.path.basename(gin_eval_config).replace(".gin", "") print("Computing metric '%s' on '%s'..." % (metric_name, post_name)) metric_dir = os.path.join(output_directory, "metrics", post_name, metric_name) eval_bindings = [ "evaluation.random_seed = {}".format(random_state.randint(2**32)), "evaluation.name = '{}'".format(metric_name) ] evaluate.evaluate_with_gin(post_dir, metric_dir, args.overwrite, [gin_eval_config], eval_bindings)
model_path = os.path.join(path, "model") postprocess_gin = ["postprocess.gin"] # This contains the settings. # postprocess.postprocess_with_gin defines the standard extraction protocol. postprocess.postprocess_with_gin(model_path, representation_path, overwrite, postprocess_gin) # 5. Compute metrics for all models. metrics = glob.glob('../disentanglement_lib/disentanglement_lib/config/unsupervised_study_v1/metric_configs/*.gin') blacklist = ['downstream_task_logistic_regression.gin'] for path in [path_custom_vae]: for metric in metrics: if os.path.basename(metric) not in blacklist: result_path = os.path.join(path, "metrics", os.path.basename(metric).replace('.gin', '')) representation_path = os.path.join(path, "representation") evaluate.evaluate_with_gin(representation_path, result_path, overwrite, gin_config_files=[metric]) # 6. Aggregate the results. # ------------------------------------------------------------------------------ # In the previous steps, we saved the scores to several output directories. We # can aggregate all the results using the following command. pattern = os.path.join(base_path, "*/metrics/*/results/aggregate/evaluation.json") results_path = os.path.join(base_path, "results.json") aggregate_results.aggregate_results_to_json(pattern, results_path) # 7. Print out the final Pandas data frame with the results. # ------------------------------------------------------------------------------ # The aggregated results contains for each computed metric all the configuration # options and all the results captured in the steps along the pipeline. This
result_path, representation_path, overwrite, gin_bindings=downstream_reconstruction_train_gin_bindings ) #["3d_shape_classifier.gin"])#gin_bindings=gin_bindings) pa = 1 / 0 downstream_train_gin_bindings = [ "evaluation.evaluation_fn = @downstream_regression_on_representations", "dataset.name = '3dshapes_task'", "evaluation.random_seed = 111", "downstream_regression_on_representations.num_train = [127500]", "downstream_regression_on_representations.num_test = 22500", "predictor.predictor_fn = @mlp_regressor", "mlp_regressor.hidden_layer_sizes = [32, 16]", "mlp_regressor.activation = 'logistic'", "mlp_regressor.max_iter = 100", "mlp_regressor.random_state = 0" ] for path in [path_vae]: result_path = os.path.join(path, "metrics", "TEST_factor_regression") representation_path = os.path.join(path, "representation") evaluate.evaluate_with_gin( representation_path, result_path, overwrite, gin_bindings=downstream_train_gin_bindings ) #["3d_shape_classifier.gin"])#gin_bindings=gin_bindings) pa = 1 / 0