def test_almost_orthogonal(self): def is_almost_orthogonal(net, lam): M = net.parametrizations.weight U_orig, S_orig, V_orig = M.original S_orig = 1.0 + lam * M.f(S_orig) self.assertIsOrthogonal(U_orig) self.assertIsOrthogonal(V_orig) self.assertHasSingularValues(net.weight, S_orig) dist = torch.abs(S_orig - 1.0) self.assertTrue((dist < lam + 1e-6).all()) net = nn.Linear(6, 2) geotorch.almost_orthogonal(net, "weight", lam=0.5) is_almost_orthogonal(net, 0.5) net = nn.Linear(7, 7) geotorch.almost_orthogonal(net, "weight", lam=0.3, triv="cayley") is_almost_orthogonal(net, 0.3) geotorch.almost_orthogonal(net, "weight", lam=1.0, f="tanh", triv="cayley") is_almost_orthogonal(net, 1.0) # Try to instantiate it in a vector rather than a matrix with self.assertRaises(ValueError): geotorch.orthogonal(net, "bias")
def __init__(self, input_size, hidden_size): super(ExpRNNCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.recurrent_kernel = nn.Linear(hidden_size, hidden_size, bias=False) self.input_kernel = nn.Linear(input_size, hidden_size) self.nonlinearity = modrelu(hidden_size) # Make recurrent_kernel orthogonal geotorch.orthogonal(self.recurrent_kernel, "weight") self.reset_parameters()
def test_orthogonal(self): net = nn.Linear(6, 1) geotorch.orthogonal(net, "weight") self.assertIsOrthogonal(net.weight) net = nn.Linear(7, 4) geotorch.orthogonal(net, "weight") self.assertIsOrthogonal(net.weight) net = nn.Linear(7, 7) geotorch.orthogonal(net, "weight", triv="cayley") self.assertIsOrthogonal(net.weight) # Try to instantiate it in a vector rather than a matrix with self.assertRaises(ValueError): geotorch.orthogonal(net, "bias")
def __init__(self, input_size, hidden_size): super(ExpRNNCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.recurrent_kernel = nn.Linear(hidden_size, hidden_size, bias=False) self.input_kernel = nn.Linear(input_size, hidden_size) self.nonlinearity = modrelu(hidden_size) # Make recurrent_kernel orthogonal if args.constraints == "orthogonal": geotorch.orthogonal(self.recurrent_kernel, "weight") elif args.constraints == "lowrank": geotorch.lowrank(self.recurrent_kernel, "weight", hidden_size) elif args.constraints == "almostorthogonal": geotorch.almost_orthogonal(self.recurrent_kernel, "weight", args.r, args.f) else: raise ValueError("Unexpected constraints. Got {}".format( args.constraints)) self.reset_parameters()
def __init__(self, backbone: nn.Module, heads: List[nn.Module], latent_size: int, *args, burn_in_period: int = 20, normalize_losses: bool = False): super(RotateOnly, self).__init__() num_tasks = len(heads) for i in range(num_tasks): heads[i] = nn.Sequential(RotateModule(self, i), heads[i]) self._backbone = [backbone] self.heads = heads # Parameterize rotations so we can run unconstrained optimization for i in range(num_tasks): self.register_parameter( f'rotation_{i}', nn.Parameter(torch.eye(latent_size), requires_grad=True)) orthogonal( self, f'rotation_{i}', triv='expm') # uses exponential map (alternative: cayley) # Parameters self.num_tasks = num_tasks self.latent_size = latent_size self.burn_in_period = burn_in_period self.normalize_losses = normalize_losses self.rep = None self.grads = [None for _ in range(num_tasks)] self.original_grads = [None for _ in range(num_tasks)] self.losses = [None for _ in range(num_tasks)] self.initial_losses = [None for _ in range(num_tasks)] self.initial_backbone_loss = None self.iteration_counter = 0
def _find_rotation_lbfgs( X, Y, tol=1e-6, max_iter=100, verbose=True, center_columns=True, ): """ Finds orthogonal matrix Q, scaling s, and translation b, to minimize sum(norm(X - s * Y @ Q - b)). Note that the solution is not in closed form because we are minimizing the sum of norms, which is non-trivial given the orthogonality constraint on Q. Without the orthogonality constraint, the problem can be formulated as a cone program: Guoliang Xue & Yinyu Ye (2000). "An Efficient Algorithm for Minimizing a Sum of p-Norms." SIAM J. Optim., 10(2), 551–579. However, the orthogonality constraint complicates things, so we just minimize by gradient methods used in manifold optimization. Mario Lezcano-Casado (2019). "Trivializations for gradient-based optimization on manifolds." NeurIPS. """ # Convert X and Y to pytorch tensors. X = torch.tensor(X) Y = torch.tensor(Y) # Check inputs. m, n = X.shape assert Y.shape == X.shape # Orthogonal linear transformation. Q = nn.Linear(n, n, bias=False) geotorch.orthogonal(Q, "weight") Q = Q.double() # Allow a rigid translation. bias = nn.Parameter(torch.zeros(n, dtype=torch.float64)) # Collect trainable parameters trainable_params = list(Q.parameters()) if center_columns: trainable_params.append(bias) # Define rotational alignment, and optimizer. optimizer = LBFGS( trainable_params, max_iter=100, # number of inner iterations. line_search_fn="strong_wolfe", ) def closure(): optimizer.zero_grad() loss = torch.mean(torch.norm(X - Q(Y) - bias, dim=1)) loss.backward() return loss # Fit parameters. converged = False itercount = 0 while (not converged) and (itercount < max_iter): # Update parameters. new_loss = optimizer.step(closure).item() # Check convergence. if itercount != 0: improvement = (last_loss - new_loss) / last_loss converged = improvement < tol last_loss = new_loss # Display progress. itercount += 1 if verbose: print(f"Iter {itercount}: {last_loss}") if converged: print("Converged!") # Extract result in numpy. Q_ = Q.weight.detach().numpy() bias_ = bias.detach().numpy() return Q_, bias_