コード例 #1
0
ファイル: test_monitor.py プロジェクト: kekeblom/GPflow
    def _test_chechpoint_roundtrip(self, use_global_step: bool, num_checkpoints: Optional[int]=5):
        """
        Performs saving/restoring roundtrip, either with or without using `global_step`.
        Note that if `global_step` is used the save will create one checkpoint for each value
        of the global step.
        """

        with tempfile.TemporaryDirectory() as tmp_event_dir:

            # Create a variable and do several checkpoints
            with session_context(tf.Graph()) as session:
                dummy_var = self._create_dummy_variable(session)
                monitor_context = mon.MonitorContext()
                monitor_context.session = session
                if use_global_step:
                    monitor_context.global_step_tensor = mon.create_global_step(session)
                monitor_task = mon.CheckpointTask(tmp_event_dir)

                for i in range(num_checkpoints):
                    session.run(dummy_var.assign(i))
                    if use_global_step:
                        session.run(monitor_context.global_step_tensor.assign(10 * i))
                    monitor_task(monitor_context)

            # Restore the session and read the variables.
            # Verify if the latest checkpoint was restored.
            with session_context(tf.Graph()) as session:
                dummy_var = self._create_dummy_variable(session)
                global_step_tensor = mon.create_global_step(session) if use_global_step else None
                mon.restore_session(session, tmp_event_dir)
                self.assertEqual(session.run(dummy_var), num_checkpoints - 1)
                if use_global_step:
                    self.assertEqual(session.run(global_step_tensor), 10 * (num_checkpoints - 1))
コード例 #2
0
ファイル: test_monitor.py プロジェクト: kekeblom/GPflow
def run_tensorboard_task(task_factory: Callable[[str], mon.BaseTensorBoardTask]) -> Dict:
    """
    Runs a tensorboard monitoring task, reads summary from the created event file and returns
    decoded proto values in a dictionary
    :param task_factory: task factory that takes the event directory as an argument.
    """

    summary = {}

    with tempfile.TemporaryDirectory() as tmp_event_dir:

        monitor_task = task_factory(tmp_event_dir)

        session = monitor_task.model.enquire_session()\
            if monitor_task.model is not None else tf.Session()
        global_step_tensor = mon.create_global_step(session)

        monitor_task.with_flush_immediately(True)

        monitor_context = mon.MonitorContext()
        monitor_context.session = session
        monitor_context.global_step_tensor = global_step_tensor

        monitor_task(monitor_context)

        # There should be one event file in the temporary directory
        event_file = str(next(pathlib.Path(tmp_event_dir).iterdir().__iter__()))

        for e in tf.train.summary_iterator(event_file):
            for v in e.summary.value:
                summary[v.tag] = v

    return summary
コード例 #3
0
    def test_update_scipy_optimiser(self):
        """
        Checks that the `update_optimiser` function sets the ScipyOptimizer state to the model
        parameters. Also checks that it sets the `optimiser_updated` flag to True.
        """

        with session_context(tf.Graph()):
            model = create_linear_model()
            optimiser = gpflow.train.ScipyOptimizer()
            context = mon.MonitorContext()
            context.session = model.enquire_session()
            context.optimiser = optimiser
            w, b, var = model.w.value, model.b.value, model.var.value
            call_count = 0

            def step_callback(*args, **kwargs):
                nonlocal model, optimiser, context, w, b, var, call_count
                context.optimiser_updated = False
                mon.update_optimiser(context, *args, **kwargs)
                w_new, b_new, var_new = model.enquire_session().run([model.w.unconstrained_tensor,
                                                                     model.b.unconstrained_tensor,
                                                                     model.var.unconstrained_tensor])
                self.assertTrue(np.alltrue(np.not_equal(w, w_new)))
                self.assertTrue(np.alltrue(np.not_equal(b, b_new)))
                self.assertTrue(np.alltrue(np.not_equal(var, var_new)))
                self.assertTrue(context.optimiser_updated)
                call_count += 1
                w, b, var = w_new, b_new, var_new

            optimiser.minimize(model, maxiter=10, step_callback=step_callback)
            self.assertGreater(call_count, 0)
コード例 #4
0
ファイル: test_monitor.py プロジェクト: kekeblom/GPflow
    def test_print_timings(self):
        """
        Tests rate calculation for the PrintTimingsTask (doesn't test the actual printing)
        """
        with session_context(tf.Graph()):
            monitor_task = mon.PrintTimingsTask()
            monitor_task._print_timings = mock.MagicMock()
            monitor_context = mon.MonitorContext()
            monitor_context.session = tf.Session()
            monitor_context.global_step_tensor = mon.create_global_step(monitor_context.session)
            monitor_context.init_global_step = 100

            # First call
            monitor_context.iteration_no = 10
            monitor_context.total_time = 20.0
            monitor_context.optimisation_time = 16.0
            monitor_context.session.run(monitor_context.global_step_tensor.assign(150))
            monitor_task(monitor_context)
            args = monitor_task._print_timings.call_args_list[0][0]
            self.assertTupleEqual(args, (10, 150, 0.5, 0.5, 3.125, 3.125))

            # Second call
            monitor_context.iteration_no = 24
            monitor_context.total_time = 30.0
            monitor_context.optimisation_time = 24.0
            monitor_context.session.run(monitor_context.global_step_tensor.assign(196))
            monitor_task(monitor_context)
            args = monitor_task._print_timings.call_args_list[1][0]
            self.assertTupleEqual(args, (24, 196, 0.8, 1.4, 4.0, 5.75))
コード例 #5
0
ファイル: test_monitor.py プロジェクト: kekeblom/GPflow
 def test_sleep_lower_bound(self):
     """
     Test that the sleep task breaks the execution for at least the required period of time
     (up to certain precision).
     """
     monitor_task = mon.SleepTask(0.2)
     start_time = mon.get_hr_time()
     monitor_task(mon.MonitorContext())
     elapsed = mon.get_hr_time() - start_time
     self.assertGreater(elapsed, 0.1)
コード例 #6
0
ファイル: test_monitor.py プロジェクト: kekeblom/GPflow
 def test_call_condition(self):
     """
     Tests that the execution of a task is controlled by the task condition.
     """
     monitor_task = _DummyMonitorTask().with_condition(
         lambda context: context.iteration_no % 2 == 0)
     monitor_context = mon.MonitorContext()
     for monitor_context.iteration_no in range(5):
         monitor_task(monitor_context)
     self.assertEqual(monitor_task.call_count, 3)
コード例 #7
0
ファイル: test_monitor.py プロジェクト: kekeblom/GPflow
 def test_condition(self):
     """
     Tests periodic condition based on the iteration number
     """
     monitor_context = mon.MonitorContext()
     condition = mon.PeriodicIterationCondition(5)
     count = 0
     for monitor_context.iteration_no in range(37):
         if condition(monitor_context):
             count += 1
     self.assertEqual(count, 7)
コード例 #8
0
ファイル: test_monitor.py プロジェクト: kekeblom/GPflow
 def test_exit_condition(self):
     """
     Tests that the execution of a task after the optimisation is finished is controlled by
     the exit condition.
     """
     monitor_task1 = _DummyMonitorTask().with_exit_condition(False)
     monitor_task2 = _DummyMonitorTask().with_exit_condition(True)
     monitor_context = mon.MonitorContext()
     monitor_context.optimisation_finished = True
     monitor_task1(monitor_context)
     monitor_task2(monitor_context)
     self.assertEqual(monitor_task1.call_count, 0)
     self.assertEqual(monitor_task2.call_count, 1)
コード例 #9
0
ファイル: test_monitor.py プロジェクト: kekeblom/GPflow
 def test_call_timing(self, mock_timer):
     """
     Test how a monitoring task keeps track of the last execution time and accumulated execution
     time.
     """
     mock_timer.side_effect = [1.0, 3.5, 4.0, 6.0]
     monitor_task = _DummyMonitorTask()
     monitor_context = mon.MonitorContext()
     monitor_task(monitor_context)
     self.assertEqual(monitor_task.total_time, 2.5)
     self.assertEqual(monitor_task.last_call_time, 2.5)
     monitor_task(monitor_context)
     self.assertEqual(monitor_task.total_time, 4.5)
     self.assertEqual(monitor_task.last_call_time, 2.0)
コード例 #10
0
ファイル: test_monitor.py プロジェクト: kekeblom/GPflow
 def test_condition(self):
     """
     Tests generic condition on arbitrary sequence
     """
     sequence = iter([2, 5, 6, 9])
     monitor_context = mon.MonitorContext()
     condition = mon.GenericCondition(lambda context: context.iteration_no, sequence)
     # Input data in the format
     # (expected condition._next, context.iteration_no, condition value)
     steps = [(2, 1, False), (2, 3, True), (5, 4, False), (5, 7, True), (9, 8, False)]
     for expected_next, iter_no, expected_result in steps:
         self.assertEqual(condition._next, expected_next)
         monitor_context.iteration_no = iter_no
         self.assertEqual(condition(monitor_context), expected_result)
コード例 #11
0
 def test_on_iteration_timing(self, mock_timer):
     """
     Tests how the Monitor keeps track of the total running time and total optimisation time.
     """
     mock_timer.side_effect = [1.0, 3.5, 4.0, 6.0, 7.0]
     context = mon.MonitorContext()
     monitor = mon.Monitor([], context=context)
     # In each call to the _on_iteration the timer is called twice - at the beginning and at
     # the end of the call.
     monitor._on_iteration()
     self.assertEqual(monitor._context.total_time, 2.5)
     self.assertEqual(monitor._context.optimisation_time, 2.5)
     monitor._on_iteration()
     self.assertEqual(monitor._context.total_time, 5.0)
     self.assertEqual(monitor._context.optimisation_time, 4.5)
コード例 #12
0
ファイル: test_monitor.py プロジェクト: kekeblom/GPflow
    def test_callback(self):

        callback = mock.MagicMock()
        monitor_task = mon.CallbackTask(callback)
        monitor_task(mon.MonitorContext())
        self.assertEqual(callback.call_count, 1)