def modify_forward( model: nn.Module, complexity_computer: ComplexityComputer, prefix: str = "", patch_attr: str = None, ) -> nn.Module: """ Modify forward pass to measure a module's parameters, like FLOPs. """ # Recursively update all the modules in the model. A module is patched if it # contains the patch_attr (like the flops() function for FLOPs computation) or it is # a leaf. We stop recursing if we patch a module since that module is supposed # to return the results for all its children as well. # Since this recursion can lead to the same module being patched through different # paths, we make sure we only patch un-patched modules. if hasattr(model, "orig_type"): return model if is_leaf(model) or (patch_attr is not None and hasattr(model, patch_attr)): model.__class__ = _patched_computation_module( model, complexity_computer, prefix ) else: for name, child in model.named_children(): modify_forward( child, complexity_computer, prefix=f"{prefix}.{name}", patch_attr=patch_attr, ) return model
def restore_forward(model): """ Restore original forward in model: """ if is_leaf(model): model.__class__ = model.orig_type for child in model.children(): restore_forward(child) return model
def modify_forward(model, compute_list, compute_fn): """ Modify forward pass to measure a module's parameters, like FLOPs. """ if is_leaf(model): model.__class__ = _patched_computation_module(model, compute_list, compute_fn) for child in model.children(): modify_forward(child, compute_list, compute_fn) return model
def modify_forward(model, flops_list): """ Modify forward pass to measure FLOPs: """ if is_leaf(model): model.__class__ = _flops_module(model, flops_list) for child in model.children(): modify_forward(child, flops_list) return model
def restore_forward(model): """ Restore original forward in model: """ if is_leaf(model) or hasattr(model, "flops"): model.__class__ = model.orig_type else: for child in model.children(): restore_forward(child) return model
def count_params(model): """ Count the number of parameters in a model. """ assert isinstance(model, nn.Module) count = 0 for child in model.children(): if is_leaf(child): if hasattr(child, "_mask"): # for masked modules (like LGC) count += child._mask.long().sum().item() # FIXME: BatchNorm parameters in LGC are not counted. else: # for regular modules for p in child.parameters(): count += p.nelement() else: count += count_params(child) return count