def setUp(self): super(PostprocessTest, self).setUp() self.model_dir = self.create_tempdir( "model", cleanup=absltest.TempFileCleanup.OFF).full_path train.train_with_gin(self.model_dir, True, [ resources.get_file("config/tests/methods/unsupervised/train_test.gin") ], [])
def main(unused_argv): if FLAGS.gin_file_name is not None: model_config_file = [ os.path.join(FLAGS.gin_model_dir, FLAGS.gin_file_name) ] else: model_config_file = [] print('config_file: {}'.format(FLAGS.gin_config)) train.train_with_gin(FLAGS.model_dir, FLAGS.overwrite, FLAGS.gin_config + model_config_file, FLAGS.gin_bindings)
def setUp(self): super(EvaluateTest, self).setUp() self.model_dir = self.create_tempdir( "model", cleanup=absltest.TempFileCleanup.OFF).full_path model_config = resources.get_file( "config/tests/methods/unsupervised/train_test.gin") train.train_with_gin(self.model_dir, True, [model_config]) self.output_dir = self.create_tempdir( "output", cleanup=absltest.TempFileCleanup.OFF).full_path postprocess_config = resources.get_file( "config/tests/postprocessing/postprocess_test_configs/mean.gin") postprocess.postprocess_with_gin(self.model_dir, self.output_dir, True, [postprocess_config])
def setUp(self): super(EvaluateTest, self).setUp() self.model1_dir = self.create_tempdir( "model1/model", cleanup=absltest.TempFileCleanup.OFF).full_path self.model2_dir = self.create_tempdir( "model2/model", cleanup=absltest.TempFileCleanup.OFF).full_path model_config = resources.get_file( "config/tests/methods/unsupervised/train_test.gin") gin.clear_config() train.train_with_gin(self.model1_dir, True, [model_config]) train.train_with_gin(self.model2_dir, True, [model_config]) self.output_dir = self.create_tempdir( "output", cleanup=absltest.TempFileCleanup.OFF).full_path
def test_visualize_sigmoid(self, activation): activation_binding = ( "reconstruction_loss.activation = '{}'".format(activation)) self.model_dir = self.create_tempdir( "model_{}".format(activation), cleanup=absltest.TempFileCleanup.OFF).full_path train.train_with_gin(self.model_dir, True, [ resources.get_file( "config/tests/methods/unsupervised/train_test.gin") ], [activation_binding]) visualize_model.visualize( self.model_dir, self.create_tempdir( "visualization_{}".format(activation)).full_path, True, num_animations=1, num_frames=4, num_points_irs=100)
def main(unused_argv): # Obtain the study to reproduce. study = reproduce.STUDIES[FLAGS.study] # Print the hyperparameter settings. if FLAGS.model_dir is None: study.print_model_config(FLAGS.model_num) else: print("Model directory (skipped training):") print("--") print(FLAGS.model_dir) print() study.print_postprocess_config() print() study.print_eval_config() if FLAGS.only_print: return # Set correct output directory. if FLAGS.output_directory is None: if FLAGS.model_dir is None: output_directory = os.path.join("output", "{study}", "{model_num}") else: output_directory = "output" else: output_directory = FLAGS.output_directory # Insert model number and study name into path if necessary. output_directory = output_directory.format(model_num=str(FLAGS.model_num), study=str(FLAGS.study)) # Model training (if model directory is not provided). if FLAGS.model_dir is None: model_bindings, model_config_file = study.get_model_config(FLAGS.model_num) logging.info("Training model...") model_dir = os.path.join(output_directory, "model") model_bindings = [ "model.name = '{}'".format(os.path.basename(model_config_file)).replace( ".gin", ""), "model.model_num = {}".format(FLAGS.model_num), ] + model_bindings train.train_with_gin(model_dir, FLAGS.overwrite, [model_config_file], model_bindings) else: logging.info("Skipped training...") model_dir = FLAGS.model_dir # We visualize reconstructions, samples and latent space traversals. #visualize_dir = os.path.join(output_directory, "visualizations") #visualize_model.visualize(model_dir, visualize_dir, FLAGS.overwrite) # We fix the random seed for the postprocessing and evaluation steps (each # config gets a different but reproducible seed derived from a master seed of # 0). The model seed was set via the gin bindings and configs of the study. random_state = np.random.RandomState(0) # We extract the different representations and save them to disk. postprocess_config_files = sorted(study.get_postprocess_config_files()) for config in postprocess_config_files: post_name = os.path.basename(config).replace(".gin", "") logging.info("Extracting representation %s...", post_name) post_dir = os.path.join(output_directory, "postprocessed", post_name) postprocess_bindings = [ "postprocess.random_seed = {}".format(random_state.randint(2**32)), "postprocess.name = '{}'".format(post_name) ] postprocess.postprocess_with_gin(model_dir, post_dir, FLAGS.overwrite, [config], postprocess_bindings) # Iterate through the disentanglement metrics. eval_configs = sorted(study.get_eval_config_files()) print(eval_configs) for config in postprocess_config_files: post_name = os.path.basename(config).replace(".gin", "") post_dir = os.path.join(output_directory, "postprocessed", post_name) # Now, we compute all the specified scores. for gin_eval_config in eval_configs: metric_name = os.path.basename(gin_eval_config).replace(".gin", "") print(metric_name) logging.info("Computing metric '%s' on '%s'...", metric_name, post_name) metric_dir = os.path.join(output_directory, "metrics", post_name, metric_name) eval_bindings = [ "evaluation.random_seed = {}".format(random_state.randint(2**32)), "evaluation.name = '{}'".format(metric_name) ] evaluate.evaluate_with_gin(post_dir, metric_dir, FLAGS.overwrite, [gin_eval_config], eval_bindings)
def test_train_model(self, gin_configs, gin_bindings): # We clear the gin config before running. Otherwise, if a prior test fails, # the gin config is locked and the current test fails. gin.clear_config() train.train_with_gin(self.create_tempdir().full_path, True, gin_configs, gin_bindings)