Exemplo n.º 1
0
    def test(self):
        devices = xm.get_xla_supported_devices()
        batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4)
        sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(batch_size, 3, 224,
                              224), torch.zeros(batch_size,
                                                dtype=torch.int64)),
            sample_count=sample_count * len(devices))

        def loop_fn(model, loader, device, context):
            loss_fn = nn.NLLLoss()
            optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

            for x, (data, target) in loader:
                with xu.TimedScope(msg='Training loop: ', printfn=None):
                    optimizer.zero_grad()
                    output = xu.timed(lambda: model(data),
                                      msg='Model: ',
                                      printfn=None)
                    loss = xu.timed(lambda: loss_fn(output, target),
                                    msg='Loss: ',
                                    printfn=None)
                    xu.timed(loss.backward, msg='LossBkw: ', printfn=None)
                    xu.timed(lambda: xm.optimizer_step(optimizer),
                             msg='Step: ',
                             printfn=None)
                    self.assertLess(loss.cpu().item(), 3.0)

        model_parallel = dp.DataParallel(torchvision.models.resnet18,
                                         device_ids=devices)
        model_parallel(loop_fn, train_loader)
Exemplo n.º 2
0
def train_imagenet():
  print('==> Preparing data..')
  img_dim = get_model_property('img_dim')
  if FLAGS.fake_data:
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
              torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
        sample_count=1200000 // FLAGS.batch_size)
    test_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
              torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
        sample_count=50000 // FLAGS.batch_size)
  else:
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    train_dataset = torchvision.datasets.ImageFolder(
        os.path.join(FLAGS.datadir, 'train'),
        transforms.Compose([
            transforms.RandomResizedCrop(img_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=FLAGS.batch_size,
        shuffle=True,
        num_workers=FLAGS.num_workers)
    test_dataset = torchvision.datasets.ImageFolder(
        os.path.join(FLAGS.datadir, 'val'),
        transforms.Compose([
            transforms.RandomResizedCrop(img_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=FLAGS.batch_size,
        shuffle=True,
        num_workers=FLAGS.num_workers)

  torch.manual_seed(42)

  devices = xm.get_xla_supported_devices(max_devices=FLAGS.num_cores)
  # Pass [] as device_ids to run using the PyTorch/CPU engine.
  torchvision_model = get_model_property('model_fn')
  model_parallel = dp.DataParallel(torchvision_model, device_ids=devices)

  def train_loop_fn(model, loader, device, context):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        model.parameters(),
        lr=FLAGS.lr,
        momentum=FLAGS.momentum,
        weight_decay=5e-4)
    tracker = xm.RateTracker()
    for x, (data, target) in loader:
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS.batch_size)
      if x % FLAGS.log_steps == 0:
        print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(),
                                                        tracker.rate()))

  def test_loop_fn(model, loader, device, context):
    total_samples = 0
    correct = 0
    for x, (data, target) in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    print('[{}] Accuracy={:.2f}%'.format(device,
                                         100.0 * correct / total_samples))
    return correct / total_samples

  accuracy = 0.0
  for epoch in range(1, FLAGS.num_epochs + 1):
    model_parallel(train_loop_fn, train_loader)
    accuracies = model_parallel(test_loop_fn, test_loader)
    accuracy = sum(accuracies) / len(devices)
    if FLAGS.metrics_debug:
      print(torch_xla._XLAC._xla_metrics_report())

  return accuracy * 100.0
Exemplo n.º 3
0
def train_mnist():
  torch.manual_seed(1)

  if FLAGS.fake_data:
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 1, 28,
                          28), torch.zeros(FLAGS.batch_size,
                                           dtype=torch.int64)),
        sample_count=60000 // FLAGS.batch_size)
    test_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 1, 28,
                          28), torch.zeros(FLAGS.batch_size,
                                           dtype=torch.int64)),
        sample_count=10000 // FLAGS.batch_size)
  else:
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            FLAGS.datadir,
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
        batch_size=FLAGS.batch_size,
        shuffle=True,
        num_workers=FLAGS.num_workers)
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(
            FLAGS.datadir,
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])),
        batch_size=FLAGS.batch_size,
        shuffle=True,
        num_workers=FLAGS.num_workers)

  devices = (
      xm.get_xla_supported_devices(
          max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
  # Scale learning rate to num cores
  lr = FLAGS.lr * max(len(devices), 1)
  # Pass [] as device_ids to run using the PyTorch/CPU engine.
  model_parallel = dp.DataParallel(MNIST, device_ids=devices)

  def train_loop_fn(model, loader, device, context):
    loss_fn = nn.NLLLoss()
    optimizer = context.getattr_or(
        'optimizer',
        lambda: optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum))
    tracker = xm.RateTracker()

    model.train()
    for x, (data, target) in loader:
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS.batch_size)
      if x % FLAGS.log_steps == 0:
        print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(device, x, loss.item(),
                                                        tracker.rate()))

  def test_loop_fn(model, loader, device, context):
    total_samples = 0
    correct = 0
    model.eval()
    for x, (data, target) in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    print('[{}] Accuracy={:.2f}%'.format(device,
                                         100.0 * correct / total_samples))
    return correct / total_samples

  accuracy = 0.0
  for epoch in range(1, FLAGS.num_epochs + 1):
    model_parallel(train_loop_fn, train_loader)
    accuracies = model_parallel(test_loop_fn, test_loader)
    accuracy = sum(accuracies) / len(accuracies)
    if FLAGS.metrics_debug:
      print(torch_xla._XLAC._xla_metrics_report())

  return accuracy * 100.0
Exemplo n.º 4
0
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    device = xm.xla_device()
    model = get_model_property('model_fn')().to(device)
    writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None
    optimizer = optim.SGD(model.parameters(),
                          lr=FLAGS.lr,
                          momentum=FLAGS.momentum,
                          weight_decay=5e-4)
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         xm.xrt_world_size())
    lr_scheduler = schedulers.wrap_optimizer_with_scheduler(
        optimizer,
        scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
        scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
        scheduler_divide_every_n_epochs=getattr(
            FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
        num_steps_per_epoch=num_training_steps_per_epoch,
        summary_writer=writer if xm.is_master_ordinal() else None)
    loss_fn = nn.CrossEntropyLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        para_loader = dp.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))

        para_loader = dp.ParallelLoader(test_loader, [device])
        accuracy = test_loop_fn(para_loader.per_device_loader(device))
        print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         epoch)

        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy
Exemplo n.º 5
0
def train_cifar():
    print('==> Preparing data..')

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                     train=True,
                                                     download=True,
                                                     transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                    train=False,
                                                    download=True,
                                                    transform=transform_test)
        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            sampler=test_sampler,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model = torchvision.models.resnet18 if FLAGS.use_torchvision else ResNet18
    model_parallel = dp.DataParallel(model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(model.parameters(),
                                           lr=FLAGS.lr,
                                           momentum=FLAGS.momentum,
                                           weight_decay=5e-4))
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = sum(accuracies) / len(accuracies)
        print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy))
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy
Exemplo n.º 6
0
def train_cifar():
    print('==> Preparing data..')

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size)
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size)
    else:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                train=True,
                                                download=True,
                                                transform=transform_train)
        train_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=FLAGS.batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers)

        testset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                               train=False,
                                               download=True,
                                               transform=transform_test)
        test_loader = torch.utils.data.DataLoader(
            testset,
            batch_size=FLAGS.batch_size,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = xm.get_xla_supported_devices(max_devices=FLAGS.num_cores)
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model_parallel = dp.DataParallel(ResNet18, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(),
                              lr=FLAGS.lr,
                              momentum=FLAGS.momentum,
                              weight_decay=5e-4)
        tracker = xm.RateTracker()

        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                print('[{}]({}) Loss={:.5f} Rate={:.2f}'.format(
                    device, x, loss.item(), tracker.rate()))

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        print('[{}] Accuracy={:.2f}%'.format(device,
                                             100.0 * correct / total_samples))
        return correct / total_samples

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = sum(accuracies) / len(devices)
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy * 100.0
Exemplo n.º 7
0
def train_imagenet():
    print('==> Preparing data..')
    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=torch.zeros(FLAGS.batch_size, 3, 224, 224),
            target=torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            sample_count=1200000 // FLAGS.batch_size)
        test_loader = xu.SampleGenerator(
            data=torch.zeros(FLAGS.batch_size, 3, 224, 224),
            target=torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            sample_count=50000 // FLAGS.batch_size)
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    print('==> Building model..')
    momentum = 0.9
    lr = 0.1
    log_interval = max(1, int(10 / FLAGS.num_cores))

    model = torchvision.models.resnet50()
    cross_entropy_loss = nn.CrossEntropyLoss()

    devices = [':{}'.format(n) for n in range(0, FLAGS.num_cores)]
    inputs = torch.zeros(FLAGS.batch_size, 3, 224, 224)
    target = torch.zeros(FLAGS.batch_size, dtype=torch.int64)
    xla_model = xm.XlaModel(model, [inputs],
                            loss_fn=cross_entropy_loss,
                            target=target,
                            num_cores=FLAGS.num_cores,
                            devices=devices)
    optimizer = optim.SGD(xla_model.parameters_list(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=5e-4)

    log_fn = xm.get_log_fn(logdir=FLAGS.logdir)
    for epoch in range(1, FLAGS.num_epochs + 1):
        xla_model.train(train_loader,
                        optimizer,
                        FLAGS.batch_size,
                        log_interval=log_interval,
                        metrics_debug=FLAGS.metrics_debug,
                        log_fn=log_fn)
        accuracy = xla_model.test(
            test_loader,
            _cross_entropy_loss_eval_fn(cross_entropy_loss),
            FLAGS.batch_size,
            log_fn=log_fn)
        xm.update_optimizer_state(optimizer, 'lr', lambda x: x / 1.025)
    return accuracy
Exemplo n.º 8
0
def train_mnist():
  torch.manual_seed(1)

  if FLAGS.fake_data:
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 1, 28,
                          28), torch.zeros(FLAGS.batch_size,
                                           dtype=torch.int64)),
        sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size())
    test_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 1, 28,
                          28), torch.zeros(FLAGS.batch_size,
                                           dtype=torch.int64)),
        sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
  else:
    train_dataset = datasets.MNIST(
        FLAGS.datadir,
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))]))
    test_dataset = datasets.MNIST(
        FLAGS.datadir,
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.1307,), (0.3081,))]))
    train_sampler = None
    if xm.xrt_world_size() > 1:
      train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=FLAGS.batch_size,
        sampler=train_sampler,
        shuffle=False if train_sampler else True,
        num_workers=FLAGS.num_workers)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=FLAGS.batch_size,
        shuffle=False,
        num_workers=FLAGS.num_workers)

  # Scale learning rate to num cores
  lr = FLAGS.lr * xm.xrt_world_size()

  device = xm.xla_device()
  model = MNIST().to(device)
  optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum)
  loss_fn = nn.NLLLoss()

  def train_loop_fn(loader):
    tracker = xm.RateTracker()

    model.train()
    for x, (data, target) in loader:
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS.batch_size)
      if x % FLAGS.log_steps == 0:
        test_utils.print_training_update(device, x, loss.item(), tracker.rate(),
                                         tracker.global_rate())

  def test_loop_fn(loader):
    total_samples = 0
    correct = 0
    model.eval()
    for x, (data, target) in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct / total_samples
    test_utils.print_test_update(device, accuracy)
    return accuracy

  accuracy = 0.0
  for epoch in range(1, FLAGS.num_epochs + 1):
    para_loader = dp.ParallelLoader(train_loader, [device])
    train_loop_fn(para_loader.per_device_loader(device))

    para_loader = dp.ParallelLoader(test_loader, [device])
    accuracy = test_loop_fn(para_loader.per_device_loader(device))
    if FLAGS.metrics_debug:
      print(torch_xla._XLAC._xla_metrics_report())

  return accuracy
Exemplo n.º 9
0
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=1200000 // FLAGS.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            sampler=test_sampler,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    torchvision_model = get_model_property('model_fn')
    model_parallel = dp.DataParallel(torchvision_model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(model.parameters(),
                                           lr=FLAGS.lr,
                                           momentum=FLAGS.momentum,
                                           weight_decay=5e-4))
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = mean(accuracies)
        print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy))
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         epoch)
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy
Exemplo n.º 10
0
def train_cifar():
    print('==> Preparing data..')

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=torch.zeros(FLAGS.batch_size, 3, 32, 32),
            target=torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            sample_count=50000 // FLAGS.batch_size)
        test_loader = xu.SampleGenerator(
            data=torch.zeros(FLAGS.batch_size, 3, 32, 32),
            target=torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            sample_count=10000 // FLAGS.batch_size)
    else:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        trainset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                train=True,
                                                download=True,
                                                transform=transform_train)
        train_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=FLAGS.batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers)

        testset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                               train=False,
                                               download=True,
                                               transform=transform_test)
        test_loader = torch.utils.data.DataLoader(
            testset,
            batch_size=FLAGS.batch_size,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    print('==> Building model..')
    momentum = 0.9
    lr = 0.1
    log_interval = max(1, int(10 / FLAGS.num_cores))

    model = ResNet18()

    devices = [':{}'.format(n) for n in range(0, FLAGS.num_cores)]
    inputs = torch.zeros(FLAGS.batch_size, 3, 32, 32)
    target = torch.zeros(FLAGS.batch_size, dtype=torch.int64)
    xla_model = xm.XlaModel(model, [inputs],
                            loss_fn=F.nll_loss,
                            target=target,
                            num_cores=FLAGS.num_cores,
                            devices=devices)
    optimizer = optim.SGD(xla_model.parameters_list(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=5e-4)

    log_fn = xm.get_log_fn(logdir=FLAGS.logdir)
    for epoch in range(1, FLAGS.num_epochs + 1):
        xla_model.train(train_loader,
                        optimizer,
                        FLAGS.batch_size,
                        log_interval=log_interval,
                        metrics_debug=FLAGS.metrics_debug,
                        log_fn=log_fn)
        accuracy = xla_model.test(test_loader,
                                  xm.category_eval_fn(F.nll_loss),
                                  FLAGS.batch_size,
                                  log_fn=log_fn)
        xm.update_optimizer_state(optimizer, 'lr', lambda x: x / 1.025)
    return accuracy
Exemplo n.º 11
0
def train_mnist():
    torch.manual_seed(1)
    # Training settings
    lr = 0.01 * FLAGS.num_cores
    momentum = 0.5
    log_interval = max(1, int(10 / FLAGS.num_cores))

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=torch.zeros(FLAGS.batch_size, 1, 28, 28),
            target=torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            sample_count=60000 // FLAGS.batch_size)
        test_loader = xu.SampleGenerator(
            data=torch.zeros(FLAGS.batch_size, 1, 28, 28),
            target=torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            sample_count=10000 // FLAGS.batch_size)
    else:
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(FLAGS.datadir,
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=FLAGS.batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(FLAGS.datadir,
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=FLAGS.batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers)

    model = MNIST()

    # Trace the model.
    devices = [':{}'.format(n) for n in range(0, FLAGS.num_cores)]
    inputs = torch.zeros(FLAGS.batch_size, 1, 28, 28)
    target = torch.zeros(FLAGS.batch_size, dtype=torch.int64)
    xla_model = xm.XlaModel(model, [inputs],
                            loss_fn=F.nll_loss,
                            target=target,
                            num_cores=FLAGS.num_cores,
                            devices=devices)
    optimizer = optim.SGD(xla_model.parameters_list(),
                          lr=lr,
                          momentum=momentum)

    log_fn = xm.get_log_fn(logdir=FLAGS.logdir)
    for epoch in range(1, FLAGS.num_epochs + 1):
        xla_model.train(train_loader,
                        optimizer,
                        FLAGS.batch_size,
                        log_interval=log_interval,
                        metrics_debug=FLAGS.metrics_debug,
                        log_fn=log_fn)
        accuracy = xla_model.test(test_loader,
                                  xm.category_eval_fn(F.nll_loss),
                                  FLAGS.batch_size,
                                  log_fn=log_fn)
    return accuracy
Exemplo n.º 12
0
def train_mnist():
    assert FLAGS.num_cores == 1
    torch.manual_seed(1)
    # Training settings
    lr = 0.01
    momentum = 0.5
    log_interval = 5

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=torch.zeros(FLAGS.batch_size, 1, 28, 28),
            target=torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            sample_count=60000 // FLAGS.batch_size)
        test_loader = xu.SampleGenerator(
            data=torch.zeros(FLAGS.batch_size, 1, 28, 28),
            target=torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            sample_count=10000 // FLAGS.batch_size)
    else:
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(FLAGS.datadir,
                           train=True,
                           download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=FLAGS.batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(FLAGS.datadir,
                           train=False,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307, ), (0.3081, ))
                           ])),
            batch_size=FLAGS.batch_size,
            shuffle=True,
            num_workers=FLAGS.num_workers)

    model = MNIST()

    inputs = torch.zeros(FLAGS.batch_size, 1, 28, 28)
    xla_model = xm.XlaModel(model, [inputs])
    optimizer = optim.SGD(xla_model.parameters_list(),
                          lr=lr,
                          momentum=momentum)
    loss_fn = nn.NLLLoss()
    accuracy = None
    for epoch in range(1, FLAGS.num_epochs + 1):
        # Training loop for epoch.
        start_time = time.time()
        processed = 0
        for batch_idx, (data, target) in enumerate(train_loader):
            if data.size()[0] != FLAGS.batch_size:
                break
            optimizer.zero_grad()
            y = xla_model(data)
            y[0].requires_grad = True
            loss = loss_fn(y[0], target)
            loss.backward()
            xla_model.backward(y)
            optimizer.step()
            processed += FLAGS.batch_size
            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\t'
                      'Loss: {:.6f}\tSamples/sec: {:.1f}'.format(
                          epoch, processed,
                          len(train_loader) * FLAGS.batch_size,
                          100. * batch_idx / len(train_loader), loss,
                          processed / (time.time() - start_time)))

        # Eval loop for epoch.
        start_time = time.time()
        correct_count = 0
        test_loss = 0
        count = 0
        for batch_idx, (data, target) in enumerate(test_loader):
            if data.size()[0] != FLAGS.batch_size:
                break
            y = xla_model(data)
            test_loss += loss_fn(y[0], target).sum().item()
            pred = y[0].max(1, keepdim=True)[1]
            correct_count += pred.eq(target.view_as(pred)).sum().item()
            count += FLAGS.batch_size

        test_loss /= count
        accuracy = 100.0 * correct_count / count
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%), '
              'Samples/sec: {:.1f}\n'.format(
                  test_loss, correct_count, count, accuracy,
                  count / (time.time() - start_time)))
        # Debug metric dumping.
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy