Пример #1
0
def _test_searchspace_on_dataset(searchspace, dataset='cifar10', arch=None):
    _reset()

    # dataset supports cifar10 and imagenet
    model, mutators = extract_mutation_from_pt_module(searchspace)

    if arch is None:
        model = try_mutation_until_success(model, mutators, 10)
        arch = {
            mut.mutator.label: _unpack_if_only_one(mut.samples)
            for mut in model.history
        }

    print('Selected model:', arch)
    with fixed_arch(arch):
        model = model.python_class(**model.python_init_params)

    if dataset == 'cifar10':
        train_data = FakeData(size=200,
                              image_size=(3, 32, 32),
                              num_classes=10,
                              transform=transforms.ToTensor())
        valid_data = FakeData(size=200,
                              image_size=(3, 32, 32),
                              num_classes=10,
                              transform=transforms.ToTensor())

    elif dataset == 'imagenet':
        train_data = FakeData(size=200,
                              image_size=(3, 224, 224),
                              num_classes=1000,
                              transform=transforms.ToTensor())
        valid_data = FakeData(size=200,
                              image_size=(3, 224, 224),
                              num_classes=1000,
                              transform=transforms.ToTensor())

    train_dataloader = pl.DataLoader(train_data, batch_size=4, shuffle=True)
    valid_dataloader = pl.DataLoader(valid_data, batch_size=6)

    evaluator = pl.Classification(
        train_dataloader=train_dataloader,
        val_dataloaders=valid_dataloader,
        export_onnx=False,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=3,
    )
    evaluator.fit(model)

    # cleanup to avoid affecting later test cases
    _reset()
Пример #2
0
 def _get_model_with_mutators(self, pytorch_model):
     return extract_mutation_from_pt_module(pytorch_model)