def update_tril(entries, D): tril = torch.zeros(D, D) tril[range(D), range(D)] = softplus(entries[0:D]) off_idx = torch.tril_indices(D, D)[0] != torch.tril_indices(D, D)[1] a, b = torch.tril_indices(D, D)[:, off_idx] tril[a, b] = entries[D:] return tril
def get_sparse_config(self, in_dim, out_dim, sparsity_level): '''Get num_diagonals and num coeffs. Given the dimension of matrix in_dim: number of columns out_dim: number of rows We want to find the right diagonal shift "d" s.t. N(d) < thr(desired sparsity) < N(d+1) N(d+1) We search as follows: - If: N(0) is below thr: try N(n) for n = -1..-out_dim - Else: try N(n) for n = 1..in_dim input: 2 dimensions of the weight matrix output: tuple (num_diagonal, num_coeff) ''' total_el = in_dim * out_dim thr = int(total_el * (1 - sparsity_level)) # just truncate fraction. for num_diag in range(out_dim): # upper triagular matrix. non_zeros = torch.tril_indices(out_dim, in_dim, -num_diag).size()[1] if non_zeros < thr: break print(f"sparsity: {(total_el - non_zeros) / total_el * 100 :.1f} %" f" vs. desired sparsity {sparsity_level * 100} %") return non_zeros, num_diag
def synchronize(self): for h in self.handles: hvd.synchronize(h) if self.merge: self._tensor_group.pull_alltensors() self._tensor_group.clear_group_flags() for name in self._name_tensors: tensor, comm_tensor = self._name_tensors[name] if self.symmetric: if self.fp16: comm_tensor = comm_tensor.float() lower_indices = torch.tril_indices(tensor.shape[0], tensor.shape[1], device=tensor.device) upper_indices = torch.triu_indices(tensor.shape[0], tensor.shape[1], device=tensor.device) tensor[upper_indices[0], upper_indices[1]] = comm_tensor tensor[lower_indices[0], lower_indices[1]] = tensor.t()[lower_indices[0], lower_indices[1]] else: if self.fp16: comm_tensor = comm_tensor.float() tensor.copy_(comm_tensor) if self.op == hvd.Average: tensor.div_(hvd.size()) self._name_tensors.clear() self.handles.clear()
def _assemble_tril(diag: torch.Tensor, lower_vec: torch.Tensor) -> torch.Tensor: dim = diag.shape[-1] L = torch.diag_embed(diag) # L is lower-triangular i, j = torch.tril_indices(dim, dim, offset=-1) L[..., i, j] = lower_vec return L
def _call(self, x): dim = int((-1 + math.sqrt(1 + 8 * x.shape[0])) / 2) tril = torch.zeros((dim, dim), dtype=x.dtype) tril_indices = torch.tril_indices(row=dim, col=dim, offset=0) tril[tril_indices[0], tril_indices[1]] = x tril[range(dim), range(dim)] = tril.diag().exp() return tril
def forward(self, state): """ forwards input through the network :param state: (B, ds) :return: mean vector (B, da) and cholesky factorization of covariance matrix (B, da, da) """ device = state.device B = state.size(0) ds = self.ds da = self.da action_low = torch.from_numpy(self.env.action_space.low)[None, ...].to( device) # (1, da) action_high = torch.from_numpy( self.env.action_space.high)[None, ...].to(device) # (1, da) x = F.relu(self.lin1(state)) x = F.relu(self.lin2(x)) mean = torch.sigmoid(self.mean_layer(x)) # (B, da) mean = action_low + (action_high - action_low) * mean cholesky_vector = self.cholesky_layer(x) # (B, (da*(da+1))//2) cholesky_diag_index = torch.arange(da, dtype=torch.long) + 1 cholesky_diag_index = (cholesky_diag_index * (cholesky_diag_index + 1)) // 2 - 1 cholesky_vector[:, cholesky_diag_index] = F.softplus( cholesky_vector[:, cholesky_diag_index]) tril_indices = torch.tril_indices(row=da, col=da, offset=0) cholesky = torch.zeros(size=(B, da, da), dtype=torch.float32).to(device) cholesky[:, tril_indices[0], tril_indices[1]] = cholesky_vector return mean, cholesky
def _unflatten_tril(x): """Unflattens a vector into a lower triangular matrix of shape `dim x dim`.""" n, dim = x.shape idxs = torch.tril_indices(dim, dim) tril = torch.zeros(n, dim, dim) tril[:, idxs[0, :], idxs[1, :]] = x return tril
def lower_vector_to_matrix( lower_flat: torch.Tensor, matrix_dim: Optional[int] = None ) -> torch.Tensor: """Convert a valid vector to a lower triangular matrix. Parameters ---------- vector : torch.Tensor vector matix_dim : Optional[int] matix_dim Returns ------- torch.Tensor """ shape = lower_flat.shape if matrix_dim is None: # (N, N) matrix has L = N * (N + 1) / 2 values in its # lower triangular region (including diagonals). # N = 0.5 * sqrt(8 * L + 1)**0.5 - 0.5 matrix_dim = int(0.5 * (8 * shape[-1] + 1) ** 0.5 - 0.5) matrix_shape = shape[:-1] + (matrix_dim, matrix_dim) lower = torch.zeros(matrix_shape) lower_idx = torch.tril_indices(matrix_dim, matrix_dim) lower[..., lower_idx[0], lower_idx[1]] = lower_flat return lower
def symmetry(x, mode="real"): center = (x.shape[1]) // 2 u = torch.arange(center) v = torch.arange(center) diag1 = torch.arange(center, x.shape[1]) diag2 = torch.arange(center, x.shape[1]) diag_indices = torch.stack((diag1, diag2)) grid = torch.tril_indices(x.shape[1], x.shape[1], -1) x_sym = torch.cat( (grid[0].reshape(-1, 1), diag_indices[0].reshape(-1, 1)), ) y_sym = torch.cat( (grid[1].reshape(-1, 1), diag_indices[1].reshape(-1, 1)), ) x = torch.rot90(x, 1, dims=(1, 2)) i = center + (center - x_sym) j = center + (center - y_sym) u = center - (center - x_sym) v = center - (center - y_sym) if mode == "real": x[:, i, j] = x[:, u, v] if mode == "imag": x[:, i, j] = -x[:, u, v] return torch.rot90(x, 3, dims=(1, 2))
def __init__(self, input_dim, hidden_dim, output_dim, K): super(MDN, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.K = K self.tril_indices = torch.tril_indices(row=output_dim, col=output_dim, offset=-1) # initialize linear transformations self.lin_input_to_hidden = nn.Linear(input_dim, hidden_dim) self.lin_hidden_to_hidden = nn.Linear(hidden_dim, hidden_dim) self.lin_hidden_to_hidden = nn.Linear(hidden_dim, hidden_dim) self.lin_hidden_to_mix_components = nn.Linear(hidden_dim, K) self.lin_hidden_to_loc = nn.Linear(hidden_dim, output_dim * K) self.lin_hidden_to_offdiag = nn.Linear(hidden_dim, K) self.lin_hidden_to_sigma = nn.Linear(hidden_dim, output_dim * K) # initialize non-linearities self.relu = nn.ReLU() self.softmax = nn.Softmax(dim=1) self.softplus = nn.Softplus() self.bn1 = nn.BatchNorm1d(hidden_dim) self.bn2 = nn.BatchNorm1d(hidden_dim) self.dropout1 = nn.Dropout(p=0.5) self.dropout2 = nn.Dropout(p=0.3)
def pt_meddistance(X, subsample=None, seed=283): """ Compute the median of pairwise distances (not distance squared) of points in the matrix. Useful as a heuristic for setting Gaussian kernel's width. Parameters ---------- X : n x d torch tensor Return ------ median distance (a scalar, not a torch tensor) """ n = X.shape[0] if subsample is None: D = torch.sqrt(pt_dist2_matrix(X, X)) I = torch.tril_indices(n, n, -1) Tri = D[I[0], I[1]] med = torch.median(Tri) return med.item() else: assert subsample > 0 with NumpySeedContext(seed=seed): ind = np.random.choice(n, min(subsample, n), replace=False) # recursion just once return pt_meddistance(X[ind], None, seed=seed)
def _interaction(self, bottom_mlp_output, embedding_outputs, batch_size): """Interaction "dot" interaction is a bit tricky to implement and test. Break it out from forward so that it can be tested independently. Args: bottom_mlp_output (Tensor): embedding_outputs (list): Sequence of tensors batch_size (int): """ concat = torch.cat([bottom_mlp_output] + embedding_outputs, dim=1) if self._interaction_op == "dot" and not self._self_interaction: concat = concat.view((batch_size, -1, self._embedding_dim)) if concat.dtype == torch.half: interaction_output = dotBasedInteract(concat, bottom_mlp_output) else: # Legacy path interaction = torch.bmm(concat, torch.transpose(concat, 1, 2)) tril_indices_row, tril_indices_col = torch.tril_indices( interaction.shape[1], interaction.shape[2], offset=-1) interaction_flat = interaction[:, tril_indices_row, tril_indices_col] # concatenate dense features and interactions zero_padding = torch.zeros( concat.shape[0], 1, dtype=concat.dtype, device=concat.device) interaction_output = torch.cat((bottom_mlp_output, interaction_flat, zero_padding), dim=1) elif self._interaction_op == "cat": interaction_output = concat else: raise NotImplementedError return interaction_output
def synchronize(self): self.merged_comm.synchronize() for h in self.handles: handle, names, tensors, comm_tensors, rank = h if rank != hvd.rank(): continue name = ','.join(names) offset = 0 buf = self.merged_tensors[name] if self.fp16: buf = buf.float() for i, t in enumerate(tensors): numel = comm_tensors[i].numel() comm_tensor = buf.data[offset:offset+numel] if self.symmetric: lower_indices = torch.tril_indices(t.shape[0], t.shape[1], device=t.device) upper_indices = torch.triu_indices(t.shape[0], t.shape[1], device=t.device) t[upper_indices[0], upper_indices[1]] = comm_tensor t[lower_indices[0], lower_indices[1]] = t.t()[lower_indices[0], lower_indices[1]] else: t.copy_(comm_tensor.view(t.shape)) t.div_(hvd.size()) offset += numel self.handles.clear()
def rsample(self, sample_shape=torch.Size()): """ References ---------- - Sawyer, S. (2007). Wishart Distributions and Inverse-Wishart Sampling. https://www.math.wustl.edu/~sawyer/hmhandouts/Wishart.pdf - Anderson, T. W. (2003). An Introduction to Multivariate Statistical Analysis (3rd ed.). John Wiley & Sons, Inc. - Odell, P. L. & Feiveson, A. H. (1966). A Numerical Procedure to Generate a Sample Covariance Matrix. Journal of the American Statistical Association, 61(313):199-203. - Ku, Y.-C. & Blomfield, P. (2010). Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX. """ shape = torch.Size(sample_shape) + self.batch_shape dtype, device = self.concentration.dtype, self.concentration.device D = self.event_shape[-1] df = 2. * self.concentration # type: torch.Tensor i = torch.arange(D, dtype=dtype, device=device) concentration = .5 * (df.unsqueeze(-1) - i).expand(shape + (D, )) V = 2. * torch._standard_gamma(concentration) N = torch.randn(*shape, D * (D - 1) // 2, dtype=dtype, device=device) T = torch.diag_embed(V.sqrt()) # T is lower-triangular i, j = torch.tril_indices(D, D, offset=-1) T[..., i, j] = N M = self.scale_tril @ T W = M @ M.transpose(-2, -1) return W
def __init__(self, block_size, offset=0, inverse=True, **kwargs): """Applies repeated lower-triangular block diagonal matrices. Let H be the size of a lower-triangular block diagonal matrix A. This layer applies: [ A, 0, 0, ... 0, A, 0, ... 0, 0, A, ... ., ., ., ... ] Args: block_size (int): Block size H offset (int, optional): Offset of A along the diagonal. Defaults to 0. inverse (bool, optional): Species which direction should be modelled directly. The matrix for the inverse A^-1 is computed by torch.inv. Defaults to True. """ super().__init__() self.block_size = block_size self.n_params = (block_size**2 - block_size) // 2 + block_size self.params = nn.Parameter(torch.zeros(self.n_params)) mask = torch.zeros((block_size, block_size), dtype=torch.bool) idx = torch.tril_indices(block_size, block_size) mask[tuple(idx)] = 1 self.mask = nn.Parameter(mask, requires_grad=False) diag_idx = (torch.arange(block_size) + 1).cumsum(0) - 1 self.diag_idx = nn.Parameter(diag_idx, requires_grad=False) self.offset = offset % self.block_size self._inverse = inverse
def log_cholesky_transform(x): r"""Perform the log cholesky transform on a vector of values. This turns a vector of :math:`\frac{N(N+1)}{2}` unconstrained values into a valid :math:`N \times N` covariance matrix. References ---------- - Jose C. Pinheiro & Douglas M. Bates. `Unconstrained Parameterizations for Variance-Covariance Matrices <https://dx.doi.org/10.1007/BF00140873>`_ *Statistics and Computing*, 1996. """ if get_backend() == "pytorch": import numpy as np import torch N = int((np.sqrt(1 + 8 * torch.numel(x)) - 1) / 2) E = torch.zeros((N, N), dtype=get_datatype()) tril_ix = torch.tril_indices(row=N, col=N, offset=0) E[..., tril_ix[0], tril_ix[1]] = x E[..., range(N), range(N)] = torch.exp(torch.diagonal(E)) return E @ torch.transpose(E, -1, -2) else: import tensorflow as tf import tensorflow_probability as tfp E = tfp.math.fill_triangular(x) E = tf.linalg.set_diag(E, tf.exp(tf.linalg.tensor_diag_part(E))) return E @ tf.transpose(E)
def interact_features(self, x, ly): if self.arch_interaction_op == "dot": # concatenate dense and sparse features (batch_size, d) = x.shape T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d)) # perform a dot product Z = torch.bmm(T, torch.transpose(T, 1, 2)) # append dense feature with the interactions (into a row vector) # approach 1: all # Zflat = Z.view((batch_size, -1)) # approach 2: unique _, ni, nj = Z.shape offset = 0 if self.arch_interaction_itself else -1 li, lj = torch.tril_indices(ni, nj, offset=offset) Zflat = Z[:, li, lj] # concatenate dense features and interactions R = torch.cat([x] + [Zflat], dim=1) elif self.arch_interaction_op == "cat": # concatenation features (into a row vector) R = torch.cat([x] + ly, dim=1) else: sys.exit( "ERROR: --arch-interaction-op=" + self.arch_interaction_op + " is not supported" ) return R
def x_to_xp_L(x, n): """Return (x, L, ln(diag(L))) given net output x :param x: Neural net output, (n_batch, n_outputs) :param n: Number of variables Note that the net predicts: ( point estimates, ln of diagonal entries of L, off-diagonal entries of L (flattened) ) """ if not isinstance(x, torch.Tensor): x = torch.Tensor(x) _, n_out = x.shape assert n_out == dds.n_out(n, 'correlated'), "Wrong number of outputs" # Split off the covariance terms from x x_p, x_diag, x_nondiag = x[:, :n], x[:, n:2 * n], x[:, 2 * n:] # Create diagonal matrices L = torch.diag_embed(torch.exp(x_diag)) # Get indices of elements below main diagonal row_indices, col_indices = torch.tril_indices(n, n, -1) # Add off-diagonal entries in-place L[:, row_indices, col_indices] += x_nondiag return x_p, L, x_diag
def eval(self, scene_tree): tables = scene_tree.find_nodes_by_type(Table) all_dists = [] for table in tables: objs = [ node for node in scene_tree.get_children_recursive(table) if isinstance(node, TabletopObjectTypes) and not isinstance(scene_tree.get_parent(node), SteamerBottom) and not isinstance(node, FirstChopstick) and not isinstance(node, SecondChopstick) ] if len(objs) <= 1: print("no objects") continue xys = torch.stack([obj.translation[:2] for obj in objs], axis=0) keepout_dists = torch.tensor([obj.KEEPOUT_RADIUS for obj in objs]) N = xys.shape[0] xys_rowwise = xys.unsqueeze(1).expand(-1, N, -1) keepout_dists_rowwise = keepout_dists.unsqueeze(1).expand(-1, N) xys_colwise = xys.unsqueeze(0).expand(N, -1, -1) keepout_dists_colwise = keepout_dists.unsqueeze(0).expand(N, -1) dists = (xys_rowwise - xys_colwise).square().sum(axis=-1) keepout_dists = (keepout_dists_rowwise + keepout_dists_colwise) # Get only lower triangular non-diagonal elems rows, cols = torch.tril_indices(N, N, -1) # Make sure pairwise dists > keepout dists dists = (dists - keepout_dists.square())[rows, cols].reshape(-1, 1) all_dists.append(dists) if len(all_dists) > 0: return torch.cat(all_dists, axis=0) else: return torch.empty(size=(0, 1))
def forward(self, bottom_output): """ Args: numerical_input (Tensor): with shape [batch_size, num_numerical_features] categorical_inputs (Tensor): with shape [num_categorical_features, batch_size] """ # The first vector in bottom_output is from bottom mlp bottom_mlp_output = bottom_output.narrow(1, 0, 1).squeeze() if self._interaction_op == "dot": if bottom_output.dtype == torch.half: interaction_output = dotBasedInteract(bottom_output, bottom_mlp_output) else: # Legacy path interaction = torch.bmm(bottom_output, torch.transpose(bottom_output, 1, 2)) tril_indices_row, tril_indices_col = torch.tril_indices( interaction.shape[1], interaction.shape[2], offset=-1) interaction_flat = interaction[:, tril_indices_row, tril_indices_col] # concatenate dense features and interactions zero_padding = torch.zeros( bottom_output.shape[0], 1, dtype=bottom_output.dtype, device=bottom_output.device) interaction_output = torch.cat((bottom_mlp_output, interaction_flat, zero_padding), dim=1) elif self._interaction_op == "cat": interaction_output = bottom_output else: raise NotImplementedError top_mlp_output = self.top_mlp(interaction_output) return top_mlp_output
def __init__(self, num_nodes, learn_edge_weight, edge_weight, num_features, num_hiddens, num_classes, K, dropout=0.5): super(EEGNet, self).__init__() self.num_nodes = num_nodes self.num_features = num_features self.num_hiddens = num_hiddens self.xs, self.ys = torch.tril_indices(self.num_nodes, self.num_nodes, offset=0) edge_weight = edge_weight.reshape( self.num_nodes, self.num_nodes)[self.xs, self.ys] # strict lower triangular values # edge_weight_gconv = torch.zeros(self.num_hiddens,1) # self.edge_weight_gconv = nn.Parameter(edge_weight_gconv, requires_grad=True) # nn.init.xavier_uniform_(self.edge_weight_gconv) self.edge_weight = nn.Parameter(edge_weight, requires_grad=learn_edge_weight) self.dropout = dropout self.chebconv_single = ChebConv(num_features, 1, K, node_dim=0) self.chebconv0 = ChebConv(num_features, num_hiddens[0], K, node_dim=0) self.chebconv1 = ChebConv(num_hiddens[0], 1, K, node_dim=0) # self.fc1 = nn.Linear(num_nodes, num_hiddens) self.fc2 = nn.Linear(num_nodes, num_classes)
def vec2tril(vec, m=None): ''' Arguments: vec: K x ((m * (m + 1)) // 2) m: integer, if None, inferred from last dimension. Returns: Batch of lower triangular matrices tril: K x m x m ''' if m is None: D = vec.size(-1) m = (torch.tensor(8. * D + 1).sqrt() - 1.) / 2. m = m.long().item() batch_shape = vec.shape[:-1] idx = torch.tril_indices(m, m) tril = torch.zeros(*batch_shape, m, m, device=vec.device) tril[..., idx[0], idx[1]] = vec # ensure positivity constraint of cholesky diagonals mask = torch.eye(m, device=vec.device).bool() tril = torch.where(mask, F.softplus(tril), tril) return tril
def __init__(self, Y_dim, device): super(DoubleGaussianNLL, self).__init__(Y_dim, device) self.tril_idx = torch.tril_indices( self.Y_dim, self.Y_dim, offset=0, device=device) # lower-triangular indices self.tril_len = len(self.tril_idx[0]) self.out_dim = self.Y_dim**2 + 3 * self.Y_dim + 1
def backward(ctx, grad_output): if grad_output is None: return None, None res_m = res_c = None need_m, need_c = ctx.needs_input_grad[0:2] if need_c or need_m: m, c = ctx.saved_tensors m_cond, c_cond = make_condition(0, m, c) v = diagonal(c, dim1=-2, dim2=-1) p = phi(m, v) P = Phi(m_cond, c_cond) grad_m = -P * p grad_output_u1 = grad_output.unsqueeze(-1) res_m = grad_output_u1 * grad_m if need_c: d = c.shape[-1] # d==1 should never happen here if d == 2: P2 = 1 else: trilind = tril_indices(d, d - 1, offset=-1) m_cond2, c_cond2 = make_condition(0, m_cond, c_cond) Q_l = Phi(m_cond2[..., trilind[0], trilind[1], :], c_cond2[..., trilind[0], trilind[1], :, :]) P2 = zeros(*Q_l.shape[:-1], d, d, dtype=Q_l.dtype) P2[..., trilind[0], trilind[1]] = Q_l P2[..., trilind[1], trilind[0]] = Q_l p2 = phi2_sub(m, c) hess = p2 * P2 D = -(m * grad_m + (hess * c).sum(-1)) / v grad_c = .5 * (hess + diag_embed(D)) res_c = grad_output_u1.unsqueeze(-1) * grad_c return res_m, res_c
def __init__(self, num_tasks, noise_covar, rank=0, task_correlation_prior=None, batch_shape=torch.Size()): """ Args: num_tasks (int): Number of tasks. noise_covar (:obj:`gpytorch.module.Module`): A model for the noise covariance. This can be a simple homoskedastic noise model, or a GP that is to be fitted on the observed measurement errors. rank (int): The rank of the task noise covariance matrix to fit. If `rank` is set to 0, then a diagonal covariance matrix is fit. task_correlation_prior (:obj:`gpytorch.priors.Prior`): Prior to use over the task noise correlation matrix. Only used when `rank` > 0. batch_shape (torch.Size): Number of batches. """ super().__init__(noise_covar=noise_covar) if rank != 0: if rank > num_tasks: raise ValueError(f"Cannot have rank ({rank}) greater than num_tasks ({num_tasks})") tidcs = torch.tril_indices(num_tasks, rank, dtype=torch.long) self.tidcs = tidcs[:, 1:] # (1, 1) must be 1.0, no need to parameterize this task_noise_corr = torch.randn(*batch_shape, self.tidcs.size(-1)) self.register_parameter("task_noise_corr", torch.nn.Parameter(task_noise_corr)) if task_correlation_prior is not None: self.register_prior( "MultitaskErrorCorrelationPrior", task_correlation_prior, lambda m: m._eval_corr_matrix ) elif task_correlation_prior is not None: raise ValueError("Can only specify task_correlation_prior if rank>0") self.num_tasks = num_tasks self.rank = rank
def get_weights(self, device): # Generate the full weights. # return: weights of shape (hidden_dim * 4 , input_dim * hidden_dim) # input to hidden w_ih = None ind = torch.tril_indices(self.hidden_dim, self.input_dim, -self.in_num_diags, device=device) weights_f = torch.zeros([self.hidden_dim, self.input_dim], device=device) for coeffs in self.in_coeffss: if self.dropout_dct: coeffs = self.wdrop(coeffs) weights = self.to_weights(coeffs, ind, weights_f, self.in_dct_layer, self.hid_dct_layer) if w_ih is not None: w_ih = torch.cat([w_ih, weights], dim=0) else: w_ih = weights # hidden to hidden w_hh = None ind = torch.tril_indices(self.hidden_dim, self.hidden_dim, -self.hidden_num_diags, device=device) weights_f = torch.zeros([self.hidden_dim, self.hidden_dim], device=device) for coeffs in self.hid_coeffss: if self.dropout_dct: coeffs = self.wdrop(coeffs) weights = self.to_weights(coeffs, ind, weights_f, self.hid_dct_layer, self.hid_dct_layer) if w_hh is not None: w_hh = torch.cat([w_hh, weights], dim=0) else: w_hh = weights # concatenate both weights = torch.cat([w_ih, w_hh], dim=1) return weights
def __call__(self, x, target): """ :param x: output segmentation, shape [*, C, *] :param Sigma: co-variance coefficients. It can be: (1) If no_covar==False, shape [*, C(C+1)/2, *] organized as row-first according to tril_indices from torch and numpy : [rho_11, rho_12, ..., rho_1C, rho_22, rho_23,...rho_2C,... rho_CC] with rho_ii = exp(.) > 0 encodes the variances and rho_ij = tanh(.) encodes the correlations. The covariance matrix is M is s.t M[i][j] = rho_ij * srqrt(rho_ii) * sqrt(rho_ij) (2) If no_covar==True, shape [*, C, *], assuming that all non-diagonal coeff are zeros. We assume it has the form [sigma_1**2, sigma_2**2, ..., sigma_C**2] :param target: true segmentation, shape [*, C, *] :return: log-likelihood for logistic regression with uncertainty """ if isinstance( x, list ): #should happen just for regression, where sigma_prediction is used in metric (utils) x, Sigma = x[0], x[1] log_Sigma = Sigma Sigma = torch.exp(log_Sigma) + 1e-6 #if Sigma.min() < 1e-6: # print(f'Warning min Sigma {Sigma.min()}') C, ndims = x.shape[1], x.ndim if self.no_covar: # Simplified Case assert C == Sigma.shape[1] and Sigma.ndim == ndims,\ "Inconsistent shape for input data and covariance: {} vs {}".format(x.shape, Sigma.shape) assert torch.all(Sigma > 0), "Negative values found in Sigma" inv_Sigma = 1. / Sigma # shape [*, C, *] #logdet_sigma = torch.log(torch.prod(Sigma, dim=1)) # shape [*, *] logdet_sigma = torch.sum(log_Sigma, dim=1) # shape [*, *] err = (target - x) # shape [*, C, *] return ((err * inv_Sigma * err).sum(dim=1) + logdet_sigma.squeeze()).mean() else: # General Case assert (C * (C+1))//2 == Sigma.shape[1] and Sigma.ndim == ndims, \ "Inconsistent shape for input data and covariance: {} vs {}".format(x.shape, Sigma.shape) # permutes the 2nd dim to last, keeping other unchanged (in v1.9, eq. to torch.moveaxis(1, -1)) swap_channel_last = (0, ) + tuple(range(2, ndims)) + (1, ) # First, re-arrange covar matrix to have shape [*, *, C, C] covar_shape = (Sigma.shape[0], ) + Sigma.shape[2:] + (C, C) tril_ind = torch.tril_indices(row=C, col=C, offset=0) triu_ind = torch.triu_indices(row=C, col=C, offset=0) Sigma_ = torch.zeros(covar_shape, device=x.device) Sigma_[..., tril_ind[0], tril_ind[1]] = Sigma.permute(swap_channel_last) Sigma_[..., triu_ind[0], triu_ind[1]] = Sigma.permute(swap_channel_last) # Then compute determinant and inverse of covariance matrices logdet_sigma = torch.logdet(Sigma_) # shape [*, *] inv_sigma = torch.inverse(Sigma_) # shape [*, *, C, C] # Finally, compute log-likehood of multivariate gaussian distribution err = (target - x).permute(swap_channel_last).unsqueeze( -1) # shape [*, *, C, 1] return ((err.transpose(-1, -2) @ inv_sigma @ err).squeeze() + logdet_sigma.squeeze()).mean()
def forward(self, x): hidden = self.softplus(self.fc1(x)) x_loc = self.fc21(hidden) x_scale_ = self.fc22(hidden) x_scale = torch.zeros(x_scale_.shape[:-1] + (self.x_dim, self.x_dim)) idx = torch.tril_indices(self.x_dim, self.x_dim) x_scale[..., idx[0], idx[1]] = x_scale_[..., :] return x_loc, x_scale
def __init__(self, Y_dim, device, Y_mean=None, Y_std=None): super(DoubleGaussianBNNPosterior, self).__init__(Y_dim, device, Y_mean, Y_std) self.tril_idx = torch.tril_indices( self.Y_dim, self.Y_dim, offset=0, device=device) # lower-triangular indices self.tril_len = len(self.tril_idx[0]) self.out_dim = self.Y_dim**2 + 3 * self.Y_dim + 1
def __init__(self, Y_dim, device): super(FullRankGaussianNLL, self).__init__(Y_dim, device) self.tril_idx = torch.tril_indices(self.Y_dim, self.Y_dim, offset=0, device=device) # lower-triang idx self.tril_len = len(self.tril_idx[0]) self.out_dim = self.Y_dim + self.Y_dim * (self.Y_dim + 1) // 2