Exemple #1
0
    def test_almost_orthogonal(self):
        def is_almost_orthogonal(net, lam):
            M = net.parametrizations.weight
            U_orig, S_orig, V_orig = M.original
            S_orig = 1.0 + lam * M.f(S_orig)
            self.assertIsOrthogonal(U_orig)
            self.assertIsOrthogonal(V_orig)
            self.assertHasSingularValues(net.weight, S_orig)
            dist = torch.abs(S_orig - 1.0)
            self.assertTrue((dist < lam + 1e-6).all())

        net = nn.Linear(6, 2)
        geotorch.almost_orthogonal(net, "weight", lam=0.5)
        is_almost_orthogonal(net, 0.5)
        net = nn.Linear(7, 7)
        geotorch.almost_orthogonal(net, "weight", lam=0.3, triv="cayley")
        is_almost_orthogonal(net, 0.3)
        geotorch.almost_orthogonal(net,
                                   "weight",
                                   lam=1.0,
                                   f="tanh",
                                   triv="cayley")
        is_almost_orthogonal(net, 1.0)
        # Try to instantiate it in a vector rather than a matrix
        with self.assertRaises(ValueError):
            geotorch.orthogonal(net, "bias")
Exemple #2
0
    def __init__(self, input_size, hidden_size):
        super(ExpRNNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.recurrent_kernel = nn.Linear(hidden_size, hidden_size, bias=False)
        self.input_kernel = nn.Linear(input_size, hidden_size)
        self.nonlinearity = modrelu(hidden_size)

        # Make recurrent_kernel orthogonal
        geotorch.orthogonal(self.recurrent_kernel, "weight")

        self.reset_parameters()
Exemple #3
0
 def test_orthogonal(self):
     net = nn.Linear(6, 1)
     geotorch.orthogonal(net, "weight")
     self.assertIsOrthogonal(net.weight)
     net = nn.Linear(7, 4)
     geotorch.orthogonal(net, "weight")
     self.assertIsOrthogonal(net.weight)
     net = nn.Linear(7, 7)
     geotorch.orthogonal(net, "weight", triv="cayley")
     self.assertIsOrthogonal(net.weight)
     # Try to instantiate it in a vector rather than a matrix
     with self.assertRaises(ValueError):
         geotorch.orthogonal(net, "bias")
Exemple #4
0
    def __init__(self, input_size, hidden_size):
        super(ExpRNNCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.recurrent_kernel = nn.Linear(hidden_size, hidden_size, bias=False)
        self.input_kernel = nn.Linear(input_size, hidden_size)
        self.nonlinearity = modrelu(hidden_size)

        # Make recurrent_kernel orthogonal
        if args.constraints == "orthogonal":
            geotorch.orthogonal(self.recurrent_kernel, "weight")
        elif args.constraints == "lowrank":
            geotorch.lowrank(self.recurrent_kernel, "weight", hidden_size)
        elif args.constraints == "almostorthogonal":
            geotorch.almost_orthogonal(self.recurrent_kernel, "weight", args.r,
                                       args.f)
        else:
            raise ValueError("Unexpected constraints. Got {}".format(
                args.constraints))

        self.reset_parameters()
Exemple #5
0
    def __init__(self,
                 backbone: nn.Module,
                 heads: List[nn.Module],
                 latent_size: int,
                 *args,
                 burn_in_period: int = 20,
                 normalize_losses: bool = False):
        super(RotateOnly, self).__init__()
        num_tasks = len(heads)

        for i in range(num_tasks):
            heads[i] = nn.Sequential(RotateModule(self, i), heads[i])

        self._backbone = [backbone]
        self.heads = heads

        # Parameterize rotations so we can run unconstrained optimization
        for i in range(num_tasks):
            self.register_parameter(
                f'rotation_{i}',
                nn.Parameter(torch.eye(latent_size), requires_grad=True))
            orthogonal(
                self, f'rotation_{i}',
                triv='expm')  # uses exponential map (alternative: cayley)

        # Parameters
        self.num_tasks = num_tasks
        self.latent_size = latent_size
        self.burn_in_period = burn_in_period
        self.normalize_losses = normalize_losses

        self.rep = None
        self.grads = [None for _ in range(num_tasks)]
        self.original_grads = [None for _ in range(num_tasks)]
        self.losses = [None for _ in range(num_tasks)]
        self.initial_losses = [None for _ in range(num_tasks)]
        self.initial_backbone_loss = None
        self.iteration_counter = 0
def _find_rotation_lbfgs(
    X,
    Y,
    tol=1e-6,
    max_iter=100,
    verbose=True,
    center_columns=True,
):
    """
    Finds orthogonal matrix Q, scaling s, and translation b, to

        minimize   sum(norm(X - s * Y @ Q - b)).

    Note that the solution is not in closed form because we are
    minimizing the sum of norms, which is non-trivial given the
    orthogonality constraint on Q. Without the orthogonality
    constraint, the problem can be formulated as a cone program:

        Guoliang Xue & Yinyu Ye (2000). "An Efficient Algorithm for
        Minimizing a Sum of p-Norms." SIAM J. Optim., 10(2), 551–579.

    However, the orthogonality constraint complicates things, so
    we just minimize by gradient methods used in manifold optimization.

        Mario Lezcano-Casado (2019). "Trivializations for gradient-based
        optimization on manifolds." NeurIPS.
    """

    # Convert X and Y to pytorch tensors.
    X = torch.tensor(X)
    Y = torch.tensor(Y)

    # Check inputs.
    m, n = X.shape
    assert Y.shape == X.shape

    # Orthogonal linear transformation.
    Q = nn.Linear(n, n, bias=False)
    geotorch.orthogonal(Q, "weight")
    Q = Q.double()

    # Allow a rigid translation.
    bias = nn.Parameter(torch.zeros(n, dtype=torch.float64))

    # Collect trainable parameters
    trainable_params = list(Q.parameters())

    if center_columns:
        trainable_params.append(bias)

    # Define rotational alignment, and optimizer.
    optimizer = LBFGS(
        trainable_params,
        max_iter=100,  # number of inner iterations.
        line_search_fn="strong_wolfe",
    )

    def closure():
        optimizer.zero_grad()
        loss = torch.mean(torch.norm(X - Q(Y) - bias, dim=1))
        loss.backward()
        return loss

    # Fit parameters.
    converged = False
    itercount = 0
    while (not converged) and (itercount < max_iter):

        # Update parameters.
        new_loss = optimizer.step(closure).item()

        # Check convergence.
        if itercount != 0:
            improvement = (last_loss - new_loss) / last_loss
            converged = improvement < tol

        last_loss = new_loss

        # Display progress.
        itercount += 1
        if verbose:
            print(f"Iter {itercount}: {last_loss}")
            if converged:
                print("Converged!")

    # Extract result in numpy.
    Q_ = Q.weight.detach().numpy()
    bias_ = bias.detach().numpy()

    return Q_, bias_