def test_reset_callback_resets_weights(a_data_module): def reset_fcs(model): """Reset all torch.nn.Linear layers.""" def reset(m): if isinstance(m, torch.nn.Linear): m.reset_parameters() model.apply(reset) model = vgg16() trainer = BaalTrainer(dataset=a_data_module.active_dataset, max_epochs=3, default_root_dir='/tmp') trainer.fit_loop.current_epoch = 10 initial_weights = copy.deepcopy(model.state_dict()) initial_params = copy.deepcopy(list(model.parameters())) callback = ResetCallback(initial_weights) # Modify the params reset_fcs(model) new_params = model.parameters() assert not all( torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params)) callback.on_train_start(trainer, model) new_params = model.parameters() assert all( torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params)) assert trainer.current_epoch == 0
def main(hparams): train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()]) test_transform = transforms.Compose([transforms.ToTensor()]) active_set = ActiveLearningDataset( CIFAR10(hparams.data_root, train=True, transform=train_transform, download=True), pool_specifics={ 'transform': test_transform }) active_set.label_randomly(10) heuristic = BALD() model = VGG16(active_set, hparams) dp = 'dp' if hparams.n_gpus > 1 else None trainer = BaalTrainer(max_epochs=3, default_root_dir=hparams.data_root, gpus=hparams.n_gpus, distributed_backend=dp, # The weights of the model will change as it gets # trained; we need to keep a copy (deepcopy) so that # we can reset them. callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))]) loop = ActiveLearningLoop(active_set, get_probabilities=trainer.predict_on_dataset_generator, heuristic=heuristic, ndata_to_label=hparams.query_size) AL_STEPS = 100 for al_step in range(AL_STEPS): print(f'Step {al_step} Dataset size {len(active_set)}') trainer.fit(model) should_continue = loop.step() if not should_continue: break
def main(): pl.seed_everything(42) args = parse_arguments() # Create our dataset. datamodule = Cifar10DataModule(args.data_root, batch_size=args.batch_size) datamodule.active_dataset.label_randomly(10) # Get our heuristic to compute uncertainty. heuristic = get_heuristic(args.heuristic, shuffle_prop=0.0, reduction='none') model = VGG16(**vars(args)) # Instantiate VGG16 # Make our PL Trainer logger = TensorBoardLogger(save_dir=os.path.join('/tmp/', 'logs', 'active'), name='CIFAR10') trainer = BaalTrainer.from_argparse_args(args, # The weights of the model will change as it gets # trained; we need to keep a copy (deepcopy) so that # we can reset them. callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))], dataset=datamodule.active_dataset, max_epochs=args.training_duration, logger=logger, heuristic=heuristic, ndata_to_label=args.query_size ) AL_STEPS = 100 for al_step in range(AL_STEPS): print(f'Step {al_step} Dataset size {len(datamodule.active_dataset)}') trainer.fit(model, datamodule=datamodule) # Train the model on the labelled set. trainer.test(model, datamodule=datamodule) # Get test performance. should_continue = trainer.step(model, datamodule=datamodule) # Label the top-k most uncertain examples. if not should_continue: break
def test_reset_callback_resets_weights(): def reset_fcs(model): """Reset all torch.nn.Linear layers.""" def reset(m): if isinstance(m, torch.nn.Linear): m.reset_parameters() model.apply(reset) model = vgg16() initial_weights = copy.deepcopy(model.state_dict()) initial_params = copy.deepcopy(list(model.parameters())) callback = ResetCallback(initial_weights) # Modify the params reset_fcs(model) new_params = model.parameters() assert not all(torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params)) callback.on_train_start(None, model) new_params = model.parameters() assert all(torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params))
def test_pl_step(): hparams = HParams() dataset = DummyDataset() active_set = ActiveLearningDataset(dataset) active_set.label_randomly(10) model = DummyPytorchLightning(active_set, hparams) ckpt = {} save_chkp = model.on_save_checkpoint(ckpt) trainer = BaalTrainer(dataset=active_set, max_epochs=3, default_root_dir='/tmp', ndata_to_label=hparams.query_size, callbacks=[ResetCallback(copy.deepcopy(save_chkp))]) trainer.model = model before = len(active_set) trainer.step() after = len(active_set) assert after - before == hparams.query_size
def test_predict(): ckpt = {} hparams = HParams() dataset = DummyDataset() active_set = ActiveLearningDataset(dataset) active_set.label_randomly(10) model = DummyPytorchLightning(active_set, hparams) save_chkp = model.on_save_checkpoint(ckpt) trainer = BaalTrainer(dataset=active_set, max_epochs=3, default_root_dir='/tmp', callbacks=[ResetCallback(copy.deepcopy(save_chkp))]) trainer.model = model alt = trainer.predict_on_dataset() assert len(alt) == len(active_set.pool) assert 'active_dataset' in save_chkp n_labelled = len(active_set) copy_save_chkp = copy.deepcopy(save_chkp) active_set.label_randomly(5) model.on_load_checkpoint(copy_save_chkp) assert len(active_set) == n_labelled
'https://download.pytorch.org/models/vgg16-397923af.pth') weights = {k: v for k, v in weights.items() if 'classifier.6' not in k} model.load_state_dict(weights, strict=False) model = PIActiveLearningModel(network=model, active_dataset=active_set, hparams=params) dp = 'dp' if params.gpus > 1 else None trainer = BaalTrainer( max_epochs=params.epochs, default_root_dir=params.data_root, gpus=params.gpus, distributed_backend=dp, # The weights of the model will change as it gets # trained; we need to keep a copy (deepcopy) so that # we can reset them. callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))], dataset=active_set, heuristic=heuristic, ndata_to_label=params.query_size) AL_STEPS = 2000 for al_step in range(AL_STEPS): # TODO fix this trainer.current_epoch = 0 print(f'Step {al_step} Dataset size {len(active_set)}') trainer.fit(model) should_continue = trainer.step() if not should_continue: break