Ejemplo n.º 1
0
    def test_training_fns_called_with_tuple_next(self, total_rounds):
        training_process = mock.create_autospec(
            iterative_process.IterativeProcess)
        training_process.initialize.return_value = 'initialize'
        training_process.next.return_value = ('update', {'metric': 1.0})
        training_selection_fn = mock.MagicMock()
        training_selection_fn.return_value = [0]

        training_loop.run_training_process(
            training_process=training_process,
            training_selection_fn=training_selection_fn,
            total_rounds=total_rounds)

        training_process.initialize.assert_called_once()
        expected_calls = []
        for round_num in range(1, total_rounds + 1):
            call = mock.call(round_num)
            expected_calls.append(call)
        self.assertEqual(training_selection_fn.call_args_list, expected_calls)
        expected_calls = []
        for round_num in range(1, total_rounds + 1):
            if round_num == 1:
                state = 'initialize'
            else:
                state = 'update'
            call = mock.call(state, [0])
            expected_calls.append(call)
        self.assertEqual(training_process.next.call_args_list, expected_calls)
Ejemplo n.º 2
0
    def test_program_state_manager_called_with_existing_program_state(
            self, version, total_rounds):
        training_process = mock.create_autospec(
            iterative_process.IterativeProcess)
        training_process.initialize.return_value = 'initialize'
        training_process.next.return_value = ('update', {'metric': 1.0})
        training_selection_fn = mock.MagicMock()
        program_state_manager = mock.AsyncMock()
        program_state_manager.load_latest.return_value = ('program_state',
                                                          version)

        training_loop.run_training_process(
            training_process=training_process,
            training_selection_fn=training_selection_fn,
            total_rounds=total_rounds,
            program_state_manager=program_state_manager,
            rounds_per_saving_program_state=1)

        program_state_manager.load_latest.assert_called_once()
        expected_calls = []
        for round_num in range(version + 1, total_rounds + 1):
            call = mock.call('update', round_num)
            expected_calls.append(call)
        self.assertEqual(program_state_manager.save.call_args_list,
                         expected_calls)
Ejemplo n.º 3
0
    def test_evaluation_fns_called(self, total_rounds, rounds_per_evaluation):
        training_process = mock.create_autospec(
            iterative_process.IterativeProcess)
        training_process.initialize.return_value = 'initialize'
        training_process.next.return_value = ('update', {'metric': 1.0})
        training_selection_fn = mock.MagicMock()
        evaluation_fn = mock.create_autospec(computation_base.Computation,
                                             return_value={'metric': 1.0})
        evaluation_selection_fn = mock.MagicMock()
        evaluation_selection_fn.return_value = [0]

        training_loop.run_training_process(
            training_process=training_process,
            training_selection_fn=training_selection_fn,
            total_rounds=total_rounds,
            evaluation_fn=evaluation_fn,
            evaluation_selection_fn=evaluation_selection_fn,
            rounds_per_evaluation=rounds_per_evaluation)

        call = mock.call(0)
        expected_calls = [call]
        for round_num in range(1, total_rounds + 1):
            if round_num % rounds_per_evaluation == 0:
                call = mock.call(round_num)
                expected_calls.append(call)
        self.assertEqual(evaluation_selection_fn.call_args_list,
                         expected_calls)
        call = mock.call('initialize', [0])
        expected_calls = [call]
        for round_num in range(1, total_rounds + 1):
            if round_num % rounds_per_evaluation == 0:
                call = mock.call('update', [0])
                expected_calls.append(call)
        self.assertEqual(evaluation_fn.call_args_list, expected_calls)
Ejemplo n.º 4
0
    def test_metrics_managers_called_with_training_and_evaluation_time_10(
            self):
        training_process = mock.create_autospec(
            iterative_process.IterativeProcess)
        training_process.initialize.return_value = 'initialize'
        training_process.initialize.type_signature.return_value = _test_init_fn.type_signature
        training_process.next.return_value = ('update', {'metric': 1.0})
        training_process.next.type_signature = _test_next_fn.type_signature
        training_selection_fn = mock.MagicMock()
        evaluation_fn = mock.MagicMock()
        evaluation_fn.return_value = {'metric': 1.0}
        evaluation_fn.type_signature = _test_evaluation_fn.type_signature
        evaluation_selection_fn = mock.MagicMock()
        metrics_manager = mock.AsyncMock()

        with mock.patch('time.time') as mock_time:
            # Since absl.logging.info uses a call to time.time, we mock it out.
            mock_time.side_effect = [0.0, 10.0] * 3
            with mock.patch('absl.logging.info'):
                training_loop.run_training_process(
                    training_process=training_process,
                    training_selection_fn=training_selection_fn,
                    total_rounds=1,
                    evaluation_fn=evaluation_fn,
                    evaluation_selection_fn=evaluation_selection_fn,
                    metrics_managers=[metrics_manager])

        expected_calls = []
        metrics = collections.OrderedDict([
            ('evaluation/metric', 1.0),
            ('evaluation/evaluation_time_in_seconds', 10.0),
        ])
        metrics_type = computation_types.StructWithPythonType([
            ('evaluation/metric', tf.float32),
            ('evaluation/evaluation_time_in_seconds', tf.float32),
        ], collections.OrderedDict)
        call = mock.call(metrics, metrics_type, 0)
        expected_calls.append(call)
        metrics = collections.OrderedDict([
            ('metric', 1.0),
            ('training_time_in_seconds', 10.0),
            ('round_number', 1),
            ('evaluation/metric', 1.0),
            ('evaluation/evaluation_time_in_seconds', 10.0),
        ])
        metrics_type = computation_types.StructWithPythonType([
            ('metric', tf.float32),
            ('training_time_in_seconds', tf.float32),
            ('round_number', tf.int32),
            ('evaluation/metric', tf.float32),
            ('evaluation/evaluation_time_in_seconds', tf.float32),
        ], collections.OrderedDict)
        call = mock.call(metrics, metrics_type, 1)
        expected_calls.append(call)
        self.assertEqual(metrics_manager.release.call_args_list,
                         expected_calls)
Ejemplo n.º 5
0
    def test_metrics_managers_called_without_evaluation(self, total_rounds):
        training_process = mock.create_autospec(
            iterative_process.IterativeProcess)
        training_process.initialize.return_value = 'initialize'
        training_process.initialize.type_signature.return_value = _test_init_fn.type_signature
        training_process.next.return_value = ('update', {'metric': 1.0})
        training_process.next.type_signature = _test_next_fn.type_signature
        training_selection_fn = mock.MagicMock()
        metrics_manager_1 = mock.AsyncMock()
        metrics_manager_2 = mock.AsyncMock()
        metrics_manager_3 = mock.AsyncMock()
        metrics_managers = [
            metrics_manager_1, metrics_manager_2, metrics_manager_3
        ]

        training_loop.run_training_process(
            training_process=training_process,
            training_selection_fn=training_selection_fn,
            total_rounds=total_rounds,
            metrics_managers=metrics_managers)

        expected_calls = []
        for round_num in range(1, total_rounds + 1):
            metrics = collections.OrderedDict([
                ('metric', 1.0),
                ('training_time_in_seconds', mock.ANY),
                ('round_number', round_num),
            ])
            metrics_type = computation_types.StructWithPythonType([
                ('metric', tf.float32),
                ('training_time_in_seconds', tf.float32),
                ('round_number', tf.int32),
            ], collections.OrderedDict)
            call = mock.call(metrics, metrics_type, round_num)
            expected_calls.append(call)
        for metrics_manager in metrics_managers:
            self.assertEqual(metrics_manager.release.call_args_list,
                             expected_calls)