예제 #1
0
  def test_calls_with_no_input_args(self, mock_compute_validation,
                                    mock_initialize):
    on_loop_start_fn = training_loop._create_on_loop_start_fn()
    on_loop_start_input = 'input'
    actual_state, actual_round = on_loop_start_fn(on_loop_start_input)
    mock_initialize.assert_not_called()
    mock_compute_validation.assert_not_called()

    expected_state = on_loop_start_input
    expected_round = 1
    self.assertEqual(actual_state, expected_state)
    self.assertEqual(actual_round, expected_round)
예제 #2
0
  def test_calls_with_only_validation_fn(self, mock_compute_validation,
                                         mock_initialize):
    validation_fn = mock.MagicMock()
    on_loop_start_fn = training_loop._create_on_loop_start_fn(
        validation_fn=validation_fn)
    on_loop_start_input = 'input'
    actual_state, actual_round = on_loop_start_fn(on_loop_start_input)

    mock_initialize.assert_not_called()
    expected_state = on_loop_start_input
    expected_round = 1
    mock_compute_validation.assert_called_once_with(expected_state,
                                                    expected_round - 1,
                                                    validation_fn)
    self.assertEqual(actual_state, expected_state)
    self.assertEqual(actual_round, expected_round)
예제 #3
0
 def test_calls_with_only_checkpoint_manager_and_non_zero_checkpoint_round(
     self, mock_compute_validation, mock_initialize):
   file_checkpoint_manager = mock.create_autospec(
       checkpoint_manager.FileCheckpointManager)
   expected_state = 'state'
   expected_round = 3
   mock_initialize.return_value = (expected_state, expected_round)
   on_loop_start_fn = training_loop._create_on_loop_start_fn(
       file_checkpoint_manager=file_checkpoint_manager)
   on_loop_start_input = 'input'
   actual_state, actual_round = on_loop_start_fn(on_loop_start_input)
   mock_initialize.assert_called_once_with(on_loop_start_input,
                                           file_checkpoint_manager)
   mock_compute_validation.assert_not_called()
   file_checkpoint_manager.save_checkpoint.assert_not_called()
   self.assertEqual(actual_state, expected_state)
   self.assertEqual(actual_round, expected_round)
예제 #4
0
  def test_calls_with_only_metrics_managers(self, mock_compute_validation,
                                            mock_initialize):
    metric_manager1 = mock.create_autospec(metrics_manager.MetricsManager)
    metric_manager2 = mock.create_autospec(metrics_manager.MetricsManager)
    metrics_managers = [metric_manager1, metric_manager2]
    on_loop_start_fn = training_loop._create_on_loop_start_fn(
        metrics_managers=metrics_managers)
    on_loop_start_input = 'input'
    actual_state, actual_round = on_loop_start_fn(on_loop_start_input)

    mock_initialize.assert_not_called()
    mock_compute_validation.assert_not_called()
    expected_state = on_loop_start_input
    expected_round = 1
    for metr_mngr in metrics_managers:
      metr_mngr.clear_metrics.assert_called_once_with(expected_round - 1)
      metr_mngr.save_metrics.assert_not_called()
    self.assertEqual(actual_state, expected_state)
    self.assertEqual(actual_round, expected_round)
예제 #5
0
 def test_calls_with_metrics_managers_and_validation_fn(
     self, mock_compute_validation, mock_initialize):
   metric_manager1 = mock.create_autospec(metrics_manager.MetricsManager)
   metric_manager2 = mock.create_autospec(metrics_manager.MetricsManager)
   metrics_managers = [metric_manager1, metric_manager2]
   validation_fn = mock.MagicMock()
   metrics = {'metric1': 2}
   mock_compute_validation.return_value = metrics
   on_loop_start_fn = training_loop._create_on_loop_start_fn(
       metrics_managers=metrics_managers, validation_fn=validation_fn)
   on_loop_start_input = 'input'
   actual_state, actual_round = on_loop_start_fn(on_loop_start_input)
   mock_initialize.assert_not_called()
   expected_state = on_loop_start_input
   expected_round = 1
   mock_compute_validation.assert_called_once_with(expected_state,
                                                   expected_round - 1,
                                                   validation_fn)
   for metr_mngr in metrics_managers:
     metr_mngr.clear_metrics.assert_called_once_with(0)
     metr_mngr.save_metrics.assert_called_once_with(metrics, 0)
   self.assertEqual(actual_state, expected_state)
   self.assertEqual(actual_round, expected_round)