def test_iid_sample_stateful(self): # Random fn using stateful samplers. def fn(key1, key2, seed=None): return [ tfd.Normal(0., 1.).sample([3, 2], seed=seed), { key1: tfd.Poisson([1., 2., 3., 4.]).sample(seed=seed + 1), key2: tfd.LogNormal(0., 1.).sample(seed=seed + 2) } ] sample = self.evaluate( fn('a', key2='b', seed=test_util.test_seed(sampler_type='stateful'))) sample_shape = [6, 1] iid_fn = vectorization_util.iid_sample(fn, sample_shape=sample_shape) iid_sample = self.evaluate(iid_fn('a', key2='b', seed=42)) # Check that we did not get repeated samples. first_sampled_vector = iid_sample[0].flatten() self.assertAllGreater( (first_sampled_vector[1:] - first_sampled_vector[0])**2, 1e-6) expected_iid_shapes = tf.nest.map_structure( lambda x: np.concatenate([sample_shape, x.shape], axis=0), sample) iid_shapes = tf.nest.map_structure(lambda x: x.shape, iid_sample) self.assertAllEqualNested(expected_iid_shapes, iid_shapes)
def _call_execute_model(self, sample_shape, seed, value=None, sample_and_trace_fn=None): """Wraps the base `_call_execute_model` with vectorized_map.""" value_might_have_sample_dims = ( value is not None and _might_have_excess_ndims( # Double-flatten in case any components have structured events. flat_value=nest.flatten_up_to(self._single_sample_ndims, self._model_flatten(value), check_types=False), flat_core_ndims=tf.nest.flatten(self._single_sample_ndims))) sample_shape_may_be_nontrivial = ( distribution_util.shape_may_be_nontrivial(sample_shape)) if not self.use_vectorized_map or not (sample_shape_may_be_nontrivial or # pylint: disable=protected-access value_might_have_sample_dims): # No need to auto-vectorize. return joint_distribution_lib.JointDistribution._call_execute_model( # pylint: disable=protected-access self, sample_shape=sample_shape, seed=seed, value=value, sample_and_trace_fn=sample_and_trace_fn) # Set up for autovectorized sampling. To support the `value` arg, we need to # first understand which dims are from the model itself, then wrap # `_call_execute_model` to batch over all remaining dims. value_core_ndims = None if value is not None: value_core_ndims = tf.nest.map_structure( lambda v, nd: None if v is None else nd, value, self._model_unflatten(self._single_sample_ndims), check_types=False) vectorized_execute_model_helper = vectorization_util.make_rank_polymorphic( lambda v, seed: ( # pylint: disable=g-long-lambda joint_distribution_lib.JointDistribution._call_execute_model( # pylint: disable=protected-access self, sample_shape=(), seed=seed, value=v, sample_and_trace_fn=sample_and_trace_fn)), core_ndims=[value_core_ndims, None], validate_args=self.validate_args) # Redefine the polymorphic fn to hack around `make_rank_polymorphic` # not currently supporting keyword args. This is needed because the # `iid_sample` wrapper below expects to pass through a `seed` kwarg. vectorized_execute_model = ( lambda v, seed: vectorized_execute_model_helper(v, seed)) # pylint: disable=unnecessary-lambda if sample_shape_may_be_nontrivial: vectorized_execute_model = vectorization_util.iid_sample( vectorized_execute_model, sample_shape) return vectorized_execute_model(value, seed=seed)
def _sample_n(self, sample_shape, seed, value=None, **kwargs): value_might_have_sample_dims = False if (value is None) and kwargs: value = self._resolve_value_from_kwargs(**kwargs) if value is not None: value = _pad_value_to_full_length(value, self.dtype) value = tf.nest.map_structure( lambda v: v if v is None else tf.convert_to_tensor(v), value) value_might_have_sample_dims = _might_have_excess_ndims( flat_value=self._model_flatten(value), flat_core_ndims=self._single_sample_ndims) if not self.use_vectorized_map or not ( _might_have_nonzero_size(sample_shape) or value_might_have_sample_dims): # No need to auto-vectorize. xs = self._call_flat_sample_distributions( sample_shape=sample_shape, seed=seed, value=value)[1] return self._model_unflatten(xs) # Set up for autovectorized sampling. To support the `value` arg, we need to # first understand which dims are from the model itself, then wrap # `_call_flat_sample_distributions` to batch over all remaining dims. value_core_ndims = None if value is not None: value_core_ndims = tf.nest.map_structure( lambda v, nd: None if v is None else nd, value, self._model_unflatten(self._single_sample_ndims), check_types=False) batch_flat_sample = vectorization_util.make_rank_polymorphic( lambda v, seed: self._call_flat_sample_distributions( # pylint: disable=g-long-lambda sample_shape=(), seed=seed, value=v)[1], core_ndims=[value_core_ndims, None], validate_args=self.validate_args) # Draw samples. vectorized_flat_sample = vectorization_util.iid_sample( # Redefine the polymorphic fn to hack around `make_rank_polymorphic` # not currently supporting keyword args. lambda v, seed: batch_flat_sample(v, seed), sample_shape) # pylint: disable=unnecessary-lambda xs = vectorized_flat_sample(value, seed=seed) return self._model_unflatten(xs)
def test_iid_sample_stateless(self): sample_shape = [6] iid_fn = vectorization_util.iid_sample(tf.random.stateless_normal, sample_shape=sample_shape) warnings.simplefilter('always') with warnings.catch_warnings(record=True) as triggered: samples = iid_fn( [], seed=test_util.test_seed(sampler_type='stateless')) self.assertTrue( any('may be quite slow' in str(warning.message) for warning in triggered)) # Check that we did not get repeated samples. samples_ = self.evaluate(samples) self.assertAllGreater((samples_[1:] - samples_[0])**2, 1e-6)