def test_passes_insufficient_rank_input_through_to_function(self): vectorized_vector_sum = vectorization_util.make_rank_polymorphic( lambda a, b: a + b, core_ndims=(1, 1)) c = vectorized_vector_sum(tf.convert_to_tensor(3.), tf.convert_to_tensor([1., 2., 3.])) self.assertAllClose(c, [4., 5., 6.]) vectorized_matvec = vectorization_util.make_rank_polymorphic( tf.linalg.matvec, core_ndims=(2, 1)) with self.assertRaisesRegexp(ValueError, 'Shape must be rank 2 but is rank 1'): vectorized_matvec(tf.zeros([5]), tf.zeros([2, 1, 5]))
def __init__(self, *args, **kwargs): super(_DefaultJointBijectorAutoBatched, self).__init__(*args, **kwargs) self._forward = vectorization_util.make_rank_polymorphic( self._forward, core_ndims=[self.forward_min_event_ndims]) self._inverse = vectorization_util.make_rank_polymorphic( self._inverse, core_ndims=[self.inverse_min_event_ndims]) self._forward_log_det_jacobian = vectorization_util.make_rank_polymorphic( self._forward_log_det_jacobian, core_ndims=[self.forward_min_event_ndims, None]) # `event_ndims` arg is not batched. self._inverse_log_det_jacobian = vectorization_util.make_rank_polymorphic( self._inverse_log_det_jacobian, core_ndims=[self.inverse_min_event_ndims, None]) # `event_ndims` arg is not batched.
def tests_aligns_broadcast_dims_using_core_ndims(self, is_static): np.random.seed(test_util.test_seed() % 2**32) def matvec(a, b): # Throws an error if either arg has extra dimensions. return tf.linalg.matvec(tf.reshape(a, tf.shape(a)[-2:]), tf.reshape(b, tf.shape(b)[-1:])) vectorized_matvec = vectorization_util.make_rank_polymorphic( matvec, core_ndims=(self.maybe_static(2, is_static=is_static), self.maybe_static(1, is_static=is_static))) for (a_shape, b_shape) in (([3, 2], [2]), ([4, 3, 2], [2]), ([4, 3, 2], [5, 1, 2])): a = self.maybe_static(np.random.randn(*a_shape), is_static=is_static) b = self.maybe_static(np.random.randn(*b_shape), is_static=is_static) c = tf.linalg.matvec(a, b) c_vectorized = vectorized_matvec(a, b) if is_static: self.assertAllEqual(c.shape, c_vectorized.shape) self.assertAllEqual(*self.evaluate((c, c_vectorized)))
def _map_measure_over_dists(self, attr, value): if any(x is None for x in self._model_flatten(value)): raise ValueError( 'No `value` part can be `None`; saw: {}.'.format(value)) if value is not None: value = self._model_flatten(value) def map_measure_fn(value): # We always provide a seed, since _flat_sample_distributions will # unconditionally split the seed. with tf.name_scope('map_measure_fn'): constant_seed = samplers.zeros_seed() return [ getattr(d, attr)(x) for (d, x) in zip(*self._flat_sample_distributions( value=value, seed=constant_seed)) ] if self.use_vectorized_map: map_measure_fn = vectorization_util.make_rank_polymorphic( map_measure_fn, core_ndims=[self._single_sample_ndims], validate_args=self.validate_args) return map_measure_fn(value)
def test_raises_error_on_insufficient_rank_input(self): def matvec(a, b): # Throws an error if either arg has extra dimensions. return tf.linalg.matvec(tf.reshape(a, tf.shape(a)[-2:]), tf.reshape(b, tf.shape(b)[-1:])) vectorized_matvec = vectorization_util.make_rank_polymorphic( matvec, core_ndims=(2, 1), validate_args=True) # Static check. with self.assertRaisesRegexp( ValueError, 'Cannot broadcast a Tensor having lower rank'): vectorized_matvec(tf.zeros([5]), tf.zeros([2, 1, 5])) # Runtime check. if not tf.executing_eagerly(): with self.assertRaisesOpError( 'Condition x >= 0 did not hold element-wise'): self.evaluate( vectorized_matvec( self.maybe_static(tf.zeros([5]), is_static=False), self.maybe_static(tf.zeros([2, 1, 5]), is_static=False)))
def _call_execute_model(self, sample_shape, seed, value=None, sample_and_trace_fn=None): """Wraps the base `_call_execute_model` with vectorized_map.""" value_might_have_sample_dims = ( value is not None and _might_have_excess_ndims( # Double-flatten in case any components have structured events. flat_value=nest.flatten_up_to(self._single_sample_ndims, self._model_flatten(value), check_types=False), flat_core_ndims=tf.nest.flatten(self._single_sample_ndims))) sample_shape_may_be_nontrivial = ( distribution_util.shape_may_be_nontrivial(sample_shape)) if not self.use_vectorized_map or not (sample_shape_may_be_nontrivial or # pylint: disable=protected-access value_might_have_sample_dims): # No need to auto-vectorize. return joint_distribution_lib.JointDistribution._call_execute_model( # pylint: disable=protected-access self, sample_shape=sample_shape, seed=seed, value=value, sample_and_trace_fn=sample_and_trace_fn) # Set up for autovectorized sampling. To support the `value` arg, we need to # first understand which dims are from the model itself, then wrap # `_call_execute_model` to batch over all remaining dims. value_core_ndims = None if value is not None: value_core_ndims = tf.nest.map_structure( lambda v, nd: None if v is None else nd, value, self._model_unflatten(self._single_sample_ndims), check_types=False) vectorized_execute_model_helper = vectorization_util.make_rank_polymorphic( lambda v, seed: ( # pylint: disable=g-long-lambda joint_distribution_lib.JointDistribution._call_execute_model( # pylint: disable=protected-access self, sample_shape=(), seed=seed, value=v, sample_and_trace_fn=sample_and_trace_fn)), core_ndims=[value_core_ndims, None], validate_args=self.validate_args) # Redefine the polymorphic fn to hack around `make_rank_polymorphic` # not currently supporting keyword args. This is needed because the # `iid_sample` wrapper below expects to pass through a `seed` kwarg. vectorized_execute_model = ( lambda v, seed: vectorized_execute_model_helper(v, seed)) # pylint: disable=unnecessary-lambda if sample_shape_may_be_nontrivial: vectorized_execute_model = vectorization_util.iid_sample( vectorized_execute_model, sample_shape) return vectorized_execute_model(value, seed=seed)
def test_docstring_example(self): add = lambda a, b: a + b add_vector_to_scalar = vectorization_util.make_rank_polymorphic( add, core_ndims=(1, 0)) self.assertAllEqual([[4., 5.], [5., 6.], [6., 7.]], self.evaluate( add_vector_to_scalar(tf.constant([1., 2.]), tf.constant([3., 4., 5.]))))
def test_can_escape_vectorization_with_none_ndims(self): # Suppose the original fn supports `None` as an input. fn = lambda x, y: (tf.reduce_sum(x, axis=0), y[0] if y is not None else y) polymorphic_fn = vectorization_util.make_rank_polymorphic( fn, core_ndims=[1, None]) rx, ry = polymorphic_fn([[1., 2., 4.], [3., 5., 7.]], None) self.assertAllEqual(rx.shape, [2]) self.assertIsNone(ry) single_arg_polymorphic_fn = vectorization_util.make_rank_polymorphic( lambda y: fn(tf.convert_to_tensor([1., 2., 3.]), y), core_ndims=None) rx, ry = self.evaluate( single_arg_polymorphic_fn( tf.convert_to_tensor([[1., 3.], [2., 4.]]))) self.assertAllEqual(ry, [1., 3.])
def test_docstring_example_passing_fn_arg(self): def apply_binop(fn, a, b): return fn(a, b) apply_binop_to_vector_and_scalar = vectorization_util.make_rank_polymorphic( apply_binop, core_ndims=(None, 1, 0)) r = self.evaluate( apply_binop_to_vector_and_scalar(lambda a, b: a * b, tf.constant([1., 2.]), tf.constant([3., 4., 5.]))) self.assertAllEqual( r, np.array([[3., 6.], [4., 8.], [5., 10.]], dtype=np.float32))
def test_unit_batch_dims_are_flattened(self): # Define `fn` to expect a vector input. fn = lambda x: tf.einsum('n->', x) # Verify that it won't accept a batch dimension. with self.assertRaisesRegexp(Exception, 'rank'): fn(tf.zeros([1, 5])) polymorphic_fn = vectorization_util.make_rank_polymorphic( fn, core_ndims=[1]) for batch_shape in ([], [1], [1, 1]): self.assertEqual(batch_shape, polymorphic_fn(tf.zeros(batch_shape + [5])).shape)
def test_can_call_with_variable_number_of_args(self): def scalar_sum(*args): return sum([tf.reshape(x, []) for x in args]) vectorized_sum = vectorization_util.make_rank_polymorphic(scalar_sum, core_ndims=0) xs = [ 1., np.array([3., 2.]).astype(np.float32), np.array([[1., 2.], [-4., 3.]]).astype(np.float32) ] self.assertAllEqual(self.evaluate(vectorized_sum(*xs)), sum(xs))
def test_rectifies_distribution_batch_shapes(self): def fn(scale): d = tfd.Normal(loc=0, scale=[scale]) x = d.sample() return d, x, d.log_prob(x) polymorphic_fn = vectorization_util.make_rank_polymorphic( fn, core_ndims=(0)) batch_scale = tf.constant([[4., 2., 5.], [1., 2., 1.]], dtype=tf.float32) d, x, lp = polymorphic_fn(batch_scale) self.assertAllEqual(d.batch_shape.as_list(), x.shape.as_list()) lp2 = d.log_prob(x) self.assertAllClose(*self.evaluate((lp, lp2)))
def _sample_n(self, sample_shape, seed, value=None, **kwargs): value_might_have_sample_dims = False if (value is None) and kwargs: value = self._resolve_value_from_kwargs(**kwargs) if value is not None: value = _pad_value_to_full_length(value, self.dtype) value = tf.nest.map_structure( lambda v: v if v is None else tf.convert_to_tensor(v), value) value_might_have_sample_dims = _might_have_excess_ndims( flat_value=self._model_flatten(value), flat_core_ndims=self._single_sample_ndims) if not self.use_vectorized_map or not ( _might_have_nonzero_size(sample_shape) or value_might_have_sample_dims): # No need to auto-vectorize. xs = self._call_flat_sample_distributions( sample_shape=sample_shape, seed=seed, value=value)[1] return self._model_unflatten(xs) # Set up for autovectorized sampling. To support the `value` arg, we need to # first understand which dims are from the model itself, then wrap # `_call_flat_sample_distributions` to batch over all remaining dims. value_core_ndims = None if value is not None: value_core_ndims = tf.nest.map_structure( lambda v, nd: None if v is None else nd, value, self._model_unflatten(self._single_sample_ndims), check_types=False) batch_flat_sample = vectorization_util.make_rank_polymorphic( lambda v, seed: self._call_flat_sample_distributions( # pylint: disable=g-long-lambda sample_shape=(), seed=seed, value=v)[1], core_ndims=[value_core_ndims, None], validate_args=self.validate_args) # Draw samples. vectorized_flat_sample = vectorization_util.iid_sample( # Redefine the polymorphic fn to hack around `make_rank_polymorphic` # not currently supporting keyword args. lambda v, seed: batch_flat_sample(v, seed), sample_shape) # pylint: disable=unnecessary-lambda xs = vectorized_flat_sample(value, seed=seed) return self._model_unflatten(xs)
def test_can_take_structured_input_and_output(self): # Dummy function that takes a (tuple, dict) pair # and returns a (dict, scalar) pair. def fn(x, y): a, b, c = x d, e = y['d'], y['e'] return {'r': a * b + c}, d + e vectorized_fn = vectorization_util.make_rank_polymorphic(fn, core_ndims=0) x = np.array([[2.], [3.]]), np.array(2.), np.array([5., 6., 7.]) y = {'d': np.array([[1.]]), 'e': np.array([2., 3., 4.])} vectorized_result = self.evaluate(vectorized_fn(x, y)) result = tf.nest.map_structure(lambda a, b: a * np.ones(b.shape), fn(x, y), vectorized_result) self.assertAllClose(result, vectorized_result)
def test_unit_batch_dims_are_not_vectorized(self): if not tf.executing_eagerly(): self.skipTest('Test relies on eager execution.') # Define `fn` to expect a vector input. def must_run_eagerly(x): if not tf.executing_eagerly(): raise ValueError( 'Code is running inside tf.function. This may ' 'indicate that auto-vectorization is being ' 'triggered unnecessarily.') return x polymorphic_fn = vectorization_util.make_rank_polymorphic( must_run_eagerly, core_ndims=[0]) for batch_shape in ([], [1], [1, 1]): polymorphic_fn(tf.zeros(batch_shape))
def _vectorize_member_fn(self, member_fn, core_ndims): # Pinned values must be treated as *inputs* to vectorized bijector # members, since the pins can have batch dimensions that coincide # with the other values being transformed. For example, given # # jd = JointDistributionNamedAutoBatched({' # 'a': tfd.LogNormal(0., 1.), # 'b': lambda a: tfd.Uniform(high=a + tf.ones([3]))}) # bij = jd.experimental_default_event_space_bijector() # sampled = jd.sample([2]) # ==> shape {'a': [2], 'b': [2, 3]} # # then if we pin a sampled value, # # pinned_jd = jd.experimental_pin(a=sampled['a']) # pinned_bij = pinned_jd.experimental_default_event_space_bijector() # # then we'd expect `pinned_bij.forward({'b': sampled['b']})` to return the # same value for `b` as `bij.forward(sampled)['b']`, in which # each of the two batch elements of `b` is transformed wrt the corresponding # batch element of `a`. If we used a naive _DefaultJointBijectorAutoBatched # instance for `pinned_bij`, we would instead get a shape error when # the pinned value for `a` appears in the model with batch shape `[2]`. The # solution is to ensure that the pinned value(s) are passed as input(s) to # every bijector method that we autovectorize. # # The approach implemented here uses a heavy hammer: calling any bijector # method rebuilds the pinned JD, creates a support bijector for its unpinned # values, and then invokes the requested method on that bijector. # (Re)creating all these Python objects incurs overhead in eager mode and # during `tf.function` tracing, but has no graph side effects, so repeated # execution of the traced function should be efficient. def build_and_invoke_pinned_bijector(pins, *args): bij = joint_distribution._DefaultJointBijector( # pylint: disable=protected-access self._jd.distribution.experimental_pin(**pins), **self._bijector_kwargs) return member_fn(bij, *args) vectorized_fn_of_pins = vectorization_util.make_rank_polymorphic( build_and_invoke_pinned_bijector, core_ndims=[self._pins_event_ndims] + core_ndims) return lambda *args: vectorized_fn_of_pins(self._jd.pins, *args)
def _map_measure_over_dists(self, attr, value): if any(x is None for x in self._model_flatten(value)): raise ValueError( 'No `value` part can be `None`; saw: {}.'.format(value)) if value is not None: value = self._model_flatten(value) def map_measure_fn(value): return [ getattr(d, attr)(x) for (d, x) in zip(*self._flat_sample_distributions(value=value)) ] if self.use_vectorized_map: map_measure_fn = vectorization_util.make_rank_polymorphic( map_measure_fn, core_ndims=[self._single_sample_ndims], validate_args=self.validate_args) return map_measure_fn(value)
def _vectorize_member_fn(self, member_fn, core_ndims): return vectorization_util.make_rank_polymorphic( lambda x: member_fn(self._joint_bijector, x), core_ndims=core_ndims)
a = (increment[..., idx] + 1.) / 2. b = (increment[..., idx] - 1.) / 2. chol = tfp_math.cholesky_update(chol, update_vector=_set_vector_index( increment, idx, a), multiplier=1) chol = tfp_math.cholesky_update(chol, update_vector=_set_vector_index( increment, idx, b), multiplier=-1) # There Cholesky decomposition should be unchanged in rows/cols before idx. # # TODO(b/229298550): Investigate whether this is really necessary, or if the # test failures we see without this line are due to an underlying bug. return tf.where((tf.range(chol.shape[-1]) < idx)[..., tf.newaxis], orig_chol, chol) def _set_vector_index_unbatched(v, idx, x): """Mutation-free equivalent of `v[idx] = x.""" return tf.tensor_scatter_nd_update(v, indices=[[idx]], updates=[x]) _set_vector_index = vectorization_util.make_rank_polymorphic( _set_vector_index_unbatched, core_ndims=[1, 0, 0]) def _half_logdet(chol): return tf.reduce_sum(tf.math.log(tf.linalg.diag_part(chol)), axis=-1)