def get_data_parallel_group(): """Get the data parallel group the caller rank belongs to.""" global _USE_MEGATRON if _USE_MEGATRON: from fairseq.model_parallel.megatron import mpu return mpu.get_data_parallel_group() else: return get_global_group()
def data_parallel_process_group(self): return get_data_parallel_group()