def _test_chechpoint_roundtrip(self, use_global_step: bool, num_checkpoints: Optional[int]=5): """ Performs saving/restoring roundtrip, either with or without using `global_step`. Note that if `global_step` is used the save will create one checkpoint for each value of the global step. """ with tempfile.TemporaryDirectory() as tmp_event_dir: # Create a variable and do several checkpoints with session_context(tf.Graph()) as session: dummy_var = self._create_dummy_variable(session) monitor_context = mon.MonitorContext() monitor_context.session = session if use_global_step: monitor_context.global_step_tensor = mon.create_global_step(session) monitor_task = mon.CheckpointTask(tmp_event_dir) for i in range(num_checkpoints): session.run(dummy_var.assign(i)) if use_global_step: session.run(monitor_context.global_step_tensor.assign(10 * i)) monitor_task(monitor_context) # Restore the session and read the variables. # Verify if the latest checkpoint was restored. with session_context(tf.Graph()) as session: dummy_var = self._create_dummy_variable(session) global_step_tensor = mon.create_global_step(session) if use_global_step else None mon.restore_session(session, tmp_event_dir) self.assertEqual(session.run(dummy_var), num_checkpoints - 1) if use_global_step: self.assertEqual(session.run(global_step_tensor), 10 * (num_checkpoints - 1))
def run_tensorboard_task(task_factory: Callable[[str], mon.BaseTensorBoardTask]) -> Dict: """ Runs a tensorboard monitoring task, reads summary from the created event file and returns decoded proto values in a dictionary :param task_factory: task factory that takes the event directory as an argument. """ summary = {} with tempfile.TemporaryDirectory() as tmp_event_dir: monitor_task = task_factory(tmp_event_dir) session = monitor_task.model.enquire_session()\ if monitor_task.model is not None else tf.Session() global_step_tensor = mon.create_global_step(session) monitor_task.with_flush_immediately(True) monitor_context = mon.MonitorContext() monitor_context.session = session monitor_context.global_step_tensor = global_step_tensor monitor_task(monitor_context) # There should be one event file in the temporary directory event_file = str(next(pathlib.Path(tmp_event_dir).iterdir().__iter__())) for e in tf.train.summary_iterator(event_file): for v in e.summary.value: summary[v.tag] = v return summary
def test_update_scipy_optimiser(self): """ Checks that the `update_optimiser` function sets the ScipyOptimizer state to the model parameters. Also checks that it sets the `optimiser_updated` flag to True. """ with session_context(tf.Graph()): model = create_linear_model() optimiser = gpflow.train.ScipyOptimizer() context = mon.MonitorContext() context.session = model.enquire_session() context.optimiser = optimiser w, b, var = model.w.value, model.b.value, model.var.value call_count = 0 def step_callback(*args, **kwargs): nonlocal model, optimiser, context, w, b, var, call_count context.optimiser_updated = False mon.update_optimiser(context, *args, **kwargs) w_new, b_new, var_new = model.enquire_session().run([model.w.unconstrained_tensor, model.b.unconstrained_tensor, model.var.unconstrained_tensor]) self.assertTrue(np.alltrue(np.not_equal(w, w_new))) self.assertTrue(np.alltrue(np.not_equal(b, b_new))) self.assertTrue(np.alltrue(np.not_equal(var, var_new))) self.assertTrue(context.optimiser_updated) call_count += 1 w, b, var = w_new, b_new, var_new optimiser.minimize(model, maxiter=10, step_callback=step_callback) self.assertGreater(call_count, 0)
def test_print_timings(self): """ Tests rate calculation for the PrintTimingsTask (doesn't test the actual printing) """ with session_context(tf.Graph()): monitor_task = mon.PrintTimingsTask() monitor_task._print_timings = mock.MagicMock() monitor_context = mon.MonitorContext() monitor_context.session = tf.Session() monitor_context.global_step_tensor = mon.create_global_step(monitor_context.session) monitor_context.init_global_step = 100 # First call monitor_context.iteration_no = 10 monitor_context.total_time = 20.0 monitor_context.optimisation_time = 16.0 monitor_context.session.run(monitor_context.global_step_tensor.assign(150)) monitor_task(monitor_context) args = monitor_task._print_timings.call_args_list[0][0] self.assertTupleEqual(args, (10, 150, 0.5, 0.5, 3.125, 3.125)) # Second call monitor_context.iteration_no = 24 monitor_context.total_time = 30.0 monitor_context.optimisation_time = 24.0 monitor_context.session.run(monitor_context.global_step_tensor.assign(196)) monitor_task(monitor_context) args = monitor_task._print_timings.call_args_list[1][0] self.assertTupleEqual(args, (24, 196, 0.8, 1.4, 4.0, 5.75))
def test_sleep_lower_bound(self): """ Test that the sleep task breaks the execution for at least the required period of time (up to certain precision). """ monitor_task = mon.SleepTask(0.2) start_time = mon.get_hr_time() monitor_task(mon.MonitorContext()) elapsed = mon.get_hr_time() - start_time self.assertGreater(elapsed, 0.1)
def test_call_condition(self): """ Tests that the execution of a task is controlled by the task condition. """ monitor_task = _DummyMonitorTask().with_condition( lambda context: context.iteration_no % 2 == 0) monitor_context = mon.MonitorContext() for monitor_context.iteration_no in range(5): monitor_task(monitor_context) self.assertEqual(monitor_task.call_count, 3)
def test_condition(self): """ Tests periodic condition based on the iteration number """ monitor_context = mon.MonitorContext() condition = mon.PeriodicIterationCondition(5) count = 0 for monitor_context.iteration_no in range(37): if condition(monitor_context): count += 1 self.assertEqual(count, 7)
def test_exit_condition(self): """ Tests that the execution of a task after the optimisation is finished is controlled by the exit condition. """ monitor_task1 = _DummyMonitorTask().with_exit_condition(False) monitor_task2 = _DummyMonitorTask().with_exit_condition(True) monitor_context = mon.MonitorContext() monitor_context.optimisation_finished = True monitor_task1(monitor_context) monitor_task2(monitor_context) self.assertEqual(monitor_task1.call_count, 0) self.assertEqual(monitor_task2.call_count, 1)
def test_call_timing(self, mock_timer): """ Test how a monitoring task keeps track of the last execution time and accumulated execution time. """ mock_timer.side_effect = [1.0, 3.5, 4.0, 6.0] monitor_task = _DummyMonitorTask() monitor_context = mon.MonitorContext() monitor_task(monitor_context) self.assertEqual(monitor_task.total_time, 2.5) self.assertEqual(monitor_task.last_call_time, 2.5) monitor_task(monitor_context) self.assertEqual(monitor_task.total_time, 4.5) self.assertEqual(monitor_task.last_call_time, 2.0)
def test_condition(self): """ Tests generic condition on arbitrary sequence """ sequence = iter([2, 5, 6, 9]) monitor_context = mon.MonitorContext() condition = mon.GenericCondition(lambda context: context.iteration_no, sequence) # Input data in the format # (expected condition._next, context.iteration_no, condition value) steps = [(2, 1, False), (2, 3, True), (5, 4, False), (5, 7, True), (9, 8, False)] for expected_next, iter_no, expected_result in steps: self.assertEqual(condition._next, expected_next) monitor_context.iteration_no = iter_no self.assertEqual(condition(monitor_context), expected_result)
def test_on_iteration_timing(self, mock_timer): """ Tests how the Monitor keeps track of the total running time and total optimisation time. """ mock_timer.side_effect = [1.0, 3.5, 4.0, 6.0, 7.0] context = mon.MonitorContext() monitor = mon.Monitor([], context=context) # In each call to the _on_iteration the timer is called twice - at the beginning and at # the end of the call. monitor._on_iteration() self.assertEqual(monitor._context.total_time, 2.5) self.assertEqual(monitor._context.optimisation_time, 2.5) monitor._on_iteration() self.assertEqual(monitor._context.total_time, 5.0) self.assertEqual(monitor._context.optimisation_time, 4.5)
def test_callback(self): callback = mock.MagicMock() monitor_task = mon.CallbackTask(callback) monitor_task(mon.MonitorContext()) self.assertEqual(callback.call_count, 1)