def _test_state_dict_device(self, config, rank, group, pure_fp16=False, **model_kwargs): model = TransformerWithSharedParams(group, **model_kwargs) if pure_fp16: assert not config["mixed_precision"] model = model.half() fsdp_model = FSDP(model, group, **config) if not config["cpu_offload"]: fsdp_model = fsdp_model.cuda() autocast = fsdp_model.mixed_precision or pure_fp16 self._train_for_several_steps(fsdp_model, 1, autocast) sd = fsdp_model.state_dict() sd_device = config.get("state_dict_device") for k, v in sd.items(): if config["cpu_offload"] or (sd_device is not None and sd_device.type == "cpu"): assert v.device.type == "cpu", v.device.type else: assert v.device.type == "cuda", v.device.type expected_dtype = torch.float16 if pure_fp16 else torch.float32 for k, v in sd.items(): if not torch.is_floating_point(v): continue assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
def _distributed_worker(gpu_id, world_size, with_fsdp, freezing_method, tempfile_name, unused, rank_0_output, expected_state): torch.cuda.set_device(gpu_id) rank = gpu_id result = dist_init(rank, world_size, tempfile_name, unused) assert result, "Dist init failed" torch.manual_seed(0) torch.backends.cudnn.deterministic = True batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = _create_model(with_fsdp) model = model.cuda() # freezing the trunk using requires_grad. assert freezing_method in ["requires_grad", "grad_to_none"] if freezing_method == "requires_grad": for param in model.trunk.parameters(): param.requires_grad = False if with_fsdp: model = FSDP(model) else: model = DistributedDataParallel(model, device_ids=[gpu_id]) if gpu_id == 0: print(model) target = torch.LongTensor([0, 1]).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) for iteration in range(3): out = model(batch) fake_loss = criterion(out, target) print("Loss", iteration, ":", fake_loss.item()) optimizer.zero_grad() fake_loss.backward() if freezing_method == "grad_to_none": for param in model.trunk.parameters(): param.grad = None optimizer.step() if with_fsdp: fsdp_state = model.state_dict() # Move tensors to CPU to compare numerics. for k, v in fsdp_state.items(): fsdp_state[k] = v.cpu() assert objects_are_equal(expected_state, fsdp_state, raise_exception=True) elif rank == 0: state_after = model.module.cpu().state_dict() torch.save(state_after, rank_0_output) teardown()
def _dist_worker(rank, world_size, files, wrap_middle, test_fn): # Get data from files. file1, file2, sd_before, sd_after, in_data = files sd_before = torch.load( sd_before, map_location=lambda storage, loc: storage.cuda(rank)) if test_fn == "train": sd_after = torch.load( sd_after, map_location=lambda storage, loc: storage.cuda(rank)) in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank)) result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2) assert result, "Dist init failed" fsdp_model = FSDP( # To debug: first make with_fsdp=False (no inner wrapping) work, then enable inner wrapping # and make that work. Model(with_fsdp=True, wrap_middle=wrap_middle), flatten_parameters=test_fn == "optim_state", mixed_precision=False, compute_dtype=torch.float16, ) fsdp_model.load_state_dict(sd_before) if test_fn == "train": _train(fsdp_model, in_data) objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True) elif test_fn == "eval": _eval(fsdp_model, in_data) elif test_fn == "optim_state": optim = SGD(fsdp_model.parameters(), lr=0.1) for _ in range(3): out = fsdp_model(in_data) out.backward() optim.step() sd = fsdp_model.gather_full_optim_state_dict(optim) if rank == 0: # There should 8 momentum buffers in the state. assert len(sd["state"].keys()) == 8 else: assert sd is None, "only rank 0 should have the optim state" else: assert 0, f"invalid test_fn {test_fn}" teardown()
def _test_module_state_dict(cls, config, rank, group): ddp_model = cls.get_wrapped_model(group, cuda_first=False, config=config) autocast = ddp_model.mixed_precision cls._train_for_several_steps(ddp_model, 2, autocast) state_1 = ddp_model.state_dict() # You must make a new FSDP instance to use module.load_state_dict unwrapped_model = TransformerWithSharedParams(group) unwrapped_model.load_state_dict(state_1) new_ddp_model = FSDP(unwrapped_model, group, **config).cuda() cls._train_for_several_steps(new_ddp_model, 2, autocast) try: ddp_model.load_state_dict(new_ddp_model.state_dict()) assert False, "ddp_model.load_state_dict(new_ddp_model.state_dict()) succeeded" except Exception: pass
def _test_identical_outputs( cls, model_init_fn, config, rank, group, num_steps=2, use_cuda=True, lr=0.01, ref_ddp_fn=None, norm_type=2, ): if config.get("mixed_precision", False): autocast = True # Force the compute dtype to be torch.float32 so that we get # identical results as PyTorch DDP when using autocast. Note that # this will cause the all-gather to happen in FP32, which is slower # than necessary in most cases. config["compute_dtype"] = torch.float32 else: autocast = False # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrapper_config=None).cuda() if ref_ddp_fn is None: model = nn.parallel.DistributedDataParallel( model, device_ids=[rank], output_device=rank, process_group=group ) else: model = ref_ddp_fn(model, group) ref_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) ref_state_dict = model.module.state_dict() if config.get("cpu_offload", False): for k in ref_state_dict.keys(): ref_state_dict[k] = ref_state_dict[k].cpu() # Confirm we get the same behavior using FullyShardedDataParallel. model = FullyShardedDataParallel(model_init_fn(group=group, wrapper_config=config), group, **config) if use_cuda: model = model.cuda() else: assert next(model.parameters()).device == torch.device("cpu") shard_loss = cls._train_for_several_steps(model, num_steps, autocast, lr=lr, norm_type=norm_type) shard_state_dict = model.state_dict() try: torch.testing.assert_allclose(ref_loss, shard_loss) assert objects_are_equal(ref_state_dict, shard_state_dict, raise_exception=True) except (AssertionError, RuntimeError) as e: raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}") if config.get("flatten_parameters", True): metadata = model.local_metadata_dict() assert isinstance(metadata, dict)
def _dist_worker(rank, world_size, files, outer_flat, inner_flat, sharing): # Get data from files. file1, file2, sd_before, sd_after, in_data = files sd_before = torch.load(sd_before, map_location=lambda storage, loc: storage.cuda(rank)) sd_after = torch.load(sd_after, map_location=lambda storage, loc: storage.cuda(rank)) in_data = torch.load(in_data, map_location=lambda storage, loc: storage.cuda(rank)) result = dist_init(rank=rank, world_size=world_size, filename=file1, filename_rpc=file2) assert result, "Dist init failed" fsdp_model = FSDP(Model(with_fsdp=True, inner_flat=inner_flat, sharing=sharing), flatten_parameters=outer_flat) fsdp_model.load_state_dict(sd_before) _train(fsdp_model, in_data) objects_are_equal(sd_after, fsdp_model.state_dict(), raise_exception=True) teardown()
def _test_state_dict_device(self, config, rank, group, pure_fp16=False, **model_kwargs): model = TransformerWithSharedParams(group, **model_kwargs) if pure_fp16: assert not config["mixed_precision"] model = model.half() fsdp_model = FullyShardedDataParallel(model, group, **config) if not config["cpu_offload"]: fsdp_model = fsdp_model.cuda() autocast = fsdp_model.mixed_precision or pure_fp16 self._train_for_several_steps(fsdp_model, 1, autocast) sd = fsdp_model.state_dict() sd_device = config.get("state_dict_device") for k, v in sd.items(): if config["cpu_offload"] or (sd_device is not None and sd_device.type == "cpu"): assert v.device.type == "cpu", v.device.type else: assert v.device.type == "cuda", v.device.type expected_dtype = torch.float16 if pure_fp16 else torch.float32 buffers = { k.replace("_fsdp_wrapped_module.", "").replace("_fpw_module.", "") for k, _ in fsdp_model.named_buffers() } for k, v in sd.items(): if not torch.is_floating_point(v): continue if k in buffers: assert v.dtype == fsdp_model.buffer_dtype, f"{v.dtype} != {fsdp_model.buffer_dtype}" else: assert v.dtype == expected_dtype, f"{v.dtype} != {expected_dtype}"
def main(local_rank, *args): torch.backends.cudnn.benchmark = True init_method = "tcp://%s:%s" % ("0.0.0.0", "9999") torch.distributed.init_process_group(backend="nccl", rank=local_rank, world_size=8, init_method=init_method) print("[Train]: Time = %s, Initialized Dist Process for Rank = %s" % (get_time_string(), local_rank)) device = torch.device( f'cuda:{local_rank}') # Unique only on individual node. torch.cuda.set_device(device) torch.cuda.set_device(device) fsdp_params = dict(mixed_precision=True, flatten_parameters=True, bucket_cap_mb=25, reshard_after_forward=False, fp32_reduce_scatter=False, cpu_offload=False, move_grads_to_cpu=False, process_group=torch.distributed.group.WORLD) with enable_wrap(wrapper_cls=FullyShardedDDP, **fsdp_params): nn_model = nn.Sequential( nn.Linear(200, 200), wrap( checkpoint_wrapper(nn.Sequential( nn.Linear(200, 200), nn.Linear(200, 200), wrap( checkpoint_wrapper(nn.Linear(200, 200), offload_to_cpu=True)), checkpoint_wrapper(nn.GELU(), offload_to_cpu=True), nn.Linear(200, 200)), offload_to_cpu=True)), checkpoint_wrapper(nn.GELU(), offload_to_cpu=True), nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda() model = FullyShardedDDP(nn_model, **fsdp_params) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, eps=1e-7, weight_decay=1e-2, betas=(0.9, 0.99)) optimizer.zero_grad(set_to_none=True) for i in range(1000): optimizer.zero_grad(set_to_none=True) fake_inputs = torch.randn(32, 200, device=device) fake_labels = torch.randn(32, 64, device=device) outputs = model(fake_inputs) loss = ((outputs - fake_labels)**2).mean() loss.backward() model.clip_grad_norm_(1.0) optimizer.step() if i % 100 == 0: print("Loss = %s, rank = %s" % (loss.item(), local_rank)) state_dict = model.state_dict() nn_model = nn.Sequential( nn.Linear(200, 200), nn.Sequential(nn.Linear(200, 200), nn.Linear(200, 200), nn.Linear(200, 200), checkpoint_wrapper(nn.GELU(), offload_to_cpu=True), nn.Linear(200, 200)), checkpoint_wrapper(nn.GELU(), offload_to_cpu=True), nn.LayerNorm(200, eps=1e-7), nn.Linear(200, 64)).cuda() nn_model.load_state_dict(state_dict) print("[Train]: Time = %s, Trainable Params = %s" % (get_time_string(), numel(nn_model) / 1_000_000))