Example #1
0
    def apply(module: Module, name: str, n_power_iterations: int, dim: int,
              eps: float, L: float) -> 'SpectralNorm':
        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, SpectralNorm) and hook.name == name:
                raise RuntimeError(
                    "Cannot register two spectral_norm hooks on "
                    "the same parameter {}".format(name))

        fn = SpectralNorm(name, n_power_iterations, dim, eps, L=L)
        weight = module._parameters[name]

        with torch.no_grad():
            weight_mat = fn.reshape_weight_to_matrix(weight)

            h, w = weight_mat.size()
            # randomly initialize `u` and `v`
            u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
            v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)

        delattr(module, fn.name)
        module.register_parameter(fn.name + "_orig", weight)
        # We still need to assign weight back as fn.name because all sorts of
        # things may assume that it exists, e.g., when initializing weights.
        # However, we can't directly assign as it could be an nn.Parameter and
        # gets added as a parameter. Instead, we register weight.data as a plain
        # attribute.
        setattr(module, fn.name, weight.data)
        module.register_buffer(fn.name + "_u", u)
        module.register_buffer(fn.name + "_v", v)

        module.register_forward_pre_hook(fn)
        module._register_state_dict_hook(SpectralNormStateDictHook(fn))
        module._register_load_state_dict_pre_hook(
            SpectralNormLoadStateDictPreHook(fn))
        return fn
Example #2
0
    def apply(module: Module, name: str, input_shape, n_power_iterations: int,
              dim: int, eps: float) -> 'SpectralNorm':
        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, SpectralNorm) and hook.name == name:
                raise RuntimeError(
                    "Cannot register two spectral_norm hooks on "
                    "the same parameter {}".format(name))

        fn = SpectralNorm(name, n_power_iterations, dim, eps)
        weight = module._parameters[name]
        # if isinstance(weight, torch.nn.parameter.UninitializedParameter):
        #     raise ValueError(
        #         'The module passed to `SpectralNorm` can\'t have uninitialized parameters. '
        #         'Make sure to run the dummy forward before applying spectral normalization')

        with torch.no_grad():
            v = weight.new_empty(1, *input_shape).normal_(0, 1)
            v = normalize(v.flatten(), dim=0,
                          eps=fn.eps).reshape(1, *input_shape)
            # weight_mat = fn.reshape_weight_to_matrix(weight)

            # h, w = weight_mat.size()
            # # randomly initialize `u` and `v`
            # u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
            # v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)

        delattr(module, fn.name)
        module.register_parameter(fn.name + "_orig", weight)
        # We still need to assign weight back as fn.name because all sorts of
        # things may assume that it exists, e.g., when initializing weights.
        # However, we can't directly assign as it could be an nn.Parameter and
        # gets added as a parameter. Instead, we register weight.data as a plain
        # attribute.
        setattr(module, fn.name, weight.data)
        # module.register_buffer(fn.name + "_u", u)
        module.register_buffer(fn.name + "_v", v)

        module.register_forward_pre_hook(fn)
        module._register_state_dict_hook(SpectralNormStateDictHook(fn))
        module._register_load_state_dict_pre_hook(
            SpectralNormLoadStateDictPreHook(fn))
        return fn
Example #3
0
    def apply(module: Module, name: str, n_power_iterations: int, dim: int,
              eps: float) -> 'QSpectralNorm':
        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, QSpectralNorm) and hook.name == name:
                raise RuntimeError(
                    "Cannot register two spectral_norm hooks on "
                    "the same parameter {}".format(name))

        fn = QSpectralNorm(name, n_power_iterations, dim, eps)
        # weight = module._parameters[name]
        weight_r = module._parameters['r_weight']
        weight_i = module._parameters['i_weight']
        weight_j = module._parameters['j_weight']
        weight_k = module._parameters['k_weight']
        # weight = getattr(module, name)
        cat_kernels_4_r = torch.cat(
            [weight_r, -weight_i, -weight_j, -weight_k], dim=1)
        cat_kernels_4_i = torch.cat([weight_i, weight_r, -weight_k, weight_j],
                                    dim=1)
        cat_kernels_4_j = torch.cat([weight_j, weight_k, weight_r, -weight_i],
                                    dim=1)
        cat_kernels_4_k = torch.cat([weight_k, -weight_j, weight_i, weight_r],
                                    dim=1)
        weight = torch.cat([
            cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k
        ],
                           dim=0)
        # print(weight)
        # if isinstance(weight, torch.nn.parameter.UninitializedParameter):
        #     raise ValueError(
        #         'The module passed to `SpectralNorm` can\'t have uninitialized parameters. '
        #         'Make sure to run the dummy forward before applying spectral normalization')

        with torch.no_grad():
            # weight_mat = fn.reshape_weight_to_matrix(weight)
            weight_mat = fn.reshape_weight_to_matrix(weight)
            h, w = weight_mat.size()
            # randomly initialize `u` and `v`
            u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
            v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)

        #delattr(module, fn.name)
        delattr(module, 'r_weight')
        delattr(module, 'i_weight')
        delattr(module, 'j_weight')
        delattr(module, 'k_weight')

        #module.register_parameter(fn.name + "_orig", weight)
        module.register_parameter('r_weight' + "_orig", weight_r)
        module.register_parameter('i_weight' + "_orig", weight_i)
        module.register_parameter('j_weight' + "_orig", weight_j)
        module.register_parameter('k_weight' + "_orig", weight_k)
        # We still need to assign weight back as fn.name because all sorts of
        # things may assume that it exists, e.g., when initializing weights.
        # However, we can't directly assign as it could be an nn.Parameter and
        # gets added as a parameter. Instead, we register weight.data as a plain
        # attribute.
        # setattr(module, fn.name, weight.data)
        setattr(module, 'r_weight', weight_r.data)
        setattr(module, 'i_weight', weight_i.data)
        setattr(module, 'j_weight', weight_j.data)
        setattr(module, 'k_weight', weight_k.data)
        module.register_buffer(fn.name + "_u", u)
        module.register_buffer(fn.name + "_v", v)
        # print(module.r_weight)
        module.register_forward_pre_hook(fn)
        module._register_state_dict_hook(QSpectralNormStateDictHook(fn))
        module._register_load_state_dict_pre_hook(
            QSpectralNormLoadStateDictPreHook(fn))
        return fn