Пример #1
0
    def default_exchange_proposed_fn_(num_replica, seed=None):
        """Default function for `exchange_proposed_fn` of `kernel`."""
        num_replica = tf.to_int32(num_replica)

        seed = distributions_util.gen_new_seed(seed,
                                               'default_exchange_proposed_fn')
        random_uniform = tf.random_uniform([], seed=seed)
        accept_proposed_exchange = random_uniform < probs

        seed = distributions_util.gen_new_seed(seed,
                                               'default_exchange_proposed_fn')
        zero_start = tf.random_uniform([], seed=seed) > 0.5
        if num_replica % 2 == 0:
            exchange_proposed = tf.where(
                zero_start, tf.range(num_replica),
                tf.sparse_to_dense(tf.range(num_replica - 2), (num_replica, ),
                                   tf.range(1, num_replica - 1)))
            exchange_proposed_n = tf.where(zero_start, num_replica // 2,
                                           num_replica // 2 - 1)
        else:
            exchange_proposed = tf.where(zero_start, tf.range(num_replica - 1),
                                         tf.range(1, num_replica))
            exchange_proposed_n = num_replica // 2

        exchange_proposed = tf.reshape(exchange_proposed,
                                       (num_replica // 2, 2))
        exchange_proposed = tf.where(accept_proposed_exchange,
                                     exchange_proposed,
                                     tf.zeros_like(exchange_proposed))
        exchange_proposed_n = tf.where(accept_proposed_exchange,
                                       exchange_proposed_n,
                                       tf.zeros_like(exchange_proposed_n))
        return exchange_proposed, exchange_proposed_n
Пример #2
0
  def default_exchange_proposed_fn_(num_replica, seed=None):
    """Default function for `exchange_proposed_fn` of `kernel`."""
    num_replica = tf.to_int32(num_replica)

    seed = distributions_util.gen_new_seed(seed, 'default_exchange_proposed_fn')
    random_uniform = tf.random_uniform([], seed=seed)
    accept_proposed_exchange = random_uniform < probs

    seed = distributions_util.gen_new_seed(seed, 'default_exchange_proposed_fn')
    zero_start = tf.random_uniform([], seed=seed) > 0.5
    if num_replica % 2 == 0:
      exchange_proposed = tf.where(
          zero_start, tf.range(num_replica),
          tf.sparse_to_dense(tf.range(num_replica - 2), (num_replica,),
                             tf.range(1, num_replica - 1)))
      exchange_proposed_n = tf.where(zero_start, num_replica // 2,
                                     num_replica // 2 - 1)
    else:
      exchange_proposed = tf.where(
          zero_start, tf.range(num_replica - 1), tf.range(1, num_replica))
      exchange_proposed_n = num_replica // 2

    exchange_proposed = tf.reshape(exchange_proposed, (num_replica // 2, 2))
    exchange_proposed = tf.where(accept_proposed_exchange, exchange_proposed,
                                 tf.zeros_like(exchange_proposed))
    exchange_proposed_n = tf.where(accept_proposed_exchange,
                                   exchange_proposed_n,
                                   tf.zeros_like(exchange_proposed_n))
    return exchange_proposed, exchange_proposed_n
Пример #3
0
 def _sample_n(self, n, seed=None):
     # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
     # ids as a [n]-shaped vector.
     batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32)
                   if self.batch_shape.is_fully_defined() else
                   math_ops.reduce_prod(self.batch_shape_tensor()))
     ids = self._mixture_distribution.sample(
         sample_shape=concat_vectors([n],
                                     distribution_util.pick_vector(
                                         self.is_scalar_batch(),
                                         np.int32([]), [batch_size])),
         seed=distribution_util.gen_new_seed(
             seed, "poisson_lognormal_quadrature_compound"))
     # Stride `quadrature_polynomial_degree` for `batch_size` number of times.
     offset = math_ops.range(start=0,
                             limit=batch_size * self._degree,
                             delta=self._degree,
                             dtype=ids.dtype)
     ids += offset
     rate = array_ops.gather(
         array_ops.reshape(self.distribution.rate, shape=[-1]), ids)
     rate = array_ops.reshape(rate,
                              shape=concat_vectors(
                                  [n], self.batch_shape_tensor()))
     return random_ops.random_poisson(lam=rate,
                                      shape=[],
                                      dtype=self.dtype,
                                      seed=seed)
  def one_step(self, current_state, previous_kernel_results):
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'rwm', 'one_step'),
        values=[self.seed,
                current_state,
                previous_kernel_results.target_log_prob]):
      with tf.name_scope('initialize'):
        current_state_parts = (list(current_state)
                               if mcmc_util.is_list_like(current_state)
                               else [current_state])
        current_state_parts = [tf.convert_to_tensor(s, name='current_state')
                               for s in current_state_parts]

      self._seed_stream = distributions_util.gen_new_seed(
          self._seed_stream, salt='rwm_kernel_proposal')
      new_state_fn = self.new_state_fn
      next_state_parts = new_state_fn(current_state_parts, self._seed_stream)
      # Compute `target_log_prob` so its available to MetropolisHastings.
      next_target_log_prob = self.target_log_prob_fn(*next_state_parts)

      def maybe_flatten(x):
        return x if mcmc_util.is_list_like(current_state) else x[0]

      return [
          maybe_flatten(next_state_parts),
          UncalibratedRandomWalkResults(
              log_acceptance_correction=tf.zeros(
                  shape=tf.shape(next_target_log_prob),
                  dtype=next_target_log_prob.dtype.base_dtype),
              target_log_prob=next_target_log_prob,
          ),
      ]
Пример #5
0
    def _apply_variational_kernel(self, inputs):
        if (not isinstance(self.kernel_posterior, tfd.Independent)
                or not isinstance(self.kernel_posterior.distribution,
                                  tfd.Normal)):
            raise TypeError(
                '`DenseFlipout` requires '
                '`kernel_posterior_fn` produce an instance of '
                '`tf.distributions.Independent(tf.distributions.Normal)` '
                '(saw: \"{}\").'.format(self.kernel_posterior.name))
        self.kernel_posterior_affine = tfd.Normal(
            loc=tf.zeros_like(self.kernel_posterior.distribution.loc),
            scale=self.kernel_posterior.distribution.scale)
        self.kernel_posterior_affine_tensor = (self.kernel_posterior_tensor_fn(
            self.kernel_posterior_affine))
        self.kernel_posterior_tensor = None

        input_shape = tf.shape(inputs)
        batch_shape = input_shape[:-1]

        sign_input = tfp_layers_util.random_sign(input_shape,
                                                 dtype=inputs.dtype,
                                                 seed=self.seed)
        sign_output = tfp_layers_util.random_sign(
            tf.concat([batch_shape, tf.expand_dims(self.units, 0)], 0),
            dtype=inputs.dtype,
            seed=distribution_util.gen_new_seed(self.seed,
                                                salt='dense_flipout'))
        perturbed_inputs = self._matmul(
            inputs * sign_input,
            self.kernel_posterior_affine_tensor) * sign_output

        outputs = self._matmul(inputs, self.kernel_posterior.distribution.loc)
        outputs += perturbed_inputs
        return outputs
Пример #6
0
 def body(i, next_replica_idx):
   """`tf.while_loop` body."""
   ratio = (
       sampled_replica_ratios[next_replica_idx[exchange_proposed[i, 0]]]
       - sampled_replica_ratios[next_replica_idx[exchange_proposed[i, 1]]])
   ratio *= (
       self.inverse_temperatures[exchange_proposed[i, 1]]
       - self.inverse_temperatures[exchange_proposed[i, 0]])
   self._seed_stream = distributions_util.gen_new_seed(
       self._seed_stream, salt='replica_exchange_one_step')
   log_uniform = tf.log(tf.random_uniform(
       shape=tf.shape(ratio),
       dtype=ratio.dtype.base_dtype,
       seed=self._seed_stream))
   exchange = log_uniform < ratio
   exchange_op = tf.sparse_to_dense(
       [exchange_proposed[i, 0], exchange_proposed[i, 1]],
       [self.num_replica],
       [next_replica_idx[exchange_proposed[i, 1]] -
        next_replica_idx[exchange_proposed[i, 0]],
        next_replica_idx[exchange_proposed[i, 0]] -
        next_replica_idx[exchange_proposed[i, 1]]])
   next_replica_idx = tf.cond(exchange,
                              lambda: next_replica_idx + exchange_op,
                              lambda: next_replica_idx)
   return [i + 1, next_replica_idx]
Пример #7
0
 def body(i, next_replica_idx):
     """`tf.while_loop` body."""
     ratio = (sampled_replica_ratios[next_replica_idx[
         exchange_proposed[i, 0]]] - sampled_replica_ratios[
             next_replica_idx[exchange_proposed[i, 1]]])
     ratio *= (self.inverse_temperatures[exchange_proposed[i, 1]] -
               self.inverse_temperatures[exchange_proposed[i, 0]])
     self._seed_stream = distributions_util.gen_new_seed(
         self._seed_stream, salt='replica_exchange_one_step')
     log_uniform = tf.log(
         tf.random_uniform(shape=tf.shape(ratio),
                           dtype=ratio.dtype.base_dtype,
                           seed=self._seed_stream))
     exchange = log_uniform < ratio
     exchange_op = tf.sparse_to_dense(
         [exchange_proposed[i, 0], exchange_proposed[i, 1]],
         [self.num_replica], [
             next_replica_idx[exchange_proposed[i, 1]] -
             next_replica_idx[exchange_proposed[i, 0]],
             next_replica_idx[exchange_proposed[i, 0]] -
             next_replica_idx[exchange_proposed[i, 1]]
         ])
     next_replica_idx = tf.cond(
         exchange, lambda: next_replica_idx + exchange_op,
         lambda: next_replica_idx)
     return [i + 1, next_replica_idx]
Пример #8
0
 def _sample_n(self, n, seed=None):
   # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
   # ids as a [n]-shaped vector.
   batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32)
                 if self.batch_shape.is_fully_defined()
                 else math_ops.reduce_prod(self.batch_shape_tensor()))
   ids = self._mixture_distribution.sample(
       sample_shape=concat_vectors(
           [n],
           distribution_util.pick_vector(
               self.is_scalar_batch(),
               np.int32([]),
               [batch_size])),
       seed=distribution_util.gen_new_seed(
           seed, "poisson_lognormal_quadrature_compound"))
   # Stride `quadrature_degree` for `batch_size` number of times.
   offset = math_ops.range(start=0,
                           limit=batch_size * len(self.quadrature_probs),
                           delta=len(self.quadrature_probs),
                           dtype=ids.dtype)
   ids += offset
   rate = array_ops.gather(
       array_ops.reshape(self.distribution.rate, shape=[-1]), ids)
   rate = array_ops.reshape(
       rate, shape=concat_vectors([n], self.batch_shape_tensor()))
   return random_ops.random_poisson(
       lam=rate, shape=[], dtype=self.dtype, seed=seed)
Пример #9
0
  def _apply_variational_kernel(self, inputs):
    if (not isinstance(self.kernel_posterior, tfd.Independent) or
        not isinstance(self.kernel_posterior.distribution, tfd.Normal)):
      raise TypeError(
          '`DenseFlipout` requires '
          '`kernel_posterior_fn` produce an instance of '
          '`tf.distributions.Independent(tf.distributions.Normal)` '
          '(saw: \"{}\").'.format(self.kernel_posterior.name))
    self.kernel_posterior_affine = tfd.Normal(
        loc=tf.zeros_like(self.kernel_posterior.distribution.loc),
        scale=self.kernel_posterior.distribution.scale)
    self.kernel_posterior_affine_tensor = (
        self.kernel_posterior_tensor_fn(self.kernel_posterior_affine))
    self.kernel_posterior_tensor = None

    input_shape = tf.shape(inputs)
    batch_shape = input_shape[:-1]

    sign_input = random_rademacher(
        input_shape,
        dtype=inputs.dtype,
        seed=self.seed)
    sign_output = random_rademacher(
        tf.concat([batch_shape,
                   tf.expand_dims(self.units, 0)], 0),
        dtype=inputs.dtype,
        seed=distribution_util.gen_new_seed(
            self.seed, salt='dense_flipout'))
    perturbed_inputs = self._matmul(
        inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output

    outputs = self._matmul(inputs, self.kernel_posterior.distribution.loc)
    outputs += perturbed_inputs
    return outputs
Пример #10
0
 def generate_one(d):
   seed[0] = distributions_util.gen_new_seed(
       seed[0], salt='mcmc_sample_halton_sequence_4')
   fn = lambda _: tf.random_shuffle(tf.range(d), seed=seed[0])
   return tf.map_fn(
       fn,
       sample_range,
       parallel_iterations=1 if seed[0] is not None else seed[0])
Пример #11
0
 def _sample_n(self, n, seed=None):
     if seed is None:
         seed = distribution_util.gen_new_seed(
             seed=np.random.randint(2**32 - 1), salt="autoregressive")
     samples = self.distribution0.sample(n, seed=seed)
     for _ in range(self._num_steps):
         samples = self.distribution_fn(samples).sample(seed=seed)
     return samples
Пример #12
0
 def generate_one(d):
   seed[0] = distributions_util.gen_new_seed(
       seed[0], salt='mcmc_sample_halton_sequence_4')
   fn = lambda _: tf.random_shuffle(tf.range(d), seed=seed[0])
   return tf.map_fn(
       fn,
       sample_range,
       parallel_iterations=1 if seed[0] is not None else 10)
Пример #13
0
  def _sample_n(self, n, seed):
    batch_shape = self.batch_shape_tensor()
    event_shape = self.event_shape_tensor()
    batch_ndims = array_ops.shape(batch_shape)[0]

    ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
    shape = array_ops.concat([[n], batch_shape, event_shape], 0)

    # Complexity: O(nbk**2)
    x = random_ops.random_normal(shape=shape,
                                 mean=0.,
                                 stddev=1.,
                                 dtype=self.dtype,
                                 seed=seed)

    # Complexity: O(nbk)
    # This parametrization is equivalent to Chi2, i.e.,
    # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
    g = random_ops.random_gamma(shape=[n],
                                alpha=self._multi_gamma_sequence(
                                    0.5 * self.df, self.dimension),
                                beta=0.5,
                                dtype=self.dtype,
                                seed=distribution_util.gen_new_seed(
                                    seed, "wishart"))

    # Complexity: O(nbk**2)
    x = array_ops.matrix_band_part(x, -1, 0)  # Tri-lower.

    # Complexity: O(nbk)
    x = array_ops.matrix_set_diag(x, math_ops.sqrt(g))

    # Make batch-op ready.
    # Complexity: O(nbk**2)
    perm = array_ops.concat([math_ops.range(1, ndims), [0]], 0)
    x = array_ops.transpose(x, perm)
    shape = array_ops.concat([batch_shape, [event_shape[0]], [-1]], 0)
    x = array_ops.reshape(x, shape)

    # Complexity: O(nbM) where M is the complexity of the operator solving a
    # vector system. E.g., for OperatorPDDiag, each matmul is O(k**2), so
    # this complexity is O(nbk**2). For OperatorPDCholesky, each matmul is
    # O(k^3) so this step has complexity O(nbk^3).
    x = self.scale_operator_pd.sqrt_matmul(x)

    # Undo make batch-op ready.
    # Complexity: O(nbk**2)
    shape = array_ops.concat([batch_shape, event_shape, [n]], 0)
    x = array_ops.reshape(x, shape)
    perm = array_ops.concat([[ndims - 1], math_ops.range(0, ndims - 1)], 0)
    x = array_ops.transpose(x, perm)

    if not self.cholesky_input_output_matrices:
      # Complexity: O(nbk^3)
      x = math_ops.matmul(x, x, adjoint_b=True)

    return x
Пример #14
0
 def _sample_n(self, n, seed=None):
   if seed is None:
     seed = distribution_util.gen_new_seed(
         seed=np.random.randint(2**32 - 1),
         salt="autoregressive")
   samples = self.distribution0.sample(n, seed=seed)
   for _ in range(self._num_steps):
     samples = self.distribution_fn(samples).sample(seed=seed)
   return samples
 def _fn(state_parts, seed):
   next_state_parts = []
   for state in state_parts:
     # Mutate seed with each use.
     seed = distributions_util.gen_new_seed(
         seed, salt='random_walk_cauchy_increment')
     next_state_parts.append(state + cauchy.sample(
         sample_shape=state.shape, seed=seed))
   return next_state_parts
Пример #16
0
    def _sample_n(self, n, seed):
        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        batch_ndims = array_ops.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = array_ops.concat([[n], batch_shape, event_shape], 0)

        # Complexity: O(nbk**2)
        x = random_ops.random_normal(shape=shape,
                                     mean=0.,
                                     stddev=1.,
                                     dtype=self.dtype,
                                     seed=seed)

        # Complexity: O(nbk)
        # This parametrization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        g = random_ops.random_gamma(
            shape=[n],
            alpha=self._multi_gamma_sequence(0.5 * self.df, self.dimension),
            beta=0.5,
            dtype=self.dtype,
            seed=distribution_util.gen_new_seed(seed, "wishart"))

        # Complexity: O(nbk**2)
        x = array_ops.matrix_band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = array_ops.matrix_set_diag(x, math_ops.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk**2)
        perm = array_ops.concat([math_ops.range(1, ndims), [0]], 0)
        x = array_ops.transpose(x, perm)
        shape = array_ops.concat([batch_shape, [event_shape[0]], [-1]], 0)
        x = array_ops.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system. E.g., for OperatorPDDiag, each matmul is O(k**2), so
        # this complexity is O(nbk**2). For OperatorPDCholesky, each matmul is
        # O(k^3) so this step has complexity O(nbk^3).
        x = self.scale_operator_pd.sqrt_matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = array_ops.concat([batch_shape, event_shape, [n]], 0)
        x = array_ops.reshape(x, shape)
        perm = array_ops.concat([[ndims - 1], math_ops.range(0, ndims - 1)], 0)
        x = array_ops.transpose(x, perm)

        if not self.cholesky_input_output_matrices:
            # Complexity: O(nbk^3)
            x = math_ops.matmul(x, x, adjoint_b=True)

        return x
Пример #17
0
 def _fn(state_parts, seed):
     next_state_parts = []
     for state in state_parts:
         # Mutate seed with each use.
         seed = distributions_util.gen_new_seed(
             seed, salt='random_walk_cauchy_increment')
         next_state_parts.append(
             state +
             cauchy.sample(sample_shape=state.shape, seed=seed))
     return next_state_parts
Пример #18
0
    def __init__(self,
                 target_log_prob_fn,
                 inverse_temperatures,
                 make_kernel_fn,
                 exchange_proposed_fn=default_exchange_proposed_fn(1.),
                 seed=None,
                 name=None,
                 **kwargs):
        """Instantiates this object.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      inverse_temperatures: sequence of inverse temperatures to perform
        samplings with each replica. Must have statically known `rank` and
        statically known leading shape, i.e.,
        `inverse_temperatures.shape[0].value is not None`
      make_kernel_fn: Python callable which takes target_log_prob_fn and seed
        args and returns a TransitionKernel instance.
      exchange_proposed_fn: Python callable which take a number of replicas, and
        return combinations of replicas for exchange and a number of
        combinations.
      seed: Python integer to seed the random number generator.
        Default value: `None` (i.e., no seed).
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., "remc_kernel").
      **kwargs: Arguments for `make_kernel_fn`.

    Raises:
      ValueError: if `inverse_temperatures` doesn't have statically known rank
        and statically known leading shape
    """
        if inverse_temperatures.shape.ndims is None or \
           inverse_temperatures.shape[0].value is None:
            raise ValueError(
                '"inverse_temperatures" must have statically known rank '
                'and statically known leading shape')
        self._seed_stream = seed  # This will be mutated with use.
        self._parameters = dict(target_log_prob_fn=target_log_prob_fn,
                                inverse_temperatures=inverse_temperatures,
                                num_replica=inverse_temperatures.shape[0],
                                exchange_proposed_fn=exchange_proposed_fn,
                                seed=seed,
                                name=name)
        self.replica_kernels = []
        for i in range(self.num_replica):
            self._seed_stream = distributions_util.gen_new_seed(
                self._seed_stream, salt='replica_kernels')
            self.replica_kernels.append(
                make_kernel_fn(target_log_prob_fn=_replica_log_prob_fn(
                    inverse_temperatures[i], target_log_prob_fn),
                               seed=self._seed_stream))
Пример #19
0
 def _sample_n(self, n, seed=None):
     # Here we use the fact that if:
     # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs)
     # then X ~ Poisson(lam) is Negative Binomially distributed.
     rate = random_ops.random_gamma(shape=[n],
                                    alpha=self.total_count,
                                    beta=math_ops.exp(-self.logits),
                                    dtype=self.dtype,
                                    seed=seed)
     return random_ops.random_poisson(rate,
                                      shape=[],
                                      dtype=self.dtype,
                                      seed=distribution_util.gen_new_seed(
                                          seed, "negative_binom"))
Пример #20
0
 def _sample_n(self, n, seed=None):
   # Here we use the fact that if:
   # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs)
   # then X ~ Poisson(lam) is Negative Binomially distributed.
   rate = random_ops.random_gamma(
       shape=[n],
       alpha=self.total_count,
       beta=math_ops.exp(-self.logits),
       dtype=self.dtype,
       seed=seed)
   return random_ops.random_poisson(
       rate,
       shape=[],
       dtype=self.dtype,
       seed=distribution_util.gen_new_seed(seed, "negative_binom"))
Пример #21
0
def _randomize(coeffs, radixes, seed=None):
  """Applies the Owen (2017) randomization to the coefficients."""
  given_dtype = coeffs.dtype
  coeffs = tf.to_int32(coeffs)
  num_coeffs = tf.shape(coeffs)[-1]
  radixes = tf.reshape(tf.to_int32(radixes), shape=[-1])
  seed = distributions_util.gen_new_seed(
      seed, salt='mcmc_sample_halton_sequence_3')
  perms = _get_permutations(num_coeffs, radixes, seed=seed)
  perms = tf.reshape(perms, shape=[-1])
  radix_sum = tf.reduce_sum(radixes)
  radix_offsets = tf.reshape(tf.cumsum(radixes, exclusive=True),
                             shape=[-1, 1])
  offsets = radix_offsets + tf.range(num_coeffs) * radix_sum
  permuted_coeffs = tf.gather(perms, coeffs + offsets)
  return tf.cast(permuted_coeffs, dtype=given_dtype)
Пример #22
0
 def _sample_n(self, n, seed=None):
     expanded_concentration1 = array_ops.ones_like(
         self.total_concentration, dtype=self.dtype) * self.concentration1
     expanded_concentration0 = array_ops.ones_like(
         self.total_concentration, dtype=self.dtype) * self.concentration0
     gamma1_sample = random_ops.random_gamma(shape=[n],
                                             alpha=expanded_concentration1,
                                             dtype=self.dtype,
                                             seed=seed)
     gamma2_sample = random_ops.random_gamma(
         shape=[n],
         alpha=expanded_concentration0,
         dtype=self.dtype,
         seed=distribution_util.gen_new_seed(seed, "beta"))
     beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
     return beta_sample
Пример #23
0
def _randomize(coeffs, radixes, seed=None):
  """Applies the Owen (2017) randomization to the coefficients."""
  given_dtype = coeffs.dtype
  coeffs = tf.to_int32(coeffs)
  num_coeffs = tf.shape(coeffs)[-1]
  radixes = tf.reshape(tf.to_int32(radixes), shape=[-1])
  seed = distributions_util.gen_new_seed(
      seed, salt='mcmc_sample_halton_sequence_3')
  perms = _get_permutations(num_coeffs, radixes, seed=seed)
  perms = tf.reshape(perms, shape=[-1])
  radix_sum = tf.reduce_sum(radixes)
  radix_offsets = tf.reshape(tf.cumsum(radixes, exclusive=True),
                             shape=[-1, 1])
  offsets = radix_offsets + tf.range(num_coeffs) * radix_sum
  permuted_coeffs = tf.gather(perms, coeffs + offsets)
  return tf.cast(permuted_coeffs, dtype=given_dtype)
Пример #24
0
  def __init__(self, target_log_prob_fn, inverse_temperatures,
               make_kernel_fn,
               exchange_proposed_fn=default_exchange_proposed_fn(1.),
               seed=None, name=None, **kwargs):
    """Instantiates this object.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      inverse_temperatures: sequence of inverse temperatures to perform
        samplings with each replica. Must have statically known `rank` and
        statically known leading shape, i.e.,
        `inverse_temperatures.shape[0].value is not None`
      make_kernel_fn: Python callable which takes target_log_prob_fn and seed
        args and returns a TransitionKernel instance.
      exchange_proposed_fn: Python callable which take a number of replicas, and
        return combinations of replicas for exchange and a number of
        combinations.
      seed: Python integer to seed the random number generator.
        Default value: `None` (i.e., no seed).
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., "remc_kernel").
      **kwargs: Arguments for `make_kernel_fn`.

    Raises:
      ValueError: if `inverse_temperatures` doesn't have statically known rank
        and statically known leading shape
    """
    if inverse_temperatures.shape.ndims is None or \
       inverse_temperatures.shape[0].value is None:
      raise ValueError('"inverse_temperatures" must have statically known rank '
                       'and statically known leading shape')
    self._seed_stream = seed  # This will be mutated with use.
    self._parameters = dict(target_log_prob_fn=target_log_prob_fn,
                            inverse_temperatures=inverse_temperatures,
                            num_replica=inverse_temperatures.shape[0],
                            exchange_proposed_fn=exchange_proposed_fn,
                            seed=seed, name=name)
    self.replica_kernels = []
    for i in range(self.num_replica):
      self._seed_stream = distributions_util.gen_new_seed(
          self._seed_stream, salt='replica_kernels')
      self.replica_kernels.append(make_kernel_fn(
          target_log_prob_fn=_replica_log_prob_fn(
              inverse_temperatures[i], target_log_prob_fn),
          seed=self._seed_stream))
Пример #25
0
 def _sample_n(self, n, seed=None):
   expanded_concentration1 = array_ops.ones_like(
       self.total_concentration, dtype=self.dtype) * self.concentration1
   expanded_concentration0 = array_ops.ones_like(
       self.total_concentration, dtype=self.dtype) * self.concentration0
   gamma1_sample = random_ops.random_gamma(
       shape=[n],
       alpha=expanded_concentration1,
       dtype=self.dtype,
       seed=seed)
   gamma2_sample = random_ops.random_gamma(
       shape=[n],
       alpha=expanded_concentration0,
       dtype=self.dtype,
       seed=distribution_util.gen_new_seed(seed, "beta"))
   beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
   return beta_sample
Пример #26
0
 def _sample_n(self, n, seed=None):
     n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
     k = self.event_shape_tensor()[0]
     unnormalized_logits = array_ops.reshape(math_ops.log(
         random_ops.random_gamma(shape=[n],
                                 alpha=self.concentration,
                                 dtype=self.dtype,
                                 seed=seed)),
                                             shape=[-1, k])
     draws = random_ops.multinomial(logits=unnormalized_logits,
                                    num_samples=n_draws,
                                    seed=distribution_util.gen_new_seed(
                                        seed, salt="dirichlet_multinomial"))
     x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2)
     final_shape = array_ops.concat(
         [[n], self.batch_shape_tensor(), [k]], 0)
     return array_ops.reshape(x, final_shape)
 def _sample_n(self, n, seed=None):
   n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32)
   k = self.event_shape_tensor()[0]
   unnormalized_logits = array_ops.reshape(
       math_ops.log(random_ops.random_gamma(
           shape=[n],
           alpha=self.concentration,
           dtype=self.dtype,
           seed=seed)),
       shape=[-1, k])
   draws = random_ops.multinomial(
       logits=unnormalized_logits,
       num_samples=n_draws,
       seed=distribution_util.gen_new_seed(seed, salt="dirichlet_multinomial"))
   x = math_ops.reduce_sum(array_ops.one_hot(draws, depth=k), -2)
   final_shape = array_ops.concat([[n], self.batch_shape_tensor(), [k]], 0)
   return array_ops.reshape(x, final_shape)
Пример #28
0
def flipout_dense(units, inputs, loc, scale, sample, seed=None):
    outputs = tf.matmul(inputs, loc)
    if sample:
        kernel_noise = reparametrize(0., scale, sample=True)
        input_shape = tf.shape(inputs)
        batch_shape = input_shape[:-1]

        sign_input = random_sign(input_shape, dtype=inputs.dtype, seed=seed)
        sign_output = random_sign(
            tf.concat([batch_shape, tf.expand_dims(units, 0)], 0),
            dtype=inputs.dtype,
            seed=distribution_util.gen_new_seed(seed, salt="dense_flipout"))

        perturbed_inputs = tf.matmul(inputs * sign_input,
                                     kernel_noise) * sign_output
        outputs += perturbed_inputs
    return outputs
Пример #29
0
 def _sample_n(self, n, seed=None):
   # The sampling method comes from the fact that if:
   #   X ~ Normal(0, 1)
   #   Z ~ Chi2(df)
   #   Y = X / sqrt(Z / df)
   # then:
   #   Y ~ StudentT(df).
   shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
   normal_sample = random_ops.random_normal(shape, dtype=self.dtype, seed=seed)
   df = self.df * array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype)
   gamma_sample = random_ops.random_gamma(
       [n],
       0.5 * df,
       beta=0.5,
       dtype=self.dtype,
       seed=distribution_util.gen_new_seed(seed, salt="student_t"))
   samples = normal_sample * math_ops.rsqrt(gamma_sample / df)
   return samples * self.scale + self.loc  # Abs(scale) not wanted.
Пример #30
0
 def _sample_n(self, n, seed=None):
   # The sampling method comes from the fact that if:
   #   X ~ Normal(0, 1)
   #   Z ~ Chi2(df)
   #   Y = X / sqrt(Z / df)
   # then:
   #   Y ~ StudentT(df).
   shape = array_ops.concat([[n], self.batch_shape_tensor()], 0)
   normal_sample = random_ops.random_normal(shape, dtype=self.dtype, seed=seed)
   df = self.df * array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype)
   gamma_sample = random_ops.random_gamma(
       [n],
       0.5 * df,
       beta=0.5,
       dtype=self.dtype,
       seed=distribution_util.gen_new_seed(seed, salt="student_t"))
   samples = normal_sample * math_ops.rsqrt(gamma_sample / df)
   return samples * self.scale + self.loc  # Abs(scale) not wanted.
Пример #31
0
def flipout_conv(filters,
                 rank,
                 inputs,
                 loc,
                 scale,
                 sample,
                 conv_op,
                 seed=None):
    outputs = conv_op(inputs, loc)
    if sample:
        kernel_noise = reparametrize(0., scale, sample=True)
        input_shape = tf.shape(inputs)
        output_shape = tf.shape(outputs)
        batch_shape = tf.expand_dims(input_shape[0], 0)
        channels = input_shape[-1]

        sign_input = random_sign(tf.concat(
            [batch_shape, tf.expand_dims(channels, 0)], 0),
                                 dtype=inputs.dtype,
                                 seed=seed)
        sign_output = random_sign(
            tf.concat([batch_shape, tf.expand_dims(filters, 0)], 0),
            dtype=inputs.dtype,
            seed=distribution_util.gen_new_seed(seed, salt="conv_flipout"))
        for _ in range(rank):
            sign_input = tf.expand_dims(sign_input, 1)
            sign_output = tf.expand_dims(sign_output, 1)

        sign_input = tf.tile(sign_input,
                             [1] + [input_shape[i + 1]
                                    for i in range(rank)] + [1])
        sign_output = tf.tile(sign_output,
                              [1] + [output_shape[i + 1]
                                     for i in range(rank)] + [1])

        perturbed_inputs = conv_op(inputs * sign_input,
                                   kernel_noise) * sign_output

        outputs += perturbed_inputs
    return outputs
  def _fn(state_parts, seed):
    """Adds a normal perturbation to the input state.

    Args:
      state_parts: A list of `Tensor`s of any shape and real dtype representing
        the state parts of the `current_state` of the Markov chain.
      seed: `int` or None. The random seed for this `Op`. If `None`, no seed is
        applied.
        Default value: `None`.

    Returns:
      perturbed_state_parts: A Python `list` of The `Tensor`s. Has the same
        shape and type as the `state_parts`.

    Raises:
      ValueError: if `scale` does not broadcast with `state_parts`.
    """
    with tf.name_scope(name, 'random_walk_normal_fn',
                       values=[state_parts, scale, seed]):
      scales = scale if mcmc_util.is_list_like(scale) else [scale]
      if len(scales) == 1:
        scales *= len(state_parts)
      if len(state_parts) != len(scales):
        raise ValueError('`scale` must broadcast with `state_parts`.')
      next_state_parts = []
      for scale_part, state_part in zip(scales, state_parts):
        # Mutate seed with each use.
        seed = distributions_util.gen_new_seed(
            seed, salt='random_walk_normal_fn')
        next_state_parts.append(tf.random_normal(
            mean=state_part,
            stddev=scale_part,
            shape=tf.shape(state_part),
            dtype=state_part.dtype.base_dtype,
            seed=seed))
      return next_state_parts
Пример #33
0
    def one_step(self, current_state, previous_kernel_results):
        with tf.name_scope(self.name, 'hmc_kernel', [
                self.step_size, self.num_leapfrog_steps, self.seed,
                current_state, previous_kernel_results.target_log_prob,
                previous_kernel_results.grads_target_log_prob
        ]):
            with tf.name_scope('initialize'):
                [
                    current_state_parts,
                    step_sizes,
                    current_target_log_prob,
                    current_grads_target_log_prob,
                ] = _prepare_args(
                    self.target_log_prob_fn,
                    current_state,
                    self.step_size,
                    previous_kernel_results.target_log_prob,
                    previous_kernel_results.grads_target_log_prob,
                    maybe_expand=True)

                current_momentums = []
                for s in current_state_parts:
                    # Note:
                    # - We mutate seed state so subsequent calls are not correlated.
                    # - We mutate seed BEFORE using it just in case users supplied the
                    #   same seed to an outer kernel, e.g., `MetropolisHastings`.
                    self._seed = distributions_util.gen_new_seed(
                        self.seed, salt='hmc_kernel_momentums')
                    current_momentums.append(
                        tf.random_normal(shape=tf.shape(s),
                                         dtype=s.dtype.base_dtype,
                                         seed=self.seed))

                num_leapfrog_steps = tf.convert_to_tensor(
                    self.num_leapfrog_steps,
                    dtype=tf.int32,
                    name='num_leapfrog_steps')

            independent_chain_ndims = distributions_util.prefer_static_rank(
                current_target_log_prob)

            [
                next_momentums,
                next_state_parts,
                next_target_log_prob,
                next_grads_target_log_prob,
            ] = _leapfrog_integrator(current_momentums,
                                     self.target_log_prob_fn,
                                     current_state_parts, step_sizes,
                                     num_leapfrog_steps,
                                     current_target_log_prob,
                                     current_grads_target_log_prob)

            def maybe_flatten(x):
                return x if mcmc_util.is_list_like(current_state) else x[0]

            return [
                maybe_flatten(next_state_parts),
                UncalibratedHamiltonianMonteCarloKernelResults(
                    log_acceptance_correction=
                    _compute_log_acceptance_correction(
                        current_momentums, next_momentums,
                        independent_chain_ndims),
                    target_log_prob=next_target_log_prob,
                    grads_target_log_prob=next_grads_target_log_prob,
                ),
            ]
Пример #34
0
def kernel(target_log_prob_fn,
           current_state,
           step_size,
           num_leapfrog_steps,
           seed=None,
           current_target_log_prob=None,
           current_grads_target_log_prob=None,
           name=None):
  """Runs one iteration of Hamiltonian Monte Carlo.

  Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC)
  algorithm that takes a series of gradient-informed steps to produce
  a Metropolis proposal. This function applies one step of HMC to
  randomly update the variable `x`.

  This function can update multiple chains in parallel. It assumes that all
  leftmost dimensions of `current_state` index independent chain states (and are
  therefore updated independently). The output of `target_log_prob_fn()` should
  sum log-probabilities across all event dimensions. Slices along the rightmost
  dimensions may have different target distributions; for example,
  `current_state[0, :]` could have a different target distribution from
  `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of
  independent chains is `tf.size(target_log_prob_fn(*current_state))`.)

  #### Examples:

  ##### Simple chain with warm-up.

  ```python
  tfd = tf.contrib.distributions

  # Tuning acceptance rates:
  dtype = np.float32
  target_accept_rate = 0.631
  num_warmup_iter = 500
  num_chain_iter = 500

  x = tf.get_variable(name="x", initializer=dtype(1))
  step_size = tf.get_variable(name="step_size", initializer=dtype(1))

  target = tfd.Normal(loc=dtype(0), scale=dtype(1))

  new_x, other_results = hmc.kernel(
      target_log_prob_fn=target.log_prob,
      current_state=x,
      step_size=step_size,
      num_leapfrog_steps=3)[:4]

  x_update = x.assign(new_x)

  step_size_update = step_size.assign_add(
      step_size * tf.where(
        other_results.acceptance_probs > target_accept_rate,
        0.01, -0.01))

  warmup = tf.group([x_update, step_size_update])

  tf.global_variables_initializer().run()

  sess.graph.finalize()  # No more graph building.

  # Warm up the sampler and adapt the step size
  for _ in xrange(num_warmup_iter):
    sess.run(warmup)

  # Collect samples without adapting step size
  samples = np.zeros([num_chain_iter])
  for i in xrange(num_chain_iter):
    _, x_, target_log_prob_, grad_ = sess.run([
        x_update,
        x,
        other_results.target_log_prob,
        other_results.grads_target_log_prob])
    samples[i] = x_

  print(samples.mean(), samples.std())
  ```

  ##### Sample from more complicated posterior.

  I.e.,

  ```none
    W ~ MVN(loc=0, scale=sigma * eye(dims))
    for i=1...num_samples:
        X[i] ~ MVN(loc=0, scale=eye(dims))
      eps[i] ~ Normal(loc=0, scale=1)
        Y[i] = X[i].T * W + eps[i]
  ```

  ```python
  tfd = tf.contrib.distributions

  def make_training_data(num_samples, dims, sigma):
    dt = np.asarray(sigma).dtype
    zeros = tf.zeros(dims, dtype=dt)
    x = tfd.MultivariateNormalDiag(
        loc=zeros).sample(num_samples, seed=1)
    w = tfd.MultivariateNormalDiag(
        loc=zeros,
        scale_identity_multiplier=sigma).sample(seed=2)
    noise = tfd.Normal(
        loc=dt(0),
        scale=dt(1)).sample(num_samples, seed=3)
    y = tf.tensordot(x, w, axes=[[1], [0]]) + noise
    return y, x, w

  def make_prior(sigma, dims):
    # p(w | sigma)
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros([dims], dtype=sigma.dtype),
        scale_identity_multiplier=sigma)

  def make_likelihood(x, w):
    # p(y | x, w)
    return tfd.MultivariateNormalDiag(
        loc=tf.tensordot(x, w, axes=[[1], [0]]))

  # Setup assumptions.
  dtype = np.float32
  num_samples = 150
  dims = 10
  num_iters = int(5e3)

  true_sigma = dtype(0.5)
  y, x, true_weights = make_training_data(num_samples, dims, true_sigma)

  # Estimate of `log(true_sigma)`.
  log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0))
  sigma = tf.exp(log_sigma)

  # State of the Markov chain.
  weights = tf.get_variable(
      name="weights",
      initializer=np.random.randn(dims).astype(dtype))

  prior = make_prior(sigma, dims)

  def joint_log_prob_fn(w):
    # f(w) = log p(w, y | x)
    return prior.log_prob(w) + make_likelihood(x, w).log_prob(y)

  weights_update = weights.assign(
      hmc.kernel(target_log_prob_fn=joint_log_prob,
                 current_state=weights,
                 step_size=0.1,
                 num_leapfrog_steps=5)[0])

  with tf.control_dependencies([weights_update]):
    loss = -prior.log_prob(weights)

  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
  log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma])

  sess.graph.finalize()  # No more graph building.

  tf.global_variables_initializer().run()

  sigma_history = np.zeros(num_iters, dtype)
  weights_history = np.zeros([num_iters, dims], dtype)

  for i in xrange(num_iters):
    _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights])
    weights_history[i, :] = weights_
    sigma_history[i] = sigma_

  true_weights_ = sess.run(true_weights)

  # Should converge to something close to true_sigma.
  plt.plot(sigma_history);
  plt.ylabel("sigma");
  plt.xlabel("iteration");
  ```

  Args:
    target_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the target distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
      for the leapfrog integrator. Must broadcast with the shape of
      `current_state`. Larger step sizes lead to faster progress, but too-large
      step sizes make rejection exponentially more likely. When possible, it's
      often helpful to match per-variable step sizes to the standard deviations
      of the target distribution in each variable.
    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
      for. Total progress per HMC step is roughly proportional to `step_size *
      num_leapfrog_steps`.
    seed: Python integer to seed the random number generator.
    current_target_log_prob: (Optional) `Tensor` representing the value of
      `target_log_prob_fn` at the `current_state`. The only reason to
      specify this argument is to reduce TF graph size.
      Default value: `None` (i.e., compute as needed).
    current_grads_target_log_prob: (Optional) Python list of `Tensor`s
      representing gradient of `current_target_log_prob` at the `current_state`
      and wrt the `current_state`. Must have same shape as `current_state`. The
      only reason to specify this argument is to reduce TF graph size.
      Default value: `None` (i.e., compute as needed).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "hmc_kernel").

  Returns:
    accepted_state: Tensor or Python list of `Tensor`s representing the state(s)
      of the Markov chain(s) at each result step. Has same shape as
      `current_state`.
    kernel_results: `collections.namedtuple` of internal calculations used to
      advance the chain.

  Raises:
    ValueError: if there isn't one `step_size` or a list with same length as
      `current_state`.
  """
  with ops.name_scope(
      name, "hmc_kernel",
      [current_state, step_size, num_leapfrog_steps, seed,
       current_target_log_prob, current_grads_target_log_prob]):
    with ops.name_scope("initialize"):
      [current_state_parts, step_sizes, current_target_log_prob,
       current_grads_target_log_prob] = _prepare_args(
           target_log_prob_fn, current_state, step_size,
           current_target_log_prob, current_grads_target_log_prob,
           maybe_expand=True)
      independent_chain_ndims = distributions_util.prefer_static_rank(
          current_target_log_prob)
      current_momentums = []
      for s in current_state_parts:
        current_momentums.append(random_ops.random_normal(
            shape=array_ops.shape(s),
            dtype=s.dtype.base_dtype,
            seed=seed))
        seed = distributions_util.gen_new_seed(
            seed, salt="hmc_kernel_momentums")

      num_leapfrog_steps = ops.convert_to_tensor(
          num_leapfrog_steps,
          dtype=dtypes.int32,
          name="num_leapfrog_steps")
    [
        proposed_momentums,
        proposed_state_parts,
        proposed_target_log_prob,
        proposed_grads_target_log_prob,
    ] = _leapfrog_integrator(current_momentums,
                             target_log_prob_fn,
                             current_state_parts,
                             step_sizes,
                             num_leapfrog_steps,
                             current_target_log_prob,
                             current_grads_target_log_prob)

    energy_change = _compute_energy_change(current_target_log_prob,
                                           current_momentums,
                                           proposed_target_log_prob,
                                           proposed_momentums,
                                           independent_chain_ndims)

    # u < exp(min(-energy, 0)),  where u~Uniform[0,1)
    # ==> -log(u) >= max(e, 0)
    # ==> -log(u) >= e
    # (Perhaps surprisingly, we don't have a better way to obtain a random
    # uniform from positive reals, i.e., `tf.random_uniform(minval=0,
    # maxval=np.inf)` won't work.)
    random_uniform = random_ops.random_uniform(
        shape=array_ops.shape(energy_change),
        dtype=energy_change.dtype,
        seed=seed)
    random_positive = -math_ops.log(random_uniform)
    is_accepted = random_positive >= energy_change

    accepted_target_log_prob = array_ops.where(is_accepted,
                                               proposed_target_log_prob,
                                               current_target_log_prob)

    accepted_state_parts = [_choose(is_accepted,
                                    proposed_state_part,
                                    current_state_part,
                                    independent_chain_ndims)
                            for current_state_part, proposed_state_part
                            in zip(current_state_parts, proposed_state_parts)]

    accepted_grads_target_log_prob = [
        _choose(is_accepted,
                proposed_grad,
                grad,
                independent_chain_ndims)
        for proposed_grad, grad
        in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)]

    maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0]
    return [
        maybe_flatten(accepted_state_parts),
        KernelResults(
            acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)),
            current_grads_target_log_prob=accepted_grads_target_log_prob,
            current_target_log_prob=accepted_target_log_prob,
            energy_change=energy_change,
            is_accepted=is_accepted,
            proposed_grads_target_log_prob=proposed_grads_target_log_prob,
            proposed_state=maybe_flatten(proposed_state_parts),
            proposed_target_log_prob=proposed_target_log_prob,
            random_positive=random_positive,
        ),
    ]
Пример #35
0
  def one_step(self, current_state, previous_kernel_results):
    """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
    with tf.name_scope(
        name=mcmc_util.make_name(self.name, 'remc', 'one_step'),
        values=[current_state, previous_kernel_results]):
      sampled_replica_states, sampled_replica_results = zip(*[
          rk.one_step(previous_kernel_results.replica_states[i],
                      previous_kernel_results.replica_results[i])
          for i, rk in enumerate(self.replica_kernels)])
      sampled_replica_states = list(sampled_replica_states)
      sampled_replica_results = list(sampled_replica_results)

      sampled_replica_results_modified = [
          srr._replace(target_log_prob=srr.target_log_prob /
                       self.inverse_temperatures[i])
          if 'target_log_prob' in srr._fields
          else srr._replace(accepted_results=srr.accepted_results._replace(
              target_log_prob=srr.accepted_results.target_log_prob /
              self.inverse_temperatures[i]))
          for i, srr in enumerate(sampled_replica_results)
      ]

      sampled_replica_ratios = [
          srr.target_log_prob if 'target_log_prob' in srr._fields
          else srr.accepted_results.target_log_prob
          for i, srr in enumerate(sampled_replica_results_modified)]
      sampled_replica_ratios = tf.stack(sampled_replica_ratios, axis=-1)

      next_replica_idx = tf.range(self.num_replica)
      self._seed_stream = distributions_util.gen_new_seed(
          self._seed_stream, salt='replica_exchange_one_step')
      exchange_proposed, exchange_proposed_n = self.exchange_proposed_fn(
          self.num_replica, seed=self._seed_stream)
      i = tf.constant(0)

      def cond(i, next_replica_idx):  # pylint: disable=unused-argument
        return tf.less(i, exchange_proposed_n)

      def body(i, next_replica_idx):
        """`tf.while_loop` body."""
        ratio = (
            sampled_replica_ratios[next_replica_idx[exchange_proposed[i, 0]]]
            - sampled_replica_ratios[next_replica_idx[exchange_proposed[i, 1]]])
        ratio *= (
            self.inverse_temperatures[exchange_proposed[i, 1]]
            - self.inverse_temperatures[exchange_proposed[i, 0]])
        self._seed_stream = distributions_util.gen_new_seed(
            self._seed_stream, salt='replica_exchange_one_step')
        log_uniform = tf.log(tf.random_uniform(
            shape=tf.shape(ratio),
            dtype=ratio.dtype.base_dtype,
            seed=self._seed_stream))
        exchange = log_uniform < ratio
        exchange_op = tf.sparse_to_dense(
            [exchange_proposed[i, 0], exchange_proposed[i, 1]],
            [self.num_replica],
            [next_replica_idx[exchange_proposed[i, 1]] -
             next_replica_idx[exchange_proposed[i, 0]],
             next_replica_idx[exchange_proposed[i, 0]] -
             next_replica_idx[exchange_proposed[i, 1]]])
        next_replica_idx = tf.cond(exchange,
                                   lambda: next_replica_idx + exchange_op,
                                   lambda: next_replica_idx)
        return [i + 1, next_replica_idx]

      next_replica_idx = tf.while_loop(
          cond, body, loop_vars=[i, next_replica_idx])[1]

      def _prep(list_):
        return list(
            tf.case({tf.equal(next_replica_idx[i], j):
                     _stateful_lambda(list_[j])
                     for j in range(self.num_replica)}, exclusive=True)
            for i in range(self.num_replica))
      next_replica_states = _prep(sampled_replica_states)
      next_replica_results = _prep(sampled_replica_results_modified)

      next_replica_results = [
          nrr._replace(target_log_prob=nrr.target_log_prob *
                       self.inverse_temperatures[i])
          if 'target_log_prob' in nrr._fields
          else nrr._replace(accepted_results=nrr.accepted_results._replace(
              target_log_prob=nrr.accepted_results.target_log_prob *
              self.inverse_temperatures[i]))
          for i, nrr in enumerate(next_replica_results)
      ]

      next_state = tf.identity(next_replica_states[0])
      kernel_results = ReplicaExchangeMCKernelResults(
          replica_states=next_replica_states,
          replica_results=next_replica_results,
          next_replica_idx=next_replica_idx,
          exchange_proposed=exchange_proposed,
          exchange_proposed_n=exchange_proposed_n,
          sampled_replica_states=sampled_replica_states,
          sampled_replica_results=sampled_replica_results,
      )

      return next_state, kernel_results
Пример #36
0
def sample_halton_sequence(dim,
                           num_results=None,
                           sequence_indices=None,
                           dtype=tf.float32,
                           randomized=True,
                           seed=None,
                           name=None):
  r"""Returns a sample from the `dim` dimensional Halton sequence.

  Warning: The sequence elements take values only between 0 and 1. Care must be
  taken to appropriately transform the domain of a function if it differs from
  the unit cube before evaluating integrals using Halton samples. It is also
  important to remember that quasi-random numbers without randomization are not
  a replacement for pseudo-random numbers in every context. Quasi random numbers
  are completely deterministic and typically have significant negative
  autocorrelation unless randomization is used.

  Computes the members of the low discrepancy Halton sequence in dimension
  `dim`. The `dim`-dimensional sequence takes values in the unit hypercube in
  `dim` dimensions. Currently, only dimensions up to 1000 are supported. The
  prime base for the k-th axes is the k-th prime starting from 2. For example,
  if `dim` = 3, then the bases will be [2, 3, 5] respectively and the first
  element of the non-randomized sequence will be: [0.5, 0.333, 0.2]. For a more
  complete description of the Halton sequences see
  [here](https://en.wikipedia.org/wiki/Halton_sequence). For low discrepancy
  sequences and their applications see
  [here](https://en.wikipedia.org/wiki/Low-discrepancy_sequence).

  If `randomized` is true, this function produces a scrambled version of the
  Halton sequence introduced by [Owen (2017)][1]. For the advantages of
  randomization of low discrepancy sequences see [here](
  https://en.wikipedia.org/wiki/Quasi-Monte_Carlo_method#Randomization_of_quasi-Monte_Carlo).

  The number of samples produced is controlled by the `num_results` and
  `sequence_indices` parameters. The user must supply either `num_results` or
  `sequence_indices` but not both.
  The former is the number of samples to produce starting from the first
  element. If `sequence_indices` is given instead, the specified elements of
  the sequence are generated. For example, sequence_indices=tf.range(10) is
  equivalent to specifying n=10.

  #### Examples

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp

  # Produce the first 1000 members of the Halton sequence in 3 dimensions.
  num_results = 1000
  dim = 3
  sample = tfp.mcmc.sample_halton_sequence(
    dim,
    num_results=num_results,
    seed=127)

  # Evaluate the integral of x_1 * x_2^2 * x_3^3  over the three dimensional
  # hypercube.
  powers = tf.range(1.0, limit=dim + 1)
  integral = tf.reduce_mean(tf.reduce_prod(sample ** powers, axis=-1))
  true_value = 1.0 / tf.reduce_prod(powers + 1.0)
  with tf.Session() as session:
    values = session.run((integral, true_value))

  # Produces a relative absolute error of 1.7%.
  print ("Estimated: %f, True Value: %f" % values)

  # Now skip the first 1000 samples and recompute the integral with the next
  # thousand samples. The sequence_indices argument can be used to do this.


  sequence_indices = tf.range(start=1000, limit=1000 + num_results,
                              dtype=tf.int32)
  sample_leaped = tfp.mcmc.sample_halton_sequence(
      dim,
      sequence_indices=sequence_indices,
      seed=111217)

  integral_leaped = tf.reduce_mean(tf.reduce_prod(sample_leaped ** powers,
                                                  axis=-1))
  with tf.Session() as session:
    values = session.run((integral_leaped, true_value))
  # Now produces a relative absolute error of 0.05%.
  print ("Leaped Estimated: %f, True Value: %f" % values)
  ```

  Args:
    dim: Positive Python `int` representing each sample's `event_size.` Must
      not be greater than 1000.
    num_results: (Optional) positive Python `int`. The number of samples to
      generate. Either this parameter or sequence_indices must be specified but
      not both. If this parameter is None, then the behaviour is determined by
      the `sequence_indices`.
      Default value: `None`.
    sequence_indices: (Optional) `Tensor` of dtype int32 and rank 1. The
      elements of the sequence to compute specified by their position in the
      sequence. The entries index into the Halton sequence starting with 0 and
      hence, must be whole numbers. For example, sequence_indices=[0, 5, 6] will
      produce the first, sixth and seventh elements of the sequence. If this
      parameter is None, then the `num_results` parameter must be specified
      which gives the number of desired samples starting from the first sample.
      Default value: `None`.
    dtype: (Optional) The dtype of the sample. One of: `float16`, `float32` or
      `float64`.
      Default value: `tf.float32`.
    randomized: (Optional) bool indicating whether to produce a randomized
      Halton sequence. If True, applies the randomization described in
      [Owen (2017)][1].
      Default value: `True`.
    seed: (Optional) Python integer to seed the random number generator. Only
      used if `randomized` is True. If not supplied and `randomized` is True,
      no seed is set.
      Default value: `None`.
    name:  (Optional) Python `str` describing ops managed by this function. If
      not supplied the name of this function is used.
      Default value: "sample_halton_sequence".

  Returns:
    halton_elements: Elements of the Halton sequence. `Tensor` of supplied dtype
      and `shape` `[num_results, dim]` if `num_results` was specified or shape
      `[s, dim]` where s is the size of `sequence_indices` if `sequence_indices`
      were specified.

  Raises:
    ValueError: if both `sequence_indices` and `num_results` were specified or
      if dimension `dim` is less than 1 or greater than 1000.

  #### References

  [1]: Art B. Owen. A randomized Halton algorithm in R. _arXiv preprint
       arXiv:1706.02808_, 2017. https://arxiv.org/abs/1706.02808
  """
  if dim < 1 or dim > _MAX_DIMENSION:
    raise ValueError(
        'Dimension must be between 1 and {}. Supplied {}'.format(_MAX_DIMENSION,
                                                                 dim))
  if (num_results is None) == (sequence_indices is None):
    raise ValueError('Either `num_results` or `sequence_indices` must be'
                     ' specified but not both.')

  if not dtype.is_floating:
    raise ValueError('dtype must be of `float`-type')

  with tf.name_scope(name, 'sample', values=[sequence_indices]):
    # Here and in the following, the shape layout is as follows:
    # [sample dimension, event dimension, coefficient dimension].
    # The coefficient dimension is an intermediate axes which will hold the
    # weights of the starting integer when expressed in the (prime) base for
    # an event dimension.
    indices = _get_indices(num_results, sequence_indices, dtype)
    radixes = tf.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1])

    max_sizes_by_axes = _base_expansion_size(tf.reduce_max(indices),
                                             radixes)

    max_size = tf.reduce_max(max_sizes_by_axes)

    # The powers of the radixes that we will need. Note that there is a bit
    # of an excess here. Suppose we need the place value coefficients of 7
    # in base 2 and 3. For 2, we will have 3 digits but we only need 2 digits
    # for base 3. However, we can only create rectangular tensors so we
    # store both expansions in a [2, 3] tensor. This leads to the problem that
    # we might end up attempting to raise large numbers to large powers. For
    # example, base 2 expansion of 1024 has 10 digits. If we were in 10
    # dimensions, then the 10th prime (29) we will end up computing 29^10 even
    # though we don't need it. We avoid this by setting the exponents for each
    # axes to 0 beyond the maximum value needed for that dimension.
    exponents_by_axes = tf.tile([tf.range(max_size)], [dim, 1])

    # The mask is true for those coefficients that are irrelevant.
    weight_mask = exponents_by_axes >= max_sizes_by_axes
    capped_exponents = tf.where(
        weight_mask,
        tf.zeros_like(exponents_by_axes),
        exponents_by_axes)
    weights = radixes ** capped_exponents
    # The following computes the base b expansion of the indices. Suppose,
    # x = a0 + a1*b + a2*b^2 + ... Then, performing a floor div of x with
    # the vector (1, b, b^2, b^3, ...) will produce
    # (a0 + s1 * b, a1 + s2 * b, ...) where s_i are coefficients we don't care
    # about. Noting that all a_i < b by definition of place value expansion,
    # we see that taking the elements mod b of the above vector produces the
    # place value expansion coefficients.
    coeffs = tf.floor_div(indices, weights)
    coeffs *= 1. - tf.cast(weight_mask, dtype)
    coeffs %= radixes
    if not randomized:
      coeffs /= radixes
      return tf.reduce_sum(coeffs / weights, axis=-1)
    seed = distributions_util.gen_new_seed(
        seed, salt='mcmc_sample_halton_sequence_1')
    coeffs = _randomize(coeffs, radixes, seed=seed)
    # Remove the contribution from randomizing the trailing zero for the
    # axes where max_size_by_axes < max_size. This will be accounted
    # for separately below (using zero_correction).
    coeffs *= 1. - tf.cast(weight_mask, dtype)
    coeffs /= radixes
    base_values = tf.reduce_sum(coeffs / weights, axis=-1)

    # The randomization used in Owen (2017) does not leave 0 invariant. While
    # we have accounted for the randomization of the first `max_size_by_axes`
    # coefficients, we still need to correct for the trailing zeros. Luckily,
    # this is equivalent to adding a uniform random value scaled so the first
    # `max_size_by_axes` coefficients are zero. The following statements perform
    # this correction.
    seed = distributions_util.gen_new_seed(
        seed, salt='mcmc_sample_halton_sequence_2')
    zero_correction = tf.random_uniform([dim, 1], seed=seed, dtype=dtype)
    zero_correction /= radixes ** max_sizes_by_axes
    return base_values + tf.reshape(zero_correction, [-1])
Пример #37
0
  def _sample_n(self, n, seed=None):
    if self._use_static_graph:
      # This sampling approach is almost the same as the approach used by
      # `MixtureSameFamily`. The differences are due to having a list of
      # `Distribution` objects rather than a single object, and maintaining
      # random seed management that is consistent with the non-static code path.
      samples = []
      cat_samples = self.cat.sample(n, seed=seed)
      for c in range(self.num_components):
        seed = distribution_util.gen_new_seed(seed, "mixture")
        samples.append(self.components[c].sample(n, seed=seed))
      x = tf.stack(samples, -self._static_event_shape.ndims - 1)  # [n, B, k, E]
      npdt = x.dtype.as_numpy_dtype
      mask = tf.one_hot(
          indices=cat_samples,  # [n, B]
          depth=self._num_components,  # == k
          on_value=np.ones([], dtype=npdt),
          off_value=np.zeros([], dtype=npdt))  # [n, B, k]
      mask = distribution_utils.pad_mixture_dimensions(
          mask, self, self._cat,
          self._static_event_shape.ndims)                   # [n, B, k, [1]*e]
      return tf.reduce_sum(
          x * mask, axis=-1 - self._static_event_shape.ndims)  # [n, B, E]

    with tf.control_dependencies(self._assertions):
      n = tf.convert_to_tensor(n, name="n")
      static_n = tensor_util.constant_value(n)
      n = int(static_n) if static_n is not None else n
      cat_samples = self.cat.sample(n, seed=seed)

      static_samples_shape = cat_samples.get_shape()
      if static_samples_shape.is_fully_defined():
        samples_shape = static_samples_shape.as_list()
        samples_size = static_samples_shape.num_elements()
      else:
        samples_shape = tf.shape(cat_samples)
        samples_size = tf.size(cat_samples)
      static_batch_shape = self.batch_shape
      if static_batch_shape.is_fully_defined():
        batch_shape = static_batch_shape.as_list()
        batch_size = static_batch_shape.num_elements()
      else:
        batch_shape = self.batch_shape_tensor()
        batch_size = tf.reduce_prod(batch_shape)
      static_event_shape = self.event_shape
      if static_event_shape.is_fully_defined():
        event_shape = np.array(static_event_shape.as_list(), dtype=np.int32)
      else:
        event_shape = self.event_shape_tensor()

      # Get indices into the raw cat sampling tensor. We will
      # need these to stitch sample values back out after sampling
      # within the component partitions.
      samples_raw_indices = tf.reshape(tf.range(0, samples_size), samples_shape)

      # Partition the raw indices so that we can use
      # dynamic_stitch later to reconstruct the samples from the
      # known partitions.
      partitioned_samples_indices = tf.dynamic_partition(
          data=samples_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)

      # Copy the batch indices n times, as we will need to know
      # these to pull out the appropriate rows within the
      # component partitions.
      batch_raw_indices = tf.reshape(
          tf.tile(tf.range(0, batch_size), [n]), samples_shape)

      # Explanation of the dynamic partitioning below:
      #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
      # Suppose partitions are:
      #     [1 1 0 0 1 1]
      # After partitioning, batch indices are cut as:
      #     [batch_indices[x] for x in 2, 3]
      #     [batch_indices[x] for x in 0, 1, 4, 5]
      # i.e.
      #     [1 1] and [0 0 0 0]
      # Now we sample n=2 from part 0 and n=4 from part 1.
      # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
      # and for part 1 we want samples from batch entries 0, 0, 0, 0
      #   (samples 0, 1, 2, 3).
      partitioned_batch_indices = tf.dynamic_partition(
          data=batch_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)
      samples_class = [None for _ in range(self.num_components)]

      for c in range(self.num_components):
        n_class = tf.size(partitioned_samples_indices[c])
        seed = distribution_util.gen_new_seed(seed, "mixture")
        samples_class_c = self.components[c].sample(n_class, seed=seed)

        # Pull out the correct batch entries from each index.
        # To do this, we may have to flatten the batch shape.

        # For sample s, batch element b of component c, we get the
        # partitioned batch indices from
        # partitioned_batch_indices[c]; and shift each element by
        # the sample index. The final lookup can be thought of as
        # a matrix gather along locations (s, b) in
        # samples_class_c where the n_class rows correspond to
        # samples within this component and the batch_size columns
        # correspond to batch elements within the component.
        #
        # Thus the lookup index is
        #   lookup[c, i] = batch_size * s[i] + b[c, i]
        # for i = 0 ... n_class[c] - 1.
        lookup_partitioned_batch_indices = (
            batch_size * tf.range(n_class) + partitioned_batch_indices[c])
        samples_class_c = tf.reshape(
            samples_class_c, tf.concat([[n_class * batch_size], event_shape],
                                       0))
        samples_class_c = tf.gather(
            samples_class_c,
            lookup_partitioned_batch_indices,
            name="samples_class_c_gather")
        samples_class[c] = samples_class_c

      # Stitch back together the samples across the components.
      lhs_flat_ret = tf.dynamic_stitch(
          indices=partitioned_samples_indices, data=samples_class)
      # Reshape back to proper sample, batch, and event shape.
      ret = tf.reshape(
          lhs_flat_ret, tf.concat(
              [samples_shape, self.event_shape_tensor()], 0))
      ret.set_shape(
          tf.TensorShape(static_samples_shape).concatenate(self.event_shape))
      return ret
Пример #38
0
def sample_chain(num_results,
                 target_log_prob_fn,
                 current_state,
                 step_size,
                 num_leapfrog_steps,
                 num_burnin_steps=0,
                 num_steps_between_results=0,
                 seed=None,
                 current_target_log_prob=None,
                 current_grads_target_log_prob=None,
                 name=None):
    """Runs multiple iterations of one or more Hamiltonian Monte Carlo chains.

  Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm
  that takes a series of gradient-informed steps to produce a Metropolis
  proposal. This function samples from an HMC Markov chain at `current_state`
  and whose stationary distribution has log-unnormalized-density
  `target_log_prob_fn()`.

  This function samples from multiple chains in parallel. It assumes that the
  the leftmost dimensions of (each) `current_state` (part) index an independent
  chain.  The function `target_log_prob_fn()` sums log-probabilities across
  event dimensions (i.e., current state (part) rightmost dimensions). Each
  element of the output of `target_log_prob_fn()` represents the (possibly
  unnormalized) log-probability of the joint distribution over (all) the current
  state (parts).

  The `current_state` can be represented as a single `Tensor` or a `list` of
  `Tensors` which collectively represent the current state. When specifying a
  `list`, one must also specify a list of `step_size`s.

  Only one out of every `num_steps_between_samples + 1` steps is included in the
  returned results. This "thinning" comes at a cost of reduced statistical
  power, while reducing memory requirements and autocorrelation. For more
  discussion see [1].

  [1]: "Statistically efficient thinning of a Markov chain sampler."
       Art B. Owen. April 2017.
       http://statweb.stanford.edu/~owen/reports/bestthinning.pdf

  #### Examples:

  ##### Sample from a diagonal-variance Gaussian.

  ```python
  tfd = tf.contrib.distributions

  def make_likelihood(true_variances):
    return tfd.MultivariateNormalDiag(
        scale_diag=tf.sqrt(true_variances))

  dims = 10
  dtype = np.float32
  true_variances = tf.linspace(dtype(1), dtype(3), dims)
  likelihood = make_likelihood(true_variances)

  states, kernel_results = hmc.sample_chain(
      num_results=1000,
      target_log_prob_fn=likelihood.log_prob,
      current_state=tf.zeros(dims),
      step_size=0.5,
      num_leapfrog_steps=2,
      num_burnin_steps=500)

  # Compute sample stats.
  sample_mean = tf.reduce_mean(states, axis=0)
  sample_var = tf.reduce_mean(
      tf.squared_difference(states, sample_mean),
      axis=0)
  ```

  ##### Sampling from factor-analysis posteriors with known factors.

  I.e.,

  ```none
  for i=1..n:
    w[i] ~ Normal(0, eye(d))            # prior
    x[i] ~ Normal(loc=matmul(w[i], F))  # likelihood
  ```

  where `F` denotes factors.

  ```python
  tfd = tf.contrib.distributions

  def make_prior(dims, dtype):
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros(dims, dtype))

  def make_likelihood(weights, factors):
    return tfd.MultivariateNormalDiag(
        loc=tf.tensordot(weights, factors, axes=[[0], [-1]]))

  # Setup data.
  num_weights = 10
  num_factors = 4
  num_chains = 100
  dtype = np.float32

  prior = make_prior(num_weights, dtype)
  weights = prior.sample(num_chains)
  factors = np.random.randn(num_factors, num_weights).astype(dtype)
  x = make_likelihood(weights, factors).sample(num_chains)

  def target_log_prob(w):
    # Target joint is: `f(w) = p(w, x | factors)`.
    return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x)

  # Get `num_results` samples from `num_chains` independent chains.
  chains_states, kernels_results = hmc.sample_chain(
      num_results=1000,
      target_log_prob_fn=target_log_prob,
      current_state=tf.zeros([num_chains, dims], dtype),
      step_size=0.1,
      num_leapfrog_steps=2,
      num_burnin_steps=500)

  # Compute sample stats.
  sample_mean = tf.reduce_mean(chains_states, axis=[0, 1])
  sample_var = tf.reduce_mean(
      tf.squared_difference(chains_states, sample_mean),
      axis=[0, 1])
  ```

  Args:
    num_results: Integer number of Markov chain draws.
    target_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the target distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
      for the leapfrog integrator. Must broadcast with the shape of
      `current_state`. Larger step sizes lead to faster progress, but too-large
      step sizes make rejection exponentially more likely. When possible, it's
      often helpful to match per-variable step sizes to the standard deviations
      of the target distribution in each variable.
    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
      for. Total progress per HMC step is roughly proportional to `step_size *
      num_leapfrog_steps`.
    num_burnin_steps: Integer number of chain steps to take before starting to
      collect results.
      Default value: 0 (i.e., no burn-in).
    num_steps_between_results: Integer number of chain steps between collecting
      a result. Only one out of every `num_steps_between_samples + 1` steps is
      included in the returned results. This "thinning" comes at a cost of
      reduced statistical power, while reducing memory requirements and
      autocorrelation. For more discussion see [1].
      Default value: 0 (i.e., no subsampling).
    seed: Python integer to seed the random number generator.
    current_target_log_prob: (Optional) `Tensor` representing the value of
      `target_log_prob_fn` at the `current_state`. The only reason to specify
      this argument is to reduce TF graph size.
      Default value: `None` (i.e., compute as needed).
    current_grads_target_log_prob: (Optional) Python list of `Tensor`s
      representing gradient of `target_log_prob` at the `current_state` and wrt
      the `current_state`. Must have same shape as `current_state`. The only
      reason to specify this argument is to reduce TF graph size.
      Default value: `None` (i.e., compute as needed).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "hmc_sample_chain").

  Returns:
    accepted_states: Tensor or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at each result step. Has same shape as
      input `current_state` but with a prepended `num_results`-size dimension.
    kernel_results: `collections.namedtuple` of internal calculations used to
      advance the chain.
  """
    with ops.name_scope(name, "hmc_sample_chain", [
            num_results, current_state, step_size, num_leapfrog_steps,
            num_burnin_steps, num_steps_between_results, seed,
            current_target_log_prob, current_grads_target_log_prob
    ]):
        with ops.name_scope("initialize"):
            [
                current_state,
                step_size,
                current_target_log_prob,
                current_grads_target_log_prob,
            ] = _prepare_args(target_log_prob_fn, current_state, step_size,
                              current_target_log_prob,
                              current_grads_target_log_prob)

        def _run_chain(num_steps, current_state, seed, kernel_results):
            """Runs the chain(s) for `num_steps`."""
            def _loop_body(iter_, current_state, kernel_results):
                return [iter_ + 1] + list(
                    kernel(target_log_prob_fn, current_state, step_size,
                           num_leapfrog_steps, seed,
                           kernel_results.current_target_log_prob,
                           kernel_results.current_grads_target_log_prob))

            return control_flow_ops.while_loop(
                cond=lambda iter_, *args: iter_ < num_steps,
                body=_loop_body,
                loop_vars=[0, current_state,
                           kernel_results])[1:]  # Lop-off "iter_".

        def _scan_body(args_list, _):
            """Closure which implements `tf.scan` body."""
            current_state, kernel_results = args_list
            return _run_chain(num_steps_between_results + 1, current_state,
                              seed, kernel_results)

        current_state, kernel_results = _run_chain(
            num_burnin_steps, current_state,
            distributions_util.gen_new_seed(seed,
                                            salt="hmc_sample_chain_burnin"),
            _make_dummy_kernel_results(current_state, current_target_log_prob,
                                       current_grads_target_log_prob))

        return functional_ops.scan(
            fn=_scan_body,
            elems=array_ops.zeros(num_results,
                                  dtype=dtypes.bool),  # Dummy arg.
            initializer=[current_state, kernel_results])
    def testDenseFlipout(self):
        batch_size, in_size, out_size = 2, 3, 4
        with self.test_session() as sess:
            (kernel_posterior, kernel_prior, kernel_divergence, bias_posterior,
             bias_prior, bias_divergence, layer, inputs, outputs,
             kl_penalty) = self._testDenseSetUp(tfp.layers.DenseFlipout,
                                                batch_size,
                                                in_size,
                                                out_size,
                                                seed=44)

            expected_kernel_posterior_affine = tfd.Normal(
                loc=tf.zeros_like(kernel_posterior.result_loc),
                scale=kernel_posterior.result_scale)
            expected_kernel_posterior_affine_tensor = (
                expected_kernel_posterior_affine.sample(seed=42))

            sign_input = tf.random_uniform([batch_size, in_size],
                                           minval=0,
                                           maxval=2,
                                           dtype=tf.int32,
                                           seed=layer.seed)
            sign_input = tf.cast(2 * sign_input - 1, inputs.dtype)
            sign_output = tf.random_uniform(
                [batch_size, out_size],
                minval=0,
                maxval=2,
                dtype=tf.int32,
                seed=distribution_util.gen_new_seed(layer.seed,
                                                    salt='dense_flipout'))
            sign_output = tf.cast(2 * sign_output - 1, inputs.dtype)
            perturbed_inputs = tf.matmul(
                inputs * sign_input, expected_kernel_posterior_affine_tensor)
            perturbed_inputs *= sign_output

            expected_outputs = tf.matmul(inputs, kernel_posterior.result_loc)
            expected_outputs += perturbed_inputs
            expected_outputs += bias_posterior.result_sample

            [
                expected_outputs_,
                actual_outputs_,
                expected_kernel_divergence_,
                actual_kernel_divergence_,
                expected_bias_,
                actual_bias_,
                expected_bias_divergence_,
                actual_bias_divergence_,
            ] = sess.run([
                expected_outputs,
                outputs,
                kernel_divergence.result,
                kl_penalty[0],
                bias_posterior.result_sample,
                layer.bias_posterior_tensor,
                bias_divergence.result,
                kl_penalty[1],
            ])

            self.assertAllClose(expected_bias_,
                                actual_bias_,
                                rtol=1e-6,
                                atol=0.)
            self.assertAllClose(expected_outputs_,
                                actual_outputs_,
                                rtol=1e-6,
                                atol=0.)
            self.assertAllClose(expected_kernel_divergence_,
                                actual_kernel_divergence_,
                                rtol=1e-6,
                                atol=0.)
            self.assertAllClose(expected_bias_divergence_,
                                actual_bias_divergence_,
                                rtol=1e-6,
                                atol=0.)

            self.assertAllEqual([[
                kernel_posterior.distribution, kernel_prior.distribution, None
            ]], kernel_divergence.args)

            self.assertAllEqual([[
                bias_posterior.distribution, bias_prior.distribution,
                bias_posterior.result_sample
            ]], bias_divergence.args)
  def _testConvFlipout(self, layer_class):
    batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5
    with self.test_session() as sess:
      (kernel_posterior, kernel_prior, kernel_divergence,
       bias_posterior, bias_prior, bias_divergence, layer, inputs,
       outputs, kl_penalty, kernel_shape) = self._testConvSetUp(
           layer_class, batch_size,
           depth=depth, height=height, width=width, channels=channels,
           filters=filters, seed=44)

      convolution_op = nn_ops.Convolution(
          tensor_shape.TensorShape(inputs.shape),
          filter_shape=tensor_shape.TensorShape(kernel_shape),
          padding="SAME")

      expected_kernel_posterior_affine = normal_lib.Normal(
          loc=array_ops.zeros_like(kernel_posterior.result_loc),
          scale=kernel_posterior.result_scale)
      expected_kernel_posterior_affine_tensor = (
          expected_kernel_posterior_affine.sample(seed=42))

      expected_outputs = convolution_op(
          inputs, kernel_posterior.distribution.loc)

      input_shape = array_ops.shape(inputs)
      output_shape = array_ops.shape(expected_outputs)
      batch_shape = array_ops.expand_dims(input_shape[0], 0)
      channels = input_shape[-1]
      rank = len(inputs.get_shape()) - 2

      sign_input = random_ops.random_uniform(
          array_ops.concat([batch_shape,
                            array_ops.expand_dims(channels, 0)], 0),
          minval=0,
          maxval=2,
          dtype=dtypes.int32,
          seed=layer.seed)
      sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype)
      sign_output = random_ops.random_uniform(
          array_ops.concat([batch_shape,
                            array_ops.expand_dims(filters, 0)], 0),
          minval=0,
          maxval=2,
          dtype=dtypes.int32,
          seed=distribution_util.gen_new_seed(
              layer.seed, salt="conv_flipout"))
      sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype)
      for _ in range(rank):
        sign_input = array_ops.expand_dims(sign_input, 1)  # 2D ex: (B, 1, 1, C)
        sign_output = array_ops.expand_dims(sign_output, 1)

      sign_input = array_ops.tile(  # tile for element-wise op broadcasting
          sign_input,
          [1] + [input_shape[i + 1] for i in range(rank)] + [1])
      sign_output = array_ops.tile(
          sign_output,
          [1] + [output_shape[i + 1] for i in range(rank)] + [1])

      perturbed_inputs = convolution_op(
          inputs * sign_input, expected_kernel_posterior_affine_tensor)
      perturbed_inputs *= sign_output

      expected_outputs += perturbed_inputs
      expected_outputs = nn.bias_add(expected_outputs,
                                     bias_posterior.result_sample,
                                     data_format="NHWC")

      [
          expected_outputs_, actual_outputs_,
          expected_kernel_divergence_, actual_kernel_divergence_,
          expected_bias_, actual_bias_,
          expected_bias_divergence_, actual_bias_divergence_,
      ] = sess.run([
          expected_outputs, outputs,
          kernel_divergence.result, kl_penalty[0],
          bias_posterior.result_sample, layer.bias_posterior_tensor,
          bias_divergence.result, kl_penalty[1],
      ])

      self.assertAllClose(
          expected_bias_, actual_bias_,
          rtol=1e-6, atol=0.)
      self.assertAllClose(
          expected_outputs_, actual_outputs_,
          rtol=1e-6, atol=0.)
      self.assertAllClose(
          expected_kernel_divergence_, actual_kernel_divergence_,
          rtol=1e-6, atol=0.)
      self.assertAllClose(
          expected_bias_divergence_, actual_bias_divergence_,
          rtol=1e-6, atol=0.)

      self.assertAllEqual(
          [[kernel_posterior.distribution, kernel_prior.distribution, None]],
          kernel_divergence.args)

      self.assertAllEqual(
          [[bias_posterior.distribution,
            bias_prior.distribution,
            bias_posterior.result_sample]],
          bias_divergence.args)
Пример #41
0
    def _sample_n(self, n, seed=None):
        with ops.control_dependencies(self._assertions):
            n = ops.convert_to_tensor(n, name="n")
            static_n = tensor_util.constant_value(n)
            n = int(static_n) if static_n is not None else n
            pi_samples = self.pi.sample(n, seed=seed)

            static_samples_shape = pi_samples.get_shape()
            if static_samples_shape.is_fully_defined():
                samples_shape = static_samples_shape.as_list()
                samples_size = static_samples_shape.num_elements()
            else:
                samples_shape = array_ops.shape(pi_samples)
                samples_size = array_ops.size(pi_samples)
            static_batch_shape = self.batch_shape
            if static_batch_shape.is_fully_defined():
                batch_shape = static_batch_shape.as_list()
                batch_size = static_batch_shape.num_elements()
            else:
                batch_shape = self.batch_shape_tensor()
                batch_size = math_ops.reduce_prod(batch_shape)
            static_event_shape = self.event_shape
            if static_event_shape.is_fully_defined():
                event_shape = np.array(static_event_shape.as_list(),
                                       dtype=np.int32)
            else:
                event_shape = self.event_shape_tensor()

            # Get indices into the raw pi sampling tensor. We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = array_ops.reshape(
                math_ops.range(0, samples_size), samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = data_flow_ops.dynamic_partition(
                data=samples_raw_indices,
                partitions=pi_samples,
                num_partitions=self.num_dist)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = array_ops.reshape(
                array_ops.tile(math_ops.range(0, batch_size), [n]),
                samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = data_flow_ops.dynamic_partition(
                data=batch_raw_indices,
                partitions=pi_samples,
                num_partitions=self.num_dist)
            samples_class = [None for _ in range(self.num_dist)]

            for c in range(self.num_dist):
                n_class = array_ops.size(partitioned_samples_indices[c])
                seed = distribution_util.gen_new_seed(seed, "ZeroInflated")
                samples_class_c = self.dist[c].sample(n_class, seed=seed)

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index. The final lookup can be thought of as
                # a matrix gather along lopiions (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * math_ops.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = array_ops.reshape(
                    samples_class_c,
                    array_ops.conpi([[n_class * batch_size], event_shape], 0))
                samples_class_c = array_ops.gather(
                    samples_class_c,
                    lookup_partitioned_batch_indices,
                    name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the dist.
            lhs_flat_ret = data_flow_ops.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = array_ops.reshape(
                lhs_flat_ret,
                array_ops.conpi(
                    [samples_shape, self.event_shape_tensor()], 0))
            ret.set_shape(
                tensor_shape.TensorShape(static_samples_shape).conpienate(
                    self.event_shape))
            return ret
  def testDenseFlipout(self):
    batch_size, in_size, out_size = 2, 3, 4
    with self.test_session() as sess:
      (kernel_posterior, kernel_prior, kernel_divergence,
       bias_posterior, bias_prior, bias_divergence, layer, inputs,
       outputs, kl_penalty) = self._testDenseSetUp(
           prob_layers_lib.DenseFlipout,
           batch_size, in_size, out_size, seed=44)

      expected_kernel_posterior_affine = normal_lib.Normal(
          loc=array_ops.zeros_like(kernel_posterior.result_loc),
          scale=kernel_posterior.result_scale)
      expected_kernel_posterior_affine_tensor = (
          expected_kernel_posterior_affine.sample(seed=42))

      sign_input = random_ops.random_uniform(
          [batch_size, in_size],
          minval=0,
          maxval=2,
          dtype=dtypes.int32,
          seed=layer.seed)
      sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype)
      sign_output = random_ops.random_uniform(
          [batch_size, out_size],
          minval=0,
          maxval=2,
          dtype=dtypes.int32,
          seed=distribution_util.gen_new_seed(
              layer.seed, salt="dense_flipout"))
      sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype)
      perturbed_inputs = math_ops.matmul(
          inputs * sign_input, expected_kernel_posterior_affine_tensor)
      perturbed_inputs *= sign_output

      expected_outputs = math_ops.matmul(inputs, kernel_posterior.result_loc)
      expected_outputs += perturbed_inputs
      expected_outputs += bias_posterior.result_sample

      [
          expected_outputs_, actual_outputs_,
          expected_kernel_divergence_, actual_kernel_divergence_,
          expected_bias_, actual_bias_,
          expected_bias_divergence_, actual_bias_divergence_,
      ] = sess.run([
          expected_outputs, outputs,
          kernel_divergence.result, kl_penalty[0],
          bias_posterior.result_sample, layer.bias_posterior_tensor,
          bias_divergence.result, kl_penalty[1],
      ])

      self.assertAllClose(
          expected_bias_, actual_bias_,
          rtol=1e-6, atol=0.)
      self.assertAllClose(
          expected_outputs_, actual_outputs_,
          rtol=1e-6, atol=0.)
      self.assertAllClose(
          expected_kernel_divergence_, actual_kernel_divergence_,
          rtol=1e-6, atol=0.)
      self.assertAllClose(
          expected_bias_divergence_, actual_bias_divergence_,
          rtol=1e-6, atol=0.)

      self.assertAllEqual(
          [[kernel_posterior.distribution, kernel_prior.distribution, None]],
          kernel_divergence.args)

      self.assertAllEqual(
          [[bias_posterior.distribution,
            bias_prior.distribution,
            bias_posterior.result_sample]],
          bias_divergence.args)
Пример #43
0
 def init_momentum(s):
   return random_ops.random_normal(
       shape=array_ops.shape(s),
       dtype=s.dtype.base_dtype,
       seed=distributions_util.gen_new_seed(
           seed, salt="hmc_kernel_momentums"))
Пример #44
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.

    Raises:
      ValueError: if `inner_kernel` results doesn't contain the member
        "target_log_prob".
    """
        # Take one inner step.
        [
            proposed_state,
            proposed_results,
        ] = self.inner_kernel.one_step(
            current_state, previous_kernel_results.accepted_results)

        if (not has_target_log_prob(proposed_results)
                or not has_target_log_prob(
                    previous_kernel_results.accepted_results)):
            raise ValueError('"target_log_prob" must be a member of '
                             '`inner_kernel` results.')

        # Compute log(acceptance_ratio).
        to_sum = [
            proposed_results.target_log_prob,
            -previous_kernel_results.accepted_results.target_log_prob
        ]
        try:
            to_sum.append(proposed_results.log_acceptance_correction)
        except AttributeError:
            warnings.warn(
                'Supplied inner `TransitionKernel` does not have a '
                '`log_acceptance_correction`. Assuming its value is `0.`')
        log_accept_ratio = mcmc_util.safe_sum(to_sum,
                                              name='compute_log_accept_ratio')

        # If proposed state reduces likelihood: randomly accept.
        # If proposed state increases likelihood: always accept.
        # I.e., u < min(1, accept_ratio),  where u ~ Uniform[0,1)
        #       ==> log(u) < log_accept_ratio
        # Note:
        # - We mutate seed state so subsequent calls are not correlated.
        # - We mutate seed BEFORE using it just in case users supplied the
        #   same seed to the inner kernel.
        self._seed = distributions_util.gen_new_seed(
            self.seed, salt='metropolis_hastings_one_step')
        log_uniform = tf.log(
            tf.random_uniform(
                shape=tf.shape(proposed_results.target_log_prob),
                dtype=proposed_results.target_log_prob.dtype.base_dtype,
                seed=self.seed))
        is_accepted = log_uniform < log_accept_ratio

        independent_chain_ndims = distributions_util.prefer_static_rank(
            proposed_results.target_log_prob)

        next_state = mcmc_util.choose(is_accepted, proposed_state,
                                      current_state, independent_chain_ndims)

        accepted_results = type(proposed_results)(
            **dict([(fn,
                     mcmc_util.choose(
                         is_accepted, getattr(proposed_results, fn),
                         getattr(previous_kernel_results.accepted_results, fn),
                         independent_chain_ndims))
                    for fn in proposed_results._fields]))

        return [
            next_state,
            MetropolisHastingsKernelResults(
                accepted_results=accepted_results,
                is_accepted=is_accepted,
                log_accept_ratio=log_accept_ratio,
                proposed_state=proposed_state,
                proposed_results=proposed_results,
            )
        ]
Пример #45
0
    def _testConvFlipout(self, layer_class):
        batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5
        with self.test_session() as sess:
            (kernel_posterior, kernel_prior, kernel_divergence, bias_posterior,
             bias_prior, bias_divergence, layer, inputs, outputs, kl_penalty,
             kernel_shape) = self._testConvSetUp(layer_class,
                                                 batch_size,
                                                 depth=depth,
                                                 height=height,
                                                 width=width,
                                                 channels=channels,
                                                 filters=filters,
                                                 seed=44)

            convolution_op = nn_ops.Convolution(
                tf.TensorShape(inputs.shape),
                filter_shape=tf.TensorShape(kernel_shape),
                padding='SAME')

            expected_kernel_posterior_affine = tfd.Normal(
                loc=tf.zeros_like(kernel_posterior.result_loc),
                scale=kernel_posterior.result_scale)
            expected_kernel_posterior_affine_tensor = (
                expected_kernel_posterior_affine.sample(seed=42))

            expected_outputs = convolution_op(
                inputs, kernel_posterior.distribution.loc)

            input_shape = tf.shape(inputs)
            output_shape = tf.shape(expected_outputs)
            batch_shape = tf.expand_dims(input_shape[0], 0)
            channels = input_shape[-1]
            rank = len(inputs.get_shape()) - 2

            sign_input = tf.random_uniform(tf.concat(
                [batch_shape, tf.expand_dims(channels, 0)], 0),
                                           minval=0,
                                           maxval=2,
                                           dtype=tf.int32,
                                           seed=layer.seed)
            sign_input = tf.cast(2 * sign_input - 1, inputs.dtype)
            sign_output = tf.random_uniform(
                tf.concat(
                    [batch_shape, tf.expand_dims(filters, 0)], 0),
                minval=0,
                maxval=2,
                dtype=tf.int32,
                seed=distribution_util.gen_new_seed(layer.seed,
                                                    salt='conv_flipout'))
            sign_output = tf.cast(2 * sign_output - 1, inputs.dtype)
            for _ in range(rank):
                sign_input = tf.expand_dims(sign_input,
                                            1)  # 2D ex: (B, 1, 1, C)
                sign_output = tf.expand_dims(sign_output, 1)

            sign_input = tf.tile(  # tile for element-wise op broadcasting
                sign_input,
                [1] + [input_shape[i + 1] for i in range(rank)] + [1])
            sign_output = tf.tile(
                sign_output,
                [1] + [output_shape[i + 1] for i in range(rank)] + [1])

            perturbed_inputs = convolution_op(
                inputs * sign_input, expected_kernel_posterior_affine_tensor)
            perturbed_inputs *= sign_output

            expected_outputs += perturbed_inputs
            expected_outputs = tf.nn.bias_add(expected_outputs,
                                              bias_posterior.result_sample,
                                              data_format='NHWC')

            [
                expected_outputs_,
                actual_outputs_,
                expected_kernel_divergence_,
                actual_kernel_divergence_,
                expected_bias_,
                actual_bias_,
                expected_bias_divergence_,
                actual_bias_divergence_,
            ] = sess.run([
                expected_outputs,
                outputs,
                kernel_divergence.result,
                kl_penalty[0],
                bias_posterior.result_sample,
                layer.bias_posterior_tensor,
                bias_divergence.result,
                kl_penalty[1],
            ])

            self.assertAllClose(expected_bias_,
                                actual_bias_,
                                rtol=1e-6,
                                atol=0.)
            self.assertAllClose(expected_outputs_,
                                actual_outputs_,
                                rtol=1e-6,
                                atol=0.)
            self.assertAllClose(expected_kernel_divergence_,
                                actual_kernel_divergence_,
                                rtol=1e-6,
                                atol=0.)
            self.assertAllClose(expected_bias_divergence_,
                                actual_bias_divergence_,
                                rtol=1e-6,
                                atol=0.)

            self.assertAllEqual([[
                kernel_posterior.distribution, kernel_prior.distribution, None
            ]], kernel_divergence.args)

            self.assertAllEqual([[
                bias_posterior.distribution, bias_prior.distribution,
                bias_posterior.result_sample
            ]], bias_divergence.args)
Пример #46
0
def sample_halton_sequence(dim,
                           num_results=None,
                           sequence_indices=None,
                           dtype=tf.float32,
                           randomized=True,
                           seed=None,
                           name=None):
  r"""Returns a sample from the `dim` dimensional Halton sequence.

  Warning: The sequence elements take values only between 0 and 1. Care must be
  taken to appropriately transform the domain of a function if it differs from
  the unit cube before evaluating integrals using Halton samples. It is also
  important to remember that quasi-random numbers without randomization are not
  a replacement for pseudo-random numbers in every context. Quasi random numbers
  are completely deterministic and typically have significant negative
  autocorrelation unless randomization is used.

  Computes the members of the low discrepancy Halton sequence in dimension
  `dim`. The `dim`-dimensional sequence takes values in the unit hypercube in
  `dim` dimensions. Currently, only dimensions up to 1000 are supported. The
  prime base for the k-th axes is the k-th prime starting from 2. For example,
  if `dim` = 3, then the bases will be [2, 3, 5] respectively and the first
  element of the non-randomized sequence will be: [0.5, 0.333, 0.2]. For a more
  complete description of the Halton sequences see
  [here](https://en.wikipedia.org/wiki/Halton_sequence). For low discrepancy
  sequences and their applications see
  [here](https://en.wikipedia.org/wiki/Low-discrepancy_sequence).

  If `randomized` is true, this function produces a scrambled version of the
  Halton sequence introduced by [Owen (2017)][1]. For the advantages of
  randomization of low discrepancy sequences see [here](
  https://en.wikipedia.org/wiki/Quasi-Monte_Carlo_method#Randomization_of_quasi-Monte_Carlo).

  The number of samples produced is controlled by the `num_results` and
  `sequence_indices` parameters. The user must supply either `num_results` or
  `sequence_indices` but not both.
  The former is the number of samples to produce starting from the first
  element. If `sequence_indices` is given instead, the specified elements of
  the sequence are generated. For example, sequence_indices=tf.range(10) is
  equivalent to specifying n=10.

  #### Examples

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp

  # Produce the first 1000 members of the Halton sequence in 3 dimensions.
  num_results = 1000
  dim = 3
  sample = tfp.mcmc.sample_halton_sequence(
    dim,
    num_results=num_results,
    seed=127)

  # Evaluate the integral of x_1 * x_2^2 * x_3^3  over the three dimensional
  # hypercube.
  powers = tf.range(1.0, limit=dim + 1)
  integral = tf.reduce_mean(tf.reduce_prod(sample ** powers, axis=-1))
  true_value = 1.0 / tf.reduce_prod(powers + 1.0)
  with tf.Session() as session:
    values = session.run((integral, true_value))

  # Produces a relative absolute error of 1.7%.
  print ("Estimated: %f, True Value: %f" % values)

  # Now skip the first 1000 samples and recompute the integral with the next
  # thousand samples. The sequence_indices argument can be used to do this.


  sequence_indices = tf.range(start=1000, limit=1000 + num_results,
                              dtype=tf.int32)
  sample_leaped = tfp.mcmc.sample_halton_sequence(
      dim,
      sequence_indices=sequence_indices,
      seed=111217)

  integral_leaped = tf.reduce_mean(tf.reduce_prod(sample_leaped ** powers,
                                                  axis=-1))
  with tf.Session() as session:
    values = session.run((integral_leaped, true_value))
  # Now produces a relative absolute error of 0.05%.
  print ("Leaped Estimated: %f, True Value: %f" % values)
  ```

  Args:
    dim: Positive Python `int` representing each sample's `event_size.` Must
      not be greater than 1000.
    num_results: (Optional) positive Python `int`. The number of samples to
      generate. Either this parameter or sequence_indices must be specified but
      not both. If this parameter is None, then the behaviour is determined by
      the `sequence_indices`.
      Default value: `None`.
    sequence_indices: (Optional) `Tensor` of dtype int32 and rank 1. The
      elements of the sequence to compute specified by their position in the
      sequence. The entries index into the Halton sequence starting with 0 and
      hence, must be whole numbers. For example, sequence_indices=[0, 5, 6] will
      produce the first, sixth and seventh elements of the sequence. If this
      parameter is None, then the `num_results` parameter must be specified
      which gives the number of desired samples starting from the first sample.
      Default value: `None`.
    dtype: (Optional) The dtype of the sample. One of: `float16`, `float32` or
      `float64`.
      Default value: `tf.float32`.
    randomized: (Optional) bool indicating whether to produce a randomized
      Halton sequence. If True, applies the randomization described in
      [Owen (2017)][1].
      Default value: `True`.
    seed: (Optional) Python integer to seed the random number generator. Only
      used if `randomized` is True. If not supplied and `randomized` is True,
      no seed is set.
      Default value: `None`.
    name:  (Optional) Python `str` describing ops managed by this function. If
      not supplied the name of this function is used.
      Default value: "sample_halton_sequence".

  Returns:
    halton_elements: Elements of the Halton sequence. `Tensor` of supplied dtype
      and `shape` `[num_results, dim]` if `num_results` was specified or shape
      `[s, dim]` where s is the size of `sequence_indices` if `sequence_indices`
      were specified.

  Raises:
    ValueError: if both `sequence_indices` and `num_results` were specified or
      if dimension `dim` is less than 1 or greater than 1000.

  #### References

  [1]: Art B. Owen. A randomized Halton algorithm in R. _arXiv preprint
       arXiv:1706.02808_, 2017. https://arxiv.org/abs/1706.02808
  """
  if dim < 1 or dim > _MAX_DIMENSION:
    raise ValueError(
        'Dimension must be between 1 and {}. Supplied {}'.format(_MAX_DIMENSION,
                                                                 dim))
  if (num_results is None) == (sequence_indices is None):
    raise ValueError('Either `num_results` or `sequence_indices` must be'
                     ' specified but not both.')

  if not dtype.is_floating:
    raise ValueError('dtype must be of `float`-type')

  with tf.name_scope(name, 'sample', values=[sequence_indices]):
    # Here and in the following, the shape layout is as follows:
    # [sample dimension, event dimension, coefficient dimension].
    # The coefficient dimension is an intermediate axes which will hold the
    # weights of the starting integer when expressed in the (prime) base for
    # an event dimension.
    indices = _get_indices(num_results, sequence_indices, dtype)
    radixes = tf.constant(_PRIMES[0:dim], dtype=dtype, shape=[dim, 1])

    max_sizes_by_axes = _base_expansion_size(tf.reduce_max(indices),
                                             radixes)

    max_size = tf.reduce_max(max_sizes_by_axes)

    # The powers of the radixes that we will need. Note that there is a bit
    # of an excess here. Suppose we need the place value coefficients of 7
    # in base 2 and 3. For 2, we will have 3 digits but we only need 2 digits
    # for base 3. However, we can only create rectangular tensors so we
    # store both expansions in a [2, 3] tensor. This leads to the problem that
    # we might end up attempting to raise large numbers to large powers. For
    # example, base 2 expansion of 1024 has 10 digits. If we were in 10
    # dimensions, then the 10th prime (29) we will end up computing 29^10 even
    # though we don't need it. We avoid this by setting the exponents for each
    # axes to 0 beyond the maximum value needed for that dimension.
    exponents_by_axes = tf.tile([tf.range(max_size)], [dim, 1])

    # The mask is true for those coefficients that are irrelevant.
    weight_mask = exponents_by_axes >= max_sizes_by_axes
    capped_exponents = tf.where(
        weight_mask,
        tf.zeros_like(exponents_by_axes),
        exponents_by_axes)
    weights = radixes ** capped_exponents
    # The following computes the base b expansion of the indices. Suppose,
    # x = a0 + a1*b + a2*b^2 + ... Then, performing a floor div of x with
    # the vector (1, b, b^2, b^3, ...) will produce
    # (a0 + s1 * b, a1 + s2 * b, ...) where s_i are coefficients we don't care
    # about. Noting that all a_i < b by definition of place value expansion,
    # we see that taking the elements mod b of the above vector produces the
    # place value expansion coefficients.
    coeffs = tf.floor_div(indices, weights)
    coeffs *= 1. - tf.cast(weight_mask, dtype)
    coeffs %= radixes
    if not randomized:
      coeffs /= radixes
      return tf.reduce_sum(coeffs / weights, axis=-1)
    seed = distributions_util.gen_new_seed(
        seed, salt='mcmc_sample_halton_sequence_1')
    coeffs = _randomize(coeffs, radixes, seed=seed)
    # Remove the contribution from randomizing the trailing zero for the
    # axes where max_size_by_axes < max_size. This will be accounted
    # for separately below (using zero_correction).
    coeffs *= 1. - tf.cast(weight_mask, dtype)
    coeffs /= radixes
    base_values = tf.reduce_sum(coeffs / weights, axis=-1)

    # The randomization used in Owen (2017) does not leave 0 invariant. While
    # we have accounted for the randomization of the first `max_size_by_axes`
    # coefficients, we still need to correct for the trailing zeros. Luckily,
    # this is equivalent to adding a uniform random value scaled so the first
    # `max_size_by_axes` coefficients are zero. The following statements perform
    # this correction.
    seed = distributions_util.gen_new_seed(
        seed, salt='mcmc_sample_halton_sequence_2')
    zero_correction = tf.random_uniform([dim, 1], seed=seed, dtype=dtype)
    zero_correction /= radixes ** max_sizes_by_axes
    return base_values + tf.reshape(zero_correction, [-1])
Пример #47
0
 def init_momentum(s):
     return random_ops.random_normal(
         shape=array_ops.shape(s),
         dtype=s.dtype.base_dtype,
         seed=distributions_util.gen_new_seed(
             seed, salt="hmc_kernel_momentums"))
Пример #48
0
  def _sample_n(self, n, seed=None):
    with ops.control_dependencies(self._assertions):
      n = ops.convert_to_tensor(n, name="n")
      static_n = tensor_util.constant_value(n)
      n = int(static_n) if static_n is not None else n
      cat_samples = self.cat.sample(n, seed=seed)

      static_samples_shape = cat_samples.get_shape()
      if static_samples_shape.is_fully_defined():
        samples_shape = static_samples_shape.as_list()
        samples_size = static_samples_shape.num_elements()
      else:
        samples_shape = array_ops.shape(cat_samples)
        samples_size = array_ops.size(cat_samples)
      static_batch_shape = self.batch_shape
      if static_batch_shape.is_fully_defined():
        batch_shape = static_batch_shape.as_list()
        batch_size = static_batch_shape.num_elements()
      else:
        batch_shape = self.batch_shape_tensor()
        batch_size = math_ops.reduce_prod(batch_shape)
      static_event_shape = self.event_shape
      if static_event_shape.is_fully_defined():
        event_shape = np.array(static_event_shape.as_list(), dtype=np.int32)
      else:
        event_shape = self.event_shape_tensor()

      # Get indices into the raw cat sampling tensor. We will
      # need these to stitch sample values back out after sampling
      # within the component partitions.
      samples_raw_indices = array_ops.reshape(
          math_ops.range(0, samples_size), samples_shape)

      # Partition the raw indices so that we can use
      # dynamic_stitch later to reconstruct the samples from the
      # known partitions.
      partitioned_samples_indices = data_flow_ops.dynamic_partition(
          data=samples_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)

      # Copy the batch indices n times, as we will need to know
      # these to pull out the appropriate rows within the
      # component partitions.
      batch_raw_indices = array_ops.reshape(
          array_ops.tile(math_ops.range(0, batch_size), [n]), samples_shape)

      # Explanation of the dynamic partitioning below:
      #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
      # Suppose partitions are:
      #     [1 1 0 0 1 1]
      # After partitioning, batch indices are cut as:
      #     [batch_indices[x] for x in 2, 3]
      #     [batch_indices[x] for x in 0, 1, 4, 5]
      # i.e.
      #     [1 1] and [0 0 0 0]
      # Now we sample n=2 from part 0 and n=4 from part 1.
      # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
      # and for part 1 we want samples from batch entries 0, 0, 0, 0
      #   (samples 0, 1, 2, 3).
      partitioned_batch_indices = data_flow_ops.dynamic_partition(
          data=batch_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)
      samples_class = [None for _ in range(self.num_components)]

      for c in range(self.num_components):
        n_class = array_ops.size(partitioned_samples_indices[c])
        seed = distribution_util.gen_new_seed(seed, "mixture")
        samples_class_c = self.components[c].sample(n_class, seed=seed)

        # Pull out the correct batch entries from each index.
        # To do this, we may have to flatten the batch shape.

        # For sample s, batch element b of component c, we get the
        # partitioned batch indices from
        # partitioned_batch_indices[c]; and shift each element by
        # the sample index. The final lookup can be thought of as
        # a matrix gather along locations (s, b) in
        # samples_class_c where the n_class rows correspond to
        # samples within this component and the batch_size columns
        # correspond to batch elements within the component.
        #
        # Thus the lookup index is
        #   lookup[c, i] = batch_size * s[i] + b[c, i]
        # for i = 0 ... n_class[c] - 1.
        lookup_partitioned_batch_indices = (
            batch_size * math_ops.range(n_class) +
            partitioned_batch_indices[c])
        samples_class_c = array_ops.reshape(
            samples_class_c,
            array_ops.concat([[n_class * batch_size], event_shape], 0))
        samples_class_c = array_ops.gather(
            samples_class_c, lookup_partitioned_batch_indices,
            name="samples_class_c_gather")
        samples_class[c] = samples_class_c

      # Stitch back together the samples across the components.
      lhs_flat_ret = data_flow_ops.dynamic_stitch(
          indices=partitioned_samples_indices, data=samples_class)
      # Reshape back to proper sample, batch, and event shape.
      ret = array_ops.reshape(lhs_flat_ret,
                              array_ops.concat([samples_shape,
                                                self.event_shape_tensor()], 0))
      ret.set_shape(
          tensor_shape.TensorShape(static_samples_shape).concatenate(
              self.event_shape))
      return ret
Пример #49
0
    def _as_dataset(self,
                    sample_batch_size=None,
                    num_steps=None,
                    num_parallel_calls=None):
        """Creates a dataset that returns entries from the buffer.

    The dataset behaves differently depending on if `num_steps` is provided or
    not.  If `num_steps is None`, then entire episodes are sampled uniformly at
    random from the buffer.  If `num_steps is not None`, is an integer, then we
    return batches of subsets of length `num_steps`.

    We attempt to shuffle the entries in the batches as well as possible.  The
    algorithm for this is roughly:

    1. Shuffle all the TFRecord files found with prefix
       "{file_prefix}_{experiment_id}_*".

    2. Read from `sample_batch_size` TFRecord files in parallel.

    If `num_steps is not None`:

    3. For each file, create blocks of size `num_steps` for all records
      in that file with shift 1.  Shuffle these blocks, possibly dropping some
      depending on the value of `dataset_block_keep_prob`, and return them
      as a stream.

    4. Interleave the block streams coming from each file and shuffle the
      results.

    5. Create batches from the shuffled blocks.

    6. Parse the batches to match the shape and dtype of `self._data_spec`.

    If `num_steps is None`:

    3. For each file, read in entire sequences of episodes, parse them to match
       the shape and dtype of `self._data_spec`, and emit the episodes.

    4. Interleave the episodes coming from each file and shuffle the
       results.

    5. Return a stream of individual episodes, which can be combined via
       e.g. `tf.data.Dataset.padded_batch()`, `bucket_by_sequence_length()`,
       or other batching approach.

    **NOTE** If `num_steps is None`, then the class properties
    `dataset_block_keep_prob` and `dataset_batch_drop_remainder` are ignored.

    Args:
      sample_batch_size: (Optional.) An optional batch_size to specify the
        number of items to return. See `as_dataset` documentation.  This
        argument should be `None` iff `num_steps` is not `None`.
      num_steps: (Optional.) Scalar int.  How many contiguous frames to get
        per entry. Default is `None`: return full-length episodes.  This
        argument should be `None` iff `sample_batch_size` is not `None`.
      num_parallel_calls: (Optional.) Number of parallel calls to use in the
        dataset pipeline when interleaving reads from parallel TFRecord files.

    Returns:
      A dataset of type tf.data.Dataset, elements of which are 2-tuples of:
        - An item or sequence of items sampled uniformly from the buffer.
        - BufferInfo namedtuple, containing the episode ids.

    Raises:
      ValueError: If `sample_batch_size is None` but `num_steps is not None`,
        or if `sample_batch_size is not None` but `num_steps is None`.
      ValueError: If the data spec contains lists that must be converted to
        tuples.
    """
        if num_steps is None:
            if sample_batch_size is not None:
                raise ValueError(
                    'When num_steps is None, sample_batch_size must be '
                    'None but saw: %s' % (sample_batch_size, ))
        else:
            if sample_batch_size is None or sample_batch_size <= 0:
                raise ValueError(
                    'When num_steps is not None, sample_batch_size must be '
                    'an integer > 0, saw: %s' % (sample_batch_size, ))

        # data_tf.nest.flatten does not flatten python lists, tf.nest.flatten does.
        flat_data_spec = tf.nest.flatten(self._data_spec)
        if flat_data_spec != data_nest.flatten(self._data_spec):
            raise ValueError(
                'Cannot perform gather; data spec contains lists and this conflicts '
                'with gathering operator.  Convert any lists to tuples.  '
                'For example, if your spec looks like [a, b, c], '
                'change it to (a, b, c).  Spec structure is:\n  {}'.format(
                    tf.nest.map_structure(lambda spec: spec.dtype,
                                          self._data_spec)))

        filename_seed = distributions_util.gen_new_seed(self._data.seed,
                                                        salt='filename_seed')

        batch_seed = distributions_util.gen_new_seed(self._data.seed,
                                                     salt='batch_seed')

        drop_block_seed = distributions_util.gen_new_seed(self._data.seed,
                                                          salt='drop_block')

        # TODO(b/128998627): Use a different seed for each file by mapping a count
        # with the filename and doing the seed generation in graph mode.
        per_episode_seed = distributions_util.gen_new_seed(
            self._data.seed, salt='per_episode_seed')

        block_keep_prob = self._data.dataset_block_keep_prob
        dropping_blocks = (tf.is_tensor(block_keep_prob)
                           or block_keep_prob != 1.0)
        if dropping_blocks:
            # empty_block_ds is in format (is_real_data=False, empty_data)
            empty_block_ds = tf.data.Dataset.from_tensors(
                (False, tf.fill([num_steps], '')))

            def select_true_or_empty(_):
                # When this returns 0, select the true block.  When this returns 1,
                # select the empty block.
                return tf.cast(
                    tf.random.uniform(
                        (), seed=drop_block_seed) > block_keep_prob, tf.int64)

            true_or_empty_block_selector_ds = (
                tf.data.experimental.Counter().map(select_true_or_empty))

        def list_and_shuffle_files(_):
            filenames = tf.io.matching_files(
                tf.strings.join(
                    (self._data.file_prefix, self._data.experiment_id, '*'),
                    separator='_'))
            shuffled = tf.random.shuffle(filenames, seed=filename_seed)
            return shuffled

        def parse_blocks_from_record(records):
            """Decode `FeatureList` tensor `records`.

      Args:
        records: `tf.string` tensor of shape either `[]` or `[batch_size]`.

      Outputs:
        A struct matching `self._data_spec` containing tensors.
        If `num_steps is not None`, it contains tensors with shape
        `[batch_size, num_steps, ...]`; otherwise they have shape `[...]`.
      """
            # If `num_steps is None`, then:
            #  records is shaped [].
            #  features is shaped [len(flatten(self._data_spec))].
            # otherwise:
            #  records is shaped [batch_size].
            #  features is shaped [batch_size, len(flatten(self._data_spec))].
            _, features = tf.io.decode_proto(
                bytes=records,
                message_type='tensorflow.FeatureList',
                field_names=['feature'],
                output_types=[tf.string])
            features = features.pop()
            num_features = len(flat_data_spec)
            features = tf.unstack(features, num_features, axis=-1)
            decoded_features = []
            for feature, spec in zip(features, flat_data_spec):
                decoded_feature = _decode_feature(feature,
                                                  spec,
                                                  has_outer_dims=num_steps
                                                  is not None)
                decoded_features.append(decoded_feature)
            return tf.nest.pack_sequence_as(self._data_spec, decoded_features)

        def read_and_block_fixed_length_tfrecord_file(filename):
            """Read records from `filename`, window them into fixed len blocks.

      This function also optionally subsamples and shuffles the blocks.

      Windowed records from filename come as a stream and prior to subsampling
      and shuffling, the stream contains blocks of the form:

         [r0, r1, ..., r_{num_steps - 1}]
         [r1, r2, ..., r_{num_steps}]
         [r2, r3, ..., r_{num_steps + 1}]
         ...

      Args:
        filename: A scalar string `Tensor` with the TFRecord filename.

      Returns:
        A `tf.data.Dataset` instance.
      """
            def drop_or_batch_window(ds):
                if not dropping_blocks:
                    return ds.batch(num_steps, drop_remainder=True)
                else:
                    # batched_ds is in format (is_real_data=True, true_ds)
                    batched_ds = tf.data.Dataset.zip(
                        (tf.data.Dataset.from_tensors(True),
                         ds.batch(num_steps, drop_remainder=True)))
                    return (tf.data.experimental.choose_from_datasets(
                        (batched_ds, empty_block_ds),
                        true_or_empty_block_selector_ds).take(1).filter(
                            lambda is_real_data, _: is_real_data).map(
                                lambda _, true_block: true_block))

            return (tf.data.TFRecordDataset(
                filename,
                compression_type=_compression_type_string(
                    self._data.record_options)).window(num_steps,
                                                       shift=1,
                                                       stride=1,
                                                       drop_remainder=True).
                    flat_map(drop_or_batch_window).shuffle(
                        buffer_size=self._data.per_file_shuffle_buffer_size,
                        seed=per_episode_seed))

        def read_and_block_variable_length_tfrecord_file(filename):
            """Read records from `filename`, window them into variable len blocks."""
            def create_ta(spec):
                return tf.TensorArray(size=0,
                                      dynamic_size=True,
                                      element_shape=spec.shape,
                                      dtype=spec.dtype)

            empty_tas = tf.nest.map_structure(create_ta, self._data_spec)

            def parse_and_block_on_episode_boundaries(partial_tas, record):
                frame = parse_blocks_from_record(record)
                updated_tas = tf.nest.map_structure(
                    lambda ta, f: ta.write(ta.size(), f), partial_tas, frame)
                # If we see a LAST field, then emit empty TAs for the state and updated
                # TAs for the output.  Otherwise emit updated TAs for the state and
                # empty TAs for the output (the empty output TAs will be filtered).
                return tf.cond(tf.equal(frame.step_type, StepType.LAST),
                               lambda: (empty_tas, updated_tas), lambda:
                               (updated_tas, empty_tas))

            stack_tas = lambda tas: tf.nest.map_structure(
                lambda ta: ta.stack(), tas)
            remove_intermediate_arrays = lambda tas: tas.step_type.size() > 0

            return (tf.data.TFRecordDataset(
                filename,
                compression_type=_compression_type_string(
                    self._data.record_options)).apply(
                        tf.data.experimental.scan(
                            empty_tas, parse_and_block_on_episode_boundaries)).
                    filter(remove_intermediate_arrays).map(stack_tas).shuffle(
                        buffer_size=self._data.per_file_shuffle_buffer_size,
                        seed=per_episode_seed))

        interleave_shuffle_buffer_size = (
            (num_parallel_calls or sample_batch_size or 4) *
            self._data.sampling_dataset_timesteps_per_episode_hint)

        if num_steps is None:
            read_and_block_fn = read_and_block_variable_length_tfrecord_file
        else:
            read_and_block_fn = read_and_block_fixed_length_tfrecord_file

        # Use tf.data.Dataset.from_tensors(0).map(...) to call the map() code once
        # per initialization.  This means that when the iterator is reinitialized,
        # we get a new list of files.
        ds = (
            tf.data.Dataset.from_tensors(0).map(
                list_and_shuffle_files).flat_map(
                    tf.data.Dataset.from_tensor_slices)
            # Interleave between blocks of records from different files.
            .interleave(read_and_block_fn,
                        cycle_length=max(
                            num_parallel_calls or sample_batch_size or 0, 4),
                        block_length=1,
                        num_parallel_calls=(
                            num_parallel_calls
                            or tf.data.experimental.AUTOTUNE)).shuffle(
                                buffer_size=interleave_shuffle_buffer_size,
                                seed=batch_seed))

        # Batch and parse the blocks.  If `num_steps is None`, parsing has already
        # happened and we're not batching.
        if num_steps is not None:
            ds = (ds.batch(batch_size=sample_batch_size,
                           drop_remainder=self._data.drop_remainder).map(
                               parse_blocks_from_record))

        return ds
Пример #50
0
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            # This sampling approach is almost the same as the approach used by
            # `MixtureSameFamily`. The differences are due to having a list of
            # `Distribution` objects rather than a single object, and maintaining
            # random seed management that is consistent with the non-static code path.
            samples = []
            cat_samples = self.cat.sample(n, seed=seed)
            for c in range(self.num_components):
                seed = distribution_util.gen_new_seed(seed, "mixture")
                samples.append(self.components[c].sample(n, seed=seed))
            x = array_ops.stack(samples, -self._static_event_shape.ndims -
                                1)  # [n, B, k, E]
            npdt = x.dtype.as_numpy_dtype
            mask = array_ops.one_hot(
                indices=cat_samples,  # [n, B]
                depth=self._num_components,  # == k
                on_value=np.ones([], dtype=npdt),
                off_value=np.zeros([], dtype=npdt))  # [n, B, k]
            mask = distribution_utils.pad_mixture_dimensions(
                mask, self, self._cat,
                self._static_event_shape.ndims)  # [n, B, k, [1]*e]
            return math_ops.reduce_sum(
                x * mask,
                axis=-1 - self._static_event_shape.ndims)  # [n, B, E]

        with ops.control_dependencies(self._assertions):
            n = ops.convert_to_tensor(n, name="n")
            static_n = tensor_util.constant_value(n)
            n = int(static_n) if static_n is not None else n
            cat_samples = self.cat.sample(n, seed=seed)

            static_samples_shape = cat_samples.get_shape()
            if static_samples_shape.is_fully_defined():
                samples_shape = static_samples_shape.as_list()
                samples_size = static_samples_shape.num_elements()
            else:
                samples_shape = array_ops.shape(cat_samples)
                samples_size = array_ops.size(cat_samples)
            static_batch_shape = self.batch_shape
            if static_batch_shape.is_fully_defined():
                batch_shape = static_batch_shape.as_list()
                batch_size = static_batch_shape.num_elements()
            else:
                batch_shape = self.batch_shape_tensor()
                batch_size = math_ops.reduce_prod(batch_shape)
            static_event_shape = self.event_shape
            if static_event_shape.is_fully_defined():
                event_shape = np.array(static_event_shape.as_list(),
                                       dtype=np.int32)
            else:
                event_shape = self.event_shape_tensor()

            # Get indices into the raw cat sampling tensor. We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = array_ops.reshape(
                math_ops.range(0, samples_size), samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = data_flow_ops.dynamic_partition(
                data=samples_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = array_ops.reshape(
                array_ops.tile(math_ops.range(0, batch_size), [n]),
                samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = data_flow_ops.dynamic_partition(
                data=batch_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)
            samples_class = [None for _ in range(self.num_components)]

            for c in range(self.num_components):
                n_class = array_ops.size(partitioned_samples_indices[c])
                seed = distribution_util.gen_new_seed(seed, "mixture")
                samples_class_c = self.components[c].sample(n_class, seed=seed)

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index. The final lookup can be thought of as
                # a matrix gather along locations (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * math_ops.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = array_ops.reshape(
                    samples_class_c,
                    array_ops.concat([[n_class * batch_size], event_shape], 0))
                samples_class_c = array_ops.gather(
                    samples_class_c,
                    lookup_partitioned_batch_indices,
                    name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the components.
            lhs_flat_ret = data_flow_ops.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = array_ops.reshape(
                lhs_flat_ret,
                array_ops.concat(
                    [samples_shape, self.event_shape_tensor()], 0))
            ret.set_shape(
                tensor_shape.TensorShape(static_samples_shape).concatenate(
                    self.event_shape))
            return ret
Пример #51
0
    def one_step(self, current_state, previous_kernel_results):
        """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """
        with tf.name_scope(name=mcmc_util.make_name(self.name, 'remc',
                                                    'one_step'),
                           values=[current_state, previous_kernel_results]):
            sampled_replica_states, sampled_replica_results = zip(*[
                rk.one_step(previous_kernel_results.replica_states[i],
                            previous_kernel_results.replica_results[i])
                for i, rk in enumerate(self.replica_kernels)
            ])
            sampled_replica_states = list(sampled_replica_states)
            sampled_replica_results = list(sampled_replica_results)

            sampled_replica_results_modified = [
                srr._replace(target_log_prob=srr.target_log_prob /
                             self.inverse_temperatures[i])
                if 'target_log_prob' in srr._fields else srr._replace(
                    accepted_results=srr.accepted_results._replace(
                        target_log_prob=srr.accepted_results.target_log_prob /
                        self.inverse_temperatures[i]))
                for i, srr in enumerate(sampled_replica_results)
            ]

            sampled_replica_ratios = [
                srr.target_log_prob if 'target_log_prob' in srr._fields else
                srr.accepted_results.target_log_prob
                for i, srr in enumerate(sampled_replica_results_modified)
            ]
            sampled_replica_ratios = tf.stack(sampled_replica_ratios, axis=-1)

            next_replica_idx = tf.range(self.num_replica)
            self._seed_stream = distributions_util.gen_new_seed(
                self._seed_stream, salt='replica_exchange_one_step')
            exchange_proposed, exchange_proposed_n = self.exchange_proposed_fn(
                self.num_replica, seed=self._seed_stream)
            i = tf.constant(0)

            def cond(i, next_replica_idx):  # pylint: disable=unused-argument
                return tf.less(i, exchange_proposed_n)

            def body(i, next_replica_idx):
                """`tf.while_loop` body."""
                ratio = (sampled_replica_ratios[next_replica_idx[
                    exchange_proposed[i, 0]]] - sampled_replica_ratios[
                        next_replica_idx[exchange_proposed[i, 1]]])
                ratio *= (self.inverse_temperatures[exchange_proposed[i, 1]] -
                          self.inverse_temperatures[exchange_proposed[i, 0]])
                self._seed_stream = distributions_util.gen_new_seed(
                    self._seed_stream, salt='replica_exchange_one_step')
                log_uniform = tf.log(
                    tf.random_uniform(shape=tf.shape(ratio),
                                      dtype=ratio.dtype.base_dtype,
                                      seed=self._seed_stream))
                exchange = log_uniform < ratio
                exchange_op = tf.sparse_to_dense(
                    [exchange_proposed[i, 0], exchange_proposed[i, 1]],
                    [self.num_replica], [
                        next_replica_idx[exchange_proposed[i, 1]] -
                        next_replica_idx[exchange_proposed[i, 0]],
                        next_replica_idx[exchange_proposed[i, 0]] -
                        next_replica_idx[exchange_proposed[i, 1]]
                    ])
                next_replica_idx = tf.cond(
                    exchange, lambda: next_replica_idx + exchange_op,
                    lambda: next_replica_idx)
                return [i + 1, next_replica_idx]

            next_replica_idx = tf.while_loop(cond,
                                             body,
                                             loop_vars=[i,
                                                        next_replica_idx])[1]

            def _prep(list_):
                return list(
                    tf.case(
                        {
                            tf.equal(next_replica_idx[i], j): _stateful_lambda(
                                list_[j])
                            for j in range(self.num_replica)
                        },
                        exclusive=True) for i in range(self.num_replica))

            next_replica_states = _prep(sampled_replica_states)
            next_replica_results = _prep(sampled_replica_results_modified)

            next_replica_results = [
                nrr._replace(target_log_prob=nrr.target_log_prob *
                             self.inverse_temperatures[i])
                if 'target_log_prob' in nrr._fields else nrr._replace(
                    accepted_results=nrr.accepted_results._replace(
                        target_log_prob=nrr.accepted_results.target_log_prob *
                        self.inverse_temperatures[i]))
                for i, nrr in enumerate(next_replica_results)
            ]

            next_state = tf.identity(next_replica_states[0])
            kernel_results = ReplicaExchangeMCKernelResults(
                replica_states=next_replica_states,
                replica_results=next_replica_results,
                next_replica_idx=next_replica_idx,
                exchange_proposed=exchange_proposed,
                exchange_proposed_n=exchange_proposed_n,
                sampled_replica_states=sampled_replica_states,
                sampled_replica_results=sampled_replica_results,
            )

            return next_state, kernel_results
Пример #52
0
def kernel(target_log_prob_fn,
           current_state,
           step_size,
           num_leapfrog_steps,
           seed=None,
           current_target_log_prob=None,
           current_grads_target_log_prob=None,
           name=None):
    """Runs one iteration of Hamiltonian Monte Carlo.

  Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC)
  algorithm that takes a series of gradient-informed steps to produce
  a Metropolis proposal. This function applies one step of HMC to
  randomly update the variable `x`.

  This function can update multiple chains in parallel. It assumes that all
  leftmost dimensions of `current_state` index independent chain states (and are
  therefore updated independently). The output of `target_log_prob_fn()` should
  sum log-probabilities across all event dimensions. Slices along the rightmost
  dimensions may have different target distributions; for example,
  `current_state[0, :]` could have a different target distribution from
  `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of
  independent chains is `tf.size(target_log_prob_fn(*current_state))`.)

  #### Examples:

  ##### Simple chain with warm-up.

  ```python
  tfd = tf.contrib.distributions

  # Tuning acceptance rates:
  dtype = np.float32
  target_accept_rate = 0.631
  num_warmup_iter = 500
  num_chain_iter = 500

  x = tf.get_variable(name="x", initializer=dtype(1))
  step_size = tf.get_variable(name="step_size", initializer=dtype(1))

  target = tfd.Normal(loc=dtype(0), scale=dtype(1))

  new_x, other_results = hmc.kernel(
      target_log_prob_fn=target.log_prob,
      current_state=x,
      step_size=step_size,
      num_leapfrog_steps=3)[:4]

  x_update = x.assign(new_x)

  step_size_update = step_size.assign_add(
      step_size * tf.where(
        other_results.acceptance_probs > target_accept_rate,
        0.01, -0.01))

  warmup = tf.group([x_update, step_size_update])

  tf.global_variables_initializer().run()

  sess.graph.finalize()  # No more graph building.

  # Warm up the sampler and adapt the step size
  for _ in xrange(num_warmup_iter):
    sess.run(warmup)

  # Collect samples without adapting step size
  samples = np.zeros([num_chain_iter])
  for i in xrange(num_chain_iter):
    _, x_, target_log_prob_, grad_ = sess.run([
        x_update,
        x,
        other_results.target_log_prob,
        other_results.grads_target_log_prob])
    samples[i] = x_

  print(samples.mean(), samples.std())
  ```

  ##### Sample from more complicated posterior.

  I.e.,

  ```none
    W ~ MVN(loc=0, scale=sigma * eye(dims))
    for i=1...num_samples:
        X[i] ~ MVN(loc=0, scale=eye(dims))
      eps[i] ~ Normal(loc=0, scale=1)
        Y[i] = X[i].T * W + eps[i]
  ```

  ```python
  tfd = tf.contrib.distributions

  def make_training_data(num_samples, dims, sigma):
    dt = np.asarray(sigma).dtype
    zeros = tf.zeros(dims, dtype=dt)
    x = tfd.MultivariateNormalDiag(
        loc=zeros).sample(num_samples, seed=1)
    w = tfd.MultivariateNormalDiag(
        loc=zeros,
        scale_identity_multiplier=sigma).sample(seed=2)
    noise = tfd.Normal(
        loc=dt(0),
        scale=dt(1)).sample(num_samples, seed=3)
    y = tf.tensordot(x, w, axes=[[1], [0]]) + noise
    return y, x, w

  def make_prior(sigma, dims):
    # p(w | sigma)
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros([dims], dtype=sigma.dtype),
        scale_identity_multiplier=sigma)

  def make_likelihood(x, w):
    # p(y | x, w)
    return tfd.MultivariateNormalDiag(
        loc=tf.tensordot(x, w, axes=[[1], [0]]))

  # Setup assumptions.
  dtype = np.float32
  num_samples = 150
  dims = 10
  num_iters = int(5e3)

  true_sigma = dtype(0.5)
  y, x, true_weights = make_training_data(num_samples, dims, true_sigma)

  # Estimate of `log(true_sigma)`.
  log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0))
  sigma = tf.exp(log_sigma)

  # State of the Markov chain.
  weights = tf.get_variable(
      name="weights",
      initializer=np.random.randn(dims).astype(dtype))

  prior = make_prior(sigma, dims)

  def joint_log_prob_fn(w):
    # f(w) = log p(w, y | x)
    return prior.log_prob(w) + make_likelihood(x, w).log_prob(y)

  weights_update = weights.assign(
      hmc.kernel(target_log_prob_fn=joint_log_prob,
                 current_state=weights,
                 step_size=0.1,
                 num_leapfrog_steps=5)[0])

  with tf.control_dependencies([weights_update]):
    loss = -prior.log_prob(weights)

  optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
  log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma])

  sess.graph.finalize()  # No more graph building.

  tf.global_variables_initializer().run()

  sigma_history = np.zeros(num_iters, dtype)
  weights_history = np.zeros([num_iters, dims], dtype)

  for i in xrange(num_iters):
    _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights])
    weights_history[i, :] = weights_
    sigma_history[i] = sigma_

  true_weights_ = sess.run(true_weights)

  # Should converge to something close to true_sigma.
  plt.plot(sigma_history);
  plt.ylabel("sigma");
  plt.xlabel("iteration");
  ```

  Args:
    target_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the target distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
      for the leapfrog integrator. Must broadcast with the shape of
      `current_state`. Larger step sizes lead to faster progress, but too-large
      step sizes make rejection exponentially more likely. When possible, it's
      often helpful to match per-variable step sizes to the standard deviations
      of the target distribution in each variable.
    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
      for. Total progress per HMC step is roughly proportional to `step_size *
      num_leapfrog_steps`.
    seed: Python integer to seed the random number generator.
    current_target_log_prob: (Optional) `Tensor` representing the value of
      `target_log_prob_fn` at the `current_state`. The only reason to
      specify this argument is to reduce TF graph size.
      Default value: `None` (i.e., compute as needed).
    current_grads_target_log_prob: (Optional) Python list of `Tensor`s
      representing gradient of `current_target_log_prob` at the `current_state`
      and wrt the `current_state`. Must have same shape as `current_state`. The
      only reason to specify this argument is to reduce TF graph size.
      Default value: `None` (i.e., compute as needed).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "hmc_kernel").

  Returns:
    accepted_state: Tensor or Python list of `Tensor`s representing the state(s)
      of the Markov chain(s) at each result step. Has same shape as
      `current_state`.
    kernel_results: `collections.namedtuple` of internal calculations used to
      advance the chain.

  Raises:
    ValueError: if there isn't one `step_size` or a list with same length as
      `current_state`.
  """
    with ops.name_scope(name, "hmc_kernel", [
            current_state, step_size, num_leapfrog_steps, seed,
            current_target_log_prob, current_grads_target_log_prob
    ]):
        with ops.name_scope("initialize"):
            [
                current_state_parts, step_sizes, current_target_log_prob,
                current_grads_target_log_prob
            ] = _prepare_args(target_log_prob_fn,
                              current_state,
                              step_size,
                              current_target_log_prob,
                              current_grads_target_log_prob,
                              maybe_expand=True)
            independent_chain_ndims = distributions_util.prefer_static_rank(
                current_target_log_prob)
            current_momentums = []
            for s in current_state_parts:
                current_momentums.append(
                    random_ops.random_normal(shape=array_ops.shape(s),
                                             dtype=s.dtype.base_dtype,
                                             seed=seed))
                seed = distributions_util.gen_new_seed(
                    seed, salt="hmc_kernel_momentums")

            num_leapfrog_steps = ops.convert_to_tensor(
                num_leapfrog_steps,
                dtype=dtypes.int32,
                name="num_leapfrog_steps")
        [
            proposed_momentums,
            proposed_state_parts,
            proposed_target_log_prob,
            proposed_grads_target_log_prob,
        ] = _leapfrog_integrator(current_momentums, target_log_prob_fn,
                                 current_state_parts, step_sizes,
                                 num_leapfrog_steps, current_target_log_prob,
                                 current_grads_target_log_prob)

        energy_change = _compute_energy_change(current_target_log_prob,
                                               current_momentums,
                                               proposed_target_log_prob,
                                               proposed_momentums,
                                               independent_chain_ndims)

        # u < exp(min(-energy, 0)),  where u~Uniform[0,1)
        # ==> -log(u) >= max(e, 0)
        # ==> -log(u) >= e
        # (Perhaps surprisingly, we don't have a better way to obtain a random
        # uniform from positive reals, i.e., `tf.random_uniform(minval=0,
        # maxval=np.inf)` won't work.)
        random_uniform = random_ops.random_uniform(
            shape=array_ops.shape(energy_change),
            dtype=energy_change.dtype,
            seed=seed)
        random_positive = -math_ops.log(random_uniform)
        is_accepted = random_positive >= energy_change

        accepted_target_log_prob = array_ops.where(is_accepted,
                                                   proposed_target_log_prob,
                                                   current_target_log_prob)

        accepted_state_parts = [
            _choose(is_accepted, proposed_state_part, current_state_part,
                    independent_chain_ndims)
            for current_state_part, proposed_state_part in zip(
                current_state_parts, proposed_state_parts)
        ]

        accepted_grads_target_log_prob = [
            _choose(is_accepted, proposed_grad, grad, independent_chain_ndims)
            for proposed_grad, grad in zip(proposed_grads_target_log_prob,
                                           current_grads_target_log_prob)
        ]

        maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0]
        return [
            maybe_flatten(accepted_state_parts),
            KernelResults(
                acceptance_probs=math_ops.exp(
                    math_ops.minimum(-energy_change, 0.)),
                current_grads_target_log_prob=accepted_grads_target_log_prob,
                current_target_log_prob=accepted_target_log_prob,
                energy_change=energy_change,
                is_accepted=is_accepted,
                proposed_grads_target_log_prob=proposed_grads_target_log_prob,
                proposed_state=maybe_flatten(proposed_state_parts),
                proposed_target_log_prob=proposed_target_log_prob,
                random_positive=random_positive,
            ),
        ]
Пример #53
0
def sample_chain(
    num_results,
    target_log_prob_fn,
    current_state,
    step_size,
    num_leapfrog_steps,
    num_burnin_steps=0,
    num_steps_between_results=0,
    seed=None,
    current_target_log_prob=None,
    current_grads_target_log_prob=None,
    name=None):
  """Runs multiple iterations of one or more Hamiltonian Monte Carlo chains.

  Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm
  that takes a series of gradient-informed steps to produce a Metropolis
  proposal. This function samples from an HMC Markov chain at `current_state`
  and whose stationary distribution has log-unnormalized-density
  `target_log_prob_fn()`.

  This function samples from multiple chains in parallel. It assumes that the
  the leftmost dimensions of (each) `current_state` (part) index an independent
  chain.  The function `target_log_prob_fn()` sums log-probabilities across
  event dimensions (i.e., current state (part) rightmost dimensions). Each
  element of the output of `target_log_prob_fn()` represents the (possibly
  unnormalized) log-probability of the joint distribution over (all) the current
  state (parts).

  The `current_state` can be represented as a single `Tensor` or a `list` of
  `Tensors` which collectively represent the current state. When specifying a
  `list`, one must also specify a list of `step_size`s.

  Only one out of every `num_steps_between_samples + 1` steps is included in the
  returned results. This "thinning" comes at a cost of reduced statistical
  power, while reducing memory requirements and autocorrelation. For more
  discussion see [1].

  [1]: "Statistically efficient thinning of a Markov chain sampler."
       Art B. Owen. April 2017.
       http://statweb.stanford.edu/~owen/reports/bestthinning.pdf

  #### Examples:

  ##### Sample from a diagonal-variance Gaussian.

  ```python
  tfd = tf.contrib.distributions

  def make_likelihood(true_variances):
    return tfd.MultivariateNormalDiag(
        scale_diag=tf.sqrt(true_variances))

  dims = 10
  dtype = np.float32
  true_variances = tf.linspace(dtype(1), dtype(3), dims)
  likelihood = make_likelihood(true_variances)

  states, kernel_results = hmc.sample_chain(
      num_results=1000,
      target_log_prob_fn=likelihood.log_prob,
      current_state=tf.zeros(dims),
      step_size=0.5,
      num_leapfrog_steps=2,
      num_burnin_steps=500)

  # Compute sample stats.
  sample_mean = tf.reduce_mean(states, axis=0)
  sample_var = tf.reduce_mean(
      tf.squared_difference(states, sample_mean),
      axis=0)
  ```

  ##### Sampling from factor-analysis posteriors with known factors.

  I.e.,

  ```none
  for i=1..n:
    w[i] ~ Normal(0, eye(d))            # prior
    x[i] ~ Normal(loc=matmul(w[i], F))  # likelihood
  ```

  where `F` denotes factors.

  ```python
  tfd = tf.contrib.distributions

  def make_prior(dims, dtype):
    return tfd.MultivariateNormalDiag(
        loc=tf.zeros(dims, dtype))

  def make_likelihood(weights, factors):
    return tfd.MultivariateNormalDiag(
        loc=tf.tensordot(weights, factors, axes=[[0], [-1]]))

  # Setup data.
  num_weights = 10
  num_factors = 4
  num_chains = 100
  dtype = np.float32

  prior = make_prior(num_weights, dtype)
  weights = prior.sample(num_chains)
  factors = np.random.randn(num_factors, num_weights).astype(dtype)
  x = make_likelihood(weights, factors).sample(num_chains)

  def target_log_prob(w):
    # Target joint is: `f(w) = p(w, x | factors)`.
    return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x)

  # Get `num_results` samples from `num_chains` independent chains.
  chains_states, kernels_results = hmc.sample_chain(
      num_results=1000,
      target_log_prob_fn=target_log_prob,
      current_state=tf.zeros([num_chains, dims], dtype),
      step_size=0.1,
      num_leapfrog_steps=2,
      num_burnin_steps=500)

  # Compute sample stats.
  sample_mean = tf.reduce_mean(chains_states, axis=[0, 1])
  sample_var = tf.reduce_mean(
      tf.squared_difference(chains_states, sample_mean),
      axis=[0, 1])
  ```

  Args:
    num_results: Integer number of Markov chain draws.
    target_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the target distribution.
    current_state: `Tensor` or Python `list` of `Tensor`s representing the
      current state(s) of the Markov chain(s). The first `r` dimensions index
      independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`.
    step_size: `Tensor` or Python `list` of `Tensor`s representing the step size
      for the leapfrog integrator. Must broadcast with the shape of
      `current_state`. Larger step sizes lead to faster progress, but too-large
      step sizes make rejection exponentially more likely. When possible, it's
      often helpful to match per-variable step sizes to the standard deviations
      of the target distribution in each variable.
    num_leapfrog_steps: Integer number of steps to run the leapfrog integrator
      for. Total progress per HMC step is roughly proportional to `step_size *
      num_leapfrog_steps`.
    num_burnin_steps: Integer number of chain steps to take before starting to
      collect results.
      Default value: 0 (i.e., no burn-in).
    num_steps_between_results: Integer number of chain steps between collecting
      a result. Only one out of every `num_steps_between_samples + 1` steps is
      included in the returned results. This "thinning" comes at a cost of
      reduced statistical power, while reducing memory requirements and
      autocorrelation. For more discussion see [1].
      Default value: 0 (i.e., no subsampling).
    seed: Python integer to seed the random number generator.
    current_target_log_prob: (Optional) `Tensor` representing the value of
      `target_log_prob_fn` at the `current_state`. The only reason to specify
      this argument is to reduce TF graph size.
      Default value: `None` (i.e., compute as needed).
    current_grads_target_log_prob: (Optional) Python list of `Tensor`s
      representing gradient of `target_log_prob` at the `current_state` and wrt
      the `current_state`. Must have same shape as `current_state`. The only
      reason to specify this argument is to reduce TF graph size.
      Default value: `None` (i.e., compute as needed).
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., "hmc_sample_chain").

  Returns:
    accepted_states: Tensor or Python list of `Tensor`s representing the
      state(s) of the Markov chain(s) at each result step. Has same shape as
      input `current_state` but with a prepended `num_results`-size dimension.
    kernel_results: `collections.namedtuple` of internal calculations used to
      advance the chain.
  """
  with ops.name_scope(
      name, "hmc_sample_chain",
      [num_results, current_state, step_size, num_leapfrog_steps,
       num_burnin_steps, num_steps_between_results, seed,
       current_target_log_prob, current_grads_target_log_prob]):
    with ops.name_scope("initialize"):
      [
          current_state,
          step_size,
          current_target_log_prob,
          current_grads_target_log_prob,
      ] = _prepare_args(
          target_log_prob_fn, current_state, step_size,
          current_target_log_prob, current_grads_target_log_prob)
    def _run_chain(num_steps, current_state, seed, kernel_results):
      """Runs the chain(s) for `num_steps`."""
      def _loop_body(iter_, current_state, kernel_results):
        return [iter_ + 1] + list(kernel(
            target_log_prob_fn,
            current_state,
            step_size,
            num_leapfrog_steps,
            seed,
            kernel_results.current_target_log_prob,
            kernel_results.current_grads_target_log_prob))
      return control_flow_ops.while_loop(
          cond=lambda iter_, *args: iter_ < num_steps,
          body=_loop_body,
          loop_vars=[0, current_state, kernel_results])[1:]  # Lop-off "iter_".

    def _scan_body(args_list, _):
      """Closure which implements `tf.scan` body."""
      current_state, kernel_results = args_list
      return _run_chain(num_steps_between_results + 1, current_state, seed,
                        kernel_results)

    current_state, kernel_results = _run_chain(
        num_burnin_steps,
        current_state,
        distributions_util.gen_new_seed(
            seed, salt="hmc_sample_chain_burnin"),
        _make_dummy_kernel_results(
            current_state,
            current_target_log_prob,
            current_grads_target_log_prob))

    return functional_ops.scan(
        fn=_scan_body,
        elems=array_ops.zeros(num_results, dtype=dtypes.bool),  # Dummy arg.
        initializer=[current_state, kernel_results])