def sigma_tf(self, t, X, Y): # M x 1, M x D, M x 1 M = self.M D = self.D return torch.diag_embed(torch.ones([M, D])).to(self.device) # M x D x D
def covariance_matrix(self): covariance_matrix = (torch.matmul(self._unbroadcasted_cov_factor, self._unbroadcasted_cov_factor.transpose(-1, -2)) + torch.diag_embed(self._unbroadcasted_cov_diag)) return covariance_matrix.expand(self._batch_shape + self._event_shape + self._event_shape)
def test_diagflat(dtype, k): backend = pytorch_backend.PyTorchBackend() array = backend.randn((16, ), dtype=dtype, seed=10) actual = backend.diagflat(array, k=k) expected = torch.diag_embed(array, offset=k) np.testing.assert_allclose(expected, actual)
def _eval_covar_matrix(self): cf = self.covar_factor return cf @ cf.transpose(-1, -2) + torch.diag_embed(self.var)
def forward(self, input, derivative=0): """Compute the pairwise distance between the electrons or its derivative. \n When required, the derivative is computed wrt to the first electron i.e. .. math:: \\frac{dr_{ij}}{dx_i} which is different from : .. math:: \\frac{d r_{ij}}{dx_j} Args: input (torch.tesnor): position of the electron \n size : Nbatch x [Nelec x Ndim] derivative (int, optional): degre of the derivative. \n Defaults to 0. Returns: torch.tensor: distance (or derivative) matrix \n Nbatch x Nelec x Nelec if derivative = 0 \n Nbatch x Ndim x Nelec x Nelec if derivative = 1,2 """ # get the distance matrices input_ = input.view(-1, self.nelec, self.ndim) dist = self._get_distance_quadratic(input_) # eosilon on the diag needed for back prop eps_ = self.eps * \ torch.diag(dist.new_ones(dist.shape[-1])).expand_as(dist) # extact the diagonal as diag can be negative someties # due to numerical noise diag = torch.diag_embed(torch.diagonal(dist, dim1=-1, dim2=-2)) # remove diagonal and add eps for backprop dist = torch.sqrt(dist - diag + eps_) if derivative == 0: return dist elif derivative == 1: eps_ = self.eps * \ torch.diag(dist.new_ones( dist.shape[-1])).expand_as(dist) invr = (1. / (dist + eps_)).unsqueeze(1) diff_axis = input_.transpose(1, 2).unsqueeze(3) diff_axis = diff_axis - diff_axis.transpose(2, 3) return diff_axis * invr elif derivative == 2: eps_ = self.eps * \ torch.diag(dist.new_ones( dist.shape[-1])).expand_as(dist) invr3 = (1. / (dist**3 + eps_)).unsqueeze(1) diff_axis = input_.transpose(1, 2).unsqueeze(3) diff_axis = (diff_axis - diff_axis.transpose(2, 3))**2 diff_axis = diff_axis[:, [[1, 2], [2, 0], [0, 1]], ...].sum(2) return (diff_axis * invr3)
def backward(ctx, grads): # (T, T, N, H) -> (N, H, T) grads = torch.diagonal(grads, dim1=0, dim2=1) # (N, H, T, T) -> (T, T, N, H) grads = torch.diag_embed(grads).permute(2, 3, 0, 1) return grads
def glasso_predict(data, COLLECT=True): print('Running ADMM: augmented lagrangian') # predict as a complete batch? criterion = nn.MSELoss() # input, target m_sig = nn.Sigmoid() criterionBCE = nn.BCELoss() theta, S = data #theta, S = theta[0], S[0] # theta -> K_train x N x N (Matrix) # S -> K_train x N x N (observed vector) # train using ALISTA style training. lambda_f = args.lambda_init rho_l1 = args.rho alpha_lr = args.alpha_lr theta_true = convert_to_torch(theta, TESTING_FLAG=True) S = convert_to_torch(S, TESTING_FLAG=True) if args.INIT_DIAG == 1: print(' extract batchwise diagonals, add offset and take inverse') batch_diags = 1 / (torch.diagonal(S, offset=0, dim1=-2, dim2=-1) + args.theta_init_offset) theta_init = torch.diag_embed(batch_diags) else: print('***************** (S+theta_offset*I)^-1 is used') theta_init = torch.inverse(S + args.theta_init_offset * torch.eye(args.N).expand_as(S).type_as(S)) zero = torch.Tensor([0]) #.type(self.dtype) if USE_CUDA == True: zero = zero.cuda() # batch size is fixed for testing as 1 num_batches = 1 #int(len(theta_true)/args.batch_size) # print('num batches: ', num_batches) epoch_loss = [] mse_binary_loss = [] bce_loss = [] frob_loss = [] duality_gap = [] ans = [] #for batch_num in range(num_batches): # processing batchwise # Get a batch #ll = my_cholesky(theta_init[ridx][0])#(theta_pred) # lower triangular theta_pred = theta_init #[batch_num*args.batch_size: (batch_num+1)*args.batch_size] #(theta_pred) # lower triangular Sb = S #[batch_num*args.batch_size: (batch_num+1)*args.batch_size]#[0] U = torch.zeros(Sb.shape).type(Sb.type()) # print('sb type: ', Sb.dtype, Sb.type(), ' U:', U.dtype, U.type()) identity_mat = torch.eye(Sb.shape[-1]).expand_as(Sb) #print('err: ', identity_mat, identity_mat.shape, Sb.shape, S.shape) if USE_CUDA == True: identity_mat = identity_mat.cuda() # print('ITR, conv.loss, obj_val_pred, obj_val_true', theta_pred) res_conv = [] obj_true = get_obj_val(theta_true, S, rho_l1) for k in range(args.L): start = time.time() if COLLECT: theta_pred_diag = torch.diag_embed( torch.diagonal(theta_pred, offset=0, dim1=-2, dim2=-1)) theta_true_diag = torch.diag_embed( torch.diagonal(theta_true, offset=0, dim1=-2, dim2=-1)) cv_loss, cv_loss_off_diag, obj_pred = get_convergence_loss( theta_pred, theta_true), get_convergence_loss( theta_pred - theta_pred_diag, theta_true - theta_true_diag), get_obj_val( theta_pred, S, rho_l1) # print(k, '%.3f %.3f %.3f %.3f' %(cv_loss, obj_pred, obj_true_rho, obj_true_orig)) res_conv.append([cv_loss, obj_pred, obj_true, cv_loss_off_diag]) #print(k, '%.3f %.3f %.3f' %(get_convergence_loss(theta_pred, theta_true), get_obj_val(theta_pred, S, rho_l1), get_obj_val(theta_true, S, rho_l1))) # print('itr = ', itr, theta_pred)#, theta_true[ridx]) # step 1 : ADMM b = 1.0 / lambda_f * Sb - theta_pred + U b2_4ac = torch.matmul(b.transpose(-1, -2), b) + 4.0 / lambda_f * identity_mat sqrt_term = torch_sqrtm(b2_4ac) #batch_matrix_sqrt(b2_4ac) theta_k1 = 1.0 / 2 * (-1 * b + sqrt_term) # step 2 : ADMM # theta_pred = model.eta_forward(theta_k1, k) theta_pred = torch.sign(theta_k1 + U) * torch.max( zero, torch.abs(theta_k1 + U) - rho_l1 / lambda_f) #Z_k1 # step 3: ADMM #lambda_f = lambda_f + alpha_lr * 0.5*get_frobenius_norm(theta_pred-theta_k1) U = U + theta_k1 - theta_pred print( 'fdr, tpr, fpr, shd, nnz, nnz/theta_pred.size, nnz_true,nnz_true/theta_true.size, ps, np.linalg.cond(theta_pred), np.linalg.cond(theta_true)' ) theta_pred = theta_pred.data.cpu().numpy() theta_true = theta_true.data.cpu().numpy() # print('ADMM: ', theta_pred, ' condition number = ', np.linalg.cond(theta_pred))#, ' nnz = ', np.count_nonzero(theta_pred), ' nnz% = ', np.count_nonzero(theta_pred)/theta_pred.size) # print('true: ', theta_true, ' condition_number = ', np.linalg.cond(theta_true))#, ' nnz = ', np.count_nonzero(theta_true), ' nnz% = ', np.count_nonzero(theta_true)/theta_true.size) fdr, tpr, fpr, shd, nnz, nnz_true, ps = metrics.report_metrics( theta_true, theta_pred) cond_theta_pred, cond_theta_true = np.linalg.cond( theta_pred), np.linalg.cond(theta_true) print(fdr, tpr, fpr, shd, nnz, nnz / theta_pred.size, nnz_true, nnz_true / theta_true.size, ps, cond_theta_pred, cond_theta_true) return [ fdr, tpr, fpr, shd, nnz, nnz_true, ps, cond_theta_pred, cond_theta_true ], res_conv
def gista_glasso(data, c=args.c, eps=1e-12, COLLECT=True): c = torch.Tensor([c]) # print('Running Gista glasso') # predict a single matrix criterion = nn.MSELoss() # input, target theta, S = data #theta = theta[0] #theta, S = theta[0], S[0] # theta -> K_train x N x N (Matrix) # S -> K_train x N x N (observed vector) theta_true = convert_to_torch(theta, TESTING_FLAG=True) S = convert_to_torch(S, TESTING_FLAG=True) # extract batchwise diagonals, add offset and take inverse #batch_diags = 1/(torch.diagonal(S, offset=0, dim1=-2, dim2=-1) + args.theta_init_offset) #theta_init = torch.diag_embed(batch_diags) if args.INIT_DIAG == 1: print(' extract batchwise diagonals, add offset and take inverse') batch_diags = 1 / (torch.diagonal(S, offset=0, dim1=-2, dim2=-1) + args.theta_init_offset) theta_init = torch.diag_embed(batch_diags) else: print('***************** (S+theta_offset*I)^-1 is used') theta_init = torch.inverse(S + args.theta_init_offset * torch.eye(args.N).expand_as(S).type_as(S)) #print('err: ', S, batch_diags, theta_init) zero = torch.Tensor([0]) #.type(self.dtype) if USE_CUDA == True: zero = zero.cuda() c = c.cuda() epoch_loss = [] frob_loss = [] duality_gap = [] ans = [] #ll = my_cholesky(theta_init[0])#(theta_pred) # lower triangular ll = torch.cholesky(theta_init) #(theta_pred) # lower triangular #Sb = S[0] #*********** theta_pred = torch.matmul(ll, ll.transpose(-1, -2)) min_eig = torch.min(torch.eig(theta_pred)[0][:, 0]) # All the inputs are ready for the G-ISTA delta = get_duality_gap(theta_pred, S) # initial duality gap step_size = min_eig**2 #print('err2: ', ll, theta_pred, logdet_eig(theta_init)) # print('Checking init delta, step_size, c: ', delta, step_size, c, theta_pred.shape, S.shape)#, theta, S, ll, torch.cholesky(theta_init[0])) # print('ITR, duality_gap, conv.loss, obj_val_pred, obj_val_true')#, theta, S, ll, torch.cholesky(theta_init[0])) epoch = 0 res_conv = [] obj_true = get_obj_val(theta_true, S) while (delta > eps and epoch < args.MAX_EPOCH): #while(epoch < args.MAX_EPOCH): start = time.time() if COLLECT: theta_pred_diag = torch.diag_embed( torch.diagonal(theta_pred, offset=0, dim1=-2, dim2=-1)) theta_true_diag = torch.diag_embed( torch.diagonal(theta_true, offset=0, dim1=-2, dim2=-1)) cv_loss, cv_loss_off_diag, obj_pred = get_convergence_loss( theta_pred, theta_true), -1, get_obj_val(theta_pred, S) # cv_loss, cv_loss_off_diag, obj_pred = get_convergence_loss(theta_pred, theta_true), get_convergence_loss(theta_pred-theta_pred_diag, theta_true-theta_true_diag), get_obj_val(theta_pred, S) # print(k, '%.3f %.3f %.3f %.3f' %(cv_loss, obj_pred, obj_true_rho, obj_true_orig)) res_conv.append([cv_loss, obj_pred, obj_true, cv_loss_off_diag]) #print(epoch, '%.10f %0.3f %.3f %.3f %.3f' %(delta, step_size, get_convergence_loss(theta_pred, theta_true), get_obj_val(theta_pred, S), get_obj_val(theta_true, S))) # print(epoch, '%.10f %.3f %.3f %.3f' %(delta, get_convergence_loss(theta_pred, theta_true), get_obj_val(theta_pred, S), get_obj_val(theta_true, S))) # print('INEQ check: ', delta>eps) theta_pred = torch.matmul(ll, ll.transpose(-1, -2)) theta_prev = theta_pred.clone() # Step 1 & 2: line search and update theta diff_term = S - torch.inverse(theta_pred) update_flag = 0 # print('EigVAL epoch = ', epoch, 'Step size: ', step_size, ' min eigvalue^2: ', min_eig**2) #for j in torch.arange(1, 0, -0.1): #for j in np.arange(1, 0, -0.1): for j in np.arange(1, 10): #cj = j cj = c**j # print('j = ', j, ' cj = ', cj) next_theta = eta(theta_pred - (cj) * step_size * diff_term, (cj) * step_size) if check_conditions(next_theta, theta_prev, S, (cj) * step_size) == 1: # conditions satisfied theta_pred = next_theta update_flag = 1 break if update_flag == 0: print('**********changing the step size to min eigval') min_eig = torch.min(torch.eig(theta_pred)[0][:, 0]) step_size = min_eig**2 #next_pred = eta(theta_pred-step_size*diff_term, step_size) #while check_conditions(next_theta, theta_prev, S, step_size) == 0: # next_pred = eta(theta_pred-step_size*diff_term, step_size) # print('reducing step_size by 1/2', step_size) # step_size = 0.5*step_size #theta_pred = next_pred #eta(theta_pred-step_size*diff_term, step_size) theta_pred = eta(theta_pred - step_size * diff_term, step_size) # Step 3: set next step size #ll = my_cholesky(theta_pred) ll = torch.cholesky(theta_pred) step_size = get_step_size(theta_pred, theta_prev) # Step 4: Calc the duality gap delta = get_duality_gap(theta_pred, S) # print('DELTA epoch = ', epoch, ' delta = ', delta, ' step = ', step_size) epoch += 1 # print('Walltime ', epoch, time.time()-start) theta_pred = theta_pred.data.cpu().numpy() # print('G-ISTA: condition number = ', np.linalg.cond(theta_pred), ' nnz = ', np.count_nonzero(theta_pred), ' nnz% = ', np.count_nonzero(theta_pred)/theta_pred.size) #print('Sample cov inv: ', S, torch.inverse(S)) theta_true = theta_true.data.cpu().numpy() # print('true: condition_number = ', np.linalg.cond(theta_true), ' nnz = ', np.count_nonzero(theta_true), ' nnz% = ', np.count_nonzero(theta_true)/theta_true.size) fdr, tpr, fpr, shd, nnz, nnz_true, ps = metrics.report_metrics( theta_true, theta_pred) cond_theta_pred, cond_theta_true = np.linalg.cond( theta_pred), np.linalg.cond(theta_true) print('Accuracy metrics: fdr ', fdr, ' tpr ', tpr, ' fpr ', fpr, ' shd ', shd, ' nnz ', nnz, ' nnz_true ', nnz_true, ' sign_match ', ps, ' pred_cond ', cond_theta_pred, ' true_cond ', cond_theta_true) # print(fdr, tpr, fpr, shd, nnz, nnz_true) # print('loss_summary:: ', sum(epoch_loss)/len(epoch_loss), ' Mean Frobenius loss: ',sum(frob_loss)/len(frob_loss), ' duality gap = ', sum(duality_gap)/len(duality_gap)) #print('loss_summary:: ', sum(epoch_loss)/len(epoch_loss), ' Mean Frobenius loss: ',sum(frob_loss)/len(frob_loss)) #10*np.log10( (np.sum(np.array(epoch_loss)))/(len(epoch_loss)*E_norm_xtrue))) return [ fdr, tpr, fpr, shd, nnz, nnz_true, ps, cond_theta_pred, cond_theta_true ], res_conv
def forward(self, xf, debug=False, manual_debug=False): # gtvforward s = self.weight_sigma if self.opt.legacy: u = self.cnnu.forward(xf) u = u.unsqueeze(1).unsqueeze(1) else: u = self.uu.forward() u_max = self.opt.u_max u_min = self.opt.u_min if debug: self.u = u.clone() u = torch.clamp(u, u_min, u_max) z = self.opt.H.matmul( xf.view(xf.shape[0], xf.shape[1], self.opt.width**2, 1)) ################### E = self.cnnf.forward(xf) Fs = (self.opt.H.matmul( E.view(E.shape[0], E.shape[1], self.opt.width**2, 1))**2) w = torch.exp(-(Fs.sum(axis=1)) / (s**2)) if debug: s = f"Sample WEIGHT SUM: {w[0, :, :].sum().item():.4f} || Mean Processed u: {u.mean().item():.4f}" self.logger.info(s) w = w.unsqueeze(1).repeat(1, self.opt.channels, 1, 1) W = self.base_W.clone() Z = W.clone() W[:, :, self.opt.connectivity_idx[0], self.opt.connectivity_idx[1]] = w.view(xf.shape[0], 3, -1) W[:, :, self.opt.connectivity_idx[1], self.opt.connectivity_idx[0]] = w.view(xf.shape[0], 3, -1) Z[:, :, self.opt.connectivity_idx[0], self.opt.connectivity_idx[1]] = torch.abs(z.view(xf.shape[0], 3, -1)) Z[:, :, self.opt.connectivity_idx[1], self.opt.connectivity_idx[0]] = torch.abs(z.view(xf.shape[0], 3, -1)) Z = torch.max(Z, self.support_zmax) L = W / Z L1 = L @ self.support_L L = torch.diag_embed(L1.squeeze(-1)) - L ######################## y = xf.view(xf.shape[0], self.opt.channels, -1, 1) ######################## xhat = self.qpsolve(L, u, y, self.support_identity, self.opt.channels) # GLR 2 def glr(y, w, u, debug=False, return_dict=None): W = self.base_W.clone() z = self.opt.H.matmul(y) Z = W.clone() W[:, :, self.opt.connectivity_idx[0], self.opt.connectivity_idx[1]] = w.view(xf.shape[0], 3, -1) W[:, :, self.opt.connectivity_idx[1], self.opt.connectivity_idx[0]] = w.view(xf.shape[0], 3, -1) Z[:, :, self.opt.connectivity_idx[0], self.opt.connectivity_idx[1]] = torch.abs( z.view(xf.shape[0], 3, -1)) Z[:, :, self.opt.connectivity_idx[1], self.opt.connectivity_idx[0]] = torch.abs( z.view(xf.shape[0], 3, -1)) Z = torch.max(Z, self.support_zmax) L = W / Z L1 = L @ self.support_L L = torch.diag_embed(L1.squeeze(-1)) - L xhat = self.qpsolve(L, u, y, self.support_identity, self.opt.channels) return xhat xhat = glr(xhat, w, u) xhat = glr(xhat, w, u) xhat = glr(xhat, w, u) xhat = glr(xhat, w, u) xhat = glr(xhat, w, u) xhat = glr(xhat, w, u) xhat = glr(xhat, w, u) xhat = glr(xhat, w, u) return xhat.view(xhat.shape[0], self.opt.channels, self.opt.width, self.opt.width)
def test_GPyTorchPosterior(self): for dtype in (torch.float, torch.double): n = 3 mean = torch.rand(n, dtype=dtype, device=self.device) variance = 1 + torch.rand(n, dtype=dtype, device=self.device) covar = variance.diag() mvn = MultivariateNormal(mean, lazify(covar)) posterior = GPyTorchPosterior(mvn=mvn) # basics self.assertEqual(posterior.device.type, self.device.type) self.assertTrue(posterior.dtype == dtype) self.assertEqual(posterior.event_shape, torch.Size([n, 1])) self.assertTrue(torch.equal(posterior.mean, mean.unsqueeze(-1))) self.assertTrue(torch.equal(posterior.variance, variance.unsqueeze(-1))) # rsample samples = posterior.rsample() self.assertEqual(samples.shape, torch.Size([1, n, 1])) for sample_shape in ([4], [4, 2]): samples = posterior.rsample(sample_shape=torch.Size(sample_shape)) self.assertEqual(samples.shape, torch.Size(sample_shape + [n, 1])) # check enabling of approximate root decomposition with ExitStack() as es: mock_func = es.enter_context( mock.patch( ROOT_DECOMP_PATH, return_value=torch.linalg.cholesky(covar) ) ) es.enter_context(gpt_settings.max_cholesky_size(0)) es.enter_context( gpt_settings.fast_computations(covar_root_decomposition=True) ) # need to clear cache, cannot re-use previous objects mvn = MultivariateNormal(mean, lazify(covar)) posterior = GPyTorchPosterior(mvn=mvn) posterior.rsample(sample_shape=torch.Size([4])) mock_func.assert_called_once() # rsample w/ base samples base_samples = torch.randn(4, 3, 1, device=self.device, dtype=dtype) # incompatible shapes with self.assertRaises(RuntimeError): posterior.rsample( sample_shape=torch.Size([3]), base_samples=base_samples ) # ensure consistent result for sample_shape in ([4], [4, 2]): base_samples = torch.randn( *sample_shape, 3, 1, device=self.device, dtype=dtype ) samples = [ posterior.rsample( sample_shape=torch.Size(sample_shape), base_samples=base_samples ) for _ in range(2) ] self.assertTrue(torch.allclose(*samples)) # collapse_batch_dims b_mean = torch.rand(2, 3, dtype=dtype, device=self.device) b_variance = 1 + torch.rand(2, 3, dtype=dtype, device=self.device) b_covar = torch.diag_embed(b_variance) b_mvn = MultivariateNormal(b_mean, lazify(b_covar)) b_posterior = GPyTorchPosterior(mvn=b_mvn) b_base_samples = torch.randn(4, 1, 3, 1, device=self.device, dtype=dtype) b_samples = b_posterior.rsample( sample_shape=torch.Size([4]), base_samples=b_base_samples ) self.assertEqual(b_samples.shape, torch.Size([4, 2, 3, 1]))
def lazy_covariance_matrix(self): """Get lazy covariance matrix.""" return CholLazyTensor(torch.diag_embed(self.variance))
def _make_gpytorch_posterior(self, shape, dtype): mean = torch.rand(*shape, dtype=dtype, device=self.device) variance = 1 + torch.rand(*shape, dtype=dtype, device=self.device) covar = torch.diag_embed(variance) mvn = MultivariateNormal(mean, lazify(covar)) return GPyTorchPosterior(mvn=mvn)
def predict(self): dataset = self.dataset["test"] dataloader = dataset.batch_delivery(self.cfg.batch_size) self.model.eval() # Slot-one-hot distribution. # [slot_num, vocab_len] slot_word_dist = F.log_softmax( torch.FloatTensor( self.model.get_unnormalized_phi()), dim=-1) assert torch.isnan(slot_word_dist).sum().item() == 0 # Slot-ce distribution. # [slot_num, feature_dim] slot_mean_dist = torch.FloatTensor( self.model.get_beta_mean()) # [slot_num, emb_dim] slot_stdvar_dist = torch.FloatTensor( self.model.get_beta_logvar()).exp().sqrt() if self.cfg.use_gpu: slot_word_dist = slot_word_dist.cuda() slot_mean_dist = slot_mean_dist.cuda() slot_stdvar_dist = slot_stdvar_dist.cuda() slot_emb_dist = [ MultivariateNormal( loc=slot_mean_dist[k], covariance_matrix=torch.diag_embed( slot_stdvar_dist[k])) for k in range( self.cfg.slot_num)] predictions = [] with torch.no_grad(): pbar = tqdm( dataloader, desc=f"Validating progress") for features, candidates in pbar: oh, ce, mask, lens, candis = dataset.add_padding( features, candidates) if self.cfg.use_gpu: oh = oh.cuda() ce = ce.cuda() mask = mask.cuda() domains, slots = self.model( oh, ce, mask, compute_loss=False) # [batch_size, padded_candi_len, slot_num] padded_logps = torch.stack([slots[:, k].unsqueeze(-1) + slot_emb_dist[k].log_prob(ce) + slot_word_dist[k][candis] for k in range(self.cfg.slot_num)], dim=-1).cpu() assert torch.isnan(padded_logps).sum().item() == 0 # iterating over batch_size for i in range(oh.shape[0]): domain = domains[i] true_len = lens[i] # [candi_len] true_candis = candis[i, :true_len] # [candi_len, slot_num] logps = padded_logps[i, :true_len] probs = torch.softmax(logps, dim=-1) prediction = [{"domain": domain.item(), "slot": torch.argmax(prob).item(), "prob": prob.data.numpy(), "word": self.vocab.itos[candi_index], } for candi_index, prob in zip(true_candis, probs)] predictions.append(prediction) pbar.update(1) return predictions
def subtract_diag(self, gram): diag_elements = torch.diag_embed(torch.diagonal(gram, 0)) gram -= diag_elements return gram
def gmm_loss(log_market_shares, prod_char, market_masks, own_mat, model_ids_to_rows, beta_mat, xw_mat, z_mat, mdraws, mweights, theta, gmm_weights, delta_0=None, full_output=False, atol=1e-06, rtol=1e-06): delta, __ = batch_invert_shares(log_market_shares, prod_char, market_masks, mdraws, mweights, theta, delta_0=delta_0, atol=atol, rtol=rtol) s, cs, ws = batch_shares(delta, prod_char, market_masks, mdraws, mweights, theta, full_output=True) ww = ws * mdraws[:, 0:1, :] * theta[0] Jp = torch.diag_embed(ww.sum(dim=2)) - torch.bmm(ww, torch.transpose(cs, 1, 2)) OJp = Jp * own_mat + torch.diag_embed(1 - market_masks) eta, __ = torch.solve(-s[:, :, None], OJp) eta = eta.squeeze() costs = prod_char[:, :, 0] - eta mc = torch.masked_select(costs, market_masks.bool()) mc[(mc < 0).detach()] = 0.001 log_mc = torch.log(mc) y = torch.cat((torch.masked_select(delta, market_masks.bool()), log_mc), 0) beta = beta_mat @ y res = y - xw_mat @ beta g_hat = z_mat * res[:, None] g_hat_agg = model_ids_to_rows @ g_hat g_hat_mean = g_hat_agg.mean(dim=0) loss = g_hat_mean.t() @ gmm_weights @ g_hat_mean if full_output: with torch.no_grad(): fact = 1.0 / (g_hat_agg.size(0) - 1) cov_mat = g_hat_agg - g_hat_mean cov_mat = fact * cov_mat.t().matmul(cov_mat) cov_mat = torch.inverse(cov_mat) return loss, delta, beta, cov_mat else: return loss, delta, beta
def fast_gaussian_multi_integral( mu0: torch.Tensor, mu1: torch.Tensor, var0: torch.Tensor, var1: torch.Tensor, forward: bool = True, need_zeta: bool = True, diag1: bool = False ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: dim0 = mu0.size()[-1] dim1 = mu1.size()[-1] assert dim0 > dim1 if forward: mu_pad = (0, dim1) var_pad = (0, dim1, 0, dim1) else: mu_pad = (dim1, 0) var_pad = (dim1, 0, dim1, 0) lambda0 = torch.inverse(var0) if diag1: lambda1 = torch.diag_embed(1.0 / var1).float() else: lambda1 = torch.inverse(var1) # use new variable avoid in-place operation mu0_new = mu0.unsqueeze(-1) mu1_new = mu1.unsqueeze(-1) eta0 = torch.matmul(lambda0, mu0_new).squeeze(-1) eta1 = torch.matmul(lambda1, mu1_new).squeeze(-1) # mu0_new = mu0_new.squeeze(-1) eta1_pad = F.pad(eta1, mu_pad, 'constant', 0) lambda1_pad = F.pad(lambda1, var_pad, 'constant', 0) lambda_new = lambda0 + lambda1_pad eta_new = eta0 + eta1_pad sigma_new = torch.inverse(lambda_new) mu_new = torch.matmul(sigma_new, eta_new.unsqueeze(-1)).squeeze(-1) if need_zeta: zeta0 = calculate_zeta(eta0, lambda0, mu=mu0) zeta1 = calculate_zeta(eta1, lambda1, mu=mu1) zeta_new = calculate_zeta(eta_new, lambda_new, sig=sigma_new) scale = zeta0 + zeta1 - zeta_new else: scale = None select = 1 if forward else 0 res_mu = torch.split(mu_new, split_size_or_sections=[dim1, dim1], dim=-1)[select] res_sigma = torch.split(torch.split(sigma_new, split_size_or_sections=[dim1, dim1], dim=-2)[select], split_size_or_sections=[dim1, dim1], dim=-1)[select] return scale, res_mu, res_sigma
def test_symeig(self): dtypes = {"double": torch.double, "float": torch.float} for name, dtype in dtypes.items(): tolerances = self.tolerances["symeig"][name] lazy_tensor = self.create_lazy_tensor().detach().requires_grad_( True) lazy_tensor_copy = lazy_tensor.clone().detach().requires_grad_( True) evaluated = self.evaluate_lazy_tensor(lazy_tensor_copy) # Perform forward pass with linalg_dtypes(dtype): evals_unsorted, evecs_unsorted = lazy_tensor.symeig( eigenvectors=True) evecs_unsorted = evecs_unsorted.evaluate() # since LazyTensor.symeig does not sort evals, we do this here for the check evals, idxr = torch.sort(evals_unsorted, dim=-1, descending=False) evecs = torch.gather( evecs_unsorted, dim=-1, index=idxr.unsqueeze(-2).expand(evecs_unsorted.shape), ) evals_actual, evecs_actual = torch.linalg.eigh( evaluated.type(dtype)) evals_actual = evals_actual.to(dtype=evaluated.dtype) evecs_actual = evecs_actual.to(dtype=evaluated.dtype) # Check forward pass self.assertAllClose(evals, evals_actual, **tolerances) lt_from_eigendecomp = evecs @ torch.diag_embed( evals) @ evecs.transpose(-1, -2) self.assertAllClose(lt_from_eigendecomp, evaluated, **tolerances) # if there are repeated evals, we'll skip checking the eigenvectors for those any_evals_repeated = False evecs_abs, evecs_actual_abs = evecs.abs(), evecs_actual.abs() for idx in itertools.product( *[range(b) for b in evals_actual.shape[:-1]]): eval_i = evals_actual[idx] if torch.unique(eval_i.detach()).shape[-1] == eval_i.shape[ -1]: # detach to avoid pytorch/pytorch#41389 self.assertAllClose(evecs_abs[idx], evecs_actual_abs[idx], **tolerances) else: any_evals_repeated = True # Perform backward pass symeig_grad = torch.randn_like(evals) ((evals * symeig_grad).sum()).backward() ((evals_actual * symeig_grad).sum()).backward() # Check grads if there were no repeated evals if not any_evals_repeated: for arg, arg_copy in zip(lazy_tensor.representation(), lazy_tensor_copy.representation()): if arg_copy.requires_grad and arg_copy.is_leaf and arg_copy.grad is not None: self.assertAllClose(arg.grad, arg_copy.grad, **tolerances) # Test with eigenvectors=False _, evecs = lazy_tensor.symeig(eigenvectors=False) self.assertIsNone(evecs)
def normalize_adj(A): D_in = torch.diag_embed(1.0 / torch.sqrt(A.sum(dim=1))) D_out = torch.diag_embed(1.0 / torch.sqrt(A.sum(dim=2))) DA = stacked_spmm(D_in, A) # swap D_in and D_out DAD = stacked_spmm(DA, D_out) return DAD
def axe_loss(logits: torch.FloatTensor, logit_lengths: torch.Tensor, targets: torch.LongTensor, target_lengths: torch.Tensor, blank_index: torch.LongTensor, delta: torch.FloatTensor, reduction: str = 'mean', label_smoothing: float = None, return_a: bool = False ) -> Union[torch.FloatTensor, List[torch.Tensor]]: """Aligned Cross Entropy Marjan Ghazvininejad, Vladimir Karpukhin, Luke Zettlemoyer, Omer Levy, in arXiv 2020 https://arxiv.org/abs/2004.01655 Computes the aligned cross entropy loss with parallel scheme. Parameters ---------- logits : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` of size (batch_size, sequence_length, num_classes) which contains the unnormalized probability for each class. logit_lengths : ``torch.Tensor``, required. A ``torch.Tensor`` of size (batch_size,) which contains lengths of the logits targets : ``torch.LongTensor``, required. A ``torch.LongTensor`` of size (batch, sequence_length) which contains the index of the true class for each corresponding step. target_lengths : ``torch.Tensor``, required. A ``torch.Tensor`` of size (batch_size,) which contains lengths of the targets blank_index : ``torch.LongTensor``, required. A ``torch.LongTensor``, An index of special blank token. delta : ``torch.FloatTensor``, required. A ``torch.FloatTensor`` for penalizing skip target operators. reduction : ``str``, optional. Specifies the reduction to apply to the output. Default "mean". label_smoothing : ``float``, optional Whether or not to apply label smoothing. return_a : ``bool``, optional. Whether to return the matrix of conditional axe values. Default is False. """ assert targets.size(0) == logits.size(0), f'Inconsistency of batch size, {targets.size(0)} of targets and {logits.size(0)} of logits.' batch_size, logits_sequence_length, num_class = logits.shape _, target_sequence_length = targets.shape device = logits.device # for torch.gather targets = targets.unsqueeze(-1) # batch_size, target_sequence_length, 1 # (batch_size, target_sequence_length + 1, logits_sequence_length + 1) batch_A = torch.zeros(targets.size(0), targets.size(1) + 1, logits.size(1) + 1).to(device) batch_blank_index = torch.full((logits.size(0), 1), blank_index, dtype = torch.long).to(device) log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # A_{i,0} = A_{i−1,0} − delta * log P_1 (Y_i) for i in range(1, targets.size(1) + 1): # batch_A[:, i, 0] is calculated from targets[:, i-1, :], because batch_A added 0-th row batch_A[:, i, 0] = batch_A[:, i-1, 0] - delta * torch.gather(log_probs[:, 0, :], dim=1, index=targets[:, i-1, :]).squeeze(-1) # A_{0,j} = A_{0,j−1} − log P_j ("BLANK") for j in range(1, logits.size(1) + 1): # batch_A[:, 0, j] is calculated from log_probs[:, j-1, :], because batch_A added 0-th column batch_A[:, 0, j] = batch_A[:, 0, j-1] - delta * torch.gather(log_probs[:, j-1, :], dim=1, index=batch_blank_index).squeeze(-1) # flip logit dim to get anti-diagonal part by using use torch.diag batch_A_flip = batch_A.flip(-1) # (batch_size, target_sequence_length + 1, logits_sequence_length + 1) log_probs_flip = log_probs.flip(-2) # (batch_size, sequence_length, num_classes) # to extract indices for the regions corresponding diag part. map_logits = torch.arange(logits.size(1)) - torch.zeros(targets.size(1), 1) map_targets = torch.arange(targets.size(1)).unsqueeze(-1) - torch.zeros((1, logits.size(1))) # index must be int map_logits = map_logits.long().to(device) map_targets = map_targets.long().to(device) for i in range(logits.size(1) - 1, -targets.size(1), -1): # batch_A_flip_sets[:, :, :, 0] : batch_A[:, i , j-1] # batch_A_flip_sets[:, :, :, 1] : batch_A[:, i-1, j ] # batch_A_flip_sets[:, :, :, 2] : batch_A[:, i-1, j-1] batch_A_flip_sets = torch.cat((batch_A_flip.roll(shifts=-1, dims=-1).unsqueeze(-1), batch_A_flip.roll(shifts= 1, dims=-2).unsqueeze(-1), batch_A_flip.roll(shifts=(1, -1), dims=(-2, -1)).unsqueeze(-1)), dim = -1) # trimming # - the last column (A_{0,j} = A_{0,j−1} − log P_j ("BLANK")) # - the first row (A_{i,0} = A_{i−1,0} − delta * log P_1 (Y_i)) batch_A_flip_sets_trim = batch_A_flip_sets[:, 1:, :-1, :] # extracting anti-diagonal part # (batch, 3, num_diag) A_diag = batch_A_flip_sets_trim.diagonal(offset=i, dim1 = -3, dim2 = -2) # (batch, num_diag, 3) A_diag = A_diag.transpose(-1, -2) num_diag = A_diag.size(1) logit_indices = map_logits.diagonal(offset=i, dim1 = -2, dim2 = -1) # log_probs_diag : (batch, num_diag, num_class) log_probs_flip_diag = log_probs_flip[:, logit_indices[0]:logit_indices[-1]+1, :] target_indices = map_targets.diagonal(offset=i, dim1 = -2, dim2 = -1) # targets_diag : (batch, num_diag, num_class) targets_diag = targets[:, target_indices[0]:target_indices[-1]+1, :] # align, skip_prediction, skip_target batch_align = A_diag[:, :, 2] - torch.gather(log_probs_flip_diag, dim=2, index=targets_diag).squeeze(-1) batch_skip_prediction = A_diag[:, :, 0] - torch.gather(log_probs_flip_diag, dim=2, index=batch_blank_index.expand(-1, num_diag).unsqueeze(-1)).squeeze(-1) batch_skip_target = A_diag[:, :, 1] - delta * torch.gather(log_probs_flip_diag, dim=2, index=targets_diag).squeeze(-1) # (batch_size, num_diag, 3) operations = torch.cat((batch_align.unsqueeze(-1), batch_skip_prediction.unsqueeze(-1), batch_skip_target.unsqueeze(-1)), dim = -1) # (batch_size, num_diag) diag_axe = torch.min(operations, dim = -1).values assert logits.size(1) > targets.size(1), "assuming target length < logit length." if i > (logits.size(1) - targets.size(1)): # (batch_size, logits_length, logits_length) # -> (batch_size, targets_length, logits_length) axe = torch.diag_embed(diag_axe, offset=i, dim1=-2, dim2=-1) batch_A_flip[:, 1:, :-1] += axe[:, :targets.size(1), :] elif i > 0: # (batch_size, logits_length, logits_length) # -> (batch_size, targets_length, logits_length) axe = torch.diag_embed(diag_axe, offset=0, dim1=-2, dim2=-1) batch_A_flip[:, 1:, i : i + targets.size(1)] += axe else: axe = torch.diag_embed(diag_axe, offset=i, dim1=-2, dim2=-1) batch_A_flip[:, 1:, :targets.size(1)] += axe # recover correct order in logit dim batch_A = batch_A_flip.flip(-1) # rm 0-th row and column _batch_A = batch_A[:, 1:, 1:] ## Gather A_nm, avoiding masks # index_m : (batch_size, target_sequence_length, 1) index_m = logit_lengths.unsqueeze(-1).expand(-1, _batch_A.size(1)).unsqueeze(-1).long() # gather m-th colmun # batch_A_nm : (batch_size, target_sequence_length, 1) # index_m occors out of bounds for index batch_A_m = torch.gather(_batch_A, dim=2, index=(index_m - 1)) batch_A_m = batch_A_m.squeeze(-1) # index_n : (batch_size, 1) index_n = target_lengths.unsqueeze(-1).long() # gather n-th row # batch_A_nm : (batch_size, 1, 1) batch_A_nm = torch.gather(batch_A_m, dim=1, index=(index_n - 1)) # batch_A_nm : (batch_size) batch_A_nm = batch_A_nm.squeeze(-1) if reduction == "mean": axe_nm = batch_A_nm.mean() else: raise NotImplementedError # Refs fairseq nat_loss. # https://github.com/pytorch/fairseq/blob/6f6461b81ac457b381669ebc8ea2d80ea798e53a/fairseq/criterions/nat_loss.py#L70 # actuary i'm not sure this is reasonable. if label_smoothing is not None and label_smoothing > 0.0: axe_nm = axe_nm * (1.0-label_smoothing) - log_probs.mean() * label_smoothing if return_a: return axe_nm, batch_A.detach() return axe_nm
def dynam_prediction(self, kp, graph, action=None, eps=5e-2, env=None, node_params=None): # kp: B x n_his x n_kp x (mean + covariance) # action: B x n_his x n_kp x action_dim args = self.args nf = args.nf_hidden_dy * 4 action_dim = args.action_dim node_attr_dim = args.node_attr_dim edge_attr_dim = args.edge_attr_dim edge_type_num = args.edge_type_num B, n_his, n_kp, _ = kp.size() # node_attr: B x n_kp x node_attr_dim # edge_attr: B x n_kp x n_kp x edge_attr_dim # edge_type: B x n_kp x n_kp x edge_type_num # edge_type_logits: B x n_kp x n_kp x edge_type_num node_attr, edge_attr, edge_type, edge_type_logits = graph # node_enc: B x n_his x n_kp x nf # edge_enc: B x n_his x (n_kp * n_kp) x nf node_enc = torch.cat([ kp, node_attr.view(B, 1, n_kp, node_attr_dim).repeat(1, n_his, 1, 1), node_params.view(B, 1, n_kp, args.node_params).repeat( 1, n_his, 1, 1) ], 3) edge_enc = torch.cat([ torch.cat([ kp[:, :, :, None, :].repeat(1, 1, 1, n_kp, 1), kp[:, :, None, :, :].repeat(1, 1, n_kp, 1, 1) ], 4), edge_attr.view(B, 1, n_kp, n_kp, edge_attr_dim).repeat( 1, n_his, 1, 1, 1) ], 4) node_enc, edge_enc = self.model_dynam_encode( node_enc.view( B * n_his, n_kp, node_attr_dim + (args.state_dim + args.state_dim**2) + args.node_params), edge_enc.view( B * n_his, n_kp, n_kp, edge_attr_dim + 2 * (args.state_dim + args.state_dim**2)), edge_type[:, None, :, :, :].repeat(1, n_his, 1, 1, 1).view(B * n_his, n_kp, n_kp, edge_type_num), start_idx=args.edge_st_idx) node_enc = node_enc.view(B, n_his, n_kp, nf) edge_enc = edge_enc.view(B, n_his, n_kp * n_kp, nf) # node_enc: B x n_kp x n_his x nf # edge_enc: B x (n_kp * n_kp) x n_his x nf node_enc = node_enc.transpose(1, 2).contiguous().view(B, n_kp, n_his, nf) edge_enc = edge_enc.transpose(1, 2).contiguous().view( B, n_kp * n_kp, n_his, nf) # node_enc: B x n_kp x n_his x (nf + node_attr_dim + action_dim) # kp_node: B x n_kp x n_his x 6 kp_node = kp.transpose(1, 2).contiguous().view( B, n_kp, n_his, (args.state_dim + args.state_dim**2)) node_enc = torch.cat([ kp_node, node_enc, node_attr.view(B, n_kp, 1, node_attr_dim).repeat(1, 1, n_his, 1) ], 3) # edge_enc: B x (n_kp * n_kp) x n_his x (nf + edge_attr_dim + action_dim) # kp_edge: B x (n_kp * n_kp) x n_his x (2 + 2) kp_edge = torch.cat([ kp_node[:, :, None, :, :].repeat(1, 1, n_kp, 1, 1), kp_node[:, None, :, :, :].repeat(1, n_kp, 1, 1, 1) ], 4) kp_edge = kp_edge.view(B, n_kp**2, n_his, 2 * (args.state_dim + args.state_dim**2)) edge_enc = torch.cat([ kp_edge, edge_enc, edge_attr.view(B, n_kp**2, 1, edge_attr_dim).repeat( 1, 1, n_his, 1) ], 3) # append action if action is not None: action = action[:, :, :, None] action_t = action.transpose(1, 2).contiguous() action_t_r = action_t[:, :, None, :, :].repeat(1, 1, n_kp, 1, 1).view( B, n_kp**2, n_his, action_dim) action_t_s = action_t[:, None, :, :, :].repeat(1, n_kp, 1, 1, 1).view( B, n_kp**2, n_his, action_dim) # print('node_enc', node_enc.size(), 'edge_enc', edge_enc.size()) # print('action_t', action_t.size(), 'action_t_r', action_t_r.size(), 'action_t_s', action_t_s.size()) node_enc = torch.cat([node_enc, action_t], 3) edge_enc = torch.cat([edge_enc, action_t_r, action_t_s], 3) # node_enc: B x n_kp x nf # edge_enc: B x n_kp x n_kp x nf node_enc = self.model_dynam_node_forward( node_enc.view(B * n_kp, n_his, -1)).view(B, n_kp, nf) edge_enc = self.model_dynam_edge_forward( edge_enc.view(B * n_kp**2, n_his, -1)).view(B, n_kp, n_kp, nf) # kp_pred: B x n_kp x (2 + 3) node_enc = torch.cat([node_enc, node_attr, kp_node[:, :, -1]], 2) edge_enc = torch.cat([ edge_enc, edge_attr, kp_edge[:, :, -1].view( B, n_kp, n_kp, 2 * (args.state_dim + args.state_dim**2)) ], 3) if action is not None: # print('node_enc', node_enc.size(), 'edge_enc', edge_enc.size(), 'action', action.size()) action_r = action[:, :, :, None, :].repeat(1, 1, 1, n_kp, 1) action_s = action[:, :, None, :, :].repeat(1, 1, n_kp, 1, 1) node_enc = torch.cat([node_enc, action[:, -1]], 2) edge_enc = torch.cat([edge_enc, action_r[:, -1], action_s[:, -1]], 3) kp_pred = self.model_dynam_decode(node_enc, edge_enc, edge_type, start_idx=args.edge_st_idx, ignore_edge=True) # kp_pred: B x n_kp x (mean + covariance) # Predicting change in state # kp_pred = torch.cat([ # kp[:, -1, :, :args.state_dim] + kp_pred[:, :, :args.state_dim], # mean # F.relu(kp_pred[:, :, 2:3]) + args.gauss_std, # covar (0, 0), need to > 0 # torch.zeros(B, n_kp, 1).cuda(), # covar (0, 1) # kp_pred[:, :, 3:4], # covar (1, 0) # F.relu(kp_pred[:, :, 4:5]) + args.gauss_std], # covar (1, 1), need to > 0 # dim=2) kp_pred = torch.cat( [ kp[:, -1, :, :args.state_dim] + kp_pred[:, :, :args.state_dim], # mean torch.diag_embed( F.relu(kp_pred[:, :, args.state_dim:]) + args.gauss_std).view(B, n_kp, args.state_dim * args.state_dim) ], dim=2) return kp_pred
def __init__(self, dim: int, num_operations: int): super().__init__(dim, num_operations) self.linear_transformations = nn.Parameter( torch.diag_embed(torch.ones(()).expand(num_operations, dim))) self.translations = nn.Parameter(torch.zeros((self.num_operations, self.dim)))
def ops_2_to_2(self, inputs, dim, normalization='inf', normalization_val=1.0): # N x D x m x m # print(f'input shape : {inputs.shape}') diag_part = torch.diagonal(inputs, dim1=-2, dim2=-1) # N x D x m # print(f'diag_part shape : {diag_part.shape}') sum_diag_part = torch.sum(diag_part, dim=2, keepdim=True) # N x D x 1 # print(f'sum_diag_part shape : {sum_diag_part.shape}') sum_of_rows = torch.sum(inputs, dim=3) # N x D x m # print(f'sum_of_rows shape : {sum_of_rows.shape}') sum_of_cols = torch.sum(inputs, dim=2) # N x D x m # print(f'sum_of_cols shape : {sum_of_cols.shape}') sum_all = torch.sum(sum_of_rows, dim=2) # N x D # print(f'sum_all shape : {sum_all.shape}') # op1 - (1234) - extract diag op1 = torch.diag_embed(diag_part) # N x D x m x m # op2 - (1234) + (12)(34) - place sum of diag on diag op2 = torch.diag_embed(sum_diag_part.repeat(1, 1, dim)) # N x D x m x m # op3 - (1234) + (123)(4) - place sum of row i on diag ii op3 = torch.diag_embed(sum_of_rows) # N x D x m x m # op4 - (1234) + (124)(3) - place sum of col i on diag ii op4 = torch.diag_embed(sum_of_cols) # N x D x m x m # op5 - (1234) + (124)(3) + (123)(4) + (12)(34) + (12)(3)(4) - place sum of all entries on diag op5 = torch.diag_embed(torch.unsqueeze(sum_all, dim=2).repeat(1, 1, dim)) # N x D x m x m # op6 - (14)(23) + (13)(24) + (24)(1)(3) + (124)(3) + (1234) - place sum of col i on row i op6 = torch.unsqueeze(sum_of_cols, dim=3).repeat(1, 1, 1, dim) # N x D x m x m # op7 - (14)(23) + (23)(1)(4) + (234)(1) + (123)(4) + (1234) - place sum of row i on row i op7 = torch.unsqueeze(sum_of_rows, dim=3).repeat(1, 1, 1, dim) # N x D x m x m # op8 - (14)(2)(3) + (134)(2) + (14)(23) + (124)(3) + (1234) - place sum of col i on col i op8 = torch.unsqueeze(sum_of_cols, dim=2).repeat(1, 1, dim, 1) # N x D x m x m # op9 - (13)(24) + (13)(2)(4) + (134)(2) + (123)(4) + (1234) - place sum of row i on col i op9 = torch.unsqueeze(sum_of_rows, dim=2).repeat(1, 1, dim, 1) # N x D x m x m # op10 - (1234) + (14)(23) - identity op10 = inputs # N x D x m x m # op11 - (1234) + (13)(24) - transpose op11 = inputs.permute(0, 1, 3, 2) # N x D x m x m # op12 - (1234) + (234)(1) - place ii element in row i op12 = torch.unsqueeze(diag_part, dim=3).repeat(1, 1, 1, dim) # N x D x m x m # op13 - (1234) + (134)(2) - place ii element in col i op13 = torch.unsqueeze(diag_part, dim=2).repeat(1, 1, dim, 1) # N x D x m x m # op14 - (34)(1)(2) + (234)(1) + (134)(2) + (1234) + (12)(34) - place sum of diag in all entries op14 = torch.unsqueeze(sum_diag_part, dim=3).repeat(1, 1, dim, dim) # N x D x m x m # op15 - sum of all ops - place sum of all entries in all entries op15 = torch.unsqueeze(torch.unsqueeze(sum_all, dim=2), dim=3).repeat(1, 1, dim, dim) # N x D x m x m if normalization is not None: float_dim = dim.type(torch.FloatTensor) if normalization is 'inf': op2 = torch.div(op2, float_dim) op3 = torch.div(op3, float_dim) op4 = torch.div(op4, float_dim) op5 = torch.div(op5, float_dim**2) op6 = torch.div(op6, float_dim) op7 = torch.div(op7, float_dim) op8 = torch.div(op8, float_dim) op9 = torch.div(op9, float_dim) op14 = torch.div(op14, float_dim) op15 = torch.div(op15, float_dim**2) return [op1, op2, op3, op4, op5, op6, op7, op8, op9, op10, op11, op12, op13, op14, op15]
def kabsch_transformation_estimation(x1, x2, weights=None, normalize_w=True, eps=1e-7, best_k=0, w_threshold=0): """ Torch differentiable implementation of the weighted Kabsch algorithm (https://en.wikipedia.org/wiki/Kabsch_algorithm). Based on the correspondences and weights calculates the optimal rotation matrix in the sense of the Frobenius norm (RMSD), based on the estimate rotation matrix is then estimates the translation vector hence solving the Procrustes problem. This implementation supports batch inputs. Args: x1 (torch array): points of the first point cloud [b,n,3] x2 (torch array): correspondences for the PC1 established in the feature space [b,n,3] weights (torch array): weights denoting if the coorespondence is an inlier (~1) or an outlier (~0) [b,n] normalize_w (bool) : flag for normalizing the weights to sum to 1 best_k (int) : number of correspondences with highest weights to be used (if 0 all are used) w_threshold (float) : only use weights higher than this w_threshold (if 0 all are used) Returns: rot_matrices (torch array): estimated rotation matrices [b,3,3] trans_vectors (torch array): estimated translation vectors [b,3,1] res (torch array): pointwise residuals (Eucledean distance) [b,n] valid_gradient (bool): Flag denoting if the SVD computation converged (gradient is valid) """ if weights is None: weights = torch.ones(x1.shape[0], x1.shape[1]).type_as(x1).to(x1.device) if normalize_w: sum_weights = torch.sum(weights, dim=1, keepdim=True) + eps weights = (weights / sum_weights) weights = weights.unsqueeze(2) if best_k > 0: indices = np.argpartition(weights.cpu().numpy(), -best_k, axis=1)[0, -best_k:, 0] weights = weights[:, indices, :] x1 = x1[:, indices, :] x2 = x2[:, indices, :] if w_threshold > 0: weights[weights < w_threshold] = 0 x1_mean = torch.matmul(weights.transpose(1, 2), x1) / (torch.sum(weights, dim=1).unsqueeze(1) + eps) x2_mean = torch.matmul(weights.transpose(1, 2), x2) / (torch.sum(weights, dim=1).unsqueeze(1) + eps) x1_centered = x1 - x1_mean x2_centered = x2 - x2_mean weight_matrix = torch.diag_embed(weights.squeeze(2)) cov_mat = torch.matmul(x1_centered.transpose(1, 2), torch.matmul(weight_matrix, x2_centered)) try: u, s, v = torch.svd(cov_mat) except Exception as e: r = torch.eye(3, device=x1.device) r = r.repeat(x1_mean.shape[0], 1, 1) t = torch.zeros((x1_mean.shape[0], 3, 1), device=x1.device) res = transformation_residuals(x1, x2, r, t) return r, t, res, True tm_determinant = torch.det( torch.matmul(v.transpose(1, 2), u.transpose(1, 2))) determinant_matrix = torch.diag_embed( torch.cat((torch.ones((tm_determinant.shape[0], 2), device=x1.device), tm_determinant.unsqueeze(1)), 1)) rotation_matrix = torch.matmul( v, torch.matmul(determinant_matrix, u.transpose(1, 2))) # translation vector translation_matrix = x2_mean.transpose(1, 2) - torch.matmul( rotation_matrix, x1_mean.transpose(1, 2)) # Residuals res = transformation_residuals(x1, x2, rotation_matrix, translation_matrix) return rotation_matrix, translation_matrix, res, False
def _create_marginal_input(self, batch_shape=torch.Size()): mat = torch.randn(*batch_shape, 5, 5) eye = torch.diag_embed(torch.ones(*batch_shape, 5)) return MultivariateNormal(torch.randn(*batch_shape, 5), mat @ mat.transpose(-1, -2) + eye)
def forward(self, images, imu_data, prev_lstm_states, prev_pose, prev_state, prev_covar, T_imu_cam): vis_meas_covar_scale = torch.ones(6, device=images.device) vis_meas_covar_scale[0:3] = vis_meas_covar_scale[0:3] * par.k4 imu_noise_covar = self.get_imu_noise_covar() if prev_covar is None: prev_covar = torch.diag(self.init_covar_diag_sqrt * self.init_covar_diag_sqrt + par.init_covar_diag_eps).repeat( images.shape[0], 1, 1) encoded_images = self.vo_module.encode_image(images) num_timesteps = images.size(1) - 1 # equals to imu_data.size(1) - 1 poses_over_timesteps = [prev_pose] states_over_timesteps = [prev_state] covars_over_timesteps = [prev_covar] vis_meas_over_timesteps = [] vis_meas_covar_over_timesteps = [] lstm_states = prev_lstm_states for k in range(0, num_timesteps): # ekf predict pred_states, pred_covars = self.ekf_module.predict( imu_data[:, k], imu_noise_covar, states_over_timesteps[-1], covars_over_timesteps[-1]) if par.hybrid_recurrency and par.enable_ekf: # concatenate the predicted states and covar with the encoded images to feed into LSTM last_pred_state_so3 = IMUKalmanFilter.state_to_so3( pred_states[-1]) last_pred_covar_flattened = pred_covars[-1].view( -1, IMUKalmanFilter.STATE_VECTOR_DIM**2) feature_vector = torch.cat([ last_pred_state_so3, last_pred_covar_flattened, encoded_images[:, k] ], -1) else: feature_vector = encoded_images[:, k] # get vis measurement vis_meas_and_covar, lstm_states = self.vo_module.forward_one_ts( feature_vector, lstm_states) vis_meas = vis_meas_and_covar[:, 0:6] # process vis meas covar if par.vis_meas_covar_use_fixed: vis_meas_covar_diag = torch.tensor(par.vis_meas_fixed_covar, dtype=torch.float32, device=vis_meas.device) vis_meas_covar_diag = vis_meas_covar_diag * vis_meas_covar_scale vis_meas_covar_diag = vis_meas_covar_diag.repeat( vis_meas.shape[0], vis_meas.shape[1], 1) else: vis_meas_covar_diag = par.vis_meas_covar_init_guess * \ 10 ** (par.vis_meas_covar_beta * torch.tanh(par.vis_meas_covar_gamma * vis_meas_and_covar[:, 6:12])) vis_meas_covar_scaled = torch.diag_embed( vis_meas_covar_diag / vis_meas_covar_scale.view(1, 6)) vis_meas_covar = torch.diag_embed(vis_meas_covar_diag) # ekf correct est_state, est_covar = self.ekf_module.update( pred_states[-1], pred_covars[-1], vis_meas.unsqueeze(-1), vis_meas_covar_scaled, T_imu_cam) new_pose, new_state, new_covar = self.ekf_module.composition( poses_over_timesteps[-1], est_state, est_covar) poses_over_timesteps.append(new_pose) states_over_timesteps.append(new_state) covars_over_timesteps.append(new_covar) vis_meas_over_timesteps.append(vis_meas) vis_meas_covar_over_timesteps.append(vis_meas_covar) return torch.stack(vis_meas_over_timesteps, 1), \ torch.stack(vis_meas_covar_over_timesteps, 1), \ lstm_states, \ torch.stack(poses_over_timesteps, 1), \ torch.stack(states_over_timesteps, 1), \ torch.stack(covars_over_timesteps, 1)
def get_distribution(self, mean, std): if self.action_dim == 1: normal = Normal(mean, std) else: normal = MultivariateNormal(mean, torch.diag_embed(std)) return normal
def sigma_tf(self, t, X, Y): # M x 1, M x D, M x 1 return 0.4 * torch.diag_embed(X) # M x D x D
def LBO(V, F): W, V_area = LBO_slim(V, F) area_matrix = torch.diag_embed(V_area) area_matrix_inv = torch.diag_embed(1 / V_area) L = torch.bmm(area_matrix_inv, W) # VALIDATED return L, area_matrix, area_matrix_inv, W
def rho_precision(logs, rho, z_dim): precision_matrix = (torch.diag_embed(-rho.unsqueeze(1).expand(-1, z_dim - 1), offset=-1) + torch.diag_embed(-rho.unsqueeze(1).expand(-1, z_dim - 1), offset=1) + torch.diag_embed(F.pad((1 + rho * rho).unsqueeze(1).expand(-1, z_dim - 2), (1, 1), value=1.))) \ * (torch.exp(-logs) / (1 - rho * rho))[:, None, None] return precision_matrix
def vec_to_diag(vec): return _torch.diag_embed(vec, offset=0)