Esempio n. 1
0
 def _test_multiprocessing(self, rank, group, config):
     mp = torch.multiprocessing.Pool(1)
     dummy_group = DummyProcessGroup(rank=group.rank(), size=group.size())
     model = mp.apply(self._get_model, (dummy_group, config))
     if not config["cpu_offload"]:
         model = model.cuda()
     self._one_step(model, group)
Esempio n. 2
0
def fsdp_enable_wrap(cfg: DistributedTrainingConfig,
                     use_sharded_state: bool = False):
    try:
        from fairscale.nn import enable_wrap
    except ImportError:
        raise ImportError(
            "Cannot find FullyShardedDataParallel. "
            "Please install fairscale with: pip install fairscale")
    if cfg.memory_efficient_fp16:
        assert cfg.fp16  # memory_efficient_fp16 should imply fp16
    group = dist_utils.get_data_parallel_group()
    if group is None and cfg.distributed_world_size == 1:
        from fairscale.utils.testing import DummyProcessGroup
        group = DummyProcessGroup(rank=0, size=1)
    fsdp_config = {
        "process_group": group,
        "reshard_after_forward": not cfg.no_reshard_after_forward,
        "mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16,
        "fp32_reduce_scatter": cfg.fp32_reduce_scatter,
        "flatten_parameters": True,
        "cpu_offload": cfg.cpu_offload,
        "compute_dtype": torch.float16 if cfg.fp16 else torch.float32,
        "bucket_cap_mb": cfg.bucket_cap_mb,
    }
    with enable_wrap(use_sharded_state=use_sharded_state, **fsdp_config):
        yield
Esempio n. 3
0
 def setUp(self) -> None:
     version = torch.__version__.split(".")[:2]
     major, minor = int(version[0]), int(version[1])
     if major < 1 or (major == 1 and minor < 6):
         raise unittest.SkipTest(
             "Need pytorch version >= 1.6 due to autocast")
     self.process_group = DummyProcessGroup(rank=0, size=1)
Esempio n. 4
0
 def test_no_sync_before_first_forward(self):
     group = DummyProcessGroup(rank=0, size=1)
     model = self.get_wrapped_model(group, config={})
     batch = model.module.get_input(torch.device("cuda"))
     with model.no_sync():
         output = model(*batch)
         loss = model.module.get_loss(batch, output)
         loss.backward()
     output = model(*batch)
     loss = model.module.get_loss(batch, output)
     loss.backward()
Esempio n. 5
0
 def setUp(self) -> None:
     self.process_group = DummyProcessGroup(rank=0, size=1)
Esempio n. 6
0
 def setUp(self) -> None:
     # For all the tests here, we use a fake group and flatten being False since those should
     # not affect how wrapping work.
     self.process_group = DummyProcessGroup(rank=0, size=1)