Ejemplo n.º 1
0
    def testTooSmall(self):
        with self.test_session():
            with self.assertRaises(ValueError):
                param = array_ops.ones([1], dtype=np.float16)
                checked_param = distribution_util.embed_check_categorical_event_shape(
                    param)

            with self.assertRaisesOpError("must have at least 2 events"):
                param = array_ops.placeholder(dtype=dtypes.float16)
                checked_param = distribution_util.embed_check_categorical_event_shape(
                    param)
                checked_param.eval(feed_dict={param: np.ones([1])})
Ejemplo n.º 2
0
    def testTooLarge(self):
        with self.test_session():
            with self.assertRaises(ValueError):
                param = array_ops.ones([int(2**11 + 1)], dtype=dtypes.float16)
                checked_param = distribution_util.embed_check_categorical_event_shape(
                    param)

            with self.assertRaisesOpError(
                    "Number of classes exceeds `dtype` precision"):
                param = array_ops.placeholder(dtype=dtypes.float16)
                checked_param = distribution_util.embed_check_categorical_event_shape(
                    param)
                checked_param.eval(
                    feed_dict={param: np.ones([int(2**11 + 1)])})
 def _maybe_assert_valid_concentration(self, concentration, validate_args):
   """Checks the validity of the concentration parameter."""
   if not validate_args:
     return concentration
   concentration = distribution_util.embed_check_categorical_event_shape(
       concentration)
   return control_flow_ops.with_dependencies([
       check_ops.assert_positive(
           concentration,
           message="Concentration parameter must be positive."),
   ], concentration)
Ejemplo n.º 4
0
 def _maybe_assert_valid_concentration(self, concentration, validate_args):
     """Checks the validity of the concentration parameter."""
     if not validate_args:
         return concentration
     concentration = distribution_util.embed_check_categorical_event_shape(
         concentration)
     return control_flow_ops.with_dependencies([
         check_ops.assert_positive(
             concentration,
             message="Concentration parameter must be positive."),
     ], concentration)
Ejemplo n.º 5
0
    def __init__(self,
                 logits=None,
                 probs=None,
                 dtype=dtypes.int32,
                 validate_args=False,
                 allow_nan_stats=True,
                 name="Categorical"):
        """Initialize Categorical distributions using class log-probabilities.

    Args:
      logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
        of a set of Categorical distributions. The first `N - 1` dimensions
        index into a batch of independent distributions and the last dimension
        represents a vector of logits for each class. Only one of `logits` or
        `probs` should be passed in.
      probs: An N-D `Tensor`, `N >= 1`, representing the probabilities
        of a set of Categorical distributions. The first `N - 1` dimensions
        index into a batch of independent distributions and the last dimension
        represents a vector of probabilities for each class. Only one of
        `logits` or `probs` should be passed in.
      dtype: The type of the event samples (default: int32).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
        parameters = dict(locals())
        with ops.name_scope(name, values=[logits, probs]) as name:
            self._logits, self._probs = distribution_util.get_logits_and_probs(
                logits=logits,
                probs=probs,
                validate_args=validate_args,
                multidimensional=True,
                name=name)

            if validate_args:
                self._logits = distribution_util.embed_check_categorical_event_shape(
                    self._logits)

            logits_shape_static = self._logits.get_shape().with_rank_at_least(
                1)
            if logits_shape_static.ndims is not None:
                self._batch_rank = ops.convert_to_tensor(
                    logits_shape_static.ndims - 1,
                    dtype=dtypes.int32,
                    name="batch_rank")
            else:
                with ops.name_scope(name="batch_rank"):
                    self._batch_rank = array_ops.rank(self._logits) - 1

            logits_shape = array_ops.shape(self._logits, name="logits_shape")
            if logits_shape_static[-1].value is not None:
                self._event_size = ops.convert_to_tensor(
                    logits_shape_static[-1].value,
                    dtype=dtypes.int32,
                    name="event_size")
            else:
                with ops.name_scope(name="event_size"):
                    self._event_size = logits_shape[self._batch_rank]

            if logits_shape_static[:-1].is_fully_defined():
                self._batch_shape_val = constant_op.constant(
                    logits_shape_static[:-1].as_list(),
                    dtype=dtypes.int32,
                    name="batch_shape")
            else:
                with ops.name_scope(name="batch_shape"):
                    self._batch_shape_val = logits_shape[:-1]
        super(Categorical, self).__init__(
            dtype=dtype,
            reparameterization_type=distribution.NOT_REPARAMETERIZED,
            validate_args=validate_args,
            allow_nan_stats=allow_nan_stats,
            parameters=parameters,
            graph_parents=[self._logits, self._probs],
            name=name)
Ejemplo n.º 6
0
  def __init__(
      self,
      logits=None,
      probs=None,
      dtype=dtypes.int32,
      validate_args=False,
      allow_nan_stats=True,
      name="Categorical"):
    """Initialize Categorical distributions using class log-probabilities.

    Args:
      logits: An N-D `Tensor`, `N >= 1`, representing the log probabilities
        of a set of Categorical distributions. The first `N - 1` dimensions
        index into a batch of independent distributions and the last dimension
        represents a vector of logits for each class. Only one of `logits` or
        `probs` should be passed in.
      probs: An N-D `Tensor`, `N >= 1`, representing the probabilities
        of a set of Categorical distributions. The first `N - 1` dimensions
        index into a batch of independent distributions and the last dimension
        represents a vector of probabilities for each class. Only one of
        `logits` or `probs` should be passed in.
      dtype: The type of the event samples (default: int32).
      validate_args: Python `bool`, default `False`. When `True` distribution
        parameters are checked for validity despite possibly degrading runtime
        performance. When `False` invalid inputs may silently render incorrect
        outputs.
      allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
        (e.g., mean, mode, variance) use the value "`NaN`" to indicate the
        result is undefined. When `False`, an exception is raised if one or
        more of the statistic's batch members are undefined.
      name: Python `str` name prefixed to Ops created by this class.
    """
    parameters = distribution_util.parent_frame_arguments()
    with ops.name_scope(name, values=[logits, probs]) as name:
      self._logits, self._probs = distribution_util.get_logits_and_probs(
          logits=logits,
          probs=probs,
          validate_args=validate_args,
          multidimensional=True,
          name=name)

      if validate_args:
        self._logits = distribution_util.embed_check_categorical_event_shape(
            self._logits)

      logits_shape_static = self._logits.get_shape().with_rank_at_least(1)
      if logits_shape_static.ndims is not None:
        self._batch_rank = ops.convert_to_tensor(
            logits_shape_static.ndims - 1,
            dtype=dtypes.int32,
            name="batch_rank")
      else:
        with ops.name_scope(name="batch_rank"):
          self._batch_rank = array_ops.rank(self._logits) - 1

      logits_shape = array_ops.shape(self._logits, name="logits_shape")
      if logits_shape_static[-1].value is not None:
        self._event_size = ops.convert_to_tensor(
            logits_shape_static[-1].value,
            dtype=dtypes.int32,
            name="event_size")
      else:
        with ops.name_scope(name="event_size"):
          self._event_size = logits_shape[self._batch_rank]

      if logits_shape_static[:-1].is_fully_defined():
        self._batch_shape_val = constant_op.constant(
            logits_shape_static[:-1].as_list(),
            dtype=dtypes.int32,
            name="batch_shape")
      else:
        with ops.name_scope(name="batch_shape"):
          self._batch_shape_val = logits_shape[:-1]
    super(Categorical, self).__init__(
        dtype=dtype,
        reparameterization_type=distribution.NOT_REPARAMETERIZED,
        validate_args=validate_args,
        allow_nan_stats=allow_nan_stats,
        parameters=parameters,
        graph_parents=[self._logits,
                       self._probs],
        name=name)
Ejemplo n.º 7
0
 def testUnsupportedDtype(self):
     with self.test_session():
         with self.assertRaises(TypeError):
             param = array_ops.ones([int(2**11 + 1)], dtype=dtypes.qint16)
             distribution_util.embed_check_categorical_event_shape(param)