def test_scaler_cpu_offload_breaks(): device = torch.device("cuda") torch.cuda.set_device(0) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(random.randint(2000, 3000)) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) try: scaler = ShardedGradScaler() model = FullyShardedDataParallel(nn.Linear(5, 5), cpu_offload=True, mixed_precision=True) optim = torch.optim.SGD(model.parameters(), lr=1e-3) input = torch.rand((1, 5), dtype=torch.float).to(device) optim.zero_grad() with autocast(): output = model(input) loss = F.mse_loss(input, output) scaler.scale(loss).backward() # TODO (Min): Need to fix. Details in issue #421. with pytest.raises(RuntimeError): scaler.step(optim) scaler.update() finally: # Clean-up is important or the next test in this file may fail to init the PG. torch.distributed.destroy_process_group() del os.environ["MASTER_ADDR"] del os.environ["MASTER_PORT"]
def __init__(self, embedding_size: int, with_fsdp: bool, process_group): super().__init__() self.conv1 = self._conv_block(3, embedding_size) self.conv2: nn.Module = self._conv_block(embedding_size, embedding_size // 2) self.conv3: nn.Module = self._conv_block(embedding_size // 2, embedding_size) self.pool = nn.AdaptiveAvgPool2d(output_size=(1, 1)) self.flatten = nn.Flatten(start_dim=1) self.relu = nn.ReLU() self.fc1: nn.Module = nn.Linear(embedding_size, 2 * embedding_size) self.fc2: nn.Module = nn.Linear(2 * embedding_size, 2 * embedding_size) self.fc3: nn.Module = nn.Linear(2 * embedding_size, embedding_size + 1) self.fc4: nn.Module = nn.Linear(embedding_size + 1, embedding_size) if with_fsdp: self.conv2 = FullyShardedDataParallel(self.conv2, process_group=process_group) self.conv3 = FullyShardedDataParallel(self.conv3, process_group=process_group, flatten_parameters=False) self.fc1 = FullyShardedDataParallel(self.fc1, process_group=process_group) self.fc3 = FullyShardedDataParallel(self.fc3, process_group=process_group, flatten_parameters=False)
def init_fsdp_model_from_weights( cls, model: FullyShardedDataParallel, checkpoint: Dict[str, Any], weights_path: List[str], strict: bool = True, head_index: int = -1, ): """ Load the weights of the checkpoint to the FSDP model: - Take into account the type of checkpoint to decide on how to perform the load (sharded or consolidated load) - Takes into account the head_index (-1 if trunk else >= 0) to find the appropriate weights for the head """ if checkpoint["type"] == CheckpointItemType.slice_list.name: # Hack for checkpoints consolidated with the "layers" format # instead of the new "classy_state_dict" format: in that case # the slices are directly saved under "layers" and do not take # into account the 'weights_path' variable if "classy_state_dict" not in checkpoint: weights = checkpoint["layers"] else: weights = cls._extract_weights(checkpoint, weights_path, head_index) if weights is not None: SlicedCheckpointLoader.load_slice_state_dict(model, weights, strict=strict) else: raise ValueError( f"Could not find weights path: {weights_path}") elif checkpoint["type"] == CheckpointItemType.consolidated.name: weights = cls._extract_weights(checkpoint, weights_path, head_index) if weights is not None: out = model.load_state_dict(weights, strict=False) cls._check_load_state_dict_out(out, strict=strict) elif strict: raise ValueError( f"Could not find weights path: {weights_path}") else: weights = cls._extract_weights(checkpoint, weights_path, head_index) if weights is not None: out = model.load_local_state_dict(weights, strict=False) cls._check_load_state_dict_out(out, strict=strict) elif strict: raise ValueError( f"Could not find weights path: {weights_path}")
def init_fsdp_model_from_weights( cls, model: FullyShardedDataParallel, checkpoint: Dict[str, Any], weights_path: List[str], ): """ Load the weights of the checkpoint to the FSDP model: Take into account the type of checkpoint to decide how to perform the load (sharded or consolidated load) """ if checkpoint["type"] == CheckpointItemType.slice_list.name: SlicedCheckpointLoader.init_model_weights(model, checkpoint) elif checkpoint["type"] == CheckpointItemType.consolidated.name: weights = cls._extract_weights(checkpoint, weights_path) model.load_state_dict(weights) else: weights = cls._extract_weights(checkpoint, weights_path) model.load_local_state_dict(weights)
def SwavPrototypesHeadFSDP( model_config: AttrDict, dims: List[int], use_bn: bool, num_clusters: int, use_bias: bool = True, return_embeddings: bool = True, skip_last_bn: bool = True, normalize_feats: bool = True, ): """ SwAV head specific FSDP wrapping: we keep the full precision for the prototypes This is important for convergence: Since we "normalize" this layer in the update hook, we need to keep its weights in full precision. It is output is going into the loss and used for clustering, so we need to have that in full precision as well. """ head = SwAVPrototypesHead( model_config=model_config, dims=dims, use_bn=use_bn, num_clusters=num_clusters, use_bias=use_bias, return_embeddings=return_embeddings, skip_last_bn=skip_last_bn, normalize_feats=normalize_feats, ) fp32_fsdp_config = model_config.FSDP_CONFIG.copy() fp32_fsdp_config["flatten_parameters"] = False fp32_fsdp_config["mixed_precision"] = False fp32_fsdp_config["fp32_reduce_scatter"] = False fp32_fsdp_config["compute_dtype"] = torch.float32 for j in range(head.nmb_heads): module = getattr(head, "prototypes" + str(j)) module = FullyShardedDataParallel(module=module, **fp32_fsdp_config) setattr(head, "prototypes" + str(j), module) return FullyShardedDataParallel(head)
def _create_model(embedding_size: int, with_fsdp: bool, process_group, flatten_parameters: bool = True): model = ConvolutionalModel(with_fsdp=with_fsdp, process_group=process_group, embedding_size=embedding_size).cuda() if with_fsdp: return FullyShardedDataParallel(model, process_group=process_group, flatten_parameters=flatten_parameters) else: return model
def _test_consolidate_weights(self, config, rank, group, paths=None, transformer=False): """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()""" # Establish reference behavior. if transformer: fsdp = self.get_wrapped_model(group, config=config).cuda() else: fsdp = FullyShardedDataParallel( MixtureOfExperts(group, wrapper_config=config)).cuda() optim = Adam( fsdp.parameters(), lr=0.01, ) optim.zero_grad() with torch.cuda.amp.autocast(enabled=True): x = fsdp.module.get_input(torch.device("cuda")) output = fsdp(*x) loss = fsdp.module.get_loss(x, output).to("cuda") fsdp.module.run_backward(loss) optim.step() # each worker saves a checkpoint with local_state_dict cp_data = { "weights": {k: v.cpu() for k, v in fsdp.local_state_dict().items()}, "meta": fsdp.local_metadata_dict(), } torch.save(cp_data, paths[fsdp.rank]) full_model_state_dict = fsdp.state_dict() torch.distributed.barrier() if fsdp.rank > 0: return all_checkpoints = [torch.load(p) for p in paths] consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights( shard_weights=[c["weights"] for c in all_checkpoints], shard_metadata=[c["meta"] for c in all_checkpoints], ) full_model_extra = set(full_model_state_dict).difference( set(consolidated_checkpoint)) consolidated_extra = set(consolidated_checkpoint).difference( set(full_model_state_dict)) msg = f"full model extra keys: {full_model_extra}, consolidated extra {consolidated_extra}" for k in full_model_state_dict.keys(): assert consolidated_checkpoint[k].shape == full_model_state_dict[ k].shape assert set(full_model_state_dict.keys()) == set( consolidated_checkpoint.keys()), msg
def test_consolidate_missing_params(): """This tests that fairseq experts, which are saved independently from the rest of the model, can be consolidated.""" desired_path = "decoder.layers.1.moe_layer.experts.0" shard_metadata = { "param_metadata": [ { "fsdp_path": "", "params": { "flat_param_0": { "names": ["missing"], "shapes": [(12, 4)], "numels": [12 * 4], "padding": 0 } }, "no_broadcast_optim_state": False, "shared_param_info": [], }, { "fsdp_path": desired_path, "params": { "flat_param_0": { "names": ["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"], "shapes": [(4, 4), (4, ), (4, 4), (4, )], "numels": [16, 4, 16, 4], "padding": 0, } }, "no_broadcast_optim_state": True, "shared_param_info": [], }, ], "buffer_names": ["missing.buffer"], } shard_weights = { "decoder.layers.1.moe_layer.experts.0.flat_param_0": torch.randn(40, dtype=torch.float16) } consolidated_weights = FullyShardedDataParallel.consolidate_shard_weights( [shard_weights], [shard_metadata], strict=False) assert len(consolidated_weights) == 4 for k in consolidated_weights: assert k.startswith( desired_path), f"{k} doesnt start with {desired_path}"
def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False): """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()""" # Establish reference behavior. if transformer: fsdp = self.get_wrapped_model(group, config=config).cuda() unwrapped_model = TransformerWithSharedParams(group).cuda() else: fsdp = FullyShardedDataParallel( NestedWrappedModule(group, wrapper_config=config), group, **config).cuda() unwrapped_model = NestedWrappedModule(group, wrapper_config=None).cuda() try: fsdp_optim = optim_fn( fsdp.parameters(), lr=0.01, ) optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) except TypeError: # Adadelta fsdp_optim = optim_fn(fsdp.parameters()) optim_unwrapped = optim_fn(unwrapped_model.parameters()) fsdp_optim.zero_grad() optim_unwrapped.zero_grad() x = fsdp.module.get_input(torch.device("cuda")) output = fsdp(*x) loss = fsdp.module.get_loss(x, output).to("cuda") fsdp.module.run_backward(loss) fsdp_optim.step() output = unwrapped_model(*x) loss = unwrapped_model.get_loss(x, output) unwrapped_model.run_backward(loss) optim_unwrapped.step() unwrapped_sd = optim_unwrapped.state_dict() tstart = time() sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) duration = time() - tstart # Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" if fsdp.rank > 0: return assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) assert_equal( sum([first_tensor_numel(v) for k, v in sd["state"].items()]), sum([ first_tensor_numel(v) for k, v in unwrapped_sd["state"].items() ]), ) shard_sd = fsdp.get_shard_from_optim_state_dict(sd) original_shard_sd = fsdp_optim.state_dict() assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"])) assert_equal(shard_sd.keys(), original_shard_sd.keys()) original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") assert_equal( sum([first_tensor_numel(v) for k, v in shard_sd["state"].items()]), sum([ first_tensor_numel(v) for k, v in original_shard_sd["state"].items() ]), ) assert objects_are_equal(shard_sd, original_shard_sd)
def _test_consolidated_optimizer(self, config, rank, group, optim_fn=torch.optim.SGD, transformer=False): """FSDP.gather_full_optim_state_dict() should return something very similar to optimizer.state_dict()""" # Establish reference behavior. if transformer: unwrapped_model = TransformerWithSharedParams( group, wrapper_config=config).cuda() fsdp = self.get_wrapped_model(group, config=config).cuda() else: unwrapped_model = MixtureOfExperts(group, wrapper_config=None).cuda() fsdp = FullyShardedDataParallel( MixtureOfExperts(group, wrapper_config=config)).cuda() try: fsdp_optim = optim_fn( fsdp.parameters(), lr=0.01, ) optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) except TypeError: # Adadelta fsdp_optim = optim_fn(fsdp.parameters()) optim_unwrapped = optim_fn(unwrapped_model.parameters()) fsdp_optim.zero_grad() optim_unwrapped.zero_grad() with torch.cuda.amp.autocast(enabled=True): x = fsdp.module.get_input(torch.device("cuda")) output = fsdp(*x) loss = fsdp.module.get_loss(x, output).to("cuda") fsdp.module.run_backward(loss) fsdp_optim.step() output = unwrapped_model(*x) loss = unwrapped_model.get_loss(x, output) unwrapped_model.run_backward(loss) optim_unwrapped.step() unwrapped_sd = optim_unwrapped.state_dict() if not transformer: no_broadcast_children = [ x for x in fsdp._fsdp_instances if x.no_broadcast_optim_state ] assert len(no_broadcast_children) == 1 assert fsdp._fsdp_instances[-1].no_broadcast_optim_state torch.cuda.empty_cache() cuda_gb_before = torch.cuda.memory_stats( fsdp.rank)["allocated_bytes.all.current"] / 1024**3 tstart = time() sd = fsdp.gather_full_optim_state_dict(fsdp_optim, recipient_rank=0) duration = time() - tstart assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" cuda_gb_after = torch.cuda.memory_stats( fsdp.rank)["allocated_bytes.all.current"] / 1024**3 mem_usg_gb = cuda_gb_after - cuda_gb_before assert mem_usg_gb == 0, f"gather_full_optim_state_dict used {mem_usg_gb:.2f} CUDA GB, max allowed is 0" assert cuda_gb_after > 0, "got 0 memory usage, logging is broken" if fsdp.rank > 0: assert sd is None return # assert whole state dict on CPU for k, v in sd["state"].items(): for buffer_name, t in v.items(): if torch.is_tensor(t): msg = f"got device {t.device} for {k}: {buffer_name}. expected CPU" assert t.device == torch.device("cpu"), msg unflat_state = sd["state"] assert "uncollected_local_ids" in sd shard_sd = fsdp.get_shard_from_optim_state_dict(sd) shard_sd = recursive_copy_to_device(shard_sd, non_blocking=False, device="cpu") state_after_get_shard = sd["state"] assert objects_are_equal(unflat_state, state_after_get_shard) # no side effects. assert_equal(len(sd["state"]), len(unwrapped_sd["state"])) assert_equal(len(sd["param_groups"][0]["params"]), len(unwrapped_sd["param_groups"][0]["params"])) assert_equal( sum([first_tensor_numel(v) for k, v in sd["state"].items()]), sum([ first_tensor_numel(v) for k, v in unwrapped_sd["state"].items() ]), ) original_shard_sd = fsdp_optim.state_dict() assert_equal(len(shard_sd["state"]), len(original_shard_sd["state"])) assert_equal(shard_sd.keys(), original_shard_sd.keys()) original_shard_sd = recursive_copy_to_device(original_shard_sd, non_blocking=False, device="cpu") # Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks. assert_equal( [first_tensor_numel(v) for k, v in shard_sd["state"].items()], [ first_tensor_numel(v) for k, v in original_shard_sd["state"].items() ], ) assert_equal( [v for k, v in shard_sd["param_groups"][0].items()], [v for k, v in original_shard_sd["param_groups"][0].items()], ) assert objects_are_equal(shard_sd["state"], original_shard_sd["state"]) assert objects_are_equal({k: shard_sd[k] for k in original_shard_sd}, original_shard_sd)
def _worker(gpu_id: int, sync_file: str, world_size: int, embedding_size: int, flatten_parameters: bool): torch.manual_seed(0) torch.cuda.set_device(gpu_id) torch.distributed.init_process_group( backend="nccl", init_method=f"file://{sync_file}", world_size=world_size, rank=gpu_id, ) process_group = torch.distributed.new_group() # Create a dummy model with dummy inputs and targets batch_size = 4 input = torch.randn(size=(batch_size, 3, 32, 32)).cuda() target = torch.zeros(size=(batch_size, embedding_size)).cuda() model = _create_model( with_fsdp=True, process_group=process_group, embedding_size=embedding_size, flatten_parameters=flatten_parameters, ) criterion = nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=1e-2) # Train the model for a few epochs for epoch in range(2): out = model(input) loss = criterion(out, target) optimizer.zero_grad() loss.backward() optimizer.step() # Save a bunch of checkpoint, one by shard cp_data = { "weights": {k: v.cpu() for k, v in model.local_state_dict().items()}, "meta": model.local_metadata_dict(), } torch.save(cp_data, f"checkpoint_{gpu_id}.torch") # Wait for all files to be written on the disk dist.barrier() # type: ignore # Reconstruct a full checkpoint from the sharded checkpoints all_checkpoints = [_load_sharded_checkpoint(rank) for rank in range(world_size)] consolidated_checkpoint = FullyShardedDataParallel.consolidate_shard_weights( shard_weights=[c["weights"] for c in all_checkpoints], shard_metadata=[c["meta"] for c in all_checkpoints], ) # Check that the reconstructed parameters are correct and of the right shape full_model = _create_model(with_fsdp=False, process_group=process_group, embedding_size=embedding_size) full_model_state_dict = full_model.state_dict() assert set(full_model_state_dict.keys()) == set(consolidated_checkpoint.keys()) for k in full_model_state_dict.keys(): assert consolidated_checkpoint[k].shape == full_model_state_dict[k].shape # Verify that the checkpoint can be loaded by a FSDP model loaded_model = _create_model( with_fsdp=True, process_group=process_group, embedding_size=embedding_size, flatten_parameters=flatten_parameters, ) loaded_model.load_state_dict(consolidated_checkpoint) for m in loaded_model.modules(): if isinstance(m, FullyShardedDataParallel): m._reset_lazy_init() # Verify that the model saved and the model loaded give the same results with torch.no_grad(): before_checkpoint_loss = criterion(model(input), target).item() after_checkpoint_loss = criterion(loaded_model(input), target).item() assert before_checkpoint_loss == after_checkpoint_loss
def _consolidate_shards(cls, weights: List[Dict[str, torch.Tensor]], metadata: List[Dict[str, Any]]): logging.info("Consolidating shards...") return FullyShardedDataParallel.consolidate_shard_weights( weights, metadata)
def _layer_memory_tracking_fsdp_worker(gpu_id: int, sync_files: Tuple[str, str], world_size: int): dist_init(world_size=world_size, rank=gpu_id, filename=sync_files[0], filename_rpc=sync_files[1]) torch.backends.cudnn.deterministic = True # Create different inputs on each GPU batch_size = 16 torch.manual_seed(gpu_id) fake_inputs = torch.randn(size=(batch_size, 10)).cuda(gpu_id) fake_targets = torch.randn(size=(batch_size, 10)).cuda(gpu_id) fake_criterion = nn.MSELoss() # Create a global group and a tracker around it group = dist.new_group() group = ProcessGroupTracker(group) # Create a simple model torch.manual_seed(0) torch.cuda.manual_seed(0) model = nn.Sequential( nn.Linear(10, 10).cuda(gpu_id), nn.ReLU(), FullyShardedDataParallel( nn.Linear(10, 10).cuda(gpu_id), flatten_parameters=False, process_group=group, ), nn.ReLU(), FullyShardedDataParallel( nn.Linear(10, 10).cuda(gpu_id), flatten_parameters=True, process_group=group, ), ) model = model.cuda(gpu_id) dist_model = FullyShardedDataParallel(model, flatten_parameters=False, process_group=group) # Track the model on a forward / backward pass tracker = LayerwiseMemoryTracker() tracker.monitor(dist_model) fake_criterion(dist_model(fake_inputs), fake_targets).backward() tracker.stop() # Check results of all gathers tracking (feature specific to FSDP) all_gathered_traces = [ (t.module_name, t.all_gathered, t.cumul_all_gathered) for t in tracker.memory_traces if t.all_gathered > 0 ] assert all_gathered_traces == [ ("_fsdp_wrapped_module._fpw_module.0", 440, 440), ("_fsdp_wrapped_module._fpw_module.2._fsdp_wrapped_module._fpw_module", 440, 880), ("_fsdp_wrapped_module._fpw_module.4._fsdp_wrapped_module._fpw_module", 440, 880), ("_fsdp_wrapped_module._fpw_module.4._fsdp_wrapped_module._fpw_module", 440, 0), ("_fsdp_wrapped_module._fpw_module.2._fsdp_wrapped_module._fpw_module", 440, 0), ], all_gathered_traces
def _norm_computation_worker(gpu_id: int, sync_file: str, world_size: int): init_distributed_on_file(world_size=world_size, gpu_id=gpu_id, sync_file=sync_file) torch.manual_seed(0) torch.backends.cudnn.deterministic = True num_iterations = 10 batch_size = 128 torch.manual_seed(gpu_id) fake_inputs = torch.randn(size=(num_iterations, batch_size, 129)) fake_targets = torch.randn(size=(num_iterations, batch_size)) losses = {} for with_fsdp in [False, True]: torch.manual_seed(0) torch.cuda.manual_seed(0) losses[with_fsdp] = [] # Create a simple model model = nn.Sequential(nn.Linear(129, 128), nn.ReLU(), nn.Linear(128, 10)) model = model.cuda(gpu_id) # Setting up FSDP vs DDP with LARC larc_config = { "clip": False, "trust_coefficient": 0.01, "eps": 0.00000001 } optimizer = optim.SGD(model.parameters(), lr=1e-2, weight_decay=1e-4, momentum=0.9) if with_fsdp: model = FullyShardedDataParallel(model, flatten_parameters=False) optimizer = LARC_FSDP(optimizer, distributed_norm=True, **larc_config) else: model = DistributedDataParallel(model, device_ids=[gpu_id]) optimizer = LARC_FSDP(optimizer, distributed_norm=False, **larc_config) # Training loop criterion = nn.MSELoss() for iteration in range(num_iterations): fake_input = fake_inputs[iteration].cuda(gpu_id) fake_target = fake_targets[iteration].cuda(gpu_id) output = model(fake_input) loss = criterion(output.sum(axis=-1), fake_target) if gpu_id == 0: losses[with_fsdp].append(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() if gpu_id == 0: for with_fsdp in [False, True]: print(losses[with_fsdp]) if world_size > 1: losses[with_fsdp] = [ round(loss, 5) for loss in losses[with_fsdp] ] assert losses[False] == losses[True]
def _layer_memory_tracking_worker(gpu_id: int, sync_file: str, world_size: int): init_distributed_on_file(world_size=world_size, gpu_id=gpu_id, sync_file=sync_file) torch.manual_seed(0) torch.backends.cudnn.deterministic = True torch.manual_seed(gpu_id) batch_size = 16 fake_inputs = torch.randn(size=(batch_size, 10)).cuda(gpu_id) fake_targets = torch.randn(size=(batch_size, 10)).cuda(gpu_id) fake_criterion = nn.MSELoss() torch.manual_seed(0) torch.cuda.manual_seed(0) # Create a global group and a tracker around it group = dist.new_group() group = ProcessGroupTracker(group) # Create a simple model model = nn.Sequential( nn.Linear(10, 10).cuda(gpu_id), nn.ReLU(), FullyShardedDataParallel( nn.Linear(10, 10).cuda(gpu_id), flatten_parameters=False, process_group=group, ), nn.ReLU(), FullyShardedDataParallel( nn.Linear(10, 10).cuda(gpu_id), flatten_parameters=True, process_group=group, ), ) model = model.cuda(gpu_id) model = FullyShardedDataParallel(model, flatten_parameters=False, process_group=group) # Setup the tracking of the model tracker = LayerwiseMemoryTracker() tracker.monitor(model) # Fake forward / backward pass fake_criterion(model(fake_inputs), fake_targets).backward() # Collect results of all gathers (the feature specific to FSDP) tracker.stop() all_gathered_traces = [ (t.module_name, t.all_gathered, t.cumul_all_gathered) for t in tracker.memory_traces if t.all_gathered > 0 ] assert all_gathered_traces == [ ("_fsdp_wrapped_module.0", 440, 440), ("_fsdp_wrapped_module.2._fsdp_wrapped_module", 440, 880), ("_fsdp_wrapped_module.4._fsdp_wrapped_module._fpw_module", 440, 880), ("_fsdp_wrapped_module.4._fsdp_wrapped_module._fpw_module", 440, 0), ("_fsdp_wrapped_module.2._fsdp_wrapped_module", 440, 0), ]