Example #1
0
def disentangled_inferred_prior_loss(qZ_X: Distribution,
                                     only_mean: bool = False,
                                     lambda_offdiag: float = 2.,
                                     lambda_diag: float = 1.) -> Tensor:
  r""" Disentangled inferred prior (DIP) matches the covariance of the prior
  distributions with the inferred prior

  Uses `cov(z_mean) = E[z_mean*z_mean^T] - E[z_mean]E[z_mean]^T`.

  Arguments:
    qZ_X : `tensorflow_probability.Distribution`
    only_mean : A Boolean. If `True`, applying DIP constraint only on the
      mean of latents `Cov[E(z)]` (i.e. type 'i'),
      otherwise, `E[Cov(z)] + Cov[E(z)]` (i.e. type 'ii')
    lambda_offdiag : A Scalar. Weight for penalizing the off-diagonal part of
      covariance matrix.
    lambda_diag : A Scalar. Weight for penalizing the diagonal.

  Reference:
    Kumar, A., Sattigeri, P., Balakrishnan, A., 2018. Variational Inference of
      Disentangled Latent Concepts from Unlabeled Observations.
      arXiv:1711.00848 [cs, stat].
    Github code https://github.com/IBM/AIX360
    Github code https://github.com/google-research/disentanglement_lib

  """
  z_mean = qZ_X.mean()
  shape = z_mean.shape
  if len(shape) > 2:
    # [sample_shape * batch_size, zdim]
    z_mean = tf.reshape(
        z_mean, (tf.cast(tf.reduce_prod(shape[:-1]), tf.int32),) + shape[-1:])
  expectation_z_mean_z_mean_t = tf.reduce_mean(tf.expand_dims(z_mean, 2) *
                                               tf.expand_dims(z_mean, 1),
                                               axis=0)
  expectation_z_mean = tf.reduce_mean(z_mean, axis=0)
  # cov_zmean [zdim, zdim]
  cov_zmean = tf.subtract(
      expectation_z_mean_z_mean_t,
      tf.expand_dims(expectation_z_mean, 1) *
      tf.expand_dims(expectation_z_mean, 0))
  # Eq(5)
  if only_mean:
    z_cov = cov_zmean
  else:
    z_var = qZ_X.variance()
    if len(shape) > 2:
      z_var = tf.reshape(
          z_var, (tf.cast(tf.reduce_prod(shape[:-1]), tf.int32),) + shape[-1:])
    # mean_zcov [zdim, zdim]
    mean_zcov = tf.reduce_mean(tf.linalg.diag(z_var), axis=0)
    z_cov = cov_zmean + mean_zcov
  # Eq(6) and Eq(7)
  # z_cov [sample_shape, zdim, zdim]
  # z_cov_diag [sample_shape, zdim]
  # z_cov_offdiag [sample_shape, zdim, zdim]
  z_cov_diag = tf.linalg.diag_part(z_cov)
  z_cov_offdiag = z_cov - tf.linalg.diag(z_cov_diag)
  return lambda_offdiag * tf.reduce_sum(z_cov_offdiag ** 2) + \
    lambda_diag * tf.reduce_sum((z_cov_diag - 1.) ** 2)
Example #2
0
def tfd_analytic_sample(n: int, dist: tfd.Distribution,
                        limits: ztyping.ObsTypeInput):
    """Sample analytically with a `tfd.Distribution` within the limits. No preprocessing.

    Args:
        n: Number of samples to get
        dist: Distribution to sample from
        limits: Limits to sample from within

    Returns:
        The sampled data with the number of samples and the number of observables.
    """
    lower_bound, upper_bound = limits.rect_limits
    lower_prob_lim = dist.cdf(lower_bound)
    upper_prob_lim = dist.cdf(upper_bound)

    shape = (n, 1)
    prob_sample = z.random.uniform(shape=shape,
                                   minval=lower_prob_lim,
                                   maxval=upper_prob_lim)
    prob_sample.set_shape((None, 1))
    try:
        sample = dist.quantile(prob_sample)
    except NotImplementedError:
        raise AnalyticSamplingNotImplementedError
    sample.set_shape((None, limits.n_obs))
    return sample
Example #3
0
 def _mi_loss(
     self,
     Q: Sequence[Distribution],
     py_z: Distribution,
     training: Optional[bool] = None,
     which_latents_sampling: Optional[List[int]] = None,
 ) -> Tuple[tf.Tensor, List[tf.Tensor]]:
     ## sample the prior
     batch_shape = Q[0].batch_shape_tensor()
     if which_latents_sampling is None:
         which_latents_sampling = list(range(len(Q)))
     z_prime = [
         q.KL_divergence.prior.sample(batch_shape)
         if i in which_latents_sampling else tf.stop_gradient(
             tf.convert_to_tensor(q)) for i, q in enumerate(Q)
     ]
     if len(z_prime) == 1:
         z_prime = z_prime[0]
     ## decoding
     px = self.decode(z_prime, training=training)[0]
     if px.reparameterization_type == NOT_REPARAMETERIZED:
         x = px.mean()
     else:
         x = tf.convert_to_tensor(px)
     # should not stop gradient here, generator need to be updated
     # x = tf.stop_gradient(x)
     Q_prime = self.encode(x, training=training)
     qy_z = self.predict_factors(latents=Q_prime, training=training)
     ## y ~ p(y|z), stop gradient here is important to prevent the encoder
     # updated twice this significantly increase the stability, otherwise,
     # encoder and latents often get NaNs gradients
     if self.reverse_mi:  # D_kl(p(y|z)||q(y|z))
         y_samples = tf.stop_gradient(py_z.sample())
         Dkl = py_z.log_prob(y_samples) - qy_z.log_prob(y_samples)
     else:  # D_kl(q(y|z)||p(y|z))
         y_samples = tf.stop_gradient(qy_z.sample())
         Dkl = qy_z.log_prob(y_samples) - py_z.log_prob(y_samples)
     ## only calculate MI for unsupervised data
     mi_y = tf.reduce_mean(Dkl)
     ## mutual information (we want to maximize this, hence, add it to the llk)
     if training:
         mi_y = tf.cond(
             self.step >= self.steps_without_mi,
             true_fn=lambda: mi_y,
             false_fn=lambda: tf.stop_gradient(mi_y),
         )
     else:
         mi_y = tf.stop_gradient(mi_y)
     mi_y = self.mi_coef * mi_y
     ## this value is just for monitoring
     mi_z = []
     for q, z in zip(as_tuple(Q_prime), as_tuple(z_prime)):
         mi = tf.reduce_mean(tf.stop_gradient(q.log_prob(z)))
         mi = tf.cond(tf.math.is_nan(mi),
                      true_fn=lambda: 0.,
                      false_fn=lambda: tf.clip_by_value(mi, -1e8, 1e8))
         mi_z.append(mi)
     return mi_y, mi_z
Example #4
0
def total_correlation(z_samples: Tensor, qZ_X: Distribution) -> Tensor:
    r"""Estimate of total correlation using Gaussian distribution on a batch.

  We need to compute the expectation over a batch of:
  `E_j [log(q(z(x_j))) - log(prod_l q(z(x_j)_l))]`

  We ignore the constants as they do not matter for the minimization.
  The constant should be equal to
  `(num_latents - 1) * log(batch_size * dataset_size)`

  If `alpha = gamma = 1`, Eq(4) can be written as `ELBO + (1 - beta) * TC`.
  (i.e. `(1. - beta) * total_correlation(z_sampled, qZ_X)`)

  Parameters
  ----------
  z_samples : Tensor
      shape `[batch_size, num_latents]` - tensor with sampled representation.
  qZ_X : Distribution
      the posterior distribution, shape `[batch_size, num_latents]`

  Note
  ----
  This involve calculating pair-wise distance, memory complexity up to
      `O(n*n*d)`.

  Returns
  -------
  Total correlation estimated on a batch.

  References
  ----------
  Chen, R.T.Q., Li, X., Grosse, R., Duvenaud, D., 2019. Isolating Sources of
      Disentanglement in Variational Autoencoders. arXiv:1802.04942 [cs, stat].
  Github code https://github.com/google-research/disentanglement_lib
  """
    gaus = Normal(loc=tf.expand_dims(qZ_X.mean(), 0),
                  scale=tf.expand_dims(qZ_X.stddev(), 0))
    # Compute log(q(z(x_j)|x_i)) for every sample in the batch, which is a
    # tensor of size [batch_size, batch_size, num_latents]. In the following
    # comments, [batch_size, batch_size, num_latents] are indexed by [j, i, l].
    log_qz_prob = gaus.log_prob(tf.expand_dims(z_samples, 1))
    # Compute log prod_l p(z(x_j)_l) = sum_l(log(sum_i(q(z(z_j)_l|x_i)))
    # + constant) for each sample in the batch, which is a vector of size
    # [batch_size,].
    log_qz_product = tf.reduce_sum(tf.reduce_logsumexp(log_qz_prob,
                                                       axis=1,
                                                       keepdims=False),
                                   axis=1,
                                   keepdims=False)
    # Compute log(q(z(x_j))) as log(sum_i(q(z(x_j)|x_i))) + constant =
    # log(sum_i(prod_l q(z(x_j)_l|x_i))) + constant.
    log_qz = tf.reduce_logsumexp(tf.reduce_sum(log_qz_prob,
                                               axis=2,
                                               keepdims=False),
                                 axis=1,
                                 keepdims=False)
    return tf.reduce_mean(log_qz - log_qz_product)
Example #5
0
def maximum_mean_discrepancy(
    qZ: Distribution,
    pZ: Distribution,
    q_sample_shape: Union[int, List[int]] = (),
    p_sample_shape: Union[int, List[int]] = 100,
    kernel: Literal['gaussian', 'linear',
                    'polynomial'] = 'gaussian') -> Tensor:
    r""" is a distance-measure between distributions p(X) and q(Y) which is
  defined as the squared distance between their embeddings in the a
  "reproducing kernel Hilbert space".

  Given n examples from p(X) and m samples from q(Y), one can formulate a
  test statistic based on the empirical estimate of the MMD:

  MMD^2(P, Q) = || \E{\phi(x)} - \E{\phi(y)} ||^2
              = \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) }

  Arguments:
    nq : a Scalar. Number of posterior samples
    np : a Scalar. Number of prior samples

  Reference:
    Gretton, A., Borgwardt, K., Rasch, M.J., Scholkopf, B., Smola, A.J., 2008.
      "A Kernel Method for the Two-Sample Problem". arXiv:0805.2368 [cs].

  """
    assert isinstance(
        qZ, Distribution
    ), 'qZ must be instance of tensorflow_probability.Distribution'
    assert isinstance(
        pZ, Distribution
    ), 'pZ must be instance of tensorflow_probability.Distribution'
    # prepare the samples
    if q_sample_shape is None:  # reuse sampled examples
        x = tf.convert_to_tensor(qZ)
    else:
        x = qZ.sample(q_sample_shape)
    y = pZ.sample(p_sample_shape)
    # select the kernel
    if kernel == 'gaussian':
        kernel = gaussian_kernel
    elif kernel == 'linear':
        kernel = linear_kernel
    elif kernel == 'polynomial':
        kernel = polynomial_kernel
    else:
        raise NotImplementedError("No support for kernel: '%s'" % kernel)
    k_xx = kernel(x, x)
    k_yy = kernel(y, y)
    k_xy = kernel(x, y)
    return tf.reduce_mean(k_xx) + tf.reduce_mean(
        k_yy) - 2 * tf.reduce_mean(k_xy)
Example #6
0
def get_fixed_topology_bijector(dist: tfd.Distribution,
                                topology_pins=tp.Dict[str,
                                                      TensorflowTreeTopology]):
    if isinstance(dist,
                  BaseTreeDistribution) and dist.tree_name in topology_pins:
        topology = topology_pins[dist.tree_name]
        tree_bijector: TreeRatioBijector = (
            dist.experimental_default_event_space_bijector(topology=topology))
        return FixedTopologyRootedTreeBijector(
            topology,
            tree_bijector.bijectors.node_heights,
            sampling_times=dist.
            sampling_times,  # TODO: Make sure dist has fixed sampling times
        )
    else:
        return dist.experimental_default_event_space_bijector()
Example #7
0
def slice_distribution(index, dist: tfd.Distribution,
                       name=None) -> tfd.Distribution:
  r""" Apply indexing on distribution parameters and return another
  `Distribution` """
  assert isinstance(dist, tfd.Distribution), \
    "dist must be instance of Distribution, but given: %s" % str(type(dist))
  if name is None:
    name = dist.name
  ## compound distribution
  if isinstance(dist, tfd.Independent):
    return tfd.Independent(
        distribution=slice_distribution(index, dist.distribution),
        reinterpreted_batch_ndims=dist.reinterpreted_batch_ndims,
        name=name)
  elif isinstance(dist, ZeroInflated):
    return ZeroInflated(\
      count_distribution=slice_distribution(index, dist.count_distribution),
      inflated_distribution=slice_distribution(index, dist.inflated_distribution),
      name=name)

  # this is very ad-hoc solution
  params = dist.parameters.copy()
  for key, val in list(params.items()):
    if isinstance(val, (np.ndarray, tf.Tensor)):
      params[key] = tf.gather(val, indices=index, axis=0)
  return dist.__class__(**params)
Example #8
0
def estimate_Izx(fn_px_z: Callable[[tf.Tensor], tfd.Distribution],
                 pz: tfd.Distribution,
                 n_samples_z: int = 10000,
                 n_mcmc_x: int = 100,
                 batch_size: int = 32,
                 verbose: bool = True):
  log_px_z = []
  prog = tqdm(desc='I(Z;X)',
              total=n_samples_z * n_mcmc_x,
              unit='samples',
              disable=not verbose)
  for start in range(0, n_samples_z, batch_size):
    batch_z = min(n_samples_z - start, batch_size)
    z = pz.sample(batch_z)
    px_z = fn_px_z(z)
    batch_llk = []
    for start in range(0, n_mcmc_x, batch_size):
      batch_x = min(n_mcmc_x - start, batch_size)
      x = px_z.sample(batch_x)
      batch_llk.append(px_z.log_prob(x))
      prog.update(batch_z * batch_x)
    batch_llk = tf.concat(batch_llk, axis=0)
    log_px_z.append(batch_llk)
  ## finalize
  prog.clear()
  prog.close()
  log_px_z = tf.concat(log_px_z, axis=1)  # [n_mcmc_x, n_samples_z]
  ## calculate the MI
  log_px = tf.reduce_logsumexp(log_px_z, axis=1, keepdims=True) - \
    tf.math.log(tf.cast(n_samples_z, tf.float32))
  H_x = tf.reduce_mean(log_px)
  print(H_x)
  exit()
  mi = tf.reduce_mean(log_px_z - log_px)
  return mi
Example #9
0
def tfd_analytic_sample(n: int, dist: tfd.Distribution, limits: ztyping.ObsTypeInput):
    """Sample analytically with a `tfd.Distribution` within the limits. No preprocessing.

    Args:
        n: Number of samples to get
        dist: Distribution to sample from
        limits: Limits to sample from within

    Returns:
        `tf.Tensor` (n, n_obs): The sampled data with the number of samples and the number of observables.
    """
    if limits.n_limits > 1:
        raise NotImplementedError
    (lower_bound,), (upper_bound,) = limits.limits
    lower_prob_lim = dist.cdf(lower_bound)
    upper_prob_lim = dist.cdf(upper_bound)

    prob_sample = ztf.random_uniform(shape=(n, limits.n_obs), minval=lower_prob_lim,
                                     maxval=upper_prob_lim)
    sample = dist.quantile(prob_sample)
    return sample
    def predict_f_samples(self,
                          predict_at: tf.Tensor,
                          sample_shape: Union[tuple, list],
                          data: Tuple[tf.Tensor, tf.Tensor] = None,
                          w_distrib: tfd.Distribution = None,
                          full_cov: bool = False,
                          use_weight_space: bool = None,
                          **kwargs) -> tf.Tensor:
        """
    Generate samples from random function $f$ evaluated at test locations.
    """
        if use_weight_space is None:
            if self.basis is None:
                num_weights = predict_at.shape[-1]
            else:
                num_weights = self.basis.units
            num_predict = predict_at.shape[-2] if full_cov else 1
            use_weight_space = num_predict > num_weights

        # Sample via weight-space representation
        if use_weight_space:
            if w_distrib is None:
                w_samples = self.predict_w_samples(sample_shape, data,
                                                   **kwargs)
            else:
                w_samples = w_distrib.sample(sample_shape)

            if self.basis is None:
                x = predict_at
            else:
                x = self.basis(predict_at)

            f_samples = tf.matmul(w_samples, x, transpose_b=True)
            return f_samples

        # Sample via function-space representation
        f_distrib = self.predict_f(predict_at,
                                   data=data,
                                   w_distrib=w_distrib,
                                   full_cov=full_cov,
                                   **kwargs)

        return f_distrib.sample(sample_shape)
    def update_w_samples(self,
                         w_distrib: tfd.Distribution,
                         w_samples: tf.Tensor,
                         data: Tuple[tf.Tensor, tf.Tensor],
                         noisy: bool = True) -> tf.Tensor:
        """
    Use Matheron's rule to directly condition weights drawn from the prior
    on observations (x, y).
    """
        x, y = self.preprocess_data(data)
        yhat = tf.matmul(x, w_samples[..., None, :], transpose_b=True)
        yhat += tf.sqrt(self.likelihood.variance) * tf.random.normal(
            shape=yhat.shape, dtype=yhat.dtype)
        resid = y - yhat  # implicit negative
        noise_var = self.likelihood.variance
        if isinstance(w_distrib, tfd.MultivariateNormalDiag):
            D = w_distrib.stddev()
            B = D[..., None, :] * x
            precis = tf.matmul(B, B, transpose_b=True)
            if noisy:
                precis += noise_var * tf.eye(precis.shape[-2],
                                             dtype=precis.dtype)
            sqprec = jitter_cholesky(precis)
            solv = parallel_solve(tf.linalg.cholesky_solve, sqprec, resid)
            update = D * tf.squeeze(tf.matmul(solv, B, transpose_a=True), -2)
        elif isinstance(w_distrib, tfd.MultivariateNormalTriL):
            L = w_distrib.scale
            B = tf.matmul(x, L)
            precis = tf.matmul(B, B, transpose_b=True)
            if noisy:
                precis += noise_var * tf.eye(precis.shape[-2],
                                             dtype=precis.dtype)
            sqprec = jitter_cholesky(precis)
            solv = parallel_solve(tf.linalg.cholesky_solve, sqprec, resid)
            update = tf.matmul(tf.matmul(tf.squeeze(solv, -1), B),
                               L,
                               transpose_b=True)
        else:
            raise NotImplementedError

        return w_samples + update