def estimate_mi_tc_dwkld(z, z_mu, z_log_sigma_sq, N=2e5):
    # computational cheaper to compute them together
    z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq))
    log_probs = []
    batch_size, z_dim = int_shape(z_mu)

    z_b = tf.stack([z for i in range(batch_size)], axis=0)
    z_mu_b = tf.stack([z_mu for i in range(batch_size)], axis=1)
    z_sigma_b = tf.stack([z_sigma for i in range(batch_size)], axis=1)
    z_norm = (z_b - z_mu_b) / z_sigma_b

    dist = tf.distributions.Normal(loc=0., scale=1.)
    log_probs = dist.log_prob(z_norm)
    ratio = np.log(float(N - 1) / (batch_size - 1)) * np.ones(
        (batch_size, batch_size))
    np.fill_diagonal(ratio, 0.)
    ratio_b = np.stack([ratio for i in range(z_dim)], axis=-1)

    lse_sum = tf.reduce_mean(
        log_sum_exp(tf.reduce_sum(log_probs, axis=-1) + ratio, axis=0))
    sum_lse = tf.reduce_mean(
        tf.reduce_sum(log_sum_exp(log_probs + ratio_b, axis=0), axis=-1))
    lse_sum -= tf.log(float(N))
    sum_lse -= tf.log(float(N)) * float(z_dim)

    kld = compute_gaussian_kld(z_mu, z_log_sigma_sq)
    cond_entropy = tf.reduce_mean(
        compute_gaussian_entropy(z_mu, z_log_sigma_sq))

    mi = -lse_sum - cond_entropy
    tc = lse_sum - sum_lse
    dwkld = sum_lse + (kld + cond_entropy)

    return mi, tc, dwkld
def estimate_tc(z, z_mu, z_log_sigma_sq, N=200000):
    z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq))
    log_probs = []
    batch_size, z_dim = int_shape(z_mu)

    z_b = tf.stack([z for i in range(batch_size)], axis=0)
    z_mu_b = tf.stack([z_mu for i in range(batch_size)], axis=1)
    z_sigma_b = tf.stack([z_sigma for i in range(batch_size)], axis=1)
    z_norm = (z_b - z_mu_b) / z_sigma_b

    dist = tf.distributions.Normal(loc=0., scale=1.)
    log_probs = dist.log_prob(z_norm)
    ratio = np.log(float(N - 1) / (batch_size - 1)) * np.ones(
        (batch_size, batch_size))
    np.fill_diagonal(ratio, 0.)
    ratio_b = np.stack([ratio for i in range(z_dim)], axis=-1)

    lse_sum = tf.reduce_mean(
        log_sum_exp(tf.reduce_sum(log_probs, axis=-1) + ratio, axis=0))
    sum_lse = tf.reduce_mean(
        tf.reduce_sum(log_sum_exp(log_probs + ratio_b, axis=0), axis=-1))
    return lse_sum - sum_lse + tf.log(float(N)) * (float(z_dim) - 1)
def discretized_mix_logistic_loss(x,l,sum_all=True, masks=None):
    """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
    xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3)
    ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100)
    nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics
    logit_probs = l[:,:,:,:nr_mix]
    l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3])
    means = l[:,:,:,:,:nr_mix]
    log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.)
    coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])

    x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels
    m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix])
    m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix])
    means = tf.concat([tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3],3)
    centered_x = x - means
    inv_stdv = tf.exp(-log_scales)

    plus_in = inv_stdv * (centered_x + 1./255.)
    cdf_plus = tf.nn.sigmoid(plus_in)
    min_in = inv_stdv * (centered_x - 1./255.)
    cdf_min = tf.nn.sigmoid(min_in)
    log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling)
    log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling)
    cdf_delta = cdf_plus - cdf_min # probability for all other cases
    mid_in = inv_stdv * centered_x
    log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code)

    # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us)

    # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select()
    # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta)))

    # robust version, that still works if probabilities are below 1e-5 (which never happens in our code)
    # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs
    # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue
    # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value
    log_probs = tf.where(x < -0.999, log_cdf_plus, tf.where(x > 0.999, log_one_minus_cdf_min, tf.where(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5))))

    log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs)

    lse = log_sum_exp(log_probs)

    if masks is not None:
        assert lse.shape==masks.shape, "shape of masks does not match the log_sum_exp outputs"
        lse *= (1 - masks)
    if sum_all:
        return -tf.reduce_sum(lse)
    else:
        return -tf.reduce_sum(lse,[1,2])