Пример #1
0
 def _mean(self):
     concentration = tf.convert_to_tensor(self.concentration)
     lim = tf.ones([], dtype=self.dtype)
     valid = concentration < lim
     safe_conc = tf.where(valid, concentration, tf.constant(.5, self.dtype))
     result = lambda: self.loc + self.scale / (1 - safe_conc)
     if self.allow_nan_stats:
         return tf.where(valid, result(),
                         tf.constant(float('nan'), self.dtype))
     with tf.control_dependencies([
             assert_util.assert_less(
                 concentration,
                 lim,
                 message='`mean` is undefined when `concentration >= 1`')
     ]):
         return result()
Пример #2
0
 def _multi_gamma_sequence(self, a, p, name="multi_gamma_sequence"):
     """Creates sequence used in multivariate (di)gamma; shape = shape(a)+[p]."""
     with self._name_and_control_scope(name):
         # Linspace only takes scalars, so we'll add in the offset afterwards.
         seq = tf.linspace(tf.constant(0., dtype=self.dtype), 0.5 - 0.5 * p,
                           tf.cast(p, tf.int32))
         return seq + tf.expand_dims(a, [-1])
def maybe_assert_bernoulli_param_correctness(is_init, validate_args, probs,
                                             probits):
    """Return assertions for `ProbitBernoulli`-type distributions."""
    if is_init:
        x, name = (probs, 'probs') if probits is None else (probits, 'probits')
        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

    if not validate_args:
        return []

    assertions = []

    if probs is not None:
        if is_init != tensor_util.is_ref(probs):
            probs = tf.convert_to_tensor(probs)
            one = tf.constant(1., probs.dtype)
            assertions += [
                assert_util.assert_non_negative(
                    probs, message='probs has components less than 0.'),
                assert_util.assert_less_equal(
                    probs, one, message='probs has components greater than 1.')
            ]

    return assertions
Пример #4
0
def _ndtr(x):
    """Implements ndtr core logic."""
    half_sqrt_2 = tf.constant(0.5 * np.sqrt(2.),
                              dtype=x.dtype,
                              name="half_sqrt_2")
    w = x * half_sqrt_2
    z = tf.abs(w)
    y = tf.where(z < half_sqrt_2, 1. + tf.math.erf(w),
                 tf.where(w > 0., 2. - tf.math.erfc(z), tf.math.erfc(z)))
    return 0.5 * y
Пример #5
0
 def _variance(self):
     concentration = tf.convert_to_tensor(self.concentration)
     lim = tf.constant(.5, self.dtype)
     valid = concentration < lim
     safe_conc = tf.where(valid, concentration,
                          tf.constant(.25, self.dtype))
     result = lambda: self.scale**2 / ((1 - safe_conc)**2 *
                                       (1 - 2 * safe_conc))
     if self.allow_nan_stats:
         return tf.where(valid, result(),
                         tf.constant(float('nan'), self.dtype))
     with tf.control_dependencies([
             assert_util.assert_less(
                 concentration,
                 lim,
                 message=
                 '`variance` is undefined when `concentration >= 0.5`')
     ]):
         return result()
Пример #6
0
    def _forward_log_det_jacobian(self, x):
        # is_constant_jacobian = True for this bijector, hence the
        # `log_det_jacobian` need only be specified for a single input, as this will
        # be tiled to match `event_ndims`.
        if self.scale is None:
            return tf.constant(0., dtype=dtype_util.base_dtype(x.dtype))

        with tf.control_dependencies(self._maybe_collect_assertions() if self.
                                     validate_args else []):
            return self.scale.log_abs_determinant()
Пример #7
0
 def _maybe_assert_valid_y(self, y):
     if not self.validate_args:
         return []
     is_positive = assert_util.assert_non_negative(
         y, message='Inverse transformation input must be greater than 0.')
     less_than_one = assert_util.assert_less_equal(
         y,
         tf.constant(1., y.dtype),
         message=
         'Inverse transformation input must be less than or equal to 1.')
     return [is_positive, less_than_one]
Пример #8
0
 def _log_prob(self, x):
     scale = tf.convert_to_tensor(self.scale)
     concentration = tf.convert_to_tensor(self.concentration)
     z = self._z(x, scale, concentration)
     eq_zero = tf.equal(concentration,
                        0)  # Concentration = 0 ==> Exponential.
     nonzero_conc = tf.where(eq_zero, tf.constant(1, self.dtype),
                             concentration)
     where_nonzero = (1 / nonzero_conc + 1) * tf.math.log1p(
         nonzero_conc * z)
     return -tf.math.log(scale) - tf.where(eq_zero, z, where_nonzero)
Пример #9
0
 def _log_cdf(self, x):
     scale = tf.convert_to_tensor(self.scale)
     concentration = tf.convert_to_tensor(self.concentration)
     z = self._z(x, scale, concentration)
     eq_zero = tf.equal(concentration,
                        0)  # Concentration = 0 ==> Exponential.
     nonzero_conc = tf.where(eq_zero, tf.constant(1, self.dtype),
                             concentration)
     where_nonzero = tf.math.log1p(-(1 + nonzero_conc * z)**(-1 /
                                                             nonzero_conc))
     where_zero = tf.math.log1p(-tf.exp(-z))
     return tf.where(eq_zero, where_zero, where_nonzero)
Пример #10
0
    def _variance(self):
        concentration = tf.convert_to_tensor(self.concentration)
        scale = tf.convert_to_tensor(self.scale)
        var = (tf.square(scale) / tf.square(concentration - 1.) /
               (concentration - 2.))
        if self.allow_nan_stats:
            assertions = []
        else:
            assertions = [
                assert_util.assert_less(
                    tf.constant(2., dtype=self.dtype),
                    concentration,
                    message='variance undefined when any concentration <= 2')
            ]

        with tf.control_dependencies(assertions):
            return tf.where(concentration > 2., var,
                            dtype_util.as_numpy_dtype(self.dtype)(np.nan))
Пример #11
0
 def _sample_n(self, n, seed=None):
     # Inversion samples via inverse CDF.
     loc = tf.convert_to_tensor(self.loc)
     scale = tf.convert_to_tensor(self.scale)
     concentration = tf.convert_to_tensor(self.concentration)
     # Pre-broadcast to ensure we draw enough randomness.
     sample_shp = tf.concat(
         [[n],
          self._batch_shape_tensor(
              loc=loc, scale=scale, concentration=concentration)],
         axis=0)
     logu = tf.math.log1p(
         -tf.random.uniform(sample_shp, dtype=self.dtype, seed=seed))
     eq_zero = tf.equal(concentration, 0)
     safe_conc = tf.where(eq_zero, tf.constant(1, dtype=self.dtype),
                          concentration)
     where_nonzero = loc + scale / safe_conc * tf.math.expm1(
         -safe_conc * logu)
     where_zero = loc - scale * logu
     return tf.where(eq_zero, where_zero, where_nonzero)
Пример #12
0
    def _forward(self, x):
        x = tf.convert_to_tensor(x, name='x')
        batch_shape = prefer_static.shape(x)[:-1]

        # Pad zeros on the top row and right column.
        y = fill_triangular.FillTriangular().forward(x)
        rank = prefer_static.rank(y)
        paddings = tf.concat([
            tf.zeros(shape=(rank - 2, 2), dtype=tf.int32),
            tf.constant([[1, 0], [0, 1]], dtype=tf.int32)
        ],
                             axis=0)
        y = tf.pad(y, paddings)

        # Set diagonal to 1s.
        n = prefer_static.shape(y)[-1]
        diag = tf.ones(tf.concat([batch_shape, [n]], axis=-1), dtype=x.dtype)
        y = tf.linalg.set_diag(y, diag)

        # Normalize each row to have Euclidean (L2) norm 1.
        y /= tf.norm(y, axis=-1)[..., tf.newaxis]
        return y
Пример #13
0
def maybe_assert_negative_binomial_param_correctness(is_init, validate_args,
                                                     total_count, probs,
                                                     logits):
    """Return assertions for `NegativeBinomial`-type distributions."""
    if is_init:
        x, name = (probs, 'probs') if logits is None else (logits, 'logits')
        if not dtype_util.is_floating(x.dtype):
            raise TypeError(
                'Argument `{}` must having floating type.'.format(name))

    if not validate_args:
        return []

    assertions = []
    if is_init != tensor_util.is_ref(total_count):
        total_count = tf.convert_to_tensor(total_count)
        assertions.extend([
            assert_util.assert_non_negative(
                total_count,
                message='`total_count` has components less than 0.'),
            distribution_util.assert_integer_form(
                total_count,
                message='`total_count` has fractional components.')
        ])
    if probs is not None:
        if is_init != tensor_util.is_ref(probs):
            probs = tf.convert_to_tensor(probs)
            one = tf.constant(1., probs.dtype)
            assertions.extend([
                assert_util.assert_non_negative(
                    probs, message='`probs` has components less than 0.'),
                assert_util.assert_less_equal(
                    probs,
                    one,
                    message='`probs` has components greater than 1.')
            ])

    return assertions
    def __init__(self,
                 distribution,
                 bijector,
                 batch_shape=None,
                 event_shape=None,
                 kwargs_split_fn=_default_kwargs_split_fn,
                 validate_args=False,
                 parameters=None,
                 name=None):
        """Construct a Transformed Distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      bijector: The object responsible for calculating the transformation.
        Typically an instance of `Bijector`.
      batch_shape: `integer` vector `Tensor` which overrides `distribution`
        `batch_shape`; valid only if `distribution.is_scalar_batch()`.
      event_shape: `integer` vector `Tensor` which overrides `distribution`
        `event_shape`; valid only if `distribution.is_scalar_event()`.
      kwargs_split_fn: Python `callable` which takes a kwargs `dict` and returns
        a tuple of kwargs `dict`s for each of the `distribution` and `bijector`
        parameters respectively.
        Default value: `_default_kwargs_split_fn` (i.e.,
            `lambda kwargs: (kwargs.get('distribution_kwargs', {}),
                             kwargs.get('bijector_kwargs', {}))`)
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      parameters: Locals dict captured by subclass constructor, to be used for
        copy/slice re-instantiation operations.
      name: Python `str` name prefixed to Ops created by this class. Default:
        `bijector.name + distribution.name`.
    """
        parameters = dict(locals()) if parameters is None else parameters
        name = name or (("" if bijector is None else bijector.name) +
                        (distribution.name or ""))
        with tf.name_scope(name) as name:
            self._kwargs_split_fn = (_default_kwargs_split_fn
                                     if kwargs_split_fn is None else
                                     kwargs_split_fn)
            # For convenience we define some handy constants.
            self._zero = tf.constant(0, dtype=tf.int32, name="zero")
            self._empty = tf.constant([], dtype=tf.int32, name="empty")

            # We will keep track of a static and dynamic version of
            # self._is_{batch,event}_override. This way we can do more prior to graph
            # execution, including possibly raising Python exceptions.

            self._override_batch_shape = self._maybe_validate_shape_override(
                batch_shape, distribution.is_scalar_batch(), validate_args,
                "batch_shape")
            self._is_batch_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_batch_shape),
                    self._zero))
            self._is_maybe_batch_override = bool(
                tf.get_static_value(self._override_batch_shape) is None
                or tf.get_static_value(self._override_batch_shape).size != 0)

            self._override_event_shape = self._maybe_validate_shape_override(
                event_shape, distribution.is_scalar_event(), validate_args,
                "event_shape")
            self._is_event_override = prefer_static.logical_not(
                prefer_static.equal(
                    prefer_static.rank_from_shape(self._override_event_shape),
                    self._zero))
            self._is_maybe_event_override = bool(
                tf.get_static_value(self._override_event_shape) is None
                or tf.get_static_value(self._override_event_shape).size != 0)

            # To convert a scalar distribution into a multivariate distribution we
            # will draw dims from the sample dims, which are otherwise iid. This is
            # easy to do except in the case that the base distribution has batch dims
            # and we're overriding event shape. When that case happens the event dims
            # will incorrectly be to the left of the batch dims. In this case we'll
            # cyclically permute left the new dims.
            self._needs_rotation = prefer_static.reduce_all([
                self._is_event_override,
                prefer_static.logical_not(self._is_batch_override),
                prefer_static.logical_not(distribution.is_scalar_batch())
            ])
            override_event_ndims = prefer_static.rank_from_shape(
                self._override_event_shape)
            self._rotate_ndims = _pick_scalar_condition(
                self._needs_rotation, override_event_ndims, 0)
            # We'll be reducing the head dims (if at all), i.e., this will be []
            # if we don't need to reduce.
            self._reduce_event_indices = prefer_static.range(
                self._rotate_ndims - override_event_ndims, self._rotate_ndims)

        self._distribution = distribution
        self._bijector = bijector
        super(TransformedDistribution, self).__init__(
            dtype=self._distribution.dtype,
            reparameterization_type=self._distribution.reparameterization_type,
            validate_args=validate_args,
            allow_nan_stats=self._distribution.allow_nan_stats,
            parameters=parameters,
            name=name)
Пример #15
0
 def _inverse_log_det_jacobian(self, y):
     return tf.constant(0., dtype=dtype_util.base_dtype(y.dtype))
Пример #16
0
 def _forward_log_det_jacobian(self, x):
     return tf.constant(0., dtype=dtype_util.base_dtype(x.dtype))
Пример #17
0
 def _forward_log_det_jacobian(self, x):
     return tf.constant(0., x.dtype)
Пример #18
0
 def _event_shape_tensor(self):
     return tf.constant([], dtype=tf.int32)
Пример #19
0
 def _event_shape_tensor(self):
     return tf.constant([self.dimension, self.dimension], dtype=tf.int32)
Пример #20
0
 def _event_shape_tensor(self, loc=None):
     del loc
     return tf.constant([], dtype=tf.int32)
 def _inverse_log_det_jacobian(self, y):
     return tf.constant(0., dtype=y.dtype)
Пример #22
0
def _ndtri(p):
    """Implements ndtri core logic."""

    # Constants used in piece-wise rational approximations. Taken from the cephes
    # library:
    # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
    p0 = list(
        reversed([
            -5.99633501014107895267E1, 9.80010754185999661536E1,
            -5.66762857469070293439E1, 1.39312609387279679503E1,
            -1.23916583867381258016E0
        ]))
    q0 = list(
        reversed([
            1.0, 1.95448858338141759834E0, 4.67627912898881538453E0,
            8.63602421390890590575E1, -2.25462687854119370527E2,
            2.00260212380060660359E2, -8.20372256168333339912E1,
            1.59056225126211695515E1, -1.18331621121330003142E0
        ]))
    p1 = list(
        reversed([
            4.05544892305962419923E0, 3.15251094599893866154E1,
            5.71628192246421288162E1, 4.40805073893200834700E1,
            1.46849561928858024014E1, 2.18663306850790267539E0,
            -1.40256079171354495875E-1, -3.50424626827848203418E-2,
            -8.57456785154685413611E-4
        ]))
    q1 = list(
        reversed([
            1.0, 1.57799883256466749731E1, 4.53907635128879210584E1,
            4.13172038254672030440E1, 1.50425385692907503408E1,
            2.50464946208309415979E0, -1.42182922854787788574E-1,
            -3.80806407691578277194E-2, -9.33259480895457427372E-4
        ]))
    p2 = list(
        reversed([
            3.23774891776946035970E0, 6.91522889068984211695E0,
            3.93881025292474443415E0, 1.33303460815807542389E0,
            2.01485389549179081538E-1, 1.23716634817820021358E-2,
            3.01581553508235416007E-4, 2.65806974686737550832E-6,
            6.23974539184983293730E-9
        ]))
    q2 = list(
        reversed([
            1.0, 6.02427039364742014255E0, 3.67983563856160859403E0,
            1.37702099489081330271E0, 2.16236993594496635890E-1,
            1.34204006088543189037E-2, 3.28014464682127739104E-4,
            2.89247864745380683936E-6, 6.79019408009981274425E-9
        ]))

    def _create_polynomial(var, coeffs):
        """Compute n_th order polynomial via Horner's method."""
        coeffs = np.array(coeffs, dtype_util.as_numpy_dtype(var.dtype))
        if not coeffs.size:
            return tf.zeros_like(var)
        return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var

    maybe_complement_p = tf.where(p > -np.expm1(-2.), 1. - p, p)
    # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
    # later on. The result from the computation when p == 0 is not used so any
    # number that doesn't result in NaNs is fine.
    sanitized_mcp = tf.where(maybe_complement_p <= 0.,
                             dtype_util.as_numpy_dtype(p.dtype)(0.5),
                             maybe_complement_p)

    # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
    w = sanitized_mcp - 0.5
    ww = w**2
    x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) /
                                _create_polynomial(ww, q0))
    x_for_big_p *= -np.sqrt(2. * np.pi)

    # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
    # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
    # arrays based on whether p < exp(-32).
    z = tf.sqrt(-2. * tf.math.log(sanitized_mcp))
    first_term = z - tf.math.log(z) / z
    second_term_small_p = (_create_polynomial(1. / z, p2) /
                           _create_polynomial(1. / z, q2) / z)
    second_term_otherwise = (_create_polynomial(1. / z, p1) /
                             _create_polynomial(1. / z, q1) / z)
    x_for_small_p = first_term - second_term_small_p
    x_otherwise = first_term - second_term_otherwise

    x = tf.where(sanitized_mcp > np.exp(-2.), x_for_big_p,
                 tf.where(z >= 8.0, x_for_small_p, x_otherwise))

    x = tf.where(p > 1. - np.exp(-2.), x, -x)
    infinity_scalar = tf.constant(np.inf, dtype=p.dtype)
    x_nan_replaced = tf.where(p <= 0.0, -infinity_scalar,
                              tf.where(p >= 1.0, infinity_scalar, x))
    return x_nan_replaced
Пример #23
0
    def __init__(self,
                 df,
                 loc=None,
                 scale_identity_multiplier=None,
                 scale_diag=None,
                 scale_tril=None,
                 scale_perturb_factor=None,
                 scale_perturb_diag=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="VectorStudentT"):
        """Instantiates the vector Student's t-distributions on `R^k`.

    The `batch_shape` is the broadcast between `df.batch_shape` and
    `Affine.batch_shape` where `Affine` is constructed from `loc` and
    `scale_*` arguments.

    The `event_shape` is the event shape of `Affine.event_shape`.

    Args:
      df: Floating-point `Tensor`. The degrees of freedom of the
        distribution(s). `df` must contain only positive values. Must be
        scalar if `loc`, `scale_*` imply non-scalar batch_shape or must have the
        same `batch_shape` implied by `loc`, `scale_*`.
      loc: Floating-point `Tensor`. If this is set to `None`, no `loc` is
        applied.
      scale_identity_multiplier: floating point rank 0 `Tensor` representing a
        scaling done to the identity matrix. When `scale_identity_multiplier =
        scale_diag=scale_tril = None` then `scale += IdentityMatrix`. Otherwise
        no scaled-identity-matrix is added to `scale`.
      scale_diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ..., k], which represents a k x k
        diagonal matrix. When `None` no diagonal term is added to `scale`.
      scale_tril: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ..., k, k], which represents a k x k
        lower triangular matrix. When `None` no `scale_tril` term is added to
        `scale`. The upper triangular elements above the diagonal are ignored.
      scale_perturb_factor: Floating-point `Tensor` representing factor matrix
        with last two dimensions of shape `(k, r)`. When `None`, no rank-r
        update is added to `scale`.
      scale_perturb_diag: Floating-point `Tensor` representing the diagonal
        matrix. `scale_perturb_diag` has shape [N1, N2, ..., r], which
        represents an r x r Diagonal matrix. When `None` low rank updates will
        take the form `scale_perturb_factor * scale_perturb_factor.T`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        args = [
            df, loc, scale_identity_multiplier, scale_diag, scale_tril,
            scale_perturb_factor, scale_perturb_diag
        ]
        with tf.name_scope(name) as name:
            with tf.name_scope("init"):
                dtype = dtype_util.common_dtype(args, tf.float32)
                df = tf.convert_to_tensor(df, name="df", dtype=dtype)
                # The shape of the _VectorStudentT distribution is governed by the
                # relationship between df.batch_shape and affine.batch_shape. In
                # pseudocode the basic procedure is:
                #   if df.batch_shape is scalar:
                #     if affine.batch_shape is not scalar:
                #       # broadcast distribution.sample so
                #       # it has affine.batch_shape.
                #     self.batch_shape = affine.batch_shape
                #   else:
                #     if affine.batch_shape is scalar:
                #       # let affine broadcasting do its thing.
                #     self.batch_shape = df.batch_shape
                # All of the above magic is actually handled by TransformedDistribution.
                # Here we really only need to collect the affine.batch_shape and decide
                # what we're going to pass in to TransformedDistribution's
                # (override) batch_shape arg.
                affine = affine_bijector.Affine(
                    shift=loc,
                    scale_identity_multiplier=scale_identity_multiplier,
                    scale_diag=scale_diag,
                    scale_tril=scale_tril,
                    scale_perturb_factor=scale_perturb_factor,
                    scale_perturb_diag=scale_perturb_diag,
                    validate_args=validate_args,
                    dtype=dtype)
                distribution = student_t.StudentT(
                    df=df,
                    loc=tf.zeros([], dtype=affine.dtype),
                    scale=tf.ones([], dtype=affine.dtype))
                batch_shape, override_event_shape = (
                    distribution_util.shapes_from_loc_and_scale(
                        affine.shift, affine.scale))
                override_batch_shape = distribution_util.pick_vector(
                    distribution.is_scalar_batch(), batch_shape,
                    tf.constant([], dtype=tf.int32))
                super(_VectorStudentT,
                      self).__init__(distribution=distribution,
                                     bijector=affine,
                                     batch_shape=override_batch_shape,
                                     event_shape=override_event_shape,
                                     validate_args=validate_args,
                                     name=name)
                self._parameters = parameters
Пример #24
0
    def one_step(self, current_state, previous_kernel_results, seed=None):
        """Takes one step of the TransitionKernel.
    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: Optional, a seed for reproducible sampling.
    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

        with tf.name_scope(mcmc_util.make_name(self.name, 'tmc', 'one_step')):
            # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
            inverse_temperatures = tf.convert_to_tensor(
                previous_kernel_results.post_tempering_inverse_temperatures,
                name='inverse_temperatures')

            steps_at_temperature = tf.convert_to_tensor(
                previous_kernel_results.steps_at_temperature,
                name='number of steps')

            target_score_for_inner_kernel = partial(self.target_score_fn,
                                                    sigma=inverse_temperatures)
            target_log_prob_for_inner_kernel = partial(
                self.target_log_prob_fn, sigma=inverse_temperatures)

            try:
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel,
                    target_score_for_inner_kernel, inverse_temperatures)
            except TypeError as e:
                if 'argument' not in str(e):
                    raise
                warnings.warn(
                    'The `seed` argument to `ReplicaExchangeMC`s `make_kernel_fn` is '
                    'deprecated. `TransitionKernel` instances now receive seeds via '
                    '`one_step`.')
                inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
                    target_log_prob_for_inner_kernel,
                    target_score_for_inner_kernel, inverse_temperatures,
                    self._seed_stream())

            if seed is not None:
                seed = samplers.sanitize_seed(seed)
                inner_seed, swap_seed, logu_seed = samplers.split_seed(
                    seed, n=3, salt='tmc_one_step')
                inner_kwargs = dict(seed=inner_seed)
            else:
                if self._seed_stream.original_seed is not None:
                    warnings.warn(mcmc_util.SEED_CTOR_ARG_DEPRECATION_MSG)
                inner_kwargs = {}
                swap_seed, logu_seed = samplers.split_seed(self._seed_stream())

            if mcmc_util.is_list_like(current_state):
                # We *always* canonicalize the states in the kernel results.
                states = current_state
            else:
                states = [current_state]
            print(states)
            [
                new_state,
                pre_tempering_results,
            ] = inner_kernel.one_step(
                states, previous_kernel_results.post_tempering_results,
                **inner_kwargs)

            # Now that we have run one step, we consider maybe lowering the temperature
            # Proposed new temperature
            proposed_inverse_temperatures = tf.clip_by_value(
                self.gamma * inverse_temperatures, self.min_temp, 1e6)
            dtype = inverse_temperatures.dtype

            # We will lower the temperature if this new proposed step is compatible with
            # a temperature swap
            v = new_state[0] - states[0]
            cs = states[0]

            @jax.vmap
            def integrand(t):
                return jnp.sum(self._parameters['target_score_fn'](
                    t * v + cs, inverse_temperatures) * v,
                               axis=-1)

            delta_logp1 = simps(integrand, 0., 1.,
                                self._parameters['num_delta_logp_steps'])

            # Now we compute the reverse
            v = -v
            cs = new_state[0]

            @jax.vmap
            def integrand(t):
                return jnp.sum(self._parameters['target_score_fn'](
                    t * v + cs, proposed_inverse_temperatures) * v,
                               axis=-1)

            delta_logp2 = simps(integrand, 0., 1.,
                                self._parameters['num_delta_logp_steps'])

            log_accept_ratio = (delta_logp1 + delta_logp2)

            log_accept_ratio = tf.where(tf.math.is_finite(log_accept_ratio),
                                        log_accept_ratio,
                                        tf.constant(-np.inf, dtype=dtype))

            # Produce Log[Uniform] draws that are identical at swapped indices.
            log_uniform = tf.math.log(
                samplers.uniform(shape=log_accept_ratio.shape,
                                 dtype=dtype,
                                 seed=logu_seed))

            is_tempering_accepted_mask = tf.less(
                log_uniform,
                log_accept_ratio,
                name='is_tempering_accepted_mask')

            is_min_steps_satisfied = tf.greater(
                steps_at_temperature,
                self.min_steps_per_temp * tf.ones_like(steps_at_temperature),
                name='is_min_steps_satisfied')

            # Only propose tempering if the chain was going to accept this point anyway
            is_tempering_accepted_mask = tf.math.logical_and(
                is_tempering_accepted_mask, pre_tempering_results.is_accepted)

            is_tempering_accepted_mask = tf.math.logical_and(
                is_tempering_accepted_mask, is_min_steps_satisfied)

            # Updating accepted inverse temperatures
            post_tempering_inverse_temperatures = mcmc_util.choose(
                is_tempering_accepted_mask, proposed_inverse_temperatures,
                inverse_temperatures)

            steps_at_temperature = mcmc_util.choose(
                is_tempering_accepted_mask,
                tf.zeros_like(steps_at_temperature), steps_at_temperature + 1)

            # Invalidating and recomputing results
            [
                new_target_log_prob,
                new_grads_target_log_prob,
            ] = mcmc_util.maybe_call_fn_and_grads(
                partial(self.target_log_prob_fn,
                        sigma=post_tempering_inverse_temperatures), new_state)

            # Updating inner kernel results
            post_tempering_results = pre_tempering_results._replace(
                proposed_results=tf.convert_to_tensor(np.nan, dtype=dtype),
                proposed_state=tf.convert_to_tensor(np.nan, dtype=dtype),
            )

            if isinstance(post_tempering_results.accepted_results,
                          hmc.UncalibratedHamiltonianMonteCarloKernelResults):
                post_tempering_results = post_tempering_results._replace(
                    accepted_results=post_tempering_results.accepted_results.
                    _replace(target_log_prob=new_target_log_prob,
                             grads_target_log_prob=new_grads_target_log_prob))
            elif isinstance(
                    post_tempering_results.accepted_results,
                    random_walk_metropolis.UncalibratedRandomWalkResults):
                post_tempering_results = post_tempering_results._replace(
                    accepted_results=post_tempering_results.accepted_results.
                    _replace(target_log_prob=new_target_log_prob))
            else:
                # TODO(b/143702650) Handle other kernels.
                raise NotImplementedError(
                    'Only HMC and RWMH Kernels are handled at this time. Please file a '
                    'request with the TensorFlow Probability team.')

            new_kernel_results = TemperedMCKernelResults(
                pre_tempering_results=pre_tempering_results,
                post_tempering_results=post_tempering_results,
                pre_tempering_inverse_temperatures=inverse_temperatures,
                post_tempering_inverse_temperatures=
                post_tempering_inverse_temperatures,
                tempering_log_accept_ratio=log_accept_ratio,
                steps_at_temperature=steps_at_temperature,
                seed=samplers.zeros_seed() if seed is None else seed,
            )

            return new_state[0], new_kernel_results
Пример #25
0
 def _inverse_log_det_jacobian(self, y):
     # is_constant_jacobian = True for this bijector, hence the
     # `log_det_jacobian` need only be specified for a single input, as this will
     # be tiled to match `event_ndims`.
     return tf.constant(0., dtype=dtype_util.base_dtype(y.dtype))