def modify_forward( model: nn.Module, complexity_computer: ComplexityComputer, prefix: str = "", patch_attr: str = None, ) -> nn.Module: """ Modify forward pass to measure a module's parameters, like FLOPs. """ # Recursively update all the modules in the model. A module is patched if it # contains the patch_attr (like the flops() function for FLOPs computation) or it is # a leaf. We stop recursing if we patch a module since that module is supposed # to return the results for all its children as well. # Since this recursion can lead to the same module being patched through different # paths, we make sure we only patch un-patched modules. if hasattr(model, "orig_type"): return model if is_leaf(model) or (patch_attr is not None and hasattr(model, patch_attr)): model.__class__ = _patched_computation_module( model, complexity_computer, prefix ) else: for name, child in model.named_children(): modify_forward( child, complexity_computer, prefix=f"{prefix}.{name}", patch_attr=patch_attr, ) return model
def dql(model: nn.Module, ts: TrainingSet, epochs: Tuple[int, int], normed: bool = False): '''Double Q-learning training''' loader = ts_loader(ts) optimizer = optim.Adam(model.parameters(), lr=0.001) ql_init(model, loader, optimizer) for _ in tqdm.tqdm(range(epochs[0])): goal = model.__class__() goal.load_state_dict(model.state_dict()) for _ in range(epochs[1]): ql_epoch(model, goal, loader, optimizer, normed)
def _reduce_last_block_by_factor(inv_residual: nn.Module, factor: int, dilation: int): assert isinstance(inv_residual, InvertedResidual), \ "Block is not of type InvertedResidual" args = inspect.getfullargspec(inv_residual.__init__).args module_args = {} for arg in args: if hasattr(inv_residual, arg): module_args[arg] = getattr(inv_residual, arg) module_args["dilation"] = dilation module_args["exp_c"] = int(module_args["exp_c"] // factor) module_args["out_c"] = int(module_args["out_c"] // factor) new_inv_residual = inv_residual.__class__(**module_args) return new_inv_residual
def _make_subclass(cls, module: Module): if not isinstance(module, cls): module.__class__ = type("Dependent" + type(module).__name__, (cls, type(module)), {}) module._reinit() return module