Exemplo n.º 1
0
def main(hparams):
    train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor()])
    test_transform = transforms.Compose([transforms.ToTensor()])

    active_set = ActiveLearningDataset(
        CIFAR10(hparams.data_root, train=True, transform=train_transform, download=True),
        pool_specifics={
            'transform': test_transform
        })
    active_set.label_randomly(10)
    heuristic = BALD()
    model = VGG16(active_set, hparams)
    dp = 'dp' if hparams.n_gpus > 1 else None
    trainer = BaalTrainer(max_epochs=3, default_root_dir=hparams.data_root,
                          gpus=hparams.n_gpus, distributed_backend=dp,
                          # The weights of the model will change as it gets
                          # trained; we need to keep a copy (deepcopy) so that
                          # we can reset them.
                          callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))])
    loop = ActiveLearningLoop(active_set, get_probabilities=trainer.predict_on_dataset_generator,
                              heuristic=heuristic,
                              ndata_to_label=hparams.query_size)

    AL_STEPS = 100
    for al_step in range(AL_STEPS):
        print(f'Step {al_step} Dataset size {len(active_set)}')
        trainer.fit(model)
        should_continue = loop.step()
        if not should_continue:
            break
Exemplo n.º 2
0
        'https://download.pytorch.org/models/vgg16-397923af.pth')
    weights = {k: v for k, v in weights.items() if 'classifier.6' not in k}
    model.load_state_dict(weights, strict=False)
    model = PIActiveLearningModel(network=model,
                                  active_dataset=active_set,
                                  hparams=params)

    dp = 'dp' if params.gpus > 1 else None
    trainer = BaalTrainer(
        max_epochs=params.epochs,
        default_root_dir=params.data_root,
        gpus=params.gpus,
        distributed_backend=dp,
        # The weights of the model will change as it gets
        # trained; we need to keep a copy (deepcopy) so that
        # we can reset them.
        callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))],
        dataset=active_set,
        heuristic=heuristic,
        ndata_to_label=params.query_size)

    AL_STEPS = 2000
    for al_step in range(AL_STEPS):
        # TODO fix this
        trainer.current_epoch = 0
        print(f'Step {al_step} Dataset size {len(active_set)}')
        trainer.fit(model)
        should_continue = trainer.step()
        if not should_continue:
            break