Esempio n. 1
0
    def _make_params(self):
        w = getattr(self.module, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2_normalize(u.data)
        v.data = l2_normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module._parameters[self.name]

        self.module.register_parameter(self.name + "_u", u)
        self.module.register_parameter(self.name + "_v", v)
        self.module.register_parameter(self.name + "_bar", w_bar)
Esempio n. 2
0
    def _make_params(self):

        w = getattr(self.module, 'weight')

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1),
                      requires_grad=False).to(get_device())
        v = Parameter(w.data.new(width).normal_(0, 1),
                      requires_grad=False).to(get_device())
        u.data = l2_normalize(u.data)
        v.data = l2_normalize(v.data)
        w_bar = Parameter(w.data).to(get_device())

        del self.module._parameters['weight']

        self.module.register_parameter('weight' + "_u", u)
        self.module.register_parameter('weight' + "_v", v)
        self.module.register_parameter('weight' + "_bar", w_bar)
Esempio n. 3
0
def remove_parametrization(module, tensor_name, leave_parametrized=True):
    r"""Removes parametrizations active on the parameter ``tensor_name``.
    If ``leave_parametrized == True``, ``module[tensor_name]`` will be set to
    its current output: the parametrized tensor.
    If ``leave_parametrized == False``, ``module[tensor_name]`` will be set to
    its unparametrized value, that is,
    ``module.parametrizations.[tensor_name].original_tensor()``

    .. warning ::

        If the parametrization changes the size of the tensor and the parametrization
        is on a parameter being optimized, since this function will register a new
        parameter, the parameters on the optimizer have to be manually updated via
        ``optim.params = model.parameters()`` after calling this method.

    Args:
        module (nn.Module): module from which remove the parametrization
        tensor_name (str): name of the parametrization to be removed
        leave_parametrized (bool, optional): leave the attribute ``tensor_name``
        parametrized or not. Default: False
    """

    if not is_parametrized(module, tensor_name):
        raise ValueError(
            "Module {} does not have a parametrization on {}".format(
                module, tensor_name))

    # TODO
    # We implement the removal recursively
    parametrization = module.parametrizations[tensor_name]
    original = parametrization.original_tensor()
    # Parametrization on a parameter or a buffer
    is_parameter = isinstance(original, Parameter)
    if leave_parametrized:
        t = getattr(module, tensor_name)
        if t.size() != original.size():
            if is_parameter:
                original = Parameter(t)
            else:
                original = t
        else:
            original.data = t

    # Remove the caching mechanism if it has one
    remove_caching(module, tensor_name)
    # Delete the property that manages the parametrization
    delattr(module.__class__, tensor_name)
    # Delete the parametrization
    delattr(module.parametrizations, tensor_name)

    if is_parameter:
        module.register_parameter(tensor_name, original)
    else:
        module.register_buffer(tensor_name, original)

    # Roll back the fancy parametrized class if no other
    # buffer or parameter is currently parametrized
    if not is_parametrized(module):
        # Delete the associated class
        del globals()[module.__class__.__qualname__]
        # Restore class
        parents = module.__class__.__bases__
        # If everything's working as expected, this should never throw
        if len(parents) != 1:
            raise TypeError(
                "Found a Parametrized module with more than "
                "one parent class. This is currently not supported.")
        module.__class__ = parents[0]
        delattr(module, "parametrizations")