def _check_grad_norms(self, grad_norm): """Check that grad norms are consistent across workers.""" if self._grad_norm_buf is not None: self._grad_norm_buf.zero_() self._grad_norm_buf[self.data_parallel_rank] = grad_norm distributed_utils.all_reduce( self._grad_norm_buf, group=self.data_parallel_process_group) def is_consistent(tensor): max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) return (torch.isfinite(tensor).all() or (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all()) if not is_consistent(self._grad_norm_buf): pretty_detail = "\n".join( "rank {:3d} = {:.8f}".format(r, n) for r, n in enumerate(self._grad_norm_buf.tolist())) error_detail = "grad_norm across the workers:\n{}\n".format( pretty_detail) # use FloatingPointError to trigger NanDetector raise FloatingPointError( "Fatal error: gradients are inconsistent between workers. " "Try --ddp-backend=no_c10d. " "Or are you mixing up different generation of GPUs in training?" + "\n" + "-" * 80 + "\n{}\n".format(error_detail) + "-" * 80)
def _check_grad_norms(self, grad_norm): """Check that grad norms are consistent across workers.""" if self._grad_norm_buf is not None: self._grad_norm_buf.zero_() self._grad_norm_buf[self.args.distributed_rank] = grad_norm distributed_utils.all_reduce(self._grad_norm_buf) if not (self._grad_norm_buf == self._grad_norm_buf[0]).all(): raise RuntimeError( "Fatal error: gradients are inconsistent between workers. " "Try --ddp-backend=no_c10d.")
def _reduce(input_): """All-reduce the the input tensor across model parallel group.""" group = get_model_parallel_group() # Bypass the function if we are using only 1 GPU. if get_world_size(group=group) == 1: return input_ # All-reduce. all_reduce(input_, group=group) return input_
def _sync_sample_ratios(self, ratios): # in case the ratios are not precisely the same across processes # also to ensure every procresses update the ratios in the same pace ratios = torch.DoubleTensor(ratios) if torch.distributed.is_initialized(): if torch.cuda.is_available(): distributed_utils.all_reduce(ratios.cuda()) else: distributed_utils.all_reduce(ratios) ret = ratios.cpu() ret = ret.numpy() return ret
def _fast_stat_sync_sum( self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, min_buffer_size: int = 50, ): """ Sync logging outputs across workers. fast_stat_sync_sum is faster than all_gather_list_sync, but is only suitable when logging outputs are scalars and can be summed. """ num_extra = len(extra_stats_to_sum) if len(logging_outputs) > 0: sorted_keys = sorted(logging_outputs[0].keys()) stats = [0.] + list(extra_stats_to_sum) + [ sum(log.get(k, 0) for log in logging_outputs) for k in sorted_keys ] stats = stats + [0.] * (min_buffer_size - len(stats)) buf = torch.cuda.DoubleTensor(stats) else: buf = torch.zeros(min_buffer_size, dtype=torch.double, device='cuda') buf[0] = 1. # flag to indicate we should fallback to _all_gather_list_sync # stats buffer is organized like: # 0: flag to indicate whether fast-stat-sync should be disabled # 1-i: extra_stats_to_sum # i-j: values from logging_outputs (sorted by key) # j-min_buffer_size: padded with 0s distributed_utils.all_reduce(buf) buf = buf.tolist() fallback = buf[0] if fallback > 0.: # fallback to _all_gather_list_sync return self._all_gather_list_sync(logging_outputs, *extra_stats_to_sum) else: extra_stats_to_sum, stats = buf[1:num_extra + 1], buf[num_extra + 1:] stats = [{k: stats[i] for i, k in enumerate(sorted_keys)}] return [stats] + extra_stats_to_sum
def _check_grad_norms(self, grad_norm): """Check that grad norms are consistent across workers.""" if self._grad_norm_buf is not None: self._grad_norm_buf.zero_() self._grad_norm_buf[self.data_parallel_rank] = grad_norm distributed_utils.all_reduce( self._grad_norm_buf, group=self.data_parallel_process_group) if not self._is_grad_norms_consistent(self._grad_norm_buf): pretty_detail = "\n".join( "rank {:3d} = {:.8f}".format(r, n) for r, n in enumerate(self._grad_norm_buf.tolist())) error_detail = "grad_norm across the workers:\n{}\n".format( pretty_detail) raise RuntimeError( "Fatal error: gradients are inconsistent between workers. " "Try --ddp-backend=no_c10d. " "Or are you mixing up different generation of GPUs in training?" + "\n" + "-" * 80 + "\n{}\n".format(error_detail) + "-" * 80)
def _sync_sample_ratios(self, ratios): # in case the ratios are not precisely the same across processes # also to ensure every procresses update the ratios in the same pace # fixme: this has bug only on tir? # ratios = torch.DoubleTensor(ratios) if torch.distributed.is_initialized(): if torch.cuda.is_available(): distributed_utils.all_reduce(ratios.cuda()) else: distributed_utils.all_reduce(ratios) ret = ratios.cpu() ret = ret.numpy() # Ad-hoc FIX! if self.remapped_lang_ids is not None: ret = ret[self.remapped_lang_ids] else: ret = ratios.cpu() ret = ret.numpy() # Ad-hoc FIX! if self.remapped_lang_ids is not None: ret = ret[self.remapped_lang_ids] return ret
def _fast_stat_sync_sum(self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum): """ Sync logging outputs across workers. fast_stat_sync_sum is faster than all_gather_list_sync, but is only suitable when logging outputs are scalars and can be summed. """ num_extra = len(extra_stats_to_sum) if len(logging_outputs) > 0: sorted_keys = sorted(logging_outputs[0].keys()) stats = list(extra_stats_to_sum) + [ sum(log.get(k, 0) for log in logging_outputs) for k in sorted_keys ] buf = torch.cuda.DoubleTensor(stats) # When the number of batches is not evenly divisible by the # number of GPUs, logging_outputs will be empty for some # workers in the last iteration. But we still need to know # the keys and buffer size, so we cache the state in case it # needs to be reused by this worker later. self._fss_buf = buf self._fss_sorted_keys = sorted_keys elif self._fss_buf is not None: buf = self._fss_buf buf.zero_() buf[:num_extra] = torch.cuda.DoubleTensor(extra_stats_to_sum) sorted_keys = self._fss_sorted_keys else: raise RuntimeError( 'fast_stat_sync failed, perhaps (# GPUs) > (# batches)?') distributed_utils.all_reduce(buf) buf = buf.tolist() extra_stats_to_sum, stats = buf[:num_extra], buf[num_extra:] stats = [{k: stats[i] for i, k in enumerate(sorted_keys)}] return [stats] + extra_stats_to_sum
def forward(ctx, vocab_parallel_logits, target): # Copy so the input remains unchanged. logits = vocab_parallel_logits.clone() # Maximum value along vocab dimension across all GPUs. logits_max = torch.max(logits, dim=-1)[0] all_reduce(logits_max, op='max', group=get_model_parallel_group()) # Subtract the maximum value. logits.sub_(logits_max.unsqueeze(dim=-1)) # Sum of exponential of logits along vocab dimension across all GPUs. exp_logits = logits.exp() sum_exp_logits = exp_logits.sum(dim=-1) all_reduce(sum_exp_logits, op='sum', group=get_model_parallel_group()) # Get the partition's vocab indecies get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size partition_vocab_size = vocab_parallel_logits.size()[-1] rank = get_model_parallel_rank() world_size = get_model_parallel_world_size() vocab_start_index, vocab_end_index = get_vocab_range( partition_vocab_size, rank, world_size) # Create a mask of valid vocab ids (1 means it needs to be masked). target_mask = (target < vocab_start_index) | (target >= vocab_end_index) masked_target = target.clone() - vocab_start_index masked_target[target_mask] = 0 # Get predicted-logits = logits[target]. # For Simplicity, we convert logits to a 2-D tensor with size # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. logits_2d = logits.view(-1, partition_vocab_size) masked_target_1d = masked_target.view(-1) arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] predicted_logits = predicted_logits_1d.view_as(target) predicted_logits[target_mask] = 0.0 # All reduce is needed to get the chunks from other GPUs. all_reduce(predicted_logits, op='sum', group=get_model_parallel_group()) # Loss = log(sum(exp(logits))) - predicted-logit. loss = torch.log(sum_exp_logits) - predicted_logits # Store softmax, target-mask and masked-target for backward pass. exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) return loss
def validate(args, trainer, task, epoch_itr, subsets, test_bleu=False, summary_writer=None): """Evaluate the model on the validation set(s) and return the losses.""" valid_losses = [] distributed_utils.barrier(args, "validate1_%d" % trainer.get_num_updates()) for subset in subsets: # Initialize data iterator def get_itr(): itr = task.get_batch_iterator( dataset=task.dataset(subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences_valid, max_positions=utils.resolve_max_positions( task.max_positions(), trainer.get_model().max_positions(), ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=8, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, ).next_epoch_itr(shuffle=False) progress = progress_bar.build_progress_bar( args, itr, epoch_itr.epoch, prefix='valid on \'{}\' subset'.format(subset), no_progress_bar='simple') return progress progress = get_itr() num_dataset = task.dataset(subset).num_dataset # reset validation loss meters for k in ['valid_loss', 'valid_nll_loss']: meter = trainer.get_meter(k) if meter is not None: meter.reset() extra_meters = collections.defaultdict(lambda: AverageMeter()) for sample in progress: log_output = trainer.valid_step(sample) for k, v in log_output.items(): if k in [ 'loss', 'nll_loss', 'ntokens', 'nsentences', 'sample_size' ]: continue extra_meters[k].update(v) bleu_scorers = [ bleu.Scorer(task.target_dictionary.pad(), task.target_dictionary.eos(), task.target_dictionary.unk()) for _ in range(num_dataset) ] if test_bleu else None # log validation stats stats = get_valid_stats(trainer) for k, meter in extra_meters.items(): stats[k] = meter.avg if bleu_scorers is not None: # test bleu print("| test bleu.") sample_size = [0 for _ in range(num_dataset)] bleu_scores = [0 for _ in range(num_dataset)] progress = get_itr() tgt_str_files = [] hypo_str_files = [] for ds_id in range(num_dataset): tgt_str_path = task.dataset( subset).dataset_names[ds_id] + '.tgt.txt' hypo_str_path = task.dataset( subset).dataset_names[ds_id] + '.hypo.txt' tgt_str_files.append( open(os.path.join(args.save_dir, tgt_str_path), 'w', encoding='utf-8')) hypo_str_files.append( open(os.path.join(args.save_dir, hypo_str_path), 'w', encoding='utf-8')) def print_to_file(dataset_id, tgt_str, hypo_str): tgt_str_files[dataset_id].write(tgt_str + '\n') hypo_str_files[dataset_id].write(hypo_str + '\n') for sample in progress: trainer.test_bleu_step(sample, bleu_scorers, print_to_file) if 'dataset_id' in sample: for ds_id in range(num_dataset): sample_size[ds_id] += ( sample['dataset_id'] == ds_id).int().sum().item() elif 'id' in sample: sample_size[0] += len(sample['id']) for f in tgt_str_files + hypo_str_files: f.close() distributed_utils.barrier( args, "validate2_%d" % trainer.get_num_updates()) for ds_id in range(num_dataset): try: bleu_scores[ds_id] = bleu_scorers[ds_id].score( ) * sample_size[ds_id] except Exception as e: bleu_scores[ds_id] = 0 sample_size = torch.Tensor(sample_size).cuda() bleu_scores = torch.Tensor(bleu_scores).cuda() if args.distributed_world_size > 1: all_reduce(sample_size) all_reduce(bleu_scores) bleu_dict = {} for ds_id in range(num_dataset): if sample_size[ds_id].item() > 0: name = "bleu_" + task.dataset(subset).dataset_names[ds_id] bleu_dict[name] = stats[name] = bleu_scores[ds_id].item( ) / sample_size[ds_id].item() try: train_ds_id = task.dataset( 'train').dataset_names.index( task.dataset(subset).dataset_names[ds_id]) task.dataset('train').student_scores[ train_ds_id] = bleu_dict[name] except ValueError: pass output_path = os.path.join(args.save_dir, 'val_bleu.json') json.dump(bleu_dict, open(output_path, 'w')) progress.print(stats) if summary_writer is not None: summary_writer.log_stats('val/' + subset, stats, trainer.get_num_updates()) valid_losses.append(stats['valid_loss']) return valid_losses
def _aggregate_model_parallel_grad_norm(total_norm): total_norm = total_norm ** 2 distributed_utils.all_reduce(total_norm, group=get_model_parallel_group()) total_norm = total_norm ** 0.5 return total_norm
def train_step(self, samples, dummy_batch=False, assistant=None, assistant_queue=None, weights=None): """Do forward, backward and parameter update.""" # Set seed based on args.seed and the update number so that we get # reproducible results when resuming from checkpoints seed = self.args.seed + self.get_num_updates() torch.manual_seed(seed) torch.cuda.manual_seed(seed) self.model.train() self.zero_grad() if not dummy_batch: self.meters['train_wall'].start() # forward and backward pass logging_outputs, sample_sizes, ooms = [], [], 0 for i, sample in enumerate(samples): sample = self._prepare_sample(sample) if sample is None: # when sample is None, run forward/backward on a dummy batch # and ignore the resulting gradients sample = self._prepare_sample(self._dummy_batch) ignore_grad = True else: ignore_grad = False try: if self.args.distributed_world_size > 1: # Whenever *samples* contains more than one mini-batch, we # want to accumulate gradients locally and only call # all-reduce in the last backwards pass. Currently the # *need_reduction* flag is only supported by # LegacyDistributedDataParallel. if i < len(samples) - 1: self.model.accumulate_grads = True else: self.model.accumulate_grads = False # forward and backward if self.args.assistant: losses, sample_size, logging_output, precisions = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad) elif self.args.spl: losses, sample_size, logging_output, precisions = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad, lambda_t=self.lambda_t) else: losses, sample_size, logging_output = self.task.train_step( sample, self.model, self.criterion, self.optimizer, ignore_grad) # record new losses if self.args.spl and not dummy_batch: y_lengths = utils.get_len( sample['target'].cpu().numpy(), self.task.target_dictionary.pad()) norm_losses = np.divide(losses.detach().cpu().numpy(), y_lengths) self.loss_chart[ sample['id'].cpu().numpy()] = torch.from_numpy( norm_losses, ).type(torch.FloatTensor).cuda() if self.args.distributed_world_size > 1: all_reduce(self.loss_chart, op=MIN_OP) # prepare data for assistant trainning if assistant is not None and np.random.rand( ) < SEC_TRAIN_RATIO: sec_batch_size = sample['id'].size(0) indices_sec = np.random.choice(sample['id'].size(0), sec_batch_size) x = sample['net_input']['src_tokens'][indices_sec] y = sample['target'][indices_sec] l = losses[indices_sec] x = x.cpu().numpy() y = y.cpu().numpy() l = l.detach().cpu().numpy() keep_probs = assistant.train_step(x, y, l) elif assistant_queue is not None and np.random.rand( ) < SEC_TRAIN_RATIO: sec_batch_size = sample['id'].size(0) local_indices_sec = np.random.choice( sample['id'].size(0), sec_batch_size) global_indices_sec = sample['id'][local_indices_sec].cpu( ).numpy() l = losses[local_indices_sec] l = l.detach().cpu().numpy() if not assistant_queue.full(): assistant_queue.put((global_indices_sec, l), block=False) else: _ = assistant_queue.get() assistant_queue.put((global_indices_sec, l), block=False) if not ignore_grad: logging_outputs.append(logging_output) sample_sizes.append(sample_size) except RuntimeError as e: if 'out of memory' in str(e): print('| WARNING: ran out of memory, skipping batch') ooms += 1 self.zero_grad() else: print(sample, flush=True, force=True) raise e if dummy_batch: return None # gather logging outputs from all replicas if self.args.distributed_world_size > 1: logging_outputs, sample_sizes, ooms = zip( *distributed_utils.all_gather_list( [logging_outputs, sample_sizes, ooms], )) logging_outputs = list(chain.from_iterable(logging_outputs)) sample_sizes = list(chain.from_iterable(sample_sizes)) ooms = sum(ooms) if ooms == self.args.distributed_world_size * len(samples): print('| WARNING: OOM in all workers, skipping update') self.zero_grad() return None # aggregate logging outputs and sample sizes logging_output = self.task.aggregate_logging_outputs( logging_outputs, self.criterion) sample_size = self.task.grad_denom(sample_sizes, self.criterion) if not all(k in logging_output for k in ['ntokens', 'nsentences']): raise Exception( ('Please update the {}.aggregate_logging_outputs() method to ' 'return ntokens and nsentences').format( self.task.__class__.__name__)) try: # normalize grads by sample size self.optimizer.multiply_grads(self.args.distributed_world_size / float(sample_size)) # clip grads grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm) # take an optimization step self.optimizer.step() self._num_updates += 1 # update learning rate self.lr_scheduler.step_update(self._num_updates) # update meters ntokens = logging_output.get('ntokens', 0) nsentences = logging_output.get('nsentences', 0) self.meters['wps'].update(ntokens) self.meters['ups'].update(1.) self.meters['wpb'].update(ntokens) self.meters['bsz'].update(nsentences) self.meters['gnorm'].update(grad_norm) self.meters['clip'].update(1. if grad_norm > self.args.clip_norm and self.args.clip_norm > 0 else 0.) self.meters['oom'].update(ooms) self.meters['train_loss'].update(logging_output.get('loss', 0), sample_size) if 'nll_loss' in logging_output: self.meters['train_nll_loss'].update( logging_output.get('nll_loss', 0), ntokens) except OverflowError as e: print('| WARNING: overflow detected, ' + str(e)) self.zero_grad() logging_output = None if self.args.fp16: self.meters['loss_scale'].reset() self.meters['loss_scale'].update(self.optimizer.scaler.loss_scale) self.meters['train_wall'].stop() return logging_output