def collate_tokens_tpu(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False): # Copied over from fairseq.data_utils, and modified so that num_columns # in the output tensor is not too variable. # correcting columns global PAD_TO_LENGTH size = max(v.size(0) for v in values) if size > PAD_TO_LENGTH: xu.eprint( 'I had to change PAD_TO_LENGTH from {} to {}, this is going to trigger graph recompiles' .format(PAD_TO_LENGTH, size)) PAD_TO_LENGTH = size size = PAD_TO_LENGTH # done correcting res = values[0].new(len(values), size).fill_(pad_idx) def copy_tensor(src, dst): assert dst.numel() == src.numel() if move_eos_to_beginning: assert src[-1] == eos_idx dst[0] = eos_idx dst[1:] = src[:-1] else: dst.copy_(src) for i, v in enumerate(values): copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) return res
def parse_args(): # We need to control certain flags here. # e.g. parallelization needs to be suppressed and deferred to torch_xla flags # e.g. input tensor shapes need to be controlled via # max_sentences, required_batch_size_multiple parser = options.get_training_parser() parser.add_argument('--num_cores', type=int, default=8) parser.add_argument('--pad_to_length', type=int, default=64) parser.add_argument('--log_steps', type=int, default=20) parser.add_argument('--use_gpu', action='store_true') parser.add_argument('--metrics_debug', action='store_true') FLAGS = options.parse_args_and_arch(parser) if not FLAGS.use_gpu: if FLAGS.fp16: raise RuntimeError( '--fp16 was provided, this is controlled by env var XLA_USE_BF16') if FLAGS.distributed_world_size > 1: xu.eprint('suppressing "distributed_world_size"') FLAGS.distributed_world_size = 1 if FLAGS.distributed_init_method is not None: xu.eprint('suppressing "distributed_init_method"') FLAGS.distributed_init_method = None if FLAGS.max_sentences != FLAGS.required_batch_size_multiple: batch_size = max( filter(lambda r: r is not None, [FLAGS.max_sentences, FLAGS.required_batch_size_multiple])) xu.eprint( '"max_sentences" and "required_batch_size_multiple" must be equal' ' to have good performance on TPUs. Using {}'.format(batch_size)) FLAGS.max_sentences = batch_size FLAGS.required_batch_size_multiple = batch_size if FLAGS.max_sentences_valid is not None and FLAGS.max_sentences_valid != FLAGS.max_sentences: FLAGS.max_sentences_valid = FLAGS.max_sentences xu.eprint('"max_sentences_valid" and "max_sentences" must be equal' ' to have good performance on TPUs. Using {}'.format( FLAGS.max_sentences)) if FLAGS.max_tokens is not None: xu.eprint('"max_tokens" needs to be None for better TPU performance') FLAGS.max_tokens = None return FLAGS
def run(self): bench_name = self._get_parent_class().__name__ try: self.setup() # Do one warmup run. self.bench() except Exception as e: xu.eprint('Failed running benchmark "{}": {}'.format( bench_name, e)) return try: start = time.time() now = start count = 0 while self.test_time > (now - start): self.bench() count += 1 now = time.time() print('{}: {:.3f}ms per loop'.format( bench_name, 1000.0 * (now - start) / count)) xu.get_print_fn()(torch_xla._XLAC._xla_metrics_report()) except Exception as e: xu.eprint('Failed running benchmark "{}": {}'.format( bench_name, e))
def main_tpu(args): def log_step(step_type, device, step, tracker=None, metrics_debug=False): msg = '{}/ {}, device {}, step {}'.format(step_type, utils.now(), device, step) if tracker: rates = tracker.rate(), tracker.global_rate() msg += ', Rate={:.2f}, Global Rate={:.2f}'.format(*rates) return msg def train_loop_fn(model, loader, device, context): trainer = trainers[str(device)] stats = None tracker = xm.RateTracker() for i, samples in loader: if i and not (i % args.log_steps): print( log_step( 'training', device, i, tracker=tracker, metrics_debug=args.metrics_debug)) _log_output = trainer.train_step(samples) xm.optimizer_step(trainer.optimizer) tracker.add(len(samples) * args.max_sentences) # n_batches * batch_size stats = fairseq_train.get_training_stats(trainer) return tracker, stats def valid_loop_fn(model, loader, device, context): trainer = trainers[str(device)] # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for i, sample in loader: if not (i % args.log_steps): print( log_step( 'validation', device, i, tracker=None, metrics_debug=args.metrics_debug)) log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in ['loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size']: continue extra_meters[k].update(v) stats = fairseq_train.get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg return stats def validate_subset(args, trainers, task, epoch_itr, subset): print('Validating the subset "{}"'.format(subset)) # Initialize data iterator itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), list(trainers.values())[0].get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_workers=args.num_workers).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') stats_per_device = model_parallel(valid_loop_fn, progress) valid_losses = [stats['loss'].avg for stats in stats_per_device] print('validation stats on subset "{}" - {}'.format(subset, utils.now())) for stats in stats_per_device: progress.print(stats, tag=subset, step=trainer.get_num_updates()) return valid_losses def validate(args, trainers, task, epoch_itr, subsets): valid_losses = { subset: validate_subset(args, trainers, task, epoch_itr, subset) for subset in subsets } return valid_losses def initialize_loader_for_epoch(args, epoch_itr): if epoch_itr.epoch <= len(args.update_freq): update_freq = args.update_freq[epoch_itr.epoch - 1] else: update_freq = args.update_freq[-1] # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=False, shuffle=(epoch_itr.epoch >= args.curriculum)) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, no_progress_bar='simple') return progress def keep_training(lr, epoch_itr, trainers): # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf max_update = args.max_update or math.inf lr = min(trainer.get_lr() for trainer in trainers.values()) n_updates = max(trainer.get_num_updates() for trainer in trainers.values()) return ((lr > FLAGS.min_lr) and (epoch_itr.epoch < max_epoch) and (n_updates < max_update)) xu.eprint('Args') for key, val in args.__dict__.items(): xu.eprint('\t{} {}'.format(key, val)) xu.eprint('---------') devices = xm.get_xla_supported_devices(max_devices=args.num_cores) task, trainers, model_parallel, epoch_itr, lr, valid_subsets = prepare_task( args, devices) train_meter = StopwatchMeter() train_meter.start() while keep_training(lr, epoch_itr, trainers): # TRAINING print('Epoch {} begin {}'.format(epoch_itr.epoch + 1, utils.now())) progress = initialize_loader_for_epoch(args, epoch_itr) out = model_parallel(train_loop_fn, progress) trackers, stats_ = zip(*out) print('Epoch {} Training stats:'.format(epoch_itr.epoch)) for device, trainer in trainers.items(): stats = fairseq_train.get_training_stats(trainer) print('device {}'.format(device)) progress.print(stats, tag=device) print('Epoch {} Tracker Rates:'.format(epoch_itr.epoch)) for tracker in trackers: rates = tracker.rate(), tracker.global_rate() print('\tRate={:.2f}, Global Rate={:.2f}'.format(*rates)) print('Epoch {} end {}'.format(epoch_itr.epoch, utils.now())) if args.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) # VALIDATION if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0: valid_losses = validate(args, trainers, task, epoch_itr, valid_subsets) # only use average first validation loss from the first device # to update the learning rate vloss = valid_losses[valid_subsets[0]][0] print('old learning rate: {}'.format(lr)) lr = trainers[devices[0]].lr_step(epoch_itr.epoch, vloss) print('new learning rate: {}'.format(lr)) # save checkpoint if epoch_itr.epoch % args.save_interval == 0: checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, vloss) if args.metrics_debug: print(torch_xla._XLAC._xla_metrics_report()) train_meter.stop() print('| done training in {:.1f} seconds'.format(train_meter.sum))
def _force_release_tensors(): count = torch_xla._XLAC._xla_force_release_all_data() xu.eprint('Forcefully released %d device handles' % count) torch_xla._XLAC._xla_flush_lazy_releases() return count