Exemplo n.º 1
0
def main(hparams):
    train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor()])
    test_transform = transforms.Compose([transforms.ToTensor()])

    active_set = ActiveLearningDataset(
        CIFAR10(hparams.data_root, train=True, transform=train_transform, download=True),
        pool_specifics={
            'transform': test_transform
        })
    active_set.label_randomly(10)
    heuristic = BALD()
    model = VGG16(active_set, hparams)
    dp = 'dp' if hparams.n_gpus > 1 else None
    trainer = BaalTrainer(max_epochs=3, default_root_dir=hparams.data_root,
                          gpus=hparams.n_gpus, distributed_backend=dp,
                          # The weights of the model will change as it gets
                          # trained; we need to keep a copy (deepcopy) so that
                          # we can reset them.
                          callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))])
    loop = ActiveLearningLoop(active_set, get_probabilities=trainer.predict_on_dataset_generator,
                              heuristic=heuristic,
                              ndata_to_label=hparams.query_size)

    AL_STEPS = 100
    for al_step in range(AL_STEPS):
        print(f'Step {al_step} Dataset size {len(active_set)}')
        trainer.fit(model)
        should_continue = loop.step()
        if not should_continue:
            break
Exemplo n.º 2
0
def test_reset_callback_resets_weights(a_data_module):
    def reset_fcs(model):
        """Reset all torch.nn.Linear layers."""
        def reset(m):
            if isinstance(m, torch.nn.Linear):
                m.reset_parameters()

        model.apply(reset)

    model = vgg16()
    trainer = BaalTrainer(dataset=a_data_module.active_dataset,
                          max_epochs=3,
                          default_root_dir='/tmp')
    trainer.current_epoch = 10
    initial_weights = copy.deepcopy(model.state_dict())
    initial_params = copy.deepcopy(list(model.parameters()))
    callback = ResetCallback(initial_weights)
    # Modify the params
    reset_fcs(model)
    new_params = model.parameters()
    assert not all(
        torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params))
    callback.on_train_start(trainer, model)
    new_params = model.parameters()
    assert all(
        torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params))
    assert trainer.current_epoch == 0
Exemplo n.º 3
0
def test_predict(a_data_module, a_pl_module):
    trainer = BaalTrainer(dataset=a_data_module.active_dataset,
                          max_epochs=3,
                          default_root_dir='/tmp')
    active_set = a_data_module.active_dataset
    alt = trainer.predict_on_dataset(a_pl_module,
                                     a_data_module.pool_dataloader())
    assert len(alt) == len(active_set.pool)

    # Replicate = False works too!
    a_pl_module.hparams.replicate_in_memory = False
    alt = trainer.predict_on_dataset(a_pl_module,
                                     a_data_module.pool_dataloader())
    assert len(alt) == len(active_set.pool)
Exemplo n.º 4
0
def test_pl_step(monkeypatch, a_data_module, a_pl_module, hparams):
    active_set = a_data_module.active_dataset
    trainer = BaalTrainer(dataset=active_set,
                          max_epochs=3,
                          default_root_dir='/tmp',
                          query_size=hparams['query_size'])
    # Give everything.

    before = len(active_set)
    trainer.step(a_pl_module, a_data_module)
    after = len(active_set)

    assert after - before == hparams['query_size']
    # Add the lightning_module manually.
    trainer.accelerator.connect(a_pl_module)

    # Only give the model
    before = len(active_set)
    trainer.step(a_pl_module)
    after = len(active_set)

    assert after - before == hparams['query_size']

    # No model, no dataloader
    before = len(active_set)
    trainer.step()
    after = len(active_set)

    assert after - before == hparams['query_size']
def main():
    pl.seed_everything(42)
    args = parse_arguments()
    # Create our dataset.
    datamodule = Cifar10DataModule(args.data_root, batch_size=args.batch_size)
    datamodule.active_dataset.label_randomly(10)
    # Get our heuristic to compute uncertainty.
    heuristic = get_heuristic(args.heuristic, shuffle_prop=0.0, reduction='none')
    model = VGG16(**vars(args))  # Instantiate VGG16

    # Make our PL Trainer
    logger = TensorBoardLogger(save_dir=os.path.join('/tmp/', 'logs', 'active'), name='CIFAR10')
    trainer = BaalTrainer.from_argparse_args(args,
                                             # The weights of the model will change as it gets
                                             # trained; we need to keep a copy (deepcopy) so that
                                             # we can reset them.
                                             callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))],
                                             dataset=datamodule.active_dataset,
                                             max_epochs=args.training_duration,
                                             logger=logger,
                                             heuristic=heuristic,
                                             ndata_to_label=args.query_size
                                             )

    AL_STEPS = 100
    for al_step in range(AL_STEPS):
        print(f'Step {al_step} Dataset size {len(datamodule.active_dataset)}')
        trainer.fit(model, datamodule=datamodule)  # Train the model on the labelled set.
        trainer.test(model, datamodule=datamodule)  # Get test performance.
        should_continue = trainer.step(model, datamodule=datamodule)  # Label the top-k most uncertain examples.
        if not should_continue:
            break
Exemplo n.º 6
0
def test_pl_step():
    hparams = HParams()
    dataset = DummyDataset()
    active_set = ActiveLearningDataset(dataset)
    active_set.label_randomly(10)
    model = DummyPytorchLightning(active_set, hparams)
    ckpt = {}
    save_chkp = model.on_save_checkpoint(ckpt)
    trainer = BaalTrainer(dataset=active_set,
                          max_epochs=3, default_root_dir='/tmp',
                          ndata_to_label=hparams.query_size,
                          callbacks=[ResetCallback(copy.deepcopy(save_chkp))])
    trainer.model = model

    before = len(active_set)
    trainer.step()
    after = len(active_set)

    assert after - before == hparams.query_size
Exemplo n.º 7
0
def test_predict():
    ckpt = {}
    hparams = HParams()
    dataset = DummyDataset()
    active_set = ActiveLearningDataset(dataset)
    active_set.label_randomly(10)
    model = DummyPytorchLightning(active_set, hparams)
    save_chkp = model.on_save_checkpoint(ckpt)
    trainer = BaalTrainer(dataset=active_set,
                          max_epochs=3, default_root_dir='/tmp',
                          callbacks=[ResetCallback(copy.deepcopy(save_chkp))])
    trainer.model = model
    alt = trainer.predict_on_dataset()
    assert len(alt) == len(active_set.pool)
    assert 'active_dataset' in save_chkp
    n_labelled = len(active_set)
    copy_save_chkp = copy.deepcopy(save_chkp)
    active_set.label_randomly(5)

    model.on_load_checkpoint(copy_save_chkp)
    assert len(active_set) == n_labelled
Exemplo n.º 8
0
    model = vgg16(pretrained=False, num_classes=10)
    weights = load_state_dict_from_url(
        'https://download.pytorch.org/models/vgg16-397923af.pth')
    weights = {k: v for k, v in weights.items() if 'classifier.6' not in k}
    model.load_state_dict(weights, strict=False)
    model = PIActiveLearningModel(network=model,
                                  active_dataset=active_set,
                                  hparams=params)

    dp = 'dp' if params.gpus > 1 else None
    trainer = BaalTrainer(
        max_epochs=params.epochs,
        default_root_dir=params.data_root,
        gpus=params.gpus,
        distributed_backend=dp,
        # The weights of the model will change as it gets
        # trained; we need to keep a copy (deepcopy) so that
        # we can reset them.
        callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))],
        dataset=active_set,
        heuristic=heuristic,
        ndata_to_label=params.query_size)

    AL_STEPS = 2000
    for al_step in range(AL_STEPS):
        # TODO fix this
        trainer.current_epoch = 0
        print(f'Step {al_step} Dataset size {len(active_set)}')
        trainer.fit(model)
        should_continue = trainer.step()
        if not should_continue:
            break