def replace_in(model: nn.Module, target: nn.Module, replacement: nn.Module): # print("searching ", model.__class__.__name__) for name, submodule in model.named_children(): # print("is it member?", name, submodule == target) if submodule == target: # we found it! if isinstance(model, nn.ModuleList): # replace in module list model[name] = replacement elif isinstance(model, nn.Sequential): # replace in sequential layer model[int(name)] = replacement else: # replace as member model.__setattr__(name, replacement) # print("Replaced " + target.__class__.__name__ + " with "+replacement.__class__.__name__+" in " + model.__class__.__name__) return True elif len(list(submodule.named_children())) > 0: # print("Browsing {} children...".format(len(list(submodule.named_children())))) if replace_in(submodule, target, replacement): return True return False
def __setattr__(self, name, value): """ Attribute modifications ignore the recursive aspect of Pipes. """ if name in ['_flags', 'input']: object.__setattr__(self, name, value) else: if self._flags['components_initialized']: # self.update_components() self.components[name] = value self._flags['recursive_get'] = 0 Module.__setattr__(self, name, value) self._flags['recursive_get'] = 1
def __setattr__(self, name: str, value: Any): if name not in ['model']: Module.__setattr__(self, name, value) else: object.__setattr__(self, name, value)