def distributed_init(args):
    if args.distributed_world_size == 1:
        raise ValueError(
            'Cannot initialize distributed with distributed_world_size=1')

    if not getattr(args, 'tpu', False):
        if torch.distributed.is_initialized():
            warnings.warn(
                'Distributed is already initialized, cannot initialize twice!')
        else:
            logger.info('distributed init (rank {}): {}'.format(
                args.distributed_rank,
                args.distributed_init_method,
            ))
            dist.init_process_group(
                backend=args.distributed_backend,
                init_method=args.distributed_init_method,
                world_size=args.distributed_world_size,
                rank=args.distributed_rank,
            )
            logger.info('initialized host {} as rank {}'.format(
                socket.gethostname(),
                args.distributed_rank,
            ))

            # perform a dummy all-reduce to initialize the NCCL communicator
            if torch.cuda.is_available():
                dist.all_reduce(torch.zeros(1).cuda())

        args.distributed_rank = torch.distributed.get_rank()
    else:
        import torch_xla.core.xla_model as xm
        assert xm.xrt_world_size() == args.distributed_world_size
        args.device_id = xm.get_local_ordinal()
        args.distributed_rank = xm.get_ordinal()
        xm.rendezvous('distributed_init')  # wait for all workers
        xm.mark_step()

    if is_master(args):
        logging.getLogger().setLevel(logging.INFO)
    else:
        logging.getLogger().setLevel(logging.WARNING)

    if args.model_parallel_size > 1:
        try:
            from fairseq.model_parallel.megatron.mpu import (
                get_model_parallel_rank,
                initialize_model_parallel,
                model_parallel_cuda_manual_seed,
            )
        except ImportError:
            raise ImportError('\n\nPlease install the megatron submodule:'
                              '\n\n  git submodule update --init '
                              'fairseq/model_parallel/megatron')
        initialize_model_parallel(args.model_parallel_size)
        model_parallel_cuda_manual_seed(args.seed)
        model_part_number = get_model_parallel_rank()
        args.checkpoint_suffix += '-model_part-{0}'.format(model_part_number)
    return args.distributed_rank
    def get_normalized_probs(
        self,
        net_output,
        log_probs,
        sample,
    ):
        """Get normalized probabilities (or log probs) from a net's output."""

        logits = net_output[0]
        vocab_size = len(self.decoder.dictionary)

        if logits.size(-1) == vocab_size:
            # we have the full set of logits
            return super().get_normalized_probs(net_output, log_probs, sample)
        # else: vocab-parallel logits, need to combine them

        assert logits.dim() == 3

        # Get the partition's vocab indices
        get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
        partition_vocab_size = logits.size(-1)
        rank = get_model_parallel_rank()
        world_size = get_model_parallel_world_size()
        vocab_start_index, vocab_end_index = get_vocab_range(
            partition_vocab_size, rank, world_size,
        )

        # Assemble full logits
        full_logits = logits.new_zeros(logits.size(0), logits.size(1), vocab_size)
        full_logits[:, :, vocab_start_index:vocab_end_index] = logits
        torch.distributed.all_reduce(
            full_logits,
            op=torch.distributed.ReduceOp.SUM,
            group=get_model_parallel_group(),
        )

        if log_probs:
            return utils.log_softmax(full_logits, dim=-1)
        else:
            return utils.softmax(full_logits, dim=-1)
def distributed_init(cfg: FairseqConfig):
    if isinstance(cfg, Namespace):
        from fairseq.dataclass.utils import convert_namespace_to_omegaconf

        cfg = convert_namespace_to_omegaconf(cfg)

    if not cfg.common.tpu:
        if torch.distributed.is_initialized():
            warnings.warn(
                "Distributed is already initialized, cannot initialize twice!")
        else:
            logger.info("distributed init (rank {}): {}".format(
                cfg.distributed_training.distributed_rank,
                cfg.distributed_training.distributed_init_method,
            ))
            dist.init_process_group(
                backend=cfg.distributed_training.distributed_backend,
                init_method=cfg.distributed_training.distributed_init_method,
                world_size=cfg.distributed_training.distributed_world_size,
                rank=cfg.distributed_training.distributed_rank,
            )
            logger.info("initialized host {} as rank {}".format(
                socket.gethostname(),
                cfg.distributed_training.distributed_rank,
            ))

            # perform a dummy all-reduce to initialize the NCCL communicator
            if torch.cuda.is_available():
                dist.all_reduce(torch.zeros(1).cuda())

        cfg.distributed_training.distributed_rank = torch.distributed.get_rank(
        )
    else:
        import torch_xla.core.xla_model as xm

        assert xm.xrt_world_size(
        ) == cfg.distributed_training.distributed_world_size
        cfg.distributed_training.device_id = xm.get_local_ordinal()
        cfg.distributed_training.distributed_rank = xm.get_ordinal()
        xm.rendezvous("distributed_init")  # wait for all workers
        xm.mark_step()

    if is_master(cfg.distributed_training):
        logging.getLogger().setLevel(logging.INFO)
    else:
        logging.getLogger().setLevel(logging.WARNING)

    if cfg.common.model_parallel_size > 1:
        try:
            from fairseq.model_parallel.megatron.mpu import (
                get_model_parallel_rank,
                initialize_model_parallel,
                model_parallel_cuda_manual_seed,
            )
        except ImportError:
            raise ImportError("\n\nPlease install the megatron submodule:"
                              "\n\n  git submodule update --init "
                              "fairseq/model_parallel/megatron")
        initialize_model_parallel(cfg.common.model_parallel_size)
        model_parallel_cuda_manual_seed(cfg.common.seed)
        model_part_number = get_model_parallel_rank()
        cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(
            model_part_number)
    return cfg.distributed_training.distributed_rank