def dPsid1dcov(self, x2, mu, cov): """ derivative of Psid1 with respect to cov """ new_cov = self.lengthscales.squeeze(1).diag().pow(2).unsqueeze( 0) + cov # (S + A) # R x K x K new_cov_inv = batch_inverse_psd(new_cov) scaled_diff = self.scaled_paired_diff( mu, x2, new_cov) # (S + A)^{-1}(m - z) is R x 1 x N2 x K term0 = new_cov_inv.unsqueeze(1).unsqueeze(1).unsqueeze(-1) * ( scaled_diff.unsqueeze(-2).unsqueeze(-2)) term0_diag = batch_make_diag(term0) term1 = self.Psi1(x2, mu, cov).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * ( term0 + term0.transpose(-2, -1) - term0_diag ) # correcting for symmetric cov term2 = -self.dPsi1dcov( x2, mu, cov).unsqueeze(-3) * scaled_diff.unsqueeze(-1).unsqueeze( -1) # (R x 1 x N2 x K) x (K x K) return term1 + term2
def dPsi2dcov(self, x2, mu, cov, x3=None): if x3 is None: x3 = x2 # we actually want this to point to the same object lengthscales_new_squared = 0.5 * self.lengthscales.squeeze( 1).diag().pow(2).unsqueeze(0) + cov # (1/2 A + S) x2_new = 0.5 * (x2.unsqueeze(2) + x3.unsqueeze(1)) # (z + z') / 2 mu_x2_new_diffs = mu.unsqueeze( 1) - x2_new # R x M x M x K is (m - (z + z') / 2 ) fromDet = 2 * cov.matmul( (1. / self.lengthscales.squeeze(1).pow(2)).diag()) + torch.eye( self.lengthscales.size(0), device=x2.device).type( float_type) # (2 S A^{-1} + I) scaled_diffs, _ = torch.solve(mu_x2_new_diffs.transpose(-1, -2), lengthscales_new_squared.unsqueeze(-3)) term1 = scaled_diffs.transpose( -1, -2 ).unsqueeze(-1) * scaled_diffs.transpose(-1, -2).unsqueeze( -2 ) # (m - (z + z') / 2 )'(1/2A + S)^{-1} dS/dS_{ij} (1/2A + S)^{-1} (m - (z + z') / 2 ) term2 = batch_inverse_psd(fromDet).div( self.lengthscales.pow(2)) # A^{-1} (2 S A^{-1} + I)^{-1} dcov = self.Psi2(x2, mu, cov, x3).unsqueeze(-1).unsqueeze(-1) * ( 0.5 * term1 - term2.unsqueeze(1).unsqueeze(1) ) # (R x M x M) x (K x K) return dcov + dcov.transpose(-2, -1) - batch_make_diag(dcov)
def get_KullbackLeibler_grad(self, model, m, S, idx): # gradients for Kullback leibler divergence # gradients wrt m are (R x 1 x 1) x (1 x K) with torch.no_grad(): dEdm_grid = 0.5 * model.transfunc.dffdm(m, S) \ - (model.transfunc.dfdm(m, S) * (self.b_grid[idx].unsqueeze(-1).unsqueeze(-1))).sum(2, keepdim=True) \ + m.matmul(self.A_grid[idx].transpose(-1, -2)).matmul(self.A_grid[idx]).unsqueeze(1).unsqueeze(1) \ - self.b_grid[idx].matmul(self.A_grid[idx]).unsqueeze(1).unsqueeze(1) \ + model.transfunc.f(m, S).matmul(self.A_grid[idx]).unsqueeze(1).unsqueeze(1) \ + (model.transfunc.dfdm(m, S) * m.matmul(self.A_grid[idx].transpose(-1, -2)).unsqueeze(-1).unsqueeze(-1)).sum(2, keepdim=True) \ + (model.transfunc.ddfdxdm(m, S) * self.A_grid[idx].matmul(S).unsqueeze(-1).unsqueeze(-1)).sum(1, keepdim=True).sum(2, keepdim=True) # gradients wrt S are (R x 1 x 1) x (K x K) # part of gradient that is already symmetrised (since grads come from transition function, which expects proper gradients) dEdS_grid_sym = 0.5 * model.transfunc.dffdS(m, S) \ - (model.transfunc.dfdS(m, S) * (self.b_grid[idx].unsqueeze(-1).unsqueeze(-1))).sum(2, keepdim=True) \ + (model.transfunc.dfdS(m, S) * m.matmul(self.A_grid[idx].transpose(-1, -2)).unsqueeze(-1).unsqueeze(-1)).sum(2, keepdim=True) \ + (model.transfunc.ddfdxdS(m, S) * self.A_grid[idx].matmul(S).unsqueeze(-1).unsqueeze(-1)).sum(1, keepdim=True).sum(2, keepdim=True) dEdS_grid_asym = self.A_grid[idx].transpose(-1, -2).matmul(model.transfunc.dfdx(m, S)).unsqueeze(1).unsqueeze(1) \ + 0.5 * self.A_grid[idx].transpose(-1, -2).matmul(self.A_grid[idx]).unsqueeze(1).unsqueeze(1) dEdS_grid = dEdS_grid_sym + dEdS_grid_asym + dEdS_grid_asym.transpose( -2, -1) - batch_make_diag( dEdS_grid_asym) # account for symmetry in S # return more compact representation of gradients return dEdm_grid.squeeze(1).squeeze(1), dEdS_grid.squeeze(1).squeeze(1)
def dOutdS(self, m, S): with torch.no_grad(): dcovdS = self.Subspace.unsqueeze(-1).permute( 1, 2, 0) * self.Subspace.unsqueeze(-1).permute(1, 0, 2) # D x K x K dcovdS = dcovdS + dcovdS.transpose(-1, -2) - batch_make_diag( dcovdS) # account for symmetry in S dmudS = torch.zeros_like(dcovdS) # outputs are 1 x D x K x K return dmudS.unsqueeze(0), dcovdS.unsqueeze(0)
def dPsid1Psi1dcov(self, x2, mu, cov, x3=None): if mu.size(-2) != 1: # Psi2 not working for matrix version raise ( 'dimensionality mismatch: input mean needs to be a R x 1 x K vector' ) if x3 is None: x3 = x2 # we actually want this to point to the same object new_cov = 0.5 * self.lengthscales.squeeze(1).diag().pow(2).unsqueeze( 0) + cov # ( 1/2 A + S) x2_scaled = x2.div(self.lengthscales.pow(2).view(1, -1)) # s'A^{-1} x2_x3_sum = 0.5 * (x2.unsqueeze(-2) + x3.unsqueeze(-3) ) # R x M2 x M3 x K is (s + z)/2 x2_x3_sum_scaled = x2_x3_sum.matmul( cov.unsqueeze(-3).div(self.lengthscales.pow(2).view( -1, 1))) # (S A^{-1} (s + z) / 2)' R x M2 x M3 x K mu_x2_new_diffs = 0.5 * mu.unsqueeze( -2 ) + x2_x3_sum_scaled # R x M2 x M3 x K is (1/2 m + S A^{-1} (s + z) / 2 ) scaled_diffs, _ = torch.solve( mu_x2_new_diffs.unsqueeze(-1), new_cov.unsqueeze(-3).unsqueeze(-3)) # R x M2 x M3 x K x 1 new_cov_inv = batch_inverse_psd(new_cov) # inv( 1/2 A + S) term0 = -new_cov_inv.unsqueeze(1).unsqueeze(1).unsqueeze(-1) * ( scaled_diffs.transpose(-2, -1).unsqueeze(-3) - x2_x3_sum.div(self.lengthscales.pow(2).view( 1, -1)).unsqueeze(-2).unsqueeze(-2)) term0_diag = batch_make_diag(term0) term1 = self.Psi2(x2, mu, cov, x3).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * ( term0 + term0.transpose(-2, -1) - term0_diag ) # correcting for symmetric cov term2 = self.dPsi2dcov(x2, mu, cov, x3).unsqueeze(-3) * ( scaled_diffs.squeeze(-1) - x2_scaled.unsqueeze(-2) ).unsqueeze(-1).unsqueeze(-1) # *(R x N2 x N3 x K) x (K x K) return term1 + term2
def dPsi1dcov(self, x2, mu, cov): """ derivative of Psi1 wrt mu Output is (R x 1 x M) x (K x K) """ lengthscales_new_squared = self.lengthscales.squeeze(1).diag().pow( 2).unsqueeze(0) + cov # (A + S) scaled_diffs = self.scaled_paired_diff( mu, x2, lengthscales_new_squared) # (A + S)^{-1} (m - z) term1 = scaled_diffs.unsqueeze(-1) * scaled_diffs.unsqueeze( -2) # (m - z)'(A + S)^{-1} dS/dS_{ij} (A + S)^{-1} (m - z) term2 = batch_inverse_psd( lengthscales_new_squared ) # A^{-1} (S A^{-1} + I)^{-1} = (A + S)^{-1} R x K x K dcov = 0.5 * self.Psi1(x2, mu, cov).unsqueeze(-1).unsqueeze(-1) * ( term1 - term2.unsqueeze(1).unsqueeze(1)) return dcov + dcov.transpose(-2, -1) - batch_make_diag( dcov) # (R x N1 x N2) x (K x K)
def dPsid1Psid2dcov(self, x2, mu, cov, x3=None): """ derivative of <d1kd2k> wrt covariance """ if mu.size(-2) != 1: # Psi2 not working for matrix version raise ( 'dimensionality mismatch: input mean needs to be a R x 1 x K vector' ) if x3 is None: x3 = x2 # we actually want this to point to the same object new_cov = 0.5 * self.lengthscales.squeeze(1).diag().pow(2).unsqueeze( 0) + cov # ( 1/2 A + S) x2_scaled = x2.div(self.lengthscales.pow(2).view(1, -1)) # s'A^{-1} x3_scaled = x3.div(self.lengthscales.pow(2).view(1, -1)) # s'A^{-1} x2_x3_sum = 0.5 * (x2.unsqueeze(-2) + x3.unsqueeze(-3) ) # R x M2 x M3 x K is (s + z)/2 x2_x3_sum_scaled = x2_x3_sum.matmul( cov.div(self.lengthscales.pow(2).view( -1, 1)).unsqueeze(-3)) # (S A^{-1} (s + z) / 2)' R x M2 x M3 x K mu_x2_new_diffs = 0.5 * mu.unsqueeze( -2 ) + x2_x3_sum_scaled # R x M2 x M3 x K is (1/2 m + S A^{-1} (s + z) / 2 ) new_cov_inv = batch_inverse_psd(new_cov) # inv( 1/2 A + S) scaled_diffs, _ = torch.solve(mu_x2_new_diffs.transpose(-1, -2), new_cov.unsqueeze(-3)) scaled_diffs, _ = torch.solve( mu_x2_new_diffs.unsqueeze(-1), new_cov.unsqueeze(-3).unsqueeze(-3)) # R x M2 x M3 x K x 1 meanx2 = scaled_diffs - x2_scaled.unsqueeze(-2).unsqueeze( -1) # ... x R x M2 x M3 x K x 1 meanx3 = scaled_diffs - x3_scaled.unsqueeze(-3).unsqueeze( -1) # ... x R x M2 x M3 x K x 1 Sigma = 0.5 * new_cov_inv.matmul(cov).div( self.lengthscales.pow(2).view(1, -1)) # R x K x K AinvdSigmaAinv = 0.25 * ( new_cov_inv.unsqueeze(-2).unsqueeze(-1) * new_cov_inv.unsqueeze(-3).unsqueeze(-2)).unsqueeze(-5).unsqueeze( -5) # R x 1 x 1 x K x K x K x K AinvdSigmaAinv_diag = batch_make_diag(AinvdSigmaAinv) AinvdSigmaAinv = AinvdSigmaAinv + AinvdSigmaAinv.transpose( -2, -1) - AinvdSigmaAinv_diag term0 = -new_cov_inv.unsqueeze(1).unsqueeze(1).unsqueeze(-1) * ( scaled_diffs.transpose(-2, -1) - x2_x3_sum.div(self.lengthscales.pow(2).view( 1, -1)).unsqueeze(-2)).unsqueeze(-2) term0_diag = batch_make_diag(term0) Ainvdmean = (term0 + term0.transpose(-2, -1) - term0_diag).unsqueeze(-3) term1 = self.dPsi2dcov(x2, mu, cov, x3).unsqueeze(-3).unsqueeze(-3) * ( Sigma.unsqueeze(-3).unsqueeze(-3) + meanx2 * meanx3.transpose(-2, -1)).unsqueeze(-1).unsqueeze(-1) term2 = self.Psi2(x2, mu, cov, x3).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) * \ (AinvdSigmaAinv + meanx2.unsqueeze(-1).unsqueeze(-1) * Ainvdmean.transpose(-4, -3) + Ainvdmean * meanx3.transpose(-1, -2).unsqueeze(-1).unsqueeze(-1)) # R x M2 x M3 x K x K x K x K return term1 + term2