Example #1
0
def test_reset_callback_resets_weights(a_data_module):
    def reset_fcs(model):
        """Reset all torch.nn.Linear layers."""
        def reset(m):
            if isinstance(m, torch.nn.Linear):
                m.reset_parameters()

        model.apply(reset)

    model = vgg16()
    trainer = BaalTrainer(dataset=a_data_module.active_dataset,
                          max_epochs=3,
                          default_root_dir='/tmp')
    trainer.fit_loop.current_epoch = 10
    initial_weights = copy.deepcopy(model.state_dict())
    initial_params = copy.deepcopy(list(model.parameters()))
    callback = ResetCallback(initial_weights)
    # Modify the params
    reset_fcs(model)
    new_params = model.parameters()
    assert not all(
        torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params))
    callback.on_train_start(trainer, model)
    new_params = model.parameters()
    assert all(
        torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params))
    assert trainer.current_epoch == 0
def test_reset_callback_resets_weights():
    def reset_fcs(model):
        """Reset all torch.nn.Linear layers."""

        def reset(m):
            if isinstance(m, torch.nn.Linear):
                m.reset_parameters()

        model.apply(reset)

    model = vgg16()
    initial_weights = copy.deepcopy(model.state_dict())
    initial_params = copy.deepcopy(list(model.parameters()))
    callback = ResetCallback(initial_weights)
    # Modify the params
    reset_fcs(model)
    new_params = model.parameters()
    assert not all(torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params))
    callback.on_train_start(None, model)
    new_params = model.parameters()
    assert all(torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params))