def test_from_checkpoint(self): config = get_test_task_config() for use_head in [True, False]: config["model"] = self.get_model_config(use_head) task = build_task(config) task.prepare() checkpoint_folder = f"{self.base_dir}/{use_head}/" input_args = {"config": config} # Simulate training by setting the model parameters to zero for param in task.model.parameters(): param.data.zero_() checkpoint_hook = CheckpointHook( checkpoint_folder, input_args, phase_types=["train"] ) # Create checkpoint dir, save checkpoint os.mkdir(checkpoint_folder) checkpoint_hook.on_start(task) task.train = True checkpoint_hook.on_phase_end(task) # Model should be checkpointed. load and compare checkpoint = load_checkpoint(checkpoint_folder) model = ClassyModel.from_checkpoint(checkpoint) self.assertTrue(isinstance(model, MyTestModel)) # All parameters must be zero for param in model.parameters(): self.assertTrue(torch.all(param.data == 0))
def test_logging(self, mock_get_rank: mock.MagicMock) -> None: """ Test that the logging happens as expected and the loss and lr values are correct. """ rank = 5 mock_get_rank.return_value = rank # set up the task and state config = get_test_task_config() config["dataset"]["train"]["batchsize_per_replica"] = 2 config["dataset"]["test"]["batchsize_per_replica"] = 5 task = build_task(config) task.prepare() losses = [1.2, 2.3, 3.4, 4.5] local_variables = {} task.phase_idx = 0 for log_freq in [5, None]: # create a loss lr meter hook loss_lr_meter_hook = LossLrMeterLoggingHook(log_freq=log_freq) # check that _log_loss_meters() is called after on_step() every # log_freq batches and after on_phase_end() # and _log_lr() is called after on_step() every log_freq batches # and after on_phase_end() with mock.patch.object(loss_lr_meter_hook, "_log_loss_meters") as mock_fn: with mock.patch.object(loss_lr_meter_hook, "_log_lr") as mock_lr_fn: num_batches = 20 for i in range(num_batches): task.losses = list(range(i)) loss_lr_meter_hook.on_step(task, local_variables) if log_freq is not None and i and i % log_freq == 0: mock_fn.assert_called_with(task, local_variables) mock_fn.reset_mock() mock_lr_fn.assert_called_with( task, local_variables) mock_lr_fn.reset_mock() continue mock_fn.assert_not_called() mock_lr_fn.assert_not_called() loss_lr_meter_hook.on_phase_end(task, local_variables) mock_fn.assert_called_with(task, local_variables) if task.train: mock_lr_fn.assert_called_with(task, local_variables) # test _log_loss_lr_meters() task.losses = losses with self.assertLogs(): loss_lr_meter_hook._log_loss_meters(task, local_variables) loss_lr_meter_hook._log_lr(task, local_variables) task.phase_idx += 1
def test_streaming_dataset(self): """ Test that streaming datasets return the correct number of batches, and that the length is also calculated correctly. """ config = get_test_task_config() dataset_config = { "name": "synthetic_image_streaming", "split": "train", "crop_size": 224, "class_ratio": 0.5, "num_samples": 2000, "length": 4000, "seed": 0, "batchsize_per_replica": 32, "use_shuffle": True, } expected_batches = 62 config["dataset"]["train"] = dataset_config task = build_task(config) task.prepare() task.advance_phase() # test that the number of batches expected is correct self.assertEqual(task.num_batches_per_phase, expected_batches) # test that the data iterator returns the expected number of batches data_iterator = task.data_iterator self._test_number_of_batches(data_iterator, expected_batches) # test that the dataloader can be rebuilt task.build_dataloaders_for_current_phase() task.create_data_iterators() data_iterator = task.data_iterator self._test_number_of_batches(data_iterator, expected_batches)
def test_torchscripting_using_trace(self): """ Test that the save_torchscript function works as expected with trace """ config = get_test_task_config() torchscript_folder = self.base_dir + "/torchscript_end_test/" # create a torchscript hook using trace torchscript_hook = TorchscriptHook(torchscript_folder) self.execute_hook(config, torchscript_folder, torchscript_hook)
def test_torchscripting_using_script(self): """ Test that the save_torchscript function works as expected with script """ config = get_test_task_config() # Setting wrapper_cls to None to make ResNet model torchscriptable ResNet.wrapper_cls = None torchscript_folder = self.base_dir + "/torchscript_end_test/" # create a torchscript hook using script torchscript_hook = TorchscriptHook(torchscript_folder, use_trace=False) self.execute_hook(config, torchscript_folder, torchscript_hook)
def test_build_task(self): config = get_test_task_config() task = build_task(config) self.assertTrue(isinstance(task, ClassificationTask)) # check that AMP is disabled by default self.assertIsNone(task.amp_args) # test a valid AMP opt level config = copy.deepcopy(config) config["amp_args"] = {"opt_level": "O1"} task = build_task(config) self.assertTrue(isinstance(task, ClassificationTask))
def test_state_checkpointing(self) -> None: """ Test that the state gets checkpointed without any errors, but only on the right phase_type and only if the checkpoint directory exists. """ config = get_test_task_config() task = build_task(config) task.prepare() local_variables = {} checkpoint_folder = self.base_dir + "/checkpoint_end_test/" input_args = {"foo": "bar"} # create a checkpoint hook checkpoint_hook = CheckpointHook(checkpoint_folder, input_args, phase_types=["train"]) # checkpoint directory doesn't exist # call the on start function with self.assertRaises(FileNotFoundError): checkpoint_hook.on_start(task) # call the on end phase function with self.assertRaises(AssertionError): checkpoint_hook.on_phase_end(task, local_variables) # try loading a non-existent checkpoint checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNone(checkpoint) # create checkpoint dir, verify on_start hook runs os.mkdir(checkpoint_folder) checkpoint_hook.on_start(task) # Phase_type is test, expect no checkpoint task.train = False # call the on end phase function checkpoint_hook.on_phase_end(task, local_variables) checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNone(checkpoint) task.train = True # call the on end phase function checkpoint_hook.on_phase_end(task, local_variables) # model should be checkpointed. load and compare checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNotNone(checkpoint) for key in ["input_args", "classy_state_dict"]: self.assertIn(key, checkpoint) # not testing for equality of classy_state_dict, that is tested in # a separate test self.assertDictEqual(checkpoint["input_args"], input_args)
def test_checkpoint_period(self) -> None: """ Test that the checkpoint_period works as expected. """ config = get_test_task_config() task = build_task(config) task.prepare() local_variables = {} checkpoint_folder = self.base_dir + "/checkpoint_end_test/" checkpoint_period = 10 for phase_types in [["train"], ["train", "test"]]: # create a checkpoint hook checkpoint_hook = CheckpointHook( checkpoint_folder, {}, phase_types=phase_types, checkpoint_period=checkpoint_period, ) # create checkpoint dir os.mkdir(checkpoint_folder) # call the on start function checkpoint_hook.on_start(task) # shouldn't create any checkpoints until there are checkpoint_period # phases which are in phase_types count = 0 valid_phase_count = 0 while valid_phase_count < checkpoint_period - 1: task.train = count % 2 == 0 # call the on end phase function checkpoint_hook.on_phase_end(task, local_variables) checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNone(checkpoint) valid_phase_count += 1 if task.phase_type in phase_types else 0 count += 1 # create a phase which is in phase_types task.train = True # call the on end phase function checkpoint_hook.on_phase_end(task, local_variables) # model should be checkpointed. load and compare checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNotNone(checkpoint) # delete the checkpoint dir shutil.rmtree(checkpoint_folder)
def test_get_state(self): config = get_test_task_config() loss = build_loss(config["loss"]) task = ( ClassificationTask().set_num_epochs(1).set_loss(loss).set_model( build_model(config["model"])).set_optimizer( build_optimizer(config["optimizer"]))) for phase_type in ["train", "test"]: dataset = build_dataset(config["dataset"][phase_type]) task.set_dataset(dataset, phase_type) task.prepare() task = build_task(config) task.prepare()
def test_failure(self) -> None: self.assertFalse(PathManager.exists("test://foo")) PathManager.register_handler(TestPathHandler()) # make sure that TestPathHandler is being used self.assertTrue(PathManager.exists("test://foo")) checkpoint_folder = "test://root" checkpoint_hook = CheckpointHook(checkpoint_folder, {}, phase_types=["train"]) config = get_test_task_config() task = build_task(config) task.prepare() # we should raise an exception while trying to save the checkpoint with self.assertRaises(TestException): checkpoint_hook.on_phase_end(task)
def test_build_task(self): config = get_test_task_config() task = build_task(config) self.assertTrue(isinstance(task, ClassificationTask)) # check that AMP is disabled by default self.assertIsNone(task.amp_opt_level) # test a valid AMP opt level config = copy.deepcopy(config) config["amp_opt_level"] = "O1" task = build_task(config) self.assertTrue(isinstance(task, ClassificationTask)) # test an invalid AMP opt level config = copy.deepcopy(config) config["amp_opt_level"] = "O5" with self.assertRaises(Exception): task = build_task(config)
def test_from_task(self): config = get_test_task_config() task = build_task(config) hub_interface = ClassyHubInterface.from_task(task) self.assertIsInstance(hub_interface.task, ClassyTask) self.assertIsInstance(hub_interface.model, ClassyModel) # this will pick up the transform from the task's config self._test_predict_and_extract_features(hub_interface) # test that the correct transform is picked up phase_type = "test" test_transform = TestTransform() task.datasets[phase_type].transform = test_transform hub_interface = ClassyHubInterface.from_task(task) dataset = hub_interface.create_image_dataset( image_files=[self.image_path], phase_type=phase_type) self.assertIsInstance(dataset.transform, TestTransform)
def test_streaming_dataset_async(self): """ Test that streaming datasets return the correct number of batches, and that the length is also calculated correctly. """ if not torch.cuda.is_available(): return True config = get_test_task_config() dataset_config = { "name": "synthetic_image_streaming", "split": "train", "crop_size": 224, "class_ratio": 0.5, "num_samples": 2000, "length": 4000, "seed": 0, "batchsize_per_replica": 32, "use_shuffle": True, "async_gpu_copy": True, } expected_batches = 62 config["dataset"]["train"] = dataset_config task = build_task(config) task.prepare() task.advance_phase() # test that the number of batches expected is correct self.assertEqual(task.num_batches_per_phase, expected_batches) # test that the data iterator returns the expected number of batches data_iterator = task.get_data_iterator() self._test_number_of_batches(data_iterator, expected_batches) # test that the dataloader can be rebuilt from the dataset inside it task._recreate_data_loader_from_dataset() task.create_data_iterator() data_iterator = task.get_data_iterator() self._test_number_of_batches(data_iterator, expected_batches)
def test_torchscripting(self): """ Test that the save_torchscript function works as expected. """ config = get_test_task_config() task = build_task(config) task.prepare() torchscript_folder = self.base_dir + "/torchscript_end_test/" # create a torchscript hook torchscript_hook = TorchscriptHook(torchscript_folder) # create checkpoint dir, verify on_start hook runs os.mkdir(torchscript_folder) torchscript_hook.on_start(task) task.train = True # call the on end function torchscript_hook.on_end(task) # load torchscript file torchscript_file_name = ( f"{torchscript_hook.torchscript_folder}/{TORCHSCRIPT_FILE}") torchscript = torch.jit.load(torchscript_file_name) # compare model load from checkpoint vs torchscript with torch.no_grad(): batchsize = 1 model = task.model input_data = torch.randn((batchsize, ) + model.input_shape, dtype=torch.float) if torch.cuda.is_available(): input_data = input_data.cuda() checkpoint_out = model(input_data) torchscript_out = torchscript(input_data) self.assertTrue(torch.allclose(checkpoint_out, torchscript_out))
def test_hooks_config_builds_correctly(self): config = get_test_task_config() config["hooks"] = [{"name": "loss_lr_meter_logging"}] task = build_task(config) self.assertTrue(len(task.hooks) == 1) self.assertTrue(isinstance(task.hooks[0], LossLrMeterLoggingHook))
def test_build_task(self): config = get_test_task_config() task = build_task(config) self.assertTrue(isinstance(task, ClassificationTask))
def _get_classy_model(self): config = get_test_task_config() model_config = config["model"] return build_model(model_config)
def test_visdom(self, mock_visdom_cls: mock.MagicMock, mock_is_primary: mock.MagicMock) -> None: """ Tests that visdom is populated with plots. """ mock_visdom = mock.create_autospec(Visdom, instance=True) mock_visdom_cls.return_value = mock_visdom # set up the task and state config = get_test_task_config() config["dataset"]["train"]["batchsize_per_replica"] = 2 config["dataset"]["test"]["batchsize_per_replica"] = 5 task = build_task(config) task.prepare() losses = [1.2, 2.3, 1.23, 2.33] loss_val = sum(losses) / len(losses) task.losses = losses visdom_server = "localhost" visdom_port = 8097 for master, visdom_conn in product([False, True], [False, True]): mock_is_primary.return_value = master mock_visdom.check_connection.return_value = visdom_conn # create a visdom hook visdom_hook = VisdomHook(visdom_server, visdom_port) mock_visdom_cls.assert_called_once() mock_visdom_cls.reset_mock() counts = {"train": 0, "test": 0} count = 0 for phase_idx in range(10): train = phase_idx % 2 == 0 task.train = train phase_type = "train" if train else "test" counts[phase_type] += 1 count += 1 # test that the metrics don't change if losses is empty and that # visdom.line() is not called task.losses = [] original_metrics = copy.deepcopy(visdom_hook.metrics) visdom_hook.on_phase_end(task) self.assertDictEqual(original_metrics, visdom_hook.metrics) mock_visdom.line.assert_not_called() # test that the metrics are updated correctly when losses # is non empty task.losses = [loss * count for loss in losses] visdom_hook.on_phase_end(task) # every meter should be present and should have the correct length for meter in task.meters: for key in meter.value: key = phase_type + "_" + meter.name + "_" + key self.assertTrue( key in visdom_hook.metrics and type(visdom_hook.metrics[key]) == list and len(visdom_hook.metrics[key]) == counts[phase_type]) # the loss metric should be calculated correctly loss_key = phase_type + "_loss" self.assertTrue(loss_key in visdom_hook.metrics and type(visdom_hook.metrics[loss_key]) == list and len(visdom_hook.metrics[loss_key]) == counts[phase_type]) self.assertAlmostEqual( visdom_hook.metrics[loss_key][-1], loss_val * count, places=4, ) # the lr metric should be correct lr_key = phase_type + "_learning_rate" self.assertTrue( lr_key in visdom_hook.metrics and type(visdom_hook.metrics[lr_key]) == list and len(visdom_hook.metrics[lr_key]) == counts[phase_type]) self.assertAlmostEqual( visdom_hook.metrics[lr_key][-1], task.optimizer.options_view.lr, places=4, ) if master and not train and visdom_conn: # visdom.line() should be called once mock_visdom.line.assert_called_once() mock_visdom.line.reset_mock() else: # visdom.line() should not be called mock_visdom.line.assert_not_called()
def test_writer(self, mock_is_primary_func: mock.MagicMock) -> None: """ Tests that the tensorboard writer writes the correct scalars to SummaryWriter iff is_primary() is True. """ for phase_idx, master in product([0, 1, 2], [True, False]): train, phase_type = ((True, "train") if phase_idx % 2 == 0 else (False, "test")) mock_is_primary_func.return_value = master # set up the task and state config = get_test_task_config() config["dataset"]["train"]["batchsize_per_replica"] = 2 config["dataset"]["test"]["batchsize_per_replica"] = 5 task = build_task(config) task.prepare() task.advance_phase() task.phase_idx = phase_idx task.train = train losses = [1.23, 4.45, 12.3, 3.4] sample_fetch_times = [1.1, 2.2, 3.3, 2.2] summary_writer = SummaryWriter(self.base_dir) # create a spy on top of summary_writer summary_writer = mock.MagicMock(wraps=summary_writer) # create a loss lr tensorboard hook tensorboard_plot_hook = TensorboardPlotHook(summary_writer) # run the hook in the correct order tensorboard_plot_hook.on_phase_start(task) # test tasks which do not pass the sample_fetch_times as well disable_sample_fetch_times = phase_idx == 0 for loss, sample_fetch_time in zip(losses, sample_fetch_times): task.losses.append(loss) step_data = ({} if disable_sample_fetch_times else { "sample_fetch_time": sample_fetch_time }) task.last_batch = LastBatchInfo(None, None, None, None, step_data) tensorboard_plot_hook.on_step(task) tensorboard_plot_hook.on_phase_end(task) if master: # add_scalar() should have been called with the right scalars if train: learning_rate_key = f"Learning Rate/{phase_type}" summary_writer.add_scalar.assert_any_call( learning_rate_key, mock.ANY, global_step=mock.ANY, walltime=mock.ANY, ) avg_loss_key = f"Losses/{phase_type}" summary_writer.add_scalar.assert_any_call(avg_loss_key, mock.ANY, global_step=mock.ANY) for meter in task.meters: for name in meter.value: meter_key = f"Meters/{phase_type}/{meter.name}/{name}" summary_writer.add_scalar.assert_any_call( meter_key, mock.ANY, global_step=mock.ANY) if step_data: summary_writer.add_scalar.assert_any_call( f"Speed/{phase_type}/cumulative_sample_fetch_time", mock.ANY, global_step=mock.ANY, walltime=mock.ANY, ) else: # add_scalar() shouldn't be called since is_primary() is False summary_writer.add_scalar.assert_not_called() summary_writer.add_scalar.reset_mock()
def test_writer(self, mock_is_master_func: mock.MagicMock) -> None: """ Tests that the tensorboard writer writes the correct scalars to SummaryWriter iff is_master() is True. """ for phase_idx, master in product([0, 1, 2], [True, False]): train, phase_type = ((True, "train") if phase_idx % 2 == 0 else (False, "test")) mock_is_master_func.return_value = master # set up the task and state config = get_test_task_config() config["dataset"]["train"]["batchsize_per_replica"] = 2 config["dataset"]["test"]["batchsize_per_replica"] = 5 task = build_task(config) task.prepare() task.phase_idx = phase_idx task.train = train losses = [1.23, 4.45, 12.3, 3.4] local_variables = {} summary_writer = SummaryWriter(self.base_dir) # create a spy on top of summary_writer summary_writer = mock.MagicMock(wraps=summary_writer) # create a loss lr tensorboard hook tensorboard_plot_hook = TensorboardPlotHook(summary_writer) # test that the hook logs a warning and doesn't write anything to # the writer if on_phase_start() is not called for initialization # before on_update() is called. with self.assertLogs() as log_watcher: tensorboard_plot_hook.on_update(task, local_variables) self.assertTrue( len(log_watcher.records) == 1 and log_watcher.records[0].levelno == logging.WARN and "learning_rates is not initialized" in log_watcher.output[0]) # test that the hook logs a warning and doesn't write anything to # the writer if on_phase_start() is not called for initialization # if on_phase_end() is called. with self.assertLogs() as log_watcher: tensorboard_plot_hook.on_phase_end(task, local_variables) self.assertTrue( len(log_watcher.records) == 1 and log_watcher.records[0].levelno == logging.WARN and "learning_rates is not initialized" in log_watcher.output[0]) summary_writer.add_scalar.reset_mock() # run the hook in the correct order tensorboard_plot_hook.on_phase_start(task, local_variables) for loss in losses: task.losses.append(loss) tensorboard_plot_hook.on_update(task, local_variables) tensorboard_plot_hook.on_phase_end(task, local_variables) if master: # add_scalar() should have been called with the right scalars if train: loss_key = f"{phase_type}_loss" learning_rate_key = f"{phase_type}_learning_rate_updates" summary_writer.add_scalar.assert_any_call( loss_key, mock.ANY, global_step=mock.ANY, walltime=mock.ANY) summary_writer.add_scalar.assert_any_call( learning_rate_key, mock.ANY, global_step=mock.ANY, walltime=mock.ANY, ) avg_loss_key = f"avg_{phase_type}_loss" summary_writer.add_scalar.assert_any_call(avg_loss_key, mock.ANY, global_step=mock.ANY) for meter in task.meters: for name in meter.value: meter_key = f"{phase_type}_{meter.name}_{name}" summary_writer.add_scalar.assert_any_call( meter_key, mock.ANY, global_step=mock.ANY) else: # add_scalar() shouldn't be called since is_master() is False summary_writer.add_scalar.assert_not_called() summary_writer.add_scalar.reset_mock()