def main(hparams): train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()]) test_transform = transforms.Compose([transforms.ToTensor()]) active_set = ActiveLearningDataset( CIFAR10(hparams.data_root, train=True, transform=train_transform, download=True), pool_specifics={ 'transform': test_transform }) active_set.label_randomly(10) heuristic = BALD() model = VGG16(active_set, hparams) dp = 'dp' if hparams.n_gpus > 1 else None trainer = BaalTrainer(max_epochs=3, default_root_dir=hparams.data_root, gpus=hparams.n_gpus, distributed_backend=dp, # The weights of the model will change as it gets # trained; we need to keep a copy (deepcopy) so that # we can reset them. callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))]) loop = ActiveLearningLoop(active_set, get_probabilities=trainer.predict_on_dataset_generator, heuristic=heuristic, ndata_to_label=hparams.query_size) AL_STEPS = 100 for al_step in range(AL_STEPS): print(f'Step {al_step} Dataset size {len(active_set)}') trainer.fit(model) should_continue = loop.step() if not should_continue: break
'https://download.pytorch.org/models/vgg16-397923af.pth') weights = {k: v for k, v in weights.items() if 'classifier.6' not in k} model.load_state_dict(weights, strict=False) model = PIActiveLearningModel(network=model, active_dataset=active_set, hparams=params) dp = 'dp' if params.gpus > 1 else None trainer = BaalTrainer( max_epochs=params.epochs, default_root_dir=params.data_root, gpus=params.gpus, distributed_backend=dp, # The weights of the model will change as it gets # trained; we need to keep a copy (deepcopy) so that # we can reset them. callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))], dataset=active_set, heuristic=heuristic, ndata_to_label=params.query_size) AL_STEPS = 2000 for al_step in range(AL_STEPS): # TODO fix this trainer.current_epoch = 0 print(f'Step {al_step} Dataset size {len(active_set)}') trainer.fit(model) should_continue = trainer.step() if not should_continue: break