def __init__(self, network, device_ids=None, batchdim=0, drop_last=False): if device_ids is None: device_ids = xm.get_xla_supported_devices() self._device_ids = list(device_ids) self._batchdim = batchdim self._drop_last = drop_last self._native_run = False if len(self._device_ids) > 1: replication_devices = xm.xla_replication_devices(self._device_ids) self._replication = xm.Replication(self._device_ids, replication_devices) else: self._replication = None self._models = [] self._contexts = [] module = network if isinstance(network, torch.nn.Module) else network() for device in device_ids: device_module = deepcopy(module).to(device=torch.device(device)) self._models.append(device_module) self._contexts.append(Context(torch.device(device))) if not self._models: # No XLA device, push a vanilla network in. device = self._get_model_device(module) self._models.append(module) self._device_ids.append(device) self._contexts.append(Context(torch.device(device))) self._native_run = True
def __init__(self, network, device_ids=None, batchdim=0, drop_last=False): if device_ids is None: device_ids = xm.get_xla_supported_devices() self._device_ids = list(device_ids) self._batchdim = batchdim self._drop_last = drop_last replication_devices = (xm.xla_replication_devices(self._device_ids) if self._device_ids else None) self._replication = (xm.Replication(self._device_ids, replication_devices) if replication_devices else None) self._models = [] for device in device_ids: module = network().to(device=torch.device(device)) self._models.append(module) if not self._models: # No XLA device, push a vanilla network in. self._models.append(network())
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Pass [] as device_ids to run using the PyTorch/CPU engine. torchvision_model = get_model_property('model_fn') model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) lr_scheduler = context.getattr_or( 'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer if xm.is_master_ordinal() else None)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) if lr_scheduler: lr_scheduler.step() def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None num_devices = len( xm.xla_replication_devices(devices)) if len(devices) > 1 else 1 num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * num_devices) for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) global_step = (epoch - 1) * num_training_steps_per_epoch test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, global_step) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy