Ejemplo n.º 1
0
class EvaluateTest(parameterized.TestCase):
    def setUp(self):
        super(EvaluateTest, self).setUp()
        self.model_dir = self.create_tempdir(
            "model", cleanup=absltest.TempFileCleanup.OFF).full_path
        model_config = resources.get_file(
            "config/tests/methods/unsupervised/train_test.gin")
        train.train_with_gin(self.model_dir, True, [model_config])
        self.output_dir = self.create_tempdir(
            "output", cleanup=absltest.TempFileCleanup.OFF).full_path
        postprocess_config = resources.get_file(
            "config/tests/postprocessing/postprocess_test_configs/mean.gin")
        postprocess.postprocess_with_gin(self.model_dir, self.output_dir, True,
                                         [postprocess_config])

    @parameterized.parameters(
        list(
            resources.get_files_in_folder(
                "config/tests/evaluation/evaluate_test_configs")))
    def test_evaluate(self, gin_config):
        # We clear the gin config before running. Otherwise, if a prior test fails,
        # the gin config is locked and the current test fails.
        gin.clear_config()
        evaluate.evaluate_with_gin(self.output_dir,
                                   self.create_tempdir().full_path, True,
                                   [gin_config])
Ejemplo n.º 2
0
class ReasonTestFromScratch(parameterized.TestCase):

  @parameterized.parameters(
      list(
          resources.get_files_in_folder(
              "config/tests/abstract_reasoning/from_scratch")))
  def test_reason_from_scratch(self, gin_config):
    # We clear the gin config before running. Otherwise, if a prior test fails,
    # the gin config is locked and the current test fails.
    gin.clear_config()
    reason.reason_with_gin(None,
                           self.create_tempdir().full_path, True, [gin_config])
Ejemplo n.º 3
0
class EvaluateTest(parameterized.TestCase):
    def setUp(self):
        super(EvaluateTest, self).setUp()
        self.model1_dir = self.create_tempdir(
            "model1/model", cleanup=absltest.TempFileCleanup.OFF).full_path
        self.model2_dir = self.create_tempdir(
            "model2/model", cleanup=absltest.TempFileCleanup.OFF).full_path
        model_config = resources.get_file(
            "config/tests/methods/unsupervised/train_test.gin")
        gin.clear_config()
        train.train_with_gin(self.model1_dir, True, [model_config])
        train.train_with_gin(self.model2_dir, True, [model_config])

        self.output_dir = self.create_tempdir(
            "output", cleanup=absltest.TempFileCleanup.OFF).full_path

    @parameterized.parameters(
        list(resources.get_files_in_folder("config/tests/methods/udr")))
    def test_evaluate(self, gin_config):
        # We clear the gin config before running. Otherwise, if a prior test fails,
        # the gin config is locked and the current test fails.
        gin.clear_config()
        gin.parse_config_files_and_bindings([gin_config], None)
        evaluate.evaluate([self.model1_dir, self.model2_dir], self.output_dir)
Ejemplo n.º 4
0
 def get_eval_config_files(self):
     """Returns evaluation config files."""
     return list(
         resources.get_files_in_folder(
             "config/abstract_reasoning_study_v1/stage1/metric_configs/"))
Ejemplo n.º 5
0
 def get_postprocess_config_files(self):
     """Returns postprocessing config files."""
     return list(
         resources.get_files_in_folder(
             "config/abstract_reasoning_study_v1/stage1/postprocess_configs/"
         ))
Ejemplo n.º 6
0
 def get_eval_config_files(self):
   """Returns evaluation config files."""
   return list(
       resources.get_files_in_folder(
           "config/unsupervised_study_v1/metric_configs/"))
Ejemplo n.º 7
0
 def get_postprocess_config_files(self):
   """Returns postprocessing config files."""
   return list(
       resources.get_files_in_folder(
           "config/unsupervised_study_v1/postprocess_configs/"))
Ejemplo n.º 8
0
 def get_eval_config_files(self):
   """Returns evaluation config files."""
   return list(
       resources.get_files_in_folder(
           "config/tests/evaluation/evaluate_test_configs"))
Ejemplo n.º 9
0
 def get_postprocess_config_files(self):
   """Returns postprocessing config files."""
   return list(
       resources.get_files_in_folder(
           "config/tests/postprocessing/postprocess_test_configs"))