Ejemplo n.º 1
0
def get_model_parallel_group():
    global _USE_MEGATRON
    if _USE_MEGATRON:
        from fairseq.model_parallel.megatron import mpu
        return mpu.get_model_parallel_group()
    else:
        return None
    def get_normalized_probs(
        self,
        net_output,
        log_probs,
        sample,
    ):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits = net_output[0]
        vocab_size = len(self.decoder.dictionary)

        if logits.size(-1) == vocab_size:
            # we have the full set of logits
            return super().get_normalized_probs(net_output, log_probs, sample)
        # else: vocab-parallel logits, need to combine them

        assert logits.dim() == 3

        # Get the partition's vocab indices
        get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
        partition_vocab_size = logits.size(-1)
        rank = get_model_parallel_rank()
        world_size = get_model_parallel_world_size()
        vocab_start_index, vocab_end_index = get_vocab_range(
            partition_vocab_size, rank, world_size,
        )

        # Assemble full logits
        full_logits = logits.new_zeros(logits.size(0), logits.size(1), vocab_size)
        full_logits[:, :, vocab_start_index:vocab_end_index] = logits
        torch.distributed.all_reduce(
            full_logits,
            op=torch.distributed.ReduceOp.SUM,
            group=get_model_parallel_group(),
        )

        if log_probs:
            return utils.log_softmax(full_logits, dim=-1)
        else:
            return utils.softmax(full_logits, dim=-1)
Ejemplo n.º 3
0
 def _aggregate_model_parallel_grad_norm(total_norm):
     total_norm = total_norm ** 2
     distributed_utils.all_reduce(total_norm, group=get_model_parallel_group())
     total_norm = total_norm ** 0.5
     return total_norm