def _test_distrib_integration_mnist(dirname, device): from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision.transforms import Compose, Normalize, ToTensor data_transform = Compose([ToTensor(), Normalize((0.1307, ), (0.3081, ))]) train_loader = DataLoader(MNIST(download=True, root="/tmp", transform=data_transform, train=True), batch_size=256, shuffle=True) class DummyModel(nn.Module): def __init__(self, n_channels=10, out_channels=1, flatten_input=False): super(DummyModel, self).__init__() self.net = nn.Sequential( nn.Flatten() if flatten_input else nn.Identity(), nn.Linear(n_channels, out_channels)) def forward(self, x): return self.net(x) model = DummyModel(n_channels=784, out_channels=10, flatten_input=True) model = model.to(device) optimizer = SGD(model.parameters(), lr=1e-4, momentum=0.0) to_save = {"model": model, "optimizer": optimizer} engine = create_supervised_trainer(model, optimizer, nn.CrossEntropyLoss(), device=device) lr_finder = FastaiLRFinder() with lr_finder.attach(engine, to_save) as trainer_with_finder: trainer_with_finder.run(train_loader) lr_finder.plot() if idist.get_rank() == 0: ax = lr_finder.plot(skip_end=0) filepath = Path(dirname) / "distrib_dummy.jpg" ax.figure.savefig(filepath) assert filepath.exists() sug_lr = lr_finder.lr_suggestion() assert 1e-3 <= sug_lr <= 1 lr_finder.apply_suggested_lr(optimizer) assert optimizer.param_groups[0]["lr"] == sug_lr
# TODO Replace with Ignite parallelization if torch.cuda.device_count() > 1: model = nn.DataParallel(model) trainer, evaluator = get_trainer_evaluator(opts)(model, optimizer, criterion, device, loaders, loggers) raw_trainer = Engine(trainer._process_function) # Handlers lr_finder = FastaiLRFinder() to_save = {'model': model, 'optimizer': optimizer} with lr_finder.attach(raw_trainer, to_save=to_save) as trainer_with_lr_finder: trainer_with_lr_finder.run(loaders['train']) lr_finder.get_results() lr_finder.plot() opts.lr = lr_finder.lr_suggestion() opts.total_steps = len(loaders['train']) * opts.epochs scheduler = get_scheduler(opts, optimizer) trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler) to_save['scheduler'] = scheduler save_handler = Checkpoint(to_save, DiskSaver(opts.checkpoints_dir), n_saved=2, filename_prefix='best', score_function=score_function, score_name=opts.score_name) evaluator.add_event_handler(EvaluatorEvents.VALIDATION_COMPLETED, save_handler)