def forward(self, X, Y, train):
        B = self.layer_init(X)
        martix1_temp = torch.diagflat(self.vec[0]).cuda()
        B = self.layer_init1(B.mm(martix1_temp))
        # X = self.layer1(X)
        # X = self.layer2(X)
        # X = self.layer3(X)
#        X_temp = X
        if train:
            for i, l in enumerate(self.linears):
                X_temp = l(X)
#                print(i)
#                if i == 4
#                print(X_temp[0:5])
                martix_temp = torch.diagflat(self.vec[i+1]).cuda()
                P = torch.inverse(X_temp.t().mm(X_temp) + self.gamma * torch.eye(size(X_temp, 1)).cuda()).mm(X_temp.t()).mm(  # .cuda()
                    B.float())  # .cuda()
                W = self.beta * torch.inverse(
                    self.beta * (Y.t().mm(Y)).float() + self.gamma * torch.eye(size(Y, 1)).cuda().float()).mm(
                    Y.t().float()).mm(B.float())
                B_temp = (self.alpha * X_temp.mm(P) + self.beta * Y.mm(W)).mm(martix_temp)
                B = Function.tanh(B_temp)
#                print(B[0:3])
        else:
            for i, l in enumerate(self.linears):
#                print(i,l)
#                print(X[0:5])
                X_temp = l(X_temp)
#                print(X_temp[0:5])
                martix_temp = torch.diagflat(self.vec[i + 1])
                P = torch.inverse(X_temp.t().mm(X_temp) + self.gamma * torch.eye(size(X_temp, 1)).cuda()).mm(X_temp.t()).mm(B.float())
                B_temp = (self.alpha * X_temp.mm(P)).mm(martix_temp)
#                print(B_temp[0:5])
                B = Function.tanh(B_temp)
        return B
示例#2
0
def perspective_projection(fov,near,far,filmSize=np.array([1,1]),cropSize=np.array([1,1]),cropOffset=np.array([0,0])):
    """[Reimplementation of Mitsuba perspective_projection function]

    Args:
        fov ([float]): [Field of view in degrees]
        near ([float]): [Near plane]
        far ([float]): [Far plane]
        filmSize ([1x2 ndarray], optional): [Film size]. Defaults to np.array([1,1]).
        cropSize ([1x2 ndarray], optional): [crop size]. Defaults to np.array([1,1]).
        cropOffset ([1x2 ndarray], optional): [Crop offset]. Defaults to np.array([0,0]).

    Returns:
        [4x4 tensor]: [Perspective camera projection matrix]
    """

    aspect = filmSize[0] / filmSize[1]
    rel_offset = cropOffset / filmSize
    rel_size = cropSize / filmSize
    p = perspective(fov,near,far)

    translate = torch.eye(4)
    translate[:3,-1] = torch.Tensor([-1.0, -1.0 / aspect, 0.0])
    scale = torch.diagflat(torch.Tensor([-0.5,-0.5*aspect,1.0,1.0]))
    translateCrop = torch.eye(4)
    translateCrop[:3,-1] = torch.Tensor([-rel_offset[0],-rel_offset[1],0.0])
    scaleCrop = torch.diagflat(torch.Tensor([1/rel_size[0],1/rel_size[1],1.0,1.0]))
    m1 = torch.mm(scaleCrop,torch.mm(translateCrop,torch.mm(scale,torch.mm(translate,p))))
    return m1
def entropic_OT(a, b, M, reg=0.1, maxiter=20, cuda=True):
    """
    Function which computes the autodiff sharp entropic OT loss.
    
    parameters:
        - a : input source measure (TorchTensor (ns))
        - b : input target measure (TorchTensor (nt))
        - M : ground cost between measure support (TorchTensor (ns, nt))
        - reg : entropic ragularization parameter (float)
        - maxiter : number of loop (int)
    
    returns:
        - sharp entropic unbalanced OT loss (float)
    """

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

    K = torch.exp(-M / reg).type(Tensor).double()
    v = torch.from_numpy(ot.unif(K.size()[1])).type(Tensor).double()

    for i in range(maxiter):
        Kv = torch.matmul(K, v)
        u = a / Kv
        Ku = torch.matmul(torch.transpose(K, 0, 1), u)
        v = b / Ku

    pi = torch.matmul(torch.diagflat(u), torch.matmul(K, torch.diagflat(v)))
    return torch.sum(pi * M.double())
    def get_maps(self, W, J):
        # compute operators
        N, bs = W.shape[1], self.batch_size
        deg = torch.sum(W.clone(), dim=2, keepdim=True).to(device)
        I = torch.eye((N), device=device).expand(bs, *(N, N))
        A = (W + I).clone().to(device)
        inv_sqrt_DD = torch.sqrt(
            1 / torch.sum(A.clone(), dim=2, keepdim=True)).to(device)

        # filling in ops
        OP = torch.zeros((bs, N, N, J + self.extra_ops), device=device)
        OP[:, :, :, 0] = I.clone()
        W_pow = (W.clone()).to(device)
        for j in range(J - 1):
            OP[:, :, :, j + 1] = (W_pow.clone()).to(device)
            if j == J - 2:
                break
            W_pow = (torch.min(torch.bmm(W_pow, W_pow),
                               torch.ones(*W.size(),
                                          device=device))).to(device)
        for k in range(bs):
            OP[k, :, :, J] = ((torch.diagflat(deg[k])).clone()).to(device)
            OP[k, :, :, J + self.extra_ops - 2] = ((torch.diagflat(
                inv_sqrt_DD[k])).clone()).to(device)
        OP[:, :, :, J + self.extra_ops - 1] = ((1.0 / float(N)) * torch.ones(
            (bs, N, N), device=device)).to(device)

        return OP.to(device), deg.unsqueeze(1).to(device)
示例#5
0
def multivariate_normal_kl(scale0, scale1, loc0, loc1):
    cov0 = torch.diagflat(scale0 ** 2)
    cov1 = torch.diagflat(scale1 ** 2)

    d0 = MultivariateNormal(loc0, covariance_matrix=cov0)
    d1 = MultivariateNormal(loc1, covariance_matrix=cov1)

    return kl_divergence(d0, d1)
示例#6
0
 def reparam(self):
     A, B, g, k, c = self.A_, torch.exp(self.B_), self.g_, torch.exp(
         self.k_) - 0.5, torch.sigmoid(self.c_) * 0.82
     S = torch.tril(self.C_, diagonal=-1) + torch.diagflat(
         self.C_.diag().exp())
     T = S.mm(S.t())
     D = torch.diagflat(T.diag().pow(-0.5))
     C = D.mm(S)
     return A, B, g, k, C, c
示例#7
0
def get_lossb(X, W, I):
    M1 = torch.mm(X.t(), W * I)
    M2 = torch.mm(X.t(), W * (1.0 - I))
    M1 -= torch.diagflat(torch.diagonal(M1))  # -j
    M2 -= torch.diagflat(torch.diagonal(M2))
    v1 = 1.0 / torch.mm(W.t(), I)
    v2 = 1.0 / torch.mm(W.t(), 1.0 - I)
    v1[v1 == float('inf')] = 0.0
    v2[v2 == float('inf')] = 0.0
    return torch.sum((M1 * v1 - M2 * v2)**2.0)
示例#8
0
def match_histogram(target_tensor, source_tensor, eps=1e-2, mode="avg"):
    if mode is "avg":
        elementwise = True
        random_frame = False
    else:
        elementwise = False
        random_frame = True

    if not isinstance(source_tensor, list):
        source_tensor = [source_tensor]

    output_tensor = th.zeros_like(target_tensor)
    for source in source_tensor:
        target = target_tensor.permute(0, 3, 2, 1)  # Function expects b,w,h,c
        source = source.permute(0, 3, 2, 1)  # Function expects b,w,h,c
        if elementwise:
            source = source.mean(0).unsqueeze(0)
        if random_frame:
            source = source[np.random.randint(0, source.shape[0])].unsqueeze(0)

        matched_tensor = th.zeros_like(target)
        for idx in range(target.shape[0] if elementwise else 1):
            frame = target[idx].unsqueeze(0) if elementwise else target
            _, t, Ct = get_histogram(frame, eps)
            mu_s, _, Cs = get_histogram(source, eps)

            # PCA
            eva_t, eve_t = th.symeig(Ct, eigenvectors=True, upper=True)
            Et = th.sqrt(th.diagflat(eva_t))
            Et[Et != Et] = 0  # Convert nan to 0
            Qt = th.mm(th.mm(eve_t, Et), eve_t.T)

            eva_s, eve_s = th.symeig(Cs, eigenvectors=True, upper=True)
            Es = th.sqrt(th.diagflat(eva_s))
            Es[Es != Es] = 0  # Convert nan to 0
            Qs = th.mm(th.mm(eve_s, Es), eve_s.T)

            ts = th.mm(th.mm(Qs, th.inverse(Qt)), t)

            match = ts.reshape(*frame.permute(0, 3, 1, 2).shape).permute(
                0, 2, 3, 1)
            match += mu_s

            if elementwise:
                matched_tensor[idx] = match
            else:
                matched_tensor = match
        output_tensor += matched_tensor.permute(0, 3, 2,
                                                1) / len(source_tensor)
    return output_tensor
示例#9
0
def cem_optimize(init_mean, cost_func, init_variance=1., samples=400, precision=1.0e-3, steps=5, nelite=40, alpha=0.1,
                 constraint_mean=None, constraint_variance=(-999999, 999999), device="cpu"):
    """
    cem_optimize minimizes cost_function by iteratively sampling values around the current mean with a set variance.
    Of the sampled values the mean of the nelite number of samples with the lowest cost is the new mean for the next iteration.
    Convergence is met when either the change of the mean during the last iteration is less then precision.
    Or when the maximum number of steps was taken.
    :param init_mean: initial mean to sample new values around
    :param cost_func: varience used for sampling
    :param init_variance: initial variance
    :param samples: number of samples to take around the mean. Ratio of samples to elites is important.
    :param precision: if the change of mean after an iteration is less than precision convergence is met
    :param steps: number of steps
    :param nelite: number of best samples whose mean will be the mean for the next iteration
    :param alpha: softupdate, weight for old mean and variance
    :param constraint_mean: tuple with minimum and maximum mean
    :param constraint_variance: tuple with minumum and maximum variance
    :param device: either gpu or cpu (torch tensor configuration)
    :return:
    """
    control_time_diagnoser.start_log("average_cem_time")
    mean = init_mean
    covariance_matrices = torch.stack([torch.diagflat(torch.tensor([init_variance], device=device)) for _ in range(len(mean))])
    # print(mean.type(), variance.type())
    step = 0
    diff = 9999999
    while diff > precision and step <= steps:
        # we create a distribution with action dimensionality and a batch size corresponding the trajectory length
        # dist.batch_shape == trajectory_len, dist.event_shape == action_space_dim
        dist = distributions.MultivariateNormal(mean, covariance_matrix=covariance_matrices)
        candidates = dist.sample_n(samples).to(device)
        costs = cost_func(candidates)
        # we sort descending because we want a maximum reward
        sorted_idx = torch.argsort(costs, dim=0, descending=True)
        candidates = candidates[sorted_idx]
        elite = candidates[:nelite]
        new_mean = torch.mean(elite, dim=0)
        new_covariance_matrizies = torch.stack([torch.diagflat(v) for v in torch.var(elite, dim=0)])
        # calculate diff for break condition on precision
        diff = torch.mean(torch.abs(mean - new_mean))
        # softupdate mean and variance with alpha
        mean = (1 - alpha) * new_mean + alpha * mean
        covariance_matrices = (1 - alpha) * new_covariance_matrizies + alpha * covariance_matrices
        # print(mean, variance)
        if constraint_mean is not None:
            mean = clip(mean, constraint_mean[0], constraint_mean[1])
        step += 1
    control_time_diagnoser.end_log("average_cem_time")
    return mean
示例#10
0
def update_m(m_mat_lst,
             a_mat_lst,
             b_mat_lst,
             avg_psi_mat_lst,
             eta=1,
             diagonal=False):
    """
        This function updates the mean according to M = M - B*B^t*E[Psi]*A*A^t.
        :param m_mat_lst: m_mat_lst: A list of matrices in size of P*N.
        :param a_mat_lst: A list of matrices in size of N*N.
        :param b_mat_lst: A list of matrices in size of P*P.
        :param avg_psi_mat_lst: A list of matrices in size of P*N.
        :param eta: .
        :param diagonal: .
        :return:
    """
    if diagonal:
        for i in range(len(m_mat_lst)):
            # M = M - diag(B*B^t)*E[Psi]*diag(A*A^t)
            m_mat_lst[i].copy_(
                torch.add(
                    m_mat_lst[i], -eta,
                    torch.mm(
                        torch.mm(
                            torch.diagflat(
                                torch.diagonal(
                                    torch.mm(
                                        b_mat_lst[i],
                                        torch.transpose(b_mat_lst[i], 0, 1)))),
                            avg_psi_mat_lst[i]),
                        torch.diagflat(
                            torch.diagonal(
                                torch.mm(a_mat_lst[i],
                                         torch.transpose(a_mat_lst[i], 0,
                                                         1)))))))
    else:
        for i in range(len(m_mat_lst)):
            # M = M - B*B^t*E[Psi]*A*A^t
            m_mat_lst[i].copy_(
                torch.add(
                    m_mat_lst[i], -eta,
                    torch.mm(
                        torch.mm(
                            torch.mm(b_mat_lst[i],
                                     torch.transpose(b_mat_lst[i], 0, 1)),
                            avg_psi_mat_lst[i]),
                        torch.mm(a_mat_lst[i],
                                 torch.transpose(a_mat_lst[i], 0, 1)))))
示例#11
0
    def _softmax(self, a, dim=1):
        a = a - a.max(dim=1, keepdim=True)[0]
        infs = torch.diagflat(torch.tensor(self.N * [-1 * float('Inf')])).to(
            self.device)

        a = torch.exp(a + infs)
        return f.normalize(a)
示例#12
0
文件: utils.py 项目: tyuioahxvm/EAE
def check():
    x = torch.tensor(np.random.randn(5, 2),
                     dtype=torch.float,
                     requires_grad=True)
    y = torch.log(x)
    j_t = torch.zeros(5, 2, 2)
    for i in range(5):
        j_t[i, :, :] = torch.diagflat(1 / x[i, :])

    print(j_t)
    j = compute_jacobian(x, y)
    print(j)
    # assert(j_t.data.numpy() == j.data.numpy())

    log_det_j_t = torch.log(torch.abs(1 / x[:, 0] * 1 / x[:, 1])).reshape(
        -1, 1)
    log_det_j = log_det_jacobian(x, y)

    print(log_det_j_t)
    print(log_det_j)
    # assert (log_det_j.data.numpy() == log_det_j_t.data.numpy())
    # print(log_det_j.shape)


#check()
#test()
示例#13
0
    def _get_sigma_kt(sigma):
        if isinstance(sigma, torch.Tensor):
            try:
                # tensor.item() works if tensor is a scalar, otherwise it throws
                # a value error.
                sigma.item()
                return sigma, "single"
            except ValueError:
                pass

            if sigma.dim() == 1 or sigma.shape[1] == 1:
                return torch.diagflat(sigma), "multi"

            if sigma.dim() != 2:
                raise TypeError(
                    "Sigma can be specified as a 1D or a 2D tensor. "
                    "Found %dD tensor" % (sigma.dim()))
            if sigma.shape[0] != sigma.shape[1]:
                raise TypeError("Sigma passed as a 2D matrix must be square. "
                                "Found dimensions %s" % (sigma.size()))
            return sigma, "multi"
        else:
            try:
                sigma = float(sigma)
                return torch.tensor(sigma, dtype=torch.float64), "single"
            except TypeError:
                raise TypeError("Sigma must be a scalar or a tensor.")
示例#14
0
    def predict(self, state, thompson_sampling=False):
        if not thompson_sampling:
            net1, net2 = self.network(state)
            net1 = net1.view(self.env.action_space.n, self.n_quantiles)
            net2 = net2.view(self.env.action_space.n, self.n_quantiles)
            action_means = torch.mean((net1 + net2) / 2, dim=1)
            action = action_means.argmax().item()
        else:
            net1, net2 = self.network(state)
            net1_target, net2_target = self.target_network(state)
            net1 = net1.view(self.env.action_space.n, self.n_quantiles)
            net2 = net2.view(self.env.action_space.n, self.n_quantiles)
            net1_target = net1_target.view(self.env.action_space.n,
                                           self.n_quantiles)
            net2_target = net2_target.view(self.env.action_space.n,
                                           self.n_quantiles)
            action_means = torch.mean((net1 + net2) / 2, dim=1)
            action_uncertainties = torch.mean(
                (net1_target - net2_target)**2, dim=1) / 2
            samples = torch.distributions.multivariate_normal.MultivariateNormal(
                action_means,
                covariance_matrix=torch.diagflat(
                    action_uncertainties)).sample()
            #print(samples)
            action = samples.argmax().item()
            if action == action_means.argmax().item():
                self.n_greedy_actions += 1

        #print(action_means,torch.sqrt(action_uncertainties))
        #if self.logging:
        #    self.logger.add_scalar('Uncertainty', action_uncertainties.mean().item(), self.timestep)
        return action
def conv3_l2_reg_orthogonal(mdl, device):
    """
    Make weight matrixs be an orthogonal matrix. (not a orthonormal matrix.)
    This is to analyze only the effect of orthogonality, not from the orthonormal vectors.
    """
    l2_reg = None
    for name, module in mdl.named_children():
        if 'layer' in name:
            for m in module:
                W = m.conv3.weight
                cols = W[0].numel()
                rows = W.shape[0]
                w1 = W.view(-1, cols)
                wt = torch.transpose(w1, 0, 1)
                if (rows > cols):
                    m = torch.matmul(wt, w1)
                else:
                    m = torch.matmul(w1, wt)

                w_tmp = (m - torch.diagflat(torch.diagonal(m)))
                b_k = Variable(torch.rand(w_tmp.shape[1], 1)).type(
                    torch.HalfTensor).to(mdl.device)
                b_k = b_k.to(mdl.device)

                v1 = torch.matmul(w_tmp, b_k)
                norm1 = torch.norm(v1, 2)
                v2 = torch.div(v1, norm1)
                v3 = torch.matmul(w_tmp, v2)

                if l2_reg is None:
                    l2_reg = (torch.norm(v3, 2))**2
                else:
                    l2_reg = l2_reg + (torch.norm(v3, 2))**2

    return l2_reg
示例#16
0
def test_lu_affine_transform():
    d = 5
    batch_size = 64
    for _ in range(100):
        lower = (torch.randn(d, d).tril(-1) + torch.eye(d)).double()
        upper = torch.randn(d, d).triu(1).double()
        diag = torch.randn(d).double()
        bias = torch.randn(d).double()
        w = lower @ (upper + torch.diagflat(diag))

        t = LUAffineTransform(lower, upper, diag, bias)

        x = torch.randn(batch_size, d).double()
        y = x @ w.t() + bias

        assert_array_equal(t.w, w)
        assert_array_equal(t.w_inv, w.inverse())

        assert_array_equal(t(x), y)
        assert_array_almost_equal(t.inv(torch.tensor(y)), x, 5)

        assert_array_almost_equal(
            t.log_abs_det_jacobian(x, y).item(),
            torch.log(torch.abs(torch.det(w))),
            decimal=5,
        )
示例#17
0
def _kl_fcmvg_ffmvg(q, p):
    """
    Computes the KL divergence between a full covariance multivariate gaussian (FCMVG as q) and a fully factorized
    multivariate Gaussian (FFMVG as p)
    Args:
        q (FullCovarianceMultivariateGaussian):
        p:

    Returns:

    """

    lower_triang_q = torch.tril(q.cov_lower_triangular, -1)
    lower_triang_q += torch.diagflat(torch.exp(q.logvars))

    kl = 0.5 * (
        torch.sum(p.logvars) -
        2.0 * torch.sum(torch.log(torch.diag(lower_triang_q))) +
        torch.sum(torch.mul(torch.pow(q.mean - p.mean, 2),
                            (-p.logvars).exp())) +
        torch.sum(
            torch.diag(
                (-p.logvars).exp() *
                torch.matmul(lower_triang_q, lower_triang_q.t()))) - q.n)

    return kl
示例#18
0
def compute_normalized_laplacian(adj):
    rowsum = torch.sum(adj, -1)
    d_inv_sqrt = torch.pow(rowsum, -0.5)
    d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = torch.diagflat(d_inv_sqrt)
    L_norm = torch.mm(torch.mm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)
    return L_norm
    def sample(self, n_samples):
        device = self.mean.device
        samples = torch.zeros(n_samples,
                              self.n,
                              self.m,
                              dtype=self.dtype,
                              device=device,
                              requires_grad=False)

        epsilon_for_log_diag_W_sample = torch.randn(self.nmc,
                                                    self.in_features,
                                                    self.out_features,
                                                    device=device,
                                                    dtype=self.dtype,
                                                    requires_grad=False)
        epsilon_for_low_rank_sample = torch.randn(self.nmc,
                                                  self.rank,
                                                  self.out_features,
                                                  device=device,
                                                  dtype=self.dtype,
                                                  requires_grad=False)
        for i in range(self.m):
            samples[:, :, i] = (
                (torch.matmul(self.cov_low_rank[i, :, :],
                              epsilon_for_low_rank_sample[:, :, i].t()).t() +
                 torch.matmul(
                     torch.diagflat(torch.exp(self.logvars[i, :] / 2.0)),
                     epsilon_for_log_diag_W_sample[:, :, i].t()).t()) +
                self.mean[:, i])

        return samples
示例#20
0
    def test_rpc_jit(self, num_clients=2, address="127.0.0.1:12346"):
        def run_client(client_id):
            client = postman.Client(address)
            client.connect(10)
            arg = np.full((1, 2), client_id, dtype=np.float32)
            batched_arg = np.full((2,), client_id, dtype=np.float32)

            function_result = client.function(arg)
            batched_function_result = client.batched_function(batched_arg)

            np.testing.assert_array_equal(function_result, np.full((1, 2), client_id))
            np.testing.assert_array_equal(batched_function_result, np.full((2,), client_id))

        clients = [mp.Process(target=run_client, args=(i,)) for i in range(num_clients)]

        linear = torch.nn.Linear(2, 2, bias=False)
        linear.weight.data = torch.diagflat(torch.ones(2))
        module = torch.jit.script(linear)
        server = postman.Server("127.0.0.1:12346")

        server.bind("function", module)
        server.bind("batched_function", module, batch_size=num_clients)

        server.run()

        for p in clients:
            p.start()

        for p in clients:
            p.join()

        server.stop()
示例#21
0
        def like(m):
            # A = self._softmax(torch.mm(self.P, self.P.t()))
            n = len(m) - 1
            a = list(m) + list(set(range(self.N)) - set(m))
            index = torch.LongTensor(a).to(self.device)

            # t = A[index][:, index]
            # t = self._softmax(self._symmetrize(self.A))[index][:,index]
            if self.sym:
                t = f.normalize(torch.exp(
                    self._symmetrize(self.P[index][:, index])),
                                p=1)
                # t = self._softmax(self._symmetrize(self.P))[index][:, index]
            else:
                t = f.normalize(torch.exp(self.P[index][:, index]), p=1)
            t = t - torch.diagflat(torch.diagonal(t))
            q_ma = (self.j * self.i)[:n]
            r_ma = ((1 - self.i) * self.j)[:n]  # -\
            # (self.j*self.i)[n].flip(1)*self.j[:n]
            q = t * q_ma
            r = t * r_ma
            p = torch.matmul(
                torch.inverse(torch.eye(self.N).to(self.device) - q), r)
            lik = p.diagonal().diagonal(offset=-1).log().sum()

            return lik
 def __init__(self, mu, logvar, covar_add_encoder_vars=True):
     self.covar = empirical_covar(mu)
     if covar_add_encoder_vars:
         self.covar += torch.diagflat(logvar.exp().mean(0))
     self.mean = mu.mean(0)
     self.gaussian = scipy.stats.multivariate_normal(
         self.mean.numpy(), self.covar.numpy())
示例#23
0
def cubic(
    input: torch.Tensor, value: torch.Tensor, domain: Tuple[float, float] = (0, 1)
) -> torch.Tensor:
    n = value.size(0) - 1
    h = (domain[1] - domain[0]) / n
    A = torch.eye(n + 1) + torch.diagflat(torch.full((n,), 0.5), 1)
    A += A.T
    A[0, 1] = A[-1, -2] = 0
    d = 3 * (value[2:] - 2 * value[1:-1] + value[:-2]) / h ** 2
    d = torch.cat((torch.zeros(1), d, torch.zeros(1))).unsqueeze_(-1)
    z, _ = torch.solve(d, A)

    sampler = VectorSampler(input, domain, n)
    x = torch.linspace(
        domain[0], domain[1], n + 1, dtype=torch.float32, device=input.device
    )
    distance_left = input - sampler.get_left(x)
    distance_right = h - distance_left
    cubic_left = torch.pow(distance_left, 3)
    cubic_right = torch.pow(distance_right, 3)

    z_left = sampler.get_left(z)
    z_right = sampler.get_right(z)
    value_left = sampler.get_left(value)
    value_right = sampler.get_right(value)

    f = z_left * cubic_right + z_right * cubic_left
    f /= 6 * h
    f += (value_right / h - z_right * h / 6) * distance_left
    f += (value_left / h - z_left * h / 6) * distance_right
    return f
示例#24
0
    def forward(self, adj, h, number_nodes=None):
        for layer in self.gcnlist:
            h = layer(adj, h)
        if self.pooling == 'sum':
            h = torch.sum(h, dim=-2)
        elif self.pooling == 'average':  # here really contain the information : contains how many None nodes.
            if self.avg_type == 0:  #in some how,this equtal to sum. and this can be good!!
                h = torch.sum(h, dim=-2)
                num_nodes = 1 / torch.sum(number_nodes, dim=-1)
                num = torch.diagflat(num_nodes)
                h = torch.matmul(num, h)
            else:  #this real average
                num_nodes = 1 / torch.sum(number_nodes, dim=-1)
                mask = torch.matmul(num_nodes.unsqueeze(-1), number_nodes)
                h = torch.matmul(mask, h).squeeze(-2)

        elif self.pooling == 'max':
            h, _ = torch.max(h, dim=-2)
        elif self.pooling == 'attention':
            e = torch.matmul(self.P, torch.transpose(h, -1, -2))
            e = F.relu(e)
            e1 = -9e15 * torch.ones_like(e)
            e = torch.where(number_nodes >= 0, e, e1)
            e = F.softmax(e, dim=-1)
            h = torch.matmul(e, h)
            h = torch.squeeze(h, -2)
        else:  # this just a feature exacture ,and don't need prediction
            return h  # as the input of next_lyer

        pred = self.pred(h)
        #pred = F.softmax(pred,dim=-1)

        return pred
示例#25
0
    def sample_local_reparam_linear(self, n_sample: int,
                                    in_data: torch.Tensor):
        # Retrieve current device (useful generate the data already in the correct device)
        device = self.mean.device
        epsilon_for_Y_sample = torch.randn(
            n_sample,
            in_data.size(-2),
            self.mean.size(1),
            dtype=self.dtype,
            device=device,
            requires_grad=False)  # type: torch.Tensor

        mean_Y = torch.matmul(in_data, self.mean)

        var_Y = torch.zeros_like(mean_Y, device=device) * torch.ones(
            n_sample, 1, 1, device=device)
        for i in range(self.m):
            cov_lower_triangular = torch.tril(
                self.cov_lower_triangular[i, :, :], -1)
            L_chol = cov_lower_triangular + torch.diagflat(
                torch.exp(self.logvars[:, i]))

            var_Y[:, :, i] = torch.sum(torch.matmul(in_data, L_chol)**2, -1)

        Y = mean_Y + torch.sqrt(
            var_Y + 1e-5) * epsilon_for_Y_sample  # type: torch.Tensor
        return Y
示例#26
0
    def sample(self, n_samples: int):
        epsilon_for_samples = torch.randn(n_samples,
                                          self.n,
                                          self.m,
                                          dtype=self.dtype,
                                          device=self.mean.device,
                                          requires_grad=False)

        samples = torch.zeros_like(epsilon_for_samples,
                                   device=self.mean.device)

        for i in range(self.m):
            cov_lower_triangular = torch.tril(
                self.cov_lower_triangular[i, :, :], -1)
            cov_lower_triangular += torch.diagflat(
                torch.exp(self.logvars[:, i]))
            samples[:, :, i] = torch.add(
                torch.matmul(cov_lower_triangular,
                             epsilon_for_samples[:, :, i].t()).t(),
                self.mean[:, i])

        samples = torch.add(
            torch.mul(epsilon_for_samples, torch.exp(self.logvars / 2.0)),
            self.mean)  # type: torch.Tensor
        return samples
示例#27
0
    def test_hessian_simple(self, device):
        def foo(x):
            return x.sin().sum()

        x = torch.randn(3, device=device)
        y = jacrev(jacrev(foo))(x)
        expected = torch.diagflat(-x.sin())
        assert torch.allclose(y, expected)
示例#28
0
def tc_loss(zs, m):
    means = zs.mean(0).unsqueeze(0)
    res = ((zs.unsqueeze(2) - means.unsqueeze(1))**2).sum(-1)
    pos = torch.diagonal(res, dim1=1, dim2=2)
    offset = torch.diagflat(torch.ones(zs.size(1))).unsqueeze(0).cuda() * 1e6
    neg = (res + offset).min(-1)[0]
    loss = torch.clamp(pos + m - neg, min=0).mean()
    return loss
示例#29
0
 def sym(self, t, c, s):
     p = self.pca(t, c)
     psp = torch.mm(torch.mm(p, s), p)
     eval_psp, evec_psp = torch.symeig(psp, eigenvectors=True, upper=True)
     e = self.nan2zero(torch.sqrt(torch.diagflat(eval_psp)))
     evec_mm = torch.mm(torch.mm(evec_psp, e), evec_psp.T)
     return torch.mm(
         torch.mm(torch.mm(torch.inverse(p), evec_mm), torch.inverse(p)), t)
示例#30
0
def normalize_t(mx):
    """Row-normalize sparse matrix in tensor"""
    rowsum = torch.sum(mx, dim=1)
    r_inv = torch.pow(rowsum, -1).flatten()
    r_inv[torch.isinf(r_inv)] = 0.
    r_mat_inv = torch.diagflat(r_inv)
    mx = torch.mm(r_mat_inv, mx)
    return mx