def valid_step(self, sample): """Do forward pass in evaluation mode.""" self.model.eval() self._num_val_iterations += 1 # forward pass sample, sample_size = self._prepare_sample(sample) with torch.no_grad(): loss, oom_fwd = self._forward(sample) logging_output = { 'ntokens': sample['ntokens'] if sample is not None else 0, 'nsentences': sample['target'].size(0) if sample is not None else 0, 'sample_size': sample_size } loss = loss.item() if loss is not None else 0 assert not oom_fwd, 'Ran out of memory during validation' # gather logging outputs from all GPUs if self.args.distributed_world_size > 1: losses, sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list( (loss, sample_size, logging_output) )) else: losses = [loss] sample_sizes = [sample_size] logging_outputs = [logging_output] # TODO: check when ntokens != sample_size ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) weight = sum(log.get('sample_size', 0) for log in logging_outputs) scaled_loss = sum(losses) / weight / math.log(2) return scaled_loss
def all_gather_from_master(args, data: List) -> List: if args.distributed_world_size == 1: return data gathered_data = distributed_utils.all_gather_list(data) # Converts [[x0, y0, z0, ...], [x1, y1, z1, ...], [x2, y2, z2, ...], ...] # to [[x0, x1, x2, ...], [y0, y1, y2, ...], [z0, z1, z2, ...], ...] gathered_data_list = list(zip(*gathered_data)) output_data = [] for data_index, all_data in enumerate(gathered_data_list): # The master's (process 0) data is guaranteed to be in position 0. master_data = all_data[0] # Sanity check that only the master returned any result. if master_data is None: raise RuntimeError( f"Input data element {data_index} of all_gather_from_master " f"returned None from master. Results from all processes: {all_data}" ) for i in range(1, len(all_data)): if all_data[i] is not None: raise RuntimeError( f"Input data element {data_index} of all_gather_from_master " f"should have returned None from non-master process {i}. " f"Results from all processes: {all_data}") output_data.append(master_data) return output_data
def valid_step(self, sample, raise_oom=False): """Do forward pass in evaluation mode.""" with torch.no_grad(): self.model.eval() self.criterion.eval() sample = self._prepare_sample(sample) if sample is None: sample = self._prepare_sample(self._dummy_batch) ignore_results = True else: ignore_results = False try: _loss, sample_size, logging_output = self.task.valid_step( sample, self.model, self.criterion) except RuntimeError as e: if 'out of memory' in str(e) and not raise_oom: print('| WARNING: ran out of memory, retrying batch') for p in self.model.parameters(): if p.grad is not None: p.grad = None # free some memory if self.cuda: torch.cuda.empty_cache() return self.valid_step(sample, raise_oom=True) else: raise e if ignore_results: logging_output, sample_size = {}, 0 # gather logging outputs from all replicas if self.args.distributed_world_size > 1: logging_output, sample_size = zip( *distributed_utils.all_gather_list( [logging_output, sample_size], )) logging_output = list(logging_output) sample_size = list(sample_size) else: logging_output = [logging_output] sample_size = [sample_size] # aggregate logging outputs and sample sizes logging_output = self.task.aggregate_logging_outputs( logging_output, self.get_criterion()) sample_size = self.task.grad_denom(sample_size, self.get_criterion()) # update meters for validation ntokens = logging_output.get('ntokens', 0) self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size) if 'valid_acc' in self.meters: self.meters['valid_acc'].update(logging_output.get('acc', 0), sample_size) if 'nll_loss' in logging_output: self.meters['valid_nll_loss'].update( logging_output.get('nll_loss', 0), ntokens) return logging_output
def validate(self, data_loader, args): was_training = self.training self.eval() keys = ['loss', 'sample_size'] logging_info_list = [] with torch.no_grad(): with tqdm(total=len(data_loader), desc=f"Evaluation", file=sys.stdout) as pbar: for step, batch in enumerate(data_loader): loss_sum, logging_info = self(**batch) logging_info = {k: logging_info[k] for k in keys} logging_info_list.append(logging_info) pbar.update(1) if was_training: self.train() stats = {k: sum(x[k] for x in logging_info_list) for k in keys} # handel distributed evaluation if args.multi_gpu: stats = distributed_utils.all_gather_list(stats) stats = {k: sum(x[k] for x in stats) for k in keys} valid_result = {'ppl': math.exp(stats['loss'] / stats['sample_size'])} return valid_result
def valid_step(self, sample): """Do forward pass in evaluation mode.""" # forward pass sample = self._prepare_sample(sample) _loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True) assert not oom_fwd, 'Ran out of memory during validation' # gather logging outputs from all GPUs if self.args.distributed_world_size > 1: sample_sizes, logging_outputs = zip( *distributed_utils.all_gather_list((sample_size, logging_output))) else: sample_sizes = [sample_size] logging_outputs = [logging_output] # aggregate stats and logging outputs ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) grad_denom = self.criterion.__class__.grad_denom(sample_sizes) agg_logging_output = self.criterion.__class__.aggregate_logging_outputs( logging_outputs) # update loss meters for validation if 'loss' in agg_logging_output: self.meters['valid_loss'].update(agg_logging_output['loss'], grad_denom) # criterions can optionally log the NLL loss too if 'nll_loss' in agg_logging_output: self.meters['valid_nll_loss'].update( agg_logging_output['nll_loss'], ntokens) return agg_logging_output
def _all_gather_list_sync( self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, ): """ Sync logging outputs across workers. all_gather_list_sync is suitable when logging outputs are complex types. """ if self.tpu: raise NotImplementedError if ignore: logging_outputs = [] results = list( zip(*distributed_utils.all_gather_list( [logging_outputs] + list(extra_stats_to_sum), max_size=getattr(self.cfg.common, "all_gather_list_size", 16384), group=self.data_parallel_process_group, ))) logging_outputs, extra_stats_to_sum = results[0], results[1:] logging_outputs = list(chain.from_iterable(logging_outputs)) extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum] return logging_outputs, extra_stats_to_sum
def valid_step(self, sample): """Do forward pass in evaluation mode.""" # forward pass sample = self._prepare_sample(sample) _loss, sample_size, logging_output, oom_fwd = self._forward(sample, eval=True) assert not oom_fwd, 'Ran out of memory during validation' # gather logging outputs from all GPUs if self.args.distributed_world_size > 1: sample_sizes, logging_outputs = zip(*distributed_utils.all_gather_list( (sample_size, logging_output) )) else: sample_sizes = [sample_size] logging_outputs = [logging_output] # aggregate stats and logging outputs ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) grad_denom = self.criterion.__class__.grad_denom(sample_sizes) agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) # update loss meters for validation if 'loss' in agg_logging_output: self.meters['valid_loss'].update(agg_logging_output['loss'], grad_denom) # criterions can optionally log the NLL loss too if 'nll_loss' in agg_logging_output: self.meters['valid_nll_loss'].update(agg_logging_output['nll_loss'], ntokens) return agg_logging_output
def _all_gather_predictions(predictions): ready = False all_ready = False reduced_predictions = [] max_size = 65000 while not all_ready: lst_len = len(predictions) size = 2000 #some extra space for python stuff n = 0 while n < lst_len: str_len = len(predictions[n].encode( 'utf8')) + 8 # per string pickle overhead if size + str_len >= max_size: break size += str_len n += 1 chunk = predictions[:n] predictions = predictions[n:] if not predictions: ready = True chunk = (ready, chunk) torch.cuda.synchronize() gathered = distributed_utils.all_gather_list(chunk, max_size=65000) torch.cuda.synchronize() reduced_predictions += [t[1] for t in gathered] all_ready = all([t[0] for t in gathered]) reduced_predictions = [ item for sublist in reduced_predictions for item in sublist ] return reduced_predictions
def _update_params(self): # gather logging outputs from all replicas sample_sizes = self._buffered_stats['sample_sizes'] logging_outputs = self._buffered_stats['logging_outputs'] ooms_fwd = self._buffered_stats['ooms_fwd'] ooms_bwd = self._buffered_stats['ooms_bwd'] if self.args.distributed_world_size > 1: sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map( lambda l: list(chain.from_iterable(l)), zip(*distributed_utils.all_gather_list((sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)))) ooms_fwd = sum(ooms_fwd) ooms_bwd = sum(ooms_bwd) if ooms_fwd == self.args.distributed_world_size: print('| WARNING: OOM in all workers, skipping batch') self.zero_grad() return None # aggregate stats and logging outputs ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) agg_logging_output = self.criterion.__class__.aggregate_logging_outputs( logging_outputs) grad_denom = self.criterion.__class__.grad_denom(sample_sizes) try: # all-reduce and rescale gradients, then take an optimization step grad_norm = self._all_reduce_and_rescale(grad_denom) self._opt() # update meters self.meters['wps'].update(ntokens) self.meters['ups'].update(1.) self.meters['wpb'].update(ntokens) self.meters['bsz'].update(nsentences) if grad_norm is not None: self.meters['gnorm'].update(grad_norm) self.meters['clip'].update( 1. if grad_norm > self.args.clip_norm else 0.) self.meters['oom'].update(ooms_fwd + ooms_bwd) # update loss meters for training if 'loss' in agg_logging_output: self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom) # criterions can optionally log the NLL loss too if 'nll_loss' in agg_logging_output: self.meters['train_nll_loss'].update( agg_logging_output['nll_loss'], ntokens) except OverflowError as e: self.zero_grad() print('| WARNING: overflow detected, ' + str(e)) self.clear_buffered_stats() return agg_logging_output
def _update_params(self): # gather logging outputs from all replicas sample_sizes = self._buffered_stats['sample_sizes'] logging_outputs = self._buffered_stats['logging_outputs'] ooms_fwd = self._buffered_stats['ooms_fwd'] ooms_bwd = self._buffered_stats['ooms_bwd'] if self.args.distributed_world_size > 1: sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map( lambda l: list(chain.from_iterable(l)), zip(*distributed_utils.all_gather_list( (sample_sizes, logging_outputs, ooms_fwd, ooms_bwd) )) ) ooms_fwd = sum(ooms_fwd) ooms_bwd = sum(ooms_bwd) if ooms_fwd == self.args.distributed_world_size: print('| WARNING: OOM in all workers, skipping batch') self.zero_grad() return None # aggregate stats and logging outputs ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) nsentences = sum(log.get('nsentences', 0) for log in logging_outputs) agg_logging_output = self.criterion.__class__.aggregate_logging_outputs(logging_outputs) grad_denom = self.criterion.__class__.grad_denom(sample_sizes) try: # all-reduce and rescale gradients, then take an optimization step grad_norm = self._all_reduce_and_rescale(grad_denom) self._opt() # update meters self.meters['wps'].update(ntokens) self.meters['ups'].update(1.) self.meters['wpb'].update(ntokens) self.meters['bsz'].update(nsentences) if grad_norm is not None: self.meters['gnorm'].update(grad_norm) self.meters['clip'].update(1. if grad_norm > self.args.clip_norm else 0.) self.meters['oom'].update(ooms_fwd + ooms_bwd) # update loss meters for training if 'loss' in agg_logging_output: self.meters['train_loss'].update(agg_logging_output['loss'], grad_denom) # criterions can optionally log the NLL loss too if 'nll_loss' in agg_logging_output: self.meters['train_nll_loss'].update(agg_logging_output['nll_loss'], ntokens) except OverflowError as e: self.zero_grad() print('| WARNING: overflow detected, ' + str(e)) self.clear_buffered_stats() return agg_logging_output
def validate(self, data_loader, args): gc.collect() keys = [ 'masked_context_token_loss', 'masked_context_token_num', 'masked_column_token_loss', 'masked_column_token_num' ] if self.config.predict_cell_tokens: keys += [ 'masked_cell_token_loss', 'masked_cell_token_num' ] was_training = self.training self.eval() logging_info_list = [] with torch.no_grad(): with tqdm(total=len(data_loader), desc=f"Evaluation", file=sys.stdout) as pbar: for step, batch in enumerate(data_loader): loss_sum, logging_info = self(**batch) logging_info = {k: logging_info[k] for k in keys} logging_info_list.append(logging_info) pbar.update(1) if was_training: self.train() stats = { k: sum(x[k] for x in logging_info_list) for k in keys } # handel distributed evaluation if args.multi_gpu: stats = distributed_utils.all_gather_list(stats) stats = { k: sum(x[k] for x in stats) for k in keys } valid_result = { 'masked_context_token_ppl': math.exp(stats['masked_context_token_loss'] / stats['masked_context_token_num']), 'masked_column_token_ppl': math.exp(stats['masked_column_token_loss'] / stats['masked_column_token_num']) } if self.config.predict_cell_tokens: valid_result['masked_cell_token_ppl'] = math.exp(stats['masked_cell_token_loss'] / stats['masked_cell_token_num']) return valid_result
def mlm_eval_step(self, sample, raise_oom=False): """Do forward pass in evaluation mode.""" with torch.no_grad(): self.model.eval() self.criterion.eval() sample = self._prepare_sample(sample) if sample is None: sample = self._prepare_sample(self._dummy_batch) ignore_results = True else: ignore_results = False try: logging_output = self.task.mlm_eval_step( sample, self.model, self.criterion ) except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if not raise_oom: print( "| WARNING: ran out of memory in validation step, retrying batch" ) for p in self.model.parameters(): if p.grad is not None: p.grad = None # free some memory if self.cuda: torch.cuda.empty_cache() return self.valid_step(sample, raise_oom=True) raise e if ignore_results: logging_output, sample_size = {}, 0 # gather logging outputs from all replicas if self.args.distributed_world_size > 1: logging_output, sample_size = zip( *distributed_utils.all_gather_list([logging_output, sample_size]) ) logging_output = list(logging_output) sample_size = list(sample_size) else: logging_output = [logging_output] sample_size = [sample_size] # aggregate logging outputs and sample sizes logging_output = self.task.aggregate_logging_outputs( logging_output, self.get_criterion() ) return logging_output
def _all_gather_bleu_scorer(scorer): stats = distributed_utils.all_gather_list(scorer.stat) bleu_stat = bleu.BleuStat() bleu_stat.reflen = reduce(lambda x, y: x + y, [s.reflen for s in stats]) bleu_stat.predlen = reduce(lambda x, y: x + y, [s.predlen for s in stats]) bleu_stat.match1 = reduce(lambda x, y: x + y, [s.match1 for s in stats]) bleu_stat.count1 = reduce(lambda x, y: x + y, [s.count1 for s in stats]) bleu_stat.match2 = reduce(lambda x, y: x + y, [s.match2 for s in stats]) bleu_stat.count2 = reduce(lambda x, y: x + y, [s.count2 for s in stats]) bleu_stat.match3 = reduce(lambda x, y: x + y, [s.match3 for s in stats]) bleu_stat.count3 = reduce(lambda x, y: x + y, [s.count3 for s in stats]) bleu_stat.match4 = reduce(lambda x, y: x + y, [s.match4 for s in stats]) bleu_stat.count4 = reduce(lambda x, y: x + y, [s.count4 for s in stats]) scorer.stat = bleu_stat
def prune_step(self, sample, raise_oom=False): """Do forward and backward pass in evaluation mode.""" self.model.eval() self.zero_grad() sample = self._prepare_sample(sample) if sample is None: sample = self._prepare_sample(self._dummy_batch) ignore_results = True else: ignore_results = False try: sample_size, logging_output = self.task.prune_step( sample, self.model, self.criterion, ) except RuntimeError as e: if 'out of memory' in str(e) and not raise_oom: print('| WARNING: ran out of memory, retrying batch') for p in self.model.parameters(): if p.grad is not None: del p.grad # free some memory torch.cuda.empty_cache() return self.prune_step(sample, raise_oom=True) else: raise e if ignore_results: logging_output, sample_size = {}, 0 # gather logging outputs from all replicas if self.args.distributed_world_size > 1: logging_output, sample_size = zip(*distributed_utils.all_gather_list( [logging_output, sample_size], )) logging_output = list(logging_output) sample_size = list(sample_size) else: logging_output = [logging_output] sample_size = [sample_size] # aggregate logging outputs and sample sizes logging_output = self.task.aggregate_logging_outputs( logging_output, self.criterion ) sample_size = self.task.grad_denom( sample_size, self.criterion ) return logging_output
def _all_gather_list_sync(self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum): """ Sync logging outputs across workers. all_gather_list_sync is suitable when logging outputs are complex types. """ results = list( zip(*distributed_utils.all_gather_list( [logging_outputs] + list(extra_stats_to_sum), max_size=getattr(self.args, 'all_gather_list_size', 16384), ))) logging_outputs, extra_stats_to_sum = results[0], results[1:] logging_outputs = list(chain.from_iterable(logging_outputs)) extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum] return [logging_outputs] + extra_stats_to_sum
def _forward(self, sample, eval=False): # prepare model and optimizer if eval: self.model.eval() else: self.model.train() self.optimizer.zero_grad() loss = None sample_size = 0 logging_output = { 'ntokens': sample['ntokens'] if sample is not None else 0, 'nsentences': sample['target'].size(0) if sample is not None else 0, } oom = 0 if sample is not None: try: with utils.maybe_no_grad(eval): # calculate loss and sample size loss, sample_size, logging_output_ = self.criterion( self.model, sample) logging_output.update(logging_output_) except RuntimeError as e: if not eval and 'out of memory' in str(e): print('| WARNING: ran out of memory, skipping batch') oom = 1 loss = None if hasattr(torch.cuda, 'empty_cache'): torch.cuda.empty_cache() else: raise e # synchronize logging outputs for multi-GPU training if self.args.distributed_world_size > 1: sample_sizes, logging_outputs, ooms = zip(*list( distributed_utils.all_gather_list((sample_size, logging_output, oom)))) ooms = sum(ooms) else: sample_sizes = [sample_size] logging_outputs = [logging_output] ooms = oom return loss, sample_sizes, logging_outputs, ooms
def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): if isinstance(cfg, Namespace): logger.warning( "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf" ) cfg = convert_namespace_to_omegaconf(cfg) self.cfg = cfg self.task = task # catalog shared parameters shared_params = _catalog_shared_params(model) self.tpu = cfg.common.tpu self.cuda = torch.cuda.is_available( ) and not cfg.common.cpu and not self.tpu if self.cuda: self.device = torch.device("cuda") elif self.tpu: self.device = utils.get_tpu_device() else: self.device = torch.device("cpu") # copy model and criterion to current device/dtype self._criterion = criterion self._model = model if cfg.common.fp16: self._criterion = self._criterion.half() self._model = self._model.half() elif cfg.common.bf16: self._criterion = self._criterion.to(dtype=torch.bfloat16) self._model = self._model.to(dtype=torch.bfloat16) if not cfg.distributed_training.pipeline_model_parallel: self._criterion = self._criterion.to(device=self.device) self._model = self._model.to(device=self.device) self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel self.last_device = None if self.cuda and self.pipeline_model_parallel: self.last_device = torch.device( cfg.distributed_training.pipeline_devices[-1]) # check that shared parameters are preserved after device transfer for shared_param in shared_params: ref = _get_module_by_path(self._model, shared_param[0]) for path in shared_param[1:]: logger.info("detected shared parameter: {} <- {}".format( shared_param[0], path)) _set_module_by_path(self._model, path, ref) self._dummy_batch = None # indicates we don't have a dummy batch at first self._lr_scheduler = None self._num_updates = 0 self._num_xla_compiles = 0 # for TPUs self._optim_history = None self._optimizer = None self._warn_once = set() self._wrapped_criterion = None self._wrapped_model = None # TODO(myleott): support tpu if self.cuda and self.data_parallel_world_size > 1: self._grad_norm_buf = torch.cuda.DoubleTensor( self.data_parallel_world_size) else: self._grad_norm_buf = None self.quantizer = quantizer if self.quantizer is not None: self.quantizer.set_trainer(self) # get detailed cuda environment if self.cuda: self.cuda_env = utils.CudaEnvironment() if self.data_parallel_world_size > 1: self.cuda_env_arr = distributed_utils.all_gather_list( self.cuda_env, group=distributed_utils.get_global_group()) else: self.cuda_env_arr = [self.cuda_env] if self.data_parallel_rank == 0: utils.CudaEnvironment.pretty_print_cuda_env_list( self.cuda_env_arr) else: self.cuda_env = None self.cuda_env_arr = None metrics.log_start_time("wall", priority=790, round=0) self._start_time = time.time() self._previous_training_time = 0 self._cumulative_training_time = None
def train_step(self, samples, dummy_batch=False, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch is None: self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() if not dummy_batch: self.meters["train_wall"].start() # forward and backward pass logging_outputs, sample_sizes, ooms = [], [], 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) ignore_grad = True else: ignore_grad = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if (self.args.distributed_world_size > 1 and hasattr(self.model, "no_sync") and i < len(samples) - 1): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size, logging_output = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad) if not ignore_grad: logging_outputs.append(logging_output) sample_sizes.append(sample_size) if self.fast_stat_sync: self._all_reduce_list[0] += sample_size self._all_reduce_list[1] += logging_output.get( "nsentences", 0.0) self._all_reduce_list[2] += logging_output.get( "loss", 0.0) self._all_reduce_list[3] += logging_output.get( "nll_loss", 0.0) self._all_reduce_list[4] += logging_output.get( "ntokens", 0.0) except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if raise_oom: raise e print( "| WARNING: attempting to recover from OOM in forward/backward pass", file=sys.stderr) ooms += 1 self.zero_grad() else: raise e if self.fast_stat_sync: self._all_reduce_list[5] += ooms if ooms > 0 and self._oom_batch is not None: self.handle_ooms(ooms) if dummy_batch: return None # gather logging outputs from all replicas if self.fast_stat_sync: # rework all_gather_list all_reduce_list_tensor = torch.cuda.DoubleTensor( self._all_reduce_list) if self._sync_stats(): torch.distributed.all_reduce(all_reduce_list_tensor) # Normalize loss and nll_loss by "sample_size" # and convert to log base 2 all_reduce_list_tensor[2:4].div_( (all_reduce_list_tensor[0:1] * torch.log(torch.cuda.DoubleTensor([2])))) self._all_reduce_list = all_reduce_list_tensor.tolist() logging_output = {} [ sample_size, logging_output["nsentences"], logging_output["loss"], logging_output["nll_loss"], logging_output["ntokens"], ooms, ] = self._all_reduce_list elif self._sync_stats(): logging_outputs, sample_sizes, ooms, prev_norms = zip( *distributed_utils.all_gather_list([ logging_outputs, sample_sizes, ooms, self._prev_grad_norm ])) logging_outputs = list(chain.from_iterable(logging_outputs)) sample_sizes = list(chain.from_iterable(sample_sizes)) ooms = sum(ooms) if not self.args.use_bmuf: assert all( norm == prev_norms[0] for norm in prev_norms ) or all( math.isnan(norm) or math.isinf(norm) for norm in prev_norms ), "Fatal error: gradients are inconsistent between workers" self.meters["oom"].update(ooms, len(samples)) if ooms == self.args.distributed_world_size * len(samples): print("| WARNING: OOM in all workers, skipping update") self.zero_grad() return None if not self.fast_stat_sync: # aggregate logging outputs and sample sizes logging_output = self.task.aggregate_logging_outputs( logging_outputs, self.get_criterion()) sample_size = self.task.grad_denom(sample_sizes, self.get_criterion()) if not all(k in logging_output for k in ["ntokens", "nsentences"]): raise Exception( ("Please update the {}.aggregate_logging_outputs() method to " "return ntokens and nsentences").format( self.task.__class__.__name__)) try: # normalize grads by sample size if sample_size > 0: self.optimizer.multiply_grads( self.args.distributed_world_size / float(sample_size)) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) self._prev_grad_norm = grad_norm # take an optimization step self.optimizer.step() self.set_num_updates(self.get_num_updates() + 1) # task specific update per step self.task.update_step(self._num_updates) # update meters ntokens = logging_output.get("ntokens", 0) nsentences = logging_output.get("nsentences", 0) self.meters["wps"].update(ntokens) self.meters["ups"].update(1.0) self.meters["wpb"].update(ntokens) self.meters["bsz"].update(nsentences) self.meters["gnorm"].update(grad_norm) self.meters["clip"].update(1.0 if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.0) self.meters["train_loss"].update(logging_output.get("loss", 0), sample_size) if "train_acc" in self.meters: self.meters["train_acc"].update(logging_output.get("acc", 0), sample_size) if "nll_loss" in logging_output: self.meters["train_nll_loss"].update( logging_output.get("nll_loss", 0), ntokens) # clear CUDA cache to reduce memory fragmentation if (self.args.empty_cache_freq > 0 and ((self.get_num_updates() + self.args.empty_cache_freq - 1) % self.args.empty_cache_freq) == 0 and torch.cuda.is_available() and not self.args.cpu): torch.cuda.empty_cache() except OverflowError as e: print("| WARNING: overflow detected, " + str(e)) self.zero_grad() logging_output = None except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) print("| ERROR: OOM during optimization, irrecoverable") raise e if self.args.fp16: self.meters["loss_scale"].reset() self.meters["loss_scale"].update(self.optimizer.scaler.loss_scale) self.clear_buffered_stats() self.meters["train_wall"].stop() return logging_output
def train_step(self, samples, dummy_batch=False, assistant=None, assistant_queue=None, weights=None): """Do forward, backward and parameter update.""" # Set seed based on args.seed and the update number so that we get # reproducible results when resuming from checkpoints seed = self.args.seed + self.get_num_updates() torch.manual_seed(seed) torch.cuda.manual_seed(seed) self.model.train() self.zero_grad() if not dummy_batch: self.meters['train_wall'].start() # forward and backward pass logging_outputs, sample_sizes, ooms = [], [], 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) ignore_grad = True else: ignore_grad = False try: if self.args.distributed_world_size > 1: # Whenever *samples* contains more than one mini-batch, we # want to accumulate gradients locally and only call # all-reduce in the last backwards pass. Currently the # *need_reduction* flag is only supported by # LegacyDistributedDataParallel. if i < len(samples) - 1: self.model.accumulate_grads = True else: self.model.accumulate_grads = False # forward and backward if self.args.assistant: losses, sample_size, logging_output, precisions = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad) elif self.args.spl: losses, sample_size, logging_output, precisions = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad, lambda_t=self.lambda_t) else: losses, sample_size, logging_output = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad) # record new losses if self.args.spl and not dummy_batch: y_lengths = utils.get_len( sample['target'].cpu().numpy(), self.task.target_dictionary.pad()) norm_losses = np.divide(losses.detach().cpu().numpy(), y_lengths) self.loss_chart[ sample['id'].cpu().numpy()] = torch.from_numpy( norm_losses, ).type(torch.FloatTensor).cuda() if self.args.distributed_world_size > 1: all_reduce(self.loss_chart, op=MIN_OP) # prepare data for assistant trainning if assistant is not None and np.random.rand( ) < SEC_TRAIN_RATIO: sec_batch_size = sample['id'].size(0) indices_sec = np.random.choice(sample['id'].size(0), sec_batch_size) x = sample['net_input']['src_tokens'][indices_sec] y = sample['target'][indices_sec] l = losses[indices_sec] x = x.cpu().numpy() y = y.cpu().numpy() l = l.detach().cpu().numpy() keep_probs = assistant.train_step(x, y, l) elif assistant_queue is not None and np.random.rand( ) < SEC_TRAIN_RATIO: sec_batch_size = sample['id'].size(0) local_indices_sec = np.random.choice( sample['id'].size(0), sec_batch_size) global_indices_sec = sample['id'][local_indices_sec].cpu( ).numpy() l = losses[local_indices_sec] l = l.detach().cpu().numpy() if not assistant_queue.full(): assistant_queue.put((global_indices_sec, l), block=False) else: _ = assistant_queue.get() assistant_queue.put((global_indices_sec, l), block=False) if not ignore_grad: logging_outputs.append(logging_output) sample_sizes.append(sample_size) except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory, skipping batch') ooms += 1 self.zero_grad() else: print(sample, flush=True, force=True) raise e if dummy_batch: return None # gather logging outputs from all replicas if self.args.distributed_world_size > 1: logging_outputs, sample_sizes, ooms = zip( *distributed_utils.all_gather_list( [logging_outputs, sample_sizes, ooms], )) logging_outputs = list(chain.from_iterable(logging_outputs)) sample_sizes = list(chain.from_iterable(sample_sizes)) ooms = sum(ooms) if ooms == self.args.distributed_world_size * len(samples): print('| WARNING: OOM in all workers, skipping update') self.zero_grad() return None # aggregate logging outputs and sample sizes logging_output = self.task.aggregate_logging_outputs( logging_outputs, self.criterion) sample_size = self.task.grad_denom(sample_sizes, self.criterion) if not all(k in logging_output for k in ['ntokens', 'nsentences']): raise Exception( ('Please update the {}.aggregate_logging_outputs() method to ' 'return ntokens and nsentences').format( self.task.__class__.__name__)) try: # normalize grads by sample size self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size)) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) # take an optimization step self.optimizer.step() self._num_updates += 1 # update learning rate self.lr_scheduler.step_update(self._num_updates) # update meters ntokens = logging_output.get('ntokens', 0) nsentences = logging_output.get('nsentences', 0) self.meters['wps'].update(ntokens) self.meters['ups'].update(1.) self.meters['wpb'].update(ntokens) self.meters['bsz'].update(nsentences) self.meters['gnorm'].update(grad_norm) self.meters['clip'].update(1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.) self.meters['oom'].update(ooms) self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size) if 'nll_loss' in logging_output: self.meters['train_nll_loss'].update( logging_output.get('nll_loss', 0), ntokens) except OverflowError as e: print('| WARNING: overflow detected, ' + str(e)) self.zero_grad() logging_output = None if self.args.fp16: self.meters['loss_scale'].reset() self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale) self.meters['train_wall'].stop() return logging_output
def main(cfg: DictConfig, override_args=None): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" use_fp16 = cfg.common.fp16 use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) if cfg.distributed_training.distributed_world_size > 1: data_parallel_world_size = distributed_utils.get_data_parallel_world_size( ) data_parallel_rank = distributed_utils.get_data_parallel_rank() else: data_parallel_world_size = 1 data_parallel_rank = 0 if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) else: overrides = None # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [cfg.common_eval.path], arg_overrides=overrides, suffix=cfg.checkpoint.checkpoint_suffix, ) model = models[0] # Move models to GPU for model in models: if use_fp16: model.half() if use_cuda: model.cuda() # Print args logger.info(saved_cfg) # Build criterion criterion = task.build_criterion(saved_cfg.criterion) criterion.eval() for subset in cfg.dataset.valid_subset.split(","): try: task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task) dataset = task.dataset(subset) except KeyError: raise Exception("Cannot find dataset: " + subset) # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), ignore_invalid_inputs=cfg.dataset. skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset. required_batch_size_multiple, seed=cfg.common.seed, num_shards=data_parallel_world_size, shard_id=data_parallel_rank, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, prefix=f"valid on '{subset}' subset", default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) log_outputs = [] for i, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample _loss, _sample_size, log_output = task.valid_step( sample, model, criterion) progress.log(log_output, step=i) log_outputs.append(log_output) if data_parallel_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=cfg.common.all_gather_list_size, group=distributed_utils.get_data_parallel_group(), ) log_outputs = list(chain.from_iterable(log_outputs)) with metrics.aggregate() as agg: task.reduce_metrics(log_outputs, criterion) log_output = agg.get_smoothed_values() progress.print(log_output, tag=subset, step=i)
def train_step(self, sample, update_params=True, last_step=False): """Do forward, backward and parameter update.""" # Set seed based on args.seed and the update number so that we get # reproducible results when resuming from checkpoints seed = self.args.seed + self.get_num_updates() torch.manual_seed(seed) torch.cuda.manual_seed(seed) self.model.train() if isinstance(self.model, DDP): if last_step: self.model.disable_allreduce() else: self.model.enable_allreduce() # forward and backward pass sample, sample_size = self._prepare_sample(sample) loss, oom_fwd = self._forward(sample) # If this is a last batch forward pass is skipped on some workers # Batch with sample_size 0 is not accounted for in weighted loss logging_output = { 'ntokens': sample['ntokens'] if sample is not None else 0, 'nsentences': sample['target'].size(0) if sample is not None else 0, 'loss': utils.item(loss.data) if loss is not None else 0, 'sample_size': sample_size } oom_bwd = self._backward(loss) # buffer stats and logging outputs self._buffered_stats['sample_sizes'].append(sample_size) self._buffered_stats['logging_outputs'].append(logging_output) self._buffered_stats['ooms_fwd'].append(oom_fwd) self._buffered_stats['ooms_bwd'].append(oom_bwd) # update parameters if update_params and not last_step: # gather logging outputs from all replicas sample_sizes = self._buffered_stats['sample_sizes'] logging_outputs = self._buffered_stats['logging_outputs'] ooms_fwd = self._buffered_stats['ooms_fwd'] ooms_bwd = self._buffered_stats['ooms_bwd'] if self.args.distributed_world_size > 1: sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map( lambda l: list(chain.from_iterable(l)), zip(*distributed_utils.all_gather_list( (sample_sizes, logging_outputs, ooms_fwd, ooms_bwd) )) ) ooms_fwd = sum(ooms_fwd) ooms_bwd = sum(ooms_bwd) ooms = ooms_fwd + ooms_bwd #this is always <= distributed_world_size if ooms == self.args.distributed_world_size: print('| WARNING: OOM in all workers, skipping batch') self.zero_grad() return # aggregate stats and logging outputs grad_denom = sum(sample_sizes) for p in self.model.parameters(): if p.requires_grad and not p.grad is None: p.grad /= grad_denom self._opt() # Handle logging sample_size = sum(log.get('sample_size', 0) for log in logging_outputs) ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) self.throughput_meter.update(ntokens) info_log_data = { 'tokens/s':self.throughput_meter.avg, 'tokens':ntokens, 'loss':sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) } debug_log_data = { 'batch_size':sum(log.get('nsentences', 0) for log in logging_outputs), 'lr':self.get_lr(), 'grad_denom':grad_denom, 'updates':1 } DLLogger.log(step=self._num_updates, data=info_log_data, verbosity=0) DLLogger.log(step=self._num_updates, data=debug_log_data, verbosity=1) self.clear_buffered_stats()
def train_step(self, sample, update_params=True): """Do forward, backward and parameter update.""" # Set seed based on args.seed and the update number so that we get # reproducible results when resuming from checkpoints seed = self.args.seed + self.get_num_updates() torch.manual_seed(seed) torch.cuda.manual_seed(seed) # forward and backward pass sample = self._prepare_sample(sample) loss, sample_size, logging_output, oom_fwd = self._forward(sample) oom_bwd = self._backward(loss) # buffer stats and logging outputs self._buffered_stats['sample_sizes'].append(sample_size) self._buffered_stats['logging_outputs'].append(logging_output) self._buffered_stats['ooms_fwd'].append(oom_fwd) self._buffered_stats['ooms_bwd'].append(oom_bwd) # update parameters if update_params: # gather logging outputs from all replicas sample_sizes = self._buffered_stats['sample_sizes'] logging_outputs = self._buffered_stats['logging_outputs'] ooms_fwd = self._buffered_stats['ooms_fwd'] ooms_bwd = self._buffered_stats['ooms_bwd'] if self.args.distributed_world_size > 1: sample_sizes, logging_outputs, ooms_fwd, ooms_bwd = map( lambda l: list(chain.from_iterable(l)), zip(*distributed_utils.all_gather_list((sample_sizes, logging_outputs, ooms_fwd, ooms_bwd)))) ooms_fwd = sum(ooms_fwd) ooms_bwd = sum(ooms_bwd) if ooms_fwd == self.args.distributed_world_size: print('| WARNING: OOM in all workers, skipping batch') self.zero_grad() return None # aggregate stats and logging outputs ntokens = sum(log.get('ntokens', 0) for log in logging_outputs) nsentences = sum( log.get('nsentences', 0) for log in logging_outputs) agg_logging_output = self.criterion.__class__.aggregate_logging_outputs( logging_outputs) grad_denom = self.criterion.__class__.grad_denom(sample_sizes) try: # all-reduce and rescale gradients, then take an optimization step grad_norm = self._all_reduce_and_rescale(grad_denom) self._opt() # update meters self.meters['wps'].update(ntokens) self.meters['ups'].update(1.) self.meters['wpb'].update(ntokens) self.meters['bsz'].update(nsentences) if grad_norm is not None: self.meters['gnorm'].update(grad_norm) self.meters['clip'].update( 1. if grad_norm > self.args.clip_norm else 0.) self.meters['oom'].update(ooms_fwd + ooms_bwd) # update loss meters for training if 'loss' in agg_logging_output: self.meters['train_loss'].update( agg_logging_output['loss'], grad_denom) # criterions can optionally log the NLL loss too if 'nll_loss' in agg_logging_output: self.meters['train_nll_loss'].update( agg_logging_output['nll_loss'], ntokens) except OverflowError as e: self.zero_grad() print('| WARNING: overflow detected, ' + str(e)) self.clear_buffered_stats() return agg_logging_output else: return None # buffering updates
def valid_step(self, sample, raise_oom=False): """Do forward pass in evaluation mode.""" with torch.no_grad(): self.model.eval() self.criterion.eval() if self.args.float_valid: self.model.float() self.criterion.float() sample = self._prepare_sample(sample, force_float=self.args.float_valid) if sample is None: sample = self._prepare_sample(self._dummy_batch) ignore_results = True else: ignore_results = False try: _loss, sample_size, logging_output = self.task.valid_step( sample, self.model, self.criterion ) src_target_hypo_str = self._get_decoding(sample) except RuntimeError as e: if "out of memory" in str(e): self._log_oom(e) if not raise_oom: print( "| WARNING: ran out of memory in validation step, retrying batch" ) for p in self.model.parameters(): if p.grad is not None: p.grad = None # free some memory if self.cuda: torch.cuda.empty_cache() return self.valid_step(sample, raise_oom=True) raise e if ignore_results: logging_output, sample_size = {}, 0 src_target_hypo_str = [] # gather logging outputs from all replicas if self.args.distributed_world_size > 1: logging_output, sample_size, src_target_hypo_str = zip( *distributed_utils.all_gather_list( [logging_output, sample_size, src_target_hypo_str], max_size=getattr(self.args, 'all_gather_list_size', 16384), ) ) logging_output = list(logging_output) sample_size = list(sample_size) else: logging_output = [logging_output] sample_size = [sample_size] src_target_hypo_str = [src_target_hypo_str] # aggregate logging outputs and sample sizes logging_output = self.task.aggregate_logging_outputs( logging_output, self.get_criterion() ) sample_size = self.task.grad_denom(sample_size, self.get_criterion()) # update meters for validation ntokens = logging_output.get("ntokens", 0) self.meters["valid_loss"].update(logging_output.get("loss", 0), sample_size) if "valid_acc" in self.meters: self.meters["valid_acc"].update(logging_output.get("acc", 0), sample_size) if "nll_loss" in logging_output: self.meters["valid_nll_loss"].update( logging_output.get("nll_loss", 0), ntokens ) logging_output['src_target_hypo_str'] = src_target_hypo_str if self.args.float_valid: self.model.half() self.criterion.half() return logging_output
def main(args, override_args=None): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' use_fp16 = args.fp16 use_cuda = torch.cuda.is_available() and not args.cpu if use_cuda: torch.cuda.set_device(args.device_id) if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, 'model_overrides', '{}'))) else: overrides = None # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [args.path], arg_overrides=overrides, suffix=getattr(args, "checkpoint_suffix", ""), ) model = models[0] # Move models to GPU for model in models: if use_fp16: model.half() if use_cuda: model.cuda() # Print args logger.info(model_args) # Build criterion criterion = task.build_criterion(model_args) criterion.eval() for subset in args.valid_subset.split(','): try: task.load_dataset(subset, combine=False, epoch=1) dataset = task.dataset(subset) except KeyError: raise Exception('Cannot find dataset: ' + subset) # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, data_buffer_size=args.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, prefix=f"valid on '{subset}' subset", default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) log_outputs = [] for i, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample _loss, _sample_size, log_output = task.valid_step( sample, model, criterion) progress.log(log_output, step=i) log_outputs.append(log_output) if args.distributed_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=getattr(args, 'all_gather_list_size', 16384), ) log_outputs = list(chain.from_iterable(log_outputs)) with metrics.aggregate() as agg: task.reduce_metrics(log_outputs, criterion) log_output = agg.get_smoothed_values() progress.print(log_output, tag=subset, step=i)
def train_step(self, samples, dummy_batch=False, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch is None: self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() if not dummy_batch: self.meters['train_wall'].start() # forward and backward pass logging_outputs, sample_sizes, ooms = [], [], 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) ignore_grad = True else: ignore_grad = False try: if self.args.distributed_world_size > 1: # Whenever *samples* contains more than one mini-batch, we # want to accumulate gradients locally and only call # all-reduce in the last backwards pass. Currently the # *accumulate_grads* flag is only supported by # LegacyDistributedDataParallel. if i < len(samples) - 1: self.model.accumulate_grads = True else: self.model.accumulate_grads = False # forward and backward loss, sample_size, logging_output = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad ) if not ignore_grad: logging_outputs.append(logging_output) sample_sizes.append(sample_size) except RuntimeError as e: if 'out of memory' in str(e): msg = ( '| WARNING: ran out of memory with exception: ' + '{};'.format(e) + '\n Skipping batch' ) # TODO: print should really go to logger, this print goes # to stdout, which is buffered, which in many case is not # printed out if another exception happens # print(msg) print(msg, file=sys.stderr) if raise_oom: raise ValueError(msg) ooms += 1 self.zero_grad() else: raise e if ooms > 0 and self._oom_batch is not None: self.handle_ooms(ooms) if dummy_batch: return None # gather logging outputs from all replicas if self.args.distributed_world_size > 1 and ( (not self.args.use_bmuf) or ( self.args.use_bmuf and (self.get_num_updates() + 1) % self.args.global_sync_iter == 0 ) ): logging_outputs, sample_sizes, ooms, prev_norms = \ zip(*distributed_utils.all_gather_list( [logging_outputs, sample_sizes, ooms, self._prev_grad_norm], )) logging_outputs = list(chain.from_iterable(logging_outputs)) sample_sizes = list(chain.from_iterable(sample_sizes)) ooms = sum(ooms) if not self.args.use_bmuf: assert ( all(norm == prev_norms[0] for norm in prev_norms) or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms) ), 'Fatal error: gradients are inconsistent between workers' self.meters['oom'].update(ooms, len(samples)) if ooms == self.args.distributed_world_size * len(samples): print('| WARNING: OOM in all workers, skipping update') self.zero_grad() return None # aggregate logging outputs and sample sizes logging_output = self.task.aggregate_logging_outputs( logging_outputs, self.criterion ) sample_size = self.task.grad_denom(sample_sizes, self.criterion) if not all(k in logging_output for k in ['ntokens', 'nsentences']): raise Exception(( 'Please update the {}.aggregate_logging_outputs() method to ' 'return ntokens and nsentences' ).format(self.task.__class__.__name__)) try: # normalize grads by sample size if sample_size > 0: self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size)) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) self._prev_grad_norm = grad_norm # take an optimization step self.optimizer.step() self.set_num_updates(self.get_num_updates() + 1) # task specific update per step self.task.update_step(self._num_updates) # update meters ntokens = logging_output.get('ntokens', 0) nsentences = logging_output.get('nsentences', 0) self.meters['wps'].update(ntokens) self.meters['ups'].update(1.) self.meters['wpb'].update(ntokens) self.meters['bsz'].update(nsentences) self.meters['gnorm'].update(grad_norm) self.meters['clip'].update( 1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0. ) self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size) if 'train_acc' in self.meters: self.meters['train_acc'].update( logging_output.get('acc', 0), sample_size) if 'nll_loss' in logging_output: self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens) except OverflowError as e: print('| WARNING: overflow detected, ' + str(e)) self.zero_grad() logging_output = None if self.args.fp16: self.meters['loss_scale'].reset() self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale) self.meters['train_wall'].stop() return logging_output
def __init__(self, args, task, model, criterion, quantizer=None): self.args = args self.task = task # catalog shared parameters shared_params = _catalog_shared_params(model) self.tpu = getattr(args, 'tpu', False) self.cuda = torch.cuda.is_available() and not args.cpu and not self.tpu if self.cuda: self.device = torch.device('cuda') elif self.tpu: self.device = utils.get_tpu_device(args) else: self.device = torch.device('cpu') # copy model and criterion to current device/dtype self._criterion = criterion self._model = model if self.tpu: import torch_xla.core.xla_model as xm self._model = xm.send_cpu_data_to_device(self._model, self.device) if args.fp16: self._criterion = self._criterion.half() self._model = self._model.half() elif args.bf16: self._criterion = self._criterion.to(dtype=torch.bfloat16) self._model = self._model.to(dtype=torch.bfloat16) self._criterion = self._criterion.to(device=self.device) self._model = self._model.to(device=self.device) # check that shared parameters are preserved after device transfer for shared_param in shared_params: ref = _get_module_by_path(self._model, shared_param[0]) for path in shared_param[1:]: logger.info( 'detected shared parameter: {} <- {}'.format(shared_param[0], path) ) _set_module_by_path(self._model, path, ref) self._dummy_batch = "DUMMY" # indicates we don't have a dummy batch at first self._lr_scheduler = None self._num_updates = 0 self._num_xla_compiles = 0 # for TPUs self._optim_history = None self._optimizer = None self._warn_once = set() self._wrapped_criterion = None self._wrapped_model = None # TODO(myleott): support tpu if self.cuda and self.data_parallel_world_size > 1: self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size) else: self._grad_norm_buf = None self.quantizer = quantizer if self.quantizer is not None: self.quantizer.set_trainer(self) # get detailed cuda environment if self.cuda: self.cuda_env = utils.CudaEnvironment() if self.data_parallel_world_size > 1: self.cuda_env_arr = distributed_utils.all_gather_list(self.cuda_env) else: self.cuda_env_arr = [self.cuda_env] if self.data_parallel_rank == 0: utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr) else: self.cuda_env = None self.cuda_env_arr = None metrics.log_start_time("wall", priority=790, round=0) self._start_time = time.time() self._previous_training_time = 0 self._cumulative_training_time = None
def train_step(self, samples, dummy_batch=False): """Do forward, backward and parameter update.""" # Set seed based on args.seed and the update number so that we get # reproducible results when resuming from checkpoints seed = self.args.seed + self.get_num_updates() torch.manual_seed(seed) torch.cuda.manual_seed(seed) self.model.train() self.zero_grad() if not dummy_batch: self.meters['train_wall'].start() # forward and backward pass logging_outputs, sample_sizes, ooms = [], [], 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) ignore_grad = True else: ignore_grad = False try: # forward and backward loss, sample_size, logging_output = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad ) if self.args.distributed_world_size > 1: # only all-reduce gradients in the last backwards pass if i < len(samples) - 1: self.model.need_reduction = False else: self.model.need_reduction = True if not ignore_grad: logging_outputs.append(logging_output) sample_sizes.append(sample_size) except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory, skipping batch') ooms += 1 self.zero_grad() else: raise e if dummy_batch: return None # gather logging outputs from all replicas if self.args.distributed_world_size > 1: logging_outputs, sample_sizes, ooms = zip(*distributed_utils.all_gather_list( [logging_outputs, sample_sizes, ooms], )) logging_outputs = list(chain.from_iterable(logging_outputs)) sample_sizes = list(chain.from_iterable(sample_sizes)) ooms = sum(ooms) if ooms == self.args.distributed_world_size: print('| WARNING: OOM in all workers, skipping update') self.zero_grad() return None # aggregate logging outputs and sample sizes sample_size = self.task.grad_denom(sample_sizes, self.criterion) logging_output = self.task.aggregate_logging_outputs( logging_outputs, self.criterion ) if not all(k in logging_output for k in ['ntokens', 'nsentences']): raise Exception(( 'Please update the {}.aggregate_logging_outputs() method to ' 'return ntokens and nsentences' ).format(self.task.__class__.__name__)) try: self._num_iterations += 1 # normalize grads by sample size self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size)) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) # take an optimization step self.optimizer.step() self._num_updates += 1 # update learning rate self.lr_scheduler.step_update(self._num_updates) # update meters ntokens = logging_output.get('ntokens', 0) nsentences = logging_output.get('nsentences', 0) self.meters['wps'].update(ntokens) self.meters['ups'].update(1.) self.meters['wpb'].update(ntokens) self.meters['bsz'].update(nsentences) self.meters['gnorm'].update(grad_norm) self.meters['clip'].update( 1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0. ) self.meters['oom'].update(ooms) self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size) if 'nll_loss' in logging_output: self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens) except OverflowError as e: print('| WARNING: overflow detected, ' + str(e)) self.zero_grad() logging_output = None if self.args.fp16: self.meters['loss_scale'].reset() self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale) self.meters['train_wall'].stop() return logging_output
def train_step(self, samples, dummy_batch=False, domain=None): """Do forward, backward and parameter update.""" self._set_seed() self.model.train() self.criterion.train() self.zero_grad() if not dummy_batch: self.meters['train_wall'].start() # forward and backward pass logging_outputs, sample_sizes, ooms = [], [], 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) ignore_grad = True else: ignore_grad = False try: if self.args.distributed_world_size > 1: # Whenever *samples* contains more than one mini-batch, we # want to accumulate gradients locally and only call # all-reduce in the last backwards pass. Currently the # *need_reduction* flag is only supported by # LegacyDistributedDataParallel. if i < len(samples) - 1: self.model.accumulate_grads = True else: self.model.accumulate_grads = False # forward and backward loss, sample_size, logging_output = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad, domain=domain, num_updates=self._num_updates) if not ignore_grad: logging_outputs.append(logging_output) sample_sizes.append(sample_size) except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory, skipping batch') ooms += 1 self.zero_grad() else: raise e if ooms > 0 and self._oom_batch is not None: self.handle_ooms(ooms) if dummy_batch: return None # gather logging outputs from all replicas if self.args.distributed_world_size > 1: logging_outputs, sample_sizes, ooms, prev_norms = \ zip(*distributed_utils.all_gather_list( [logging_outputs, sample_sizes, ooms, self._prev_grad_norm], )) logging_outputs = list(chain.from_iterable(logging_outputs)) sample_sizes = list(chain.from_iterable(sample_sizes)) ooms = sum(ooms) assert all(norm == prev_norms[0] for norm in prev_norms), \ 'Fatal error: gradients are inconsistent between workers' self.meters['oom'].update(ooms, len(samples)) if ooms == self.args.distributed_world_size * len(samples): print('| WARNING: OOM in all workers, skipping update') self.zero_grad() return None # aggregate logging outputs and sample sizes logging_output = self.task.aggregate_logging_outputs( logging_outputs, self.criterion) sample_size = self.task.grad_denom(sample_sizes, self.criterion) if not all(k in logging_output for k in ['ntokens', 'nsentences']): raise Exception( ('Please update the {}.aggregate_logging_outputs() method to ' 'return ntokens and nsentences').format( self.task.__class__.__name__)) try: # normalize grads by sample size self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size)) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) self._prev_grad_norm = grad_norm # take an optimization step self.optimizer.step() self._num_updates += 1 # update learning rate self.lr_scheduler.step_update(self._num_updates) # update meters ntokens = logging_output.get('ntokens', 0) nsentences = logging_output.get('nsentences', 0) self.meters['wps'].update(ntokens) self.meters['ups'].update(1.) self.meters['wpb'].update(ntokens) self.meters['bsz'].update(nsentences) self.meters['gnorm'].update(grad_norm) self.meters['clip'].update(1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.) self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size) if 'nll_loss' in logging_output: self.meters['train_nll_loss'].update( logging_output.get('nll_loss', 0), ntokens) except OverflowError as e: print('| WARNING: overflow detected, ' + str(e)) self.zero_grad() logging_output = None if self.args.fp16: self.meters['loss_scale'].reset() self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale) self.meters['train_wall'].stop() return logging_output
def compute_bleu(reference_corpus, translation_corpus, args, max_order=4, use_bp=True): """Computes BLEU score of translated segments against one or more references. Args: reference_corpus: list of references for each translation. Each reference should be tokenized into a list of tokens. translation_corpus: list of translations to score. Each translation should be tokenized into a list of tokens. args: CLI arguments max_order: Maximum n-gram order to use when computing BLEU score. use_bp: boolean, whether to apply brevity penalty. Returns: BLEU score. """ reference_length = 0 translation_length = 0 bp = 1.0 geo_mean = 0 matches_by_order = [0] * max_order possible_matches_by_order = [0] * max_order precisions = [] for (references, translations) in zip(reference_corpus, translation_corpus): reference_length += len(references) translation_length += len(translations) ref_ngram_counts = _get_ngrams_with_counter(references, max_order) translation_ngram_counts = _get_ngrams_with_counter( translations, max_order) overlap = dict((ngram, min(count, translation_ngram_counts[ngram])) for ngram, count in ref_ngram_counts.items()) for ngram in overlap: matches_by_order[len(ngram) - 1] += overlap[ngram] for ngram in translation_ngram_counts: possible_matches_by_order[len(ngram) - 1] += translation_ngram_counts[ngram] precisions = [0] * max_order smooth = 1.0 # do reductions of matches_by_order and possible_matches_by_order if args.distributed_world_size > 1: stats = RefBleuStats(matches_by_order, possible_matches_by_order, reference_length, translation_length) all_stats = distributed_utils.all_gather_list(stats) stats = reduce(lambda a, b: a + b, all_stats) matches_by_order = stats.matches_by_order possible_matches_by_order = stats.possible_matches_by_order reference_length = stats.reference_length translation_length = stats.translation_length for i in xrange(0, max_order): if possible_matches_by_order[i] > 0: precisions[i] = float( matches_by_order[i]) / possible_matches_by_order[i] if matches_by_order[i] > 0: precisions[i] = float( matches_by_order[i]) / possible_matches_by_order[i] else: smooth *= 2 precisions[i] = 1.0 / (smooth * possible_matches_by_order[i]) else: precisions[i] = 0.0 if max(precisions) > 0: p_log_sum = sum(math.log(p) for p in precisions if p) geo_mean = math.exp(p_log_sum / max_order) if use_bp: if reference_length > 0: ratio = translation_length / reference_length bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0 else: bp = 1.0 bleu = geo_mean * bp return np.float32(bleu) * 100.0
def valid_step(self, sample, raise_oom=False): """Do forward pass in evaluation mode.""" with torch.no_grad(): self.model.eval() sample = self._prepare_sample(sample) if sample is None: sample = self._prepare_sample(self._dummy_batch) ignore_results = True else: ignore_results = False try: _loss, sample_size, logging_output = self.task.get_loss( self.model, self.criterion, sample, is_valid=True) except RuntimeError as e: if 'out of memory' in str(e) and not raise_oom: print('| WARNING: ran out of memory, retrying batch') for p in self.model.parameters(): if p.grad is not None: del p.grad # free some memory torch.cuda.empty_cache() return self.valid_step(sample, raise_oom=True) else: raise e if ignore_results: logging_output, sample_size = {}, 0 # gather logging outputs from all replicas if self.args.distributed_world_size > 1: logging_output, sample_size = zip( *distributed_utils.all_gather_list( [logging_output, sample_size], )) logging_output = list(logging_output) sample_size = list(sample_size) else: logging_output = [logging_output] sample_size = [sample_size] # extra hacky! if hasattr(self.task, 'aggregate_extra_metrics'): extra_metrics = self.task.aggregate_extra_metrics(logging_output) else: extra_metrics = None # aggregate logging outputs and sample sizes logging_output = self.criterion._aggregate_logging_outputs( logging_output) sample_size = self.criterion.__class__.grad_denom(sample_size) if extra_metrics is not None: logging_output['extra_metrics'] = extra_metrics # update meters for validation ntokens = logging_output.get('ntokens', 0) self.meters['valid_loss'].update(logging_output.get('loss', 0), sample_size) if 'nll_loss' in logging_output: self.meters['valid_nll_loss'].update( logging_output.get('nll_loss', 0), ntokens) if 'extra_metrics' in logging_output: for n, m in self.meters['task'].items(): m.update(*logging_output['extra_metrics'][n]) return logging_output
def train_step(self, samples, dummy_batch=False, raise_oom=False): """Do forward, backward and parameter update.""" if self._dummy_batch is None: self._dummy_batch = samples[0] self._set_seed() self.model.train() self.criterion.train() self.zero_grad() if not dummy_batch: self.meters['train_wall'].start() # forward and backward pass logging_outputs, sample_sizes, ooms = [], [], 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) ignore_grad = True else: ignore_grad = False def maybe_no_sync(): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ if ( self.args.distributed_world_size > 1 and hasattr(self.model, 'no_sync') and i < len(samples) - 1 ): return self.model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager try: with maybe_no_sync(): # forward and backward loss, sample_size, logging_output = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad ) if not ignore_grad: logging_outputs.append(logging_output) sample_sizes.append(sample_size) if self.fast_stat_sync: self._all_reduce_list[0] += sample_size self._all_reduce_list[1] += logging_output.get('nsentences', 0.0) self._all_reduce_list[2] += logging_output.get('loss', 0.0) self._all_reduce_list[3] += logging_output.get('nll_loss', 0.0) self._all_reduce_list[4] += logging_output.get('ntokens', 0.0) except RuntimeError as e: if 'out of memory' in str(e): msg = ( '| WARNING: ran out of memory with exception: ' + '{};'.format(e) + '\n Skipping batch' ) # TODO: print should really go to logger, this print goes # to stdout, which is buffered, which in many case is not # printed out if another exception happens # print(msg) print(msg, file=sys.stderr) if raise_oom: raise ValueError(msg) ooms += 1 self.zero_grad() else: raise e if self.fast_stat_sync: self._all_reduce_list[5] += ooms if ooms > 0 and self._oom_batch is not None: self.handle_ooms(ooms) if dummy_batch: return None # gather logging outputs from all replicas if self.fast_stat_sync: # rework all_gather_list all_reduce_list_tensor = torch.cuda.DoubleTensor(self._all_reduce_list) if self._sync_stats(): torch.distributed.all_reduce(all_reduce_list_tensor) # Normalize loss and nll_loss by "sample_size" # and convert to log base 2 all_reduce_list_tensor[2:4].div_( ( all_reduce_list_tensor[0:1] * torch.log(torch.cuda.DoubleTensor([2])) ) ) self._all_reduce_list = all_reduce_list_tensor.tolist() logging_output = {} [ sample_size, logging_output['nsentences'], logging_output['loss'], logging_output['nll_loss'], logging_output['ntokens'], ooms, ] = self._all_reduce_list elif self._sync_stats(): logging_outputs, sample_sizes, ooms, prev_norms = \ zip(*distributed_utils.all_gather_list( [logging_outputs, sample_sizes, ooms, self._prev_grad_norm], )) logging_outputs = list(chain.from_iterable(logging_outputs)) sample_sizes = list(chain.from_iterable(sample_sizes)) ooms = sum(ooms) if not self.args.use_bmuf: assert ( all(norm == prev_norms[0] for norm in prev_norms) or all(math.isnan(norm) or math.isinf(norm) for norm in prev_norms) ), 'Fatal error: gradients are inconsistent between workers' self.meters['oom'].update(ooms, len(samples)) if ooms == self.args.distributed_world_size * len(samples): print('| WARNING: OOM in all workers, skipping update') self.zero_grad() return None if not self.fast_stat_sync: # aggregate logging outputs and sample sizes logging_output = self.task.aggregate_logging_outputs( logging_outputs, self.get_criterion() ) sample_size = self.task.grad_denom(sample_sizes, self.get_criterion()) if not all(k in logging_output for k in ['ntokens', 'nsentences']): raise Exception(( 'Please update the {}.aggregate_logging_outputs() method to ' 'return ntokens and nsentences' ).format(self.task.__class__.__name__)) try: # normalize grads by sample size if sample_size > 0: self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size)) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) self._prev_grad_norm = grad_norm # take an optimization step self.optimizer.step() self.set_num_updates(self.get_num_updates() + 1) # task specific update per step self.task.update_step(self._num_updates) # update meters ntokens = logging_output.get('ntokens', 0) nsentences = logging_output.get('nsentences', 0) self.meters['wps'].update(ntokens) self.meters['ups'].update(1.) self.meters['wpb'].update(ntokens) self.meters['bsz'].update(nsentences) self.meters['gnorm'].update(grad_norm) self.meters['clip'].update( 1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0. ) self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size) if 'train_acc' in self.meters: self.meters['train_acc'].update( logging_output.get('acc', 0), sample_size) if 'nll_loss' in logging_output: self.meters['train_nll_loss'].update(logging_output.get('nll_loss', 0), ntokens) except OverflowError as e: print('| WARNING: overflow detected, ' + str(e)) self.zero_grad() logging_output = None if self.args.fp16: self.meters['loss_scale'].reset() self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale) self.clear_buffered_stats() self.meters['train_wall'].stop() return logging_output
def _main(args, output_file): logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO, stream=output_file, ) logger = logging.getLogger('fairnr_cli.render') utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 logger.info(args) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(os.pathsep), arg_overrides=eval(args.model_overrides), task=task, ) # Optimize ensemble for generation for model in models: if args.fp16: model.half() if use_cuda: model.cuda() logging.info(model) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_workers=args.num_workers).next_epoch_itr(shuffle=False) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(args) shard_id, world_size = args.distributed_rank, args.distributed_world_size output_files = [] if generator.test_poses is not None: total_frames = generator.test_poses.shape[0] _frames = int(np.floor(total_frames / world_size)) step = shard_id * _frames frames = _frames if shard_id < (world_size - 1) else total_frames - step else: step = shard_id * args.render_num_frames frames = args.render_num_frames with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for i, sample in enumerate(t): sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() step, _output_files = task.inference_step(generator, models, [sample, step, frames]) output_files += _output_files gen_timer.stop(500) wps_meter.update(500) t.log({'wps': round(wps_meter.avg)}) timestamp = generator.save_images( output_files, steps='shard{}'.format(shard_id), combine_output=args.render_combine_output) # join videos from all GPUs and delete temp files try: timestamps = distributed_utils.all_gather_list(timestamp) except: timestamps = [timestamp] if shard_id == 0: generator.merge_videos(timestamps)