Ejemplo n.º 1
0
    def test_from_checkpoint(self):
        config = get_test_task_config()
        for use_head in [True, False]:
            config["model"] = self.get_model_config(use_head)
            task = build_task(config)
            task.prepare()

            checkpoint_folder = f"{self.base_dir}/{use_head}/"
            input_args = {"config": config}

            # Simulate training by setting the model parameters to zero
            for param in task.model.parameters():
                param.data.zero_()

            checkpoint_hook = CheckpointHook(
                checkpoint_folder, input_args, phase_types=["train"]
            )

            # Create checkpoint dir, save checkpoint
            os.mkdir(checkpoint_folder)
            checkpoint_hook.on_start(task)

            task.train = True
            checkpoint_hook.on_phase_end(task)

            # Model should be checkpointed. load and compare
            checkpoint = load_checkpoint(checkpoint_folder)

            model = ClassyModel.from_checkpoint(checkpoint)
            self.assertTrue(isinstance(model, MyTestModel))

            # All parameters must be zero
            for param in model.parameters():
                self.assertTrue(torch.all(param.data == 0))
Ejemplo n.º 2
0
    def test_logging(self, mock_get_rank: mock.MagicMock) -> None:
        """
        Test that the logging happens as expected and the loss and lr values are
        correct.
        """
        rank = 5
        mock_get_rank.return_value = rank

        # set up the task and state
        config = get_test_task_config()
        config["dataset"]["train"]["batchsize_per_replica"] = 2
        config["dataset"]["test"]["batchsize_per_replica"] = 5
        task = build_task(config)
        task.prepare()

        losses = [1.2, 2.3, 3.4, 4.5]

        local_variables = {}
        task.phase_idx = 0

        for log_freq in [5, None]:
            # create a loss lr meter hook
            loss_lr_meter_hook = LossLrMeterLoggingHook(log_freq=log_freq)

            # check that _log_loss_meters() is called after on_step() every
            # log_freq batches and after on_phase_end()
            # and _log_lr() is called after on_step() every log_freq batches
            # and after on_phase_end()
            with mock.patch.object(loss_lr_meter_hook,
                                   "_log_loss_meters") as mock_fn:
                with mock.patch.object(loss_lr_meter_hook,
                                       "_log_lr") as mock_lr_fn:
                    num_batches = 20

                    for i in range(num_batches):
                        task.losses = list(range(i))
                        loss_lr_meter_hook.on_step(task, local_variables)
                        if log_freq is not None and i and i % log_freq == 0:
                            mock_fn.assert_called_with(task, local_variables)
                            mock_fn.reset_mock()
                            mock_lr_fn.assert_called_with(
                                task, local_variables)
                            mock_lr_fn.reset_mock()
                            continue
                        mock_fn.assert_not_called()
                        mock_lr_fn.assert_not_called()

                    loss_lr_meter_hook.on_phase_end(task, local_variables)
                    mock_fn.assert_called_with(task, local_variables)
                    if task.train:
                        mock_lr_fn.assert_called_with(task, local_variables)

            # test _log_loss_lr_meters()
            task.losses = losses

            with self.assertLogs():
                loss_lr_meter_hook._log_loss_meters(task, local_variables)
                loss_lr_meter_hook._log_lr(task, local_variables)

            task.phase_idx += 1
Ejemplo n.º 3
0
    def test_streaming_dataset(self):
        """
        Test that streaming datasets return the correct number of batches, and that
        the length is also calculated correctly.
        """
        config = get_test_task_config()
        dataset_config = {
            "name": "synthetic_image_streaming",
            "split": "train",
            "crop_size": 224,
            "class_ratio": 0.5,
            "num_samples": 2000,
            "length": 4000,
            "seed": 0,
            "batchsize_per_replica": 32,
            "use_shuffle": True,
        }
        expected_batches = 62
        config["dataset"]["train"] = dataset_config
        task = build_task(config)
        task.prepare()
        task.advance_phase()
        # test that the number of batches expected is correct
        self.assertEqual(task.num_batches_per_phase, expected_batches)

        # test that the data iterator returns the expected number of batches
        data_iterator = task.data_iterator
        self._test_number_of_batches(data_iterator, expected_batches)

        # test that the dataloader can be rebuilt
        task.build_dataloaders_for_current_phase()
        task.create_data_iterators()
        data_iterator = task.data_iterator
        self._test_number_of_batches(data_iterator, expected_batches)
Ejemplo n.º 4
0
    def test_torchscripting_using_trace(self):
        """
        Test that the save_torchscript function works as expected with trace
        """
        config = get_test_task_config()
        torchscript_folder = self.base_dir + "/torchscript_end_test/"

        # create a torchscript hook using trace
        torchscript_hook = TorchscriptHook(torchscript_folder)
        self.execute_hook(config, torchscript_folder, torchscript_hook)
Ejemplo n.º 5
0
    def test_torchscripting_using_script(self):
        """
        Test that the save_torchscript function works as expected with script
        """
        config = get_test_task_config()
        # Setting wrapper_cls to None to make ResNet model torchscriptable
        ResNet.wrapper_cls = None
        torchscript_folder = self.base_dir + "/torchscript_end_test/"

        # create a torchscript hook using script
        torchscript_hook = TorchscriptHook(torchscript_folder, use_trace=False)
        self.execute_hook(config, torchscript_folder, torchscript_hook)
    def test_build_task(self):
        config = get_test_task_config()
        task = build_task(config)
        self.assertTrue(isinstance(task, ClassificationTask))
        # check that AMP is disabled by default
        self.assertIsNone(task.amp_args)

        # test a valid AMP opt level
        config = copy.deepcopy(config)
        config["amp_args"] = {"opt_level": "O1"}
        task = build_task(config)
        self.assertTrue(isinstance(task, ClassificationTask))
Ejemplo n.º 7
0
    def test_state_checkpointing(self) -> None:
        """
        Test that the state gets checkpointed without any errors, but only on the
        right phase_type and only if the checkpoint directory exists.
        """
        config = get_test_task_config()
        task = build_task(config)
        task.prepare()

        local_variables = {}
        checkpoint_folder = self.base_dir + "/checkpoint_end_test/"
        input_args = {"foo": "bar"}

        # create a checkpoint hook
        checkpoint_hook = CheckpointHook(checkpoint_folder,
                                         input_args,
                                         phase_types=["train"])

        # checkpoint directory doesn't exist
        # call the on start function
        with self.assertRaises(FileNotFoundError):
            checkpoint_hook.on_start(task)
        # call the on end phase function
        with self.assertRaises(AssertionError):
            checkpoint_hook.on_phase_end(task, local_variables)
        # try loading a non-existent checkpoint
        checkpoint = load_checkpoint(checkpoint_folder)
        self.assertIsNone(checkpoint)

        # create checkpoint dir, verify on_start hook runs
        os.mkdir(checkpoint_folder)
        checkpoint_hook.on_start(task)

        # Phase_type is test, expect no checkpoint
        task.train = False
        # call the on end phase function
        checkpoint_hook.on_phase_end(task, local_variables)
        checkpoint = load_checkpoint(checkpoint_folder)
        self.assertIsNone(checkpoint)

        task.train = True
        # call the on end phase function
        checkpoint_hook.on_phase_end(task, local_variables)
        # model should be checkpointed. load and compare
        checkpoint = load_checkpoint(checkpoint_folder)
        self.assertIsNotNone(checkpoint)
        for key in ["input_args", "classy_state_dict"]:
            self.assertIn(key, checkpoint)
        # not testing for equality of classy_state_dict, that is tested in
        # a separate test
        self.assertDictEqual(checkpoint["input_args"], input_args)
Ejemplo n.º 8
0
    def test_checkpoint_period(self) -> None:
        """
        Test that the checkpoint_period works as expected.
        """
        config = get_test_task_config()
        task = build_task(config)
        task.prepare()

        local_variables = {}
        checkpoint_folder = self.base_dir + "/checkpoint_end_test/"
        checkpoint_period = 10

        for phase_types in [["train"], ["train", "test"]]:
            # create a checkpoint hook
            checkpoint_hook = CheckpointHook(
                checkpoint_folder,
                {},
                phase_types=phase_types,
                checkpoint_period=checkpoint_period,
            )

            # create checkpoint dir
            os.mkdir(checkpoint_folder)

            # call the on start function
            checkpoint_hook.on_start(task)

            # shouldn't create any checkpoints until there are checkpoint_period
            # phases which are in phase_types
            count = 0
            valid_phase_count = 0
            while valid_phase_count < checkpoint_period - 1:
                task.train = count % 2 == 0
                # call the on end phase function
                checkpoint_hook.on_phase_end(task, local_variables)
                checkpoint = load_checkpoint(checkpoint_folder)
                self.assertIsNone(checkpoint)
                valid_phase_count += 1 if task.phase_type in phase_types else 0
                count += 1

            # create a phase which is in phase_types
            task.train = True
            # call the on end phase function
            checkpoint_hook.on_phase_end(task, local_variables)
            # model should be checkpointed. load and compare
            checkpoint = load_checkpoint(checkpoint_folder)
            self.assertIsNotNone(checkpoint)
            # delete the checkpoint dir
            shutil.rmtree(checkpoint_folder)
    def test_get_state(self):
        config = get_test_task_config()
        loss = build_loss(config["loss"])
        task = (
            ClassificationTask().set_num_epochs(1).set_loss(loss).set_model(
                build_model(config["model"])).set_optimizer(
                    build_optimizer(config["optimizer"])))
        for phase_type in ["train", "test"]:
            dataset = build_dataset(config["dataset"][phase_type])
            task.set_dataset(dataset, phase_type)

        task.prepare()

        task = build_task(config)
        task.prepare()
    def test_failure(self) -> None:
        self.assertFalse(PathManager.exists("test://foo"))
        PathManager.register_handler(TestPathHandler())
        # make sure that TestPathHandler is being used
        self.assertTrue(PathManager.exists("test://foo"))

        checkpoint_folder = "test://root"

        checkpoint_hook = CheckpointHook(checkpoint_folder, {}, phase_types=["train"])
        config = get_test_task_config()
        task = build_task(config)
        task.prepare()

        # we should raise an exception while trying to save the checkpoint
        with self.assertRaises(TestException):
            checkpoint_hook.on_phase_end(task)
    def test_build_task(self):
        config = get_test_task_config()
        task = build_task(config)
        self.assertTrue(isinstance(task, ClassificationTask))
        # check that AMP is disabled by default
        self.assertIsNone(task.amp_opt_level)

        # test a valid AMP opt level
        config = copy.deepcopy(config)
        config["amp_opt_level"] = "O1"
        task = build_task(config)
        self.assertTrue(isinstance(task, ClassificationTask))

        # test an invalid AMP opt level
        config = copy.deepcopy(config)
        config["amp_opt_level"] = "O5"
        with self.assertRaises(Exception):
            task = build_task(config)
Ejemplo n.º 12
0
    def test_from_task(self):
        config = get_test_task_config()
        task = build_task(config)
        hub_interface = ClassyHubInterface.from_task(task)

        self.assertIsInstance(hub_interface.task, ClassyTask)
        self.assertIsInstance(hub_interface.model, ClassyModel)

        # this will pick up the transform from the task's config
        self._test_predict_and_extract_features(hub_interface)

        # test that the correct transform is picked up
        phase_type = "test"
        test_transform = TestTransform()
        task.datasets[phase_type].transform = test_transform
        hub_interface = ClassyHubInterface.from_task(task)
        dataset = hub_interface.create_image_dataset(
            image_files=[self.image_path], phase_type=phase_type)
        self.assertIsInstance(dataset.transform, TestTransform)
    def test_streaming_dataset_async(self):
        """
        Test that streaming datasets return the correct number of batches, and that
        the length is also calculated correctly.
        """

        if not torch.cuda.is_available():
            return True

        config = get_test_task_config()
        dataset_config = {
            "name": "synthetic_image_streaming",
            "split": "train",
            "crop_size": 224,
            "class_ratio": 0.5,
            "num_samples": 2000,
            "length": 4000,
            "seed": 0,
            "batchsize_per_replica": 32,
            "use_shuffle": True,
            "async_gpu_copy": True,
        }
        expected_batches = 62
        config["dataset"]["train"] = dataset_config
        task = build_task(config)
        task.prepare()
        task.advance_phase()

        # test that the number of batches expected is correct
        self.assertEqual(task.num_batches_per_phase, expected_batches)

        # test that the data iterator returns the expected number of batches
        data_iterator = task.get_data_iterator()
        self._test_number_of_batches(data_iterator, expected_batches)

        # test that the dataloader can be rebuilt from the dataset inside it
        task._recreate_data_loader_from_dataset()
        task.create_data_iterator()
        data_iterator = task.get_data_iterator()
        self._test_number_of_batches(data_iterator, expected_batches)
Ejemplo n.º 14
0
    def test_torchscripting(self):
        """
        Test that the save_torchscript function works as expected.
        """
        config = get_test_task_config()
        task = build_task(config)
        task.prepare()

        torchscript_folder = self.base_dir + "/torchscript_end_test/"

        # create a torchscript hook
        torchscript_hook = TorchscriptHook(torchscript_folder)

        # create checkpoint dir, verify on_start hook runs
        os.mkdir(torchscript_folder)
        torchscript_hook.on_start(task)

        task.train = True
        # call the on end function
        torchscript_hook.on_end(task)

        # load torchscript file
        torchscript_file_name = (
            f"{torchscript_hook.torchscript_folder}/{TORCHSCRIPT_FILE}")
        torchscript = torch.jit.load(torchscript_file_name)
        # compare model load from checkpoint vs torchscript
        with torch.no_grad():
            batchsize = 1
            model = task.model
            input_data = torch.randn((batchsize, ) + model.input_shape,
                                     dtype=torch.float)
            if torch.cuda.is_available():
                input_data = input_data.cuda()
            checkpoint_out = model(input_data)
            torchscript_out = torchscript(input_data)
            self.assertTrue(torch.allclose(checkpoint_out, torchscript_out))
 def test_hooks_config_builds_correctly(self):
     config = get_test_task_config()
     config["hooks"] = [{"name": "loss_lr_meter_logging"}]
     task = build_task(config)
     self.assertTrue(len(task.hooks) == 1)
     self.assertTrue(isinstance(task.hooks[0], LossLrMeterLoggingHook))
 def test_build_task(self):
     config = get_test_task_config()
     task = build_task(config)
     self.assertTrue(isinstance(task, ClassificationTask))
 def _get_classy_model(self):
     config = get_test_task_config()
     model_config = config["model"]
     return build_model(model_config)
Ejemplo n.º 18
0
    def test_visdom(self, mock_visdom_cls: mock.MagicMock,
                    mock_is_primary: mock.MagicMock) -> None:
        """
        Tests that visdom is populated with plots.
        """
        mock_visdom = mock.create_autospec(Visdom, instance=True)
        mock_visdom_cls.return_value = mock_visdom

        # set up the task and state
        config = get_test_task_config()
        config["dataset"]["train"]["batchsize_per_replica"] = 2
        config["dataset"]["test"]["batchsize_per_replica"] = 5
        task = build_task(config)
        task.prepare()

        losses = [1.2, 2.3, 1.23, 2.33]
        loss_val = sum(losses) / len(losses)

        task.losses = losses

        visdom_server = "localhost"
        visdom_port = 8097

        for master, visdom_conn in product([False, True], [False, True]):
            mock_is_primary.return_value = master
            mock_visdom.check_connection.return_value = visdom_conn

            # create a visdom hook
            visdom_hook = VisdomHook(visdom_server, visdom_port)

            mock_visdom_cls.assert_called_once()
            mock_visdom_cls.reset_mock()

            counts = {"train": 0, "test": 0}
            count = 0

            for phase_idx in range(10):
                train = phase_idx % 2 == 0
                task.train = train
                phase_type = "train" if train else "test"

                counts[phase_type] += 1
                count += 1

                # test that the metrics don't change if losses is empty and that
                # visdom.line() is not called
                task.losses = []
                original_metrics = copy.deepcopy(visdom_hook.metrics)
                visdom_hook.on_phase_end(task)
                self.assertDictEqual(original_metrics, visdom_hook.metrics)
                mock_visdom.line.assert_not_called()

                # test that the metrics are updated correctly when losses
                # is non empty
                task.losses = [loss * count for loss in losses]
                visdom_hook.on_phase_end(task)

                # every meter should be present and should have the correct length
                for meter in task.meters:
                    for key in meter.value:
                        key = phase_type + "_" + meter.name + "_" + key
                        self.assertTrue(
                            key in visdom_hook.metrics
                            and type(visdom_hook.metrics[key]) == list
                            and len(visdom_hook.metrics[key])
                            == counts[phase_type])

                # the loss metric should be calculated correctly
                loss_key = phase_type + "_loss"
                self.assertTrue(loss_key in visdom_hook.metrics
                                and type(visdom_hook.metrics[loss_key]) == list
                                and len(visdom_hook.metrics[loss_key])
                                == counts[phase_type])
                self.assertAlmostEqual(
                    visdom_hook.metrics[loss_key][-1],
                    loss_val * count,
                    places=4,
                )

                # the lr metric should be correct
                lr_key = phase_type + "_learning_rate"
                self.assertTrue(
                    lr_key in visdom_hook.metrics
                    and type(visdom_hook.metrics[lr_key]) == list
                    and len(visdom_hook.metrics[lr_key]) == counts[phase_type])
                self.assertAlmostEqual(
                    visdom_hook.metrics[lr_key][-1],
                    task.optimizer.options_view.lr,
                    places=4,
                )

                if master and not train and visdom_conn:
                    # visdom.line() should be called once
                    mock_visdom.line.assert_called_once()
                    mock_visdom.line.reset_mock()
                else:
                    # visdom.line() should not be called
                    mock_visdom.line.assert_not_called()
    def test_writer(self, mock_is_primary_func: mock.MagicMock) -> None:
        """
        Tests that the tensorboard writer writes the correct scalars to SummaryWriter
        iff is_primary() is True.
        """
        for phase_idx, master in product([0, 1, 2], [True, False]):
            train, phase_type = ((True, "train") if phase_idx % 2 == 0 else
                                 (False, "test"))
            mock_is_primary_func.return_value = master

            # set up the task and state
            config = get_test_task_config()
            config["dataset"]["train"]["batchsize_per_replica"] = 2
            config["dataset"]["test"]["batchsize_per_replica"] = 5
            task = build_task(config)
            task.prepare()
            task.advance_phase()
            task.phase_idx = phase_idx
            task.train = train

            losses = [1.23, 4.45, 12.3, 3.4]
            sample_fetch_times = [1.1, 2.2, 3.3, 2.2]

            summary_writer = SummaryWriter(self.base_dir)
            # create a spy on top of summary_writer
            summary_writer = mock.MagicMock(wraps=summary_writer)

            # create a loss lr tensorboard hook
            tensorboard_plot_hook = TensorboardPlotHook(summary_writer)

            # run the hook in the correct order
            tensorboard_plot_hook.on_phase_start(task)

            # test tasks which do not pass the sample_fetch_times as well
            disable_sample_fetch_times = phase_idx == 0

            for loss, sample_fetch_time in zip(losses, sample_fetch_times):
                task.losses.append(loss)
                step_data = ({} if disable_sample_fetch_times else {
                    "sample_fetch_time": sample_fetch_time
                })
                task.last_batch = LastBatchInfo(None, None, None, None,
                                                step_data)
                tensorboard_plot_hook.on_step(task)

            tensorboard_plot_hook.on_phase_end(task)

            if master:
                # add_scalar() should have been called with the right scalars
                if train:
                    learning_rate_key = f"Learning Rate/{phase_type}"
                    summary_writer.add_scalar.assert_any_call(
                        learning_rate_key,
                        mock.ANY,
                        global_step=mock.ANY,
                        walltime=mock.ANY,
                    )
                avg_loss_key = f"Losses/{phase_type}"
                summary_writer.add_scalar.assert_any_call(avg_loss_key,
                                                          mock.ANY,
                                                          global_step=mock.ANY)
                for meter in task.meters:
                    for name in meter.value:
                        meter_key = f"Meters/{phase_type}/{meter.name}/{name}"
                        summary_writer.add_scalar.assert_any_call(
                            meter_key, mock.ANY, global_step=mock.ANY)
                if step_data:
                    summary_writer.add_scalar.assert_any_call(
                        f"Speed/{phase_type}/cumulative_sample_fetch_time",
                        mock.ANY,
                        global_step=mock.ANY,
                        walltime=mock.ANY,
                    )
            else:
                # add_scalar() shouldn't be called since is_primary() is False
                summary_writer.add_scalar.assert_not_called()
            summary_writer.add_scalar.reset_mock()
    def test_writer(self, mock_is_master_func: mock.MagicMock) -> None:
        """
        Tests that the tensorboard writer writes the correct scalars to SummaryWriter
        iff is_master() is True.
        """
        for phase_idx, master in product([0, 1, 2], [True, False]):
            train, phase_type = ((True, "train") if phase_idx % 2 == 0 else
                                 (False, "test"))
            mock_is_master_func.return_value = master

            # set up the task and state
            config = get_test_task_config()
            config["dataset"]["train"]["batchsize_per_replica"] = 2
            config["dataset"]["test"]["batchsize_per_replica"] = 5
            task = build_task(config)
            task.prepare()
            task.phase_idx = phase_idx
            task.train = train

            losses = [1.23, 4.45, 12.3, 3.4]

            local_variables = {}

            summary_writer = SummaryWriter(self.base_dir)
            # create a spy on top of summary_writer
            summary_writer = mock.MagicMock(wraps=summary_writer)

            # create a loss lr tensorboard hook
            tensorboard_plot_hook = TensorboardPlotHook(summary_writer)

            # test that the hook logs a warning and doesn't write anything to
            # the writer if on_phase_start() is not called for initialization
            # before on_update() is called.
            with self.assertLogs() as log_watcher:
                tensorboard_plot_hook.on_update(task, local_variables)

            self.assertTrue(
                len(log_watcher.records) == 1
                and log_watcher.records[0].levelno == logging.WARN and
                "learning_rates is not initialized" in log_watcher.output[0])

            # test that the hook logs a warning and doesn't write anything to
            # the writer if on_phase_start() is not called for initialization
            # if on_phase_end() is called.
            with self.assertLogs() as log_watcher:
                tensorboard_plot_hook.on_phase_end(task, local_variables)

            self.assertTrue(
                len(log_watcher.records) == 1
                and log_watcher.records[0].levelno == logging.WARN and
                "learning_rates is not initialized" in log_watcher.output[0])
            summary_writer.add_scalar.reset_mock()

            # run the hook in the correct order
            tensorboard_plot_hook.on_phase_start(task, local_variables)

            for loss in losses:
                task.losses.append(loss)
                tensorboard_plot_hook.on_update(task, local_variables)

            tensorboard_plot_hook.on_phase_end(task, local_variables)

            if master:
                # add_scalar() should have been called with the right scalars
                if train:
                    loss_key = f"{phase_type}_loss"
                    learning_rate_key = f"{phase_type}_learning_rate_updates"
                    summary_writer.add_scalar.assert_any_call(
                        loss_key,
                        mock.ANY,
                        global_step=mock.ANY,
                        walltime=mock.ANY)
                    summary_writer.add_scalar.assert_any_call(
                        learning_rate_key,
                        mock.ANY,
                        global_step=mock.ANY,
                        walltime=mock.ANY,
                    )
                avg_loss_key = f"avg_{phase_type}_loss"
                summary_writer.add_scalar.assert_any_call(avg_loss_key,
                                                          mock.ANY,
                                                          global_step=mock.ANY)
                for meter in task.meters:
                    for name in meter.value:
                        meter_key = f"{phase_type}_{meter.name}_{name}"
                        summary_writer.add_scalar.assert_any_call(
                            meter_key, mock.ANY, global_step=mock.ANY)
            else:
                # add_scalar() shouldn't be called since is_master() is False
                summary_writer.add_scalar.assert_not_called()
            summary_writer.add_scalar.reset_mock()