def _test_func(rank, world_size, tempfile_name, unused): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" # Keep initialization deterministic. torch.manual_seed(0) model = FullyShardedDataParallel(SimpleModuleWithCheckpointing().cuda()) optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Collect parameter sizes to ensure these stay consistent through the steps below. expected_param_shapes = { name: tuple(param.shape) for name, param in model.named_parameters() } # For clarity, this is what `expected_param_shapes` should look like depending on world size: assert expected_param_shapes == { "_fsdp_wrapped_module.flat_param_0": (12, ), "_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0": (6, ), }, expected_param_shapes torch.manual_seed(1 + rank) # Train for a step. _train_step(model, optim, expected_param_shapes) # Now do an eval step. _eval_step(model, optim, expected_param_shapes) # And finally do another train step. _train_step(model, optim, expected_param_shapes) teardown()
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class Model(Module): def __init__(self): super().__init__() self.inner = FSDP(Linear(4, 4), **fsdp_config) self.outer = Linear(4, 5) def forward(self, x): # Forward twice. i = self.inner(x) j = self.inner(x) return self.outer(i + j) model = FSDP(Model(), **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(3): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() model.assert_state(TrainingState.IDLE) teardown()
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" if test_case["assert_ref_out"]: with torch.no_grad(): weight = model.weight.T.clone().cuda() v = torch.Tensor(test_case["inputs"][0][rank]).cuda() ref_out = torch.matmul(v, weight) model.to("cuda") assert isinstance(fsdp_config, dict), str(fsdp_config) model = FSDP(model, **fsdp_config) optim = SGD(model.parameters(), lr=0.1) inputs = test_case["inputs"] assert len(inputs) == 1 or not test_case["assert_ref_out"] assert len(inputs[0]) >= world_size for in_data in inputs: in_data = Tensor(in_data[rank]).cuda() out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if test_case["assert_ref_out"]: torch.testing.assert_allclose(ref_out, out) model.assert_state(TrainingState.IDLE) teardown()
def _distributed_worker(gpu_id, world_size, with_fsdp, freezing_method, tempfile_name, unused, rank_0_output, expected_state): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = _create_model(with_fsdp) model = model.cuda() # freezing the trunk using requires_grad. assert freezing_method in ["requires_grad", "grad_to_none"] if freezing_method == "requires_grad": for param in model.trunk.parameters(): param.requires_grad = False if with_fsdp: model = FSDP(model) else: model = DistributedDataParallel(model, device_ids=[gpu_id]) if gpu_id == 0: print(model) target = torch.LongTensor([0, 1]).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) for iteration in range(3): out = model(batch) fake_loss = criterion(out, target) print("Loss", iteration, ":", fake_loss.item()) optimizer.zero_grad() fake_loss.backward() if freezing_method == "grad_to_none": for param in model.trunk.parameters(): param.grad = None optimizer.step() if with_fsdp: fsdp_state = model.state_dict() # Move tensors to CPU to compare numerics. for k, v in fsdp_state.items(): fsdp_state[k] = v.cpu() assert objects_are_equal(expected_state, fsdp_state, raise_exception=True) elif rank == 0: state_after = model.module.cpu().state_dict() torch.save(state_after, rank_0_output) teardown()
def _test_func(rank, world_size, tempfile_name, unused, flatten, mixed_precision, amp_context, half_input, fsdp_wrap_ckpt): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" # Keep initialization deterministic. torch.manual_seed(0) model = FSDP( SimpleModuleWithCheckpointing(flatten, mixed_precision, fsdp_wrap_ckpt).cuda(), flatten_parameters=flatten, mixed_precision=mixed_precision, ) optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Collect parameter sizes to ensure these stay consistent through the steps below. expected_param_shapes = { name: tuple(param.shape) for name, param in model.named_parameters() } # For clarity, this is what `expected_param_shapes` should look like depending on world size: if not flatten: assert expected_param_shapes == { "ffn.0.weight": (5, ), "ffn.0.bias": (2, ), "ffn.1.weight": (5, ), "ffn.1.bias": (2, ), "ffn.2.weight": (5, ), "ffn.2.bias": (2, ), } else: assert expected_param_shapes == { "_fsdp_wrapped_module.flat_param_0": (12, ), "_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0": (6, ), }, expected_param_shapes torch.manual_seed(1 + rank) # Train for a step. _train_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input) # Now do an eval step. _eval_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input) # And finally do another train step. _train_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input) teardown()
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" my_lr = 0.1 device = torch.device("cuda") if fsdp_config.get("mixed_precision", False): dtype = torch.float16 fsdp_config["fp32_reduce_scatter"] = True else: dtype = torch.float32 if test_case["assert_ref_out"]: with torch.no_grad(): # Compute one iteration local output. fp32_weight = model.weight.T.clone().to(device) weight = fp32_weight.to(dtype) v = torch.Tensor(test_case["inputs"][0][rank]).to(device, dtype) ref_forward_output_my_rank = torch.matmul(v, weight) # Compute one iteration global weight update. v = torch.Tensor(test_case["inputs"][0][:world_size]).to( device, dtype) grad = v.float().sum(0).repeat(weight.shape[0], 1).div(world_size) ref_weight_out = fp32_weight - grad.T * my_lr assert ref_weight_out.dtype == torch.float32 model.to( device) # not dtype, since FSDP will manage mixed precision internally assert isinstance(fsdp_config, dict), str(fsdp_config) model = FSDP(model, **fsdp_config) optim = SGD(model.parameters(), lr=my_lr) inputs = test_case["inputs"] assert len(inputs) == 1 or not test_case["assert_ref_out"] assert len(inputs[0]) >= world_size for in_data in inputs: in_data = Tensor(in_data[rank]).to(device, dtype) out = model(in_data) out.float().sum().backward() optim.step() optim.zero_grad() if test_case["assert_ref_out"]: with model.summon_full_params(): weight_out = model.module.weight.data.T.clone() # make sure we can do more fwd/bwd loss = model(in_data) loss.sum().backward() if test_case["assert_ref_out"]: torch.testing.assert_allclose(ref_forward_output_my_rank, out) torch.testing.assert_allclose(ref_weight_out, weight_out) model.assert_state(TrainingState.IDLE) teardown()
def test_input_type(temp_files, fsdp_config, input_cls): """Test FSDP with input being a list or a dict, only single GPU.""" if torch_version() < (1, 7, 0): # This test runs multiple test cases in a single process. On 1.6.0 it # throw an error like this: # RuntimeError: Container is already initialized! Cannot initialize it twice! pytest.skip( "older pytorch doesn't work well with single process dist_init multiple times" ) result = dist_init(rank=0, world_size=1, filename=temp_files[0], filename_rpc=temp_files[1]) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class Model(Module): def __init__(self): super().__init__() self.layer = Linear(4, 4) def forward(self, input): if isinstance(input, list): input = input[0] else: assert isinstance(input, dict), input input = input["in"] return self.layer(input) model = FSDP(Model(), **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(5): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True if input_cls is list: in_data = [in_data] else: assert input_cls is dict in_data = {"in": in_data} out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() model.assert_state(TrainingState.IDLE) teardown()
def _dist_worker(rank, world_size, files, wrap_middle, test_fn): # Get data from files. file1, file2, sd_before, sd_after, in_data = files sd_before = torch.load( sd_before, map_location=lambda storage, loc: storage.cuda(rank)) if test_fn == "train": sd_after = torch.load( sd_after, map_location=lambda storage, loc: storage.cuda(rank)) in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank)) result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2) assert result, "Dist init failed" fsdp_model = FSDP( # To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping # and make that work. Model(with_fsdp=True, wrap_middle=wrap_middle), flatten_parameters=test_fn == "optim_state", mixed_precision=False, compute_dtype=torch.float16, ) fsdp_model.load_state_dict(sd_before) if test_fn == "train": _train(fsdp_model, in_data) objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True) elif test_fn == "eval": _eval(fsdp_model, in_data) elif test_fn == "optim_state": optim = SGD(fsdp_model.parameters(), lr=0.1) for _ in range(3): out = fsdp_model(in_data) out.backward() optim.step() sd = fsdp_model.gather_full_optim_state_dict(optim) if rank == 0: # There should 8 momentum buffers in the state. assert len(sd["state"].keys()) == 8 else: assert sd is None, "only rank 0 should have the optim state" else: assert 0, f"invalid test_fn {test_fn}" teardown()
def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, filename, filename_rpc) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = create_model(with_fsdp, with_checkpoint) model = model.cuda() if with_fsdp: model = to_fsdp(model) else: model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500) criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=1e-4) results = {} for iteration in range(3): get_cur_mem(gpu_id, results, f"iter {iteration}: start") out = model(batch) get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd") out = sum(o.sum() for o in out[0]) fake_loss = criterion(out, torch.tensor(0.0).cuda()) get_cur_mem(gpu_id, results, f"iter {iteration}: after loss") fake_loss.backward() get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd") optimizer.step() get_cur_mem(gpu_id, results, f"iter {iteration}: after step") # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory. if torch_version() >= (1, 7, 0): model.zero_grad(set_to_none=True) else: for p in model.parameters(): p.grad = None get_cur_mem(gpu_id, results, f"iter {iteration}: done") assert results == expected, f"{results} but expected {expected}" teardown()
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class Model(Module): def __init__(self): super().__init__() # TODO (Min): for now, we just test pytorch sync_bn here. # this will grow into regnet; testing apex sync_bn, etc. self.conv = Conv2d(2, 2, (1, 1)) self.bn = BatchNorm2d(2) def forward(self, x): x = self.conv(x) x = self.bn(x) return x # TODO (Min): check DDP equivalency. model = Model() # Note, different rank may wrap in different order due to different random # seeds. But results should be the same. if random.randint(0, 1) == 0: print("auto_wrap_bn, then convert_sync_batchnorm") model = auto_wrap_bn(model) model = SyncBatchNorm.convert_sync_batchnorm(model) else: print("convert_sync_batchnorm, then auto_wrap_bn") model = SyncBatchNorm.convert_sync_batchnorm(model) model = auto_wrap_bn(model) model = FSDP(model, **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(3): in_data = torch.rand(2, 2, 2, 2).cuda() in_data.requires_grad = True out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() model.assert_state(TrainingState.IDLE) teardown()
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" my_lr = 0.1 if test_case["assert_ref_out"]: with torch.no_grad(): # Compute one iteration local output. weight = model.weight.T.clone().cuda() v = torch.Tensor(test_case["inputs"][0][rank]).cuda() ref_forward_output_my_rank = torch.matmul(v, weight) # Compute one iteration global weight update. v = torch.Tensor(test_case["inputs"][0][:world_size]).cuda() grad = v.sum(0).repeat(weight.shape[0], 1).div(world_size) ref_weight_out = weight - grad.T * my_lr model.to("cuda") assert isinstance(fsdp_config, dict), str(fsdp_config) model = FSDP(model, **fsdp_config) optim = SGD(model.parameters(), lr=my_lr) inputs = test_case["inputs"] assert len(inputs) == 1 or not test_case["assert_ref_out"] assert len(inputs[0]) >= world_size for in_data in inputs: in_data = Tensor(in_data[rank]).cuda() out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if test_case["assert_ref_out"]: with model.summon_full_params(): weight_out = model.module.weight.data.T.clone() # make sure we can do more fwd/bwd loss = model(in_data) loss.sum().backward() if test_case["assert_ref_out"]: torch.testing.assert_allclose(ref_forward_output_my_rank, out) torch.testing.assert_allclose(ref_weight_out, weight_out) model.assert_state(TrainingState.IDLE) teardown()
def _freeze_distributed_worker( gpu_id, world_size, tempfile_name, unused, ): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(2, 3, 224, 224)).cuda() # The use case for this test is where the weights in the submodule # are not frozen but the leftover weights or those contained by the # root module are frozen. Refer to issue #758 for a real world example. model = FreezeModel() model = model.cuda() for param in model.head.parameters(): param.requires_grad = False model = FSDP(model) if gpu_id == 0: print(model) target = torch.tensor([0, 1], dtype=torch.long).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) for iteration in range(3): out = model(batch) fake_loss = criterion(out, target) print("Loss", iteration, ":", fake_loss.item()) optimizer.zero_grad() fake_loss.backward() optimizer.step() teardown()
def test_local_state_dict_calls_state_dict_recursion(): """Testing the case of infinite recursive when FSDP is subclassed""" class TestModule(FSDP): def __init__(self): super().__init__(module=nn.Linear(100, 100)) def state_dict(self, *args, **kwargs): return self.local_state_dict(*args, **kwargs) rank = 0 world_size = 1 with temp_files_ctx(2) as temp_files: result = dist_init(rank, world_size, temp_files[0], temp_files[1]) assert result, "Dist init failed" m = TestModule() d = m.state_dict() teardown()
def _dist_worker(rank, world_size, files, outer_flat, inner_flat, sharing): # Get data from files. file1, file2, sd_before, sd_after, in_data = files sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank)) sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank)) in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank)) result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2) assert result, "Dist init failed" fsdp_model = FSDP(Model(with_fsdp=True, inner_flat=inner_flat, sharing=sharing), flatten_parameters=outer_flat) fsdp_model.load_state_dict(sd_before) _train(fsdp_model, in_data) objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True) teardown()
def test_pre_backward_hook(temp_files): """Test FSDP with a model that triggers a pre_backward hook bug.""" result = dist_init(rank=0, world_size=1, filename=temp_files[0], filename_rpc=temp_files[1]) assert result, "Dist init failed" class Model(Module): def __init__(self): super().__init__() self.l1 = Linear(4, 4).cuda() self.l2 = FSDP(Linear(4, 4).cuda()) self.l3 = Linear(4, 4).cuda() def forward(self, x): x = self.l1(x) x = self.l2(x) inner_result = x x = self.l3(x) return x, inner_result def assert_and_clear_grad(self): for p in self.parameters(): assert p.shape in [(4, 4), (4, ), (4 * 4 + 4, )], p.shape assert p.grad is not None p.grad = None model = FSDP(Model(), flatten_parameters=False).cuda() in_data = torch.rand(1, 4).cuda() for _ in range(3): out, _ = model(in_data) out.sum().backward() model.assert_and_clear_grad() teardown()
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class InnerModel(Module): def __init__(self): super().__init__() self.layers = Sequential(FSDP(Linear(5, 5), **fsdp_config), ) def forward(self, x): return self.layers(x) inner_model = InnerModel() model = FSDP(inner_model, **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for i in range(3): input = torch.rand((1, 5), dtype=torch.float).cuda() input.requires_grad = True output = model(input) output.sum().backward() optim.step() optim.zero_grad() input = torch.rand((1, 5), dtype=torch.float).cuda() output = model(input) model.assert_state(TrainingState.IDLE) # second time to rewrap the inner model rewrapped_model = FSDP(inner_model, **fsdp_config).cuda() rewrapped_output = rewrapped_model(input) assert torch.allclose(output, rewrapped_output) teardown()
def _distributed_worker( gpu_id, world_size, with_model2, with_sync_bn, with_fsdp, with_checkpoint, files, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, bucket_cap_mb, ): filename, filename_rpc = files[:2] filename_loss = files[2:] torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, filename, filename_rpc) assert result, "Dist init failed" # use False below to debug since error msg is not as good with cudnn. torch.backends.cudnn.enabled = True # these make things deterministic. torch.manual_seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Ensure we have multiple forward passes. batch = [ torch.randn(size=(2, 3, 16, 16)).cuda(), torch.randn(size=(2, 3, 9, 9)).cuda(), torch.randn(size=(2, 3, 9, 9)).cuda(), ] if mixed_precision and not with_fsdp: batch = [x.half() for x in batch] model = _create_model( with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, bucket_cap_mb, ) model = model.cuda() if with_fsdp: model = FSDP( model, flatten_parameters=flatten, mixed_precision=mixed_precision, compute_dtype=torch.float32, fp32_reduce_scatter=fp32_reduce_scatter, bucket_cap_mb=bucket_cap_mb, ) model.set_gradient_divide_factors(1.0, 2.0, True) no_sync_context = contextlib.suppress() else: # With DDP, we need no_sync and manual gradient reduction below because # it can't handle multiple forward pass + checkpointing otherwise. model = DistributedDataParallel(model, device_ids=[gpu_id]) no_sync_context = model.no_sync() mp_context = contextlib.suppress() if mixed_precision: mp_context = torch.cuda.amp.autocast(enabled=True) if gpu_id == 0: print(model) target = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.long).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) losses = {} i = 0 with no_sync_context: for iteration in range(3): with mp_context: out = model(batch) loss = criterion(out, target) print("Loss", iteration, ":", loss.item()) losses[f"iter_{i}"] = loss i += 1 optimizer.zero_grad() loss.backward() # Manual grad reduction, no autocast. if not with_fsdp: for p in model.parameters(): dist.all_reduce(p.grad.data) p.grad.data.div_(2.0) # Stepping, no autocast optimizer.step() # Due to dist.all_reduce code block above with ddp.no_sync, we seem to hit a bug # in DDP where tensor.cpu() and torch.save() calls both hang. FSDP is not affected. # Therefore, we have to compare losses here instead of states. with open(filename_loss[rank], "wb") as f: pickle.dump(losses, f) teardown()
def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected, model_hidden_dim, fsdp_config): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, filename, filename_rpc) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True # Note that FSDP auto-cast the input in AMP mode. So we don't need to call half() here. batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = create_model(with_fsdp, with_checkpoint, model_hidden_dim, fsdp_config) model = model.cuda() if with_fsdp: model = to_fsdp(model, fsdp_config) else: model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500) # We enable momentum so that after the first iteration, the optimizer state is added # to the total memory used. criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) # Set AMP context if needed. context = contextlib.suppress() if "mixed_precision" in fsdp_config and fsdp_config["mixed_precision"]: context = torch.cuda.amp.autocast(enabled=True) # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this # test but on much bigger scale tests). We run 4 iterations here just in case it happens. iterations = 4 results = {} # results of memory stats for iteration in range(iterations): get_cur_mem(gpu_id, results, f"iter {iteration}: start") with context: out = model(batch) get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd") out = sum(o.sum() for o in out[0]) fake_loss = criterion(out, torch.tensor(0.0).cuda()) get_cur_mem(gpu_id, results, f"iter {iteration}: after loss") fake_loss.backward() get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd") optimizer.step() get_cur_mem(gpu_id, results, f"iter {iteration}: after step") # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory. if torch_version() >= (1, 7, 0): model.zero_grad(set_to_none=True) else: for p in model.parameters(): p.grad = None get_cur_mem(gpu_id, results, f"iter {iteration}: done") dump_all_tensors(gpu_id) print(results) def cmp(results, expected): ret = "" assert results.keys() == expected.keys( ), f"{list(results.keys())} vs. {list(expected.keys())}" for k, v in results.items(): exp = expected[k] if abs(exp - v) > 1: # allow 1MB rounding differences ret += f"{k}: got {v}, expected {exp}\n" return ret output = cmp(results, expected) assert not output, output teardown()
def _distributed_worker( gpu_id, world_size, fsdp_config, tempfile, tempfile_rpc, ): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, tempfile, tempfile_rpc) assert result, "Dist init failed" # Save the original torch.distributed.all_gather function since we will # patch it to include an artificial delay. orig_all_gather = torch.distributed.all_gather def run(compute_cycles, all_gather_cycles): has_params = all_gather_cycles > 0 model = _create_model(fsdp_config, compute_cycles, has_params) # Get the input and sets the input's requires_grad to True because # we have a fake compute in the forward pass. batch = torch.rand(1).cuda() batch.requires_grad = True # We run 20 iterations but only collect timing data from the minimal 10 # data points because nondeterministic system events can disturb the timing. cpu_iter = Min10() cpu_wait = Min10() gpu_compute = Min10() gpu_total = Min10() for _ in range(20): # Get two events for measuring the overall time. e1 = Event(enable_timing=True) e2 = Event(enable_timing=True) cpu_start = time.process_time() all_gather_called = False def _delayed_all_gather(*args, **kwargs): nonlocal all_gather_called all_gather_called = True torch.cuda._sleep(all_gather_cycles) return orig_all_gather(*args, **kwargs) # forward pass # # Even though both e1 & e2 are on the compute stream, since # compute depends on all_gather, e2-e1 includes all_gather time. e1.record() with patch("torch.distributed.all_gather", _delayed_all_gather): out = model(batch) if has_params and world_size > 1: assert all_gather_called else: assert not all_gather_called e2.record() # backward pass out.backward() if torch_version() >= (1, 7, 0): model.zero_grad(set_to_none=True) else: for p in model.parameters(): p.grad = None cpu_iter_time = time.process_time() - cpu_start # wait for gpu out.item() cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time # get sum of the compute time times = [] for mod in model.modules(): if not isinstance(mod, Layer): continue times.append(mod.get_time()) # get gpu compute + all_gather time overall_gpu_time = e1.elapsed_time(e2) cpu_iter.add(cpu_iter_time) cpu_wait.add(cpu_wait_for_gpu_time) gpu_compute.add(sum(times)) gpu_total.add(overall_gpu_time) del model return { "cpu_iter": cpu_iter.avg(), "cpu_wait": cpu_wait.avg(), "gpu_compute": gpu_compute.avg(), "gpu_total": gpu_total.avg(), } sleep_cycles = int(100 * get_cycles_per_ms()) e1 = run(0, 0) # no compute, no all-gather e2 = run(0, sleep_cycles) # no compute, only all-gather e3 = run(sleep_cycles, 0) # only compute, no all-gather e4 = run(sleep_cycles, sleep_cycles) # both compute and all-gather debug_string = f"\nrank{rank}:\n e1: {e1}\n e2: {e2}\n e3: {e3}\n e4: {e4}" print(debug_string) # Check the cpu/gpu timing. CPU should run ahead of GPU. Therefore, cpu-gpu # wait should be long, except when there is no real work on GPU. # # If the assertions fail below, we likely have a cpu-gpu wait in the forward/backward pass. short = [e1["cpu_iter"], e2["cpu_iter"], e3["cpu_iter"], e4["cpu_iter"], e1["cpu_wait"]] long = [e3["cpu_wait"], e4["cpu_wait"]] if world_size == 1: short.append(e2["cpu_wait"]) # all gather should not be happening. else: long.append(e2["cpu_wait"]) # all gather should happen and prolong the cpu-gpu wait. for s in short: for l in long: # 10X longer is a safe margin, since the GPU work timing is around 100X more # of that of the CPU. assert s * 10 < l, f"{s} * 10 < {l} in " + debug_string # Check the GPU timing. short = [e1["gpu_compute"], e1["gpu_total"], e2["gpu_compute"]] long = [e3["gpu_compute"], e3["gpu_total"], e4["gpu_compute"], e4["gpu_total"]] if world_size == 1: short.append(e2["gpu_total"]) # all gather should not be happening. else: long.append(e2["gpu_total"]) # all gather should happen and prolong the cpu-gpu wait. for s in short: for l in long: # 10X longer is a safe margin, since the time is around 100X longer # when there is work on GPU vs. no work. assert s * 10 < l, f"{s} * 10 < {l} in " + debug_string # Check the GPU overlapping when there is all-gather. if world_size > 1: compute_only = e3["gpu_compute"] all_gather_only = e2["gpu_total"] both = e4["gpu_total"] assert compute_only + all_gather_only > 1.1 * both, ( f"{compute_only} + {all_gather_only} > 1.1 * {both} in " + debug_string ) teardown()
def _distributed_worker( rank, world_size, fsdp_config, fsdp_wrap_bn, ddp_mixed_precision, tempfile_name, unused, state_before, inputs, rank_0_output, state_after, sync_bn, conv_bias, linear_bias, ): torch.backends.cudnn.deterministic = True result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" ddp = True if fsdp_config: ddp = False assert isinstance(fsdp_config, dict), str(fsdp_config) if fsdp_config["mixed_precision"]: # To match DDP in AMP -O1, we need fp32 reduce scatter. fsdp_config["fp32_reduce_scatter"] = True model = Model(conv_bias, linear_bias) model.load_state_dict(state_before) model = model.cuda() class DummyScaler: def scale(self, loss): return loss def step(self, optim): optim.step() def update(self): pass scaler = DummyScaler() if ddp: if sync_bn == "pytorch": model = pytorch_bn_converter(model) model = DDP(model, device_ids=[rank], broadcast_buffers=True) if ddp_mixed_precision: scaler = GradScaler() else: # Note, different rank may wrap in different order due to different random # seeds. But results should be the same. if random.randint(0, 1) == 0: print(f"auto_wrap_bn {fsdp_wrap_bn}, then sync_bn {sync_bn}") if fsdp_wrap_bn: model = auto_wrap_bn(model, _single_rank_pg) if sync_bn == "pytorch": model = pytorch_bn_converter(model) else: print(f"sync_bn {sync_bn}, then auto_wrap_bn {fsdp_wrap_bn}") if sync_bn == "pytorch": model = pytorch_bn_converter(model) if fsdp_wrap_bn: model = auto_wrap_bn(model, _single_rank_pg) model = FSDP(model, **fsdp_config).cuda() if fsdp_config["mixed_precision"]: scaler = ShardedGradScaler() # Print the model for verification. if rank == 0: print(model) optim = SGD(model.parameters(), lr=0.1) loss_func = CrossEntropyLoss() for in_data in inputs[rank]: in_data = in_data.cuda() context = contextlib.suppress() if ddp and ddp_mixed_precision: in_data = in_data.half() context = torch.cuda.amp.autocast(enabled=True) if not ddp and fsdp_config["mixed_precision"]: context = torch.cuda.amp.autocast(enabled=True) with context: out = model(in_data) fake_label = torch.zeros(1, dtype=torch.long).cuda() loss = loss_func(out.unsqueeze(0), fake_label) scaler.scale(loss).backward() scaler.step(optim) scaler.update() optim.zero_grad() if ddp: # Save the rank 0 state_dict to the output file. if rank == 0: state_after = model.module.cpu().state_dict() torch.save(state_after, rank_0_output) else: model.assert_state(TrainingState.IDLE) # Ensure final state equals to the state_after. fsdp_state = model.state_dict() # Move tensors to CPU to compare numerics. for k, v in fsdp_state.items(): fsdp_state[k] = v.cpu() # Change False to True to enable this when you want to debug the mismatch. if False and rank == 0: def dump(d): for k, v in d.items(): print(k, v) dump(state_after) dump(fsdp_state) # If sync_bn is used, all ranks should have the same state, so we can compare with # rank 0 state on every rank. Otherwise, only compare rank 0 with rank 0. if sync_bn != "none" or rank == 0: assert objects_are_equal(state_after, fsdp_state, raise_exception=True) teardown()