def normal_diag_KLdiv_ad(head, mean, scale):
    """
    An example of differentiating normal_diag.KLdiv in all inputs and paramters

    Args:
        head: The adjoint of the output, in other words, some tensors, by which the Jacobians
            will be multiplied
        x: input
        mean: vector of means of MVN
        scale: vector of sigma of MVN with diagonal covariance

    """
    mod = normal_diag.normal_diag(mean, scale).KL_divergence()
    auto_diff_outs = list(akg.differentiate(mod, [mean, scale], head))
    return auto_diff_outs
Exemple #2
0
def sample_op(mean, scale, eps):
  return normal_diag.normal_diag(mean, scale).sample(eps)
def logprob_op(x, mean, scale):
  return normal_diag.normal_diag(mean, scale).log_prob(x)
def KLdiv_op(mean, scale):
    return normal_diag.normal_diag(mean, scale).KL_divergence()