예제 #1
0
def update_tril(entries, D):
    tril = torch.zeros(D, D)
    tril[range(D), range(D)] = softplus(entries[0:D])
    off_idx = torch.tril_indices(D, D)[0] != torch.tril_indices(D, D)[1]
    a, b = torch.tril_indices(D, D)[:, off_idx]
    tril[a, b] = entries[D:]
    return tril
예제 #2
0
    def get_sparse_config(self, in_dim, out_dim, sparsity_level):
        '''Get num_diagonals and num coeffs.

        Given the dimension of matrix
          in_dim: number of columns
          out_dim: number of rows

        We want to find the right diagonal shift "d" s.t.

        N(d) < thr(desired sparsity) < N(d+1)
        N(d+1)

        We search as follows:
        - If: N(0) is below thr: try N(n) for n = -1..-out_dim
        - Else: try N(n) for n = 1..in_dim

        input: 2 dimensions of the weight matrix
        output: tuple (num_diagonal, num_coeff)
        '''

        total_el = in_dim * out_dim
        thr = int(total_el * (1 - sparsity_level))  # just truncate fraction.

        for num_diag in range(out_dim):  # upper triagular matrix.
            non_zeros = torch.tril_indices(out_dim, in_dim,
                                           -num_diag).size()[1]
            if non_zeros < thr:
                break

        print(f"sparsity: {(total_el - non_zeros) / total_el * 100 :.1f} %"
              f" vs. desired sparsity {sparsity_level * 100} %")
        return non_zeros, num_diag
예제 #3
0
 def synchronize(self):
     for h in self.handles:
         hvd.synchronize(h)
     if self.merge:
         self._tensor_group.pull_alltensors()
         self._tensor_group.clear_group_flags()
     for name in self._name_tensors:
         tensor, comm_tensor = self._name_tensors[name]
         if self.symmetric:
             if self.fp16:
                 comm_tensor = comm_tensor.float()
             lower_indices = torch.tril_indices(tensor.shape[0],
                                                tensor.shape[1],
                                                device=tensor.device)
             upper_indices = torch.triu_indices(tensor.shape[0],
                                                tensor.shape[1],
                                                device=tensor.device)
             tensor[upper_indices[0], upper_indices[1]] = comm_tensor
             tensor[lower_indices[0],
                    lower_indices[1]] = tensor.t()[lower_indices[0],
                                                   lower_indices[1]]
         else:
             if self.fp16:
                 comm_tensor = comm_tensor.float()
                 tensor.copy_(comm_tensor)
         if self.op == hvd.Average:
             tensor.div_(hvd.size())
     self._name_tensors.clear()
     self.handles.clear()
예제 #4
0
def _assemble_tril(diag: torch.Tensor,
                   lower_vec: torch.Tensor) -> torch.Tensor:
    dim = diag.shape[-1]
    L = torch.diag_embed(diag)  # L is lower-triangular
    i, j = torch.tril_indices(dim, dim, offset=-1)
    L[..., i, j] = lower_vec
    return L
예제 #5
0
 def _call(self, x):
     dim = int((-1 + math.sqrt(1 + 8 * x.shape[0])) / 2)
     tril = torch.zeros((dim, dim), dtype=x.dtype)
     tril_indices = torch.tril_indices(row=dim, col=dim, offset=0)
     tril[tril_indices[0], tril_indices[1]] = x
     tril[range(dim), range(dim)] = tril.diag().exp()
     return tril
예제 #6
0
 def forward(self, state):
     """
     forwards input through the network
     :param state: (B, ds)
     :return: mean vector (B, da) and cholesky factorization of covariance matrix (B, da, da)
     """
     device = state.device
     B = state.size(0)
     ds = self.ds
     da = self.da
     action_low = torch.from_numpy(self.env.action_space.low)[None, ...].to(
         device)  # (1, da)
     action_high = torch.from_numpy(
         self.env.action_space.high)[None, ...].to(device)  # (1, da)
     x = F.relu(self.lin1(state))
     x = F.relu(self.lin2(x))
     mean = torch.sigmoid(self.mean_layer(x))  # (B, da)
     mean = action_low + (action_high - action_low) * mean
     cholesky_vector = self.cholesky_layer(x)  # (B, (da*(da+1))//2)
     cholesky_diag_index = torch.arange(da, dtype=torch.long) + 1
     cholesky_diag_index = (cholesky_diag_index *
                            (cholesky_diag_index + 1)) // 2 - 1
     cholesky_vector[:, cholesky_diag_index] = F.softplus(
         cholesky_vector[:, cholesky_diag_index])
     tril_indices = torch.tril_indices(row=da, col=da, offset=0)
     cholesky = torch.zeros(size=(B, da, da),
                            dtype=torch.float32).to(device)
     cholesky[:, tril_indices[0], tril_indices[1]] = cholesky_vector
     return mean, cholesky
예제 #7
0
def _unflatten_tril(x):
    """Unflattens a vector into a lower triangular matrix of shape `dim x dim`."""
    n, dim = x.shape
    idxs = torch.tril_indices(dim, dim)
    tril = torch.zeros(n, dim, dim)
    tril[:, idxs[0, :], idxs[1, :]] = x
    return tril
예제 #8
0
def lower_vector_to_matrix(
    lower_flat: torch.Tensor, matrix_dim: Optional[int] = None
) -> torch.Tensor:
    """Convert a valid vector to a lower triangular matrix.

    Parameters
    ----------
    vector : torch.Tensor
        vector
    matix_dim : Optional[int]
        matix_dim

    Returns
    -------
    torch.Tensor

    """
    shape = lower_flat.shape

    if matrix_dim is None:
        # (N, N) matrix has L =  N * (N + 1) / 2 values in its
        # lower triangular region (including diagonals).

        # N = 0.5 * sqrt(8 * L + 1)**0.5 - 0.5
        matrix_dim = int(0.5 * (8 * shape[-1] + 1) ** 0.5 - 0.5)

    matrix_shape = shape[:-1] + (matrix_dim, matrix_dim)

    lower = torch.zeros(matrix_shape)
    lower_idx = torch.tril_indices(matrix_dim, matrix_dim)
    lower[..., lower_idx[0], lower_idx[1]] = lower_flat
    return lower
예제 #9
0
def symmetry(x, mode="real"):
    center = (x.shape[1]) // 2
    u = torch.arange(center)
    v = torch.arange(center)

    diag1 = torch.arange(center, x.shape[1])
    diag2 = torch.arange(center, x.shape[1])
    diag_indices = torch.stack((diag1, diag2))
    grid = torch.tril_indices(x.shape[1], x.shape[1], -1)

    x_sym = torch.cat(
        (grid[0].reshape(-1, 1), diag_indices[0].reshape(-1, 1)),
    )
    y_sym = torch.cat(
        (grid[1].reshape(-1, 1), diag_indices[1].reshape(-1, 1)),
    )
    x = torch.rot90(x, 1, dims=(1, 2))
    i = center + (center - x_sym)
    j = center + (center - y_sym)
    u = center - (center - x_sym)
    v = center - (center - y_sym)
    if mode == "real":
        x[:, i, j] = x[:, u, v]
    if mode == "imag":
        x[:, i, j] = -x[:, u, v]
    return torch.rot90(x, 3, dims=(1, 2))
    def __init__(self, input_dim, hidden_dim, output_dim, K):
        super(MDN, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.K = K
        self.tril_indices = torch.tril_indices(row=output_dim,
                                               col=output_dim,
                                               offset=-1)
        # initialize linear transformations
        self.lin_input_to_hidden = nn.Linear(input_dim, hidden_dim)
        self.lin_hidden_to_hidden = nn.Linear(hidden_dim, hidden_dim)
        self.lin_hidden_to_hidden = nn.Linear(hidden_dim, hidden_dim)
        self.lin_hidden_to_mix_components = nn.Linear(hidden_dim, K)
        self.lin_hidden_to_loc = nn.Linear(hidden_dim, output_dim * K)
        self.lin_hidden_to_offdiag = nn.Linear(hidden_dim, K)
        self.lin_hidden_to_sigma = nn.Linear(hidden_dim, output_dim * K)

        # initialize non-linearities
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        self.softplus = nn.Softplus()
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.dropout1 = nn.Dropout(p=0.5)
        self.dropout2 = nn.Dropout(p=0.3)
예제 #11
0
def pt_meddistance(X, subsample=None, seed=283):
    """
    Compute the median of pairwise distances (not distance squared) of points
    in the matrix.  Useful as a heuristic for setting Gaussian kernel's width.

    Parameters
    ----------
    X : n x d torch tensor

    Return
    ------
    median distance (a scalar, not a torch tensor)
    """
    n = X.shape[0]
    if subsample is None:
        D = torch.sqrt(pt_dist2_matrix(X, X))
        I = torch.tril_indices(n, n, -1)
        Tri = D[I[0], I[1]]
        med = torch.median(Tri)
        return med.item()
    else:
        assert subsample > 0
        with NumpySeedContext(seed=seed):
            ind = np.random.choice(n, min(subsample, n), replace=False)
        # recursion just once
        return pt_meddistance(X[ind], None, seed=seed)
예제 #12
0
    def _interaction(self, bottom_mlp_output, embedding_outputs, batch_size):
        """Interaction

        "dot" interaction is a bit tricky to implement and test. Break it out from forward so that it can be tested
        independently.

        Args:
            bottom_mlp_output (Tensor):
            embedding_outputs (list): Sequence of tensors
            batch_size (int):
        """
        concat = torch.cat([bottom_mlp_output] + embedding_outputs, dim=1)
        if self._interaction_op == "dot" and not self._self_interaction:
            concat = concat.view((batch_size, -1, self._embedding_dim))
            if concat.dtype == torch.half:
                interaction_output = dotBasedInteract(concat, bottom_mlp_output)
            else:  # Legacy path
                interaction = torch.bmm(concat, torch.transpose(concat, 1, 2))
                tril_indices_row, tril_indices_col = torch.tril_indices(
                    interaction.shape[1], interaction.shape[2], offset=-1)
                interaction_flat = interaction[:, tril_indices_row, tril_indices_col]

                # concatenate dense features and interactions
                zero_padding = torch.zeros(
                    concat.shape[0], 1, dtype=concat.dtype, device=concat.device)
                interaction_output = torch.cat((bottom_mlp_output, interaction_flat, zero_padding), dim=1)

        elif self._interaction_op == "cat":
            interaction_output = concat
        else:
            raise NotImplementedError

        return interaction_output
예제 #13
0
    def synchronize(self):
        self.merged_comm.synchronize()
        for h in self.handles:
            handle, names, tensors, comm_tensors, rank = h

            if rank != hvd.rank():
                continue

            name = ','.join(names)
            offset = 0
            buf = self.merged_tensors[name]

            if self.fp16:
                buf = buf.float()

            for i, t in enumerate(tensors):
                numel = comm_tensors[i].numel()
                comm_tensor = buf.data[offset:offset+numel]

                if self.symmetric:
                    lower_indices = torch.tril_indices(t.shape[0], t.shape[1], device=t.device)
                    upper_indices = torch.triu_indices(t.shape[0], t.shape[1], device=t.device)
                    t[upper_indices[0], upper_indices[1]] = comm_tensor
                    t[lower_indices[0], lower_indices[1]] = t.t()[lower_indices[0], lower_indices[1]]
                else:
                    t.copy_(comm_tensor.view(t.shape))
                t.div_(hvd.size())
                offset += numel 
        self.handles.clear()
예제 #14
0
 def rsample(self, sample_shape=torch.Size()):
     """
     References
     ----------
     - Sawyer, S. (2007). Wishart Distributions and Inverse-Wishart Sampling.
       https://www.math.wustl.edu/~sawyer/hmhandouts/Wishart.pdf
     - Anderson, T. W. (2003). An Introduction to Multivariate Statistical Analysis (3rd ed.).
       John Wiley & Sons, Inc.
     - Odell, P. L. & Feiveson, A. H. (1966). A Numerical Procedure to Generate a Sample
       Covariance Matrix. Journal of the American Statistical Association, 61(313):199-203.
     - Ku, Y.-C. & Blomfield, P. (2010). Generating Random Wishart Matrices with Fractional
       Degrees of Freedom in OX.
     """
     shape = torch.Size(sample_shape) + self.batch_shape
     dtype, device = self.concentration.dtype, self.concentration.device
     D = self.event_shape[-1]
     df = 2. * self.concentration  # type: torch.Tensor
     i = torch.arange(D, dtype=dtype, device=device)
     concentration = .5 * (df.unsqueeze(-1) - i).expand(shape + (D, ))
     V = 2. * torch._standard_gamma(concentration)
     N = torch.randn(*shape, D * (D - 1) // 2, dtype=dtype, device=device)
     T = torch.diag_embed(V.sqrt())  # T is lower-triangular
     i, j = torch.tril_indices(D, D, offset=-1)
     T[..., i, j] = N
     M = self.scale_tril @ T
     W = M @ M.transpose(-2, -1)
     return W
예제 #15
0
    def __init__(self, block_size, offset=0, inverse=True, **kwargs):
        """Applies repeated lower-triangular block diagonal matrices.
        Let H be the size of a lower-triangular block diagonal matrix A.
        This layer applies:
        [
            A, 0, 0, ...
            0, A, 0, ...
            0, 0, A, ...
            ., ., ., ...
        ]

        Args:
            block_size (int): Block size H
            offset (int, optional): Offset of A along the diagonal. Defaults to 0.
            inverse (bool, optional): Species which direction should be modelled directly. 
                The matrix for the inverse A^-1 is computed by torch.inv. Defaults to True.
        """
        super().__init__()
        self.block_size = block_size
        self.n_params = (block_size**2 - block_size) // 2 + block_size
        self.params = nn.Parameter(torch.zeros(self.n_params))
        mask = torch.zeros((block_size, block_size), dtype=torch.bool)
        idx = torch.tril_indices(block_size, block_size)
        mask[tuple(idx)] = 1
        self.mask = nn.Parameter(mask, requires_grad=False)
        diag_idx = (torch.arange(block_size) + 1).cumsum(0) - 1
        self.diag_idx = nn.Parameter(diag_idx, requires_grad=False)
        self.offset = offset % self.block_size
        self._inverse = inverse
예제 #16
0
def log_cholesky_transform(x):
    r"""Perform the log cholesky transform on a vector of values.

    This turns a vector of :math:`\frac{N(N+1)}{2}` unconstrained values into a
    valid :math:`N \times N` covariance matrix.


    References
    ----------

    - Jose C. Pinheiro & Douglas M. Bates.  `Unconstrained Parameterizations
      for Variance-Covariance Matrices
      <https://dx.doi.org/10.1007/BF00140873>`_ *Statistics and Computing*,
      1996.
    """

    if get_backend() == "pytorch":
        import numpy as np
        import torch

        N = int((np.sqrt(1 + 8 * torch.numel(x)) - 1) / 2)
        E = torch.zeros((N, N), dtype=get_datatype())
        tril_ix = torch.tril_indices(row=N, col=N, offset=0)
        E[..., tril_ix[0], tril_ix[1]] = x
        E[..., range(N), range(N)] = torch.exp(torch.diagonal(E))
        return E @ torch.transpose(E, -1, -2)
    else:
        import tensorflow as tf
        import tensorflow_probability as tfp

        E = tfp.math.fill_triangular(x)
        E = tf.linalg.set_diag(E, tf.exp(tf.linalg.tensor_diag_part(E)))
        return E @ tf.transpose(E)
예제 #17
0
    def interact_features(self, x, ly):
        if self.arch_interaction_op == "dot":
            # concatenate dense and sparse features
            (batch_size, d) = x.shape
            T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
            # perform a dot product
            Z = torch.bmm(T, torch.transpose(T, 1, 2))
            # append dense feature with the interactions (into a row vector)
            # approach 1: all
            # Zflat = Z.view((batch_size, -1))
            # approach 2: unique
            _, ni, nj = Z.shape
            offset = 0 if self.arch_interaction_itself else -1
            li, lj = torch.tril_indices(ni, nj, offset=offset)
            Zflat = Z[:, li, lj]
            # concatenate dense features and interactions
            R = torch.cat([x] + [Zflat], dim=1)
        elif self.arch_interaction_op == "cat":
            # concatenation features (into a row vector)
            R = torch.cat([x] + ly, dim=1)
        else:
            sys.exit(
                "ERROR: --arch-interaction-op="
                + self.arch_interaction_op
                + " is not supported"
            )

        return R
예제 #18
0
def x_to_xp_L(x, n):
    """Return (x, L, ln(diag(L))) given net output x
    :param x: Neural net output, (n_batch, n_outputs)
    :param n: Number of variables

    Note that the net predicts: (
        point estimates,
        ln of diagonal entries of L,
        off-diagonal entries of L (flattened)
    )
    """
    if not isinstance(x, torch.Tensor):
        x = torch.Tensor(x)
    _, n_out = x.shape
    assert n_out == dds.n_out(n, 'correlated'), "Wrong number of outputs"

    # Split off the covariance terms from x
    x_p, x_diag, x_nondiag = x[:, :n], x[:, n:2 * n], x[:, 2 * n:]

    # Create diagonal matrices
    L = torch.diag_embed(torch.exp(x_diag))

    # Get indices of elements below main diagonal
    row_indices, col_indices = torch.tril_indices(n, n, -1)

    # Add off-diagonal entries in-place
    L[:, row_indices, col_indices] += x_nondiag

    return x_p, L, x_diag
예제 #19
0
    def eval(self, scene_tree):
        tables = scene_tree.find_nodes_by_type(Table)
        all_dists = []
        for table in tables:
            objs = [
                node for node in scene_tree.get_children_recursive(table)
                if isinstance(node, TabletopObjectTypes)
                and not isinstance(scene_tree.get_parent(node), SteamerBottom)
                and not isinstance(node, FirstChopstick)
                and not isinstance(node, SecondChopstick)
            ]
            if len(objs) <= 1:
                print("no objects")
                continue
            xys = torch.stack([obj.translation[:2] for obj in objs], axis=0)
            keepout_dists = torch.tensor([obj.KEEPOUT_RADIUS for obj in objs])
            N = xys.shape[0]
            xys_rowwise = xys.unsqueeze(1).expand(-1, N, -1)
            keepout_dists_rowwise = keepout_dists.unsqueeze(1).expand(-1, N)
            xys_colwise = xys.unsqueeze(0).expand(N, -1, -1)
            keepout_dists_colwise = keepout_dists.unsqueeze(0).expand(N, -1)
            dists = (xys_rowwise - xys_colwise).square().sum(axis=-1)
            keepout_dists = (keepout_dists_rowwise + keepout_dists_colwise)

            # Get only lower triangular non-diagonal elems
            rows, cols = torch.tril_indices(N, N, -1)
            # Make sure pairwise dists > keepout dists
            dists = (dists - keepout_dists.square())[rows, cols].reshape(-1, 1)
            all_dists.append(dists)
        if len(all_dists) > 0:
            return torch.cat(all_dists, axis=0)
        else:
            return torch.empty(size=(0, 1))
예제 #20
0
    def forward(self, bottom_output):
        """

        Args:
            numerical_input (Tensor): with shape [batch_size, num_numerical_features]
            categorical_inputs (Tensor): with shape [num_categorical_features, batch_size]
        """
        # The first vector in bottom_output is from bottom mlp
        bottom_mlp_output = bottom_output.narrow(1, 0, 1).squeeze()
        if self._interaction_op == "dot":
            if bottom_output.dtype == torch.half:
                interaction_output = dotBasedInteract(bottom_output, bottom_mlp_output)
            else:  # Legacy path
                interaction = torch.bmm(bottom_output, torch.transpose(bottom_output, 1, 2))
                tril_indices_row, tril_indices_col = torch.tril_indices(
                    interaction.shape[1], interaction.shape[2], offset=-1)
                interaction_flat = interaction[:, tril_indices_row, tril_indices_col]

                # concatenate dense features and interactions
                zero_padding = torch.zeros(
                    bottom_output.shape[0], 1, dtype=bottom_output.dtype, device=bottom_output.device)
                interaction_output = torch.cat((bottom_mlp_output, interaction_flat, zero_padding), dim=1)

        elif self._interaction_op == "cat":
            interaction_output = bottom_output
        else:
            raise NotImplementedError

        top_mlp_output = self.top_mlp(interaction_output)

        return top_mlp_output
예제 #21
0
    def __init__(self,
                 num_nodes,
                 learn_edge_weight,
                 edge_weight,
                 num_features,
                 num_hiddens,
                 num_classes,
                 K,
                 dropout=0.5):
        super(EEGNet, self).__init__()
        self.num_nodes = num_nodes
        self.num_features = num_features
        self.num_hiddens = num_hiddens
        self.xs, self.ys = torch.tril_indices(self.num_nodes,
                                              self.num_nodes,
                                              offset=0)
        edge_weight = edge_weight.reshape(
            self.num_nodes,
            self.num_nodes)[self.xs, self.ys]  # strict lower triangular values
        # edge_weight_gconv = torch.zeros(self.num_hiddens,1)
        # self.edge_weight_gconv = nn.Parameter(edge_weight_gconv, requires_grad=True)
        # nn.init.xavier_uniform_(self.edge_weight_gconv)
        self.edge_weight = nn.Parameter(edge_weight,
                                        requires_grad=learn_edge_weight)
        self.dropout = dropout
        self.chebconv_single = ChebConv(num_features, 1, K, node_dim=0)
        self.chebconv0 = ChebConv(num_features, num_hiddens[0], K, node_dim=0)
        self.chebconv1 = ChebConv(num_hiddens[0], 1, K, node_dim=0)

        # self.fc1 = nn.Linear(num_nodes, num_hiddens)
        self.fc2 = nn.Linear(num_nodes, num_classes)
예제 #22
0
파일: gp_utils.py 프로젝트: pgsrv/var-gp
def vec2tril(vec, m=None):
    '''
  Arguments:
    vec: K x ((m * (m + 1)) // 2)
    m: integer, if None, inferred from last dimension.

  Returns:
    Batch of lower triangular matrices

    tril: K x m x m
  '''
    if m is None:
        D = vec.size(-1)
        m = (torch.tensor(8. * D + 1).sqrt() - 1.) / 2.
        m = m.long().item()

    batch_shape = vec.shape[:-1]

    idx = torch.tril_indices(m, m)

    tril = torch.zeros(*batch_shape, m, m, device=vec.device)
    tril[..., idx[0], idx[1]] = vec

    # ensure positivity constraint of cholesky diagonals
    mask = torch.eye(m, device=vec.device).bool()
    tril = torch.where(mask, F.softplus(tril), tril)

    return tril
예제 #23
0
 def __init__(self, Y_dim, device):
     super(DoubleGaussianNLL, self).__init__(Y_dim, device)
     self.tril_idx = torch.tril_indices(
         self.Y_dim, self.Y_dim, offset=0,
         device=device)  # lower-triangular indices
     self.tril_len = len(self.tril_idx[0])
     self.out_dim = self.Y_dim**2 + 3 * self.Y_dim + 1
예제 #24
0
 def backward(ctx, grad_output):
     if grad_output is None:
         return None, None
     res_m = res_c = None
     need_m, need_c = ctx.needs_input_grad[0:2]
     if need_c or need_m:
         m, c = ctx.saved_tensors
         m_cond, c_cond = make_condition(0, m, c)
         v = diagonal(c, dim1=-2, dim2=-1)
         p = phi(m, v)
         P = Phi(m_cond, c_cond)
         grad_m = -P * p
         grad_output_u1 = grad_output.unsqueeze(-1)
         res_m = grad_output_u1 * grad_m
         if need_c:
             d = c.shape[-1]  # d==1 should never happen here
             if d == 2:
                 P2 = 1
             else:
                 trilind = tril_indices(d, d - 1, offset=-1)
                 m_cond2, c_cond2 = make_condition(0, m_cond, c_cond)
                 Q_l = Phi(m_cond2[..., trilind[0], trilind[1], :],
                           c_cond2[..., trilind[0], trilind[1], :, :])
                 P2 = zeros(*Q_l.shape[:-1], d, d, dtype=Q_l.dtype)
                 P2[..., trilind[0], trilind[1]] = Q_l
                 P2[..., trilind[1], trilind[0]] = Q_l
             p2 = phi2_sub(m, c)
             hess = p2 * P2
             D = -(m * grad_m + (hess * c).sum(-1)) / v
             grad_c = .5 * (hess + diag_embed(D))
             res_c = grad_output_u1.unsqueeze(-1) * grad_c
     return res_m, res_c
예제 #25
0
 def __init__(self, num_tasks, noise_covar, rank=0, task_correlation_prior=None, batch_shape=torch.Size()):
     """
     Args:
         num_tasks (int):
             Number of tasks.
         noise_covar (:obj:`gpytorch.module.Module`):
             A model for the noise covariance. This can be a simple homoskedastic noise model, or a GP
             that is to be fitted on the observed measurement errors.
         rank (int):
             The rank of the task noise covariance matrix to fit. If `rank` is set to 0, then a diagonal covariance
             matrix is fit.
         task_correlation_prior (:obj:`gpytorch.priors.Prior`):
             Prior to use over the task noise correlation matrix. Only used when `rank` > 0.
         batch_shape (torch.Size):
             Number of batches.
     """
     super().__init__(noise_covar=noise_covar)
     if rank != 0:
         if rank > num_tasks:
             raise ValueError(f"Cannot have rank ({rank}) greater than num_tasks ({num_tasks})")
         tidcs = torch.tril_indices(num_tasks, rank, dtype=torch.long)
         self.tidcs = tidcs[:, 1:]  # (1, 1) must be 1.0, no need to parameterize this
         task_noise_corr = torch.randn(*batch_shape, self.tidcs.size(-1))
         self.register_parameter("task_noise_corr", torch.nn.Parameter(task_noise_corr))
         if task_correlation_prior is not None:
             self.register_prior(
                 "MultitaskErrorCorrelationPrior", task_correlation_prior, lambda m: m._eval_corr_matrix
             )
     elif task_correlation_prior is not None:
         raise ValueError("Can only specify task_correlation_prior if rank>0")
     self.num_tasks = num_tasks
     self.rank = rank
예제 #26
0
    def get_weights(self, device):
        # Generate the full weights.
        # return: weights of shape (hidden_dim * 4 , input_dim * hidden_dim)

        # input to hidden
        w_ih = None
        ind = torch.tril_indices(self.hidden_dim,
                                 self.input_dim,
                                 -self.in_num_diags,
                                 device=device)
        weights_f = torch.zeros([self.hidden_dim, self.input_dim],
                                device=device)
        for coeffs in self.in_coeffss:
            if self.dropout_dct:
                coeffs = self.wdrop(coeffs)
            weights = self.to_weights(coeffs, ind, weights_f,
                                      self.in_dct_layer, self.hid_dct_layer)

            if w_ih is not None:
                w_ih = torch.cat([w_ih, weights], dim=0)
            else:
                w_ih = weights

        # hidden to hidden
        w_hh = None
        ind = torch.tril_indices(self.hidden_dim,
                                 self.hidden_dim,
                                 -self.hidden_num_diags,
                                 device=device)
        weights_f = torch.zeros([self.hidden_dim, self.hidden_dim],
                                device=device)

        for coeffs in self.hid_coeffss:
            if self.dropout_dct:
                coeffs = self.wdrop(coeffs)

            weights = self.to_weights(coeffs, ind, weights_f,
                                      self.hid_dct_layer, self.hid_dct_layer)

            if w_hh is not None:
                w_hh = torch.cat([w_hh, weights], dim=0)
            else:
                w_hh = weights

        # concatenate both
        weights = torch.cat([w_ih, w_hh], dim=1)
        return weights
예제 #27
0
    def __call__(self, x, target):
        """
        :param x: output segmentation, shape [*, C, *]
        :param Sigma: co-variance coefficients. It can be:
            (1) If no_covar==False, shape [*, C(C+1)/2, *] organized as row-first according to tril_indices
               from torch and numpy : [rho_11, rho_12, ..., rho_1C, rho_22, rho_23,...rho_2C,... rho_CC]
               with rho_ii = exp(.) > 0 encodes the variances and rho_ij = tanh(.) encodes the correlations.
               The covariance matrix is M is s.t  M[i][j] = rho_ij * srqrt(rho_ii) * sqrt(rho_ij)
            (2) If no_covar==True, shape [*, C, *], assuming that all non-diagonal coeff are zeros. We assume it
                has the form [sigma_1**2, sigma_2**2, ..., sigma_C**2]
        :param target: true segmentation, shape [*, C, *]
        :return: log-likelihood for logistic regression with uncertainty
        """

        if isinstance(
                x, list
        ):  #should happen just for regression, where sigma_prediction is used in metric (utils)
            x, Sigma = x[0], x[1]
            log_Sigma = Sigma
            Sigma = torch.exp(log_Sigma) + 1e-6
            #if Sigma.min() < 1e-6:
            #    print(f'Warning min Sigma {Sigma.min()}')

        C, ndims = x.shape[1], x.ndim

        if self.no_covar:
            # Simplified Case
            assert C == Sigma.shape[1] and Sigma.ndim == ndims,\
                "Inconsistent shape for input data and covariance: {} vs {}".format(x.shape, Sigma.shape)
            assert torch.all(Sigma > 0), "Negative values found in Sigma"
            inv_Sigma = 1. / Sigma  # shape [*, C, *]
            #logdet_sigma = torch.log(torch.prod(Sigma, dim=1)) # shape [*, *]
            logdet_sigma = torch.sum(log_Sigma, dim=1)  # shape [*, *]
            err = (target - x)  # shape [*, C, *]
            return ((err * inv_Sigma * err).sum(dim=1) +
                    logdet_sigma.squeeze()).mean()
        else:
            # General Case
            assert (C * (C+1))//2 == Sigma.shape[1] and Sigma.ndim == ndims, \
                "Inconsistent shape for input data and covariance: {} vs {}".format(x.shape, Sigma.shape)
            # permutes the 2nd dim to last, keeping other unchanged (in v1.9, eq. to torch.moveaxis(1, -1))
            swap_channel_last = (0, ) + tuple(range(2, ndims)) + (1, )
            # First, re-arrange covar matrix to have shape [*, *, C, C]
            covar_shape = (Sigma.shape[0], ) + Sigma.shape[2:] + (C, C)
            tril_ind = torch.tril_indices(row=C, col=C, offset=0)
            triu_ind = torch.triu_indices(row=C, col=C, offset=0)
            Sigma_ = torch.zeros(covar_shape, device=x.device)
            Sigma_[..., tril_ind[0],
                   tril_ind[1]] = Sigma.permute(swap_channel_last)
            Sigma_[..., triu_ind[0],
                   triu_ind[1]] = Sigma.permute(swap_channel_last)
            # Then compute determinant and inverse of covariance matrices
            logdet_sigma = torch.logdet(Sigma_)  # shape [*, *]
            inv_sigma = torch.inverse(Sigma_)  # shape [*, *, C, C]
            # Finally, compute log-likehood of multivariate gaussian distribution
            err = (target - x).permute(swap_channel_last).unsqueeze(
                -1)  # shape [*, *, C, 1]
            return ((err.transpose(-1, -2) @ inv_sigma @ err).squeeze() +
                    logdet_sigma.squeeze()).mean()
예제 #28
0
 def forward(self, x):
     hidden = self.softplus(self.fc1(x))
     x_loc = self.fc21(hidden)
     x_scale_ = self.fc22(hidden)
     x_scale = torch.zeros(x_scale_.shape[:-1] + (self.x_dim, self.x_dim))
     idx = torch.tril_indices(self.x_dim, self.x_dim)
     x_scale[..., idx[0], idx[1]] = x_scale_[..., :]
     return x_loc, x_scale
예제 #29
0
 def __init__(self, Y_dim, device, Y_mean=None, Y_std=None):
     super(DoubleGaussianBNNPosterior,
           self).__init__(Y_dim, device, Y_mean, Y_std)
     self.tril_idx = torch.tril_indices(
         self.Y_dim, self.Y_dim, offset=0,
         device=device)  # lower-triangular indices
     self.tril_len = len(self.tril_idx[0])
     self.out_dim = self.Y_dim**2 + 3 * self.Y_dim + 1
예제 #30
0
 def __init__(self, Y_dim, device):
     super(FullRankGaussianNLL, self).__init__(Y_dim, device)
     self.tril_idx = torch.tril_indices(self.Y_dim,
                                        self.Y_dim,
                                        offset=0,
                                        device=device)  # lower-triang idx
     self.tril_len = len(self.tril_idx[0])
     self.out_dim = self.Y_dim + self.Y_dim * (self.Y_dim + 1) // 2