def testGetLogitsAndProbsImproperArguments(self):
    with self.test_session():
      with self.assertRaises(ValueError):
        distribution_util.get_logits_and_probs(logits=None, probs=None)

      with self.assertRaises(ValueError):
        distribution_util.get_logits_and_probs(logits=[0.1], probs=[0.1])
  def testGetLogitsAndProbsProbabilityValidateArgs(self):
    p = [0.01, 0.2, 0.5, 0.7, .99]
    # Component less than 0.
    p2 = [-1, 0.2, 0.5, 0.3, .2]
    # Component greater than 1.
    p3 = [2, 0.2, 0.5, 0.3, .2]

    with self.test_session():
      _, prob = distribution_util.get_logits_and_probs(
          probs=p, validate_args=True)
      prob.eval()

      with self.assertRaisesOpError("Condition x >= 0"):
        _, prob = distribution_util.get_logits_and_probs(
            probs=p2, validate_args=True)
        prob.eval()

      _, prob = distribution_util.get_logits_and_probs(
          probs=p2, validate_args=False)
      prob.eval()

      with self.assertRaisesOpError("probs has components greater than 1"):
        _, prob = distribution_util.get_logits_and_probs(
            probs=p3, validate_args=True)
        prob.eval()

      _, prob = distribution_util.get_logits_and_probs(
          probs=p3, validate_args=False)
      prob.eval()
    def testGetLogitsAndProbsProbabilityValidateArgs(self):
        p = [0.01, 0.2, 0.5, 0.7, .99]
        # Component less than 0.
        p2 = [-1, 0.2, 0.5, 0.3, .2]
        # Component greater than 1.
        p3 = [2, 0.2, 0.5, 0.3, .2]

        with self.test_session():
            _, prob = distribution_util.get_logits_and_probs(
                probs=p, validate_args=True)
            prob.eval()

            with self.assertRaisesOpError("Condition x >= 0"):
                _, prob = distribution_util.get_logits_and_probs(
                    probs=p2, validate_args=True)
                prob.eval()

            _, prob = distribution_util.get_logits_and_probs(
                probs=p2, validate_args=False)
            prob.eval()

            with self.assertRaisesOpError(
                    "probs has components greater than 1"):
                _, prob = distribution_util.get_logits_and_probs(
                    probs=p3, validate_args=True)
                prob.eval()

            _, prob = distribution_util.get_logits_and_probs(
                probs=p3, validate_args=False)
            prob.eval()
    def testGetLogitsAndProbsImproperArguments(self):
        with self.test_session():
            with self.assertRaises(ValueError):
                distribution_util.get_logits_and_probs(logits=None, probs=None)

            with self.assertRaises(ValueError):
                distribution_util.get_logits_and_probs(logits=[0.1],
                                                       probs=[0.1])
  def __init__(self,
               temperature,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="RelaxedBernoulli"):
    """Construct RelaxedBernoulli distributions.

    Args:
      temperature: An 0-D `Tensor`, representing the temperature
        of a set of RelaxedBernoulli distributions. The temperature should be
        positive.
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent RelaxedBernoulli distribution where the probability of an
        event is sigmoid(logits). Only one of `logits` or `probs` should be
        passed in.
      probs: An N-D `Tensor` representing the probability of a positive event.
        Each entry in the `Tensor` parameterizes an independent Bernoulli
        distribution. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `Boolean`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined.  When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: `String` name prefixed to Ops created by this class.

    Raises:
      ValueError: If both `probs` and `logits` are passed, or if neither.
    """
    parameters = locals()
    with ops.name_scope(name, values=[logits, probs, temperature]) as ns:
      with ops.control_dependencies([check_ops.assert_positive(temperature)]
                                    if validate_args else []):
        self._temperature = array_ops.identity(temperature, name="temperature")

      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits=logits, probs=probs, validate_args=validate_args)
      dist = logistic._Logistic(self._logits / self._temperature,
                                1. / self._temperature,
                                validate_args=validate_args,
                                allow_nan_stats=allow_nan_stats,
                                name=ns)
      self._parameters = parameters

    def inverse_log_det_jacobian_fn(y):
      return -math_ops.reduce_sum(math_ops.log(y) + math_ops.log1p(-y), -1)

    sigmoid_bijector = bijector.Inline(
        forward_fn=math_ops.sigmoid,
        inverse_fn=(lambda y: math_ops.log(y) - math_ops.log1p(-y)),
        inverse_log_det_jacobian_fn=inverse_log_det_jacobian_fn,
        name="sigmoid")
    super(_RelaxedBernoulli, self).__init__(dist,
                                            sigmoid_bijector,
                                            name=name)
Beispiel #6
0
  def __init__(
      self,
      logits=None,
      probs=None,
      dtype=dtypes.int32,
      validate_args=False,
      allow_nan_stats=True,
      name="OneHotCategorical"):
    """Initialize OneHotCategorical distributions using class log-probabilities.

    Args:
      logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities of a
        set of Categorical distributions. The first `N - 1` dimensions index
        into a batch of independent distributions and the last dimension
        represents a vector of logits for each class. Only one of `logits` or
        `probs` should be passed in.
      probs: An N-D `Tensor`, `N >= 1`, representing the probabilities of a set
        of Categorical distributions. The first `N - 1` dimensions index into a
        batch of independent distributions and the last dimension represents a
        vector of probabilities for each class. Only one of `logits` or `probs`
        should be passed in.
      dtype: The type of the event samples (default: int32).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = locals()
    with ops.name_scope(name, values=[logits, probs]) as ns:
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          name=name, logits=logits, probs=probs, validate_args=validate_args,
          multidimensional=True)

      logits_shape_static = self._logits.get_shape().with_rank_at_least(1)
      if logits_shape_static.ndims is not None:
        self._batch_rank = ops.convert_to_tensor(
            logits_shape_static.ndims - 1,
            dtype=dtypes.int32,
            name="batch_rank")
      else:
        with ops.name_scope(name="batch_rank"):
          self._batch_rank = array_ops.rank(self._logits) - 1

      with ops.name_scope(name="event_size"):
        self._event_size = array_ops.shape(self._logits)[-1]

    super(OneHotCategorical, self).__init__(
        dtype=dtype,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._logits,
                       self._probs],
        name=ns)
  def __init__(self,
               temperature,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="RelaxedBernoulli"):
    """Construct RelaxedBernoulli distributions.

    Args:
      temperature: An 0-D `Tensor`, representing the temperature
        of a set of RelaxedBernoulli distributions. The temperature should be
        positive.
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent RelaxedBernoulli distribution where the probability of an
        event is sigmoid(logits). Only one of `logits` or `probs` should be
        passed in.
      probs: An N-D `Tensor` representing the probability of a positive event.
        Each entry in the `Tensor` parameterizes an independent Bernoulli
        distribution. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `Boolean`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined.  When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: `String` name prefixed to Ops created by this class.

    Raises:
      ValueError: If both `probs` and `logits` are passed, or if neither.
    """
    parameters = locals()
    with ops.name_scope(name, values=[logits, probs, temperature]) as ns:
      with ops.control_dependencies([check_ops.assert_positive(temperature)]
                                    if validate_args else []):
        self._temperature = array_ops.identity(temperature, name="temperature")

      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits=logits, probs=probs, validate_args=validate_args)
      dist = logistic.Logistic(self._logits / self._temperature,
                               1. / self._temperature,
                               validate_args=validate_args,
                               allow_nan_stats=allow_nan_stats,
                               name=ns)
      self._parameters = parameters

    def inverse_log_det_jacobian_fn(y):
      return -math_ops.log(y) - math_ops.log1p(-y)

    sigmoid_bijector = bijector.Inline(
        forward_fn=math_ops.sigmoid,
        inverse_fn=(lambda y: math_ops.log(y) - math_ops.log1p(-y)),
        inverse_log_det_jacobian_fn=inverse_log_det_jacobian_fn,
        name="sigmoid")
    super(RelaxedBernoulli, self).__init__(dist, sigmoid_bijector, name=name)
    def testGetLogitsAndProbsProbability(self):
        p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)

        with self.test_session():
            new_logits, new_p = distribution_util.get_logits_and_probs(
                probs=p, validate_args=True)

            self.assertAllClose(special.logit(p), new_logits.eval())
            self.assertAllClose(p, new_p.eval())
  def __init__(self,
               total_count,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Multinomial"):
    """Initialize a batch of Multinomial distributions.

    Args:
      total_count: Non-negative floating point tensor with shape broadcastable
        to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
        `N1 x ... x Nm` different Multinomial distributions. Its components
        should be equal to integer values.
      logits: Floating point tensor representing the log-odds of a
        positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`,
        and the same dtype as `total_count`. Defines this as a batch of
        `N1 x ... x Nm` different `k` class Multinomial distributions. Only one
        of `logits` or `probs` should be passed in.
      probs: Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm, k]` `m >= 0` and same dtype as `total_count`. Defines
        this as a batch of `N1 x ... x Nm` different `k` class Multinomial
        distributions. `probs`'s components in the last portion of its shape
        should sum to `1`. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = locals()
    with ops.name_scope(name, values=[total_count, logits, probs]) as ns:
      self._total_count = self._maybe_assert_valid_total_count(
          ops.convert_to_tensor(total_count, name="total_count"),
          validate_args)
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits=logits,
          probs=probs,
          multidimensional=True,
          validate_args=validate_args,
          name=name)
      self._mean_val = self._total_count[..., array_ops.newaxis] * self._probs
    super(Multinomial, self).__init__(
        dtype=self._probs.dtype,
        is_continuous=False,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._total_count,
                       self._logits,
                       self._probs],
        name=ns)
Beispiel #10
0
  def __init__(self,
               total_count,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Multinomial"):
    """Initialize a batch of Multinomial distributions.

    Args:
      total_count: Non-negative floating point tensor with shape broadcastable
        to `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
        `N1 x ... x Nm` different Multinomial distributions.  Its components
        should be equal to integer values.
      logits: Floating point tensor representing the log-odds of a
        positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`,
        and the same dtype as `total_count`. Defines this as a batch of
        `N1 x ... x Nm` different `k` class Multinomial distributions. Only one
        of `logits` or `probs` should be passed in.
      probs: Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm, k]` `m >= 0` and same dtype as `total_count`.  Defines
        this as a batch of `N1 x ... x Nm` different `k` class Multinomial
        distributions. `probs`'s components in the last portion of its shape
        should sum to `1`. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `Boolean`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined.  When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: `String` name prefixed to Ops created by this class.
    """
    parameters = locals()
    with ops.name_scope(name, values=[total_count, logits, probs]) as ns:
      self._total_count = self._maybe_assert_valid_total_count(
          ops.convert_to_tensor(total_count, name="total_count"),
          validate_args)
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits=logits,
          probs=probs,
          multidimensional=True,
          validate_args=validate_args,
          name=name)
      self._mean_val = self._total_count[..., None] * self._probs
    super(Multinomial, self).__init__(
        dtype=self._probs.dtype,
        is_continuous=False,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._total_count,
                       self._logits,
                       self._probs],
        name=ns)
    def __init__(self,
                 total_count,
                 logits=None,
                 probs=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="NegativeBinomial"):
        """Construct NegativeBinomial distributions.

    Args:
      total_count: Non-negative floating-point `Tensor` with shape
        broadcastable to `[B1,..., Bb]` with `b >= 0` and the same dtype as
        `probs` or `logits`. Defines this as a batch of `N1 x ... x Nm`
        different Negative Binomial distributions. In practice, this represents
        the number of negative Bernoulli trials to stop at (the `total_count`
        of failures), but this is still a valid distribution when
        `total_count` is a non-integer.
      logits: Floating-point `Tensor` with shape broadcastable to
        `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions.
        Each entry represents logits for the probability of success for
        independent Negative Binomial distributions and must be in the open
        interval `(-inf, inf)`. Only one of `logits` or `probs` should be
        specified.
      probs: Positive floating-point `Tensor` with shape broadcastable to
        `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions.
        Each entry represents the probability of success for independent
        Negative Binomial distributions and must be in the open interval
        `(0, 1)`. Only one of `logits` or `probs` should be specified.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """

        parameters = locals()
        with ops.name_scope(name, values=[total_count, logits, probs]) as ns:
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits, probs, validate_args=validate_args, name=name)
            with ops.control_dependencies(
                [check_ops.assert_positive(total_count
                                           )] if validate_args else []):
                self._total_count = array_ops.identity(total_count)

        super(NegativeBinomial, self).__init__(
            dtype=self._probs.dtype,
            is_continuous=False,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._total_count, self._probs, self._logits],
            name=ns)
  def testGetLogitsAndProbsProbabilityMultidimensional(self):
    p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)

    with self.test_session():
      new_logits, new_p = distribution_util.get_logits_and_probs(
          probs=p, multidimensional=True, validate_args=True)

      self.assertAllClose(np.log(p), new_logits.eval())
      self.assertAllClose(p, new_p.eval())
    def testGetLogitsAndProbsProbabilityMultidimensional(self):
        p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)

        with self.test_session():
            new_logits, new_p = distribution_util.get_logits_and_probs(
                probs=p, multidimensional=True, validate_args=True)

            self.assertAllClose(np.log(p), new_logits.eval())
            self.assertAllClose(p, new_p.eval())
  def testGetLogitsAndProbsProbability(self):
    p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)

    with self.test_session():
      new_logits, new_p = distribution_util.get_logits_and_probs(
          probs=p, validate_args=True)

      self.assertAllClose(special.logit(p), new_logits.eval())
      self.assertAllClose(p, new_p.eval())
  def __init__(self,
               total_count,
               logits=None,
               probs=None,
               validate_args=False,
               allow_nan_stats=True,
               name="NegativeBinomial"):
    """Construct NegativeBinomial distributions.

    Args:
      total_count: Non-negative floating-point `Tensor` with shape
        broadcastable to `[B1,..., Bb]` with `b >= 0` and the same dtype as
        `probs` or `logits`. Defines this as a batch of `N1 x ... x Nm`
        different Negative Binomial distributions. In practice, this represents
        the number of negative Bernoulli trials to stop at (the `total_count`
        of failures), but this is still a valid distribution when
        `total_count` is a non-integer.
      logits: Floating-point `Tensor` with shape broadcastable to
        `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions.
        Each entry represents logits for the probability of success for
        independent Negative Binomial distributions and must be in the open
        interval `(-inf, inf)`. Only one of `logits` or `probs` should be
        specified.
      probs: Positive floating-point `Tensor` with shape broadcastable to
        `[B1, ..., Bb]` where `b >= 0` indicates the number of batch dimensions.
        Each entry represents the probability of success for independent
        Negative Binomial distributions and must be in the open interval
        `(0, 1)`. Only one of `logits` or `probs` should be specified.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """

    parameters = locals()
    with ops.name_scope(name, values=[total_count, logits, probs]) as ns:
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits, probs, validate_args=validate_args, name=name)
      with ops.control_dependencies(
          [check_ops.assert_positive(total_count)] if validate_args else []):
        self._total_count = array_ops.identity(total_count)

    super(NegativeBinomial, self).__init__(
        dtype=self._probs.dtype,
        is_continuous=False,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._total_count, self._probs, self._logits],
        name=ns)
    def testGetLogitsAndProbsLogitsMultidimensional(self):
        p = np.array([0.2, 0.3, 0.5], dtype=np.float32)
        logits = np.log(p)

        with self.test_session():
            new_logits, new_p = distribution_util.get_logits_and_probs(
                logits=logits, multidimensional=True, validate_args=True)

            self.assertAllClose(new_p.eval(), p)
            self.assertAllClose(new_logits.eval(), logits)
  def testGetLogitsAndProbsLogitsMultidimensional(self):
    p = np.array([0.2, 0.3, 0.5], dtype=np.float32)
    logits = np.log(p)

    with self.test_session():
      new_logits, new_p = distribution_util.get_logits_and_probs(
          logits=logits, multidimensional=True, validate_args=True)

      self.assertAllClose(new_p.eval(), p)
      self.assertAllClose(new_logits.eval(), logits)
Beispiel #18
0
    def __init__(self,
                 logits=None,
                 probs=None,
                 dtype=dtypes.int32,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Bernoulli"):
        """Construct Bernoulli distributions.

    Args:
      logits: An N-D `Tensor` representing the log-odds of a `1` event. Each
        entry in the `Tensor` parametrizes an independent Bernoulli distribution
        where the probability of an event is sigmoid(logits). Only one of
        `logits` or `probs` should be passed in.
      probs: An N-D `Tensor` representing the probability of a `1`
        event. Each entry in the `Tensor` parameterizes an independent
        Bernoulli distribution. Only one of `logits` or `probs` should be passed
        in.
      dtype: The type of the event samples. Default: `int32`.
      validate_args: Python `Boolean`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `Boolean`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined.  When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: `String` name prefixed to Ops created by this class.

    Raises:
      ValueError: If p and logits are passed, or if neither are passed.
    """
        parameters = locals()
        parameters.pop("self")
        with ops.name_scope(name) as ns:
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits=logits,
                probs=probs,
                validate_args=validate_args,
                name=name)
        super(Bernoulli, self).__init__(
            dtype=dtype,
            is_continuous=False,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._logits, self._probs],
            name=ns)
    def __init__(self,
                 logits=None,
                 probs=None,
                 validate_args=True,
                 allow_nan_stats=False,
                 name="Geometric"):
        """Construct Geometric distributions.

    Args:
      logits: Floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0`
        indicates the number of batch dimensions. Each entry represents logits
        for the probability of success for independent Geometric distributions
        and must be in the range `(-inf, inf]`. Only one of `logits` or `probs`
        should be specified.
      probs: Positive floating-point `Tensor` with shape `[B1, ..., Bb]`
        where `b >= 0` indicates the number of batch dimensions. Each entry
        represents the probability of success for independent Geometric
        distributions and must be in the range `(0, 1]`. Only one of `logits`
        or `probs` should be specified.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """

        parameters = locals()
        with ops.name_scope(name, values=[logits, probs]) as ns:
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits, probs, validate_args=validate_args, name=name)

            with ops.control_dependencies(
                [check_ops.assert_positive(self._probs
                                           )] if validate_args else []):
                self._probs = array_ops.identity(self._probs, name="probs")

        super(Geometric, self).__init__(
            dtype=self._probs.dtype,
            is_continuous=False,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._probs, self._logits],
            name=ns)
Beispiel #20
0
  def __init__(self,
               logits=None,
               probs=None,
               dtype=dtypes.int32,
               validate_args=False,
               allow_nan_stats=True,
               name="Bernoulli"):
    """Construct Bernoulli distributions.

    Args:
      logits: An N-D `Tensor` representing the log-odds of a `1` event. Each
        entry in the `Tensor` parametrizes an independent Bernoulli distribution
        where the probability of an event is sigmoid(logits). Only one of
        `logits` or `probs` should be passed in.
      probs: An N-D `Tensor` representing the probability of a `1`
        event. Each entry in the `Tensor` parameterizes an independent
        Bernoulli distribution. Only one of `logits` or `probs` should be passed
        in.
      dtype: The type of the event samples. Default: `int32`.
      validate_args: Python `Boolean`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `Boolean`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined.  When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: `String` name prefixed to Ops created by this class.

    Raises:
      ValueError: If p and logits are passed, or if neither are passed.
    """
    parameters = locals()
    with ops.name_scope(name) as ns:
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits=logits,
          probs=probs,
          validate_args=validate_args,
          name=name)
    super(Bernoulli, self).__init__(
        dtype=dtype,
        is_continuous=False,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._logits, self._probs],
        name=ns)
Beispiel #21
0
  def __init__(self,
               logits=None,
               probs=None,
               validate_args=True,
               allow_nan_stats=False,
               name="Geometric"):
    """Construct Geometric distributions.

    Args:
      logits: Floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0`
        indicates the number of batch dimensions. Each entry represents logits
        for the probability of success for independent Geometric distributions
        and must be in the range `(-inf, inf]`. Only one of `logits` or `probs`
        should be specified.
      probs: Positive floating-point `Tensor` with shape `[B1, ..., Bb]`
        where `b >= 0` indicates the number of batch dimensions. Each entry
        represents the probability of success for independent Geometric
        distributions and must be in the range `(0, 1]`. Only one of `logits`
        or `probs` should be specified.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """

    parameters = locals()
    with ops.name_scope(name, values=[logits, probs]) as ns:
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits, probs, validate_args=validate_args, name=name)

      with ops.control_dependencies(
          [check_ops.assert_positive(self._probs)] if validate_args else []):
        self._probs = array_ops.identity(self._probs, name="probs")

    super(Geometric, self).__init__(
        dtype=self._probs.dtype,
        is_continuous=False,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._probs, self._logits],
        name=ns)
    def testGetLogitsAndProbsProbabilityValidateArgsMultidimensional(self):
        p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
        # Component less than 0. Still sums to 1.
        p2 = np.array([[-.3, 0.4, 0.9], [0.1, 0.5, 0.4]], dtype=np.float32)
        # Component greater than 1. Does not sum to 1.
        p3 = np.array([[1.3, 0.0, 0.0], [0.1, 0.5, 0.4]], dtype=np.float32)
        # Does not sum to 1.
        p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32)

        with self.test_session():
            _, prob = distribution_util.get_logits_and_probs(
                probs=p, multidimensional=True)
            prob.eval()

            with self.assertRaisesOpError("Condition x >= 0"):
                _, prob = distribution_util.get_logits_and_probs(
                    probs=p2, multidimensional=True, validate_args=True)
                prob.eval()

            _, prob = distribution_util.get_logits_and_probs(
                probs=p2, multidimensional=True, validate_args=False)
            prob.eval()

            with self.assertRaisesOpError(
                    "(probs has components greater than 1|probs does not sum to 1)"
            ):
                _, prob = distribution_util.get_logits_and_probs(
                    probs=p3, multidimensional=True, validate_args=True)
                prob.eval()

            _, prob = distribution_util.get_logits_and_probs(
                probs=p3, multidimensional=True, validate_args=False)
            prob.eval()

            with self.assertRaisesOpError("probs does not sum to 1"):
                _, prob = distribution_util.get_logits_and_probs(
                    probs=p4, multidimensional=True, validate_args=True)
                prob.eval()

            _, prob = distribution_util.get_logits_and_probs(
                probs=p4, multidimensional=True, validate_args=False)
            prob.eval()
  def testGetLogitsAndProbsProbabilityValidateArgsMultidimensional(self):
    p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
    # Component less than 0. Still sums to 1.
    p2 = np.array([[-.3, 0.4, 0.9], [0.1, 0.5, 0.4]], dtype=np.float32)
    # Component greater than 1. Does not sum to 1.
    p3 = np.array([[1.3, 0.0, 0.0], [0.1, 0.5, 0.4]], dtype=np.float32)
    # Does not sum to 1.
    p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32)

    with self.test_session():
      _, prob = distribution_util.get_logits_and_probs(
          probs=p, multidimensional=True)
      prob.eval()

      with self.assertRaisesOpError("Condition x >= 0"):
        _, prob = distribution_util.get_logits_and_probs(
            probs=p2, multidimensional=True, validate_args=True)
        prob.eval()

      _, prob = distribution_util.get_logits_and_probs(
          probs=p2, multidimensional=True, validate_args=False)
      prob.eval()

      with self.assertRaisesOpError(
          "(probs has components greater than 1|probs does not sum to 1)"):
        _, prob = distribution_util.get_logits_and_probs(
            probs=p3, multidimensional=True, validate_args=True)
        prob.eval()

      _, prob = distribution_util.get_logits_and_probs(
          probs=p3, multidimensional=True, validate_args=False)
      prob.eval()

      with self.assertRaisesOpError("probs does not sum to 1"):
        _, prob = distribution_util.get_logits_and_probs(
            probs=p4, multidimensional=True, validate_args=True)
        prob.eval()

      _, prob = distribution_util.get_logits_and_probs(
          probs=p4, multidimensional=True, validate_args=False)
      prob.eval()
    def __init__(self,
                 logits=None,
                 probs=None,
                 dtype=dtypes.int32,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="OneHotCategorical"):
        """Initialize OneHotCategorical distributions using class log-probabilities.

    Args:
      logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities of a
        set of Categorical distributions. The first `N - 1` dimensions index
        into a batch of independent distributions and the last dimension
        represents a vector of logits for each class. Only one of `logits` or
        `probs` should be passed in.
      probs: An N-D `Tensor`, `N >= 1`, representing the probabilities of a set
        of Categorical distributions. The first `N - 1` dimensions index into a
        batch of independent distributions and the last dimension represents a
        vector of probabilities for each class. Only one of `logits` or `probs`
        should be passed in.
      dtype: The type of the event samples (default: int32).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = locals()
        with ops.name_scope(name, values=[logits, probs]) as ns:
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                name=name,
                logits=logits,
                probs=probs,
                validate_args=validate_args,
                multidimensional=True)

            logits_shape_static = self._logits.get_shape().with_rank_at_least(
                1)
            if logits_shape_static.ndims is not None:
                self._batch_rank = ops.convert_to_tensor(
                    logits_shape_static.ndims - 1,
                    dtype=dtypes.int32,
                    name="batch_rank")
            else:
                with ops.name_scope(name="batch_rank"):
                    self._batch_rank = array_ops.rank(self._logits) - 1

            with ops.name_scope(name="event_size"):
                self._event_size = array_ops.shape(self._logits)[-1]

        super(OneHotCategorical, self).__init__(
            dtype=dtype,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._logits, self._probs],
            name=ns)