def test(self): devices = xm.get_xla_supported_devices() batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4) sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10) train_loader = xu.SampleGenerator( data=(torch.zeros(batch_size, 3, 224, 224), torch.zeros(batch_size, dtype=torch.int64)), sample_count=sample_count * len(devices)) def loop_fn(model, loader, device, context): loss_fn = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) for data, target in loader: with xu.TimedScope(msg='Training loop: ', printfn=None): optimizer.zero_grad() output = xu.timed(lambda: model(data), msg='Model: ', printfn=None) loss = xu.timed( lambda: loss_fn(output, target), msg='Loss: ', printfn=None) xu.timed(loss.backward, msg='LossBkw: ', printfn=None) xu.timed( lambda: xm.optimizer_step(optimizer), msg='Step: ', printfn=None) self.assertLess(loss.cpu().item(), 3.0) model_parallel = dp.DataParallel( torchvision.models.resnet18, device_ids=devices) model_parallel(loop_fn, train_loader)
def _regular_health_check(): uneven_health_timeout = xu.getenv_as( 'XLA_UNEVEN_HEARTBEAT_TIMEOUT', int, 900) even_health_timeout = xu.getenv_as('XLA_EVEN_HEARTBEAT_TIMEOUT', int, 1800) while True: self._check_client_mesh_health(uneven_health_timeout, even_health_timeout) time.sleep(self.HEARTBEAT_CHECK_PERIOD)
def mark_step(): if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False): print('torch_xla.core.xla_model::mark_step', file=sys.stderr, flush=True) torch_xla._XLAC._xla_step_marker( torch_xla._XLAC._xla_get_default_device(), [], wait=xu.getenv_as('XLA_SYNC_WAIT', bool, False)) # Only emit metrics from the first local device index, to avoid emitting the # same values from different threads. if is_master_ordinal(): ms.save_metrics() _run_step_closures() _TLS.all_reduce_token = None
def _check_client_mesh_health(self, uneven_health_timeout, even_health_timeout): min_delay = max(uneven_health_timeout, even_health_timeout) + 1 count = None now = time.time() if xu.getenv_as('XLA_DEBUG_LOG_HEARTBEATS', bool, False): self.logger.info('Worker Heartbeats: {}'.format( self._last_heartbeats), extra={ 'clientip': '', 'ordinal': '' }) for cw_hb in self._last_heartbeats.values(): min_delay = min(min_delay, now - cw_hb['last_time']) if count is None: count = cw_hb['count'] elif count >= 0 and count != cw_hb['count']: count = -1 if count < 0 and min_delay > uneven_health_timeout: self._error_queue.put( RuntimeError( 'Client mesh is unhealthy with uneven heartbeats')) elif count > 0 and min_delay > even_health_timeout: self._error_queue.put( RuntimeError('Client mesh is unhealthy with even heartbeats'))
def __init__(self, smooth_factor=None): self._smooth_factor = xu.getenv_as( 'RATE_TRACKER_SMOOTHING', float, 0.4) if smooth_factor is None else smooth_factor self._start_time = time.time() self._partial_time = self._start_time self._partial_count = 0.0 self._partial_rate = None self._count = 0.0
def mark_step(): torch_xla._XLAC._xla_step_marker(torch_xla._XLAC._xla_get_default_device(), [], wait=xu.getenv_as('XLA_SYNC_WAIT', bool, False)) # Only emit metrics from the first local device index, to avoid emitting the # same values from different threads. if is_master_ordinal(): ms.save_metrics()
def xrt_world_size(defval=1): """Retrieves the number of devices which is taking part of the replication. Args: defval (int, optional): The default value to be returned in case there is no replication information available. Default: 1 Returns: The number of devices which is taking part of the replication. """ return xu.getenv_as(xenv.WORLD_SIZE, int, defval=defval)
def prepare_for_compare(self, tx, ty): print_tensors = xu.getenv_as('TEST_PRINT_TENSORS', bool, defval=False) x, y = tx, ty if type(x) == torch.Tensor: x = tx.to(device='cpu') if print_tensors: print('Tensor X ({}):\n{}'.format(tx.device, x), file=sys.stderr) if type(y) == torch.Tensor: y = ty.to(device='cpu') if print_tensors: print('Tensor Y ({}):\n{}'.format(ty.device, y), file=sys.stderr) return x, y
def assert_on_losses(args, train_loss=None, valid_loss=None): if xu.getenv_as('XLA_USE_BF16', bool, False): return if args.target_valid_loss is not None: assert valid_loss is not None and args.target_valid_loss > valid_loss, \ 'valid loss is {}, target is {}'.format( valid_loss, args.target_valid_loss ) if args.target_train_loss is not None: assert train_loss is not None and args.target_train_loss > train_loss, \ 'train loss is {}, target is {}'.format( train_loss, args.target_train_loss )
def get_ordinal(defval=0): """Retrieves the replication ordinal of the current process. The ordinals range from 0 to `xrt_world_size()` minus 1. Args: defval (int, optional): The default value to be returned in case there is no replication information available. Default: 0 Returns: The replication ordinal of the current process. """ return xu.getenv_as(xenv.ORDINAL, int, defval=defval)
def get_local_ordinal(defval=0): """Retrieves the replication local ordinal of the current process. The local ordinals range from 0 to the number of local devices minus 1. Args: defval (int, optional): The default value to be returned in case there is no replication information available. Default: 0 Returns: The replication local ordinal of the current process. """ return xu.getenv_as(xenv.LOCAL_ORDINAL, int, defval=defval)
def get_local_ordinal(defval=0): """Retrieves the replication local ordinal of the current thread. The local ordinals range from 0 to the number of local devices minus 1. Args: defval (int, optional): The default value to be returned in case there is no replication information available. Default: 0 Returns: The replication local ordinal of the current thread. """ if pjrt.using_pjrt(): return pjrt.local_ordinal(defval) ordinal = xu.getenv_as(xenv.LOCAL_ORDINAL, int, defval=-1) if ordinal >= 0: return ordinal return getattr(_get_device_context(), 'device_index', defval)
def __call__(self, *args, **kwargs): """Perform the PyTorch operation based on XLA tensors. Args: args: The PyTorch XLA tensors which are inputs of the operation. kwargs: Keyword arguments passed to the lowering function. These are Python scalars and cannot be XLA tensors. Returns: The PyTorch tensors wrapping the values returned by XLA lowering function. """ shapes = xb.tensor_shape(args) key = pickle.dumps([shapes, kwargs]) with self._lock: computation = self._computations.get(key, None) if computation is None: computation = xb.create_computation(self._name, self._opfn, shapes, **kwargs) self._computations[key] = computation if xu.getenv_as('XLA_OP_PRINT_COMPUTATIONS', bool, False): print(xb.get_computation_hlo(computation), file=sys.stderr) result = torch_xla._XLAC._xla_user_computation(self._opname, args, computation) return result[0] if len(result) == 1 else result
def device_type() -> Optional[str]: """Returns the currrent PjRt device type.""" return xu.getenv_as(xenv.PJRT_DEVICE, str)
def num_visible_tpu_chips(default: int = 4) -> int: """Returns number of TPU chips visible to current process.""" visible_devices = xu.getenv_as(xenv.TPU_VISIBLE_DEVICES, str) return len(visible_devices.split(',')) if visible_devices else default
def __init__(self, args): self.args = args self.device = xm.xla_device() self.test_time = xu.getenv_as('BENCH_TEST_TIME', float, 5.0) torch.manual_seed(42)
self.b = ForTest1(self, a) xdata = { 2: (11, ['a', 'b'], 17), 'w': [12, 'q', 12.33], 17.09: set(['a', 'b', 21]), } data = ForTest2(xdata) wids = [] def convert(x): wids.append(id(x)) return x xu.for_each_instance_rewrite(data, lambda x: isinstance(x, (int, str, float)), convert) self.assertEqual(len(wids), 11) if __name__ == '__main__': torch.set_default_tensor_type('torch.FloatTensor') torch.manual_seed(42) torch_xla._XLAC._xla_set_use_full_mat_mul_precision( use_full_mat_mul_precision=True) test = unittest.main(verbosity=FLAGS.verbosity, exit=False) if xu.getenv_as('METRICS_DEBUG', bool, defval=False): print(met.metrics_report()) sys.exit(0 if test.result.wasSuccessful() else 1)
def xrt_world_size(defval=1): return xu.getenv_as(xenv.WORLD_SIZE, int, defval=defval)
def setup(self): self.size = xu.getenv_as('ADD_MUL_DIV_SIZE', int, 100) self.a = torch.rand(self.size, self.size) self.b = torch.rand(self.size, self.size).abs() + 1.0
def get_ordinal(defval=0): return xu.getenv_as(xenv.ORDINAL, int, defval=defval)
def main_tpu(args): def prepare_task(args, xla_device): # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(','): task.load_dataset(valid_sub_split, combine=True, epoch=0) # Build models and criteria to print some metadata torch.manual_seed(args.seed) model, criterion = task.build_model(args), task.build_criterion(args) xm.master_print(model) xm.master_print('| model {}, criterion {}'.format( args.arch, criterion.__class__.__name__)) xm.master_print('| num. model params: {} (num. trained: {})'.format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad))) model = model.to(xla_device) trainer = Trainer(args, task, model, criterion, xla_device=xla_device) lr = trainer.get_lr() # Load the latest checkpoint if one is available and restore the # corresponding train iterator # we overwrite distributed args here to shard data using torch_xla's # distributed training. trainer.args.distributed_rank = xm.get_ordinal() trainer.args.distributed_world_size = xm.xrt_world_size() extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) trainer.args.distributed_rank = 0 trainer.args.distributed_world_size = 1 trainer.meters_to_device(xla_device) valid_subsets = args.valid_subset.split(',') ordinal = xm.get_ordinal(defval=-1) device_str = ( str(xla_device) if ordinal < 0 else '{}/{}'.format(xla_device, ordinal) ) return task, trainer, model, epoch_itr, lr, valid_subsets, device_str def train_loop_fn(device, trainer, loader, last_batch_index): """ This is the main training loop. It trains for 1 epoch. """ def print_training_update(trainer, progress, args, i): stats = get_training_stats(trainer, args=args) stats['now'] = now() progress.log(stats, tag='train', step=trainer.get_num_updates()) progress.print_mid_epoch(i+1, force=True) stats, log_output, skip_stat_keys = None, None, {'clip'} max_update = args.max_update or math.inf for i, samples in enumerate(loader, start=epoch_itr.iterations_in_epoch): if i == last_batch_index: # last batches are incomplete break log_output = trainer.train_step(samples) reset_perf_training_meters(trainer, i, ignore_index=10) if (not (i % args.log_steps)) or (i == last_batch_index-1): step_args = trainer, progress, args, i xm.add_step_closure(print_training_update, args=step_args) num_updates = trainer.get_num_updates() if ( not args.disable_validation and args.save_interval_updates > 0 and num_updates % args.save_interval_updates == 0 and num_updates > 0 ): vloss = validate_subset( args, device, trainer, task, epoch_itr, valid_subsets[0] ) checkpoint_utils.save_checkpoint( args, trainer, epoch_itr, vloss.item(), epoch=epoch, end_of_epoch=False, ) if num_updates >= max_update: break def valid_loop_fn( args, device, trainer, progress, loader, last_batch_index ): extra_meters = collections.defaultdict(lambda: AverageMeter()) for i, sample in enumerate(loader): if i == last_batch_index: # last batches are of different size, will cause recompilations break log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue extra_meters[k].update(v) stats = get_valid_stats(trainer, args) for k, meter in extra_meters.items(): stats[k] = meter.avg return stats def validate_subset(args, device, trainer, task, epoch_itr, subset): xm.master_print('Validating the subset "{}", {}'.format(subset, now())) # Initialize data iterator # we're not sharding the validation set itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_workers=args.num_workers ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on {} \'{}\' subset'.format(device, subset), no_progress_bar='simple' ) para_loader = pl.ParallelLoader(progress, [xla_device]) reset_validation_loss_meters(trainer) stats = valid_loop_fn( args, device, trainer, progress, para_loader.per_device_loader(xla_device), len(progress) - 1 ) progress_bar.progress_bar_print( progress, stats, step=trainer.get_num_updates(), force=True, tag='validate-{}'.format(subset), flush_writer=True, ) xm.master_print('Validated the subset "{}", {}'.format(subset, now())) return stats['loss'].avg def validate_subsets(args, device, trainer, task, epoch_itr, subsets): valid_losses = { subset: validate_subset( args, device, trainer, task, epoch_itr, subset ) for subset in subsets } return valid_losses def keep_training(lr, epoch_itr, trainer): # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr, n_updates = trainer.get_lr(), trainer.get_num_updates() return ((lr > args.min_lr) and (epoch_itr.epoch < max_epoch) and (n_updates < max_update)) if xu.getenv_as('XLA_USE_BF16', bool, False): xm.master_print( 'WARNING: bfloat16 is enabled. Note that fairseq meters such as ' 'loss will accumulate the numerator, and increment the denominator.' ' Due to lack of precision in higher numbers in bfloat16, these ' 'meters will report invalid values after a while.', fd=sys.stderr ) xm.master_print('Args', fd=sys.stderr) for key, val in args.__dict__.items(): xm.master_print('\t{} {}'.format(key, val), fd=sys.stderr) # `xla_device` is `torch.device` and `device` is `str` xla_device = xm.xla_device() task, trainer, model, epoch_itr, lr, valid_subsets, device = prepare_task( args, xla_device) train_meter = StopwatchMeter() train_meter.start() while keep_training(lr, epoch_itr, trainer): # TRAINING epoch = epoch_itr.epoch + 1 xm.master_print('Epoch {} begin {}'.format(epoch, now())) progress = initialize_loader_for_epoch( args, epoch_itr, prefix='training on {}'.format(device), ) skip_stat_keys = {'clip'} if args.suppress_loss_report: skip_stat_keys.update({'loss', 'nll_loss', 'gnorm'}) progress.set_keys_to_skip_mid_epoch(skip_stat_keys) para_loader = pl.ParallelLoader(progress, [xla_device]) train_loop_fn( device, trainer, para_loader.per_device_loader(xla_device), len(progress) - 1 ) training_stats = get_training_stats(trainer, args=args) tloss = training_stats['loss'].avg.item() progress_bar.progress_bar_print( progress, training_stats, tag='train', force=True, step=trainer.get_num_updates(), log_xla_metrics=True, flush_writer=True, ) xm.master_print('Epoch {} end {}'.format(epoch_itr.epoch, now())) if args.metrics_debug: xm.master_print(met.metrics_report()) reset_training_meters(trainer) # VALIDATION if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate_subsets( args, device, trainer, task, epoch_itr, valid_subsets ) # only use average first validation loss to update learning rate vloss = valid_losses[valid_subsets[0]].item() xm.master_print('old learning rate: {}'.format(lr)) lr = trainer.lr_step(epoch_itr.epoch, vloss) xm.master_print('new learning rate: {}'.format(lr)) if args.metrics_debug: xm.master_print(met.metrics_report()) else: vloss = None # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint( args, trainer, epoch_itr, vloss, epoch=epoch, end_of_epoch=True, ) train_meter.stop() xm.master_print('| done training in {:.1f} seconds'.format(train_meter.sum)) assert_on_losses(args, train_loss=tloss, valid_loss=vloss)