def test_writer(self, mock_is_primary_func: mock.MagicMock) -> None: """ Tests that the tensorboard writer calls SummaryWriter with the model iff is_primary() is True. """ mock_summary_writer = mock.create_autospec(SummaryWriter, instance=True) task = get_test_classy_task() task.prepare() for primary in [False, True]: mock_is_primary_func.return_value = primary model_configs = get_test_model_configs() for model_config in model_configs: model = build_model(model_config) task.base_model = model # create a model tensorboard hook model_tensorboard_hook = ModelTensorboardHook( mock_summary_writer) model_tensorboard_hook.on_start(task) if primary: # SummaryWriter should have been init-ed with the correct # add_graph should be called once with model as the first arg mock_summary_writer.add_graph.assert_called_once() self.assertEqual( mock_summary_writer.add_graph.call_args[0][0], model) else: # add_graph shouldn't be called since is_primary() is False mock_summary_writer.add_graph.assert_not_called() mock_summary_writer.reset_mock()
def test_profiler( self, mock_summarize_profiler_info: mock.MagicMock, mock_profile_cls: mock.MagicMock, ) -> None: """ Tests that a profile instance is returned by the profiler and that the profiler actually ran. """ mock_summarize_profiler_info.return_value = "" mock_profile = mock.MagicMock() mock_profile_returned = mock.MagicMock() mock_profile.__enter__.return_value = mock_profile_returned mock_profile_cls.return_value = mock_profile for task in [get_test_classy_task(), get_test_classy_video_task()]: task.prepare() # create a model tensorboard hook profiler_hook = ProfilerHook() with self.assertLogs(): profiler_hook.on_start(task) # a new profile should be created with use_cuda=True mock_profile_cls.assert_called_once_with(use_cuda=True) mock_profile_cls.reset_mock() # summarize_profiler_info should have been called once with the profile mock_summarize_profiler_info.assert_called_once() profile = mock_summarize_profiler_info.call_args[0][0] mock_summarize_profiler_info.reset_mock() self.assertEqual(profile, mock_profile_returned)
def test_model_complexity_hook(self) -> None: model_configs = get_test_model_configs() task = get_test_classy_task() task.prepare() # create a model complexity hook model_complexity_hook = ModelComplexityHook() for model_config in model_configs: model = build_model(model_config) task.base_model = model with self.assertLogs(): model_complexity_hook.on_start(task)
def test_model_complexity(self) -> None: """ Test that the number of parameters and the FLOPs are calcuated correctly. """ model_configs = get_test_model_configs() expected_mega_flops = [4122, 4274, 106152] expected_params = [25557032, 25028904, 43009448] local_variables = {} task = get_test_classy_task() task.prepare() # create a model complexity hook model_complexity_hook = ModelComplexityHook() for model_config, mega_flops, params in zip(model_configs, expected_mega_flops, expected_params): model = build_model(model_config) task.base_model = model with self.assertLogs() as log_watcher: model_complexity_hook.on_start(task, local_variables) # there should be 2 log statements generated self.assertEqual(len(log_watcher.output), 2) # first statement - either the MFLOPs or a warning if mega_flops is not None: match = re.search( r"FLOPs for forward pass: (?P<mega_flops>[-+]?\d*\.\d+|\d+) MFLOPs", log_watcher.output[0], ) self.assertIsNotNone(match) self.assertEqual(mega_flops, float(match.group("mega_flops"))) else: self.assertIn("Model contains unsupported modules", log_watcher.output[0]) # second statement match = re.search( r"Number of parameters in model: (?P<params>[-+]?\d*\.\d+|\d+)", log_watcher.output[1], ) self.assertIsNotNone(match) self.assertEqual(params, float(match.group("params")))
def test_progress_bar(self, mock_is_master: mock.MagicMock, mock_progressbar_pkg: mock.MagicMock) -> None: """ Tests that the progress bar is created, updated and destroyed correctly. """ mock_progress_bar = mock.create_autospec(progressbar.ProgressBar, instance=True) mock_progressbar_pkg.ProgressBar.return_value = mock_progress_bar mock_is_master.return_value = True task = get_test_classy_task() task.prepare() task.advance_phase() num_batches = task.num_batches_per_phase # make sure we are checking at least one batch self.assertGreater(num_batches, 0) # create a progress bar hook progress_bar_hook = ProgressBarHook() # progressbar.ProgressBar should be init-ed with num_batches progress_bar_hook.on_phase_start(task) mock_progressbar_pkg.ProgressBar.assert_called_once_with(num_batches) mock_progress_bar.start.assert_called_once_with() mock_progress_bar.start.reset_mock() mock_progressbar_pkg.ProgressBar.reset_mock() # on_step should update the progress bar correctly for i in range(num_batches): progress_bar_hook.on_step(task) mock_progress_bar.update.assert_called_once_with(i + 1) mock_progress_bar.update.reset_mock() # check that even if on_step is called again, the progress bar is # only updated with num_batches for _ in range(num_batches): progress_bar_hook.on_step(task) mock_progress_bar.update.assert_called_once_with(num_batches) mock_progress_bar.update.reset_mock() # finish should be called on the progress bar progress_bar_hook.on_phase_end(task) mock_progress_bar.finish.assert_called_once_with() mock_progress_bar.finish.reset_mock() # check that even if the progress bar isn't created, the code doesn't # crash progress_bar_hook = ProgressBarHook() try: progress_bar_hook.on_step(task) progress_bar_hook.on_phase_end(task) except Exception as e: self.fail( "Received Exception when on_phase_start() isn't called: {}". format(e)) mock_progressbar_pkg.ProgressBar.assert_not_called() # check that a progress bar is not created if is_master() returns False mock_is_master.return_value = False progress_bar_hook = ProgressBarHook() try: progress_bar_hook.on_phase_start(task) progress_bar_hook.on_step(task) progress_bar_hook.on_phase_end(task) except Exception as e: self.fail( "Received Exception when is_master() is False: {}".format(e)) self.assertIsNone(progress_bar_hook.progress_bar) mock_progressbar_pkg.ProgressBar.assert_not_called()
def test_time_metrics( self, mock_get_rank: mock.MagicMock, mock_report_str: mock.MagicMock, mock_time: mock.MagicMock, ) -> None: """ Tests that the progress bar is created, updated and destroyed correctly. """ rank = 5 mock_get_rank.return_value = rank mock_report_str.return_value = "" local_variables = {} for log_freq, train in product([5, None], [True, False]): # create a time metrics hook time_metrics_hook = TimeMetricsHook(log_freq=log_freq) phase_type = "train" if train else "test" task = get_test_classy_task() task.prepare() task.train = train # on_phase_start() should set the start time and perf_stats start_time = 1.2 mock_time.return_value = start_time time_metrics_hook.on_phase_start(task, local_variables) self.assertEqual(time_metrics_hook.start_time, start_time) self.assertTrue( isinstance(local_variables.get("perf_stats"), PerfStats)) # test that the code doesn't raise an exception if losses is empty try: time_metrics_hook.on_phase_end(task, local_variables) except Exception as e: self.fail("Received Exception when losses is []: {}".format(e)) # check that _log_performance_metrics() is called after on_step() # every log_freq batches and after on_phase_end() with mock.patch.object(time_metrics_hook, "_log_performance_metrics") as mock_fn: num_batches = 20 for i in range(num_batches): task.losses = list(range(i)) time_metrics_hook.on_step(task, local_variables) if log_freq is not None and i and i % log_freq == 0: mock_fn.assert_called_with(task, local_variables) mock_fn.reset_mock() continue mock_fn.assert_not_called() time_metrics_hook.on_phase_end(task, local_variables) mock_fn.assert_called_with(task, local_variables) task.losses = [0.23, 0.45, 0.34, 0.67] end_time = 10.4 avg_batch_time_ms = 2.3 * 1000 mock_time.return_value = end_time # test _log_performance_metrics() with self.assertLogs() as log_watcher: time_metrics_hook._log_performance_metrics( task, local_variables) # there should 2 be info logs for train and 1 for test self.assertEqual(len(log_watcher.output), 2 if train else 1) self.assertTrue( all(log_record.levelno == logging.INFO for log_record in log_watcher.records)) match = re.search( (r"Average {} batch time \(ms\) for {} batches: " r"(?P<avg_batch_time>[-+]?\d*\.\d+|\d+)").format( phase_type, len(task.losses)), log_watcher.output[0], ) self.assertIsNotNone(match) self.assertAlmostEqual(avg_batch_time_ms, float(match.group("avg_batch_time")), places=4) if train: self.assertIn(f"Train step time breakdown (rank {rank})", log_watcher.output[1]) # if on_phase_start() is not called, 2 warnings should be logged # create a new time metrics hook local_variables = {} time_metrics_hook_new = TimeMetricsHook() with self.assertLogs() as log_watcher: time_metrics_hook_new.on_phase_end(task, local_variables) self.assertEqual(len(log_watcher.output), 2) self.assertTrue( all(log_record.levelno == logging.WARN for log_record in log_watcher.records))