def test(self): batch_size = 128 scaler = torch.Tensor([[1.0 / batch_size]]) def loss_fn(x, y): diff = x - y sloss = diff.t().mm(diff) return sloss.mm(scaler) A = 3.11 B = 4.09 gen = xu.FnDataGenerator(lambda x: x * A + B, batch_size, _gen_tensor, count=100) model = AxPlusB(dims=(batch_size, 1)) xla_model = xm.XlaModel(model, [_gen_tensor(batch_size, 1)], target=_gen_tensor(batch_size, 1), loss_fn=loss_fn, num_cores=1, devices=[':0']) optimizer = optim.SGD(xla_model.parameters_list(), lr=0.1, momentum=0.5) xla_model.train(gen, optimizer, batch_size, log_fn=None) def eval_fn(output, target): mloss = (output - target) * (output - target) error = torch.ones_like(mloss) * 1e-5 count = torch.le(mloss, error).sum() return mloss.mean().item(), count.item() gen = xu.FnDataGenerator(lambda x: x * A + B, batch_size, _gen_tensor) accuracy = xla_model.test(gen, eval_fn, batch_size, log_fn=None) self.assertEqual(accuracy, 100.0)
def test(self): devices = xm.get_xla_supported_devices() A = 3.11 B = 4.09 batch_size = 128 * len(devices) gen = xu.FnDataGenerator( lambda x: x * A + B, batch_size, _gen_tensor, dims=[8], count=10) para_loader = pl.ParallelLoader(gen, batch_size, devices) for x, (data, target) in para_loader: for device in devices: dx = para_loader.to(data, device) self.assertEqual(dx.device, torch.device(device))
def test(self): A = 3.11 B = 4.09 batch_size = 128 gen = xu.FnDataGenerator(lambda x: x * A + B, batch_size, _gen_tensor, count=100) model = AxPlusB(dims=(batch_size, 1)) xla_model = xm.XlaModel(model, [_gen_tensor(batch_size, 1)]) optimizer = optim.SGD(xla_model.parameters_list(), lr=0.1, momentum=0.5) square_loss = SquareLoss() loss = None for x, target in gen: optimizer.zero_grad() y = xla_model(x) loss = square_loss(y[0], target) loss.backward() xla_model.backward(y) optimizer.step() self.assertEqualRel(loss.sum(), torch.tensor(0.0))