def distributed_init(args): if args.distributed_world_size == 1: raise ValueError( 'Cannot initialize distributed with distributed_world_size=1') if not getattr(args, 'tpu', False): if torch.distributed.is_initialized(): warnings.warn( 'Distributed is already initialized, cannot initialize twice!') else: logger.info('distributed init (rank {}): {}'.format( args.distributed_rank, args.distributed_init_method, )) dist.init_process_group( backend=args.distributed_backend, init_method=args.distributed_init_method, world_size=args.distributed_world_size, rank=args.distributed_rank, ) logger.info('initialized host {} as rank {}'.format( socket.gethostname(), args.distributed_rank, )) # perform a dummy all-reduce to initialize the NCCL communicator if torch.cuda.is_available(): dist.all_reduce(torch.zeros(1).cuda()) args.distributed_rank = torch.distributed.get_rank() else: import torch_xla.core.xla_model as xm assert xm.xrt_world_size() == args.distributed_world_size args.device_id = xm.get_local_ordinal() args.distributed_rank = xm.get_ordinal() xm.rendezvous('distributed_init') # wait for all workers xm.mark_step() if is_master(args): logging.getLogger().setLevel(logging.INFO) else: logging.getLogger().setLevel(logging.WARNING) if args.model_parallel_size > 1: try: from fairseq.model_parallel.megatron.mpu import ( get_model_parallel_rank, initialize_model_parallel, model_parallel_cuda_manual_seed, ) except ImportError: raise ImportError('\n\nPlease install the megatron submodule:' '\n\n git submodule update --init ' 'fairseq/model_parallel/megatron') initialize_model_parallel(args.model_parallel_size) model_parallel_cuda_manual_seed(args.seed) model_part_number = get_model_parallel_rank() args.checkpoint_suffix += '-model_part-{0}'.format(model_part_number) return args.distributed_rank
def get_normalized_probs( self, net_output, log_probs, sample, ): """Get normalized probabilities (or log probs) from a net's output.""" logits = net_output[0] vocab_size = len(self.decoder.dictionary) if logits.size(-1) == vocab_size: # we have the full set of logits return super().get_normalized_probs(net_output, log_probs, sample) # else: vocab-parallel logits, need to combine them assert logits.dim() == 3 # Get the partition's vocab indices get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size partition_vocab_size = logits.size(-1) rank = get_model_parallel_rank() world_size = get_model_parallel_world_size() vocab_start_index, vocab_end_index = get_vocab_range( partition_vocab_size, rank, world_size, ) # Assemble full logits full_logits = logits.new_zeros(logits.size(0), logits.size(1), vocab_size) full_logits[:, :, vocab_start_index:vocab_end_index] = logits torch.distributed.all_reduce( full_logits, op=torch.distributed.ReduceOp.SUM, group=get_model_parallel_group(), ) if log_probs: return utils.log_softmax(full_logits, dim=-1) else: return utils.softmax(full_logits, dim=-1)
def distributed_init(cfg: FairseqConfig): if isinstance(cfg, Namespace): from fairseq.dataclass.utils import convert_namespace_to_omegaconf cfg = convert_namespace_to_omegaconf(cfg) if not cfg.common.tpu: if torch.distributed.is_initialized(): warnings.warn( "Distributed is already initialized, cannot initialize twice!") else: logger.info("distributed init (rank {}): {}".format( cfg.distributed_training.distributed_rank, cfg.distributed_training.distributed_init_method, )) dist.init_process_group( backend=cfg.distributed_training.distributed_backend, init_method=cfg.distributed_training.distributed_init_method, world_size=cfg.distributed_training.distributed_world_size, rank=cfg.distributed_training.distributed_rank, ) logger.info("initialized host {} as rank {}".format( socket.gethostname(), cfg.distributed_training.distributed_rank, )) # perform a dummy all-reduce to initialize the NCCL communicator if torch.cuda.is_available(): dist.all_reduce(torch.zeros(1).cuda()) cfg.distributed_training.distributed_rank = torch.distributed.get_rank( ) else: import torch_xla.core.xla_model as xm assert xm.xrt_world_size( ) == cfg.distributed_training.distributed_world_size cfg.distributed_training.device_id = xm.get_local_ordinal() cfg.distributed_training.distributed_rank = xm.get_ordinal() xm.rendezvous("distributed_init") # wait for all workers xm.mark_step() if is_master(cfg.distributed_training): logging.getLogger().setLevel(logging.INFO) else: logging.getLogger().setLevel(logging.WARNING) if cfg.common.model_parallel_size > 1: try: from fairseq.model_parallel.megatron.mpu import ( get_model_parallel_rank, initialize_model_parallel, model_parallel_cuda_manual_seed, ) except ImportError: raise ImportError("\n\nPlease install the megatron submodule:" "\n\n git submodule update --init " "fairseq/model_parallel/megatron") initialize_model_parallel(cfg.common.model_parallel_size) model_parallel_cuda_manual_seed(cfg.common.seed) model_part_number = get_model_parallel_rank() cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format( model_part_number) return cfg.distributed_training.distributed_rank