def apply(module: Module, name: str, n_power_iterations: int, dim: int, eps: float, L: float) -> 'SpectralNorm': for k, hook in module._forward_pre_hooks.items(): if isinstance(hook, SpectralNorm) and hook.name == name: raise RuntimeError( "Cannot register two spectral_norm hooks on " "the same parameter {}".format(name)) fn = SpectralNorm(name, n_power_iterations, dim, eps, L=L) weight = module._parameters[name] with torch.no_grad(): weight_mat = fn.reshape_weight_to_matrix(weight) h, w = weight_mat.size() # randomly initialize `u` and `v` u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) delattr(module, fn.name) module.register_parameter(fn.name + "_orig", weight) # We still need to assign weight back as fn.name because all sorts of # things may assume that it exists, e.g., when initializing weights. # However, we can't directly assign as it could be an nn.Parameter and # gets added as a parameter. Instead, we register weight.data as a plain # attribute. setattr(module, fn.name, weight.data) module.register_buffer(fn.name + "_u", u) module.register_buffer(fn.name + "_v", v) module.register_forward_pre_hook(fn) module._register_state_dict_hook(SpectralNormStateDictHook(fn)) module._register_load_state_dict_pre_hook( SpectralNormLoadStateDictPreHook(fn)) return fn
def apply(module: Module, name: str, input_shape, n_power_iterations: int, dim: int, eps: float) -> 'SpectralNorm': for k, hook in module._forward_pre_hooks.items(): if isinstance(hook, SpectralNorm) and hook.name == name: raise RuntimeError( "Cannot register two spectral_norm hooks on " "the same parameter {}".format(name)) fn = SpectralNorm(name, n_power_iterations, dim, eps) weight = module._parameters[name] # if isinstance(weight, torch.nn.parameter.UninitializedParameter): # raise ValueError( # 'The module passed to `SpectralNorm` can\'t have uninitialized parameters. ' # 'Make sure to run the dummy forward before applying spectral normalization') with torch.no_grad(): v = weight.new_empty(1, *input_shape).normal_(0, 1) v = normalize(v.flatten(), dim=0, eps=fn.eps).reshape(1, *input_shape) # weight_mat = fn.reshape_weight_to_matrix(weight) # h, w = weight_mat.size() # # randomly initialize `u` and `v` # u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) # v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) delattr(module, fn.name) module.register_parameter(fn.name + "_orig", weight) # We still need to assign weight back as fn.name because all sorts of # things may assume that it exists, e.g., when initializing weights. # However, we can't directly assign as it could be an nn.Parameter and # gets added as a parameter. Instead, we register weight.data as a plain # attribute. setattr(module, fn.name, weight.data) # module.register_buffer(fn.name + "_u", u) module.register_buffer(fn.name + "_v", v) module.register_forward_pre_hook(fn) module._register_state_dict_hook(SpectralNormStateDictHook(fn)) module._register_load_state_dict_pre_hook( SpectralNormLoadStateDictPreHook(fn)) return fn
def apply(module: Module, name: str, n_power_iterations: int, dim: int, eps: float) -> 'QSpectralNorm': for k, hook in module._forward_pre_hooks.items(): if isinstance(hook, QSpectralNorm) and hook.name == name: raise RuntimeError( "Cannot register two spectral_norm hooks on " "the same parameter {}".format(name)) fn = QSpectralNorm(name, n_power_iterations, dim, eps) # weight = module._parameters[name] weight_r = module._parameters['r_weight'] weight_i = module._parameters['i_weight'] weight_j = module._parameters['j_weight'] weight_k = module._parameters['k_weight'] # weight = getattr(module, name) cat_kernels_4_r = torch.cat( [weight_r, -weight_i, -weight_j, -weight_k], dim=1) cat_kernels_4_i = torch.cat([weight_i, weight_r, -weight_k, weight_j], dim=1) cat_kernels_4_j = torch.cat([weight_j, weight_k, weight_r, -weight_i], dim=1) cat_kernels_4_k = torch.cat([weight_k, -weight_j, weight_i, weight_r], dim=1) weight = torch.cat([ cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k ], dim=0) # print(weight) # if isinstance(weight, torch.nn.parameter.UninitializedParameter): # raise ValueError( # 'The module passed to `SpectralNorm` can\'t have uninitialized parameters. ' # 'Make sure to run the dummy forward before applying spectral normalization') with torch.no_grad(): # weight_mat = fn.reshape_weight_to_matrix(weight) weight_mat = fn.reshape_weight_to_matrix(weight) h, w = weight_mat.size() # randomly initialize `u` and `v` u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) #delattr(module, fn.name) delattr(module, 'r_weight') delattr(module, 'i_weight') delattr(module, 'j_weight') delattr(module, 'k_weight') #module.register_parameter(fn.name + "_orig", weight) module.register_parameter('r_weight' + "_orig", weight_r) module.register_parameter('i_weight' + "_orig", weight_i) module.register_parameter('j_weight' + "_orig", weight_j) module.register_parameter('k_weight' + "_orig", weight_k) # We still need to assign weight back as fn.name because all sorts of # things may assume that it exists, e.g., when initializing weights. # However, we can't directly assign as it could be an nn.Parameter and # gets added as a parameter. Instead, we register weight.data as a plain # attribute. # setattr(module, fn.name, weight.data) setattr(module, 'r_weight', weight_r.data) setattr(module, 'i_weight', weight_i.data) setattr(module, 'j_weight', weight_j.data) setattr(module, 'k_weight', weight_k.data) module.register_buffer(fn.name + "_u", u) module.register_buffer(fn.name + "_v", v) # print(module.r_weight) module.register_forward_pre_hook(fn) module._register_state_dict_hook(QSpectralNormStateDictHook(fn)) module._register_load_state_dict_pre_hook( QSpectralNormLoadStateDictPreHook(fn)) return fn