Esempio n. 1
0
    def first_grads(self):
        r""" To do.
        Args:
            x : To do.
        Returns:
            To do.
        Example:
            >>> self.train_points
            tensor([[0.2513, 0.6246],
            [0.5098, 0.8833],
            [0.1971, 0.7218]], device='cuda:0')

            >>> first_grads()
            tensor([[0.3441, 0.1167],
                    [0.4489, 0.2082],
                    [0.3500, 0.1231]], 
                    device='cuda:0', grad_fn=<MmBackward>)

        """
        
        v = torch.ones(self.train_points.shape[0]).to(self.device)
        predict_vjp = vjp(self.vjp_reducer, self.train_points,v, create_graph=True)

        predict_value = predict_vjp[0]
        first_grad = predict_vjp[1]
        return first_grad
Esempio n. 2
0
    def mvhp(self, train_points):
        r""" vhp for multi outputs. This function is as fast as second_grads function.
        Returns:
            To do;
        Example:
            >>> self.train_points
            tensor([[0.0293, 0.2130],
                    [0.6103, 0.7452],
                    [0.6648, 0.5782]], device='cuda:0')

            >>> mvjp(outdim)
            tensor([[0.5848, 0.2233, 0.2233, 0.0852, 0.0860, 0.1133, 0.1133, 0.1492],
                    [0.3434, 0.1311, 0.1311, 0.0501, 0.0601, 0.0792, 0.0792, 0.1043],
                    [0.3456, 0.1319, 0.1319, 0.0504, 0.0630, 0.0829, 0.0829, 0.1092]],
                   device='cuda:0', grad_fn=<CatBackward>)

        """
        m,n,k = self.outshape[0], self.outshape[1] * self.inshape[1], self.inshape[1] ** 2
        second_grad_list = []
        for i in range(n):
            v = torch.cat([torch.ones(m,1)*(i==j) for j in range(n)],1).cuda()
            hessian_vjp = vjp(self.mvjp, train_points, v, create_graph=True)[1]
            second_grad_list.append(hessian_vjp)
        second_grad = torch.cat(second_grad_list,1)
        second_grad = torch.split(second_grad, k, 1)
        return second_grad
Esempio n. 3
0
    def mvjp(self, train_points):
        r""" vjp for multi outputs.  This function is as fast as first_grads function.
        Args:
            outdim: output shape 
        Returns:
            [[(dy_1)^1/dx_1, (dy_1)^1/dx_2, (dy_1)^2/dx_1, (dy_1)^2/dx_2],
             [(dy_2)^1/dx_1, (dy_2)^1/dx_2, (dy_2)^2/dx_1, (dy_2)^2/dx_2],
             [(dy_3)^1/dx_1, (dy_3)^1/dx_2, (dy_3)^2/dx_1, (dy_3)^2/dx_2]]
        Example:
            >>> self.train_points
            tensor([[0.3191, 0.2468],
                    [0.0719, 0.2555],
                    [0.7084, 0.0403]], device='cuda:0')

            >>> mvjp(outdim)
            tensor([[ 0.2953, -0.2649, -0.0216,  0.4649],
                    [ 0.2564, -0.2301, -0.0218,  0.4693],
                    [ 0.4057, -0.3640, -0.0194,  0.4185]], device='cuda:0', grad_fn=<MmBackward>)

        """
        m,n = self.outshape[0], self.outshape[1]
        first_grad_list = []
        for i in range(n):
            v = torch.cat([torch.ones(m,1)*(i==j) for j in range(n)],1).cuda()
            jacobian_vjp = vjp(self.mvjp_reducer, train_points, v, create_graph=True)[1]
            first_grad_list.append(jacobian_vjp)
        first_grad = torch.cat(first_grad_list,1)
        
        return first_grad
Esempio n. 4
0
def test_pruning_qp_match_slow_backward():
    """Test c++ vs pure python algo: backward pass"""
    depth = 3
    n_nodes = 2 ** (depth + 1) - 1

    torch.manual_seed(42)
    grad_ds = torch.randn(23, n_nodes)
    grad_ds /= torch.norm(grad_ds, dim=1).unsqueeze(1)

    def pruning_qp_sl_(q, eta):
        return pruning_qp_slow(q, eta, BinarySearchTree(depth))

    for seed in range(10):
        data = make_data(depth, seed=seed)

        for k in range(grad_ds.shape[0]):
            grad_d = grad_ds[k]
            _, (grad_q_fa, grad_eta_fa) = vjp(pruning_qp, data, grad_d)
            _, (grad_q_sl, grad_eta_sl) = vjp(pruning_qp_sl_, data, grad_d)
            assert torch.allclose(grad_q_fa, grad_q_sl)
            assert torch.allclose(grad_eta_fa, grad_eta_sl)
Esempio n. 5
0
def kmeans(max_iter, clusters, features, tolerance=1):
    t = 0
    converged = False
    hes_v = torch.ones_like(clusters)
    while t < max_iter and not converged:
        _, jac = vjp(partial(cost, features), clusters, v=torch.tensor(1.0))
        _, hes = vhp(partial(cost, features), clusters, v=hes_v)

        new_cluster = clusters - jac / hes
        converged = ((new_cluster - clusters)**2).sum() < tolerance
        clusters = new_cluster
        t += 1
    return clusters
 def augmented_dynamics(t, aug_state, p, interpolation):
     adj = aug_state[:self._n_states]
     y = interpolation.sol(t)
     with torch.enable_grad():
         t_ = torch.as_tensor(t, dtype=torch.float)
         y_ = torch.as_tensor(y, dtype=torch.float)
         p_ = torch.as_tensor(p, dtype=torch.float)
         adj_ = torch.as_tensor(adj, dtype=torch.float)
         dL_y, _, dL_p = functional.vjp(
             lambda y, t, p, tch=True: self._rhs(y, t, p, tch),
             (y_, t_, p_),
             -adj_,
             strict=False,
             create_graph=False)[1]
     return np.hstack(
         (dL_y.detach().numpy(), dL_p.detach().numpy()))
Esempio n. 7
0
    def backward(ctx, grad_t, grad_state):
        func = ctx.func
        event_fn = ctx.event_fn
        event_t, state_t = ctx.saved_tensors

        event_t = event_t.detach().clone().requires_grad_(True)
        state_t = state_t.detach().clone().requires_grad_(True)

        f_val = func(event_t, state_t)

        with torch.enable_grad():
            c, (par_dt, dstate) = vjp(event_fn, (event_t, state_t))

        # Total derivative of event_fn wrt t evaluated at event_t.
        dcdt = par_dt + torch.sum(dstate * f_val)

        # Add the gradient from final state to final time value as if a regular odeint was called.
        grad_t = grad_t + torch.sum(grad_state * f_val)

        dstate = dstate * (-grad_t / (dcdt + 1e-12)).reshape_as(c)

        grad_state = grad_state + dstate

        return None, None, None, grad_state