Exemple #1
0
  def test(self):
    devices = xm.get_xla_supported_devices()
    batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4)
    sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
    train_loader = xu.SampleGenerator(
        data=(torch.zeros(batch_size, 3, 224,
                          224), torch.zeros(batch_size, dtype=torch.int64)),
        sample_count=sample_count * len(devices))

    def loop_fn(model, loader, device, context):
      loss_fn = nn.NLLLoss()
      optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

      for data, target in loader:
        with xu.TimedScope(msg='Training loop: ', printfn=None):
          optimizer.zero_grad()
          output = xu.timed(lambda: model(data), msg='Model: ', printfn=None)
          loss = xu.timed(
              lambda: loss_fn(output, target), msg='Loss: ', printfn=None)
          xu.timed(loss.backward, msg='LossBkw: ', printfn=None)
          xu.timed(
              lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None)
          self.assertLess(loss.cpu().item(), 3.0)

    model_parallel = dp.DataParallel(
        torchvision.models.resnet18, device_ids=devices)
    model_parallel(loop_fn, train_loader)
Exemple #2
0
 def _regular_health_check():
     uneven_health_timeout = xu.getenv_as(
         'XLA_UNEVEN_HEARTBEAT_TIMEOUT', int, 900)
     even_health_timeout = xu.getenv_as('XLA_EVEN_HEARTBEAT_TIMEOUT',
                                        int, 1800)
     while True:
         self._check_client_mesh_health(uneven_health_timeout,
                                        even_health_timeout)
         time.sleep(self.HEARTBEAT_CHECK_PERIOD)
Exemple #3
0
def mark_step():
  if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False):
    print('torch_xla.core.xla_model::mark_step', file=sys.stderr, flush=True)
  torch_xla._XLAC._xla_step_marker(
      torch_xla._XLAC._xla_get_default_device(), [],
      wait=xu.getenv_as('XLA_SYNC_WAIT', bool, False))
  # Only emit metrics from the first local device index, to avoid emitting the
  # same values from different threads.
  if is_master_ordinal():
    ms.save_metrics()
  _run_step_closures()
  _TLS.all_reduce_token = None
Exemple #4
0
    def _check_client_mesh_health(self, uneven_health_timeout,
                                  even_health_timeout):
        min_delay = max(uneven_health_timeout, even_health_timeout) + 1
        count = None
        now = time.time()
        if xu.getenv_as('XLA_DEBUG_LOG_HEARTBEATS', bool, False):
            self.logger.info('Worker Heartbeats: {}'.format(
                self._last_heartbeats),
                             extra={
                                 'clientip': '',
                                 'ordinal': ''
                             })

        for cw_hb in self._last_heartbeats.values():
            min_delay = min(min_delay, now - cw_hb['last_time'])
            if count is None:
                count = cw_hb['count']
            elif count >= 0 and count != cw_hb['count']:
                count = -1

        if count < 0 and min_delay > uneven_health_timeout:
            self._error_queue.put(
                RuntimeError(
                    'Client mesh is unhealthy with uneven heartbeats'))
        elif count > 0 and min_delay > even_health_timeout:
            self._error_queue.put(
                RuntimeError('Client mesh is unhealthy with even heartbeats'))
Exemple #5
0
 def __init__(self, smooth_factor=None):
     self._smooth_factor = xu.getenv_as(
         'RATE_TRACKER_SMOOTHING', float,
         0.4) if smooth_factor is None else smooth_factor
     self._start_time = time.time()
     self._partial_time = self._start_time
     self._partial_count = 0.0
     self._partial_rate = None
     self._count = 0.0
Exemple #6
0
def mark_step():
    torch_xla._XLAC._xla_step_marker(torch_xla._XLAC._xla_get_default_device(),
                                     [],
                                     wait=xu.getenv_as('XLA_SYNC_WAIT', bool,
                                                       False))
    # Only emit metrics from the first local device index, to avoid emitting the
    # same values from different threads.
    if is_master_ordinal():
        ms.save_metrics()
Exemple #7
0
def xrt_world_size(defval=1):
    """Retrieves the number of devices which is taking part of the replication.

  Args:
    defval (int, optional): The default value to be returned in case there is no
      replication information available.
      Default: 1

  Returns:
    The number of devices which is taking part of the replication.
  """
    return xu.getenv_as(xenv.WORLD_SIZE, int, defval=defval)
Exemple #8
0
 def prepare_for_compare(self, tx, ty):
   print_tensors = xu.getenv_as('TEST_PRINT_TENSORS', bool, defval=False)
   x, y = tx, ty
   if type(x) == torch.Tensor:
     x = tx.to(device='cpu')
     if print_tensors:
       print('Tensor X ({}):\n{}'.format(tx.device, x), file=sys.stderr)
   if type(y) == torch.Tensor:
     y = ty.to(device='cpu')
     if print_tensors:
       print('Tensor Y ({}):\n{}'.format(ty.device, y), file=sys.stderr)
   return x, y
Exemple #9
0
def assert_on_losses(args, train_loss=None, valid_loss=None):
    if xu.getenv_as('XLA_USE_BF16', bool, False):
        return
    if args.target_valid_loss is not None:
        assert valid_loss is not None and args.target_valid_loss > valid_loss, \
            'valid loss is {}, target is {}'.format(
                valid_loss, args.target_valid_loss
            )
    if args.target_train_loss is not None:
        assert train_loss is not None and args.target_train_loss > train_loss, \
            'train loss is {}, target is {}'.format(
                train_loss, args.target_train_loss
            )
Exemple #10
0
def get_ordinal(defval=0):
    """Retrieves the replication ordinal of the current process.

  The ordinals range from 0 to `xrt_world_size()` minus 1.

  Args:
    defval (int, optional): The default value to be returned in case there is no
      replication information available.
      Default: 0

  Returns:
    The replication ordinal of the current process.
  """
    return xu.getenv_as(xenv.ORDINAL, int, defval=defval)
Exemple #11
0
def get_local_ordinal(defval=0):
    """Retrieves the replication local ordinal of the current process.

  The local ordinals range from 0 to the number of local devices minus 1.

  Args:
    defval (int, optional): The default value to be returned in case there is no
      replication information available.
      Default: 0

  Returns:
    The replication local ordinal of the current process.
  """

    return xu.getenv_as(xenv.LOCAL_ORDINAL, int, defval=defval)
Exemple #12
0
def get_local_ordinal(defval=0):
    """Retrieves the replication local ordinal of the current thread.

  The local ordinals range from 0 to the number of local devices minus 1.

  Args:
    defval (int, optional): The default value to be returned in case there is no
      replication information available.
      Default: 0

  Returns:
    The replication local ordinal of the current thread.
  """
    if pjrt.using_pjrt():
        return pjrt.local_ordinal(defval)

    ordinal = xu.getenv_as(xenv.LOCAL_ORDINAL, int, defval=-1)
    if ordinal >= 0:
        return ordinal
    return getattr(_get_device_context(), 'device_index', defval)
Exemple #13
0
  def __call__(self, *args, **kwargs):
    """Perform the PyTorch operation based on XLA tensors.

    Args:
      args: The PyTorch XLA tensors which are inputs of the operation.
      kwargs: Keyword arguments passed to the lowering function. These are
        Python scalars and cannot be XLA tensors.
    Returns:
      The PyTorch tensors wrapping the values returned by XLA lowering function.
    """
    shapes = xb.tensor_shape(args)
    key = pickle.dumps([shapes, kwargs])
    with self._lock:
      computation = self._computations.get(key, None)
      if computation is None:
        computation = xb.create_computation(self._name, self._opfn, shapes,
                                            **kwargs)
        self._computations[key] = computation
        if xu.getenv_as('XLA_OP_PRINT_COMPUTATIONS', bool, False):
          print(xb.get_computation_hlo(computation), file=sys.stderr)
    result = torch_xla._XLAC._xla_user_computation(self._opname, args,
                                                   computation)
    return result[0] if len(result) == 1 else result
Exemple #14
0
def device_type() -> Optional[str]:
  """Returns the currrent PjRt device type."""
  return xu.getenv_as(xenv.PJRT_DEVICE, str)
Exemple #15
0
def num_visible_tpu_chips(default: int = 4) -> int:
  """Returns number of TPU chips visible to current process."""
  visible_devices = xu.getenv_as(xenv.TPU_VISIBLE_DEVICES, str)

  return len(visible_devices.split(',')) if visible_devices else default
Exemple #16
0
 def __init__(self, args):
     self.args = args
     self.device = xm.xla_device()
     self.test_time = xu.getenv_as('BENCH_TEST_TIME', float, 5.0)
     torch.manual_seed(42)
Exemple #17
0
        self.b = ForTest1(self, a)

    xdata = {
        2: (11, ['a', 'b'], 17),
        'w': [12, 'q', 12.33],
        17.09: set(['a', 'b', 21]),
    }
    data = ForTest2(xdata)

    wids = []

    def convert(x):
      wids.append(id(x))
      return x

    xu.for_each_instance_rewrite(data,
                                 lambda x: isinstance(x, (int, str, float)),
                                 convert)
    self.assertEqual(len(wids), 11)


if __name__ == '__main__':
  torch.set_default_tensor_type('torch.FloatTensor')
  torch.manual_seed(42)
  torch_xla._XLAC._xla_set_use_full_mat_mul_precision(
      use_full_mat_mul_precision=True)
  test = unittest.main(verbosity=FLAGS.verbosity, exit=False)
  if xu.getenv_as('METRICS_DEBUG', bool, defval=False):
    print(met.metrics_report())
  sys.exit(0 if test.result.wasSuccessful() else 1)
Exemple #18
0
def xrt_world_size(defval=1):
    return xu.getenv_as(xenv.WORLD_SIZE, int, defval=defval)
Exemple #19
0
 def setup(self):
     self.size = xu.getenv_as('ADD_MUL_DIV_SIZE', int, 100)
     self.a = torch.rand(self.size, self.size)
     self.b = torch.rand(self.size, self.size).abs() + 1.0
Exemple #20
0
def get_ordinal(defval=0):
    return xu.getenv_as(xenv.ORDINAL, int, defval=defval)
Exemple #21
0
def main_tpu(args):

    def prepare_task(args, xla_device):
        # Setup task, e.g., translation, language modeling, etc.
        task = tasks.setup_task(args)

        # Load valid dataset (we load training data below, based on the latest checkpoint)
        for valid_sub_split in args.valid_subset.split(','):
            task.load_dataset(valid_sub_split, combine=True, epoch=0)

        # Build models and criteria to print some metadata
        torch.manual_seed(args.seed)
        model, criterion = task.build_model(args), task.build_criterion(args)
        xm.master_print(model)
        xm.master_print('| model {}, criterion {}'.format(
            args.arch, criterion.__class__.__name__))
        xm.master_print('| num. model params: {} (num. trained: {})'.format(
            sum(p.numel() for p in model.parameters()),
            sum(p.numel() for p in model.parameters() if p.requires_grad)))
        model = model.to(xla_device)
        trainer = Trainer(args, task, model, criterion, xla_device=xla_device)
        lr = trainer.get_lr()

        # Load the latest checkpoint if one is available and restore the
        # corresponding train iterator
        # we overwrite distributed args here to shard data using torch_xla's
        # distributed training.
        trainer.args.distributed_rank = xm.get_ordinal()
        trainer.args.distributed_world_size = xm.xrt_world_size()
        extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer)
        trainer.args.distributed_rank = 0
        trainer.args.distributed_world_size = 1
        trainer.meters_to_device(xla_device)
        valid_subsets = args.valid_subset.split(',')
        ordinal = xm.get_ordinal(defval=-1)
        device_str = (
            str(xla_device) if ordinal < 0 else
            '{}/{}'.format(xla_device, ordinal)
        )
        return task, trainer, model, epoch_itr, lr, valid_subsets, device_str

    def train_loop_fn(device, trainer, loader, last_batch_index):
        """
        This is the main training loop. It trains for 1 epoch.
        """

        def print_training_update(trainer, progress, args, i):
            stats = get_training_stats(trainer, args=args)
            stats['now'] = now()
            progress.log(stats, tag='train', step=trainer.get_num_updates())
            progress.print_mid_epoch(i+1, force=True)

        stats, log_output, skip_stat_keys = None, None, {'clip'}
        max_update = args.max_update or math.inf
        for i, samples in enumerate(loader, start=epoch_itr.iterations_in_epoch):
            if i == last_batch_index:
                # last batches are incomplete
                break
            log_output = trainer.train_step(samples)
            reset_perf_training_meters(trainer, i, ignore_index=10)
            if (not (i % args.log_steps)) or (i == last_batch_index-1):
                step_args = trainer, progress, args, i
                xm.add_step_closure(print_training_update, args=step_args)
            num_updates = trainer.get_num_updates()
            if (
                not args.disable_validation
                and args.save_interval_updates > 0
                and num_updates % args.save_interval_updates == 0
                and num_updates > 0
            ):
                vloss = validate_subset(
                    args, device, trainer, task, epoch_itr, valid_subsets[0]
                )
                checkpoint_utils.save_checkpoint(
                    args, trainer, epoch_itr, vloss.item(),
                    epoch=epoch, end_of_epoch=False,
                )
            if num_updates >= max_update:
                break


    def valid_loop_fn(
        args, device, trainer, progress, loader, last_batch_index
    ):
        extra_meters = collections.defaultdict(lambda: AverageMeter())
        for i, sample in enumerate(loader):
            if i == last_batch_index:
                # last batches are of different size, will cause recompilations
                break
            log_output = trainer.valid_step(sample)
            for k, v in log_output.items():
                if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']:
                    continue
                extra_meters[k].update(v)
        stats = get_valid_stats(trainer, args)
        for k, meter in extra_meters.items():
            stats[k] = meter.avg
        return stats

    def validate_subset(args, device, trainer, task, epoch_itr, subset):
        xm.master_print('Validating the subset "{}", {}'.format(subset, now()))
        # Initialize data iterator
        # we're not sharding the validation set
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_workers=args.num_workers
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args, itr, epoch_itr.epoch,
            prefix='valid on {} \'{}\' subset'.format(device, subset),
            no_progress_bar='simple'
        )
        para_loader = pl.ParallelLoader(progress, [xla_device])
        reset_validation_loss_meters(trainer)
        stats = valid_loop_fn(
            args, device, trainer, progress,
            para_loader.per_device_loader(xla_device), len(progress) - 1
        )
        progress_bar.progress_bar_print(
            progress, stats, step=trainer.get_num_updates(), force=True,
            tag='validate-{}'.format(subset), flush_writer=True,
        )
        xm.master_print('Validated the subset "{}", {}'.format(subset, now()))
        return stats['loss'].avg

    def validate_subsets(args, device, trainer, task, epoch_itr, subsets):
        valid_losses = {
            subset: validate_subset(
                args, device, trainer, task, epoch_itr, subset
            )
            for subset in subsets
        }
        return valid_losses

    def keep_training(lr, epoch_itr, trainer):
        # Train until the learning rate gets too small
        max_epoch = args.max_epoch or math.inf
        max_update = args.max_update or math.inf
        lr, n_updates = trainer.get_lr(), trainer.get_num_updates()
        return ((lr > args.min_lr) and (epoch_itr.epoch < max_epoch) and
            (n_updates < max_update))

    if xu.getenv_as('XLA_USE_BF16', bool, False):
        xm.master_print(
            'WARNING: bfloat16 is enabled. Note that fairseq meters such as '
            'loss will accumulate the numerator, and increment the denominator.'
            ' Due to lack of precision in higher numbers in bfloat16, these '
            'meters will report invalid values after a while.',
            fd=sys.stderr
        )

    xm.master_print('Args', fd=sys.stderr)
    for key, val in args.__dict__.items():
        xm.master_print('\t{} {}'.format(key, val), fd=sys.stderr)
    # `xla_device` is `torch.device` and `device` is `str`
    xla_device = xm.xla_device()
    task, trainer, model, epoch_itr, lr, valid_subsets, device = prepare_task(
        args, xla_device)

    train_meter = StopwatchMeter()
    train_meter.start()
    while keep_training(lr, epoch_itr, trainer):
        # TRAINING
        epoch = epoch_itr.epoch + 1
        xm.master_print('Epoch {} begin {}'.format(epoch, now()))
        progress = initialize_loader_for_epoch(
            args, epoch_itr, prefix='training on {}'.format(device),
        )
        skip_stat_keys = {'clip'}
        if args.suppress_loss_report:
            skip_stat_keys.update({'loss', 'nll_loss', 'gnorm'})
        progress.set_keys_to_skip_mid_epoch(skip_stat_keys)
        para_loader = pl.ParallelLoader(progress, [xla_device])
        train_loop_fn(
            device, trainer, para_loader.per_device_loader(xla_device),
            len(progress) - 1
        )
        training_stats = get_training_stats(trainer, args=args)
        tloss = training_stats['loss'].avg.item()
        progress_bar.progress_bar_print(
            progress, training_stats, tag='train', force=True,
            step=trainer.get_num_updates(), log_xla_metrics=True,
            flush_writer=True,
        )
        xm.master_print('Epoch {} end {}'.format(epoch_itr.epoch, now()))
        if args.metrics_debug:
            xm.master_print(met.metrics_report())
        reset_training_meters(trainer)

        # VALIDATION
        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate_subsets(
                args, device, trainer, task, epoch_itr, valid_subsets
            )

            # only use average first validation loss to update learning rate
            vloss = valid_losses[valid_subsets[0]].item()
            xm.master_print('old learning rate: {}'.format(lr))
            lr = trainer.lr_step(epoch_itr.epoch, vloss)
            xm.master_print('new learning rate: {}'.format(lr))
            if args.metrics_debug:
                xm.master_print(met.metrics_report())
        else:
            vloss = None

        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            checkpoint_utils.save_checkpoint(
                args, trainer, epoch_itr, vloss,
                epoch=epoch, end_of_epoch=True,
            )

    train_meter.stop()
    xm.master_print('| done training in {:.1f} seconds'.format(train_meter.sum))
    assert_on_losses(args, train_loss=tloss, valid_loss=vloss)