def train_cifar(): print('==> Preparing data..') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, 32, 32), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=True, download=True, transform=transform_train) test_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=False, download=True, transform=transform_test) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Pass [] as device_ids to run using the PyTorch/CPU engine. model = torchvision.models.resnet18 if FLAGS.use_torchvision else ResNet18 model_parallel = dp.DataParallel(model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(accuracies) print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy)) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_dataset_len = 1200000 # Roughly the size of Imagenet dataset. train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_dataset_len = len(train_dataset.imgs) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) device = xm.xla_device() model = get_model_property('model_fn')().to(device) writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None optimizer = optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4) num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * xm.xrt_world_size()) lr_scheduler = schedulers.wrap_optimizer_with_scheduler( optimizer, scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None), scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None), scheduler_divide_every_n_epochs=getattr( FLAGS, 'lr_scheduler_divide_every_n_epochs', None), num_steps_per_epoch=num_training_steps_per_epoch, summary_writer=writer if xm.is_master_ordinal() else None) loss_fn = nn.CrossEntropyLoss() def train_loop_fn(loader): tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if lr_scheduler: lr_scheduler.step() if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(loader): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): para_loader = dp.ParallelLoader(train_loader, [device]) train_loop_fn(para_loader.per_device_loader(device)) para_loader = dp.ParallelLoader(test_loader, [device]) accuracy = test_loop_fn(para_loader.per_device_loader(device)) print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)) test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy
def train_imagenet(): print('==> Preparing data..') img_dim = get_model_property('img_dim') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=1200000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim), torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)), sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size()) else: normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(img_dim), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) resize_dim = max(img_dim, 256) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), # Matches Torchvision's eval transforms except Torchvision uses size # 256 resize for all models both here and in the train loader. Their # version crashes during training on 299x299 images, e.g. inception. transforms.Compose([ transforms.Resize(resize_dim), transforms.CenterCrop(img_dim), transforms.ToTensor(), normalize, ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.test_set_batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Pass [] as device_ids to run using the PyTorch/CPU engine. torchvision_model = get_model_property('model_fn') model_parallel = dp.DataParallel(torchvision_model, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.CrossEntropyLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD(model.parameters(), lr=FLAGS.lr, momentum=FLAGS.momentum, weight_decay=5e-4)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = mean(accuracies) print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy)) test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy, epoch) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy
def train_mnist(): torch.manual_seed(1) if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size()) test_loader = xu.SampleGenerator( data=(torch.zeros(FLAGS.batch_size, 1, 28, 28), torch.zeros(FLAGS.batch_size, dtype=torch.int64)), sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size()) else: train_dataset = datasets.MNIST(FLAGS.datadir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) test_dataset = datasets.MNIST(FLAGS.datadir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])) train_sampler = None test_sampler = None if xm.xrt_world_size() > 1: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) test_sampler = torch.utils.data.distributed.DistributedSampler( test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, sampler=train_sampler, shuffle=False if train_sampler else True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, sampler=test_sampler, shuffle=False, num_workers=FLAGS.num_workers) devices = (xm.get_xla_supported_devices( max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else []) # Scale learning rate to num cores lr = FLAGS.lr * max(len(devices), 1) # Pass [] as device_ids to run using the PyTorch/CPU engine. model_parallel = dp.DataParallel(MNIST, device_ids=devices) def train_loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = context.getattr_or( 'optimizer', lambda: optim.SGD( model.parameters(), lr=lr, momentum=FLAGS.momentum)) tracker = xm.RateTracker() model.train() for x, (data, target) in loader: optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) tracker.add(FLAGS.batch_size) if x % FLAGS.log_steps == 0: test_utils.print_training_update(device, x, loss.item(), tracker.rate(), tracker.global_rate()) def test_loop_fn(model, loader, device, context): total_samples = 0 correct = 0 model.eval() for x, (data, target) in loader: output = model(data) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view_as(pred)).sum().item() total_samples += data.size()[0] accuracy = 100.0 * correct / total_samples test_utils.print_test_update(device, accuracy) return accuracy accuracy = 0.0 for epoch in range(1, FLAGS.num_epochs + 1): model_parallel(train_loop_fn, train_loader) accuracies = model_parallel(test_loop_fn, test_loader) accuracy = sum(accuracies) / len(accuracies) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy
def _get_device_spec(device): ordinal = xm.get_ordinal(defval=-1) return str(device) if ordinal < 0 else '{}/{}'.format(device, ordinal)