def _xla_run(model, input, device='TPU'): if isinstance(input, (tuple, list)): devices = ['{}:{}'.format(device, n) for n in range(0, len(input))] xla_model = xm.XlaModel(model, input[0], num_cores=len(input), devices=devices, full_conv_precision=True) output_xla = xla_model(*input) return xm.convert_to_tensors(output_xla) else: xla_model = xm.XlaModel(model, [input], full_conv_precision=True) output_xla = xla_model(input) return output_xla[0]
def test(self): batch_size = 128 scaler = torch.Tensor([[1.0 / batch_size]]) def loss_fn(x, y): diff = x - y sloss = diff.t().mm(diff) return sloss.mm(scaler) A = 3.11 B = 4.09 gen = xu.FnDataGenerator(lambda x: x * A + B, batch_size, _gen_tensor, count=100) model = AxPlusB(dims=(batch_size, 1)) xla_model = xm.XlaModel(model, [_gen_tensor(batch_size, 1)], target=_gen_tensor(batch_size, 1), loss_fn=loss_fn, num_cores=1, devices=[':0']) optimizer = optim.SGD(xla_model.parameters_list(), lr=0.1, momentum=0.5) xla_model.train(gen, optimizer, batch_size, log_fn=None) def eval_fn(output, target): mloss = (output - target) * (output - target) error = torch.ones_like(mloss) * 1e-5 count = torch.le(mloss, error).sum() return mloss.mean().item(), count.item() gen = xu.FnDataGenerator(lambda x: x * A + B, batch_size, _gen_tensor) accuracy = xla_model.test(gen, eval_fn, batch_size, log_fn=None) self.assertEqual(accuracy, 100.0)
def train_imagenet(): print('==> Preparing data..') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'train'), transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) test_dataset = torchvision.datasets.ImageFolder( os.path.join(FLAGS.datadir, 'val'), transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) torch.manual_seed(42) print('==> Building model..') momentum = 0.9 lr = 0.1 log_interval = max(1, int(10 / FLAGS.num_cores)) model = torchvision.models.resnet50() cross_entropy_loss = nn.CrossEntropyLoss() devices = [':{}'.format(n) for n in range(0, FLAGS.num_cores)] inputs = torch.zeros(FLAGS.batch_size, 3, 224, 224) target = torch.zeros(FLAGS.batch_size, dtype=torch.int64) xla_model = xm.XlaModel(model, [inputs], loss_fn=cross_entropy_loss, target=target, num_cores=FLAGS.num_cores, devices=devices) optimizer = optim.SGD(xla_model.parameters_list(), lr=lr, momentum=momentum, weight_decay=5e-4) for epoch in range(1, FLAGS.num_epochs + 1): xla_model.train(train_loader, optimizer, FLAGS.batch_size, log_interval=log_interval) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) accuracy = xla_model.test(test_loader, _cross_entropy_loss_eval_fn(cross_entropy_loss), FLAGS.batch_size) xm.update_optimizer_state(optimizer, 'lr', lambda x: x / 1.025) return accuracy
def _xla_run(model, input, device='TPU'): if isinstance(input, (tuple, list)): devices = ['{}:{}'.format(device, n) for n in range(0, len(input))] xla_model = xm.XlaModel(model, input[0], num_cores=len(input), devices=devices, full_conv_precision=True) output_xla = xla_model(*input) output = [] for xla_replica_outputs in output_xla: replica_outputs = [] for o in xm.as_list(xla_replica_outputs): replica_outputs.append(o.to_tensor()) output.append(tuple(replica_outputs)) return tuple(output) else: xla_model = xm.XlaModel(model, [input], full_conv_precision=True) output_xla = xla_model(input) return output_xla[0]
def train_mnist(): torch.manual_seed(1) # Training settings lr = 0.01 * FLAGS.num_cores momentum = 0.5 log_interval = max(1, int(10 / FLAGS.num_cores)) train_loader = torch.utils.data.DataLoader(datasets.MNIST( FLAGS.datadir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader(datasets.MNIST( FLAGS.datadir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, )) ])), batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) model = MNIST() # Trace the model. devices = [':{}'.format(n) for n in range(0, FLAGS.num_cores)] inputs = torch.zeros(FLAGS.batch_size, 1, 28, 28) target = torch.zeros(FLAGS.batch_size, dtype=torch.int64) xla_model = xm.XlaModel(model, [inputs], loss_fn=F.nll_loss, target=target, num_cores=FLAGS.num_cores, devices=devices) optimizer = optim.SGD(xla_model.parameters_list(), lr=lr, momentum=momentum) for epoch in range(1, FLAGS.num_epochs + 1): xla_model.train(train_loader, optimizer, FLAGS.batch_size, log_interval=log_interval) if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) accuracy = xla_model.test(test_loader, xm.category_eval_fn(F.nll_loss), FLAGS.batch_size) return accuracy
def test(self): A = 3.11 B = 4.09 model = AxPlusB(dims=(1, 1)) xla_model = xm.XlaModel(model, [_gen_tensor(1, 1)]) optimizer = optim.SGD(xla_model.parameters_list(), lr=0.1, momentum=0.5) square_loss = SquareLoss() loss = None for _ in range(0, 100): optimizer.zero_grad() x = _gen_tensor(1, 1) target = x * A + B y = xla_model(x) loss = square_loss(y[0], target) loss.backward() xla_model.backward(y) optimizer.step() self.assertEqualRel(loss.sum(), torch.tensor(0.0))
def test(self): A = 3.11 B = 4.09 batch_size = 128 gen = FnDataGenerator(lambda x: x * A + B, batch_size, count=100) model = AxPlusB(dims=(batch_size, 1)) xla_model = xm.XlaModel(model, [torch.randn(batch_size, 1)]) optimizer = optim.SGD(xla_model.parameters_list(), lr=0.1, momentum=0.5) square_loss = SquareLoss() loss = None for x, target in gen: optimizer.zero_grad() y = xla_model(x) loss = square_loss(y[0], target) loss.backward() xla_model.backward(y) optimizer.step() self.assertEqualRel(loss.sum(), torch.tensor(0.0))
def compareModel(self, model, input, rel_err=0.05, abs_err=1e-4): xla_model = xm.XlaModel(model, [input], full_conv_precision=True) output_xla = xla_model(input) output = model(input) self.assertEqualRel(output, xm.convert_to_tensors(output_xla)[0], rel_err=rel_err, abs_err=abs_err) grad_output = _gen_tensor(*output.shape) # random gradients grad_output.grad = grad_output.data output.backward(grad_output) xla_model.backward([grad_output]) xla_updated_params = [ p.grad.to_tensor() for p in xla_model.parameters()[0] ] updated_params = [p.grad for p in model.parameters()] self.assertEqual(len(xla_updated_params), len(updated_params)) for i in range(0, len(updated_params)): self.assertEqualRel(xla_updated_params[i], updated_params[i], rel_err=rel_err, abs_err=abs_err)
def train_cifar(): print('==> Preparing data..') if FLAGS.fake_data: train_loader = xu.SampleGenerator( data=torch.zeros(FLAGS.batch_size, 3, 32, 32), target=torch.zeros(FLAGS.batch_size, dtype=torch.int64), sample_count=50000 // FLAGS.batch_size) test_loader = xu.SampleGenerator( data=torch.zeros(FLAGS.batch_size, 3, 32, 32), target=torch.zeros(FLAGS.batch_size, dtype=torch.int64), sample_count=10000 // FLAGS.batch_size) else: transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) trainset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader( trainset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) testset = torchvision.datasets.CIFAR10(root=FLAGS.datadir, train=False, download=True, transform=transform_test) test_loader = torch.utils.data.DataLoader( testset, batch_size=FLAGS.batch_size, shuffle=False, num_workers=FLAGS.num_workers) torch.manual_seed(42) print('==> Building model..') momentum = 0.9 lr = 0.1 log_interval = max(1, int(10 / FLAGS.num_cores)) model = ResNet18() devices = [':{}'.format(n) for n in range(0, FLAGS.num_cores)] inputs = torch.zeros(FLAGS.batch_size, 3, 32, 32) target = torch.zeros(FLAGS.batch_size, dtype=torch.int64) xla_model = xm.XlaModel(model, [inputs], loss_fn=F.nll_loss, target=target, num_cores=FLAGS.num_cores, devices=devices) optimizer = optim.SGD(xla_model.parameters_list(), lr=lr, momentum=momentum, weight_decay=5e-4) log_fn = xm.get_log_fn(logdir=FLAGS.logdir) for epoch in range(1, FLAGS.num_epochs + 1): xla_model.train(train_loader, optimizer, FLAGS.batch_size, log_interval=log_interval, metrics_debug=FLAGS.metrics_debug, log_fn=log_fn) accuracy = xla_model.test(test_loader, xm.category_eval_fn(F.nll_loss), FLAGS.batch_size, log_fn=log_fn) xm.update_optimizer_state(optimizer, 'lr', lambda x: x / 1.025) return accuracy
def train_mnist(): assert FLAGS.num_cores == 1 torch.manual_seed(1) # Training settings lr = 0.01 momentum = 0.5 log_interval = 5 train_loader = torch.utils.data.DataLoader( datasets.MNIST(FLAGS.datadir, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) test_loader = torch.utils.data.DataLoader( datasets.MNIST(FLAGS.datadir, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers) model = MNIST() inputs = torch.zeros(FLAGS.batch_size, 1, 28, 28) xla_model = xm.XlaModel(model, [inputs]) optimizer = optim.SGD(xla_model.parameters_list(), lr=lr, momentum=momentum) loss_fn = nn.NLLLoss() accuracy = None for epoch in range(1, FLAGS.num_epochs + 1): # Training loop for epoch. start_time = time.time() processed = 0 for batch_idx, (data, target) in enumerate(train_loader): if data.size()[0] != FLAGS.batch_size: break optimizer.zero_grad() y = xla_model(data) y[0].requires_grad = True loss = loss_fn(y[0], target) loss.backward() xla_model.backward(y) optimizer.step() processed += FLAGS.batch_size if batch_idx % log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\t' 'Loss: {:.6f}\tSamples/sec: {:.1f}'.format( epoch, processed, len(train_loader) * FLAGS.batch_size, 100. * batch_idx / len(train_loader), loss, processed / (time.time() - start_time))) # Eval loop for epoch. start_time = time.time() correct_count = 0 test_loss = 0 count = 0 for batch_idx, (data, target) in enumerate(test_loader): if data.size()[0] != FLAGS.batch_size: break y = xla_model(data) test_loss += loss_fn(y[0], target).sum().item() pred = y[0].max(1, keepdim=True)[1] correct_count += pred.eq(target.view_as(pred)).sum().item() count += FLAGS.batch_size test_loss /= count accuracy = 100.0 * correct_count / count print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%), ' 'Samples/sec: {:.1f}\n'.format(test_loss, correct_count, count, accuracy, count / (time.time() - start_time))) # Debug metric dumping. if FLAGS.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) return accuracy