def test_diabetes(): _reset() nni.trial._params = {'foo': 'bar', 'parameter_id': 0} nni.runtime.platform.test._last_metric = None train_dataset = DiabetesDataset(train=True) test_dataset = DiabetesDataset(train=False) lightning = pl.Regression( optimizer=torch.optim.SGD, train_dataloader=pl.DataLoader(train_dataset, batch_size=20), val_dataloaders=pl.DataLoader(test_dataset, batch_size=20), max_epochs=100, progress_bar_refresh_rate=progress_bar_refresh_rate) lightning._execute(FCNet(train_dataset.x.shape[1], 1)) assert _get_final_result() < 2e4 _reset()
def test_mnist(): _reset() transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) train_dataset = bm(MNIST)(root='data/mnist', train=True, download=True, transform=transform) test_dataset = bm(MNIST)(root='data/mnist', train=False, download=True, transform=transform) lightning = pl.Classification( train_dataloader=pl.DataLoader(train_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), max_epochs=2, limit_train_batches=0.25, # for faster training progress_bar_refresh_rate=progress_bar_refresh_rate) lightning._execute(MNISTModel) assert _get_final_result() > 0.7 _reset()
]) valid_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_dataset = bm(CIFAR10)(root='data/cifar10', train=True, download=True, transform=train_transform) test_dataset = bm(CIFAR10)(root='data/cifar10', train=False, download=True, transform=valid_transform) trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), max_epochs=1, limit_train_batches=0.2) simple_startegy = RandomStrategy() exp = RetiariiExperiment(base_model, trainer, [], simple_startegy) exp_config = RetiariiExeConfig('local') exp_config.experiment_name = 'darts_search' exp_config.trial_concurrency = 2 exp_config.max_trial_number = 10 exp_config.trial_gpu_number = 1 exp_config.training_service.use_active_gpu = True