def is_valid_fsdp_model(model: FSDP) -> bool: """ Checks if a FSDP model is valid by looking at the sub-FSDP modules and ensuring that they do not think they are the root FSDP model """ for n, m in model.named_modules(): if isinstance(m, FSDP): if n != "" and m._is_root is not None: return False return True
def fsdp_recursive_reset_lazy_init(fsdp_module: FSDP): """ Before the first forward pass, an FSDP module might have been initialized for instance by calling load_state_dict or load_local_state_dict to reload a previous training checkpoint. This function will recursively walk though the sub-FSDP modules and call _reset_lazy_init on each of them. """ to_visit = list(fsdp_module.named_modules()) while to_visit: name, module = to_visit.pop() if isinstance(module, FSDP) and module._is_root is not None: module._reset_lazy_init() for child_name, child in module.named_modules(): if child_name: to_visit.append((name + "." + child_name, child))