Example #1
0
  def __init__(self, alpha, beta, name="Gamma"):
    """Construct Gamma distributions with parameters `alpha` and `beta`.

    The parameters `alpha` and `beta` must be shaped in a way that supports
    broadcasting (e.g. `alpha + beta` is a valid operation).

    Args:
      alpha: `float` or `double` tensor, the shape params of the
        distribution(s).
        alpha must contain only positive values.
      beta: `float` or `double` tensor, the inverse scale params of the
        distribution(s).
        beta must contain only positive values.
      name: The name to prepend to all ops created by this distribution.

    Raises:
      TypeError: if `alpha` and `beta` are different dtypes.
    """
    with ops.op_scope([alpha, beta], name):
      with ops.control_dependencies([
          check_ops.assert_positive(alpha), check_ops.assert_positive(beta)]):
        alpha = array_ops.identity(alpha, name="alpha")
        beta = array_ops.identity(beta, name="beta")

        contrib_tensor_util.assert_same_float_dtype((alpha, beta))
        self._broadcast_tensor = alpha + beta

    self._get_batch_shape = self._broadcast_tensor.get_shape()
    self._get_event_shape = tensor_shape.TensorShape([])

    self._alpha = alpha
    self._beta = beta
    self._name = name
Example #2
0
  def __init__(self, df, mu, sigma, name="StudentT"):
    """Construct Student's t distributions.

    The distributions have degree of freedom `df`, mean `mu`, and scale `sigma`.

    The parameters `df`, `mu`, and `sigma` must be shaped in a way that supports
    broadcasting (e.g. `df + mu + sigma` is a valid operation).

    Args:
      df: `float` or `double` tensor, the degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      mu: `float` or `double` tensor, the means of the distribution(s).
      sigma: `float` or `double` tensor, the scaling factor for the
        distribution(s). `sigma` must contain only positive values.
        Note that `sigma` is not the standard deviation of this distribution.
      name: The name to give Ops created by the initializer.

    Raises:
      TypeError: if mu and sigma are different dtypes.
    """
    super(StudentT, self).__init__()
    with ops.op_scope([df, mu, sigma], name) as scope:
      with ops.control_dependencies([check_ops.assert_positive(df),
                                     check_ops.assert_positive(sigma)]):
        self._df = ops.convert_to_tensor(df, name="df")
        self._mu = ops.convert_to_tensor(mu, name="mu")
        self._sigma = ops.convert_to_tensor(sigma, name="sigma")
        contrib_tensor_util.assert_same_float_dtype(
            (self._df, self._mu, self._sigma))
      self._name = scope
      self._get_batch_shape = self._ones().get_shape()
      self._get_event_shape = tensor_shape.TensorShape([])
Example #3
0
  def __init__(self,
               df,
               mu,
               sigma,
               validate_args=False,
               allow_nan_stats=True,
               name="StudentT"):
    """Construct Student's t distributions.

    The distributions have degree of freedom `df`, mean `mu`, and scale `sigma`.

    The parameters `df`, `mu`, and `sigma` must be shaped in a way that supports
    broadcasting (e.g. `df + mu + sigma` is a valid operation).

    Args:
      df: Floating point tensor, the degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      mu: Floating point tensor, the means of the distribution(s).
      sigma: Floating point tensor, the scaling factor for the
        distribution(s). `sigma` must contain only positive values.
        Note that `sigma` is not the standard deviation of this distribution.
      validate_args: `Boolean`, default `False`.  Whether to assert that
        `df > 0` and `sigma > 0`. If `validate_args` is `False` and inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      TypeError: if mu and sigma are different dtypes.
    """
    parameters = locals()
    parameters.pop("self")
    with ops.name_scope(name, values=[df, mu, sigma]) as ns:
      with ops.control_dependencies([
          check_ops.assert_positive(df),
          check_ops.assert_positive(sigma),
      ] if validate_args else []):
        self._df = array_ops.identity(df, name="df")
        self._mu = array_ops.identity(mu, name="mu")
        self._sigma = array_ops.identity(sigma, name="sigma")
        contrib_tensor_util.assert_same_float_dtype(
            (self._df, self._mu, self._sigma))
    super(StudentT, self).__init__(
        dtype=self._sigma.dtype,
        is_continuous=True,
        is_reparameterized=True,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._df, self._mu, self._sigma],
        name=ns)
Example #4
0
  def __init__(self, a, b, validate_args=False, allow_nan_stats=True,
               name="Beta"):
    """Initialize a batch of Beta distributions.

    Args:
      a:  Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm]` `m >= 0`.  Defines this as a batch of `N1 x ... x Nm`
         different Beta distributions. This also defines the
         dtype of the distribution.
      b:  Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm]` `m >= 0`.  Defines this as a batch of `N1 x ... x Nm`
         different Beta distributions.
      validate_args: `Boolean`, default `False`.  Whether to assert valid
        values for parameters `a`, `b`, and `x` in `prob` and `log_prob`.
        If `False` and inputs are invalid, correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prefix Ops created by this distribution class.

    Examples:

    ```python
    # Define 1-batch.
    dist = Beta(1.1, 2.0)

    # Define a 2-batch.
    dist = Beta([1.0, 2.0], [4.0, 5.0])
    ```

    """
    parameters = locals()
    parameters.pop("self")
    with ops.name_scope(name, values=[a, b]) as ns:
      with ops.control_dependencies([
          check_ops.assert_positive(a),
          check_ops.assert_positive(b),
      ] if validate_args else []):
        self._a = array_ops.identity(a, name="a")
        self._b = array_ops.identity(b, name="b")
        contrib_tensor_util.assert_same_float_dtype((self._a, self._b))
        # Used for mean/mode/variance/entropy/sampling computations
        self._a_b_sum = self._a + self._b
    super(Beta, self).__init__(
        dtype=self._a_b_sum.dtype,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        is_continuous=True,
        is_reparameterized=False,
        parameters=parameters,
        graph_parents=[self._a, self._b, self._a_b_sum],
        name=ns)
Example #5
0
  def __init__(self,
               concentration,
               rate,
               validate_args=False,
               allow_nan_stats=True,
               name="InverseGamma"):
    """Construct InverseGamma with `concentration` and `rate` parameters.

    The parameters `concentration` and `rate` must be shaped in a way that
    supports broadcasting (e.g. `concentration + rate` is a valid operation).

    Args:
      concentration: Floating point tensor, the concentration params of the
        distribution(s). Must contain only positive values.
      rate: Floating point tensor, the inverse scale params of the
        distribution(s). Must contain only positive values.
      validate_args: Python `Boolean`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined.  When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: `String` name prefixed to Ops created by this class.


    Raises:
      TypeError: if `concentration` and `rate` are different dtypes.
    """
    parameters = locals()
    with ops.name_scope(name, values=[concentration, rate]) as ns:
      with ops.control_dependencies([
          check_ops.assert_positive(concentration),
          check_ops.assert_positive(rate),
      ] if validate_args else []):
        self._concentration = array_ops.identity(
            concentration, name="concentration")
        self._rate = array_ops.identity(rate, name="rate")
        contrib_tensor_util.assert_same_float_dtype(
            [self._concentration, self._rate])
    super(InverseGamma, self).__init__(
        dtype=self._concentration.dtype,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        is_continuous=True,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        parameters=parameters,
        graph_parents=[self._concentration,
                       self._rate],
        name=ns)
Example #6
0
  def __init__(self,
               alpha,
               beta,
               validate_args=False,
               allow_nan_stats=True,
               name="Gamma"):
    """Construct Gamma distributions with parameters `alpha` and `beta`.

    The parameters `alpha` and `beta` must be shaped in a way that supports
    broadcasting (e.g. `alpha + beta` is a valid operation).

    Args:
      alpha: Floating point tensor, the shape params of the
        distribution(s).
        alpha must contain only positive values.
      beta: Floating point tensor, the inverse scale params of the
        distribution(s).
        beta must contain only positive values.
      validate_args: `Boolean`, default `False`.  Whether to assert that
        `a > 0`, `b > 0`, and that `x > 0` in the methods `prob(x)` and
        `log_prob(x)`.  If `validate_args` is `False` and the inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prepend to all ops created by this distribution.

    Raises:
      TypeError: if `alpha` and `beta` are different dtypes.
    """
    parameters = locals()
    parameters.pop("self")
    with ops.name_scope(name, values=[alpha, beta]) as ns:
      with ops.control_dependencies([
          check_ops.assert_positive(alpha),
          check_ops.assert_positive(beta),
      ] if validate_args else []):
        self._alpha = array_ops.identity(alpha, name="alpha")
        self._beta = array_ops.identity(beta, name="beta")
        contrib_tensor_util.assert_same_float_dtype((self._alpha, self._beta))
    super(Gamma, self).__init__(
        dtype=self._alpha.dtype,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        is_continuous=True,
        is_reparameterized=False,
        parameters=parameters,
        graph_parents=[self._alpha, self._beta],
        name=ns)
Example #7
0
  def __init__(self, a, b, validate_args=True, allow_nan_stats=False,
               name="Beta"):
    """Initialize a batch of Beta distributions.

    Args:
      a:  Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm]` `m >= 0`.  Defines this as a batch of `N1 x ... x Nm`
         different Beta distributions. This also defines the
         dtype of the distribution.
      b:  Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm]` `m >= 0`.  Defines this as a batch of `N1 x ... x Nm`
         different Beta distributions.
      validate_args: Whether to assert valid values for parameters `a` and `b`,
        and `x` in `prob` and `log_prob`.  If `False`, correct behavior is not
        guaranteed.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prefix Ops created by this distribution class.

    Examples:

    ```python
    # Define 1-batch.
    dist = Beta(1.1, 2.0)

    # Define a 2-batch.
    dist = Beta([1.0, 2.0], [4.0, 5.0])
    ```

    """
    with ops.name_scope(name, values=[a, b]):
      with ops.control_dependencies([
          check_ops.assert_positive(a),
          check_ops.assert_positive(b)] if validate_args else []):
        a = array_ops.identity(a, name="a")
        b = array_ops.identity(b, name="b")

      self._a = a
      self._b = b
      self._name = name

      # Used for mean/mode/variance/entropy/sampling computations
      self._a_b_sum = self._a + self._b

      self._get_batch_shape = self._a_b_sum.get_shape()
      self._get_event_shape = tensor_shape.TensorShape([])
      self._validate_args = validate_args
      self._allow_nan_stats = allow_nan_stats
Example #8
0
  def __init__(self,
               df,
               mu,
               sigma,
               validate_args=True,
               allow_nan_stats=False,
               name="StudentT"):
    """Construct Student's t distributions.

    The distributions have degree of freedom `df`, mean `mu`, and scale `sigma`.

    The parameters `df`, `mu`, and `sigma` must be shaped in a way that supports
    broadcasting (e.g. `df + mu + sigma` is a valid operation).

    Args:
      df: Floating point tensor, the degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      mu: Floating point tensor, the means of the distribution(s).
      sigma: Floating point tensor, the scaling factor for the
        distribution(s). `sigma` must contain only positive values.
        Note that `sigma` is not the standard deviation of this distribution.
      validate_args: Whether to assert that `df > 0, sigma > 0`. If
        `validate_args` is `False` and inputs are invalid, correct behavior is
        not guaranteed.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      TypeError: if mu and sigma are different dtypes.
    """
    self._allow_nan_stats = allow_nan_stats
    self._validate_args = validate_args
    with ops.name_scope(name, values=[df, mu, sigma]) as scope:
      with ops.control_dependencies([check_ops.assert_positive(
          df), check_ops.assert_positive(sigma)] if validate_args else []):
        self._df = ops.convert_to_tensor(df, name="df")
        self._mu = ops.convert_to_tensor(mu, name="mu")
        self._sigma = ops.convert_to_tensor(sigma, name="sigma")
        contrib_tensor_util.assert_same_float_dtype(
            (self._df, self._mu, self._sigma))
      self._name = scope
      self._get_batch_shape = common_shapes.broadcast_shape(
          self._sigma.get_shape(), common_shapes.broadcast_shape(
              self._df.get_shape(), self._mu.get_shape()))
      self._get_event_shape = tensor_shape.TensorShape([])
Example #9
0
    def __init__(self, mu, sigma, validate_args=True, allow_nan_stats=False, name="Normal"):
        """Construct Normal distributions with mean and stddev `mu` and `sigma`.

    The parameters `mu` and `sigma` must be shaped in a way that supports
    broadcasting (e.g. `mu + sigma` is a valid operation).

    Args:
      mu: Floating point tensor, the means of the distribution(s).
      sigma: Floating point tensor, the stddevs of the distribution(s).
        sigma must contain only positive values.
      validate_args: Whether to assert that `sigma > 0`. If `validate_args` is
        `False`, correct output is not guaranteed when input is invalid.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      TypeError: if mu and sigma are different dtypes.
    """
        with ops.name_scope(name, values=[mu, sigma]):
            with ops.control_dependencies([check_ops.assert_positive(sigma)] if validate_args else []):
                self._mu = array_ops.identity(mu, name="mu")
                self._sigma = array_ops.identity(sigma, name="sigma")
                contrib_tensor_util.assert_same_float_dtype((self._mu, self._sigma))
                super(Normal, self).__init__(
                    dtype=self._sigma.dtype,
                    parameters={"mu": self._mu, "sigma": self._sigma},
                    is_reparameterized=True,
                    validate_args=validate_args,
                    allow_nan_stats=allow_nan_stats,
                    name=name,
                )
Example #10
0
  def log_prob(self, x, name="log_prob"):
    """Log prob of observations in `x` under these Gamma distribution(s).

    Args:
      x: tensor of dtype `dtype`, must be broadcastable with `alpha` and `beta`.
      name: The name to give this op.

    Returns:
      log_prob: tensor of dtype `dtype`, the log-PDFs of `x`.

    Raises:
      TypeError: if `x` and `alpha` are different dtypes.
    """
    with ops.name_scope(self.name):
      with ops.op_scope([self._alpha, self._beta, x], name):
        alpha = self._alpha
        beta = self._beta
        x = ops.convert_to_tensor(x)
        x = control_flow_ops.with_dependencies(
            [check_ops.assert_positive(x)] if self.strict else [],
            x)
        contrib_tensor_util.assert_same_float_dtype(tensors=[x,],
                                                    dtype=self.dtype)

        return (alpha * math_ops.log(beta) + (alpha - 1) * math_ops.log(x) -
                beta * x - math_ops.lgamma(self._alpha))
 def _check_alpha(self, alpha):
   alpha = ops.convert_to_tensor(alpha, name='alpha')
   if not self.strict:
     return alpha
   return control_flow_ops.with_dependencies(
       [check_ops.assert_rank_at_least(alpha, 1),
        check_ops.assert_positive(alpha)], alpha)
 def _assert_valid_alpha(self, alpha, validate_args):
   alpha = ops.convert_to_tensor(alpha, name="alpha")
   if not validate_args:
     return alpha
   return control_flow_ops.with_dependencies(
       [check_ops.assert_rank_at_least(alpha, 1),
        check_ops.assert_positive(alpha)], alpha)
Example #13
0
  def __init__(self,
               lam,
               validate_args=True,
               allow_nan_stats=False,
               name="Poisson"):
    """Construct Poisson distributions.

    Args:
      lam: Floating point tensor, the rate parameter of the
        distribution(s). `lam` must be positive.
      validate_args: Whether to assert that `lam > 0` as well as inputs to
        pmf computations are non-negative integers. If validate_args is
        `False`, then `pmf` computations might return NaN, as well as
        can be evaluated at any real value.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution.
    """
    with ops.name_scope(name, values=[lam]) as scope:
      self._name = scope
      with ops.control_dependencies(
          [check_ops.assert_positive(lam)] if validate_args else []):
        self._lam = array_ops.identity(lam, name="lam")
        self._validate_args = validate_args
        self._allow_nan_stats = allow_nan_stats
Example #14
0
  def __init__(self,
               lam,
               validate_args=False,
               allow_nan_stats=True,
               name="Poisson"):
    """Construct Poisson distributions.

    Args:
      lam: Floating point tensor, the rate parameter of the
        distribution(s). `lam` must be positive.
      validate_args: `Boolean`, default `False`.  Whether to assert that
        `lam > 0` as well as inputs to pmf computations are non-negative
        integers. If validate_args is `False`, then `pmf` computations might
        return `NaN`, but can be evaluated at any real value.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution.
    """
    parameters = locals()
    parameters.pop("self")
    with ops.name_scope(name, values=[lam]) as ns:
      with ops.control_dependencies([check_ops.assert_positive(lam)] if
                                    validate_args else []):
        self._lam = array_ops.identity(lam, name="lam")
    super(Poisson, self).__init__(
        dtype=self._lam.dtype,
        is_continuous=False,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._lam],
        name=ns)
Example #15
0
  def __init__(self,
               skewness=0.,
               tailweight=1.,
               event_ndims=0,
               validate_args=False,
               name="sinh_arcsinh"):
    """Instantiates the `SinhArcsinh` bijector.

    Args:
      skewness:  Skewness parameter.  Float-type `Tensor`.
      tailweight:  Tailweight parameter.  Positive `Tensor` of same `dtype` as
        `skewness`
        and broadcastable `shape`.
      event_ndims: Python scalar indicating the number of dimensions associated
        with a particular draw from the distribution.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
    """
    self._graph_parents = []
    self._name = name
    self._validate_args = validate_args
    with self._name_scope("init", values=[skewness, tailweight]):
      self._skewness = ops.convert_to_tensor(skewness, name="skewness")
      self._tailweight = ops.convert_to_tensor(tailweight, name="tailweight")
      check_ops.assert_same_float_dtype([self._skewness, self._tailweight])
      if validate_args:
        self._tailweight = control_flow_ops.with_dependencies([
            check_ops.assert_positive(
                self._tailweight,
                message="Argument tailweight was not positive")
        ], self._tailweight)
    super(SinhArcsinh, self).__init__(
        event_ndims=event_ndims, validate_args=validate_args, name=name)
 def _check_diag(self, diag):
   """Verify that `diag` is positive."""
   diag = ops.convert_to_tensor(diag, name="diag")
   if not self.verify_pd:
     return diag
   deps = [check_ops.assert_positive(diag)]
   return control_flow_ops.with_dependencies(deps, diag)
Example #17
0
 def _log_cdf(self, x):
   x = control_flow_ops.with_dependencies([check_ops.assert_positive(x)] if
                                          self.validate_args else [], x)
   contrib_tensor_util.assert_same_float_dtype(tensors=[x], dtype=self.dtype)
   # Note that igamma returns the regularized incomplete gamma function,
   # which is what we want for the CDF.
   return math_ops.log(math_ops.igamma(self.alpha, self.beta * x))
Example #18
0
  def __init__(self, mu, sigma, name="Normal"):
    """Construct Normal distributions with mean and stddev `mu` and `sigma`.

    The parameters `mu` and `sigma` must be shaped in a way that supports
    broadcasting (e.g. `mu + sigma` is a valid operation).

    Args:
      mu: `float` or `double` tensor, the means of the distribution(s).
      sigma: `float` or `double` tensor, the stddevs of the distribution(s).
        sigma must contain only positive values.
      name: The name to give Ops created by the initializer.

    Raises:
      TypeError: if mu and sigma are different dtypes.
    """
    with ops.op_scope([mu, sigma], name):
      mu = ops.convert_to_tensor(mu)
      sigma = ops.convert_to_tensor(sigma)
      with ops.control_dependencies([check_ops.assert_positive(sigma)]):
        self._name = name
        self._mu = array_ops.identity(mu, name="mu")
        self._sigma = array_ops.identity(sigma, name="sigma")
        self._batch_shape = self._ones().get_shape()
        self._event_shape = tensor_shape.TensorShape([])

    contrib_tensor_util.assert_same_float_dtype((mu, sigma))
Example #19
0
  def __init__(self,
               scale,
               validate_args=False,
               allow_nan_stats=True,
               name="HalfNormal"):
    """Construct HalfNormals with scale `scale`.

    Args:
      scale: Floating point tensor; the scales of the distribution(s).
        Must contain only positive values.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = locals()
    with ops.name_scope(name, values=[scale]) as name:
      with ops.control_dependencies([check_ops.assert_positive(scale)] if
                                    validate_args else []):
        self._scale = array_ops.identity(scale, name="scale")
    super(HalfNormal, self).__init__(
        dtype=self._scale.dtype,
        reparameterization_type=distribution.FULLY_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._scale],
        name=name)
Example #20
0
 def _maybe_assert_valid_sample(self, x):
   check_ops.assert_same_float_dtype(tensors=[x], dtype=self.dtype)
   if not self.validate_args:
     return x
   return control_flow_ops.with_dependencies([
       check_ops.assert_positive(x),
   ], x)
Example #21
0
 def _maybe_assert_valid_x(self, x):
   if not self.validate_args:
     return x
   is_valid = check_ops.assert_positive(
       x[..., 1:] - x[..., :-1],
       message="Forward transformation input must be strictly increasing.")
   return control_flow_ops.with_dependencies([is_valid], x)
Example #22
0
def calculate_reshape(original_shape, new_shape, validate=False, name=None):
  """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
  batch_shape_static = tensor_util.constant_value_as_shape(new_shape)
  if batch_shape_static.is_fully_defined():
    return np.int32(batch_shape_static.as_list()), batch_shape_static, []
  with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]):
    original_size = math_ops.reduce_prod(original_shape)
    implicit_dim = math_ops.equal(new_shape, -1)
    size_implicit_dim = (
        original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape)))
    new_ndims = array_ops.shape(new_shape)
    expanded_new_shape = array_ops.where(  # Assumes exactly one `-1`.
        implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape)
    validations = [] if not validate else [
        check_ops.assert_rank(
            original_shape, 1, message="Original shape must be a vector."),
        check_ops.assert_rank(
            new_shape, 1, message="New shape must be a vector."),
        check_ops.assert_less_equal(
            math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32),
            1,
            message="At most one dimension can be unknown."),
        check_ops.assert_positive(
            expanded_new_shape, message="Shape elements must be >=-1."),
        check_ops.assert_equal(
            math_ops.reduce_prod(expanded_new_shape),
            original_size,
            message="Shape sizes do not match."),
    ]
    return expanded_new_shape, batch_shape_static, validations
Example #23
0
 def test_raises_when_zero(self):
   with self.test_session():
     meechum = constant_op.constant([0], name="meechum")
     with ops.control_dependencies([check_ops.assert_positive(meechum)]):
       out = array_ops.identity(meechum)
     with self.assertRaisesOpError("meechum"):
       out.eval()
Example #24
0
def _Check3DImage(image, require_static=True):
  """Assert that we are working with properly shaped image.

  Args:
    image: 3-D Tensor of shape [height, width, channels]
    require_static: If `True`, requires that all dimensions of `image` are
      known and non-zero.

  Raises:
    ValueError: if `image.shape` is not a 3-vector.

  Returns:
    An empty list, if `image` has fully defined dimensions. Otherwise, a list
    containing an assert op is returned.
  """
  try:
    image_shape = image.get_shape().with_rank(3)
  except ValueError:
    raise ValueError("'image' must be three-dimensional.")
  if require_static and not image_shape.is_fully_defined():
    raise ValueError("'image' must be fully defined.")
  if any(x == 0 for x in image_shape):
    raise ValueError("all dims of 'image.shape' must be > 0: %s" %
                     image_shape)
  if not image_shape.is_fully_defined():
    return [check_ops.assert_positive(array_ops.shape(image),
                                      ["all dims of 'image.shape' "
                                       "must be > 0."])]
  else:
    return []
Example #25
0
  def __init__(self,
               rate,
               validate_args=False,
               allow_nan_stats=True,
               name="Poisson"):
    """Initialize a batch of Poisson distributions.

    Args:
      rate: Floating point tensor, the rate parameter of the
        distribution(s). `rate` must be positive.
      validate_args: Python `Boolean`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `Boolean`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined.  When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: `String` name prefixed to Ops created by this class.
    """
    parameters = locals()
    with ops.name_scope(name, values=[rate]) as ns:
      with ops.control_dependencies([check_ops.assert_positive(rate)] if
                                    validate_args else []):
        self._rate = array_ops.identity(rate, name="rate")
    super(Poisson, self).__init__(
        dtype=self._rate.dtype,
        is_continuous=False,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._rate],
        name=ns)
Example #26
0
  def __init__(self,
               loc=0.,
               scale=1.,
               validate_args=False,
               name="gumbel"):
    """Instantiates the `Gumbel` bijector.

    Args:
      loc: Float-like `Tensor` that is the same dtype and is
        broadcastable with `scale`.
        This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.
      scale: Positive Float-like `Tensor` that is the same dtype and is
        broadcastable with `loc`.
        This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
    """
    self._graph_parents = []
    self._name = name
    self._validate_args = validate_args
    with self._name_scope("init", values=[loc, scale]):
      self._loc = ops.convert_to_tensor(loc, name="loc")
      self._scale = ops.convert_to_tensor(scale, name="scale")
      check_ops.assert_same_float_dtype([self._loc, self._scale])
      if validate_args:
        self._scale = control_flow_ops.with_dependencies([
            check_ops.assert_positive(
                self._scale, message="Argument scale was not positive")
        ], self._scale)

    super(Gumbel, self).__init__(
        validate_args=validate_args,
        forward_min_event_ndims=0,
        name=name)
Example #27
0
  def __init__(self,
               alpha,
               beta,
               validate_args=True,
               allow_nan_stats=False,
               name="Gamma"):
    """Construct Gamma distributions with parameters `alpha` and `beta`.

    The parameters `alpha` and `beta` must be shaped in a way that supports
    broadcasting (e.g. `alpha + beta` is a valid operation).

    Args:
      alpha: Floating point tensor, the shape params of the
        distribution(s).
        alpha must contain only positive values.
      beta: Floating point tensor, the inverse scale params of the
        distribution(s).
        beta must contain only positive values.
      validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
        the methods `prob(x)` and `log_prob(x)`.  If `validate_args` is `False`
        and the inputs are invalid, correct behavior is not guaranteed.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prepend to all ops created by this distribution.

    Raises:
      TypeError: if `alpha` and `beta` are different dtypes.
    """
    self._allow_nan_stats = allow_nan_stats
    self._validate_args = validate_args
    with ops.op_scope([alpha, beta], name) as scope:
      self._name = scope
      with ops.control_dependencies([check_ops.assert_positive(
          alpha), check_ops.assert_positive(beta)] if validate_args else []):
        alpha = array_ops.identity(alpha, name="alpha")
        beta = array_ops.identity(beta, name="beta")

        contrib_tensor_util.assert_same_float_dtype((alpha, beta))
        self._broadcast_tensor = alpha + beta

    self._get_batch_shape = self._broadcast_tensor.get_shape()
    self._get_event_shape = tensor_shape.TensorShape([])

    self._alpha = alpha
    self._beta = beta
Example #28
0
def _verify_input(tensor_list, labels, probs_list):
  """Verify that batched inputs are well-formed."""
  checked_probs_list = []
  for probs in probs_list:
    # Since number of classes shouldn't change at runtime, probabilities shape
    # should be fully defined.
    probs.get_shape().assert_is_fully_defined()

    # Probabilities must be 1D.
    probs.get_shape().assert_has_rank(1)

    # Probabilities must be nonnegative and sum to one.
    tol = 1e-6
    prob_sum = math_ops.reduce_sum(probs)
    checked_probs = control_flow_ops.with_dependencies([
        check_ops.assert_non_negative(probs),
        check_ops.assert_less(prob_sum, 1.0 + tol),
        check_ops.assert_less(1.0 - tol, prob_sum)
    ], probs)
    checked_probs_list.append(checked_probs)

  # All probabilities should be the same length.
  prob_length = checked_probs_list[0].get_shape().num_elements()
  for checked_prob in checked_probs_list:
    if checked_prob.get_shape().num_elements() != prob_length:
      raise ValueError('Probability parameters must have the same length.')

  # Labels tensor should only have batch dimension.
  labels.get_shape().assert_has_rank(1)

  for tensor in tensor_list:
    # Data tensor should have a batch dimension.
    shape = tensor.get_shape().with_rank_at_least(1)

    # Data and label batch dimensions must be compatible.
    tensor_shape.dimension_at_index(shape, 0).assert_is_compatible_with(
        labels.get_shape()[0])

  # Data and labels must have the same, strictly positive batch size. Since we
  # can't assume we know the batch size at graph creation, add runtime checks.
  labels_batch_size = array_ops.shape(labels)[0]
  lbl_assert = check_ops.assert_positive(labels_batch_size)

  # Make each tensor depend on its own checks.
  labels = control_flow_ops.with_dependencies([lbl_assert], labels)
  tensor_list = [
      control_flow_ops.with_dependencies([
          lbl_assert,
          check_ops.assert_equal(array_ops.shape(x)[0], labels_batch_size)
      ], x) for x in tensor_list
  ]

  # Label's classes must be integers 0 <= x < num_classes.
  labels = control_flow_ops.with_dependencies([
      check_ops.assert_integer(labels), check_ops.assert_non_negative(labels),
      check_ops.assert_less(labels, math_ops.cast(prob_length, labels.dtype))
  ], labels)

  return tensor_list, labels, checked_probs_list
Example #29
0
def validate_init_args(
    distribution,
    batch_shape,
    validate_args,
    batch_shape_static):
  """Helper to __init__ which makes or raises assertions."""
  with ops.name_scope(name="validate_init_args",
                      values=[batch_shape] + distribution._graph_parents):  # pylint: disable=protected-access
    runtime_assertions = []

    if batch_shape.shape.ndims is not None:
      if batch_shape.shape.ndims != 1:
        raise ValueError("`batch_shape` must be a vector "
                         "(saw rank: {}).".format(
                             batch_shape.shape.ndims))
    elif validate_args:
      runtime_assertions += [
          check_ops.assert_rank(
              batch_shape,
              1,
              message="`batch_shape` must be a vector.",
              name="assert_batch_shape_is_vector"),
      ]

    batch_size_static = np.prod(batch_shape_static)
    dist_batch_size_static = (
        None if not distribution.batch_shape.is_fully_defined()
        else np.prod(distribution.batch_shape).value)

    if batch_size_static is not None and dist_batch_size_static is not None:
      if batch_size_static != dist_batch_size_static:
        raise ValueError("`batch_shape` size ({}) must match "
                         "`distribution.batch_shape` size ({}).".format(
                             batch_size_static,
                             dist_batch_size_static))
    elif validate_args:
      runtime_assertions += [
          check_ops.assert_equal(
              math_ops.reduce_prod(batch_shape),
              math_ops.reduce_prod(distribution.batch_shape_tensor()),
              message=("`batch_shape` size must match "
                       "`distributions.batch_shape` size."),
              name="assert_batch_size"),
      ]

    if batch_shape_static is not None:
      if np.any(batch_shape_static < 1):
        raise ValueError("`batch_shape` elements must be positive "
                         "(i.e., larger than zero).")
    elif validate_args:
      runtime_assertions += [
          check_ops.assert_positive(
              batch_shape,
              message=("`batch_shape` elements must be positive "
                       "(i.e., larger than zero)."),
              name="assert_batch_shape_positive")
      ]

    return runtime_assertions
Example #30
0
 def _check_x(self, x):
   """Check x for proper shape, values, then return tensor version."""
   x = ops.convert_to_tensor(x, name="x_before_deps")
   dependencies = [
       check_ops.assert_positive(x),
       check_ops.assert_less(x, constant_op.constant(
           1, self.dtype))] if self.validate_args else []
   return control_flow_ops.with_dependencies(dependencies, x)
Example #31
0
 def _maybe_assert_valid_sample(self, x):
     """Checks the validity of a sample."""
     if not self.validate_args:
         return x
     return control_flow_ops.with_dependencies([
         check_ops.assert_positive(x, message="samples must be positive"),
         distribution_util.assert_close(
             array_ops.ones([], dtype=self.dtype),
             math_ops.reduce_sum(x, -1),
             message="sample last-dimension must sum to `1`"),
     ], x)
 def _maybe_assert_valid_concentration(self, concentration, validate_args):
     """Checks the validity of the concentration parameter."""
     if not validate_args:
         return concentration
     concentration = distribution_util.embed_check_categorical_event_shape(
         concentration)
     return control_flow_ops.with_dependencies([
         check_ops.assert_positive(
             concentration,
             message="Concentration parameter must be positive."),
     ], concentration)
    def __init__(self,
                 temperature,
                 logits=None,
                 probs=None,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="RelaxedBernoulli"):
        """Construct RelaxedBernoulli distributions.

    Args:
      temperature: An 0-D `Tensor`, representing the temperature
        of a set of RelaxedBernoulli distributions. The temperature should be
        positive.
      logits: An N-D `Tensor` representing the log-odds
        of a positive event. Each entry in the `Tensor` parametrizes
        an independent RelaxedBernoulli distribution where the probability of an
        event is sigmoid(logits). Only one of `logits` or `probs` should be
        passed in.
      probs: An N-D `Tensor` representing the probability of a positive event.
        Each entry in the `Tensor` parameterizes an independent Bernoulli
        distribution. Only one of `logits` or `probs` should be passed in.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      ValueError: If both `probs` and `logits` are passed, or if neither.
    """
        parameters = locals()
        with ops.name_scope(name, values=[logits, probs, temperature]):
            with ops.control_dependencies(
                [check_ops.assert_positive(temperature
                                           )] if validate_args else []):
                self._temperature = array_ops.identity(temperature,
                                                       name="temperature")
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits=logits, probs=probs, validate_args=validate_args)
            super(RelaxedBernoulli,
                  self).__init__(distribution=logistic.Logistic(
                      self._logits / self._temperature,
                      1. / self._temperature,
                      validate_args=validate_args,
                      allow_nan_stats=allow_nan_stats,
                      name=name + "/Logistic"),
                                 bijector=Sigmoid(validate_args=validate_args),
                                 validate_args=validate_args,
                                 name=name)
        self._parameters = parameters
Example #34
0
 def _maybe_assert_valid_sample(self, x):
   """Checks the validity of a sample."""
   if not self.validate_args:
     return x
   return control_flow_ops.with_dependencies([
       check_ops.assert_positive(x, message="sample must be positive"),
       check_ops.assert_less(
           x,
           array_ops.ones([], self.dtype),
           message="sample must be less than `1`."),
   ], x)
Example #35
0
 def _assert_positive_definite(self):
     # This operator has the action  Ax = F^H D F x,
     # where D is the diagonal matrix with self.spectrum on the diag.  Therefore,
     # <x, Ax> = <Fx, DFx>,
     # Since F is bijective, the condition for positive definite is the same as
     # for a diagonal matrix, i.e. real part of spectrum is positive.
     message = (
         "Not positive definite:  Real part of spectrum was not all positive."
     )
     return check_ops.assert_positive(math_ops.real(self.spectrum),
                                      message=message)
Example #36
0
 def _check_x(self, x):
     """Check x for proper shape, values, then return tensor version."""
     x = ops.convert_to_tensor(x, name="x_before_deps")
     candidate_one = math_ops.reduce_sum(x, reduction_indices=[-1])
     one = constant_op.constant(1., self.dtype)
     dependencies = [
         check_ops.assert_positive(x),
         check_ops.assert_less(x, one),
         _assert_close(one, candidate_one)
     ] if self.validate_args else []
     return control_flow_ops.with_dependencies(dependencies, x)
Example #37
0
    def __init__(self,
                 scale=1.,
                 concentration=1.,
                 validate_args=False,
                 name="weibull"):
        """Instantiates the `Weibull` bijector.

    Args:
      scale: Positive Float-type `Tensor` that is the same dtype and is
        broadcastable with `concentration`.
        This is `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`.
      concentration: Positive Float-type `Tensor` that is the same dtype and is
        broadcastable with `scale`.
        This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.
    """
        self._graph_parents = []
        self._name = name
        self._validate_args = validate_args
        with self._name_scope("init", values=[scale, concentration]):
            self._scale = ops.convert_to_tensor(scale, name="scale")
            self._concentration = ops.convert_to_tensor(concentration,
                                                        name="concentration")
            check_ops.assert_same_float_dtype(
                [self._scale, self._concentration])
            if validate_args:
                self._scale = control_flow_ops.with_dependencies([
                    check_ops.assert_positive(
                        self._scale, message="Argument scale was not positive")
                ], self._scale)
                self._concentration = control_flow_ops.with_dependencies([
                    check_ops.assert_positive(
                        self._concentration,
                        message="Argument concentration was not positive")
                ], self._concentration)

        super(Weibull, self).__init__(forward_min_event_ndims=0,
                                      validate_args=validate_args,
                                      name=name)
Example #38
0
    def __init__(self,
                 alpha,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Dirichlet"):
        """Initialize a batch of Dirichlet distributions.

    Args:
      alpha:  Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm, k]` `m >= 0`.  Defines this as a batch of `N1 x ... x Nm`
         different `k` class Dirichlet distributions.
      validate_args: `Boolean`, default `False`.  Whether to assert valid values
        for parameters `alpha` and `x` in `prob` and `log_prob`.  If `False`,
        correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prefix Ops created by this distribution class.

    Examples:

    ```python
    # Define 1-batch of 2-class Dirichlet distributions,
    # also known as a Beta distribution.
    dist = Dirichlet([1.1, 2.0])

    # Define a 2-batch of 3-class distributions.
    dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    ```

    """
        parameters = locals()
        parameters.pop("self")
        with ops.name_scope(name, values=[alpha]) as ns:
            alpha = ops.convert_to_tensor(alpha, name="alpha")
            with ops.control_dependencies([
                    check_ops.assert_positive(alpha),
                    check_ops.assert_rank_at_least(alpha, 1)
            ] if validate_args else []):
                self._alpha = array_ops.identity(alpha, name="alpha")
                self._alpha_sum = math_ops.reduce_sum(alpha,
                                                      reduction_indices=[-1],
                                                      keep_dims=False)
        super(Dirichlet,
              self).__init__(dtype=self._alpha.dtype,
                             validate_args=validate_args,
                             allow_nan_stats=allow_nan_stats,
                             is_continuous=True,
                             is_reparameterized=False,
                             parameters=parameters,
                             graph_parents=[self._alpha, self._alpha_sum],
                             name=ns)
Example #39
0
    def __init__(self,
                 df,
                 mu,
                 sigma,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="StudentT"):
        """Construct Student's t distributions.

    The distributions have degree of freedom `df`, mean `mu`, and scale `sigma`.

    The parameters `df`, `mu`, and `sigma` must be shaped in a way that supports
    broadcasting (e.g. `df + mu + sigma` is a valid operation).

    Args:
      df: Numeric `Tensor`. The degrees of freedom of the distribution(s).
        `df` must contain only positive values.
      mu: Numeric `Tensor`. The mean(s) of the distribution(s).
      sigma: Numeric `Tensor`. The scaling factor(s) for the distribution(s).
        Note that `sigma` is not technically the standard deviation of this
        distribution but has semantics more similar to std. deviation than
        variance.
      validate_args: `Boolean`, default `False`.  Whether to assert that
        `df > 0` and `sigma > 0`. If `validate_args` is `False` and inputs are
        invalid, correct behavior is not guaranteed.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      TypeError: if mu and sigma are different dtypes.
    """
        parameters = locals()
        parameters.pop("self")
        with ops.name_scope(name, values=[df, mu, sigma]) as ns:
            with ops.control_dependencies(
                [check_ops.assert_positive(df)] if validate_args else []):
                self._df = array_ops.identity(df, name="df")
                self._mu = array_ops.identity(mu, name="mu")
                self._sigma = array_ops.identity(sigma, name="sigma")
                contrib_tensor_util.assert_same_float_dtype(
                    (self._df, self._mu, self._sigma))
        super(StudentT,
              self).__init__(dtype=self._sigma.dtype,
                             is_continuous=True,
                             is_reparameterized=True,
                             validate_args=validate_args,
                             allow_nan_stats=allow_nan_stats,
                             parameters=parameters,
                             graph_parents=[self._df, self._mu, self._sigma],
                             name=ns)
Example #40
0
    def __init__(self,
                 df,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="StudentT"):
        """Construct Student's t distributions.

    The distributions have degree of freedom `df`, mean `loc`, and scale
    `scale`.

    The parameters `df`, `loc`, and `scale` must be shaped in a way that
    supports broadcasting (e.g. `df + loc + scale` is a valid operation).

    Args:
      df: Floating-point `Tensor`. The degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      loc: Floating-point `Tensor`. The mean(s) of the distribution(s).
      scale: Floating-point `Tensor`. The scaling factor(s) for the
        distribution(s). Note that `scale` is not technically the standard
        deviation of this distribution but has semantics more similar to
        standard deviation than variance.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`,
        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
        indicate the result is undefined. When `False`, an exception is raised
        if one or more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.

    Raises:
      TypeError: if loc and scale are different dtypes.
    """
        parameters = locals()
        with ops.name_scope(name, values=[df, loc, scale]):
            with ops.control_dependencies(
                [check_ops.assert_positive(df)] if validate_args else []):
                self._df = array_ops.identity(df, name="df")
                self._loc = array_ops.identity(loc, name="loc")
                self._scale = array_ops.identity(scale, name="scale")
                check_ops.assert_same_float_dtype(
                    (self._df, self._loc, self._scale))
        super(StudentT, self).__init__(
            dtype=self._scale.dtype,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._df, self._loc, self._scale],
            name=name)
 def _initialize(self):
     with ops.control_dependencies([
             check_ops.assert_positive(self._num_remaining),
     ]):
         if self._initial_clusters == KMC2_INIT:
             num_now_remaining = self._kmc2_multiple_centers()
         else:
             num_now_remaining = self._add_new_centers()
         return control_flow_ops.cond(
             math_ops.equal(num_now_remaining, 0), lambda: state_ops.assign(
                 self._cluster_centers_initialized, True),
             control_flow_ops.no_op)
    def _assert_positive_definite(self):
        if self.dtype.is_complex:
            message = (
                "Diagonal operator had diagonal entries with non-positive real part, "
                "thus was not positive definite.")
        else:
            message = (
                "Real diagonal operator had non-positive diagonal entries, "
                "thus was not positive definite.")

        return check_ops.assert_positive(math_ops.real(self._diag),
                                         message=message)
Example #43
0
    def __init__(self,
                 alpha,
                 validate_args=True,
                 allow_nan_stats=False,
                 name="Dirichlet"):
        """Initialize a batch of Dirichlet distributions.

    Args:
      alpha:  Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm, k]` `m >= 0`.  Defines this as a batch of `N1 x ... x Nm`
         different `k` class Dirichlet distributions.
      validate_args: Whether to assert valid values for parameters `alpha` and
        `x` in `prob` and `log_prob`.  If `False`, correct behavior is not
        guaranteed.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prefix Ops created by this distribution class.

    Examples:

    ```python
    # Define 1-batch of 2-class Dirichlet distributions,
    # also known as a Beta distribution.
    dist = Dirichlet([1.1, 2.0])

    # Define a 2-batch of 3-class distributions.
    dist = Dirichlet([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    ```

    """
        with ops.op_scope([alpha], name):
            alpha = ops.convert_to_tensor(alpha, name="alpha_before_deps")
            with ops.control_dependencies([
                    check_ops.assert_positive(alpha),
                    check_ops.assert_rank_at_least(alpha, 1)
            ] if validate_args else []):
                alpha = array_ops.identity(alpha, name="alpha")

            self._alpha = alpha
            self._name = name

            # Used for mean/mode/variance/entropy computations
            self._alpha_0 = math_ops.reduce_sum(alpha,
                                                reduction_indices=[-1],
                                                keep_dims=False)

            self._get_batch_shape = self._alpha_0.get_shape()
            self._get_event_shape = self._alpha.get_shape().with_rank_at_least(
                1)[-1:]
            self._validate_args = validate_args
            self._allow_nan_stats = allow_nan_stats
Example #44
0
    def __init__(self, alpha, beta, name="Gamma"):
        """Construct Gamma distributions with parameters `alpha` and `beta`.

    The parameters `alpha` and `beta` must be shaped in a way that supports
    broadcasting (e.g. `alpha + beta` is a valid operation).

    Args:
      alpha: `float` or `double` tensor, the shape params of the
        distribution(s).
        alpha must contain only positive values.
      beta: `float` or `double` tensor, the inverse scale params of the
        distribution(s).
        beta must contain only positive values.
      name: The name to prepend to all ops created by this distribution.

    Raises:
      TypeError: if `alpha` and `beta` are different dtypes.
    """
        with ops.op_scope([alpha, beta], name):
            with ops.control_dependencies([
                    check_ops.assert_positive(alpha),
                    check_ops.assert_positive(beta)
            ]):
                alpha = array_ops.identity(alpha, name="alpha")
                beta = array_ops.identity(beta, name="beta")

                contrib_tensor_util.assert_same_float_dtype((alpha, beta))

                with ops.name_scope("mean"):
                    self._mean = alpha / beta

                with ops.name_scope("variance"):
                    self._variance = alpha / math_ops.square(beta)

        self._get_batch_shape = self._mean.get_shape()
        self._get_event_shape = tensor_shape.TensorShape([])

        self._alpha = alpha
        self._beta = beta
        self._name = name
Example #45
0
    def __init__(self,
                 sigma,
                 alpha,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Pareto"):
        """Construct pareto distributions (Type 1).

    Args:
      sigma: Floating point tensor, the scale parameter of the distribution(s). `sigma` must be positive and non-zero.
      alpha: Floating point tensor, the shape parameter of the distribution(s). `alpha` must be positive and non-zero.
      validate_args: `Boolean`, default `False`.  Whether to assert that
        `p > 0` as well as inputs to pmf computations are non-negative
        integers. If validate_args is `False`, then `pmf` computations might
        return `NaN`, but can be evaluated at any real value.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution.
    """
        parameters = locals()
        parameters.pop("self")
        with ops.name_scope(name, values=[sigma, alpha]) as ns:
            with ops.control_dependencies([
                    check_ops.assert_positive(sigma),
                    check_ops.assert_positive(alpha)
            ] if validate_args else []):
                self._sigma = array_ops.identity(sigma, name="r")
                self._alpha = array_ops.identity(alpha, name="p")
                contrib_tensor_util.assert_same_float_dtype(
                    (self._sigma, self._alpha))
        super(Pareto, self).__init__(dtype=self._sigma.dtype,
                                     is_continuous=True,
                                     is_reparameterized=False,
                                     validate_args=validate_args,
                                     allow_nan_stats=allow_nan_stats,
                                     parameters=parameters,
                                     graph_parents=[self._sigma, self._alpha],
                                     name=ns)
Example #46
0
    def __init__(self,
                 dist,
                 pi,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="ZeroInflated"):
        """Construct zero-inflated distributions.

    Args:
      dist: A 'Distribution' instance.
      pi: Floating point tensor, the zero-inflation parameter of the
        distribution(s). `pi` must be in the interval [0, 1].
      validate_args: `Boolean`, default `False`.  Whether to assert that
        `lambd > 0` as well as inputs to pmf computations are non-negative
        integers. If validate_args is `False`, then `pmf` computations might
        return `NaN`, but can be evaluated at any real value.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution.
    """
        parameters = locals()
        parameters.pop("self")

        if not isinstance(dist, distribution.Distribution):
            raise TypeError("dist must be Distribution instance"
                            " but saw: %s" % dist)
        is_continuous = dist.is_continuous

        static_event_shape = dist.get_event_shape()
        static_batch_shape = pi.get_shape()

        with ops.name_scope(name, values=[pi] + dist._graph_parents) as ns:
            with ops.control_dependencies(
                [check_ops.assert_positive(pi)] if validate_args else []):
                self._dist = dist
                self._pi = array_ops.identity(pi, name="pi")
                self._static_event_shape = static_event_shape
                self._static_batch_shape = static_batch_shape
                #contrib_tensor_util.assert_same_float_dtype((self._lambd, self._pi))
        graph_parents = [self._pi]
        graph_parents += self._dist._graph_parents

        super(ZeroInflated, self).__init__(dtype=self._pi.dtype,
                                           is_continuous=is_continuous,
                                           is_reparameterized=False,
                                           validate_args=validate_args,
                                           allow_nan_stats=allow_nan_stats,
                                           parameters=parameters,
                                           graph_parents=graph_parents,
                                           name=ns)
Example #47
0
    def __init__(self,
                 r,
                 p,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="NegativeBinomial"):
        """Construct negative binomial distributions.

    Args:
      r: Floating point tensor, the number of failures before stop parameter of the distribution(s). `r` must be positive.
      p: Floating point tensor, the succes probability parameter of the distribution(s). `p` must be in the interval [0, 1].
      validate_args: `Boolean`, default `False`.  Whether to assert that
        `p > 0` as well as inputs to pmf computations are non-negative
        integers. If validate_args is `False`, then `pmf` computations might
        return `NaN`, but can be evaluated at any real value.
      allow_nan_stats: `Boolean`, default `True`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: A name for this distribution.
    """
        parameters = locals()
        parameters.pop("self")
        with ops.name_scope(name, values=[r, p]) as ns:
            with ops.control_dependencies(
                [check_ops.assert_positive(r),
                 check_ops.assert_positive(p)] if validate_args else []):
                self._r = array_ops.identity(r, name="r")
                self._p = array_ops.identity(p, name="p")
                contrib_tensor_util.assert_same_float_dtype((self._r, self._p))
        super(NegativeBinomial,
              self).__init__(dtype=self._r.dtype,
                             is_continuous=True,
                             is_reparameterized=False,
                             validate_args=validate_args,
                             allow_nan_stats=allow_nan_stats,
                             parameters=parameters,
                             graph_parents=[self._r, self._p],
                             name=ns)
Example #48
0
 def _assert_valid_sample(self, x):
     """Check x for proper shape, values, then return tensor version."""
     if not self.validate_args: return x
     return control_flow_ops.with_dependencies([
         check_ops.assert_positive(
             x,
             message="Negative events lie outside Beta distribution support."
         ),
         check_ops.assert_less(
             x,
             array_ops.ones((), self.dtype),
             message="Event>=1 lies outside Beta distribution support."),
     ], x)
 def _maybe_attach_assertion(x):
   if not validate_args:
     return x
   if assert_positive:
     return control_flow_ops.with_dependencies([
         check_ops.assert_positive(
             x, message="diagonal part must be positive"),
     ], x)
   return control_flow_ops.with_dependencies([
       check_ops.assert_none_equal(
           x,
           array_ops.zeros([], x.dtype),
           message="diagonal part must be non-zero")], x)
Example #50
0
def make_runtime_assertions(distribution, batch_shape, validate_args,
                            batch_shape_static):
    """Helper to __init__ which makes or raises assertions."""
    runtime_assertions = []

    if batch_shape.shape.ndims is not None:
        if batch_shape.shape.ndims != 1:
            raise ValueError("`batch_shape` must be a vector "
                             "(saw rank: {}).".format(batch_shape.shape.ndims))
    elif validate_args:
        runtime_assertions += [
            check_ops.assert_rank(batch_shape,
                                  1,
                                  message="`batch_shape` must be a vector.",
                                  name="assert_batch_shape_is_vector"),
        ]

    batch_size_static = np.prod(batch_shape_static)
    dist_batch_size_static = (None if
                              not distribution.batch_shape.is_fully_defined()
                              else np.prod(distribution.batch_shape).value)

    if batch_size_static is not None and dist_batch_size_static is not None:
        if batch_size_static != dist_batch_size_static:
            raise ValueError("`batch_shape` size ({}) must match "
                             "`distribution.batch_shape` size ({}).".format(
                                 batch_size_static, dist_batch_size_static))
    elif validate_args:
        runtime_assertions += [
            check_ops.assert_equal(
                math_ops.reduce_prod(batch_shape),
                math_ops.reduce_prod(distribution.batch_shape_tensor()),
                message=("`batch_shape` size must match "
                         "`distributions.batch_shape` size."),
                name="assert_batch_size"),
        ]

    if batch_shape_static is not None:
        if np.any(batch_shape_static < 1):
            raise ValueError("`batch_shape` elements must be positive "
                             "(i.e., larger than zero).")
    elif validate_args:
        runtime_assertions += [
            check_ops.assert_positive(
                batch_shape,
                message=("`batch_shape` elements must be positive "
                         "(i.e., larger than zero)."),
                name="assert_batch_shape_positive")
        ]

    return runtime_assertions
Example #51
0
    def __init__(self,
                 df,
                 scale,
                 cholesky_input_output_matrices=False,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="WishartFull"):
        """Construct Wishart distributions.

    Args:
      df: `float` or `double` `Tensor`. Degrees of freedom, must be greater than
        or equal to dimension of the scale matrix.
      scale: `float` or `double` `Tensor`. The symmetric positive definite
        scale matrix of the distribution.
      cholesky_input_output_matrices: Python `bool`. Any function which whose
        input or output is a matrix assumes the input is Cholesky and returns a
        Cholesky factored matrix. Example `log_prob` input takes a Cholesky and
        `sample_n` returns a Cholesky when
        `cholesky_input_output_matrices=True`.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        with ops.name_scope(name) as name:
            with ops.name_scope("init", values=[scale]):
                scale = ops.convert_to_tensor(scale)
                if validate_args:
                    scale = distribution_util.assert_symmetric(scale)
                chol = linalg_ops.cholesky(scale)
                chol = control_flow_ops.with_dependencies([
                    check_ops.assert_positive(array_ops.matrix_diag_part(chol))
                ] if validate_args else [], chol)
        super(WishartFull, self).__init__(
            df=df,
            scale_operator=linalg.LinearOperatorLowerTriangular(
                tril=chol,
                is_non_singular=True,
                is_positive_definite=True,
                is_square=True),
            cholesky_input_output_matrices=cholesky_input_output_matrices,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            name=name)
        self._parameters = parameters
Example #52
0
  def __init__(self, alpha, beta, strict=True, name="Gamma"):
    """Construct Gamma distributions with parameters `alpha` and `beta`.

    The parameters `alpha` and `beta` must be shaped in a way that supports
    broadcasting (e.g. `alpha + beta` is a valid operation).

    Args:
      alpha: `float` or `double` tensor, the shape params of the
        distribution(s).
        alpha must contain only positive values.
      beta: `float` or `double` tensor, the inverse scale params of the
        distribution(s).
        beta must contain only positive values.
      strict: Whether to assert that `a > 0, b > 0`, and that `x > 0` in the
        methods `pdf(x)` and `log_pdf(x)`.  If `strict` is False
        and the inputs are invalid, correct behavior is not guaranteed.
      name: The name to prepend to all ops created by this distribution.

    Raises:
      TypeError: if `alpha` and `beta` are different dtypes.
    """
    self._strict = strict
    with ops.op_scope([alpha, beta], name) as scope:
      self._name = scope
      with ops.control_dependencies(
          [check_ops.assert_positive(alpha), check_ops.assert_positive(beta)]
          if strict else []):
        alpha = array_ops.identity(alpha, name="alpha")
        beta = array_ops.identity(beta, name="beta")

        contrib_tensor_util.assert_same_float_dtype((alpha, beta))
        self._broadcast_tensor = alpha + beta

    self._get_batch_shape = self._broadcast_tensor.get_shape()
    self._get_event_shape = tensor_shape.TensorShape([])

    self._alpha = alpha
    self._beta = beta
    def __init__(self,
                 loc,
                 scale,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Hyperbolic-wrapped-norm"):
        """Construct hyperbolic wrapped normal distributions with mean of 'loc'
        and scale of `scale`.

        Args:
          loc: Floating point tensor; the mean of the distribution(s).
          scale: Floating point tensor; the concentration of the distribution(s).
            Must contain only non-negative values.
          validate_args: Python `bool`, default `False`. When `True` distribution
            parameters are checked for validity despite possibly degrading runtime
            performance. When `False` invalid inputs may silently render incorrect
            outputs.
          allow_nan_stats: Python `bool`, default `True`. When `True`,
            statistics (e.g., mean, mode, variance) use the value "`NaN`" to
            indicate the result is undefined. When `False`, an exception is raised
            if one or more of the statistic's batch members are undefined.
          name: Python `str` name prefixed to Ops created by this class.

        Raises:
          TypeError: if `loc` and `scale` have different `dtype`.
        """
        parameters = locals()
        with ops.name_scope(name, values=[loc, scale]):
            with ops.control_dependencies([
                    check_ops.assert_positive(scale),
                    check_ops.
                    assert_near(linalg_ops.norm(loc, axis=-1), 1, atol=1e-5)
            ] if validate_args else []):
                self._loc = array_ops.identity(loc, name="loc")
                self._scale = array_ops.identity(scale, name="scale")
                check_ops.assert_same_float_dtype([self._loc, self._scale])

        super(HyperbolicWrappedNorm, self).__init__(
            dtype=self._scale.dtype,
            reparameterization_type=distributions.FULLY_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._loc, self._scale],
            name=name)

        self._base_dist = tfp.distributions.Normal(loc=tf.zeros_like(
            self._scale),
                                                   scale=self._scale)
        self._dim = tf.shape(self._loc)[1] - 1
Example #54
0
def maybe_mask_score(score, memory_sequence_length, score_mask_value):
    if memory_sequence_length is None:
        return score
    with ops.control_dependencies([
            check_ops.assert_positive(
                memory_sequence_length,
                message=
                ("All values in memory_sequence_length must greater than zero."
                 ))
    ]):
        score_mask = array_ops.sequence_mask(memory_sequence_length,
                                             maxlen=array_ops.shape(score)[1])
        score_mask_values = score_mask_value * array_ops.ones_like(score)
        return array_ops.where(score_mask, score, score_mask_values)
Example #55
0
def _verify_input(data, labels, probs_list):
    """Verify that batched inputs are well-formed."""
    checked_probs_list = []
    for probs in probs_list:
        # Probabilities must be able to be converted to non-object numpy array.
        np_probs = np.asarray(probs)
        if np_probs.dtype == np.dtype('object'):
            raise ValueError(
                'Probabilities must be able to be converted to a numpy '
                'array.')
        checked_probs_list.append(np_probs)

        # Probabilities must sum to one.
        # TODO(joelshor): Investigate whether logits should be passed instead of
        # probs.
        if not np.isclose(np.sum(probs), 1.0):
            raise ValueError('Probabilities must sum to one.')

    # All probabilities should be the same length.
    if not np.array_equiv([probs.shape for probs in checked_probs_list],
                          checked_probs_list[0].shape):
        raise ValueError('Probability parameters must have the same length.')

    # Labels tensor should only have batch dimension.
    labels.get_shape().assert_has_rank(1)

    # Data tensor should have a batch dimension.
    data_shape = data.get_shape().with_rank_at_least(1)

    # Data and label batch dimensions must be compatible.
    data_shape[0].assert_is_compatible_with(labels.get_shape()[0])

    # Data and labels must have the same, strictly positive batch size. Since we
    # can't assume we know the batch size at graph creation, add runtime checks.
    data_batch_size = array_ops.shape(data)[0]
    labels_batch_size = array_ops.shape(labels)[0]

    data = control_flow_ops.with_dependencies([
        check_ops.assert_positive(data_batch_size),
        check_ops.assert_equal(data_batch_size, labels_batch_size)
    ], data)

    # Label's classes must be integers 0 <= x < num_classes.
    labels = control_flow_ops.with_dependencies([
        check_ops.assert_integer(labels),
        check_ops.assert_non_negative(labels),
        check_ops.assert_less(labels, math_ops.cast(len(probs), labels.dtype))
    ], labels)

    return data, labels, checked_probs_list
Example #56
0
    def _check_scale(self, scale, dtype):
        """Check that the init arg `scale` defines a valid operator."""
        if scale is None:
            return constant_op.constant(1.0, dtype=dtype)

        scale = ops.convert_to_tensor(scale, dtype=dtype, name="scale")

        if not self._verify_pd:
            return scale

        # Further check that this is a rank 0, positive tensor.
        scale = contrib_tensor_util.assert_scalar(scale)
        return control_flow_ops.with_dependencies(
            [check_ops.assert_positive(scale)], scale)
 def _maybe_attach_assertion(x):
     if not validate_args:
         return x
     if assert_positive:
         return control_flow_ops.with_dependencies([
             check_ops.assert_positive(
                 x, message="diagonal part must be positive"),
         ], x)
     # TODO(b/35157376): Use `assert_none_equal` once it exists.
     return control_flow_ops.with_dependencies([
         check_ops.assert_greater(math_ops.abs(x),
                                  array_ops.zeros([], x.dtype),
                                  message="diagonal part must be non-zero"),
     ], x)
Example #58
0
    def __init__(self, df, mu, sigma, strict=True, name="StudentT"):
        """Construct Student's t distributions.

    The distributions have degree of freedom `df`, mean `mu`, and scale `sigma`.

    The parameters `df`, `mu`, and `sigma` must be shaped in a way that supports
    broadcasting (e.g. `df + mu + sigma` is a valid operation).

    Args:
      df: `float` or `double` tensor, the degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      mu: `float` or `double` tensor, the means of the distribution(s).
      sigma: `float` or `double` tensor, the scaling factor for the
        distribution(s). `sigma` must contain only positive values.
        Note that `sigma` is not the standard deviation of this distribution.
      strict: Whether to assert that `df > 0, sigma > 0`. If `strict` is False
        and inputs are invalid, correct behavior is not guaranteed.
      name: The name to give Ops created by the initializer.

    Raises:
      TypeError: if mu and sigma are different dtypes.
    """
        super(StudentT, self).__init__()
        self._strict = strict
        with ops.op_scope([df, mu, sigma], name) as scope:
            with ops.control_dependencies([
                    check_ops.assert_positive(df),
                    check_ops.assert_positive(sigma)
            ] if strict else []):
                self._df = ops.convert_to_tensor(df, name="df")
                self._mu = ops.convert_to_tensor(mu, name="mu")
                self._sigma = ops.convert_to_tensor(sigma, name="sigma")
                contrib_tensor_util.assert_same_float_dtype(
                    (self._df, self._mu, self._sigma))
            self._name = scope
            self._get_batch_shape = self._ones().get_shape()
            self._get_event_shape = tensor_shape.TensorShape([])
    def __init__(self,
                 logits=None,
                 probs=None,
                 validate_args=True,
                 allow_nan_stats=False,
                 name="Geometric"):
        """Construct Geometric distributions.

    Args:
      logits: Floating-point `Tensor` with shape `[B1, ..., Bb]` where `b >= 0`
        indicates the number of batch dimensions. Each entry represents logits
        for the probability of success for independent Geometric distributions
        and must be in the range `(-inf, inf]`. Only one of `logits` or `probs`
        should be specified.
      probs: Positive floating-point `Tensor` with shape `[B1, ..., Bb]`
        where `b >= 0` indicates the number of batch dimensions. Each entry
        represents the probability of success for independent Geometric
        distributions and must be in the range `(0, 1]`. Only one of `logits`
        or `probs` should be specified.
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """

        parameters = locals()
        with ops.name_scope(name, values=[logits, probs]) as ns:
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits, probs, validate_args=validate_args, name=name)

            with ops.control_dependencies(
                [check_ops.assert_positive(self._probs
                                           )] if validate_args else []):
                self._probs = array_ops.identity(self._probs, name="probs")

        super(Geometric, self).__init__(
            dtype=self._probs.dtype,
            is_continuous=False,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._probs, self._logits],
            name=ns)
Example #60
0
 def _maybe_assert_valid_concentration(self, concentration, validate_args):
   """Checks the validity of the concentration parameter."""
   if not validate_args:
     return concentration
   return control_flow_ops.with_dependencies([
       check_ops.assert_positive(
           concentration,
           message="Concentration parameter must be positive."),
       check_ops.assert_rank_at_least(
           concentration, 1,
           message="Concentration parameter must have >=1 dimensions."),
       check_ops.assert_less(
           1, array_ops.shape(concentration)[-1],
           message="Concentration parameter must have event_size >= 2."),
   ], concentration)