コード例 #1
0
    def test_passes_insufficient_rank_input_through_to_function(self):

        vectorized_vector_sum = vectorization_util.make_rank_polymorphic(
            lambda a, b: a + b, core_ndims=(1, 1))
        c = vectorized_vector_sum(tf.convert_to_tensor(3.),
                                  tf.convert_to_tensor([1., 2., 3.]))
        self.assertAllClose(c, [4., 5., 6.])

        vectorized_matvec = vectorization_util.make_rank_polymorphic(
            tf.linalg.matvec, core_ndims=(2, 1))
        with self.assertRaisesRegexp(ValueError,
                                     'Shape must be rank 2 but is rank 1'):
            vectorized_matvec(tf.zeros([5]), tf.zeros([2, 1, 5]))
コード例 #2
0
 def __init__(self, *args, **kwargs):
     super(_DefaultJointBijectorAutoBatched, self).__init__(*args, **kwargs)
     self._forward = vectorization_util.make_rank_polymorphic(
         self._forward, core_ndims=[self.forward_min_event_ndims])
     self._inverse = vectorization_util.make_rank_polymorphic(
         self._inverse, core_ndims=[self.inverse_min_event_ndims])
     self._forward_log_det_jacobian = vectorization_util.make_rank_polymorphic(
         self._forward_log_det_jacobian,
         core_ndims=[self.forward_min_event_ndims,
                     None])  # `event_ndims` arg is not batched.
     self._inverse_log_det_jacobian = vectorization_util.make_rank_polymorphic(
         self._inverse_log_det_jacobian,
         core_ndims=[self.inverse_min_event_ndims,
                     None])  # `event_ndims` arg is not batched.
コード例 #3
0
    def tests_aligns_broadcast_dims_using_core_ndims(self, is_static):
        np.random.seed(test_util.test_seed() % 2**32)

        def matvec(a, b):
            # Throws an error if either arg has extra dimensions.
            return tf.linalg.matvec(tf.reshape(a,
                                               tf.shape(a)[-2:]),
                                    tf.reshape(b,
                                               tf.shape(b)[-1:]))

        vectorized_matvec = vectorization_util.make_rank_polymorphic(
            matvec,
            core_ndims=(self.maybe_static(2, is_static=is_static),
                        self.maybe_static(1, is_static=is_static)))

        for (a_shape, b_shape) in (([3, 2], [2]), ([4, 3,
                                                    2], [2]), ([4, 3,
                                                                2], [5, 1,
                                                                     2])):
            a = self.maybe_static(np.random.randn(*a_shape),
                                  is_static=is_static)
            b = self.maybe_static(np.random.randn(*b_shape),
                                  is_static=is_static)

            c = tf.linalg.matvec(a, b)
            c_vectorized = vectorized_matvec(a, b)
            if is_static:
                self.assertAllEqual(c.shape, c_vectorized.shape)
            self.assertAllEqual(*self.evaluate((c, c_vectorized)))
コード例 #4
0
    def _map_measure_over_dists(self, attr, value):
        if any(x is None for x in self._model_flatten(value)):
            raise ValueError(
                'No `value` part can be `None`; saw: {}.'.format(value))
        if value is not None:
            value = self._model_flatten(value)

        def map_measure_fn(value):
            # We always provide a seed, since _flat_sample_distributions will
            # unconditionally split the seed.
            with tf.name_scope('map_measure_fn'):
                constant_seed = samplers.zeros_seed()
                return [
                    getattr(d, attr)(x)
                    for (d, x) in zip(*self._flat_sample_distributions(
                        value=value, seed=constant_seed))
                ]

        if self.use_vectorized_map:
            map_measure_fn = vectorization_util.make_rank_polymorphic(
                map_measure_fn,
                core_ndims=[self._single_sample_ndims],
                validate_args=self.validate_args)

        return map_measure_fn(value)
コード例 #5
0
    def test_raises_error_on_insufficient_rank_input(self):
        def matvec(a, b):
            # Throws an error if either arg has extra dimensions.
            return tf.linalg.matvec(tf.reshape(a,
                                               tf.shape(a)[-2:]),
                                    tf.reshape(b,
                                               tf.shape(b)[-1:]))

        vectorized_matvec = vectorization_util.make_rank_polymorphic(
            matvec, core_ndims=(2, 1), validate_args=True)

        # Static check.
        with self.assertRaisesRegexp(
                ValueError, 'Cannot broadcast a Tensor having lower rank'):
            vectorized_matvec(tf.zeros([5]), tf.zeros([2, 1, 5]))

        # Runtime check.
        if not tf.executing_eagerly():
            with self.assertRaisesOpError(
                    'Condition x >= 0 did not hold element-wise'):
                self.evaluate(
                    vectorized_matvec(
                        self.maybe_static(tf.zeros([5]), is_static=False),
                        self.maybe_static(tf.zeros([2, 1, 5]),
                                          is_static=False)))
コード例 #6
0
    def _call_execute_model(self,
                            sample_shape,
                            seed,
                            value=None,
                            sample_and_trace_fn=None):
        """Wraps the base `_call_execute_model` with vectorized_map."""
        value_might_have_sample_dims = (
            value is not None and _might_have_excess_ndims(
                # Double-flatten in case any components have structured events.
                flat_value=nest.flatten_up_to(self._single_sample_ndims,
                                              self._model_flatten(value),
                                              check_types=False),
                flat_core_ndims=tf.nest.flatten(self._single_sample_ndims)))
        sample_shape_may_be_nontrivial = (
            distribution_util.shape_may_be_nontrivial(sample_shape))

        if not self.use_vectorized_map or not (sample_shape_may_be_nontrivial
                                               or  # pylint: disable=protected-access
                                               value_might_have_sample_dims):
            # No need to auto-vectorize.
            return joint_distribution_lib.JointDistribution._call_execute_model(  # pylint: disable=protected-access
                self,
                sample_shape=sample_shape,
                seed=seed,
                value=value,
                sample_and_trace_fn=sample_and_trace_fn)

        # Set up for autovectorized sampling. To support the `value` arg, we need to
        # first understand which dims are from the model itself, then wrap
        # `_call_execute_model` to batch over all remaining dims.
        value_core_ndims = None
        if value is not None:
            value_core_ndims = tf.nest.map_structure(
                lambda v, nd: None if v is None else nd,
                value,
                self._model_unflatten(self._single_sample_ndims),
                check_types=False)

        vectorized_execute_model_helper = vectorization_util.make_rank_polymorphic(
            lambda v, seed: (  # pylint: disable=g-long-lambda
                joint_distribution_lib.JointDistribution._call_execute_model(  # pylint: disable=protected-access
                    self,
                    sample_shape=(),
                    seed=seed,
                    value=v,
                    sample_and_trace_fn=sample_and_trace_fn)),
            core_ndims=[value_core_ndims, None],
            validate_args=self.validate_args)
        # Redefine the polymorphic fn to hack around `make_rank_polymorphic`
        # not currently supporting keyword args. This is needed because the
        # `iid_sample` wrapper below expects to pass through a `seed` kwarg.
        vectorized_execute_model = (
            lambda v, seed: vectorized_execute_model_helper(v, seed))  # pylint: disable=unnecessary-lambda

        if sample_shape_may_be_nontrivial:
            vectorized_execute_model = vectorization_util.iid_sample(
                vectorized_execute_model, sample_shape)

        return vectorized_execute_model(value, seed=seed)
コード例 #7
0
 def test_docstring_example(self):
     add = lambda a, b: a + b
     add_vector_to_scalar = vectorization_util.make_rank_polymorphic(
         add, core_ndims=(1, 0))
     self.assertAllEqual([[4., 5.], [5., 6.], [6., 7.]],
                         self.evaluate(
                             add_vector_to_scalar(tf.constant([1., 2.]),
                                                  tf.constant([3., 4.,
                                                               5.]))))
コード例 #8
0
    def test_can_escape_vectorization_with_none_ndims(self):

        # Suppose the original fn supports `None` as an input.
        fn = lambda x, y: (tf.reduce_sum(x, axis=0), y[0]
                           if y is not None else y)

        polymorphic_fn = vectorization_util.make_rank_polymorphic(
            fn, core_ndims=[1, None])
        rx, ry = polymorphic_fn([[1., 2., 4.], [3., 5., 7.]], None)
        self.assertAllEqual(rx.shape, [2])
        self.assertIsNone(ry)

        single_arg_polymorphic_fn = vectorization_util.make_rank_polymorphic(
            lambda y: fn(tf.convert_to_tensor([1., 2., 3.]), y),
            core_ndims=None)
        rx, ry = self.evaluate(
            single_arg_polymorphic_fn(
                tf.convert_to_tensor([[1., 3.], [2., 4.]])))
        self.assertAllEqual(ry, [1., 3.])
コード例 #9
0
    def test_docstring_example_passing_fn_arg(self):
        def apply_binop(fn, a, b):
            return fn(a, b)

        apply_binop_to_vector_and_scalar = vectorization_util.make_rank_polymorphic(
            apply_binop, core_ndims=(None, 1, 0))
        r = self.evaluate(
            apply_binop_to_vector_and_scalar(lambda a, b: a * b,
                                             tf.constant([1., 2.]),
                                             tf.constant([3., 4., 5.])))
        self.assertAllEqual(
            r, np.array([[3., 6.], [4., 8.], [5., 10.]], dtype=np.float32))
コード例 #10
0
    def test_unit_batch_dims_are_flattened(self):
        # Define `fn` to expect a vector input.
        fn = lambda x: tf.einsum('n->', x)
        # Verify that it won't accept a batch dimension.
        with self.assertRaisesRegexp(Exception, 'rank'):
            fn(tf.zeros([1, 5]))

        polymorphic_fn = vectorization_util.make_rank_polymorphic(
            fn, core_ndims=[1])
        for batch_shape in ([], [1], [1, 1]):
            self.assertEqual(batch_shape,
                             polymorphic_fn(tf.zeros(batch_shape + [5])).shape)
コード例 #11
0
    def test_can_call_with_variable_number_of_args(self):
        def scalar_sum(*args):
            return sum([tf.reshape(x, []) for x in args])

        vectorized_sum = vectorization_util.make_rank_polymorphic(scalar_sum,
                                                                  core_ndims=0)

        xs = [
            1.,
            np.array([3., 2.]).astype(np.float32),
            np.array([[1., 2.], [-4., 3.]]).astype(np.float32)
        ]
        self.assertAllEqual(self.evaluate(vectorized_sum(*xs)), sum(xs))
コード例 #12
0
    def test_rectifies_distribution_batch_shapes(self):
        def fn(scale):
            d = tfd.Normal(loc=0, scale=[scale])
            x = d.sample()
            return d, x, d.log_prob(x)

        polymorphic_fn = vectorization_util.make_rank_polymorphic(
            fn, core_ndims=(0))
        batch_scale = tf.constant([[4., 2., 5.], [1., 2., 1.]],
                                  dtype=tf.float32)
        d, x, lp = polymorphic_fn(batch_scale)
        self.assertAllEqual(d.batch_shape.as_list(), x.shape.as_list())
        lp2 = d.log_prob(x)
        self.assertAllClose(*self.evaluate((lp, lp2)))
コード例 #13
0
    def _sample_n(self, sample_shape, seed, value=None, **kwargs):

        value_might_have_sample_dims = False
        if (value is None) and kwargs:
            value = self._resolve_value_from_kwargs(**kwargs)
        if value is not None:
            value = _pad_value_to_full_length(value, self.dtype)
            value = tf.nest.map_structure(
                lambda v: v if v is None else tf.convert_to_tensor(v), value)
            value_might_have_sample_dims = _might_have_excess_ndims(
                flat_value=self._model_flatten(value),
                flat_core_ndims=self._single_sample_ndims)

        if not self.use_vectorized_map or not (
                _might_have_nonzero_size(sample_shape)
                or value_might_have_sample_dims):
            # No need to auto-vectorize.
            xs = self._call_flat_sample_distributions(
                sample_shape=sample_shape, seed=seed, value=value)[1]
            return self._model_unflatten(xs)

        # Set up for autovectorized sampling. To support the `value` arg, we need to
        # first understand which dims are from the model itself, then wrap
        # `_call_flat_sample_distributions` to batch over all remaining dims.
        value_core_ndims = None
        if value is not None:
            value_core_ndims = tf.nest.map_structure(
                lambda v, nd: None if v is None else nd,
                value,
                self._model_unflatten(self._single_sample_ndims),
                check_types=False)
        batch_flat_sample = vectorization_util.make_rank_polymorphic(
            lambda v, seed: self._call_flat_sample_distributions(  # pylint: disable=g-long-lambda
                sample_shape=(),
                seed=seed,
                value=v)[1],
            core_ndims=[value_core_ndims, None],
            validate_args=self.validate_args)

        # Draw samples.
        vectorized_flat_sample = vectorization_util.iid_sample(
            # Redefine the polymorphic fn to hack around `make_rank_polymorphic`
            # not currently supporting keyword args.
            lambda v, seed: batch_flat_sample(v, seed),
            sample_shape)  # pylint: disable=unnecessary-lambda
        xs = vectorized_flat_sample(value, seed=seed)
        return self._model_unflatten(xs)
コード例 #14
0
    def test_can_take_structured_input_and_output(self):
        # Dummy function that takes a (tuple, dict) pair
        # and returns a (dict, scalar) pair.
        def fn(x, y):
            a, b, c = x
            d, e = y['d'], y['e']
            return {'r': a * b + c}, d + e

        vectorized_fn = vectorization_util.make_rank_polymorphic(fn,
                                                                 core_ndims=0)

        x = np.array([[2.], [3.]]), np.array(2.), np.array([5., 6., 7.])
        y = {'d': np.array([[1.]]), 'e': np.array([2., 3., 4.])}
        vectorized_result = self.evaluate(vectorized_fn(x, y))
        result = tf.nest.map_structure(lambda a, b: a * np.ones(b.shape),
                                       fn(x, y), vectorized_result)
        self.assertAllClose(result, vectorized_result)
コード例 #15
0
    def test_unit_batch_dims_are_not_vectorized(self):
        if not tf.executing_eagerly():
            self.skipTest('Test relies on eager execution.')

        # Define `fn` to expect a vector input.
        def must_run_eagerly(x):
            if not tf.executing_eagerly():
                raise ValueError(
                    'Code is running inside tf.function. This may '
                    'indicate that auto-vectorization is being '
                    'triggered unnecessarily.')
            return x

        polymorphic_fn = vectorization_util.make_rank_polymorphic(
            must_run_eagerly, core_ndims=[0])
        for batch_shape in ([], [1], [1, 1]):
            polymorphic_fn(tf.zeros(batch_shape))
コード例 #16
0
    def _vectorize_member_fn(self, member_fn, core_ndims):
        # Pinned values must be treated as *inputs* to vectorized bijector
        # members, since the pins can have batch dimensions that coincide
        # with the other values being transformed. For example, given
        #
        # jd = JointDistributionNamedAutoBatched({'
        #   'a': tfd.LogNormal(0., 1.),
        #   'b': lambda a: tfd.Uniform(high=a + tf.ones([3]))})
        # bij = jd.experimental_default_event_space_bijector()
        # sampled = jd.sample([2])  # ==> shape {'a': [2], 'b': [2, 3]}
        #
        # then if we pin a sampled value,
        #
        # pinned_jd = jd.experimental_pin(a=sampled['a'])
        # pinned_bij = pinned_jd.experimental_default_event_space_bijector()
        #
        # then we'd expect `pinned_bij.forward({'b': sampled['b']})` to return the
        # same value for `b` as `bij.forward(sampled)['b']`, in which
        # each of the two batch elements of `b` is transformed wrt the corresponding
        # batch element of `a`. If we used a naive _DefaultJointBijectorAutoBatched
        # instance for `pinned_bij`, we would instead get a shape error when
        # the pinned value for `a` appears in the model with batch shape `[2]`. The
        # solution is to ensure that the pinned value(s) are passed as input(s) to
        # every bijector method that we autovectorize.
        #
        # The approach implemented here uses a heavy hammer: calling any bijector
        # method rebuilds the pinned JD, creates a support bijector for its unpinned
        # values, and then invokes the requested method on that bijector.
        # (Re)creating all these Python objects incurs overhead in eager mode and
        # during `tf.function` tracing, but has no graph side effects, so repeated
        # execution of the traced function should be efficient.
        def build_and_invoke_pinned_bijector(pins, *args):
            bij = joint_distribution._DefaultJointBijector(  # pylint: disable=protected-access
                self._jd.distribution.experimental_pin(**pins),
                **self._bijector_kwargs)
            return member_fn(bij, *args)

        vectorized_fn_of_pins = vectorization_util.make_rank_polymorphic(
            build_and_invoke_pinned_bijector,
            core_ndims=[self._pins_event_ndims] + core_ndims)
        return lambda *args: vectorized_fn_of_pins(self._jd.pins, *args)
コード例 #17
0
    def _map_measure_over_dists(self, attr, value):
        if any(x is None for x in self._model_flatten(value)):
            raise ValueError(
                'No `value` part can be `None`; saw: {}.'.format(value))
        if value is not None:
            value = self._model_flatten(value)

        def map_measure_fn(value):
            return [
                getattr(d, attr)(x)
                for (d,
                     x) in zip(*self._flat_sample_distributions(value=value))
            ]

        if self.use_vectorized_map:
            map_measure_fn = vectorization_util.make_rank_polymorphic(
                map_measure_fn,
                core_ndims=[self._single_sample_ndims],
                validate_args=self.validate_args)

        return map_measure_fn(value)
コード例 #18
0
 def _vectorize_member_fn(self, member_fn, core_ndims):
     return vectorization_util.make_rank_polymorphic(
         lambda x: member_fn(self._joint_bijector, x),
         core_ndims=core_ndims)
コード例 #19
0
        a = (increment[..., idx] + 1.) / 2.
        b = (increment[..., idx] - 1.) / 2.
        chol = tfp_math.cholesky_update(chol,
                                        update_vector=_set_vector_index(
                                            increment, idx, a),
                                        multiplier=1)
        chol = tfp_math.cholesky_update(chol,
                                        update_vector=_set_vector_index(
                                            increment, idx, b),
                                        multiplier=-1)

        # There Cholesky decomposition should be unchanged in rows/cols before idx.
        #
        # TODO(b/229298550): Investigate whether this is really necessary, or if the
        # test failures we see without this line are due to an underlying bug.
        return tf.where((tf.range(chol.shape[-1]) < idx)[..., tf.newaxis],
                        orig_chol, chol)


def _set_vector_index_unbatched(v, idx, x):
    """Mutation-free equivalent of `v[idx] = x."""
    return tf.tensor_scatter_nd_update(v, indices=[[idx]], updates=[x])


_set_vector_index = vectorization_util.make_rank_polymorphic(
    _set_vector_index_unbatched, core_ndims=[1, 0, 0])


def _half_logdet(chol):
    return tf.reduce_sum(tf.math.log(tf.linalg.diag_part(chol)), axis=-1)