def estimate_mi_tc_dwkld(z, z_mu, z_log_sigma_sq, N=2e5): # computational cheaper to compute them together z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq)) log_probs = [] batch_size, z_dim = int_shape(z_mu) z_b = tf.stack([z for i in range(batch_size)], axis=0) z_mu_b = tf.stack([z_mu for i in range(batch_size)], axis=1) z_sigma_b = tf.stack([z_sigma for i in range(batch_size)], axis=1) z_norm = (z_b - z_mu_b) / z_sigma_b dist = tf.distributions.Normal(loc=0., scale=1.) log_probs = dist.log_prob(z_norm) ratio = np.log(float(N - 1) / (batch_size - 1)) * np.ones( (batch_size, batch_size)) np.fill_diagonal(ratio, 0.) ratio_b = np.stack([ratio for i in range(z_dim)], axis=-1) lse_sum = tf.reduce_mean( log_sum_exp(tf.reduce_sum(log_probs, axis=-1) + ratio, axis=0)) sum_lse = tf.reduce_mean( tf.reduce_sum(log_sum_exp(log_probs + ratio_b, axis=0), axis=-1)) lse_sum -= tf.log(float(N)) sum_lse -= tf.log(float(N)) * float(z_dim) kld = compute_gaussian_kld(z_mu, z_log_sigma_sq) cond_entropy = tf.reduce_mean( compute_gaussian_entropy(z_mu, z_log_sigma_sq)) mi = -lse_sum - cond_entropy tc = lse_sum - sum_lse dwkld = sum_lse + (kld + cond_entropy) return mi, tc, dwkld
def estimate_tc(z, z_mu, z_log_sigma_sq, N=200000): z_sigma = tf.sqrt(tf.exp(z_log_sigma_sq)) log_probs = [] batch_size, z_dim = int_shape(z_mu) z_b = tf.stack([z for i in range(batch_size)], axis=0) z_mu_b = tf.stack([z_mu for i in range(batch_size)], axis=1) z_sigma_b = tf.stack([z_sigma for i in range(batch_size)], axis=1) z_norm = (z_b - z_mu_b) / z_sigma_b dist = tf.distributions.Normal(loc=0., scale=1.) log_probs = dist.log_prob(z_norm) ratio = np.log(float(N - 1) / (batch_size - 1)) * np.ones( (batch_size, batch_size)) np.fill_diagonal(ratio, 0.) ratio_b = np.stack([ratio for i in range(z_dim)], axis=-1) lse_sum = tf.reduce_mean( log_sum_exp(tf.reduce_sum(log_probs, axis=-1) + ratio, axis=0)) sum_lse = tf.reduce_mean( tf.reduce_sum(log_sum_exp(log_probs + ratio_b, axis=0), axis=-1)) return lse_sum - sum_lse + tf.log(float(N)) * (float(z_dim) - 1)
def discretized_mix_logistic_loss(x,l,sum_all=True, masks=None): """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3) ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100) nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics logit_probs = l[:,:,:,:nr_mix] l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3]) means = l[:,:,:,:,:nr_mix] log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.) coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix]) x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix]) m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix]) means = tf.concat([tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3],3) centered_x = x - means inv_stdv = tf.exp(-log_scales) plus_in = inv_stdv * (centered_x + 1./255.) cdf_plus = tf.nn.sigmoid(plus_in) min_in = inv_stdv * (centered_x - 1./255.) cdf_min = tf.nn.sigmoid(min_in) log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling) log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling) cdf_delta = cdf_plus - cdf_min # probability for all other cases mid_in = inv_stdv * centered_x log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code) # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us) # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select() # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta))) # robust version, that still works if probabilities are below 1e-5 (which never happens in our code) # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value log_probs = tf.where(x < -0.999, log_cdf_plus, tf.where(x > 0.999, log_one_minus_cdf_min, tf.where(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5)))) log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs) lse = log_sum_exp(log_probs) if masks is not None: assert lse.shape==masks.shape, "shape of masks does not match the log_sum_exp outputs" lse *= (1 - masks) if sum_all: return -tf.reduce_sum(lse) else: return -tf.reduce_sum(lse,[1,2])