def gradient_beta(self, alpha, beta, C, eps, output_layer=None, return_loss=False, computation=None): """Compute the gradient of Sinkhorn relative to beta with autodiff.""" if computation is None: computation = self.gradient_computation alpha, beta, C = check_tensor(alpha, beta, C, device=self.device) if computation == 'autodiff': beta = check_tensor(beta, device=self.device, requires_grad=True) f, g, _ = self(alpha, beta, C, eps, output_layer=output_layer) res = self._get_grad_beta(f, g, alpha, beta, C, eps, return_loss=return_loss) if return_loss: return get_np(res[0]), get_np(res[1]) return get_np(res)
def forward(self, alpha, beta, C, eps, output_layer=None, log_iters=None, log_callbacks=DEFAULT_CALLBACKS): n_alpha, n_beta = C.shape if output_layer is None: output_layer = self.n_layers elif output_layer > self.n_layers: raise ValueError("Requested output from out-of-bound layer " "output_layer={} (n_layers={})" .format(output_layer, self.n_layers)) if log_iters is None: log_iters = [output_layer] if self.log_domain: g = torch.zeros_like(beta) else: v = torch.ones_like(beta) K = torch.exp(- C / eps) # Compute the following layers log = [] for id_layer in range(output_layer): if self.log_domain: g_hat = g f = eps * (torch.log(alpha) - log_dot_exp(C, g, eps)) g = eps * (torch.log(beta) - log_dot_exp(C.t(), f, eps)) else: v_hat = v u = alpha / torch.matmul(v, K.t()) v = beta / torch.matmul(u, K) # Check if the variables are not moving anymore. if self.tol is not None and id_layer % 10 == 0: if self.log_domain: err = torch.norm(g - g_hat) else: err = torch.norm(v - v_hat) if err < 1e-10: break if self.verbose > 0 and (id_layer + 1) % 100 == 0: print(f"{(id_layer + 1) / output_layer:6.1%}" + '\b'*6, end='', flush=True) if id_layer + 1 in log_iters: if not self.log_domain: f, g = eps * torch.log(u), eps * torch.log(v) rec = {k: get_np(CALLBACKS[k](self, f, g, alpha, beta, C, eps)) for k in log_callbacks} rec['iter'] = id_layer log.append(rec) if not self.log_domain: f, g = eps * torch.log(u), eps * torch.log(v) if log_iters is not None: return f, g, log return f, g, None
def transform(self, alpha, beta, C, eps, output_layer=None, log_iters=None, log_callbacks=DEFAULT_CALLBACKS, requires_grad=False): """Compute the dual variables associate to the transport plan. The transport plan can be recovered using the formula: P = exp(f / eps)[:, None] * exp(-C / eps) * exp (g / eps)[None] """ # Compat numpy alpha, beta, C = check_tensor(alpha, beta, C, device=self.device) beta = check_tensor(beta, requires_grad=True) with nullcontext() if requires_grad else torch.no_grad(): f, g, log = self(alpha, beta, C, eps, output_layer=output_layer, log_iters=log_iters, log_callbacks=log_callbacks) return (get_np(f), get_np(g)), log
def compute_loss(self, alpha, beta, C, eps, primal=False): """Compute the loss along the network's layers Parameters ---------- alpha : ndarray, shape (n_alpha,) First input distribution. beta: ndarray, shape (n_beta,) Second input distribution. C : ndarray, shape (n_alpha, n_beta) Cost matrix between the samples of each distribution. eps : float Entropic regularization parameter primal : boolean (default: False) If set to True, output the primal loss function. Else, output the dual loss. """ alpha, beta, C = check_tensor(alpha, beta, C, device=self.device) loss = [] with torch.no_grad(): for output_layer in range(self.n_layers): f, g, _ = self(alpha, beta, C, eps, output_layer=output_layer + 1) loss.append(get_np(self._loss_fn(f, g, alpha, beta, C, eps, primal=primal))) return np.array(loss)
def score(self, alpha, beta, C, eps, primal=False, output_layer=None): """Compute the loss for the network's output Parameters ---------- alpha : ndarray, shape (n_samples, n_alpha) First input distribution. beta: ndarray, shape (n_beta,) Second input distribution. C : ndarray, shape (n_alpha, n_beta) Cost matrix between the samples of each distribution. eps : float Entropic regularization parameter primal : boolean (default: False) If set to True, output the primal loss function. Else, output the dual loss. output_layer : int (default: None) Layer to output from. It should be smaller than the number of layers of the network. Ifs set to None, output the network's last layer. Return ------ loss : float Regularized logreg loss between x and Dz, with regularization reg """ alpha, beta, C = check_tensor(alpha, beta, C, device=self.device) with torch.no_grad(): f, g, _ = self(alpha, beta, C, eps, output_layer=output_layer) return get_np(self._loss_fn(f, g, alpha, beta, C, eps, primal=primal))
def _get_default_step(self, x, D): with torch.no_grad(): n_dim, _ = D.shape D_ = get_np(D) L_B = np.linalg.norm(D_.T.dot(D_**(self.p - 1)), ord=2) step = 1 / L_B step = .1 return step
def get_jacobian_beta(self, alpha, beta, C, eps, output_layer=None): """Compute the Jacobian of the scale dual variable g relative to beta. """ n_features = beta.shape alpha = check_tensor(alpha, device=self.device) beta = check_tensor(beta, device=self.device, require_grad=True) C = check_tensor(C, device=self.device) # Contruct the matrix to probe the jacobian beta = beta.squeeze() beta = beta.repeat(n_features, 1) f, g, _ = self(alpha, beta, C, eps, output_layer=output_layer) return get_np(torch.autograd.grad( g, beta, grad_outputs=torch.eye(n_features))[0])
def _get_default_step(self, x, D, reg): with torch.no_grad(): n_dim, _ = D.shape L_D = np.linalg.norm(get_np(D), ord=2)**2 / n_dim step = 1 / (L_D / 4 + reg) return step
def _get_default_step(self, x, A, B, C, u, v): with torch.no_grad(): n_dim, _ = B.shape L_B = np.linalg.norm(get_np(B), ord=2) step = 1 / L_B return step