def lstsq(b, y, alpha=0.01): """ Batched linear least-squares for pytorch with optional L1 regularization. Parameters ---------- b : shape(L, M, N) y : shape(L, M) Returns ------- tuple of (coefficients, model, residuals) """ bT = b.transpose(-1, -2) AA = torch.bmm(bT, b) if alpha != 0: diag = torch.diagonal(AA, dim1=1, dim2=2) diag += alpha RHS = torch.bmm(bT, y[:, :, None]) X, LU = torch.gesv(RHS, AA) fit = torch.bmm(b, X)[..., 0] res = y - fit return X[..., 0], fit, res
def loss_function(self, *args, **kwargs) -> dict: """ Computes the VAE loss function. KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2} :param args: :param kwargs: :return: """ recons = args[0] input = args[1] mu = args[2] log_var = args[3] kld_weight = 1 #* kwargs['M_N'] # Account for the minibatch samples from the dataset recons_loss = F.mse_loss(recons, input, reduction='sum') kld_loss = torch.sum( -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0) # DIP Loss centered_mu = mu - mu.mean(dim=1, keepdim=True) # [B x D] cov_mu = centered_mu.t().matmul(centered_mu).squeeze() # [D X D] # Add Variance for DIP Loss II cov_z = cov_mu + torch.mean(torch.diagonal( (2. * log_var).exp(), dim1=0), dim=0) # [D x D] # For DIp Loss I # cov_z = cov_mu cov_diag = torch.diag(cov_z) # [D] cov_offdiag = cov_z - torch.diag(cov_diag) # [D x D] dip_loss = self.lambda_offdiag * torch.sum(cov_offdiag ** 2) + \ self.lambda_diag * torch.sum((cov_diag - 1) ** 2) loss = recons_loss + kld_weight * kld_loss + dip_loss return { 'loss': loss, 'Reconstruction_Loss': recons_loss, 'KLD': -kld_loss, 'DIP_Loss': dip_loss }
def forward(self, z_a, z_b): """Forward head. Args: z_a (Tensor): NxD representation from one randomly augmented image. z_b (Tensor): NxD representation from another version of augmentation. Returns: dict[str, Tensor]: A dictionary of loss components. """ N, D = z_a.shape # # normalize repr. along the batch dimension # z_a_norm = (z_a - z_a.mean(0)) / z_a.std(0) # NxD # z_b_norm = (z_b - z_b.mean(0)) / z_b.std(0) # NxD # if self.dimension == 'D': # # cross-correlation matrix # c = torch.mm(z_a_norm.T, z_b_norm) / N # DxD # # loss # c_diff = (c - torch.eye(D).cuda()).pow(2) # DxD # # multiply off-diagonal elems of c_diff by lambda # c_diff[~torch.eye(D, dtype=bool).cuda()] *= self.lambd # elif self.dimension == 'N': # # auto-correlation matrix # c = torch.mm(z_a_norm, z_b_norm.T) / N # NxN # # loss # c_diff = (c - torch.eye(N).cuda()).pow(2) # NxN # # multiply off-diagonal elems of c_diff by lambda # c_diff[~torch.eye(N, dtype=bool).cuda()] *= self.lambd # loss = c_diff.sum() # empirical cross-correlation matrix c = self.bn(z_a).T @ self.bn(z_b) # sum the cross-correlation matrix between all gpus c.div_(N) torch.distributed.all_reduce(c) on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() off_diag = off_diagonal(c).pow_(2).sum() loss = on_diag + self.lambd * off_diag losses = dict() losses["loss"] = loss return losses
def forward(self, X): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Initialize r and i nodes (X.shape[0] = n_batches, X, shape[1] = n_time_steps) R = torch.zeros((X.shape[0], self.visible_size, self.full_size), dtype=torch.float32).to(device) Y = torch.zeros((X.shape[0], self.visible_size, self.full_size), dtype=torch.float32).to(device) Lambda_batch = self.Lambda.repeat(X.shape[0], 1).view(X.shape[0], self.visible_size, self.full_size).to(device) # Forward path for t in range(X.shape[1]): Y[:, :, :self.visible_size] = X[:, t, :].repeat(self.visible_size, 1)\ .view(self.visible_size, X.shape[0], self.visible_size).transpose(0, 1) U = torch.mul(Lambda_batch, self.W(R)) + torch.mul( (1 - Lambda_batch), Y) R = self.phi(U) return torch.diagonal(U[:, :, :self.visible_size], dim1=-1, dim2=-2)
def forward(self, x): # in: N x d x m x m # out: N x (d * basis) x m N = x.size(0) m = x.size(-1) diag_part = torch.diagonal(x, dim1=2, dim2=3) max_diag_part = torch.max(diag_part, 2)[0].unsqueeze(-1) max_of_rows = torch.max(x, 3)[0] max_of_cols = torch.max(x, 2)[0] max_all = torch.max(torch.max(x, 2)[0], 2)[0].unsqueeze(-1) op1 = diag_part op2 = max_diag_part.expand_as(op1) op3 = max_of_rows op4 = max_of_cols op5 = max_all.expand_as(op1) return torch.stack([op1, op2, op3, op4, op5]).permute(1, 0, 2, 3).reshape(N, -1, m)
def forward(self, y1, y2): z1 = self.projector(self.backbone(y1)) z2 = self.projector(self.backbone(y2)) # empirical cross-correlation matrix c = self.bn(z1).T @ self.bn(z2) # sum the cross-correlation matrix between all gpus c.div_(self.per_device_batch_size * self.trainer.num_processes) self.all_reduce(c) # use --scale-loss to multiply the loss by a constant factor # In order to match the code that was used to develop Barlow Twins, # the authors included an additional parameter, --scale-loss, # that multiplies the loss by a constant factor. on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(self.scale_loss) off_diag = self.off_diagonal(c).pow_(2).sum().mul(self.scale_loss) loss = on_diag + self.lambd * off_diag return loss
def train(net, data_loader, train_optimizer): net.train() total_loss, total_num, train_bar = 0.0, 0, tqdm(data_loader) for data_tuple in train_bar: (pos_1, pos_2), _ = data_tuple pos_1, pos_2 = pos_1.cuda(non_blocking=True), pos_2.cuda( non_blocking=True) feature_1, out_1 = net(pos_1) feature_2, out_2 = net(pos_2) # Barlow Twins # normalize the representations along the batch dimension out_1_norm = (out_1 - out_1.mean(dim=0)) / out_1.std(dim=0) out_2_norm = (out_2 - out_2.mean(dim=0)) / out_2.std(dim=0) # cross-correlation matrix c = torch.matmul(out_1_norm.T, out_2_norm) / batch_size # loss on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() if corr_neg_one is False: # the loss described in the original Barlow Twin's paper # encouraging off_diag to be zero off_diag = off_diagonal(c).pow_(2).sum() else: # inspired by HSIC # encouraging off_diag to be negative ones off_diag = off_diagonal(c).add_(1).pow_(2).sum() loss = on_diag + lmbda * off_diag train_optimizer.zero_grad() loss.backward() train_optimizer.step() total_num += batch_size total_loss += loss.item() * batch_size if corr_neg_one is True: off_corr = -1 else: off_corr = 0 train_bar.set_description('Train Epoch: [{}/{}] Loss: {:.4f} off_corr:{} lmbda:{:.4f} bsz:{} f_dim:{} dataset: {}'.format(\ epoch, epochs, total_loss / total_num, off_corr, lmbda, batch_size, feature_dim, dataset)) return total_loss / total_num
def contrastive_loss(z1, z2, temp=1): # get unit vectors z1_unit = z1 / torch.norm(z1, p=2, dim=1).view(-1, 1) z2_unit = z2 / torch.norm(z2, p=2, dim=1).view(-1, 1) # compute z_i * z_j for all i,j in z1 intra_cos_sims = z1_unit @ z1_unit.T # compute z_i * z_k for all i in z1, k in z2 inter_cos_sims = z1_unit @ z2_unit.T cos_sims = torch.cat((intra_cos_sims, inter_cos_sims), dim=1) / temp # compute cross-entropy loss term # subtract out e to remove the zi * zi term from intra_cos_sims exp_numerator = torch.exp(torch.diagonal(inter_cos_sims)) sum_exp = torch.sum(torch.exp(cos_sims), dim=1) - np.e losses = -torch.log(exp_numerator / sum_exp) return torch.mean(losses)
def forward(self, bert_token_b, bert_segment_b, bert_masks_b, bert_clause_b, doc_len, adj, y_mask_b): bert_output = self.bert(input_ids=bert_token_b.to(DEVICE), attention_mask=bert_masks_b.to(DEVICE), token_type_ids=bert_segment_b.to(DEVICE)) doc_sents_h = self.batched_index_select(bert_output, bert_clause_b.to(DEVICE)) reduced_doc_sents_h = self.reduce(doc_sents_h) X = self.seq2mat(reduced_doc_sents_h, reduced_doc_sents_h) # (B, N, N, H) X = F.relu(X, inplace=True) masks = torch.IntTensor(y_mask_b).to(DEVICE) masks = masks.unsqueeze(1) & masks.unsqueeze(2) T, _ = self.mdrnn(X, states=None, masks=masks) T_diagonal = torch.diagonal(T, offset=0, dim1=1, dim2=2) diagonal = T_diagonal.permute(0, 2, 1) pred_e, pred_c = self.pred(diagonal) couples_pred = self.tab_pred(T) return couples_pred, pred_e, pred_c
def backward(ctx, dK): r""" Backward function from the affinity matrix :math:`\mathbf K` to node-wise affinity matrix :math:`\mathbf K_e` and edge-wize affinity matrix :math:`\mathbf K_e`. """ device = dK.device Ke, Kp = ctx.saved_tensors Kro1t, Kro2t = ctx.K dKe = dKp = None if ctx.needs_input_grad[0]: dKe = bilinear_diag_torch(Kro1t, dK.contiguous(), Kro2t) dKe = dKe.view(Ke.shape[0], Ke.shape[2], Ke.shape[1]).transpose(1, 2) if ctx.needs_input_grad[1]: dKp = torch.diagonal(dK, dim1=-2, dim2=-1) dKp = dKp.view(Kp.shape[0], Kp.shape[2], Kp.shape[1]).transpose(1, 2) return dKe, dKp, None, None, None, None
def test(self, validation=False): self.netF.eval() self.netC.eval() test_loss = 0 correct = 0 size = 0 num_class = self.nclasses output_all = np.zeros((0, num_class)) confusion_matrix = torch.zeros(num_class, num_class) if validation: loader = self.source_val_loader else: loader = self.target_loader with torch.no_grad(): for batch_idx, data_t in enumerate(loader): imgs, labels = data_t imgs = imgs.cuda() labels = labels.cuda() feat = self.netF(imgs) logits = self.netC(feat) output_all = np.r_[output_all, logits.data.cpu().numpy()] size += imgs.size(0) pred = logits.data.max(1)[ 1] # get the index of the max log-probability for t, p in zip(labels.view(-1), pred.view(-1)): confusion_matrix[t.long(), p.long()] += 1 correct += pred.eq(labels.data).cpu().sum() test_loss += self.criterion(logits, labels) / len(loader) print( '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} C ({:.0f}%)\n'. format(test_loss, correct, size, 100. * (float(correct) / size))) mean_class_acc = torch.diagonal(confusion_matrix).float() / torch.sum( confusion_matrix, dim=1) mean_class_acc = mean_class_acc * 100.0 print('Classwise accuracy') print(mean_class_acc) mean_class_acc = torch.mean(mean_class_acc) net_class_acc = 100. * float(correct) / size return test_loss.data, mean_class_acc, net_class_acc
def banddiag(orig_x, lu, ld, fill=0): s1 = list(orig_x.shape) s2 = list(orig_x.shape) x = orig_x s1[-2] = lu s2[-2] = ld x = torch.cat( [ torch.zeros(*s1, device=x.device, dtype=x.dtype), x, torch.zeros(*s2, device=x.device, dtype=x.dtype), ], dim=-2, ) unf = x.unfold(-2, lu + ld + 1, 1) return ( torch.diagonal(unf, 0, -3, -2).transpose(-2, -1), x.narrow(-2, lu, orig_x.shape[-2]), )
def forward(self, x, labels): ''' input shape (N, in_features) ''' assert len(x) == len(labels) assert torch.min(labels) >= 0 assert torch.max(labels) < self.out_features for W in self.fc.parameters(): W = F.normalize(W, dim=1) x = F.normalize(x, dim=1) wf = self.fc(x) numerator = self.s * (torch.diagonal(wf.transpose(0, 1)[labels]) - self.m) excl = torch.cat([torch.cat((wf[i, :y], wf[i, y+1:])).unsqueeze(0) for i, y in enumerate(labels)], dim=0) denominator = torch.exp(numerator) + torch.sum(torch.exp(self.s * excl), dim=1) L = numerator - torch.log(denominator) return -torch.mean(L)
def test_to_hermitian(self): m = torch.rand(4, 2, 3, 3) h = sm.to_hermitian(m) h_real, h_imag = sm.real(h), sm.imag(h) # 1 - real part is symmetric self.assertAllEqual(h_real, h_real.transpose(-1, -2)) # 2 - Imaginary diagonal must be 0 imag_diag = torch.diagonal(h_imag, dim1=-2, dim2=-1) self.assertAllEqual(imag_diag, torch.zeros_like(imag_diag)) # 3 - imaginary elements in the upper triangular part of the matrix must be of opposite sign than the # elements in the lower triangular part imag_triu = torch.triu(h_imag, diagonal=1) imag_tril = torch.tril(h_imag, diagonal=-1) self.assertAllEqual(imag_triu, imag_tril.transpose(-1, -2) * -1)
def keep(state1, state2): valid_state1 = torch.isfinite(state1[:, 0]) valid_state2 = torch.isfinite(state2[:, 0]) small_distance = ((PedPedPotential.norm_r_ab( PedPedPotential.r_ab(state1)) < radius) | (PedPedPotential.norm_r_ab( PedPedPotential.r_ab(state2)) < radius)) torch.diagonal(small_distance)[:] = False small_distance = torch.any(small_distance, dim=-1) acc = (torch.abs(state1[:, 4:6]) > acc_abs) | (torch.abs( state2[:, 4:6]) > acc_abs) acc = torch.any(acc, dim=-1) # keep 10% of samples without acc: acc[~acc] = (torch.rand(acc[~acc].shape) < 0.1) acc[:] = torch.any(acc) # symmetrize return valid_state1 & valid_state2 & small_distance & acc
def decode(self, spin_up, spin_down, num_qubits): # spin_up and spin_down (1,mps.size,mps.size) mps = {'0': spin_up, '1': spin_down} coeffs = [] states = seq_gen(num_qubits) for state in states: mat = mps[state[0]] for site in state[1:]: mat = torch.matmul(mat, mps[site]) diagonal = torch.diagonal(mat, dim1=-1, dim2=-2) coeffs.append(torch.sum(diagonal, dim=-1, keepdim=True)) c_i = coeffs[0] for i in coeffs[1:]: c_i = torch.cat((c_i, i), dim=1) return c_i.squeeze()
def forward(self, x): iv = self.di_pool(x) iv = torch.diagonal(iv, dim1=2, dim2=3) iv = torch.log2(iv / torch.mean(iv)) top = self.top_pool(iv[:, :, self.deriv_size:]) bottom = self.bottom_pool(iv[:, :, :-self.deriv_size]) dv = (top - bottom) left = torch.cat([torch.zeros(dv.shape[0], dv.shape[1], 2), dv], dim=2) right = torch.cat([dv, torch.zeros(dv.shape[0], dv.shape[1], 2)], dim=2) band = ((left < 0) == torch.ones_like(left)) * ( (right > 0) == torch.ones_like(right)) band = band[:, :, 2:-2] boundaries = [] for i in range(0, band.shape[0]): cur_bound = torch.where( band[i, 0])[0] + self.window_radius + self.deriv_size boundaries.append(cur_bound) return iv, dv, boundaries
def forward(self, cov: Tensor) -> Tensor: # cov, var = functional.spd_insurance(cov) var = torch.diagonal(cov, dim1=-1, dim2=-2) if self.training: u = torch.mean(var, dim=0, keepdim=True) self.track_mean(u) if self.weight is not None: weight = self.weight / torch.sqrt(u + self.eps) else: weight = 1 / torch.sqrt(u + self.eps) else: if self.weight is not None: weight = self.weight / torch.sqrt(self.running_mean + self.eps) else: weight = 1 / torch.sqrt(self.running_mean + self.eps) cov = cov * torch.matmul(weight.unsqueeze(-1), weight.unsqueeze(-2)) if self.bias_var is not None: cov = cov + mnn_core.nn.functional.var2cov(self.bias_var) return cov
def update(self, y_pred: torch.Tensor, cat_c: torch.Tensor): if self.k is None: order = torch.argsort(input=y_pred, descending=True) else: sequence_length = y_pred.shape[0] if sequence_length < self.k: k = sequence_length else: k = self.k _, order = torch.topk(input=y_pred, k=k, largest=True) cat_c = cat_c[order] cat_score = F.normalize(cat_c, dim=-1) cat_score = cat_score @ cat_score.T cat_score = torch.sum(cat_score) - torch.sum( torch.diagonal(cat_score, 0)) cat_score = cat_score / (cat_c.size(0) * (cat_c.size(0) - 1)) self.ils_cat += cat_score self.count += 1.0
def forward(ctx, input): # LUP decompose the matrices inp_lu, pivots = input.lu() perm, inpl, inpu = torch.lu_unpack(inp_lu, pivots) # get the number of permuations s = (pivots != torch.as_tensor(range( 1, input.shape[1] + 1)).int()).sum(1).type( torch.get_default_dtype()) # get the prod of the diag of U d = torch.diagonal(inpu, dim1=-2, dim2=-1).prod(1) # assemble det = ((-1)**s * d) ctx.save_for_backward(input, det) return det
def _F1F2(mean: Tensor, cov: Tensor, lower: Tensor, upper: Tensor) -> Tuple[Tensor, Tensor]: is_cens_up = torch.isfinite(upper) is_cens_lo = torch.isfinite(lower) std = torch.diagonal(cov, dim1=-2, dim2=-1).sqrt() # mask out the infs before any gradients are being tracked: alpha = torch.zeros_like(mean) alpha[is_cens_lo] = (lower[is_cens_lo] - mean[is_cens_lo]) / std[is_cens_lo] beta = torch.zeros_like(mean) beta[is_cens_up] = (upper[is_cens_up] - mean[is_cens_up]) / std[is_cens_up] # _F1F2_no_inf unstable for large z-scores, so use the lim(+/-inf) version for those as well is_cens_up = is_cens_up & (beta.data < 4.) is_cens_lo = is_cens_lo & (alpha.data > -4.) is_cens_both = is_cens_up & is_cens_lo # sqrt_2 = 2.**.5 x = alpha / sqrt_2 y = beta / sqrt_2 # uncensored F1, F2 = torch.zeros_like(mean), torch.zeros_like(mean) # censored both: F1[is_cens_both], F2[is_cens_both] = _F1F2_no_inf(x[is_cens_both], y[is_cens_both]) # censored lower, uncensored upper: F1[is_cens_lo & ~is_cens_up] = 1. / erfcx(x[is_cens_lo & ~is_cens_up]) F2[is_cens_lo & ~is_cens_up] = x[is_cens_lo & ~is_cens_up] / erfcx( x[is_cens_lo & ~is_cens_up]) # uncensored lower, censored upper: F1[~is_cens_lo & is_cens_up] = -1. / erfcx(-y[~is_cens_lo & is_cens_up]) F2[~is_cens_lo & is_cens_up] = -y[~is_cens_lo & is_cens_up] / erfcx(-y[~is_cens_lo & is_cens_up]) return F1, F2
def _sampson_dist(F, X, Y, if_homo=False): if not if_homo: X = utils_misc._homo(X) Y = utils_misc._homo(Y) if len(X.size()) == 2: nominator = (torch.diag(Y @ F @ X.t()))**2 Fx1 = torch.mm(F, X.t()) Fx2 = torch.mm(F.t(), Y.t()) denom = Fx1[0]**2 + Fx1[1]**2 + Fx2[0]**2 + Fx2[1]**2 else: nominator = (torch.diagonal(Y @ F @ X.transpose(1, 2), dim1=1, dim2=2))**2 Fx1 = torch.matmul(F, X.transpose(1, 2)) Fx2 = torch.matmul(F.transpose(1, 2), Y.transpose(1, 2)) denom = Fx1[:, 0]**2 + Fx1[:, 1]**2 + Fx2[:, 0]**2 + Fx2[:, 1]**2 # print(nominator.size(), denom.size()) errors = nominator / denom return errors
def re_loss_direct(self, x): log_term = (self.D/2) * (np.log(2*np.pi) + 2*self.decoder.log_s) s = torch.exp(self.decoder.log_s) x_mat = torch.matmul(self.decoder.W.weight, self.encoder.M.weight) - torch.eye(self.D) x_vect = torch.matmul(x, x_mat) x_vect_sq = torch.mul(x_vect, x_vect) x_vect_sq_sum = torch.sum(x_vect_sq, axis=-1) norm_term = torch.mean(x_vect_sq_sum)/(2*s*s) W = self.decoder.W.weight W_T = torch.transpose(W, 0, 1) trace_mat = torch.matmul(W_T,W) S = torch.exp(self.encoder.log_S) trace_diag = torch.diagonal(trace_mat).reshape(S.shape) diag_new = trace_diag * S trace_term = torch.sum(diag_new)/(2*s*s) return log_term + norm_term + trace_term
def _get_constraints(self, y_train_labeled, n, augment): """ Get constraints for the equivalence matrix based on the labeled data. :param y_train_labeled: Labels for the labeled subset of the batch :param n: Batch size (labeled + unlabeled observations) :param augment: Number of augmented copies of each input example to use. If 0 then no augmentation is performed :return: mask: Binary matrix with value 0 in entry (i,j) if it is known whether i and j belong to the same class and 1 else :return: known: Binary matrix with value 1 in entry (i,j) if it is known that i and j belong to the same class and 0 else """ if y_train_labeled is not None: nl = len(y_train_labeled) mask = (torch.BoolTensor(n, n).zero_() + 1).to(defaults.device) known = torch.zeros(n, n).to(defaults.device) y_labeled_one_hot = opt_utils.one_hot_embedding( y_train_labeled, self.params.nclasses) known[0:nl, 0:nl] = y_labeled_one_hot.mm(y_labeled_one_hot.t()) torch.diagonal(known).fill_(1) mask[0:nl, 0:nl] = 0 torch.diagonal(mask).fill_(0) else: mask = (torch.BoolTensor(n, n).zero_() + 1).to(defaults.device) known = torch.zeros(n, n).to(defaults.device) torch.diagonal(known).fill_(1) torch.diagonal(mask).fill_(0) if augment > 0: if y_train_labeled is not None: nu = n - nl else: nu = n nl = 0 for i in range(int(nu // augment)): known[nl + i * augment:nl + (i + 1) * augment, nl + i * augment:nl + (i + 1) * augment] = 1 mask[nl + i * augment:nl + (i + 1) * augment, nl + i * augment:nl + (i + 1) * augment] = 0 return mask, known
def _forward_equivariance(self, src, tgt, equi_src, equi_tgt, T): inv_loss, acc, fp, cn = self._forward_invariance(src, tgt) # equi feature: nb, nc, na # L2 distance function dist_func = lambda a, b: (a - b)**2 bdim = src.size(0) # so3 interpolation # equi_srcR = self._interpolate(equi_src, T, sigma=self.sigma).view(bdim, -1) # equi_tgt = equi_tgt.view(bdim, -1) equi_tgt = self._interpolate(equi_tgt, T, sigma=self.sigma).view(bdim, -1) equi_srcR = equi_src.view(bdim, -1) # furthest positive all_dist = pairwise_distance_matrix(equi_srcR, equi_tgt) furthest_positive = torch.diagonal(all_dist) closest_negative = batch_hard_negative_mining(all_dist) diff = furthest_positive - closest_negative if self.loss == 'hard': diff = F.relu(diff + self.margin) elif self.loss == 'soft': diff = F.softplus(diff, beta=self.margin) elif self.loss == 'contrastive': diff = furthest_positive + F.relu(self.margin - closest_negative) # evaluate accuracy _, idx = torch.topk(all_dist, k=self.k_precision, dim=1, largest=False) accuracy = torch.sum(idx.int() == self.gt_idx).float() / float(bdim) inv_info = [inv_loss, acc, fp, cn] equi_loss = diff.mean() total_loss = inv_loss + self.alpha * equi_loss equi_info = [ equi_loss, accuracy, furthest_positive.mean(), closest_negative.mean() ] return total_loss, inv_info, equi_info
def KRt_from_projection( P: torch.Tensor, eps: float = 1e-6) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Decompose the Projection matrix into Camera-Matrix, Rotation Matrix and Translation vector. Args: P: the projection matrix with shape :math:`(B, 3, 4)`. Returns: - The Camera matrix with shape :math:`(B, 3, 3)`. - The Rotation matrix with shape :math:`(B, 3, 3)`. - The Translation vector with shape :math:`(B, 3)`. """ if P.shape[-2:] != (3, 4): raise AssertionError("P must be of shape [B, 3, 4]") if len(P.shape) != 3: raise AssertionError submat_3x3 = P[:, 0:3, 0:3] last_column = P[:, 0:3, 3].unsqueeze(-1) # Trick to turn QR-decomposition into RQ-decomposition reverse = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=P.device, dtype=P.dtype).unsqueeze(0) submat_3x3 = torch.matmul(reverse, submat_3x3).permute(0, 2, 1) ortho_mat, upper_mat = linalg_qr(submat_3x3) ortho_mat = torch.matmul(reverse, ortho_mat.permute(0, 2, 1)) upper_mat = torch.matmul(reverse, torch.matmul(upper_mat.permute(0, 2, 1), reverse)) # Turning the `upper_mat's` diagonal elements to positive. diagonals = torch.diagonal(upper_mat, dim1=-2, dim2=-1) + eps signs = torch.sign(diagonals) signs_mat = torch.diag_embed(signs) K = torch.matmul(upper_mat, signs_mat) R = torch.matmul(signs_mat, ortho_mat) t = torch.matmul(torch.inverse(K), last_column) return K, R, t
def forward(self, AttM, input, label=None, TrainOrTest='Train', clsGroup=None, showAtt=False): batch_num = len(input) cls_num = len(AttM) x = F.normalize(input) att = self.img_guaid(x) att = self.ReLU(att) att = F.softmax(att / self.tmp).reshape(-1, 1, self.att_dims) att = att + torch.ones_like(att) att = att.repeat(1, cls_num, 1) AttM = AttM.reshape(1, -1, self.att_dims) AttM = AttM.repeat(len(input), 1, 1) AttM = att * AttM W1 = self.L1(AttM) W1 = self.ReLU1(W1) W2 = self.L2(W1) classifier = self.ReLU2(W2) classifier = F.normalize(classifier, p=2, dim=-1, eps=1e-12) out = self.scale_cls * (torch.matmul(classifier, x.t()) + self.bias) out = out.permute(1, 0, 2) out = torch.diagonal(out, offset=0, dim1=1, dim2=2) out = out.t() if TrainOrTest == 'Train': if showAtt: return out, att else: return out else: if showAtt: return out, out / self.scale_cls, att else: return out, out / self.scale_cls
def randneg_train(query_encode_func, doc_encode_func, input_query_ids, query_attention_mask, input_doc_ids, doc_attention_mask, other_doc_ids=None, other_doc_attention_mask=None, hard_pair_mask=None): query_embs = query_encode_func(input_query_ids, query_attention_mask) doc_embs = doc_encode_func(input_doc_ids, doc_attention_mask) with autocast(enabled=False): batch_scores = torch.matmul(query_embs, doc_embs.T) single_positive_scores = torch.diagonal(batch_scores, 0) # other_doc_ids: batch size, per query doc, length other_doc_num = other_doc_ids.shape[0] * other_doc_ids.shape[1] other_doc_ids = other_doc_ids.reshape(other_doc_num, -1) other_doc_attention_mask = other_doc_attention_mask.reshape( other_doc_num, -1) other_doc_embs = doc_encode_func(other_doc_ids, other_doc_attention_mask) with autocast(enabled=False): other_batch_scores = torch.matmul(query_embs, other_doc_embs.T) other_batch_scores = other_batch_scores.reshape(-1) positive_scores = single_positive_scores.reshape(-1, 1).repeat( 1, other_doc_num).reshape(-1) other_logit_matrix = torch.cat( [positive_scores.unsqueeze(1), other_batch_scores.unsqueeze(1)], dim=1) # print(logit_matrix) other_lsm = F.log_softmax(other_logit_matrix, dim=1) other_loss = -1.0 * other_lsm[:, 0] if hard_pair_mask is not None: hard_pair_mask = hard_pair_mask.reshape(-1) other_loss = other_loss * hard_pair_mask second_loss, second_num = other_loss.sum(), hard_pair_mask.sum() else: second_loss, second_num = other_loss.sum(), len(other_loss) return (second_loss / second_num, )
def duel_forward(self, feature, package_valid_matrix): feat = self.linear1(feature) hidden_sum = self.GAT(self.multilinear1, feat, package_valid_matrix) hidden_sum = hidden_sum.transpose(0,1) mask = torch.diagonal(package_valid_matrix[:, :, :, 0],dim1=1,dim2=2) src_key_padding_mask = (1 - mask).bool() output = self.transformer_encoder(hidden_sum, src_key_padding_mask=src_key_padding_mask) adv_output = self.adv_encoder_layers(output, src_key_padding_mask=src_key_padding_mask).transpose(1,0) value_output = self.value_encoder_layers(output, src_key_padding_mask=src_key_padding_mask).transpose(1,0) adv = self.advlinear(adv_output).squeeze(-1) value = torch.sum(value_output * mask.unsqueeze(-1), dim=1) value = self.outputlayer(value) # B * 1 sum_mask = torch.sum(mask, dim=1) value_adv = value - (torch.sum(adv * mask, dim=1)/sum_mask).unsqueeze(1) ret = value_adv + adv ret -= 9999999999 * (1-mask) return ret
def forward(self, z1, z2, **kwargs) -> Tensor: batch_size, _ = z1.size() sim_matrix = torch.einsum('ik,jk->ij', z1, z2) if self.norm: z1_abs = z1.norm(dim=1) z2_abs = z2.norm(dim=1) sim_matrix = sim_matrix / torch.einsum('i,j->ij', z1_abs, z2_abs) sim_matrix = torch.exp(sim_matrix / self.tau) pos_sim = torch.diagonal(sim_matrix) loss = pos_sim / (sim_matrix.sum(dim=1)) loss = - torch.log(loss).mean() if self.variance_reg > 0: loss += self.variance_reg * (std_loss(z1) + std_loss(z2)) if self.covariance_reg > 0: loss += self.covariance_reg * (cov_loss(z1) + cov_loss(z2)) if self.uniformity_reg > 0: loss += self.uniformity_reg * uniformity_loss(z1, z2) return loss
def pdist(embeddings, squared=False): # The shape of embeddings will be [batch_size, 2] max_fn = nn.ReLU() embeddings_transpose = torch.transpose(embeddings, 0, 1) dot_product = torch.mm(embeddings, embeddings_transpose) square_norm = torch.diagonal(dot_product, 0) distances = square_norm.view(1, -1) - 2.0 * dot_product + square_norm.view( -1, 1) # Because of computation errors, some distances might be negative, so we force all of them to be >= 0 distances = max_fn(distances) if not squared: mask = torch.eq(distances, 0.0).float() distances = distances + mask * 1e-16 # Adding small epsilon for numerical stability as gradient of square root function if value is 0 will be infinite distances = torch.sqrt(distances) # Correcting the epsilon value added distances = distances * (1.0 - mask) return distances
def zero_mean_covariance(covariance, stability=0.0): '''Output covariance of ReLU for zero-mean Gaussian input. f(x) = max(x, 0). Args: covariance: Input covariance matrix (Size, Size). stability: For accurate results this should be zero if used in training, use a value like 1e-4 for stability. Returns: Output covariance of ReLU for zero-mean Gaussian input (Size, Size). ''' S = outer(torch.sqrt(torch.diagonal(covariance, 0, -2, -1))) V = (covariance / S).clamp_(stability - 1.0, 1.0 - stability) Q = torch.acos(-V) * V + torch.sqrt(1.0 - (V**2.0)) - 1.0 cov = S * Q * (1.0 / (2.0 * math.pi)) # handle degenerate case when we have zero variance cov[cov != cov] = 0 # replace nans with zeros return cov