def test_reset_callback_resets_weights(a_data_module): def reset_fcs(model): """Reset all torch.nn.Linear layers.""" def reset(m): if isinstance(m, torch.nn.Linear): m.reset_parameters() model.apply(reset) model = vgg16() trainer = BaalTrainer(dataset=a_data_module.active_dataset, max_epochs=3, default_root_dir='/tmp') trainer.fit_loop.current_epoch = 10 initial_weights = copy.deepcopy(model.state_dict()) initial_params = copy.deepcopy(list(model.parameters())) callback = ResetCallback(initial_weights) # Modify the params reset_fcs(model) new_params = model.parameters() assert not all( torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params)) callback.on_train_start(trainer, model) new_params = model.parameters() assert all( torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params)) assert trainer.current_epoch == 0
def test_reset_callback_resets_weights(): def reset_fcs(model): """Reset all torch.nn.Linear layers.""" def reset(m): if isinstance(m, torch.nn.Linear): m.reset_parameters() model.apply(reset) model = vgg16() initial_weights = copy.deepcopy(model.state_dict()) initial_params = copy.deepcopy(list(model.parameters())) callback = ResetCallback(initial_weights) # Modify the params reset_fcs(model) new_params = model.parameters() assert not all(torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params)) callback.on_train_start(None, model) new_params = model.parameters() assert all(torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params))