Example #1
0
def _xla_run(model, input, device='TPU'):
    if isinstance(input, (tuple, list)):
        devices = ['{}:{}'.format(device, n) for n in range(0, len(input))]
        xla_model = xm.XlaModel(model,
                                input[0],
                                num_cores=len(input),
                                devices=devices,
                                full_conv_precision=True)
        output_xla = xla_model(*input)
        return xm.convert_to_tensors(output_xla)
    else:
        xla_model = xm.XlaModel(model, [input], full_conv_precision=True)
        output_xla = xla_model(input)
        return output_xla[0]
Example #2
0
    def test(self):
        batch_size = 128
        scaler = torch.Tensor([[1.0 / batch_size]])

        def loss_fn(x, y):
            diff = x - y
            sloss = diff.t().mm(diff)
            return sloss.mm(scaler)

        A = 3.11
        B = 4.09
        gen = xu.FnDataGenerator(lambda x: x * A + B,
                                 batch_size,
                                 _gen_tensor,
                                 count=100)
        model = AxPlusB(dims=(batch_size, 1))
        xla_model = xm.XlaModel(model, [_gen_tensor(batch_size, 1)],
                                target=_gen_tensor(batch_size, 1),
                                loss_fn=loss_fn,
                                num_cores=1,
                                devices=[':0'])
        optimizer = optim.SGD(xla_model.parameters_list(),
                              lr=0.1,
                              momentum=0.5)
        xla_model.train(gen, optimizer, batch_size, log_fn=None)

        def eval_fn(output, target):
            mloss = (output - target) * (output - target)
            error = torch.ones_like(mloss) * 1e-5
            count = torch.le(mloss, error).sum()
            return mloss.mean().item(), count.item()

        gen = xu.FnDataGenerator(lambda x: x * A + B, batch_size, _gen_tensor)
        accuracy = xla_model.test(gen, eval_fn, batch_size, log_fn=None)
        self.assertEqual(accuracy, 100.0)
Example #3
0
def train_imagenet():
    print('==> Preparing data..')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = torchvision.datasets.ImageFolder(
        os.path.join(FLAGS.datadir, 'train'),
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=FLAGS.batch_size, shuffle=True,
        num_workers=FLAGS.num_workers)
    test_dataset = torchvision.datasets.ImageFolder(
        os.path.join(FLAGS.datadir, 'val'),
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=FLAGS.batch_size, shuffle=True,
        num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    print('==> Building model..')
    momentum = 0.9
    lr = 0.1
    log_interval = max(1, int(10 / FLAGS.num_cores))

    model = torchvision.models.resnet50()
    cross_entropy_loss = nn.CrossEntropyLoss()

    devices = [':{}'.format(n) for n in range(0, FLAGS.num_cores)]
    inputs = torch.zeros(FLAGS.batch_size, 3, 224, 224)
    target = torch.zeros(FLAGS.batch_size, dtype=torch.int64)
    xla_model = xm.XlaModel(model, [inputs], loss_fn=cross_entropy_loss,
                            target=target, num_cores=FLAGS.num_cores,
                            devices=devices)
    optimizer = optim.SGD(xla_model.parameters_list(), lr=lr,
                          momentum=momentum, weight_decay=5e-4)

    for epoch in range(1, FLAGS.num_epochs + 1):
        xla_model.train(train_loader, optimizer, FLAGS.batch_size,
                        log_interval=log_interval)
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())
        accuracy = xla_model.test(test_loader,
                                  _cross_entropy_loss_eval_fn(cross_entropy_loss),
                                  FLAGS.batch_size)
        xm.update_optimizer_state(optimizer, 'lr', lambda x: x / 1.025)
    return accuracy
Example #4
0
def _xla_run(model, input, device='TPU'):
    if isinstance(input, (tuple, list)):
        devices = ['{}:{}'.format(device, n) for n in range(0, len(input))]
        xla_model = xm.XlaModel(model,
                                input[0],
                                num_cores=len(input),
                                devices=devices,
                                full_conv_precision=True)
        output_xla = xla_model(*input)
        output = []
        for xla_replica_outputs in output_xla:
            replica_outputs = []
            for o in xm.as_list(xla_replica_outputs):
                replica_outputs.append(o.to_tensor())
            output.append(tuple(replica_outputs))
        return tuple(output)
    else:
        xla_model = xm.XlaModel(model, [input], full_conv_precision=True)
        output_xla = xla_model(input)
        return output_xla[0]
Example #5
0
def train_mnist():
    torch.manual_seed(1)
    # Training settings
    lr = 0.01 * FLAGS.num_cores
    momentum = 0.5
    log_interval = max(1, int(10 / FLAGS.num_cores))

    train_loader = torch.utils.data.DataLoader(datasets.MNIST(
        FLAGS.datadir,
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=FLAGS.batch_size,
                                               shuffle=True,
                                               num_workers=FLAGS.num_workers)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(
        FLAGS.datadir,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=FLAGS.batch_size,
                                              shuffle=True,
                                              num_workers=FLAGS.num_workers)

    model = MNIST()

    # Trace the model.
    devices = [':{}'.format(n) for n in range(0, FLAGS.num_cores)]
    inputs = torch.zeros(FLAGS.batch_size, 1, 28, 28)
    target = torch.zeros(FLAGS.batch_size, dtype=torch.int64)
    xla_model = xm.XlaModel(model, [inputs],
                            loss_fn=F.nll_loss,
                            target=target,
                            num_cores=FLAGS.num_cores,
                            devices=devices)
    optimizer = optim.SGD(xla_model.parameters_list(),
                          lr=lr,
                          momentum=momentum)

    for epoch in range(1, FLAGS.num_epochs + 1):
        xla_model.train(train_loader,
                        optimizer,
                        FLAGS.batch_size,
                        log_interval=log_interval)
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())
        accuracy = xla_model.test(test_loader, xm.category_eval_fn(F.nll_loss),
                                  FLAGS.batch_size)
    return accuracy
Example #6
0
 def test(self):
     A = 3.11
     B = 4.09
     model = AxPlusB(dims=(1, 1))
     xla_model = xm.XlaModel(model, [_gen_tensor(1, 1)])
     optimizer = optim.SGD(xla_model.parameters_list(),
                           lr=0.1,
                           momentum=0.5)
     square_loss = SquareLoss()
     loss = None
     for _ in range(0, 100):
         optimizer.zero_grad()
         x = _gen_tensor(1, 1)
         target = x * A + B
         y = xla_model(x)
         loss = square_loss(y[0], target)
         loss.backward()
         xla_model.backward(y)
         optimizer.step()
     self.assertEqualRel(loss.sum(), torch.tensor(0.0))
Example #7
0
 def test(self):
     A = 3.11
     B = 4.09
     batch_size = 128
     gen = FnDataGenerator(lambda x: x * A + B, batch_size, count=100)
     model = AxPlusB(dims=(batch_size, 1))
     xla_model = xm.XlaModel(model, [torch.randn(batch_size, 1)])
     optimizer = optim.SGD(xla_model.parameters_list(),
                           lr=0.1,
                           momentum=0.5)
     square_loss = SquareLoss()
     loss = None
     for x, target in gen:
         optimizer.zero_grad()
         y = xla_model(x)
         loss = square_loss(y[0], target)
         loss.backward()
         xla_model.backward(y)
         optimizer.step()
     self.assertEqualRel(loss.sum(), torch.tensor(0.0))
Example #8
0
 def compareModel(self, model, input, rel_err=0.05, abs_err=1e-4):
     xla_model = xm.XlaModel(model, [input], full_conv_precision=True)
     output_xla = xla_model(input)
     output = model(input)
     self.assertEqualRel(output,
                         xm.convert_to_tensors(output_xla)[0],
                         rel_err=rel_err,
                         abs_err=abs_err)
     grad_output = _gen_tensor(*output.shape)  # random gradients
     grad_output.grad = grad_output.data
     output.backward(grad_output)
     xla_model.backward([grad_output])
     xla_updated_params = [
         p.grad.to_tensor() for p in xla_model.parameters()[0]
     ]
     updated_params = [p.grad for p in model.parameters()]
     self.assertEqual(len(xla_updated_params), len(updated_params))
     for i in range(0, len(updated_params)):
         self.assertEqualRel(xla_updated_params[i],
                             updated_params[i],
                             rel_err=rel_err,
                             abs_err=abs_err)
Example #9
0
def train_cifar():
    print('==> Preparing data..')

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=torch.zeros(FLAGS.batch_size, 3, 32, 32),
            target=torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            sample_count=50000 // FLAGS.batch_size)
        test_loader = xu.SampleGenerator(
            data=torch.zeros(FLAGS.batch_size, 3, 32, 32),
            target=torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            sample_count=10000 // FLAGS.batch_size)
    else:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                train=True,
                                                download=True,
                                                transform=transform_train)
        train_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=FLAGS.batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers)

        testset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                               train=False,
                                               download=True,
                                               transform=transform_test)
        test_loader = torch.utils.data.DataLoader(
            testset,
            batch_size=FLAGS.batch_size,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    print('==> Building model..')
    momentum = 0.9
    lr = 0.1
    log_interval = max(1, int(10 / FLAGS.num_cores))

    model = ResNet18()

    devices = [':{}'.format(n) for n in range(0, FLAGS.num_cores)]
    inputs = torch.zeros(FLAGS.batch_size, 3, 32, 32)
    target = torch.zeros(FLAGS.batch_size, dtype=torch.int64)
    xla_model = xm.XlaModel(model, [inputs],
                            loss_fn=F.nll_loss,
                            target=target,
                            num_cores=FLAGS.num_cores,
                            devices=devices)
    optimizer = optim.SGD(xla_model.parameters_list(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=5e-4)

    log_fn = xm.get_log_fn(logdir=FLAGS.logdir)
    for epoch in range(1, FLAGS.num_epochs + 1):
        xla_model.train(train_loader,
                        optimizer,
                        FLAGS.batch_size,
                        log_interval=log_interval,
                        metrics_debug=FLAGS.metrics_debug,
                        log_fn=log_fn)
        accuracy = xla_model.test(test_loader,
                                  xm.category_eval_fn(F.nll_loss),
                                  FLAGS.batch_size,
                                  log_fn=log_fn)
        xm.update_optimizer_state(optimizer, 'lr', lambda x: x / 1.025)
    return accuracy
Example #10
0
def train_mnist():
    assert FLAGS.num_cores == 1
    torch.manual_seed(1)
    # Training settings
    lr = 0.01
    momentum = 0.5
    log_interval = 5

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(FLAGS.datadir, train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(FLAGS.datadir, train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers)

    model = MNIST()

    inputs = torch.zeros(FLAGS.batch_size, 1, 28, 28)
    xla_model = xm.XlaModel(model, [inputs])
    optimizer = optim.SGD(xla_model.parameters_list(), lr=lr, momentum=momentum)
    loss_fn = nn.NLLLoss()
    accuracy = None
    for epoch in range(1, FLAGS.num_epochs + 1):
        # Training loop for epoch.
        start_time = time.time()
        processed = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            if data.size()[0] != FLAGS.batch_size:
                break
            optimizer.zero_grad()
            y = xla_model(data)
            y[0].requires_grad = True
            loss = loss_fn(y[0], target)
            loss.backward()
            xla_model.backward(y)
            optimizer.step()
            processed += FLAGS.batch_size
            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\t'
                      'Loss: {:.6f}\tSamples/sec: {:.1f}'.format(
                          epoch, processed,
                          len(train_loader) * FLAGS.batch_size,
                          100. * batch_idx / len(train_loader), loss,
                          processed / (time.time() - start_time)))

        # Eval loop for epoch.
        start_time = time.time()
        correct_count = 0
        test_loss = 0
        count = 0
        for batch_idx, (data, target) in enumerate(test_loader):
            if data.size()[0] != FLAGS.batch_size:
                break
            y = xla_model(data)
            test_loss += loss_fn(y[0], target).sum().item()
            pred = y[0].max(1, keepdim=True)[1]
            correct_count += pred.eq(target.view_as(pred)).sum().item()
            count += FLAGS.batch_size

        test_loss /= count
        accuracy = 100.0 * correct_count / count
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%), '
              'Samples/sec: {:.1f}\n'.format(test_loss, correct_count, count,
                                             accuracy,
                                             count / (time.time() - start_time)))
        # Debug metric dumping.
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy