コード例 #1
0
ファイル: profiler.py プロジェクト: yangchen918/ClassyVision
def modify_forward(
    model: nn.Module,
    complexity_computer: ComplexityComputer,
    prefix: str = "",
    patch_attr: str = None,
) -> nn.Module:
    """
    Modify forward pass to measure a module's parameters, like FLOPs.
    """
    # Recursively update all the modules in the model. A module is patched if it
    # contains the patch_attr (like the flops() function for FLOPs computation) or it is
    # a leaf. We stop recursing if we patch a module since that module is supposed
    # to return the results for all its children as well.
    # Since this recursion can lead to the same module being patched through different
    # paths, we make sure we only patch un-patched modules.
    if hasattr(model, "orig_type"):
        return model
    if is_leaf(model) or (patch_attr is not None and hasattr(model, patch_attr)):
        model.__class__ = _patched_computation_module(
            model, complexity_computer, prefix
        )
    else:
        for name, child in model.named_children():
            modify_forward(
                child,
                complexity_computer,
                prefix=f"{prefix}.{name}",
                patch_attr=patch_attr,
            )
    return model
コード例 #2
0
ファイル: section5.py プロジェクト: francois-rozet/info8003-1
def dql(model: nn.Module,
        ts: TrainingSet,
        epochs: Tuple[int, int],
        normed: bool = False):
    '''Double Q-learning training'''

    loader = ts_loader(ts)
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    ql_init(model, loader, optimizer)

    for _ in tqdm.tqdm(range(epochs[0])):
        goal = model.__class__()
        goal.load_state_dict(model.state_dict())

        for _ in range(epochs[1]):
            ql_epoch(model, goal, loader, optimizer, normed)
コード例 #3
0
    def _reduce_last_block_by_factor(inv_residual: nn.Module, factor: int,
                                     dilation: int):
        assert isinstance(inv_residual, InvertedResidual), \
            "Block is not of type InvertedResidual"

        args = inspect.getfullargspec(inv_residual.__init__).args
        module_args = {}
        for arg in args:
            if hasattr(inv_residual, arg):
                module_args[arg] = getattr(inv_residual, arg)

        module_args["dilation"] = dilation
        module_args["exp_c"] = int(module_args["exp_c"] // factor)
        module_args["out_c"] = int(module_args["out_c"] // factor)

        new_inv_residual = inv_residual.__class__(**module_args)

        return new_inv_residual
コード例 #4
0
ファイル: dependentmodule.py プロジェクト: yhqjohn/MetaNN
 def _make_subclass(cls, module: Module):
     if not isinstance(module, cls):
         module.__class__ = type("Dependent" + type(module).__name__, (cls, type(module)), {})
         module._reinit()
     return module