def test_writer(self, mock_is_primary_func: mock.MagicMock) -> None:
        """
        Tests that the tensorboard writer calls SummaryWriter with the model
        iff is_primary() is True.
        """
        mock_summary_writer = mock.create_autospec(SummaryWriter,
                                                   instance=True)

        task = get_test_classy_task()
        task.prepare()

        for primary in [False, True]:
            mock_is_primary_func.return_value = primary
            model_configs = get_test_model_configs()

            for model_config in model_configs:
                model = build_model(model_config)
                task.base_model = model

                # create a model tensorboard hook
                model_tensorboard_hook = ModelTensorboardHook(
                    mock_summary_writer)

                model_tensorboard_hook.on_start(task)

                if primary:
                    # SummaryWriter should have been init-ed with the correct
                    # add_graph should be called once with model as the first arg
                    mock_summary_writer.add_graph.assert_called_once()
                    self.assertEqual(
                        mock_summary_writer.add_graph.call_args[0][0], model)
                else:
                    # add_graph shouldn't be called since is_primary() is False
                    mock_summary_writer.add_graph.assert_not_called()
                mock_summary_writer.reset_mock()
 def test_get_model_dummy_input(self):
     for config in get_test_model_configs():
         model = build_model(
             config)  # pass in a dummy model for the cuda check
         batchsize = 8
         # input_key is list
         input_key = ["audio", "video"]
         input_shape = [[3, 40, 100], [4, 16, 223,
                                       223]]  # dummy input shapes
         result = util.get_model_dummy_input(model, input_shape, input_key,
                                             batchsize)
         self.assertEqual(result.keys(), {"audio", "video"})
         for i in range(len(input_key)):
             self.assertEqual(result[input_key[i]].size(),
                              tuple([batchsize] + input_shape[i]))
         # input_key is string
         input_key = "video"
         input_shape = [4, 16, 223, 223]
         result = util.get_model_dummy_input(model, input_shape, input_key,
                                             batchsize)
         self.assertEqual(result.keys(), {"video"})
         self.assertEqual(result[input_key].size(),
                          tuple([batchsize] + input_shape))
         # input_key is None
         input_key = None
         input_shape = [4, 16, 223, 223]
         result = util.get_model_dummy_input(model, input_shape, input_key,
                                             batchsize)
         self.assertEqual(result.size(), tuple([batchsize] + input_shape))
    def test_model_complexity_hook(self) -> None:
        model_configs = get_test_model_configs()

        task = get_test_classy_task()
        task.prepare()

        # create a model complexity hook
        model_complexity_hook = ModelComplexityHook()

        for model_config in model_configs:
            model = build_model(model_config)

            task.base_model = model

            with self.assertLogs():
                model_complexity_hook.on_start(task)
    def test_model_complexity(self) -> None:
        """
        Test that the number of parameters and the FLOPs are calcuated correctly.
        """
        model_configs = get_test_model_configs()
        expected_mega_flops = [4122, 4274, 106152]
        expected_params = [25557032, 25028904, 43009448]
        local_variables = {}

        task = get_test_classy_task()
        task.prepare()

        # create a model complexity hook
        model_complexity_hook = ModelComplexityHook()

        for model_config, mega_flops, params in zip(model_configs,
                                                    expected_mega_flops,
                                                    expected_params):
            model = build_model(model_config)

            task.base_model = model

            with self.assertLogs() as log_watcher:
                model_complexity_hook.on_start(task, local_variables)

            # there should be 2 log statements generated
            self.assertEqual(len(log_watcher.output), 2)

            # first statement - either the MFLOPs or a warning
            if mega_flops is not None:
                match = re.search(
                    r"FLOPs for forward pass: (?P<mega_flops>[-+]?\d*\.\d+|\d+) MFLOPs",
                    log_watcher.output[0],
                )
                self.assertIsNotNone(match)
                self.assertEqual(mega_flops, float(match.group("mega_flops")))
            else:
                self.assertIn("Model contains unsupported modules",
                              log_watcher.output[0])

            # second statement
            match = re.search(
                r"Number of parameters in model: (?P<params>[-+]?\d*\.\d+|\d+)",
                log_watcher.output[1],
            )
            self.assertIsNotNone(match)
            self.assertEqual(params, float(match.group("params")))
예제 #5
0
    def test_complexity_calculation_resnext(self) -> None:
        model_configs = get_test_model_configs()
        # make sure there are three configs returned
        self.assertEqual(len(model_configs), 3)

        # expected values which allow minor deviations from model changes
        # we only test at the 10^6 scale
        expected_m_flops = [4122, 7850, 8034]
        expected_m_params = [25, 44, 44]
        expected_m_activations = [11, 16, 21]

        for model_config, m_flops, m_params, m_activations in zip(
            model_configs, expected_m_flops, expected_m_params, expected_m_activations
        ):
            model = build_model(model_config)
            self.assertEqual(compute_activations(model) // 10 ** 6, m_activations)
            self.assertEqual(compute_flops(model) // 10 ** 6, m_flops)
            self.assertEqual(count_params(model) // 10 ** 6, m_params)
예제 #6
0
class TestClassyModel(unittest.TestCase):
    model_configs = get_test_model_configs()

    def _get_config(self, model_config):
        return {
            "name": "classification_task",
            "num_epochs": 12,
            "loss": {
                "name": "test_loss"
            },
            "dataset": {
                "name": "imagenet",
                "batchsize_per_replica": 8,
                "use_pairs": False,
                "num_samples_per_phase": None,
                "use_shuffle": {
                    "train": True,
                    "test": False
                },
            },
            "meters": [],
            "model": model_config,
            "optimizer": {
                "name": "test_opt"
            },
        }

    def _compare_model_state(self, state, state2):
        compare_model_state(self, state, state2)

    def test_build_model(self):
        for cfg in self.model_configs:
            config = self._get_config(cfg)
            model = build_model(config["model"])
            self.assertTrue(isinstance(model, ClassyModel))
            self.assertTrue(
                type(model.input_shape) == tuple
                and len(model.input_shape) == 3)
            self.assertTrue(
                type(model.output_shape) == tuple
                and len(model.output_shape) == 2)
            self.assertTrue(type(model.model_depth) == int)

    def test_get_set_state(self):
        config = self._get_config(self.model_configs[0])
        model = build_model(config["model"])
        fake_input = torch.Tensor(1, 3, 224, 224).float()
        model.eval()
        state = model.get_classy_state()
        with torch.no_grad():
            output = model(fake_input)

        model2 = build_model(config["model"])
        model2.set_classy_state(state)

        # compare the states
        state2 = model2.get_classy_state()
        self._compare_model_state(state, state2)

        model2.eval()
        with torch.no_grad():
            output2 = model2(fake_input)
        self.assertTrue(torch.allclose(output, output2))

        # test deep_copy by assigning a deep copied state to model2
        # and then changing the original model's state
        state = model.get_classy_state(deep_copy=True)

        model3 = build_model(config["model"])
        state3 = model3.get_classy_state()

        # assign model2's state to model's and also re-assign model's state
        model2.set_classy_state(state)
        model.set_classy_state(state3)

        # compare the states
        state2 = model2.get_classy_state()
        self._compare_model_state(state, state2)

    def test_get_set_head_states(self):
        config = copy.deepcopy(self._get_config(self.model_configs[0]))
        head_configs = config["model"]["heads"]
        config["model"]["heads"] = []
        model = build_model(config["model"])
        trunk_state = model.get_classy_state()

        heads = defaultdict(dict)
        for head_config in head_configs:
            head = build_head(head_config)
            heads[head_config["fork_block"]][head.unique_id] = head
        model.set_heads(heads)
        model_state = model.get_classy_state()

        # the heads should be the same as we set
        self.assertEqual(len(heads), len(model.get_heads()))
        for block_name, hs in model.get_heads().items():
            self.assertEqual(hs, heads[block_name])

        model._clear_heads()
        self._compare_model_state(model.get_classy_state(), trunk_state)

        model.set_heads(heads)
        self._compare_model_state(model.get_classy_state(), model_state)