def evaluate_metrics_on_permuted_runs():
  """Evaluates metrics on permuted runs, for across-run metrics only."""
  gin_bindings = [
      ('eval_metrics.Evaluator.metrics = '
       '[@IqrAcrossRuns/singleton(), @LowerCVaROnAcross/singleton()]')
  ]
  n_permutations_per_worker = int(p.n_random_samples / p.n_worker)

  # Parse gin config.
  gin.parse_config_files_and_bindings([p.gin_file], gin_bindings)

  for algo1 in p.algos:
    for algo2 in p.algos:
      for task in p.tasks:
        for i_worker in range(p.n_worker):
          # Get the subdirectories corresponding to each run.
          summary_path_1 = os.path.join(p.data_dir, algo1, task)
          summary_path_2 = os.path.join(p.data_dir, algo2, task)
          run_dirs_1 = eval_metrics.get_run_dirs(summary_path_1, 'train',
                                                 p.runs)
          run_dirs_2 = eval_metrics.get_run_dirs(summary_path_2, 'train',
                                                 p.runs)

          # Evaluate the metrics.
          outfile_prefix = os.path.join(p.metric_values_dir_permuted, '%s_%s' %
                                        (algo1, algo2), task) + '/'
          evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED)
          evaluator.write_metric_params(outfile_prefix)
          evaluator.evaluate_with_permutations(
              run_dirs_1=run_dirs_1,
              run_dirs_2=run_dirs_2,
              outfile_prefix=outfile_prefix,
              n_permutations=n_permutations_per_worker,
              permutation_start_idx=(n_permutations_per_worker * i_worker),
              random_seed=i_worker)
예제 #2
0
def evaluate_metrics_on_bootstrapped_runs():
    """Evaluates metrics on bootstrapped runs, for across-run metrics only."""
    gin_bindings = [
        'eval_metrics.Evaluator.metrics = [@IqrAcrossRuns/singleton(), '
        '@LowerCVaROnAcross/singleton()]'
    ]
    n_bootstraps_per_worker = int(p.n_random_samples / p.n_worker)

    # Parse gin config.
    gin.parse_config_files_and_bindings([p.gin_file], gin_bindings)

    for algo in p.algos:
        for task in p.tasks:
            for i_worker in range(p.n_worker):
                # Get the subdirectories corresponding to each run.
                summary_path = os.path.join(p.data_dir, algo, task)
                run_dirs = eval_metrics.get_run_dirs(summary_path, 'train',
                                                     p.runs)

                # Evaluate results.
                outfile_prefix = os.path.join(p.metric_values_dir_bootstrapped,
                                              algo, task) + '/'
                evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED)
                evaluator.write_metric_params(outfile_prefix)
                evaluator.evaluate_with_bootstraps(
                    run_dirs=run_dirs,
                    outfile_prefix=outfile_prefix,
                    n_bootstraps=n_bootstraps_per_worker,
                    bootstrap_start_idx=(n_bootstraps_per_worker * i_worker),
                    random_seed=i_worker)
 def test_get_run_dirs(self, selected_runs, expected_dirs):
     run_dirs = eval_metrics.get_run_dirs(self.test_data_dir, 'train',
                                          selected_runs)
     run_dirs.sort()
     expected = [
         os.path.join(self.test_data_dir, d, 'train') for d in expected_dirs
     ]
     self.assertEqual(run_dirs, expected)
def evaluate_metrics():
  """Evaluates metrics specified in the gin config."""
  # Parse gin config.
  gin.parse_config_files_and_bindings([p.gin_file], [])

  for algo in p.algos:
    for task in p.tasks:
      # Get the subdirectories corresponding to each run.
      summary_path = os.path.join(p.data_dir, algo, task)
      run_dirs = eval_metrics.get_run_dirs(summary_path, 'train', p.runs)

      # Evaluate metrics.
      outfile_prefix = os.path.join(p.metric_values_dir, algo, task) + '/'
      evaluator = eval_metrics.Evaluator(metrics=gin.REQUIRED)
      evaluator.write_metric_params(outfile_prefix)
      evaluator.evaluate(run_dirs=run_dirs, outfile_prefix=outfile_prefix)