def run_test_column_parallel_linear(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) mpu.initialize_model_parallel(model_parallel_size) if torch.distributed.get_rank() == 0: print("> testing ColumnParallelLinear with model parallel size: {}".format(model_parallel_size)) model_parallel_size = mpu.get_model_parallel_world_size() seed = 12345 set_random_seed(seed) input_size_coeff = 13 input_size = input_size_coeff * model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * model_parallel_size batch_size = 7 # Network identity_layer = IdentityLayer2D(batch_size, input_size).cuda() linear_layer = layers.ColumnParallelLinear(input_size, output_size, keep_master_weight_for_test=True).cuda() loss_weight = torch.randn([batch_size, output_size]).cuda() # Forward input_ = identity_layer() output = linear_layer(input_) loss = torch.mul(output, loss_weight).sum() # Backward loss.backward() # Values. dLdY = loss_weight X = identity_layer.weight A = linear_layer.master_weight.cuda() dLdA = torch.matmul(dLdY.t(), X) dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) dLdX = torch.matmul(dLdY, A) rank = mpu.get_model_parallel_rank() my_dLdA = torch.split(dLdA, output_size_coeff, dim=0)[rank].contiguous().clone() error = my_dLdA.sub(linear_layer.weight.grad).abs().max() torch.distributed.barrier() print(" error in dLdA on global rank {}: {}".format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 my_dLdb = torch.split(dLdb, output_size_coeff, dim=0)[rank].contiguous().clone() error = my_dLdb.sub(linear_layer.bias.grad).abs().max() torch.distributed.barrier() print(" error in dLdb on global rank {}: {}".format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = dLdX.sub(identity_layer.weight.grad).abs().max() torch.distributed.barrier() print(" error in dLdX on global rank {}: {}".format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(" >> passed the test :-)")
def run_worker(rank, world_size, args): if args.world_size != 0: world_size = args.world_size dist_init(rank + args.rank_base, world_size, hostname=args.host) initialize_model_parallel(1, world_size) init_random_seed(0) run_mp_worker(args, world_size) rpc.shutdown() torch.distributed.destroy_process_group()
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class Model(Module): def __init__(self): super().__init__() self.inner = FSDP(Linear(4, 4), **fsdp_config) self.outer = Linear(4, 5) def forward(self, x): # Forward twice. i = self.inner(x) j = self.inner(x) return self.outer(i + j) model = FSDP(Model(), **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(3): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() model.assert_state(TrainingState.IDLE) teardown()
def _test_func(rank, world_size, tempfile_name, unused): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" # Keep initialization deterministic. torch.manual_seed(0) model = FullyShardedDataParallel(SimpleModuleWithCheckpointing().cuda()) optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Collect parameter sizes to ensure these stay consistent through the steps below. expected_param_shapes = { name: tuple(param.shape) for name, param in model.named_parameters() } # For clarity, this is what `expected_param_shapes` should look like depending on world size: assert expected_param_shapes == { "_fsdp_wrapped_module.flat_param_0": (12, ), "_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0": (6, ), }, expected_param_shapes torch.manual_seed(1 + rank) # Train for a step. _train_step(model, optim, expected_param_shapes) # Now do an eval step. _eval_step(model, optim, expected_param_shapes) # And finally do another train step. _train_step(model, optim, expected_param_shapes) teardown()
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" if test_case["assert_ref_out"]: with torch.no_grad(): weight = model.weight.T.clone().cuda() v = torch.Tensor(test_case["inputs"][0][rank]).cuda() ref_out = torch.matmul(v, weight) model.to("cuda") assert isinstance(fsdp_config, dict), str(fsdp_config) model = FSDP(model, **fsdp_config) optim = SGD(model.parameters(), lr=0.1) inputs = test_case["inputs"] assert len(inputs) == 1 or not test_case["assert_ref_out"] assert len(inputs[0]) >= world_size for in_data in inputs: in_data = Tensor(in_data[rank]).cuda() out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if test_case["assert_ref_out"]: torch.testing.assert_allclose(ref_out, out) model.assert_state(TrainingState.IDLE) teardown()
def _distributed_worker(gpu_id, world_size, with_fsdp, freezing_method, tempfile_name, unused, rank_0_output, expected_state): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = _create_model(with_fsdp) model = model.cuda() # freezing the trunk using requires_grad. assert freezing_method in ["requires_grad", "grad_to_none"] if freezing_method == "requires_grad": for param in model.trunk.parameters(): param.requires_grad = False if with_fsdp: model = FSDP(model) else: model = DistributedDataParallel(model, device_ids=[gpu_id]) if gpu_id == 0: print(model) target = torch.LongTensor([0, 1]).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) for iteration in range(3): out = model(batch) fake_loss = criterion(out, target) print("Loss", iteration, ":", fake_loss.item()) optimizer.zero_grad() fake_loss.backward() if freezing_method == "grad_to_none": for param in model.trunk.parameters(): param.grad = None optimizer.step() if with_fsdp: fsdp_state = model.state_dict() # Move tensors to CPU to compare numerics. for k, v in fsdp_state.items(): fsdp_state[k] = v.cpu() assert objects_are_equal(expected_state, fsdp_state, raise_exception=True) elif rank == 0: state_after = model.module.cpu().state_dict() torch.save(state_after, rank_0_output) teardown()
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" my_lr = 0.1 device = torch.device("cuda") if fsdp_config.get("mixed_precision", False): dtype = torch.float16 fsdp_config["fp32_reduce_scatter"] = True else: dtype = torch.float32 if test_case["assert_ref_out"]: with torch.no_grad(): # Compute one iteration local output. fp32_weight = model.weight.T.clone().to(device) weight = fp32_weight.to(dtype) v = torch.Tensor(test_case["inputs"][0][rank]).to(device, dtype) ref_forward_output_my_rank = torch.matmul(v, weight) # Compute one iteration global weight update. v = torch.Tensor(test_case["inputs"][0][:world_size]).to( device, dtype) grad = v.float().sum(0).repeat(weight.shape[0], 1).div(world_size) ref_weight_out = fp32_weight - grad.T * my_lr assert ref_weight_out.dtype == torch.float32 model.to( device) # not dtype, since FSDP will manage mixed precision internally assert isinstance(fsdp_config, dict), str(fsdp_config) model = FSDP(model, **fsdp_config) optim = SGD(model.parameters(), lr=my_lr) inputs = test_case["inputs"] assert len(inputs) == 1 or not test_case["assert_ref_out"] assert len(inputs[0]) >= world_size for in_data in inputs: in_data = Tensor(in_data[rank]).to(device, dtype) out = model(in_data) out.float().sum().backward() optim.step() optim.zero_grad() if test_case["assert_ref_out"]: with model.summon_full_params(): weight_out = model.module.weight.data.T.clone() # make sure we can do more fwd/bwd loss = model(in_data) loss.sum().backward() if test_case["assert_ref_out"]: torch.testing.assert_allclose(ref_forward_output_my_rank, out) torch.testing.assert_allclose(ref_weight_out, weight_out) model.assert_state(TrainingState.IDLE) teardown()
def _test_func(rank, world_size, tempfile_name, unused, flatten, mixed_precision, amp_context, half_input, fsdp_wrap_ckpt): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" # Keep initialization deterministic. torch.manual_seed(0) model = FSDP( SimpleModuleWithCheckpointing(flatten, mixed_precision, fsdp_wrap_ckpt).cuda(), flatten_parameters=flatten, mixed_precision=mixed_precision, ) optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Collect parameter sizes to ensure these stay consistent through the steps below. expected_param_shapes = { name: tuple(param.shape) for name, param in model.named_parameters() } # For clarity, this is what `expected_param_shapes` should look like depending on world size: if not flatten: assert expected_param_shapes == { "ffn.0.weight": (5, ), "ffn.0.bias": (2, ), "ffn.1.weight": (5, ), "ffn.1.bias": (2, ), "ffn.2.weight": (5, ), "ffn.2.bias": (2, ), } else: assert expected_param_shapes == { "_fsdp_wrapped_module.flat_param_0": (12, ), "_fsdp_wrapped_module._fpw_module.ffn.1._fsdp_wrapped_module.flat_param_0": (6, ), }, expected_param_shapes torch.manual_seed(1 + rank) # Train for a step. _train_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input) # Now do an eval step. _eval_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input) # And finally do another train step. _train_step(model, optim, expected_param_shapes, amp_context, mixed_precision, half_input) teardown()
def test_input_type(temp_files, fsdp_config, input_cls): """Test FSDP with input being a list or a dict, only single GPU.""" if torch_version() < (1, 7, 0): # This test runs multiple test cases in a single process. On 1.6.0 it # throw an error like this: # RuntimeError: Container is already initialized! Cannot initialize it twice! pytest.skip( "older pytorch doesn't work well with single process dist_init multiple times" ) result = dist_init(rank=0, world_size=1, filename=temp_files[0], filename_rpc=temp_files[1]) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class Model(Module): def __init__(self): super().__init__() self.layer = Linear(4, 4) def forward(self, input): if isinstance(input, list): input = input[0] else: assert isinstance(input, dict), input input = input["in"] return self.layer(input) model = FSDP(Model(), **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(5): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True if input_cls is list: in_data = [in_data] else: assert input_cls is dict in_data = {"in": in_data} out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() model.assert_state(TrainingState.IDLE) teardown()
def run_test_get_model_parallel_src_rank(rank, model_parallel_size_): dist_init(rank, model_parallel_size_) if torch.distributed.get_rank() == 0: print("> testing get_model_parallel_src_rank with size {} ...".format( model_parallel_size_)) model_parallel_size = min(model_parallel_size_, torch.distributed.get_world_size()) assert not mpu.model_parallel_is_initialized() mpu.initialize_model_parallel(model_parallel_size) assert mpu.model_parallel_is_initialized() # Checks src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank() assert mpu.get_model_parallel_src_rank() == src_rank # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def run_test_cross_entropy(rank, model_parallel_size): dist_init(rank, model_parallel_size) if torch.distributed.get_rank() == 0: print("> testing cross entropy with model parallel size {} ...".format( model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() batch_size = 13 seq_length = 17 vocab_size_per_partition = 11 logits_scale = 1000.0 vocab_size = vocab_size_per_partition * model_parallel_size seed = 1234 loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed) loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed) error = loss_torch.sub_(loss_mpu).abs().max() print(" max error in loss on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 error = grad_torch.sub_(grad_mpu).abs().max() print(" max error in grad on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def _layer_memory_tracking_ddp_worker(gpu_id: int, sync_files: Tuple[str, str], world_size: int): dist_init(world_size=world_size, rank=gpu_id, filename=sync_files[0], filename_rpc=sync_files[1]) torch.backends.cudnn.deterministic = True # Create different inputs on each GPU batch_size = 16 torch.manual_seed(gpu_id) fake_inputs = torch.randn(size=(batch_size, 10)).cuda(gpu_id) fake_targets = torch.randn(size=(batch_size, 10)).cuda(gpu_id) fake_criterion = nn.MSELoss() # Create a simple model torch.manual_seed(0) torch.cuda.manual_seed(0) model = nn.Sequential( nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 32), nn.ReLU(), nn.Linear(32, 10), ) model = model.cuda(gpu_id) ddp_model = DistributedDataParallel(model, device_ids=[gpu_id]) # Track the model on a forward / backward pass tracker = LayerwiseMemoryTracker() tracker.monitor(ddp_model) fake_criterion(ddp_model(fake_inputs), fake_targets).backward() tracker.stop() # Check the overall structure of the collected traces forward_names = [f"module.{i}" for i in range(5)] backward_names = [f"module.{i}" for i in reversed(range(5))] trace_names = [t.module_name for t in tracker.memory_traces] assert trace_names == (forward_names + backward_names)
def _dist_worker(rank, world_size, files, wrap_middle, test_fn): # Get data from files. file1, file2, sd_before, sd_after, in_data = files sd_before = torch.load( sd_before, map_location=lambda storage, loc: storage.cuda(rank)) if test_fn == "train": sd_after = torch.load( sd_after, map_location=lambda storage, loc: storage.cuda(rank)) in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank)) result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2) assert result, "Dist init failed" fsdp_model = FSDP( # To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping # and make that work. Model(with_fsdp=True, wrap_middle=wrap_middle), flatten_parameters=test_fn == "optim_state", mixed_precision=False, compute_dtype=torch.float16, ) fsdp_model.load_state_dict(sd_before) if test_fn == "train": _train(fsdp_model, in_data) objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True) elif test_fn == "eval": _eval(fsdp_model, in_data) elif test_fn == "optim_state": optim = SGD(fsdp_model.parameters(), lr=0.1) for _ in range(3): out = fsdp_model(in_data) out.backward() optim.step() sd = fsdp_model.gather_full_optim_state_dict(optim) if rank == 0: # There should 8 momentum buffers in the state. assert len(sd["state"].keys()) == 8 else: assert sd is None, "only rank 0 should have the optim state" else: assert 0, f"invalid test_fn {test_fn}" teardown()
def run_test_initialize_model_parallel(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) if torch.distributed.get_rank() == 0: print("> testing initialize_model_parallel with size {} ...".format( model_parallel_size)) model_parallel_size_ = min(model_parallel_size, torch.distributed.get_world_size()) assert not mpu.model_parallel_is_initialized() mpu.initialize_model_parallel(model_parallel_size_) assert mpu.model_parallel_is_initialized() # Checks. def check(group, world_size, rank): assert world_size == torch.distributed.get_world_size(group=group) assert rank == torch.distributed.get_rank(group=group) # Model parallel. world_size = model_parallel_size_ rank = torch.distributed.get_rank() % model_parallel_size_ assert world_size == mpu.get_model_parallel_world_size() assert rank == mpu.get_model_parallel_rank() check(mpu.get_model_parallel_group(), world_size, rank) # Data parallel. world_size = torch.distributed.get_world_size() // model_parallel_size_ rank = torch.distributed.get_rank() // model_parallel_size assert world_size == mpu.get_data_parallel_world_size() assert rank == mpu.get_data_parallel_rank() check(mpu.get_data_parallel_group(), world_size, rank) # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def run_test_model_parallel_cuda_manual_seed(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) if torch.distributed.get_rank() == 0: print("> testing model parallel cuda manual seed with size {} ...".format(model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() model_parallel_cuda_manual_seed(12345) assert torch.cuda.initial_seed() == 12345 with get_cuda_rng_tracker().fork(): assert torch.cuda.initial_seed() == (12345 + 2718 + mpu.get_model_parallel_rank()) # Reset the tracker get_cuda_rng_tracker().reset() # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, filename, filename_rpc) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = create_model(with_fsdp, with_checkpoint) model = model.cuda() if with_fsdp: model = to_fsdp(model) else: model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500) criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=1e-4) results = {} for iteration in range(3): get_cur_mem(gpu_id, results, f"iter {iteration}: start") out = model(batch) get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd") out = sum(o.sum() for o in out[0]) fake_loss = criterion(out, torch.tensor(0.0).cuda()) get_cur_mem(gpu_id, results, f"iter {iteration}: after loss") fake_loss.backward() get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd") optimizer.step() get_cur_mem(gpu_id, results, f"iter {iteration}: after step") # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory. if torch_version() >= (1, 7, 0): model.zero_grad(set_to_none=True) else: for p in model.parameters(): p.grad = None get_cur_mem(gpu_id, results, f"iter {iteration}: done") assert results == expected, f"{results} but expected {expected}" teardown()
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class Model(Module): def __init__(self): super().__init__() # TODO (Min): for now, we just test pytorch sync_bn here. # this will grow into regnet; testing apex sync_bn, etc. self.conv = Conv2d(2, 2, (1, 1)) self.bn = BatchNorm2d(2) def forward(self, x): x = self.conv(x) x = self.bn(x) return x # TODO (Min): check DDP equivalency. model = Model() # Note, different rank may wrap in different order due to different random # seeds. But results should be the same. if random.randint(0, 1) == 0: print("auto_wrap_bn, then convert_sync_batchnorm") model = auto_wrap_bn(model) model = SyncBatchNorm.convert_sync_batchnorm(model) else: print("convert_sync_batchnorm, then auto_wrap_bn") model = SyncBatchNorm.convert_sync_batchnorm(model) model = auto_wrap_bn(model) model = FSDP(model, **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(3): in_data = torch.rand(2, 2, 2, 2).cuda() in_data.requires_grad = True out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() model.assert_state(TrainingState.IDLE) teardown()
def test_local_state_dict_calls_state_dict_recursion(): """Testing the case of infinite recursive when FSDP is subclassed""" class TestModule(FSDP): def __init__(self): super().__init__(module=nn.Linear(100, 100)) def state_dict(self, *args, **kwargs): return self.local_state_dict(*args, **kwargs) rank = 0 world_size = 1 with temp_files_ctx(2) as temp_files: result = dist_init(rank, world_size, temp_files[0], temp_files[1]) assert result, "Dist init failed" m = TestModule() d = m.state_dict() teardown()
def _test_func(rank, world_size, model, fsdp_config, tempfile_name, unused, test_case): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" my_lr = 0.1 if test_case["assert_ref_out"]: with torch.no_grad(): # Compute one iteration local output. weight = model.weight.T.clone().cuda() v = torch.Tensor(test_case["inputs"][0][rank]).cuda() ref_forward_output_my_rank = torch.matmul(v, weight) # Compute one iteration global weight update. v = torch.Tensor(test_case["inputs"][0][:world_size]).cuda() grad = v.sum(0).repeat(weight.shape[0], 1).div(world_size) ref_weight_out = weight - grad.T * my_lr model.to("cuda") assert isinstance(fsdp_config, dict), str(fsdp_config) model = FSDP(model, **fsdp_config) optim = SGD(model.parameters(), lr=my_lr) inputs = test_case["inputs"] assert len(inputs) == 1 or not test_case["assert_ref_out"] assert len(inputs[0]) >= world_size for in_data in inputs: in_data = Tensor(in_data[rank]).cuda() out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if test_case["assert_ref_out"]: with model.summon_full_params(): weight_out = model.module.weight.data.T.clone() # make sure we can do more fwd/bwd loss = model(in_data) loss.sum().backward() if test_case["assert_ref_out"]: torch.testing.assert_allclose(ref_forward_output_my_rank, out) torch.testing.assert_allclose(ref_weight_out, weight_out) model.assert_state(TrainingState.IDLE) teardown()
def _freeze_distributed_worker( gpu_id, world_size, tempfile_name, unused, ): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(2, 3, 224, 224)).cuda() # The use case for this test is where the weights in the submodule # are not frozen but the leftover weights or those contained by the # root module are frozen. Refer to issue #758 for a real world example. model = FreezeModel() model = model.cuda() for param in model.head.parameters(): param.requires_grad = False model = FSDP(model) if gpu_id == 0: print(model) target = torch.tensor([0, 1], dtype=torch.long).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) for iteration in range(3): out = model(batch) fake_loss = criterion(out, target) print("Loss", iteration, ":", fake_loss.item()) optimizer.zero_grad() fake_loss.backward() optimizer.step() teardown()
def _dist_worker(rank, world_size, files, outer_flat, inner_flat, sharing): # Get data from files. file1, file2, sd_before, sd_after, in_data = files sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank)) sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank)) in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank)) result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2) assert result, "Dist init failed" fsdp_model = FSDP(Model(with_fsdp=True, inner_flat=inner_flat, sharing=sharing), flatten_parameters=outer_flat) fsdp_model.load_state_dict(sd_before) _train(fsdp_model, in_data) objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True) teardown()
def test_pre_backward_hook(temp_files): """Test FSDP with a model that triggers a pre_backward hook bug.""" result = dist_init(rank=0, world_size=1, filename=temp_files[0], filename_rpc=temp_files[1]) assert result, "Dist init failed" class Model(Module): def __init__(self): super().__init__() self.l1 = Linear(4, 4).cuda() self.l2 = FSDP(Linear(4, 4).cuda()) self.l3 = Linear(4, 4).cuda() def forward(self, x): x = self.l1(x) x = self.l2(x) inner_result = x x = self.l3(x) return x, inner_result def assert_and_clear_grad(self): for p in self.parameters(): assert p.shape in [(4, 4), (4, ), (4 * 4 + 4, )], p.shape assert p.grad is not None p.grad = None model = FSDP(Model(), flatten_parameters=False).cuda() in_data = torch.rand(1, 4).cuda() for _ in range(3): out, _ = model(in_data) out.sum().backward() model.assert_and_clear_grad() teardown()
def _test_func(rank, world_size, fsdp_config, tempfile_name, unused): result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" assert isinstance(fsdp_config, dict), str(fsdp_config) class InnerModel(Module): def __init__(self): super().__init__() self.layers = Sequential(FSDP(Linear(5, 5), **fsdp_config), ) def forward(self, x): return self.layers(x) inner_model = InnerModel() model = FSDP(inner_model, **fsdp_config).cuda() optim = SGD(model.parameters(), lr=0.1) for i in range(3): input = torch.rand((1, 5), dtype=torch.float).cuda() input.requires_grad = True output = model(input) output.sum().backward() optim.step() optim.zero_grad() input = torch.rand((1, 5), dtype=torch.float).cuda() output = model(input) model.assert_state(TrainingState.IDLE) # second time to rewrap the inner model rewrapped_model = FSDP(inner_model, **fsdp_config).cuda() rewrapped_output = rewrapped_model(input) assert torch.allclose(output, rewrapped_output) teardown()
def run_test_set_cuda_rng_state(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) if torch.distributed.get_rank() == 0: print("> testing set_rng_state with size {} ...".format(model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() size = 123 seed = 1234 torch.cuda.manual_seed(1234) tensor = torch.cuda.FloatTensor(size) # Get the state rng_state = torch.cuda.get_rng_state() rng_state_copy = rng_state.clone() # Do some stuff. for _ in range(5): torch.randn(size, out=tensor) result_1 = tensor.clone() assert rng_state.sub(rng_state_copy).max() == 0 assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0 # State should be different. new_rng_state = torch.cuda.get_rng_state() max_diff = new_rng_state.sub(rng_state).max() print( " max diff in rng state (should be non-zero) on global rank {}: {}".format( torch.distributed.get_rank(), max_diff ) ) assert max_diff > 0 # Reset the rng state and do the same stuff. random._set_cuda_rng_state(rng_state) for _ in range(5): torch.randn(size, out=tensor) random._set_cuda_rng_state(rng_state) for _ in range(5): torch.randn(size, out=tensor) result_2 = tensor.clone() # Results should be the same error = result_2.sub(result_1).abs().max() print( " max error in generated tensors (should be zero) on global rank {}: {}".format( torch.distributed.get_rank(), error ) ) assert error < 1.0e-6 # Input state should have remained intact. error = rng_state.sub(rng_state_copy).max() print( " max error in rng state (should be zero) on global rank {}: {}".format(torch.distributed.get_rank(), error) ) assert error == 0 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def _distributed_worker( gpu_id, world_size, with_model2, with_sync_bn, with_fsdp, with_checkpoint, files, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, bucket_cap_mb, ): filename, filename_rpc = files[:2] filename_loss = files[2:] torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, filename, filename_rpc) assert result, "Dist init failed" # use False below to debug since error msg is not as good with cudnn. torch.backends.cudnn.enabled = True # these make things deterministic. torch.manual_seed(0) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Ensure we have multiple forward passes. batch = [ torch.randn(size=(2, 3, 16, 16)).cuda(), torch.randn(size=(2, 3, 9, 9)).cuda(), torch.randn(size=(2, 3, 9, 9)).cuda(), ] if mixed_precision and not with_fsdp: batch = [x.half() for x in batch] model = _create_model( with_model2, with_sync_bn, with_fsdp, with_checkpoint, mixed_precision, flatten, wrap_bn, fp32_reduce_scatter, bucket_cap_mb, ) model = model.cuda() if with_fsdp: model = FSDP( model, flatten_parameters=flatten, mixed_precision=mixed_precision, compute_dtype=torch.float32, fp32_reduce_scatter=fp32_reduce_scatter, bucket_cap_mb=bucket_cap_mb, ) model.set_gradient_divide_factors(1.0, 2.0, True) no_sync_context = contextlib.suppress() else: # With DDP, we need no_sync and manual gradient reduction below because # it can't handle multiple forward pass + checkpointing otherwise. model = DistributedDataParallel(model, device_ids=[gpu_id]) no_sync_context = model.no_sync() mp_context = contextlib.suppress() if mixed_precision: mp_context = torch.cuda.amp.autocast(enabled=True) if gpu_id == 0: print(model) target = torch.tensor([0, 1, 2, 3, 4, 5], dtype=torch.long).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) losses = {} i = 0 with no_sync_context: for iteration in range(3): with mp_context: out = model(batch) loss = criterion(out, target) print("Loss", iteration, ":", loss.item()) losses[f"iter_{i}"] = loss i += 1 optimizer.zero_grad() loss.backward() # Manual grad reduction, no autocast. if not with_fsdp: for p in model.parameters(): dist.all_reduce(p.grad.data) p.grad.data.div_(2.0) # Stepping, no autocast optimizer.step() # Due to dist.all_reduce code block above with ddp.no_sync, we seem to hit a bug # in DDP where tensor.cpu() and torch.save() calls both hang. FSDP is not affected. # Therefore, we have to compare losses here instead of states. with open(filename_loss[rank], "wb") as f: pickle.dump(losses, f) teardown()
def init_and_run(fn, args, rank, world_size, filename, filename_rpc): dist_init(rank, world_size, filename, filename_rpc) group = torch.distributed.new_group() fn(rank, group, *args)
def run_test_parallel_embedding(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) if torch.distributed.get_rank() == 0: print("> testing parallel embedding with model parallel size {} ...". format(model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() batch_size = 17 seq_length = 23 vocab_size = 48 hidden_size = 16 seed = 1236 set_random_seed(123) input_data = torch.LongTensor(size=(batch_size, seq_length)).random_( 0, vocab_size).cuda() loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() set_random_seed(seed) embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() output = embedding_original(input_data) loss_original = torch.mul(output, loss_weight).sum() loss_original.backward() set_random_seed(seed) embedding_parallel = layers.ParallelEmbedding( vocab_size, hidden_size, init_method=init.normal_).cuda() output = embedding_parallel(input_data) loss_parallel = torch.mul(output, loss_weight).sum() loss_parallel.backward() set_random_seed(seed) embedding_vocab_parallel = layers.VocabParallelEmbedding( vocab_size, hidden_size, init_method=init.normal_).cuda() output = embedding_vocab_parallel(input_data) loss_vocab_parallel = torch.mul(output, loss_weight).sum() loss_vocab_parallel.backward() torch.distributed.barrier() error = loss_parallel.sub(loss_original).abs() print(" error in loss (parallel) on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, "error: {}".format(error) torch.distributed.barrier() error = loss_vocab_parallel.sub(loss_original).abs() print(" error in loss (vocab parallel) on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, "error: {}".format(error) weight_grad_orig = torch.split(embedding_original.weight.grad, hidden_size // model_parallel_size, 1)[mpu.get_model_parallel_rank()] error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() print(" error in grad (parallel) on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, "error: {}".format(error) weight_grad_orig = torch.split(embedding_original.weight.grad, vocab_size // model_parallel_size, 0)[mpu.get_model_parallel_rank()] error = embedding_vocab_parallel.weight.grad.sub( weight_grad_orig).abs().max() print(" error in grad (vocab parallel) on global rank {}: {}".format( torch.distributed.get_rank(), error)) assert error < 1.0e-12, "error: {}".format(error) # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")
def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False): pipe_world_size = 2 if world_size == 1: return if not skip_dist_init: dist_init(rank, world_size, filename, filename_rpc) else: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29502" rpc.init_rpc(f"Test{rank}", rank=rank, world_size=world_size) mpu.initialize_model_parallel(world_size / pipe_world_size, pipe_world_size) model_parallel_size = mpu.get_model_parallel_world_size() if torch.distributed.get_rank() == 0: print( "> testing Sequential + MultiProcessPipe with model parallel size: {}, pipe: {}" .format(model_parallel_size, pipe_world_size)) chunk_size = 4 seed = 12345 set_random_seed(seed) input_size_coeff = 3 input_size = input_size_coeff * model_parallel_size output_size_coeff = 7 output_size = output_size_coeff * model_parallel_size batch_size = 3 * chunk_size target = torch.rand((batch_size, input_size), requires_grad=True).cuda() print(f"target = {target}") identity = IdentityLayer2D(batch_size, input_size).cuda() pipeline_devices = mpu.get_pipeline_parallel_group() set_random_seed(seed) model = nn.Sequential( layers.ColumnParallelLinear(input_size, output_size, keep_master_weight_for_test=True, bias=False).cuda(), nn.ReLU(), layers.RowParallelLinear(output_size, input_size, keep_master_weight_for_test=True, bias=False).cuda(), ) set_random_seed(seed) reference = [ nn.Linear(input_size, output_size, bias=False).cuda(), nn.ReLU(), nn.Linear(output_size, input_size, bias=False).cuda(), ] print( f"setup {reference[0].weight.size()}, {model[0].weight.size()}, {(input_size, output_size)}" ) print(f"setup {reference[2].weight.size()}, {(output_size, input_size)}") reference[0].weight = Parameter( model[0].get_master_weight().clone()).cuda() reference[2].weight = Parameter( model[2].get_master_weight().clone()).cuda() reference = nn.Sequential(*reference) def grad_graph(depth, grad): result = depth * " " + str(grad) if grad: for x in grad.next_functions: result += "\n" + grad_graph(depth + 1, x[0]) return result def check_weights(x, y, key: str, index=None): for i in [2, 0]: if index is not None and i != index: continue left = x[i].get_master_weight() right = y[i].weight.data if not torch.allclose(left, right, atol=1.0e-6) or index is not None: print( f"check_weights {key}-{i}: left = {left}, \nright = {right}" ) if not torch.equal(left, right): print( f"check_weights NOT_EQUAL {key}-{i}: left = {left}, \nright = {right}" ) assert torch.allclose(left, right, atol=1.0e-6) def dump_opt_params(opt): for i, group in enumerate(opt.param_groups): for j, p in enumerate(group["params"]): print(f"{torch.distributed.get_rank()}:param {(i,j)} = {p}") print( f"{torch.distributed.get_rank()}:param.grad {(i,j)} = {p.grad}" ) def forward_model(model_, target, step=False): optimizer = torch.optim.SGD(model_.parameters(), lr=0.01, momentum=0.9) optimizer.zero_grad() model_.zero_grad() output = model_(identity()) loss = nn.MSELoss() model_.zero_grad() if step: loss(output, target).backward() saved_weight_0 = model_[0].weight.data.clone() saved_weight_2 = model_[2].weight.data.clone() dump_opt_params(optimizer) optimizer.step() assert not torch.allclose( saved_weight_0, model_[0].weight.data, atol=1.0e-6) assert not torch.allclose( saved_weight_2, model_[2].weight.data, atol=1.0e-6) return output output = forward_model(model, target) reference_output = forward_model(reference, target) error = reference_output.sub(output).max() torch.distributed.barrier() assert error < 1.0e-6 output = forward_model(model, target) error = reference_output.sub(output).max() torch.distributed.barrier() assert error < 1.0e-6 output = forward_model(model, target) error = reference_output.sub(output).max() torch.distributed.barrier() assert error < 1.0e-6 check_weights(model, reference, "before") saved_weight_0 = model[0].weight.data.clone() saved_weight_2 = model[2].weight.data.clone() output = forward_model(model, target, step=True) error = reference_output.sub(output).max() assert error < 1.0e-6 model[0].weight.data = saved_weight_0 model[2].weight.data = saved_weight_2 worker_map = { i: f"Test{i}" for i in range(torch.distributed.get_world_size()) } if pipe_world_size == 2: print("actually doing pipe stuff now") assert torch.equal(saved_weight_0, model[0].weight.data) assert torch.equal(saved_weight_2, model[2].weight.data) pipe_model = MultiProcessPipe( model, [2, 1], group=pipeline_devices, worker_map=worker_map, input_device=torch.cuda.current_device(), chunks=chunk_size, ).cuda() torch.distributed.barrier() pipe_rank = torch.distributed.get_rank( group=mpu.get_pipeline_parallel_group()) print(f"pipe rank is {pipe_rank}") if pipe_rank == 0: assert torch.equal(saved_weight_0, pipe_model[0].weight.data) else: if not torch.equal(saved_weight_2, pipe_model[0].weight.data): print( f"ne {pipe_rank}: left\n{saved_weight_2}\nright:\n{pipe_model[0].weight.data}" ) assert torch.equal(saved_weight_2, pipe_model[0].weight.data) optimizer = torch.optim.SGD(pipe_model.parameters(), lr=0.01, momentum=0.9) optimizer.zero_grad() if pipe_rank == 0: assert torch.equal(saved_weight_0, pipe_model[0].weight.data) print(f"runner {rank}:\n{pipe_model[0].weight.data}") else: assert torch.equal(saved_weight_2, pipe_model[0].weight.data) print(f"runner {rank}:\n{pipe_model[0].weight.data}") if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: check_weights(model, reference, "pre-pipe", index=2) else: check_weights(model, reference, "pre-pipe", index=0) pipe_output = pipe_model(identity()) print(f"exited pipe for {rank}") forward_model(reference, target, step=True) print(f"pipe_output {rank} = {pipe_output}") print(f"reference_output {rank} = {reference_output}") torch.distributed.barrier() if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: error = reference_output.sub(pipe_output.cuda()).max() if error >= 1.0e-6: print(f"error bad {error}") assert error < 1.0e-6 loss = nn.MSELoss() failed = False pipe_output.retain_grad() with torch.autograd.profiler.profile() as prof: try: loss(pipe_output, target).backward() except Exception as e: failed = True print(f"got {e} while doing backward, deadlock?") if failed: raise RuntimeError("failed somehow") dump_opt_params(optimizer) optimizer.step() print("calling check_weights on master") check_weights(model, reference, "pipe", index=2) print(f"waiting for barrier on master, pid={os.getpid()}") else: print(f"calling backwards on slave, pid={os.getpid()}") failed = False with torch.autograd.profiler.profile() as prof: try: pipe_model.back_helper(pipe_output) except Exception as e: failed = True print(f"got {e} while doing backward, deadlock?") if failed: raise RuntimeError("failed somehow") dump_opt_params(optimizer) print("calling step on slave") optimizer.step() print("calling check_weights on slave") check_weights(model, reference, "pipe", index=0) print("waiting for barrier on slave") pipe_model.zero_grad() torch.distributed.barrier() pipe_model.eval() pipe_output = pipe_model(identity()) updated_ref_output = forward_model(reference, target) if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: error = updated_ref_output.sub(pipe_output.cuda()).max() print( f"outputs are ref:\n{updated_ref_output}\npipe:\n{pipe_output}" ) assert error < 1.0e-6 torch.distributed.barrier() print(f"finished waiting for barrier on, pid={os.getpid()}") print(f"really exited pipe for {rank}") rpc.shutdown() torch.distributed.destroy_process_group()
def run_test_initialize_affine_weight(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) mpu.initialize_model_parallel(model_parallel_size) if torch.distributed.get_rank() == 0: print( "> testing initialize_affine_weight with model parallel size: {}". format(model_parallel_size)) model_parallel_size = mpu.get_model_parallel_world_size() seed = 12345 input_size_coeff = 13 input_size = input_size_coeff * model_parallel_size output_size_coeff = 17 output_size = output_size_coeff * model_parallel_size # --------------- # Column parallel # --------------- weight = torch.empty(output_size_coeff, input_size) set_random_seed(seed) layers._initialize_affine_weight(weight, output_size, input_size, output_size_coeff, 0, torch.nn.init.normal_) # Target. set_random_seed(seed) master_weight = torch.empty(output_size, input_size) torch.nn.init.normal_(master_weight) rank = mpu.get_model_parallel_rank() my_weight = torch.split(master_weight, output_size_coeff, dim=0)[rank].contiguous().clone() # Compare. error = weight.sub(my_weight).abs().max() torch.distributed.barrier() print( " column parallel max error (should be zero) on global rank {}: {}". format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # ------------ # Row parallel # ------------ weight = torch.empty(output_size, input_size_coeff) set_random_seed(seed) layers._initialize_affine_weight(weight, output_size, input_size, input_size_coeff, 1, torch.nn.init.normal_) # Target. set_random_seed(seed) master_weight = torch.empty(output_size, input_size) torch.nn.init.normal_(master_weight) rank = mpu.get_model_parallel_rank() my_weight = torch.split(master_weight, input_size_coeff, dim=1)[rank].contiguous().clone() # Compare. error = weight.sub(my_weight).abs().max() torch.distributed.barrier() print(" row parallel max error (should be zero) on global rank {}: {}". format(torch.distributed.get_rank(), error)) assert error < 1.0e-6 # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(" >> passed the test :-)")
def run_test_cuda_rng_tracker(rank, model_parallel_size, filename, filename_rpc): dist_init(rank, model_parallel_size, filename, filename_rpc) if torch.distributed.get_rank() == 0: print("> testing cuda rng tracker with size {} ...".format(model_parallel_size)) mpu.initialize_model_parallel(model_parallel_size) model_parallel_size = mpu.get_model_parallel_world_size() seed_1 = 1234 seed_2 = 4321 size = [12, 21] tensor = torch.cuda.FloatTensor(size) # Set to seed_1 and generate two tensors. torch.cuda.manual_seed(seed_1) torch.randn(size, out=tensor) target_11 = tensor.clone() torch.randn(size, out=tensor) target_12 = tensor.clone() # Set to seed_2 and generate two tensors. torch.cuda.manual_seed(seed_2) torch.randn(size, out=tensor) target_21 = tensor.clone() torch.randn(size, out=tensor) target_22 = tensor.clone() # Now if we interleave seed_1 and seed_2, # we should still get the same tensors torch.cuda.manual_seed(seed_1) get_cuda_rng_tracker().add("test", seed_2) torch.randn(size, out=tensor) result_11 = tensor.clone() with get_cuda_rng_tracker().fork("test"): torch.randn(size, out=tensor) result_21 = tensor.clone() torch.randn(size, out=tensor) result_12 = tensor.clone() with get_cuda_rng_tracker().fork("test"): torch.randn(size, out=tensor) result_22 = tensor.clone() diff = result_11.sub(result_21).abs().max() diff = min(diff, result_12.sub(result_22).abs().max()) print( " max diff in generated tensors (should be non-zero) on global rank {}: {}".format( torch.distributed.get_rank(), diff ) ) assert diff > 1.0e-6 error = max(result_11.sub(target_11).abs().max(), result_12.sub(target_12).abs().max()) error = max(error, result_21.sub(target_21).abs().max()) error = max(error, result_22.sub(target_22).abs().max()) print( " max error in generated tensors (should be zero) on global rank {}: {}".format( torch.distributed.get_rank(), error ) ) assert error < 1.0e-6 # Reset the tracker get_cuda_rng_tracker().reset() # Reset groups mpu.destroy_model_parallel() torch.distributed.barrier() if torch.distributed.get_rank() == 0: print(">> passed the test :-)")