예제 #1
0
def test_save_checkpoint(config: ModelConfigBase) -> None:
    """
    Test that checkpoints are saved correctly
    """

    config.mean_teacher_alpha = 0.999

    model_and_info = ModelAndInfo(config,
                                  model_execution_mode=ModelExecutionMode.TEST,
                                  checkpoint_path=None)
    model_and_info.try_create_model_and_load_from_checkpoint()
    model_and_info.try_create_mean_teacher_model_and_load_from_checkpoint()
    model_and_info.try_create_optimizer_and_load_from_checkpoint()

    def get_constant_init_function(constant: float) -> Callable:
        def init(layer: nn.Module) -> None:
            if type(layer) == nn.Conv3d:
                layer.weight.data.fill_(constant)  # type: ignore
        return init

    assert model_and_info.mean_teacher_model is not None  # for mypy

    model_and_info.model.apply(get_constant_init_function(1.0))
    model_and_info.mean_teacher_model.apply(get_constant_init_function(2.0))

    epoch = 3

    checkpoint_path = config.get_path_to_checkpoint(epoch=epoch)
    checkpoint_dir = checkpoint_path.parent
    if not os.path.isdir(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    model_and_info.save_checkpoint(epoch=epoch)

    model_and_info_restored = ModelAndInfo(config,
                                           model_execution_mode=ModelExecutionMode.TEST,
                                           checkpoint_path=config.get_path_to_checkpoint(epoch=epoch))
    model_and_info_restored.try_create_model_load_from_checkpoint_and_adjust()
    model_and_info_restored.try_create_mean_teacher_model_load_from_checkpoint_and_adjust()

    assert model_and_info_restored.mean_teacher_model is not None  # for mypy

    for module in model_and_info_restored.model.modules():
        if type(module) == nn.Conv3d:
            assert torch.equal(module.weight.detach(), torch.full_like(module.weight.detach(), 1.0))  # type: ignore

    for module in model_and_info_restored.mean_teacher_model.modules():
        if type(module) == nn.Conv3d:
            assert torch.equal(module.weight.detach(), torch.full_like(module.weight.detach(), 2.0))  # type: ignore
예제 #2
0
def test_try_create_mean_teacher_model_and_load_from_checkpoint(config: ModelConfigBase, checkpoint_path: str) -> None:
    config.mean_teacher_alpha = 0.999

    # no checkpoint path provided
    model_and_info = ModelAndInfo(config,
                                  model_execution_mode=ModelExecutionMode.TEST,
                                  checkpoint_path=None)

    with pytest.raises(ValueError):
        model_and_info.mean_teacher_model

    model_loaded = model_and_info.try_create_mean_teacher_model_and_load_from_checkpoint()
    assert model_loaded
    if isinstance(config, SegmentationModelBase):
        assert isinstance(model_and_info.mean_teacher_model, BaseModel)
    else:
        assert isinstance(model_and_info.mean_teacher_model, DeviceAwareModule)

    # Invalid checkpoint path provided
    model_and_info = ModelAndInfo(config,
                                  model_execution_mode=ModelExecutionMode.TEST,
                                  checkpoint_path=full_ml_test_data_path("non_exist.pth.tar"))
    model_loaded = model_and_info.try_create_mean_teacher_model_and_load_from_checkpoint()
    assert not model_loaded
    # Current code assumes that even if this function returns False, the model itself was created, only the checkpoint
    # loading failed.
    if isinstance(config, SegmentationModelBase):
        assert isinstance(model_and_info.mean_teacher_model, BaseModel)
    else:
        assert isinstance(model_and_info.mean_teacher_model, DeviceAwareModule)

    # Valid checkpoint path provided
    model_and_info = ModelAndInfo(config,
                                  model_execution_mode=ModelExecutionMode.TEST,
                                  checkpoint_path=full_ml_test_data_path(checkpoint_path))
    model_loaded = model_and_info.try_create_mean_teacher_model_and_load_from_checkpoint()
    assert model_loaded
    if isinstance(config, SegmentationModelBase):
        assert isinstance(model_and_info.mean_teacher_model, BaseModel)
    else:
        assert isinstance(model_and_info.mean_teacher_model, DeviceAwareModule)
    assert model_and_info.checkpoint_epoch == 1
예제 #3
0
def test_mean_teacher_model(test_output_dirs: OutputFolderForTests) -> None:
    """
    Test training and weight updates of the mean teacher model computation.
    """
    def _get_parameters_of_model(model: DeviceAwareModule) -> Any:
        """
        Returns the iterator of model parameters
        """
        if isinstance(model, DataParallelModel):
            return model.module.parameters()
        else:
            return model.parameters()

    config = DummyClassification()
    config.set_output_to(test_output_dirs.root_dir)
    checkpoint_handler = get_default_checkpoint_handler(
        model_config=config, project_root=test_output_dirs.root_dir)

    config.num_epochs = 1
    # Set train batch size to be arbitrary big to ensure we have only one training step
    # i.e. one mean teacher update.
    config.train_batch_size = 100
    # Train without mean teacher
    model_train(config, checkpoint_handler=checkpoint_handler)

    # Retrieve the weight after one epoch
    model_and_info = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TEST,
        checkpoint_path=config.get_path_to_checkpoint(epoch=1))
    model_and_info.try_create_model_and_load_from_checkpoint()
    model = model_and_info.model
    model_weight = next(_get_parameters_of_model(model))

    # Get the starting weight of the mean teacher model
    ml_util.set_random_seed(config.get_effective_random_seed())

    model_and_info_mean_teacher = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TEST,
        checkpoint_path=None)
    model_and_info_mean_teacher.try_create_model_and_load_from_checkpoint()

    model_and_info_mean_teacher.try_create_mean_teacher_model_and_load_from_checkpoint(
    )
    mean_teach_model = model_and_info_mean_teacher.mean_teacher_model
    assert mean_teach_model is not None  # for mypy
    initial_weight_mean_teacher_model = next(
        _get_parameters_of_model(mean_teach_model))

    # Now train with mean teacher and check the update of the weight
    alpha = 0.999
    config.mean_teacher_alpha = alpha
    model_train(config, checkpoint_handler=checkpoint_handler)

    # Retrieve weight of mean teacher model saved in the checkpoint
    model_and_info_mean_teacher = ModelAndInfo(
        config=config,
        model_execution_mode=ModelExecutionMode.TEST,
        checkpoint_path=config.get_path_to_checkpoint(1))
    model_and_info_mean_teacher.try_create_mean_teacher_model_and_load_from_checkpoint(
    )
    mean_teacher_model = model_and_info_mean_teacher.mean_teacher_model
    assert mean_teacher_model is not None  # for mypy
    result_weight = next(_get_parameters_of_model(mean_teacher_model))
    # Retrieve the associated student weight
    model_and_info_mean_teacher.try_create_model_and_load_from_checkpoint()
    student_model = model_and_info_mean_teacher.model
    student_model_weight = next(_get_parameters_of_model(student_model))

    # Assert that the student weight corresponds to the weight of a simple training without mean teacher
    # computation
    assert student_model_weight.allclose(model_weight)

    # Check the update of the parameters
    assert torch.all(alpha * initial_weight_mean_teacher_model +
                     (1 - alpha) * student_model_weight == result_weight)