Ejemplo n.º 1
0
    def _start_trajectory_batched(self, state, target_log_prob):
        """Computations needed to start a trajectory."""
        with tf.name_scope('start_trajectory_batched'):
            seed_stream = SeedStream(self._seed_stream,
                                     salt='start_trajectory_batched')
            momentum = [
                tf.random.normal(  # pylint: disable=g-complex-comprehension
                    shape=prefer_static.shape(x),
                    dtype=x.dtype,
                    seed=seed_stream()) for x in state
            ]
            init_energy = compute_hamiltonian(target_log_prob, momentum)

            if MULTINOMIAL_SAMPLE:
                return momentum, init_energy, None

            # Draw a slice variable u ~ Uniform(0, p(initial state, initial
            # momentum)) and compute log u. For numerical stability, we perform this
            # in log space where log u = log (u' * p(...)) = log u' + log
            # p(...) and u' ~ Uniform(0, 1).
            log_slice_sample = tf.math.log1p(
                -tf.random.uniform(shape=prefer_static.shape(init_energy),
                                   dtype=init_energy.dtype,
                                   seed=seed_stream()))
            return momentum, init_energy, log_slice_sample
Ejemplo n.º 2
0
  def default_exchange_proposed_fn_(num_replica, seed=None):
    """Default function for `exchange_proposed_fn` of `kernel`."""
    seed_stream = SeedStream(seed, 'default_exchange_proposed_fn')

    zero_start = tf.random_uniform([], seed=seed_stream()) > 0.5
    if num_replica % 2 == 0:

      def _exchange():
        flat_exchange = tf.range(num_replica)
        if num_replica > 2:
          start = tf.to_int32(~zero_start)
          end = num_replica - start
          flat_exchange = flat_exchange[start:end]
        return tf.reshape(flat_exchange, [tf.size(flat_exchange) // 2, 2])
    else:

      def _exchange():
        start = tf.to_int32(zero_start)
        end = num_replica - tf.to_int32(~zero_start)
        flat_exchange = tf.range(num_replica)[start:end]
        return tf.reshape(flat_exchange, [tf.size(flat_exchange) // 2, 2])

    def _null_exchange():
      return tf.reshape(tf.to_int32([]), shape=[0, 2])

    return tf.cond(
        tf.random_uniform([], seed=seed_stream()) < prob_exchange, _exchange,
        _null_exchange)
Ejemplo n.º 3
0
 def _sample_n(self, n, seed=None):
   shape = tf.concat([[n], self._batch_shape_tensor()], axis=0)
   seed = SeedStream(seed, salt="random_horseshoe")
   local_shrinkage = self._half_cauchy.sample(shape, seed=seed())
   shrinkage = self.scale * local_shrinkage
   sampled = tf.random.normal(
       shape=shape, mean=0., stddev=1., dtype=self.scale.dtype, seed=seed())
   return sampled * shrinkage
Ejemplo n.º 4
0
  def __init__(self,
               target_log_prob_fn,
               inverse_temperatures,
               make_kernel_fn,
               exchange_proposed_fn=default_exchange_proposed_fn(1.),
               seed=None,
               name=None,
               **kwargs):
    """Instantiates this object.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
      inverse_temperatures: `1D` `Tensor of inverse temperatures to perform
        samplings with each replica. Must have statically known `shape`.
        `inverse_temperatures[0]` produces the states returned by samplers,
        and is typically == 1.
      make_kernel_fn: Python callable which takes target_log_prob_fn and seed
        args and returns a TransitionKernel instance.
      exchange_proposed_fn: Python callable which take a number of replicas, and
        return combinations of replicas for exchange.
      seed: Python integer to seed the random number generator.
        Default value: `None` (i.e., no seed).
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., "remc_kernel").
      **kwargs: Arguments for `make_kernel_fn`.

    Raises:
      ValueError: `inverse_temperatures` doesn't have statically known 1D shape.
    """
    inverse_temperatures = tf.convert_to_tensor(
        inverse_temperatures, name='inverse_temperatures')

    # Note these are static checks, and don't need to be embedded in the graph.
    inverse_temperatures.shape.assert_is_fully_defined()
    inverse_temperatures.shape.assert_has_rank(1)

    self._seed_stream = SeedStream(seed, salt=name)
    self._seeded_mcmc = seed is not None
    self._parameters = dict(
        target_log_prob_fn=target_log_prob_fn,
        inverse_temperatures=inverse_temperatures,
        num_replica=inverse_temperatures.shape[0].value,
        exchange_proposed_fn=exchange_proposed_fn,
        seed=seed,
        name=name)
    self.replica_kernels = []
    for i in range(self.num_replica):
      self.replica_kernels.append(
          make_kernel_fn(
              target_log_prob_fn=_replica_log_prob_fn(inverse_temperatures[i],
                                                      target_log_prob_fn),
              seed=self._seed_stream()))
Ejemplo n.º 5
0
 def _sample_n(self, n, seed=None):
     seed = SeedStream(seed, "gamma_gamma")
     rate = tf.random_gamma(shape=[n],
                            alpha=self.mixing_concentration,
                            beta=self.mixing_rate,
                            dtype=self.dtype,
                            seed=seed())
     return tf.random_gamma(shape=[],
                            alpha=self.concentration,
                            beta=rate,
                            dtype=self.dtype,
                            seed=seed())
Ejemplo n.º 6
0
 def _sample_3d(self, n, seed=None):
   """Specialized inversion sampler for 3D."""
   seed = SeedStream(seed, salt='von_mises_fisher_3d')
   u_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0)
   z = tf.random_uniform(u_shape, seed=seed(), dtype=self.dtype)
   # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could
   # be bisected for bounded sampling runtime (i.e. not rejection sampling).
   # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/
   # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa
   # We must protect against both kappa and z being zero.
   safe_conc = tf.where(self.concentration > 0,
                        self.concentration,
                        tf.ones_like(self.concentration))
   safe_z = tf.where(z > 0, z, tf.ones_like(z))
   safe_u = 1 + tf.reduce_logsumexp([
       tf.log(safe_z), tf.log1p(-safe_z) - 2 * safe_conc], axis=0) / safe_conc
   # Limit of the above expression as kappa->0 is 2*z-1
   u = tf.where(self.concentration > tf.zeros_like(safe_u), safe_u,
                2 * z - 1)
   # Limit of the expression as z->0 is -1.
   u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u)
   if not self._allow_nan_stats:
     u = tf.check_numerics(u, 'u in _sample_3d')
   return u[..., tf.newaxis]
Ejemplo n.º 7
0
def random_von_mises(shape, concentration, dtype=tf.float32, seed=None):
    """Samples from the standardized von Mises distribution.

  The distribution is vonMises(loc=0, concentration=concentration), so the mean
  is zero.
  The location can then be changed by adding it to the samples.

  The sampling algorithm is rejection sampling with wrapped Cauchy proposal [1].
  The samples are pathwise differentiable using the approach of [2].

  Arguments:
    shape: The output sample shape.
    concentration: The concentration parameter of the von Mises distribution.
    dtype: The data type of concentration and the outputs.
    seed: (optional) The random seed.

  Returns:
    Differentiable samples of standardized von Mises.

  References:
    [1] Luc Devroye "Non-Uniform Random Variate Generation", Springer-Verlag,
    1986; Chapter 9, p. 473-476.
    http://www.nrbook.com/devroye/Devroye_files/chapter_nine.pdf
    + corrections http://www.nrbook.com/devroye/Devroye_files/errors.pdf
    [2] Michael Figurnov, Shakir Mohamed, Andriy Mnih. "Implicit
    Reparameterization Gradients", 2018.
  """
    seed = SeedStream(seed, salt="von_mises")
    concentration = tf.convert_to_tensor(concentration,
                                         dtype=dtype,
                                         name="concentration")

    @tf.custom_gradient
    def rejection_sample_with_gradient(concentration):
        """Performs rejection sampling for standardized von Mises.

    A nested function is required because @tf.custom_gradient does not handle
    non-tensor inputs such as dtype. Instead, they are captured by the outer
    scope.

    Arguments:
      concentration: The concentration parameter of the distribution.

    Returns:
      Differentiable samples of standardized von Mises.
    """
        r = 1. + tf.sqrt(1. + 4. * concentration**2)
        rho = (r - tf.sqrt(2. * r)) / (2. * concentration)

        s_exact = (1. + rho**2) / (2. * rho)

        # For low concentration, s becomes numerically unstable.
        # To fix that, we use an approximation. Here is the derivation.
        # First-order Taylor expansion at conc = 0 gives
        #   sqrt(1 + 4 concentration^2) ~= 1 + (2 concentration)^2 / 2.
        # Therefore, r ~= 2 + 2 concentration. By plugging this into rho, we have
        #   rho ~= conc + 1 / conc - sqrt(1 + 1 / concentration^2).
        # Let's expand the last term at concentration=0 up to the linear term:
        #   sqrt(1 + 1 / concentration^2) ~= 1 / concentration + concentration / 2
        # Thus, rho ~= concentration / 2. Finally,
        #   s = 1 / (2 rho) + rho / 2 ~= 1 / concentration + concentration / 4.
        # Since concentration is small, we drop the second term and simply use
        #   s ~= 1 / concentration.
        s_approximate = 1. / concentration

        # To compute the cutoff, we compute s_exact using mpmath with 30 decimal
        # digits precision and compare that to the s_exact and s_approximate
        # computed with dtype. Then, the cutoff is the largest concentration for
        # which abs(s_exact - s_exact_mpmath) > abs(s_approximate - s_exact_mpmath).
        s_concentration_cutoff_dict = {
            tf.float16: 1.8e-1,
            tf.float32: 2e-2,
            tf.float64: 1.2e-4,
        }
        s_concentration_cutoff = s_concentration_cutoff_dict[dtype]

        s = tf.where(concentration > s_concentration_cutoff, s_exact,
                     s_approximate)

        def loop_body(should_continue, u, w):
            """Resample the non-accepted points."""
            # We resample u each time completely. Only its sign is used outside the
            # loop, which is random.
            u = tf.random_uniform(shape,
                                  minval=-1.,
                                  maxval=1.,
                                  dtype=dtype,
                                  seed=seed())
            z = tf.cos(np.pi * u)
            # Update the non-accepted points.
            w = tf.where(should_continue, (1. + s * z) / (s + z), w)
            y = concentration * (s - w)

            v = tf.random_uniform(shape,
                                  minval=0.,
                                  maxval=1.,
                                  dtype=dtype,
                                  seed=seed())
            accept = (y * (2. - y) >= v) | (tf.log(y / v) + 1. >= y)
            should_continue = should_continue & (~accept)

            return should_continue, u, w

        _, u, w = tf.while_loop(
            cond=lambda should_continue, *ignore: tf.reduce_any(should_continue
                                                                ),
            body=loop_body,
            loop_vars=(
                tf.ones(shape, dtype=tf.bool),  # should_continue
                tf.zeros(shape, dtype=dtype),  # u
                tf.zeros(shape, dtype=dtype)),  # w
            # The expected number of iterations depends on concentration.
            # It monotonically increases from one iteration for concentration = 0 to
            # sqrt(2 pi / e) ~= 1.52 iterations for concentration = +inf [1].
            # We use a limit of 100 iterations to avoid infinite loops
            # for very large / nan concentration.
            maximum_iterations=100,
        )

        x = tf.sign(u) * tf.math.acos(w)

        def grad(dy):
            """The gradient of the von Mises samples w.r.t. concentration."""
            broadcast_concentration = concentration + tf.zeros_like(x)
            cdf_func = lambda conc: von_mises_cdf(x, conc)
            _, dcdf_dconcentration = _compute_value_and_grad(
                cdf_func, broadcast_concentration)
            inv_prob = tf.exp(-broadcast_concentration * (tf.cos(x) - 1.)) * (
                (2. * np.pi) * tf.math.bessel_i0e(broadcast_concentration))
            # Compute the implicit reparameterization gradient [2],
            # dz/dconc = -(dF(z; conc) / dconc) / p(z; conc)
            ret = dy * (-inv_prob * dcdf_dconcentration)
            # Sum over the sample dimensions. Assume that they are always the first
            # ones.
            num_sample_dimensions = (tf.rank(broadcast_concentration) -
                                     tf.rank(concentration))
            return tf.reduce_sum(ret, axis=tf.range(num_sample_dimensions))

        return x, grad

    return rejection_sample_with_gradient(concentration)
Ejemplo n.º 8
0
  def _sample_n(self, n, seed=None):
    shape = tf.concat([[n], self.batch_shape_tensor()], axis=0)

    has_seed = seed is not None
    seed = SeedStream(seed, salt="zipf")

    minval_u = self._hat_integral(0.5) + 1.
    maxval_u = self._hat_integral(tf.int64.max - 0.5)

    def loop_body(should_continue, k):
      """Resample the non-accepted points."""
      # The range of U is chosen so that the resulting sample K lies in
      # [0, tf.int64.max). The final sample, if accepted, is K + 1.
      u = tf.random.uniform(
          shape,
          minval=minval_u,
          maxval=maxval_u,
          dtype=self.power.dtype,
          seed=seed())

      # Sample the point X from the continuous density h(x) \propto x^(-power).
      x = self._hat_integral_inverse(u)

      # Rejection-inversion requires a `hat` function, h(x) such that
      # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
      # support. A natural hat function for us is h(x) = x^(-power).
      #
      # After sampling X from h(x), suppose it lies in the interval
      # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
      # if lies to the left of x_K, where x_K is defined by:
      #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
      # where H(x) = \int_x^inf h(x) dx.

      # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
      # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
      # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

      # Update the non-accepted points.
      # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
      k = tf.where(should_continue, tf.floor(x + 0.5), k)
      accept = (u <= self._hat_integral(k + .5) + tf.exp(self._log_prob(k + 1)))

      return [should_continue & (~accept), k]

    should_continue, samples = tf.while_loop(
        cond=lambda should_continue, *ignore: tf.reduce_any(
            input_tensor=should_continue),
        body=loop_body,
        loop_vars=[
            tf.ones(shape, dtype=tf.bool),  # should_continue
            tf.zeros(shape, dtype=self.power.dtype),  # k
        ],
        parallel_iterations=1 if has_seed else 10,
        maximum_iterations=self.sample_maximum_iterations,
    )
    samples = samples + 1.

    if self.validate_args and dtype_util.is_integer(self.dtype):
      samples = distribution_util.embed_check_integer_casting_closed(
          samples, target_dtype=self.dtype, assert_positive=True)

    samples = tf.cast(samples, self.dtype)

    if self.validate_args:
      npdt = dtype_util.as_numpy_dtype(self.dtype)
      v = npdt(dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan)
      mask = tf.fill(shape, value=v)
      samples = tf.where(should_continue, mask, samples)

    return samples
Ejemplo n.º 9
0
    def __init__(self,
                 target_log_prob_fn,
                 step_size,
                 max_tree_depth=6,
                 max_energy_diff=1000.,
                 unrolled_leapfrog_steps=1,
                 seed=None,
                 name=None):
        """Initializes this transition kernel.

    Args:
      target_log_prob_fn: Python callable which takes an argument like
        `current_state` (or `*current_state` if it's a list) and returns its
        (possibly unnormalized) log-density under the target distribution.
        Currnently only support target_log_prob_fn that takes only 1 arg (ie the
        state or free parameters of your model), with the the input being a 2d
        tensor with shape being batch_size * state_part_size.
      step_size: `Tensor` or Python `list` of `Tensor`s representing the step
        size for the leapfrog integrator. Must broadcast with the shape of
        `current_state`. Larger step sizes lead to faster progress, but
        too-large step sizes make rejection exponentially more likely. When
        possible, it's often helpful to match per-variable step sizes to the
        standard deviations of the target distribution in each variable.
      max_tree_depth: Maximum depth of the tree implicitly built by NUTS. The
        maximum number of leapfrog steps is bounded by `2**max_tree_depth` i.e.
        the number of nodes in a binary tree `max_tree_depth` nodes deep. The
        default setting of 6 takes up to 64 leapfrog steps.
      max_energy_diff: Scaler threshold of energy differences at each leapfrog,
        divergence samples are defined as leapfrog steps that exceed this
        threshold. Default to 1000.
      unrolled_leapfrog_steps: The number of leapfrogs to unroll per tree
        expansion step. Applies a direct linear multipler to the maximum
        trajectory length implied by max_tree_depth. Defaults to 1.
      seed: Python integer to seed the random number generator.
      name: Python `str` name prefixed to Ops created by this function.
        Default value: `None` (i.e., 'nuts_kernel').
    """
        with tf.name_scope(name or 'NoUTurnSamplerUnrolled') as name:
            # Process `max_tree_depth` argument.
            max_tree_depth = tf.get_static_value(max_tree_depth)
            if max_tree_depth is None or max_tree_depth < 1:
                raise ValueError(
                    'max_tree_depth must be known statically and >= 1 but was '
                    '{}'.format(max_tree_depth))
            self._max_tree_depth = max_tree_depth

            # Compute parameters derived from `max_tree_depth`.
            instruction_array = build_tree_uturn_instruction(max_tree_depth,
                                                             init_memory=-1)
            [write_instruction, read_instruction
             ] = generate_efficient_write_read_instruction(instruction_array)
            if USE_RAGGED_TENSOR:
                self._write_instruction = tf.constant(write_instruction)
                self._read_instruction = tf.ragged.constant(read_instruction)
            else:
                f = lambda int_iter: write_instruction[int_iter]
                self._write_instruction = {
                    x: functools.partial(f, x)
                    for x in range(len(write_instruction))
                }
                self._read_instruction = read_instruction

            # Process all other arguments.
            self._target_log_prob_fn = target_log_prob_fn
            if not tf.nest.is_nested(step_size):
                step_size = [step_size]
            step_size = [
                tf.convert_to_tensor(s, dtype_hint=tf.float32)
                for s in step_size
            ]
            self._step_size = step_size

            self._parameters = dict(
                target_log_prob_fn=target_log_prob_fn,
                step_size=step_size,
                max_tree_depth=max_tree_depth,
                max_energy_diff=max_energy_diff,
                unrolled_leapfrog_steps=unrolled_leapfrog_steps,
                seed=seed,
                name=name,
            )
            self._seed_stream = SeedStream(seed, salt='nuts_one_step')
            self._unrolled_leapfrog_steps = unrolled_leapfrog_steps
            self._name = name
            self._max_energy_diff = max_energy_diff
Ejemplo n.º 10
0
  def _sample_n(self, n, seed=None):
    seed = SeedStream(seed, salt='vom_mises_fisher')
    # The sampling strategy relies on the fact that vMF variates are symmetric
    # about the mean direction. Accordingly, if we have a sampling strategy for
    # the away-from-mean angle, then we can uniformly sample the remaining
    # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
    # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
    #
    # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
    # von-Mises distributed `x` value in [-1, 1], then uniformly select what
    # amounts to a "up" or "down" additional degree of freedom after unit
    # normalizing, followed by a final rotation to the desired mean direction
    # from a basis of (1, 0).
    #
    # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
    # unit sphere over which the distribution is uniform, in particular the
    # circle where x = \hat{x} intersects the unit sphere. We pick a point on
    # that circle, then rotate to the desired mean direction from a basis of
    # (1, 0, 0).
    event_dim = self.event_shape[0].value or self._event_shape_tensor()[0]

    sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0)
    dim = tf.cast(event_dim - 1, self.dtype)
    if event_dim == 3:
      samples_dim0 = self._sample_3d(n, seed=seed)
    else:
      # Wood'94 provides a rejection algorithm to sample the x coordinate.
      # Wood'94 definition of b:
      # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
      # https://stats.stackexchange.com/questions/156729 suggests:
      b = dim / (2 * self.concentration +
                 tf.sqrt(4 * self.concentration**2 + dim**2))
      # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
      #     https://github.com/nicola-decao/s-vae-tf/
      x = (1 - b) / (1 + b)
      c = self.concentration * x + dim * tf.log1p(-x**2)
      beta = tf.distributions.Beta(dim / 2, dim / 2)

      def cond_fn(w, should_continue):
        del w
        return tf.reduce_any(should_continue)

      def body_fn(w, should_continue):
        z = beta.sample(sample_shape=sample_batch_shape, seed=seed())
        w = tf.where(should_continue, (1 - (1 + b) * z) / (1 - (1 - b) * z), w)
        w = tf.check_numerics(w, 'w')
        should_continue = tf.logical_and(
            should_continue,
            self.concentration * w + dim * tf.log1p(-x * w) - c <
            tf.log(tf.random_uniform(sample_batch_shape, seed=seed(),
                                     dtype=self.dtype)))
        return w, should_continue

      w = tf.zeros(sample_batch_shape, dtype=self.dtype)
      should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
      samples_dim0 = tf.while_loop(cond_fn, body_fn, (w, should_continue))[0]
      samples_dim0 = samples_dim0[..., tf.newaxis]
    if not self._allow_nan_stats:
      # Verify samples are w/in -1, 1, with useful error output tensors (top
      # value rather than all values).
      with tf.control_dependencies([
          tf.assert_less_equal(
              samples_dim0, self.dtype.as_numpy_dtype(1.01),
              data=[tf.nn.top_k(tf.reshape(samples_dim0, [-1]))[0]]),
          tf.assert_greater_equal(
              samples_dim0, self.dtype.as_numpy_dtype(-1.01),
              data=[-tf.nn.top_k(tf.reshape(-samples_dim0, [-1]))[0]])]):
        samples_dim0 = tf.identity(samples_dim0)
    samples_otherdims_shape = tf.concat([sample_batch_shape, [event_dim - 1]],
                                        axis=0)
    unit_otherdims = tf.nn.l2_normalize(
        tf.random_normal(samples_otherdims_shape, seed=seed(),
                         dtype=self.dtype),
        axis=-1)
    samples = tf.concat([
        samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
        tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
    ], axis=-1)
    samples = tf.nn.l2_normalize(samples, axis=-1)
    if not self._allow_nan_stats:
      samples = tf.check_numerics(samples, 'samples')

    # Runtime assert that samples are unit length.
    if not self._allow_nan_stats:
      worst, idx = tf.nn.top_k(
          tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
      with tf.control_dependencies([
          tf.assert_near(
              self.dtype.as_numpy_dtype(0), worst,
              data=[worst, idx,
                    tf.gather(tf.reshape(samples, [-1, event_dim]), idx)],
              atol=1e-4, summarize=100)]):
        samples = tf.identity(samples)
    # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
    # Now, we move the mode to `self.mean_direction` using a rotation matrix.
    if not self._allow_nan_stats:
      # Assert that the basis vector rotates to the mean direction, as expected.
      basis = tf.cast(tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                      self.dtype)
      with tf.control_dependencies([
          tf.assert_less(
              tf.linalg.norm(self._rotate(basis) - self.mean_direction,
                             axis=-1),
              self.dtype.as_numpy_dtype(1e-5))
      ]):
        return self._rotate(samples)
    return self._rotate(samples)