def test_save_checkpoint(config: ModelConfigBase) -> None: """ Test that checkpoints are saved correctly """ config.mean_teacher_alpha = 0.999 model_and_info = ModelAndInfo(config, model_execution_mode=ModelExecutionMode.TEST, checkpoint_path=None) model_and_info.try_create_model_and_load_from_checkpoint() model_and_info.try_create_mean_teacher_model_and_load_from_checkpoint() model_and_info.try_create_optimizer_and_load_from_checkpoint() def get_constant_init_function(constant: float) -> Callable: def init(layer: nn.Module) -> None: if type(layer) == nn.Conv3d: layer.weight.data.fill_(constant) # type: ignore return init assert model_and_info.mean_teacher_model is not None # for mypy model_and_info.model.apply(get_constant_init_function(1.0)) model_and_info.mean_teacher_model.apply(get_constant_init_function(2.0)) epoch = 3 checkpoint_path = config.get_path_to_checkpoint(epoch=epoch) checkpoint_dir = checkpoint_path.parent if not os.path.isdir(checkpoint_dir): os.makedirs(checkpoint_dir) model_and_info.save_checkpoint(epoch=epoch) model_and_info_restored = ModelAndInfo(config, model_execution_mode=ModelExecutionMode.TEST, checkpoint_path=config.get_path_to_checkpoint(epoch=epoch)) model_and_info_restored.try_create_model_load_from_checkpoint_and_adjust() model_and_info_restored.try_create_mean_teacher_model_load_from_checkpoint_and_adjust() assert model_and_info_restored.mean_teacher_model is not None # for mypy for module in model_and_info_restored.model.modules(): if type(module) == nn.Conv3d: assert torch.equal(module.weight.detach(), torch.full_like(module.weight.detach(), 1.0)) # type: ignore for module in model_and_info_restored.mean_teacher_model.modules(): if type(module) == nn.Conv3d: assert torch.equal(module.weight.detach(), torch.full_like(module.weight.detach(), 2.0)) # type: ignore
def test_try_create_mean_teacher_model_and_load_from_checkpoint(config: ModelConfigBase, checkpoint_path: str) -> None: config.mean_teacher_alpha = 0.999 # no checkpoint path provided model_and_info = ModelAndInfo(config, model_execution_mode=ModelExecutionMode.TEST, checkpoint_path=None) with pytest.raises(ValueError): model_and_info.mean_teacher_model model_loaded = model_and_info.try_create_mean_teacher_model_and_load_from_checkpoint() assert model_loaded if isinstance(config, SegmentationModelBase): assert isinstance(model_and_info.mean_teacher_model, BaseModel) else: assert isinstance(model_and_info.mean_teacher_model, DeviceAwareModule) # Invalid checkpoint path provided model_and_info = ModelAndInfo(config, model_execution_mode=ModelExecutionMode.TEST, checkpoint_path=full_ml_test_data_path("non_exist.pth.tar")) model_loaded = model_and_info.try_create_mean_teacher_model_and_load_from_checkpoint() assert not model_loaded # Current code assumes that even if this function returns False, the model itself was created, only the checkpoint # loading failed. if isinstance(config, SegmentationModelBase): assert isinstance(model_and_info.mean_teacher_model, BaseModel) else: assert isinstance(model_and_info.mean_teacher_model, DeviceAwareModule) # Valid checkpoint path provided model_and_info = ModelAndInfo(config, model_execution_mode=ModelExecutionMode.TEST, checkpoint_path=full_ml_test_data_path(checkpoint_path)) model_loaded = model_and_info.try_create_mean_teacher_model_and_load_from_checkpoint() assert model_loaded if isinstance(config, SegmentationModelBase): assert isinstance(model_and_info.mean_teacher_model, BaseModel) else: assert isinstance(model_and_info.mean_teacher_model, DeviceAwareModule) assert model_and_info.checkpoint_epoch == 1
def test_mean_teacher_model(test_output_dirs: OutputFolderForTests) -> None: """ Test training and weight updates of the mean teacher model computation. """ def _get_parameters_of_model(model: DeviceAwareModule) -> Any: """ Returns the iterator of model parameters """ if isinstance(model, DataParallelModel): return model.module.parameters() else: return model.parameters() config = DummyClassification() config.set_output_to(test_output_dirs.root_dir) checkpoint_handler = get_default_checkpoint_handler( model_config=config, project_root=test_output_dirs.root_dir) config.num_epochs = 1 # Set train batch size to be arbitrary big to ensure we have only one training step # i.e. one mean teacher update. config.train_batch_size = 100 # Train without mean teacher model_train(config, checkpoint_handler=checkpoint_handler) # Retrieve the weight after one epoch model_and_info = ModelAndInfo( config=config, model_execution_mode=ModelExecutionMode.TEST, checkpoint_path=config.get_path_to_checkpoint(epoch=1)) model_and_info.try_create_model_and_load_from_checkpoint() model = model_and_info.model model_weight = next(_get_parameters_of_model(model)) # Get the starting weight of the mean teacher model ml_util.set_random_seed(config.get_effective_random_seed()) model_and_info_mean_teacher = ModelAndInfo( config=config, model_execution_mode=ModelExecutionMode.TEST, checkpoint_path=None) model_and_info_mean_teacher.try_create_model_and_load_from_checkpoint() model_and_info_mean_teacher.try_create_mean_teacher_model_and_load_from_checkpoint( ) mean_teach_model = model_and_info_mean_teacher.mean_teacher_model assert mean_teach_model is not None # for mypy initial_weight_mean_teacher_model = next( _get_parameters_of_model(mean_teach_model)) # Now train with mean teacher and check the update of the weight alpha = 0.999 config.mean_teacher_alpha = alpha model_train(config, checkpoint_handler=checkpoint_handler) # Retrieve weight of mean teacher model saved in the checkpoint model_and_info_mean_teacher = ModelAndInfo( config=config, model_execution_mode=ModelExecutionMode.TEST, checkpoint_path=config.get_path_to_checkpoint(1)) model_and_info_mean_teacher.try_create_mean_teacher_model_and_load_from_checkpoint( ) mean_teacher_model = model_and_info_mean_teacher.mean_teacher_model assert mean_teacher_model is not None # for mypy result_weight = next(_get_parameters_of_model(mean_teacher_model)) # Retrieve the associated student weight model_and_info_mean_teacher.try_create_model_and_load_from_checkpoint() student_model = model_and_info_mean_teacher.model student_model_weight = next(_get_parameters_of_model(student_model)) # Assert that the student weight corresponds to the weight of a simple training without mean teacher # computation assert student_model_weight.allclose(model_weight) # Check the update of the parameters assert torch.all(alpha * initial_weight_mean_teacher_model + (1 - alpha) * student_model_weight == result_weight)