Example #1
0
def run_benchmark(args, pos_args):
  devices = xm.get_xla_supported_devices(max_devices=args.max_devices)
  shape = [int(x) for x in args.shape.split(',')]

  send_list = []
  for i in range(0, len(devices)):
    mb = []
    for j in range(0, args.prefetch):
      mb.append(torch.randn(*shape))
    send_list.append(mb)

  def threadfn(i):
    device = devices[i]
    xdevices = [device] * len(send_list[i])
    for n in range(0, args.test_count):
      with xu.TimedScope(msg='Send[{}][{}]: '.format(i, n), printfn=print):
        _ = torch_xla._XLAC._xla_tensors_from_aten(send_list[i], xdevices)

  threads = []
  for i in range(0, len(devices)):
    t = threading.Thread(target=threadfn, args=(i,))
    t.start()
    threads.append(t)
  for t in threads:
    t.join()
  print(torch_xla._XLAC._xla_metrics_report())
Example #2
0
  def test(self):
    devices = xm.get_xla_supported_devices()
    batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4)
    sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(batch_size, 3, 224,
                          224), torch.zeros(batch_size, dtype=torch.int64)),
        sample_count=sample_count * len(devices))

    def loop_fn(model, loader, device, context):
      loss_fn = nn.NLLLoss()
      optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

      for data, target in loader:
        with xu.TimedScope(msg='Training loop: ', printfn=None):
          optimizer.zero_grad()
          output = xu.timed(lambda: model(data), msg='Model: ', printfn=None)
          loss = xu.timed(
              lambda: loss_fn(output, target), msg='Loss: ', printfn=None)
          xu.timed(loss.backward, msg='LossBkw: ', printfn=None)
          xu.timed(
              lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None)
          self.assertLess(loss.cpu().item(), 3.0)

    model_parallel = dp.DataParallel(
        torchvision.models.resnet18, device_ids=devices)
    model_parallel(loop_fn, train_loader)
Example #3
0
def inference():

    model = model_from_name(args.model_name)

    devices = (xm.get_xla_supported_devices(
        max_devices=args.num_cores) if args.num_cores != 0 else [])

    model.load_state_dict(torch.load(args.weight_file))

    model_parallel = dp.DataParallel(model, device_ids=devices)

    patient = extract_patient(args.csv_file_path, return_label=False)
    ## For all predictions files must be multiple by num_cores*batch_size
    extra = args.num_cores * args.batch_size - len(
        patient) % args.num_cores * args.batch_size
    patient = np.concatenate([patient, patient[:extra]])

    ds = RSNADataset(patient, path=args.path, transform=train_transforms)
    loader = D.DataLoader(ds,
                          batch_size=args.batch_size,
                          shuffle=False,
                          num_workers=args.num_workers)

    result = model_parallel(infer_loop, loader)
    result = np.array(result)
    score, patientid = [result[:, i] for i in range(result.shape[-1])]
    score = np.concatenate(score)
    patientid = np.concatenate(patientid)

    prediction_to_df(score, patientid, args.subm_file)
Example #4
0
    def _is_device_tpu() -> bool:
        """
        Check if TPU devices are available

        Return:
            A boolean value indicating if TPU devices are available
        """

        return len(xm.get_xla_supported_devices("TPU")) > 0
Example #5
0
def test__xla_dist_model_run_parallel_n_threads_without_sync():
    # tests issue : https://github.com/pytorch/ignite/issues/1096
    import torch_xla.core.xla_model as xm
    from joblib import delayed, Parallel

    devices = xm.get_xla_supported_devices()
    folds = 1
    d = 0
    if len(devices) > 5:
        folds = 5
        d = 1
    Parallel(n_jobs=folds, backend="threading")(delayed(main_fold)(i + d) for i in range(folds))
Example #6
0
    def _is_device_tpu() -> bool:
        """Check if TPU devices are available.

        Return:
            A boolean value indicating if TPU devices are available
        """
        # For the TPU Pod training process, for example, if we have
        # TPU v3-32 with 4 VMs, the world size would be 4 and as
        # we would have to use `torch_xla.distributed.xla_dist` for
        # multiple VMs and TPU_CONFIG won't be available, running
        # `xm.get_xla_supported_devices("TPU")` won't be possible.
        return (xm.xrt_world_size() > 1) or bool(
            xm.get_xla_supported_devices("TPU"))
Example #7
0
 def test(self):
   devices = [torch.device(x) for x in xm.get_xla_supported_devices()]
   A = 3.11
   B = 4.09
   batch_size = 128 * len(devices)
   gen = xu.FnDataGenerator(
       lambda x: x * A + B, batch_size, _gen_tensor, dims=[8], count=10)
   para_loader = pl.ParallelLoader(gen, devices)
   for device in devices:
     loader = para_loader.per_device_loader(device)
     for data, target in loader:
       self.assertEqual(data.device, device)
       self.assertEqual(target.device, device)
Example #8
0
def run_thread_per_device(rank: int, processes: int,
                          fn: Callable[..., R]) -> Dict[int, R]:
  """Runs `fn` in a separate thread on each visible device.

  Args:
    rank: rank of current process
    processes: number of processes on this host
    fn: Function to run on all devices

  Returns:
    Dict of the form {thread_rank: return_value}, where return_value is the
    result of calling `fn`.
  """
  if device_type() == 'TPU':
    configure_tpu_topology(rank, processes)

  xm.set_replication(xm.xla_device(), xm.get_xla_supported_devices())
  threads = len(xm.get_xla_supported_devices())

  def _thread_fn(fn, device_index):

    @functools.wraps(fn)
    def wrapper(*args, **kwargs):
      # Assumes same number of threads per process
      set_global_ordinal(rank * threads + device_index)
      set_local_ordinal(rank * threads + device_index)

      return fn(*args, **kwargs)

    return wrapper

  with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor:
    futures = {executor.submit(_thread_fn(fn, i)): i for i in range(threads)}

    results = {
        futures[f]: f.result() for f in concurrent.futures.as_completed(futures)
    }

  return results
Example #9
0
def train():
    set_seeds()
    logging.info('Loading masks...')
    with open(args.json_file, 'r') as f:
        masks = json.load(f)

    # for example use only 200 images
    filename = list(masks.keys())[:200]

    global devices, num_steps_per_epoch

    devices = (xm.get_xla_supported_devices(
        max_devices=args.num_cores) if args.num_cores != 0 else [])

    logging.info('Start training model')
    if args.model_name == 'deeplabv3_resnet50':
        m = torchvision.models.segmentation.deeplabv3_resnet50(False)
    else:
        m = torchvision.models.segmentation.fcn_resnet50(False)

    m.classifier[-1] = torch.nn.Conv2d(m.classifier[-1].in_channels, 46, 1)
    # wrapped for parallel training
    model = dp.DataParallel(m, device_ids=devices)

    ds = FashionDataset(filename,
                        masks,
                        path=args.data_path,
                        transform=train_transform,
                        size=(256, 256))
    loader = D.DataLoader(ds,
                          batch_size=args.batch_size,
                          shuffle=True,
                          num_workers=args.num_worker)

    num_steps_per_epoch = len(loader) // len(devices)

    for epoch in range(1, args.epochs + 1):

        train_loss = model(train_loop_fn, loader)
        train_loss = np.array(train_loss).mean()

        logging.info('[Epoch {:3d}] Train loss: {:.3f}'.format(
            epoch, train_loss))

    # Save weights
    state_dict = model.models[0].to('cpu').state_dict()
    torch.save(state_dict, args.save_file)
    logging.info('')
    logging.info('Model saved\n')
Example #10
0
 def test(self):
     devices = xm.get_xla_supported_devices()
     A = 3.11
     B = 4.09
     batch_size = 128 * len(devices)
     gen = xu.FnDataGenerator(lambda x: x * A + B,
                              batch_size,
                              _gen_tensor,
                              dims=[8],
                              count=10)
     para_loader = pl.ParallelLoader(gen, batch_size, devices)
     for x, (data, target) in para_loader:
         for device in devices:
             dx = para_loader.to(data, device)
             self.assertEqual(dx.device, torch.device(device))
Example #11
0
 def test(self):
   devices = xm.get_xla_supported_devices()
   for device in reversed(devices):
     t = _gen_tensor(8, 12)
     tto = t.to(device=torch.device(device))
     self.assertEqual(tto.device, torch.device(device))
   t = _gen_tensor(8, 12).to(device=torch.device(devices[0]))
   for device in devices[1:]:
     tto = t.to(device=torch.device(device))
     self.assertEqual(tto.device, torch.device(device))
   for i in range(0, len(devices) - 1):
     dev0 = devices[i]
     dev1 = devices[i + 1]
     t0 = torch.zeros(4, 4, device=torch.device(dev0))
     t1 = t0.to(device=torch.device(dev1))
     t0 = t0 + torch.ones_like(t0, device=torch.device(dev0))
     t1 = t1 + torch.ones_like(t1, device=torch.device(dev1))
     self.assertEqual(t0.cpu(), t1.cpu())
Example #12
0
 def __init__(self, network, device_ids=None):
   if device_ids is None:
     device_ids = xm.get_xla_supported_devices()
   self._device_ids = [str(x) for x in device_ids]
   self._native_run = False
   self._models = []
   self._contexts = []
   module = network if isinstance(network, torch.nn.Module) else network()
   for device in device_ids:
     device_module = deepcopy(module).to(device=torch.device(device))
     self._models.append(device_module)
     self._contexts.append(Context(torch.device(device)))
   if not self._models:
     # No XLA device, push a vanilla network in.
     device = self._get_model_device(module)
     self._models.append(module)
     self._device_ids.append(device)
     self._contexts.append(Context(torch.device(device)))
     self._native_run = True
Example #13
0
def xla_device(n: Optional[int] = None,
               devkind: Optional[str] = None) -> torch.device:
  """Returns an XLA device.

  Args:
    n: Index of XLA device within visibible devices. If not set, use local
      ordinal (default 0) to select an addressable device.
    devkind: Type of device to return. Should match `device_type()`.

  Returns:
    A `torch.device` representing an XLA device.
  """
  devices = xm.get_xla_supported_devices(devkind=devkind)
  device_index = n or (local_ordinal(default=0) % addressable_device_count())
  if device_index > len(devices):
    raise IndexError('Device index {} out of range in {}'.format(
        device_index, devices))

  torch_xla._XLAC._xla_set_default_device(devices[device_index])
  return torch.device(devices[device_index])
Example #14
0
def train():
    set_seeds()
    
    global num_steps_per_epoch   

    ## Create and wrap model ##
    model = model_from_name(args.model_name)
    
    devices = (
        xm.get_xla_supported_devices(max_devices=args.num_cores) if args.num_cores != 0 else [])
    model_parallel = dp.DataParallel(model,  device_ids=devices)
    logging.info('Model {} loaded and wrapped'.format(args.model_name))
    logging.info('')
    
    ## Create dataset and loader    
    patient, label = extract_patient(args.csv_file_path, return_label=True)
    
    if args.use_image:
        ds = RSNAImages(patient, label, path=args.path, transform=train_transforms)
    else:
        ds = RSNADataset(patient, label, path=args.path, transform=train_transforms)
        
    loader = D.DataLoader(ds, batch_size=args.batch_size,
                          shuffle=True, num_workers=args.num_workers)
    logging.info('Dataset created\n')
    logging.info('Start training model\n')

    num_steps_per_epoch = len(loader) // len(devices)

    for epoch in range(1, args.num_epochs + 1):
        start_time = time.time()
        model_parallel(train_loop_fn, loader)
        logging.info('')
        logging.info('Epoch training time: {:.2f} minutes\n'.format((time.time() - start_time)/60**1))
        
    # Save weights
    state_dict = model_parallel.models[0].to('cpu').state_dict()
    torch.save(state_dict, args.save_pht)
    logging.info('')
    logging.info('Model saved')
    logging.info('')
Example #15
0
    def test_xla_sharded_tensor(self):
        # Test XLAShardedTensor basic properties

        # Simple 1-D sharding
        num_devices = len(xm.get_xla_supported_devices("TPU"))
        mesh_shape = (1, num_devices)
        partition_spec = (1, )
        t1 = torch.tensor([2.0, 3.0],
                          dtype=torch.float,
                          device=xm.xla_device())
        t1_sharded = XLAShardedTensor(t1, mesh_shape, partition_spec)
        t2 = torch.tensor([2.0, 3.0],
                          dtype=torch.float,
                          device=xm.xla_device())
        t3 = torch.add(t1_sharded, t2)

        assert isinstance(
            t3,
            XLAShardedTensor), "Sharded ops should return XLAShardedTensor."
        assert t3.size() == t1.size(
        ), "Sharded output should return unpartitioned tensor size."

        device_ids = np.array(range(num_devices))
        device_assignment = list(device_ids.reshape(mesh_shape))
Example #16
0
def train_cifar():
    print('==> Preparing data..')

    if FLAGS.fake_data:
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, 32, 32),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                     train=True,
                                                     download=True,
                                                     transform=transform_train)
        test_dataset = torchvision.datasets.CIFAR10(root=FLAGS.datadir,
                                                    train=False,
                                                    download=True,
                                                    transform=transform_test)
        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            sampler=test_sampler,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model = torchvision.models.resnet18 if FLAGS.use_torchvision else ResNet18
    model_parallel = dp.DataParallel(model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(model.parameters(),
                                           lr=FLAGS.lr,
                                           momentum=FLAGS.momentum,
                                           weight_decay=5e-4))
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = sum(accuracies) / len(accuracies)
        print("Epoch: {}, Mean Accuracy: {:.2f}%".format(epoch, accuracy))
        if FLAGS.metrics_debug:
            print(met.metrics_report())

    return accuracy
def train_unet():

  # logging setup
  logger_name = 'train_logger'
  logger = initializeLogging(os.path.join(logdir,
                'train_history.txt'), logger_name)

  # checkpointing setup
  checkpoint_frequency = log_steps


  torch.manual_seed(1)

  '''
  train_dataset = datasets.MNIST(
      datadir,
      train=True,
      download=True,
      transform=transforms.Compose(
          [transforms.ToTensor(),
           transforms.Normalize((0.1307,), (0.3081,))]))
  train_dataset_len = len(train_dataset)
  test_dataset = datasets.MNIST(
      datadir,
      train=False,
      transform=transforms.Compose(
          [transforms.ToTensor(),
           transforms.Normalize((0.1307,), (0.3081,))]))
  '''

  img_dim = 1024
  normalize = transforms.Normalize(
      mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

  train_dataset = satelliteDataSet(
      os.path.join(datadir, 'train'),
      transforms.Compose([
        transforms.RandomResizedCrop(img_dim),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize
      ]))
  train_dataset_len = len(train_dataset)

  resize_dim = max(img_dim, 256)
  test_dataset = satelliteDataSet(
      os.path.join(datadir, 'train'),
      transforms.Compose([
          transforms.Resize(resize_dim),
          transforms.CenterCrop(img_dim),
          transforms.ToTensor(),
          normalize,
      ]))

  train_sampler = None
  test_sampler = None
  if xm.xrt_world_size() > 1:
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True)
    test_sampler = torch.utils.data.distributed.DistributedSampler(
        test_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False)
  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=batch_size,
      sampler=train_sampler,
      drop_last=drop_last,
      shuffle=False if train_sampler else True,
      num_workers=num_workers)
  test_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=batch_size,
      sampler=test_sampler,
      drop_last=drop_last,
      shuffle=False,
      num_workers=num_workers)

  maxItr = num_epochs * len(train_loader) // train_loader.batch_size + 1

  devices = (
      xm.get_xla_supported_devices(
          max_devices=num_cores) if num_cores != 0 else [])
  # Scale learning rate to num cores
  lr = 0.0001
  lr = lr * max(len(devices), 1)
  # Pass [] as device_ids to run using the PyTorch/CPU engine.
  model_parallel = dp.DataParallel(UNet, device_ids=devices)

  def train_loop_fn(model, loader, device, context):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = context.getattr_or(
        'optimizer',
        lambda: optim.Adam(model.parameters(), lr=lr))
    tracker = xm.RateTracker()

    model.train()
    print('# of iterations: {}'.format(maxItr))
    logger.info('# of iterations: {}'.format(maxItr))
    optimizer.zero_grad()
    for x, (data, target) in enumerate(loader):
      data = target[0].permute(0,3,1,2)
      target = target[1]
      output = model(data)
      loss = loss_fn(output, target.long())
      #_, preds = torch.max(output, 1)
      loss.backward()
      
      # backprop every log_step iterations
      if x % log_steps == 0:
        xm.optimizer_step(optimizer)
        optimizer.zero_grad()

      tracker.add(batch_size)

      # compute the confusion matrix and IoU
      #print(preds.shape)
      #print(target.shape)
      
      #val_conf = np.zeros((num_classes, num_classes))
      #val_conf = val_conf + confusion_matrix(
      #    target[target >= 0].view(-1).cpu().numpy(),
      #    preds[target >= 0].view(-1).cpu().numpy())
      #pos = np.sum(val_conf, 1)
      #res = np.sum(val_conf, 0)
      #tp = np.diag(val_conf)
      #iou = np.mean(tp / np.maximum(1, pos + res - tp))

      #logger.info('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate()))
      print('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate()))
      
      if x % log_steps == 0:
        logger.info('device: {}, x: {}, loss: {}, tracker_rate: {}, tracker_global_rate: {}'.format(device, x, loss.item(), tracker.rate(), tracker.global_rate()))

  def test_loop_fn(model, loader, device, context):
    total_samples = 0
    correct = 0
    model.eval()
    for data, target in loader:
      data = target[0].permute(0,3,1,2)
      target = target[1] 
      output = model(data)
      #pred = output.max(1, keepdim=True)[1].float()
      _, preds = torch.max(output, 1)
      preds = preds.float()
      correct += preds.eq(target.view_as(preds)).sum().item()
      total_samples += target.shape[1]**2
      print('device: {}, Running Accuracy: {}'.format(device, correct/total_samples))

    accuracy = 100.0 * correct / total_samples
    test_utils.print_test_update(device, accuracy)
    logger.info('TEST: device: {}, accuracy: {}'.format(device, accuracy))
    return accuracy

  accuracy = 0.0
  writer = test_utils.get_summary_writer(logdir)
  num_devices = len(
      xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
  num_training_steps_per_epoch = train_dataset_len // (
      batch_size * num_devices)
  print('total epochs: {}'.format(num_epochs))

  for epoch in range(1, num_epochs + 1):
    print(epoch)
    print(train_loader)
    
    model_parallel(train_loop_fn, train_loader)
    accuracies = model_parallel(test_loop_fn, test_loader)
    accuracy = mean(accuracies)
    
    print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
    logger.info('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
    
    global_step = (epoch - 1) * num_training_steps_per_epoch

    if metrics_debug:
      print(met.metrics_report())
      logger.info(met.metrics_report())
      
    logger.info('saving checkpoint. epoch: {}'.format(epoch))
    torch.save(model_parallel, os.path.join(logdir,'model_parallel_chkpt.pt'))
    logger.info('checkpoint saved. epoch: {}'.format(epoch))
       


  test_utils.close_summary_writer(writer)
  return accuracy
Example #18
0
def train_mnist(FLAGS):

    DTYPE = torch.float32

    torch.manual_seed(1)

    dims = (
        FLAGS.batch_size,
        1,
        784,
    )

    train_dataset_len = FLAGS.steps_per_epoch if FLAGS.steps_per_epoch else 60000
    train_loader = xu.SampleGenerator(
        data=(
            torch.ones(dims, dtype=DTYPE,),
            torch.ones(
                FLAGS.batch_size,
                dtype=torch.int64 if not _MSE_LOSS else DTYPE,
            ),
        ),
        sample_count=train_dataset_len
        // FLAGS.batch_size
        // xm.xrt_world_size(),
    )
    test_loader = xu.SampleGenerator(
        data=(
            torch.ones(dims, dtype=DTYPE,),
            torch.ones(
                FLAGS.batch_size,
                dtype=torch.int64 if not _MSE_LOSS else DTYPE,
            ),
        ),
        sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size(),
    )

    devices = (
        xm.get_xla_supported_devices(max_devices=FLAGS.num_cores)
        if FLAGS.num_cores != 0
        else []
    )

    """ 
    Non multi-processing
    """
    # Scale learning rate to num cores
    lr = FLAGS.lr * max(len(devices), 1)

    model = MNIST(FLAGS)
    model_parallel = dp.DataParallel(
        model,
        device_ids=devices,
    )

    writer = None
    if xm.is_master_ordinal():
        writer = test_utils.get_summary_writer(FLAGS.logdir)

    # Just some step closure output
    def train_output_fn(outputs, ctx, args, tracker):
        if ctx.step > 0 and args.log_steps and ctx.step % args.log_steps == 0:
            now_time = time.time()
            if hasattr(ctx, 'start_time') and ctx.start_time:
                per_step_time = (now_time - ctx.start_time) / (
                    ctx.step - ctx.last_step_timed
                )
                steps_per_second = 1 / per_step_time
                print(
                    f'[{xm.get_ordinal()}] Round-trip step time: '
                    f'{per_step_time} seconds, steps per second: {steps_per_second}'
                )
                if tracker:
                    _train_update(
                        device=device,
                        step=ctx.step,
                        loss=outputs[0],
                        tracker=tracker,
                        epoch=epoch,
                        writer=writer,
                    )
                print(f'BEGIN Train step {ctx.step}')
                ctx.start_time = time.time()
                ctx.last_step_timed = ctx.step
            else:
                ctx.start_time = time.time()
                ctx.last_step_timed = ctx.step
        ctx.step += 1

    def train_loop_fn(model, loader, device=None, context=None):
        lr_adder = 0.0

        if _MSE_LOSS:
            loss_fn = nn.MSELoss()
        else:
            loss_fn = nn.NLLLoss()
        optimizer = context.getattr_or(
            'optimizer',
            lambda: optim.SGD(
                model.parameters(),
                lr=lr + lr_adder,
                momentum=FLAGS.momentum,
            ),
        )

        tracker = xm.RateTracker()

        model.train()

        def train_inner_loop_fn(batch, ctx):
            step = ctx.step
            print(f'Step {step}')
            data = batch[0]
            target = batch[1]
            optimizer.zero_grad()
            output = model(data)

            loss = loss_fn(output, target)
            loss.backward()

            xm.optimizer_step(
                optimizer,
                barrier=False,
            )

            if (
                FLAGS.log_steps != 0
                and (
                    FLAGS.log_steps == 1
                    or (step > 0 and step % FLAGS.log_steps == 0)
                )
            ):
                xm.add_step_closure(
                    _train_update,
                    args=(device, step, loss, tracker, epoch, writer),
                )

            if step == 0:
                xm.master_print(f"End TRAIN step {step}")

            ctx.step += 1
            return [loss]

        step = 0
        # Train
        print('Starting new epoch train loop... (epoch={epoch})')
        for step, (data, target) in enumerate(loader):
            if step % FLAGS.step_print_interval == 0:
                xm.master_print(f"Begin TRAIN Step: {step}")
            context.step = step

            if not FLAGS.use_autograph:
                outputs = train_inner_loop_fn((data, target), context)
            else:
                outputs = ptwse.flow.runner.maybe_run_converted(
                    train_inner_loop_fn,
                    (data, target),
                    context
                )

        xm.master_print(f"Saving model...")
        _save_checkpoint(FLAGS, device, None, model, is_epoch=True)
        xm.master_print(f"Model saved")

    def test_loop_fn(model, loader, device, context):
        print("***********************")
        print("ENTERING TEST FUNCTION")
        print("***********************")
        print('Evaluating...')
        total_samples = 0
        correct = 0
        model.eval()
        for step, (data, target) in enumerate(loader):
            if step >= FLAGS.test_max_step:
                break
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            if FLAGS.mp:
                correct += pred.eq(target.view_as(pred)).sum()
            else:
                correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        if FLAGS.mp:
            this_accuracy = 100.0 * correct.item() / total_samples
            print("CALLING: mesh_reduce('test_accuracy')")
            this_accuracy = xm.mesh_reduce(
                'test_accuracy', this_accuracy, np.mean
            )
            print("BACK FROM: mesh_reduce('test_accuracy')")
        else:
            this_accuracy = 100.0 * correct / total_samples
            test_utils.print_test_update(device, this_accuracy)
        print("***********************")
        print("LEAVING TEST FUNCTION")
        print("***********************")
        return this_accuracy

    #
    # Set up for
    #
    accuracy = 0.0

    num_devices = (
        len(xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
    )

    if not FLAGS.steps_per_epoch:
        num_training_steps_per_epoch = train_dataset_len // (
            FLAGS.batch_size * num_devices
        )
    else:
        num_training_steps_per_epoch = FLAGS.steps_per_epoch
    max_accuracy = 0.0

    #
    # Epoch loop
    #
    for epoch in range(1, FLAGS.num_epochs + 1):
        #
        # Train
        #
        device = xm.xla_device()
        ctx = dp.Context(device=device)
        ctx.tracker = xm.RateTracker()
        ctx.step = 0
        train_loop_fn(model, train_loader, device, ctx)

        #
        # Test
        #
        if FLAGS.run_test:
            with ptwse.scope.proxy_disabled(disabled=FLAGS.test_off_proxy):
                accuracies = model_parallel(test_loop_fn, test_loader)
            accuracy = mean(accuracies)
            print(
                'Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy)
            )

        global_step = (epoch - 1) * num_training_steps_per_epoch
        max_accuracy = max(accuracy, max_accuracy)

        test_utils.write_to_summary(
            writer,
            global_step,
            dict_to_write={'Accuracy/test': accuracy},
            write_xla_metrics=True,
        )

        if FLAGS.metrics_debug:
            xm.master_print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    xm.master_print('Max Accuracy: {:.2f}%'.format(max_accuracy))
    return max_accuracy
Example #19
0
import torch
import torch_xla
import torch_xla.core.xla_model as xm

print("Starting...")
print("Supported xla devices")
print(xm.get_xla_supported_devices())
t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)
Example #20
0
def train_imagenet():
    print('==> Preparing data..')
    img_dim = get_model_property('img_dim')
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                  torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64)),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'train'),
            transforms.Compose([
                transforms.RandomResizedCrop(img_dim),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, 'val'),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose([
                transforms.Resize(resize_dim),
                transforms.CenterCrop(img_dim),
                transforms.ToTensor(),
                normalize,
            ]))

        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            sampler=test_sampler,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    torch.manual_seed(42)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    torchvision_model = get_model_property('model_fn')
    model_parallel = dp.DataParallel(torchvision_model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(model.parameters(),
                                           lr=FLAGS.lr,
                                           momentum=FLAGS.momentum,
                                           weight_decay=5e-4))
        lr_scheduler = context.getattr_or(
            'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler(
                optimizer,
                scheduler_type=getattr(FLAGS, 'lr_scheduler_type', None),
                scheduler_divisor=getattr(FLAGS, 'lr_scheduler_divisor', None),
                scheduler_divide_every_n_epochs=getattr(
                    FLAGS, 'lr_scheduler_divide_every_n_epochs', None),
                num_steps_per_epoch=num_training_steps_per_epoch,
                summary_writer=writer if xm.is_master_ordinal() else None))
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())
            if lr_scheduler:
                lr_scheduler.step()

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    writer = SummaryWriter(log_dir=FLAGS.logdir) if FLAGS.logdir else None
    num_devices = len(
        xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         num_devices)
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = mean(accuracies)
        print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        global_step = (epoch - 1) * num_training_steps_per_epoch
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         global_step)
        if FLAGS.metrics_debug:
            print(torch_xla._XLAC._xla_metrics_report())

    return accuracy
Example #21
0
def train_mnist():
    torch.manual_seed(1)

    if FLAGS.fake_data:
        train_dataset_len = 60000  # Number of images in MNIST dataset.
        train_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=train_dataset_len // FLAGS.batch_size //
            xm.xrt_world_size())
        test_loader = xu.SampleGenerator(
            data=(torch.zeros(FLAGS.batch_size, 1, 28, 28),
                  torch.zeros(FLAGS.batch_size, dtype=torch.int64)),
            sample_count=10000 // FLAGS.batch_size // xm.xrt_world_size())
    else:
        train_dataset = datasets.MNIST(FLAGS.datadir,
                                       train=True,
                                       download=True,
                                       transform=transforms.Compose([
                                           transforms.ToTensor(),
                                           transforms.Normalize((0.1307, ),
                                                                (0.3081, ))
                                       ]))
        train_dataset_len = len(train_dataset)
        test_dataset = datasets.MNIST(FLAGS.datadir,
                                      train=False,
                                      transform=transforms.Compose([
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.1307, ),
                                                               (0.3081, ))
                                      ]))
        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True)
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False)
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers)
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.batch_size,
            sampler=test_sampler,
            drop_last=FLAGS.drop_last,
            shuffle=False,
            num_workers=FLAGS.num_workers)

    devices = (xm.get_xla_supported_devices(
        max_devices=FLAGS.num_cores) if FLAGS.num_cores != 0 else [])
    # Scale learning rate to num cores
    lr = FLAGS.lr * max(len(devices), 1)
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    model_parallel = dp.DataParallel(MNIST, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.NLLLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(
                model.parameters(), lr=lr, momentum=FLAGS.momentum))
        tracker = xm.RateTracker()

        model.train()
        for x, (data, target) in enumerate(loader):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                test_utils.print_training_update(device, x, loss.item(),
                                                 tracker.rate(),
                                                 tracker.global_rate())

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for data, target in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy

    accuracy = 0.0
    writer = test_utils.get_summary_writer(FLAGS.logdir)
    num_devices = len(
        xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size *
                                                         num_devices)
    for epoch in range(1, FLAGS.num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = mean(accuracies)
        print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        global_step = (epoch - 1) * num_training_steps_per_epoch
        test_utils.add_scalar_to_summary(writer, 'Accuracy/test', accuracy,
                                         global_step)
        if FLAGS.metrics_debug:
            print(met.metrics_report())

    test_utils.close_summary_writer(writer)
    return accuracy
def train_imagenet():
    print("==> Preparing data..")
    img_dim = get_model_property("img_dim")
    if FLAGS.fake_data:
        train_dataset_len = 1200000  # Roughly the size of Imagenet dataset.
        train_loader = xu.SampleGenerator(
            data=(
                torch.zeros(FLAGS.batch_size, 3, img_dim, img_dim),
                torch.zeros(FLAGS.batch_size, dtype=torch.int64),
            ),
            sample_count=train_dataset_len // FLAGS.batch_size // xm.xrt_world_size(),
        )
        test_loader = xu.SampleGenerator(
            data=(
                torch.zeros(FLAGS.test_set_batch_size, 3, img_dim, img_dim),
                torch.zeros(FLAGS.test_set_batch_size, dtype=torch.int64),
            ),
            sample_count=50000 // FLAGS.batch_size // xm.xrt_world_size(),
        )
    else:
        normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
        train_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, "train"),
            transforms.Compose(
                [
                    transforms.RandomResizedCrop(img_dim),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]
            ),
        )
        train_dataset_len = len(train_dataset.imgs)
        resize_dim = max(img_dim, 256)
        test_dataset = torchvision.datasets.ImageFolder(
            os.path.join(FLAGS.datadir, "val"),
            # Matches Torchvision's eval transforms except Torchvision uses size
            # 256 resize for all models both here and in the train loader. Their
            # version crashes during training on 299x299 images, e.g. inception.
            transforms.Compose(
                [
                    transforms.Resize(resize_dim),
                    transforms.CenterCrop(img_dim),
                    transforms.ToTensor(),
                    normalize,
                ]
            ),
        )
        train_sampler = None
        test_sampler = None
        if xm.xrt_world_size() > 1:
            train_sampler = torch.utils.data.distributed.DistributedSampler(
                train_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=True,
            )
            test_sampler = torch.utils.data.distributed.DistributedSampler(
                test_dataset,
                num_replicas=xm.xrt_world_size(),
                rank=xm.get_ordinal(),
                shuffle=False,
            )
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=FLAGS.batch_size,
            sampler=train_sampler,
            drop_last=True,
            shuffle=False if train_sampler else True,
            num_workers=FLAGS.num_workers,
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset,
            batch_size=FLAGS.test_set_batch_size,
            sampler=test_sampler,
            drop_last=False,
            shuffle=False,
            num_workers=FLAGS.num_workers,
        )

    torch.manual_seed(42)

    devices = (
        xm.get_xla_supported_devices(max_devices=FLAGS.num_cores)
        if FLAGS.num_cores != 0
        else ["cpu"]
    )
    print("use tpu devices", devices)
    # Pass [] as device_ids to run using the PyTorch/CPU engine.
    torchvision_model = get_model_property("model_fn")
    model_parallel = dp.DataParallel(torchvision_model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            "optimizer",
            lambda: optim.SGD(
                model.parameters(),
                lr=FLAGS.lr,
                momentum=FLAGS.momentum,
                weight_decay=1e-4,
            ),
        )
        lr_scheduler = context.getattr_or(
            "lr_scheduler",
            lambda: schedulers.wrap_optimizer_with_scheduler(
                optimizer,
                scheduler_type=getattr(FLAGS, "lr_scheduler_type", None),
                scheduler_divisor=getattr(FLAGS, "lr_scheduler_divisor", None),
                scheduler_divide_every_n_epochs=getattr(
                    FLAGS, "lr_scheduler_divide_every_n_epochs", None
                ),
                num_steps_per_epoch=num_training_steps_per_epoch,
                summary_writer=writer if xm.is_master_ordinal() else None,
            ),
        )
        tracker = xm.RateTracker()
        model.train()
        total_samples = 0
        correct = 0
        top5_accuracys = 0
        losses = 0
        for x, (data, target) in loader:
            optimizer.zero_grad()
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum()
            total_samples += data.size()[0]
            top5_accuracys += topk_accuracy(output, target, topk=5)
            loss = loss_fn(output, target)
            loss.backward()
            losses += loss.item()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if x % FLAGS.log_steps == 0:
                print(
                    "[{}]({}) Loss={:.5f} Top-1 ACC = {:.2f} Rate={:.2f} GlobalRate={:.2f} Time={}".format(
                        str(device),
                        x,
                        loss.item(),
                        (100.0 * correct / total_samples).item(),
                        tracker.rate(),
                        tracker.global_rate(),
                        time.asctime(),
                    )
                )

            if lr_scheduler:
                lr_scheduler.step()
        return (
            losses / (x + 1),
            (100.0 * correct / total_samples).item(),
            (top5_accuracys / (x + 1)).item(),
        )

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        top5_accuracys = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]
            top5_accuracys += topk_accuracy(output, target, topk=5).item()

        accuracy = 100.0 * correct / total_samples
        test_utils.print_test_update(device, accuracy)
        return accuracy, top5_accuracys

    accuracy = 0.0
    writer = SummaryWriter(FLAGS.logdir) if FLAGS.logdir else None
    num_devices = len(xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
    num_training_steps_per_epoch = train_dataset_len // (FLAGS.batch_size * num_devices)
    max_accuracy = 0.0
    print("train_loader_len", len(train_loader), num_training_steps_per_epoch)
    print("test_loader_len", len(test_loader))
    for epoch in range(1, FLAGS.num_epochs + 1):
        global_step = (epoch - 1) * num_training_steps_per_epoch
        # Train evaluate
        metrics = model_parallel(train_loop_fn, train_loader)
        losses, accuracies, top5_accuracys = zip(*metrics)
        loss = mean(losses)
        accuracy = mean(accuracies)
        top5_accuracy = mean(top5_accuracys)
        print(
            "Epoch: {} (Train), Loss {}, Mean Top-1 Accuracy: {:.2f} Top-5 accuracy: {}".format(
                epoch, loss, accuracy, top5_accuracy
            )
        )
        test_utils.add_scalar_to_summary(writer, "Loss/train", loss, global_step)
        test_utils.add_scalar_to_summary(
            writer, "Top-1 Accuracy/train", accuracy, global_step
        )
        test_utils.add_scalar_to_summary(
            writer, "Top-5 Accuracy/train", top5_accuracy, global_step
        )

        # Test evaluate
        metrics = model_parallel(test_loop_fn, test_loader)
        accuracies, top5_accuracys = zip(*metrics)
        top5_accuracys = sum(top5_accuracys)
        top5_accuracy = top5_accuracys / len(test_loader)
        accuracy = mean(accuracies)
        print(
            "Epoch: {} (Valid), Mean Top-1 Accuracy: {:.2f} Top-5 accuracy: {}".format(
                epoch, accuracy, top5_accuracy
            )
        )
        test_utils.add_scalar_to_summary(
            writer, "Top-1 Accuracy/test", accuracy, global_step
        )
        test_utils.add_scalar_to_summary(
            writer, "Top-5 Accuracy/test", top5_accuracy, global_step
        )
        if FLAGS.metrics_debug:
            print(met.metrics_report())
        if accuracy > max_accuracy:
            max_accuracy = max(accuracy, max_accuracy)
    torch.save(
        model_parallel.models[0].to("cpu").state_dict(),
        f"./reports/resnet50_model-{epoch}.pt",
    )
    model_parallel.models[0].to(devices[0])

    test_utils.close_summary_writer(writer)
    print("Max Accuracy: {:.2f}%".format(accuracy))
    return max_accuracy
Example #23
0
!pip3 install -r "/content/drive/My Drive/htfl/Freeze.txt"

import os
print(os.environ["COLAB_TPU_ADDR"])
import torch

# imports the torch_xla package
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.xla_multiprocessing as xmp
num_cores = 10
devices = (
    xm.get_xla_supported_devices(
        max_devices=num_cores) if num_cores != 0 else [])
print("Devices: {}".format(devices))

os.chdir('/content/drive/My Drive/htfl/')
!rm train.py -f 
import os
import sys
from google.colab import files
uploaded = files.upload()

os.chdir('/content/drive/My Drive/htfl/config')
!rm bert_config.json -f 
import os
import sys
from google.colab import files
uploaded = files.upload()
Example #24
0
        args.b,
        settings.MEAN,
        settings.STD,
    )

    Flag = {}
    Flag['lr'] = args.lr
    Flag['epoch'] = args.e
    Flag['warm'] = args.warm
    Flag['batch_size'] = args.b
    Flag['milestones'] = settings.MILESTONES
    Flag['ignore_idx'] = train_data_loader.dataset.ignore_index

    len(train_data_loader.dataset)
    net = UNet(3, train_data_loader.dataset.class_num)
    devices = (xm.get_xla_supported_devices())
    print(devices)
    net = dp.DataParallel(net, device_ids=devices)

    iter_per_epoch = len(train_data_loader) / 8

    best_iou = 0
    for epoch in range(1, args.e + 1):
        print('training epoch {}'.format(epoch))
        t1 = time.time()

        net(train_loop_fn, train_data_loader)

        print(time.time() - t1)

        #result = net(test_loop_fn, test_data_loader)
Example #25
0
def measure_tpu(warmups, steps, h_emb, h_indices, h_offsets, args):

    import torch_xla
    import torch_xla.core.xla_model as xm
    import os

    tsize = int(os.environ.get("MODEL_PARTITION_SIZE", 3000000))

    def syncTPU(tensor):
        torch_xla._XLAC._xla_sync_multi([tensor],
                                        devices=[],
                                        wait=True,
                                        sync_xla_data=True)

    alldev = xm.get_xla_supported_devices()
    allrealdev = xm.xla_real_devices(alldev)
    print("Found {0} devices: {1}".format(len(allrealdev), allrealdev))

    dev = xm.xla_device()
    if (args.features > tsize):
        if args.usexlabag:
            tsplit = torch.split(h_emb.embtable.weight, tsize, dim=0)
        else:
            tsplit = torch.split(h_emb.weight, tsize, dim=0)
        tsplit = list(tsplit)
        for i, chunk in enumerate(tsplit):
            tsplit[i] = chunk.to(dev)

        t = nn.Parameter(torch.ones(10, 10))
        if args.usexlabag:
            h_emb.embtable.weight = t
            t_emb = h_emb.to(dev)
            tsplit = torch.cat(tsplit)
            t_emb.embtable.weight = nn.Parameter(tsplit)
            print("Xla EMB weight shape: ", t_emb.embtable.weight.shape,
                  " on device: ", str(dev))
        else:
            h_emb.weight = t
            t_emb = h_emb.to(dev)
            tsplit = torch.cat(tsplit)
            t_emb.weight = nn.Parameter(tsplit)
            print("EMB weight shape: ", t_emb.weight.shape, " on device: ",
                  str(dev))
    else:
        t_emb = h_emb.to(dev)

    t_indices = h_indices.to(dev)
    t_offsets = h_offsets.to(dev)

    emb_times = 0.0
    start1 = time.perf_counter()
    for i in range(warmups + steps):
        start = time.perf_counter()
        results = t_emb(t_indices, t_offsets)
        syncTPU(results)
        end = time.perf_counter()
        print("Time: {0:.6f} ".format(end - start))
        if (i >= warmups):
            emb_times += end - start

    end1 = time.perf_counter()

    return end1 - start1, emb_times, results
Example #26
0
def train_unet():
    print('==> Preparing data..')
    img_dim = 1024

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = satelliteDataSet(
        os.path.join(datadir, 'train'),
        transforms.Compose([
            transforms.RandomResizedCrop(img_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ]))
    train_dataset_len = len(train_dataset)

    resize_dim = max(img_dim, 256)
    test_dataset = satelliteDataSet(
        os.path.join(datadir, 'train'),
        transforms.Compose([
            transforms.Resize(resize_dim),
            transforms.CenterCrop(img_dim),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = None
    test_sampler = None
    if xm.xrt_world_size() > 1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)
        test_sampler = torch.utils.data.distributed.DistributedSampler(
            test_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=False)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        shuffle=False if train_sampler else True,
        num_workers=num_workers)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=test_set_batch_size,
                                              sampler=test_sampler,
                                              shuffle=False,
                                              num_workers=num_workers)

    torch.manual_seed(42)

    devices = (xm.get_xla_supported_devices(max_devices=num_cores))

    torchvision_model = UNet()
    model_parallel = dp.DataParallel(torchvision_model, device_ids=devices)

    def train_loop_fn(model, loader, device, context):
        loss_fn = nn.CrossEntropyLoss()
        optimizer = context.getattr_or(
            'optimizer', lambda: optim.SGD(model.parameters(),
                                           lr=lr,
                                           momentum=momentum,
                                           weight_decay=1e-4))
        lr_scheduler = context.getattr_or(
            'lr_scheduler', lambda: schedulers.wrap_optimizer_with_scheduler(
                optimizer,
                scheduler_type=lr_scheduler_type,
                scheduler_divisor=lr_scheduler_divisor,
                scheduler_divide_every_n_epochs=
                lr_scheduler_divide_every_n_epochs,
                num_steps_per_epoch=num_training_steps_per_epoch,
                summary_writer=None))
        tracker = xm.RateTracker()
        model.train()
        for x, (data, target) in loader:
            optimizer.zero_grad()
            data = data.permute(0, 3, 1, 2)
            output = model(data)
            print('passed through model')
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(batch_size)
            if x % log_steps == 0:
                print(
                    'device: {}, x: {}, loss: {}, tracker: {}, tracker_global: {} '
                    .format(device, x, loss.item(), tracker.rate(),
                            tracker.global_rate()))
            if lr_scheduler:
                lr_scheduler.step()

    def test_loop_fn(model, loader, device, context):
        total_samples = 0
        correct = 0
        model.eval()
        for x, (data, target) in loader:
            output = model(data)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
            total_samples += data.size()[0]

        accuracy = 100.0 * correct / total_samples
        #test_utils.print_test_update(device, accuracy)
        print('device: {}, accuracy: {}'.format(device, accuracy))
        return accuracy

    accuracy = 0.0
    num_devices = len(
        xm.xla_replication_devices(devices)) if len(devices) > 1 else 1
    num_training_steps_per_epoch = train_dataset_len // (batch_size *
                                                         num_devices)
    for epoch in range(1, num_epochs + 1):
        model_parallel(train_loop_fn, train_loader)
        accuracies = model_parallel(test_loop_fn, test_loader)
        accuracy = mean(accuracies)
        print('Epoch: {}, Mean Accuracy: {:.2f}%'.format(epoch, accuracy))
        global_step = (epoch - 1) * num_training_steps_per_epoch
        print('global step: {}'.format(global_step))

    return accuracy
Example #27
0
def worker_fn(rank, flags):
    args = flags['args']
    world_size = flags['world_size']

    distributed = args.distributed
    is_primary = rank == 0

    mlp.logging.use_fancy_colors()

    # ########## EXPERIMENT SETUP  ###########
    torch.random.manual_seed(args.seed)  # For reproducibility

    if distributed:
        logger_name = f"[Device {rank}] {os.path.basename(__file__)}"
    else:
        logger_name = os.path.basename(__file__)

    logger = get_logger(logger_name)
    # ########################################

    # ############## DEVICE SETUP ##############
    xla_available = len(xm.get_xla_supported_devices()) > 0
    if not xla_available:
        logger.error("No XLA devices available, unable to train")
        return

    if distributed:
        logger.info(
            f"Training using multiple XLA devices: Using XLA device {rank}/{world_size}"
        )
    else:
        logger.info(f"Single XLA device mode : Using XLA device {rank} ")

    device = xm.xla_device()
    # ########################################

    # ########## SETUP BATCH DATASETS ##########
    if distributed and not is_primary:
        xm.rendezvous("loading_data")

    training_data, test_data = load_data()

    if distributed and is_primary:
        xm.rendezvous("loading_data")

    training_sampler = None
    validation_sampler = None
    if distributed:
        training_sampler = torch.utils.data.distributed.DistributedSampler(
            training_data, num_replicas=world_size, rank=rank)
        validation_sampler = torch.utils.data.distributed.DistributedSampler(
            test_data, num_replicas=world_size, rank=rank)

    training_dataset = torch.utils.data.DataLoader(
        training_data,
        batch_size=args.batch_size,
        shuffle=(training_sampler is None),
        sampler=training_sampler,
        num_workers=3)

    # Using the test set as a validation set, just for demonstration purposes
    validation_dataset = torch.utils.data.DataLoader(
        test_data,
        batch_size=args.batch_size,
        shuffle=(validation_sampler is None),
        sampler=validation_sampler,
        num_workers=3)
    # ##########################################

    # ############ BUILD THE MODEL #############
    classifier = build_model(args.hidden_size)

    train_model = TrainModel(classifier, device)

    # Move model to assigned GPU (see torch.cuda.set_device(args.local_rank))
    classifier.to(device)
    # ############################################

    # ############ SETUP OPTIMIZER #############
    optimizer = torch.optim.Adam(classifier.parameters(),
                                 lr=args.learning_rate)
    # ##########################################

    # ############# SETUP TRAINING ##############
    trainer = mlp.trainers.DefaultTrainer(optimizers=optimizer,
                                          model_components=classifier)

    model_hyper_parameters = {"hidden_size": args.hidden_size}

    callbacks = create_callbacks_for(trainer, args.experiment_name,
                                     model_hyper_parameters, is_primary,
                                     validation_dataset,
                                     args.progress_log_period)

    manager = mlp.trainers.TrainingManager(trainer,
                                           training_dataset,
                                           num_epochs=args.num_epochs,
                                           callbacks=callbacks,
                                           experiment_data={"args": args})

    trainer.set_training_model(train_model)
    # ##########################################

    # ################# START! #################
    manager.start_training()
    # ##########################################

    logger.info("DONE.")
Example #28
0
 def test_num_local_devices(self):
   self.assertLen(xm.get_xla_supported_devices(),
                  pjrt.addressable_device_count())
Example #29
0
 def test_get_real_xla_devices(self):
   devices = xm.get_xla_supported_devices()
   xla_devices = torch_xla._XLAC._xla_real_devices(devices)
   for device, xdevice in zip(devices, xla_devices):
     self.assertTrue(re.match(r'(CPU|GPU|TPU):\d+$', xdevice) is not None)
Example #30
0
            print("c device: ", c.device, type(c), c.dtype)
            print("c[2x2] : ", c.narrow(0, 0, 2).narrow(1, 0, 2))
            print("------")
            print(
                "GPU Time is {0:.6f} seconds, rate {1:.3f} GFlops for iter {2} "
                .format(elap1,
                        m * n * k * 2 * 1.0 / (elap1 * 1000000000 / steps),
                        steps))
            print("------\n")

    if (args.testtpu):
        # import torch_xla
        import torch_xla.core.xla_model as xm

        alldev = xm.get_xla_supported_devices()
        allrealdev = xm.xla_real_devices(alldev)
        print("Found {0} XLA devices: {1}".format(len(allrealdev), allrealdev))
        print(torch.__version__)

        dev = xm.xla_device()
        a = a.to(dev)
        b = b.to(dev)
        c = c.to(dev)
        measure_xla(a, b, warmups, m)
        elap1 = measure_xla(a, b, steps, m)

        print("c device: ", c.device, type(c), c.dtype)
        print("c[2x2] : ", c.narrow(0, 0, 2).narrow(1, 0, 2))
        print("------")
        print(