Ejemplo n.º 1
0
    def gradient_beta(self, alpha, beta, C, eps, output_layer=None,
                      return_loss=False, computation=None):
        """Compute the gradient of Sinkhorn relative to beta with autodiff."""
        if computation is None:
            computation = self.gradient_computation

        alpha, beta, C = check_tensor(alpha, beta, C, device=self.device)
        if computation == 'autodiff':
            beta = check_tensor(beta, device=self.device, requires_grad=True)
        f, g, _ = self(alpha, beta, C, eps, output_layer=output_layer)
        res = self._get_grad_beta(f, g, alpha, beta, C, eps,
                                  return_loss=return_loss)
        if return_loss:
            return get_np(res[0]), get_np(res[1])
        return get_np(res)
Ejemplo n.º 2
0
    def forward(self, alpha, beta, C, eps, output_layer=None, log_iters=None,
                log_callbacks=DEFAULT_CALLBACKS):

        n_alpha, n_beta = C.shape

        if output_layer is None:
            output_layer = self.n_layers
        elif output_layer > self.n_layers:
            raise ValueError("Requested output from out-of-bound layer "
                             "output_layer={} (n_layers={})"
                             .format(output_layer, self.n_layers))

        if log_iters is None:
            log_iters = [output_layer]

        if self.log_domain:
            g = torch.zeros_like(beta)
        else:
            v = torch.ones_like(beta)
            K = torch.exp(- C / eps)

        # Compute the following layers
        log = []
        for id_layer in range(output_layer):
            if self.log_domain:
                g_hat = g
                f = eps * (torch.log(alpha) - log_dot_exp(C, g, eps))
                g = eps * (torch.log(beta) - log_dot_exp(C.t(), f, eps))
            else:
                v_hat = v
                u = alpha / torch.matmul(v, K.t())
                v = beta / torch.matmul(u, K)

            # Check if the variables are not moving anymore.
            if self.tol is not None and id_layer % 10 == 0:
                if self.log_domain:
                    err = torch.norm(g - g_hat)
                else:
                    err = torch.norm(v - v_hat)
                if err < 1e-10:
                    break

            if self.verbose > 0 and (id_layer + 1) % 100 == 0:
                print(f"{(id_layer + 1) / output_layer:6.1%}" + '\b'*6,
                      end='', flush=True)
            if id_layer + 1 in log_iters:
                if not self.log_domain:
                    f, g = eps * torch.log(u), eps * torch.log(v)
                rec = {k: get_np(CALLBACKS[k](self, f, g, alpha,
                                              beta, C, eps))
                       for k in log_callbacks}
                rec['iter'] = id_layer
                log.append(rec)

        if not self.log_domain:
            f, g = eps * torch.log(u), eps * torch.log(v)

        if log_iters is not None:
            return f, g, log
        return f, g, None
Ejemplo n.º 3
0
    def transform(self, alpha, beta, C, eps, output_layer=None, log_iters=None,
                  log_callbacks=DEFAULT_CALLBACKS, requires_grad=False):
        """Compute the dual variables associate to the transport plan.

        The transport plan can be recovered using the formula:
            P = exp(f / eps)[:, None] * exp(-C / eps) * exp (g / eps)[None]
        """
        # Compat numpy
        alpha, beta, C = check_tensor(alpha, beta, C, device=self.device)
        beta = check_tensor(beta, requires_grad=True)

        with nullcontext() if requires_grad else torch.no_grad():
            f, g, log = self(alpha, beta, C, eps, output_layer=output_layer,
                             log_iters=log_iters, log_callbacks=log_callbacks)

        return (get_np(f), get_np(g)), log
Ejemplo n.º 4
0
    def compute_loss(self, alpha, beta, C, eps, primal=False):
        """Compute the loss  along the network's layers

        Parameters
        ----------
        alpha : ndarray, shape (n_alpha,)
            First input distribution.
        beta: ndarray, shape (n_beta,)
            Second input distribution.
        C : ndarray, shape (n_alpha, n_beta)
            Cost matrix between the samples of each distribution.
        eps : float
            Entropic regularization parameter
        primal : boolean (default: False)
            If set to True, output the primal loss function. Else, output the
            dual loss.
        """
        alpha, beta, C = check_tensor(alpha, beta, C, device=self.device)
        loss = []
        with torch.no_grad():
            for output_layer in range(self.n_layers):
                f, g, _ = self(alpha, beta, C, eps,
                               output_layer=output_layer + 1)
                loss.append(get_np(self._loss_fn(f, g, alpha, beta, C, eps,
                                                 primal=primal)))
        return np.array(loss)
Ejemplo n.º 5
0
    def score(self, alpha, beta, C, eps, primal=False, output_layer=None):
        """Compute the loss for the network's output

        Parameters
        ----------
        alpha : ndarray, shape (n_samples, n_alpha)
            First input distribution.
        beta: ndarray, shape (n_beta,)
            Second input distribution.
        C : ndarray, shape (n_alpha, n_beta)
            Cost matrix between the samples of each distribution.
        eps : float
            Entropic regularization parameter
        primal : boolean (default: False)
            If set to True, output the primal loss function. Else, output the
            dual loss.
        output_layer : int (default: None)
            Layer to output from. It should be smaller than the number of
            layers of the network. Ifs set to None, output the network's last
            layer.

        Return
        ------
        loss : float
            Regularized logreg loss between x and Dz, with regularization reg
        """
        alpha, beta, C = check_tensor(alpha, beta, C, device=self.device)
        with torch.no_grad():
            f, g, _ = self(alpha, beta, C, eps, output_layer=output_layer)
            return get_np(self._loss_fn(f, g, alpha, beta, C, eps,
                                        primal=primal))
Ejemplo n.º 6
0
 def _get_default_step(self, x, D):
     with torch.no_grad():
         n_dim, _ = D.shape
         D_ = get_np(D)
         L_B = np.linalg.norm(D_.T.dot(D_**(self.p - 1)), ord=2)
         step = 1 / L_B
         step = .1
     return step
Ejemplo n.º 7
0
    def get_jacobian_beta(self, alpha, beta, C, eps, output_layer=None):
        """Compute the Jacobian of the scale dual variable g relative to beta.
        """
        n_features = beta.shape

        alpha = check_tensor(alpha, device=self.device)
        beta = check_tensor(beta, device=self.device, require_grad=True)
        C = check_tensor(C, device=self.device)

        # Contruct the matrix to probe the jacobian
        beta = beta.squeeze()
        beta = beta.repeat(n_features, 1)
        f, g, _ = self(alpha, beta, C, eps, output_layer=output_layer)
        return get_np(torch.autograd.grad(
            g, beta, grad_outputs=torch.eye(n_features))[0])
Ejemplo n.º 8
0
 def _get_default_step(self, x, D, reg):
     with torch.no_grad():
         n_dim, _ = D.shape
         L_D = np.linalg.norm(get_np(D), ord=2)**2 / n_dim
         step = 1 / (L_D / 4 + reg)
     return step
Ejemplo n.º 9
0
 def _get_default_step(self, x, A, B, C, u, v):
     with torch.no_grad():
         n_dim, _ = B.shape
         L_B = np.linalg.norm(get_np(B), ord=2)
         step = 1 / L_B
     return step