def test_writer(self, mock_is_primary_func: mock.MagicMock) -> None:
        """
        Tests that the tensorboard writer calls SummaryWriter with the model
        iff is_primary() is True.
        """
        mock_summary_writer = mock.create_autospec(SummaryWriter,
                                                   instance=True)

        task = get_test_classy_task()
        task.prepare()

        for primary in [False, True]:
            mock_is_primary_func.return_value = primary
            model_configs = get_test_model_configs()

            for model_config in model_configs:
                model = build_model(model_config)
                task.base_model = model

                # create a model tensorboard hook
                model_tensorboard_hook = ModelTensorboardHook(
                    mock_summary_writer)

                model_tensorboard_hook.on_start(task)

                if primary:
                    # SummaryWriter should have been init-ed with the correct
                    # add_graph should be called once with model as the first arg
                    mock_summary_writer.add_graph.assert_called_once()
                    self.assertEqual(
                        mock_summary_writer.add_graph.call_args[0][0], model)
                else:
                    # add_graph shouldn't be called since is_primary() is False
                    mock_summary_writer.add_graph.assert_not_called()
                mock_summary_writer.reset_mock()
    def test_profiler(
        self,
        mock_summarize_profiler_info: mock.MagicMock,
        mock_profile_cls: mock.MagicMock,
    ) -> None:
        """
        Tests that a profile instance is returned by the profiler
        and that the profiler actually ran.
        """
        mock_summarize_profiler_info.return_value = ""

        mock_profile = mock.MagicMock()
        mock_profile_returned = mock.MagicMock()
        mock_profile.__enter__.return_value = mock_profile_returned
        mock_profile_cls.return_value = mock_profile

        for task in [get_test_classy_task(), get_test_classy_video_task()]:
            task.prepare()

            # create a model tensorboard hook
            profiler_hook = ProfilerHook()

            with self.assertLogs():
                profiler_hook.on_start(task)

            # a new profile should be created with use_cuda=True
            mock_profile_cls.assert_called_once_with(use_cuda=True)
            mock_profile_cls.reset_mock()

            # summarize_profiler_info should have been called once with the profile
            mock_summarize_profiler_info.assert_called_once()
            profile = mock_summarize_profiler_info.call_args[0][0]
            mock_summarize_profiler_info.reset_mock()
            self.assertEqual(profile, mock_profile_returned)
    def test_model_complexity_hook(self) -> None:
        model_configs = get_test_model_configs()

        task = get_test_classy_task()
        task.prepare()

        # create a model complexity hook
        model_complexity_hook = ModelComplexityHook()

        for model_config in model_configs:
            model = build_model(model_config)

            task.base_model = model

            with self.assertLogs():
                model_complexity_hook.on_start(task)
    def test_model_complexity(self) -> None:
        """
        Test that the number of parameters and the FLOPs are calcuated correctly.
        """
        model_configs = get_test_model_configs()
        expected_mega_flops = [4122, 4274, 106152]
        expected_params = [25557032, 25028904, 43009448]
        local_variables = {}

        task = get_test_classy_task()
        task.prepare()

        # create a model complexity hook
        model_complexity_hook = ModelComplexityHook()

        for model_config, mega_flops, params in zip(model_configs,
                                                    expected_mega_flops,
                                                    expected_params):
            model = build_model(model_config)

            task.base_model = model

            with self.assertLogs() as log_watcher:
                model_complexity_hook.on_start(task, local_variables)

            # there should be 2 log statements generated
            self.assertEqual(len(log_watcher.output), 2)

            # first statement - either the MFLOPs or a warning
            if mega_flops is not None:
                match = re.search(
                    r"FLOPs for forward pass: (?P<mega_flops>[-+]?\d*\.\d+|\d+) MFLOPs",
                    log_watcher.output[0],
                )
                self.assertIsNotNone(match)
                self.assertEqual(mega_flops, float(match.group("mega_flops")))
            else:
                self.assertIn("Model contains unsupported modules",
                              log_watcher.output[0])

            # second statement
            match = re.search(
                r"Number of parameters in model: (?P<params>[-+]?\d*\.\d+|\d+)",
                log_watcher.output[1],
            )
            self.assertIsNotNone(match)
            self.assertEqual(params, float(match.group("params")))
    def test_progress_bar(self, mock_is_master: mock.MagicMock,
                          mock_progressbar_pkg: mock.MagicMock) -> None:
        """
        Tests that the progress bar is created, updated and destroyed correctly.
        """
        mock_progress_bar = mock.create_autospec(progressbar.ProgressBar,
                                                 instance=True)
        mock_progressbar_pkg.ProgressBar.return_value = mock_progress_bar

        mock_is_master.return_value = True

        task = get_test_classy_task()
        task.prepare()
        task.advance_phase()

        num_batches = task.num_batches_per_phase
        # make sure we are checking at least one batch
        self.assertGreater(num_batches, 0)

        # create a progress bar hook
        progress_bar_hook = ProgressBarHook()

        # progressbar.ProgressBar should be init-ed with num_batches
        progress_bar_hook.on_phase_start(task)
        mock_progressbar_pkg.ProgressBar.assert_called_once_with(num_batches)
        mock_progress_bar.start.assert_called_once_with()
        mock_progress_bar.start.reset_mock()
        mock_progressbar_pkg.ProgressBar.reset_mock()

        # on_step should update the progress bar correctly
        for i in range(num_batches):
            progress_bar_hook.on_step(task)
            mock_progress_bar.update.assert_called_once_with(i + 1)
            mock_progress_bar.update.reset_mock()

        # check that even if on_step is called again, the progress bar is
        # only updated with num_batches
        for _ in range(num_batches):
            progress_bar_hook.on_step(task)
            mock_progress_bar.update.assert_called_once_with(num_batches)
            mock_progress_bar.update.reset_mock()

        # finish should be called on the progress bar
        progress_bar_hook.on_phase_end(task)
        mock_progress_bar.finish.assert_called_once_with()
        mock_progress_bar.finish.reset_mock()

        # check that even if the progress bar isn't created, the code doesn't
        # crash
        progress_bar_hook = ProgressBarHook()
        try:
            progress_bar_hook.on_step(task)
            progress_bar_hook.on_phase_end(task)
        except Exception as e:
            self.fail(
                "Received Exception when on_phase_start() isn't called: {}".
                format(e))
        mock_progressbar_pkg.ProgressBar.assert_not_called()

        # check that a progress bar is not created if is_master() returns False
        mock_is_master.return_value = False
        progress_bar_hook = ProgressBarHook()
        try:
            progress_bar_hook.on_phase_start(task)
            progress_bar_hook.on_step(task)
            progress_bar_hook.on_phase_end(task)
        except Exception as e:
            self.fail(
                "Received Exception when is_master() is False: {}".format(e))
        self.assertIsNone(progress_bar_hook.progress_bar)
        mock_progressbar_pkg.ProgressBar.assert_not_called()
    def test_time_metrics(
        self,
        mock_get_rank: mock.MagicMock,
        mock_report_str: mock.MagicMock,
        mock_time: mock.MagicMock,
    ) -> None:
        """
        Tests that the progress bar is created, updated and destroyed correctly.
        """
        rank = 5
        mock_get_rank.return_value = rank

        mock_report_str.return_value = ""
        local_variables = {}

        for log_freq, train in product([5, None], [True, False]):
            # create a time metrics hook
            time_metrics_hook = TimeMetricsHook(log_freq=log_freq)

            phase_type = "train" if train else "test"

            task = get_test_classy_task()
            task.prepare()
            task.train = train

            # on_phase_start() should set the start time and perf_stats
            start_time = 1.2
            mock_time.return_value = start_time
            time_metrics_hook.on_phase_start(task, local_variables)
            self.assertEqual(time_metrics_hook.start_time, start_time)
            self.assertTrue(
                isinstance(local_variables.get("perf_stats"), PerfStats))

            # test that the code doesn't raise an exception if losses is empty
            try:
                time_metrics_hook.on_phase_end(task, local_variables)
            except Exception as e:
                self.fail("Received Exception when losses is []: {}".format(e))

            # check that _log_performance_metrics() is called after on_step()
            # every log_freq batches and after on_phase_end()
            with mock.patch.object(time_metrics_hook,
                                   "_log_performance_metrics") as mock_fn:
                num_batches = 20

                for i in range(num_batches):
                    task.losses = list(range(i))
                    time_metrics_hook.on_step(task, local_variables)
                    if log_freq is not None and i and i % log_freq == 0:
                        mock_fn.assert_called_with(task, local_variables)
                        mock_fn.reset_mock()
                        continue
                    mock_fn.assert_not_called()

                time_metrics_hook.on_phase_end(task, local_variables)
                mock_fn.assert_called_with(task, local_variables)

            task.losses = [0.23, 0.45, 0.34, 0.67]

            end_time = 10.4
            avg_batch_time_ms = 2.3 * 1000
            mock_time.return_value = end_time

            # test _log_performance_metrics()
            with self.assertLogs() as log_watcher:
                time_metrics_hook._log_performance_metrics(
                    task, local_variables)

            # there should 2 be info logs for train and 1 for test
            self.assertEqual(len(log_watcher.output), 2 if train else 1)
            self.assertTrue(
                all(log_record.levelno == logging.INFO
                    for log_record in log_watcher.records))
            match = re.search(
                (r"Average {} batch time \(ms\) for {} batches: "
                 r"(?P<avg_batch_time>[-+]?\d*\.\d+|\d+)").format(
                     phase_type, len(task.losses)),
                log_watcher.output[0],
            )
            self.assertIsNotNone(match)
            self.assertAlmostEqual(avg_batch_time_ms,
                                   float(match.group("avg_batch_time")),
                                   places=4)
            if train:
                self.assertIn(f"Train step time breakdown (rank {rank})",
                              log_watcher.output[1])

            # if on_phase_start() is not called, 2 warnings should be logged
            # create a new time metrics hook
            local_variables = {}
            time_metrics_hook_new = TimeMetricsHook()

            with self.assertLogs() as log_watcher:
                time_metrics_hook_new.on_phase_end(task, local_variables)

            self.assertEqual(len(log_watcher.output), 2)
            self.assertTrue(
                all(log_record.levelno == logging.WARN
                    for log_record in log_watcher.records))