Esempio n. 1
0
 def _sync_sample_ratios(self, ratios):
     # in case the ratios are not precisely the same across processes
     # also to ensure every procresses update the ratios in the same pace
     ratios = torch.DoubleTensor(ratios)
     if torch.distributed.is_initialized():
         if torch.cuda.is_available():
             distributed_utils.all_reduce(
                 ratios.cuda(),
                 group=distributed_utils.get_data_parallel_group())
         else:
             distributed_utils.all_reduce(
                 ratios, group=distributed_utils.get_data_parallel_group())
         ret = ratios.cpu()
         ret = ret.numpy()
     return ret
Esempio n. 2
0
def fsdp_enable_wrap(cfg: DistributedTrainingConfig,
                     use_sharded_state: bool = False):
    try:
        from fairscale.nn import enable_wrap
    except ImportError:
        raise ImportError(
            "Cannot find FullyShardedDataParallel. "
            "Please install fairscale with: pip install fairscale")
    if cfg.memory_efficient_fp16:
        assert cfg.fp16  # memory_efficient_fp16 should imply fp16
    group = dist_utils.get_data_parallel_group()
    if group is None and cfg.distributed_world_size == 1:
        from fairscale.utils.testing import DummyProcessGroup
        group = DummyProcessGroup(rank=0, size=1)
    fsdp_config = {
        "process_group": group,
        "reshard_after_forward": not cfg.no_reshard_after_forward,
        "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16,
        "fp32_reduce_scatter": cfg.fp32_reduce_scatter,
        "flatten_parameters": True,
        "cpu_offload": cfg.cpu_offload,
        "compute_dtype": torch.float16 if cfg.fp16 else torch.float32,
        "bucket_cap_mb": cfg.bucket_cap_mb,
    }
    with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config):
        yield
Esempio n. 3
0
 def data_parallel_process_group(self):
     return distributed_utils.get_data_parallel_group()