def _sample_3d(self, n, seed=None):
     """Specialized inversion sampler for 3D."""
     seed = SeedStream(seed, salt='von_mises_fisher_3d')
     u_shape = tf.concat([[n], self._batch_shape_tensor()], axis=0)
     z = tf.random.uniform(u_shape, seed=seed(), dtype=self.dtype)
     # TODO(bjp): Higher-order odd dim analytic CDFs are available in [1], could
     # be bisected for bounded sampling runtime (i.e. not rejection sampling).
     # [1]: Inversion sampler via: https://ieeexplore.ieee.org/document/7347705/
     # The inversion is: u = 1 + log(z + (1-z)*exp(-2*kappa)) / kappa
     # We must protect against both kappa and z being zero.
     safe_conc = tf.where(self.concentration > 0, self.concentration,
                          tf.ones_like(self.concentration))
     safe_z = tf.where(z > 0, z, tf.ones_like(z))
     safe_u = 1 + tf.reduce_logsumexp(
         [tf.math.log(safe_z),
          tf.math.log1p(-safe_z) - 2 * safe_conc],
         axis=0) / safe_conc
     # Limit of the above expression as kappa->0 is 2*z-1
     u = tf.where(self.concentration > tf.zeros_like(safe_u), safe_u,
                  2 * z - 1)
     # Limit of the expression as z->0 is -1.
     u = tf.where(tf.equal(z, 0), -tf.ones_like(u), u)
     if not self._allow_nan_stats:
         u = tf.debugging.check_numerics(u, 'u in _sample_3d')
     return u[..., tf.newaxis]
Exemplo n.º 2
0
 def _sample_n(self, n, seed=None):
     seed = SeedStream(seed, salt="Autoregressive")()
     samples = self.distribution0.sample(n, seed=seed)
     for _ in range(self._num_steps):
         # pylint: disable=not-callable
         samples = self.distribution_fn(samples).sample(seed=seed)
     return samples
Exemplo n.º 3
0
    def _sample_n(self, n, seed=None):
        low = tf.convert_to_tensor(self.low)
        high = tf.convert_to_tensor(self.high)
        peak = tf.convert_to_tensor(self.peak)

        stream = SeedStream(seed, salt='triangular')
        shape = tf.concat(
            [[n], self._batch_shape_tensor(low=low, high=high, peak=peak)],
            axis=0)
        samples = tf.random.uniform(shape=shape,
                                    dtype=self.dtype,
                                    seed=stream())
        # We use Inverse CDF sampling here. Because the CDF is a quadratic function,
        # we must use sqrts here.
        interval_length = high - low
        return tf.where(
            # Note the CDF on the left side of the peak is
            # (x - low) ** 2 / ((high - low) * (peak - low)).
            # If we plug in peak for x, we get that the CDF at the peak
            # is (peak - low) / (high - low). Because of this we decide
            # which part of the piecewise CDF we should use based on the cdf samples
            # we drew.
            samples < (peak - low) / interval_length,
            # Inverse of (x - low) ** 2 / ((high - low) * (peak - low)).
            low + tf.sqrt(samples * interval_length * (peak - low)),
            # Inverse of 1 - (high - x) ** 2 / ((high - low) * (high - peak))
            high - tf.sqrt((1. - samples) * interval_length * (high - peak)))
Exemplo n.º 4
0
 def _sample_n(self, n, seed=None):
     # See https://en.wikipedia.org/wiki/Inverse_Gaussian_distribution or
     # https://www.jstor.org/stable/2683801
     concentration = tf.convert_to_tensor(self.concentration)
     loc = tf.convert_to_tensor(self.loc)
     seed = SeedStream(seed, 'inverse_gaussian')
     shape = tf.concat(
         [[n],
          self._batch_shape_tensor(loc=loc, concentration=concentration)],
         axis=0)
     sampled_chi2 = (tf.random.normal(shape,
                                      mean=0.,
                                      stddev=1.,
                                      seed=seed(),
                                      dtype=self.dtype))**2.
     sampled_uniform = tf.random.uniform(shape,
                                         minval=0.,
                                         maxval=1.,
                                         seed=seed(),
                                         dtype=self.dtype)
     sampled = (loc + loc**2. * sampled_chi2 / (2. * concentration) - loc /
                (2. * concentration) *
                (4. * loc * concentration * sampled_chi2 +
                 (loc * sampled_chi2)**2)**0.5)
     return tf.where(sampled_uniform <= loc / (loc + sampled), sampled,
                     loc**2 / sampled)
Exemplo n.º 5
0
 def _flat_sample_distributions(self,
                                sample_shape=(),
                                seed=None,
                                value=None):
     # This function additionally depends on:
     #   self._dist_fn_wrapped
     #   self._dist_fn_args
     #   self._always_use_specified_sample_shape
     seed = SeedStream('JointDistributionSequential', seed)
     ds = []
     xs = [None] * len(self._dist_fn_wrapped) if value is None else list(
         value)
     if len(xs) != len(self._dist_fn_wrapped):
         raise ValueError('Number of `xs`s must match number of '
                          'distributions.')
     for i, (dist_fn, args) in enumerate(
             zip(self._dist_fn_wrapped, self._dist_fn_args)):
         ds.append(dist_fn(*xs[:i]))  # Chain rule of probability.
         if xs[i] is None:
             # TODO(b/129364796): We should ignore args prefixed with `_`; this
             # would mean we more often identify when to use `sample_shape=()`
             # rather than `sample_shape=sample_shape`.
             xs[i] = ds[-1].sample(
                 () if args and not self._always_use_specified_sample_shape
                 else sample_shape,
                 seed=seed())
         else:
             xs[i] = tf.convert_to_tensor(xs[i], dtype_hint=ds[-1].dtype)
             seed(
             )  # Ensure reproducibility even when xs are (partially) set.
     # Note: we could also resolve distributions up to the first non-`None` in
     # `self._model_flatten(value)`, however we omit this feature for simplicity,
     # speed, and because it has not yet been requested.
     return ds, xs
Exemplo n.º 6
0
 def _sample_n(self, n, seed=None):
     scale = tf.convert_to_tensor(self.scale)
     shape = tf.concat([[n], tf.shape(scale)], axis=0)
     seed = SeedStream(seed, salt='random_horseshoe')
     local_shrinkage = self._half_cauchy.sample(shape, seed=seed())
     shrinkage = scale * local_shrinkage
     sampled = tf.random.normal(shape=shape,
                                mean=0.,
                                stddev=1.,
                                dtype=scale.dtype,
                                seed=seed())
     return sampled * shrinkage
 def _sample_n(self, n, seed=None):
     # Here we use the fact that if:
     # lam ~ Gamma(concentration=total_count, rate=(1-probs)/probs)
     # then X ~ Poisson(lam) is Negative Binomially distributed.
     logits = self._logits_parameter_no_checks()
     stream = SeedStream(seed, salt='NegativeBinomial')
     rate = tf.random.gamma(shape=[n],
                            alpha=self.total_count,
                            beta=tf.exp(-logits),
                            dtype=self.dtype,
                            seed=stream())
     return tf.random.poisson(lam=rate,
                              shape=[],
                              dtype=self.dtype,
                              seed=stream())
Exemplo n.º 8
0
 def _sample_n(self, n, seed=None):
     seed = SeedStream(seed, "beta")
     concentration1 = tf.convert_to_tensor(self.concentration1)
     concentration0 = tf.convert_to_tensor(self.concentration0)
     shape = self._batch_shape_tensor(concentration1, concentration0)
     expanded_concentration1 = tf.broadcast_to(concentration1, shape)
     expanded_concentration0 = tf.broadcast_to(concentration0, shape)
     gamma1_sample = tf.random.gamma(shape=[n],
                                     alpha=expanded_concentration1,
                                     dtype=self.dtype,
                                     seed=seed())
     gamma2_sample = tf.random.gamma(shape=[n],
                                     alpha=expanded_concentration0,
                                     dtype=self.dtype,
                                     seed=seed())
     beta_sample = gamma1_sample / (gamma1_sample + gamma2_sample)
     return beta_sample
Exemplo n.º 9
0
 def _sample_n(self, n, seed=None):
     concentration = tf.convert_to_tensor(self.concentration)
     mixing_concentration = tf.convert_to_tensor(self.mixing_concentration)
     mixing_rate = tf.convert_to_tensor(self.mixing_rate)
     seed = SeedStream(seed, 'gamma_gamma')
     rate = tf.random.gamma(
         shape=[n],
         # Be sure to draw enough rates for the fully-broadcasted gamma-gamma.
         alpha=mixing_concentration + tf.zeros_like(concentration),
         beta=mixing_rate,
         dtype=self.dtype,
         seed=seed())
     return tf.random.gamma(shape=[],
                            alpha=concentration,
                            beta=rate,
                            dtype=self.dtype,
                            seed=seed())
Exemplo n.º 10
0
  def _sample_n(self, n, seed=None):
    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    distributions = self.poisson_and_mixture_distributions()
    dist, mixture_dist = distributions
    batch_size = tensorshape_util.num_elements(self.batch_shape)
    if batch_size is None:
      batch_size = tf.reduce_prod(
          self._batch_shape_tensor(distributions=distributions))
    # We need to 'sample extra' from the mixture distribution if it doesn't
    # already specify a probs vector for each batch coordinate.
    # We only support this kind of reduced broadcasting, i.e., there is exactly
    # one probs vector for all batch dims or one for each.
    stream = SeedStream(seed, salt='PoissonLogNormalQuadratureCompound')
    ids = mixture_dist.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                mixture_dist.is_scalar_batch(),
                [batch_size],
                np.int32([]))),
        seed=stream())
    # We need to flatten batch dims in case mixture_dist has its own
    # batch dims.
    ids = tf.reshape(
        ids,
        shape=concat_vectors([n],
                             distribution_util.pick_vector(
                                 self.is_scalar_batch(), np.int32([]),
                                 np.int32([-1]))))

    # Stride `quadrature_size` for `batch_size` number of times.
    offset = tf.range(
        start=0,
        limit=batch_size * self._quadrature_size,
        delta=self._quadrature_size,
        dtype=ids.dtype)
    ids = ids + offset
    rate = tf.gather(tf.reshape(dist.rate, shape=[-1]), ids)
    rate = tf.reshape(
        rate, shape=concat_vectors([n], self._batch_shape_tensor(
            distributions=distributions)))
    return tf.random.poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
Exemplo n.º 11
0
  def _sample_n(self, n, seed=None):
    # The sampling method comes from the fact that if:
    #   X ~ Normal(0, 1)
    #   Z ~ Chi2(df)
    #   Y = X / sqrt(Z / df)
    # then:
    #   Y ~ StudentT(df).
    df = tf.convert_to_tensor(self.df)
    loc = tf.convert_to_tensor(self.loc)
    scale = tf.convert_to_tensor(self.scale)
    batch_shape = self._batch_shape_tensor(df=df, loc=loc, scale=scale)
    shape = tf.concat([[n], batch_shape], 0)
    seed = SeedStream(seed, 'student_t')

    normal_sample = tf.random.normal(shape, dtype=self.dtype, seed=seed())
    df = df * tf.ones(batch_shape, dtype=self.dtype)
    gamma_sample = tf.random.gamma(
        [n], 0.5 * df, beta=0.5, dtype=self.dtype, seed=seed())
    samples = normal_sample * tf.math.rsqrt(gamma_sample / df)
    return samples * scale + loc  # Abs(scale) not wanted.
 def _sample_n(self, n, seed):
     with tf.control_dependencies(self._runtime_assertions):
         seed = SeedStream(seed, salt="MixtureSameFamily")
         x = self.components_distribution.sample(
             n, seed=seed())  # [n, B, k, E]
         # TODO(jvdillon): Consider using tf.gather (by way of index unrolling).
         npdt = dtype_util.as_numpy_dtype(x.dtype)
         mask = tf.one_hot(
             indices=self.mixture_distribution.sample(
                 n, seed=seed()),  # [n, B]
             depth=self._num_components,  # == k
             on_value=npdt(1),
             off_value=npdt(0))  # [n, B, k]
         mask = distribution_utils.pad_mixture_dimensions(
             mask, self, self.mixture_distribution,
             self._event_ndims)  # [n, B, k, [1]*e]
         x = tf.reduce_sum(x * mask,
                           axis=-1 - self._event_ndims)  # [n, B, E]
         if self._reparameterize:
             x = self._reparameterize_sample(x)
         return x
  def _flat_sample_distributions(self, sample_shape=(), seed=None, value=None):
    """Executes `model`, creating both samples and distributions."""
    ds = []
    values_out = []
    seed = SeedStream('JointDistributionCoroutine', seed)
    gen = self._model()
    index = 0
    d = next(gen)
    if not isinstance(d, self.Root):
      raise ValueError('First distribution yielded by coroutine must '
                       'be wrapped in `Root`.')
    try:
      while True:
        actual_distribution = d.distribution if isinstance(d, self.Root) else d
        ds.append(actual_distribution)
        if (value is not None and len(value) > index and
            value[index] is not None):
          seed()
          next_value = value[index]
        else:
          next_value = actual_distribution.sample(
              sample_shape=sample_shape if isinstance(d, self.Root) else (),
              seed=seed())

        if self._validate_args:
          with tf.control_dependencies(
              self._assert_compatible_shape(
                  index, sample_shape, next_value)):
            values_out.append(tf.identity(next_value))
        else:
          values_out.append(next_value)

        index += 1
        d = gen.send(next_value)
    except StopIteration:
      pass
    return ds, values_out
  def _sample_n(self, n, seed=None):
    seed = SeedStream(seed, 'dirichlet_multinomial')

    concentration = tf.convert_to_tensor(self._concentration)
    total_count = tf.convert_to_tensor(self._total_count)

    n_draws = tf.cast(total_count, dtype=tf.int32)
    k = self._event_shape_tensor(concentration)[0]
    alpha = tf.math.multiply(
        tf.ones_like(total_count[..., tf.newaxis]),
        concentration,
        name='alpha')

    unnormalized_logits = tf.math.log(
        tf.random.gamma(
            shape=[n],
            alpha=alpha,
            dtype=self.dtype,
            seed=seed()))
    x = multinomial.draw_sample(
        1, k, unnormalized_logits, n_draws, self.dtype, seed())
    final_shape = tf.concat(
        [[n], self._batch_shape_tensor(concentration, total_count), [k]], 0)
    return tf.reshape(x, final_shape)
Exemplo n.º 15
0
    def _sample_n(self, n, seed):
        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        batch_ndims = tf.shape(batch_shape)[0]

        ndims = batch_ndims + 3  # sample_ndims=1, event_ndims=2
        shape = tf.concat([[n], batch_shape, event_shape], 0)
        stream = SeedStream(seed, salt="Wishart")

        # Complexity: O(nbk**2)
        x = tf.random.normal(shape=shape,
                             mean=0.,
                             stddev=1.,
                             dtype=self.dtype,
                             seed=stream())

        # Complexity: O(nbk)
        # This parametrization is equivalent to Chi2, i.e.,
        # ChiSquared(k) == Gamma(alpha=k/2, beta=1/2)
        expanded_df = self.df * tf.ones(
            self.scale_operator.batch_shape_tensor(),
            dtype=dtype_util.base_dtype(self.df.dtype))

        g = tf.random.gamma(shape=[n],
                            alpha=self._multi_gamma_sequence(
                                0.5 * expanded_df, self.dimension),
                            beta=0.5,
                            dtype=self.dtype,
                            seed=stream())

        # Complexity: O(nbk**2)
        x = tf.linalg.band_part(x, -1, 0)  # Tri-lower.

        # Complexity: O(nbk)
        x = tf.linalg.set_diag(x, tf.sqrt(g))

        # Make batch-op ready.
        # Complexity: O(nbk**2)
        perm = tf.concat([tf.range(1, ndims), [0]], 0)
        x = tf.transpose(a=x, perm=perm)
        shape = tf.concat(
            [batch_shape, [event_shape[0]], [event_shape[1] * n]], 0)
        x = tf.reshape(x, shape)

        # Complexity: O(nbM) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each matmul is O(k^3) so
        # this step has complexity O(nbk^3).
        x = self.scale_operator.matmul(x)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat([batch_shape, event_shape, [n]], 0)
        x = tf.reshape(x, shape)
        perm = tf.concat([[ndims - 1], tf.range(0, ndims - 1)], 0)
        x = tf.transpose(a=x, perm=perm)

        if not self.input_output_cholesky:
            # Complexity: O(nbk**3)
            x = tf.matmul(x, x, adjoint_b=True)

        return x
Exemplo n.º 16
0
    def _sample_n(self, n, seed=None):
        if self._use_static_graph:
            with tf.control_dependencies(self._assertions):
                # This sampling approach is almost the same as the approach used by
                # `MixtureSameFamily`. The differences are due to having a list of
                # `Distribution` objects rather than a single object, and maintaining
                # random seed management that is consistent with the non-static code
                # path.
                samples = []
                cat_samples = self.cat.sample(n, seed=seed)
                stream = SeedStream(seed, salt="Mixture")

                for c in range(self.num_components):
                    samples.append(self.components[c].sample(n, seed=stream()))
                stack_axis = -1 - tensorshape_util.rank(
                    self._static_event_shape)
                x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
                npdt = dtype_util.as_numpy_dtype(x.dtype)
                mask = tf.one_hot(
                    indices=cat_samples,  # [n, B]
                    depth=self._num_components,  # == k
                    on_value=npdt(1),
                    off_value=npdt(0))  # [n, B, k]
                mask = distribution_util.pad_mixture_dimensions(
                    mask, self, self._cat,
                    tensorshape_util.rank(
                        self._static_event_shape))  # [n, B, k, [1]*e]
                return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

        with tf.control_dependencies(self._assertions):
            n = tf.convert_to_tensor(n, name="n")
            static_n = tf.get_static_value(n)
            n = int(static_n) if static_n is not None else n
            cat_samples = self.cat.sample(n, seed=seed)

            static_samples_shape = cat_samples.shape
            if tensorshape_util.is_fully_defined(static_samples_shape):
                samples_shape = tensorshape_util.as_list(static_samples_shape)
                samples_size = tensorshape_util.num_elements(
                    static_samples_shape)
            else:
                samples_shape = tf.shape(cat_samples)
                samples_size = tf.size(cat_samples)
            static_batch_shape = self.batch_shape
            if tensorshape_util.is_fully_defined(static_batch_shape):
                batch_shape = tensorshape_util.as_list(static_batch_shape)
                batch_size = tensorshape_util.num_elements(static_batch_shape)
            else:
                batch_shape = self.batch_shape_tensor()
                batch_size = tf.reduce_prod(batch_shape)
            static_event_shape = self.event_shape
            if tensorshape_util.is_fully_defined(static_event_shape):
                event_shape = np.array(
                    tensorshape_util.as_list(static_event_shape),
                    dtype=np.int32)
            else:
                event_shape = self.event_shape_tensor()

            # Get indices into the raw cat sampling tensor. We will
            # need these to stitch sample values back out after sampling
            # within the component partitions.
            samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                             samples_shape)

            # Partition the raw indices so that we can use
            # dynamic_stitch later to reconstruct the samples from the
            # known partitions.
            partitioned_samples_indices = tf.dynamic_partition(
                data=samples_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)

            # Copy the batch indices n times, as we will need to know
            # these to pull out the appropriate rows within the
            # component partitions.
            batch_raw_indices = tf.reshape(
                tf.tile(tf.range(0, batch_size), [n]), samples_shape)

            # Explanation of the dynamic partitioning below:
            #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
            # Suppose partitions are:
            #     [1 1 0 0 1 1]
            # After partitioning, batch indices are cut as:
            #     [batch_indices[x] for x in 2, 3]
            #     [batch_indices[x] for x in 0, 1, 4, 5]
            # i.e.
            #     [1 1] and [0 0 0 0]
            # Now we sample n=2 from part 0 and n=4 from part 1.
            # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
            # and for part 1 we want samples from batch entries 0, 0, 0, 0
            #   (samples 0, 1, 2, 3).
            partitioned_batch_indices = tf.dynamic_partition(
                data=batch_raw_indices,
                partitions=cat_samples,
                num_partitions=self.num_components)
            samples_class = [None for _ in range(self.num_components)]

            stream = SeedStream(seed, salt="Mixture")

            for c in range(self.num_components):
                n_class = tf.size(partitioned_samples_indices[c])
                samples_class_c = self.components[c].sample(n_class,
                                                            seed=stream())

                # Pull out the correct batch entries from each index.
                # To do this, we may have to flatten the batch shape.

                # For sample s, batch element b of component c, we get the
                # partitioned batch indices from
                # partitioned_batch_indices[c]; and shift each element by
                # the sample index. The final lookup can be thought of as
                # a matrix gather along locations (s, b) in
                # samples_class_c where the n_class rows correspond to
                # samples within this component and the batch_size columns
                # correspond to batch elements within the component.
                #
                # Thus the lookup index is
                #   lookup[c, i] = batch_size * s[i] + b[c, i]
                # for i = 0 ... n_class[c] - 1.
                lookup_partitioned_batch_indices = (
                    batch_size * tf.range(n_class) +
                    partitioned_batch_indices[c])
                samples_class_c = tf.reshape(
                    samples_class_c,
                    tf.concat([[n_class * batch_size], event_shape], 0))
                samples_class_c = tf.gather(samples_class_c,
                                            lookup_partitioned_batch_indices,
                                            name="samples_class_c_gather")
                samples_class[c] = samples_class_c

            # Stitch back together the samples across the components.
            lhs_flat_ret = tf.dynamic_stitch(
                indices=partitioned_samples_indices, data=samples_class)
            # Reshape back to proper sample, batch, and event shape.
            ret = tf.reshape(
                lhs_flat_ret,
                tf.concat(
                    [samples_shape, self.event_shape_tensor()], 0))
            tensorshape_util.set_shape(
                ret,
                tensorshape_util.concatenate(static_samples_shape,
                                             self.event_shape))
            return ret
    def _sample_n(self, n, seed=None):
        stream = SeedStream(seed, salt="VectorDiffeomixture")
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=stream())  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = tensorshape_util.num_elements(self.batch_shape)
        if batch_size is None:
            batch_size = tf.reduce_prod(self.batch_shape_tensor())
        mix_batch_size = tensorshape_util.num_elements(
            self.mixture_distribution.batch_shape)
        if mix_batch_size is None:
            mix_batch_size = tf.reduce_prod(
                self.mixture_distribution.batch_shape_tensor())
        ids = self.mixture_distribution.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]),
                                          [batch_size // mix_batch_size])),
                                               seed=stream())
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `components * quadrature_size` for `batch_size` number of times.
        stride = tensorshape_util.num_elements(
            tensorshape_util.with_rank_at_least(self.grid.shape, 2)[-2:])
        if stride is None:
            stride = tf.reduce_prod(tf.shape(self.grid)[-2:])
        offset = tf.range(start=0,
                          limit=batch_size * stride,
                          delta=stride,
                          dtype=ids.dtype)

        weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset)
        # At this point, weight flattened all batch dims into one.
        # We also need to append a singleton to broadcast with event dims.
        if tensorshape_util.is_fully_defined(self.batch_shape):
            new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1]
        else:
            new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]),
                                  axis=0)
        weight = tf.reshape(weight, shape=new_shape)

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
Exemplo n.º 18
0
    def _sample_n(self, n, seed=None):
        seed = SeedStream(seed, salt='vom_mises_fisher')
        # The sampling strategy relies on the fact that vMF variates are symmetric
        # about the mean direction. Accordingly, if we have a sampling strategy for
        # the away-from-mean angle, then we can uniformly sample the remaining
        # dimensions on the S^{dim-2} sphere for , and rotate these samples from a
        # (1, 0, 0, ..., 0)-mode distribution into the target orientation.
        #
        # This is easy to imagine on the 1-sphere (S^1; in 2-D space): sample a
        # von-Mises distributed `x` value in [-1, 1], then uniformly select what
        # amounts to a "up" or "down" additional degree of freedom after unit
        # normalizing, followed by a final rotation to the desired mean direction
        # from a basis of (1, 0).
        #
        # On S^2 (in 3-D), selecting a vMF `x` identifies a circle in `yz` on the
        # unit sphere over which the distribution is uniform, in particular the
        # circle where x = \hat{x} intersects the unit sphere. We pick a point on
        # that circle, then rotate to the desired mean direction from a basis of
        # (1, 0, 0).
        event_dim = (tf.compat.dimension_value(self.event_shape[0])
                     or self._event_shape_tensor()[0])

        sample_batch_shape = tf.concat([[n], self._batch_shape_tensor()],
                                       axis=0)
        dim = tf.cast(event_dim - 1, self.dtype)
        if event_dim == 3:
            samples_dim0 = self._sample_3d(n, seed=seed)
        else:
            # Wood'94 provides a rejection algorithm to sample the x coordinate.
            # Wood'94 definition of b:
            # b = (-2 * kappa + tf.sqrt(4 * kappa**2 + dim**2)) / dim
            # https://stats.stackexchange.com/questions/156729 suggests:
            b = dim / (2 * self.concentration +
                       tf.sqrt(4 * self.concentration**2 + dim**2))
            # TODO(bjp): Integrate any useful numerical tricks from hyperspherical VAE
            #     https://github.com/nicola-decao/s-vae-tf/
            x = (1 - b) / (1 + b)
            c = self.concentration * x + dim * tf.math.log1p(-x**2)
            beta = beta_lib.Beta(dim / 2, dim / 2)

            def cond_fn(w, should_continue):
                del w
                return tf.reduce_any(should_continue)

            def body_fn(w, should_continue):
                z = beta.sample(sample_shape=sample_batch_shape, seed=seed())
                # set_shape needed here because of b/139013403
                z.set_shape(w.shape)
                w = tf.where(should_continue,
                             (1 - (1 + b) * z) / (1 - (1 - b) * z), w)
                w = tf.debugging.check_numerics(w, 'w')
                unif = tf.random.uniform(sample_batch_shape,
                                         seed=seed(),
                                         dtype=self.dtype)
                # set_shape needed here because of b/139013403
                unif.set_shape(w.shape)
                should_continue = tf.logical_and(
                    should_continue,
                    self.concentration * w + dim * tf.math.log1p(-x * w) - c <
                    tf.math.log(unif))
                return w, should_continue

            w = tf.zeros(sample_batch_shape, dtype=self.dtype)
            should_continue = tf.ones(sample_batch_shape, dtype=tf.bool)
            samples_dim0 = tf.while_loop(cond=cond_fn,
                                         body=body_fn,
                                         loop_vars=(w, should_continue))[0]
            samples_dim0 = samples_dim0[..., tf.newaxis]
        if not self._allow_nan_stats:
            # Verify samples are w/in -1, 1, with useful error output tensors (top
            # value rather than all values).
            with tf.control_dependencies([
                    assert_util.assert_less_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(1.01),
                        data=[
                            tf.math.top_k(tf.reshape(samples_dim0, [-1]))[0]
                        ]),
                    assert_util.assert_greater_equal(
                        samples_dim0,
                        dtype_util.as_numpy_dtype(self.dtype)(-1.01),
                        data=[
                            -tf.math.top_k(tf.reshape(-samples_dim0, [-1]))[0]
                        ])
            ]):
                samples_dim0 = tf.identity(samples_dim0)
        samples_otherdims_shape = tf.concat(
            [sample_batch_shape, [event_dim - 1]], axis=0)
        unit_otherdims = tf.math.l2_normalize(tf.random.normal(
            samples_otherdims_shape, seed=seed(), dtype=self.dtype),
                                              axis=-1)
        samples = tf.concat(
            [
                samples_dim0,  # we must avoid sqrt(1 - (>1)**2)
                tf.sqrt(tf.maximum(1 - samples_dim0**2, 0.)) * unit_otherdims
            ],
            axis=-1)
        samples = tf.math.l2_normalize(samples, axis=-1)
        if not self._allow_nan_stats:
            samples = tf.debugging.check_numerics(samples, 'samples')

        # Runtime assert that samples are unit length.
        if not self._allow_nan_stats:
            worst, idx = tf.math.top_k(
                tf.reshape(tf.abs(1 - tf.linalg.norm(samples, axis=-1)), [-1]))
            with tf.control_dependencies([
                    assert_util.assert_near(
                        dtype_util.as_numpy_dtype(self.dtype)(0),
                        worst,
                        data=[
                            worst, idx,
                            tf.gather(tf.reshape(samples, [-1, event_dim]),
                                      idx)
                        ],
                        atol=1e-4,
                        summarize=100)
            ]):
                samples = tf.identity(samples)
        # The samples generated are symmetric around a mode at (1, 0, 0, ...., 0).
        # Now, we move the mode to `self.mean_direction` using a rotation matrix.
        if not self._allow_nan_stats:
            # Assert that the basis vector rotates to the mean direction, as expected.
            basis = tf.cast(
                tf.concat([[1.], tf.zeros([event_dim - 1])], axis=0),
                self.dtype)
            with tf.control_dependencies([
                    assert_util.assert_less(
                        tf.linalg.norm(self._rotate(basis) -
                                       self.mean_direction,
                                       axis=-1),
                        dtype_util.as_numpy_dtype(self.dtype)(1e-5))
            ]):
                return self._rotate(samples)
        return self._rotate(samples)
    def _sample_n(self, n, seed=None):
        with tf.control_dependencies(self._runtime_assertions):
            strm = SeedStream(seed, salt="HiddenMarkovModel")

            num_states = self._num_states

            batch_shape = self.batch_shape_tensor()
            batch_size = tf.reduce_prod(batch_shape)

            # The batch sizes of the underlying initial distributions and
            # transition distributions might not match the batch size of
            # the HMM distribution.
            # As a result we need to ask for more samples from the
            # underlying distributions and then reshape the results into
            # the correct batch size for the HMM.
            init_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._initial_distribution.batch_shape_tensor()))
            init_state = self._initial_distribution.sample(n * init_repeat,
                                                           seed=strm())
            init_state = tf.reshape(init_state, [n, batch_size])
            # init_state :: n batch_size

            transition_repeat = (
                tf.reduce_prod(self.batch_shape_tensor()) // tf.reduce_prod(
                    self._transition_distribution.batch_shape_tensor()[:-1]))

            def generate_step(state, _):
                """Take a single step in Markov chain."""

                gen = self._transition_distribution.sample(n *
                                                           transition_repeat,
                                                           seed=strm())
                # gen :: (n * transition_repeat) transition_batch

                new_states = tf.reshape(gen, [n, batch_size, num_states])

                # new_states :: n batch_size num_states

                old_states_one_hot = tf.one_hot(state,
                                                num_states,
                                                dtype=tf.int32)

                # old_states :: n batch_size num_states

                return tf.reduce_sum(old_states_one_hot * new_states, axis=-1)

            def _scan_multiple_steps():
                """Take multiple steps with tf.scan."""
                dummy_index = tf.zeros(self._num_steps - 1, dtype=tf.float32)
                if seed is not None:
                    # Force parallel_iterations to 1 to ensure reproducibility
                    # b/139210489
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state,
                                            parallel_iterations=1)
                else:
                    # Invoke default parallel_iterations behavior
                    hidden_states = tf.scan(generate_step,
                                            dummy_index,
                                            initializer=init_state)

                # TODO(b/115618503): add/use prepend_initializer to tf.scan
                return tf.concat([[init_state], hidden_states], axis=0)

            hidden_states = prefer_static.cond(
                self._num_steps > 1, _scan_multiple_steps,
                lambda: init_state[tf.newaxis, ...])

            hidden_one_hot = tf.one_hot(
                hidden_states,
                num_states,
                dtype=self._observation_distribution.dtype)
            # hidden_one_hot :: num_steps n batch_size num_states

            # The observation distribution batch size might not match
            # the required batch size so as with the initial and
            # transition distributions we generate more samples and
            # reshape.
            observation_repeat = (batch_size // tf.reduce_prod(
                self._observation_distribution.batch_shape_tensor()[:-1]))

            possible_observations = self._observation_distribution.sample(
                [self._num_steps, observation_repeat * n], seed=strm())

            inner_shape = self._observation_distribution.event_shape

            # possible_observations :: num_steps (observation_repeat * n)
            #                          observation_batch[:-1] num_states inner_shape

            possible_observations = tf.reshape(
                possible_observations,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           inner_shape],
                          axis=0))

            # possible_observations :: steps n batch_size num_states inner_shape

            hidden_one_hot = tf.reshape(
                hidden_one_hot,
                tf.concat([[self._num_steps, n], batch_shape, [num_states],
                           tf.ones_like(inner_shape)],
                          axis=0))

            # hidden_one_hot :: steps n batch_size num_states "inner_shape"

            observations = tf.reduce_sum(hidden_one_hot *
                                         possible_observations,
                                         axis=-1 - tf.size(inner_shape))

            # observations :: steps n batch_size inner_shape

            observations = distribution_util.move_dimension(
                observations, 0, 1 + tf.size(batch_shape))

            # returned :: n batch_shape steps inner_shape

            return observations
Exemplo n.º 20
0
    def _sample_n(self, n, seed=None):
        power = tf.convert_to_tensor(self.power)
        shape = tf.concat([[n], tf.shape(power)], axis=0)

        has_seed = seed is not None
        seed = SeedStream(seed, salt='zipf')

        minval_u = self._hat_integral(0.5, power=power) + 1.
        maxval_u = self._hat_integral(tf.int64.max - 0.5, power=power)

        def loop_body(should_continue, k):
            """Resample the non-accepted points."""
            # The range of U is chosen so that the resulting sample K lies in
            # [0, tf.int64.max). The final sample, if accepted, is K + 1.
            u = tf.random.uniform(shape,
                                  minval=minval_u,
                                  maxval=maxval_u,
                                  dtype=power.dtype,
                                  seed=seed())

            # Sample the point X from the continuous density h(x) \propto x^(-power).
            x = self._hat_integral_inverse(u, power=power)

            # Rejection-inversion requires a `hat` function, h(x) such that
            # \int_{k - .5}^{k + .5} h(x) dx >= pmf(k + 1) for points k in the
            # support. A natural hat function for us is h(x) = x^(-power).
            #
            # After sampling X from h(x), suppose it lies in the interval
            # (K - .5, K + .5) for integer K. Then the corresponding K is accepted if
            # if lies to the left of x_K, where x_K is defined by:
            #   \int_{x_k}^{K + .5} h(x) dx = H(x_K) - H(K + .5) = pmf(K + 1),
            # where H(x) = \int_x^inf h(x) dx.

            # Solving for x_K, we find that x_K = H_inverse(H(K + .5) + pmf(K + 1)).
            # Or, the acceptance condition is X <= H_inverse(H(K + .5) + pmf(K + 1)).
            # Since X = H_inverse(U), this simplifies to U <= H(K + .5) + pmf(K + 1).

            # Update the non-accepted points.
            # Since X \in (K - .5, K + .5), the sample K is chosen as floor(X + 0.5).
            k = tf.where(should_continue, tf.floor(x + 0.5), k)
            accept = (u <= self._hat_integral(k + .5, power=power) +
                      tf.exp(self._log_prob(k + 1, power=power)))

            return [should_continue & (~accept), k]

        should_continue, samples = tf.while_loop(
            cond=lambda should_continue, *ignore: tf.reduce_any(should_continue
                                                                ),
            body=loop_body,
            loop_vars=[
                tf.ones(shape, dtype=tf.bool),  # should_continue
                tf.zeros(shape, dtype=power.dtype),  # k
            ],
            parallel_iterations=1 if has_seed else 10,
            maximum_iterations=self.sample_maximum_iterations,
        )
        samples = samples + 1.

        if self.validate_args and dtype_util.is_integer(self.dtype):
            samples = distribution_util.embed_check_integer_casting_closed(
                samples, target_dtype=self.dtype, assert_positive=True)

        samples = tf.cast(samples, self.dtype)

        if self.validate_args:
            npdt = dtype_util.as_numpy_dtype(self.dtype)
            v = npdt(
                dtype_util.min(npdt) if dtype_util.is_integer(npdt) else np.nan
            )
            samples = tf.where(should_continue, v, samples)

        return samples
Exemplo n.º 21
0
    def _sample_n(self, num_samples, seed=None, name=None):
        """Returns a Tensor of samples from an LKJ distribution.

    Args:
      num_samples: Python `int`. The number of samples to draw.
      seed: Python integer seed for RNG
      name: Python `str` name prefixed to Ops created by this function.

    Returns:
      samples: A Tensor of correlation matrices with shape `[n, B, D, D]`,
        where `B` is the shape of the `concentration` parameter, and `D`
        is the `dimension`.

    Raises:
      ValueError: If `dimension` is negative.
    """
        if self.dimension < 0:
            raise ValueError(
                'Cannot sample negative-dimension correlation matrices.')
        # Notation below: B is the batch shape, i.e., tf.shape(concentration)
        seed = SeedStream(seed, 'sample_lkj')
        with tf.name_scope('sample_lkj' or name):
            concentration = tf.convert_to_tensor(self.concentration)
            if not dtype_util.is_floating(concentration.dtype):
                raise TypeError(
                    'The concentration argument should have floating type, not '
                    '{}'.format(dtype_util.name(concentration.dtype)))

            concentration = _replicate(num_samples, concentration)
            concentration_shape = tf.shape(concentration)
            if self.dimension <= 1:
                # For any dimension <= 1, there is only one possible correlation matrix.
                shape = tf.concat(
                    [concentration_shape, [self.dimension, self.dimension]],
                    axis=0)
                return tf.ones(shape=shape, dtype=concentration.dtype)
            beta_conc = concentration + (self.dimension - 2.) / 2.
            beta_dist = beta.Beta(concentration1=beta_conc,
                                  concentration0=beta_conc)

            # Note that the sampler below deviates from [1], by doing the sampling in
            # cholesky space. This does not change the fundamental logic of the
            # sampler, but does speed up the sampling.

            # This is the correlation coefficient between the first two dimensions.
            # This is also `r` in reference [1].
            corr12 = 2. * beta_dist.sample(seed=seed()) - 1.

            # Below we construct the Cholesky of the initial 2x2 correlation matrix,
            # which is of the form:
            # [[1, 0], [r, sqrt(1 - r**2)]], where r is the correlation between the
            # first two dimensions.
            # This is the top-left corner of the cholesky of the final sample.
            first_row = tf.concat([
                tf.ones_like(corr12)[..., tf.newaxis],
                tf.zeros_like(corr12)[..., tf.newaxis]
            ],
                                  axis=-1)
            second_row = tf.concat([
                corr12[..., tf.newaxis],
                tf.sqrt(1 - corr12**2)[..., tf.newaxis]
            ],
                                   axis=-1)

            chol_result = tf.concat([
                first_row[..., tf.newaxis, :], second_row[..., tf.newaxis, :]
            ],
                                    axis=-2)

            for n in range(2, self.dimension):
                # Loop invariant: on entry, result has shape B + [n, n]
                beta_conc = beta_conc - 0.5
                # norm is y in reference [1].
                norm = beta.Beta(concentration1=n / 2.,
                                 concentration0=beta_conc).sample(seed=seed())
                # distance shape: B + [1] for broadcast
                distance = tf.sqrt(norm)[..., tf.newaxis]
                # direction is u in reference [1].
                # direction shape: B + [n]
                direction = _uniform_unit_norm(n, concentration_shape,
                                               concentration.dtype, seed)
                # raw_correlation is w in reference [1].
                raw_correlation = distance * direction  # shape: B + [n]

                # This is the next row in the cholesky of the result,
                # which differs from the construction in reference [1].
                # In the reference, the new row `z` = chol_result @ raw_correlation^T
                # = C @ raw_correlation^T (where as short hand we use C = chol_result).
                # We prove that the below equation is the right row to add to the
                # cholesky, by showing equality with reference [1].
                # Let S be the sample constructed so far, and let `z` be as in
                # reference [1]. Then at this iteration, the new sample S' will be
                # [[S z^T]
                #  [z 1]]
                # In our case we have the cholesky decomposition factor C, so
                # we want our new row x (same size as z) to satisfy:
                #  [[S z^T]  [[C 0]    [[C^T  x^T]         [[CC^T  Cx^T]
                #   [z 1]] =  [x k]]    [0     k]]  =       [xC^t   xx^T + k**2]]
                # Since C @ raw_correlation^T = z = C @ x^T, and C is invertible,
                # we have that x = raw_correlation. Also 1 = xx^T + k**2, so k
                # = sqrt(1 - xx^T) = sqrt(1 - |raw_correlation|**2) = sqrt(1 -
                # distance**2).
                new_row = tf.concat(
                    [raw_correlation,
                     tf.sqrt(1. - norm[..., tf.newaxis])],
                    axis=-1)

                # Finally add this new row, by growing the cholesky of the result.
                chol_result = tf.concat([
                    chol_result,
                    tf.zeros_like(chol_result[..., 0][..., tf.newaxis])
                ],
                                        axis=-1)

                chol_result = tf.concat(
                    [chol_result, new_row[..., tf.newaxis, :]], axis=-2)

            if self.input_output_cholesky:
                return chol_result

            result = tf.matmul(chol_result, chol_result, transpose_b=True)
            # The diagonal for a correlation matrix should always be ones. Due to
            # numerical instability the matmul might not achieve that, so manually set
            # these to ones.
            result = tf.linalg.set_diag(
                result, tf.ones(shape=tf.shape(result)[:-1],
                                dtype=result.dtype))
            # This sampling algorithm can produce near-PSD matrices on which standard
            # algorithms such as `tf.cholesky` or `tf.linalg.self_adjoint_eigvals`
            # fail. Specifically, as documented in b/116828694, around 2% of trials
            # of 900,000 5x5 matrices (distributed according to 9 different
            # concentration parameter values) contained at least one matrix on which
            # the Cholesky decomposition failed.
            return result