def _sample_n(self, n, seed=None):
   sample_shape = _concat_vectors(
       distribution_util.pick_vector(self._needs_rotation, self._empty, [n]),
       self._override_batch_shape,
       self._override_event_shape,
       distribution_util.pick_vector(self._needs_rotation, [n], self._empty))
   x = self.distribution.sample(sample_shape=sample_shape, seed=seed)
   x = self._maybe_rotate_dims(x)
   # We'll apply the bijector in the `_call_sample_n` function.
   return x
 def _sample_n(self, n, seed=None):
   sample_shape = _concat_vectors(
       distribution_util.pick_vector(self._needs_rotation, self._empty, [n]),
       self._override_batch_shape,
       self._override_event_shape,
       distribution_util.pick_vector(self._needs_rotation, [n], self._empty))
   x = self.distribution.sample(sample_shape=sample_shape, seed=seed)
   x = self._maybe_rotate_dims(x)
   # We'll apply the bijector in the `_call_sample_n` function.
   return x
 def _sample_n(self, n, seed=None, **distribution_kwargs):
   sample_shape = prefer_static.concat([
       distribution_util.pick_vector(self._needs_rotation, self._empty, [n]),
       self._override_batch_shape,
       self._override_event_shape,
       distribution_util.pick_vector(self._needs_rotation, [n], self._empty),
   ], axis=0)
   x = self.distribution.sample(sample_shape=sample_shape, seed=seed,
                                **distribution_kwargs)
   x = self._maybe_rotate_dims(x)
   # We'll apply the bijector in the `_call_sample_n` function.
   return x
 def testCorrectlyPicksVector(self):
   x = np.arange(10, 12)
   y = np.arange(15, 18)
   self.assertAllEqual(
       x, self.evaluate(distribution_util.pick_vector(tf.less(0, 5), x, y)))
   self.assertAllEqual(
       y, self.evaluate(distribution_util.pick_vector(tf.less(5, 0), x, y)))
   self.assertAllEqual(x,
                       distribution_util.pick_vector(
                           tf.constant(True), x, y))  # No eval.
   self.assertAllEqual(y,
                       distribution_util.pick_vector(
                           tf.constant(False), x, y))  # No eval.
  def _make_columnar(self, x):
    """Ensures non-scalar input has at least one column.

    Example:
      If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`.

      If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged.

      If `x = 1` then the output is unchanged.

    Args:
      x: `Tensor`.

    Returns:
      columnar_x: `Tensor` with at least two dimensions.
    """
    if x.shape.ndims is not None:
      if x.shape.ndims == 1:
        x = x[tf.newaxis, :]
      return x
    shape = tf.shape(x)
    maybe_expanded_shape = tf.concat([
        shape[:-1],
        distribution_util.pick_vector(
            tf.equal(tf.rank(x), 1), [1], np.array([], dtype=np.int32)),
        shape[-1:],
    ], 0)
    return tf.reshape(x, maybe_expanded_shape)
Esempio n. 6
0
    def _expand_sample_shape_to_vector(self, x, name):
        """Helper to `sample` which ensures input is 1D."""
        x_static_val = tensor_util.constant_value(x)
        if x_static_val is None:
            prod = tf.reduce_prod(x)
        else:
            prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype())

        ndims = x.get_shape().ndims  # != sample_ndims
        if ndims is None:
            # Maybe expand_dims.
            ndims = tf.rank(x)
            expanded_shape = util.pick_vector(tf.equal(ndims, 0),
                                              np.array([1], dtype=np.int32),
                                              tf.shape(x))
            x = tf.reshape(x, expanded_shape)
        elif ndims == 0:
            # Definitely expand_dims.
            if x_static_val is not None:
                x = tf.convert_to_tensor(np.array(
                    [x_static_val], dtype=x.dtype.as_numpy_dtype()),
                                         name=name)
            else:
                x = tf.reshape(x, [1])
        elif ndims != 1:
            raise ValueError("Input is neither scalar nor vector.")

        return x, prod
    def _batch_shape_tensor(self,
                            override_batch_shape=None,
                            base_batch_shape_tensor=None):
        override_batch_shape = (tf.convert_to_tensor(
            self._override_batch_shape) if override_batch_shape is None else
                                override_batch_shape)
        base_batch_shape_tensor = (self.distribution.batch_shape_tensor()
                                   if base_batch_shape_tensor is None else
                                   base_batch_shape_tensor)

        # The `batch_shape_tensor` of the transformed distribution is the same as
        # that of the base distribution in all cases except when the following are
        # both true:
        #   - the base distribution is joint with structured `batch_shape_tensor`
        #   - the transformed distribution is not joint.
        # In this case, the components of the base distribution's
        # `batch_shape_tensor` are broadcast to obtain the `batch_shape_tensor` of
        # the transformed distribution. Non-broadcasting components are not
        # supported. (Note that joint distributions may either have a single
        # `batch_shape_tensor` for all components, or a component-wise
        # `batch_shape_tensor` with the same nested structure as the distribution's
        # dtype.)
        if tf.nest.is_nested(base_batch_shape_tensor):
            if self._is_joint:
                return base_batch_shape_tensor

            base_batch_shape_tensor = functools.reduce(
                prefer_static.broadcast_shape,
                tf.nest.flatten(base_batch_shape_tensor))

        # If the batch shape has been overridden, return the override batch shape
        # instead.
        return distribution_util.pick_vector(
            self._has_nonzero_rank(override_batch_shape), override_batch_shape,
            base_batch_shape_tensor)
Esempio n. 8
0
  def _make_columnar(self, x):
    """Ensures non-scalar input has at least one column.

    Example:
      If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`.

      If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged.

      If `x = 1` then the output is unchanged.

    Args:
      x: `Tensor`.

    Returns:
      columnar_x: `Tensor` with at least two dimensions.
    """
    if tensorshape_util.rank(x.shape) is not None:
      if tensorshape_util.rank(x.shape) == 1:
        x = x[tf.newaxis, :]
      return x
    shape = tf.shape(x)
    maybe_expanded_shape = tf.concat([
        shape[:-1],
        distribution_util.pick_vector(
            tf.equal(tf.rank(x), 1), [1], np.array([], dtype=np.int32)),
        shape[-1:],
    ], 0)
    return tf.reshape(x, maybe_expanded_shape)
Esempio n. 9
0
  def _expand_sample_shape_to_vector(self, x, name):
    """Helper to `sample` which ensures input is 1D."""
    x_static_val = tf.contrib.util.constant_value(x)
    if x_static_val is None:
      prod = tf.reduce_prod(x)
    else:
      prod = np.prod(x_static_val, dtype=x.dtype.as_numpy_dtype())

    ndims = x.shape.ndims  # != sample_ndims
    if ndims is None:
      # Maybe expand_dims.
      ndims = tf.rank(x)
      expanded_shape = util.pick_vector(
          tf.equal(ndims, 0),
          np.array([1], dtype=np.int32), tf.shape(x))
      x = tf.reshape(x, expanded_shape)
    elif ndims == 0:
      # Definitely expand_dims.
      if x_static_val is not None:
        x = tf.convert_to_tensor(
            np.array([x_static_val], dtype=x.dtype.as_numpy_dtype()),
            name=name)
      else:
        x = tf.reshape(x, [1])
    elif ndims != 1:
      raise ValueError("Input is neither scalar nor vector.")

    return x, prod
Esempio n. 10
0
  def _sample_n(self, n, seed=None):
    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    distributions = self.poisson_and_mixture_distributions()
    dist, mixture_dist = distributions
    batch_size = tensorshape_util.num_elements(self.batch_shape)
    if batch_size is None:
      batch_size = tf.reduce_prod(
          self._batch_shape_tensor(distributions=distributions))
    # We need to 'sample extra' from the mixture distribution if it doesn't
    # already specify a probs vector for each batch coordinate.
    # We only support this kind of reduced broadcasting, i.e., there is exactly
    # one probs vector for all batch dims or one for each.
    mixture_seed, poisson_seed = samplers.split_seed(
        seed, salt='PoissonLogNormalQuadratureCompound')
    ids = mixture_dist.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                mixture_dist.is_scalar_batch(),
                [batch_size],
                np.int32([]))),
        seed=mixture_seed)
    # We need to flatten batch dims in case mixture_dist has its own
    # batch dims.
    ids = tf.reshape(
        ids,
        shape=concat_vectors([n],
                             distribution_util.pick_vector(
                                 self.is_scalar_batch(), np.int32([]),
                                 np.int32([-1]))))

    # Stride `quadrature_size` for `batch_size` number of times.
    offset = tf.range(
        start=0,
        limit=batch_size * self._quadrature_size,
        delta=self._quadrature_size,
        dtype=ids.dtype)
    ids = ids + offset
    rate = tf.gather(tf.reshape(dist.rate_parameter(), shape=[-1]), ids)
    rate = tf.reshape(
        rate, shape=concat_vectors([n], self._batch_shape_tensor(
            distributions=distributions)))
    return samplers.poisson(
        shape=[], lam=rate, dtype=self.dtype, seed=poisson_seed)
 def _batch_shape_tensor(self,
                         override_batch_shape=None,
                         base_batch_shape_tensor=None):
     override_batch_shape = (tf.convert_to_tensor(
         self._override_batch_shape) if override_batch_shape is None else
                             override_batch_shape)
     base_batch_shape_tensor = (self.distribution.batch_shape_tensor()
                                if base_batch_shape_tensor is None else
                                base_batch_shape_tensor)
     return distribution_util.pick_vector(
         self._has_nonzero_rank(override_batch_shape), override_batch_shape,
         base_batch_shape_tensor)
Esempio n. 12
0
  def _sample_n(self, n, seed=None):
    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    batch_size = self.batch_shape.num_elements()
    if batch_size is None:
      batch_size = tf.reduce_prod(self.batch_shape_tensor())
    # We need to "sample extra" from the mixture distribution if it doesn't
    # already specify a probs vector for each batch coordinate.
    # We only support this kind of reduced broadcasting, i.e., there is exactly
    # one probs vector for all batch dims or one for each.
    stream = seed_stream.SeedStream(
        seed, salt="PoissonLogNormalQuadratureCompound")
    ids = self._mixture_distribution.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                self.mixture_distribution.is_scalar_batch(),
                [batch_size],
                np.int32([]))),
        seed=stream())
    # We need to flatten batch dims in case mixture_distribution has its own
    # batch dims.
    ids = tf.reshape(
        ids,
        shape=concat_vectors([n],
                             distribution_util.pick_vector(
                                 self.is_scalar_batch(), np.int32([]),
                                 np.int32([-1]))))

    # Stride `quadrature_size` for `batch_size` number of times.
    offset = tf.range(
        start=0,
        limit=batch_size * self._quadrature_size,
        delta=self._quadrature_size,
        dtype=ids.dtype)
    ids += offset
    rate = tf.gather(tf.reshape(self.distribution.rate, shape=[-1]), ids)
    rate = tf.reshape(
        rate, shape=concat_vectors([n], self.batch_shape_tensor()))
    return tf.random_poisson(lam=rate, shape=[], dtype=self.dtype, seed=seed)
Esempio n. 13
0
    def _sample_n(self, n, seed=None):
        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = self.batch_shape.num_elements()
        if batch_size is None:
            batch_size = tf.reduce_prod(input_tensor=self.batch_shape_tensor())
        # We need to "sample extra" from the mixture distribution if it doesn't
        # already specify a probs vector for each batch coordinate.
        # We only support this kind of reduced broadcasting, i.e., there is exactly
        # one probs vector for all batch dims or one for each.
        stream = seed_stream.SeedStream(
            seed, salt="PoissonLogNormalQuadratureCompound")
        ids = self._mixture_distribution.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                self.mixture_distribution.is_scalar_batch(), [batch_size],
                np.int32([]))),
                                                seed=stream())
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `quadrature_size` for `batch_size` number of times.
        offset = tf.range(start=0,
                          limit=batch_size * self._quadrature_size,
                          delta=self._quadrature_size,
                          dtype=ids.dtype)
        ids += offset
        rate = tf.gather(tf.reshape(self.distribution.rate, shape=[-1]), ids)
        rate = tf.reshape(rate,
                          shape=concat_vectors([n], self.batch_shape_tensor()))
        return tf.random.poisson(lam=rate,
                                 shape=[],
                                 dtype=self.dtype,
                                 seed=seed)
 def _event_shape_tensor(self,
                         override_event_shape=None,
                         base_event_shape_tensor=None):
     override_event_shape = (tf.convert_to_tensor(
         self._override_event_shape) if override_event_shape is None else
                             override_event_shape)
     base_event_shape_tensor = (self.distribution.event_shape_tensor()
                                if base_event_shape_tensor is None else
                                base_event_shape_tensor)
     return self.bijector.forward_event_shape_tensor(
         distribution_util.pick_vector(
             self._has_nonzero_rank(override_event_shape),
             override_event_shape, base_event_shape_tensor))
    def _sample_n(self, n, seed=None, **distribution_kwargs):
        override_event_shape = tf.convert_to_tensor(self._override_event_shape)
        override_batch_shape = tf.convert_to_tensor(self._override_batch_shape)
        base_is_scalar_batch = self.distribution.is_scalar_batch()

        needs_rotation = self._needs_rotation(override_event_shape,
                                              override_batch_shape,
                                              base_is_scalar_batch)
        sample_shape = prefer_static.concat([
            distribution_util.pick_vector(needs_rotation, self._empty, [n]),
            override_batch_shape,
            override_event_shape,
            distribution_util.pick_vector(needs_rotation, [n], self._empty),
        ],
                                            axis=0)
        x = self.distribution.sample(sample_shape=sample_shape,
                                     seed=seed,
                                     **distribution_kwargs)
        x = self._maybe_rotate_dims(x, override_event_shape,
                                    override_batch_shape, base_is_scalar_batch)
        # We'll apply the bijector in the `_call_sample_n` function.
        return x
Esempio n. 16
0
    def make_batch_of_event_sample_matrices(
            self,
            x,
            expand_batch_dim=True,
            name="make_batch_of_event_sample_matrices"):
        """Reshapes/transposes `Distribution` `Tensor` from S+B+E to B_+E_+S_.

    Where:
      - `B_ = B if B or not expand_batch_dim else [1]`,
      - `E_ = E if E else [1]`,
      - `S_ = [tf.reduce_prod(S)]`.

    Args:
      x: `Tensor`.
      expand_batch_dim: Python `bool`. If `True` the batch dims will be expanded
        such that `batch_ndims >= 1`.
      name: Python `str`. The name to give this op.

    Returns:
      x: `Tensor`. Input transposed/reshaped to `B_+E_+S_`.
      sample_shape: `Tensor` (1D, `int32`).
    """
        with self._name_scope(name, values=[x]):
            x = tf.convert_to_tensor(x, name="x")
            # x.shape: S+B+E
            sample_shape, batch_shape, event_shape = self.get_shape(x)
            event_shape = distribution_util.pick_vector(
                self._event_ndims_is_0, [1], event_shape)
            if expand_batch_dim:
                batch_shape = distribution_util.pick_vector(
                    self._batch_ndims_is_0, [1], batch_shape)
            new_shape = tf.concat([[-1], batch_shape, event_shape], 0)
            x = tf.reshape(x, shape=new_shape)
            # x.shape: [prod(S)]+B_+E_
            x = distribution_util.rotate_transpose(x, shift=-1)
            # x.shape: B_+E_+[prod(S)]
            return x, sample_shape
    def _event_shape_tensor(self,
                            override_event_shape=None,
                            base_event_shape_tensor=None):
        override_event_shape = (tf.convert_to_tensor(
            self._override_event_shape) if override_event_shape is None else
                                override_event_shape)
        base_event_shape_tensor = (self.distribution.event_shape_tensor()
                                   if base_event_shape_tensor is None else
                                   base_event_shape_tensor)

        # If the base distribution is not joint, use the base event shape override,
        # if any.
        if not self._base_is_joint:
            base_event_shape_tensor = distribution_util.pick_vector(
                self._has_nonzero_rank(override_event_shape),
                override_event_shape, base_event_shape_tensor)
        return self.bijector.forward_event_shape_tensor(
            base_event_shape_tensor)
Esempio n. 18
0
  def __init__(self,
               df,
               loc=None,
               scale_identity_multiplier=None,
               scale_diag=None,
               scale_tril=None,
               scale_perturb_factor=None,
               scale_perturb_diag=None,
               validate_args=False,
               allow_nan_stats=True,
               name="VectorStudentT"):
    """Instantiates the vector Student's t-distributions on `R^k`.

    The `batch_shape` is the broadcast between `df.batch_shape` and
    `Affine.batch_shape` where `Affine` is constructed from `loc` and
    `scale_*` arguments.

    The `event_shape` is the event shape of `Affine.event_shape`.

    Args:
      df: Floating-point `Tensor`. The degrees of freedom of the
        distribution(s). `df` must contain only positive values. Must be
        scalar if `loc`, `scale_*` imply non-scalar batch_shape or must have the
        same `batch_shape` implied by `loc`, `scale_*`.
      loc: Floating-point `Tensor`. If this is set to `None`, no `loc` is
        applied.
      scale_identity_multiplier: floating point rank 0 `Tensor` representing a
        scaling done to the identity matrix. When `scale_identity_multiplier =
        scale_diag=scale_tril = None` then `scale += IdentityMatrix`. Otherwise
        no scaled-identity-matrix is added to `scale`.
      scale_diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ..., k], which represents a k x k
        diagonal matrix. When `None` no diagonal term is added to `scale`.
      scale_tril: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ..., k, k], which represents a k x k
        lower triangular matrix. When `None` no `scale_tril` term is added to
        `scale`. The upper triangular elements above the diagonal are ignored.
      scale_perturb_factor: Floating-point `Tensor` representing factor matrix
        with last two dimensions of shape `(k, r)`. When `None`, no rank-r
        update is added to `scale`.
      scale_perturb_diag: Floating-point `Tensor` representing the diagonal
        matrix. `scale_perturb_diag` has shape [N1, N2, ..., r], which
        represents an r x r Diagonal matrix. When `None` low rank updates will
        take the form `scale_perturb_factor * scale_perturb_factor.T`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = dict(locals())
    graph_parents = [df, loc, scale_identity_multiplier, scale_diag,
                     scale_tril, scale_perturb_factor, scale_perturb_diag]
    with tf.name_scope(name) as name:
      with tf.name_scope("init", values=graph_parents):
        # The shape of the _VectorStudentT distribution is governed by the
        # relationship between df.batch_shape and affine.batch_shape. In
        # pseudocode the basic procedure is:
        #   if df.batch_shape is scalar:
        #     if affine.batch_shape is not scalar:
        #       # broadcast distribution.sample so
        #       # it has affine.batch_shape.
        #     self.batch_shape = affine.batch_shape
        #   else:
        #     if affine.batch_shape is scalar:
        #       # let affine broadcasting do its thing.
        #     self.batch_shape = df.batch_shape
        # All of the above magic is actually handled by TransformedDistribution.
        # Here we really only need to collect the affine.batch_shape and decide
        # what we're going to pass in to TransformedDistribution's
        # (override) batch_shape arg.
        affine = bijectors.Affine(
            shift=loc,
            scale_identity_multiplier=scale_identity_multiplier,
            scale_diag=scale_diag,
            scale_tril=scale_tril,
            scale_perturb_factor=scale_perturb_factor,
            scale_perturb_diag=scale_perturb_diag,
            validate_args=validate_args)
        distribution = student_t.StudentT(
            df=df,
            loc=tf.zeros([], dtype=affine.dtype),
            scale=tf.ones([], dtype=affine.dtype))
        batch_shape, override_event_shape = (
            distribution_util.shapes_from_loc_and_scale(
                affine.shift, affine.scale))
        override_batch_shape = distribution_util.pick_vector(
            distribution.is_scalar_batch(), batch_shape,
            tf.constant([], dtype=tf.int32))
        super(_VectorStudentT, self).__init__(
            distribution=distribution,
            bijector=affine,
            batch_shape=override_batch_shape,
            event_shape=override_event_shape,
            validate_args=validate_args,
            name=name)
        self._parameters = parameters
  def __init__(self,
               mixture_distribution,
               components_distribution,
               validate_args=False,
               allow_nan_stats=True,
               name="MixtureSameFamily"):
    """Construct a `MixtureSameFamily` distribution.

    Args:
      mixture_distribution: `tfp.distributions.Categorical`-like instance.
        Manages the probability of selecting components. The number of
        categories must match the rightmost batch dimension of the
        `components_distribution`. Must have either scalar `batch_shape` or
        `batch_shape` matching `components_distribution.batch_shape[:-1]`.
      components_distribution: `tfp.distributions.Distribution`-like instance.
        Right-most batch dimension indexes components.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: `if not mixture_distribution.dtype.is_integer`.
      ValueError: if mixture_distribution does not have scalar `event_shape`.
      ValueError: if `mixture_distribution.batch_shape` and
        `components_distribution.batch_shape[:-1]` are both fully defined and
        the former is neither scalar nor equal to the latter.
      ValueError: if `mixture_distribution` categories does not equal
        `components_distribution` rightmost batch shape.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      self._mixture_distribution = mixture_distribution
      self._components_distribution = components_distribution
      self._runtime_assertions = []

      s = components_distribution.event_shape_tensor()
      self._event_ndims = (
          s.shape[0].value if s.shape.with_rank_at_least(1)[0].value is not None
          else tf.shape(s)[0])

      if not mixture_distribution.dtype.is_integer:
        raise ValueError(
            "`mixture_distribution.dtype` ({}) is not over integers".format(
                mixture_distribution.dtype.name))

      if (mixture_distribution.event_shape.ndims is not None
          and mixture_distribution.event_shape.ndims != 0):
        raise ValueError("`mixture_distribution` must have scalar `event_dim`s")
      elif validate_args:
        self._runtime_assertions += [
            tf.assert_rank(
                mixture_distribution.event_shape_tensor(), 0,
                message="`mixture_distribution` must have scalar `event_dim`s"),
        ]

      mdbs = mixture_distribution.batch_shape
      cdbs = components_distribution.batch_shape.with_rank_at_least(1)[:-1]
      if mdbs.is_fully_defined() and cdbs.is_fully_defined():
        if mdbs.ndims != 0 and mdbs != cdbs:
          raise ValueError(
              "`mixture_distribution.batch_shape` (`{}`) is not "
              "compatible with `components_distribution.batch_shape` "
              "(`{}`)".format(mdbs.as_list(), cdbs.as_list()))
      elif validate_args:
        mdbs = mixture_distribution.batch_shape_tensor()
        cdbs = components_distribution.batch_shape_tensor()[:-1]
        self._runtime_assertions += [
            tf.assert_equal(
                distribution_utils.pick_vector(
                    mixture_distribution.is_scalar_batch(), cdbs, mdbs),
                cdbs,
                message=(
                    "`mixture_distribution.batch_shape` is not "
                    "compatible with `components_distribution.batch_shape`"))]

      km = mixture_distribution.logits.shape.with_rank_at_least(1)[-1].value
      kc = components_distribution.batch_shape.with_rank_at_least(1)[-1].value
      if km is not None and kc is not None and km != kc:
        raise ValueError("`mixture_distribution components` ({}) does not "
                         "equal `components_distribution.batch_shape[-1]` "
                         "({})".format(km, kc))
      elif validate_args:
        km = tf.shape(mixture_distribution.logits)[-1]
        kc = components_distribution.batch_shape_tensor()[-1]
        self._runtime_assertions += [
            tf.assert_equal(
                km, kc,
                message=("`mixture_distribution components` does not equal "
                         "`components_distribution.batch_shape[-1:]`")),
        ]
      elif km is None:
        km = tf.shape(mixture_distribution.logits)[-1]

      self._num_components = km

      super(MixtureSameFamily, self).__init__(
          dtype=self._components_distribution.dtype,
          reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          graph_parents=(
              self._mixture_distribution._graph_parents  # pylint: disable=protected-access
              + self._components_distribution._graph_parents),  # pylint: disable=protected-access
          name=name)
  def _sample_n(self, n, seed=None):
    stream = seed_stream.SeedStream(seed, salt="VectorDiffeomixture")
    x = self.distribution.sample(
        sample_shape=concat_vectors(
            [n],
            self.batch_shape_tensor(),
            self.event_shape_tensor()),
        seed=stream())   # shape: [n, B, e]
    x = [aff.forward(x) for aff in self.endpoint_affine]

    # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
    # ids as a [n]-shaped vector.
    batch_size = self.batch_shape.num_elements()
    if batch_size is None:
      batch_size = tf.reduce_prod(self.batch_shape_tensor())
    mix_batch_size = self.mixture_distribution.batch_shape.num_elements()
    if mix_batch_size is None:
      mix_batch_size = tf.reduce_prod(
          self.mixture_distribution.batch_shape_tensor())
    ids = self.mixture_distribution.sample(
        sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(
                self.is_scalar_batch(),
                np.int32([]),
                [batch_size // mix_batch_size])),
        seed=stream())
    # We need to flatten batch dims in case mixture_distribution has its own
    # batch dims.
    ids = tf.reshape(
        ids,
        shape=concat_vectors([n],
                             distribution_util.pick_vector(
                                 self.is_scalar_batch(), np.int32([]),
                                 np.int32([-1]))))

    # Stride `components * quadrature_size` for `batch_size` number of times.
    stride = self.grid.shape.with_rank_at_least(
        2)[-2:].num_elements()
    if stride is None:
      stride = tf.reduce_prod(tf.shape(self.grid)[-2:])
    offset = tf.range(
        start=0, limit=batch_size * stride, delta=stride, dtype=ids.dtype)

    weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset)
    # At this point, weight flattened all batch dims into one.
    # We also need to append a singleton to broadcast with event dims.
    if self.batch_shape.is_fully_defined():
      new_shape = [-1] + self.batch_shape.as_list() + [1]
    else:
      new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]), axis=0)
    weight = tf.reshape(weight, shape=new_shape)

    if len(x) != 2:
      # We actually should have already triggered this exception. However as a
      # policy we're putting this exception wherever we exploit the bimixture
      # assumption.
      raise NotImplementedError("Currently only bimixtures are supported; "
                                "len(scale)={} is not 2.".format(len(x)))

    # Alternatively:
    # x = weight * x[0] + (1. - weight) * x[1]
    x = weight * (x[0] - x[1]) + x[1]

    return x
Esempio n. 21
0
    def _sample_n(self, n, seed=None):
        stream = seed_stream.SeedStream(seed, salt="VectorDiffeomixture")
        x = self.distribution.sample(sample_shape=concat_vectors(
            [n], self.batch_shape_tensor(), self.event_shape_tensor()),
                                     seed=stream())  # shape: [n, B, e]
        x = [aff.forward(x) for aff in self.endpoint_affine]

        # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get
        # ids as a [n]-shaped vector.
        batch_size = tensorshape_util.num_elements(self.batch_shape)
        if batch_size is None:
            batch_size = tf.reduce_prod(input_tensor=self.batch_shape_tensor())
        mix_batch_size = tensorshape_util.num_elements(
            self.mixture_distribution.batch_shape)
        if mix_batch_size is None:
            mix_batch_size = tf.reduce_prod(
                input_tensor=self.mixture_distribution.batch_shape_tensor())
        ids = self.mixture_distribution.sample(sample_shape=concat_vectors(
            [n],
            distribution_util.pick_vector(self.is_scalar_batch(), np.int32([]),
                                          [batch_size // mix_batch_size])),
                                               seed=stream())
        # We need to flatten batch dims in case mixture_distribution has its own
        # batch dims.
        ids = tf.reshape(ids,
                         shape=concat_vectors([n],
                                              distribution_util.pick_vector(
                                                  self.is_scalar_batch(),
                                                  np.int32([]),
                                                  np.int32([-1]))))

        # Stride `components * quadrature_size` for `batch_size` number of times.
        stride = tensorshape_util.num_elements(
            tensorshape_util.with_rank_at_least(self.grid.shape, 2)[-2:])
        if stride is None:
            stride = tf.reduce_prod(input_tensor=tf.shape(
                input=self.grid)[-2:])
        offset = tf.range(start=0,
                          limit=batch_size * stride,
                          delta=stride,
                          dtype=ids.dtype)

        weight = tf.gather(tf.reshape(self.grid, shape=[-1]), ids + offset)
        # At this point, weight flattened all batch dims into one.
        # We also need to append a singleton to broadcast with event dims.
        if tensorshape_util.is_fully_defined(self.batch_shape):
            new_shape = [-1] + tensorshape_util.as_list(self.batch_shape) + [1]
        else:
            new_shape = tf.concat(([-1], self.batch_shape_tensor(), [1]),
                                  axis=0)
        weight = tf.reshape(weight, shape=new_shape)

        if len(x) != 2:
            # We actually should have already triggered this exception. However as a
            # policy we're putting this exception wherever we exploit the bimixture
            # assumption.
            raise NotImplementedError(
                "Currently only bimixtures are supported; "
                "len(scale)={} is not 2.".format(len(x)))

        # Alternatively:
        # x = weight * x[0] + (1. - weight) * x[1]
        x = weight * (x[0] - x[1]) + x[1]

        return x
Esempio n. 22
0
    def __init__(self,
                 mixture_distribution,
                 components_distribution,
                 reparameterize=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MixtureSameFamily"):
        """Construct a `MixtureSameFamily` distribution.

    Args:
      mixture_distribution: `tfp.distributions.Categorical`-like instance.
        Manages the probability of selecting components. The number of
        categories must match the rightmost batch dimension of the
        `components_distribution`. Must have either scalar `batch_shape` or
        `batch_shape` matching `components_distribution.batch_shape[:-1]`.
      components_distribution: `tfp.distributions.Distribution`-like instance.
        Right-most batch dimension indexes components.
      reparameterize: Python `bool`, default `False`. Whether to reparameterize
        samples of the distribution using implicit reparameterization gradients
        [(Figurnov et al., 2018)][1]. The gradients for the mixture logits are
        equivalent to the ones described by [(Graves, 2016)][2]. The gradients
        for the components parameters are also computed using implicit
        reparameterization (as opposed to ancestral sampling), meaning that
        all components are updated every step.
        Only works when:
          (1) components_distribution is fully reparameterized;
          (2) components_distribution is either a scalar distribution or
          fully factorized (tfd.Independent applied to a scalar distribution);
          (3) batch shape has a known rank.
        Experimental, may be slow and produce infs/NaNs.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: `if not mixture_distribution.dtype.is_integer`.
      ValueError: if mixture_distribution does not have scalar `event_shape`.
      ValueError: if `mixture_distribution.batch_shape` and
        `components_distribution.batch_shape[:-1]` are both fully defined and
        the former is neither scalar nor equal to the latter.
      ValueError: if `mixture_distribution` categories does not equal
        `components_distribution` rightmost batch shape.

    #### References

    [1]: Michael Figurnov, Shakir Mohamed and Andriy Mnih. Implicit
         reparameterization gradients. In _Neural Information Processing
         Systems_, 2018. https://arxiv.org/abs/1805.08498

    [2]: Alex Graves. Stochastic Backpropagation through Mixture Density
         Distributions. _arXiv_, 2016. https://arxiv.org/abs/1607.05690
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._mixture_distribution = mixture_distribution
            self._components_distribution = components_distribution
            self._runtime_assertions = []

            s = components_distribution.event_shape_tensor()
            self._event_ndims = tf.compat.dimension_value(s.shape[0])
            if self._event_ndims is None:
                self._event_ndims = tf.size(input=s)
            self._event_size = tf.reduce_prod(input_tensor=s)

            if not mixture_distribution.dtype.is_integer:
                raise ValueError(
                    "`mixture_distribution.dtype` ({}) is not over integers".
                    format(mixture_distribution.dtype.name))

            if (mixture_distribution.event_shape.ndims is not None
                    and mixture_distribution.event_shape.ndims != 0):
                raise ValueError(
                    "`mixture_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        tf.size(
                            input=mixture_distribution.event_shape_tensor()),
                        0,
                        message=
                        "`mixture_distribution` must have scalar `event_dim`s"
                    ),
                ]

            mdbs = mixture_distribution.batch_shape
            cdbs = components_distribution.batch_shape.with_rank_at_least(
                1)[:-1]
            if mdbs.is_fully_defined() and cdbs.is_fully_defined():
                if mdbs.ndims != 0 and mdbs != cdbs:
                    raise ValueError(
                        "`mixture_distribution.batch_shape` (`{}`) is not "
                        "compatible with `components_distribution.batch_shape` "
                        "(`{}`)".format(mdbs.as_list(), cdbs.as_list()))
            elif validate_args:
                mdbs = mixture_distribution.batch_shape_tensor()
                cdbs = components_distribution.batch_shape_tensor()[:-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        distribution_utils.pick_vector(
                            mixture_distribution.is_scalar_batch(), cdbs,
                            mdbs),
                        cdbs,
                        message=
                        ("`mixture_distribution.batch_shape` is not "
                         "compatible with `components_distribution.batch_shape`"
                         ))
                ]

            km = tf.compat.dimension_value(
                mixture_distribution.logits.shape.with_rank_at_least(1)[-1])
            kc = tf.compat.dimension_value(
                components_distribution.batch_shape.with_rank_at_least(1)[-1])
            if km is not None and kc is not None and km != kc:
                raise ValueError(
                    "`mixture_distribution components` ({}) does not "
                    "equal `components_distribution.batch_shape[-1]` "
                    "({})".format(km, kc))
            elif validate_args:
                km = tf.shape(input=mixture_distribution.logits)[-1]
                kc = components_distribution.batch_shape_tensor()[-1]
                self._runtime_assertions += [
                    assert_util.assert_equal(
                        km,
                        kc,
                        message=(
                            "`mixture_distribution components` does not equal "
                            "`components_distribution.batch_shape[-1:]`")),
                ]
            elif km is None:
                km = tf.shape(input=mixture_distribution.logits)[-1]

            self._num_components = km

            self._reparameterize = reparameterize
            if reparameterize:
                # Note: tfd.Independent passes through the reparameterization type hence
                # we do not need separate logic for Independent.
                if (self._components_distribution.reparameterization_type !=
                        reparameterization.FULLY_REPARAMETERIZED):
                    raise ValueError("Cannot reparameterize a mixture of "
                                     "non-reparameterized components.")
                reparameterization_type = reparameterization.FULLY_REPARAMETERIZED
            else:
                reparameterization_type = reparameterization.NOT_REPARAMETERIZED

            super(MixtureSameFamily, self).__init__(
                dtype=self._components_distribution.dtype,
                reparameterization_type=reparameterization_type,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=(
                    self._mixture_distribution._graph_parents  # pylint: disable=protected-access
                    + self._components_distribution._graph_parents),  # pylint: disable=protected-access
                name=name)
 def _event_shape_tensor(self):
   return self.bijector.forward_event_shape_tensor(
       distribution_util.pick_vector(
           self._is_event_override,
           self._override_event_shape,
           self.distribution.event_shape_tensor()))
Esempio n. 24
0
  def _forward_log_det_jacobian(self, x):
    # Let Y be a symmetric, positive definite matrix and write:
    #   Y = X X.T
    # where X is lower-triangular.
    #
    # Observe that,
    #   dY[i,j]/dX[a,b]
    #   = d/dX[a,b] { X[i,:] X[j,:] }
    #   = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] }
    #
    # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is
    # symmetric and X is lower-triangular, we need vectors of dimension:
    #   d = p (p + 1) / 2
    # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e.,
    #   k = { i (i + 1) / 2 + j   i>=j
    #       { undef               i<j
    # and assume zero-based indexes. When k is undef, the element is dropped.
    # Example:
    #           j      k
    #        0 1 2 3  /
    #    0 [ 0 . . . ]
    # i  1 [ 1 2 . . ]
    #    2 [ 3 4 5 . ]
    #    3 [ 6 7 8 9 ]
    # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With
    # slight abuse: k(i,j)=undef means the element is dropped.)
    #
    # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are
    # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b.
    # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since:
    # (1) j<=i<a thus i,j!=a.
    # (2) i=a>j  thus i,j!=a.
    #
    # Since the Jacobian is lower-triangular, we need only compute the product
    # of diagonal elements:
    #   d vec[Y] / d vec[X] @[k(i,j), k(i,j)]
    #   = X[j,j] + I[i=j] X[i,j]
    #   = 2 X[j,j].
    # Since there is a 2 X[j,j] term for every lower-triangular element of X we
    # conclude:
    #   |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}.
    diag = tf.linalg.diag_part(x)

    # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output
    # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the
    # output is unchanged.
    diag = self._make_columnar(diag)

    with tf.control_dependencies(self._assertions(x)):
      # Create a vector equal to: [p, p-1, ..., 2, 1].
      if tf.compat.dimension_value(x.shape[-1]) is None:
        p_int = tf.shape(x)[-1]
        p_float = tf.cast(p_int, dtype=x.dtype)
      else:
        p_int = tf.compat.dimension_value(x.shape[-1])
        p_float = dtype_util.as_numpy_dtype(x.dtype)(p_int)
      exponents = tf.linspace(p_float, 1., p_int)

      sum_weighted_log_diag = tf.squeeze(
          tf.matmul(tf.math.log(diag), exponents[..., tf.newaxis]), axis=-1)
      fldj = p_float * np.log(2.) + sum_weighted_log_diag

      # We finally need to undo adding an extra column in non-scalar cases
      # where there is a single matrix as input.
      if tensorshape_util.rank(x.shape) is not None:
        if tensorshape_util.rank(x.shape) == 2:
          fldj = tf.squeeze(fldj, axis=-1)
        return fldj

      shape = tf.shape(fldj)
      maybe_squeeze_shape = tf.concat([
          shape[:-1],
          distribution_util.pick_vector(
              tf.equal(tf.rank(x), 2),
              np.array([], dtype=np.int32), shape[-1:])], 0)
      return tf.reshape(fldj, maybe_squeeze_shape)
    def _parameter_control_dependencies(self, is_init):
        assertions = []

        if is_init and not dtype_util.is_integer(
                self.mixture_distribution.dtype):
            raise ValueError(
                '`mixture_distribution.dtype` ({}) is not over integers'.
                format(dtype_util.name(self.mixture_distribution.dtype)))

        if tensorshape_util.rank(
                self.mixture_distribution.event_shape) is not None:
            if tensorshape_util.rank(
                    self.mixture_distribution.event_shape) != 0:
                raise ValueError(
                    '`mixture_distribution` must have scalar `event_dim`s')
        elif self.validate_args:
            assertions += [
                assert_util.assert_equal(
                    tf.size(self.mixture_distribution.event_shape_tensor()),
                    0,
                    message=
                    '`mixture_distribution` must have scalar `event_dim`s'),
            ]

        # pylint: disable=protected-access
        mixture_dist_param = (self.mixture_distribution._probs
                              if self.mixture_distribution._logits is None else
                              self.mixture_distribution._logits)
        km = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(mixture_dist_param.shape,
                                                1)[-1])
        kc = tf.compat.dimension_value(
            tensorshape_util.with_rank_at_least(
                self.components_distribution.batch_shape, 1)[-1])
        component_bst = None
        if km is not None and kc is not None:
            if km != kc:
                raise ValueError(
                    '`mixture_distribution` components ({}) does not '
                    'equal `components_distribution.batch_shape[-1]` '
                    '({})'.format(km, kc))
        elif self.validate_args:
            if km is None:
                mixture_dist_param = tf.convert_to_tensor(mixture_dist_param)
                km = tf.shape(mixture_dist_param)[-1]
            if kc is None:
                component_bst = self.components_distribution.batch_shape_tensor(
                )
                kc = component_bst[-1]
            assertions += [
                assert_util.assert_equal(
                    km,
                    kc,
                    message=(
                        '`mixture_distribution` components does not equal '
                        '`components_distribution.batch_shape[-1]`')),
            ]

        mdbs = self.mixture_distribution.batch_shape
        cdbs = tensorshape_util.with_rank_at_least(
            self.components_distribution.batch_shape, 1)[:-1]
        if (tensorshape_util.is_fully_defined(mdbs)
                and tensorshape_util.is_fully_defined(cdbs)):
            if tensorshape_util.rank(mdbs) != 0 and mdbs != cdbs:
                raise ValueError(
                    '`mixture_distribution.batch_shape` (`{}`) is not '
                    'compatible with `components_distribution.batch_shape` '
                    '(`{}`)'.format(tensorshape_util.as_list(mdbs),
                                    tensorshape_util.as_list(cdbs)))
        elif self.validate_args:
            if not tensorshape_util.is_fully_defined(mdbs):
                mixture_dist_param = tf.convert_to_tensor(mixture_dist_param)
                mdbs = tf.shape(mixture_dist_param)[:-1]
            if not tensorshape_util.is_fully_defined(cdbs):
                if component_bst is None:
                    component_bst = self.components_distribution.batch_shape_tensor(
                    )
                cdbs = component_bst[:-1]
            assertions += [
                assert_util.assert_equal(
                    distribution_utils.pick_vector(
                        tf.equal(tf.shape(mdbs)[0], 0), cdbs, mdbs),
                    cdbs,
                    message=(
                        '`mixture_distribution.batch_shape` is not '
                        'compatible with `components_distribution.batch_shape`'
                    ))
            ]

        return assertions
Esempio n. 26
0
    def __init__(self,
                 mixture_distribution,
                 components_distribution,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="MixtureSameFamily"):
        """Construct a `MixtureSameFamily` distribution.

    Args:
      mixture_distribution: `tfp.distributions.Categorical`-like instance.
        Manages the probability of selecting components. The number of
        categories must match the rightmost batch dimension of the
        `components_distribution`. Must have either scalar `batch_shape` or
        `batch_shape` matching `components_distribution.batch_shape[:-1]`.
      components_distribution: `tfp.distributions.Distribution`-like instance.
        Right-most batch dimension indexes components.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: `if not mixture_distribution.dtype.is_integer`.
      ValueError: if mixture_distribution does not have scalar `event_shape`.
      ValueError: if `mixture_distribution.batch_shape` and
        `components_distribution.batch_shape[:-1]` are both fully defined and
        the former is neither scalar nor equal to the latter.
      ValueError: if `mixture_distribution` categories does not equal
        `components_distribution` rightmost batch shape.
    """
        parameters = dict(locals())
        with tf.name_scope(name) as name:
            self._mixture_distribution = mixture_distribution
            self._components_distribution = components_distribution
            self._runtime_assertions = []

            s = components_distribution.event_shape_tensor()
            self._event_ndims = (s.shape[0].value
                                 if s.shape.with_rank_at_least(1)[0].value
                                 is not None else tf.shape(s)[0])

            if not mixture_distribution.dtype.is_integer:
                raise ValueError(
                    "`mixture_distribution.dtype` ({}) is not over integers".
                    format(mixture_distribution.dtype.name))

            if (mixture_distribution.event_shape.ndims is not None
                    and mixture_distribution.event_shape.ndims != 0):
                raise ValueError(
                    "`mixture_distribution` must have scalar `event_dim`s")
            elif validate_args:
                self._runtime_assertions += [
                    control_flow_ops.assert_has_rank(
                        mixture_distribution.event_shape_tensor(),
                        0,
                        message=
                        "`mixture_distribution` must have scalar `event_dim`s"
                    ),
                ]

            mdbs = mixture_distribution.batch_shape
            cdbs = components_distribution.batch_shape.with_rank_at_least(
                1)[:-1]
            if mdbs.is_fully_defined() and cdbs.is_fully_defined():
                if mdbs.ndims != 0 and mdbs != cdbs:
                    raise ValueError(
                        "`mixture_distribution.batch_shape` (`{}`) is not "
                        "compatible with `components_distribution.batch_shape` "
                        "(`{}`)".format(mdbs.as_list(), cdbs.as_list()))
            elif validate_args:
                mdbs = mixture_distribution.batch_shape_tensor()
                cdbs = components_distribution.batch_shape_tensor()[:-1]
                self._runtime_assertions += [
                    control_flow_ops.assert_equal(
                        distribution_utils.pick_vector(
                            mixture_distribution.is_scalar_batch(), cdbs,
                            mdbs),
                        cdbs,
                        message=
                        ("`mixture_distribution.batch_shape` is not "
                         "compatible with `components_distribution.batch_shape`"
                         ))
                ]

            km = mixture_distribution.logits.shape.with_rank_at_least(
                1)[-1].value
            kc = components_distribution.batch_shape.with_rank_at_least(
                1)[-1].value
            if km is not None and kc is not None and km != kc:
                raise ValueError(
                    "`mixture_distribution components` ({}) does not "
                    "equal `components_distribution.batch_shape[-1]` "
                    "({})".format(km, kc))
            elif validate_args:
                km = tf.shape(mixture_distribution.logits)[-1]
                kc = components_distribution.batch_shape_tensor()[-1]
                self._runtime_assertions += [
                    control_flow_ops.assert_equal(
                        km,
                        kc,
                        message=(
                            "`mixture_distribution components` does not equal "
                            "`components_distribution.batch_shape[-1:]`")),
                ]
            elif km is None:
                km = tf.shape(mixture_distribution.logits)[-1]

            self._num_components = km

            super(MixtureSameFamily, self).__init__(
                dtype=self._components_distribution.dtype,
                reparameterization_type=reparameterization.NOT_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=(
                    self._mixture_distribution._graph_parents  # pylint: disable=protected-access
                    + self._components_distribution._graph_parents),  # pylint: disable=protected-access
                name=name)
 def _batch_shape_tensor(self):
   return distribution_util.pick_vector(
       self._is_batch_override,
       self._override_batch_shape,
       self.distribution.batch_shape_tensor())
  def _forward_log_det_jacobian(self, x):
    # Let Y be a symmetric, positive definite matrix and write:
    #   Y = X X.T
    # where X is lower-triangular.
    #
    # Observe that,
    #   dY[i,j]/dX[a,b]
    #   = d/dX[a,b] { X[i,:] X[j,:] }
    #   = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] }
    #
    # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is
    # symmetric and X is lower-triangular, we need vectors of dimension:
    #   d = p (p + 1) / 2
    # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e.,
    #   k = { i (i + 1) / 2 + j   i>=j
    #       { undef               i<j
    # and assume zero-based indexes. When k is undef, the element is dropped.
    # Example:
    #           j      k
    #        0 1 2 3  /
    #    0 [ 0 . . . ]
    # i  1 [ 1 2 . . ]
    #    2 [ 3 4 5 . ]
    #    3 [ 6 7 8 9 ]
    # Write vec[.] to indicate transforming a matrix to vector via k(i,j). (With
    # slight abuse: k(i,j)=undef means the element is dropped.)
    #
    # We now show d vec[Y] / d vec[X] is lower triangular. Assuming both are
    # defined, observe that k(i,j) < k(a,b) iff (1) i<a or (2) i=a and j<b.
    # In both cases dvec[Y]/dvec[X]@[k(i,j),k(a,b)] = 0 since:
    # (1) j<=i<a thus i,j!=a.
    # (2) i=a>j  thus i,j!=a.
    #
    # Since the Jacobian is lower-triangular, we need only compute the product
    # of diagonal elements:
    #   d vec[Y] / d vec[X] @[k(i,j), k(i,j)]
    #   = X[j,j] + I[i=j] X[i,j]
    #   = 2 X[j,j].
    # Since there is a 2 X[j,j] term for every lower-triangular element of X we
    # conclude:
    #   |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}.
    diag = tf.matrix_diag_part(x)

    # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output
    # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the
    # output is unchanged.
    diag = self._make_columnar(diag)

    if self.validate_args:
      is_matrix = tf.assert_rank_at_least(
          x, 2, message="Input must be a (batch of) matrix.")
      shape = tf.shape(x)
      is_square = tf.assert_equal(
          shape[-2],
          shape[-1],
          message="Input must be a (batch of) square matrix.")
      # Assuming lower-triangular means we only need check diag>0.
      is_positive_definite = tf.assert_positive(
          diag, message="Input must be positive definite.")
      x = control_flow_ops.with_dependencies(
          [is_matrix, is_square, is_positive_definite], x)

    # Create a vector equal to: [p, p-1, ..., 2, 1].
    if x.shape.ndims is None or x.shape[-1].value is None:
      p_int = tf.shape(x)[-1]
      p_float = tf.cast(p_int, dtype=x.dtype)
    else:
      p_int = x.shape[-1].value
      p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype)
    exponents = tf.linspace(p_float, 1., p_int)

    sum_weighted_log_diag = tf.squeeze(
        tf.matmul(tf.log(diag), exponents[..., tf.newaxis]), axis=-1)
    fldj = p_float * np.log(2.) + sum_weighted_log_diag

    # We finally need to undo adding an extra column in non-scalar cases
    # where there is a single matrix as input.
    if x.shape.ndims is not None:
      if x.shape.ndims == 2:
        fldj = tf.squeeze(fldj, axis=-1)
      return fldj

    shape = tf.shape(fldj)
    maybe_squeeze_shape = tf.concat([
        shape[:-1],
        distribution_util.pick_vector(
            tf.equal(tf.rank(x), 2),
            np.array([], dtype=np.int32), shape[-1:])], 0)
    return tf.reshape(fldj, maybe_squeeze_shape)
Esempio n. 29
0
    def __init__(self,
                 df,
                 loc=None,
                 scale_identity_multiplier=None,
                 scale_diag=None,
                 scale_tril=None,
                 scale_perturb_factor=None,
                 scale_perturb_diag=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="VectorStudentT"):
        """Instantiates the vector Student's t-distributions on `R^k`.

    The `batch_shape` is the broadcast between `df.batch_shape` and
    `Affine.batch_shape` where `Affine` is constructed from `loc` and
    `scale_*` arguments.

    The `event_shape` is the event shape of `Affine.event_shape`.

    Args:
      df: Floating-point `Tensor`. The degrees of freedom of the
        distribution(s). `df` must contain only positive values. Must be
        scalar if `loc`, `scale_*` imply non-scalar batch_shape or must have the
        same `batch_shape` implied by `loc`, `scale_*`.
      loc: Floating-point `Tensor`. If this is set to `None`, no `loc` is
        applied.
      scale_identity_multiplier: floating point rank 0 `Tensor` representing a
        scaling done to the identity matrix. When `scale_identity_multiplier =
        scale_diag=scale_tril = None` then `scale += IdentityMatrix`. Otherwise
        no scaled-identity-matrix is added to `scale`.
      scale_diag: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ..., k], which represents a k x k
        diagonal matrix. When `None` no diagonal term is added to `scale`.
      scale_tril: Floating-point `Tensor` representing the diagonal matrix.
        `scale_diag` has shape [N1, N2, ..., k, k], which represents a k x k
        lower triangular matrix. When `None` no `scale_tril` term is added to
        `scale`. The upper triangular elements above the diagonal are ignored.
      scale_perturb_factor: Floating-point `Tensor` representing factor matrix
        with last two dimensions of shape `(k, r)`. When `None`, no rank-r
        update is added to `scale`.
      scale_perturb_diag: Floating-point `Tensor` representing the diagonal
        matrix. `scale_perturb_diag` has shape [N1, N2, ..., r], which
        represents an r x r Diagonal matrix. When `None` low rank updates will
        take the form `scale_perturb_factor * scale_perturb_factor.T`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        graph_parents = [
            df, loc, scale_identity_multiplier, scale_diag, scale_tril,
            scale_perturb_factor, scale_perturb_diag
        ]
        with tf.name_scope(name) as name:
            with tf.name_scope("init"):
                dtype = dtype_util.common_dtype(graph_parents, tf.float32)
                df = tf.convert_to_tensor(value=df, name="df", dtype=dtype)
                # The shape of the _VectorStudentT distribution is governed by the
                # relationship between df.batch_shape and affine.batch_shape. In
                # pseudocode the basic procedure is:
                #   if df.batch_shape is scalar:
                #     if affine.batch_shape is not scalar:
                #       # broadcast distribution.sample so
                #       # it has affine.batch_shape.
                #     self.batch_shape = affine.batch_shape
                #   else:
                #     if affine.batch_shape is scalar:
                #       # let affine broadcasting do its thing.
                #     self.batch_shape = df.batch_shape
                # All of the above magic is actually handled by TransformedDistribution.
                # Here we really only need to collect the affine.batch_shape and decide
                # what we're going to pass in to TransformedDistribution's
                # (override) batch_shape arg.
                affine = affine_bijector.Affine(
                    shift=loc,
                    scale_identity_multiplier=scale_identity_multiplier,
                    scale_diag=scale_diag,
                    scale_tril=scale_tril,
                    scale_perturb_factor=scale_perturb_factor,
                    scale_perturb_diag=scale_perturb_diag,
                    validate_args=validate_args,
                    dtype=dtype)
                distribution = student_t.StudentT(
                    df=df,
                    loc=tf.zeros([], dtype=affine.dtype),
                    scale=tf.ones([], dtype=affine.dtype))
                batch_shape, override_event_shape = (
                    distribution_util.shapes_from_loc_and_scale(
                        affine.shift, affine.scale))
                override_batch_shape = distribution_util.pick_vector(
                    distribution.is_scalar_batch(), batch_shape,
                    tf.constant([], dtype=tf.int32))
                super(_VectorStudentT,
                      self).__init__(distribution=distribution,
                                     bijector=affine,
                                     batch_shape=override_batch_shape,
                                     event_shape=override_event_shape,
                                     validate_args=validate_args,
                                     name=name)
                self._parameters = parameters
 def _event_shape_tensor(self):
   return self.bijector.forward_event_shape_tensor(
       distribution_util.pick_vector(
           self._is_event_override,
           self._override_event_shape,
           self.distribution.event_shape_tensor()))
 def _batch_shape_tensor(self):
   return distribution_util.pick_vector(
       self._is_batch_override,
       self._override_batch_shape,
       self.distribution.batch_shape_tensor())