Exemplo n.º 1
0
def modify_forward(
    model: nn.Module,
    complexity_computer: ComplexityComputer,
    prefix: str = "",
    patch_attr: str = None,
) -> nn.Module:
    """
    Modify forward pass to measure a module's parameters, like FLOPs.
    """
    # Recursively update all the modules in the model. A module is patched if it
    # contains the patch_attr (like the flops() function for FLOPs computation) or it is
    # a leaf. We stop recursing if we patch a module since that module is supposed
    # to return the results for all its children as well.
    # Since this recursion can lead to the same module being patched through different
    # paths, we make sure we only patch un-patched modules.
    if hasattr(model, "orig_type"):
        return model
    if is_leaf(model) or (patch_attr is not None and hasattr(model, patch_attr)):
        model.__class__ = _patched_computation_module(
            model, complexity_computer, prefix
        )
    else:
        for name, child in model.named_children():
            modify_forward(
                child,
                complexity_computer,
                prefix=f"{prefix}.{name}",
                patch_attr=patch_attr,
            )
    return model
Exemplo n.º 2
0
def restore_forward(model):
    """
    Restore original forward in model:
    """
    if is_leaf(model):
        model.__class__ = model.orig_type
    for child in model.children():
        restore_forward(child)

    return model
Exemplo n.º 3
0
def modify_forward(model, compute_list, compute_fn):
    """
    Modify forward pass to measure a module's parameters, like FLOPs.
    """
    if is_leaf(model):
        model.__class__ = _patched_computation_module(model, compute_list, compute_fn)
    for child in model.children():
        modify_forward(child, compute_list, compute_fn)

    return model
Exemplo n.º 4
0
def modify_forward(model, flops_list):
    """
    Modify forward pass to measure FLOPs:
    """
    if is_leaf(model):
        model.__class__ = _flops_module(model, flops_list)
    for child in model.children():
        modify_forward(child, flops_list)

    return model
Exemplo n.º 5
0
def restore_forward(model):
    """
    Restore original forward in model:
    """
    if is_leaf(model) or hasattr(model, "flops"):
        model.__class__ = model.orig_type

    else:
        for child in model.children():
            restore_forward(child)

    return model
Exemplo n.º 6
0
def count_params(model):
    """
    Count the number of parameters in a model.
    """
    assert isinstance(model, nn.Module)
    count = 0
    for child in model.children():
        if is_leaf(child):
            if hasattr(child, "_mask"):  # for masked modules (like LGC)
                count += child._mask.long().sum().item()
                # FIXME: BatchNorm parameters in LGC are not counted.
            else:  # for regular modules
                for p in child.parameters():
                    count += p.nelement()
        else:
            count += count_params(child)
    return count