Exemple #1
0
    def backward(ctx, grad_output):
        L, control_var, epsilon = ctx.saved_tensors
        B, C = control_var
        g = grad_output
        loc_grad = sum_leftmost(grad_output, -1)

        # compute the rep trick gradient
        epsilon_jb = epsilon.unsqueeze(-2)
        g_ja = g.unsqueeze(-1)
        diff_L_ab = sum_leftmost(g_ja * epsilon_jb, -2)

        # modulate the velocity fields with infinitesimal rotations, i.e. apply the control variate
        gL = torch.matmul(g, L)
        eps_gL_ab = sum_leftmost(gL.unsqueeze(-1) * epsilon.unsqueeze(-2), -2)
        xi_ab = eps_gL_ab - eps_gL_ab.t()
        BC_lab = B.unsqueeze(-1) * C.unsqueeze(-2)
        diff_L_ab += (xi_ab.unsqueeze(0) * BC_lab).sum(0)
        L_grad = torch.tril(diff_L_ab)

        # compute control_var grads
        diff_B = (L_grad.unsqueeze(0) * C.unsqueeze(-2) *
                  xi_ab.unsqueeze(0)).sum(2)
        diff_C = (L_grad.t().unsqueeze(0) * B.unsqueeze(-2) *
                  xi_ab.t().unsqueeze(0)).sum(2)
        diff_CV = torch.stack([diff_B, diff_C])

        return loc_grad, L_grad, diff_CV, None
Exemple #2
0
    def backward(ctx, grad_output):
        jitter = 1.0e-8  # do i really need this?
        z, epsilon, L = ctx.saved_tensors

        dim = L.shape[0]
        g = grad_output
        loc_grad = sum_leftmost(grad_output, -1)

        identity = eye_like(g, dim)
        R_inv = torch.triangular_solve(identity, L.t(), transpose=False, upper=True)[0]

        z_ja = z.unsqueeze(-1)
        g_R_inv = torch.matmul(g, R_inv).unsqueeze(-2)
        epsilon_jb = epsilon.unsqueeze(-2)
        g_ja = g.unsqueeze(-1)
        diff_L_ab = 0.5 * sum_leftmost(g_ja * epsilon_jb + g_R_inv * z_ja, -2)

        Sigma_inv = torch.mm(R_inv, R_inv.t())
        V, D, _ = torch.svd(Sigma_inv + jitter)
        D_outer = D.unsqueeze(-1) + D.unsqueeze(0)

        expand_tuple = tuple([-1] * (z.dim() - 1) + [dim, dim])
        z_tilde = identity * torch.matmul(z, V).unsqueeze(-1).expand(*expand_tuple)
        g_tilde = identity * torch.matmul(g, V).unsqueeze(-1).expand(*expand_tuple)

        Y = sum_leftmost(torch.matmul(z_tilde, torch.matmul(1.0 / D_outer, g_tilde)), -2)
        Y = torch.mm(V, torch.mm(Y, V.t()))
        Y = Y + Y.t()

        Tr_xi_Y = torch.mm(torch.mm(Sigma_inv, Y), R_inv) - torch.mm(Y, torch.mm(Sigma_inv, R_inv))
        diff_L_ab += 0.5 * Tr_xi_Y
        L_grad = torch.tril(diff_L_ab)

        return loc_grad, L_grad, None
Exemple #3
0
    def backward(ctx, grad_output):

        z, coord_scale, locs, component_logits, pis = ctx.saved_tensors
        K = component_logits.size(-1)
        batch_dims = coord_scale.dim() - 1
        g = grad_output  # l b i

        z_tilde = z / coord_scale  # l b i
        locs_tilde = locs / coord_scale.unsqueeze(-2)  # b j i
        mu_ab = locs_tilde.unsqueeze(-2) - locs_tilde.unsqueeze(-3)  # b k j i
        mu_ab_norm = torch.pow(mu_ab, 2.0).sum(-1).sqrt()  # b k j
        mu_ab /= mu_ab_norm.unsqueeze(-1)  # b k j i
        diagonals = torch.empty((K, ), dtype=torch.long, device=z.device)
        torch.arange(K, out=diagonals)
        mu_ab[..., diagonals, diagonals, :] = 0.0

        mu_ll_ab = (locs_tilde.unsqueeze(-2) * mu_ab).sum(-1)  # b k j
        z_ll_ab = (z_tilde.unsqueeze(-2).unsqueeze(-2) * mu_ab).sum(
            -1)  # l b k j
        z_perp_ab = z_tilde.unsqueeze(-2).unsqueeze(
            -2) - z_ll_ab.unsqueeze(-1) * mu_ab  # l b k j i
        z_perp_ab_sqr = torch.pow(z_perp_ab, 2.0).sum(-1)  # l b k j

        epsilons = z_tilde.unsqueeze(-2) - locs_tilde  # l b j i
        log_qs = -0.5 * torch.pow(epsilons, 2.0)  # l b j i
        log_q_j = log_qs.sum(-1, keepdim=True)  # l b j 1
        log_q_j_max = torch.max(log_q_j, -2, keepdim=True)[0]
        q_j_prime = torch.exp(log_q_j - log_q_j_max)  # l b j 1
        q_j = torch.exp(log_q_j)  # l b j 1

        q_tot = (pis.unsqueeze(-1) * q_j).sum(-2)  # l b 1
        q_tot_prime = (pis.unsqueeze(-1) * q_j_prime).sum(-2).unsqueeze(
            -1)  # l b 1 1

        root_two = math.sqrt(2.0)
        mu_ll_ba = torch.transpose(mu_ll_ab, -1, -2)
        logits_grad = torch.erf((z_ll_ab - mu_ll_ab) / root_two) - torch.erf(
            (z_ll_ab + mu_ll_ba) / root_two)
        logits_grad *= torch.exp(-0.5 * z_perp_ab_sqr)  # l b k j

        #                 bi      lbi                               bkji
        mu_ab_sigma_g = ((coord_scale * g).unsqueeze(-2).unsqueeze(-2) *
                         mu_ab).sum(-1)  # l b k j
        logits_grad *= -mu_ab_sigma_g * pis.unsqueeze(-2)  # l b k j
        logits_grad = pis * sum_leftmost(
            logits_grad.sum(-1) / q_tot, -(1 + batch_dims))  # b k
        logits_grad *= math.sqrt(0.5 * math.pi)

        #           b j                 l b j 1   l b i             l b 1 1
        prefactor = pis.unsqueeze(-1) * q_j_prime * g.unsqueeze(
            -2) / q_tot_prime  # l b j i
        locs_grad = sum_leftmost(prefactor, -(2 + batch_dims))  # b j i
        coord_scale_grad = sum_leftmost(prefactor * epsilons,
                                        -(2 + batch_dims)).sum(-2)  # b i

        return locs_grad, coord_scale_grad, logits_grad, None, None, None
Exemple #4
0
 def posterior(self, obs):
     concentration1 = self._latent.concentration1
     concentration0 = self._latent.concentration0
     total_count = self._conditional.total_count
     reduce_dims = len(obs.size()) - len(concentration1.size())
     # Unexpand total_count to have the same shape as concentration0.
     # Raise exception if this isn't possible.
     total_count = sum_leftmost(total_count, reduce_dims)
     summed_obs = sum_leftmost(obs, reduce_dims)
     return dist.Beta(concentration1 + summed_obs,
                      total_count + concentration0 - summed_obs,
                      validate_args=self._latent._validate_args)
    def backward(ctx, grad_output):
        z, coord_scale, component_logits, component_scale, pis, coeffs = ctx.saved_tensors
        dim = coord_scale.size(0)
        g = grad_output  # l i
        g = g.unsqueeze(-2)  # l 1 i

        component_scale_sqr = torch.pow(component_scale, 2.0)  # j
        epsilons = z / coord_scale  # l i
        epsilons_sqr = torch.pow(epsilons, 2.0)  # l i
        r_sqr = epsilons_sqr.sum(-1, keepdim=True)  # l
        r_sqr_j = r_sqr / component_scale_sqr  # l j
        coord_scale_product = coord_scale.prod()
        component_scale_power = torch.pow(component_scale, float(dim))

        q_j = torch.exp(-0.5 * r_sqr_j) / math.pow(2.0 * math.pi,
                                                   0.5 * float(dim))  # l j
        q_j /= coord_scale_product * component_scale_power  # l j
        q_tot = (pis * q_j).sum(-1, keepdim=True)  # l

        Phi_j = torch.exp(-0.5 * r_sqr_j)  # l j
        exponents = -torch.arange(1., int(dim / 2) + 1., 1.)
        if z.dim() > 1:
            r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1, -1,
                                                    int(dim / 2))  # l j d/2
        else:
            r_j_poly = r_sqr_j.unsqueeze(-1).expand(-1,
                                                    int(dim / 2))  # l j d/2
        r_j_poly = coeffs * torch.pow(r_j_poly, exponents)
        Phi_j *= r_j_poly.sum(-1)
        if dim % 2 == 1:
            root_two = math.sqrt(2.0)
            extra_term = coeffs[-1] * math.sqrt(0.5 * math.pi) * (
                1.0 - torch.erf(r_sqr_j.sqrt() / root_two))  # l j
            Phi_j += extra_term * torch.pow(r_sqr_j, -0.5 * float(dim))

        logits_grad = (z.unsqueeze(-2) * Phi_j.unsqueeze(-1) * g).sum(
            -1)  # l j
        logits_grad /= q_tot
        logits_grad = sum_leftmost(logits_grad, -1) * math.pow(
            2.0 * math.pi, -0.5 * float(dim))
        logits_grad = pis * logits_grad / (component_scale_power *
                                           coord_scale_product)
        logits_grad = logits_grad - logits_grad.sum() * pis

        prefactor = pis.unsqueeze(-1) * q_j.unsqueeze(
            -1) * g / q_tot.unsqueeze(-1)  # l j i
        coord_scale_grad = sum_leftmost(prefactor * epsilons.unsqueeze(-2), -1)
        component_scale_grad = sum_leftmost(
            (prefactor * z.unsqueeze(-2)).sum(-1) / component_scale, -1)

        return coord_scale_grad, logits_grad, component_scale_grad, None, None, None, None
Exemple #6
0
 def posterior(self, obs):
     concentration = self._latent.concentration
     rate = self._latent.rate
     reduce_dims = len(obs.size()) - len(rate.size())
     num_obs = obs.shape[:reduce_dims].numel()
     summed_obs = sum_leftmost(obs, reduce_dims)
     return dist.Gamma(concentration + summed_obs, rate + num_obs)
Exemple #7
0
def test_sum_leftmost():
    x = torch.ones(2, 3, 4)
    assert sum_leftmost(x, 0).shape == (2, 3, 4)
    assert sum_leftmost(x, 1).shape == (3, 4)
    assert sum_leftmost(x, 2).shape == (4, )
    assert sum_leftmost(x, -1).shape == (4, )
    assert sum_leftmost(x, -2).shape == (3, 4)
    assert sum_leftmost(x, INF).shape == ()
Exemple #8
0
def test_sum_leftmost():
    x = torch.ones(2, 3, 4)
    assert sum_leftmost(x, 0).shape == (2, 3, 4)
    assert sum_leftmost(x, 1).shape == (3, 4)
    assert sum_leftmost(x, 2).shape == (4,)
    assert sum_leftmost(x, -1).shape == (4,)
    assert sum_leftmost(x, -2).shape == (3, 4)
    assert sum_leftmost(x, float('inf')).shape == ()
Exemple #9
0
    def backward(ctx, grad_output):

        z, scales, locs, logits, pis = ctx.saved_tensors
        dim = scales.size(-1)
        K = logits.size(-1)
        g = grad_output  # l b i
        g = g.unsqueeze(-2)  # l b 1 i
        batch_dims = locs.dim() - 2

        locs_tilde = locs / scales  # b j i
        sigma_0 = torch.min(scales, -2, keepdim=True)[0]  # b 1 i
        z_shift = (z.unsqueeze(-2) - locs) / sigma_0  # l b j i
        z_tilde = z.unsqueeze(-2) / scales - locs_tilde  # l b j i

        mu_cd = locs.unsqueeze(-2) - locs.unsqueeze(-3)  # b c d i
        mu_cd_norm = torch.pow(mu_cd, 2.0).sum(-1).sqrt()  # b c d
        mu_cd /= mu_cd_norm.unsqueeze(-1)  # b c d i
        diagonals = torch.empty((K, ), dtype=torch.long, device=z.device)
        torch.arange(K, out=diagonals)
        mu_cd[..., diagonals, diagonals, :] = 0.0

        mu_ll_cd = (locs.unsqueeze(-2) * mu_cd).sum(-1)  # b c d
        z_ll_cd = (z.unsqueeze(-2).unsqueeze(-2) * mu_cd).sum(-1)  # l b c d
        z_perp_cd = (z.unsqueeze(-2).unsqueeze(-2) -
                     z_ll_cd.unsqueeze(-1) * mu_cd)  # l b c d i
        z_perp_cd_sqr = torch.pow(z_perp_cd, 2.0).sum(-1)  # l b c d

        shift_indices = torch.empty((dim, ), dtype=torch.long, device=z.device)
        torch.arange(dim, out=shift_indices)
        shift_indices = shift_indices - 1
        shift_indices[0] = 0

        z_shift_cumsum = torch.pow(z_shift, 2.0)
        z_shift_cumsum = z_shift_cumsum.sum(-1, keepdim=True) - torch.cumsum(
            z_shift_cumsum, dim=-1)  # l b j i
        z_tilde_cumsum = torch.cumsum(torch.pow(z_tilde, 2.0),
                                      dim=-1)  # l b j i
        z_tilde_cumsum = torch.index_select(z_tilde_cumsum, -1, shift_indices)
        z_tilde_cumsum[..., 0] = 0.0
        r_sqr_ji = z_shift_cumsum + z_tilde_cumsum  # l b j i

        log_scales = torch.log(scales)  # b j i
        epsilons_sqr = torch.pow(z_tilde, 2.0)  # l b j i
        log_qs = (-0.5 * epsilons_sqr - 0.5 * math.log(2.0 * math.pi) -
                  log_scales)  # l b j i
        log_q_j = log_qs.sum(-1, keepdim=True)  # l b j 1
        q_j = torch.exp(log_q_j)  # l b j 1
        q_tot = (pis * q_j.squeeze(-1)).sum(-1)  # l b
        q_tot = q_tot.unsqueeze(-1)  # l b 1

        root_two = math.sqrt(2.0)
        shift_log_scales = log_scales[..., shift_indices]
        shift_log_scales[..., 0] = 0.0
        sigma_products = torch.cumsum(shift_log_scales, dim=-1).exp()  # b j i

        reverse_indices = torch.tensor(range(dim - 1, -1, -1),
                                       dtype=torch.long,
                                       device=z.device)
        reverse_log_sigma_0 = sigma_0.log()[..., reverse_indices]  # b 1 i
        sigma_0_products = torch.cumsum(reverse_log_sigma_0,
                                        dim=-1).exp()[..., reverse_indices -
                                                      1]  # b 1 i
        sigma_0_products[..., -1] = 1.0
        sigma_products *= sigma_0_products

        logits_grad = torch.erf(z_tilde / root_two) - torch.erf(
            z_shift / root_two)  # l b j i
        logits_grad *= torch.exp(-0.5 * r_sqr_ji)  # l b j i
        logits_grad = (logits_grad * g / sigma_products).sum(-1)  # l b j
        logits_grad = sum_leftmost(logits_grad / q_tot, -1 - batch_dims)  # b j
        logits_grad *= 0.5 * math.pow(2.0 * math.pi, -0.5 * (dim - 1))
        logits_grad = -pis * logits_grad
        logits_grad = logits_grad - logits_grad.sum(-1, keepdim=True) * pis

        mu_ll_dc = torch.transpose(mu_ll_cd, -1, -2)
        v_cd = torch.erf((z_ll_cd - mu_ll_cd) / root_two) - torch.erf(
            (z_ll_cd + mu_ll_dc) / root_two)
        v_cd *= torch.exp(-0.5 * z_perp_cd_sqr)  # l b c d
        mu_cd_g = (g.unsqueeze(-2) * mu_cd).sum(-1)  # l b c d
        v_cd *= (-mu_cd_g * pis.unsqueeze(-2) * 0.5 *
                 math.pow(2.0 * math.pi, -0.5 * (dim - 1)))  # l b c d
        v_cd = pis * sum_leftmost(v_cd.sum(-1) / q_tot, -1 - batch_dims)
        logits_grad += v_cd

        prefactor = pis.unsqueeze(-1) * q_j * g / q_tot.unsqueeze(-1)
        locs_grad = sum_leftmost(prefactor, -2 - batch_dims)
        scales_grad = sum_leftmost(prefactor * z_tilde, -2 - batch_dims)

        return locs_grad, scales_grad, logits_grad, None, None, None