def _sync_sample_ratios(self, ratios): # in case the ratios are not precisely the same across processes # also to ensure every procresses update the ratios in the same pace ratios = torch.DoubleTensor(ratios) if torch.distributed.is_initialized(): if torch.cuda.is_available(): distributed_utils.all_reduce( ratios.cuda(), group=distributed_utils.get_data_parallel_group()) else: distributed_utils.all_reduce( ratios, group=distributed_utils.get_data_parallel_group()) ret = ratios.cpu() ret = ret.numpy() return ret
def fsdp_enable_wrap(cfg: DistributedTrainingConfig, use_sharded_state: bool = False): try: from fairscale.nn import enable_wrap except ImportError: raise ImportError( "Cannot find FullyShardedDataParallel. " "Please install fairscale with: pip install fairscale") if cfg.memory_efficient_fp16: assert cfg.fp16 # memory_efficient_fp16 should imply fp16 group = dist_utils.get_data_parallel_group() if group is None and cfg.distributed_world_size == 1: from fairscale.utils.testing import DummyProcessGroup group = DummyProcessGroup(rank=0, size=1) fsdp_config = { "process_group": group, "reshard_after_forward": not cfg.no_reshard_after_forward, "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16, "fp32_reduce_scatter": cfg.fp32_reduce_scatter, "flatten_parameters": True, "cpu_offload": cfg.cpu_offload, "compute_dtype": torch.float16 if cfg.fp16 else torch.float32, "bucket_cap_mb": cfg.bucket_cap_mb, } with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config): yield
def data_parallel_process_group(self): return distributed_utils.get_data_parallel_group()