def test_reset_callback_resets_weights(a_data_module): def reset_fcs(model): """Reset all torch.nn.Linear layers.""" def reset(m): if isinstance(m, torch.nn.Linear): m.reset_parameters() model.apply(reset) model = vgg16() trainer = BaalTrainer(dataset=a_data_module.active_dataset, max_epochs=3, default_root_dir='/tmp') trainer.current_epoch = 10 initial_weights = copy.deepcopy(model.state_dict()) initial_params = copy.deepcopy(list(model.parameters())) callback = ResetCallback(initial_weights) # Modify the params reset_fcs(model) new_params = model.parameters() assert not all( torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params)) callback.on_train_start(trainer, model) new_params = model.parameters() assert all( torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params)) assert trainer.current_epoch == 0
def main(hparams): train_transform = transforms.Compose( [transforms.RandomHorizontalFlip(), transforms.ToTensor()]) test_transform = transforms.Compose([transforms.ToTensor()]) active_set = ActiveLearningDataset( CIFAR10(hparams.data_root, train=True, transform=train_transform, download=True), pool_specifics={'transform': test_transform}) active_set.label_randomly(10) heuristic = BALD() model = VGG16(active_set, hparams) dp = 'dp' if hparams.n_gpus > 1 else None trainer = BaalTrainer( max_epochs=3, default_root_dir=hparams.data_root, gpus=hparams.n_gpus, distributed_backend=dp, # The weights of the model will change as it gets # trained; we need to keep a copy (deepcopy) so that # we can reset them. callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))], dataset=active_set, heuristic=heuristic, ndata_to_label=hparams.query_size) AL_STEPS = 100 for al_step in range(AL_STEPS): # TODO Issue 95 Make PL trainer epoch self-aware trainer.current_epoch = 0 print(f'Step {al_step} Dataset size {len(active_set)}') trainer.fit(model) should_continue = trainer.step() if not should_continue: break
'https://download.pytorch.org/models/vgg16-397923af.pth') weights = {k: v for k, v in weights.items() if 'classifier.6' not in k} model.load_state_dict(weights, strict=False) model = PIActiveLearningModel(network=model, active_dataset=active_set, hparams=params) dp = 'dp' if params.gpus > 1 else None trainer = BaalTrainer( max_epochs=params.epochs, default_root_dir=params.data_root, gpus=params.gpus, distributed_backend=dp, # The weights of the model will change as it gets # trained; we need to keep a copy (deepcopy) so that # we can reset them. callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))], dataset=active_set, heuristic=heuristic, ndata_to_label=params.query_size) AL_STEPS = 2000 for al_step in range(AL_STEPS): # TODO fix this trainer.current_epoch = 0 print(f'Step {al_step} Dataset size {len(active_set)}') trainer.fit(model) should_continue = trainer.step() if not should_continue: break