Ejemplo n.º 1
0
  def testGetLogitsAndProbProbabilityValidateArgs(self):
    p = [0.01, 0.2, 0.5, 0.7, .99]
    # Component less than 0.
    p2 = [-1, 0.2, 0.5, 0.3, .2]
    # Component greater than 1.
    p3 = [2, 0.2, 0.5, 0.3, .2]

    with self.test_session():
      _, prob = distribution_util.get_logits_and_prob(p=p, validate_args=True)
      prob.eval()

      with self.assertRaisesOpError("Condition x >= 0"):
        _, prob = distribution_util.get_logits_and_prob(
            p=p2, validate_args=True)
        prob.eval()

      _, prob = distribution_util.get_logits_and_prob(p=p2, validate_args=False)
      prob.eval()

      with self.assertRaisesOpError("p has components greater than 1"):
        _, prob = distribution_util.get_logits_and_prob(
            p=p3, validate_args=True)
        prob.eval()

      _, prob = distribution_util.get_logits_and_prob(p=p3, validate_args=False)
      prob.eval()
    def testGetLogitsAndProbProbabilityValidateArgs(self):
        p = [0.01, 0.2, 0.5, 0.7, .99]
        # Component less than 0.
        p2 = [-1, 0.2, 0.5, 0.3, .2]
        # Component greater than 1.
        p3 = [2, 0.2, 0.5, 0.3, .2]

        with self.test_session():
            _, prob = distribution_util.get_logits_and_prob(p=p,
                                                            validate_args=True)
            prob.eval()

            with self.assertRaisesOpError("Condition x >= 0"):
                _, prob = distribution_util.get_logits_and_prob(
                    p=p2, validate_args=True)
                prob.eval()

            _, prob = distribution_util.get_logits_and_prob(
                p=p2, validate_args=False)
            prob.eval()

            with self.assertRaisesOpError("p has components greater than 1"):
                _, prob = distribution_util.get_logits_and_prob(
                    p=p3, validate_args=True)
                prob.eval()

            _, prob = distribution_util.get_logits_and_prob(
                p=p3, validate_args=False)
            prob.eval()
Ejemplo n.º 3
0
  def testGetLogitsAndProbImproperArguments(self):
    with self.test_session():
      with self.assertRaises(ValueError):
        distribution_util.get_logits_and_prob(logits=None, p=None)

      with self.assertRaises(ValueError):
        distribution_util.get_logits_and_prob(logits=[0.1], p=[0.1])
    def testGetLogitsAndProbImproperArguments(self):
        with self.test_session():
            with self.assertRaises(ValueError):
                distribution_util.get_logits_and_prob(logits=None, p=None)

            with self.assertRaises(ValueError):
                distribution_util.get_logits_and_prob(logits=[0.1], p=[0.1])
Ejemplo n.º 5
0
  def __init__(self,
               n,
               logits=None,
               p=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Binomial"):
    """Initialize a batch of Binomial distributions.

    Args:
      n:  Non-negative floating point tensor with shape broadcastable to
        `[N1,..., Nm]` with `m >= 0` and the same dtype as `p` or `logits`.
        Defines this as a batch of `N1 x ... x Nm` different Binomial
        distributions. Its components should be equal to integer values.
      logits: Floating point tensor representing the log-odds of a
        positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and
        the same dtype as `n`. Each entry represents logits for the probability
        of success for independent Binomial distributions.
      p:  Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm]` `m >= 0`, `p in [0, 1]`. Each entry represents the
        probability of success for independent Binomial distributions.
      validate_args: `Boolean`, default `False`.  Whether to assert valid values
        for parameters `n`, `p`, and `x` in `prob` and `log_prob`.
        If `False` and inputs are invalid, correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prefix Ops created by this distribution class.

    Examples:

    ```python
    # Define 1-batch of a binomial distribution.
    dist = Binomial(n=2., p=.9)

    # Define a 2-batch.
    dist = Binomial(n=[4., 5], p=[.1, .3])
    ```

    """
    self._logits, self._p = distribution_util.get_logits_and_prob(
        name=name, logits=logits, p=p, validate_args=validate_args)
    with ops.name_scope(name, values=[n]) as ns:
      with ops.control_dependencies([
          check_ops.assert_non_negative(
              n, message="n has negative components."),
          distribution_util.assert_integer_form(
              n, message="n has non-integer components."),
      ] if validate_args else []):
        self._n = array_ops.identity(n, name="n")
        super(Binomial, self).__init__(
            dtype=self._p.dtype,
            parameters={"n": self._n, "p": self._p, "logits": self._logits},
            is_continuous=False,
            is_reparameterized=False,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name=ns)
Ejemplo n.º 6
0
  def testGetLogitsAndProbProbabilityMultidimensional(self):
    p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)

    with self.test_session():
      new_logits, new_p = distribution_util.get_logits_and_prob(
          p=p, multidimensional=True, validate_args=True)

      self.assertAllClose(special.logit(p), new_logits.eval())
      self.assertAllClose(p, new_p.eval())
Ejemplo n.º 7
0
  def testGetLogitsAndProbProbability(self):
    p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)

    with self.test_session():
      new_logits, new_p = distribution_util.get_logits_and_prob(
          p=p, validate_args=True)

      self.assertAllClose(special.logit(p), new_logits.eval())
      self.assertAllClose(p, new_p.eval())
    def testGetLogitsAndProbProbabilityMultidimensional(self):
        p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)

        with self.test_session():
            new_logits, new_p = distribution_util.get_logits_and_prob(
                p=p, multidimensional=True, validate_args=True)

            self.assertAllClose(np.log(p), new_logits.eval())
            self.assertAllClose(p, new_p.eval())
    def testGetLogitsAndProbProbability(self):
        p = np.array([0.01, 0.2, 0.5, 0.7, .99], dtype=np.float32)

        with self.test_session():
            new_logits, new_p = distribution_util.get_logits_and_prob(
                p=p, validate_args=True)

            self.assertAllClose(special.logit(p), new_logits.eval())
            self.assertAllClose(p, new_p.eval())
    def testGetLogitsAndProbLogitsMultidimensional(self):
        p = np.array([0.2, 0.3, 0.5], dtype=np.float32)
        logits = np.log(p)

        with self.test_session():
            new_logits, new_p = distribution_util.get_logits_and_prob(
                logits=logits, multidimensional=True, validate_args=True)

            self.assertAllClose(new_p.eval(), p)
            self.assertAllClose(new_logits.eval(), logits)
  def testGetLogitsAndProbLogitsMultidimensional(self):
    p = np.array([0.2, 0.3, 0.5], dtype=np.float32)
    logits = np.log(p)

    with self.test_session():
      new_logits, new_p = distribution_util.get_logits_and_prob(
          logits=logits, multidimensional=True, validate_args=True)

      self.assertAllClose(new_p.eval(), p)
      self.assertAllClose(new_logits.eval(), logits)
Ejemplo n.º 12
0
    def __init__(self,
                 logits=None,
                 p=None,
                 dtype=dtypes.int32,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Bernoulli"):
        """Construct Bernoulli distributions.

    Args:
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent Bernoulli distribution where the probability of an event
        is sigmoid(logits). Only one of `logits` or `p` should be passed in.
      p: An N-D `Tensor` representing the probability of a positive
          event. Each entry in the `Tensor` parameterizes an independent
          Bernoulli distribution. Only one of `logits` or `p` should be passed
          in.
      dtype: dtype for samples.
      validate_args: `Boolean`, default `False`.  Whether to validate that
        `0 <= p <= 1`. If `validate_args` is `False`, and the inputs are
        invalid, methods like `log_pmf` may return `NaN` values.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution.

    Raises:
      ValueError: If p and logits are passed, or if neither are passed.
    """
        parameters = locals()
        parameters.pop("self")
        with ops.name_scope(name) as ns:
            self._logits, self._p = distribution_util.get_logits_and_prob(
                logits=logits, p=p, validate_args=validate_args)
            with ops.name_scope("q"):
                self._q = 1. - self._p
        super(Bernoulli,
              self).__init__(dtype=dtype,
                             is_continuous=False,
                             is_reparameterized=False,
                             validate_args=validate_args,
                             allow_nan_stats=allow_nan_stats,
                             parameters=parameters,
                             graph_parents=[self._p, self._q, self._logits],
                             name=ns)
Ejemplo n.º 13
0
  def __init__(self,
               logits=None,
               p=None,
               dtype=dtypes.int32,
               validate_args=False,
               allow_nan_stats=True,
               name="Bernoulli"):
    """Construct Bernoulli distributions.

    Args:
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent Bernoulli distribution where the probability of an event
        is sigmoid(logits). Only one of `logits` or `p` should be passed in.
      p: An N-D `Tensor` representing the probability of a positive
          event. Each entry in the `Tensor` parameterizes an independent
          Bernoulli distribution. Only one of `logits` or `p` should be passed
          in.
      dtype: dtype for samples.
      validate_args: `Boolean`, default `False`.  Whether to validate that
        `0 <= p <= 1`. If `validate_args` is `False`, and the inputs are
        invalid, methods like `log_pmf` may return `NaN` values.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution.

    Raises:
      ValueError: If p and logits are passed, or if neither are passed.
    """
    parameters = locals()
    parameters.pop("self")
    with ops.name_scope(name) as ns:
      self._logits, self._p = distribution_util.get_logits_and_prob(
          logits=logits, p=p, validate_args=validate_args)
      with ops.name_scope("q"):
        self._q = 1. - self._p
    super(Bernoulli, self).__init__(
        dtype=dtype,
        is_continuous=False,
        is_reparameterized=False,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._p, self._q, self._logits],
        name=ns)
Ejemplo n.º 14
0
    def __init__(self,
                 logits=None,
                 p=None,
                 dtype=dtypes.int32,
                 validate_args=True,
                 allow_nan_stats=False,
                 name="Bernoulli"):
        """Construct Bernoulli distributions.

    Args:
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent Bernoulli distribution where the probability of an event
        is sigmoid(logits).
      p: An N-D `Tensor` representing the probability of a positive
          event. Each entry in the `Tensor` parameterizes an independent
          Bernoulli distribution.
      dtype: dtype for samples.
      validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args,
       `log_pmf` may return nans.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution.

    Raises:
      ValueError: If p and logits are passed, or if neither are passed.
    """
        self._logits, self._p = distribution_util.get_logits_and_prob(
            name=name, logits=logits, p=p, validate_args=validate_args)
        with ops.name_scope(name):
            with ops.name_scope("q"):
                self._q = 1. - self._p
                super(Bernoulli,
                      self).__init__(dtype=dtype,
                                     parameters={
                                         "p": self._p,
                                         "q": self._q,
                                         "logits": self._logits
                                     },
                                     is_continuous=False,
                                     validate_args=validate_args,
                                     allow_nan_stats=allow_nan_stats,
                                     name=name)
Ejemplo n.º 15
0
  def __init__(self,
               logits=None,
               p=None,
               dtype=dtypes.int32,
               validate_args=True,
               allow_nan_stats=False,
               name="Bernoulli"):
    """Construct Bernoulli distributions.

    Args:
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent Bernoulli distribution where the probability of an event
        is sigmoid(logits).
      p: An N-D `Tensor` representing the probability of a positive
          event. Each entry in the `Tensor` parameterizes an independent
          Bernoulli distribution.
      dtype: dtype for samples.
      validate_args: Whether to assert that `0 <= p <= 1`. If not validate_args,
       `log_pmf` may return nans.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution.

    Raises:
      ValueError: If p and logits are passed, or if neither are passed.
    """
    self._logits, self._p = distribution_util.get_logits_and_prob(
        name=name, logits=logits, p=p, validate_args=validate_args)
    with ops.name_scope(name):
      with ops.name_scope("q"):
        self._q = 1. - self._p
        super(Bernoulli, self).__init__(
            dtype=dtype,
            parameters={"p": self._p, "q": self._q, "logits": self._logits},
            is_continuous=False,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name=name)
Ejemplo n.º 16
0
  def testGetLogitsAndProbProbabilityValidateArgsMultidimensional(self):
    p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
    # Component less than 0. Still sums to 1.
    p2 = np.array([[-.3, 0.4, 0.9], [0.1, 0.5, 0.4]], dtype=np.float32)
    # Component greater than 1. Does not sum to 1.
    p3 = np.array([[1.3, 0.0, 0.0], [0.1, 0.5, 0.4]], dtype=np.float32)
    # Does not sum to 1.
    p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32)

    with self.test_session():
      _, prob = distribution_util.get_logits_and_prob(
          p=p, multidimensional=True)
      prob.eval()

      with self.assertRaisesOpError("Condition x >= 0"):
        _, prob = distribution_util.get_logits_and_prob(
            p=p2, multidimensional=True, validate_args=True)
        prob.eval()

      _, prob = distribution_util.get_logits_and_prob(
          p=p2, multidimensional=True, validate_args=False)
      prob.eval()

      with self.assertRaisesOpError(
          "(p has components greater than 1|p does not sum to 1)"):
        _, prob = distribution_util.get_logits_and_prob(
            p=p3, multidimensional=True, validate_args=True)
        prob.eval()

      _, prob = distribution_util.get_logits_and_prob(
          p=p3, multidimensional=True, validate_args=False)
      prob.eval()

      with self.assertRaisesOpError("p does not sum to 1"):
        _, prob = distribution_util.get_logits_and_prob(
            p=p4, multidimensional=True, validate_args=True)
        prob.eval()

      _, prob = distribution_util.get_logits_and_prob(
          p=p4, multidimensional=True, validate_args=False)
      prob.eval()
    def testGetLogitsAndProbProbabilityValidateArgsMultidimensional(self):
        p = np.array([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]], dtype=np.float32)
        # Component less than 0. Still sums to 1.
        p2 = np.array([[-.3, 0.4, 0.9], [0.1, 0.5, 0.4]], dtype=np.float32)
        # Component greater than 1. Does not sum to 1.
        p3 = np.array([[1.3, 0.0, 0.0], [0.1, 0.5, 0.4]], dtype=np.float32)
        # Does not sum to 1.
        p4 = np.array([[1.1, 0.3, 0.4], [0.1, 0.5, 0.4]], dtype=np.float32)

        with self.test_session():
            _, prob = distribution_util.get_logits_and_prob(
                p=p, multidimensional=True)
            prob.eval()

            with self.assertRaisesOpError("Condition x >= 0"):
                _, prob = distribution_util.get_logits_and_prob(
                    p=p2, multidimensional=True, validate_args=True)
                prob.eval()

            _, prob = distribution_util.get_logits_and_prob(
                p=p2, multidimensional=True, validate_args=False)
            prob.eval()

            with self.assertRaisesOpError(
                    "(p has components greater than 1|p does not sum to 1)"):
                _, prob = distribution_util.get_logits_and_prob(
                    p=p3, multidimensional=True, validate_args=True)
                prob.eval()

            _, prob = distribution_util.get_logits_and_prob(
                p=p3, multidimensional=True, validate_args=False)
            prob.eval()

            with self.assertRaisesOpError("p does not sum to 1"):
                _, prob = distribution_util.get_logits_and_prob(
                    p=p4, multidimensional=True, validate_args=True)
                prob.eval()

            _, prob = distribution_util.get_logits_and_prob(
                p=p4, multidimensional=True, validate_args=False)
            prob.eval()
Ejemplo n.º 18
0
  def __init__(self,
               temperature,
               logits=None,
               p=None,
               validate_args=False,
               allow_nan_stats=True,
               name="RelaxedBernoulli"):
    """Construct RelaxedBernoulli distributions.

    Args:
      temperature: An 0-D `Tensor`, representing the temperature
        of a set of RelaxedBernoulli distributions. The temperature should be
        positive.
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent RelaxedBernoulli distribution where the probability of an
        event is sigmoid(logits). Only one of `logits` or `p` should be passed
        in.
      p: An N-D `Tensor` representing the probability of a positive
        event. Each entry in the `Tensor` parameterizes an independent
        Bernoulli distribution. Only one of `logits` or `p` should be passed
        in.
      validate_args: `Boolean`, default `False`.  Whether to validate that
        `0 <= p <= 1`. If `validate_args` is `False`, and the inputs are
        invalid, methods like `log_pmf` may return `NaN` values.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution.

    Raises:
      ValueError: If p and logits are passed, or if neither are passed.
    """
    parameters = locals()
    parameters.pop("self")
    with ops.name_scope(name, values=[logits, p, temperature]) as ns:
      with ops.control_dependencies([check_ops.assert_positive(temperature)]
                                    if validate_args else []):
        self._temperature = array_ops.identity(temperature, name="temperature")

      self._logits, self._p = distribution_util.get_logits_and_prob(
          logits=logits, p=p, validate_args=validate_args)
      with ops.name_scope("q"):
        self._q = 1. - self._p
      dist = logistic._Logistic(self._logits / self._temperature,
                                1./self._temperature,
                                validate_args=validate_args,
                                allow_nan_stats=allow_nan_stats,
                                name=ns)

    def inverse_log_det_jacobian_fn(y):
      return -math_ops.reduce_sum(math_ops.log(y) + math_ops.log(1-y),
                                  reduction_indices=-1)

    sigmoidbijector = bijector.Inline(
        forward_fn=math_ops.sigmoid,
        inverse_fn=(lambda y: math_ops.log(y) - math_ops.log(1-y)),
        inverse_log_det_jacobian_fn=inverse_log_det_jacobian_fn,
        name="sigmoid")
    super(_RelaxedBernoulli, self).__init__(dist,
                                            sigmoidbijector,
                                            name=name)
Ejemplo n.º 19
0
  def __init__(self,
               n,
               logits=None,
               p=None,
               validate_args=False,
               allow_nan_stats=True,
               name="Multinomial"):
    """Initialize a batch of Multinomial distributions.

    Args:
      n:  Non-negative floating point tensor with shape broadcastable to
        `[N1,..., Nm]` with `m >= 0`. Defines this as a batch of
        `N1 x ... x Nm` different Multinomial distributions.  Its components
        should be equal to integer values.
      logits: Floating point tensor representing the log-odds of a
        positive event with shape broadcastable to `[N1,..., Nm, k], m >= 0`,
        and the same dtype as `n`. Defines this as a batch of `N1 x ... x Nm`
        different `k` class Multinomial distributions.
      p:  Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm, k]` `m >= 0` and same dtype as `n`.  Defines this as
        a batch of `N1 x ... x Nm` different `k` class Multinomial
        distributions. `p`'s components in the last portion of its shape should
        sum up to 1.
      validate_args: `Boolean`, default `False`.  Whether to assert valid
        values for parameters `n` and `p`, and `x` in `prob` and `log_prob`.
        If `False`, correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prefix Ops created by this distribution class.

    Examples:

    ```python
    # Define 1-batch of 2-class multinomial distribution,
    # also known as a Binomial distribution.
    dist = Multinomial(n=2., p=[.1, .9])

    # Define a 2-batch of 3-class distributions.
    dist = Multinomial(n=[4., 5], p=[[.1, .3, .6], [.4, .05, .55]])
    ```

    """

    self._logits, self._p = distribution_util.get_logits_and_prob(
        name=name, logits=logits, p=p, validate_args=validate_args,
        multidimensional=True)
    with ops.name_scope(name, values=[n, self._p]) as ns:
      with ops.control_dependencies([
          check_ops.assert_non_negative(
              n, message="n has negative components."),
          distribution_util.assert_integer_form(
              n, message="n has non-integer components.")
      ] if validate_args else []):
        self._n = array_ops.identity(n, name="convert_n")
        self._mean_val = array_ops.expand_dims(n, -1) * self._p
        self._broadcast_shape = math_ops.reduce_sum(
            self._mean_val, reduction_indices=[-1], keep_dims=False)
        super(Multinomial, self).__init__(
            dtype=self._p.dtype,
            parameters={"p": self._p,
                        "n": self._n,
                        "mean": self._mean,
                        "logits": self._logits,
                        "broadcast_shape": self._broadcast_shape},
            is_continuous=False,
            is_reparameterized=False,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name=ns)
Ejemplo n.º 20
0
  def __init__(
      self,
      logits=None,
      p=None,
      dtype=dtypes.int32,
      validate_args=False,
      allow_nan_stats=True,
      name="OneHotCategorical"):
    """Initialize OneHotCategorical distributions using class log-probabilities.

    Args:
      logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
        of a set of Categorical distributions. The first `N - 1` dimensions
        index into a batch of independent distributions and the last dimension
        represents a vector of logits for each class. Only one of `logits` or
        `p` should be passed in.
      p: An N-D `Tensor`, `N >= 1`, representing the probabilities
        of a set of Categorical distributions. The first `N - 1` dimensions
        index into a batch of independent distributions and the last dimension
        represents a vector of probabilities for each class. Only one of
        `logits` or `p` should be passed in.
      dtype: The type of the event samples (default: int32).
      validate_args: Unused in this distribution.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution (optional).
    """
    parameters = locals()
    parameters.pop("self")
    with ops.name_scope(name, values=[logits]) as ns:
      self._logits, self._p = distribution_util.get_logits_and_prob(
          name=name, logits=logits, p=p, validate_args=validate_args,
          multidimensional=True)

      logits_shape_static = self._logits.get_shape().with_rank_at_least(1)
      if logits_shape_static.ndims is not None:
        self._batch_rank = ops.convert_to_tensor(
            logits_shape_static.ndims - 1,
            dtype=dtypes.int32,
            name="batch_rank")
      else:
        with ops.name_scope(name="batch_rank"):
          self._batch_rank = array_ops.rank(self._logits) - 1

      logits_shape = array_ops.shape(self._logits, name="logits_shape")
      if logits_shape_static[-1].value is not None:
        self._num_classes = ops.convert_to_tensor(
            logits_shape_static[-1].value,
            dtype=dtypes.int32,
            name="num_classes")
      else:
        self._num_classes = array_ops.gather(logits_shape,
                                             self._batch_rank,
                                             name="num_classes")

      if logits_shape_static[:-1].is_fully_defined():
        self._batch_shape_val = constant_op.constant(
            logits_shape_static[:-1].as_list(),
            dtype=dtypes.int32,
            name="batch_shape")
      else:
        with ops.name_scope(name="batch_shape"):
          self._batch_shape_val = logits_shape[:-1]
    super(_OneHotCategorical, self).__init__(
        dtype=dtype,
        is_continuous=False,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._logits, self._num_classes],
        name=ns)
    def __init__(self,
                 temperature,
                 logits=None,
                 p=None,
                 dtype=dtypes.float32,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="ExpRelaxedOneHotCategorical"):
        """Initialize ExpRelaxedOneHotCategorical using class log-probabilities.

    Args:
      temperature: An 0-D `Tensor`, representing the temperature
        of a set of ExpRelaxedCategorical distributions. The temperature should
        be positive.
      logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
        of a set of ExpRelaxedCategorical distributions. The first
        `N - 1` dimensions index into a batch of independent distributions and
        the last dimension represents a vector of logits for each class. Only
        one of `logits` or `p` should be passed in.
      p: An N-D `Tensor`, `N >= 1`, representing the probabilities
        of a set of ExpRelaxedCategorical distributions. The first
        `N - 1` dimensions index into a batch of independent distributions and
        the last dimension represents a vector of probabilities for each
        class. Only one of `logits` or `p` should be passed in.
      dtype: The type of the event samples (default: int32).
      validate_args: Unused in this distribution.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution (optional).
    """
        parameters = locals()
        parameters.pop("self")
        with ops.name_scope(name, values=[logits, p, temperature]) as ns:
            with ops.control_dependencies(
                [check_ops.assert_positive(temperature
                                           )] if validate_args else []):
                self._temperature = array_ops.identity(temperature,
                                                       name="temperature")
            self._logits, self._p = distribution_util.get_logits_and_prob(
                name=name,
                logits=logits,
                p=p,
                validate_args=validate_args,
                multidimensional=True)

            logits_shape_static = self._logits.get_shape().with_rank_at_least(
                1)
            if logits_shape_static.ndims is not None:
                self._batch_rank = ops.convert_to_tensor(
                    logits_shape_static.ndims - 1,
                    dtype=dtypes.int32,
                    name="batch_rank")
            else:
                with ops.name_scope(name="batch_rank"):
                    self._batch_rank = array_ops.rank(self._logits) - 1

            logits_shape = array_ops.shape(self._logits, name="logits_shape")
            if logits_shape_static[-1].value is not None:
                self._num_classes = ops.convert_to_tensor(
                    logits_shape_static[-1].value,
                    dtype=dtypes.int32,
                    name="num_classes")
            else:
                self._num_classes = array_ops.gather(logits_shape,
                                                     self._batch_rank,
                                                     name="num_classes")

            if logits_shape_static[:-1].is_fully_defined():
                self._batch_shape_val = constant_op.constant(
                    logits_shape_static[:-1].as_list(),
                    dtype=dtypes.int32,
                    name="batch_shape")
            else:
                with ops.name_scope(name="batch_shape"):
                    self._batch_shape_val = logits_shape[:-1]
        super(_ExpRelaxedOneHotCategorical, self).__init__(
            dtype=dtype,
            is_continuous=True,
            is_reparameterized=True,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._logits, self._temperature, self._num_classes],
            name=ns)
Ejemplo n.º 22
0
    def __init__(self,
                 temperature,
                 logits=None,
                 p=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="RelaxedBernoulli"):
        """Construct RelaxedBernoulli distributions.

    Args:
      temperature: An 0-D `Tensor`, representing the temperature
        of a set of RelaxedBernoulli distributions. The temperature should be
        positive.
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent RelaxedBernoulli distribution where the probability of an
        event is sigmoid(logits). Only one of `logits` or `p` should be passed
        in.
      p: An N-D `Tensor` representing the probability of a positive
        event. Each entry in the `Tensor` parameterizes an independent
        Bernoulli distribution. Only one of `logits` or `p` should be passed
        in.
      validate_args: `Boolean`, default `False`.  Whether to validate that
        `0 <= p <= 1`. If `validate_args` is `False`, and the inputs are
        invalid, methods like `log_pmf` may return `NaN` values.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution.

    Raises:
      ValueError: If p and logits are passed, or if neither are passed.
    """
        parameters = locals()
        parameters.pop("self")
        with ops.name_scope(name, values=[logits, p, temperature]) as ns:
            with ops.control_dependencies(
                [check_ops.assert_positive(temperature
                                           )] if validate_args else []):
                self._temperature = array_ops.identity(temperature,
                                                       name="temperature")

            self._logits, self._p = distribution_util.get_logits_and_prob(
                logits=logits, p=p, validate_args=validate_args)
            with ops.name_scope("q"):
                self._q = 1. - self._p
            dist = logistic._Logistic(self._logits / self._temperature,
                                      1. / self._temperature,
                                      validate_args=validate_args,
                                      allow_nan_stats=allow_nan_stats,
                                      name=ns)

        def inverse_log_det_jacobian_fn(y):
            return -math_ops.reduce_sum(math_ops.log(y) + math_ops.log(1 - y),
                                        reduction_indices=-1)

        sigmoidbijector = bijector.Inline(
            forward_fn=math_ops.sigmoid,
            inverse_fn=(lambda y: math_ops.log(y) - math_ops.log(1 - y)),
            inverse_log_det_jacobian_fn=inverse_log_det_jacobian_fn,
            name="sigmoid")
        super(_RelaxedBernoulli, self).__init__(dist,
                                                sigmoidbijector,
                                                name=name)