示例#1
0
    def step(self, epoch=None):
        current_lr = self.get_lr()[0]

        # Outside of warmup epochs, we use the same learning rate for every step
        # in an epoch. Don't bother updating learning rate if it hasn't changed.
        if abs(current_lr - self._previous_lr) >= self._min_delta_to_update_lr:
            super(WarmupAndExponentialDecayScheduler, self).step()
            self._previous_lr = current_lr
        else:
            self._step_count += 1  # This normally happens in super().step().

        # Add current learning rate to Tensorboard metrics. For warmup epochs,
        # log the learning rate at every step. For non-warmup epochs, log only
        # the first step since the entire epoch will use the same learning rate.
        if self._summary_writer:
            if self._is_warmup_epoch() or (self._step_count %
                                           self._num_steps_per_epoch == 0):
                test_utils.add_scalar_to_summary(
                    self._summary_writer, 'LearningRate',
                    self.optimizer.param_groups[0]['lr'], self._step_count)
示例#2
0
def train_mnist():
    torch.manual_seed(1)

    if FLAGS.fake_data:
        train_dataset_len = 60000  # Number of images in MNIST dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(FLAGS.datadir,
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        train_dataset_len = len(train_dataset)
        test_dataset = datasets.MNIST(FLAGS.datadir,
                                      train=False,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            sampler=test_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Scale learning rate to num cores
    lr = FLAGS.lr * max(len(devices), 1)
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model_parallel = dp.DataParallel(MNIST, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.NLLLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(
                model.parameters(), lr=lr, momentum=FLAGS.momentum))
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    writer = test_utils.get_summary_writer(FLAGS.logdir)
    num_devices = len(
        xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         num_devices)
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = mean(accuracies)
        print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        global_step = (epoch - 1) * num_training_steps_per_epoch
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         global_step)
        if FLAGS.metrics_debug:
            print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    return accuracy
示例#3
0
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            sampler=test_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    torchvision_model = get_model_property('model_fn')
    model_parallel = dp.DataParallel(torchvision_model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(model.parameters(),
                                           lr=FLAGS.lr,
                                           momentum=FLAGS.momentum,
                                           weight_decay=1e-4))
        lr_scheduler = context.getattr_or(
            'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler(
                optimizer,
                scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
                scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
                scheduler_divide_every_n_epochs=getattr(
                    FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
                num_steps_per_epoch=num_training_steps_per_epoch,
                summary_writer=writer if xm.is_master_ordinal() else None))
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())
            if lr_scheduler:
                lr_scheduler.step()

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    writer = test_utils.get_summary_writer(FLAGS.logdir)
    num_devices = len(
        xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         num_devices)
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = mean(accuracies)
        print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        global_step = (epoch - 1) * num_training_steps_per_epoch
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         global_step)
        if FLAGS.metrics_debug:
            print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    return accuracy
示例#4
0
def train_mnist():
    torch.manual_seed(1)

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=60000 // FLAGS.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(os.path.join(FLAGS.datadir,
                                                    str(xm.get_ordinal())),
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        test_dataset = datasets.MNIST(os.path.join(FLAGS.datadir,
                                                   str(xm.get_ordinal())),
                                      train=False,
                                      download=True,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    # Scale learning rate to num cores
    lr = FLAGS.lr * xm.xrt_world_size()

    device = xm.xla_device()
    model = MNIST().to(device)
    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=FLAGS.momentum)
    loss_fn = nn.NLLLoss()

    def train_loop_fn(loader):
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                xm.add_step_closure(_train_update,
                                    args=(device, x, loss, tracker))

    def test_loop_fn(loader):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device))
        xm.master_print('Finished training epoch {}'.format(epoch))

        para_loader = pl.ParallelLoader(test_loader, [device])
        accuracy = test_loop_fn(para_loader.per_device_loader(device))
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         epoch)
        if FLAGS.metrics_debug:
            print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    return accuracy
示例#5
0
def train_cifar():
  print('==> Preparing data..')

  if FLAGS.fake_data:
    train_dataset_len = 50000  # Number of example in CIFAR train set.
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 3, 32,
                          32), torch.zeros(FLAGS.batch_size,
                                           dtype=torch.int64)),
        sample_count=train_dataset_len // FLAGS.batch_size
        // xm.xrt_world_size())
    test_loader = xu.SampleGenerator(
        data=(torch.zeros(FLAGS.batch_size, 3, 32,
                          32), torch.zeros(FLAGS.batch_size,
                                           dtype=torch.int64)),
        sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
  else:
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = torchvision.datasets.CIFAR10(
        root=FLAGS.datadir,
        train=True,
        download=True,
        transform=transform_train)
    train_dataset_len = len(train_dataset)
    test_dataset = torchvision.datasets.CIFAR10(
        root=FLAGS.datadir,
        train=False,
        download=True,
        transform=transform_test)
    train_sampler = None
    test_sampler = None
    if xm.xrt_world_size() > 1:
      train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset, num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(), shuffle=True)
      test_sampler = torch.utils.data.distributed.DistributedSampler(
          test_dataset, num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(), shuffle=False)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=FLAGS.batch_size,
        sampler=train_sampler,
        drop_last=FLAGS.drop_last,
        shuffle=False if train_sampler else True,
        num_workers=FLAGS.num_workers)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=FLAGS.batch_size,
        sampler=test_sampler,
        drop_last=FLAGS.drop_last,
        shuffle=False,
        num_workers=FLAGS.num_workers)

  torch.manual_seed(42)

  devices = (
      xm.get_xla_supported_devices(
          max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
  # Pass [] as device_ids to run using the PyTorch/CPU engine.
  model = torchvision.models.resnet18 if FLAGS.use_torchvision else ResNet18
  model_parallel = dp.DataParallel(model, device_ids=devices)

  def train_loop_fn(model, loader, device, context):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = context.getattr_or(
        'optimizer', lambda: optim.SGD(
            model.parameters(),
            lr=FLAGS.lr,
            momentum=FLAGS.momentum,
            weight_decay=5e-4))
    tracker = xm.RateTracker()

    model.train()
    for x, (data, target) in enumerate(loader):
      optimizer.zero_grad()
      output = model(data)
      loss = loss_fn(output, target)
      loss.backward()
      xm.optimizer_step(optimizer)
      tracker.add(FLAGS.batch_size)
      if x % FLAGS.log_steps == 0:
        test_utils.print_training_update(device, x, loss.item(),
                                         tracker.rate(),
                                         tracker.global_rate())

  def test_loop_fn(model, loader, device, context):
    total_samples = 0
    correct = 0
    model.eval()
    for data, target in loader:
      output = model(data)
      pred = output.max(1, keepdim=True)[1]
      correct += pred.eq(target.view_as(pred)).sum().item()
      total_samples += data.size()[0]

    accuracy = 100.0 * correct / total_samples
    test_utils.print_test_update(device, accuracy)
    return accuracy

  accuracy = 0.0
  writer = test_utils.get_summary_writer(FLAGS.logdir)
  num_devices = len(
      xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
  num_training_steps_per_epoch = train_dataset_len // (
      FLAGS.batch_size * num_devices)
  max_accuracy = 0.0
  for epoch in range(1, FLAGS.num_epochs + 1):
    model_parallel(train_loop_fn, train_loader)
    accuracies = model_parallel(test_loop_fn, test_loader)
    accuracy = mean(accuracies)
    max_accuracy = max(accuracy, max_accuracy)
    print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
    global_step = (epoch - 1) * num_training_steps_per_epoch
    test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                     global_step)
    if FLAGS.metrics_debug:
      print(met.metrics_report())

  test_utils.close_summary_writer(writer)
  print('Max Accuracy: {:.2f}%'.format(accuracy))
  return max_accuracy