def backward(ctx, grad_output): L, control_var, epsilon = ctx.saved_tensors B, C = control_var g = grad_output loc_grad = sum_leftmost(grad_output, -1) # compute the rep trick gradient epsilon_jb = epsilon.unsqueeze(-2) g_ja = g.unsqueeze(-1) diff_L_ab = sum_leftmost(g_ja * epsilon_jb, -2) # modulate the velocity fields with infinitesimal rotations, i.e. apply the control variate gL = torch.matmul(g, L) eps_gL_ab = sum_leftmost(gL.unsqueeze(-1) * epsilon.unsqueeze(-2), -2) xi_ab = eps_gL_ab - eps_gL_ab.t() BC_lab = B.unsqueeze(-1) * C.unsqueeze(-2) diff_L_ab += (xi_ab.unsqueeze(0) * BC_lab).sum(0) L_grad = torch.tril(diff_L_ab) # compute control_var grads diff_B = (L_grad.unsqueeze(0) * C.unsqueeze(-2) * xi_ab.unsqueeze(0)).sum(2) diff_C = (L_grad.t().unsqueeze(0) * B.unsqueeze(-2) * xi_ab.t().unsqueeze(0)).sum(2) diff_CV = torch.stack([diff_B, diff_C]) return loc_grad, L_grad, diff_CV, None
def backward(ctx, grad_output): jitter = 1.0e-8 # do i really need this? z, epsilon, L = ctx.saved_tensors dim = L.shape[0] g = grad_output loc_grad = sum_leftmost(grad_output, -1) identity = eye_like(g, dim) R_inv = torch.triangular_solve(identity, L.t(), transpose=False, upper=True)[0] z_ja = z.unsqueeze(-1) g_R_inv = torch.matmul(g, R_inv).unsqueeze(-2) epsilon_jb = epsilon.unsqueeze(-2) g_ja = g.unsqueeze(-1) diff_L_ab = 0.5 * sum_leftmost(g_ja * epsilon_jb + g_R_inv * z_ja, -2) Sigma_inv = torch.mm(R_inv, R_inv.t()) V, D, _ = torch.svd(Sigma_inv + jitter) D_outer = D.unsqueeze(-1) + D.unsqueeze(0) expand_tuple = tuple([-1] * (z.dim() - 1) + [dim, dim]) z_tilde = identity * torch.matmul(z, V).unsqueeze(-1).expand(*expand_tuple) g_tilde = identity * torch.matmul(g, V).unsqueeze(-1).expand(*expand_tuple) Y = sum_leftmost(torch.matmul(z_tilde, torch.matmul(1.0 / D_outer, g_tilde)), -2) Y = torch.mm(V, torch.mm(Y, V.t())) Y = Y + Y.t() Tr_xi_Y = torch.mm(torch.mm(Sigma_inv, Y), R_inv) - torch.mm(Y, torch.mm(Sigma_inv, R_inv)) diff_L_ab += 0.5 * Tr_xi_Y L_grad = torch.tril(diff_L_ab) return loc_grad, L_grad, None
def backward(ctx, grad_output): z, coord_scale, locs, component_logits, pis = ctx.saved_tensors K = component_logits.size(-1) batch_dims = coord_scale.dim() - 1 g = grad_output # l b i z_tilde = z / coord_scale # l b i locs_tilde = locs / coord_scale.unsqueeze(-2) # b j i mu_ab = locs_tilde.unsqueeze(-2) - locs_tilde.unsqueeze(-3) # b k j i mu_ab_norm = torch.pow(mu_ab, 2.0).sum(-1).sqrt() # b k j mu_ab /= mu_ab_norm.unsqueeze(-1) # b k j i diagonals = torch.empty((K, ), dtype=torch.long, device=z.device) torch.arange(K, out=diagonals) mu_ab[..., diagonals, diagonals, :] = 0.0 mu_ll_ab = (locs_tilde.unsqueeze(-2) * mu_ab).sum(-1) # b k j z_ll_ab = (z_tilde.unsqueeze(-2).unsqueeze(-2) * mu_ab).sum( -1) # l b k j z_perp_ab = z_tilde.unsqueeze(-2).unsqueeze( -2) - z_ll_ab.unsqueeze(-1) * mu_ab # l b k j i z_perp_ab_sqr = torch.pow(z_perp_ab, 2.0).sum(-1) # l b k j epsilons = z_tilde.unsqueeze(-2) - locs_tilde # l b j i log_qs = -0.5 * torch.pow(epsilons, 2.0) # l b j i log_q_j = log_qs.sum(-1, keepdim=True) # l b j 1 log_q_j_max = torch.max(log_q_j, -2, keepdim=True)[0] q_j_prime = torch.exp(log_q_j - log_q_j_max) # l b j 1 q_j = torch.exp(log_q_j) # l b j 1 q_tot = (pis.unsqueeze(-1) * q_j).sum(-2) # l b 1 q_tot_prime = (pis.unsqueeze(-1) * q_j_prime).sum(-2).unsqueeze( -1) # l b 1 1 root_two = math.sqrt(2.0) mu_ll_ba = torch.transpose(mu_ll_ab, -1, -2) logits_grad = torch.erf((z_ll_ab - mu_ll_ab) / root_two) - torch.erf( (z_ll_ab + mu_ll_ba) / root_two) logits_grad *= torch.exp(-0.5 * z_perp_ab_sqr) # l b k j # bi lbi bkji mu_ab_sigma_g = ((coord_scale * g).unsqueeze(-2).unsqueeze(-2) * mu_ab).sum(-1) # l b k j logits_grad *= -mu_ab_sigma_g * pis.unsqueeze(-2) # l b k j logits_grad = pis * sum_leftmost( logits_grad.sum(-1) / q_tot, -(1 + batch_dims)) # b k logits_grad *= math.sqrt(0.5 * math.pi) # b j l b j 1 l b i l b 1 1 prefactor = pis.unsqueeze(-1) * q_j_prime * g.unsqueeze( -2) / q_tot_prime # l b j i locs_grad = sum_leftmost(prefactor, -(2 + batch_dims)) # b j i coord_scale_grad = sum_leftmost(prefactor * epsilons, -(2 + batch_dims)).sum(-2) # b i return locs_grad, coord_scale_grad, logits_grad, None, None, None
def posterior(self, obs): concentration1 = self._latent.concentration1 concentration0 = self._latent.concentration0 total_count = self._conditional.total_count reduce_dims = len(obs.size()) - len(concentration1.size()) # Unexpand total_count to have the same shape as concentration0. # Raise exception if this isn't possible. total_count = sum_leftmost(total_count, reduce_dims) summed_obs = sum_leftmost(obs, reduce_dims) return dist.Beta(concentration1 + summed_obs, total_count + concentration0 - summed_obs, validate_args=self._latent._validate_args)
def backward(ctx, grad_output): z, coord_scale, component_logits, component_scale, pis, coeffs = ctx.saved_tensors dim = coord_scale.size(0) g = grad_output # l i g = g.unsqueeze(-2) # l 1 i component_scale_sqr = torch.pow(component_scale, 2.0) # j epsilons = z / coord_scale # l i epsilons_sqr = torch.pow(epsilons, 2.0) # l i r_sqr = epsilons_sqr.sum(-1, keepdim=True) # l r_sqr_j = r_sqr / component_scale_sqr # l j coord_scale_product = coord_scale.prod() component_scale_power = torch.pow(component_scale, float(dim)) q_j = torch.exp(-0.5 * r_sqr_j) / math.pow(2.0 * math.pi, 0.5 * float(dim)) # l j q_j /= coord_scale_product * component_scale_power # l j q_tot = (pis * q_j).sum(-1, keepdim=True) # l Phi_j = torch.exp(-0.5 * r_sqr_j) # l j exponents = -torch.arange(1., int(dim / 2) + 1., 1.) if z.dim() > 1: r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, -1, int(dim / 2)) # l j d/2 else: r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, int(dim / 2)) # l j d/2 r_j_poly = coeffs * torch.pow(r_j_poly, exponents) Phi_j *= r_j_poly.sum(-1) if dim % 2 == 1: root_two = math.sqrt(2.0) extra_term = coeffs[-1] * math.sqrt(0.5 * math.pi) * ( 1.0 - torch.erf(r_sqr_j.sqrt() / root_two)) # l j Phi_j += extra_term * torch.pow(r_sqr_j, -0.5 * float(dim)) logits_grad = (z.unsqueeze(-2) * Phi_j.unsqueeze(-1) * g).sum( -1) # l j logits_grad /= q_tot logits_grad = sum_leftmost(logits_grad, -1) * math.pow( 2.0 * math.pi, -0.5 * float(dim)) logits_grad = pis * logits_grad / (component_scale_power * coord_scale_product) logits_grad = logits_grad - logits_grad.sum() * pis prefactor = pis.unsqueeze(-1) * q_j.unsqueeze( -1) * g / q_tot.unsqueeze(-1) # l j i coord_scale_grad = sum_leftmost(prefactor * epsilons.unsqueeze(-2), -1) component_scale_grad = sum_leftmost( (prefactor * z.unsqueeze(-2)).sum(-1) / component_scale, -1) return coord_scale_grad, logits_grad, component_scale_grad, None, None, None, None
def posterior(self, obs): concentration = self._latent.concentration rate = self._latent.rate reduce_dims = len(obs.size()) - len(rate.size()) num_obs = obs.shape[:reduce_dims].numel() summed_obs = sum_leftmost(obs, reduce_dims) return dist.Gamma(concentration + summed_obs, rate + num_obs)
def test_sum_leftmost(): x = torch.ones(2, 3, 4) assert sum_leftmost(x, 0).shape == (2, 3, 4) assert sum_leftmost(x, 1).shape == (3, 4) assert sum_leftmost(x, 2).shape == (4, ) assert sum_leftmost(x, -1).shape == (4, ) assert sum_leftmost(x, -2).shape == (3, 4) assert sum_leftmost(x, INF).shape == ()
def test_sum_leftmost(): x = torch.ones(2, 3, 4) assert sum_leftmost(x, 0).shape == (2, 3, 4) assert sum_leftmost(x, 1).shape == (3, 4) assert sum_leftmost(x, 2).shape == (4,) assert sum_leftmost(x, -1).shape == (4,) assert sum_leftmost(x, -2).shape == (3, 4) assert sum_leftmost(x, float('inf')).shape == ()
def backward(ctx, grad_output): z, scales, locs, logits, pis = ctx.saved_tensors dim = scales.size(-1) K = logits.size(-1) g = grad_output # l b i g = g.unsqueeze(-2) # l b 1 i batch_dims = locs.dim() - 2 locs_tilde = locs / scales # b j i sigma_0 = torch.min(scales, -2, keepdim=True)[0] # b 1 i z_shift = (z.unsqueeze(-2) - locs) / sigma_0 # l b j i z_tilde = z.unsqueeze(-2) / scales - locs_tilde # l b j i mu_cd = locs.unsqueeze(-2) - locs.unsqueeze(-3) # b c d i mu_cd_norm = torch.pow(mu_cd, 2.0).sum(-1).sqrt() # b c d mu_cd /= mu_cd_norm.unsqueeze(-1) # b c d i diagonals = torch.empty((K, ), dtype=torch.long, device=z.device) torch.arange(K, out=diagonals) mu_cd[..., diagonals, diagonals, :] = 0.0 mu_ll_cd = (locs.unsqueeze(-2) * mu_cd).sum(-1) # b c d z_ll_cd = (z.unsqueeze(-2).unsqueeze(-2) * mu_cd).sum(-1) # l b c d z_perp_cd = (z.unsqueeze(-2).unsqueeze(-2) - z_ll_cd.unsqueeze(-1) * mu_cd) # l b c d i z_perp_cd_sqr = torch.pow(z_perp_cd, 2.0).sum(-1) # l b c d shift_indices = torch.empty((dim, ), dtype=torch.long, device=z.device) torch.arange(dim, out=shift_indices) shift_indices = shift_indices - 1 shift_indices[0] = 0 z_shift_cumsum = torch.pow(z_shift, 2.0) z_shift_cumsum = z_shift_cumsum.sum(-1, keepdim=True) - torch.cumsum( z_shift_cumsum, dim=-1) # l b j i z_tilde_cumsum = torch.cumsum(torch.pow(z_tilde, 2.0), dim=-1) # l b j i z_tilde_cumsum = torch.index_select(z_tilde_cumsum, -1, shift_indices) z_tilde_cumsum[..., 0] = 0.0 r_sqr_ji = z_shift_cumsum + z_tilde_cumsum # l b j i log_scales = torch.log(scales) # b j i epsilons_sqr = torch.pow(z_tilde, 2.0) # l b j i log_qs = (-0.5 * epsilons_sqr - 0.5 * math.log(2.0 * math.pi) - log_scales) # l b j i log_q_j = log_qs.sum(-1, keepdim=True) # l b j 1 q_j = torch.exp(log_q_j) # l b j 1 q_tot = (pis * q_j.squeeze(-1)).sum(-1) # l b q_tot = q_tot.unsqueeze(-1) # l b 1 root_two = math.sqrt(2.0) shift_log_scales = log_scales[..., shift_indices] shift_log_scales[..., 0] = 0.0 sigma_products = torch.cumsum(shift_log_scales, dim=-1).exp() # b j i reverse_indices = torch.tensor(range(dim - 1, -1, -1), dtype=torch.long, device=z.device) reverse_log_sigma_0 = sigma_0.log()[..., reverse_indices] # b 1 i sigma_0_products = torch.cumsum(reverse_log_sigma_0, dim=-1).exp()[..., reverse_indices - 1] # b 1 i sigma_0_products[..., -1] = 1.0 sigma_products *= sigma_0_products logits_grad = torch.erf(z_tilde / root_two) - torch.erf( z_shift / root_two) # l b j i logits_grad *= torch.exp(-0.5 * r_sqr_ji) # l b j i logits_grad = (logits_grad * g / sigma_products).sum(-1) # l b j logits_grad = sum_leftmost(logits_grad / q_tot, -1 - batch_dims) # b j logits_grad *= 0.5 * math.pow(2.0 * math.pi, -0.5 * (dim - 1)) logits_grad = -pis * logits_grad logits_grad = logits_grad - logits_grad.sum(-1, keepdim=True) * pis mu_ll_dc = torch.transpose(mu_ll_cd, -1, -2) v_cd = torch.erf((z_ll_cd - mu_ll_cd) / root_two) - torch.erf( (z_ll_cd + mu_ll_dc) / root_two) v_cd *= torch.exp(-0.5 * z_perp_cd_sqr) # l b c d mu_cd_g = (g.unsqueeze(-2) * mu_cd).sum(-1) # l b c d v_cd *= (-mu_cd_g * pis.unsqueeze(-2) * 0.5 * math.pow(2.0 * math.pi, -0.5 * (dim - 1))) # l b c d v_cd = pis * sum_leftmost(v_cd.sum(-1) / q_tot, -1 - batch_dims) logits_grad += v_cd prefactor = pis.unsqueeze(-1) * q_j * g / q_tot.unsqueeze(-1) locs_grad = sum_leftmost(prefactor, -2 - batch_dims) scales_grad = sum_leftmost(prefactor * z_tilde, -2 - batch_dims) return locs_grad, scales_grad, logits_grad, None, None, None