def test_iid_sample_stateful(self):

        # Random fn using stateful samplers.
        def fn(key1, key2, seed=None):
            return [
                tfd.Normal(0., 1.).sample([3, 2], seed=seed), {
                    key1: tfd.Poisson([1., 2., 3., 4.]).sample(seed=seed + 1),
                    key2: tfd.LogNormal(0., 1.).sample(seed=seed + 2)
                }
            ]

        sample = self.evaluate(
            fn('a',
               key2='b',
               seed=test_util.test_seed(sampler_type='stateful')))

        sample_shape = [6, 1]
        iid_fn = vectorization_util.iid_sample(fn, sample_shape=sample_shape)
        iid_sample = self.evaluate(iid_fn('a', key2='b', seed=42))

        # Check that we did not get repeated samples.
        first_sampled_vector = iid_sample[0].flatten()
        self.assertAllGreater(
            (first_sampled_vector[1:] - first_sampled_vector[0])**2, 1e-6)

        expected_iid_shapes = tf.nest.map_structure(
            lambda x: np.concatenate([sample_shape, x.shape], axis=0), sample)
        iid_shapes = tf.nest.map_structure(lambda x: x.shape, iid_sample)
        self.assertAllEqualNested(expected_iid_shapes, iid_shapes)
    def _call_execute_model(self,
                            sample_shape,
                            seed,
                            value=None,
                            sample_and_trace_fn=None):
        """Wraps the base `_call_execute_model` with vectorized_map."""
        value_might_have_sample_dims = (
            value is not None and _might_have_excess_ndims(
                # Double-flatten in case any components have structured events.
                flat_value=nest.flatten_up_to(self._single_sample_ndims,
                                              self._model_flatten(value),
                                              check_types=False),
                flat_core_ndims=tf.nest.flatten(self._single_sample_ndims)))
        sample_shape_may_be_nontrivial = (
            distribution_util.shape_may_be_nontrivial(sample_shape))

        if not self.use_vectorized_map or not (sample_shape_may_be_nontrivial
                                               or  # pylint: disable=protected-access
                                               value_might_have_sample_dims):
            # No need to auto-vectorize.
            return joint_distribution_lib.JointDistribution._call_execute_model(  # pylint: disable=protected-access
                self,
                sample_shape=sample_shape,
                seed=seed,
                value=value,
                sample_and_trace_fn=sample_and_trace_fn)

        # Set up for autovectorized sampling. To support the `value` arg, we need to
        # first understand which dims are from the model itself, then wrap
        # `_call_execute_model` to batch over all remaining dims.
        value_core_ndims = None
        if value is not None:
            value_core_ndims = tf.nest.map_structure(
                lambda v, nd: None if v is None else nd,
                value,
                self._model_unflatten(self._single_sample_ndims),
                check_types=False)

        vectorized_execute_model_helper = vectorization_util.make_rank_polymorphic(
            lambda v, seed: (  # pylint: disable=g-long-lambda
                joint_distribution_lib.JointDistribution._call_execute_model(  # pylint: disable=protected-access
                    self,
                    sample_shape=(),
                    seed=seed,
                    value=v,
                    sample_and_trace_fn=sample_and_trace_fn)),
            core_ndims=[value_core_ndims, None],
            validate_args=self.validate_args)
        # Redefine the polymorphic fn to hack around `make_rank_polymorphic`
        # not currently supporting keyword args. This is needed because the
        # `iid_sample` wrapper below expects to pass through a `seed` kwarg.
        vectorized_execute_model = (
            lambda v, seed: vectorized_execute_model_helper(v, seed))  # pylint: disable=unnecessary-lambda

        if sample_shape_may_be_nontrivial:
            vectorized_execute_model = vectorization_util.iid_sample(
                vectorized_execute_model, sample_shape)

        return vectorized_execute_model(value, seed=seed)
示例#3
0
    def _sample_n(self, sample_shape, seed, value=None, **kwargs):

        value_might_have_sample_dims = False
        if (value is None) and kwargs:
            value = self._resolve_value_from_kwargs(**kwargs)
        if value is not None:
            value = _pad_value_to_full_length(value, self.dtype)
            value = tf.nest.map_structure(
                lambda v: v if v is None else tf.convert_to_tensor(v), value)
            value_might_have_sample_dims = _might_have_excess_ndims(
                flat_value=self._model_flatten(value),
                flat_core_ndims=self._single_sample_ndims)

        if not self.use_vectorized_map or not (
                _might_have_nonzero_size(sample_shape)
                or value_might_have_sample_dims):
            # No need to auto-vectorize.
            xs = self._call_flat_sample_distributions(
                sample_shape=sample_shape, seed=seed, value=value)[1]
            return self._model_unflatten(xs)

        # Set up for autovectorized sampling. To support the `value` arg, we need to
        # first understand which dims are from the model itself, then wrap
        # `_call_flat_sample_distributions` to batch over all remaining dims.
        value_core_ndims = None
        if value is not None:
            value_core_ndims = tf.nest.map_structure(
                lambda v, nd: None if v is None else nd,
                value,
                self._model_unflatten(self._single_sample_ndims),
                check_types=False)
        batch_flat_sample = vectorization_util.make_rank_polymorphic(
            lambda v, seed: self._call_flat_sample_distributions(  # pylint: disable=g-long-lambda
                sample_shape=(),
                seed=seed,
                value=v)[1],
            core_ndims=[value_core_ndims, None],
            validate_args=self.validate_args)

        # Draw samples.
        vectorized_flat_sample = vectorization_util.iid_sample(
            # Redefine the polymorphic fn to hack around `make_rank_polymorphic`
            # not currently supporting keyword args.
            lambda v, seed: batch_flat_sample(v, seed),
            sample_shape)  # pylint: disable=unnecessary-lambda
        xs = vectorized_flat_sample(value, seed=seed)
        return self._model_unflatten(xs)
    def test_iid_sample_stateless(self):

        sample_shape = [6]
        iid_fn = vectorization_util.iid_sample(tf.random.stateless_normal,
                                               sample_shape=sample_shape)

        warnings.simplefilter('always')
        with warnings.catch_warnings(record=True) as triggered:
            samples = iid_fn(
                [], seed=test_util.test_seed(sampler_type='stateless'))
            self.assertTrue(
                any('may be quite slow' in str(warning.message)
                    for warning in triggered))

        # Check that we did not get repeated samples.
        samples_ = self.evaluate(samples)
        self.assertAllGreater((samples_[1:] - samples_[0])**2, 1e-6)