def check_event_space_bijector_constrains(self, dist, data):
        event_space_bijector = dist.experimental_default_event_space_bijector()
        if event_space_bijector is None:
            return

        total_sample_shape = tensorshape_util.concatenate(
            # Draw a sample shape
            data.draw(tfp_hps.shapes()),
            # Draw a shape that broadcasts with `[batch_shape, inverse_event_shape]`
            # where `inverse_event_shape` is the event shape in the bijector's
            # domain. This is the shape of `y` in R**n, such that
            # x = event_space_bijector(y) has the event shape of the distribution.
            data.draw(
                tfp_hps.broadcasting_shapes(tensorshape_util.concatenate(
                    dist.batch_shape,
                    event_space_bijector.inverse_event_shape(
                        dist.event_shape)),
                                            n=1))[0])

        y = data.draw(
            tfp_hps.constrained_tensors(tfp_hps.identity_fn,
                                        total_sample_shape.as_list()))
        with tfp_hps.no_tf_rank_errors():
            x = event_space_bijector(y)
            with tf.control_dependencies(dist._sample_control_dependencies(x)):
                self.evaluate(tf.identity(x))
Beispiel #2
0
 def _call_and_reshape_output(
     self,
     fn,
     event_shape_list=None,
     static_event_shape_list=None,
     extra_kwargs=None):
   """Calls `fn` and appropriately reshapes its output."""
   # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs`
   # because it is possible the user provided extra kwargs would itself
   # have `fn`, `event_shape_list`, `static_event_shape_list` and/or
   # `extra_kwargs` as keys.
   if event_shape_list is None:
     event_shape_list = [self._event_shape_tensor()]
   if static_event_shape_list is None:
     static_event_shape_list = [self.event_shape]
   new_shape = tf.concat(
       [self._batch_shape_unexpanded] + event_shape_list, axis=0)
   result = tf.reshape(fn(**extra_kwargs) if extra_kwargs else fn(),
                       new_shape)
   if (tensorshape_util.rank(self.batch_shape) is not None and
       tensorshape_util.rank(self.event_shape) is not None):
     event_shape = tf.TensorShape([])
     for rss in static_event_shape_list:
       event_shape = tensorshape_util.concatenate(event_shape, rss)
     static_shape = tensorshape_util.concatenate(
         self.batch_shape, event_shape)
     tensorshape_util.set_shape(result, static_shape)
   return result
Beispiel #3
0
 def _inverse_event_shape(self, output_shape):
   if not self._maybe_changes_size:
     return output_shape
   output_shape = tensorshape_util.with_rank_at_least(output_shape, 1)
   static_block_sizes = tf.get_static_value(self.block_sizes)
   if static_block_sizes is None:
     return tensorshape_util.concatenate(output_shape[:-1], [None])
   input_size = sum(static_block_sizes)
   return tensorshape_util.concatenate(output_shape[:-1], [input_size])
Beispiel #4
0
    def _call_reshape_input_output(self,
                                   fn,
                                   x,
                                   input_event_shape=None,
                                   output_event_shape=None,
                                   keep_event_dims=False,
                                   extra_kwargs=None):
        """Calls `fn`, appropriately reshaping its input `x` and output."""
        # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs`
        # because it is possible the user provided extra kwargs would itself
        # have `fn` and/or `x` as a key.
        if input_event_shape is None:
            static_input_event_shape, input_event_shape_tensor = (
                self.event_shape, self.event_shape_tensor())
        else:
            static_input_event_shape, input_event_shape_tensor = input_event_shape

        if output_event_shape is None:
            if input_event_shape is None:
                static_output_event_shape, output_event_shape_tensor = (
                    static_input_event_shape, input_event_shape_tensor)
            else:
                static_output_event_shape, output_event_shape_tensor = (
                    self.event_shape, self.event_shape_tensor())
        else:
            static_output_event_shape, output_event_shape_tensor = output_event_shape

        sample_shape, static_sample_shape = self._sample_shape(
            x, static_input_event_shape, input_event_shape_tensor)
        old_shape = ps.concat([
            sample_shape,
            self.distribution.batch_shape_tensor(),
            input_event_shape_tensor,
        ],
                              axis=0)
        x_reshape = tf.reshape(x, old_shape)
        result = fn(x_reshape, **
                    extra_kwargs) if extra_kwargs else fn(x_reshape)
        new_shape = ps.concat([
            sample_shape,
            self._batch_shape_unexpanded,
        ],
                              axis=0)
        if keep_event_dims:
            new_shape = ps.concat([new_shape, output_event_shape_tensor],
                                  axis=0)
        result = tf.reshape(result, new_shape)
        if (tensorshape_util.rank(static_sample_shape) is not None
                and tensorshape_util.rank(self.batch_shape) is not None):
            new_shape = tensorshape_util.concatenate(static_sample_shape,
                                                     self.batch_shape)
            if keep_event_dims:
                new_shape = tensorshape_util.concatenate(
                    new_shape, static_output_event_shape)
            tensorshape_util.set_shape(result, new_shape)
        return result
Beispiel #5
0
    def _forward_event_shape(self, input_shape):
        input_shape = tensorshape_util.with_rank_at_least(input_shape, 1)
        static_block_sizes = tf.get_static_value(self.block_sizes)
        if static_block_sizes is None:
            return tensorshape_util.concatenate(input_shape[:-1], [None])

        output_size = sum(
            b.forward_event_shape([bs])[0]
            for b, bs in zip(self.bijectors, static_block_sizes))

        return tensorshape_util.concatenate(input_shape[:-1], [output_size])
Beispiel #6
0
def pad_shape_with_ones(x, ndims, start=-1):
  """Maybe add `ndims` ones to `x.shape` starting at `start`.

  If `ndims` is zero, this is a no-op; otherwise, we will create and return a
  new `Tensor` whose shape is that of `x` with `ndims` ones concatenated on the
  right side. If the shape of `x` is known statically, the shape of the return
  value will be as well.

  Args:
    x: The `Tensor` we'll return a reshaping of.
    ndims: Python `integer` number of ones to pad onto `x.shape`.
    start: Python `integer` specifying where to start padding with ones. Must
      be a negative integer. For instance, a value of `-1` means to pad at the
      end of the shape. Default value: `-1`.
  Returns:
    If `ndims` is zero, `x`; otherwise, a `Tensor` whose shape is that of `x`
    with `ndims` ones concatenated on the right side. If possible, returns a
    `Tensor` whose shape is known statically.
  Raises:
    ValueError: if `ndims` is not a Python `integer` greater than or equal to
    zero.
  """
  if not (isinstance(ndims, int) and ndims >= 0):
    raise ValueError(
        '`ndims` must be a Python `integer` greater than zero. Got: {}'
        .format(ndims))
  if not (isinstance(start, int) and start <= -1):
    raise ValueError(
        '`start` must be a Python `integer` less than zero. Got: {}'
        .format(start))
  if ndims == 0:
    return x
  x = tf.convert_to_tensor(value=x)
  original_shape = x.shape
  rank = ps.rank(x)
  first_shape = ps.shape(x)[:rank + start + 1]
  second_shape = ps.shape(x)[rank + start + 1:]
  new_shape = ps.pad(first_shape, paddings=[[0, ndims]], constant_values=1)
  new_shape = ps.concat([new_shape, second_shape], axis=0)
  x = tf.reshape(x, new_shape)
  if start == -1:
    tensorshape_util.set_shape(
        x, tensorshape_util.concatenate(original_shape, [1] * ndims))
  elif tensorshape_util.rank(original_shape) is not None:
    original_ndims = tensorshape_util.rank(original_shape)
    new_shape = tensorshape_util.concatenate(
        original_shape[:original_ndims + start + 1],
        tensorshape_util.concatenate(
            [1] * ndims,
            original_shape[original_ndims + start + 1:]))
    tensorshape_util.set_shape(x, new_shape)
  return x
    def _event_shape(self):
        # The examples index is one position to the left of the feature dims.
        index_points = self.index_points

        if index_points is None:
            return tf.TensorShape([None, self.kernel.num_tasks])
        examples_index = -(self.kernel.feature_ndims + 1)
        shape = tensorshape_util.concatenate(
            index_points.shape[examples_index:examples_index + 1],
            (self.kernel.num_tasks, ))
        if tensorshape_util.rank(shape) is None:
            return tensorshape_util.concatenate(
                [index_points.shape[examples_index:examples_index + 1]],
                [self.kernel.num_tasks])
        return shape
Beispiel #8
0
 def _batch_shape(self):
     batch_stack = tf.TensorShape(tf.get_static_value(self.batch_stack))
     if (tensorshape_util.rank(batch_stack) is None or
             tensorshape_util.rank(self.distribution.event_shape) is None):
         return tf.TensorShape(None)
     return tensorshape_util.concatenate(batch_stack,
                                         self.distribution.batch_shape)
Beispiel #9
0
 def _call_reshape_input_output(self, fn, x, extra_kwargs=None):
     """Calls `fn`, appropriately reshaping its input `x` and output."""
     # Note: we take `extra_kwargs` as a dict rather than `**extra_kwargs`
     # because it is possible the user provided extra kwargs would itself
     # have `fn` and/or `x` as a key.
     with tf.control_dependencies(self._runtime_assertions +
                                  self._validate_sample_arg(x)):
         sample_shape, static_sample_shape = self._sample_shape(x)
         old_shape = tf.concat([
             sample_shape,
             self.distribution.batch_shape_tensor(),
             self.event_shape_tensor(),
         ],
                               axis=0)
         x_reshape = tf.reshape(x, old_shape)
         result = fn(x_reshape, **
                     extra_kwargs) if extra_kwargs else fn(x_reshape)
         new_shape = tf.concat([
             sample_shape,
             self._batch_shape_unexpanded,
         ],
                               axis=0)
         result = tf.reshape(result, new_shape)
         if (tensorshape_util.rank(static_sample_shape) is not None
                 and tensorshape_util.rank(self.batch_shape) is not None):
             new_shape = tensorshape_util.concatenate(
                 static_sample_shape, self.batch_shape)
             tensorshape_util.set_shape(result, new_shape)
         return result
def independents(draw,
                 batch_shape=None,
                 event_dim=None,
                 enable_vars=False,
                 depth=None):
    """Strategy for drawing `Independent` distributions.

  The underlying distribution is drawn from the `distributions` strategy.

  Args:
    draw: Hypothesis MacGuffin.  Supplied by `@hps.composite`.
    batch_shape: An optional `TensorShape`.  The batch shape of the resulting
      `Independent` distribution.  Note that the underlying distribution will in
      general have a higher-rank batch shape, to make room for reinterpreting
      some of those dimensions as the `Independent`'s event.  Hypothesis will
      pick one if omitted.
    event_dim: Optional Python int giving the size of each of the underlying
      distribution's parameters' event dimensions.  This is shared across all
      parameters, permitting square event matrices, compatible location and
      scale Tensors, etc. If omitted, Hypothesis will choose one.
    enable_vars: TODO(bjp): Make this `True` all the time and put variable
      initialization in slicing_test.  If `False`, the returned parameters are
      all Tensors, never Variables or DeferredTensor.
    depth: Python `int` giving maximum nesting depth of compound Distributions.

  Returns:
    dists: A strategy for drawing `Independent` distributions with the specified
      `batch_shape` (or an arbitrary one if omitted).
  """
    if depth is None:
        depth = draw(depths())

    reinterpreted_batch_ndims = draw(hps.integers(min_value=0, max_value=2))
    if batch_shape is None:
        batch_shape = draw(tfp_hps.shapes(min_ndims=reinterpreted_batch_ndims))
    else:  # This independent adds some batch dims to its underlying distribution.
        batch_shape = tensorshape_util.concatenate(
            batch_shape,
            draw(
                tfp_hps.shapes(min_ndims=reinterpreted_batch_ndims,
                               max_ndims=reinterpreted_batch_ndims)))
    underlying = draw(
        distributions(batch_shape=batch_shape,
                      event_dim=event_dim,
                      enable_vars=enable_vars,
                      depth=depth - 1))
    logging.info(
        'underlying distribution: %s; parameters used: %s', underlying,
        [k for k, v in six.iteritems(underlying.parameters) if v is not None])
    result_dist = tfd.Independent(
        underlying,
        reinterpreted_batch_ndims=reinterpreted_batch_ndims,
        validate_args=True)
    expected_shape = batch_shape[:len(batch_shape) - reinterpreted_batch_ndims]
    if expected_shape != result_dist.batch_shape:
        msg = ('Independent strategy generated a bad batch shape '
               'for {}, should have been {}.').format(result_dist,
                                                      expected_shape)
        raise AssertionError(msg)
    return result_dist
Beispiel #11
0
 def _event_shape(self):
     sample_shape = tf.TensorShape(tf.get_static_value(self.sample_shape))
     if (tensorshape_util.rank(sample_shape) is None or
             tensorshape_util.rank(self.distribution.event_shape) is None):
         return tf.TensorShape(None)
     return tensorshape_util.concatenate(sample_shape,
                                         self.distribution.event_shape)
Beispiel #12
0
    def _inverse(self, y):
        # To derive the inverse mapping note that:
        #   y[i] = exp(x[i]) / normalization
        # and
        #   y[end] = 1 / normalization.
        # Thus:
        # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization)
        #      = log(exp(x[i])/normalization) - log(y[end])
        #      = log(y[i]) - log(y[end])

        # Do this first to make sure CSE catches that it'll happen again in
        # _inverse_log_det_jacobian.
        x = tf.math.log(y)

        log_normalization = (-x[..., -1])[..., tf.newaxis]
        x = x[..., :-1] + log_normalization

        # Set shape hints.
        if tensorshape_util.rank(y.shape) is not None:
            last_dim = tf.compat.dimension_value(y.shape[-1])
            shape = tensorshape_util.concatenate(
                y.shape[:-1], None if last_dim is None else last_dim - 1)
            tensorshape_util.set_shape(x, shape)

        return x
Beispiel #13
0
def make_multivariate_mixture(batch_shape,
                              num_components,
                              event_shape,
                              use_static_graph,
                              batch_shape_tensor=None):
    if batch_shape_tensor is None:
        batch_shape_tensor = batch_shape
    batch_shape_tensor = tf.convert_to_tensor(value=batch_shape_tensor,
                                              dtype=tf.int32)
    logits = tf.random.uniform(tf.concat(
        (batch_shape_tensor, [num_components]), 0),
                               -1,
                               1,
                               dtype=tf.float32) - 50.
    tensorshape_util.set_shape(
        logits, tensorshape_util.concatenate(batch_shape, num_components))
    static_batch_and_event_shape = (
        tf.TensorShape(batch_shape).concatenate(event_shape))
    event_shape = tf.convert_to_tensor(value=event_shape, dtype=tf.int32)
    batch_and_event_shape = tf.concat((batch_shape_tensor, event_shape), 0)

    def create_component():
        loc = tf.random.normal(batch_and_event_shape)
        scale_diag = 10 * tf.random.uniform(batch_and_event_shape)
        tensorshape_util.set_shape(loc, static_batch_and_event_shape)
        tensorshape_util.set_shape(scale_diag, static_batch_and_event_shape)
        return tfd.MultivariateNormalDiag(loc=loc, scale_diag=scale_diag)

    components = [create_component() for _ in range(num_components)]
    cat = tfd.Categorical(logits, dtype=tf.int32)
    return tfd.Mixture(cat, components, use_static_graph=use_static_graph)
def mixtures_same_family(draw, batch_shape=None, event_dim=None):
    if batch_shape is None:
        # Ensure the components dist has at least one batch dim (a component dim).
        batch_shape = draw(batch_shapes(min_ndims=1, min_lastdimsize=2))
    else:  # This mixture adds a batch dim to its underlying components dist.
        batch_shape = tensorshape_util.concatenate(
            batch_shape,
            draw(batch_shapes(min_ndims=1, max_ndims=1, min_lastdimsize=2)))

    component_dist, _ = distributions(
        draw,
        batch_shape=batch_shape,
        event_dim=event_dim,
        eligibility_filter=lambda name: name != 'MixtureSameFamily')
    logging.info('component distribution: %s; parameters used: %s',
                 component_dist, [
                     k for k, v in six.iteritems(component_dist.parameters)
                     if v is not None
                 ])
    # scalar or same-shaped categorical?
    mixture_batch_shape = draw(
        hps.one_of(hps.just(batch_shape[:-1]), hps.just(tf.TensorShape([]))))
    mixture_dist, _ = distributions(
        draw,
        dist_name='Categorical',
        batch_shape=mixture_batch_shape,
        event_dim=tensorshape_util.as_list(batch_shape)[-1])
    logging.info(
        'mixture distribution: %s; parameters used: %s', mixture_dist, [
            k
            for k, v in six.iteritems(mixture_dist.parameters) if v is not None
        ])
    return (tfd.MixtureSameFamily(components_distribution=component_dist,
                                  mixture_distribution=mixture_dist,
                                  validate_args=True), batch_shape[:-1])
    def _stddev(self):
        if distribution_util.is_diagonal_scale(self.scale):
            stddev = tf.abs(self.scale.diag_part())
        elif (isinstance(self.scale, tf.linalg.LinearOperatorLowRankUpdate)
              and self.scale.is_self_adjoint):
            stddev = tf.sqrt(
                tf.linalg.diag_part(self.scale.matmul(self.scale.to_dense())))
        else:
            stddev = tf.sqrt(
                tf.linalg.diag_part(
                    self.scale.matmul(self.scale.to_dense(),
                                      adjoint_arg=True)))

        shape = tensorshape_util.concatenate(self.batch_shape,
                                             self.event_shape)
        has_static_shape = tensorshape_util.is_fully_defined(shape)
        if not has_static_shape:
            shape = tf.concat([
                self.batch_shape_tensor(),
                self.event_shape_tensor(),
            ], 0)

        if has_static_shape and shape == stddev.shape:
            return stddev

        # Add dummy tensor of zeros to broadcast.  This is only necessary if shape
        # != stddev.shape, but we could not determine if this is the case.
        return stddev + tf.zeros(shape, self.dtype)
def independents(draw, batch_shape=None, event_dim=None, enable_vars=False):
  reinterpreted_batch_ndims = draw(hps.integers(min_value=0, max_value=2))
  if batch_shape is None:
    batch_shape = draw(
        tfp_hps.batch_shapes(min_ndims=reinterpreted_batch_ndims))
  else:  # This independent adds some batch dims to its underlying distribution.
    batch_shape = tensorshape_util.concatenate(
        batch_shape,
        draw(
            tfp_hps.batch_shapes(
                min_ndims=reinterpreted_batch_ndims,
                max_ndims=reinterpreted_batch_ndims)))
  underlying, batch_shape = draw(
      distributions(
          batch_shape=batch_shape,
          event_dim=event_dim,
          enable_vars=enable_vars,
          eligibility_filter=lambda name: name != 'Independent'))
  logging.info(
      'underlying distribution: %s; parameters used: %s', underlying,
      [k for k, v in six.iteritems(underlying.parameters) if v is not None])
  return (tfd.Independent(
      underlying,
      reinterpreted_batch_ndims=reinterpreted_batch_ndims,
      validate_args=True),
          batch_shape[:len(batch_shape) - reinterpreted_batch_ndims])
Beispiel #17
0
 def _forward_event_shape(self, input_shape):
   batch_shape, d = input_shape[:-1], tf.compat.dimension_value(
       input_shape[-1])
   if d is None:
     n = None
   else:
     n = vector_size_to_square_matrix_size(d, self.validate_args)
   return tensorshape_util.concatenate(batch_shape, [n, n])
Beispiel #18
0
 def _inverse_event_shape(self, event_shape):
     num_steps = tf.nest.flatten(event_shape)[0][0]
     head_shape = tf.nest.map_structure(lambda s: s[1:], event_shape)
     tail_shape = tf.nest.map_structure(
         lambda s: tensorshape_util.concatenate([num_steps - 1], s[1:]),
         event_shape)
     return (self.initial_bijector.inverse_event_shape(head_shape),
             self.transition_bijector.inverse_event_shape(tail_shape))
 def f(x_unconstrained, batch_shape=batch_shape):
     # Unflatten any batch dimensions now under the tape.
     unflattened_x_unconstrained = tf.reshape(
         x_unconstrained,
         tensorshape_util.concatenate(batch_shape,
                                      x_unconstrained.shape[-1:]))
     f_x = bijector.forward(
         input_to_unconstrained.inverse(unflattened_x_unconstrained))
     return f_x
 def _mode_mean_shape(self):
   """Shape for the mode/mean Tensors."""
   shape = tensorshape_util.concatenate(self.batch_shape, self.event_shape)
   has_static_shape = tensorshape_util.is_fully_defined(shape)
   if not has_static_shape:
     shape = tf.concat([
         self.batch_shape_tensor(),
         self.event_shape_tensor(),
     ], 0)
   return shape
Beispiel #21
0
    def testCholeskyUpdateRandomized(self, data):
        target_bs = data.draw(hpnp.array_shapes())
        chol_bs, u_bs, multiplier_bs = data.draw(
            tfp_hps.broadcasting_shapes(target_bs, 3))
        l = data.draw(hps.integers(min_value=1, max_value=12))

        rng_seed = data.draw(hps.integers(min_value=0, max_value=2**32 - 1))
        rng = np.random.RandomState(seed=rng_seed)
        xs = push_apart(
            rng.uniform(size=tensorshape_util.concatenate(chol_bs, (l, 1))),
            axis=-2)
        hp.note(xs)
        xs = xs.astype(self.dtype)
        xs = tf1.placeholder_with_default(
            xs, shape=xs.shape if self.use_static_shape else None)

        k = tfp.math.psd_kernels.MaternOneHalf()
        jitter = lambda n: tf.linalg.eye(n, dtype=self.dtype) * 5e-5

        mat = k.matrix(xs, xs) + jitter(l)
        chol = tf.linalg.cholesky(mat)

        u = rng.uniform(size=tensorshape_util.concatenate(u_bs, (l, )))
        hp.note(u)
        u = u.astype(self.dtype)
        u = tf1.placeholder_with_default(
            u, shape=u.shape if self.use_static_shape else None)

        multiplier = rng.uniform(size=multiplier_bs)
        hp.note(multiplier)
        multiplier = multiplier.astype(self.dtype)
        multiplier = tf1.placeholder_with_default(
            multiplier,
            shape=multiplier.shape if self.use_static_shape else None)

        new_chol_expected = tf.linalg.cholesky(
            mat + multiplier[..., tf.newaxis, tf.newaxis] *
            tf.linalg.matmul(u[..., tf.newaxis], u[..., tf.newaxis, :]))

        new_chol = tfp.math.cholesky_update(chol, u, multiplier=multiplier)
        self.assertAllClose(new_chol_expected, new_chol, rtol=1e-5, atol=2e-5)
        self.assertAllEqual(tf.linalg.band_part(new_chol, -1, 0), new_chol)
Beispiel #22
0
 def _expand_base_distribution_mean(self):
   """Ensures `self.distribution.mean()` has `[batch, event]` shape."""
   single_draw_shape = concat_vectors(self.batch_shape_tensor(),
                                      self.event_shape_tensor())
   m = tf.reshape(
       self.distribution.mean(),  # A scalar.
       shape=tf.ones_like(single_draw_shape, dtype=tf.int32))
   m = tf.tile(m, multiples=single_draw_shape)
   tensorshape_util.set_shape(
       m, tensorshape_util.concatenate(self.batch_shape, self.event_shape))
   return m
Beispiel #23
0
 def _inverse_event_shape(self, output_shape):
   batch_shape, n1, n2 = (output_shape[:-2],
                          tf.compat.dimension_value(output_shape[-2]),
                          tf.compat.dimension_value(output_shape[-1]))
   if n1 is None or n2 is None:
     m = None
   elif n1 != n2:
     raise ValueError("Matrix must be square. (saw [{}, {}])".format(n1, n2))
   else:
     m = n1 * (n1 + 1) / 2
   return tensorshape_util.concatenate(batch_shape, [m])
Beispiel #24
0
    def testDistribution(self, data):
        enable_vars = data.draw(hps.booleans())

        # TODO(b/146572907): Fix `enable_vars` for metadistributions.
        broken_dists = EVENT_SPACE_BIJECTOR_IS_BROKEN
        if enable_vars:
            broken_dists.extend(dhps.INSTANTIABLE_META_DISTS)

        dist = data.draw(
            dhps.distributions(
                enable_vars=enable_vars,
                eligibility_filter=(lambda name: name not in broken_dists)))
        self.evaluate([var.initializer for var in dist.variables])
        self.check_bad_loc_scale(dist)

        event_space_bijector = dist._experimental_default_event_space_bijector(
        )
        if event_space_bijector is None:
            return

        total_sample_shape = tensorshape_util.concatenate(
            # Draw a sample shape
            data.draw(tfp_hps.shapes()),
            # Draw a shape that broadcasts with `[batch_shape, inverse_event_shape]`
            # where `inverse_event_shape` is the event shape in the bijector's
            # domain. This is the shape of `y` in R**n, such that
            # x = event_space_bijector(y) has the event shape of the distribution.
            data.draw(
                tfp_hps.broadcasting_shapes(tensorshape_util.concatenate(
                    dist.batch_shape,
                    event_space_bijector.inverse_event_shape(
                        dist.event_shape)),
                                            n=1))[0])

        y = data.draw(
            tfp_hps.constrained_tensors(tfp_hps.identity_fn,
                                        total_sample_shape.as_list()))
        x = event_space_bijector(y)
        with tf.control_dependencies(dist._sample_control_dependencies(x)):
            self.evaluate(tf.identity(x))
Beispiel #25
0
    def _forward(self, x):
        # Pad the last dim with a zeros vector. We need this because it lets us
        # infer the scale in the inverse function.
        y = distribution_util.pad(x, axis=-1, back=True)

        # Set shape hints.
        if tensorshape_util.rank(x.shape) is not None:
            last_dim = tf.compat.dimension_value(x.shape[-1])
            shape = tensorshape_util.concatenate(
                x.shape[:-1], None if last_dim is None else last_dim + 1)
            tensorshape_util.set_shape(y, shape)

        return tf.math.softmax(y)
Beispiel #26
0
  def _mean(self):
    shape = tensorshape_util.concatenate(self.batch_shape, self.event_shape)
    has_static_shape = tensorshape_util.is_fully_defined(shape)
    if not has_static_shape:
      shape = tf.concat([
          self.batch_shape_tensor(),
          self.event_shape_tensor(),
      ], 0)

    if self.loc is None:
      return tf.zeros(shape, self.dtype)

    return tf.broadcast_to(self.loc, shape)
Beispiel #27
0
 def _event_shape(self):
     batch_shape = self.distribution.batch_shape
     if self._static_reinterpreted_batch_ndims is None:
         return tf.TensorShape(None)
     if tensorshape_util.rank(batch_shape) is not None:
         reinterpreted_batch_shape = batch_shape[
             tensorshape_util.rank(batch_shape) -
             self._static_reinterpreted_batch_ndims:]
     else:
         reinterpreted_batch_shape = tf.TensorShape(
             [None] * int(self._static_reinterpreted_batch_ndims))
     return tensorshape_util.concatenate(reinterpreted_batch_shape,
                                         self.distribution.event_shape)
    def check_event_space_bijector_constrains(self, dist, data):
        event_space_bijector = dist.experimental_default_event_space_bijector()
        if event_space_bijector is None:
            return

        # Draw a sample shape
        sample_shape = data.draw(tfp_hps.shapes())
        inv_event_shape = event_space_bijector.inverse_event_shape(
            tensorshape_util.concatenate(dist.batch_shape, dist.event_shape))

        # Draw a shape that broadcasts with `[batch_shape, inverse_event_shape]`
        # where `inverse_event_shape` is the event shape in the bijector's
        # domain. This is the shape of `y` in R**n, such that
        # x = event_space_bijector(y) has the event shape of the distribution.

        # TODO(b/174778703): Actually draw broadcast compatible shapes.
        batch_inv_event_compat_shape = inv_event_shape
        # batch_inv_event_compat_shape = data.draw(
        #     tfp_hps.broadcast_compatible_shape(inv_event_shape))
        # batch_inv_event_compat_shape = tensorshape_util.concatenate(
        #     (1,) * (len(inv_event_shape) - len(batch_inv_event_compat_shape)),
        #     batch_inv_event_compat_shape)

        total_sample_shape = tensorshape_util.concatenate(
            sample_shape, batch_inv_event_compat_shape)
        # full_sample_batch_event_shape = tensorshape_util.concatenate(
        #     sample_shape, inv_event_shape)

        y = data.draw(
            tfp_hps.constrained_tensors(tfp_hps.identity_fn,
                                        total_sample_shape.as_list()))
        hp.note('Trying to constrain inputs {}'.format(y))
        with tfp_hps.no_tf_rank_errors():
            x = event_space_bijector(y)
            hp.note('Got constrained samples {}'.format(x))
            with tf.control_dependencies(dist._sample_control_dependencies(x)):
                self.evaluate(tensor_util.identity_as_tensor(x))
  def _mean(self):
    shape = tensorshape_util.concatenate(self.batch_shape, self.event_shape)
    has_static_shape = tensorshape_util.is_fully_defined(shape)
    if not has_static_shape:
      shape = tf.concat([
          self.batch_shape_tensor(),
          self.event_shape_tensor(),
      ], 0)

    if self.loc is None:
      return tf.zeros(shape, self.dtype)

    if has_static_shape and shape == self.loc.shape:
      return tf.identity(self.loc)

    # Add dummy tensor of zeros to broadcast.  This is only necessary if shape
    # != self.loc.shape, but we could not determine if this is the case.
    return tf.identity(self.loc) + tf.zeros(shape, self.dtype)
Beispiel #30
0
def _replace_event_shape_in_tensorshape(
    input_tensorshape, event_shape_in, event_shape_out):
  """Replaces the event shape dims of a `TensorShape`.

  Args:
    input_tensorshape: a `TensorShape` instance in which to attempt replacing
      event shape.
    event_shape_in: `Tensor` shape representing the event shape expected to
      be present in (rightmost dims of) `tensorshape_in`. Must be compatible
      with the rightmost dims of `tensorshape_in`.
    event_shape_out: `Tensor` shape representing the new event shape, i.e.,
      the replacement of `event_shape_in`,

  Returns:
    output_tensorshape: `TensorShape` with the rightmost `event_shape_in`
      replaced by `event_shape_out`. Might be partially defined, i.e.,
      `TensorShape(None)`.
    is_validated: Python `bool` indicating static validation happened.

  Raises:
    ValueError: if we can determine the event shape portion of
      `tensorshape_in` as well as `event_shape_in` both statically, and they
      are not compatible. "Compatible" here means that they are identical on
      any dims that are not -1 in `event_shape_in`.
  """
  event_shape_in_ndims = tensorshape_util.num_elements(event_shape_in.shape)
  if tensorshape_util.rank(
      input_tensorshape) is None or event_shape_in_ndims is None:
    return tf.TensorShape(None), False  # Not is_validated.

  input_non_event_ndims = tensorshape_util.rank(
      input_tensorshape) - event_shape_in_ndims
  if input_non_event_ndims < 0:
    raise ValueError(
        'Input has fewer ndims ({}) than event shape ndims ({}).'.format(
            tensorshape_util.rank(input_tensorshape), event_shape_in_ndims))

  input_non_event_tensorshape = input_tensorshape[:input_non_event_ndims]
  input_event_tensorshape = input_tensorshape[input_non_event_ndims:]

  # Check that `input_event_shape_` and `event_shape_in` are compatible in the
  # sense that they have equal entries in any position that isn't a `-1` in
  # `event_shape_in`. Note that our validations at construction time ensure
  # there is at most one such entry in `event_shape_in`.
  event_shape_in_ = tf.get_static_value(event_shape_in)
  is_validated = (
      tensorshape_util.is_fully_defined(input_event_tensorshape) and
      event_shape_in_ is not None)
  if is_validated:
    input_event_shape_ = np.int32(input_event_tensorshape)
    mask = event_shape_in_ >= 0
    explicit_input_event_shape_ = input_event_shape_[mask]
    explicit_event_shape_in_ = event_shape_in_[mask]
    if not all(explicit_input_event_shape_ == explicit_event_shape_in_):
      raise ValueError(
          'Input `event_shape` does not match `event_shape_in`. '
          '({} vs {}).'.format(input_event_shape_, event_shape_in_))

  event_tensorshape_out = tensorshape_util.constant_value_as_shape(
      event_shape_out)
  if tensorshape_util.rank(event_tensorshape_out) is None:
    output_tensorshape = tf.TensorShape(None)
  else:
    output_tensorshape = tensorshape_util.concatenate(
        input_non_event_tensorshape, event_tensorshape_out)

  return output_tensorshape, is_validated