Exemplo n.º 1
0
def test_diabetes():
    _reset()
    nni.trial._params = {'foo': 'bar', 'parameter_id': 0}
    nni.runtime.platform.test._last_metric = None
    train_dataset = DiabetesDataset(train=True)
    test_dataset = DiabetesDataset(train=False)
    lightning = pl.Regression(
        optimizer=torch.optim.SGD,
        train_dataloader=pl.DataLoader(train_dataset, batch_size=20),
        val_dataloaders=pl.DataLoader(test_dataset, batch_size=20),
        max_epochs=100,
        progress_bar_refresh_rate=progress_bar_refresh_rate)
    lightning._execute(FCNet(train_dataset.x.shape[1], 1))
    assert _get_final_result() < 2e4
    _reset()
Exemplo n.º 2
0
def test_mnist():
    _reset()
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    train_dataset = bm(MNIST)(root='data/mnist',
                              train=True,
                              download=True,
                              transform=transform)
    test_dataset = bm(MNIST)(root='data/mnist',
                             train=False,
                             download=True,
                             transform=transform)
    lightning = pl.Classification(
        train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
        val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
        max_epochs=2,
        limit_train_batches=0.25,  # for faster training
        progress_bar_refresh_rate=progress_bar_refresh_rate)
    lightning._execute(MNISTModel)
    assert _get_final_result() > 0.7
    _reset()
Exemplo n.º 3
0
    ])
    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = bm(CIFAR10)(root='data/cifar10',
                                train=True,
                                download=True,
                                transform=train_transform)
    test_dataset = bm(CIFAR10)(root='data/cifar10',
                               train=False,
                               download=True,
                               transform=valid_transform)
    trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset,
                                                               batch_size=100),
                                val_dataloaders=pl.DataLoader(test_dataset,
                                                              batch_size=100),
                                max_epochs=1,
                                limit_train_batches=0.2)

    simple_startegy = RandomStrategy()

    exp = RetiariiExperiment(base_model, trainer, [], simple_startegy)

    exp_config = RetiariiExeConfig('local')
    exp_config.experiment_name = 'darts_search'
    exp_config.trial_concurrency = 2
    exp_config.max_trial_number = 10
    exp_config.trial_gpu_number = 1
    exp_config.training_service.use_active_gpu = True