def _get_permutations(num_results, dims, seed=None):
    """Uniform iid sample from the space of permutations.

  Draws a sample of size `num_results` from the group of permutations of degrees
  specified by the `dims` tensor. These are packed together into one tensor
  such that each row is one sample from each of the dimensions in `dims`. For
  example, if dims = [2,3] and num_results = 2, the result is a tensor of shape
  [2, 2 + 3] and the first row of the result might look like:
  [1, 0, 2, 0, 1]. The first two elements are a permutation over 2 elements
  while the next three are a permutation over 3 elements.

  Args:
    num_results: A positive scalar `Tensor` of integral type. The number of
      draws from the discrete uniform distribution over the permutation groups.
    dims: A 1D `Tensor` of the same dtype as `num_results`. The degree of the
      permutation groups from which to sample.
    seed: (Optional) Python integer to seed the random number generator.

  Returns:
    permutations: A `Tensor` of shape `[num_results, sum(dims)]` and the same
    dtype as `dims`.
  """
    sample_range = tf.range(num_results)
    stream = SeedStream(seed, salt='MCMCSampleHaltonSequence3')

    def generate_one(d):
        seed = stream()
        fn = lambda _: tf.random.shuffle(tf.range(d), seed=seed)
        return tf.map_fn(fn,
                         sample_range,
                         parallel_iterations=1 if seed is not None else 10)

    return tf.concat([generate_one(d) for d in tf.unstack(dims)], axis=-1)
Example #2
0
    def _start_trajectory_batched(self, state, target_log_prob):
        """Computations needed to start a trajectory."""
        with tf.name_scope('start_trajectory_batched'):
            seed_stream = SeedStream(self._seed_stream,
                                     salt='start_trajectory_batched')
            momentum = [
                tf.random.normal(  # pylint: disable=g-complex-comprehension
                    shape=prefer_static.shape(x),
                    dtype=x.dtype,
                    seed=seed_stream()) for x in state
            ]
            init_energy = compute_hamiltonian(target_log_prob, momentum)

            if MULTINOMIAL_SAMPLE:
                return momentum, init_energy, None

            # Draw a slice variable u ~ Uniform(0, p(initial state, initial
            # momentum)) and compute log u. For numerical stability, we perform this
            # in log space where log u = log (u' * p(...)) = log u' + log
            # p(...) and u' ~ Uniform(0, 1).
            log_slice_sample = tf.math.log1p(
                -tf.random.uniform(shape=prefer_static.shape(init_energy),
                                   dtype=init_energy.dtype,
                                   seed=seed_stream()))
            return momentum, init_energy, log_slice_sample
  def default_exchange_proposed_fn_(num_replica, seed=None):
    """Default function for `exchange_proposed_fn` of `kernel`."""
    seed_stream = SeedStream(seed, 'default_exchange_proposed_fn')

    zero_start = tf.random.uniform([], seed=seed_stream()) > 0.5
    if num_replica % 2 == 0:

      def _exchange():
        flat_exchange = tf.range(num_replica)
        if num_replica > 2:
          start = tf.cast(~zero_start, dtype=tf.int32)
          end = num_replica - start
          flat_exchange = flat_exchange[start:end]
        return tf.reshape(flat_exchange, [tf.size(input=flat_exchange) // 2, 2])
    else:

      def _exchange():
        start = tf.cast(zero_start, dtype=tf.int32)
        end = num_replica - tf.cast(~zero_start, dtype=tf.int32)
        flat_exchange = tf.range(num_replica)[start:end]
        return tf.reshape(flat_exchange, [tf.size(input=flat_exchange) // 2, 2])

    def _null_exchange():
      return tf.reshape(tf.cast([], dtype=tf.int32), shape=[0, 2])

    return tf.cond(
        pred=tf.random.uniform([], seed=seed_stream()) < prob_exchange,
        true_fn=_exchange,
        false_fn=_null_exchange)
Example #4
0
    def __init__(self,
                 target_log_prob_fn,
                 step_size,
                 num_leapfrog_steps,
                 state_gradients_are_stopped=False,
                 step_size_update_fn=None,
                 seed=None,
                 store_parameters_in_results=False,
                 name=None):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
        for. Total progress per HMC step is roughly proportional to
        `step_size * num_leapfrog_steps`.
      state_gradients_are_stopped: Python `bool` indicating that the proposed
        new state be run through `tf.stop_gradient`. This is particularly useful
        when combining optimization over samples from the HMC chain.
        Default value: `False` (i.e., do not apply `stop_gradient`).
      step_size_update_fn: Python `callable` taking current `step_size`
        (typically a `tf.Variable`) and `kernel_results` (typically
        `collections.namedtuple`) and returns updated step_size (`Tensor`s).
        Default value: `None` (i.e., do not update `step_size` automatically).
      seed: Python integer to seed the random number generator.
      store_parameters_in_results: If `True`, then `step_size` and
        `num_leapfrog_steps` are written to and read from eponymous fields in
        the kernel results objects returned from `one_step` and
        `bootstrap_results`. This allows wrapper kernels to adjust those
        parameters on the fly. This is incompatible with `step_size_update_fn`,
        which must be set to `None`.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'hmc_kernel').
    """
        if step_size_update_fn and store_parameters_in_results:
            raise ValueError('It is invalid to simultaneously specify '
                             '`step_size_update_fn` and set '
                             '`store_parameters_in_results` to `True`.')
        self._seed_stream = SeedStream(seed, salt='hmc')
        self._impl = metropolis_hastings.MetropolisHastings(
            inner_kernel=UncalibratedHamiltonianMonteCarlo(
                target_log_prob_fn=target_log_prob_fn,
                step_size=step_size,
                num_leapfrog_steps=num_leapfrog_steps,
                state_gradients_are_stopped=state_gradients_are_stopped,
                seed=self._seed_stream(),
                name=name or 'hmc_kernel',
                store_parameters_in_results=store_parameters_in_results),
            seed=self._seed_stream())
        self._parameters = self._impl.inner_kernel.parameters.copy()
        self._parameters['step_size_update_fn'] = step_size_update_fn
        self._parameters['seed'] = seed
Example #5
0
    def __init__(self,
                 target_log_prob_fn,
                 step_size,
                 num_leapfrog_steps,
                 state_gradients_are_stopped=False,
                 seed=None,
                 store_parameters_in_results=False,
                 name=None):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
        for. Total progress per HMC step is roughly proportional to
        `step_size * num_leapfrog_steps`.
      state_gradients_are_stopped: Python `bool` indicating that the proposed
        new state be run through `tf.stop_gradient`. This is particularly useful
        when combining optimization over samples from the HMC chain.
        Default value: `False` (i.e., do not apply `stop_gradient`).
      seed: Python integer to seed the random number generator. Deprecated, pass
        seed to `tfp.mcmc.sample_chain`.
      store_parameters_in_results: If `True`, then `step_size` and
        `num_leapfrog_steps` are written to and read from eponymous fields in
        the kernel results objects returned from `one_step` and
        `bootstrap_results`. This allows wrapper kernels to adjust those
        parameters on the fly.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'hmc_kernel').
    """
        if seed is not None and tf.executing_eagerly():
            # TODO(b/68017812): Re-enable once TFE supports `tf.random.shuffle` seed.
            raise NotImplementedError(
                'Specifying a `seed` when running eagerly is '
                'not currently supported. To run in Eager '
                'mode with a seed, pass the seed to '
                '`tfp.mcmc.sample_chain`.')
        if not store_parameters_in_results:
            mcmc_util.warn_if_parameters_are_not_simple_tensors(
                dict(step_size=step_size,
                     num_leapfrog_steps=num_leapfrog_steps))
        self._seed_stream = SeedStream(seed, salt='uncalibrated_hmc_one_step')
        self._parameters = dict(
            target_log_prob_fn=target_log_prob_fn,
            step_size=step_size,
            num_leapfrog_steps=num_leapfrog_steps,
            state_gradients_are_stopped=state_gradients_are_stopped,
            seed=seed,
            name=name,
            store_parameters_in_results=store_parameters_in_results,
        )
        self._momentum_dtype = None
Example #6
0
 def _inner(seed):
     seed_stream = SeedStream(seed, '_inner')
     x = tf.random.normal(sample_shape,
                          dtype=internal_dtype,
                          seed=seed_stream())
     # This implicitly broadcasts alpha up to sample shape.
     v = 1 + c * x
     return (x, v), v > 0.
Example #7
0
 def randomized_computation(seed):
   seed_stream = SeedStream(seed, 'batched_rejection_sampler')
   proposed_samples, proposed_values = proposal(seed_stream())
   good_samples_mask = tf.less_equal(
       proposed_values * tf.random.uniform(
           proposed_samples.shape, maxval=1., seed=seed_stream()),
       target(proposed_samples))
   return proposed_samples, good_samples_mask
Example #8
0
  def _sample_n(self, n, seed):
    df = tf.convert_to_tensor(self.df)
    batch_shape = self._batch_shape_tensor(df)
    event_shape = self._event_shape_tensor()
    batch_ndims = tf.shape(batch_shape)[0]

    ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
    shape = tf.concat([[n], batch_shape, event_shape], 0)
    stream = SeedStream(seed, salt='Wishart')

    # Complexity: O(nbk**2)
    x = tf.random.normal(
        shape=shape, mean=0., stddev=1., dtype=self.dtype, seed=stream())

    # Complexity: O(nbk)
    # This parameterization is equivalent to Chi2, i.e.,
    # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
    expanded_df = df * tf.ones(
        self._scale.batch_shape_tensor(),
        dtype=dtype_util.base_dtype(df.dtype))

    g = tf.random.gamma(
        shape=[n],
        alpha=self._multi_gamma_sequence(0.5 * expanded_df, self._dimension()),
        beta=0.5,
        dtype=self.dtype,
        seed=stream())

    # Complexity: O(nbk**2)
    x = tf.linalg.band_part(x, -1, 0)  # Tri-lower.

    # Complexity: O(nbk)
    x = tf.linalg.set_diag(x, tf.sqrt(g))

    # Make batch-op ready.
    # Complexity: O(nbk**2)
    perm = tf.concat([tf.range(1, ndims), [0]], 0)
    x = tf.transpose(a=x, perm=perm)
    shape = tf.concat([batch_shape, [event_shape[0]], [event_shape[1] * n]], 0)
    x = tf.reshape(x, shape)

    # Complexity: O(nbM) where M is the complexity of the operator solving a
    # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so
    # this step has complexity O(nbk^3).
    x = self._scale.matmul(x)

    # Undo make batch-op ready.
    # Complexity: O(nbk**2)
    shape = tf.concat([batch_shape, event_shape, [n]], 0)
    x = tf.reshape(x, shape)
    perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0)
    x = tf.transpose(a=x, perm=perm)

    if not self.input_output_cholesky:
      # Complexity: O(nbk**3)
      x = tf.matmul(x, x, adjoint_b=True)

    return x
Example #9
0
 def _sample_n(self, n, seed):
   with tf.compat.v1.control_dependencies(self._runtime_assertions):
     seed = SeedStream(seed, salt="ZeroInflated")
     mask = self.inflated_distribution.sample(n, seed())
     samples = self.count_distribution.sample(n, seed())
     mask, samples = _broadcast_rate(mask, samples)
     # mask = 1 => new_sample = 0
     # mask = 0 => new_sample = sample
     return samples * tf.cast(1 - mask, samples.dtype)
Example #10
0
 def randomized_computation(seed):
   seed_stream = SeedStream(seed, 'batched_rejection_sampler')
   proposed_samples, proposed_values = proposal_fn(seed_stream())
   good_samples_mask = tf.less_equal(
       proposed_values * tf.random.uniform(
           prefer_static.shape(proposed_samples),
           seed=seed_stream(),
           dtype=dtype),
       target_fn(proposed_samples))
   return proposed_samples, good_samples_mask
Example #11
0
    def _sample_n(self, n, seed=None):
        seeds = samplers.split_seed(seed,
                                    n=self.num_components + 1,
                                    salt='Mixture')
        try:
            seed_stream = SeedStream(seed, salt='Mixture')
        except TypeError as e:  # Can happen for Tensor seed.
            seed_stream = None
            seed_stream_err = e

        # This sampling approach is almost the same as the approach used by
        # `MixtureSameFamily`. The differences are due to having a list of
        # `Distribution` objects rather than a single object.
        samples = []
        cat_samples = self.cat.sample(n, seed=seeds[0])

        for c in range(self.num_components):
            try:
                samples.append(self.components[c].sample(n, seed=seeds[c + 1]))
                if seed_stream is not None:
                    seed_stream()
            except TypeError as e:
                if ('Expected int for argument' not in str(e)
                        and TENSOR_SEED_MSG_PREFIX not in str(e)):
                    raise
                if seed_stream is None:
                    raise seed_stream_err
                msg = (
                    'Falling back to stateful sampling for `components[{}]` {} of '
                    'type `{}`. Please update to use `tf.random.stateless_*` RNGs. '
                    'This fallback may be removed after 20-Aug-2020. ({})')
                warnings.warn(
                    msg.format(c, self.components[c].name,
                               type(self.components[c]), str(e)))
                samples.append(self.components[c].sample(n,
                                                         seed=seed_stream()))
        stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
        x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
        # TODO(b/170730865): Is all this masking stuff really called for?
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        mask = tf.one_hot(
            indices=cat_samples,  # [n, B]
            depth=self._num_components,  # == k
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k]
        mask = distribution_util.pad_mixture_dimensions(
            mask, self, self._cat,
            tensorshape_util.rank(
                self._static_event_shape))  # [n, B, k, [1]*e]
        if x.dtype.is_floating:
            masked = tf.math.multiply_no_nan(x, mask)
        else:
            masked = x * mask
        return tf.reduce_sum(masked, axis=stack_axis)  # [n, B, E]
Example #12
0
    def _flat_sample_distributions(self,
                                   sample_shape=(),
                                   seed=None,
                                   value=None):
        """Executes `model`, creating both samples and distributions."""
        ds = []
        values_out = []
        seed = SeedStream(seed, salt='JointDistributionCoroutine')
        gen = self._model_coroutine()
        index = 0
        d = next(gen)
        if self._require_root and not isinstance(d, self.Root):
            raise ValueError('First distribution yielded by coroutine must '
                             'be wrapped in `Root`.')
        try:
            while True:
                actual_distribution = d.distribution if isinstance(
                    d, self.Root) else d
                ds.append(actual_distribution)
                if (value is not None and len(value) > index
                        and value[index] is not None):
                    seed(
                    )  # Ensure reproducibility even when xs are (partially) set.

                    def convert_tree_to_tensor(x, dtype_hint):
                        return tf.convert_to_tensor(x, dtype_hint=dtype_hint)

                    # This signature does not allow kwarg names. Applies
                    # `convert_to_tensor` on the next value.
                    next_value = nest.map_structure_up_to(
                        ds[-1].dtype,  # shallow_tree
                        convert_tree_to_tensor,  # func
                        value[index],  # x
                        ds[-1].dtype)  # dtype_hint
                else:
                    next_value = actual_distribution.sample(
                        sample_shape=sample_shape
                        if isinstance(d, self.Root) else (),
                        seed=seed())

                if self._validate_args:
                    with tf.control_dependencies(
                            self._assert_compatible_shape(
                                index, sample_shape, next_value)):
                        values_out.append(
                            tf.nest.map_structure(tf.identity, next_value))
                else:
                    values_out.append(next_value)

                index += 1
                d = gen.send(next_value)
        except StopIteration:
            pass
        return ds, values_out
Example #13
0
 def _sample_n(self, n, seed=None):
     scale = tf.convert_to_tensor(self.scale)
     shape = tf.concat([[n], tf.shape(scale)], axis=0)
     seed = SeedStream(seed, salt='random_horseshoe')
     local_shrinkage = self._half_cauchy.sample(shape, seed=seed())
     shrinkage = scale * local_shrinkage
     sampled = tf.random.normal(shape=shape,
                                mean=0.,
                                stddev=1.,
                                dtype=scale.dtype,
                                seed=seed())
     return sampled * shrinkage
    def _sample_n(self, n, seed):
        seed = SeedStream(seed, salt='MixtureSameFamily')
        x = self.components_distribution.sample(n, seed=seed())  # [n, B, k, E]

        event_shape = None
        event_ndims = tensorshape_util.rank(self.event_shape)
        if event_ndims is None:
            event_shape = self.components_distribution.event_shape_tensor()
            event_ndims = prefer_static.rank_from_shape(event_shape)
        event_ndims_static = tf.get_static_value(event_ndims)

        num_components = None
        if event_ndims_static is not None:
            num_components = tf.compat.dimension_value(
                x.shape[-1 - event_ndims_static])
        # We could also check if num_components can be computed statically from
        # self.mixture_distribution's logits or probs.
        if num_components is None:
            num_components = tf.shape(x)[-1 - event_ndims]

        # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
        npdt = dtype_util.as_numpy_dtype(x.dtype)
        mask = tf.one_hot(
            indices=self.mixture_distribution.sample(
                n, seed=seed()),  # [n, B] or [n]
            depth=num_components,
            on_value=npdt(1),
            off_value=npdt(0))  # [n, B, k] or [n, k]

        # Pad `mask` to [n, B, k, [1]*e] or [n, [1]*b, k, [1]*e] .
        batch_ndims = prefer_static.rank(x) - event_ndims - 1
        mask_batch_ndims = prefer_static.rank(mask) - 1
        pad_ndims = batch_ndims - mask_batch_ndims
        mask_shape = prefer_static.shape(mask)
        mask = tf.reshape(
            mask,
            shape=prefer_static.concat([
                mask_shape[:-1],
                prefer_static.ones([pad_ndims], dtype=tf.int32),
                mask_shape[-1:],
                prefer_static.ones([event_ndims], dtype=tf.int32),
            ],
                                       axis=0))

        ret = tf.reduce_sum(x * mask, axis=-1 - event_ndims)  # [n, B, E]

        if self._reparameterize:
            if event_shape is None:
                event_shape = self.components_distribution.event_shape_tensor()
            ret = self._reparameterize_sample(ret, event_shape=event_shape)

        return ret
  def __init__(self,
               target_log_prob_fn,
               inverse_temperatures,
               make_kernel_fn,
               exchange_proposed_fn=default_exchange_proposed_fn(1.),
               seed=None,
               name=None):
    """Instantiates this object.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      inverse_temperatures: `1D` `Tensor of inverse temperatures to perform
        samplings with each replica. Must have statically known `shape`.
        `inverse_temperatures[0]` produces the states returned by samplers,
        and is typically == 1.
      make_kernel_fn: Python callable which takes target_log_prob_fn and seed
        args and returns a TransitionKernel instance.
      exchange_proposed_fn: Python callable which take a number of replicas, and
        return combinations of replicas for exchange.
      seed: Python integer to seed the random number generator.
        Default value: `None` (i.e., no seed).
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., "remc_kernel").

    Raises:
      ValueError: `inverse_temperatures` doesn't have statically known 1D shape.
    """
    inverse_temperatures = tf.convert_to_tensor(
        value=inverse_temperatures, name='inverse_temperatures')

    # Note these are static checks, and don't need to be embedded in the graph.
    inverse_temperatures.shape.assert_is_fully_defined()
    inverse_temperatures.shape.assert_has_rank(1)

    self._seed_stream = SeedStream(seed, salt=name)
    self._seeded_mcmc = seed is not None
    self._parameters = dict(
        target_log_prob_fn=target_log_prob_fn,
        inverse_temperatures=inverse_temperatures,
        num_replica=tf.compat.dimension_value(inverse_temperatures.shape[0]),
        exchange_proposed_fn=exchange_proposed_fn,
        seed=seed,
        name=name)
    self.replica_kernels = []
    for i in range(self.num_replica):
      self.replica_kernels.append(
          make_kernel_fn(
              target_log_prob_fn=_replica_log_prob_fn(inverse_temperatures[i],
                                                      target_log_prob_fn),
              seed=self._seed_stream()))
Example #16
0
 def _sample_n(self, n, seed=None):
   seed = SeedStream(seed, 'beta')
   concentration1 = tf.convert_to_tensor(self.concentration1)
   concentration0 = tf.convert_to_tensor(self.concentration0)
   shape = self._batch_shape_tensor(concentration1, concentration0)
   expanded_concentration1 = tf.broadcast_to(concentration1, shape)
   expanded_concentration0 = tf.broadcast_to(concentration0, shape)
   gamma1_sample = tf.random.gamma(
       shape=[n], alpha=expanded_concentration1, dtype=self.dtype, seed=seed())
   gamma2_sample = tf.random.gamma(
       shape=[n], alpha=expanded_concentration0, dtype=self.dtype, seed=seed())
   beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
   return beta_sample
Example #17
0
 def _sample_n(self, n, seed):
     seed = SeedStream(seed, salt="ZeroInflated")
     mask = self.inflated_distribution.sample(n, seed())
     samples = self.count_distribution.sample(n, seed())
     tf.assert_equal(
         tf.rank(samples) >= tf.rank(mask),
         True,
         message=f"Cannot broadcast zero inflated mask of shape {mask.shape} "
         f"to sample shape {samples.shape}")
     samples, mask = _make_broadcastable(samples, mask)
     # mask = 1 => new_sample = 0
     # mask = 0 => new_sample = sample
     return samples * tf.cast(1 - mask, samples.dtype)
Example #18
0
 def _sample_n(self, n, seed=None):
   # Here we use the fact that if:
   # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs)
   # then X ~ Poisson(lam) is Negative Binomially distributed.
   logits = self._logits_parameter_no_checks()
   stream = SeedStream(seed, salt='NegativeBinomial')
   rate = tf.random.gamma(
       shape=[n],
       alpha=self.total_count,
       beta=tf.exp(-logits),
       dtype=self.dtype,
       seed=stream())
   return tf.random.poisson(
       lam=rate, shape=[], dtype=self.dtype, seed=stream())
    def _sample_n(self, n, seed=None):
        seed_stream = SeedStream(seed, 'beta_binomial')

        total_count, concentration1, concentration0 = self._params_list_as_tensors(
        )

        batch_shape_tensor = self.batch_shape_tensor()
        probs = beta.Beta(tf.broadcast_to(concentration1, batch_shape_tensor),
                          concentration0,
                          validate_args=self.validate_args).sample(
                              n, seed=seed_stream())
        return binomial.Binomial(
            total_count, probs=probs,
            validate_args=self.validate_args).sample(seed=seed_stream())
    def resample(log_weights, current_state, particle_info, seed=None):
      """Resample particles based on importance weights."""
      with tf.name_scope('resample_particles'):
        seed = SeedStream(seed, salt='resample_particles')
        resampling_indexes = tf.random.categorical(
            [log_weights], ps.reduce_prod(*ps.shape(log_weights)), seed=seed())
        next_state = tf.nest.map_structure(
            lambda x: tf.reshape(tf.gather(x, resampling_indexes), ps.shape(x)),
            current_state)
        next_particle_info = tf.nest.map_structure(
            lambda x: tf.reshape(tf.gather(x, resampling_indexes), ps.shape(x)),
            particle_info)

        return next_state, next_particle_info
Example #21
0
    def _sample_n(self, n, seed):
        # only for MixtureSameFamilySampleFix
        import warnings
        from tensorflow_probability.python.distributions import independent
        from tensorflow_probability.python.internal import dtype_util
        from tensorflow_probability.python.internal import prefer_static
        from tensorflow_probability.python.internal import samplers
        from tensorflow_probability.python.internal import tensorshape_util
        from tensorflow_probability.python.util.seed_stream import SeedStream
        from tensorflow_probability.python.util.seed_stream import (
            TENSOR_SEED_MSG_PREFIX, )

        components_seed, mix_seed = samplers.split_seed(
            seed, salt="MixtureSameFamily")
        try:
            seed_stream = SeedStream(seed, salt="MixtureSameFamily")
        except TypeError as e:  # Can happen for Tensor seeds.
            seed_stream = None
            seed_stream_err = e
        try:
            mix_sample = self.mixture_distribution.sample(
                n, seed=mix_seed)  # [n, B] or [n]
        except TypeError as e:
            if "Expected int for argument" not in str(
                    e) and TENSOR_SEED_MSG_PREFIX not in str(e):
                raise
            if seed_stream is None:
                raise seed_stream_err
            msg = (
                "Falling back to stateful sampling for `mixture_distribution` "
                "{} of type `{}`. Please update to use `tf.random.stateless_*` "
                "RNGs. This fallback may be removed after 20-Aug-2020. ({})")
            warnings.warn(
                msg.format(
                    self.mixture_distribution.name,
                    type(self.mixture_distribution),
                    str(e),
                ))
            mix_sample = self.mixture_distribution.sample(
                n, seed=seed_stream())  # [n, B] or [n]
        _seed = int(components_seed[0].numpy())
        ret = tf.stack(
            [
                self.components_distribution[i_component.numpy()].sample(
                    seed=_seed + i) for i, i_component in enumerate(mix_sample)
            ],
            axis=0,
        )
        return ret
Example #22
0
    def __init__(self,
                 target_log_prob_fn,
                 new_state_fn=None,
                 seed=None,
                 name=None):
        if new_state_fn is None:
            new_state_fn = random_walk_normal_fn()

        self._target_log_prob_fn = target_log_prob_fn
        self._seed_stream = SeedStream(seed, salt='RandomWalkMetropolis')
        self._name = name
        self._parameters = dict(target_log_prob_fn=target_log_prob_fn,
                                new_state_fn=new_state_fn,
                                seed=seed,
                                name=name)
Example #23
0
def _randomize(coeffs, radixes, seed=None):
    """Applies the Owen (2017) randomization to the coefficients."""
    given_dtype = coeffs.dtype
    coeffs = tf.cast(coeffs, dtype=tf.int32)
    num_coeffs = tf.shape(coeffs)[-1]
    radixes = tf.reshape(tf.cast(radixes, dtype=tf.int32), shape=[-1])
    stream = SeedStream(seed, salt='MCMCSampleHaltonSequence2')
    perms = _get_permutations(num_coeffs, radixes, seed=stream())
    perms = tf.reshape(perms, shape=[-1])
    radix_sum = tf.reduce_sum(radixes)
    radix_offsets = tf.reshape(tf.cumsum(radixes, exclusive=True),
                               shape=[-1, 1])
    offsets = radix_offsets + tf.range(num_coeffs) * radix_sum
    permuted_coeffs = tf.gather(perms, coeffs + offsets)
    return tf.cast(permuted_coeffs, dtype=given_dtype)
Example #24
0
    def __init__(self,
                 target_log_prob_fn,
                 inverse_temperatures,
                 make_kernel_fn,
                 swap_proposal_fn=default_swap_proposal_fn(1.),
                 state_includes_replicas=False,
                 seed=None,
                 validate_args=False,
                 name=None):
        """Instantiates this object.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      inverse_temperatures: `Tensor` of inverse temperatures to temper each
        replica. The leftmost dimension is the `num_replica` and the
        second dimension through the rightmost can provide different temperature
        to different batch members, doing a left-justified broadcast.
      make_kernel_fn: Python callable which takes a `target_log_prob_fn`
        arg and returns a `tfp.mcmc.TransitionKernel` instance. Passing a
        function taking `(target_log_prob_fn, seed)` deprecated but supported
        until 2020-09-20.
      swap_proposal_fn: Python callable which take a number of replicas, and
        returns `swaps`, a shape `[num_replica] + batch_shape` `Tensor`, where
        axis 0 indexes a permutation of `{0,..., num_replica-1}`, designating
        replicas to swap.
      state_includes_replicas: Boolean indicating whether the leftmost dimension
        of each state sample should index replicas. If `True`, the leftmost
        dimension of the `current_state` kwarg to `tfp.mcmc.sample_chain` will
        be interpreted as indexing replicas.
      seed: Python integer to seed the random number generator. Deprecated, pass
        seed to `tfp.mcmc.sample_chain`. Default value: `None` (i.e., no seed).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., "remc_kernel").

    Raises:
      ValueError: `inverse_temperatures` doesn't have statically known 1D shape.
    """
        self._parameters = {k: v for k, v in locals().items() if v is not self}
        self._state_includes_replicas = state_includes_replicas
        self._seed_stream = SeedStream(seed, salt='replica_mc')
Example #25
0
    def __init__(self,
                 target_log_prob_fn,
                 new_state_fn=None,
                 seed=None,
                 name=None):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      new_state_fn: Python callable which takes a list of state parts and a
        seed; returns a same-type `list` of `Tensor`s, each being a perturbation
        of the input state parts. The perturbation distribution is assumed to be
        a symmetric distribution centered at the input state part.
        Default value: `None` which is mapped to
          `tfp.mcmc.random_walk_normal_fn()`.
      seed: Python integer to seed the random number generator. Deprecated, pass
        seed to `tfp.mcmc.sample_chain`.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'rwm_kernel').

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) at each result step. Has same shape as
        `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      ValueError: if there isn't one `scale` or a list with same length as
        `current_state`.
    """
        if new_state_fn is None:
            new_state_fn = random_walk_normal_fn()

        seed_stream = SeedStream(seed, salt='rwm')
        mh_kwargs = {} if seed is None else dict(seed=seed_stream())
        uncal_kwargs = {} if seed is None else dict(seed=seed_stream())
        self._impl = metropolis_hastings.MetropolisHastings(
            inner_kernel=UncalibratedRandomWalk(
                target_log_prob_fn=target_log_prob_fn,
                new_state_fn=new_state_fn,
                name=name,
                **uncal_kwargs),
            **mh_kwargs)
def make_rwmh_kernel_fn(target_log_prob_fn, init_state, scalings, seed=None):
    """Generate a Random Walk MH kernel."""
    with tf.name_scope('make_rwmh_kernel_fn'):
        seed = SeedStream(seed, salt='make_rwmh_kernel_fn')
        state_std = [
            tf.math.reduce_std(x, axis=0, keepdims=True) for x in init_state
        ]
        step_size = [
            s * ps.cast(  # pylint: disable=g-complex-comprehension
                mcmc_util.left_justified_expand_dims_like(scalings, s),
                s.dtype) for s in state_std
        ]
        return random_walk_metropolis.RandomWalkMetropolis(
            target_log_prob_fn,
            new_state_fn=random_walk_metropolis.random_walk_normal_fn(
                scale=step_size),
            seed=seed)
Example #27
0
 def _sample_n(self, n, seed=None):
     concentration = tf.convert_to_tensor(self.concentration)
     mixing_concentration = tf.convert_to_tensor(self.mixing_concentration)
     mixing_rate = tf.convert_to_tensor(self.mixing_rate)
     seed = SeedStream(seed, 'gamma_gamma')
     rate = tf.random.gamma(
         shape=[n],
         # Be sure to draw enough rates for the fully-broadcasted gamma-gamma.
         alpha=mixing_concentration + tf.zeros_like(concentration),
         beta=mixing_rate,
         dtype=self.dtype,
         seed=seed())
     return tf.random.gamma(shape=[],
                            alpha=concentration,
                            beta=rate,
                            dtype=self.dtype,
                            seed=seed())
    def _sample_n(self, n, seed=None):
        # Like with the univariate Student's t, sampling can be implemented as a
        # ratio of samples from a multivariate gaussian with the appropriate
        # covariance matrix and a sample from the chi-squared distribution.
        seed = SeedStream(seed, salt='multivariate t')

        loc = tf.broadcast_to(self.loc, self._sample_shape())
        mvn = mvn_linear_operator.MultivariateNormalLinearOperator(
            loc=tf.zeros_like(loc), scale=self.scale)
        normal_samp = mvn.sample(n, seed=seed())

        df = tf.broadcast_to(self.df, self.batch_shape_tensor())
        chi2 = chi2_lib.Chi2(df=df)
        chi2_samp = chi2.sample(n, seed=seed())

        return (
            self._loc +
            normal_samp * tf.math.rsqrt(chi2_samp / self._df)[..., tf.newaxis])
Example #29
0
    def _sample_n(self, n, seed=None):
        distribution0 = self._get_distribution0()

        if self._num_steps is not None:
            num_steps = tf.convert_to_tensor(self._num_steps)
            num_steps_static = tf.get_static_value(num_steps)
        else:
            num_steps_static = tensorshape_util.num_elements(
                distribution0.event_shape)
            if num_steps_static is None:
                num_steps = tf.reduce_prod(distribution0.event_shape_tensor())

        stateless_seed = samplers.sanitize_seed(seed, salt='Autoregressive')
        stateful_seed = None
        try:
            samples = distribution0.sample(n, seed=stateless_seed)
            is_stateful_sampler = False
        except TypeError as e:
            if ('Expected int for argument' not in str(e)
                    and TENSOR_SEED_MSG_PREFIX not in str(e)):
                raise
            msg = (
                'Falling back to stateful sampling for `distribution_fn(sample0)` of '
                'type `{}`. Please update to use `tf.random.stateless_*` RNGs. '
                'This fallback may be removed after 20-Aug-2020. ({})')
            warnings.warn(
                msg.format(distribution0.name, type(distribution0), str(e)))
            stateful_seed = SeedStream(seed, salt='Autoregressive')()
            samples = distribution0.sample(n, seed=stateful_seed)
            is_stateful_sampler = True

        seed = stateful_seed if is_stateful_sampler else stateless_seed

        if num_steps_static is not None:
            for _ in range(num_steps_static):
                # pylint: disable=not-callable
                samples = self.distribution_fn(samples).sample(seed=seed)
        else:
            # pylint: disable=not-callable
            samples = tf.foldl(
                lambda s, _: self.distribution_fn(s).sample(seed=seed),
                elems=tf.range(0, num_steps),
                initializer=samples)
        return samples
    def _flat_sample_distributions(self,
                                   sample_shape=(),
                                   seed=None,
                                   value=None):
        """Executes `model`, creating both samples and distributions."""
        ds = []
        values_out = []
        seed = SeedStream('JointDistributionCoroutine', seed)
        gen = self._model()
        index = 0
        d = next(gen)
        if not isinstance(d, self.Root):
            raise ValueError('First distribution yielded by coroutine must '
                             'be wrapped in `Root`.')
        try:
            while True:
                actual_distribution = d.distribution if isinstance(
                    d, self.Root) else d
                ds.append(actual_distribution)
                if (value is not None and len(value) > index
                        and value[index] is not None):
                    seed()
                    next_value = value[index]
                else:
                    next_value = actual_distribution.sample(
                        sample_shape=sample_shape
                        if isinstance(d, self.Root) else (),
                        seed=seed())

                if self._validate_args:
                    with tf.control_dependencies(
                            self._assert_compatible_shape(
                                index, sample_shape, next_value)):
                        values_out.append(
                            tf.nest.map_structure(tf.identity, next_value))
                else:
                    values_out.append(next_value)

                index += 1
                d = gen.send(next_value)
        except StopIteration:
            pass
        return ds, values_out