def test_catch_grad_grad(): with temp_files_ctx(num=1) as temp_files: # Check that ShardedDDP exposes the original module's attributes dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1) model = Sequential(Linear(2, 3), Linear(3, 3)) model.train() chained_grad = torch.zeros_like(next(model.parameters())) chained_grad.requires_grad = True next(model.parameters()).grad = chained_grad optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) inputs = torch.rand(100, 2) with pytest.raises(RuntimeError): _ = ddp_model(inputs) dist.destroy_process_group()
def _get_cached_results( world_size, with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, bucket_cap_mb, ): """Cache the training to save time. For DDP, flatten, wrap_bn etc. doesn't matter, so the results can be cached. """ if not with_fsdp: flatten = None wrap_bn = None fp32_reduce_scatter = None key = ( world_size, with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, bucket_cap_mb, ) global _result_cache if key not in _result_cache: # Get 4 files: 2 for dist_init and 2 for each rank to save the losses. with temp_files_ctx(num=2 + world_size) as temp_files: mp.spawn( _distributed_worker, ( world_size, with_model2, with_sync_bn, with_fsdp, with_checkpoint, temp_files, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, bucket_cap_mb, ), nprocs=world_size, ) final_losses = {} for rank in range(world_size): with open(temp_files[2 + rank], "rb") as f: for iter_key, loss in pickle.load(f).items(): final_losses[f"rank_{rank}_{iter_key}"] = loss _result_cache[key] = final_losses return _result_cache[key]
def test_train_and_eval_with_checkpointing(flatten, mixed_precision, amp_context, half_input, fsdp_wrap_ckpt): flatten = flatten == "flat" mixed_precision = mixed_precision == "fp16" amp_context = amp_context == "autocast" half_input = half_input == "halfin" fsdp_wrap_ckpt = fsdp_wrap_ckpt == "F->C" # Expecting an known bug in 4 out of 32 cases. if fsdp_wrap_ckpt and mixed_precision and not flatten: pytest.skip("known bug") world_size = 2 with temp_files_ctx(2) as (temp_file_name, unused): mp.spawn( _test_func, args=( world_size, temp_file_name, unused, flatten, mixed_precision, amp_context, half_input, fsdp_wrap_ckpt, ), nprocs=world_size, join=True, )
def test_consolidation(embedding_size: int, flatten_parameters: bool): world_size = 2 with in_temporary_directory(): with temp_files_ctx(num=1) as temp_files: mp.spawn(_worker, (temp_files[0], world_size, embedding_size, flatten_parameters), nprocs=world_size)
def test_gpt2(world_size): # Check that having trainable unused params is fine backend = "gloo" device = "cuda" with temp_files_ctx(num=1) as temp_files: mp.spawn(run_test_gpt2, args=(world_size, backend, device, temp_files[0]), nprocs=world_size, join=True)
def test_forward_overlap(world_size, flatten, mixed): fsdp_config = { "flatten_parameters": flatten == "flatten", "mixed_precision": mixed == "mixed", } with temp_files_ctx(2) as temp_files: mp.spawn( _distributed_worker, (world_size, fsdp_config, temp_files[0], temp_files[1]), nprocs=world_size, )
def test_train_eval_change(): world_size = 4 with temp_files_ctx(num=1) as temp_files: mp.spawn( run_test_train_eval_change, args=(world_size, temp_files[0]), nprocs=world_size, join=True, )
def test_multiple_groups(reduce_buffer_size, backend): world_size = 4 with temp_files_ctx(num=1) as temp_files: mp.spawn( run_test_multiple_groups, args=(world_size, temp_files[0], backend, reduce_buffer_size), nprocs=world_size, join=True, )
def test_ddp_parity_two_optim(reduce_buffer_size): world_size = 2 backend = dist.Backend.NCCL with temp_files_ctx(num=1) as temp_files: mp.spawn( run_ddp_parity_two_optim, args=(world_size, backend, temp_files[0], reduce_buffer_size), nprocs=world_size, join=True, )
def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type): with temp_files_ctx(num=1) as temp_files: mp.spawn( run_one_step, args=(world_size, backend, device, temp_files[0], broadcast_buffers, grad_accumulation, reduce_buffer_size), nprocs=world_size, join=True, )
def test_two_optimizers(): # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs world_size = 2 backend = "gloo" device = "cpu" with temp_files_ctx(num=1) as temp_files: mp.spawn(run_test_two_optimizers, args=(world_size, backend, device, temp_files[0]), nprocs=world_size, join=True)
def test_training_change(reduce_buffer_size): world_size = 2 backend = "nccl" device = "cuda" with temp_files_ctx(num=1) as temp_files: mp.spawn( run_test_training_change, args=(world_size, backend, device, temp_files[0], reduce_buffer_size), nprocs=world_size, join=True, )
def test_ddp_sync_batch_norm(): # Check that ShardedDDP is compatible with sync batch norm across multiple GPUs world_size = 2 backend = "gloo" device = "cuda" with temp_files_ctx(num=1) as temp_files: mp.spawn( run_test_ddp_sync_batch_norm, args=(world_size, backend, device, temp_files[0]), nprocs=world_size, join=True, )
def test_device_change(reduce_buffer_size): # Check that ShardedDDP handles a device change properly world_size = 2 backend = "nccl" with temp_files_ctx(num=1) as temp_files: device = "cuda" mp.spawn( run_test_device_change, args=(world_size, backend, device, temp_files[0], reduce_buffer_size), nprocs=world_size, join=True, )
def test_train_and_eval_with_checkpointing(): if torch_version() < (1, 6, 0): pytest.skip("older pytorch doesn't support reduce_scatter") world_size = 2 with temp_files_ctx(2) as (temp_file_name, unused): mp.spawn( _test_func, args=(world_size, temp_file_name, unused), nprocs=world_size, join=True, )
def test_memory_tracking_fsdp(): """ Check that we can collect memory traces of a simplistic model in the context of FSDP distributed training """ with temp_files_ctx(num=2) as sync_files: world_size = 2 mp.spawn( _layer_memory_tracking_fsdp_worker, (sync_files, world_size), nprocs=world_size, )
def test_inputs(reduce_buffer_size, backend, device): # Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs world_size = 2 if backend == "nccl" and device == "cpu": pytest.skip("Incompatible combination, or cuda not available") return with temp_files_ctx(num=1) as temp_files: mp.spawn( run_test_two_inputs, args=(world_size, backend, device, temp_files[0], reduce_buffer_size), nprocs=world_size, join=True, )
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn): mixed_precision = precision == "mixed" flatten = flatten == "flatten" wrap_bn = wrap_bn == "auto_wrap_bn" fp32_reduce_scatter = True if mixed_precision else None if torch_version() < (1, 8, 0) and flatten: # 1.6 and 1.7 throws this error: # RuntimeError: Trying to backward through the graph a second time, but the saved # intermediate results have already been freed. Specify retain_graph=True when calling # backward the first time. pytest.skip("older pytorch throws error when flatten is used") world_size = 2 expected_losses = None # Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt. for with_fsdp in [False, True]: for with_checkpoint in [False, True]: # Get 4 files: 2 for dist_init and 2 for each rank to save the losses. with temp_files_ctx(num=2 + world_size) as temp_files: mp.spawn( _distributed_worker, ( world_size, with_fsdp, with_checkpoint, temp_files, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, ), nprocs=world_size, ) final_losses = {} for rank in range(world_size): with open(temp_files[2 + rank], "rb") as f: final_losses[f"rank_{rank}"] = pickle.load(f) if expected_losses is None: expected_losses = final_losses else: print(f"fsdp: {with_fsdp} ckpt: {with_checkpoint}") assert objects_are_equal(expected_losses, final_losses, raise_exception=True)
def test_local_state_dict_calls_state_dict_recursion(): """Testing the case of infinite recursive when FSDP is subclassed""" class TestModule(FSDP): def __init__(self): super().__init__(module=nn.Linear(100, 100)) def state_dict(self, *args, **kwargs): return self.local_state_dict(*args, **kwargs) rank = 0 world_size = 1 with temp_files_ctx(2) as temp_files: result = dist_init(rank, world_size, temp_files[0], temp_files[1]) assert result, "Dist init failed" m = TestModule() d = m.state_dict() teardown()
def test_mixed_types(): with temp_files_ctx(num=1) as temp_files: # Check that ShardedDDP exposes the original module's attributes dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1) model = _get_mlp(tripwire=True) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) model = ShardedDataParallel(model, optimizer) input_tensor = torch.rand((2, 2)) _ = model(input_tensor) dist.destroy_process_group()
def test_ddp_attributes(): # Check that ShardedDDP exposes the same attributes as Pytorch's DDP # - is multi_device_module # - device_type with temp_files_ctx(num=1) as temp_files: dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1) model = Sequential(Linear(2, 3), Linear(3, 3)) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) assert hasattr(ddp_model, "is_multi_device_module") assert hasattr(ddp_model, "device_type") dist.destroy_process_group()
def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, reduce_fp16, setup): world_size = 2 with temp_files_ctx(num=1) as temp_files: mp.spawn( run_one_step, args=( world_size, setup[0], setup[1], temp_files[0], broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, reduce_fp16, ), nprocs=world_size, join=True, )
def test_random_attributes(): with temp_files_ctx(num=1) as temp_files: # Check that ShardedDDP exposes the original module's attributes dist.init_process_group(init_method="file://" + temp_files[0], backend="gloo", rank=0, world_size=1) model = Sequential(Linear(2, 3), Linear(3, 3)) model.banana = "sweet" optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) ddp_model = ShardedDataParallel(model, optimizer) assert hasattr(ddp_model, "banana") assert not hasattr(ddp_model, "orange") dist.destroy_process_group()
def test_ddp_parity( reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp, manual_reduction, multiple_fw, ): if torch_version() < (1, 8, 0): pytest.skip("pytorch version >= 1.8.0 required") if manual_reduction and change_train_graph: pytest.skip( "Skipping changing model and grad accumulation combination, makes little sense" ) world_size = torch.cuda.device_count() backend = dist.Backend.NCCL with temp_files_ctx(num=1) as temp_files: mp.spawn( run_ddp_parity, args=( world_size, backend, temp_files[0], reduce_buffer_size, grad_accumulation, change_train_graph, fp16_reduction, clip_grad_norm, amp, manual_reduction, multiple_fw, ), nprocs=world_size, join=True, )
def temp_files(): # dist_init needs 2 files with temp_files_ctx(2) as files: yield files
def test_fsdp_memory(fsdp, ckpt): expected = { ("ddp", "no_ckpt"): { "iter 0: start": 9, "iter 0: after fwd": 346, "iter 0: after loss": 346, "iter 0: after bwd": 14, "iter 0: after step": 14, "iter 0: done": 9, "iter 1: start": 9, "iter 1: after fwd": 346, "iter 1: after loss": 346, "iter 1: after bwd": 14, "iter 1: after step": 14, "iter 1: done": 9, "iter 2: start": 9, "iter 2: after fwd": 346, "iter 2: after loss": 346, "iter 2: after bwd": 14, "iter 2: after step": 14, "iter 2: done": 9, }, ("fsdp", "no_ckpt"): { "iter 0: start": 3, "iter 0: after fwd": 340, "iter 0: after loss": 340, "iter 0: after bwd": 66, "iter 0: after step": 66, "iter 0: done": 3, "iter 1: start": 3, "iter 1: after fwd": 340, "iter 1: after loss": 340, "iter 1: after bwd": 66, "iter 1: after step": 66, "iter 1: done": 3, "iter 2: start": 3, "iter 2: after fwd": 340, "iter 2: after loss": 340, "iter 2: after bwd": 66, "iter 2: after step": 66, "iter 2: done": 3, }, ("ddp", "ckpt"): { "iter 0: start": 9, "iter 0: after fwd": 57, "iter 0: after loss": 57, "iter 0: after bwd": 14, "iter 0: after step": 14, "iter 0: done": 9, "iter 1: start": 9, "iter 1: after fwd": 57, "iter 1: after loss": 57, "iter 1: after bwd": 14, "iter 1: after step": 14, "iter 1: done": 9, "iter 2: start": 9, "iter 2: after fwd": 57, "iter 2: after loss": 57, "iter 2: after bwd": 14, "iter 2: after step": 14, "iter 2: done": 9, }, ("fsdp", "ckpt"): { "iter 0: start": 3, "iter 0: after fwd": 51, "iter 0: after loss": 51, "iter 0: after bwd": 66, "iter 0: after step": 66, "iter 0: done": 3, "iter 1: start": 3, "iter 1: after fwd": 51, "iter 1: after loss": 51, "iter 1: after bwd": 66, "iter 1: after step": 66, "iter 1: done": 3, "iter 2: start": 3, "iter 2: after fwd": 51, "iter 2: after loss": 51, "iter 2: after bwd": 66, "iter 2: after step": 66, "iter 2: done": 3, }, }[(fsdp, ckpt)] fsdp = fsdp == "fsdp" ckpt = ckpt == "ckpt" world_size = 2 with temp_files_ctx(num=2) as temp_files: mp.spawn( _distributed_worker, (world_size, fsdp, ckpt, temp_files[0], temp_files[1], expected), nprocs=world_size )
def test_fsdp_memory(fsdp, ckpt): expected = { ("ddp", "no_ckpt"): { "iter 0: start": 9, "iter 0: after fwd": 346, "iter 0: after loss": 346, "iter 0: after bwd": 14, "iter 0: after step": 17, "iter 0: done": 13, "iter 1: start": 13, "iter 1: after fwd": 350, "iter 1: after loss": 350, "iter 1: after bwd": 17, "iter 1: after step": 17, "iter 1: done": 13, "iter 2: start": 13, "iter 2: after fwd": 350, "iter 2: after loss": 350, "iter 2: after bwd": 17, "iter 2: after step": 17, "iter 2: done": 13, "iter 3: start": 13, "iter 3: after fwd": 350, "iter 3: after loss": 350, "iter 3: after bwd": 17, "iter 3: after step": 17, "iter 3: done": 13, }, ("fsdp", "no_ckpt"): { "iter 0: start": 3, "iter 0: after fwd": 340, "iter 0: after loss": 340, "iter 0: after bwd": 16, "iter 0: after step": 18, "iter 0: done": 5, "iter 1: start": 5, "iter 1: after fwd": 342, "iter 1: after loss": 342, "iter 1: after bwd": 18, "iter 1: after step": 18, "iter 1: done": 5, "iter 2: start": 5, "iter 2: after fwd": 342, "iter 2: after loss": 342, "iter 2: after bwd": 18, "iter 2: after step": 18, "iter 2: done": 5, "iter 3: start": 5, "iter 3: after fwd": 342, "iter 3: after loss": 342, "iter 3: after bwd": 18, "iter 3: after step": 18, "iter 3: done": 5, }, ("fsdp_amp_default", "no_ckpt"): { "iter 0: start": 28, "iter 0: after fwd": 630, "iter 0: after loss": 630, "iter 0: after bwd": 67, "iter 0: after step": 93, "iter 0: done": 54, "iter 1: start": 54, "iter 1: after fwd": 657, "iter 1: after loss": 657, "iter 1: after bwd": 93, "iter 1: after step": 93, "iter 1: done": 54, "iter 2: start": 54, "iter 2: after fwd": 657, "iter 2: after loss": 657, "iter 2: after bwd": 93, "iter 2: after step": 93, "iter 2: done": 54, "iter 3: start": 54, "iter 3: after fwd": 657, "iter 3: after loss": 657, "iter 3: after bwd": 93, "iter 3: after step": 93, "iter 3: done": 54, }, ("fsdp_amp_compute_dtype32", "no_ckpt"): { "iter 0: start": 28, "iter 0: after fwd": 657, "iter 0: after loss": 657, "iter 0: after bwd": 67, "iter 0: after step": 93, "iter 0: done": 54, "iter 1: start": 54, "iter 1: after fwd": 684, "iter 1: after loss": 684, "iter 1: after bwd": 93, "iter 1: after step": 93, "iter 1: done": 54, "iter 2: start": 54, "iter 2: after fwd": 684, "iter 2: after loss": 684, "iter 2: after bwd": 93, "iter 2: after step": 93, "iter 2: done": 54, "iter 3: start": 54, "iter 3: after fwd": 684, "iter 3: after loss": 684, "iter 3: after bwd": 93, "iter 3: after step": 93, "iter 3: done": 54, }, ("ddp", "ckpt"): { "iter 0: start": 9, "iter 0: after fwd": 57, "iter 0: after loss": 57, "iter 0: after bwd": 14, "iter 0: after step": 17, "iter 0: done": 13, "iter 1: start": 13, "iter 1: after fwd": 61, "iter 1: after loss": 61, "iter 1: after bwd": 17, "iter 1: after step": 17, "iter 1: done": 13, "iter 2: start": 13, "iter 2: after fwd": 61, "iter 2: after loss": 61, "iter 2: after bwd": 17, "iter 2: after step": 17, "iter 2: done": 13, "iter 3: start": 13, "iter 3: after fwd": 61, "iter 3: after loss": 61, "iter 3: after bwd": 17, "iter 3: after step": 17, "iter 3: done": 13, }, ("fsdp", "ckpt"): { "iter 0: start": 3, "iter 0: after fwd": 51, "iter 0: after loss": 51, "iter 0: after bwd": 16, "iter 0: after step": 18, "iter 0: done": 5, "iter 1: start": 5, "iter 1: after fwd": 53, "iter 1: after loss": 53, "iter 1: after bwd": 18, "iter 1: after step": 18, "iter 1: done": 5, "iter 2: start": 5, "iter 2: after fwd": 53, "iter 2: after loss": 53, "iter 2: after bwd": 18, "iter 2: after step": 18, "iter 2: done": 5, "iter 3: start": 5, "iter 3: after fwd": 53, "iter 3: after loss": 53, "iter 3: after bwd": 18, "iter 3: after step": 18, "iter 3: done": 5, }, ("fsdp_amp_default", "ckpt"): { "iter 0: start": 28, "iter 0: after fwd": 52, "iter 0: after loss": 52, "iter 0: after bwd": 67, "iter 0: after step": 93, "iter 0: done": 54, "iter 1: start": 54, "iter 1: after fwd": 79, "iter 1: after loss": 79, "iter 1: after bwd": 93, "iter 1: after step": 93, "iter 1: done": 54, "iter 2: start": 54, "iter 2: after fwd": 79, "iter 2: after loss": 79, "iter 2: after bwd": 93, "iter 2: after step": 93, "iter 2: done": 54, "iter 3: start": 54, "iter 3: after fwd": 79, "iter 3: after loss": 79, "iter 3: after bwd": 93, "iter 3: after step": 93, "iter 3: done": 54, }, ("fsdp_amp_compute_dtype32", "ckpt"): { "iter 0: start": 28, "iter 0: after fwd": 52, "iter 0: after loss": 52, "iter 0: after bwd": 67, "iter 0: after step": 93, "iter 0: done": 54, "iter 1: start": 54, "iter 1: after fwd": 79, "iter 1: after loss": 79, "iter 1: after bwd": 93, "iter 1: after step": 93, "iter 1: done": 54, "iter 2: start": 54, "iter 2: after fwd": 79, "iter 2: after loss": 79, "iter 2: after bwd": 93, "iter 2: after step": 93, "iter 2: done": 54, "iter 3: start": 54, "iter 3: after fwd": 79, "iter 3: after loss": 79, "iter 3: after bwd": 93, "iter 3: after step": 93, "iter 3: done": 54, }, }[(fsdp, ckpt)] # Compute the FSDP config. fsdp_config = {} # Set mixed precision. if "amp" in fsdp: fsdp_config["mixed_precision"] = True # When compute_dtype is FP32, make sure we use clear_autocast_cache. # Setting fp32_reduce_scatter and verbose for more code coverage. if "compute_dtype32" in fsdp: fsdp_config["compute_dtype"] = torch.float32 fsdp_config["fp32_reduce_scatter"] = True fsdp_config["clear_autocast_cache"] = True fsdp_config["verbose"] = True # Using bigger hidden dimension for AMP to increase the model size # so that bug in handling params will show up but we don't do that # in the base case to keep the test fast. # - hidden_dim 128: model size ~4MB # - hidden_dim 512: model size ~55MB # - hidden_dim 1024: model size ~200MB (seems to be too big for CI tests though) model_hidden_dim = 128 if "amp" in fsdp: model_hidden_dim = 512 # Get the fsdp and checkpoint flags. with_fsdp = "fsdp" in fsdp with_ckpt = ckpt == "ckpt" world_size = 2 with temp_files_ctx(num=2) as temp_files: mp.spawn( _distributed_worker, (world_size, with_fsdp, with_ckpt, temp_files[0], temp_files[1], expected, model_hidden_dim, fsdp_config), nprocs=world_size, )
def temp_files(): # dist_init needs 2 files + 3 files for before state, after state, in_data. with temp_files_ctx(5) as files: yield files