def get_model_parallel_group(): global _USE_MEGATRON if _USE_MEGATRON: from fairseq.model_parallel.megatron import mpu return mpu.get_model_parallel_group() else: return None
def get_normalized_probs( self, net_output, log_probs, sample, ): """Get normalized probabilities (or log probs) from a net's output.""" logits = net_output[0] vocab_size = len(self.decoder.dictionary) if logits.size(-1) == vocab_size: # we have the full set of logits return super().get_normalized_probs(net_output, log_probs, sample) # else: vocab-parallel logits, need to combine them assert logits.dim() == 3 # Get the partition's vocab indices get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size partition_vocab_size = logits.size(-1) rank = get_model_parallel_rank() world_size = get_model_parallel_world_size() vocab_start_index, vocab_end_index = get_vocab_range( partition_vocab_size, rank, world_size, ) # Assemble full logits full_logits = logits.new_zeros(logits.size(0), logits.size(1), vocab_size) full_logits[:, :, vocab_start_index:vocab_end_index] = logits torch.distributed.all_reduce( full_logits, op=torch.distributed.ReduceOp.SUM, group=get_model_parallel_group(), ) if log_probs: return utils.log_softmax(full_logits, dim=-1) else: return utils.softmax(full_logits, dim=-1)
def _aggregate_model_parallel_grad_norm(total_norm): total_norm = total_norm ** 2 distributed_utils.all_reduce(total_norm, group=get_model_parallel_group()) total_norm = total_norm ** 0.5 return total_norm