コード例 #1
0
 def lightning():
     return pl.Classification(
         train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
         val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
         max_epochs=1,
         limit_train_batches=0.1,  # for faster training
         progress_bar_refresh_rate=progress_bar_refresh_rate)
コード例 #2
0
def test_mnist():
    _reset()
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=True, transform=transform)
    test_dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=transform)
    lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
                                  val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
                                  max_epochs=2, limit_train_batches=0.25,  # for faster training
                                  progress_bar_refresh_rate=progress_bar_refresh_rate)
    lightning._execute(MNISTModel)
    assert _get_final_result() > 0.7
    _reset()
コード例 #3
0
def _test_searchspace_on_dataset(searchspace, dataset='cifar10', arch=None):
    _reset()

    # dataset supports cifar10 and imagenet
    model, mutators = extract_mutation_from_pt_module(searchspace)

    if arch is None:
        model = try_mutation_until_success(model, mutators, 10)
        arch = {
            mut.mutator.label: _unpack_if_only_one(mut.samples)
            for mut in model.history
        }

    print('Selected model:', arch)
    with fixed_arch(arch):
        model = model.python_class(**model.python_init_params)

    if dataset == 'cifar10':
        train_data = FakeData(size=200,
                              image_size=(3, 32, 32),
                              num_classes=10,
                              transform=transforms.ToTensor())
        valid_data = FakeData(size=200,
                              image_size=(3, 32, 32),
                              num_classes=10,
                              transform=transforms.ToTensor())

    elif dataset == 'imagenet':
        train_data = FakeData(size=200,
                              image_size=(3, 224, 224),
                              num_classes=1000,
                              transform=transforms.ToTensor())
        valid_data = FakeData(size=200,
                              image_size=(3, 224, 224),
                              num_classes=1000,
                              transform=transforms.ToTensor())

    train_dataloader = pl.DataLoader(train_data, batch_size=4, shuffle=True)
    valid_dataloader = pl.DataLoader(valid_data, batch_size=6)

    evaluator = pl.Classification(
        train_dataloader=train_dataloader,
        val_dataloaders=valid_dataloader,
        export_onnx=False,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=3,
    )
    evaluator.fit(model)

    # cleanup to avoid affecting later test cases
    _reset()
コード例 #4
0
ファイル: test_experiment.py プロジェクト: yinfupai/nni
def get_mnist_evaluator():
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_dataset = nni.trace(MNIST)('data/mnist', download=True, train=True, transform=transform)
    train_loader = pl.DataLoader(train_dataset, 64)
    valid_dataset = nni.trace(MNIST)('data/mnist', download=True, train=False, transform=transform)
    valid_loader = pl.DataLoader(valid_dataset, 64)
    return pl.Classification(
        train_dataloader=train_loader, val_dataloaders=valid_loader,
        limit_train_batches=20,
        limit_val_batches=20,
        max_epochs=1
    )
コード例 #5
0
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = serialize(CIFAR10, root='data/cifar10', train=True, download=True, transform=train_transform)
    test_dataset = serialize(CIFAR10, root='data/cifar10', train=False, download=True, transform=valid_transform)
    trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
                                val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
                                max_epochs=1, limit_train_batches=0.2)

    simple_strategy = strategy.Random()

    exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)

    exp_config = RetiariiExeConfig('remote')
    exp_config.experiment_name = 'darts_search'
    exp_config.trial_concurrency = 2
    exp_config.max_trial_number = 10
    exp_config.trial_gpu_number = 1
    exp_config.training_service.use_active_gpu = True
    exp_config.training_service.reuse_mode = True
    exp_config.training_service.gpu_indices = [0, 1, 2]
コード例 #6
0
ファイル: test.py プロジェクト: SparkSnail/nni
    ])

    train_dataset = serialize(CIFAR10,
                              root='data/cifar10',
                              train=True,
                              download=True,
                              transform=train_transform)
    test_dataset = serialize(CIFAR10,
                             root='data/cifar10',
                             train=False,
                             download=True,
                             transform=valid_transform)
    trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset,
                                                               batch_size=100),
                                val_dataloaders=pl.DataLoader(test_dataset,
                                                              batch_size=100),
                                max_epochs=1,
                                limit_train_batches=0.2,
                                progress_bar_refresh_rate=0)

    simple_strategy = strategy.Random()

    exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)

    exp_config = RetiariiExeConfig('local')
    exp_config.experiment_name = 'darts_search'
    exp_config.trial_concurrency = 2
    exp_config.max_trial_number = 10
    exp_config.trial_gpu_number = 1
    exp_config.training_service.use_active_gpu = True
    exp_config.training_service.gpu_indices = [1, 2]
コード例 #7
0
    ])

    train_dataset = serialize(CIFAR10,
                              root='data/cifar10',
                              train=True,
                              download=True,
                              transform=train_transform)
    test_dataset = serialize(CIFAR10,
                             root='data/cifar10',
                             train=False,
                             download=True,
                             transform=valid_transform)
    trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset,
                                                               batch_size=100),
                                val_dataloaders=pl.DataLoader(test_dataset,
                                                              batch_size=100),
                                max_epochs=1,
                                limit_train_batches=0.2,
                                enable_progress_bar=False)

    simple_strategy = strategy.Random()

    exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)

    exp_config = RetiariiExeConfig('local')
    exp_config.experiment_name = 'darts_search'
    exp_config.trial_concurrency = 2
    exp_config.max_trial_number = 10
    exp_config.trial_gpu_number = 1
    exp_config.training_service.use_active_gpu = True
    exp_config.training_service.gpu_indices = [1, 2]