Exemplo n.º 1
0
    def _train_and_eval_local(self,
                              params,
                              check_output_values=False,
                              max_final_loss=10.,
                              skip_eval=False,
                              use_test_preprocessor=True):
        # TODO(reedwm): check_output_values should default to True and be enabled
        # on every test. Currently, if check_output_values=True and the calls to
        # tf.set_random_seed(...) and np.seed(...) are passed certain seed values in
        # benchmark_cnn.py, then most tests will fail. This indicates the tests
        # are brittle and could fail with small changes when
        # check_output_values=True, so check_output_values defaults to False for
        # now.

        def run_fn(run_type, inner_params):
            del run_type
            if use_test_preprocessor:
                return [
                    self._run_benchmark_cnn_with_black_and_white_images(
                        inner_params)
                ]
            else:
                return [self._run_benchmark_cnn(inner_params)]

        return test_util.train_and_eval(
            self,
            run_fn,
            params,
            check_output_values=check_output_values,
            max_final_loss=max_final_loss,
            skip_eval=skip_eval)
  def _test_distributed(self,
                        test_name,
                        num_workers,
                        num_ps,
                        params,
                        num_controllers=0,
                        check_output_values=False,
                        skip=None):
    # TODO(reedwm): check_output_values should default to True and be enabled
    # on every test. See the TODO in benchmark_cnn_test.py.
    def run_fn(run_type, inner_params):
      output_dir_path = os.path.join(test_name, run_type)
      if run_type == 'Evaluation':
        # Distributed evaluation is not supported, so we use a single process.
        # We still must spawn another process, because if we evaluate in the
        # current process, it would allocate the GPU memory causing future test
        # methods to fail.
        if inner_params.variable_update == 'distributed_replicated':
          inner_params = inner_params._replace(variable_update='replicated')
        return _spawn_benchmark_processes(
            output_dir_path, num_workers=1, num_ps=0, num_controllers=0,
            params=inner_params)
      else:
        return _spawn_benchmark_processes(output_dir_path, num_workers, num_ps,
                                          num_controllers, inner_params)

    return test_util.train_and_eval(self, run_fn, params,
                                    check_output_values=check_output_values,
                                    skip=skip)