def fill_triu( shape: tuple[int, ...], triu_tensor: torch.Tensor, ) -> torch.Tensor: """Reconstruct symmetric 2D tensor from flattened upper triangle. Usage: >>> x = tensor.new_empty([10, 10]) >>> triu_x = get_triu(x) >>> x_new = fill_triu([10, 10], triu_tensor) >>> assert torch.equal(x, x_new) # true Args: shape (tuple): tuple(rows, cols) of size of output tensor. triu_tensor (tensor): flattened upper triangle of the tensor returned by get_triu(). Returns: symmetric tensor with `shape` where the upper/lower triangles are filled with the data in `triu_tensor` """ if len(shape) != 2: raise ValueError('shape must be 2 dimensional') rows, cols = shape dst_tensor = triu_tensor.new_empty(shape) idxs = torch.triu_indices(rows, cols, device=triu_tensor.device) dst_tensor[idxs[0], idxs[1]] = triu_tensor idxs = torch.triu_indices(rows, rows, 1, device=dst_tensor.device) dst_tensor.transpose(0, 1)[idxs[0], idxs[1]] = dst_tensor[idxs[0], idxs[1]] return dst_tensor
def forward(self, A_src, A_tgt, F_src, F_tgt, U_src, U_tgt, w1=1, w2=1): """ FORWARD ROUTINE - Compute global affinity matrix """ # (A) Retrieve shape parameters n, m = F_src.shape[-1], F_tgt.shape[-1] batch_num = F_src.shape[0] # (B) Construct weight Matrix lambda1 = self.relu(self.lambda1 + self.lambda1.transpose(0, 1)) * w1 lambda2 = self.relu(self.lambda2 + self.lambda2.transpose(0, 1)) * w2 weight = torch.cat((torch.cat( (lambda1, lambda2)), torch.cat((lambda2, lambda1))), 1) # (C) Construct G1, G2 and H1, H2 G1 = torch.zeros(batch_num, n, A_src.shape[1], device=F_src.device) #.to(dtype = torch.bool) H1 = torch.zeros(batch_num, n, A_src.shape[1], device=F_src.device) #.to(dtype = torch.bool) G2 = torch.zeros(batch_num, m, A_tgt.shape[1], device=F_src.device) #.to(dtype = torch.bool) H2 = torch.zeros(batch_num, m, A_tgt.shape[1], device=F_src.device) #.to(dtype = torch.bool) a = torch.triu_indices(U_src.shape[1], U_src.shape[2], offset=1) b = torch.triu_indices(U_tgt.shape[1], U_tgt.shape[2], offset=1) for i in range(batch_num): G1[i, a[0, A_src[i] > 0], A_src[i] > 0] = 1 H1[i, a[1, A_src[i] > 0], A_src[i] > 0] = 1 G2[i, b[0, A_tgt[i] > 0], A_tgt[i] > 0] = 1 H2[i, b[1, A_tgt[i] > 0], A_tgt[i] > 0] = 1 # (D) Reshape Edge Feature for further use X = reshape_edge_feature(F_src, G1, H1) Y = reshape_edge_feature(F_tgt, G2, H2) # (E) Compute Me and Mp (node-to-node and edge-to-edge similarities) M_e = torch.bmm( torch.bmm(X.permute(0, 2, 1), weight.expand(X.shape[0], -1, -1)), Y) M_p = torch.bmm(U_src.permute(0, 2, 1), U_tgt) # (F) Compute M based on the Affinity Matrix Factorization [Zhou and De Lea Torre] K1 = kronecker_torch(G2, G1) K2 = kronecker_torch(H2, H1) diagMp = batch_diagonal(M_p.view(M_p.shape[0], -1)) diagMe = M_e.view(M_e.shape[0], -1, 1) K1_new = torch.zeros_like(K1).to(device=K1.device) for j in range(K1.shape[2]): K1_new[:, :, j] = torch.matmul(K1[:, :, j].unsqueeze(2), diagMe[:, j].unsqueeze(2)).squeeze() M = diagMp + torch.bmm(K1_new, K2.permute(0, 2, 1)) return M
def allreduce_async_(self, name, tensor, op=hvd.Average): self.op = op if self.merge: if self.symmetric: upper_indices = torch.triu_indices(tensor.shape[0], tensor.shape[0], device=tensor.device) comm_tensor = tensor[upper_indices[0], upper_indices[1]] else: comm_tensor = tensor if self.fp16: if self.residual: if name not in self._residuals: self._residuals[name] = comm_tensor.new_zeros( comm_tensor.shape) comm_tensor.add_(self._residuals[name]) half_tensor = comm_tensor.half() if self.residual: self._residuals[name] = comm_tensor - half_tensor comm_tensor = half_tensor self._name_tensors[name] = (tensor, comm_tensor) new_name, new_tensor = self._tensor_group.push_tensor( name, comm_tensor) if new_tensor is not None: current_stream = torch.cuda.current_stream() current_stream.synchronize() handle = hvd.allreduce_async_(new_tensor, op=hvd.Sum, name=self.prefix + new_name) self.handles.append(handle) else: if self.symmetric: upper_indices = torch.triu_indices(tensor.shape[0], tensor.shape[0], device=tensor.device) comm_tensor = tensor[upper_indices[0], upper_indices[1]] else: comm_tensor = tensor if self.fp16: if self.residual: if name not in self._residuals: self._residuals[name] = comm_tensor.new_zeros( comm_tensor.shape) comm_tensor.add_(self._residuals[name]) half_tensor = comm_tensor.half() if self.residual: self._residuals[name] = comm_tensor - half_tensor comm_tensor = half_tensor #comm_tensor.half() #comm_tensor = comm_tensor.bfloat16() self._name_tensors[name] = (tensor, comm_tensor) handle = hvd.allreduce_async_(comm_tensor, op=hvd.Sum) self.handles.append(handle)
def forward(self, x, calculation_mode = 'einsum'): b, t, k = x.shape h = self.heads keys = self.tokeys(x).view(b, t, h, k) # shape: b, t, h*k --> b, t, h, k queries = self.toqueries(x).view(b, t, h, k) # shape: b, t, h*k --> b, t, h, k values = self.tovalues(x).view(b, t, h, k) # shape: b, t, h*k --> b, t, h, k if calculation_mode == 'einsum': # shape: b t h k # shape: b t h k # shape: b h t t weights = torch.einsum('bthk, bihk -> bhti', [queries, keys]) / torch.sqrt(torch.tensor(k)) # shape: b, h, t, t if self.masked: indicies = torch.triu_indices(t, t, offset=1) weights[:, indicies[0], indicies[1]] = float('-inf') weights = torch.softmax(weights, dim=-1) # shape: b, h, t, t # shape: b h t t_ # shape: b t_ h k out = torch.einsum('bhte, behk -> bthk', weights, values)# shape: b t h k print(out.shape) out = out.reshape(b, t, h*k) else: keys = keys.transpose(1, 2).contiguous().view(b * h, t, k) queries = queries.transpose(1, 2).contiguous().view(b * h, t, k) values = values.transpose(1, 2).contiguous().view(b * h, t, k) queries = queries / (k ** (1/4)) keys = keys / (k ** (1/4)) # - get dot product of queries and keys, and scale dot = torch.bmm(queries, keys.transpose(1, 2)) if self.masked: indicies = torch.triu_indices(t, t, offset=1) dot[:, indicies[0], indicies[1]] = float('-inf') # - dot has size (b*h, t, t) containing raw weights dot = F.softmax(dot, dim=2) out = torch.bmm(dot, values).view(b, h, t, k) out = out.transpose(1, 2).contiguous().view(b, t, h * k) return self.unifyheads(out)
def finv(wsh): """ Inverse transform 1 This method takes a block matrix representing a complex Hermitian matrix and converts it to a complex matrix represented by its upper triangular part. The result will have the following format: (*,2,C+P) Arguments --------- wsh : tensor An input matrix. The tensor must have the following format: (*,2C,2C) """ # Dimensions D = wsh.dim() C = int(wsh.shape[D - 1] / 2) P = int(C * (C + 1) / 2) # Output matrix ws = torch.zeros(wsh.shape[0:(D - 2)] + (2, P), dtype=wsh.dtype, device=wsh.device) ids = torch.triu_indices(C, C) ws[..., 0, :] = wsh[..., ids[0] * 2, ids[1] * 2] ws[..., 1, :] = -1 * wsh[..., ids[0] * 2, ids[1] * 2 + 1] return ws
def convert_A_to_Avec(A): """ Convert BxNXN symmetric matrices to BxM vectors encoding unique values""" if A.dim() < 3: A = A.unsqueeze(dim=0) idx = torch.triu_indices(A.shape[1], A.shape[1]) A_vec = A[:, idx[0], idx[1]] return A_vec.squeeze()
def pairwise_similarity(tensors): """ Computes the pairwise similarity overlap of the tensors. Parameters ---------- tensors : list of torch.Tensor A list of binary vectors. Each entry can be either a single vector tensor (one incoming area) or a tuple of tensors (multiple incoming areas). Returns ------- similarity : float The pairwise :math:`L_{0/1}` similarity from 0 to 1. """ tensors = [t for t in tensors if t is not None] if len(tensors) <= 1: return np.nan if not isinstance(tensors[0], torch.Tensor): # multiple incoming areas sim_areas = list(map(pairwise_similarity, zip(*tensors))) sim_areas = np.nanmean(sim_areas) return sim_areas else: tensors = torch.stack(tensors) similarity = tensors.matmul(tensors.t()) n_elements = len(tensors) ii, jj = torch.triu_indices(row=n_elements, col=n_elements, offset=1) similarity = similarity[ii, jj].mean() similarity /= K_ACTIVE return similarity
def get_pairwise_distance_matrix(self, particle_tensor): ''' Input: tensors of particles Output: matrix of pairwise distances ''' num_particles = particle_tensor.shape[0] euclidean_dists = torch.nn.functional.pdist(input=particle_tensor, p=2) # shape of (N) # initialize matrix of pairwise distances as a N x N matrix pairwise_d_matrix = torch.zeros((num_particles, num_particles), device=self.device) # assign upper-triangle part triu_indices = torch.triu_indices(row=num_particles, col=num_particles, offset=1) pairwise_d_matrix[triu_indices[0], triu_indices[1]] = euclidean_dists # assign lower-triangle part #SHOULD WE NOT TRANSPOSE BACK?? pairwise_d_matrix = torch.transpose(pairwise_d_matrix, dim0=0, dim1=1) pairwise_d_matrix[triu_indices[0], triu_indices[1]] = euclidean_dists return pairwise_d_matrix
def forward(self, qk, v, query_len = None, input_mask = None, input_attn_mask = None): b, seq_len, dim = qk.shape query_len = default(query_len, seq_len) t = query_len q = qk[:, 0:query_len] qk = F.normalize(qk, 2, dim=-1).type(q.type()) dot = torch.einsum('bie,bje->bij', q, qk) * (dim ** -0.5) # qk attention requires tokens not attend to self i = torch.arange(t) dot[:, i, i] = TOKEN_SELF_ATTN_VALUE masked_value = max_neg_value(dot) # Input mask for padding in variable lengthed sequences if input_mask is not None: mask = input_mask[:, 0:query_len, None] * input_mask[:, None, :] mask = F.pad(mask, (0, seq_len - mask.shape[-1]), value=True) dot.masked_fill_(~mask, masked_value) # Mask for post qk attention logits of the input sequence if input_attn_mask is not None: input_attn_mask = F.pad(input_attn_mask, (0, seq_len - input_attn_mask.shape[-1]), value=True) dot.masked_fill_(~input_attn_mask, masked_value) if self.causal: i, j = torch.triu_indices(t, t, 1) dot[:, i, j] = masked_value dot = dot.softmax(dim=-1) out = torch.einsum('bij,bje->bie', dot, v) return out, dot, torch.empty(0)
def simple_embedding_truth(coords, truth_label_by_hits, device='cpu'): truth_ordering = torch.argsort(truth_label_by_hits) uniques, counts = torch.unique(truth_label_by_hits, return_counts=True) out_truths: List[PairTensor] = [] ''' for each latent space 2d -coordinates of category cat, compute all incat and outofcat indices, then compute pnorm distace with both kind of categories, return distances and truths(in or out) ''' for cat in uniques: thecat = cat.item() in_cat = coords[truth_label_by_hits == thecat] not_cat = coords[truth_label_by_hits != thecat] in_cat_dists = torch.cdist(in_cat, in_cat) in_idxs = torch.triu_indices(in_cat_dists.size()[0], in_cat_dists.size()[0], offset=1, device=in_cat.device) in_idxs = in_idxs[0] + in_cat_dists.size()[0]*in_idxs[1] in_cat_dists = in_cat_dists.view(-1)[in_idxs] / (uniques.size()[0] - 1) ''' all pairwise distances between in-category and out of category there's a factor of 2 here I need to deal with ''' not_cat_dists = torch.cdist(in_cat, not_cat).flatten() / (uniques.size()[0] - 1) '''build the final labelled distance vectors''' dists = torch.cat([in_cat_dists, not_cat_dists], dim=0) truth = torch.cat([torch.ones_like(in_cat_dists, dtype=torch.int64), torch.full_like(not_cat_dists, -1, dtype=torch.int64)], dim=0) out_truths.append((dists, truth)) return out_truths
def params2orb(params, coeffs, with_penalty): # params: (*, nparams) # coeffs: (*, nao, norb) nao = coeffs.shape[-2] norb = coeffs.shape[-1] nparams = params.shape[-1] bshape = params.shape[:-1] # construct the rotation parameters triu_idxs = torch.triu_indices(nao, nao, offset=1)[..., :nparams] rotmat = torch.zeros((*bshape, nao, nao), dtype=params.dtype, device=params.device) rotmat[..., triu_idxs[0], triu_idxs[1]] = params rotmat = rotmat - rotmat.transpose(-2, -1).conj() # calculate the orthogonal orbital ortho_orb = torch.matrix_exp(rotmat) @ coeffs if with_penalty: penalty = torch.zeros((1, ), dtype=params.dtype, device=params.device) return ortho_orb, penalty else: return ortho_orb
def mask(matrices, mask_val, mask_diagonal=True): """ Mask all values in place of given batch of matrices. Upper triangle becomes mask_val """ b, h, w = matrices.size() indices = torch.triu_indices(h, w, offset=0 if mask_diagonal else 1) matrices[:, indices[0], indices[1]] = mask_val
def forward(self, data): x = data.x batch = data.batch x = self.inp(x) xdense, mask = torch_geometric.utils.to_dense_batch(x, data.batch) print(xdense.shape) adj_dense = torch_geometric.utils.to_dense_adj(data.edge_index, data.batch, data.edge_attr) inds = torch.triu_indices(adj_dense.shape[1], adj_dense.shape[1], device=device) adj_dense[:, inds[0], inds[1]] += torch.sqrt( (xdense[:, inds[0], 0] - xdense[:, inds[1], 0])**2 + (xdense[:, inds[0], 1] - xdense[:, inds[1], 1])**2) x = self.conv(xdense, adj_dense, mask)[mask] cand_ids = self.nn1(x) cand_p4 = data.x[:, len(elem_to_id):len(elem_to_id) + 4] + self.nn2( torch.cat([cand_ids, x], axis=-1)) return torch.nn.functional.sigmoid(data.edge_attr), cand_ids, cand_p4
def DeepCoral(source, target): # d = source.data.shape[1] # xm = torch.mean(source, 1, keepdim=True) # xc = torch.matmul(torch.transpose(xm, 0, 1), xm) # source covariance # xmt = torch.mean(target, 1, keepdim=True) # xct = torch.matmul(torch.transpose(xmt, 0, 1), xmt) # target covariance # loss = torch.mean(torch.mul((xc - xct), (xc - xct))) # frobenius norm between source and target # res = - loss / (4 * d * d) # return res # x = torch.cat([x ** i for i in range(1, self.K + 1)], 1) # print(source.shape,target.shape) row_idx, col_idx = torch.triu_indices(source.shape[1], source.shape[2],offset=1) x = source[:, row_idx, col_idx] y = target[:, row_idx, col_idx] # print('triu.shape',x.shape,y.shape) # # # x = torch.cat([source[i][torch.triu(torch.ones(source.shape[1], source.shape[2]), diagonal=1) == 1].reshape(1,2278) for i in range(source.shape[0])],0) # y = torch.cat([target[i,torch.triu(torch.ones(target.shape[1], target.shape[2]), diagonal=1) == 1].reshape(1,2278) for i in range(target.shape[0])],0) # x = source # y = target vx = x - torch.mean(x, 1, keepdim=True) vy = y - torch.mean(y, 1, keepdim=True) xy_cov = torch.bmm(vx.view(vx.shape[0], 1, vx.shape[1]), vy.view(vy.shape[0], vy.shape[1], 1), ).view(vx.shape[0]) # std_x = torch.std(vx,dim=1) # std_y = torch.std(vy,dim=1) # xy_std = torch.mul(std_x,std_y) # corr = torch.div(xy_cov,xy_std) cost = xy_cov / torch.mul(torch.sqrt(torch.sum(vx ** 2, dim=1)), torch.sqrt(torch.sum(vy ** 2, dim=1))) # print('train_corr',torch.mean(cost),x.shape,y.shape) loss = 1 - torch.mean(cost) # cost = torch.sum(vx * vy) ./ (torch.sqrt(torch.sum(vx ** 2)) * torch.sqrt(torch.sum(vy ** 2))) return loss
def synchronize(self): for h in self.handles: hvd.synchronize(h) if self.merge: self._tensor_group.pull_alltensors() self._tensor_group.clear_group_flags() for name in self._name_tensors: tensor, comm_tensor = self._name_tensors[name] if self.symmetric: if self.fp16: comm_tensor = comm_tensor.float() lower_indices = torch.tril_indices(tensor.shape[0], tensor.shape[1], device=tensor.device) upper_indices = torch.triu_indices(tensor.shape[0], tensor.shape[1], device=tensor.device) tensor[upper_indices[0], upper_indices[1]] = comm_tensor tensor[lower_indices[0], lower_indices[1]] = tensor.t()[lower_indices[0], lower_indices[1]] else: if self.fp16: comm_tensor = comm_tensor.float() tensor.copy_(comm_tensor) if self.op == hvd.Average: tensor.div_(hvd.size()) self._name_tensors.clear() self.handles.clear()
def forward(self, x, mask=False): # Create queries, keys and values b, t, k = x.size() h = self.heads ## .view reshape (b, t, h*k) to (b,t,h,k) queries = self.toqueries(x).view(b, t, h, k) keys = self.tokeys(x).view(b, t, h, k) values = self.tovalues(x).view(b, t, h, k) # Compute the dot products ## Fold heads into the batch dimention ## .contiguous sorts index order in memory ## to avoid error which happens ## when applying .view to transposed tensor queries = queries.transpose(1, 2).contiguous().view(b * h, t, k) keys = keys.transpose(1, 2).contiguous().view(b * h, t, k) values = values.transpose(1, 2).contiguous().view(b * h, t, k) ## Avoid gradient disappearing caused by softmax with large value input ## Divide by k**1/4 instead of k**1/2 to save memory queries = queries / (k**(1 / 4)) keys = keys / (k**(1 / 4)) ## dot has shape (b*h, t, t) dot = torch.bmm(queries, keys.transpose(1, 2)) if self.mask: # for text generation, hide forward part of sequence indices = torch.triu_indices( k, k, offset=0) # idx of upper triangle k x k dot[:, indices[0], indices[1]] = float('-inf') dot = F.softmax(dot, dim=2) # Apply the self attention to the values out = torch.bmm(dot, values).view(b, h, t, k) out = out.transpose(1, 2).contiguous().view(b, t, h * k) return self.unifyhead(out)
def pdist(self, x, squared=False): assert x.ndim == 3 n = x.shape[0] l_inv, _ = self.invchol(x) m = torch.triu_indices(n, n, 1, device=x.device) lylt = tb.axat(l_inv[m[0]], x[m[1]]) return self._norm_log(lylt, squared=squared)
def synchronize(self): self.merged_comm.synchronize() for h in self.handles: handle, names, tensors, comm_tensors, rank = h if rank != hvd.rank(): continue name = ','.join(names) offset = 0 buf = self.merged_tensors[name] if self.fp16: buf = buf.float() for i, t in enumerate(tensors): numel = comm_tensors[i].numel() comm_tensor = buf.data[offset:offset+numel] if self.symmetric: lower_indices = torch.tril_indices(t.shape[0], t.shape[1], device=t.device) upper_indices = torch.triu_indices(t.shape[0], t.shape[1], device=t.device) t[upper_indices[0], upper_indices[1]] = comm_tensor t[lower_indices[0], lower_indices[1]] = t.t()[lower_indices[0], lower_indices[1]] else: t.copy_(comm_tensor.view(t.shape)) t.div_(hvd.size()) offset += numel self.handles.clear()
def test_stein_pdiv(seed, d): spd = SPD(2) xs = spd.rand(10, ir=1.0, out=torch.empty(10, d, d, dtype=torch.float64)) pdivs = spd.stein_pdiv(xs) m = torch.triu_indices(10, 10, 1) ref_pdivs = spd.stein_div(xs[m[0]], xs[m[1]]) assert_allclose(ref_pdivs, pdivs, atol=1e-4)
def forward(self, x1, x2, **params): """ computes batched distance operation for two batches of data x1 and x2 (first dimension refering to the batch dimension) and returns the matrix of distances. """ # get batch dimension of both data batches n1 = x1.shape[0] n2 = x2.shape[0] #if operation is computed on one dataset, we can skip redundant index pairs if x1.size() == x2.size() and torch.equal(x1, x2): inds = torch.triu_indices(n1, n2) triu = True #use only upper triangular else: inds = self._get_index_pairs(n1, n2) #get index pairs without looping triu = False # expand data such that pair-wise operation covers all required pairs of # instances x1_batch = x1[inds[0]] x2_batch = x2[inds[1]] result = self.op(x1_batch, x2_batch) #check if op returns tuple of results and use the result_index'th element: if type(result) == tuple: result = result[self.result_index] #convert flat output to result matrix (e.g. a distance matrix) D = torch.zeros(n1, n2, dtype=self.dtype, device=self.device) D[inds[0], inds[1]] = result.to(dtype=self.dtype) if triu: #mirror upper triangular such that full distance matrix is recovered D = self._triu_to_full(D) return D
def reduce_async_(self, names, tensors, rank): if self.fp16: comm_tensors = [t.half() for t in tensors] else: comm_tensors = tensors if self.symmetric: sym_comm_tensors = [] for tensor in comm_tensors: upper_indices = torch.triu_indices(tensor.shape[0], tensor.shape[0], device=tensor.device) comm_tensor = tensor[upper_indices[0], upper_indices[1]] sym_comm_tensors.append(comm_tensor) comm_tensors = sym_comm_tensors name = ','.join(names) if len(comm_tensors) > 1: if name not in self.merged_tensors: size = 0 for t in comm_tensors: size += t.numel() buf = comm_tensors[0].new_zeros(size) self.merged_tensors[name] = buf else: self.merged_tensors[name] = comm_tensors[0] buf = self.merged_tensors[name] if len(comm_tensors) > 1: offset = 0 for t in comm_tensors: numel = t.numel() buf.data[offset:offset+numel].copy_(t.view(numel)) offset += numel handle = self.merged_comm.reduce(buf, rank) self.handles.append((handle, names, tensors, comm_tensors, rank))
def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: float) -> Tensor: """Compute pairs of atoms that are neighbors (doesn't use PBC) This function bypasses the calculation of shifts and duplication of atoms in order to make calculations faster Arguments: padding_mask (:class:`torch.Tensor`): boolean tensor of shape (molecules, atoms) for padding mask. 1 == is padding. coordinates (:class:`torch.Tensor`): tensor of shape (molecules * atoms, 3) for atom coordinates. cutoff (float): the cutoff inside which atoms are considered pairs """ coordinates = coordinates.detach() current_device = coordinates.device num_atoms = padding_mask.shape[1] num_mols = padding_mask.shape[0] p12_all = torch.triu_indices(num_atoms, num_atoms, 1, device=current_device) p12_all_flattened = p12_all.view(-1) pair_coordinates = coordinates.index_select(1, p12_all_flattened).view(num_mols, 2, -1, 3) distances = (pair_coordinates[:, 0, ...] - pair_coordinates[:, 1, ...]).norm(2, -1) padding_mask = padding_mask.index_select(1, p12_all_flattened).view(num_mols, 2, -1).any(dim=1) distances.masked_fill_(padding_mask, math.inf) in_cutoff = (distances <= cutoff).nonzero() molecule_index, pair_index = in_cutoff.unbind(1) molecule_index *= num_atoms atom_index12 = p12_all[:, pair_index] + molecule_index return atom_index12
def triu_index(num_species: int) -> Tensor: species1, species2 = torch.triu_indices(num_species, num_species).unbind(0) pair_index = torch.arange(species1.shape[0], dtype=torch.long) ret = torch.zeros(num_species, num_species, dtype=torch.long) ret[species1, species2] = pair_index ret[species2, species1] = pair_index return ret
def second_order(self, batch_size, index, values, embeddings, n_fields, embedding_dim, mats): # type: (int, Tensor, Tensor, Tensor, int, int, List[Tensor]) -> Tensor attention_w, attention_b, attention_h, attention_p = mats biinteraction_num = int(n_fields * (n_fields - 1) * 0.5) embeddings_ = embeddings.view(batch_size, n_fields, embedding_dim) tri_indices = torch.triu_indices(n_fields, n_fields, 1) indices_i = tri_indices[0] indices_j = tri_indices[1] biinteraction_result = torch.index_select( embeddings_, 1, indices_i) * torch.index_select( embeddings_, 1, indices_j) temp_mul = torch.matmul( biinteraction_result.view(batch_size, biinteraction_num, embedding_dim), attention_w) temp_w = torch.relu( temp_mul.view(batch_size, -1) + attention_b.view(-1).repeat(biinteraction_num)) attention_weight_matrix = F.softmax(torch.matmul( temp_w.view(batch_size, biinteraction_num, -1), attention_h), dim=1) attention_weighted_sum = attention_weight_matrix.view(batch_size, biinteraction_num).repeat(1, embedding_dim) * \ biinteraction_result.view(batch_size, -1) attention_out = torch.matmul( attention_weighted_sum.view(batch_size, biinteraction_num, -1), attention_p.view(-1)).sum(1) return attention_out
def cal_sal_rank_loss(real_pred, lite_pred, target, margin=0): B, T, K = real_pred.shape # TODO(shape) B * T b_idx = [[x] * T for x in range(B)] t_idx = [list(range(T)) for _ in range(B)] k_idx = [[tgt] * T for tgt in target[:, 0].cpu().numpy()] # TODO(shape) B * T real_cfd = real_pred[b_idx, t_idx, k_idx] lite_cfd = lite_pred[b_idx, t_idx, k_idx] x, y = torch.triu_indices(T - 1, T - 1) + torch.tensor([[0], [1]]) # TODO(shape) B * (T*(T-1)/2) real_cfd_x = real_cfd[:, x] real_cfd_y = real_cfd[:, y] lite_cfd_x = lite_cfd[:, x] lite_cfd_y = lite_cfd[:, y] psu_label = (real_cfd_x > real_cfd_y).double() * 2 - 1 return F.margin_ranking_loss(lite_cfd_x, lite_cfd_y, psu_label, margin=margin, reduction="mean")
def forward(self, q, k, v, attn_mask=None): n, device = q.size(1), q.device q, k, v = map( lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.n_heads), (q, k, v)) # classic dot-product attention dots = torch.einsum('bhid,bhjd->bhij', q * self.scale, k) if exists(attn_mask): dots.masked_fill_(~attn_mask, MASK_VAL) del attn_mask if self.shared_qk: m = torch.arange(n) dots[:, :, m, m] = SELF_ATTN_MASK_VAL if self.causal: i, j = torch.triu_indices(n, n, 1) dots[:, :, i, j] = MASK_VAL attn = F.softmax(dots, -1) if self.store_attention: self.attention = attn.detach().cpu() attn = self.dropout(attn) out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') return out
def bcast_async_(self, names, tensors, rank): #name = 'merged_tensor_comm_'+','.join(names) if self.fp16: comm_tensors = [t.half() for t in tensors] else: comm_tensors = tensors if self.symmetric: sym_comm_tensors = [] for tensor in comm_tensors: upper_indices = torch.triu_indices(tensor.shape[0], tensor.shape[0], device=tensor.device) comm_tensor = tensor[upper_indices[0], upper_indices[1]] sym_comm_tensors.append(comm_tensor) comm_tensors = sym_comm_tensors name = ','.join(names) if name not in self.merged_tensors: size = 0 for t in comm_tensors: size += t.numel() buf = comm_tensors[0].new_zeros(size) self.merged_tensors[name] = buf buf = self.merged_tensors[name] offset = 0 for t in comm_tensors: numel = t.numel() buf.data[offset:offset + numel].copy_(t.view(numel)) offset += numel #handle = hvd.broadcast_async_(buf, rank, name=name) handle = hvd.broadcast_async_(buf, rank) self.handles.append((handle, names, tensors, comm_tensors))
def get_weights(self, device): # Generate the full weights. # return: weights of shape (hidden_dim * 4 , input_dim * hidden_dim) # input to hidden w_ih = None ind = torch.triu_indices(self.hidden_dim, self.input_dim, self.in_num_diags, device=device) weights_f = torch.zeros([self.hidden_dim, self.input_dim], device=device) for coeffs in self.in_coeffss: if self.dropout_dct: coeffs = self.wdrop(coeffs) weights = self.to_weights(coeffs, ind, weights_f, self.in_dct_layer, self.hid_dct_layer) if w_ih is not None: w_ih = torch.cat([w_ih, weights], dim=0) else: w_ih = weights # hidden to hidden w_hh = None ind = torch.triu_indices(self.hidden_dim, self.hidden_dim, self.hidden_num_diags, device=device) weights_f = torch.zeros([self.hidden_dim, self.hidden_dim], device=device) for coeffs in self.hid_coeffss: if self.dropout_dct: coeffs = self.wdrop(coeffs) weights = self.to_weights(coeffs, ind, weights_f, self.hid_dct_layer, self.hid_dct_layer) if w_hh is not None: w_hh = torch.cat([w_hh, weights], dim=0) else: w_hh = weights # concatenate both weights = torch.cat([w_ih, w_hh], dim=1) return weights
def __call__(self, x, target): """ :param x: output segmentation, shape [*, C, *] :param Sigma: co-variance coefficients. It can be: (1) If no_covar==False, shape [*, C(C+1)/2, *] organized as row-first according to tril_indices from torch and numpy : [rho_11, rho_12, ..., rho_1C, rho_22, rho_23,...rho_2C,... rho_CC] with rho_ii = exp(.) > 0 encodes the variances and rho_ij = tanh(.) encodes the correlations. The covariance matrix is M is s.t M[i][j] = rho_ij * srqrt(rho_ii) * sqrt(rho_ij) (2) If no_covar==True, shape [*, C, *], assuming that all non-diagonal coeff are zeros. We assume it has the form [sigma_1**2, sigma_2**2, ..., sigma_C**2] :param target: true segmentation, shape [*, C, *] :return: log-likelihood for logistic regression with uncertainty """ if isinstance( x, list ): #should happen just for regression, where sigma_prediction is used in metric (utils) x, Sigma = x[0], x[1] log_Sigma = Sigma Sigma = torch.exp(log_Sigma) + 1e-6 #if Sigma.min() < 1e-6: # print(f'Warning min Sigma {Sigma.min()}') C, ndims = x.shape[1], x.ndim if self.no_covar: # Simplified Case assert C == Sigma.shape[1] and Sigma.ndim == ndims,\ "Inconsistent shape for input data and covariance: {} vs {}".format(x.shape, Sigma.shape) assert torch.all(Sigma > 0), "Negative values found in Sigma" inv_Sigma = 1. / Sigma # shape [*, C, *] #logdet_sigma = torch.log(torch.prod(Sigma, dim=1)) # shape [*, *] logdet_sigma = torch.sum(log_Sigma, dim=1) # shape [*, *] err = (target - x) # shape [*, C, *] return ((err * inv_Sigma * err).sum(dim=1) + logdet_sigma.squeeze()).mean() else: # General Case assert (C * (C+1))//2 == Sigma.shape[1] and Sigma.ndim == ndims, \ "Inconsistent shape for input data and covariance: {} vs {}".format(x.shape, Sigma.shape) # permutes the 2nd dim to last, keeping other unchanged (in v1.9, eq. to torch.moveaxis(1, -1)) swap_channel_last = (0, ) + tuple(range(2, ndims)) + (1, ) # First, re-arrange covar matrix to have shape [*, *, C, C] covar_shape = (Sigma.shape[0], ) + Sigma.shape[2:] + (C, C) tril_ind = torch.tril_indices(row=C, col=C, offset=0) triu_ind = torch.triu_indices(row=C, col=C, offset=0) Sigma_ = torch.zeros(covar_shape, device=x.device) Sigma_[..., tril_ind[0], tril_ind[1]] = Sigma.permute(swap_channel_last) Sigma_[..., triu_ind[0], triu_ind[1]] = Sigma.permute(swap_channel_last) # Then compute determinant and inverse of covariance matrices logdet_sigma = torch.logdet(Sigma_) # shape [*, *] inv_sigma = torch.inverse(Sigma_) # shape [*, *, C, C] # Finally, compute log-likehood of multivariate gaussian distribution err = (target - x).permute(swap_channel_last).unsqueeze( -1) # shape [*, *, C, 1] return ((err.transpose(-1, -2) @ inv_sigma @ err).squeeze() + logdet_sigma.squeeze()).mean()
def q(self) -> torch.Tensor: indices = torch.triu_indices(self.state_count, self.state_count, 1) R = torch.zeros((self.state_count, self.state_count), dtype=self.rates.dtype) R[indices[0], indices[1]] = self.rates[self.mapping.tensor] R[indices[1], indices[0]] = self.rates[self.mapping.tensor] Q = R @ self.frequencies.diag() Q[range(len(Q)), range(len(Q))] = -torch.sum(Q, dim=1) return Q