def disentangled_inferred_prior_loss(qZ_X: Distribution, only_mean: bool = False, lambda_offdiag: float = 2., lambda_diag: float = 1.) -> Tensor: r""" Disentangled inferred prior (DIP) matches the covariance of the prior distributions with the inferred prior Uses `cov(z_mean) = E[z_mean*z_mean^T] - E[z_mean]E[z_mean]^T`. Arguments: qZ_X : `tensorflow_probability.Distribution` only_mean : A Boolean. If `True`, applying DIP constraint only on the mean of latents `Cov[E(z)]` (i.e. type 'i'), otherwise, `E[Cov(z)] + Cov[E(z)]` (i.e. type 'ii') lambda_offdiag : A Scalar. Weight for penalizing the off-diagonal part of covariance matrix. lambda_diag : A Scalar. Weight for penalizing the diagonal. Reference: Kumar, A., Sattigeri, P., Balakrishnan, A., 2018. Variational Inference of Disentangled Latent Concepts from Unlabeled Observations. arXiv:1711.00848 [cs, stat]. Github code https://github.com/IBM/AIX360 Github code https://github.com/google-research/disentanglement_lib """ z_mean = qZ_X.mean() shape = z_mean.shape if len(shape) > 2: # [sample_shape * batch_size, zdim] z_mean = tf.reshape( z_mean, (tf.cast(tf.reduce_prod(shape[:-1]), tf.int32),) + shape[-1:]) expectation_z_mean_z_mean_t = tf.reduce_mean(tf.expand_dims(z_mean, 2) * tf.expand_dims(z_mean, 1), axis=0) expectation_z_mean = tf.reduce_mean(z_mean, axis=0) # cov_zmean [zdim, zdim] cov_zmean = tf.subtract( expectation_z_mean_z_mean_t, tf.expand_dims(expectation_z_mean, 1) * tf.expand_dims(expectation_z_mean, 0)) # Eq(5) if only_mean: z_cov = cov_zmean else: z_var = qZ_X.variance() if len(shape) > 2: z_var = tf.reshape( z_var, (tf.cast(tf.reduce_prod(shape[:-1]), tf.int32),) + shape[-1:]) # mean_zcov [zdim, zdim] mean_zcov = tf.reduce_mean(tf.linalg.diag(z_var), axis=0) z_cov = cov_zmean + mean_zcov # Eq(6) and Eq(7) # z_cov [sample_shape, zdim, zdim] # z_cov_diag [sample_shape, zdim] # z_cov_offdiag [sample_shape, zdim, zdim] z_cov_diag = tf.linalg.diag_part(z_cov) z_cov_offdiag = z_cov - tf.linalg.diag(z_cov_diag) return lambda_offdiag * tf.reduce_sum(z_cov_offdiag ** 2) + \ lambda_diag * tf.reduce_sum((z_cov_diag - 1.) ** 2)
def tfd_analytic_sample(n: int, dist: tfd.Distribution, limits: ztyping.ObsTypeInput): """Sample analytically with a `tfd.Distribution` within the limits. No preprocessing. Args: n: Number of samples to get dist: Distribution to sample from limits: Limits to sample from within Returns: The sampled data with the number of samples and the number of observables. """ lower_bound, upper_bound = limits.rect_limits lower_prob_lim = dist.cdf(lower_bound) upper_prob_lim = dist.cdf(upper_bound) shape = (n, 1) prob_sample = z.random.uniform(shape=shape, minval=lower_prob_lim, maxval=upper_prob_lim) prob_sample.set_shape((None, 1)) try: sample = dist.quantile(prob_sample) except NotImplementedError: raise AnalyticSamplingNotImplementedError sample.set_shape((None, limits.n_obs)) return sample
def _mi_loss( self, Q: Sequence[Distribution], py_z: Distribution, training: Optional[bool] = None, which_latents_sampling: Optional[List[int]] = None, ) -> Tuple[tf.Tensor, List[tf.Tensor]]: ## sample the prior batch_shape = Q[0].batch_shape_tensor() if which_latents_sampling is None: which_latents_sampling = list(range(len(Q))) z_prime = [ q.KL_divergence.prior.sample(batch_shape) if i in which_latents_sampling else tf.stop_gradient( tf.convert_to_tensor(q)) for i, q in enumerate(Q) ] if len(z_prime) == 1: z_prime = z_prime[0] ## decoding px = self.decode(z_prime, training=training)[0] if px.reparameterization_type == NOT_REPARAMETERIZED: x = px.mean() else: x = tf.convert_to_tensor(px) # should not stop gradient here, generator need to be updated # x = tf.stop_gradient(x) Q_prime = self.encode(x, training=training) qy_z = self.predict_factors(latents=Q_prime, training=training) ## y ~ p(y|z), stop gradient here is important to prevent the encoder # updated twice this significantly increase the stability, otherwise, # encoder and latents often get NaNs gradients if self.reverse_mi: # D_kl(p(y|z)||q(y|z)) y_samples = tf.stop_gradient(py_z.sample()) Dkl = py_z.log_prob(y_samples) - qy_z.log_prob(y_samples) else: # D_kl(q(y|z)||p(y|z)) y_samples = tf.stop_gradient(qy_z.sample()) Dkl = qy_z.log_prob(y_samples) - py_z.log_prob(y_samples) ## only calculate MI for unsupervised data mi_y = tf.reduce_mean(Dkl) ## mutual information (we want to maximize this, hence, add it to the llk) if training: mi_y = tf.cond( self.step >= self.steps_without_mi, true_fn=lambda: mi_y, false_fn=lambda: tf.stop_gradient(mi_y), ) else: mi_y = tf.stop_gradient(mi_y) mi_y = self.mi_coef * mi_y ## this value is just for monitoring mi_z = [] for q, z in zip(as_tuple(Q_prime), as_tuple(z_prime)): mi = tf.reduce_mean(tf.stop_gradient(q.log_prob(z))) mi = tf.cond(tf.math.is_nan(mi), true_fn=lambda: 0., false_fn=lambda: tf.clip_by_value(mi, -1e8, 1e8)) mi_z.append(mi) return mi_y, mi_z
def total_correlation(z_samples: Tensor, qZ_X: Distribution) -> Tensor: r"""Estimate of total correlation using Gaussian distribution on a batch. We need to compute the expectation over a batch of: `E_j [log(q(z(x_j))) - log(prod_l q(z(x_j)_l))]` We ignore the constants as they do not matter for the minimization. The constant should be equal to `(num_latents - 1) * log(batch_size * dataset_size)` If `alpha = gamma = 1`, Eq(4) can be written as `ELBO + (1 - beta) * TC`. (i.e. `(1. - beta) * total_correlation(z_sampled, qZ_X)`) Parameters ---------- z_samples : Tensor shape `[batch_size, num_latents]` - tensor with sampled representation. qZ_X : Distribution the posterior distribution, shape `[batch_size, num_latents]` Note ---- This involve calculating pair-wise distance, memory complexity up to `O(n*n*d)`. Returns ------- Total correlation estimated on a batch. References ---------- Chen, R.T.Q., Li, X., Grosse, R., Duvenaud, D., 2019. Isolating Sources of Disentanglement in Variational Autoencoders. arXiv:1802.04942 [cs, stat]. Github code https://github.com/google-research/disentanglement_lib """ gaus = Normal(loc=tf.expand_dims(qZ_X.mean(), 0), scale=tf.expand_dims(qZ_X.stddev(), 0)) # Compute log(q(z(x_j)|x_i)) for every sample in the batch, which is a # tensor of size [batch_size, batch_size, num_latents]. In the following # comments, [batch_size, batch_size, num_latents] are indexed by [j, i, l]. log_qz_prob = gaus.log_prob(tf.expand_dims(z_samples, 1)) # Compute log prod_l p(z(x_j)_l) = sum_l(log(sum_i(q(z(z_j)_l|x_i))) # + constant) for each sample in the batch, which is a vector of size # [batch_size,]. log_qz_product = tf.reduce_sum(tf.reduce_logsumexp(log_qz_prob, axis=1, keepdims=False), axis=1, keepdims=False) # Compute log(q(z(x_j))) as log(sum_i(q(z(x_j)|x_i))) + constant = # log(sum_i(prod_l q(z(x_j)_l|x_i))) + constant. log_qz = tf.reduce_logsumexp(tf.reduce_sum(log_qz_prob, axis=2, keepdims=False), axis=1, keepdims=False) return tf.reduce_mean(log_qz - log_qz_product)
def maximum_mean_discrepancy( qZ: Distribution, pZ: Distribution, q_sample_shape: Union[int, List[int]] = (), p_sample_shape: Union[int, List[int]] = 100, kernel: Literal['gaussian', 'linear', 'polynomial'] = 'gaussian') -> Tensor: r""" is a distance-measure between distributions p(X) and q(Y) which is defined as the squared distance between their embeddings in the a "reproducing kernel Hilbert space". Given n examples from p(X) and m samples from q(Y), one can formulate a test statistic based on the empirical estimate of the MMD: MMD^2(P, Q) = || \E{\phi(x)} - \E{\phi(y)} ||^2 = \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) } Arguments: nq : a Scalar. Number of posterior samples np : a Scalar. Number of prior samples Reference: Gretton, A., Borgwardt, K., Rasch, M.J., Scholkopf, B., Smola, A.J., 2008. "A Kernel Method for the Two-Sample Problem". arXiv:0805.2368 [cs]. """ assert isinstance( qZ, Distribution ), 'qZ must be instance of tensorflow_probability.Distribution' assert isinstance( pZ, Distribution ), 'pZ must be instance of tensorflow_probability.Distribution' # prepare the samples if q_sample_shape is None: # reuse sampled examples x = tf.convert_to_tensor(qZ) else: x = qZ.sample(q_sample_shape) y = pZ.sample(p_sample_shape) # select the kernel if kernel == 'gaussian': kernel = gaussian_kernel elif kernel == 'linear': kernel = linear_kernel elif kernel == 'polynomial': kernel = polynomial_kernel else: raise NotImplementedError("No support for kernel: '%s'" % kernel) k_xx = kernel(x, x) k_yy = kernel(y, y) k_xy = kernel(x, y) return tf.reduce_mean(k_xx) + tf.reduce_mean( k_yy) - 2 * tf.reduce_mean(k_xy)
def get_fixed_topology_bijector(dist: tfd.Distribution, topology_pins=tp.Dict[str, TensorflowTreeTopology]): if isinstance(dist, BaseTreeDistribution) and dist.tree_name in topology_pins: topology = topology_pins[dist.tree_name] tree_bijector: TreeRatioBijector = ( dist.experimental_default_event_space_bijector(topology=topology)) return FixedTopologyRootedTreeBijector( topology, tree_bijector.bijectors.node_heights, sampling_times=dist. sampling_times, # TODO: Make sure dist has fixed sampling times ) else: return dist.experimental_default_event_space_bijector()
def slice_distribution(index, dist: tfd.Distribution, name=None) -> tfd.Distribution: r""" Apply indexing on distribution parameters and return another `Distribution` """ assert isinstance(dist, tfd.Distribution), \ "dist must be instance of Distribution, but given: %s" % str(type(dist)) if name is None: name = dist.name ## compound distribution if isinstance(dist, tfd.Independent): return tfd.Independent( distribution=slice_distribution(index, dist.distribution), reinterpreted_batch_ndims=dist.reinterpreted_batch_ndims, name=name) elif isinstance(dist, ZeroInflated): return ZeroInflated(\ count_distribution=slice_distribution(index, dist.count_distribution), inflated_distribution=slice_distribution(index, dist.inflated_distribution), name=name) # this is very ad-hoc solution params = dist.parameters.copy() for key, val in list(params.items()): if isinstance(val, (np.ndarray, tf.Tensor)): params[key] = tf.gather(val, indices=index, axis=0) return dist.__class__(**params)
def estimate_Izx(fn_px_z: Callable[[tf.Tensor], tfd.Distribution], pz: tfd.Distribution, n_samples_z: int = 10000, n_mcmc_x: int = 100, batch_size: int = 32, verbose: bool = True): log_px_z = [] prog = tqdm(desc='I(Z;X)', total=n_samples_z * n_mcmc_x, unit='samples', disable=not verbose) for start in range(0, n_samples_z, batch_size): batch_z = min(n_samples_z - start, batch_size) z = pz.sample(batch_z) px_z = fn_px_z(z) batch_llk = [] for start in range(0, n_mcmc_x, batch_size): batch_x = min(n_mcmc_x - start, batch_size) x = px_z.sample(batch_x) batch_llk.append(px_z.log_prob(x)) prog.update(batch_z * batch_x) batch_llk = tf.concat(batch_llk, axis=0) log_px_z.append(batch_llk) ## finalize prog.clear() prog.close() log_px_z = tf.concat(log_px_z, axis=1) # [n_mcmc_x, n_samples_z] ## calculate the MI log_px = tf.reduce_logsumexp(log_px_z, axis=1, keepdims=True) - \ tf.math.log(tf.cast(n_samples_z, tf.float32)) H_x = tf.reduce_mean(log_px) print(H_x) exit() mi = tf.reduce_mean(log_px_z - log_px) return mi
def tfd_analytic_sample(n: int, dist: tfd.Distribution, limits: ztyping.ObsTypeInput): """Sample analytically with a `tfd.Distribution` within the limits. No preprocessing. Args: n: Number of samples to get dist: Distribution to sample from limits: Limits to sample from within Returns: `tf.Tensor` (n, n_obs): The sampled data with the number of samples and the number of observables. """ if limits.n_limits > 1: raise NotImplementedError (lower_bound,), (upper_bound,) = limits.limits lower_prob_lim = dist.cdf(lower_bound) upper_prob_lim = dist.cdf(upper_bound) prob_sample = ztf.random_uniform(shape=(n, limits.n_obs), minval=lower_prob_lim, maxval=upper_prob_lim) sample = dist.quantile(prob_sample) return sample
def predict_f_samples(self, predict_at: tf.Tensor, sample_shape: Union[tuple, list], data: Tuple[tf.Tensor, tf.Tensor] = None, w_distrib: tfd.Distribution = None, full_cov: bool = False, use_weight_space: bool = None, **kwargs) -> tf.Tensor: """ Generate samples from random function $f$ evaluated at test locations. """ if use_weight_space is None: if self.basis is None: num_weights = predict_at.shape[-1] else: num_weights = self.basis.units num_predict = predict_at.shape[-2] if full_cov else 1 use_weight_space = num_predict > num_weights # Sample via weight-space representation if use_weight_space: if w_distrib is None: w_samples = self.predict_w_samples(sample_shape, data, **kwargs) else: w_samples = w_distrib.sample(sample_shape) if self.basis is None: x = predict_at else: x = self.basis(predict_at) f_samples = tf.matmul(w_samples, x, transpose_b=True) return f_samples # Sample via function-space representation f_distrib = self.predict_f(predict_at, data=data, w_distrib=w_distrib, full_cov=full_cov, **kwargs) return f_distrib.sample(sample_shape)
def update_w_samples(self, w_distrib: tfd.Distribution, w_samples: tf.Tensor, data: Tuple[tf.Tensor, tf.Tensor], noisy: bool = True) -> tf.Tensor: """ Use Matheron's rule to directly condition weights drawn from the prior on observations (x, y). """ x, y = self.preprocess_data(data) yhat = tf.matmul(x, w_samples[..., None, :], transpose_b=True) yhat += tf.sqrt(self.likelihood.variance) * tf.random.normal( shape=yhat.shape, dtype=yhat.dtype) resid = y - yhat # implicit negative noise_var = self.likelihood.variance if isinstance(w_distrib, tfd.MultivariateNormalDiag): D = w_distrib.stddev() B = D[..., None, :] * x precis = tf.matmul(B, B, transpose_b=True) if noisy: precis += noise_var * tf.eye(precis.shape[-2], dtype=precis.dtype) sqprec = jitter_cholesky(precis) solv = parallel_solve(tf.linalg.cholesky_solve, sqprec, resid) update = D * tf.squeeze(tf.matmul(solv, B, transpose_a=True), -2) elif isinstance(w_distrib, tfd.MultivariateNormalTriL): L = w_distrib.scale B = tf.matmul(x, L) precis = tf.matmul(B, B, transpose_b=True) if noisy: precis += noise_var * tf.eye(precis.shape[-2], dtype=precis.dtype) sqprec = jitter_cholesky(precis) solv = parallel_solve(tf.linalg.cholesky_solve, sqprec, resid) update = tf.matmul(tf.matmul(tf.squeeze(solv, -1), B), L, transpose_b=True) else: raise NotImplementedError return w_samples + update