def test_norm(device, norm_type, mixed_precision): """Test checkpoint_wrapper with different norm layers.""" if device == "cuda" and not torch.cuda.is_available(): pytest.skip("Skip due to lack of GPU") # Get input, ref, checkpoint models and make them equal. in_data = torch.rand(2, 2, 3, 3).to(device) m_ref = get_model(norm_type, False, mixed_precision).to(device) m_cpt = get_model(norm_type, True, mixed_precision).to(device) m_cpt.load_state_dict(m_ref.state_dict()) if torch_version() >= (1, 6, 0): # This assert fails on 1.5.1. assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict()) if mixed_precision != "fp32": in_data = in_data.half() # Needed due to checkpointing. in_data.requires_grad = True for model in (m_ref, m_cpt): optim = SGD(model.parameters(), lr=0.1) if device == "cpu" and mixed_precision != "fp32": # Got: RuntimeError: "batch_norm"/"layer_norm" not implemented for 'Half'. with pytest.raises(RuntimeError): out = model(in_data) return else: # Everything else work. out = model(in_data) out.sum().backward() optim.step() if torch_version() >= (1, 6, 0): assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict())
def test_single_run(): if torch_version() < (1, 8, 0): pytest.skip("requires torch version >= 1.8.0") from fairscale.experimental.nn.auto_shard import shard_model model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout) sharded_model = shard_model(model) assert len(sharded_model) == 2, "Length is sharded model is incorrect." expected_param_nums = [5998600, 5785383] for i, model in enumerate(sharded_model): param_count = {} for name, module in model.named_modules(): if "." in name: continue param_count[name] = sum([x.numel() for x in module.parameters()]) assert expected_param_nums[i] == param_count[""] src_mask = torch.randn((35, 35), dtype=torch.float32) src = torch.randint(1, ntokens, (35, 20)) input = [src, src_mask] for model in sharded_model: if type(input) == list: input = model(*input) else: input = model(input) assert input.size() == torch.Size([35, 20, 28783])
def test_correctness(use_fp16, checkpoint_activation, num_microbatches, use_auto_shard): if use_auto_shard and torch_version() < (1, 8, 0): pytest.skip("auto_shard requires torch version >= 1.8.0") if (use_fp16 or checkpoint_activation) and not hasattr( torch.cuda.amp, "custom_fwd"): pytest.skip( f"AMP APIs are not supported in torch version {torch.__version__}") if not checkpoint_activation and num_microbatches > 1: pytest.skip("We only support microbatches with activation offloading.") device, offload_device = _init() model = _get_model() if use_auto_shard: offload_model = shard_model(model) else: offload_model = model rmodel, ropt, rloss = _train_reg_model(model, device, offload_device) omodel, oopt, oloss = _train_offload_model( offload_model, device, offload_device, use_fp16=use_fp16, checkpoint_activation=checkpoint_activation, num_microbatches=num_microbatches, ) _check_parity(rmodel.cpu(), omodel.cpu(), ropt, oopt, rloss, oloss)
def test_smaller_than_world_size(world_size, test_case, fsdp_config): """Test FSDP with uneven divide of parameter shards.""" if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter in gloo backend") if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs.") temp_file_name = tempfile.mkstemp()[1] unused = tempfile.mkstemp()[1] model = Sequential( Linear(3, 3, bias=False), Linear(3, 4, bias=False), Linear(4, 5, bias=False), Linear(5, 4, bias=False), Linear(4, 3, bias=False), Linear(3, 1, bias=False), Linear(1, 1, bias=False), # param here is smaller than world_size if unflattened. ) mp.spawn( _test_func, args=(world_size, model, fsdp_config, temp_file_name, unused, test_case), nprocs=world_size, join=True, )
def setUp(self): if torch_version() < (1, 6, 0): raise unittest.SkipTest("Need pytorch version >= 1.6 due to lack of reduce_scatter") if not torch.cuda.is_available(): raise unittest.SkipTest("CUDA not available, skipping test") if sys.platform == "win32": raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") if torch.cuda.device_count() < 2: raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")
def test_regnet(temp_files, ddp_ref, precision, flatten, sync_bn): if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter") state_before, inputs, conv_bias, linear_bias, state_after = ddp_ref state_after = state_after[(precision, sync_bn)] fsdp_config = {} fsdp_config["mixed_precision"] = precision == "mixed" fsdp_config["flatten_parameters"] = flatten == "flatten" # When linear bias is True, DDP's AMP O1 and FSDP's default AMP O1.5 is different, # we force FSDP to use AMP O1 here by setting compute_dtype to float32. if linear_bias: fsdp_config["compute_dtype"] = torch.float32 if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0): pytest.skip("Only CUDA 11 is supported with AMP equivalency") # Wrap BN half of the time. wrap_bn = True if random.randint(0, 1) == 0: wrap_bn = False # Except, always wrap BN in mixed precision + sync_bn mode, due to error of sync_bn wrapping, # regardless of compute_dtype. if fsdp_config["mixed_precision"] and sync_bn != "none": wrap_bn = True # When BN is not wrapped (i.e. not in full precision), FSDP's compute_dtype needs to # be fp32 to match DDP (otherwise, numerical errors happen on BN's running_mean/running_var # buffers). if fsdp_config["mixed_precision"] and not wrap_bn: fsdp_config["compute_dtype"] = torch.float32 world_size = _world_size mp.spawn( _distributed_worker, args=( world_size, fsdp_config, wrap_bn, None, temp_files[0], temp_files[1], state_before, inputs, None, state_after, sync_bn, conv_bias, linear_bias, ), nprocs=world_size, join=True, )
def test_basic(device): if "cuda" in device and not torch.cuda.is_available(): pytest.skip("test requires a GPU") input = torch.rand(2, 16, 32).requires_grad_(True) model = BasicModel().to(device) no_cpt = get_loss_and_gnorm(model, input.to(device)) model = BasicModel(use_pytorch_checkpoint=True).to(device) pyt_cpt = get_loss_and_gnorm(model, input.to(device)) model = BasicModel(use_fairscale_checkpoint=True).to(device) fairscale_cpt = get_loss_and_gnorm(model, input.to(device)) model = BasicModel(use_fairscale_checkpoint=True, offload_to_cpu=True).to(device) fairscale_cpt_offload = get_loss_and_gnorm(model, input.to(device)) # Check for correctness. for key in "loss", "gnorm": if not (no_cpt[key] == pyt_cpt[key] == fairscale_cpt[key] == fairscale_cpt_offload[key]): print(no_cpt, pyt_cpt, fairscale_cpt, fairscale_cpt_offload) assert 0 del no_cpt[key] del pyt_cpt[key] del fairscale_cpt[key] del fairscale_cpt_offload[key] # Check for memory usage for cuda only. if "cpu" in device: return mem_peaks = [98816, 103424, 103424, 107520] if torch_version() < (1, 7, 0): # Older torch behaves slightly differently mem_peaks = [102400, 103424, 103424, 107520] assert no_cpt == {"mem_0": 38912, "mem_peak": mem_peaks[0], "mem_after_fwd": 64000, "mem_after_bwd": 74240}, no_cpt assert pyt_cpt == { "mem_0": 38912, "mem_peak": mem_peaks[1], "mem_after_fwd": 43520, "mem_after_bwd": 74240, }, pyt_cpt assert fairscale_cpt == { "mem_0": 38912, "mem_peak": mem_peaks[2], "mem_after_fwd": 43520, "mem_after_bwd": 74240, }, fairscale_cpt assert fairscale_cpt_offload == { "mem_0": 38912, "mem_peak": mem_peaks[3], "mem_after_fwd": 43520, "mem_after_bwd": 74240, }, fairscale_cpt_offload
def test_torch_version(): assert torch_version("") == tuple() assert torch_version("bad format") == tuple() assert torch_version("1.9.0") == (1, 9, 0) assert torch_version("1.10.0a0+gitbc6fc3e") == (1, 10, 0) assert torch_version("1.7.0+cu102") == (1, 7, 0) assert torch_version("1.10.0a0+fb") == (1, 10, 0)
def test_input_type(temp_files, fsdp_config, input_cls): """Test FSDP with input being a list or a dict, only single GPU.""" if torch_version() < (1, 7, 0): # This test runs multiple test cases in a single process. On 1.6.0 it # throw an error like this: # RuntimeError: Container is already initialized! Cannot initialize it twice! pytest.skip( "older pytorch doesn't work well with single process dist_init multiple times" ) result = dist_init(rank=0, world_size=1, filename=temp_files[0], filename_rpc=temp_files[1]) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class Model(Module): def __init__(self): super().__init__() self.layer = Linear(4, 4) def forward(self, input): if isinstance(input, list): input = input[0] else: assert isinstance(input, dict), input input = input["in"] return self.layer(input) model = FSDP(Model(), **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(5): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True if input_cls is list: in_data = [in_data] else: assert input_cls is dict in_data = {"in": in_data} out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() model.assert_state(TrainingState.IDLE) teardown()
def test_train_and_eval_with_checkpointing(): if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter") world_size = 2 with temp_files_ctx(2) as (temp_file_name, unused): mp.spawn( _test_func, args=(world_size, temp_file_name, unused), nprocs=world_size, join=True, )
def test_dynaimc_conditionals_auto_wrapped(): if torch_version() < (1, 8, 0): pytest.skip("requires torch version >= 1.8.0") from fairscale.experimental.nn.auto_shard import shard_model features = 10 model = BranchedNetwork(features) sharded_model = shard_model(model, 3) assert len(sharded_model) == 3 input_ = torch.randn(3, features) model_output = model(input_) sharded_model_output = input_ for shard in sharded_model: sharded_model_output = shard(sharded_model_output) assert torch.allclose(model_output, sharded_model_output)
def check(exp, res): assert list(exp.keys()) == list( res.keys()), f"{list(exp.keys())} vs. {list(res.keys())}" rtol = 1e-4 atol = 1e-5 if with_model2 and mixed_precision and torch_version() >= ( 1, 9, 0): # On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has # larger errors. rtol = 1e-3 atol = 1e-4 for key in exp.keys(): exp_loss = exp[key] res_loss = res[key] torch.testing.assert_allclose(exp_loss, res_loss, rtol=rtol, atol=atol)
def test1(precision, flatten): if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter") temp_file_name = tempfile.mkstemp()[1] unused = tempfile.mkstemp()[1] fsdp_config = {} fsdp_config["mixed_precision"] = precision == "mixed" fsdp_config["flatten_parameters"] = flatten == "flatten" # Some bugs only show up when we are in world_size > 1 due to sharding changing # the tensor dimensions. world_size = 2 mp.spawn( _test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True, )
def test_one_iteration(world_size, test_case, fsdp_config): """Test FSDP with uneven divide of parameter shards.""" if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter") if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs.") temp_file_name = tempfile.mkstemp()[1] unused = tempfile.mkstemp()[1] # TODO (Min): we may want to extend this to a simple 2 layer model so that it covers # more cases in FSDP. Also, assert_ref_out can be extended to multiple # iterations. This could be a good bootcamp task. I should file a github # issue once we merge. model = Linear(3, 3, bias=False) mp.spawn( _test_func, args=(world_size, model, fsdp_config, temp_file_name, unused, test_case), nprocs=world_size, join=True, )
def test_ddp_parity( reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp, manual_reduction, multiple_fw, ): if torch_version() < (1, 8, 0): pytest.skip("pytorch version >= 1.8.0 required") if manual_reduction and change_train_graph: pytest.skip( "Skipping changing model and grad accumulation combination, makes little sense" ) world_size = torch.cuda.device_count() backend = dist.Backend.NCCL with temp_files_ctx(num=1) as temp_files: mp.spawn( run_ddp_parity, args=( world_size, backend, temp_files[0], reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp, manual_reduction, multiple_fw, ), nprocs=world_size, join=True, )
def test(world_size, precision, flatten): """ This test simulates wrapping the module after training to run inference. This is required in cases where later in a session, the model is wrapped again in FSDP but contains nested FSDP wrappers within the module. """ if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter") temp_file_name = tempfile.mkstemp()[1] unused = tempfile.mkstemp()[1] fsdp_config = { "mixed_precision": precision == "mixed", "flatten_parameters": flatten == "flatten", } mp.spawn( _test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True, )
# This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. import functools import tempfile import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp from fairscale.nn import MOELayer, Top2Gate from fairscale.utils import torch_version pytestmark = pytest.mark.skipif( not (torch.cuda.is_available() and torch_version() >= (1, 8, 0)), reason="cuda and torch>=1.8.0 required") devices = ["cuda"] def pg_worker(rank, world_size, init_file, func, *args): init_url = "file://" + init_file dist.init_process_group(backend=dist.Backend.NCCL, rank=rank, world_size=world_size, init_method=init_url) torch.cuda.set_device(rank) dist.all_reduce(torch.zeros(1).cuda()) func(*args) dist.destroy_process_group()
import unittest from parameterized import parameterized import pytest import torch from torch import nn import torch.distributed from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState from fairscale.utils import torch_version from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes # Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release. print(f"torch version {torch_version()}") pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0), reason="requires torch version >= 1.11.0") # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 # All helper functions called by spawn must be either @classmethod, @staticmethod class DistributedTest(unittest.TestCase): def setUp(self): if not torch.cuda.is_available(): raise unittest.SkipTest("CUDA not available, skipping test") if sys.platform == "win32": raise unittest.SkipTest( "NCCL doesn't support Windows, skipping test") if torch.cuda.device_count() < 2: raise unittest.SkipTest(
def check_pytorch_version() -> None: if torch_version() < (1, 9, 0): raise Exception("DistributedPipeline requires PyTorch version 1.9 or higher")
def __init__( self, module: nn.Sequential, balance: Optional[Iterable[int]] = None, *, devices: Optional[Devices] = None, chunks: int = chunks, checkpoint: str = checkpoint, deferred_batch_norm: bool = False, ) -> None: super().__init__() if torch_version()[:2] >= (1, 8): warnings.warn( "fairscale.nn.Pipe has been upstreamed to PyTorch as torch.distributed.pipeline.sync.Pipe. " "It is now deprecated and will be removed in a future version of fairscale. " "The PyTorch API has minor changes. Please see https://pytorch.org/docs/stable/pipeline.html for details.", DeprecationWarning, ) chunks = int(chunks) checkpoint = str(checkpoint) if balance is None: raise ValueError(recommend_auto_balance("balance is required")) if chunks <= 0: raise ValueError("number of chunks must be positive integer") if checkpoint not in ["always", "except_last", "never"]: raise ValueError( "checkpoint is not one of 'always', 'except_last', or 'never'") verify_module(module) # Verify if the underlying skippable modules satisfy integrity. The # integrity can be verified before forward() because it is static. verify_skippables(module) self.chunks = chunks self.checkpoint = checkpoint if deferred_batch_norm: module = DeferredBatchNorm.convert_deferred_batch_norm( module, chunks) if devices is None: devices = range(torch.cuda.device_count()) devices = [torch.device(d) for d in devices] devices = cast(List[torch.device], devices) try: self.partitions, self.balance, self.devices = split_module( module, balance, devices) except BalanceError as exc: raise ValueError(recommend_auto_balance(str(exc))) verify_splitting(module, self.partitions, self.balance, self.devices) self._copy_streams: List[List[AbstractStream]] = [] self._skip_layout = inspect_skip_layout(self.partitions) # Separate CUDA streams for copy. copy_streams = self._ensure_copy_streams() # The micro-batch index where the checkpointing stops. checkpoint_stop = { "always": self.chunks, "except_last": self.chunks - 1, "never": 0 }[self.checkpoint] self.pipeline = Pipeline(self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop)
def run(compute_cycles, all_gather_cycles): has_params = all_gather_cycles > 0 model = _create_model(fsdp_config, compute_cycles, has_params) # Get the input and sets the input's requires_grad to True because # we have a fake compute in the forward pass. batch = torch.rand(1).cuda() batch.requires_grad = True # We run 20 iterations but only collect timing data from the minimal 10 # data points because nondeterministic system events can disturb the timing. cpu_iter = Min10() cpu_wait = Min10() gpu_compute = Min10() gpu_total = Min10() for _ in range(20): # Get two events for measuring the overall time. e1 = Event(enable_timing=True) e2 = Event(enable_timing=True) cpu_start = time.process_time() all_gather_called = False all_gather_base_called = False def _delayed_all_gather(*args, **kwargs): nonlocal all_gather_called all_gather_called = True torch.cuda._sleep(all_gather_cycles) return orig_all_gather(*args, **kwargs) def _delayed_all_gather_base(*args, **kwargs): nonlocal all_gather_base_called all_gather_base_called = True torch.cuda._sleep(all_gather_cycles) assert orig_all_gather_base return orig_all_gather_base(*args, **kwargs) method_string_all_gather_base = "torch.distributed._all_gather_base" if hasattr(torch.distributed, "_all_gather_base") is False: # no such method, to make mock_all_gather_base 0 invocation, use an impossible name method_string_all_gather_base = "math.nan" pass # forward pass # # Even though both e1 & e2 are on the compute stream, since # compute depends on all_gather, e2-e1 includes all_gather time. e1.record() with patch("torch.distributed.all_gather", _delayed_all_gather): with patch(method_string_all_gather_base, _delayed_all_gather_base): out = model(batch) if has_params and world_size > 1: assert all_gather_called or all_gather_base_called else: assert not all_gather_called and not all_gather_base_called e2.record() # backward pass out.backward() if torch_version() >= (1, 7, 0): model.zero_grad(set_to_none=True) else: for p in model.parameters(): p.grad = None cpu_iter_time = time.process_time() - cpu_start # wait for gpu out.item() cpu_wait_for_gpu_time = time.process_time( ) - cpu_start - cpu_iter_time # get sum of the compute time times = [] for mod in model.modules(): if not isinstance(mod, Layer): continue times.append(mod.get_time()) # get gpu compute + all_gather time overall_gpu_time = e1.elapsed_time(e2) cpu_iter.add(cpu_iter_time) cpu_wait.add(cpu_wait_for_gpu_time) gpu_compute.add(sum(times)) gpu_total.add(overall_gpu_time) del model return { "cpu_iter": cpu_iter.avg(), "cpu_wait": cpu_wait.avg(), "gpu_compute": gpu_compute.avg(), "gpu_total": gpu_total.avg(), }
def run_ddp_parity( rank, world_size, backend, temp_file_name, reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp, manual_reduction, multiple_fw, ): dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) device = torch.device("cuda") torch.cuda.set_device(rank) torch.manual_seed(rank) np.random.seed(rank) NUMBER_BATCHS = 5 # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets print( f"{rank}: Checking configuration: accumulate {grad_accumulation}" + f" - change train graph {change_train_graph}" + f" - amp {amp}" + f" - manual reduction {manual_reduction}" + f" - buffers {reduce_buffer_size}" + f" - multiple FW {multiple_fw}", flush=True, ) # The API should be the exact same in between the sharded and non-sharded variants, generic closure def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False): accumulate_steps = 3 if should_accumulate else 1 model.zero_grad() def step(): if scaler is not None: with torch.cuda.amp.autocast(): loss = model(input_tensor).abs().sum() scaler.scale(loss).backward() else: loss = model(input_tensor).abs().sum() loss.backward() with model.no_sync() if should_accumulate else suppress(): for _ in range(accumulate_steps - 1): step() if not _manual_reduction: step() else: with model.no_sync(): step() model.reduce() # Any model works. Add one different buffer per rank model = _get_mlp_emb(multiple_fw) model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) # Make sure that the model starts with non-trainable, so that we check for the buckets to be # properly reassigned when/if this changes next(model.parameters()).requires_grad = False sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True, reduce_buffer_size=reduce_buffer_size, reduce_fp16=fp16_reduction, ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) if fp16_reduction: from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook ddp_model.register_comm_hook(state=None, hook=fp16_compress_hook) # type: ignore ddp_scaler = TorchGradScaler() if amp else None sharded_scaler = ShardedGradScaler() if amp else None # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(sharded_ddp_model, ddp_model) # Typical training loop, check that we get the exact same results as DDP for i in range(NUMBER_BATCHS): input_tensor = _get_random_inputs(device) def ddp_closure(input_tensor=input_tensor): return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation) def sharded_closure(input_tensor=input_tensor): return closure( sharded_ddp_model, sharded_scaler, input_tensor, grad_accumulation, _manual_reduction=manual_reduction, ) # Step/scale both for _scaler, _closure, _optimizer in ( (ddp_scaler, ddp_closure, ddp_optimizer), (sharded_scaler, sharded_closure, sharded_optimizer), ): if _scaler is not None: _ = _closure(input_tensor) _scaler.step(_optimizer) _scaler.update() else: _optimizer.step(_closure()) check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke") # Check that the two grad norm are equivalent # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also # be valid for ShardedDDP # NOTE: DDP does not handle parameters trainability being changed after the fact, see # https://github.com/pytorch/pytorch/blob/5781aec74ef00284e0262817a649278c2e8072bf/torch/nn/parallel/distributed.py#L471 if clip_grad_norm and not change_train_graph: if torch_version() >= (1, 9, 0): total_norm = torch.nn.utils.clip_grad_norm_( ddp_model.parameters(), 0.3, norm_type=2.0, error_if_nonfinite=False) # type: ignore else: total_norm = torch.nn.utils.clip_grad_norm_( ddp_model.parameters(), 0.3, norm_type=2.0) # type: ignore if not torch.isnan(total_norm): oss_total_norm = sharded_optimizer.clip_grad_norm( 0.3, norm_type=2.0) allclose = torch.allclose(oss_total_norm, total_norm, atol=1e-2 if amp else 1e-8) if not allclose: # Debug helper if this unit test does not pass, compare the gradients in between DDP and ShardedDDP for idx, (p_ddp, p_sdp) in enumerate( zip(ddp_model.parameters(), sharded_ddp_model.parameters())): if p_ddp.grad is not None: if p_sdp.grad is not None: print(rank, idx, torch.norm(p_ddp.grad), torch.norm(p_sdp.grad), flush=True) else: print(rank, idx, torch.norm(p_ddp.grad), "not owned", flush=True) assert ( allclose ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}" else: print(rank, "NaN grad norm in DDP", flush=True) # Flip the trainability of the first parameter back and forth if i == 0 and change_train_graph: next(sharded_ddp_model.parameters()).requires_grad = not next( sharded_ddp_model.parameters()).requires_grad next(ddp_model.parameters()).requires_grad = not next( ddp_model.parameters()).requires_grad check_same_model_params( sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke") dist.destroy_process_group()
import numpy as np import pytest import torch from torch.cuda.amp import GradScaler as TorchGradScaler import torch.distributed as dist import torch.multiprocessing as mp from torch.nn import Linear, Sequential from torch.nn.parallel import DistributedDataParallel as DDP from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.optim import OSS from fairscale.utils import torch_version from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx if torch_version() >= (1, 8, 0): from fairscale.optim.grad_scaler import ShardedGradScaler """ Check that ShardedDDP gets the same results as DDP in a variety of scenarii """ _test_fp16_reduction = [False] if hasattr(dist, "algorithms.ddp_com_hooks.default_hooks"): _test_fp16_reduction.append(True) _test_amp = [False] if hasattr(torch.cuda.amp, "autocast"): _test_amp.append(True) EMB_SIZE = 32
def dist_init(rank: int, world_size: int, filename: str, filename_rpc: str = "") -> bool: """ Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated tests to be run concurrently. Return false if not enough GPUs present in the system. .. warning: This limits the usecase to all ranks being on the same node """ try: torch.distributed.rpc.shutdown() except Exception: pass print(f"dist init r={rank}, world={world_size}") os.environ["WORLD_SIZE"] = str(world_size) os.environ["RANK"] = str(rank) url = "file://" + filename url_rpc = "file://" + filename_rpc if torch_version() >= (1, 6, 0): backend = "nccl" if torch.cuda.is_available() else "gloo" if backend == "nccl" and torch.cuda.device_count() < world_size: logging.warning( "Requested world size cannot be reached on this machine, not enough GPUs" ) return False torch.distributed.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method=url) tp_options = {"init_method": url_rpc} # Workaround for bug in torch v1.8.0. Should be fixed in v1.8.1 if torch_version() == (1, 8, 0): if torch.cuda.is_available(): # Workaround for https://github.com/pytorch/pytorch/issues/53844 tp_options["_transports"] = ["ibv", "uv"] # type: ignore else: # Workaround for https://github.com/pytorch/pytorch/issues/54266 tp_options["_channels"] = [ "mpt_uv", "basic", "cuda_ipc", "cuda_gdr", "cuda_xth", "cuda_basic" ] # type: ignore rpc.init_rpc( f"Test{rank}", rank=rank, world_size=world_size, backend=rpc.BackendType.TENSORPIPE, rpc_backend_options=rpc.TensorPipeRpcBackendOptions(**tp_options), ) else: if world_size > 1: # TensorPipe is not available in Torch 1.5 rpc.init_rpc( name=f"Test{rank}", rank=rank, world_size=world_size, rpc_backend_options=rpc.ProcessGroupRpcBackendOptions( init_method=url_rpc), ) elif torch.cuda.is_available(): torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size, init_method=url) else: return False if torch.cuda.is_available() and torch.cuda.device_count(): torch.cuda.set_device(rank % torch.cuda.device_count()) return True
from typing import Any, Dict, List, NamedTuple, Tuple import pytest import torch import torch.distributed.autograd as dist_autograd from torch.distributed.nn import RemoteModule from torch.distributed.optim import DistributedOptimizer import torch.distributed.rpc as rpc import torch.multiprocessing as mp import torch.nn as nn from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph from fairscale.utils import torch_version pytestmark = pytest.mark.skipif( not torch.cuda.is_available() or torch_version() < (1, 9, 0), reason="CPU tests fail right now and all tests require torch version >= 1.9.0.", ) CPU_DEVICES = ["worker0/cpu", "worker1/cpu"] GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"] if torch.cuda.is_available(): DEVICES = [CPU_DEVICES, GPU_DEVICES] else: DEVICES = [CPU_DEVICES] def rpc_worker(rank, world_size, init_file, func, *args): options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file) for i in range(world_size): options.set_device_map("worker" + str(i), {rank: i})
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn, model_type, bn_type): mixed_precision = precision == "mixed" flatten = flatten == "flatten" wrap_bn = wrap_bn == "auto_wrap_bn" fp32_reduce_scatter = True if mixed_precision else None with_model2 = model_type == "model2" with_sync_bn = bn_type == "sync_bn" if torch_version() >= (1, 7, 0) and torch_version() < (1, 8, 0) and with_sync_bn: # SyncBN is buggy in 1.7, errors like: # E File "/home/circleci/venv/lib/python3.8/site-packages/torch/nn/modules/_functions.py", line 13, in forward # E dtype=running_mean.dtype, # E AttributeError: 'NoneType' object has no attribute 'dtype' pytest.skip("SyncBatchNorm in 1.7 is buggy") if with_sync_bn and not wrap_bn: pytest.skip("SyncBatchNorm requires auto_wrap_bn") if torch_version() < (1, 8, 0) and flatten: # 1.6 and 1.7 throws this error: # RuntimeError: Trying to backward through the graph a second time, but the saved # intermediate results have already been freed. Specify retain_graph=True when calling # backward the first time. pytest.skip("older pytorch throws error when flatten is used") world_size = 2 expected_losses = None # Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt. for with_fsdp in [False, True]: for with_checkpoint in [False, True]: if not with_fsdp and with_checkpoint: continue final_losses = _get_cached_results( world_size, with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, ) if expected_losses is None: expected_losses = final_losses else: print( f"checking: fsdp {with_fsdp} ckpt {with_checkpoint} with ddp+no_ckpt" ) def check(exp, res): assert list(exp.keys()) == list(res.keys( )), f"{list(exp.keys())} vs. {list(res.keys())}" rtol = 1e-4 atol = 1e-5 if with_model2 and mixed_precision and torch_version() >= ( 1, 9, 0): # On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has # larger errors. rtol = 1e-3 atol = 1e-4 for key in exp.keys(): exp_loss = exp[key] res_loss = res[key] torch.testing.assert_allclose(exp_loss, res_loss, rtol=rtol, atol=atol) check(expected_losses, final_losses)
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn, model_type, bn_type): mixed_precision = precision == "mixed" flatten = flatten == "flatten" wrap_bn = wrap_bn == "auto_wrap_bn" fp32_reduce_scatter = True if mixed_precision else None with_model2 = model_type == "model2" with_sync_bn = bn_type == "sync_bn" if torch_version() >= (1, 7, 0) and torch_version() < (1, 8, 0) and with_sync_bn: # SyncBN is buggy in 1.7, errors like: # E File "/home/circleci/venv/lib/python3.8/site-packages/torch/nn/modules/_functions.py", line 13, in forward # E dtype=running_mean.dtype, # E AttributeError: 'NoneType' object has no attribute 'dtype' pytest.skip("SyncBatchNorm in 1.7 is buggy") if with_sync_bn and not wrap_bn: pytest.skip("SyncBatchNorm requires auto_wrap_bn") if torch_version() < (1, 8, 0) and flatten: # 1.6 and 1.7 throws this error: # RuntimeError: Trying to backward through the graph a second time, but the saved # intermediate results have already been freed. Specify retain_graph=True when calling # backward the first time. pytest.skip("older pytorch throws error when flatten is used") world_size = 2 expected_losses = None # Ensure ddp == fsdp when modules are called multiple times per forward pass with/without checkpointing, forward # counters and reducer bucketing. # # The bucketing check exists because the asynchronous gradient reduction it induces can interact with multiple # forward passes in complex ways. For example, in the midst of a sharded backward pass, `parameter.grad` may only be # `None` or an unsharded gradient tensor. The sharded tensor is then set at the end of the backwards pass. But a # unit test with bucketing enabled might not catch violations of this invariant. For very small models, like the # kind used in this unit test, bucketing will delay gradient reduction until after all the gradient computation is # done. If the reduction incorrectly sets `.grad` to the _sharded_ variant, the test might not fail, since the # gradient computations have already happened. Toggling bucketing helps verify that gradient reduction and # computation interact correctly. combinations = [] for with_fsdp in [False, True]: for with_checkpoint in [False, True]: if not with_fsdp and with_checkpoint: continue for with_bucketing in [False, True]: if not with_fsdp and with_bucketing: continue combinations.append( (with_fsdp, with_checkpoint, with_bucketing)) print("") print("Testing the following configurations:") for with_fsdp, with_checkpoint, with_bucketing in combinations: print( f" fsdp {with_fsdp} ckpt {with_checkpoint} bucketing {with_bucketing}" ) for with_fsdp, with_checkpoint, with_bucketing in combinations: if with_bucketing: bucket_cap_mb = 25 else: bucket_cap_mb = 0 final_losses = _get_cached_results( world_size, with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, bucket_cap_mb, ) if expected_losses is None: expected_losses = final_losses else: print( f"checking: fsdp {with_fsdp} ckpt {with_checkpoint} bucketing {with_bucketing} with ddp+no_ckpt" ) def check(exp, res): assert list(exp.keys()) == list( res.keys()), f"{list(exp.keys())} vs. {list(res.keys())}" rtol = 1e-4 atol = 1e-5 if with_model2 and mixed_precision and torch_version() >= ( 1, 9, 0): # On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has # larger errors. rtol = 1e-3 atol = 1e-4 for key in exp.keys(): exp_loss = exp[key] res_loss = res[key] torch.testing.assert_allclose(exp_loss, res_loss, rtol=rtol, atol=atol) check(expected_losses, final_losses)
def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected, model_hidden_dim, fsdp_config): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, filename, filename_rpc) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True # Note that FSDP auto-cast the input in AMP mode. So we don't need to call half() here. batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config) model = model.cuda() if with_fsdp: model = to_fsdp(model, fsdp_config) else: model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500) # We enable momentum so that after the first iteration, the optimizer state is added # to the total memory used. criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) # Set AMP context if needed. context = contextlib.suppress() if "mixed_precision" in fsdp_config and fsdp_config["mixed_precision"]: context = torch.cuda.amp.autocast(enabled=True) # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this # test but on much bigger scale tests). We run 4 iterations here just in case it happens. iterations = 4 results = {} # results of memory stats for iteration in range(iterations): get_cur_mem(gpu_id, results, f"iter {iteration}: start") with context: out = model(batch) get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd") out = sum(o.sum() for o in out[0]) fake_loss = criterion(out, torch.tensor(0.0).cuda()) get_cur_mem(gpu_id, results, f"iter {iteration}: after loss") fake_loss.backward() get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd") optimizer.step() get_cur_mem(gpu_id, results, f"iter {iteration}: after step") # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory. if torch_version() >= (1, 7, 0): model.zero_grad(set_to_none=True) else: for p in model.parameters(): p.grad = None get_cur_mem(gpu_id, results, f"iter {iteration}: done") dump_all_tensors(gpu_id) print(results) def cmp(results, expected): ret = "" assert results.keys() == expected.keys( ), f"{list(results.keys())} vs. {list(expected.keys())}" for k, v in results.items(): exp = expected[k] if abs(exp - v) > 1: # allow 1MB rounding differences ret += f"{k}: got {v}, expected {exp}\n" return ret output = cmp(results, expected) assert not output, output teardown()
all_reduce_handle = dist.all_reduce(total_count, group=process_group, async_op=True) mean = torch.mean(input, dim=dim, keepdim=True) meansqr = torch.mean(input * input, dim=dim, keepdim=True) vec = torch.cat([mean, meansqr]) all_reduce_handle.wait() vec = vec * (count / total_count) dist.all_reduce(vec, group=process_group) mean, meansqr = vec.chunk(2) var = meansqr - mean * mean invstd = torch.rsqrt(var + eps) return mean, var, invstd, total_count if torch_version()[:2] >= (1, 7): _forward = torch.jit.script(_forward) # type: ignore _track_running_stats = torch.jit.script( _track_running_stats) # type: ignore class _SyncBatchNormFunction(torch.autograd.Function): """ An autograd function used to avoid storing activations for intermediate results. NOTE: Even though the mean and var are passed into this function, we do the entire backward, including mean and var, here. We have to calculate statistics outside this function in order to avoid multiple all_reduces when using checkpointing. """ @staticmethod # type: ignore
@torch_spawn([3]) @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required") def rpc_multiple_tensors(): class FuseTwo(nn.Module): def forward(self, left, right): return left + right class SplitTwo(nn.Module): def forward(self, inputs): return (inputs, 2 * inputs) @torch_spawn([2]) @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="no mpi") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") # TODO(msb) Fix this @pytest.mark.skipif(torch_version() >= (1, 8, 0), reason="disabled for torch 1.8.0") def construct_only_rank_zero(): model = [nn.Linear(10, 10), nn.ReLU()] if torch.distributed.get_rank() == 0: PipeRPCWrapper(model, [1, 1], worker_map=get_worker_map()) rpc.shutdown() else: # Must enter rpc loop to complte PipeRPCWrapper constructor above rpc.shutdown() with pytest.raises(AssertionError): PipeRPCWrapper(model, [1, 1], worker_map=get_worker_map())