Beispiel #1
0
def sinkhorn_logspace(logits_rows, logits_cols, costs, n_steps,
                      entropy_strength):
    """Sinkhorn algorithm for (unbalanced) entropy-regularized optimal transport.

  The updates are computed in log-space and are thus more stable.

  Args:
    logits_rows: (..., n) tensor with the logits of the row-sum constraint
    logits_cols: (..., m) tensor with the logits of the column-sum constraint
    costs: (..., n, m) tensor holding the transportation costs
    n_steps: How many Sinkhorn iterations to perform.
    entropy_strength: The strength of the entropic regularizer

  Returns:
    (..., n, m) tensor with the computation optimal transportation matrices
  """
    assert n_steps > 0
    assert entropy_strength > 0

    logits_rows = tf.expand_dims(logits_rows, axis=-1)
    logits_cols = tf.expand_dims(logits_cols, axis=-2)
    log_kernel = -costs / entropy_strength + logits_rows + logits_cols

    log_lbd_cols = tf.zeros_like(logits_cols)
    for _ in range(n_steps):
        log_lbd_rows = logits_rows - tf.reduce_logsumexp(
            log_kernel + log_lbd_cols, axis=-1, keepdims=True)
        log_lbd_cols = logits_cols - tf.reduce_logsumexp(
            log_kernel + log_lbd_rows, axis=-2, keepdims=True)
    return tf.exp(log_lbd_cols + log_kernel + log_lbd_rows)
Beispiel #2
0
 def testNoWeights(self):
     logx_ = np.array([[0., -1, 1000.], [0, 1, -1000.], [-5, 0, 5]])
     logx = tf.constant(logx_)
     expected = tf.reduce_logsumexp(logx, axis=-1)
     grad_expected, _ = tfp.math.value_and_gradient(
         lambda logx: tf.reduce_logsumexp(logx, axis=-1), logx)
     actual, actual_sgn = tfp.math.reduce_weighted_logsumexp(
         logx, axis=-1, return_sign=True)
     grad_actual, _ = tfp.math.value_and_gradient(
         lambda logx: tfp.math.reduce_weighted_logsumexp(logx, axis=-1),
         logx)
     [
         actual_,
         actual_sgn_,
         grad_actual_,
         expected_,
         grad_expected_,
     ] = self.evaluate([
         actual,
         actual_sgn,
         grad_actual,
         expected,
         grad_expected,
     ])
     self.assertAllEqual(expected_, actual_)
     self.assertAllEqual(grad_expected_, grad_actual_)
     self.assertAllEqual([1., 1, 1], actual_sgn_)
Beispiel #3
0
def forward(a0, a, e, x, num):
    """Markov chain forward algorithm, with tf graph construction."""
    emit_lp = tf.matmul(e, x, transpose_b=True)
    f = a0 + emit_lp[:, 0]
    for j in range(1, num):
        f = tf.reduce_logsumexp(f[:, None] + a, axis=0) + emit_lp[:, j]

    return tf.reduce_logsumexp(f)
Beispiel #4
0
def forward_mean(a0, a, e, num):
    """Mean of HMM (by individual location)."""
    f = []
    f.append(a0[None, :])
    for j in range(1, num):
        f.append(
            tf.reduce_logsumexp(f[-1][0, :, None] + a, axis=0, keepdims=True))
    fm = tf.concat(f, axis=0)
    hmean = tf.reduce_logsumexp(fm[:, :, None] + e[None, :, :], axis=1)

    return tf.exp(hmean)
Beispiel #5
0
    def compute_log_conditional_distribution(self, X):
        print("---Tracing---log_conditional_distribub")
        before_reduce_sum = tf.map_fn(lambda y: self.compute_log_pdf(X, y),
                                      self.y_unique,
                                      fn_output_signature=tf.float32)

        reduce_sum = tf.reduce_logsumexp(
            before_reduce_sum, axis=-1) + tf.expand_dims(
                tf.math.log(self.logits_y + self.stable), axis=1)

        log_joint_prob = tf.transpose(reduce_sum)

        return log_joint_prob - tf.reduce_logsumexp(
            log_joint_prob, axis=-1, keepdims=True)
Beispiel #6
0
def reduce_logmeanexp(input_tensor, axis=None, keepdims=False, name=None):
    """Computes `log(mean(exp(input_tensor)))`.

  Reduces `input_tensor` along the dimensions given in `axis`.  Unless
  `keepdims` is true, the rank of the tensor is reduced by 1 for each entry in
  `axis`. If `keepdims` is true, the reduced dimensions are retained with length
  1.

  If `axis` has no entries, all dimensions are reduced, and a tensor with a
  single element is returned.

  This function is more numerically stable than `log(reduce_mean(exp(input)))`.
  It avoids overflows caused by taking the exp of large inputs and underflows
  caused by taking the log of small inputs.

  Args:
    input_tensor: The tensor to reduce. Should have numeric type.
    axis: The dimensions to reduce. If `None` (the default), reduces all
      dimensions. Must be in the range `[-rank(input_tensor),
      rank(input_tensor))`.
    keepdims:  Boolean.  Whether to keep the axis as singleton dimensions.
      Default value: `False` (i.e., squeeze the reduced dimensions).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., `'reduce_logmeanexp'`).

  Returns:
    log_mean_exp: The reduced tensor.
  """
    with tf.name_scope(name or 'reduce_logmeanexp'):
        lse = tf.reduce_logsumexp(input_tensor, axis=axis, keepdims=keepdims)
        n = prefer_static.size(input_tensor) // prefer_static.size(lse)
        log_n = tf.math.log(tf.cast(n, lse.dtype))
        return lse - log_n
Beispiel #7
0
  def _entropy(self):
    if self._logits is None:
      # If we only have probs, there's not much we can do to ensure numerical
      # precision.
      probs = tf.convert_to_tensor(self._probs)
      return -tf.reduce_sum(
          tf.math.multiply_no_nan(tf.math.log(probs), probs),
          axis=-1)
    # The following result can be derived as follows. Let s[i] be a logit.
    # The entropy is:
    #   H = -sum_i(p[i] * log(p[i]))
    #     = -sum_i(p[i] * (s[i] - logsumexp(s))
    #     = logsumexp(s) - sum_i(p[i] * s[i])
    logits = tf.convert_to_tensor(self._logits)
    logits = logits - tf.reduce_max(logits, axis=-1, keepdims=True)
    lse_logits = tf.reduce_logsumexp(logits, axis=-1)

    # TODO(b/161014180): Workaround to support correct gradient calculations
    # with -inf logits.
    masked_logits = tf.where(
        (tf.math.is_inf(logits) & (logits < 0)),
        tf.cast(1.0, dtype=logits.dtype), logits)
    return lse_logits - tf.reduce_sum(
        tf.math.multiply_no_nan(masked_logits, tf.math.exp(logits)),
        axis=-1) / tf.math.exp(lse_logits)
Beispiel #8
0
def _masked_logmeanexp(input_tensor, mask_tensor, axis=None):
    """Compute log(mean(exp(input_tensor))) on masked elements.

  Args:
    input_tensor: `float`-like `Tensor` to be reduced.
    mask_tensor: `bool`-like `Tensor` of the same shape of `input_tensor`.
      Only the elements from `input_tensor` with the mask of `True` in
      `mask_tensor` will be selected for calculation.
    axis: The dimensions to sum across.
      Default value: `None`, i.e. all dimensions will be reduced.

  Returns:
    reduced_tensor: `float`-like `Tensor` contains the reduced result.
  """
    with tf.name_scope('masked_logmeanexp'):
        input_tensor = tf.convert_to_tensor(input_tensor,
                                            dtype_hint=tf.float32,
                                            name='input_tensor')
        mask_tensor = tf.convert_to_tensor(mask_tensor,
                                           dtype_hint=tf.bool,
                                           name='mask_tensor')
        # To mask out a value from log space, one could push the value to -inf.
        masked_input = _get_masked_scores(input_tensor, mask_tensor)
        log_n = tf.math.log(
            tf.cast(tf.reduce_sum(tf.where(mask_tensor, 1, 0), axis=axis),
                    masked_input.dtype))
        return tf.reduce_logsumexp(masked_input, axis=axis) - log_n
Beispiel #9
0
 def _log_prob(self, x):
   with tf.control_dependencies(self._runtime_assertions):
     x = self._pad_sample_dims(x)
     log_prob_x = self.components_distribution.log_prob(x)  # [S, B, k]
     log_mix_prob = tf.nn.log_softmax(
         self.mixture_distribution.logits_parameter(), axis=-1)  # [B, k]
     return tf.reduce_logsumexp(log_prob_x + log_mix_prob, axis=-1)  # [S, B]
Beispiel #10
0
    def importance_weighted_divergence_fn(q_samples):
        q_lp = precomputed_surrogate_log_prob
        if q_lp is None:
            q_lp = surrogate_posterior.log_prob(q_samples)
        target_log_prob = nest_util.call_fn(target_log_prob_fn, q_samples)
        log_weights = target_log_prob - q_lp

        # Explicitly break out `importance_sample_size` as a separate axis.
        log_weights = tf.reshape(
            log_weights,
            ps.concat([[-1, importance_sample_size],
                       ps.shape(log_weights)[1:]],
                      axis=0))
        log_sum_weights = tf.reduce_logsumexp(log_weights, axis=1)
        log_avg_weights = log_sum_weights - tf.math.log(
            tf.cast(importance_sample_size, dtype=log_weights.dtype))

        if gradient_estimator == GradientEstimators.DOUBLY_REPARAMETERIZED:
            # Adapted from original implementation at
            # https://github.com/google-research/google-research/blob/master/dreg_estimators/model.py
            normalized_weights = tf.stop_gradient(
                tf.nn.softmax(log_weights, axis=1))
            log_weights_with_stopped_q = tf.reshape(
                target_log_prob -
                stopped_surrogate_posterior.log_prob(q_samples),
                ps.shape(log_weights))
            dreg_objective = tf.reduce_sum(log_weights_with_stopped_q *
                                           tf.square(normalized_weights),
                                           axis=1)
            # Replace the objective's gradient with the doubly-reparameterized
            # gradient.
            log_avg_weights = tf.stop_gradient(log_avg_weights) + (
                dreg_objective - tf.stop_gradient(dreg_objective))

        return discrepancy_fn(log_avg_weights)
 def _sample_3d(self, n, mean_direction, concentration, seed=None):
     """Specialized inversion sampler for 3D."""
     u_shape = ps.concat(
         [[n],
          self._batch_shape_tensor(mean_direction=mean_direction,
                                   concentration=concentration)],
         axis=0)
     z = samplers.uniform(u_shape, seed=seed, dtype=self.dtype)
     # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could
     # be bisected for bounded sampling runtime (i.e. not rejection sampling).
     # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/
     # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa
     # We must protect against both kappa and z being zero.
     safe_conc = tf.where(concentration > 0, concentration,
                          tf.ones_like(concentration))
     safe_z = tf.where(z > 0, z, tf.ones_like(z))
     safe_u = 1 + tf.reduce_logsumexp(
         [tf.math.log(safe_z),
          tf.math.log1p(-safe_z) - 2 * safe_conc],
         axis=0) / safe_conc
     # Limit of the above expression as kappa->0 is 2*z-1
     u = tf.where(concentration > 0., safe_u, 2 * z - 1)
     # Limit of the expression as z->0 is -1.
     u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u)
     if not self._allow_nan_stats:
         u = tf.debugging.check_numerics(u, 'u in _sample_3d')
     return u[..., tf.newaxis]
Beispiel #12
0
 def _log_variance(self):
     # Following calculation is based on law of total variance:
     #
     # Var[Z] = E[Var[Z | V]] + Var[E[Z | V]]
     #
     # where,
     #
     # Z|v ~ interpolate_affine[v](distribution)
     # V ~ mixture_distribution
     #
     # thus,
     #
     # E[Var[Z | V]] = sum{ prob[d] Var[d] : d=0, ..., deg-1 }
     # Var[E[Z | V]] = sum{ prob[d] (Mean[d] - Mean)**2 : d=0, ..., deg-1 }
     v = tf.stack(
         [
             # log(self.distribution.variance()) = log(Var[d]) = log(rate[d])
             self.distribution.log_rate,
             # log((Mean[d] - Mean)**2)
             2. * tf.math.log(
                 tf.abs(self.distribution.mean() -
                        self._mean()[..., tf.newaxis])),
         ],
         axis=-1)
     return tf.reduce_logsumexp(
         self.mixture_distribution.logits[..., tf.newaxis] + v,
         axis=[-2, -1])
Beispiel #13
0
    def _log_prob(self, y, **kwargs):
        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)

        # For caching to work, it is imperative that the bijector is the first to
        # modify the input.
        x = self.bijector.inverse(y, **bijector_kwargs)
        event_ndims = tf.nest.map_structure(ps.rank_from_shape,
                                            self._event_shape_tensor(),
                                            self.event_shape)

        ildj = self.bijector.inverse_log_det_jacobian(y,
                                                      event_ndims=event_ndims,
                                                      **bijector_kwargs)
        if self.bijector._is_injective:  # pylint: disable=protected-access
            base_log_prob = self.distribution.log_prob(x,
                                                       **distribution_kwargs)
            return base_log_prob + tf.cast(ildj, base_log_prob.dtype)

        # Compute log_prob on each element of the inverse image.
        lp_on_fibers = []
        for x_i, ildj_i in zip(x, ildj):
            base_log_prob = self.distribution.log_prob(x_i,
                                                       **distribution_kwargs)
            lp_on_fibers.append(base_log_prob +
                                tf.cast(ildj_i, base_log_prob.dtype))
        return tf.reduce_logsumexp(tf.stack(lp_on_fibers), axis=0)
    def _log_prob(self, y, **kwargs):
        distribution_kwargs, bijector_kwargs = self._kwargs_split_fn(kwargs)
        override_event_shape = tf.convert_to_tensor(self._override_event_shape)
        override_batch_shape = tf.convert_to_tensor(self._override_batch_shape)
        base_is_scalar_batch = self.distribution.is_scalar_batch()

        # For caching to work, it is imperative that the bijector is the first to
        # modify the input.
        x = self.bijector.inverse(y, **bijector_kwargs)
        event_ndims = self._maybe_get_static_event_ndims(override_event_shape)

        ildj = self.bijector.inverse_log_det_jacobian(y,
                                                      event_ndims=event_ndims,
                                                      **bijector_kwargs)
        if self.bijector._is_injective:  # pylint: disable=protected-access
            return self._finish_log_prob_for_one_fiber(y, x, ildj, event_ndims,
                                                       override_event_shape,
                                                       override_batch_shape,
                                                       base_is_scalar_batch,
                                                       **distribution_kwargs)

        lp_on_fibers = [
            self._finish_log_prob_for_one_fiber(  # pylint: disable=g-complex-comprehension
                y, x_i, ildj_i, event_ndims, override_event_shape,
                override_batch_shape, base_is_scalar_batch,
                **distribution_kwargs) for x_i, ildj_i in zip(x, ildj)
        ]
        return tf.reduce_logsumexp(tf.stack(lp_on_fibers), axis=0)
Beispiel #15
0
 def _log_cdf(self, x):
     x = self._pad_sample_dims(x)
     log_cdf_x = self.components_distribution.log_cdf(x)  # [S, B, k]
     log_mix_prob = tf.nn.log_softmax(self.mixture_distribution.logits,
                                      axis=-1)  # [B, k]
     return tf.reduce_logsumexp(input_tensor=log_cdf_x + log_mix_prob,
                                axis=-1)  # [S, B]
Beispiel #16
0
def make_hmm_params(vxln,
                    vcln,
                    uln,
                    rln,
                    lln,
                    transfer_mats,
                    eps=1e-32,
                    dtype=tf.float32):
    """Assemble the HMM parameters based on the s, u and r parameters."""

    # Assemble transition matrices.
    ulnfull = uln[:, None, :] * tf.ones((1, 3, 1), dtype=dtype)
    rlnfull = rln[:, None, :] * tf.ones((1, 3, 1), dtype=dtype)
    a0 = (tf.einsum('ijk,ijkl->l', ulnfull, transfer_mats['ut0']) +
          tf.einsum('ijk,ijkl->l', rlnfull, transfer_mats['rt0']) +
          (-1 / eps) * transfer_mats['nt0'])
    a = (tf.einsum('ijk,ijklf->lf', ulnfull, transfer_mats['ut']) +
         tf.einsum('ijk,ijklf->lf', rlnfull, transfer_mats['rt']) +
         (-1 / eps) * transfer_mats['nt'])

    # Assemble emission matrix.
    seqln = (tf.einsum('ij,ik->kj', vxln, transfer_mats['vxt']) +
             tf.einsum('ij,ik->kj', vcln, transfer_mats['vct']))
    if lln is not None:
        # Substitution matrix.
        e = tf.reduce_logsumexp(seqln[:, :, None] + lln[None, :, :], axis=1)
    else:
        e = seqln

    return a0, a, e
Beispiel #17
0
 def testNoWeights(self):
   logx_ = np.array([[0., -1, 1000.],
                     [0, 1, -1000.],
                     [-5, 0, 5]])
   logx = tf.constant(logx_)
   with tf.GradientTape() as tape:
     tape.watch(logx)
     expected = tf.reduce_logsumexp(input_tensor=logx, axis=-1)
   grad_expected = tape.gradient(expected, logx)
   with tf.GradientTape() as tape:
     tape.watch(logx)
     actual, actual_sgn = tfp.math.reduce_weighted_logsumexp(
         logx, axis=-1, return_sign=True)
   grad_actual = tape.gradient(actual, logx)
   [
       actual_,
       actual_sgn_,
       grad_actual_,
       expected_,
       grad_expected_,
   ] = self.evaluate([
       actual,
       actual_sgn,
       grad_actual,
       expected,
       grad_expected,
   ])
   self.assertAllEqual(expected_, actual_)
   self.assertAllEqual(grad_expected_, grad_actual_)
   self.assertAllEqual([1., 1, 1], actual_sgn_)
Beispiel #18
0
    def _entropy(self):
        if self._logits is None:
            # If we only have probs, there's not much we can do to ensure numerical
            # precision.
            probs = tf.convert_to_tensor(self._probs)
            return -tf.reduce_sum(
                tf.math.multiply_no_nan(tf.math.log(probs), probs), axis=-1)

        # The following result can be derived as follows. Write log(p[i]) as:
        # s[i]-m-lse(s[i]-m) where m=max(s), then you have:
        #   sum_i exp(s[i]-m-lse(s-m)) (s[i] - m - lse(s-m))
        #   = -m - lse(s-m) + sum_i s[i] exp(s[i]-m-lse(s-m))
        #   = -m - lse(s-m) + (1/exp(lse(s-m))) sum_i s[i] exp(s[i]-m)
        #   = -m - lse(s-m) + (1/sumexp(s-m)) sum_i s[i] exp(s[i]-m)
        # Write x[i]=s[i]-m then you have:
        #   = -m - lse(x) + (1/sum_exp(x)) sum_i s[i] exp(x[i])
        # Negating all of this result is the Shanon (discrete) entropy.
        logits = tf.convert_to_tensor(self._logits)
        m = tf.reduce_max(logits, axis=-1, keepdims=True)
        x = logits - m
        lse_logits = m[..., 0] + tf.reduce_logsumexp(x, axis=-1)
        sum_exp_x = tf.reduce_sum(tf.math.exp(x), axis=-1)
        return lse_logits - tf.reduce_sum(tf.math.multiply_no_nan(
            logits, tf.math.exp(x)),
                                          axis=-1) / sum_exp_x
 def _log_prob(self, x):
     x = self._pad_sample_dims(x)
     log_prob_x = self.components_distribution.log_prob(x)  # [S, B, k]
     log_mix_prob = tf.math.log_softmax(
         self.mixture_distribution.logits_parameter(), axis=-1)  # [B, k]
     return tf.reduce_logsumexp(log_prob_x + log_mix_prob,
                                axis=-1)  # [S, B]
Beispiel #20
0
 def _assert_valid_sample(self, x):
   if not self.validate_args:
     return x
   return distribution_util.with_dependencies([
       assert_util.assert_non_positive(x),
       assert_util.assert_near(
           tf.zeros([], dtype=self.dtype), tf.reduce_logsumexp(x, axis=[-1])),
   ], x)
Beispiel #21
0
 def _mean(self, distributions=None):
   if distributions is None:
     distributions = self.poisson_and_mixture_distributions()
   dist, mixture_dist = distributions
   return tf.exp(
       tf.reduce_logsumexp(
           mixture_dist.logits + dist.log_rate,
           axis=-1))
Beispiel #22
0
    def _log_prob(self, value):
        # The argument `value` is a tensor of sequences of observations.
        # `observation_batch_shape` is the shape of that tensor with the
        # sequence part removed.
        # `observation_batch_shape` is then broadcast to the full batch shape
        # to give the `batch_shape` that defines the shape of the result.

        observation_tensor_shape = tf.shape(value)
        observation_distribution = self.observation_distribution
        underlying_event_rank = tf.size(
            observation_distribution.event_shape_tensor())
        observation_batch_shape = observation_tensor_shape[:-1 -
                                                           underlying_event_rank]
        # value :: observation_batch_shape num_steps observation_event_shape
        batch_shape = tf.broadcast_dynamic_shape(observation_batch_shape,
                                                 self.batch_shape_tensor())
        num_states = self.transition_distribution.batch_shape_tensor()[-1]
        log_init = _extract_log_probs(num_states, self.initial_distribution)
        # log_init :: batch_shape num_states
        log_init = tf.broadcast_to(
            log_init, tf.concat([batch_shape, [num_states]], axis=0))
        log_transition = _extract_log_probs(num_states,
                                            self.transition_distribution)

        # `observation_event_shape` is the shape of each sequence of observations
        # emitted by the model.
        observation_event_shape = observation_tensor_shape[
            -1 - underlying_event_rank:]
        working_obs = tf.broadcast_to(
            value, tf.concat([batch_shape, observation_event_shape], axis=0))
        # working_obs :: batch_shape observation_event_shape
        r = underlying_event_rank

        # Move index into sequence of observations to front so we can apply
        # tf.foldl
        working_obs = distribution_util.move_dimension(working_obs, -1 - r, 0)
        # working_obs :: num_steps batch_shape underlying_event_shape
        working_obs = tf.expand_dims(working_obs, -1 - r)
        # working_obs :: num_steps batch_shape 1 underlying_event_shape

        observation_probs = observation_distribution.log_prob(working_obs)

        # observation_probs :: num_steps batch_shape num_states

        def forward_step(log_prev_step, log_prob_observation):
            return _log_vector_matrix(log_prev_step,
                                      log_transition) + log_prob_observation

        fwd_prob = tf.foldl(forward_step,
                            observation_probs,
                            initializer=log_init)
        # fwd_prob :: batch_shape num_states

        log_prob = tf.reduce_logsumexp(fwd_prob, axis=-1)
        # log_prob :: batch_shape

        return log_prob
    def test_patching(self, exp, log, expm1, log1p, logsumexp, softplus):

        exp_calls = 0
        expm1_calls = 0
        log_calls = 0
        log1p_calls = 0
        logsumexp_calls = 0
        softplus_calls = 0
        with tfp.experimental.math.patch_manual_special_functions():
            tf.exp(0.)
            exp_calls += 1
            self.assertEqual(exp_calls, exp.call_count)

            tf.math.exp(0.)
            exp_calls += 1
            self.assertEqual(exp_calls, exp.call_count)

            tf.math.log(0.)
            log_calls += 1
            self.assertEqual(log_calls, log.call_count)

            tf.math.expm1(0.)
            expm1_calls += 1
            self.assertEqual(expm1_calls, expm1.call_count)

            tf.math.log1p(0.)
            log1p_calls += 1
            self.assertEqual(log1p_calls, log1p.call_count)

            tf.math.reduce_logsumexp(0.)
            logsumexp_calls += 1
            self.assertEqual(logsumexp_calls, logsumexp.call_count)

            tf.reduce_logsumexp(0.)
            logsumexp_calls += 1
            self.assertEqual(logsumexp_calls, logsumexp.call_count)

            tf.math.softplus(0.)
            softplus_calls += 1
            self.assertEqual(softplus_calls, softplus.call_count)

            tf.nn.softplus(0.)
            softplus_calls += 1
            self.assertEqual(softplus_calls, softplus.call_count)
 def _forward_log_det_jacobian(self, x):
     # This code is similar to tf.math.log_softmax but different because we have
     # an implicit zero column to handle. I.e., instead of:
     #   reduce_sum(logits - reduce_sum(exp(logits), dim))
     # we must do:
     #   log_normalization = 1 + reduce_sum(exp(logits))
     #   -log_normalization + reduce_sum(logits - log_normalization)
     np1 = prefer_static.cast(1 + prefer_static.shape(x)[-1], dtype=x.dtype)
     return (0.5 * prefer_static.log(np1) + tf.reduce_sum(x, axis=-1) -
             np1 * tf.math.softplus(tf.reduce_logsumexp(x, axis=-1)))
Beispiel #25
0
    def log_joint_prob(self, X):
        print("----Tracing-log_joint_prob")
        before_reduce_sum = tf.map_fn(lambda y: self.compute_log_pdf(X, y),
                                      self.y_unique,
                                      fn_output_signature=tf.float32)

        reduce_sum = tf.reduce_logsumexp(
            before_reduce_sum, axis=-1) + tf.expand_dims(
                tf.math.log(self.logits_y + self.stable), axis=1)
        return tf.transpose(reduce_sum)
Beispiel #26
0
 def test1DLarge(self):
     # This test ensures that the operation is correct even when the naive
     # implementation would overflow.
     x = tf.convert_to_tensor(np.arange(20) * 20.0, dtype=tf.float32)
     result_fused = self.evaluate(tfp.math.log_cumsum_exp(x))
     result_map = self.evaluate(
         tf.map_fn(lambda i: tf.reduce_logsumexp(x[:i + 1]),
                   tf.range(tf.shape(x)[0]),
                   dtype=x.dtype))
     self.assertAllClose(result_fused, result_map)
Beispiel #27
0
def encode(x,
           uln0,
           rln0,
           lln0,
           latent_length,
           latent_alphabet_size,
           alphabet_size,
           padded_data_length,
           transfer_mats,
           dtype=tf.float64,
           eps=1e-32):
    """First layer of encoder, using the MuE mean."""

    # Set initial sequence (replace inf with large number)
    vxln = tf.maximum(tf.math.log(x), -1e32)

    # Set insert biases to uniform distribution.
    vcln = -np.log(alphabet_size) * tf.ones_like(vxln)

    # Set deletion and insertion parameters.
    uln = tf.ones((padded_data_length, 2),
                  dtype=dtype) * (uln0 - tf.reduce_logsumexp(uln0))[None, :]
    rln = tf.ones((padded_data_length, 2),
                  dtype=dtype) * (rln0 - tf.reduce_logsumexp(rln0))[None, :]
    lln = lln0 - tf.reduce_logsumexp(lln0, axis=1, keepdims=True)

    # Build HiddenMarkovModel, with one-hot encoded output.
    a0_enc, a_enc, e_enc = make_hmm_params(vxln,
                                           vcln,
                                           uln,
                                           rln,
                                           lln,
                                           transfer_mats,
                                           eps=eps,
                                           dtype=dtype)

    hmm_enc = tfpd.HiddenMarkovModel(tfpd.Categorical(logits=a0_enc),
                                     tfpd.Categorical(logits=a_enc),
                                     tfpd.OneHotCategorical(logits=e_enc),
                                     latent_length)

    return hmm_mean(hmm_enc, latent_length)
Beispiel #28
0
 def neg_log_likelihood(state):
     state_ext = tf.expand_dims(state, 0)
     linear_part = tf.matmul(state_ext, x_data)
     linear_part_ex = tf.stack(
         [tf.zeros_like(linear_part), linear_part], axis=0)
     term1 = tf.squeeze(
         tf.matmul(tf.reduce_logsumexp(linear_part_ex, axis=0), y_data),
         -1)
     term2 = (0.5 * tf.reduce_sum(state_ext * state_ext, axis=-1) -
              tf.reduce_sum(linear_part, axis=-1))
     return tf.squeeze(term1 + term2)
 def _log_prob(self, x):
     x = tf.convert_to_tensor(x, name='x')
     distribution_log_probs = [d.log_prob(x) for d in self.components]
     cat_log_probs = self._cat_probs(log_probs=True)
     final_log_probs = [
         cat_lp + d_lp
         for (cat_lp, d_lp) in zip(cat_log_probs, distribution_log_probs)
     ]
     concat_log_probs = tf.stack(final_log_probs, 0)
     log_sum_exp = tf.reduce_logsumexp(concat_log_probs, axis=[0])
     return log_sum_exp
 def _log_cdf(self, x):
     x = tf.convert_to_tensor(x, name='x')
     distribution_log_cdfs = [d.log_cdf(x) for d in self.components]
     cat_log_probs = self._cat_probs(log_probs=True)
     final_log_cdfs = [
         cat_lp + d_lcdf
         for (cat_lp, d_lcdf) in zip(cat_log_probs, distribution_log_cdfs)
     ]
     concatted_log_cdfs = tf.stack(final_log_cdfs, axis=0)
     mixture_log_cdf = tf.reduce_logsumexp(concatted_log_cdfs, axis=[0])
     return mixture_log_cdf