def _test_variable_update(self,
                              test_name,
                              num_workers,
                              num_ps,
                              params,
                              num_controllers=0):
        """Tests variables are updated correctly when the given params are used."""
        output_dir_path = os.path.join(test_name, 'variable_update')
        logs = _spawn_benchmark_processes(output_dir_path, num_workers, num_ps,
                                          num_controllers, params)
        actual_losses = []
        for worker_logs in logs:
            outputs = test_util.get_training_outputs_from_logs(
                worker_logs, params.print_training_accuracy)
            actual_losses.append([x.loss for x in outputs])

        inputs = test_util.get_fake_var_update_inputs()
        expected_losses = test_util.TestCNNModel().manually_compute_losses(
            inputs, num_workers, params)
        if params.variable_update == 'distributed_all_reduce':
            # In distributed all reduce, each step, the controller outputs the average
            # of the loss from each worker. So we modify expected losses accordingly.
            # E.g, we change [[1, 2], [4, 5]] to [[2.5, 3.5]]
            expected_losses = [[
                sum(losses) / num_workers for losses in zip(*expected_losses)
            ]]
        rtol = 3e-2 if params.use_fp16 else 1e-5
        for worker_actual_losses, worker_expected_losses in zip(
                actual_losses, expected_losses):
            self.assertAllClose(
                worker_actual_losses[:len(worker_expected_losses)],
                worker_expected_losses,
                rtol=rtol,
                atol=0.)
def run_with_test_model(params):
  """Runs tf_cnn_benchmarks with a test model."""
  model = test_util.TestCNNModel()
  inputs = test_util.get_fake_var_update_inputs()
  with test_util.monkey_patch(benchmark_cnn,
                              LOSS_AND_ACCURACY_DIGITS_TO_SHOW=15):
    bench = benchmark_cnn.BenchmarkCNN(params, dataset=test_util.TestDataSet(),
                                       model=model)
    # The test model does not use labels when computing loss, so the label
    # values do not matter as long as it's the right shape.
    labels = np.array([1] * inputs.shape[0])
    bench.image_preprocessor.set_fake_data(inputs, labels)
    bench.run()
Beispiel #3
0
  def _test_variable_update(self, params):
    """Tests variables are updated correctly when the given params are used.

    A BenchmarkCNN is created with a TestCNNModel, and is run with some scalar
    images. The losses are then compared with the losses obtained with
    TestCNNModel().manually_compute_losses()

    Args:
      params: a Params tuple used to create BenchmarkCNN.
    """
    inputs = test_util.get_fake_var_update_inputs()
    actual_losses = self._get_benchmark_cnn_losses(inputs, params)
    expected_losses, = test_util.TestCNNModel().manually_compute_losses(
        inputs, 1, params)
    rtol = 3e-2 if params.use_fp16 else 1e-5
    self.assertAllClose(actual_losses[:len(expected_losses)], expected_losses,
                        rtol=rtol, atol=0.)
Beispiel #4
0
  def _get_benchmark_cnn_losses(self, inputs, params):
    """Returns the losses of BenchmarkCNN on the given inputs and params."""
    logs = []
    model = test_util.TestCNNModel()
    with test_util.monkey_patch(benchmark_cnn,
                                log_fn=test_util.print_and_add_to_list(logs),
                                LOSS_AND_ACCURACY_DIGITS_TO_SHOW=15):
      bench = benchmark_cnn.BenchmarkCNN(
          params, dataset=test_util.TestDataSet(), model=model)
      # The test model does not use labels when computing loss, so the label
      # values do not matter as long as it's the right shape.
      labels = np.array([1] * inputs.shape[0])
      bench.image_preprocessor.set_fake_data(inputs, labels)
      bench.run()

    outputs = test_util.get_training_outputs_from_logs(
        logs, params.print_training_accuracy)
    return [x.loss for x in outputs]