Exemplo n.º 1
0
def test_reset_callback_resets_weights(a_data_module):
    def reset_fcs(model):
        """Reset all torch.nn.Linear layers."""
        def reset(m):
            if isinstance(m, torch.nn.Linear):
                m.reset_parameters()

        model.apply(reset)

    model = vgg16()
    trainer = BaalTrainer(dataset=a_data_module.active_dataset,
                          max_epochs=3,
                          default_root_dir='/tmp')
    trainer.current_epoch = 10
    initial_weights = copy.deepcopy(model.state_dict())
    initial_params = copy.deepcopy(list(model.parameters()))
    callback = ResetCallback(initial_weights)
    # Modify the params
    reset_fcs(model)
    new_params = model.parameters()
    assert not all(
        torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params))
    callback.on_train_start(trainer, model)
    new_params = model.parameters()
    assert all(
        torch.eq(p1, p2).all() for p1, p2 in zip(initial_params, new_params))
    assert trainer.current_epoch == 0
Exemplo n.º 2
0
def main(hparams):
    train_transform = transforms.Compose(
        [transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])
    test_transform = transforms.Compose([transforms.ToTensor()])

    active_set = ActiveLearningDataset(
        CIFAR10(hparams.data_root,
                train=True,
                transform=train_transform,
                download=True),
        pool_specifics={'transform': test_transform})
    active_set.label_randomly(10)
    heuristic = BALD()
    model = VGG16(active_set, hparams)
    dp = 'dp' if hparams.n_gpus > 1 else None
    trainer = BaalTrainer(
        max_epochs=3,
        default_root_dir=hparams.data_root,
        gpus=hparams.n_gpus,
        distributed_backend=dp,
        # The weights of the model will change as it gets
        # trained; we need to keep a copy (deepcopy) so that
        # we can reset them.
        callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))],
        dataset=active_set,
        heuristic=heuristic,
        ndata_to_label=hparams.query_size)

    AL_STEPS = 100
    for al_step in range(AL_STEPS):
        # TODO Issue 95 Make PL trainer epoch self-aware
        trainer.current_epoch = 0
        print(f'Step {al_step} Dataset size {len(active_set)}')
        trainer.fit(model)
        should_continue = trainer.step()
        if not should_continue:
            break
Exemplo n.º 3
0
        'https://download.pytorch.org/models/vgg16-397923af.pth')
    weights = {k: v for k, v in weights.items() if 'classifier.6' not in k}
    model.load_state_dict(weights, strict=False)
    model = PIActiveLearningModel(network=model,
                                  active_dataset=active_set,
                                  hparams=params)

    dp = 'dp' if params.gpus > 1 else None
    trainer = BaalTrainer(
        max_epochs=params.epochs,
        default_root_dir=params.data_root,
        gpus=params.gpus,
        distributed_backend=dp,
        # The weights of the model will change as it gets
        # trained; we need to keep a copy (deepcopy) so that
        # we can reset them.
        callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))],
        dataset=active_set,
        heuristic=heuristic,
        ndata_to_label=params.query_size)

    AL_STEPS = 2000
    for al_step in range(AL_STEPS):
        # TODO fix this
        trainer.current_epoch = 0
        print(f'Step {al_step} Dataset size {len(active_set)}')
        trainer.fit(model)
        should_continue = trainer.step()
        if not should_continue:
            break