def testShapes(self, is_static):
        # 5x5 grid of index points in R^2 and flatten to 25x2
        index_points = np.linspace(-4., 4., 5, dtype=np.float64)
        index_points = np.stack(np.meshgrid(index_points, index_points),
                                axis=-1)
        index_points = np.reshape(index_points, [-1, 2])
        # ==> shape = [25, 2]
        batched_index_points = np.expand_dims(np.stack([index_points] * 6), -3)
        # ==> shape = [6, 1, 25, 2]

        # 9 inducing index points in R^2
        inducing_index_points = np.linspace(-4., 4., 3, dtype=np.float64)
        inducing_index_points = np.stack(np.meshgrid(inducing_index_points,
                                                     inducing_index_points),
                                         axis=-1)
        inducing_index_points = np.reshape(inducing_index_points, [-1, 2])
        # ==> shape = [9, 2]

        variational_inducing_observations_loc = np.zeros([3, 9],
                                                         dtype=np.float64)
        variational_inducing_observations_scale = np.eye(9, dtype=np.float64)

        # Kernel with batch_shape [2, 4, 1, 1]
        amplitude = np.array([1., 2.], np.float64).reshape([2, 1, 1, 1])
        length_scale = np.array([.1, .2, .3, .4],
                                np.float64).reshape([1, 4, 1, 1])

        jitter = np.float64(1e-6)
        observation_noise_variance = np.float64(1e-2)

        if not is_static:
            amplitude = tf1.placeholder_with_default(amplitude, shape=None)
            length_scale = tf1.placeholder_with_default(length_scale,
                                                        shape=None)
            batched_index_points = tf1.placeholder_with_default(
                batched_index_points, shape=None)

            inducing_index_points = tf1.placeholder_with_default(
                inducing_index_points, shape=None)
            variational_inducing_observations_loc = tf1.placeholder_with_default(
                variational_inducing_observations_loc, shape=None)
            variational_inducing_observations_scale = tf1.placeholder_with_default(
                variational_inducing_observations_scale, shape=None)

        kernel = tfp.math.psd_kernels.ExponentiatedQuadratic(
            amplitude, length_scale)

        vgp = tfd.VariationalGaussianProcess(
            kernel=kernel,
            index_points=batched_index_points,
            inducing_index_points=inducing_index_points,
            variational_inducing_observations_loc=(
                variational_inducing_observations_loc),
            variational_inducing_observations_scale=(
                variational_inducing_observations_scale),
            observation_noise_variance=observation_noise_variance,
            jitter=jitter)

        batch_shape = [2, 4, 6, 3]
        event_shape = [25]
        sample_shape = [9, 3]

        samples = vgp.sample(sample_shape, seed=test_util.test_seed())

        if is_static or tf.executing_eagerly():
            self.assertAllEqual(vgp.observation_noise_variance.shape,
                                tf.TensorShape([]))
            self.assertAllEqual(vgp.predictive_noise_variance.shape,
                                tf.TensorShape([]))
            self.assertAllEqual(vgp.batch_shape_tensor(), batch_shape)
            self.assertAllEqual(vgp.event_shape_tensor(), event_shape)
            self.assertAllEqual(samples.shape,
                                sample_shape + batch_shape + event_shape)
            self.assertAllEqual(vgp.batch_shape, batch_shape)
            self.assertAllEqual(vgp.event_shape, event_shape)
            self.assertAllEqual(samples.shape,
                                sample_shape + batch_shape + event_shape)
        else:
            self.assertAllEqual(
                self.evaluate(tf.shape(vgp.observation_noise_variance)), [])
            self.assertAllEqual(
                self.evaluate(tf.shape(vgp.predictive_noise_variance)), [])
            self.assertAllEqual(self.evaluate(vgp.batch_shape_tensor()),
                                batch_shape)
            self.assertAllEqual(self.evaluate(vgp.event_shape_tensor()),
                                event_shape)
            self.assertAllEqual(
                self.evaluate(samples).shape,
                sample_shape + batch_shape + event_shape)
            self.assertIsNone(tensorshape_util.rank(samples.shape))
            self.assertIsNone(tensorshape_util.rank(vgp.batch_shape))
            self.assertEqual(tensorshape_util.rank(vgp.event_shape), 1)
            self.assertIsNone(
                tf.compat.dimension_value(
                    tensorshape_util.dims(vgp.event_shape)[0]))
Exemple #2
0
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        # For `logits` and `probs`, we only want to have an assertion on what the
        # user actually passed. For now, we access the underlying categorical's
        # _logits and _probs directly. After the 2019-10-01 deprecation, it would
        # also work to use .logits() and .probs().
        logits = self._categorical._logits
        probs = self._categorical._probs
        outcomes = self._outcomes
        validate_args = self._validate_args

        # Build all shape and dtype checks during the `is_init` call.
        if is_init:

            def validate_equal_last_dim(tensor_a, tensor_b, message):
                event_size_a = tf.compat.dimension_value(tensor_a.shape[-1])
                event_size_b = tf.compat.dimension_value(tensor_b.shape[-1])
                if event_size_a is not None and event_size_b is not None:
                    if event_size_a != event_size_b:
                        raise ValueError(message)
                elif validate_args:
                    return assert_util.assert_equal(tf.shape(tensor_a)[-1],
                                                    tf.shape(tensor_b)[-1],
                                                    message=message)

            message = 'Size of outcomes must be greater than 0.'
            if tensorshape_util.num_elements(outcomes.shape) is not None:
                if tensorshape_util.num_elements(outcomes.shape) == 0:
                    raise ValueError(message)
            elif validate_args:
                assertions.append(
                    tf.assert_greater(tf.size(outcomes), 0, message=message))

            if logits is not None:
                maybe_assert = validate_equal_last_dim(
                    outcomes,
                    # pylint: disable=protected-access
                    self._categorical._logits,
                    # pylint: enable=protected-access
                    message=
                    'Last dimension of outcomes and logits must be equal size.'
                )
                if maybe_assert:
                    assertions.append(maybe_assert)

            if probs is not None:
                maybe_assert = validate_equal_last_dim(
                    outcomes,
                    probs,
                    message=
                    'Last dimension of outcomes and probs must be equal size.')
                if maybe_assert:
                    assertions.append(maybe_assert)

            message = 'Rank of outcomes must be 1.'
            ndims = tensorshape_util.rank(outcomes.shape)
            if ndims is not None:
                if ndims != 1:
                    raise ValueError(message)
            elif validate_args:
                assertions.append(
                    assert_util.assert_rank(outcomes, 1, message=message))

        if not validate_args:
            assert not assertions  # Should never happen.
            return []

        if is_init != tensor_util.is_ref(outcomes):
            assertions.append(
                assert_util.assert_equal(
                    tf.math.is_strictly_increasing(outcomes),
                    True,
                    message='outcomes is not strictly increasing.'))

        return assertions
Exemple #3
0
    def _log_prob(self, x):
        if self.input_output_cholesky:
            x_sqrt = x
        else:
            # Complexity: O(nbk**3)
            x_sqrt = tf.linalg.cholesky(x)

        batch_shape = self.batch_shape_tensor()
        event_shape = self.event_shape_tensor()
        x_ndims = tf.rank(input=x_sqrt)
        num_singleton_axes_to_prepend = (
            tf.maximum(tf.size(input=batch_shape) + 2, x_ndims) - x_ndims)
        x_with_prepended_singletons_shape = tf.concat([
            tf.ones([num_singleton_axes_to_prepend], dtype=tf.int32),
            tf.shape(input=x_sqrt)
        ], 0)
        x_sqrt = tf.reshape(x_sqrt, x_with_prepended_singletons_shape)
        ndims = tf.rank(x_sqrt)
        # sample_ndims = ndims - batch_ndims - event_ndims
        sample_ndims = ndims - tf.size(input=batch_shape) - 2
        sample_shape = tf.shape(input=x_sqrt)[:sample_ndims]

        # We need to be able to pre-multiply each matrix by its corresponding
        # batch scale matrix. Since a Distribution Tensor supports multiple
        # samples per batch, this means we need to reshape the input matrix `x`
        # so that the first b dimensions are batch dimensions and the last two
        # are of shape [dimension, dimensions*number_of_samples]. Doing these
        # gymnastics allows us to do a batch_solve.
        #
        # After we're done with sqrt_solve (the batch operation) we need to undo
        # this reshaping so what we're left with is a Tensor partitionable by
        # sample, batch, event dimensions.

        # Complexity: O(nbk**2) since transpose must access every element.
        scale_sqrt_inv_x_sqrt = x_sqrt
        perm = tf.concat(
            [tf.range(sample_ndims, ndims),
             tf.range(0, sample_ndims)], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)
        last_dim_size = (
            tf.cast(self.dimension, dtype=tf.int32) * tf.reduce_prod(
                input_tensor=x_with_prepended_singletons_shape[:sample_ndims]))
        shape = tf.concat([
            x_with_prepended_singletons_shape[sample_ndims:-2],
            [tf.cast(self.dimension, dtype=tf.int32), last_dim_size]
        ],
                          axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)

        # Complexity: O(nbM*k) where M is the complexity of the operator solving a
        # vector system. For LinearOperatorLowerTriangular, each solve is O(k**2) so
        # this step has complexity O(nbk^3).
        scale_sqrt_inv_x_sqrt = self.scale_operator.solve(
            scale_sqrt_inv_x_sqrt)

        # Undo make batch-op ready.
        # Complexity: O(nbk**2)
        shape = tf.concat([
            tf.shape(input=scale_sqrt_inv_x_sqrt)[:-2], event_shape,
            sample_shape
        ],
                          axis=0)
        scale_sqrt_inv_x_sqrt = tf.reshape(scale_sqrt_inv_x_sqrt, shape)
        perm = tf.concat([
            tf.range(ndims - sample_ndims, ndims),
            tf.range(0, ndims - sample_ndims)
        ], 0)
        scale_sqrt_inv_x_sqrt = tf.transpose(a=scale_sqrt_inv_x_sqrt,
                                             perm=perm)

        # Write V = SS', X = LL'. Then:
        # tr[inv(V) X] = tr[inv(S)' inv(S) L L']
        #              = tr[inv(S) L L' inv(S)']
        #              = tr[(inv(S) L) (inv(S) L)']
        #              = sum_{ik} (inv(S) L)_{ik}**2
        # The second equality follows from the cyclic permutation property.
        # Complexity: O(nbk**2)
        trace_scale_inv_x = tf.reduce_sum(
            input_tensor=tf.square(scale_sqrt_inv_x_sqrt), axis=[-2, -1])

        # Complexity: O(nbk)
        half_log_det_x = tf.reduce_sum(input_tensor=tf.math.log(
            tf.linalg.diag_part(x_sqrt)),
                                       axis=[-1])

        # Complexity: O(nbk**2)
        log_prob = ((self.df - self.dimension - 1.) * half_log_det_x -
                    0.5 * trace_scale_inv_x - self.log_normalization())

        # Set shape hints.
        # Try to merge what we know from the input x with what we know from the
        # parameters of this distribution.
        if tensorshape_util.rank(
                x.shape) is not None and tensorshape_util.rank(
                    self.batch_shape) is not None:
            log_prob.set_shape(
                tf.broadcast_static_shape(x.shape[:-2], self.batch_shape))

        return log_prob
    def _sample_control_dependencies(self, x):
        """Helper which validates sample arg, e.g., input to `log_prob`."""
        x_ndims = (tf.rank(x) if tensorshape_util.rank(x.shape) is None else
                   tensorshape_util.rank(x.shape))
        event_ndims = (tf.size(self.event_shape_tensor())
                       if tensorshape_util.rank(self.event_shape) is None else
                       tensorshape_util.rank(self.event_shape))
        batch_ndims = (tf.size(self.batch_shape_tensor())
                       if tensorshape_util.rank(self.batch_shape) is None else
                       tensorshape_util.rank(self.batch_shape))
        expected_batch_event_ndims = batch_ndims + event_ndims

        if (isinstance(x_ndims, int)
                and isinstance(expected_batch_event_ndims, int)):
            if x_ndims < expected_batch_event_ndims:
                raise NotImplementedError(
                    'Broadcasting is not supported; too few batch and event dims '
                    '(expected at least {}, saw {}).'.format(
                        expected_batch_event_ndims, x_ndims))
            ndims_assertion = []
        elif self.validate_args:
            ndims_assertion = [
                assert_util.assert_greater_equal(
                    x_ndims,
                    expected_batch_event_ndims,
                    message=('Broadcasting is not supported; too few '
                             'batch and event dims.'),
                    name='assert_batch_and_event_ndims_large_enough'),
            ]

        if (tensorshape_util.is_fully_defined(self.batch_shape)
                and tensorshape_util.is_fully_defined(self.event_shape)):
            expected_batch_event_shape = np.int32(
                tensorshape_util.concatenate(self.batch_shape,
                                             self.event_shape))
        else:
            expected_batch_event_shape = tf.concat([
                self.batch_shape_tensor(),
                self.event_shape_tensor(),
            ],
                                                   axis=0)

        sample_ndims = x_ndims - expected_batch_event_ndims
        if isinstance(sample_ndims, int):
            sample_ndims = max(sample_ndims, 0)
        if (isinstance(sample_ndims, int)
                and tensorshape_util.is_fully_defined(x.shape[sample_ndims:])):
            actual_batch_event_shape = np.int32(x.shape[sample_ndims:])
        else:
            sample_ndims = tf.maximum(sample_ndims, 0)
            actual_batch_event_shape = tf.shape(x)[sample_ndims:]

        assertions = []
        if (isinstance(expected_batch_event_shape, np.ndarray)
                and isinstance(actual_batch_event_shape, np.ndarray)):
            if any(expected_batch_event_shape != actual_batch_event_shape):
                raise NotImplementedError('Broadcasting is not supported; '
                                          'unexpected batch and event shape '
                                          '(expected {}, saw {}).'.format(
                                              expected_batch_event_shape,
                                              actual_batch_event_shape))
            assertions.extend(ndims_assertion)
        elif self.validate_args:
            with tf.control_dependencies(ndims_assertion):
                shape_assertion = assert_util.assert_equal(
                    expected_batch_event_shape,
                    actual_batch_event_shape,
                    message=('Broadcasting is not supported; '
                             'unexpected batch and event shape.'),
                    name='assert_batch_and_event_shape_same')
            assertions.append(shape_assertion)

        return assertions
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        axis = None
        paddings = None

        if is_init != tensor_util.is_ref(self.axis):
            # First we check the shape of the axis argument.
            msg = 'Argument `axis` must be scalar or vector.'
            if tensorshape_util.rank(self.axis.shape) is not None:
                if tensorshape_util.rank(self.axis.shape) > 1:
                    raise ValueError(msg)
            elif self.validate_args:
                if axis is None: axis = tf.convert_to_tensor(self.axis)
                assertions.append(
                    assert_util.assert_rank_at_most(axis, 1, message=msg))
            # Next we check the values of the axis argument.
            axis_ = tf.get_static_value(self.axis)
            msg = 'Argument `axis` must be negative.'
            if axis_ is not None:
                if np.any(axis_ > -1):
                    raise ValueError(msg)
            elif self.validate_args:
                if axis is None: axis = tf.convert_to_tensor(self.axis)
                assertions.append(assert_util.assert_less(axis, 0,
                                                          message=msg))
            msg = 'Argument `axis` elements must be unique.'
            if axis_ is not None:
                if len(np.array(axis_).reshape(-1)) != len(np.unique(axis_)):
                    raise ValueError(msg)
            elif self.validate_args:
                if axis is None: axis = tf.convert_to_tensor(self.axis)
                assertions.append(
                    assert_util.assert_equal(ps.size0(axis),
                                             ps.size0(ps.setdiff1d(axis)),
                                             message=msg))

        if is_init != tensor_util.is_ref(self.paddings):
            # First we check the shape of the paddings argument.
            msg = 'Argument `paddings` must be a vector of pairs.'
            if tensorshape_util.is_fully_defined(self.paddings.shape):
                shape = np.int32(self.paddings.shape)
                if len(shape) != 2 or shape[0] < 1 or shape[1] != 2:
                    raise ValueError(msg)
            elif self.validate_args:
                if paddings is None:
                    paddings = tf.convert_to_tensor(self.paddings)
                with tf.control_dependencies([
                        assert_util.assert_equal(tf.rank(paddings),
                                                 2,
                                                 message=msg)
                ]):
                    shape = tf.shape(paddings)
                    assertions.extend([
                        assert_util.assert_greater(shape[0], 0, message=msg),
                        assert_util.assert_equal(shape[1], 2, message=msg),
                    ])
            # Next we check the values of the paddings argument.
            paddings_ = tf.get_static_value(self.paddings)
            msg = 'Argument `paddings` must be non-negative.'
            if paddings_ is not None:
                if np.any(paddings_ < 0):
                    raise ValueError(msg)
            elif self.validate_args:
                if paddings is None:
                    paddings = tf.convert_to_tensor(self.paddings)
                assertions.append(
                    assert_util.assert_greater(paddings, -1, message=msg))

        if is_init != (tensor_util.is_ref(self.axis)
                       and tensor_util.is_ref(self.paddings)):
            axis_ = tf.get_static_value(self.axis)
            if axis_ is None and axis is None:
                axis = tf.convert_to_tensor(self.axis)
            len_axis = ps.size0(
                ps.reshape(axis if axis_ is None else axis_, shape=-1))

            paddings_ = tf.get_static_value(self.paddings)
            if paddings_ is None and paddings is None:
                paddings = tf.convert_to_tensor(self.paddings)
            len_paddings = ps.size0(
                paddings if paddings_ is None else paddings_)

            msg = ('Arguments `axis` and `paddings` must have the same number '
                   'of elements.')
            if (ps.is_numpy(len_axis) and ps.is_numpy(len_paddings)):
                if len_axis != len_paddings:
                    raise ValueError(
                        msg + ' Saw: {}, {}.'.format(self.axis, self.paddings))
            elif self.validate_args:
                assertions.append(
                    assert_util.assert_equal(len_axis,
                                             len_paddings,
                                             message=msg))

        return assertions
Exemple #6
0
    def _parameter_control_dependencies(self, is_init):
        """Validate parameters."""
        bw, bh, kd = None, None, None
        try:
            shape = tf.broadcast_static_shape(self.bin_widths.shape,
                                              self.bin_heights.shape)
        except ValueError as e:
            raise ValueError(
                '`bin_widths`, `bin_heights` must broadcast: {}'.format(
                    str(e)))
        bin_sizes_shape = shape
        try:
            shape = tf.broadcast_static_shape(shape[:-1],
                                              self.knot_slopes.shape[:-1])
        except ValueError as e:
            raise ValueError(
                '`bin_widths`, `bin_heights`, and `knot_slopes` must broadcast on '
                'batch axes: {}'.format(str(e)))

        assertions = []
        if (tensorshape_util.is_fully_defined(bin_sizes_shape[-1:])
                and tensorshape_util.is_fully_defined(
                    self.knot_slopes.shape[-1:])):
            if tensorshape_util.rank(self.knot_slopes.shape) > 0:
                num_interior_knots = tensorshape_util.dims(
                    bin_sizes_shape)[-1] - 1
                if tensorshape_util.dims(self.knot_slopes.shape)[-1] not in (
                        1, num_interior_knots):
                    raise ValueError(
                        'Innermost axis of non-scalar `knot_slopes` must broadcast with '
                        '{}; got {}.'.format(num_interior_knots,
                                             self.knot_slopes.shape))
        # elif self.validate_args:
        #   if is_init != any(tensor_util.is_ref(t)
        #       for t in (self.bin_widths, self.bin_heights, self.knot_slopes)):
        #     bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
        #     bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
        #     kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd
        #     shape = tf.broadcast_dynamic_shape(
        #         tf.shape((bw + bh)[..., :-1]), tf.shape(kd))
        #     assertions.append(
        #         assert_util.assert_greater(
        #             tf.shape(shape)[0],
        #             tf.zeros([], dtype=shape.dtype),
        #             message='`(bin_widths + bin_heights)[..., :-1]` must broadcast '
        #             'with `knot_slopes` to at least 1-D.'))

        if not self.validate_args:
            assert not assertions
            return assertions

        # if (is_init != tensor_util.is_ref(self.bin_widths) or
        #     is_init != tensor_util.is_ref(self.bin_heights)):
        #   bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
        #   bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
        #   assertions += [
        #       assert_util.assert_near(
        #           tf.reduce_sum(bw, axis=-1),
        #           tf.reduce_sum(bh, axis=-1),
        #           message='`sum(bin_widths, axis=-1)` must equal '
        #           '`sum(bin_heights, axis=-1)`.'),
        #   ]
        # if is_init != tensor_util.is_ref(self.bin_widths):
        #   bw = tf.convert_to_tensor(self.bin_widths) if bw is None else bw
        #   assertions += [
        #       assert_util.assert_positive(
        #           bw, message='`bin_widths` must be positive.'),
        #   ]
        # if is_init != tensor_util.is_ref(self.bin_heights):
        #   bh = tf.convert_to_tensor(self.bin_heights) if bh is None else bh
        #   assertions += [
        #       assert_util.assert_positive(
        #           bh, message='`bin_heights` must be positive.'),
        #   ]
        # if is_init != tensor_util.is_ref(self.knot_slopes):
        #   kd = _ensure_at_least_1d(self.knot_slopes) if kd is None else kd
        #   assertions += [
        #       assert_util.assert_positive(
        #           kd, message='`knot_slopes` must be positive.'),
        #   ]
        return assertions
Exemple #7
0
 def _batch_shape(self):
   if tensorshape_util.rank(self.samples.shape) is None:
     return tf.TensorShape(None)
   return self.samples.shape[:self._samples_axis]
 def make_response_likelihood(self, w, x):
   if tensorshape_util.rank(w.shape) == 1:
     y_bar = tf.matmul(w[tf.newaxis], x)[0]
   else:
     y_bar = tf.matmul(w, x)
   return tfd.Normal(loc=y_bar, scale=tf.ones_like(y_bar))  # [n]
Exemple #9
0
    def __init__(self,
                 logits=None,
                 probs=None,
                 dtype=tf.int32,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Categorical"):
        """Initialize Categorical distributions using class log-probabilities.

    Args:
      logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
        of a set of Categorical distributions. The first `N - 1` dimensions
        index into a batch of independent distributions and the last dimension
        represents a vector of logits for each class. Only one of `logits` or
        `probs` should be passed in.
      probs: An N-D `Tensor`, `N >= 1`, representing the probabilities
        of a set of Categorical distributions. The first `N - 1` dimensions
        index into a batch of independent distributions and the last dimension
        represents a vector of probabilities for each class. Only one of
        `logits` or `probs` should be passed in.
      dtype: The type of the event samples (default: int32).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits=logits,
                probs=probs,
                validate_args=validate_args,
                multidimensional=True,
                name=name)

            if validate_args:
                self._logits = distribution_util.embed_check_categorical_event_shape(
                    self._logits)

            logits_shape_static = tensorshape_util.with_rank_at_least(
                self._logits.shape, 1)
            if tensorshape_util.rank(logits_shape_static) is not None:
                self._batch_rank = tf.convert_to_tensor(
                    value=tensorshape_util.rank(logits_shape_static) - 1,
                    dtype=tf.int32,
                    name="batch_rank")
            else:
                with tf.name_scope("batch_rank"):
                    self._batch_rank = tf.rank(self._logits) - 1

            logits_shape = tf.shape(input=self._logits, name="logits_shape")
            num_categories = tf.compat.dimension_value(logits_shape_static[-1])
            if num_categories is not None:
                self._num_categories = tf.convert_to_tensor(
                    value=num_categories,
                    dtype=tf.int32,
                    name="num_categories")
            else:
                with tf.name_scope("num_categories"):
                    self._num_categories = logits_shape[self._batch_rank]

            if tensorshape_util.is_fully_defined(logits_shape_static[:-1]):
                self._batch_shape_val = tf.constant(
                    logits_shape_static[:-1].as_list(),
                    dtype=tf.int32,
                    name="batch_shape")
            else:
                with tf.name_scope("batch_shape"):
                    self._batch_shape_val = logits_shape[:-1]
        super(Categorical, self).__init__(
            dtype=dtype,
            reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._logits, self._probs],
            name=name)
    def reduce_fn(operands, inits, axis=None, keepdims=False):
        """Applies `reducer` to the given operands along the given axes.

    Args:
      operands: tuple of tensors, all having the same shape.
      inits: tuple of scalar tensors, with dtypes aligned to those of operands.
      axis: The axis or axes to reduce. One of `None`, an `int` or a sequence of
        `int`. `None` is taken to mean "reduce all axes".
      keepdims: When `True`, we do not squeeze away the reduced dims, instead
        returning values with singleton dims in those axes.

    Returns:
      reduced: A tuple of the reduced operands.
    """
        # Static shape consistency checks.
        args_shape = operands[0].shape
        for arg in operands[1:]:
            args_shape = tensorshape_util.merge_with(args_shape, arg.shape)
        ndims = tensorshape_util.rank(args_shape)
        if ndims is None:
            raise ValueError(
                'Rank of at least one of `operands` must be known statically.')
        # Ensure the 'axis' arg is a tuple of non-negative ints.
        axis = np.arange(ndims) if axis is None else np.array(axis)
        if axis.ndim > 1:
            raise ValueError(
                '`axis` must be `None`, an `int`, or a sequence of '
                '`int`, but got {}'.format(axis))
        axis = np.reshape(axis, [-1])
        axis = np.where(axis < 0, axis + ndims, axis)
        axis = tuple(int(ax) for ax in axis)

        if JAX_MODE:
            from jax import lax  # pylint: disable=g-import-not-at-top
            result = lax.reduce(operands,
                                init_values=inits,
                                dimensions=axis,
                                computation=reducer)
        elif (tf.executing_eagerly()
              or not control_flow_util.GraphOrParentsInXlaContext(
                  tf1.get_default_graph())):
            result = _variadic_reduce(operands,
                                      init=inits,
                                      axis=axis,
                                      reducer=reducer)
        else:
            result = _xla_reduce(operands, inits, axis)

        if keepdims:
            axis_nhot = ps.reduce_sum(ps.one_hot(axis,
                                                 depth=ndims,
                                                 on_value=True,
                                                 off_value=False,
                                                 dtype=tf.bool),
                                      axis=0)
            in_shape = args_shape
            if not tensorshape_util.is_fully_defined(in_shape):
                in_shape = tf.shape(operands[0])
            final_shape = ps.where(axis_nhot, 1, in_shape)
            result = tf.nest.map_structure(
                lambda t: tf.reshape(t, final_shape), result)
        return result
Exemple #11
0
def _rank(input, name=None):  # pylint: disable=redefined-builtin,unused-argument
    if not hasattr(input, 'shape'):
        input = (tf.convert_to_tensor(input)
                 if tf.get_static_value(input) is None else np.array(input))
    ndims_ = tensorshape_util.rank(getattr(input, 'shape', None))
    return tf.rank(input) if ndims_ is None else np.int32(ndims_)
def get_fldj_theoretical(bijector,
                         x,
                         event_ndims,
                         inverse_event_ndims=None,
                         input_to_unconstrained=None,
                         output_to_unconstrained=None):
    """Numerically approximate the forward log det Jacobian of a bijector.

  We compute the Jacobian of the chain
  output_to_unconst_vec(bijector(inverse(input_to_unconst_vec))) so that
  we're working with a full rank matrix.  We then adjust the resulting Jacobian
  for the unconstraining bijectors.

  Bijectors that constrain / unconstrain their inputs/outputs may not be
  testable with this method, since the composition above may reduce the test
  to something trivial.  However, bijectors that map within constrained spaces
  should be fine.

  Args:
    bijector: the bijector whose Jacobian we wish to approximate
    x: the value for which we want to approximate the Jacobian. x must have
      a a single batch dimension for compatibility with tape.batch_jacobian.
    event_ndims: number of dimensions in an event
    inverse_event_ndims: Integer describing the number of event dimensions for
      the bijector codomain. If None, then the value of `event_ndims` is used.
    input_to_unconstrained: bijector that maps the input to the above bijector
      to an unconstrained 1-D vector.  If unspecified, flatten the input into
      a 1-D vector according to its event_ndims.
    output_to_unconstrained: bijector that maps the output of the above bijector
      to an unconstrained 1-D vector.  If unspecified, flatten the input into
      a 1-D vector according to its event_ndims.

  Returns:
    A gradient-based approximation to the log det Jacobian of bijector.forward
    evaluated at x.
  """
    if inverse_event_ndims is None:
        inverse_event_ndims = event_ndims
    if input_to_unconstrained is None:
        input_to_unconstrained = reshape_bijector.Reshape(
            event_shape_in=x.shape[tensorshape_util.rank(x.shape) -
                                   event_ndims:],
            event_shape_out=[-1])
    if output_to_unconstrained is None:
        output_to_unconstrained = reshape_bijector.Reshape(
            event_shape_in=x.shape[tensorshape_util.rank(x.shape) -
                                   event_ndims:],
            event_shape_out=[-1])

    x = tf.convert_to_tensor(x)
    x_unconstrained = 1 * input_to_unconstrained.forward(x)

    with tf.GradientTape(persistent=True) as tape:
        tape.watch(x_unconstrained)
        f_x = bijector.forward(input_to_unconstrained.inverse(x_unconstrained))
        f_x_unconstrained = output_to_unconstrained.forward(f_x)
    jacobian = tape.batch_jacobian(f_x_unconstrained,
                                   x_unconstrained,
                                   experimental_use_pfor=False)
    logging.vlog(1, 'Jacobian: %s', jacobian)

    log_det_jacobian = 0.5 * tf.linalg.slogdet(
        tf.matmul(jacobian, jacobian, adjoint_a=True)).log_abs_determinant

    return (log_det_jacobian + input_to_unconstrained.forward_log_det_jacobian(
        x, event_ndims=event_ndims) -
            output_to_unconstrained.forward_log_det_jacobian(
                f_x, event_ndims=inverse_event_ndims))
def index_remapping_gather(params,
                           indices,
                           axis=0,
                           indices_axis=0,
                           name='index_remapping_gather'):
    """Gather values from `axis` of `params` using `indices_axis` of `indices`.

  The shape of `indices` must broadcast to that of `params` when
  their `indices_axis` and `axis` (respectively) are aligned:

  ```python
  # params.shape:
  [p[0],  ..., ...,         p[axis], ..., ..., p[rank(params)] - 1])
  # indices.shape:
        [i[0], ..., i[indices_axis], ..., i[rank(indices)] - 1])
  ```

  In particular, `params` must have at least as many
  leading dimensions as `indices` (`axis >= indices_axis`), and at least as many
  trailing dimensions (`rank(params) - axis >= rank(indices) - indices_axis`).

  The `result` has the same shape as `params`, except that the dimension
  of size `p[axis]` is replaced by one of size `i[indices_axis]`:

  ```python
  # result.shape:
  [p[0],  ..., ..., i[indices_axis], ..., ..., p[rank(params) - 1]]
  ```

  In the case where `rank(params) == 5`, `rank(indices) == 3`, `axis = 2`, and
  `indices_axis = 1`, the result is given by

   ```python
   # alignment is:                       v axis
   # params.shape    ==   [p[0], p[1], p[2], p[3], p[4]]
   # indices.shape   ==         [i[0], i[1], i[2]]
   #                                     ^ indices_axis
   result[i, j, k, l, m] = params[i, j, indices[j, k, l], l, m]
  ```

  Args:
    params:  `N-D` `Tensor` (`N > 0`) from which to gather values.
      Number of dimensions must be known statically.
    indices: `Tensor` with values in `{0, ..., params.shape[axis] - 1}`, whose
      shape broadcasts to that of `params` as described above.
    axis: Python `int` axis of `params` from which to gather.
    indices_axis: Python `int` axis of `indices` to align with the `axis`
      over which `params` is gathered.
    name: String name for scoping created ops.

  Returns:
    `Tensor` composed of elements of `params`.

  Raises:
    ValueError: If shape/rank requirements are not met.
  """
    with tf.name_scope(name):
        params = tf.convert_to_tensor(params, name='params')
        indices = tf.convert_to_tensor(indices, name='indices')

        params_ndims = tensorshape_util.rank(params.shape)
        indices_ndims = tensorshape_util.rank(indices.shape)
        # `axis` dtype must match ndims, which are 64-bit Python ints.
        axis = tf.get_static_value(
            ps.convert_to_shape_tensor(axis, dtype=tf.int64))
        indices_axis = tf.get_static_value(
            ps.convert_to_shape_tensor(indices_axis, dtype=tf.int64))

        if params_ndims is None:
            raise ValueError(
                'Rank of `params`, must be known statically. This is due to '
                'tf.gather not accepting a `Tensor` for `batch_dims`.')

        if axis is None:
            raise ValueError(
                '`axis` must be known statically. This is due to '
                'tf.gather not accepting a `Tensor` for `batch_dims`.')

        if indices_axis is None:
            raise ValueError(
                '`indices_axis` must be known statically. This is due to '
                'tf.gather not accepting a `Tensor` for `batch_dims`.')

        if indices_axis > axis:
            raise ValueError(
                '`indices_axis` should be <= `axis`, but was {} > {}'.format(
                    indices_axis, axis))

        if params_ndims < 1:
            raise ValueError(
                'Rank of params should be `> 0`, but was {}'.format(
                    params_ndims))

        if indices_ndims is not None and indices_ndims < 1:
            raise ValueError(
                'Rank of indices should be `> 0`, but was {}'.format(
                    indices_ndims))

        if (indices_ndims is not None
                and (indices_ndims - indices_axis > params_ndims - axis)):
            raise ValueError(
                '`rank(params) - axis` ({} - {}) must be >= `rank(indices) - '
                'indices_axis` ({} - {}), but was not.'.format(
                    params_ndims, axis, indices_ndims, indices_axis))

        # `tf.gather` requires the axis to be the rightmost batch ndim. So, we
        # transpose `indices_axis` to be the rightmost dimension of `indices`...
        transposed_indices = dist_util.move_dimension(indices,
                                                      source_idx=indices_axis,
                                                      dest_idx=-1)

        # ... and `axis` to be the corresponding (aligned as in the docstring)
        # dimension of `params`.
        broadcast_indices_ndims = indices_ndims + (axis - indices_axis)
        transposed_params = dist_util.move_dimension(
            params, source_idx=axis, dest_idx=broadcast_indices_ndims - 1)

        # Next we broadcast `indices` so that its shape has the same prefix as
        # `params.shape`.
        transposed_params_shape = ps.shape(transposed_params)
        result_shape = ps.concat([
            transposed_params_shape[:broadcast_indices_ndims - 1],
            ps.shape(indices)[indices_axis:indices_axis + 1],
            transposed_params_shape[broadcast_indices_ndims:]
        ],
                                 axis=0)
        broadcast_indices = ps.broadcast_to(
            transposed_indices, result_shape[:broadcast_indices_ndims])

        result_t = tf.gather(transposed_params,
                             broadcast_indices,
                             batch_dims=broadcast_indices_ndims - 1,
                             axis=broadcast_indices_ndims - 1)
        return dist_util.move_dimension(result_t,
                                        source_idx=broadcast_indices_ndims - 1,
                                        dest_idx=axis)
Exemple #14
0
def brier_decomposition(labels, logits, name=None):
    r"""Decompose the Brier score into uncertainty, resolution, and reliability.

  [Proper scoring rules][1] measure the quality of probabilistic predictions;
  any proper scoring rule admits a [unique decomposition][2] as
  `Score = Uncertainty - Resolution + Reliability`, where:

  * `Uncertainty`, is a generalized entropy of the average predictive
    distribution; it can both be positive or negative.
  * `Resolution`, is a generalized variance of individual predictive
    distributions; it is always non-negative.  Difference in predictions reveal
    information, that is why a larger resolution improves the predictive score.
  * `Reliability`, a measure of calibration of predictions against the true
    frequency of events.  It is always non-negative and a lower value here
    indicates better calibration.

  This method estimates the above decomposition for the case of the Brier
  scoring rule for discrete outcomes.  For this, we need to discretize the space
  of probability distributions; we choose a simple partition of the space into
  `nlabels` events: given a distribution `p` over `nlabels` outcomes, the index
  `k` for which `p_k > p_i` for all `i != k` determines the discretization
  outcome; that is, `p in M_k`, where `M_k` is the set of all distributions for
  which `p_k` is the largest value among all probabilities.

  The estimation error of each component is O(k/n), where n is the number
  of instances and k is the number of labels.  There may be an error of this
  order when compared to `brier_score`.

  #### References
  [1]: Tilmann Gneiting, Adrian E. Raftery.
       Strictly Proper Scoring Rules, Prediction, and Estimation.
       Journal of the American Statistical Association, Vol. 102, 2007.
       https://www.stat.washington.edu/raftery/Research/PDF/Gneiting2007jasa.pdf
  [2]: Jochen Broecker.  Reliability, sufficiency, and the decomposition of
       proper scores.
       Quarterly Journal of the Royal Meteorological Society, Vol. 135, 2009.
       https://rmets.onlinelibrary.wiley.com/doi/epdf/10.1002/qj.456

  Args:
    labels: Tensor, (n,), with tf.int32 or tf.int64 elements containing ground
      truth class labels in the range [0,nlabels].
    logits: Tensor, (n, nlabels), with logits for n instances and nlabels.
    name: Python `str` name prefixed to Ops created by this function.

  Returns:
    uncertainty: Tensor, scalar, the uncertainty component of the
      decomposition.
    resolution: Tensor, scalar, the resolution component of the decomposition.
    reliability: Tensor, scalar, the reliability component of the
      decomposition.
  """
    with tf.name_scope(name or 'brier_decomposition'):
        labels = tf.convert_to_tensor(labels)
        logits = tf.convert_to_tensor(logits)
        num_classes = logits.shape[-1]

        # Compute pbar, the average distribution
        pred_class = tf.argmax(logits, axis=-1, output_type=labels.dtype)

        if tensorshape_util.rank(logits.shape) > 2:
            flatten, unflatten = _make_flatten_unflatten_fns(logits.shape[:-2])

            def fn_to_map(args):
                yhat, y = args
                return tf.math.confusion_matrix(yhat,
                                                y,
                                                num_classes=num_classes,
                                                dtype=logits.dtype)

            confusion_matrix = tf.map_fn(
                fn_to_map,
                [flatten(pred_class), flatten(labels)],
                fn_output_signature=logits.dtype)
            confusion_matrix = unflatten(confusion_matrix)
        else:
            confusion_matrix = tf.math.confusion_matrix(
                pred_class,
                labels,
                num_classes=num_classes,
                dtype=logits.dtype)

        dist_weights = tf.reduce_sum(confusion_matrix, axis=-1)
        dist_weights /= tf.reduce_sum(dist_weights, axis=-1, keepdims=True)
        pbar = tf.reduce_sum(confusion_matrix, axis=-2)
        pbar /= tf.reduce_sum(pbar, axis=-1, keepdims=True)

        eps = np.finfo(dtype_util.as_numpy_dtype(confusion_matrix.dtype)).eps
        # dist_mean[k,:] contains the empirical distribution for the set M_k
        # Some outcomes may not realize, corresponding to dist_weights[k] = 0
        dist_mean = confusion_matrix / (
            eps + tf.reduce_sum(confusion_matrix, axis=-1, keepdims=True))

        # Uncertainty: quadratic entropy of the average label distribution
        uncertainty = -tf.reduce_sum(tf.square(pbar), axis=-1)

        # Resolution: expected quadratic divergence of predictive to mean
        resolution = tf.square(tf.expand_dims(pbar, -1) - dist_mean)
        resolution = tf.reduce_sum(dist_weights *
                                   tf.reduce_sum(resolution, axis=-1),
                                   axis=-1)

        # Reliability: expected quadratic divergence of predictive to true
        if tensorshape_util.rank(logits.shape) > 2:
            # TODO(b/139094519): Avoid using tf.map_fn here.
            prob_true = tf.map_fn(
                lambda args: tf.gather(args[0], args[1]),
                [flatten(dist_mean), flatten(pred_class)],
                fn_output_signature=dist_mean.dtype)
            prob_true = unflatten(prob_true)
        else:
            prob_true = tf.gather(dist_mean, pred_class, axis=0)
        log_prob_true = tf.math.log(prob_true)

        log_prob_pred = logits - tf.math.reduce_logsumexp(
            logits, axis=-1, keepdims=True)

        log_reliability = _reduce_log_l2_exp(log_prob_pred,
                                             log_prob_true,
                                             axis=-1)
        log_reliability = tf.math.reduce_logsumexp(
            log_reliability,
            axis=-1,
        )

        num_samples = tf.cast(tf.shape(logits)[-2], logits.dtype)
        reliability = tf.exp(log_reliability - tf.math.log(num_samples))

        return uncertainty, resolution, reliability
Exemple #15
0
    def __init__(self,
                 cat,
                 components,
                 validate_args=False,
                 allow_nan_stats=True,
                 use_static_graph=False,
                 name='Mixture'):
        """Initialize a Mixture distribution.

    A `Mixture` is defined by a `Categorical` (`cat`, representing the
    mixture probabilities) and a list of `Distribution` objects
    all having matching dtype, batch shape, event shape, support, and continuity
    properties (the components).

    The `num_classes` of `cat` must be possible to infer at graph construction
    time and match `len(components)`.

    Args:
      cat: A `Categorical` distribution instance, representing the probabilities
          of `distributions`.
      components: A list or tuple of `Distribution` instances.
        Each instance must have the same type, be defined on the same domain,
        and have matching `event_shape` and `batch_shape`.
      validate_args: Python `bool`, default `False`. If `True`, raise a runtime
        error if batch or event ranks are inconsistent between cat and any of
        the distributions. This is only checked if the ranks cannot be
        determined statically at graph construction time.
      allow_nan_stats: Boolean, default `True`. If `False`, raise an
       exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      use_static_graph: Calls to `sample` will not rely on dynamic tensor
        indexing, allowing for some static graph compilation optimizations, but
        at the expense of sampling all underlying distributions in the mixture.
        (Possibly useful when running on TPUs).
        Default value: `False` (i.e., use dynamic indexing).
      name: A name for this distribution (optional).

    Raises:
      TypeError: If cat is not a `Categorical`, or `components` is not
        a list or tuple, or the elements of `components` are not
        instances of `Distribution`, or do not have matching `dtype`.
      ValueError: If `components` is an empty list or tuple, or its
        elements do not have a statically known event rank.
        If `cat.num_classes` cannot be inferred at graph creation time,
        or the constant value of `cat.num_classes` is not equal to
        `len(components)`, or all `components` and `cat` do not have
        matching static batch shapes, or all components do not
        have matching static event shapes.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:

            if not isinstance(cat, categorical.Categorical):
                raise TypeError(
                    'cat must be a Categorical distribution, but saw: %s' %
                    cat)
            if not components:
                raise ValueError(
                    'components must be a non-empty list or tuple')
            if not isinstance(components, (list, tuple)):
                raise TypeError(
                    'components must be a list or tuple, but saw: %s' %
                    components)
            if not all(
                    isinstance(c, distribution.Distribution)
                    for c in components):
                raise TypeError(
                    'all entries in components must be Distribution instances'
                    ' but saw: %s' % components)

            dtype = components[0].dtype
            if not all(d.dtype == dtype for d in components):
                raise TypeError(
                    'All components must have the same dtype, but saw '
                    'dtypes: %s' % [(d.name, d.dtype) for d in components])

            static_event_shape = components[0].event_shape
            static_batch_shape = cat.batch_shape
            for di, d in enumerate(components):
                if not tensorshape_util.is_compatible_with(
                        static_batch_shape, d.batch_shape):
                    raise ValueError(
                        'components[{}] batch shape must be compatible with cat '
                        'shape and other component batch shapes ({} vs {})'.
                        format(di, static_batch_shape, d.batch_shape))
                if not tensorshape_util.is_compatible_with(
                        static_event_shape, d.event_shape):
                    raise ValueError(
                        'components[{}] event shape must be compatible with other '
                        'component event shapes ({} vs {})'.format(
                            di, static_event_shape, d.event_shape))
                static_event_shape = tensorshape_util.merge_with(
                    static_event_shape, d.event_shape)
                static_batch_shape = tensorshape_util.merge_with(
                    static_batch_shape, d.batch_shape)
            if tensorshape_util.rank(static_event_shape) is None:
                raise ValueError(
                    'Expected to know rank(event_shape) from components, but '
                    'none of the components provide a static number of ndims')

            # pylint: disable=protected-access
            cat_dist_param = cat._probs if cat._logits is None else cat._logits
            # pylint: enable=protected-access
            static_num_components = tf.compat.dimension_value(
                cat_dist_param.shape[-1])
            if static_num_components is None:
                raise ValueError(
                    'Could not infer number of classes from cat and unable '
                    'to compare this value to the number of components passed in.'
                )
            if static_num_components != len(components):
                raise ValueError(
                    'cat.num_classes != len(components): %d vs. %d' %
                    (static_num_components, len(components)))

            self._cat = cat
            self._components = list(components)
            self._num_components = static_num_components
            self._static_event_shape = static_event_shape
            self._static_batch_shape = static_batch_shape
            self._use_static_graph = use_static_graph

            super(Mixture, self).__init__(
                dtype=dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
Exemple #16
0
    def _test_slicing(self, data, dist):
        strm = tfp_test_util.test_seed_stream()
        batch_shape = dist.batch_shape
        slices = data.draw(valid_slices(batch_shape))
        slice_str = 'dist[{}]'.format(', '.join(stringify_slices(slices)))
        # Make sure the slice string appears in Hypothesis' attempted example log
        hp.note('Using slice ' + slice_str)
        if not slices:  # Nothing further to check.
            return
        sliced_zeros = np.zeros(batch_shape)[slices]
        sliced_dist = dist[slices]

        # Check that slicing modifies batch shape as expected.
        self.assertAllEqual(sliced_zeros.shape, sliced_dist.batch_shape)

        if not sliced_zeros.size:
            # TODO(b/128924708): Fix distributions that fail on degenerate empty
            #     shapes, e.g. Multinomial, DirichletMultinomial, ...
            return

        # Check that sampling of sliced distributions executes.
        with no_tf_rank_errors():
            samples = self.evaluate(dist.sample(seed=strm()))
            sliced_samples = self.evaluate(sliced_dist.sample(seed=strm()))

        # Come up with the slices for samples (which must also include event dims).
        sample_slices = (tuple(slices) if isinstance(
            slices, collections.Sequence) else (slices, ))
        if Ellipsis not in sample_slices:
            sample_slices += (Ellipsis, )
        sample_slices += tuple([slice(None)] *
                               tensorshape_util.rank(dist.event_shape))

        # Report sub-sliced samples (on which we compare log_prob) to hypothesis.
        hp.note('Sample(s) for testing log_prob ' +
                str(samples[sample_slices]))

        # Check that sampling a sliced distribution produces the same shape as
        # slicing the samples from the original.
        self.assertAllEqual(samples[sample_slices].shape, sliced_samples.shape)

        # Check that a sliced distribution can compute the log_prob of its own
        # samples (up to numerical validation errors).
        with no_tf_rank_errors():
            try:
                lp = self.evaluate(dist.log_prob(samples))
            except tf.errors.InvalidArgumentError:
                # TODO(b/129271256): d.log_prob(d.sample()) should not fail
                #     validate_args checks.
                # We only tolerate this case for the non-sliced dist.
                return
            sliced_lp = self.evaluate(
                sliced_dist.log_prob(samples[sample_slices]))

        # Check that the sliced dist's log_prob agrees with slicing the original's
        # log_prob.
        # TODO(b/128708201): Better numerics for Geometric/Beta?
        # Eigen can return quite different results for packet vs non-packet ops.
        # To work around this, we use a much larger rtol for the last 3
        # (assuming packet size 4) elements.
        packetized_lp = lp[slices].reshape(-1)[:-3]
        packetized_sliced_lp = sliced_lp.reshape(-1)[:-3]
        rtol = (0.1 if any(x in dist.name for x in ('Geometric', 'Beta',
                                                    'Dirichlet')) else 0.02)
        self.assertAllClose(packetized_lp, packetized_sliced_lp, rtol=rtol)
        possibly_nonpacket_lp = lp[slices].reshape(-1)[-3:]
        possibly_nonpacket_sliced_lp = sliced_lp.reshape(-1)[-3:]
        self.assertAllClose(possibly_nonpacket_lp,
                            possibly_nonpacket_sliced_lp,
                            rtol=0.4,
                            atol=1e-4)
Exemple #17
0
 def gather_squeeze(params, indices):
     rank = tensorshape_util.rank(indices.shape)
     if rank is None:
         raise ValueError('`indices` must have statically known rank.')
     return tf.gather(params, indices, axis=-1,
                      batch_dims=rank - 1)[..., 0]
Exemple #18
0
def pivoted_cholesky(matrix, max_rank, diag_rtol=1e-3, name=None):
    """Computes the (partial) pivoted cholesky decomposition of `matrix`.

  The pivoted Cholesky is a low rank approximation of the Cholesky decomposition
  of `matrix`, i.e. as described in [(Harbrecht et al., 2012)][1]. The
  currently-worst-approximated diagonal element is selected as the pivot at each
  iteration. This yields from a `[B1...Bn, N, N]` shaped `matrix` a `[B1...Bn,
  N, K]` shaped rank-`K` approximation `lr` such that `lr @ lr.T ~= matrix`.
  Note that, unlike the Cholesky decomposition, `lr` is not triangular even in
  a rectangular-matrix sense. However, under a permutation it could be made
  triangular (it has one more zero in each column as you move to the right).

  Such a matrix can be useful as a preconditioner for conjugate gradient
  optimization, i.e. as in [(Wang et al. 2019)][2], as matmuls and solves can be
  cheaply done via the Woodbury matrix identity, as implemented by
  `tf.linalg.LinearOperatorLowRankUpdate`.

  Args:
    matrix: Floating point `Tensor` batch of symmetric, positive definite
      matrices.
    max_rank: Scalar `int` `Tensor`, the rank at which to truncate the
      approximation.
    diag_rtol: Scalar floating point `Tensor` (same dtype as `matrix`). If the
      errors of all diagonal elements of `lr @ lr.T` are each lower than
      `element * diag_rtol`, iteration is permitted to terminate early.
    name: Optional name for the op.

  Returns:
    lr: Low rank pivoted Cholesky approximation of `matrix`.

  #### References

  [1]: H Harbrecht, M Peters, R Schneider. On the low-rank approximation by the
       pivoted Cholesky decomposition. _Applied numerical mathematics_,
       62(4):428-440, 2012.

  [2]: K. A. Wang et al. Exact Gaussian Processes on a Million Data Points.
       _arXiv preprint arXiv:1903.08114_, 2019. https://arxiv.org/abs/1903.08114
  """
    with tf.compat.v2.name_scope(name or 'pivoted_cholesky'):
        dtype = dtype_util.common_dtype([matrix, diag_rtol],
                                        preferred_dtype=tf.float32)
        matrix = tf.convert_to_tensor(value=matrix, name='matrix', dtype=dtype)
        if tensorshape_util.rank(matrix.shape) is None:
            raise NotImplementedError(
                'Rank of `matrix` must be known statically')

        max_rank = tf.convert_to_tensor(value=max_rank,
                                        name='max_rank',
                                        dtype=tf.int64)
        max_rank = tf.minimum(
            max_rank,
            prefer_static.shape(matrix, out_type=tf.int64)[-1])
        diag_rtol = tf.convert_to_tensor(value=diag_rtol,
                                         dtype=dtype,
                                         name='diag_rtol')
        matrix_diag = tf.linalg.diag_part(matrix)
        # matrix is P.D., therefore all matrix_diag > 0, so we don't need abs.
        orig_error = tf.reduce_max(input_tensor=matrix_diag, axis=-1)

        def cond(m, pchol, perm, matrix_diag):
            """Condition for `tf.while_loop` continuation."""
            del pchol
            del perm
            error = tf.linalg.norm(tensor=matrix_diag, ord=1, axis=-1)
            max_err = tf.reduce_max(input_tensor=error / orig_error)
            return (m < max_rank) & (tf.equal(m, 0) | (max_err > diag_rtol))

        batch_dims = tensorshape_util.rank(matrix.shape) - 2

        def batch_gather(params, indices, axis=-1):
            return tf.gather(params, indices, axis=axis, batch_dims=batch_dims)

        def body(m, pchol, perm, matrix_diag):
            """Body of a single `tf.while_loop` iteration."""
            # Here is roughly a numpy, non-batched version of what's going to happen.
            # (See also Algorithm 1 of Harbrecht et al.)
            # 1: maxi = np.argmax(matrix_diag[perm[m:]]) + m
            # 2: maxval = matrix_diag[perm][maxi]
            # 3: perm[m], perm[maxi] = perm[maxi], perm[m]
            # 4: row = matrix[perm[m]][perm[m + 1:]]
            # 5: row -= np.sum(pchol[:m][perm[m + 1:]] * pchol[:m][perm[m]]], axis=-2)
            # 6: pivot = np.sqrt(maxval); row /= pivot
            # 7: row = np.concatenate([[[pivot]], row], -1)
            # 8: matrix_diag[perm[m:]] -= row**2
            # 9: pchol[m, perm[m:]] = row

            # Find the maximal position of the (remaining) permuted diagonal.
            # Steps 1, 2 above.
            permuted_diag = batch_gather(matrix_diag, perm[..., m:])
            maxi = tf.argmax(input=permuted_diag,
                             axis=-1,
                             output_type=tf.int64)[..., tf.newaxis]
            maxval = batch_gather(permuted_diag, maxi)
            maxi = maxi + m
            maxval = maxval[..., 0]
            # Update perm: Swap perm[...,m] with perm[...,maxi]. Step 3 above.
            perm = _swap_m_with_i(perm, m, maxi)
            # Step 4.
            row = batch_gather(matrix, perm[..., m:m + 1], axis=-2)
            row = batch_gather(row, perm[..., m + 1:])
            # Step 5.
            prev_rows = pchol[..., :m, :]
            prev_rows_perm_m_onward = batch_gather(prev_rows, perm[...,
                                                                   m + 1:])
            prev_rows_pivot_col = batch_gather(prev_rows, perm[..., m:m + 1])
            row -= tf.reduce_sum(input_tensor=prev_rows_perm_m_onward *
                                 prev_rows_pivot_col,
                                 axis=-2)[..., tf.newaxis, :]
            # Step 6.
            pivot = tf.sqrt(maxval)[..., tf.newaxis, tf.newaxis]
            # Step 7.
            row = tf.concat([pivot, row / pivot], axis=-1)
            # TODO(b/130899118): Pad grad fails with int64 paddings.
            # Step 8.
            paddings = tf.concat([
                tf.zeros([prefer_static.rank(pchol) - 1, 2], dtype=tf.int32),
                [[tf.cast(m, tf.int32), 0]]
            ],
                                 axis=0)
            diag_update = tf.pad(tensor=row**2, paddings=paddings)[..., 0, :]
            reverse_perm = _invert_permutation(perm)
            matrix_diag -= batch_gather(diag_update, reverse_perm)
            # Step 9.
            row = tf.pad(tensor=row, paddings=paddings)
            # TODO(bjp): Defer the reverse permutation all-at-once at the end?
            row = batch_gather(row, reverse_perm)
            pchol_shape = pchol.shape
            pchol = tf.concat([pchol[..., :m, :], row, pchol[..., m + 1:, :]],
                              axis=-2)
            tensorshape_util.set_shape(pchol, pchol_shape)
            return m + 1, pchol, perm, matrix_diag

        m = np.int64(0)
        pchol = tf.zeros_like(matrix[..., :max_rank, :])
        matrix_shape = prefer_static.shape(matrix, out_type=tf.int64)
        perm = tf.broadcast_to(prefer_static.range(matrix_shape[-1]),
                               matrix_shape[:-1])
        _, pchol, _, _ = tf.while_loop(cond=cond,
                                       body=body,
                                       loop_vars=(m, pchol, perm, matrix_diag))
        pchol = tf.linalg.matrix_transpose(pchol)
        tensorshape_util.set_shape(
            pchol, tensorshape_util.concatenate(matrix_diag.shape, [None]))
        return pchol
Exemple #19
0
    def __init__(self,
                 initial_distribution,
                 transition_distribution,
                 observation_distribution,
                 num_steps,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HiddenMarkovModel"):
        """Initialize hidden Markov model.

    Args:
      initial_distribution: A `Categorical`-like instance.
        Determines probability of first hidden state in Markov chain.
        The number of categories must match the number of categories of
        `transition_distribution` as well as both the rightmost batch
        dimension of `transition_distribution` and the rightmost batch
        dimension of `observation_distribution`.
      transition_distribution: A `Categorical`-like instance.
        The rightmost batch dimension indexes the probability distribution
        of each hidden state conditioned on the previous hidden state.
      observation_distribution: A `tfp.distributions.Distribution`-like
        instance.  The rightmost batch dimension indexes the distribution
        of each observation conditioned on the corresponding hidden state.
      num_steps: The number of steps taken in Markov chain. A python `int`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
        Default value: `False`.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
        Default value: `True`.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "HiddenMarkovModel".

    Raises:
      ValueError: if `num_steps` is not at least 1.
      ValueError: if `initial_distribution` does not have scalar `event_shape`.
      ValueError: if `transition_distribution` does not have scalar
        `event_shape.`
      ValueError: if `transition_distribution` and `observation_distribution`
        are fully defined but don't have matching rightmost dimension.
    """

        parameters = dict(locals())

        # pylint: disable=protected-access
        with tf.name_scope(name) as name:
            self._runtime_assertions = []  # pylint: enable=protected-access

            num_steps = tf.convert_to_tensor(value=num_steps, name="num_steps")
            if validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.rank(num_steps),
                        0,
                        message="`num_steps` must be a scalar")
                ]
                self._runtime_assertions += [
                    assert_util.assert_greater_equal(
                        num_steps,
                        1,
                        message="`num_steps` must be at least 1.")
                ]

            self._initial_distribution = initial_distribution
            self._observation_distribution = observation_distribution
            self._transition_distribution = transition_distribution

            if (initial_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        initial_distribution.event_shape) != 0):
                raise ValueError(
                    "`initial_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(initial_distribution.event_shape_tensor())[0],
                        0,
                        message="`initial_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (transition_distribution.event_shape is not None
                    and tensorshape_util.rank(
                        transition_distribution.event_shape) != 0):
                raise ValueError(
                    "`transition_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.shape(
                            transition_distribution.event_shape_tensor())[0],
                        0,
                        message="`transition_distribution` must have scalar"
                        "`event_dim`s")
                ]

            if (tensorshape_util.dims(transition_distribution.batch_shape)
                    is not None and tensorshape_util.rank(
                        transition_distribution.batch_shape) == 0):
                raise ValueError(
                    "`transition_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(transition_distribution.batch_shape_tensor()),
                        0,
                        message="`transition_distribution` can't have scalar "
                        "batches")
                ]

            if (tensorshape_util.dims(observation_distribution.batch_shape)
                    is not None and tensorshape_util.rank(
                        observation_distribution.batch_shape) == 0):
                raise ValueError(
                    "`observation_distribution` can't have scalar batches")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_greater(
                        tf.size(observation_distribution.batch_shape_tensor()),
                        0,
                        message="`observation_distribution` can't have scalar "
                        "batches")
                ]

            # Infer number of hidden states and check consistency
            # between transitions and observations
            with tf.control_dependencies(self._runtime_assertions):
                self._num_states = (
                    (tensorshape_util.dims(transition_distribution.batch_shape)
                     is not None and tensorshape_util.as_list(
                         transition_distribution.batch_shape)[-1])
                    or transition_distribution.batch_shape_tensor()[-1])

                observation_states = (
                    (tensorshape_util.dims(
                        observation_distribution.batch_shape) is not None
                     and tensorshape_util.as_list(
                         observation_distribution.batch_shape)[-1])
                    or observation_distribution.batch_shape_tensor()[-1])

            if (tf.is_tensor(self._num_states)
                    or tf.is_tensor(observation_states)):
                if validate_args:
                    self._runtime_assertions += [
                        assert_util.assert_equal(
                            self._num_states,
                            observation_states,
                            message="`transition_distribution` and "
                            "`observation_distribution` must agree on "
                            "last dimension of batch size")
                    ]
            elif self._num_states != observation_states:
                raise ValueError("`transition_distribution` and "
                                 "`observation_distribution` must agree on "
                                 "last dimension of batch size")

            self._log_init = _extract_log_probs(self._num_states,
                                                initial_distribution)
            self._log_trans = _extract_log_probs(self._num_states,
                                                 transition_distribution)

            self._num_steps = num_steps
            self._num_states = tf.shape(self._log_init)[-1]

            self._underlying_event_rank = tf.size(
                self._observation_distribution.event_shape_tensor())

            num_steps_ = tf.get_static_value(num_steps)
            if num_steps_ is not None:
                self.static_event_shape = tf.TensorShape([
                    num_steps_
                ]).concatenate(self._observation_distribution.event_shape)
            else:
                self.static_event_shape = None

            with tf.control_dependencies(self._runtime_assertions):
                self.static_batch_shape = tf.broadcast_static_shape(
                    self._initial_distribution.batch_shape,
                    tf.broadcast_static_shape(
                        self._transition_distribution.batch_shape[:-1],
                        self._observation_distribution.batch_shape[:-1]))

            # pylint: disable=protected-access
            super(HiddenMarkovModel, self).__init__(
                dtype=self._observation_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                name=name)
            # pylint: enable=protected-access

            self._parameters = parameters
Exemple #20
0
def _ones_like(input, dtype=None, name=None):  # pylint: disable=redefined-builtin
  s = _shape(input)
  if is_numpy(s):
    return np.ones(s, _numpy_dtype(dtype or input.dtype))
  return tf.ones(s, dtype or s.dtype, name)
ones_like = _copy_docstring(tf.ones_like, _ones_like)

range = _prefer_static(  # pylint: disable=redefined-builtin
    tf.range,
    lambda start, limit=None, delta=1, dtype=None, name='range': np.arange(  # pylint: disable=g-long-lambda
        start, limit, delta, _numpy_dtype(dtype)))

rank = _copy_docstring(
    tf.rank,
    lambda input, name=None: (  # pylint: disable=redefined-builtin,g-long-lambda
        tf.rank(input) if tensorshape_util.rank(input.shape) is None else
        tensorshape_util.rank(input.shape)))

reduce_all = _prefer_static(
    tf.reduce_all,
    lambda input_tensor, axis=None, keepdims=False, name=None: np.all(  # pylint: disable=g-long-lambda
        input_tensor, axis, keepdims=keepdims))

reduce_any = _prefer_static(
    tf.reduce_any,
    lambda input_tensor, axis=None, keepdims=False, name=None: np.any(  # pylint: disable=g-long-lambda
        input_tensor, axis, keepdims=keepdims))

reduce_prod = _prefer_static(
    tf.reduce_prod,
    lambda input_tensor, axis=None, keepdims=False, name=None: np.prod(  # pylint: disable=g-long-lambda
Exemple #21
0
 def _event_shape(self):
   if tensorshape_util.rank(self.samples.shape) is None:
     return tf.TensorShape(None)
   return self.samples.shape[self._samples_axis + 1:]
def _kl_independent(a, b, name='kl_independent'):
    """Batched KL divergence `KL(a || b)` for Independent distributions.

  We can leverage the fact that
  ```
  KL(Independent(a) || Independent(b)) = sum(KL(a || b))
  ```
  where the sum is over the `reinterpreted_batch_ndims`.

  Args:
    a: Instance of `Independent`.
    b: Instance of `Independent`.
    name: (optional) name to use for created ops. Default 'kl_independent'.

  Returns:
    Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the event space for `a` and `b`, or their underlying
      distributions don't match.
  """
    p = a.distribution
    q = b.distribution

    # The KL between any two (non)-batched distributions is a scalar.
    # Given that the KL between two factored distributions is the sum, i.e.
    # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute
    # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions.
    if (tensorshape_util.is_fully_defined(a.event_shape)
            and tensorshape_util.is_fully_defined(b.event_shape)):
        if a.event_shape == b.event_shape:
            if p.event_shape == q.event_shape:
                num_reduce_dims = (tensorshape_util.rank(a.event_shape) -
                                   tensorshape_util.rank(p.event_shape))
                reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)]

                return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                    q,
                                                                    name=name),
                                     axis=reduce_dims)
            else:
                raise NotImplementedError(
                    'KL between Independents with different '
                    'event shapes not supported.')
        else:
            raise ValueError('Event shapes do not match.')
    else:
        p_event_shape_tensor = p.event_shape_tensor()
        q_event_shape_tensor = q.event_shape_tensor()
        # NOTE: We could optimize by passing the event_shape_tensor of p and q
        # to a.event_shape_tensor() and b.event_shape_tensor().
        a_event_shape_tensor = a.event_shape_tensor()
        b_event_shape_tensor = b.event_shape_tensor()
        with tf.control_dependencies([
                assert_util.assert_equal(a_event_shape_tensor,
                                         b_event_shape_tensor,
                                         message='Event shapes do not match.'),
                assert_util.assert_equal(p_event_shape_tensor,
                                         q_event_shape_tensor,
                                         message='Event shapes do not match.'),
        ]):
            num_reduce_dims = (prefer_static.rank_from_shape(
                a_event_shape_tensor, a.event_shape) -
                               prefer_static.rank_from_shape(
                                   p_event_shape_tensor, p.event_shape))
            reduce_dims = prefer_static.range(-num_reduce_dims, 0, 1)
            return tf.reduce_sum(kullback_leibler.kl_divergence(p,
                                                                q,
                                                                name=name),
                                 axis=reduce_dims)
Exemple #23
0
def find_bins(x,
              edges,
              extend_lower_interval=False,
              extend_upper_interval=False,
              dtype=None,
              name=None):
    """Bin values into discrete intervals.

  Given `edges = [c0, ..., cK]`, defining intervals
  `I0 = [c0, c1)`, `I1 = [c1, c2)`, ..., `I_{K-1} = [c_{K-1}, cK]`,
  This function returns `bins`, such that:
  `edges[bins[i]] <= x[i] < edges[bins[i] + 1]`.

  Args:
    x:  Numeric `N-D` `Tensor` with `N > 0`.
    edges:  `Tensor` of same `dtype` as `x`.  The first dimension indexes edges
      of intervals.  Must either be `1-D` or have
      `x.shape[1:] == edges.shape[1:]`.  If `rank(edges) > 1`, `edges[k]`
      designates a shape `edges.shape[1:]` `Tensor` of bin edges for the
      corresponding dimensions of `x`.
    extend_lower_interval:  Python `bool`.  If `True`, extend the lowest
      interval `I0` to `(-inf, c1]`.
    extend_upper_interval:  Python `bool`.  If `True`, extend the upper
      interval `I_{K-1}` to `[c_{K-1}, +inf)`.
    dtype: The output type (`int32` or `int64`). `Default value:` `x.dtype`.
      This effects the output values when `x` is below/above the intervals,
      which will be `-1/K+1` for `int` types and `NaN` for `float`s.
      At indices where `x` is `NaN`, the output values will be `0` for `int`
      types and `NaN` for floats.
    name:  A Python string name to prepend to created ops. Default: 'find_bins'

  Returns:
    bins: `Tensor` with same `shape` as `x` and `dtype`.
      Has whole number values.  `bins[i] = k` means the `x[i]` falls into the
      `kth` bin, ie, `edges[bins[i]] <= x[i] < edges[bins[i] + 1]`.

  Raises:
    ValueError:  If `edges.shape[0]` is determined to be less than 2.

  #### Examples

  Cut a `1-D` array

  ```python
  x = [0., 5., 6., 10., 20.]
  edges = [0., 5., 10.]
  tfp.stats.find_bins(x, edges)
  ==> [0., 0., 1., 1., np.nan]
  ```

  Cut `x` into its deciles

  ```python
  x = tf.random.uniform(shape=(100, 200))
  decile_edges = tfp.stats.quantiles(x, num_quantiles=10)
  bins = tfp.stats.find_bins(x, edges=decile_edges)
  bins.shape
  ==> (100, 200)
  tf.reduce_mean(bins == 0.)
  ==> approximately 0.1
  tf.reduce_mean(bins == 1.)
  ==> approximately 0.1
  ```

  """
    # TFP users may be surprised to see the "action" in the leftmost dim of
    # edges, rather than the rightmost (event) dim.  Why?
    # 1. Most likely you created edges by getting quantiles over samples, and
    #    quantile/percentile return these edges in the leftmost (sample) dim.
    # 2. Say you have event_shape = [5], then we expect the bin will be different
    #    for all 5 events, so the index of the bin should not be in the event dim.
    with tf.name_scope(name or 'find_bins'):
        in_type = dtype_util.common_dtype([x, edges], dtype_hint=tf.float32)
        edges = tf.convert_to_tensor(edges, name='edges', dtype=in_type)
        x = tf.convert_to_tensor(x, name='x', dtype=in_type)

        if (tf.compat.dimension_value(edges.shape[0]) is not None
                and tf.compat.dimension_value(edges.shape[0]) < 2):
            raise ValueError(
                'First dimension of `edges` must have length > 1 to index 1 or '
                'more bin. Found: {}'.format(edges.shape))

        flattening_x = (tensorshape_util.rank(edges.shape) == 1
                        and tensorshape_util.rank(x.shape) > 1)

        if flattening_x:
            x_orig_shape = ps.shape(x)
            x = tf.reshape(x, [-1])

        if dtype is None:
            dtype = in_type
        dtype = tf.as_dtype(dtype)

        # Move first dims into the rightmost.
        x_permed = distribution_util.rotate_transpose(x, shift=-1)
        edges_permed = distribution_util.rotate_transpose(edges, shift=-1)

        # If...
        #   x_permed = [0, 1, 6., 10]
        #   edges = [0, 5, 10.]
        #   ==> almost_output = [0, 1, 2, 2]
        searchsorted_type = dtype if dtype in [tf.int32, tf.int64] else None
        almost_output_permed = tf.searchsorted(sorted_sequence=edges_permed,
                                               values=x_permed,
                                               side='right',
                                               out_type=searchsorted_type)
        # Move the rightmost dims back to the leftmost.
        almost_output = tf.cast(
            distribution_util.rotate_transpose(almost_output_permed, shift=1),
            dtype)

        # In above example, we want [0, 0, 1, 1], so correct this here.
        bins = tf.clip_by_value(almost_output - 1, tf.cast(0, dtype),
                                tf.cast(tf.shape(edges)[0] - 2, dtype))

        if not extend_lower_interval:
            low_fill = np.nan if dtype_util.is_floating(dtype) else -1
            bins = tf.where(x < tf.expand_dims(edges[0], 0),
                            tf.cast(low_fill, dtype), bins)

        if not extend_upper_interval:
            up_fill = (np.nan if dtype_util.is_floating(dtype) else
                       tf.shape(edges)[0] - 1)
            bins = tf.where(x > tf.expand_dims(edges[-1], 0),
                            tf.cast(up_fill, dtype), bins)

        if flattening_x:
            bins = tf.reshape(bins, x_orig_shape)

        return bins
    def __init__(self,
                 distribution,
                 reinterpreted_batch_ndims=None,
                 validate_args=False,
                 name=None):
        """Construct an `Independent` distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims
        which will be regarded as event dims. When `None` all but the first
        batch axis (batch axis 0) will be transferred to event dimensions
        (analogous to `tf.layers.flatten`).
      validate_args: Python `bool`.  Whether to validate input with asserts.
        If `validate_args` is `False`, and the inputs are invalid,
        correct behavior is not guaranteed.
      name: The name for ops managed by the distribution.
        Default value: `Independent + distribution.name`.

    Raises:
      ValueError: if `reinterpreted_batch_ndims` exceeds
        `distribution.batch_ndims`
    """
        parameters = dict(locals())
        with tf.name_scope(name
                           or ('Independent' + distribution.name)) as name:
            self._distribution = distribution

            if reinterpreted_batch_ndims is None:
                # If possible, statically infer reinterpreted_batch_ndims.
                batch_ndims = tensorshape_util.rank(distribution.batch_shape)
                if batch_ndims is not None:
                    self._static_reinterpreted_batch_ndims = max(
                        0, batch_ndims - 1)
                    self._reinterpreted_batch_ndims = tf.convert_to_tensor(
                        self._static_reinterpreted_batch_ndims,
                        dtype_hint=tf.int32,
                        name='reinterpreted_batch_ndims')
                else:
                    self._reinterpreted_batch_ndims = None
                    self._static_reinterpreted_batch_ndims = None

            else:
                self._reinterpreted_batch_ndims = tensor_util.convert_nonref_to_tensor(
                    reinterpreted_batch_ndims,
                    dtype_hint=tf.int32,
                    name='reinterpreted_batch_ndims')
                static_val = tf.get_static_value(
                    self._reinterpreted_batch_ndims)
                self._static_reinterpreted_batch_ndims = (
                    None if static_val is None else int(static_val))

            super(Independent, self).__init__(
                dtype=self._distribution.dtype,
                reparameterization_type=self._distribution.
                reparameterization_type,
                validate_args=validate_args,
                allow_nan_stats=self._distribution.allow_nan_stats,
                parameters=parameters,
                name=name)
Exemple #25
0
def _replace_event_shape_in_tensorshape(input_tensorshape, event_shape_in,
                                        event_shape_out):
    """Replaces the event shape dims of a `TensorShape`.

  Args:
    input_tensorshape: a `TensorShape` instance in which to attempt replacing
      event shape.
    event_shape_in: `Tensor` shape representing the event shape expected to
      be present in (rightmost dims of) `tensorshape_in`. Must be compatible
      with the rightmost dims of `tensorshape_in`.
    event_shape_out: `Tensor` shape representing the new event shape, i.e.,
      the replacement of `event_shape_in`,

  Returns:
    output_tensorshape: `TensorShape` with the rightmost `event_shape_in`
      replaced by `event_shape_out`. Might be partially defined, i.e.,
      `TensorShape(None)`.
    is_validated: Python `bool` indicating static validation happened.

  Raises:
    ValueError: if we can determine the event shape portion of
      `tensorshape_in` as well as `event_shape_in` both statically, and they
      are not compatible. "Compatible" here means that they are identical on
      any dims that are not -1 in `event_shape_in`.
  """
    event_shape_in_ndims = tensorshape_util.num_elements(event_shape_in.shape)
    if tensorshape_util.rank(
            input_tensorshape) is None or event_shape_in_ndims is None:
        return tf.TensorShape(None), False  # Not is_validated.

    input_non_event_ndims = tensorshape_util.rank(
        input_tensorshape) - event_shape_in_ndims
    if input_non_event_ndims < 0:
        raise ValueError(
            'Input has lower rank ({}) than `event_shape_ndims` ({}).'.format(
                tensorshape_util.rank(input_tensorshape),
                event_shape_in_ndims))

    input_non_event_tensorshape = input_tensorshape[:input_non_event_ndims]
    input_event_tensorshape = input_tensorshape[input_non_event_ndims:]

    # Check that `input_event_shape_` and `event_shape_in` are compatible in the
    # sense that they have equal entries in any position that isn't a `-1` in
    # `event_shape_in`. Note that our validations at construction time ensure
    # there is at most one such entry in `event_shape_in`.
    event_shape_in_ = tf.get_static_value(event_shape_in)
    is_validated = (tensorshape_util.is_fully_defined(input_event_tensorshape)
                    and event_shape_in_ is not None)
    if is_validated:
        input_event_shape_ = np.int32(input_event_tensorshape)
        mask = event_shape_in_ >= 0
        explicit_input_event_shape_ = input_event_shape_[mask]
        explicit_event_shape_in_ = event_shape_in_[mask]
        if not np.all(explicit_input_event_shape_ == explicit_event_shape_in_):
            raise ValueError(
                'Input `event_shape` does not match `event_shape_in` '
                '({} vs {}).'.format(input_event_shape_, event_shape_in_))

    event_tensorshape_out = tensorshape_util.constant_value_as_shape(
        event_shape_out)
    if tensorshape_util.rank(event_tensorshape_out) is None:
        output_tensorshape = tf.TensorShape(None)
    else:
        output_tensorshape = tensorshape_util.concatenate(
            input_non_event_tensorshape, event_tensorshape_out)
    return output_tensorshape, is_validated
 def additional_check(dist):
     return (tensorshape_util.rank(dist.event_shape) == 2
             and int(dist.event_shape[0]) == int(dist.event_shape[1]))
Exemple #27
0
  def _test_slicing(self, data, dist_name, dist):
    strm = test_util.test_seed_stream()
    batch_shape = dist.batch_shape
    slices = data.draw(dhps.valid_slices(batch_shape))
    slice_str = 'dist[{}]'.format(', '.join(dhps.stringify_slices(
        slices)))
    # Make sure the slice string appears in Hypothesis' attempted example log
    hp.note('Using slice ' + slice_str)
    if not slices:  # Nothing further to check.
      return
    sliced_zeros = np.zeros(batch_shape)[slices]
    sliced_dist = dist[slices]
    hp.note('Using sliced distribution {}.'.format(sliced_dist))

    # Check that slicing modifies batch shape as expected.
    self.assertAllEqual(sliced_zeros.shape, sliced_dist.batch_shape)

    if not sliced_zeros.size:
      # TODO(b/128924708): Fix distributions that fail on degenerate empty
      #     shapes, e.g. Multinomial, DirichletMultinomial, ...
      return

    # Check that sampling of sliced distributions executes.
    with tfp_hps.no_tf_rank_errors():
      samples = self.evaluate(dist.sample(seed=strm()))
      sliced_dist_samples = self.evaluate(sliced_dist.sample(seed=strm()))

    # Come up with the slices for samples (which must also include event dims).
    sample_slices = (
        tuple(slices) if isinstance(slices, collections.Sequence) else
        (slices,))
    if Ellipsis not in sample_slices:
      sample_slices += (Ellipsis,)
    sample_slices += tuple([slice(None)] *
                           tensorshape_util.rank(dist.event_shape))

    sliced_samples = samples[sample_slices]

    # Report sub-sliced samples (on which we compare log_prob) to hypothesis.
    hp.note('Sample(s) for testing log_prob ' + str(sliced_samples))

    # Check that sampling a sliced distribution produces the same shape as
    # slicing the samples from the original.
    self.assertAllEqual(sliced_samples.shape, sliced_dist_samples.shape)

    # Check that a sliced distribution can compute the log_prob of its own
    # samples (up to numerical validation errors).
    with tfp_hps.no_tf_rank_errors():
      try:
        lp = self.evaluate(dist.log_prob(samples))
      except tf.errors.InvalidArgumentError:
        # TODO(b/129271256): d.log_prob(d.sample()) should not fail
        #     validate_args checks.
        # We only tolerate this case for the non-sliced dist.
        return
      sliced_lp = self.evaluate(sliced_dist.log_prob(sliced_samples))

    # Check that the sliced dist's log_prob agrees with slicing the original's
    # log_prob.

    # This `hp.assume` is suppressing array sizes that cause the sliced and
    # non-sliced distribution to follow different Eigen code paths.  Those
    # different code paths lead to arbitrarily large variations in the results
    # at parameter settings that Hypothesis is all too good at finding.  Since
    # the purpose of this test is just to check that we got slicing right, those
    # discrepancies are a distraction.
    # TODO(b/140229057): Remove this `hp.assume`, if and when Eigen's numerics
    # become index-independent.
    all_packetized = (
        _all_packetized(dist) and _all_packetized(sliced_dist) and
        _all_packetized(samples) and _all_packetized(sliced_samples))
    hp.note('Packetization check {}'.format(all_packetized))
    all_non_packetized = (
        _all_non_packetized(dist) and _all_non_packetized(sliced_dist) and
        _all_non_packetized(samples) and _all_non_packetized(sliced_samples))
    hp.note('Non-packetization check {}'.format(all_non_packetized))
    hp.assume(all_packetized or all_non_packetized)

    self.assertAllClose(lp[slices], sliced_lp,
                        atol=SLICING_LOGPROB_ATOL[dist_name],
                        rtol=SLICING_LOGPROB_RTOL[dist_name])
Exemple #28
0
    def _sample_n(self, n, seed=None):
        seeds = samplers.split_seed(seed,
                                    n=self.num_components + 1,
                                    salt='Mixture')
        try:
            seed_stream = SeedStream(seed, salt='Mixture')
        except TypeError as e:  # Can happen for Tensor seed.
            seed_stream = None
            seed_stream_err = e
        if self._use_static_graph:
            # This sampling approach is almost the same as the approach used by
            # `MixtureSameFamily`. The differences are due to having a list of
            # `Distribution` objects rather than a single object, and maintaining
            # random seed management that is consistent with the non-static code
            # path.
            samples = []
            cat_samples = self.cat.sample(n, seed=seeds[0])

            for c in range(self.num_components):
                try:
                    samples.append(self.components[c].sample(n,
                                                             seed=seeds[c +
                                                                        1]))
                    if seed_stream is not None:
                        seed_stream()
                except TypeError as e:
                    if ('Expected int for argument' not in str(e)
                            and TENSOR_SEED_MSG_PREFIX not in str(e)):
                        raise
                    if seed_stream is None:
                        raise seed_stream_err
                    msg = (
                        'Falling back to stateful sampling for `components[{}]` {} of '
                        'type `{}`. Please update to use `tf.random.stateless_*` RNGs. '
                        'This fallback may be removed after 20-Aug-2020. ({})')
                    warnings.warn(
                        msg.format(c, self.components[c].name,
                                   type(self.components[c]), str(e)))
                    samples.append(self.components[c].sample(
                        n, seed=seed_stream()))
            stack_axis = -1 - tensorshape_util.rank(self._static_event_shape)
            x = tf.stack(samples, axis=stack_axis)  # [n, B, k, E]
            npdt = dtype_util.as_numpy_dtype(x.dtype)
            mask = tf.one_hot(
                indices=cat_samples,  # [n, B]
                depth=self._num_components,  # == k
                on_value=npdt(1),
                off_value=npdt(0))  # [n, B, k]
            mask = distribution_util.pad_mixture_dimensions(
                mask, self, self._cat,
                tensorshape_util.rank(
                    self._static_event_shape))  # [n, B, k, [1]*e]
            return tf.reduce_sum(x * mask, axis=stack_axis)  # [n, B, E]

        n = tf.convert_to_tensor(n, name='n')
        static_n = tf.get_static_value(n)
        n = int(static_n) if static_n is not None else n
        cat_samples = self.cat.sample(n, seed=seeds[0])

        static_samples_shape = cat_samples.shape
        if tensorshape_util.is_fully_defined(static_samples_shape):
            samples_shape = tensorshape_util.as_list(static_samples_shape)
            samples_size = tensorshape_util.num_elements(static_samples_shape)
        else:
            samples_shape = tf.shape(cat_samples)
            samples_size = tf.size(cat_samples)
        static_batch_shape = self.batch_shape
        if tensorshape_util.is_fully_defined(static_batch_shape):
            batch_shape = tensorshape_util.as_list(static_batch_shape)
            batch_size = tensorshape_util.num_elements(static_batch_shape)
        else:
            batch_shape = tf.shape(cat_samples)[1:]
            batch_size = tf.reduce_prod(batch_shape)
        static_event_shape = self.event_shape
        if tensorshape_util.is_fully_defined(static_event_shape):
            event_shape = np.array(
                tensorshape_util.as_list(static_event_shape), dtype=np.int32)
        else:
            event_shape = None

        # Get indices into the raw cat sampling tensor. We will
        # need these to stitch sample values back out after sampling
        # within the component partitions.
        samples_raw_indices = tf.reshape(tf.range(0, samples_size),
                                         samples_shape)

        # Partition the raw indices so that we can use
        # dynamic_stitch later to reconstruct the samples from the
        # known partitions.
        partitioned_samples_indices = tf.dynamic_partition(
            data=samples_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)

        # Copy the batch indices n times, as we will need to know
        # these to pull out the appropriate rows within the
        # component partitions.
        batch_raw_indices = tf.reshape(tf.tile(tf.range(0, batch_size), [n]),
                                       samples_shape)

        # Explanation of the dynamic partitioning below:
        #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
        # Suppose partitions are:
        #     [1 1 0 0 1 1]
        # After partitioning, batch indices are cut as:
        #     [batch_indices[x] for x in 2, 3]
        #     [batch_indices[x] for x in 0, 1, 4, 5]
        # i.e.
        #     [1 1] and [0 0 0 0]
        # Now we sample n=2 from part 0 and n=4 from part 1.
        # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
        # and for part 1 we want samples from batch entries 0, 0, 0, 0
        #   (samples 0, 1, 2, 3).
        partitioned_batch_indices = tf.dynamic_partition(
            data=batch_raw_indices,
            partitions=cat_samples,
            num_partitions=self.num_components)
        samples_class = [None for _ in range(self.num_components)]

        for c in range(self.num_components):
            n_class = tf.size(partitioned_samples_indices[c])
            try:
                samples_class_c = self.components[c].sample(n_class,
                                                            seed=seeds[c + 1])
                if seed_stream is not None:
                    seed_stream()
            except TypeError as e:
                if ('Expected int for argument' not in str(e)
                        and TENSOR_SEED_MSG_PREFIX not in str(e)):
                    raise
                if seed_stream is None:
                    raise seed_stream_err
                msg = (
                    'Falling back to stateful sampling for `components[{}]` {} of '
                    'type `{}`. Please update to use `tf.random.stateless_*` RNGs. '
                    'This fallback may be removed after 20-Aug-2020. ({})')
                warnings.warn(
                    msg.format(c, self.components[c].name,
                               type(self.components[c]), str(e)))
                samples_class_c = self.components[c].sample(n_class,
                                                            seed=seed_stream())

            if event_shape is None:
                batch_ndims = prefer_static.rank_from_shape(batch_shape)
                event_shape = tf.shape(samples_class_c)[1 + batch_ndims:]

            # Pull out the correct batch entries from each index.
            # To do this, we may have to flatten the batch shape.

            # For sample s, batch element b of component c, we get the
            # partitioned batch indices from
            # partitioned_batch_indices[c]; and shift each element by
            # the sample index. The final lookup can be thought of as
            # a matrix gather along locations (s, b) in
            # samples_class_c where the n_class rows correspond to
            # samples within this component and the batch_size columns
            # correspond to batch elements within the component.
            #
            # Thus the lookup index is
            #   lookup[c, i] = batch_size * s[i] + b[c, i]
            # for i = 0 ... n_class[c] - 1.
            lookup_partitioned_batch_indices = (
                batch_size * tf.range(n_class) + partitioned_batch_indices[c])
            samples_class_c = tf.reshape(
                samples_class_c,
                tf.concat([[n_class * batch_size], event_shape], 0))
            samples_class_c = tf.gather(samples_class_c,
                                        lookup_partitioned_batch_indices,
                                        name='samples_class_c_gather')
            samples_class[c] = samples_class_c

        # Stitch back together the samples across the components.
        lhs_flat_ret = tf.dynamic_stitch(indices=partitioned_samples_indices,
                                         data=samples_class)
        # Reshape back to proper sample, batch, and event shape.
        ret = tf.reshape(lhs_flat_ret,
                         tf.concat([samples_shape, event_shape], 0))
        tensorshape_util.set_shape(
            ret,
            tensorshape_util.concatenate(static_samples_shape,
                                         self.event_shape))
        return ret
    def _test_slicing(self, data, dist, batch_shape):
        slices = data.draw(valid_slices(batch_shape))
        slice_str = 'dist[{}]'.format(', '.join(stringify_slices(slices)))
        logging.info('slice used: %s', slice_str)
        # Make sure the slice string appears in Hypothesis' attempted example log,
        # by drawing and discarding it.
        data.draw(hps.just(slice_str))
        if not slices:  # Nothing further to check.
            return
        sliced_zeros = np.zeros(batch_shape)[slices]
        sliced_dist = dist[slices]
        self.assertAllEqual(sliced_zeros.shape, sliced_dist.batch_shape)

        try:
            seed = data.draw(
                hpnp.arrays(dtype=np.int64, shape=[]).filter(lambda x: x != 0))
            samples = self.evaluate(dist.sample(seed=maybe_seed(seed)))

            if not sliced_zeros.size:
                # TODO(b/128924708): Fix distributions that fail on degenerate empty
                #     shapes, e.g. Multinomial, DirichletMultinomial, ...
                return

            sliced_samples = self.evaluate(
                sliced_dist.sample(seed=maybe_seed(seed)))
        except NotImplementedError as e:
            raise
        except tf.errors.UnimplementedError as e:
            if 'Unhandled input dimensions' in str(e) or 'rank not in' in str(
                    e):
                # Some cases can fail with 'Unhandled input dimensions \d+' or
                # 'inputs rank not in [0,6]: \d+'
                return
            raise

        # Come up with the slices for samples (which must also include event dims).
        sample_slices = (tuple(slices) if isinstance(
            slices, collections.Sequence) else (slices, ))
        if Ellipsis not in sample_slices:
            sample_slices += (Ellipsis, )
        sample_slices += tuple([slice(None)] *
                               tensorshape_util.rank(dist.event_shape))

        # Report sub-sliced samples (on which we compare log_prob) to hypothesis.
        data.draw(hps.just(samples[sample_slices]))
        self.assertAllEqual(samples[sample_slices].shape, sliced_samples.shape)
        try:
            try:
                lp = self.evaluate(dist.log_prob(samples))
            except tf.errors.InvalidArgumentError:
                # TODO(b/129271256): d.log_prob(d.sample()) should not fail
                #     validate_args checks.
                # We only tolerate this case for the non-sliced dist.
                return
            sliced_lp = self.evaluate(
                sliced_dist.log_prob(samples[sample_slices]))
        except tf.errors.UnimplementedError as e:
            if 'Unhandled input dimensions' in str(e) or 'rank not in' in str(
                    e):
                # Some cases can fail with 'Unhandled input dimensions \d+' or
                # 'inputs rank not in [0,6]: \d+'
                return
            raise
        # TODO(b/128708201): Better numerics for Geometric/Beta?
        # Eigen can return quite different results for packet vs non-packet ops.
        # To work around this, we use a much larger rtol for the last 3
        # (assuming packet size 4) elements.
        packetized_lp = lp[slices].reshape(-1)[:-3]
        packetized_sliced_lp = sliced_lp.reshape(-1)[:-3]
        rtol = (0.1 if any(x in dist.name for x in ('Geometric', 'Beta',
                                                    'Dirichlet')) else 0.02)
        self.assertAllClose(packetized_lp, packetized_sliced_lp, rtol=rtol)
        possibly_nonpacket_lp = lp[slices].reshape(-1)[-3:]
        possibly_nonpacket_sliced_lp = sliced_lp.reshape(-1)[-3:]
        rtol = 0.4
        self.assertAllClose(possibly_nonpacket_lp,
                            possibly_nonpacket_sliced_lp,
                            rtol=rtol)
    def _distributional_transform(self, x):
        """Performs distributional transform of the mixture samples.

    Distributional transform removes the parameters from samples of a
    multivariate distribution by applying conditional CDFs:
      (F(x_1), F(x_2 | x1_), ..., F(x_d | x_1, ..., x_d-1))
    (the indexing is over the "flattened" event dimensions).
    The result is a sample of product of Uniform[0, 1] distributions.

    We assume that the components are factorized, so the conditional CDFs become
      F(x_i | x_1, ..., x_i-1) = sum_k w_i^k F_k (x_i),
    where w_i^k is the posterior mixture weight: for i > 0
      w_i^k = w_k prob_k(x_1, ..., x_i-1) / sum_k' w_k' prob_k'(x_1, ..., x_i-1)
    and w_0^k = w_k is the mixture probability of the k-th component.

    Arguments:
      x: Sample of mixture distribution

    Returns:
      Result of the distributional transform
    """

        if tensorshape_util.rank(x.shape) is None:
            # tf.math.softmax raises an error when applied to inputs of undefined
            # rank.
            raise ValueError(
                "Distributional transform does not support inputs of "
                "undefined rank.")

        # Obtain factorized components distribution and assert that it's
        # a scalar distribution.
        if isinstance(self._components_distribution, independent.Independent):
            univariate_components = self._components_distribution.distribution
        else:
            univariate_components = self._components_distribution

        with tf.control_dependencies([
                assert_util.assert_equal(
                    univariate_components.is_scalar_event(),
                    True,
                    message="`univariate_components` must have scalar event")
        ]):
            x_padded = self._pad_sample_dims(x)  # [S, B, 1, E]
            log_prob_x = univariate_components.log_prob(
                x_padded)  # [S, B, k, E]
            cdf_x = univariate_components.cdf(x_padded)  # [S, B, k, E]

            # log prob_k (x_1, ..., x_i-1)
            cumsum_log_prob_x = tf.reshape(
                tf.math.cumsum(
                    # [S*prod(B)*k, prod(E)]
                    tf.reshape(log_prob_x, [-1, self._event_size]),
                    exclusive=True,
                    axis=-1),
                tf.shape(log_prob_x))  # [S, B, k, E]

            logits_mix_prob = distribution_utils.pad_mixture_dimensions(
                self.mixture_distribution.logits_parameter(), self,
                self.mixture_distribution, self._event_ndims)  # [B, k, 1]

            # Logits of the posterior weights: log w_k + log prob_k (x_1, ..., x_i-1)
            log_posterior_weights_x = logits_mix_prob + cumsum_log_prob_x

            component_axis = tensorshape_util.rank(x.shape) - self._event_ndims
            posterior_weights_x = tf.math.softmax(log_posterior_weights_x,
                                                  axis=component_axis)
            return tf.reduce_sum(posterior_weights_x * cdf_x,
                                 axis=component_axis)