def test_fit_api(): _reset() transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=True, transform=transform) test_dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=transform) def lightning(): return pl.Classification( train_dataloader=pl.DataLoader(train_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), max_epochs=1, limit_train_batches=0.1, # for faster training progress_bar_refresh_rate=progress_bar_refresh_rate) # Lightning will have some cache in models / trainers, # which is problematic if we call fit multiple times. lightning().fit(lambda: MNISTModel()) lightning().fit(MNISTModel) lightning().fit(MNISTModel()) _reset()
def test_mnist(): _reset() transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) train_dataset = nni.trace(MNIST)(root='data/mnist', train=True, download=True, transform=transform) test_dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=transform) lightning = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100), val_dataloaders=pl.DataLoader(test_dataset, batch_size=100), max_epochs=2, limit_train_batches=0.25, # for faster training progress_bar_refresh_rate=progress_bar_refresh_rate) lightning._execute(MNISTModel) assert _get_final_result() > 0.7 _reset()
def _dataset_factory(dataset_type, subset=20): if dataset_type == 'cifar10': normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) train_dataset = nni.trace(CIFAR10)( 'data/cifar10', train=True, transform=transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, 4), transforms.ToTensor(), normalize, ])) valid_dataset = nni.trace(CIFAR10)('data/cifar10', train=False, transform=transforms.Compose([ transforms.ToTensor(), normalize, ])) elif dataset_type == 'imagenet': normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = nni.trace(ImageNet)( 'data/imagenet', split='val', # no train data available in tests transform=transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) valid_dataset = nni.trace(ImageNet)('data/imagenet', split='val', transform=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) else: raise ValueError(f'Unsupported dataset type: {dataset_type}') if subset: train_dataset = Subset( train_dataset, np.random.permutation(len(train_dataset))[:subset]) valid_dataset = Subset( valid_dataset, np.random.permutation(len(valid_dataset))[:subset]) return train_dataset, valid_dataset
def test_custom_class(): module = nni.trace(Foo)(3) assert nni.load(nni.dump(module)) == module module = nni.trace(Foo)(b=2, a=1) assert nni.load(nni.dump(module)) == module module = nni.trace(Foo)(Foo(1), 5) dumped_module = nni.dump(module) module = nni.load(dumped_module) assert module.bb[0] == module.bb[999] == 6 module = nni.trace(Foo)(nni.trace(Foo)(1), 5) dumped_module = nni.dump(module) assert nni.load(dumped_module) == module
def test_multiprocessing_dataloader(): # check whether multi-processing works # it's possible to have pickle errors dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=nni.trace(transforms.Compose)([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) import nni.retiarii.evaluator.pytorch.lightning as pl dataloader = pl.DataLoader(dataset, batch_size=10, num_workers=2) x, y = next(iter(dataloader)) assert x.size() == torch.Size([10, 1, 28, 28]) assert y.size() == torch.Size([10])
def test_custom_class(): module = nni.trace(Foo)(3) assert nni.load(nni.dump(module)) == module module = nni.trace(Foo)(b=2, a=1) assert nni.load(nni.dump(module)) == module module = nni.trace(Foo)(Foo(1), 5) dumped_module = nni.dump(module) assert len( dumped_module ) > 200 # should not be too longer if the serialization is correct module = nni.trace(Foo)(nni.trace(Foo)(1), 5) dumped_module = nni.dump(module) assert nni.load(dumped_module) == module
def test_dataset(): dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True) dataloader = nni.trace(DataLoader)(dataset, batch_size=10) dumped_ans = { "__symbol__": "path:torch.utils.data.dataloader.DataLoader", "__kwargs__": { "dataset": { "__symbol__": "path:torchvision.datasets.mnist.MNIST", "__kwargs__": { "root": "data/mnist", "train": False, "download": True } }, "batch_size": 10 } } print(nni.dump(dataloader)) print(nni.dump(dumped_ans)) assert nni.dump(dataloader) == nni.dump(dumped_ans) dataloader = nni.load(nni.dump(dumped_ans)) assert isinstance(dataloader, DataLoader) dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=nni.trace(transforms.Compose)([ nni.trace(transforms.ToTensor)(), nni.trace(transforms.Normalize)((0.1307, ), (0.3081, )) ])) dataloader = nni.trace(DataLoader)(dataset, batch_size=10) x, y = next(iter(nni.load(nni.dump(dataloader)))) assert x.size() == torch.Size([10, 1, 28, 28]) assert y.size() == torch.Size([10]) dataset = nni.trace(MNIST)(root='data/mnist', train=False, download=True, transform=nni.trace(transforms.Compose)([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) dataloader = nni.trace(DataLoader)(dataset, batch_size=10) x, y = next(iter(nni.load(nni.dump(dataloader)))) assert x.size() == torch.Size([10, 1, 28, 28]) assert y.size() == torch.Size([10])
def get_mnist_evaluator(): transform = transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) train_dataset = nni.trace(MNIST)('data/mnist', download=True, train=True, transform=transform) train_loader = pl.DataLoader(train_dataset, 64) valid_dataset = nni.trace(MNIST)('data/mnist', download=True, train=False, transform=transform) valid_loader = pl.DataLoader(valid_dataset, 64) return pl.Classification( train_dataloader=train_loader, val_dataloaders=valid_loader, limit_train_batches=20, limit_val_batches=20, max_epochs=1 )
def test_lightning_earlystop(): import nni.retiarii.evaluator.pytorch.lightning as pl from pytorch_lightning.callbacks.early_stopping import EarlyStopping trainer = pl.Trainer( callbacks=[nni.trace(EarlyStopping)(monitor="val_loss")]) trainer = nni.load(nni.dump(trainer)) assert any( isinstance(callback, EarlyStopping) for callback in trainer.callbacks)
def test_external_class(): from collections import OrderedDict d = nni.trace(kw_only=False)(OrderedDict)([('a', 1), ('b', 2)]) assert d['a'] == 1 assert d['b'] == 2 dump_str = nni.dump(d) assert dump_str == '{"a": 1, "b": 2}' conv = nni.trace(torch.nn.Conv2d)(3, 16, 3) assert conv.in_channels == 3 assert conv.out_channels == 16 assert conv.kernel_size == (3, 3) assert nni.dump(conv) == \ r'{"__symbol__": "path:torch.nn.modules.conv.Conv2d", ' \ r'"__kwargs__": {"in_channels": 3, "out_channels": 16, "kernel_size": 3}}' conv = nni.load(nni.dump(conv)) assert conv.kernel_size == (3, 3)
def test_lightning_earlystop(): import nni.retiarii.evaluator.pytorch.lightning as pl from pytorch_lightning.callbacks.early_stopping import EarlyStopping trainer = pl.Trainer( callbacks=[nni.trace(EarlyStopping)(monitor="val_loss")]) pickle_size_limit = 4096 if sys.platform == 'linux' else 32768 trainer = nni.load(nni.dump(trainer, pickle_size_limit=pickle_size_limit)) assert any( isinstance(callback, EarlyStopping) for callback in trainer.callbacks)
def test_multiprocessing_dataset(): from torch.utils.data import Dataset dataset = nni.trace(Dataset)() import multiprocessing process = multiprocessing.Process( target=_test_multiprocessing_dataset_worker, args=(dataset, )) process.start() process.join() assert process.exitcode == 0
def test_arguments_kind(): def foo(a, b, *c, **d): pass d = nni.trace(foo)(1, 2, 3, 4) assert d.trace_args == [1, 2, 3, 4] assert d.trace_kwargs == {} d = nni.trace(foo)(a=1, b=2) assert d.trace_kwargs == dict(a=1, b=2) d = nni.trace(foo)(1, b=2) # this is not perfect, but it's safe assert d.trace_kwargs == dict(a=1, b=2) def foo(a, *, b=3, c=5): pass d = nni.trace(foo)(1, b=2, c=3) assert d.trace_kwargs == dict(a=1, b=2, c=3) import torch.nn as nn lstm = nni.trace(nn.LSTM)(2, 2) assert lstm.input_size == 2 assert lstm.hidden_size == 2 assert lstm.trace_args == [2, 2] lstm = nni.trace(nn.LSTM)(input_size=2, hidden_size=2) assert lstm.trace_kwargs == {'input_size': 2, 'hidden_size': 2}
def test_function(): t = nni.trace(math.sqrt, kw_only=False)(3) assert 1 < t < 2 assert t.trace_symbol == math.sqrt assert t.trace_args == [3] t = nni.load(nni.dump(t)) assert 1 < t < 2 assert not is_traceable(t) # trace not recovered, expected, limitation def simple_class_factory(bb=3.): return SimpleClass(1, bb) t = nni.trace(simple_class_factory)(4) ts = nni.dump(t) assert '__kwargs__' in ts t = nni.load(ts) assert t._a == 1 assert is_traceable(t) t = t.trace_copy() assert is_traceable(t) assert t.trace_symbol(10)._b == 10 assert t.trace_kwargs['bb'] == 4 assert is_traceable(t.trace_copy())
def test_generator(): import torch.nn as nn import torch.optim as optim class Net(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 10, 1) def forward(self, x): return self.conv(x) model = Net() optimizer = nni.trace(optim.Adam)(model.parameters()) print(optimizer.trace_kwargs)
def _mnist_net(type_, evaluator_kwargs): if type_ == 'simple': base_model = SimpleNet(False) elif type_ == 'simple_value_choice': base_model = SimpleNet() elif type_ == 'value_choice': base_model = ValueChoiceConvNet() elif type_ == 'repeat': base_model = RepeatNet() elif type_ == 'cell': base_model = CellNet() elif type_ == 'custom_op': base_model = CustomOpValueChoiceNet() else: raise ValueError(f'Unsupported type: {type_}') transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) train_dataset = nni.trace(MNIST)('data/mnist', download=True, train=True, transform=transform) # Multi-GPU combined dataloader will break this subset sampler. Expected though. train_random_sampler = nni.trace(RandomSampler)( train_dataset, True, int(len(train_dataset) / 20)) train_loader = nni.trace(DataLoader)(train_dataset, 64, sampler=train_random_sampler) valid_dataset = nni.trace(MNIST)('data/mnist', download=True, train=False, transform=transform) valid_random_sampler = nni.trace(RandomSampler)( valid_dataset, True, int(len(valid_dataset) / 20)) valid_loader = nni.trace(DataLoader)(valid_dataset, 64, sampler=valid_random_sampler) evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **evaluator_kwargs) return base_model, evaluator
from pathlib import Path import torch from torch.optim import Adam import nni from nni.compression.experiment.experiment import CompressionExperiment from nni.compression.experiment.config import CompressionExperimentConfig, TaylorFOWeightPrunerConfig from vessel import LeNet, finetuner, evaluator, trainer, criterion, device model = LeNet().to(device) # pre-training model finetuner(model) optimizer = nni.trace(Adam)(model.parameters()) dummy_input = torch.rand(16, 1, 28, 28).to(device) # normal experiment setting, no need to set search_space and trial_command config = CompressionExperimentConfig('local') config.experiment_name = 'auto compression torch example' config.trial_concurrency = 1 config.max_trial_number = 10 config.trial_code_directory = Path(__file__).parent config.tuner.name = 'TPE' config.tuner.class_args['optimize_mode'] = 'maximize' # compression experiment specific setting # single float value means the expected remaining ratio upper limit for flops & params, lower limit for metric config.compression_setting.flops = 0.2
model = BertForSequenceClassification.from_pretrained( 'bert-base-cased', num_labels=num_labels).to(device) print('Initial: {}'.format( evaluator(model, metric, is_regression, validate_dataloader))) config_list = [{ 'op_types': ['Linear'], 'op_partial_names': ['bert.encoder'], 'sparsity': 0.9 }] p_trainer = functools.partial(trainer, train_dataloader=train_dataloader) # make sure you have used nni.trace to wrap the optimizer class before initialize traced_optimizer = nni.trace(Adam)(model.parameters(), lr=2e-5) pruner = MovementPruner(model, config_list, p_trainer, traced_optimizer, criterion, training_epochs=10, warm_up_step=12272, cool_down_beginning_step=110448) _, masks = pruner.compress() pruner.show_pruned_weights() print('Final: {}'.format( evaluator(model, metric, is_regression, validate_dataloader)))
def get_optimizer(model): return nni.trace(torch.optim.SGD)(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
class LightningModule(pl.LightningModule): """ Basic wrapper of generated model. Lightning modules used in NNI should inherit this class. """ def set_model(self, model: Union[Type[nn.Module], nn.Module]) -> None: if isinstance(model, nn.Module): self.model = model else: self.model = model() Trainer = nni.trace(pl.Trainer) DataLoader = nni.trace(torch_data.DataLoader) @nni.trace class Lightning(Evaluator): """ Delegate the whole training to PyTorch Lightning. Since the arguments passed to the initialization needs to be serialized, ``LightningModule``, ``Trainer`` or ``DataLoader`` in this file should be used. Another option is to hide dataloader in the Lightning module, in which case, dataloaders are not required for this class to work. Following the programming style of Lightning, metrics sent to NNI should be obtained from ``callback_metrics`` in trainer. Two hooks are added at the end of validation epoch and the end of ``fit``, respectively. The metric name and type depend on the specific task.
pre_flops, pre_params, _ = count_flops_params( model, torch.randn([128, 3, 32, 32]).to(device)) g_epoch = 0 # Start to prune and speedup print('\n' + '=' * 50 + ' START TO PRUNE THE BEST ACCURACY PRETRAINED MODEL ' + '=' * 50) config_list = [{ 'total_sparsity': 0.5, 'op_types': ['Conv2d'], }] # make sure you have used nni.trace to wrap the optimizer class before initialize traced_optimizer = nni.trace(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) if 'apoz' in args.pruner: pruner = ActivationAPoZRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20) else: pruner = ActivationMeanRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
import nni def test_positional_only(): def foo(a, b, /, c): pass d = nni.trace(foo)(1, 2, c=3) assert d.trace_args == [1, 2] assert d.trace_kwargs == dict(c=3)