Exemple #1
0
def main(argv):
    del argv # Unused

    # Save all results in subdirectories of following path
    base_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), FLAGS.base_dir)

    # Overwrite output or not (for rerunning script)
    overwrite = True

    # Results directory of BetaTCVAE
    path_adagvae = os.path.join(base_path,FLAGS.output_dir)

    gin_bindings = [
        "model.random_seed = {}".format(FLAGS.seed),
        "subset.name = '{}'".format(FLAGS.subset),
        "encoder.num_latent = {}".format(FLAGS.dim)
    ]
    # Train model. Training is configured with a gin config
    train.train_with_gin(os.path.join(path_adagvae, 'model'), overwrite,
                         ['baselines/adagvae/adagvae_train.gin'], gin_bindings)

    # Extract mean representation of latent space
    representation_path = os.path.join(path_adagvae, "representation")
    model_path = os.path.join(path_adagvae, "model")
    postprocess_gin = ['baselines/adagvae/adagvae_postprocess.gin']  # This contains the settings.
    postprocess.postprocess_with_gin(model_path, representation_path, overwrite,
                                     postprocess_gin)

    # Compute DCI metric
    result_path = os.path.join(path_adagvae, "metrics", "dci")
    representation_path = os.path.join(path_adagvae, "representation")
    evaluate.evaluate_with_gin(representation_path, result_path, overwrite, ['baselines/adagvae/adagvae_dci.gin'])
Exemple #2
0
 def test_evaluate(self, gin_config):
     # We clear the gin config before running. Otherwise, if a prior test fails,
     # the gin config is locked and the current test fails.
     gin.clear_config()
     evaluate.evaluate_with_gin(self.output_dir,
                                self.create_tempdir().full_path, True,
                                [gin_config])
Exemple #3
0
def eval(study, output_directory, model_dir):
    # We fix the random seed for the postprocessing and evaluation steps (each
    # config gets a different but reproducible seed derived from a master seed of
    # 0). The model seed was set via the gin bindings and configs of the study.
    random_state = np.random.RandomState(0)

    # We extract the different representations and save them to disk.
    postprocess_config_files = sorted(study.get_postprocess_config_files())
    for config in postprocess_config_files:
        post_name = os.path.basename(config).replace(".gin", "")
        logging.info("Extracting representation %s...", post_name)
        post_dir = os.path.join(output_directory, "postprocessed", post_name)
        postprocess_bindings = [
            "postprocess.random_seed = {}".format(random_state.randint(2**16)),
            "postprocess.name = '{}'".format(post_name)
        ]
        postprocess.postprocess_with_gin(model_dir, post_dir, FLAGS.overwrite,
                                         [config], postprocess_bindings)

    # Iterate through the disentanglement metrics.
    eval_configs = sorted(study.get_eval_config_files())
    for config in postprocess_config_files:
        post_name = os.path.basename(config).replace(".gin", "")
        post_dir = os.path.join(output_directory, "postprocessed", post_name)
        # Now, we compute all the specified scores.
        for gin_eval_config in eval_configs:
            metric_name = os.path.basename(gin_eval_config).replace(".gin", "")
            logging.info("Computing metric '%s' on '%s'...", metric_name,
                         post_name)
            metric_dir = os.path.join(output_directory, "metrics", post_name,
                                      metric_name)
            eval_bindings = [
                "evaluation.random_seed = {}".format(
                    random_state.randint(2**16)),
                "evaluation.name = '{}'".format(metric_name)
            ]
            evaluate.evaluate_with_gin(post_dir, metric_dir, FLAGS.overwrite,
                                       [gin_eval_config], eval_bindings)
def main(unused_argv):
    base_path = "3dshapes_models"

    done = False
    while not done:
        try:
            print("\n\n*- Preprocessing '%s' \n\n" % (FLAGS.dataset))
            preproces_gin_bindings = [
                "dataset.name = '%s'" % (FLAGS.dataset),
                "preprocess.preprocess_fn = @split_train_and_validation_per_model",
                "split_train_and_validation_per_model.random_seed = %d" %
                (FLAGS.rng)
            ]

            preprocess.preprocess_with_gin(FLAGS.dataset,
                                           FLAGS.model,
                                           overwrite=FLAGS.overwrite,
                                           gin_config_files=None,
                                           gin_bindings=preproces_gin_bindings)
            print("\n\n*- Preprocessing DONE \n\n")
            done = True
        except:
            time.sleep(30)

    if FLAGS.model == "vae":
        gin_file = "3d_shape_vae.gin"
    if FLAGS.model == "bvae":
        gin_file = "3d_shape_bvae.gin"
    if FLAGS.model == "b8vae":
        gin_file = "3d_shape_b8vae.gin"
    if FLAGS.model == "fvae":
        gin_file = "3d_shape_fvae.gin"
    if FLAGS.model == "btcvae":
        gin_file = "3d_shape_btcvae.gin"
    if FLAGS.model == "annvae":
        gin_file = "3d_shape_annvae.gin"
    if FLAGS.model == "randomvae":
        gin_file = "3d_shape_randomvae.gin"

    print("\n\n*- Training '%s' \n\n" % (FLAGS.model))
    vae_gin_bindings = [
        "model.random_seed = %d" % (FLAGS.rng),
        "dataset.name = '%s'" %
        (FLAGS.dataset + '_' + FLAGS.model + '_' + str(FLAGS.rng))
    ]
    vae_path = os.path.join(base_path,
                            FLAGS.model + FLAGS.dataset + '_' + str(FLAGS.rng))
    train_vae_path = os.path.join(vae_path, 'model')
    unsupervised_train_partial.train_with_gin(train_vae_path, FLAGS.overwrite,
                                              [gin_file], vae_gin_bindings)
    visualize_model.visualize(train_vae_path, vae_path + "/vis",
                              FLAGS.overwrite)
    preprocess.destroy_train_and_validation_splits(FLAGS.dataset + '_' +
                                                   FLAGS.model + '_' +
                                                   str(FLAGS.rng))
    print("\n\n*- Training DONE \n\n")

    print("\n\n*- Postprocessing '%s' \n\n" % (FLAGS.model))
    postprocess_gin_bindings = [
        "postprocess.postprocess_fn = @mean_representation",
        "dataset.name='dummy_data'",
        "postprocess.random_seed = %d" % (FLAGS.rng)
    ]

    representation_path = os.path.join(vae_path, "representation")
    model_path = os.path.join(vae_path, "model")
    postprocess.postprocess_with_gin(model_path,
                                     representation_path,
                                     FLAGS.overwrite,
                                     gin_config_files=None,
                                     gin_bindings=postprocess_gin_bindings)
    print("\n\n*- Postprocessing DONE \n\n")

    # --- Evaluate disentanglement metrics
    print("\n\n*- Evaluating MIG.")
    gin_bindings = [
        "evaluation.evaluation_fn = @mig", "dataset.name='3dshapes'",
        "evaluation.random_seed = 0", "mig.num_train = 10000",
        "discretizer.discretizer_fn = @histogram_discretizer",
        "discretizer.num_bins = 20"
    ]
    result_path = os.path.join(vae_path, "metrics", "mig")
    evaluate.evaluate_with_gin(representation_path,
                               result_path,
                               FLAGS.overwrite,
                               gin_bindings=gin_bindings)

    print("\n\n*- Evaluating BetaVEA.")
    gin_bindings = [
        "evaluation.evaluation_fn = @beta_vae_sklearn",
        "dataset.name='3dshapes'", "evaluation.random_seed = 0",
        "beta_vae_sklearn.batch_size = 16",
        "beta_vae_sklearn.num_train = 10000",
        "beta_vae_sklearn.num_eval = 5000",
        "discretizer.discretizer_fn = @histogram_discretizer",
        "discretizer.num_bins = 20"
    ]
    result_path = os.path.join(vae_path, "metrics", "bvae")
    evaluate.evaluate_with_gin(representation_path,
                               result_path,
                               FLAGS.overwrite,
                               gin_bindings=gin_bindings)

    print("\n\n*- Evaluating FactorVAE.")
    gin_bindings = [
        "evaluation.evaluation_fn = @factor_vae_score",
        "dataset.name='3dshapes'", "evaluation.random_seed = 0",
        "factor_vae_score.batch_size = 16",
        "factor_vae_score.num_train = 10000",
        "factor_vae_score.num_eval = 5000",
        "factor_vae_score.num_variance_estimate = 10000",
        "discretizer.discretizer_fn = @histogram_discretizer",
        "discretizer.num_bins = 20"
    ]

    result_path = os.path.join(vae_path, "metrics", "fvae")
    evaluate.evaluate_with_gin(representation_path,
                               result_path,
                               FLAGS.overwrite,
                               gin_bindings=gin_bindings)

    print("\n\n*- Evaluating DCI.")
    gin_bindings = [
        "evaluation.evaluation_fn = @dci", "dataset.name='3dshapes'",
        "evaluation.random_seed = 0", "dci.batch_size = 16",
        "dci.num_train = 10000", "dci.num_test = 5000",
        "discretizer.discretizer_fn = @histogram_discretizer",
        "discretizer.num_bins = 20"
    ]

    result_path = os.path.join(vae_path, "metrics", "dci")
    evaluate.evaluate_with_gin(representation_path,
                               result_path,
                               FLAGS.overwrite,
                               gin_bindings=gin_bindings)
    print("\n\n*- Evaluation COMPLETED \n\n")

    # --- Downstream tasks
    print("\n\n*- Training downstream factor regression '%s' \n\n" %
          (FLAGS.model))
    downstream_regression_train_gin_bindings = [
        "evaluation.evaluation_fn = @downstream_regression_on_representations",
        "dataset.name = '3dshapes_task'", "evaluation.random_seed = 0",
        "downstream_regression_on_representations.holdout_dataset_name = '3dshapes_holdout'",
        "downstream_regression_on_representations.num_train = [127500]",
        "downstream_regression_on_representations.num_test = 22500",
        "downstream_regression_on_representations.num_holdout = 80000",
        "predictor.predictor_fn = @mlp_regressor",
        "mlp_regressor.hidden_layer_sizes = [16, 8]",
        "mlp_regressor.activation = 'logistic'", "mlp_regressor.max_iter = 50",
        "mlp_regressor.random_state = 0"
    ]

    result_path = os.path.join(vae_path, "metrics", "factor_regression")
    evaluate.evaluate_with_gin(
        representation_path,
        result_path,
        FLAGS.overwrite,
        gin_config_files=None,
        gin_bindings=downstream_regression_train_gin_bindings)
    print("\n\n*- Training downstream factor regression DONE \n\n")

    print("\n\n*- Training downstream reconstruction '%s' \n\n" %
          (FLAGS.model))
    downstream_reconstruction_train_gin_bindings = [
        "supervised_model.model = @downstream_decoder()",
        "supervised_model.batch_size = 64",
        "supervised_model.training_steps = 30000",
        "supervised_model.eval_steps = 1000",
        "supervised_model.random_seed = 0",
        "supervised_model.holdout_dataset_name = '3dshapes_holdout'",
        "dataset.name='3dshapes_task'",
        "decoder_optimizer.optimizer_fn = @AdamOptimizer",
        "AdamOptimizer.beta1 = 0.9", "AdamOptimizer.beta2 = 0.999",
        "AdamOptimizer.epsilon = 1e-08",
        "AdamOptimizer.learning_rate = 0.0001", "AdamOptimizer.name = 'Adam'",
        "AdamOptimizer.use_locking = False",
        "decoder.decoder_fn = @deconv_decoder",
        "reconstruction_loss.loss_fn = @l2_loss"
    ]

    result_path = os.path.join(vae_path, "metrics", "reconstruction")
    supervised_train_partial.train_with_gin(
        result_path,
        representation_path,
        FLAGS.overwrite,
        gin_bindings=downstream_reconstruction_train_gin_bindings)
    visualize_model.visualize_supervised(result_path, representation_path,
                                         result_path + "/vis", FLAGS.overwrite)

    print("\n\n*- Training downstream reconstruction DONE \n\n")
    print("\n\n*- Training & evaluation COMPLETED \n\n")
Exemple #5
0
# We use the Mutual Information Gap (with a low number of samples to make it
# faster). To learn more, have a look at the different scores in
# disentanglement_lib.evaluation.evaluate.metrics and the predefined .gin
# configuration files in
# disentanglement_lib/config/unsupervised_study_v1/metrics_configs/(...).
gin_bindings = [
    "evaluation.evaluation_fn = @mig", "dataset.name='auto'",
    "evaluation.random_seed = 0", "mig.num_train=1000",
    "discretizer.discretizer_fn = @histogram_discretizer",
    "discretizer.num_bins = 20"
]
for path in [path_vae, path_custom_vae]:
    result_path = os.path.join(path, "metrics", "mig")
    representation_path = os.path.join(path, "representation")
    evaluate.evaluate_with_gin(representation_path,
                               result_path,
                               overwrite,
                               gin_bindings=gin_bindings)


# 5. Compute a custom disentanglement metric for both models.
# ------------------------------------------------------------------------------
# The following function implements a dummy metric. Note that all metrics get
# ground_truth_data, representation_function, random_state arguments by the
# evaluation protocol, while all other arguments have to be configured via gin.
@gin.configurable(
    "custom_metric",
    blacklist=["ground_truth_data", "representation_function", "random_state"])
def compute_custom_metric(ground_truth_data,
                          representation_function,
                          random_state,
                          num_train=gin.REQUIRED,
Exemple #6
0
def main(unused_argv):
    base_path = "3dshapes_models"

    print("\n\n*- Evaluating '%s' \n\n" % (FLAGS.model))
    vae_path = os.path.join(base_path,
                            FLAGS.model + FLAGS.dataset + '_' + str(FLAGS.rng))
    representation_path = os.path.join(vae_path, "representation")
    print(vae_path, representation_path)

    print("\n\n*- Evaluating MIG.")
    gin_bindings = [
        "evaluation.evaluation_fn = @mig", "dataset.name='3dshapes'",
        "evaluation.random_seed = 0", "mig.num_train = 100000",
        "discretizer.discretizer_fn = @histogram_discretizer",
        "discretizer.num_bins = 20"
    ]
    result_path = os.path.join(vae_path, "metrics", "mig_10000")
    evaluate.evaluate_with_gin(representation_path,
                               result_path,
                               FLAGS.overwrite,
                               gin_bindings=gin_bindings)

    print("\n\n*- Evaluating BetaVEA.")
    gin_bindings = [
        "evaluation.evaluation_fn = @beta_vae_sklearn",
        "dataset.name='3dshapes'", "evaluation.random_seed = 0",
        "beta_vae_sklearn.batch_size = 16",
        "beta_vae_sklearn.num_train = 100000",
        "beta_vae_sklearn.num_eval = 5000",
        "discretizer.discretizer_fn = @histogram_discretizer",
        "discretizer.num_bins = 20"
    ]
    result_path = os.path.join(vae_path, "metrics", "bvae_10000")
    evaluate.evaluate_with_gin(representation_path,
                               result_path,
                               FLAGS.overwrite,
                               gin_bindings=gin_bindings)

    print("\n\n*- Evaluating FactorVAE.")
    gin_bindings = [
        "evaluation.evaluation_fn = @factor_vae_score",
        "dataset.name='3dshapes'", "evaluation.random_seed = 0",
        "factor_vae_score.batch_size = 16",
        "factor_vae_score.num_train = 100000",
        "factor_vae_score.num_eval = 5000",
        "factor_vae_score.num_variance_estimate = 100000",
        "discretizer.discretizer_fn = @histogram_discretizer",
        "discretizer.num_bins = 20"
    ]

    result_path = os.path.join(vae_path, "metrics", "fvae_10000")
    evaluate.evaluate_with_gin(representation_path,
                               result_path,
                               FLAGS.overwrite,
                               gin_bindings=gin_bindings)

    print("\n\n*- Evaluating DCI.")
    gin_bindings = [
        "evaluation.evaluation_fn = @dci", "dataset.name='3dshapes'",
        "evaluation.random_seed = 0", "dci.batch_size = 16",
        "dci.num_train = 100000", "dci.num_test = 5000",
        "discretizer.discretizer_fn = @histogram_discretizer",
        "discretizer.num_bins = 20"
    ]

    result_path = os.path.join(vae_path, "metrics", "dci_10000")
    evaluate.evaluate_with_gin(representation_path,
                               result_path,
                               FLAGS.overwrite,
                               gin_bindings=gin_bindings)

    print("\n\n*- Evaluation COMPLETED \n\n")
Exemple #7
0
def main():
    parser = argparse.ArgumentParser(description='Project description.')
    parser.add_argument('--result_dir',
                        help='Results directory.',
                        type=str,
                        default='/mnt/hdd/repo_results/Ramiel/sweep')
    parser.add_argument('--study',
                        help='Name of the study.',
                        type=str,
                        default='unsupervised_study_v1')
    parser.add_argument('--model_gin',
                        help='Name of the gin config.',
                        type=str,
                        default='test_model.gin')
    parser.add_argument('--model_name',
                        help='Name of the model.',
                        type=str,
                        default='GroupVAE')
    parser.add_argument('--vae_beta',
                        help='Beta-VAE beta.',
                        type=str,
                        default='1')
    parser.add_argument('--hyps',
                        help='Hyperparameters of rec_mat_oth_spl_seed.',
                        type=str,
                        default='1_1_1_1_1_0')
    parser.add_argument('--overwrite',
                        help='Whether to overwrite output directory.',
                        type=_str_to_bool,
                        default=False)
    parser.add_argument('--dataset',
                        help='Dataset.',
                        type=str,
                        default='dsprites_full')
    parser.add_argument('--recons_type',
                        help='Reconstruction loss type.',
                        type=str,
                        default='bernoulli_loss')
    args = parser.parse_args()

    # 1. Settings
    study = reproduce.STUDIES[args.study]
    args.hyps = args.hyps.split('_')
    print()
    study.print_postprocess_config()
    print()
    study.print_eval_config()

    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            # Currently, memory growth needs to be the same across GPUs
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus),
                  "Logical GPUs")
        except RuntimeError as e:
            # Memory growth must be set before GPUs have been initialized
            print(e)

    # Call training module to train the custom model.
    if args.model_name == "GroupVAE":
        dir_name = "GroupVAE-" + "-".join(args.hyps)
    elif args.model_name == "vae":
        dir_name = "LieVAE-" + args.vae_beta + "-" + args.hyps[5]
    output_directory = os.path.join(args.result_dir, dir_name)
    model_dir = os.path.join(output_directory, "model")
    gin_bindings = [
        "model.model = @" + args.model_name + "()",
        "vae.beta = " + args.vae_beta, "GroupVAE.hy_rec = " + args.hyps[0],
        "GroupVAE.hy_mat = " + args.hyps[1],
        "GroupVAE.hy_oth = " + args.hyps[2],
        "GroupVAE.hy_spl = " + args.hyps[3],
        "GroupVAE.hy_ncut = " + args.hyps[4],
        "model.random_seed = " + args.hyps[5],
        "dataset.name = '" + args.dataset + "'",
        "reconstruction_loss.loss_fn = @" + args.recons_type
    ]
    train.train_with_gin(model_dir, args.overwrite, [args.model_gin],
                         gin_bindings)

    # We fix the random seed for the postprocessing and evaluation steps (each
    # config gets a different but reproducible seed derived from a master seed of
    # 0). The model seed was set via the gin bindings and configs of the study.
    random_state = np.random.RandomState(0)

    # We extract the different representations and save them to disk.
    postprocess_config_files = sorted(study.get_postprocess_config_files())
    for config in postprocess_config_files:
        post_name = os.path.basename(config).replace(".gin", "")
        print("Extracting representation " + post_name + "...")
        post_dir = os.path.join(output_directory, "postprocessed", post_name)
        postprocess_bindings = [
            "postprocess.random_seed = {}".format(random_state.randint(2**32)),
            "postprocess.name = '{}'".format(post_name)
        ]
        postprocess.postprocess_with_gin(model_dir, post_dir, args.overwrite,
                                         [config], postprocess_bindings)

    # Iterate through the disentanglement metrics.
    eval_configs = sorted(study.get_eval_config_files())
    blacklist = ['downstream_task_logistic_regression.gin']
    # blacklist = [
    # 'downstream_task_logistic_regression.gin', 'beta_vae_sklearn.gin',
    # 'dci.gin', 'downstream_task_boosted_trees.gin', 'mig.gin',
    # 'modularity_explicitness.gin', 'sap_score.gin', 'unsupervised.gin'
    # ]
    for config in postprocess_config_files:
        post_name = os.path.basename(config).replace(".gin", "")
        post_dir = os.path.join(output_directory, "postprocessed", post_name)
        # Now, we compute all the specified scores.
        for gin_eval_config in eval_configs:
            if os.path.basename(gin_eval_config) not in blacklist:
                metric_name = os.path.basename(gin_eval_config).replace(
                    ".gin", "")
                print("Computing metric " + metric_name + " on " + post_name +
                      "...")
                metric_dir = os.path.join(output_directory, "metrics",
                                          post_name, metric_name)
                eval_bindings = [
                    "evaluation.random_seed = {}".format(
                        random_state.randint(2**32)),
                    "evaluation.name = '{}'".format(metric_name)
                ]
                evaluate.evaluate_with_gin(post_dir, metric_dir,
                                           args.overwrite, [gin_eval_config],
                                           eval_bindings)

    # We visualize reconstructions, samples and latent space traversals.
    visualize_dir = os.path.join(output_directory, "visualizations")
    visualize_model.visualize(model_dir, visualize_dir, args.overwrite)
def main():
  parser = argparse.ArgumentParser(description='Project description.')
  parser.add_argument('--study', help='Name of the study.', type=str, default='unsupervised_study_v1')
  parser.add_argument('--output_directory', help='Output directory of experiments.', type=str, default=None)
  parser.add_argument('--model_dir', help='Directory to take trained model from.', type=str, default=None)
  parser.add_argument('--model_num', help='Integer with model number to train.', type=int, default=None)
  parser.add_argument('--only_print', help='Whether to only print the hyperparameter settings.', type=_str_to_bool, default=False)
  parser.add_argument('--overwrite', help='Whether to overwrite output directory.', type=_str_to_bool, default=False)
  args = parser.parse_args()
  # logging.set_verbosity('error')
  # logging.set_stderrthreshold('error')
  gpus = tf.config.experimental.list_physical_devices('GPU')
  if gpus:
    try:
      # Currently, memory growth needs to be the same across GPUs
      for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
      logical_gpus = tf.config.experimental.list_logical_devices('GPU')
      print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
      # Memory growth must be set before GPUs have been initialized
      print(e)

  # Obtain the study to reproduce.
  study = reproduce.STUDIES[args.study]

  # Print the hyperparameter settings.
  if args.model_dir is None:
    study.print_model_config(args.model_num)
  else:
    print("Model directory (skipped training):")
    print("--")
    print(args.model_dir)
  print()
  study.print_postprocess_config()
  print()
  study.print_eval_config()
  if args.only_print:
    return

  # Set correct output directory.
  if args.output_directory is None:
    if args.model_dir is None:
      output_directory = os.path.join("output", "{study}", "{model_num}")
    else:
      output_directory = "output"
  else:
    output_directory = args.output_directory

  # Insert model number and study name into path if necessary.
  output_directory = output_directory.format(model_num=str(args.model_num),
                                             study=str(args.study))

  # Model training (if model directory is not provided).
  if args.model_dir is None:
    model_bindings, model_config_file = study.get_model_config(args.model_num)
    print("Training model...")
    model_dir = os.path.join(output_directory, "model")
    model_bindings = [
        "model.name = '{}'".format(os.path.basename(model_config_file)).replace(
            ".gin", ""),
        "model.model_num = {}".format(args.model_num),
    ] + model_bindings
    train.train_with_gin(model_dir, args.overwrite, [model_config_file],
                         model_bindings)
  else:
    print("Skipped training...")
    model_dir = args.model_dir

  # We visualize reconstructions, samples and latent space traversals.
  visualize_dir = os.path.join(output_directory, "visualizations")
  visualize_model.visualize(model_dir, visualize_dir, args.overwrite)

  # We fix the random seed for the postprocessing and evaluation steps (each
  # config gets a different but reproducible seed derived from a master seed of
  # 0). The model seed was set via the gin bindings and configs of the study.
  random_state = np.random.RandomState(0)

  # We extract the different representations and save them to disk.
  postprocess_config_files = sorted(study.get_postprocess_config_files())
  for config in postprocess_config_files:
    post_name = os.path.basename(config).replace(".gin", "")
    print("Extracting representation %s..." % post_name)
    post_dir = os.path.join(output_directory, "postprocessed", post_name)
    postprocess_bindings = [
        "postprocess.random_seed = {}".format(random_state.randint(2**32)),
        "postprocess.name = '{}'".format(post_name)
    ]
    postprocess.postprocess_with_gin(model_dir, post_dir, args.overwrite,
                                     [config], postprocess_bindings)

  # Iterate through the disentanglement metrics.
  eval_configs = sorted(study.get_eval_config_files())
  blacklist = ['downstream_task_logistic_regression.gin']
  for config in postprocess_config_files:
    post_name = os.path.basename(config).replace(".gin", "")
    post_dir = os.path.join(output_directory, "postprocessed",
                            post_name)
    # Now, we compute all the specified scores.
    for gin_eval_config in eval_configs:
      if os.path.basename(gin_eval_config) not in blacklist:
        metric_name = os.path.basename(gin_eval_config).replace(".gin", "")
        print("Computing metric '%s' on '%s'..." % (metric_name, post_name))
        metric_dir = os.path.join(output_directory, "metrics", post_name,
                                  metric_name)
        eval_bindings = [
            "evaluation.random_seed = {}".format(random_state.randint(2**32)),
            "evaluation.name = '{}'".format(metric_name)
        ]
        evaluate.evaluate_with_gin(post_dir, metric_dir, args.overwrite,
                                   [gin_eval_config], eval_bindings)
Exemple #9
0
        model_path = os.path.join(path, "model")
        postprocess_gin = ["postprocess.gin"]  # This contains the settings.
        # postprocess.postprocess_with_gin defines the standard extraction protocol.
        postprocess.postprocess_with_gin(model_path, representation_path,
                                         overwrite, postprocess_gin)

    # 5. Compute metrics for all models.
    metrics = glob.glob('../disentanglement_lib/disentanglement_lib/config/unsupervised_study_v1/metric_configs/*.gin')
    blacklist = ['downstream_task_logistic_regression.gin']
    for path in [path_custom_vae]:
        for metric in metrics:
            if os.path.basename(metric) not in blacklist:
                result_path = os.path.join(path, "metrics", os.path.basename(metric).replace('.gin', ''))
                representation_path = os.path.join(path, "representation")
                evaluate.evaluate_with_gin(representation_path,
                                           result_path,
                                           overwrite,
                                           gin_config_files=[metric])

# 6. Aggregate the results.
# ------------------------------------------------------------------------------
# In the previous steps, we saved the scores to several output directories. We
# can aggregate all the results using the following command.
pattern = os.path.join(base_path,
                       "*/metrics/*/results/aggregate/evaluation.json")
results_path = os.path.join(base_path, "results.json")
aggregate_results.aggregate_results_to_json(pattern, results_path)

# 7. Print out the final Pandas data frame with the results.
# ------------------------------------------------------------------------------
# The aggregated results contains for each computed metric all the configuration
# options and all the results captured in the steps along the pipeline. This
Exemple #10
0
        result_path,
        representation_path,
        overwrite,
        gin_bindings=downstream_reconstruction_train_gin_bindings
    )  #["3d_shape_classifier.gin"])#gin_bindings=gin_bindings)

pa = 1 / 0

downstream_train_gin_bindings = [
    "evaluation.evaluation_fn = @downstream_regression_on_representations",
    "dataset.name = '3dshapes_task'", "evaluation.random_seed = 111",
    "downstream_regression_on_representations.num_train = [127500]",
    "downstream_regression_on_representations.num_test = 22500",
    "predictor.predictor_fn = @mlp_regressor",
    "mlp_regressor.hidden_layer_sizes = [32, 16]",
    "mlp_regressor.activation = 'logistic'", "mlp_regressor.max_iter = 100",
    "mlp_regressor.random_state = 0"
]

for path in [path_vae]:
    result_path = os.path.join(path, "metrics", "TEST_factor_regression")
    representation_path = os.path.join(path, "representation")
    evaluate.evaluate_with_gin(
        representation_path,
        result_path,
        overwrite,
        gin_bindings=downstream_train_gin_bindings
    )  #["3d_shape_classifier.gin"])#gin_bindings=gin_bindings)

pa = 1 / 0