예제 #1
0
 def setUp(self):
   super(PostprocessTest, self).setUp()
   self.model_dir = self.create_tempdir(
       "model", cleanup=absltest.TempFileCleanup.OFF).full_path
   train.train_with_gin(self.model_dir, True, [
       resources.get_file("config/tests/methods/unsupervised/train_test.gin")
   ], [])
예제 #2
0
def main(unused_argv):
    if FLAGS.gin_file_name is not None:
        model_config_file = [
            os.path.join(FLAGS.gin_model_dir, FLAGS.gin_file_name)
        ]
    else:
        model_config_file = []
    print('config_file: {}'.format(FLAGS.gin_config))
    train.train_with_gin(FLAGS.model_dir, FLAGS.overwrite,
                         FLAGS.gin_config + model_config_file,
                         FLAGS.gin_bindings)
예제 #3
0
 def setUp(self):
     super(EvaluateTest, self).setUp()
     self.model_dir = self.create_tempdir(
         "model", cleanup=absltest.TempFileCleanup.OFF).full_path
     model_config = resources.get_file(
         "config/tests/methods/unsupervised/train_test.gin")
     train.train_with_gin(self.model_dir, True, [model_config])
     self.output_dir = self.create_tempdir(
         "output", cleanup=absltest.TempFileCleanup.OFF).full_path
     postprocess_config = resources.get_file(
         "config/tests/postprocessing/postprocess_test_configs/mean.gin")
     postprocess.postprocess_with_gin(self.model_dir, self.output_dir, True,
                                      [postprocess_config])
예제 #4
0
    def setUp(self):
        super(EvaluateTest, self).setUp()
        self.model1_dir = self.create_tempdir(
            "model1/model", cleanup=absltest.TempFileCleanup.OFF).full_path
        self.model2_dir = self.create_tempdir(
            "model2/model", cleanup=absltest.TempFileCleanup.OFF).full_path
        model_config = resources.get_file(
            "config/tests/methods/unsupervised/train_test.gin")
        gin.clear_config()
        train.train_with_gin(self.model1_dir, True, [model_config])
        train.train_with_gin(self.model2_dir, True, [model_config])

        self.output_dir = self.create_tempdir(
            "output", cleanup=absltest.TempFileCleanup.OFF).full_path
예제 #5
0
 def test_visualize_sigmoid(self, activation):
     activation_binding = (
         "reconstruction_loss.activation = '{}'".format(activation))
     self.model_dir = self.create_tempdir(
         "model_{}".format(activation),
         cleanup=absltest.TempFileCleanup.OFF).full_path
     train.train_with_gin(self.model_dir, True, [
         resources.get_file(
             "config/tests/methods/unsupervised/train_test.gin")
     ], [activation_binding])
     visualize_model.visualize(
         self.model_dir,
         self.create_tempdir(
             "visualization_{}".format(activation)).full_path,
         True,
         num_animations=1,
         num_frames=4,
         num_points_irs=100)
예제 #6
0
def main(unused_argv):
  # Obtain the study to reproduce.
  study = reproduce.STUDIES[FLAGS.study]

  # Print the hyperparameter settings.
  if FLAGS.model_dir is None:
    study.print_model_config(FLAGS.model_num)
  else:
    print("Model directory (skipped training):")
    print("--")
    print(FLAGS.model_dir)
  print()
  study.print_postprocess_config()
  print()
  study.print_eval_config()
  if FLAGS.only_print:
    return

  # Set correct output directory.
  if FLAGS.output_directory is None:
    if FLAGS.model_dir is None:
      output_directory = os.path.join("output", "{study}", "{model_num}")
    else:
      output_directory = "output"
  else:
    output_directory = FLAGS.output_directory

  # Insert model number and study name into path if necessary.
  output_directory = output_directory.format(model_num=str(FLAGS.model_num),
                                             study=str(FLAGS.study))

  # Model training (if model directory is not provided).
  if FLAGS.model_dir is None:
    model_bindings, model_config_file = study.get_model_config(FLAGS.model_num)
    logging.info("Training model...")
    model_dir = os.path.join(output_directory, "model")
    model_bindings = [
        "model.name = '{}'".format(os.path.basename(model_config_file)).replace(
            ".gin", ""),
        "model.model_num = {}".format(FLAGS.model_num),
    ] + model_bindings
    train.train_with_gin(model_dir, FLAGS.overwrite, [model_config_file],
                         model_bindings)
  else:
    logging.info("Skipped training...")
    model_dir = FLAGS.model_dir

  # We visualize reconstructions, samples and latent space traversals.
  #visualize_dir = os.path.join(output_directory, "visualizations")
  #visualize_model.visualize(model_dir, visualize_dir, FLAGS.overwrite)

  # We fix the random seed for the postprocessing and evaluation steps (each
  # config gets a different but reproducible seed derived from a master seed of
  # 0). The model seed was set via the gin bindings and configs of the study.
  random_state = np.random.RandomState(0)

  # We extract the different representations and save them to disk.
  postprocess_config_files = sorted(study.get_postprocess_config_files())
  for config in postprocess_config_files:
    post_name = os.path.basename(config).replace(".gin", "")
    logging.info("Extracting representation %s...", post_name)
    post_dir = os.path.join(output_directory, "postprocessed", post_name)
    postprocess_bindings = [
        "postprocess.random_seed = {}".format(random_state.randint(2**32)),
        "postprocess.name = '{}'".format(post_name)
    ]
    postprocess.postprocess_with_gin(model_dir, post_dir, FLAGS.overwrite,
                                     [config], postprocess_bindings)

  # Iterate through the disentanglement metrics.
  eval_configs = sorted(study.get_eval_config_files())
  print(eval_configs)
  for config in postprocess_config_files:
    post_name = os.path.basename(config).replace(".gin", "")
    post_dir = os.path.join(output_directory, "postprocessed",
                            post_name)
    # Now, we compute all the specified scores.
    for gin_eval_config in eval_configs:
      
      metric_name = os.path.basename(gin_eval_config).replace(".gin", "")
      print(metric_name)
      logging.info("Computing metric '%s' on '%s'...", metric_name, post_name)
      metric_dir = os.path.join(output_directory, "metrics", post_name,
                                metric_name)
      eval_bindings = [
          "evaluation.random_seed = {}".format(random_state.randint(2**32)),
          "evaluation.name = '{}'".format(metric_name)
      ]
      evaluate.evaluate_with_gin(post_dir, metric_dir, FLAGS.overwrite,
                                 [gin_eval_config], eval_bindings)
예제 #7
0
 def test_train_model(self, gin_configs, gin_bindings):
   # We clear the gin config before running. Otherwise, if a prior test fails,
   # the gin config is locked and the current test fails.
   gin.clear_config()
   train.train_with_gin(self.create_tempdir().full_path, True, gin_configs,
                        gin_bindings)