Ejemplo n.º 1
0
  def __init__(self,
               scale,
               validate_args=False,
               allow_nan_stats=True,
               name="HalfNormal"):
    """Construct HalfNormals with scale `scale`.

    Args:
      scale: Floating point tensor; the scales of the distribution(s).
        Must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[scale]) as name:
      with ops.control_dependencies([check_ops.assert_positive(scale)] if
                                    validate_args else []):
        self._scale = array_ops.identity(scale, name="scale")
    super(HalfNormal, self).__init__(
        dtype=self._scale.dtype,
        reparameterization_type=distribution.FULLY_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._scale],
        name=name)
Ejemplo n.º 2
0
  def __init__(
      self,
      logits=None,
      probs=None,
      dtype=dtypes.int32,
      validate_args=False,
      allow_nan_stats=True,
      name="OneHotCategorical"):
    """Initialize OneHotCategorical distributions using class log-probabilities.

    Args:
      logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities of a
        set of Categorical distributions. The first `N - 1` dimensions index
        into a batch of independent distributions and the last dimension
        represents a vector of logits for each class. Only one of `logits` or
        `probs` should be passed in.
      probs: An N-D `Tensor`, `N >= 1`, representing the probabilities of a set
        of Categorical distributions. The first `N - 1` dimensions index into a
        batch of independent distributions and the last dimension represents a
        vector of probabilities for each class. Only one of `logits` or `probs`
        should be passed in.
      dtype: The type of the event samples (default: int32).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[logits, probs]) as name:
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          name=name, logits=logits, probs=probs, validate_args=validate_args,
          multidimensional=True)

      logits_shape_static = self._logits.get_shape().with_rank_at_least(1)
      if logits_shape_static.ndims is not None:
        self._batch_rank = ops.convert_to_tensor(
            logits_shape_static.ndims - 1,
            dtype=dtypes.int32,
            name="batch_rank")
      else:
        with ops.name_scope(name="batch_rank"):
          self._batch_rank = array_ops.rank(self._logits) - 1

      with ops.name_scope(name="event_size"):
        self._event_size = array_ops.shape(self._logits)[-1]

    super(OneHotCategorical, self).__init__(
        dtype=dtype,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._logits,
                       self._probs],
        name=name)
Ejemplo n.º 3
0
    def __init__(self,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="HalfNormal"):
        """Construct HalfNormals with scale `scale`.

    Args:
      scale: Floating point tensor; the scales of the distribution(s).
        Must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name, values=[scale]) as name:
            with ops.control_dependencies(
                [check_ops.assert_positive(scale)] if validate_args else []):
                self._scale = array_ops.identity(scale, name="scale")
        super(HalfNormal, self).__init__(
            dtype=self._scale.dtype,
            reparameterization_type=distribution.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._scale],
            name=name)
Ejemplo n.º 4
0
  def __init__(self,
               distribution_fn,
               sample0=None,
               num_steps=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Autoregressive"):
    """Construct an `Autoregressive` distribution.

    Args:
      distribution_fn: Python `callable` which constructs a
        `tf.distributions.Distribution`-like instance from a `Tensor` (e.g.,
        `sample0`). The function must respect the "autoregressive property",
        i.e., there exists a permutation of event such that each coordinate is a
        diffeomorphic function of on preceding coordinates.
      sample0: Initial input to `distribution_fn`; used to
        build the distribution in `__init__` which in turn specifies this
        distribution's properties, e.g., `event_shape`, `batch_shape`, `dtype`.
        If unspecified, then `distribution_fn` should be default constructable.
      num_steps: Number of times `distribution_fn` is composed from samples,
        e.g., `num_steps=2` implies
        `distribution_fn(distribution_fn(sample0).sample(n)).sample()`.
      validate_args: Python `bool`.  Whether to validate input with asserts.
        If `validate_args` is `False`, and the inputs are invalid,
        correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "Autoregressive".

    Raises:
      ValueError: if `num_steps` and
        `distribution_fn(sample0).event_shape.num_elements()` are both `None`.
      ValueError: if `num_steps < 1`.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name) as name:
      self._distribution_fn = distribution_fn
      self._sample0 = sample0
      self._distribution0 = (distribution_fn() if sample0 is None
                             else distribution_fn(sample0))
      if num_steps is None:
        num_steps = self._distribution0.event_shape.num_elements()
        if num_steps is None:
          raise ValueError("distribution_fn must generate a distribution "
                           "with fully known `event_shape`.")
      if num_steps < 1:
        raise ValueError("num_steps ({}) must be at least 1.".format(num_steps))
      self._num_steps = num_steps
    super(Autoregressive, self).__init__(
        dtype=self._distribution0.dtype,
        reparameterization_type=self._distribution0.reparameterization_type,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=self._distribution0._graph_parents,  # pylint: disable=protected-access
        name=name)
Ejemplo n.º 5
0
  def __init__(self,
               total_count,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Multinomial"):
    """Initialize a batch of Multinomial distributions.

    Args:
      total_count: Non-negative floating point tensor with shape broadcastable
        to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
        `N1 x ... x Nm` different Multinomial distributions. Its components
        should be equal to integer values.
      logits: Floating point tensor representing unnormalized log-probabilities
        of a positive event with shape broadcastable to
        `[N1,..., Nm, K]` `m >= 0`, and the same dtype as `total_count`. Defines
        this as a batch of `N1 x ... x Nm` different `K` class Multinomial
        distributions. Only one of `logits` or `probs` should be passed in.
      probs: Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm, K]` `m >= 0` and same dtype as `total_count`. Defines
        this as a batch of `N1 x ... x Nm` different `K` class Multinomial
        distributions. `probs`'s components in the last portion of its shape
        should sum to `1`. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[total_count, logits, probs]) as name:
      self._total_count = ops.convert_to_tensor(total_count, name="total_count")
      if validate_args:
        self._total_count = (
            distribution_util.embed_check_nonnegative_integer_form(
                self._total_count))
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits=logits,
          probs=probs,
          multidimensional=True,
          validate_args=validate_args,
          name=name)
      self._mean_val = self._total_count[..., array_ops.newaxis] * self._probs
    super(Multinomial, self).__init__(
        dtype=self._probs.dtype,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._total_count,
                       self._logits,
                       self._probs],
        name=name)
    def __init__(self,
                 total_count,
                 logits=None,
                 probs=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Multinomial"):
        """Initialize a batch of Multinomial distributions.

    Args:
      total_count: Non-negative floating point tensor with shape broadcastable
        to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
        `N1 x ... x Nm` different Multinomial distributions. Its components
        should be equal to integer values.
      logits: Floating point tensor representing unnormalized log-probabilities
        of a positive event with shape broadcastable to
        `[N1,..., Nm, K]` `m >= 0`, and the same dtype as `total_count`. Defines
        this as a batch of `N1 x ... x Nm` different `K` class Multinomial
        distributions. Only one of `logits` or `probs` should be passed in.
      probs: Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm, K]` `m >= 0` and same dtype as `total_count`. Defines
        this as a batch of `N1 x ... x Nm` different `K` class Multinomial
        distributions. `probs`'s components in the last portion of its shape
        should sum to `1`. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name, values=[total_count, logits, probs]) as name:
            self._total_count = ops.convert_to_tensor(total_count,
                                                      name="total_count")
            if validate_args:
                self._total_count = (
                    distribution_util.embed_check_nonnegative_integer_form(
                        self._total_count))
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits=logits,
                probs=probs,
                multidimensional=True,
                validate_args=validate_args,
                name=name)
            self._mean_val = self._total_count[...,
                                               array_ops.newaxis] * self._probs
        super(Multinomial, self).__init__(
            dtype=self._probs.dtype,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._total_count, self._logits, self._probs],
            name=name)
    def __init__(self,
                 total_count,
                 concentration,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="DirichletMultinomial"):
        """Initialize a batch of DirichletMultinomial distributions.

    Args:
      total_count:  Non-negative floating point tensor, whose dtype is the same
        as `concentration`. The shape is broadcastable to `[N1,..., Nm]` with
        `m >= 0`. Defines this as a batch of `N1 x ... x Nm` different
        Dirichlet multinomial distributions. Its components should be equal to
        integer values.
      concentration: Positive floating point tensor, whose dtype is the
        same as `n` with shape broadcastable to `[N1,..., Nm, K]` `m >= 0`.
        Defines this as a batch of `N1 x ... x Nm` different `K` class Dirichlet
        multinomial distributions.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name, values=[total_count, concentration]) as name:
            # Broadcasting works because:
            # * The broadcasting convention is to prepend dimensions of size [1], and
            #   we use the last dimension for the distribution, whereas
            #   the batch dimensions are the leading dimensions, which forces the
            #   distribution dimension to be defined explicitly (i.e. it cannot be
            #   created automatically by prepending). This forces enough explicitness.
            # * All calls involving `counts` eventually require a broadcast between
            #  `counts` and concentration.
            self._total_count = ops.convert_to_tensor(total_count,
                                                      name="total_count")
            if validate_args:
                self._total_count = (
                    distribution_util.embed_check_nonnegative_integer_form(
                        self._total_count))
            self._concentration = self._maybe_assert_valid_concentration(
                ops.convert_to_tensor(concentration, name="concentration"),
                validate_args)
            self._total_concentration = math_ops.reduce_sum(
                self._concentration, -1)
        super(DirichletMultinomial, self).__init__(
            dtype=self._concentration.dtype,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            parameters=parameters,
            graph_parents=[self._total_count, self._concentration],
            name=name)
    def __init__(self,
                 total_count,
                 logits=None,
                 probs=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="NegativeBinomial"):
        """Construct NegativeBinomial distributions.

    Args:
      total_count: Non-negative floating-point `Tensor` with shape
        broadcastable to `[B1,..., Bb]` with `b >= 0` and the same dtype as
        `probs` or `logits`. Defines this as a batch of `N1 x ... x Nm`
        different Negative Binomial distributions. In practice, this represents
        the number of negative Bernoulli trials to stop at (the `total_count`
        of failures), but this is still a valid distribution when
        `total_count` is a non-integer.
      logits: Floating-point `Tensor` with shape broadcastable to
        `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions.
        Each entry represents logits for the probability of success for
        independent Negative Binomial distributions and must be in the open
        interval `(-inf, inf)`. Only one of `logits` or `probs` should be
        specified.
      probs: Positive floating-point `Tensor` with shape broadcastable to
        `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions.
        Each entry represents the probability of success for independent
        Negative Binomial distributions and must be in the open interval
        `(0, 1)`. Only one of `logits` or `probs` should be specified.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """

        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name, values=[total_count, logits, probs]) as name:
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits, probs, validate_args=validate_args, name=name)
            with ops.control_dependencies(
                [check_ops.assert_positive(total_count
                                           )] if validate_args else []):
                self._total_count = array_ops.identity(total_count)

        super(NegativeBinomial, self).__init__(
            dtype=self._probs.dtype,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._total_count, self._probs, self._logits],
            name=name)
  def __init__(self,
               rate=None,
               log_rate=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Poisson"):
    """Initialize a batch of Poisson distributions.

    Args:
      rate: Floating point tensor, the rate parameter. `rate` must be positive.
        Must specify exactly one of `rate` and `log_rate`.
      log_rate: Floating point tensor, the log of the rate parameter.
        Must specify exactly one of `rate` and `log_rate`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if none or both of `rate`, `log_rate` are specified.
      TypeError: if `rate` is not a float-type.
      TypeError: if `log_rate` is not a float-type.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[rate]) as name:
      if (rate is None) == (log_rate is None):
        raise ValueError("Must specify exactly one of `rate` and `log_rate`.")
      elif log_rate is None:
        rate = ops.convert_to_tensor(rate, name="rate")
        if not rate.dtype.is_floating:
          raise TypeError("rate.dtype ({}) is a not a float-type.".format(
              rate.dtype.name))
        with ops.control_dependencies([check_ops.assert_positive(rate)] if
                                      validate_args else []):
          self._rate = array_ops.identity(rate, name="rate")
          self._log_rate = math_ops.log(rate, name="log_rate")
      else:
        log_rate = ops.convert_to_tensor(log_rate, name="log_rate")
        if not log_rate.dtype.is_floating:
          raise TypeError("log_rate.dtype ({}) is a not a float-type.".format(
              log_rate.dtype.name))
        self._rate = math_ops.exp(log_rate, name="rate")
        self._log_rate = ops.convert_to_tensor(log_rate, name="log_rate")
    super(Poisson, self).__init__(
        dtype=self._rate.dtype,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._rate],
        name=name)
Ejemplo n.º 10
0
    def __init__(self,
                 distribution,
                 batch_shape,
                 validate_args=False,
                 allow_nan_stats=True,
                 name=None):
        """Construct BatchReshape distribution.

    Args:
      distribution: The base distribution instance to reshape. Typically an
        instance of `Distribution`.
      batch_shape: Positive `int`-like vector-shaped `Tensor` representing
        the new shape of the batch dimensions. Up to one dimension may contain
        `-1`, meaning the remainder of the batch size.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: The name to give Ops created by the initializer.
        Default value: `"BatchReshape" + distribution.name`.

    Raises:
      ValueError: if `batch_shape` is not a vector.
      ValueError: if `batch_shape` has non-positive elements.
      ValueError: if `batch_shape` size is not the same as a
        `distribution.batch_shape` size.
    """
        parameters = distribution_util.parent_frame_arguments()
        name = name or "BatchReshape" + distribution.name
        with ops.name_scope(name, values=[batch_shape]) as name:
            # The unexpanded batch shape may contain up to one dimension of -1.
            self._batch_shape_unexpanded = ops.convert_to_tensor(
                batch_shape, dtype=dtypes.int32, name="batch_shape")
            validate_init_args_statically(distribution,
                                          self._batch_shape_unexpanded)
            batch_shape, batch_shape_static, runtime_assertions = calculate_reshape(
                distribution.batch_shape_tensor(),
                self._batch_shape_unexpanded, validate_args)
            self._distribution = distribution
            self._batch_shape_ = batch_shape
            self._batch_shape_static = batch_shape_static
            self._runtime_assertions = runtime_assertions
            super(BatchReshape, self).__init__(
                dtype=distribution.dtype,
                reparameterization_type=distribution.reparameterization_type,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=([self._batch_shape_unexpanded] +
                               distribution._graph_parents),  # pylint: disable=protected-access
                name=name)
Ejemplo n.º 11
0
  def __init__(self,
               total_count,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="NegativeBinomial"):
    """Construct NegativeBinomial distributions.

    Args:
      total_count: Non-negative floating-point `Tensor` with shape
        broadcastable to `[B1,..., Bb]` with `b >= 0` and the same dtype as
        `probs` or `logits`. Defines this as a batch of `N1 x ... x Nm`
        different Negative Binomial distributions. In practice, this represents
        the number of negative Bernoulli trials to stop at (the `total_count`
        of failures), but this is still a valid distribution when
        `total_count` is a non-integer.
      logits: Floating-point `Tensor` with shape broadcastable to
        `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions.
        Each entry represents logits for the probability of success for
        independent Negative Binomial distributions and must be in the open
        interval `(-inf, inf)`. Only one of `logits` or `probs` should be
        specified.
      probs: Positive floating-point `Tensor` with shape broadcastable to
        `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions.
        Each entry represents the probability of success for independent
        Negative Binomial distributions and must be in the open interval
        `(0, 1)`. Only one of `logits` or `probs` should be specified.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """

    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[total_count, logits, probs]) as name:
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits, probs, validate_args=validate_args, name=name)
      with ops.control_dependencies(
          [check_ops.assert_positive(total_count)] if validate_args else []):
        self._total_count = array_ops.identity(total_count)

    super(NegativeBinomial, self).__init__(
        dtype=self._probs.dtype,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._total_count, self._probs, self._logits],
        name=name)
Ejemplo n.º 12
0
    def __init__(self,
                 temperature,
                 logits=None,
                 probs=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="RelaxedBernoulli"):
        """Construct RelaxedBernoulli distributions.

    Args:
      temperature: An 0-D `Tensor`, representing the temperature
        of a set of RelaxedBernoulli distributions. The temperature should be
        positive.
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent RelaxedBernoulli distribution where the probability of an
        event is sigmoid(logits). Only one of `logits` or `probs` should be
        passed in.
      probs: An N-D `Tensor` representing the probability of a positive event.
        Each entry in the `Tensor` parameterizes an independent Bernoulli
        distribution. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: If both `probs` and `logits` are passed, or if neither.
    """
        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name, values=[logits, probs, temperature]) as name:
            with ops.control_dependencies(
                [check_ops.assert_positive(temperature
                                           )] if validate_args else []):
                self._temperature = array_ops.identity(temperature,
                                                       name="temperature")
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits=logits, probs=probs, validate_args=validate_args)
            super(RelaxedBernoulli,
                  self).__init__(distribution=logistic.Logistic(
                      self._logits / self._temperature,
                      1. / self._temperature,
                      validate_args=validate_args,
                      allow_nan_stats=allow_nan_stats,
                      name=name + "/Logistic"),
                                 bijector=Sigmoid(validate_args=validate_args),
                                 validate_args=validate_args,
                                 name=name)
        self._parameters = parameters
Ejemplo n.º 13
0
  def __init__(self,
               distribution,
               batch_shape,
               validate_args=False,
               allow_nan_stats=True,
               name=None):
    """Construct BatchReshape distribution.

    Args:
      distribution: The base distribution instance to reshape. Typically an
        instance of `Distribution`.
      batch_shape: Positive `int`-like vector-shaped `Tensor` representing
        the new shape of the batch dimensions. Up to one dimension may contain
        `-1`, meaning the remainder of the batch size.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: The name to give Ops created by the initializer.
        Default value: `"BatchReshape" + distribution.name`.

    Raises:
      ValueError: if `batch_shape` is not a vector.
      ValueError: if `batch_shape` has non-positive elements.
      ValueError: if `batch_shape` size is not the same as a
        `distribution.batch_shape` size.
    """
    parameters = distribution_util.parent_frame_arguments()
    name = name or "BatchReshape" + distribution.name
    with ops.name_scope(name, values=[batch_shape]) as name:
      # The unexpanded batch shape may contain up to one dimension of -1.
      self._batch_shape_unexpanded = ops.convert_to_tensor(
          batch_shape, dtype=dtypes.int32, name="batch_shape")
      validate_init_args_statically(distribution, self._batch_shape_unexpanded)
      batch_shape, batch_shape_static, runtime_assertions = calculate_reshape(
          distribution.batch_shape_tensor(), self._batch_shape_unexpanded,
          validate_args)
      self._distribution = distribution
      self._batch_shape_ = batch_shape
      self._batch_shape_static = batch_shape_static
      self._runtime_assertions = runtime_assertions
      super(BatchReshape, self).__init__(
          dtype=distribution.dtype,
          reparameterization_type=distribution.reparameterization_type,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          graph_parents=(
              [self._batch_shape_unexpanded] + distribution._graph_parents),  # pylint: disable=protected-access
          name=name)
Ejemplo n.º 14
0
    def __init__(self,
                 df,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="StudentT"):
        """Construct Student's t distributions.

    The distributions have degree of freedom `df`, mean `loc`, and scale
    `scale`.

    The parameters `df`, `loc`, and `scale` must be shaped in a way that
    supports broadcasting (e.g. `df + loc + scale` is a valid operation).

    Args:
      df: Floating-point `Tensor`. The degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      loc: Floating-point `Tensor`. The mean(s) of the distribution(s).
      scale: Floating-point `Tensor`. The scaling factor(s) for the
        distribution(s). Note that `scale` is not technically the standard
        deviation of this distribution but has semantics more similar to
        standard deviation than variance.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if loc and scale are different dtypes.
    """
        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name, values=[df, loc, scale]) as name:
            with ops.control_dependencies(
                [check_ops.assert_positive(df)] if validate_args else []):
                self._df = array_ops.identity(df, name="df")
                self._loc = array_ops.identity(loc, name="loc")
                self._scale = array_ops.identity(scale, name="scale")
                check_ops.assert_same_float_dtype(
                    (self._df, self._loc, self._scale))
        super(StudentT, self).__init__(
            dtype=self._scale.dtype,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._df, self._loc, self._scale],
            name=name)
Ejemplo n.º 15
0
  def __init__(self,
               df,
               loc,
               scale,
               validate_args=False,
               allow_nan_stats=True,
               name="StudentT"):
    """Construct Student's t distributions.

    The distributions have degree of freedom `df`, mean `loc`, and scale
    `scale`.

    The parameters `df`, `loc`, and `scale` must be shaped in a way that
    supports broadcasting (e.g. `df + loc + scale` is a valid operation).

    Args:
      df: Floating-point `Tensor`. The degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      loc: Floating-point `Tensor`. The mean(s) of the distribution(s).
      scale: Floating-point `Tensor`. The scaling factor(s) for the
        distribution(s). Note that `scale` is not technically the standard
        deviation of this distribution but has semantics more similar to
        standard deviation than variance.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if loc and scale are different dtypes.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[df, loc, scale]) as name:
      with ops.control_dependencies([check_ops.assert_positive(df)]
                                    if validate_args else []):
        self._df = array_ops.identity(df, name="df")
        self._loc = array_ops.identity(loc, name="loc")
        self._scale = array_ops.identity(scale, name="scale")
        check_ops.assert_same_float_dtype(
            (self._df, self._loc, self._scale))
    super(StudentT, self).__init__(
        dtype=self._scale.dtype,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._df, self._loc, self._scale],
        name=name)
Ejemplo n.º 16
0
  def __init__(self,
               total_count,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Binomial"):
    """Initialize a batch of Binomial distributions.

    Args:
      total_count: Non-negative floating point tensor with shape broadcastable
        to `[N1,..., Nm]` with `m >= 0` and the same dtype as `probs` or
        `logits`. Defines this as a batch of `N1 x ...  x Nm` different Binomial
        distributions. Its components should be equal to integer values.
      logits: Floating point tensor representing the log-odds of a
        positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and
        the same dtype as `total_count`. Each entry represents logits for the
        probability of success for independent Binomial distributions. Only one
        of `logits` or `probs` should be passed in.
      probs: Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm]` `m >= 0`, `probs in [0, 1]`. Each entry represents the
        probability of success for independent Binomial distributions. Only one
        of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[total_count, logits, probs]) as name:
      self._total_count = self._maybe_assert_valid_total_count(
          ops.convert_to_tensor(total_count, name="total_count"),
          validate_args)
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits=logits,
          probs=probs,
          validate_args=validate_args,
          name=name)
    super(Binomial, self).__init__(
        dtype=self._probs.dtype,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._total_count,
                       self._logits,
                       self._probs],
        name=name)
Ejemplo n.º 17
0
  def __init__(self,
               temperature,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="RelaxedBernoulli"):
    """Construct RelaxedBernoulli distributions.

    Args:
      temperature: An 0-D `Tensor`, representing the temperature
        of a set of RelaxedBernoulli distributions. The temperature should be
        positive.
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent RelaxedBernoulli distribution where the probability of an
        event is sigmoid(logits). Only one of `logits` or `probs` should be
        passed in.
      probs: An N-D `Tensor` representing the probability of a positive event.
        Each entry in the `Tensor` parameterizes an independent Bernoulli
        distribution. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: If both `probs` and `logits` are passed, or if neither.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[logits, probs, temperature]) as name:
      with ops.control_dependencies([check_ops.assert_positive(temperature)]
                                    if validate_args else []):
        self._temperature = array_ops.identity(temperature, name="temperature")
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits=logits, probs=probs, validate_args=validate_args)
      super(RelaxedBernoulli, self).__init__(
          distribution=logistic.Logistic(
              self._logits / self._temperature,
              1. / self._temperature,
              validate_args=validate_args,
              allow_nan_stats=allow_nan_stats,
              name=name + "/Logistic"),
          bijector=Sigmoid(validate_args=validate_args),
          validate_args=validate_args,
          name=name)
    self._parameters = parameters
    def __init__(self,
                 distribution,
                 reinterpreted_batch_ndims=None,
                 validate_args=False,
                 name=None):
        """Construct a `Independent` distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims
        which will be regarded as event dims. When `None` all but the first
        batch axis (batch axis 0) will be transferred to event dimensions
        (analogous to `tf.layers.flatten`).
      validate_args: Python `bool`.  Whether to validate input with asserts.
        If `validate_args` is `False`, and the inputs are invalid,
        correct behavior is not guaranteed.
      name: The name for ops managed by the distribution.
        Default value: `Independent + distribution.name`.

    Raises:
      ValueError: if `reinterpreted_batch_ndims` exceeds
        `distribution.batch_ndims`
    """
        parameters = distribution_util.parent_frame_arguments()
        name = name or "Independent" + distribution.name
        self._distribution = distribution
        with ops.name_scope(name) as name:
            if reinterpreted_batch_ndims is None:
                reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims(
                    distribution)
            reinterpreted_batch_ndims = ops.convert_to_tensor(
                reinterpreted_batch_ndims,
                dtype=dtypes.int32,
                name="reinterpreted_batch_ndims")
            self._reinterpreted_batch_ndims = reinterpreted_batch_ndims
            self._static_reinterpreted_batch_ndims = tensor_util.constant_value(
                reinterpreted_batch_ndims)
            if self._static_reinterpreted_batch_ndims is not None:
                self._reinterpreted_batch_ndims = self._static_reinterpreted_batch_ndims
            super(Independent, self).__init__(
                dtype=self._distribution.dtype,
                reparameterization_type=self._distribution.
                reparameterization_type,
                validate_args=validate_args,
                allow_nan_stats=self._distribution.allow_nan_stats,
                parameters=parameters,
                graph_parents=([reinterpreted_batch_ndims] +
                               distribution._graph_parents),  # pylint: disable=protected-access
                name=name)
            self._runtime_assertions = self._make_runtime_assertions(
                distribution, reinterpreted_batch_ndims, validate_args)
Ejemplo n.º 19
0
 def __init__(self,
              rate,
              validate_args=False,
              allow_nan_stats=True,
              name="ExponentialWithSoftplusRate"):
   parameters = distribution_util.parent_frame_arguments()
   with ops.name_scope(name, values=[rate]) as name:
     super(ExponentialWithSoftplusRate, self).__init__(
         rate=nn.softplus(rate, name="softplus_rate"),
         validate_args=validate_args,
         allow_nan_stats=allow_nan_stats,
         name=name)
   self._parameters = parameters
Ejemplo n.º 20
0
 def __init__(self,
              rate,
              validate_args=False,
              allow_nan_stats=True,
              name="ExponentialWithSoftplusRate"):
     parameters = distribution_util.parent_frame_arguments()
     with ops.name_scope(name, values=[rate]) as name:
         super(ExponentialWithSoftplusRate,
               self).__init__(rate=nn.softplus(rate, name="softplus_rate"),
                              validate_args=validate_args,
                              allow_nan_stats=allow_nan_stats,
                              name=name)
     self._parameters = parameters
    def __init__(self,
                 total_count,
                 logits=None,
                 probs=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Binomial"):
        """Initialize a batch of Binomial distributions.

    Args:
      total_count: Non-negative floating point tensor with shape broadcastable
        to `[N1,..., Nm]` with `m >= 0` and the same dtype as `probs` or
        `logits`. Defines this as a batch of `N1 x ...  x Nm` different Binomial
        distributions. Its components should be equal to integer values.
      logits: Floating point tensor representing the log-odds of a
        positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and
        the same dtype as `total_count`. Each entry represents logits for the
        probability of success for independent Binomial distributions. Only one
        of `logits` or `probs` should be passed in.
      probs: Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm]` `m >= 0`, `probs in [0, 1]`. Each entry represents the
        probability of success for independent Binomial distributions. Only one
        of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name, values=[total_count, logits, probs]) as name:
            self._total_count = self._maybe_assert_valid_total_count(
                ops.convert_to_tensor(total_count, name="total_count"),
                validate_args)
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits=logits,
                probs=probs,
                validate_args=validate_args,
                name=name)
        super(Binomial, self).__init__(
            dtype=self._probs.dtype,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._total_count, self._logits, self._probs],
            name=name)
Ejemplo n.º 22
0
  def __init__(
      self, distribution, reinterpreted_batch_ndims=None,
      validate_args=False, name=None):
    """Construct a `Independent` distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims
        which will be regarded as event dims. When `None` all but the first
        batch axis (batch axis 0) will be transferred to event dimensions
        (analogous to `tf.layers.flatten`).
      validate_args: Python `bool`.  Whether to validate input with asserts.
        If `validate_args` is `False`, and the inputs are invalid,
        correct behavior is not guaranteed.
      name: The name for ops managed by the distribution.
        Default value: `Independent + distribution.name`.

    Raises:
      ValueError: if `reinterpreted_batch_ndims` exceeds
        `distribution.batch_ndims`
    """
    parameters = distribution_util.parent_frame_arguments()
    name = name or "Independent" + distribution.name
    self._distribution = distribution
    with ops.name_scope(name) as name:
      if reinterpreted_batch_ndims is None:
        reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims(
            distribution)
      reinterpreted_batch_ndims = ops.convert_to_tensor(
          reinterpreted_batch_ndims,
          dtype=dtypes.int32,
          name="reinterpreted_batch_ndims")
      self._reinterpreted_batch_ndims = reinterpreted_batch_ndims
      self._static_reinterpreted_batch_ndims = tensor_util.constant_value(
          reinterpreted_batch_ndims)
      if self._static_reinterpreted_batch_ndims is not None:
        self._reinterpreted_batch_ndims = self._static_reinterpreted_batch_ndims
      super(Independent, self).__init__(
          dtype=self._distribution.dtype,
          reparameterization_type=self._distribution.reparameterization_type,
          validate_args=validate_args,
          allow_nan_stats=self._distribution.allow_nan_stats,
          parameters=parameters,
          graph_parents=(
              [reinterpreted_batch_ndims] +
              distribution._graph_parents),  # pylint: disable=protected-access
          name=name)
      self._runtime_assertions = self._make_runtime_assertions(
          distribution, reinterpreted_batch_ndims, validate_args)
Ejemplo n.º 23
0
  def __init__(self,
               concentration,
               rate,
               validate_args=False,
               allow_nan_stats=True,
               name="InverseGamma"):
    """Construct InverseGamma with `concentration` and `rate` parameters.

    The parameters `concentration` and `rate` must be shaped in a way that
    supports broadcasting (e.g. `concentration + rate` is a valid operation).

    Args:
      concentration: Floating point tensor, the concentration params of the
        distribution(s). Must contain only positive values.
      rate: Floating point tensor, the inverse scale params of the
        distribution(s). Must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.


    Raises:
      TypeError: if `concentration` and `rate` are different dtypes.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[concentration, rate]) as name:
      with ops.control_dependencies([
          check_ops.assert_positive(concentration),
          check_ops.assert_positive(rate),
      ] if validate_args else []):
        self._concentration = array_ops.identity(
            concentration, name="concentration")
        self._rate = array_ops.identity(rate, name="rate")
        check_ops.assert_same_float_dtype(
            [self._concentration, self._rate])
    super(InverseGamma, self).__init__(
        dtype=self._concentration.dtype,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        parameters=parameters,
        graph_parents=[self._concentration,
                       self._rate],
        name=name)
Ejemplo n.º 24
0
    def __init__(self,
                 concentration1=None,
                 concentration0=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Beta"):
        """Initialize a batch of Beta distributions.

    Args:
      concentration1: Positive floating-point `Tensor` indicating mean
        number of successes; aka "alpha". Implies `self.dtype` and
        `self.batch_shape`, i.e.,
        `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`.
      concentration0: Positive floating-point `Tensor` indicating mean
        number of failures; aka "beta". Otherwise has same semantics as
        `concentration1`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name, values=[concentration1,
                                          concentration0]) as name:
            self._concentration1 = self._maybe_assert_valid_concentration(
                ops.convert_to_tensor(concentration1, name="concentration1"),
                validate_args)
            self._concentration0 = self._maybe_assert_valid_concentration(
                ops.convert_to_tensor(concentration0, name="concentration0"),
                validate_args)
            check_ops.assert_same_float_dtype(
                [self._concentration1, self._concentration0])
            self._total_concentration = self._concentration1 + self._concentration0
        super(Beta, self).__init__(
            dtype=self._total_concentration.dtype,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            parameters=parameters,
            graph_parents=[
                self._concentration1, self._concentration0,
                self._total_concentration
            ],
            name=name)
    def __init__(self,
                 concentration,
                 rate,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="InverseGamma"):
        """Construct InverseGamma with `concentration` and `rate` parameters.

    The parameters `concentration` and `rate` must be shaped in a way that
    supports broadcasting (e.g. `concentration + rate` is a valid operation).

    Args:
      concentration: Floating point tensor, the concentration params of the
        distribution(s). Must contain only positive values.
      rate: Floating point tensor, the inverse scale params of the
        distribution(s). Must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.


    Raises:
      TypeError: if `concentration` and `rate` are different dtypes.
    """
        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name, values=[concentration, rate]) as name:
            with ops.control_dependencies([
                    check_ops.assert_positive(concentration),
                    check_ops.assert_positive(rate),
            ] if validate_args else []):
                self._concentration = array_ops.identity(concentration,
                                                         name="concentration")
                self._rate = array_ops.identity(rate, name="rate")
                check_ops.assert_same_float_dtype(
                    [self._concentration, self._rate])
        super(InverseGamma, self).__init__(
            dtype=self._concentration.dtype,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            parameters=parameters,
            graph_parents=[self._concentration, self._rate],
            name=name)
Ejemplo n.º 26
0
 def __init__(self,
              loc,
              scale,
              validate_args=False,
              allow_nan_stats=True,
              name="NormalWithSoftplusScale"):
   parameters = distribution_util.parent_frame_arguments()
   with ops.name_scope(name, values=[scale]) as name:
     super(NormalWithSoftplusScale, self).__init__(
         loc=loc,
         scale=nn.softplus(scale, name="softplus_scale"),
         validate_args=validate_args,
         allow_nan_stats=allow_nan_stats,
         name=name)
   self._parameters = parameters
Ejemplo n.º 27
0
 def __init__(self,
              df,
              validate_args=False,
              allow_nan_stats=True,
              name="Chi2WithAbsDf"):
   parameters = distribution_util.parent_frame_arguments()
   with ops.name_scope(name, values=[df]) as name:
     super(Chi2WithAbsDf, self).__init__(
         df=math_ops.floor(
             math_ops.abs(df, name="abs_df"),
             name="floor_abs_df"),
         validate_args=validate_args,
         allow_nan_stats=allow_nan_stats,
         name=name)
   self._parameters = parameters
Ejemplo n.º 28
0
 def __init__(self,
              df,
              validate_args=False,
              allow_nan_stats=True,
              name="Chi2WithAbsDf"):
     parameters = distribution_util.parent_frame_arguments()
     with ops.name_scope(name, values=[df]) as name:
         super(Chi2WithAbsDf,
               self).__init__(df=math_ops.floor(math_ops.abs(df,
                                                             name="abs_df"),
                                                name="floor_abs_df"),
                              validate_args=validate_args,
                              allow_nan_stats=allow_nan_stats,
                              name=name)
     self._parameters = parameters
    def __init__(self,
                 logits=None,
                 probs=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Geometric"):
        """Construct Geometric distributions.

    Args:
      logits: Floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0`
        indicates the number of batch dimensions. Each entry represents logits
        for the probability of success for independent Geometric distributions
        and must be in the range `(-inf, inf]`. Only one of `logits` or `probs`
        should be specified.
      probs: Positive floating-point `Tensor` with shape `[B1, ..., Bb]`
        where `b >= 0` indicates the number of batch dimensions. Each entry
        represents the probability of success for independent Geometric
        distributions and must be in the range `(0, 1]`. Only one of `logits`
        or `probs` should be specified.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """

        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name, values=[logits, probs]) as name:
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits, probs, validate_args=validate_args, name=name)

            with ops.control_dependencies(
                [check_ops.assert_positive(self._probs
                                           )] if validate_args else []):
                self._probs = array_ops.identity(self._probs, name="probs")

        super(Geometric, self).__init__(
            dtype=self._probs.dtype,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._probs, self._logits],
            name=name)
 def __init__(self,
              concentration,
              rate,
              validate_args=False,
              allow_nan_stats=True,
              name="InverseGammaWithSoftplusConcentrationRate"):
     parameters = distribution_util.parent_frame_arguments()
     with ops.name_scope(name, values=[concentration, rate]) as name:
         super(InverseGammaWithSoftplusConcentrationRate, self).__init__(
             concentration=nn.softplus(concentration,
                                       name="softplus_concentration"),
             rate=nn.softplus(rate, name="softplus_rate"),
             validate_args=validate_args,
             allow_nan_stats=allow_nan_stats,
             name=name)
     self._parameters = parameters
Ejemplo n.º 31
0
 def __init__(self,
              concentration,
              rate,
              validate_args=False,
              allow_nan_stats=True,
              name="GammaWithSoftplusConcentrationRate"):
   parameters = distribution_util.parent_frame_arguments()
   with ops.name_scope(name, values=[concentration, rate]) as name:
     super(GammaWithSoftplusConcentrationRate, self).__init__(
         concentration=nn.softplus(concentration,
                                   name="softplus_concentration"),
         rate=nn.softplus(rate, name="softplus_rate"),
         validate_args=validate_args,
         allow_nan_stats=allow_nan_stats,
         name=name)
   self._parameters = parameters
 def __init__(self,
              loc,
              scale,
              validate_args=False,
              allow_nan_stats=True,
              name="NormalWithSoftplusScale"):
     parameters = distribution_util.parent_frame_arguments()
     with ops.name_scope(name, values=[scale]) as name:
         super(NormalWithSoftplusScale,
               self).__init__(loc=loc,
                              scale=nn.softplus(scale,
                                                name="softplus_scale"),
                              validate_args=validate_args,
                              allow_nan_stats=allow_nan_stats,
                              name=name)
     self._parameters = parameters
Ejemplo n.º 33
0
  def __init__(self,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Geometric"):
    """Construct Geometric distributions.

    Args:
      logits: Floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0`
        indicates the number of batch dimensions. Each entry represents logits
        for the probability of success for independent Geometric distributions
        and must be in the range `(-inf, inf]`. Only one of `logits` or `probs`
        should be specified.
      probs: Positive floating-point `Tensor` with shape `[B1, ..., Bb]`
        where `b >= 0` indicates the number of batch dimensions. Each entry
        represents the probability of success for independent Geometric
        distributions and must be in the range `(0, 1]`. Only one of `logits`
        or `probs` should be specified.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """

    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[logits, probs]) as name:
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits, probs, validate_args=validate_args, name=name)

      with ops.control_dependencies(
          [check_ops.assert_positive(self._probs)] if validate_args else []):
        self._probs = array_ops.identity(self._probs, name="probs")

    super(Geometric, self).__init__(
        dtype=self._probs.dtype,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._probs, self._logits],
        name=name)
Ejemplo n.º 34
0
  def __init__(self,
               logits=None,
               probs=None,
               dtype=dtypes.int32,
               validate_args=False,
               allow_nan_stats=True,
               name="Bernoulli"):
    """Construct Bernoulli distributions.

    Args:
      logits: An N-D `Tensor` representing the log-odds of a `1` event. Each
        entry in the `Tensor` parametrizes an independent Bernoulli distribution
        where the probability of an event is sigmoid(logits). Only one of
        `logits` or `probs` should be passed in.
      probs: An N-D `Tensor` representing the probability of a `1`
        event. Each entry in the `Tensor` parameterizes an independent
        Bernoulli distribution. Only one of `logits` or `probs` should be passed
        in.
      dtype: The type of the event samples. Default: `int32`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: If p and logits are passed, or if neither are passed.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name) as name:
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits=logits,
          probs=probs,
          validate_args=validate_args,
          name=name)
    super(Bernoulli, self).__init__(
        dtype=dtype,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._logits, self._probs],
        name=name)
    def __init__(self,
                 logits=None,
                 probs=None,
                 dtype=dtypes.int32,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Bernoulli"):
        """Construct Bernoulli distributions.

    Args:
      logits: An N-D `Tensor` representing the log-odds of a `1` event. Each
        entry in the `Tensor` parametrizes an independent Bernoulli distribution
        where the probability of an event is sigmoid(logits). Only one of
        `logits` or `probs` should be passed in.
      probs: An N-D `Tensor` representing the probability of a `1`
        event. Each entry in the `Tensor` parameterizes an independent
        Bernoulli distribution. Only one of `logits` or `probs` should be passed
        in.
      dtype: The type of the event samples. Default: `int32`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: If p and logits are passed, or if neither are passed.
    """
        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name) as name:
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits=logits,
                probs=probs,
                validate_args=validate_args,
                name=name)
        super(Bernoulli, self).__init__(
            dtype=dtype,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._logits, self._probs],
            name=name)
Ejemplo n.º 36
0
 def __init__(self,
              df,
              loc,
              scale,
              validate_args=False,
              allow_nan_stats=True,
              name="StudentTWithAbsDfSoftplusScale"):
   parameters = distribution_util.parent_frame_arguments()
   with ops.name_scope(name, values=[df, scale]) as name:
     super(StudentTWithAbsDfSoftplusScale, self).__init__(
         df=math_ops.floor(math_ops.abs(df)),
         loc=loc,
         scale=nn.softplus(scale, name="softplus_scale"),
         validate_args=validate_args,
         allow_nan_stats=allow_nan_stats,
         name=name)
   self._parameters = parameters
Ejemplo n.º 37
0
 def __init__(self,
              df,
              loc,
              scale,
              validate_args=False,
              allow_nan_stats=True,
              name="StudentTWithAbsDfSoftplusScale"):
     parameters = distribution_util.parent_frame_arguments()
     with ops.name_scope(name, values=[df, scale]) as name:
         super(StudentTWithAbsDfSoftplusScale,
               self).__init__(df=math_ops.floor(math_ops.abs(df)),
                              loc=loc,
                              scale=nn.softplus(scale,
                                                name="softplus_scale"),
                              validate_args=validate_args,
                              allow_nan_stats=allow_nan_stats,
                              name=name)
     self._parameters = parameters
Ejemplo n.º 38
0
 def __init__(self,
              concentration1,
              concentration0,
              validate_args=False,
              allow_nan_stats=True,
              name="BetaWithSoftplusConcentration"):
     parameters = distribution_util.parent_frame_arguments()
     with ops.name_scope(name, values=[concentration1,
                                       concentration0]) as name:
         super(BetaWithSoftplusConcentration, self).__init__(
             concentration1=nn.softplus(concentration1,
                                        name="softplus_concentration1"),
             concentration0=nn.softplus(concentration0,
                                        name="softplus_concentration0"),
             validate_args=validate_args,
             allow_nan_stats=allow_nan_stats,
             name=name)
     self._parameters = parameters
    def __init__(self,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Laplace"):
        """Construct Laplace distribution with parameters `loc` and `scale`.

    The parameters `loc` and `scale` must be shaped in a way that supports
    broadcasting (e.g., `loc / scale` is a valid operation).

    Args:
      loc: Floating point tensor which characterizes the location (center)
        of the distribution.
      scale: Positive floating point tensor which characterizes the spread of
        the distribution.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `loc` and `scale` are of different dtype.
    """
        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name, values=[loc, scale]) as name:
            with ops.control_dependencies(
                [check_ops.assert_positive(scale)] if validate_args else []):
                self._loc = array_ops.identity(loc, name="loc")
                self._scale = array_ops.identity(scale, name="scale")
                check_ops.assert_same_float_dtype([self._loc, self._scale])
            super(Laplace, self).__init__(
                dtype=self._loc.dtype,
                reparameterization_type=distribution.FULLY_REPARAMETERIZED,
                validate_args=validate_args,
                allow_nan_stats=allow_nan_stats,
                parameters=parameters,
                graph_parents=[self._loc, self._scale],
                name=name)
Ejemplo n.º 40
0
  def __init__(self,
               low=0.,
               high=1.,
               validate_args=False,
               allow_nan_stats=True,
               name="Uniform"):
    """Initialize a batch of Uniform distributions.

    Args:
      low: Floating point tensor, lower boundary of the output interval. Must
        have `low < high`.
      high: Floating point tensor, upper boundary of the output interval. Must
        have `low < high`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      InvalidArgumentError: if `low >= high` and `validate_args=False`.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[low, high]) as name:
      with ops.control_dependencies([
          check_ops.assert_less(
              low, high, message="uniform not defined when low >= high.")
      ] if validate_args else []):
        self._low = array_ops.identity(low, name="low")
        self._high = array_ops.identity(high, name="high")
        check_ops.assert_same_float_dtype([self._low, self._high])
    super(Uniform, self).__init__(
        dtype=self._low.dtype,
        reparameterization_type=distribution.FULLY_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._low,
                       self._high],
        name=name)
Ejemplo n.º 41
0
  def __init__(self,
               loc,
               scale,
               validate_args=False,
               allow_nan_stats=True,
               name="Laplace"):
    """Construct Laplace distribution with parameters `loc` and `scale`.

    The parameters `loc` and `scale` must be shaped in a way that supports
    broadcasting (e.g., `loc / scale` is a valid operation).

    Args:
      loc: Floating point tensor which characterizes the location (center)
        of the distribution.
      scale: Positive floating point tensor which characterizes the spread of
        the distribution.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if `loc` and `scale` are of different dtype.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[loc, scale]) as name:
      with ops.control_dependencies([check_ops.assert_positive(scale)] if
                                    validate_args else []):
        self._loc = array_ops.identity(loc, name="loc")
        self._scale = array_ops.identity(scale, name="scale")
        check_ops.assert_same_float_dtype([self._loc, self._scale])
      super(Laplace, self).__init__(
          dtype=self._loc.dtype,
          reparameterization_type=distribution.FULLY_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          graph_parents=[self._loc, self._scale],
          name=name)
    def __init__(self,
                 concentration,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Dirichlet"):
        """Initialize a batch of Dirichlet distributions.

    Args:
      concentration: Positive floating-point `Tensor` indicating mean number
        of class occurrences; aka "alpha". Implies `self.dtype`, and
        `self.batch_shape`, `self.event_shape`, i.e., if
        `concentration.shape = [N1, N2, ..., Nm, k]` then
        `batch_shape = [N1, N2, ..., Nm]` and
        `event_shape = [k]`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name, values=[concentration]) as name:
            self._concentration = self._maybe_assert_valid_concentration(
                ops.convert_to_tensor(concentration, name="concentration"),
                validate_args)
            self._total_concentration = math_ops.reduce_sum(
                self._concentration, -1)
        super(Dirichlet, self).__init__(
            dtype=self._concentration.dtype,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            parameters=parameters,
            graph_parents=[self._concentration, self._total_concentration],
            name=name)
Ejemplo n.º 43
0
    def __init__(self,
                 df,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Chi2"):
        """Construct Chi2 distributions with parameter `df`.

    Args:
      df: Floating point tensor, the degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = distribution_util.parent_frame_arguments()
        # Even though all stats of chi2 are defined for valid parameters, this is
        # not true in the parent class "gamma."  therefore, passing
        # allow_nan_stats=True
        # through to the parent class results in unnecessary asserts.
        with ops.name_scope(name, values=[df]) as name:
            with ops.control_dependencies([
                    check_ops.assert_positive(df),
            ] if validate_args else []):
                self._df = array_ops.identity(df, name="df")

            super(Chi2, self).__init__(concentration=0.5 * self._df,
                                       rate=constant_op.constant(
                                           0.5, dtype=self._df.dtype),
                                       validate_args=validate_args,
                                       allow_nan_stats=allow_nan_stats,
                                       name=name)
        self._parameters = parameters
Ejemplo n.º 44
0
    def __init__(self,
                 rate,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Exponential"):
        """Construct Exponential distribution with parameter `rate`.

    Args:
      rate: Floating point tensor, equivalent to `1 / mean`. Must contain only
        positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = distribution_util.parent_frame_arguments()
        # Even though all statistics of are defined for valid inputs, this is not
        # true in the parent class "Gamma."  Therefore, passing
        # allow_nan_stats=True
        # through to the parent class results in unnecessary asserts.
        with ops.name_scope(name, values=[rate]) as name:
            self._rate = ops.convert_to_tensor(rate, name="rate")
        super(Exponential, self).__init__(concentration=array_ops.ones(
            [], dtype=self._rate.dtype),
                                          rate=self._rate,
                                          allow_nan_stats=allow_nan_stats,
                                          validate_args=validate_args,
                                          name=name)
        # While the Gamma distribution is not reparameterizable, the exponential
        # distribution is.
        self._reparameterization_type = True
        self._parameters = parameters
        self._graph_parents += [self._rate]
Ejemplo n.º 45
0
  def __init__(self,
               rate,
               validate_args=False,
               allow_nan_stats=True,
               name="Exponential"):
    """Construct Exponential distribution with parameter `rate`.

    Args:
      rate: Floating point tensor, equivalent to `1 / mean`. Must contain only
        positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = distribution_util.parent_frame_arguments()
    # Even though all statistics of are defined for valid inputs, this is not
    # true in the parent class "Gamma."  Therefore, passing
    # allow_nan_stats=True
    # through to the parent class results in unnecessary asserts.
    with ops.name_scope(name, values=[rate]) as name:
      self._rate = ops.convert_to_tensor(rate, name="rate")
    super(Exponential, self).__init__(
        concentration=array_ops.ones([], dtype=self._rate.dtype),
        rate=self._rate,
        allow_nan_stats=allow_nan_stats,
        validate_args=validate_args,
        name=name)
    # While the Gamma distribution is not reparameterizable, the exponential
    # distribution is.
    self._reparameterization_type = True
    self._parameters = parameters
    self._graph_parents += [self._rate]
Ejemplo n.º 46
0
  def __init__(self,
               df,
               validate_args=False,
               allow_nan_stats=True,
               name="Chi2"):
    """Construct Chi2 distributions with parameter `df`.

    Args:
      df: Floating point tensor, the degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = distribution_util.parent_frame_arguments()
    # Even though all stats of chi2 are defined for valid parameters, this is
    # not true in the parent class "gamma."  therefore, passing
    # allow_nan_stats=True
    # through to the parent class results in unnecessary asserts.
    with ops.name_scope(name, values=[df]) as name:
      with ops.control_dependencies([
          check_ops.assert_positive(df),
      ] if validate_args else []):
        self._df = array_ops.identity(df, name="df")

      super(Chi2, self).__init__(
          concentration=0.5 * self._df,
          rate=constant_op.constant(0.5, dtype=self._df.dtype),
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          name=name)
    self._parameters = parameters
    def __init__(self,
                 distribution_fn,
                 sample0=None,
                 num_steps=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Autoregressive"):
        """Construct an `Autoregressive` distribution.

    Args:
      distribution_fn: Python `callable` which constructs a
        `tf.distributions.Distribution`-like instance from a `Tensor` (e.g.,
        `sample0`). The function must respect the "autoregressive property",
        i.e., there exists a permutation of event such that each coordinate is a
        diffeomorphic function of on preceding coordinates.
      sample0: Initial input to `distribution_fn`; used to
        build the distribution in `__init__` which in turn specifies this
        distribution's properties, e.g., `event_shape`, `batch_shape`, `dtype`.
        If unspecified, then `distribution_fn` should be default constructable.
      num_steps: Number of times `distribution_fn` is composed from samples,
        e.g., `num_steps=2` implies
        `distribution_fn(distribution_fn(sample0).sample(n)).sample()`.
      validate_args: Python `bool`.  Whether to validate input with asserts.
        If `validate_args` is `False`, and the inputs are invalid,
        correct behavior is not guaranteed.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
        Default value: "Autoregressive".

    Raises:
      ValueError: if `num_steps` and
        `distribution_fn(sample0).event_shape.num_elements()` are both `None`.
      ValueError: if `num_steps < 1`.
    """
        parameters = distribution_util.parent_frame_arguments()
        with ops.name_scope(name) as name:
            self._distribution_fn = distribution_fn
            self._sample0 = sample0
            self._distribution0 = (distribution_fn() if sample0 is None else
                                   distribution_fn(sample0))
            if num_steps is None:
                num_steps = self._distribution0.event_shape.num_elements()
                if num_steps is None:
                    raise ValueError(
                        "distribution_fn must generate a distribution "
                        "with fully known `event_shape`.")
            if num_steps < 1:
                raise ValueError(
                    "num_steps ({}) must be at least 1.".format(num_steps))
            self._num_steps = num_steps
        super(Autoregressive, self).__init__(
            dtype=self._distribution0.dtype,
            reparameterization_type=self._distribution0.
            reparameterization_type,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=self._distribution0._graph_parents,  # pylint: disable=protected-access
            name=name)
Ejemplo n.º 48
0
  def __init__(self,
               mixture_distribution,
               components_distribution,
               validate_args=False,
               allow_nan_stats=True,
               name="MixtureSameFamily"):
    """Construct a `MixtureSameFamily` distribution.

    Args:
      mixture_distribution: `tf.distributions.Categorical`-like instance.
        Manages the probability of selecting components. The number of
        categories must match the rightmost batch dimension of the
        `components_distribution`. Must have either scalar `batch_shape` or
        `batch_shape` matching `components_distribution.batch_shape[:-1]`.
      components_distribution: `tf.distributions.Distribution`-like instance.
        Right-most batch dimension indexes components.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: `if not mixture_distribution.dtype.is_integer`.
      ValueError: if mixture_distribution does not have scalar `event_shape`.
      ValueError: if `mixture_distribution.batch_shape` and
        `components_distribution.batch_shape[:-1]` are both fully defined and
        the former is neither scalar nor equal to the latter.
      ValueError: if `mixture_distribution` categories does not equal
        `components_distribution` rightmost batch shape.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name) as name:
      self._mixture_distribution = mixture_distribution
      self._components_distribution = components_distribution
      self._runtime_assertions = []

      s = components_distribution.event_shape_tensor()
      self._event_ndims = (s.shape[0].value
                           if s.shape.with_rank_at_least(1)[0].value is not None
                           else array_ops.shape(s)[0])

      if not mixture_distribution.dtype.is_integer:
        raise ValueError(
            "`mixture_distribution.dtype` ({}) is not over integers".format(
                mixture_distribution.dtype.name))

      if (mixture_distribution.event_shape.ndims is not None
          and mixture_distribution.event_shape.ndims != 0):
        raise ValueError("`mixture_distribution` must have scalar `event_dim`s")
      elif validate_args:
        self._runtime_assertions += [
            control_flow_ops.assert_has_rank(
                mixture_distribution.event_shape_tensor(), 0,
                message="`mixture_distribution` must have scalar `event_dim`s"),
        ]

      mdbs = mixture_distribution.batch_shape
      cdbs = components_distribution.batch_shape.with_rank_at_least(1)[:-1]
      if mdbs.is_fully_defined() and cdbs.is_fully_defined():
        if mdbs.ndims != 0 and mdbs != cdbs:
          raise ValueError(
              "`mixture_distribution.batch_shape` (`{}`) is not "
              "compatible with `components_distribution.batch_shape` "
              "(`{}`)".format(mdbs.as_list(), cdbs.as_list()))
      elif validate_args:
        mdbs = mixture_distribution.batch_shape_tensor()
        cdbs = components_distribution.batch_shape_tensor()[:-1]
        self._runtime_assertions += [
            control_flow_ops.assert_equal(
                distribution_util.pick_vector(
                    mixture_distribution.is_scalar_batch(), cdbs, mdbs),
                cdbs,
                message=(
                    "`mixture_distribution.batch_shape` is not "
                    "compatible with `components_distribution.batch_shape`"))]

      km = mixture_distribution.logits.shape.with_rank_at_least(1)[-1].value
      kc = components_distribution.batch_shape.with_rank_at_least(1)[-1].value
      if km is not None and kc is not None and km != kc:
        raise ValueError("`mixture_distribution components` ({}) does not "
                         "equal `components_distribution.batch_shape[-1]` "
                         "({})".format(km, kc))
      elif validate_args:
        km = array_ops.shape(mixture_distribution.logits)[-1]
        kc = components_distribution.batch_shape_tensor()[-1]
        self._runtime_assertions += [
            control_flow_ops.assert_equal(
                km, kc,
                message=("`mixture_distribution components` does not equal "
                         "`components_distribution.batch_shape[-1:]`")),
        ]
      elif km is None:
        km = array_ops.shape(mixture_distribution.logits)[-1]

      self._num_components = km

      super(MixtureSameFamily, self).__init__(
          dtype=self._components_distribution.dtype,
          reparameterization_type=distribution.NOT_REPARAMETERIZED,
          validate_args=validate_args,
          allow_nan_stats=allow_nan_stats,
          parameters=parameters,
          graph_parents=(
              self._mixture_distribution._graph_parents  # pylint: disable=protected-access
              + self._components_distribution._graph_parents),  # pylint: disable=protected-access
          name=name)
  def __init__(self,
               distribution,
               bijector=None,
               batch_shape=None,
               event_shape=None,
               validate_args=False,
               name=None):
    """Construct a Transformed Distribution.

    Args:
      distribution: The base distribution instance to transform. Typically an
        instance of `Distribution`.
      bijector: The object responsible for calculating the transformation.
        Typically an instance of `Bijector`. `None` means `Identity()`.
      batch_shape: `integer` vector `Tensor` which overrides `distribution`
        `batch_shape`; valid only if `distribution.is_scalar_batch()`.
      event_shape: `integer` vector `Tensor` which overrides `distribution`
        `event_shape`; valid only if `distribution.is_scalar_event()`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this class. Default:
        `bijector.name + distribution.name`.
    """
    parameters = distribution_util.parent_frame_arguments()
    name = name or (("" if bijector is None else bijector.name) +
                    distribution.name)
    with ops.name_scope(name, values=[event_shape, batch_shape]) as name:
      # For convenience we define some handy constants.
      self._zero = constant_op.constant(0, dtype=dtypes.int32, name="zero")
      self._empty = constant_op.constant([], dtype=dtypes.int32, name="empty")

      if bijector is None:
        bijector = identity_bijector.Identity(validate_args=validate_args)

      # We will keep track of a static and dynamic version of
      # self._is_{batch,event}_override. This way we can do more prior to graph
      # execution, including possibly raising Python exceptions.

      self._override_batch_shape = self._maybe_validate_shape_override(
          batch_shape, distribution.is_scalar_batch(), validate_args,
          "batch_shape")
      self._is_batch_override = _logical_not(_logical_equal(
          _ndims_from_shape(self._override_batch_shape), self._zero))
      self._is_maybe_batch_override = bool(
          tensor_util.constant_value(self._override_batch_shape) is None or
          tensor_util.constant_value(self._override_batch_shape).size != 0)

      self._override_event_shape = self._maybe_validate_shape_override(
          event_shape, distribution.is_scalar_event(), validate_args,
          "event_shape")
      self._is_event_override = _logical_not(_logical_equal(
          _ndims_from_shape(self._override_event_shape), self._zero))
      self._is_maybe_event_override = bool(
          tensor_util.constant_value(self._override_event_shape) is None or
          tensor_util.constant_value(self._override_event_shape).size != 0)

      # To convert a scalar distribution into a multivariate distribution we
      # will draw dims from the sample dims, which are otherwise iid. This is
      # easy to do except in the case that the base distribution has batch dims
      # and we're overriding event shape. When that case happens the event dims
      # will incorrectly be to the left of the batch dims. In this case we'll
      # cyclically permute left the new dims.
      self._needs_rotation = _logical_and(
          self._is_event_override,
          _logical_not(self._is_batch_override),
          _logical_not(distribution.is_scalar_batch()))
      override_event_ndims = _ndims_from_shape(self._override_event_shape)
      self._rotate_ndims = _pick_scalar_condition(
          self._needs_rotation, override_event_ndims, 0)
      # We'll be reducing the head dims (if at all), i.e., this will be []
      # if we don't need to reduce.
      self._reduce_event_indices = math_ops.range(
          self._rotate_ndims - override_event_ndims, self._rotate_ndims)

    self._distribution = distribution
    self._bijector = bijector
    super(TransformedDistribution, self).__init__(
        dtype=self._distribution.dtype,
        reparameterization_type=self._distribution.reparameterization_type,
        validate_args=validate_args,
        allow_nan_stats=self._distribution.allow_nan_stats,
        parameters=parameters,
        # We let TransformedDistribution access _graph_parents since this class
        # is more like a baseclass than derived.
        graph_parents=(distribution._graph_parents +  # pylint: disable=protected-access
                       bijector.graph_parents),
        name=name)
Ejemplo n.º 50
0
  def __init__(self,
               cat,
               components,
               validate_args=False,
               allow_nan_stats=True,
               use_static_graph=False,
               name="Mixture"):
    """Initialize a Mixture distribution.

    A `Mixture` is defined by a `Categorical` (`cat`, representing the
    mixture probabilities) and a list of `Distribution` objects
    all having matching dtype, batch shape, event shape, and continuity
    properties (the components).

    The `num_classes` of `cat` must be possible to infer at graph construction
    time and match `len(components)`.

    Args:
      cat: A `Categorical` distribution instance, representing the probabilities
          of `distributions`.
      components: A list or tuple of `Distribution` instances.
        Each instance must have the same type, be defined on the same domain,
        and have matching `event_shape` and `batch_shape`.
      validate_args: Python `bool`, default `False`. If `True`, raise a runtime
        error if batch or event ranks are inconsistent between cat and any of
        the distributions. This is only checked if the ranks cannot be
        determined statically at graph construction time.
      allow_nan_stats: Boolean, default `True`. If `False`, raise an
       exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member. If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      use_static_graph: Calls to `sample` will not rely on dynamic tensor
        indexing, allowing for some static graph compilation optimizations, but
        at the expense of sampling all underlying distributions in the mixture.
        (Possibly useful when running on TPUs).
        Default value: `False` (i.e., use dynamic indexing).
      name: A name for this distribution (optional).

    Raises:
      TypeError: If cat is not a `Categorical`, or `components` is not
        a list or tuple, or the elements of `components` are not
        instances of `Distribution`, or do not have matching `dtype`.
      ValueError: If `components` is an empty list or tuple, or its
        elements do not have a statically known event rank.
        If `cat.num_classes` cannot be inferred at graph creation time,
        or the constant value of `cat.num_classes` is not equal to
        `len(components)`, or all `components` and `cat` do not have
        matching static batch shapes, or all components do not
        have matching static event shapes.
    """
    parameters = distribution_util.parent_frame_arguments()
    if not isinstance(cat, categorical.Categorical):
      raise TypeError("cat must be a Categorical distribution, but saw: %s" %
                      cat)
    if not components:
      raise ValueError("components must be a non-empty list or tuple")
    if not isinstance(components, (list, tuple)):
      raise TypeError("components must be a list or tuple, but saw: %s" %
                      components)
    if not all(isinstance(c, distribution.Distribution) for c in components):
      raise TypeError(
          "all entries in components must be Distribution instances"
          " but saw: %s" % components)

    dtype = components[0].dtype
    if not all(d.dtype == dtype for d in components):
      raise TypeError("All components must have the same dtype, but saw "
                      "dtypes: %s" % [(d.name, d.dtype) for d in components])
    static_event_shape = components[0].event_shape
    static_batch_shape = cat.batch_shape
    for d in components:
      static_event_shape = static_event_shape.merge_with(d.event_shape)
      static_batch_shape = static_batch_shape.merge_with(d.batch_shape)
    if static_event_shape.ndims is None:
      raise ValueError(
          "Expected to know rank(event_shape) from components, but "
          "none of the components provide a static number of ndims")

    # Ensure that all batch and event ndims are consistent.
    with ops.name_scope(name, values=[cat.logits]) as name:
      num_components = cat.event_size
      static_num_components = tensor_util.constant_value(num_components)
      if static_num_components is None:
        raise ValueError(
            "Could not infer number of classes from cat and unable "
            "to compare this value to the number of components passed in.")
      # Possibly convert from numpy 0-D array.
      static_num_components = int(static_num_components)
      if static_num_components != len(components):
        raise ValueError("cat.num_classes != len(components): %d vs. %d" %
                         (static_num_components, len(components)))

      cat_batch_shape = cat.batch_shape_tensor()
      cat_batch_rank = array_ops.size(cat_batch_shape)
      if validate_args:
        batch_shapes = [d.batch_shape_tensor() for d in components]
        batch_ranks = [array_ops.size(bs) for bs in batch_shapes]
        check_message = ("components[%d] batch shape must match cat "
                         "batch shape")
        self._assertions = [
            check_ops.assert_equal(
                cat_batch_rank, batch_ranks[di], message=check_message % di)
            for di in range(len(components))
        ]
        self._assertions += [
            check_ops.assert_equal(
                cat_batch_shape, batch_shapes[di], message=check_message % di)
            for di in range(len(components))
        ]
      else:
        self._assertions = []

      self._cat = cat
      self._components = list(components)
      self._num_components = static_num_components
      self._static_event_shape = static_event_shape
      self._static_batch_shape = static_batch_shape

      self._use_static_graph = use_static_graph
      if use_static_graph and static_num_components is None:
        raise ValueError("Number of categories must be known statically when "
                         "`static_sample=True`.")
    # We let the Mixture distribution access _graph_parents since its arguably
    # more like a baseclass.
    graph_parents = self._cat._graph_parents  # pylint: disable=protected-access
    for c in self._components:
      graph_parents += c._graph_parents  # pylint: disable=protected-access

    super(Mixture, self).__init__(
        dtype=dtype,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=graph_parents,
        name=name)
  def __init__(self,
               loc=None,
               covariance_matrix=None,
               validate_args=False,
               allow_nan_stats=True,
               name="MultivariateNormalFullCovariance"):
    """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and
    `covariance_matrix` arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `covariance_matrix`. The last dimension of `loc` (if provided) must
    broadcast with this.

    A non-batch `covariance_matrix` matrix is a `k x k` symmetric positive
    definite matrix.  In other words it is (real) symmetric with all eigenvalues
    strictly positive.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      covariance_matrix: Floating-point, symmetric positive definite `Tensor` of
        same `dtype` as `loc`.  The strict upper triangle of `covariance_matrix`
        is ignored, so if `covariance_matrix` is not symmetric no error will be
        raised (unless `validate_args is True`).  `covariance_matrix` has shape
        `[B1, ..., Bb, k, k]` where `b >= 0` and `k` is the event size.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if neither `loc` nor `covariance_matrix` are specified.
    """
    parameters = distribution_util.parent_frame_arguments()

    # Convert the covariance_matrix up to a scale_tril and call MVNTriL.
    with ops.name_scope(name) as name:
      with ops.name_scope("init", values=[loc, covariance_matrix]):
        if covariance_matrix is None:
          scale_tril = None
        else:
          covariance_matrix = ops.convert_to_tensor(
              covariance_matrix, name="covariance_matrix")
          if validate_args:
            covariance_matrix = control_flow_ops.with_dependencies([
                check_ops.assert_near(
                    covariance_matrix,
                    array_ops.matrix_transpose(covariance_matrix),
                    message="Matrix was not symmetric")], covariance_matrix)
          # No need to validate that covariance_matrix is non-singular.
          # LinearOperatorLowerTriangular has an assert_non_singular method that
          # is called by the Bijector.
          # However, cholesky() ignores the upper triangular part, so we do need
          # to separately assert symmetric.
          scale_tril = linalg_ops.cholesky(covariance_matrix)
        super(MultivariateNormalFullCovariance, self).__init__(
            loc=loc,
            scale_tril=scale_tril,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name=name)
    self._parameters = parameters
  def __init__(
      self,
      temperature,
      logits=None,
      probs=None,
      dtype=None,
      validate_args=False,
      allow_nan_stats=True,
      name="ExpRelaxedOneHotCategorical"):
    """Initialize ExpRelaxedOneHotCategorical using class log-probabilities.

    Args:
      temperature: An 0-D `Tensor`, representing the temperature
        of a set of ExpRelaxedCategorical distributions. The temperature should
        be positive.
      logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
        of a set of ExpRelaxedCategorical distributions. The first
        `N - 1` dimensions index into a batch of independent distributions and
        the last dimension represents a vector of logits for each class. Only
        one of `logits` or `probs` should be passed in.
      probs: An N-D `Tensor`, `N >= 1`, representing the probabilities
        of a set of ExpRelaxedCategorical distributions. The first
        `N - 1` dimensions index into a batch of independent distributions and
        the last dimension represents a vector of probabilities for each
        class. Only one of `logits` or `probs` should be passed in.
      dtype: The type of the event samples (default: inferred from
        logits/probs).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[logits, probs, temperature]) as name:

      self._logits, self._probs = distribution_util.get_logits_and_probs(
          name=name, logits=logits, probs=probs, validate_args=validate_args,
          multidimensional=True)

      if dtype is None:
        dtype = self._logits.dtype
        if not validate_args:
          temperature = math_ops.cast(temperature, dtype)

      with ops.control_dependencies([check_ops.assert_positive(temperature)]
                                    if validate_args else []):
        self._temperature = array_ops.identity(temperature, name="temperature")
        self._temperature_2d = array_ops.reshape(temperature, [-1, 1],
                                                 name="temperature_2d")

      logits_shape_static = self._logits.get_shape().with_rank_at_least(1)
      if logits_shape_static.ndims is not None:
        self._batch_rank = ops.convert_to_tensor(
            logits_shape_static.ndims - 1,
            dtype=dtypes.int32,
            name="batch_rank")
      else:
        with ops.name_scope(name="batch_rank"):
          self._batch_rank = array_ops.rank(self._logits) - 1

      with ops.name_scope(name="event_size"):
        self._event_size = array_ops.shape(self._logits)[-1]

    super(ExpRelaxedOneHotCategorical, self).__init__(
        dtype=dtype,
        reparameterization_type=distribution.FULLY_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._logits,
                       self._probs,
                       self._temperature],
        name=name)
Ejemplo n.º 53
0
  def __init__(self,
               loc=None,
               scale_tril=None,
               validate_args=False,
               allow_nan_stats=True,
               name="MultivariateNormalTriL"):
    """Construct Multivariate Normal distribution on `R^k`.

    The `batch_shape` is the broadcast shape between `loc` and `scale`
    arguments.

    The `event_shape` is given by last dimension of the matrix implied by
    `scale`. The last dimension of `loc` (if provided) must broadcast with this.

    Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:

    ```none
    scale = scale_tril
    ```

    where `scale_tril` is lower-triangular `k x k` matrix with non-zero
    diagonal, i.e., `tf.diag_part(scale_tril) != 0`.

    Additional leading dimensions (if any) will index batches.

    Args:
      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
        `b >= 0` and `k` is the event size.
      scale_tril: Floating-point, lower-triangular `Tensor` with non-zero
        diagonal elements. `scale_tril` has shape `[B1, ..., Bb, k, k]` where
        `b >= 0` and `k` is the event size.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: if neither `loc` nor `scale_tril` are specified.
    """
    parameters = distribution_util.parent_frame_arguments()
    def _convert_to_tensor(x, name):
      return None if x is None else ops.convert_to_tensor(x, name=name)
    if loc is None and scale_tril is None:
      raise ValueError("Must specify one or both of `loc`, `scale_tril`.")
    with ops.name_scope(name) as name:
      with ops.name_scope("init", values=[loc, scale_tril]):
        loc = _convert_to_tensor(loc, name="loc")
        scale_tril = _convert_to_tensor(scale_tril, name="scale_tril")
        if scale_tril is None:
          scale = linalg.LinearOperatorIdentity(
              num_rows=distribution_util.dimension_size(loc, -1),
              dtype=loc.dtype,
              is_self_adjoint=True,
              is_positive_definite=True,
              assert_proper_shapes=validate_args)
        else:
          # No need to validate that scale_tril is non-singular.
          # LinearOperatorLowerTriangular has an assert_non_singular
          # method that is called by the Bijector.
          scale = linalg.LinearOperatorLowerTriangular(
              scale_tril,
              is_non_singular=True,
              is_self_adjoint=False,
              is_positive_definite=False)
    super(MultivariateNormalTriL, self).__init__(
        loc=loc,
        scale=scale,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        name=name)
    self._parameters = parameters
Ejemplo n.º 54
0
  def __init__(self,
               distribution,
               low=None,
               high=None,
               validate_args=False,
               name="QuantizedDistribution"):
    """Construct a Quantized Distribution representing `Y = ceiling(X)`.

    Some properties are inherited from the distribution defining `X`. Example:
    `allow_nan_stats` is determined for this `QuantizedDistribution` by reading
    the `distribution`.

    Args:
      distribution:  The base distribution class to transform. Typically an
        instance of `Distribution`.
      low: `Tensor` with same `dtype` as this distribution and shape
        able to be added to samples. Should be a whole number. Default `None`.
        If provided, base distribution's `prob` should be defined at
        `low`.
      high: `Tensor` with same `dtype` as this distribution and shape
        able to be added to samples. Should be a whole number. Default `None`.
        If provided, base distribution's `prob` should be defined at
        `high - 1`.
        `high` must be strictly greater than `low`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: If `dist_cls` is not a subclass of
          `Distribution` or continuous.
      NotImplementedError:  If the base distribution does not implement `cdf`.
    """
    parameters = distribution_util.parent_frame_arguments()
    values = (
        list(distribution.parameters.values()) +
        [low, high])
    with ops.name_scope(name, values=values) as name:
      self._dist = distribution

      if low is not None:
        low = ops.convert_to_tensor(low, name="low")
      if high is not None:
        high = ops.convert_to_tensor(high, name="high")
      check_ops.assert_same_float_dtype(
          tensors=[self.distribution, low, high])

      # We let QuantizedDistribution access _graph_parents since this class is
      # more like a baseclass.
      graph_parents = self._dist._graph_parents  # pylint: disable=protected-access

      checks = []
      if validate_args and low is not None and high is not None:
        message = "low must be strictly less than high."
        checks.append(
            check_ops.assert_less(
                low, high, message=message))
      self._validate_args = validate_args  # self._check_integer uses this.
      with ops.control_dependencies(checks if validate_args else []):
        if low is not None:
          self._low = self._check_integer(low)
          graph_parents += [self._low]
        else:
          self._low = None
        if high is not None:
          self._high = self._check_integer(high)
          graph_parents += [self._high]
        else:
          self._high = None

    super(QuantizedDistribution, self).__init__(
        dtype=self._dist.dtype,
        reparameterization_type=distributions.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=self._dist.allow_nan_stats,
        parameters=parameters,
        graph_parents=graph_parents,
        name=name)
Ejemplo n.º 55
0
    def __init__(self,
                 distribution,
                 low=None,
                 high=None,
                 validate_args=False,
                 name="QuantizedDistribution"):
        """Construct a Quantized Distribution representing `Y = ceiling(X)`.

    Some properties are inherited from the distribution defining `X`. Example:
    `allow_nan_stats` is determined for this `QuantizedDistribution` by reading
    the `distribution`.

    Args:
      distribution:  The base distribution class to transform. Typically an
        instance of `Distribution`.
      low: `Tensor` with same `dtype` as this distribution and shape
        able to be added to samples. Should be a whole number. Default `None`.
        If provided, base distribution's `prob` should be defined at
        `low`.
      high: `Tensor` with same `dtype` as this distribution and shape
        able to be added to samples. Should be a whole number. Default `None`.
        If provided, base distribution's `prob` should be defined at
        `high - 1`.
        `high` must be strictly greater than `low`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: If `dist_cls` is not a subclass of
          `Distribution` or continuous.
      NotImplementedError:  If the base distribution does not implement `cdf`.
    """
        parameters = distribution_util.parent_frame_arguments()
        values = (list(distribution.parameters.values()) + [low, high])
        with ops.name_scope(name, values=values) as name:
            self._dist = distribution

            if low is not None:
                low = ops.convert_to_tensor(low, name="low")
            if high is not None:
                high = ops.convert_to_tensor(high, name="high")
            check_ops.assert_same_float_dtype(
                tensors=[self.distribution, low, high])

            # We let QuantizedDistribution access _graph_parents since this class is
            # more like a baseclass.
            graph_parents = self._dist._graph_parents  # pylint: disable=protected-access

            checks = []
            if validate_args and low is not None and high is not None:
                message = "low must be strictly less than high."
                checks.append(check_ops.assert_less(low, high,
                                                    message=message))
            self._validate_args = validate_args  # self._check_integer uses this.
            with ops.control_dependencies(checks if validate_args else []):
                if low is not None:
                    self._low = self._check_integer(low)
                    graph_parents += [self._low]
                else:
                    self._low = None
                if high is not None:
                    self._high = self._check_integer(high)
                    graph_parents += [self._high]
                else:
                    self._high = None

        super(QuantizedDistribution, self).__init__(
            dtype=self._dist.dtype,
            reparameterization_type=distributions.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=self._dist.allow_nan_stats,
            parameters=parameters,
            graph_parents=graph_parents,
            name=name)