예제 #1
0
def test_pl_step():
    hparams = HParams()
    dataset = DummyDataset()
    active_set = ActiveLearningDataset(dataset)
    active_set.label_randomly(10)
    model = DummyPytorchLightning(active_set, hparams)
    ckpt = {}
    save_chkp = model.on_save_checkpoint(ckpt)
    trainer = BaalTrainer(dataset=active_set,
                          max_epochs=3, default_root_dir='/tmp',
                          ndata_to_label=hparams.query_size,
                          callbacks=[ResetCallback(copy.deepcopy(save_chkp))])
    trainer.model = model

    before = len(active_set)
    trainer.step()
    after = len(active_set)

    assert after - before == hparams.query_size
예제 #2
0
def test_pl_step(monkeypatch, a_data_module, a_pl_module, hparams):
    active_set = a_data_module.active_dataset
    trainer = BaalTrainer(dataset=active_set,
                          max_epochs=3,
                          default_root_dir='/tmp',
                          query_size=hparams['query_size'])
    # Give everything.

    before = len(active_set)
    trainer.step(a_pl_module, a_data_module)
    after = len(active_set)

    assert after - before == hparams['query_size']
    # Add the lightning_module manually.
    trainer.accelerator.connect(a_pl_module)

    # Only give the model
    before = len(active_set)
    trainer.step(a_pl_module)
    after = len(active_set)

    assert after - before == hparams['query_size']

    # No model, no dataloader
    before = len(active_set)
    trainer.step()
    after = len(active_set)

    assert after - before == hparams['query_size']
예제 #3
0
def main(hparams):
    train_transform = transforms.Compose(
        [transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])
    test_transform = transforms.Compose([transforms.ToTensor()])

    active_set = ActiveLearningDataset(
        CIFAR10(hparams.data_root,
                train=True,
                transform=train_transform,
                download=True),
        pool_specifics={'transform': test_transform})
    active_set.label_randomly(10)
    heuristic = BALD()
    model = VGG16(active_set, hparams)
    dp = 'dp' if hparams.n_gpus > 1 else None
    trainer = BaalTrainer(
        max_epochs=3,
        default_root_dir=hparams.data_root,
        gpus=hparams.n_gpus,
        distributed_backend=dp,
        # The weights of the model will change as it gets
        # trained; we need to keep a copy (deepcopy) so that
        # we can reset them.
        callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))],
        dataset=active_set,
        heuristic=heuristic,
        ndata_to_label=hparams.query_size)

    AL_STEPS = 100
    for al_step in range(AL_STEPS):
        # TODO Issue 95 Make PL trainer epoch self-aware
        trainer.current_epoch = 0
        print(f'Step {al_step} Dataset size {len(active_set)}')
        trainer.fit(model)
        should_continue = trainer.step()
        if not should_continue:
            break
예제 #4
0
        'https://download.pytorch.org/models/vgg16-397923af.pth')
    weights = {k: v for k, v in weights.items() if 'classifier.6' not in k}
    model.load_state_dict(weights, strict=False)
    model = PIActiveLearningModel(network=model,
                                  active_dataset=active_set,
                                  hparams=params)

    dp = 'dp' if params.gpus > 1 else None
    trainer = BaalTrainer(
        max_epochs=params.epochs,
        default_root_dir=params.data_root,
        gpus=params.gpus,
        distributed_backend=dp,
        # The weights of the model will change as it gets
        # trained; we need to keep a copy (deepcopy) so that
        # we can reset them.
        callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))],
        dataset=active_set,
        heuristic=heuristic,
        ndata_to_label=params.query_size)

    AL_STEPS = 2000
    for al_step in range(AL_STEPS):
        # TODO fix this
        trainer.current_epoch = 0
        print(f'Step {al_step} Dataset size {len(active_set)}')
        trainer.fit(model)
        should_continue = trainer.step()
        if not should_continue:
            break