def apply( module: Module, name: str, coeff: float, n_power_iterations: int, dim: int, eps: float, ) -> "SpectralNorm": for k, hook in module._forward_pre_hooks.items(): if isinstance(hook, SpectralNorm) and hook.name == name: raise RuntimeError("Cannot register two spectral_norm hooks on " "the same parameter {}".format(name)) fn = SpectralNorm(coeff, name, n_power_iterations, dim, eps) weight = module._parameters[name] with torch.no_grad(): weight_mat = fn.reshape_weight_to_matrix(weight) h, w = weight_mat.size() # randomly initialize `u` and `v` u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) delattr(module, fn.name) module.register_parameter(fn.name + "_orig", weight) # We still need to assign weight back as fn.name because all sorts of # things may assume that it exists, e.g., when initializing weights. # However, we can't directly assign as it could be an nn.Parameter and # gets added as a parameter. Instead, we register weight.data as a plain # attribute. setattr(module, fn.name, weight.data) module.register_buffer(fn.name + "_u", u) module.register_buffer(fn.name + "_v", v) module.register_buffer(fn.name + "_sigma", torch.ones(1).to(weight.device)) module.register_forward_pre_hook(fn) module._register_state_dict_hook(SpectralNormStateDictHook(fn)) module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn)) return fn
def remove(self, module: Module) -> None: with torch.no_grad(): weight = self.compute_weight(module, do_power_iteration=False) delattr(module, self.name) delattr(module, self.name + "_u") delattr(module, self.name + "_v") delattr(module, self.name + "_orig") module.register_parameter(self.name, torch.nn.Parameter(weight.detach()))