예제 #1
0
  def _log_prob(self, event):
    if self.validate_args:
      event = distribution_util.embed_check_integer_casting_closed(
          event, target_dtype=dtypes.bool)

    # TODO(jaana): The current sigmoid_cross_entropy_with_logits has
    # inconsistent behavior for logits = inf/-inf.
    event = math_ops.cast(event, self.logits.dtype)
    logits = self.logits
    # sigmoid_cross_entropy_with_logits doesn't broadcast shape,
    # so we do this here.

    def _broadcast(logits, event):
      return (array_ops.ones_like(event) * logits,
              array_ops.ones_like(logits) * event)

    # First check static shape.
    if (event.get_shape().is_fully_defined() and
        logits.get_shape().is_fully_defined()):
      if event.get_shape() != logits.get_shape():
        logits, event = _broadcast(logits, event)
    else:
      logits, event = control_flow_ops.cond(
          distribution_util.same_dynamic_shape(logits, event),
          lambda: (logits, event),
          lambda: _broadcast(logits, event))
    return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
예제 #2
0
    def _log_prob(self, k):
        k = ops.convert_to_tensor(k, name="k")
        if self.validate_args:
            k = distribution_util.embed_check_integer_casting_closed(
                k, target_dtype=dtypes.int32)

        if self.logits.get_shape()[:-1] == k.get_shape():
            logits = self.logits
        else:
            logits = self.logits * array_ops.ones_like(
                array_ops.expand_dims(k, -1), dtype=self.logits.dtype)
            logits_shape = array_ops.shape(logits)[:-1]
            k *= array_ops.ones(logits_shape, dtype=k.dtype)
            k.set_shape(tensor_shape.TensorShape(logits.get_shape()[:-1]))
            if k.dtype.is_integer:
                pass
            elif k.dtype.is_floating:
                # When `validate_args=True` we've already ensured int/float casting
                # is closed.
                return ops.cast(k, dtype=dtypes.int32)
            else:
                raise TypeError("`value` should have integer `dtype` or "
                                "`self.dtype` ({})".format(
                                    self.dtype.base_dtype))
        return -nn_ops.sparse_softmax_cross_entropy_with_logits(labels=k,
                                                                logits=logits)
예제 #3
0
    def _log_prob(self, event):
        if self.validate_args:
            event = distribution_util.embed_check_integer_casting_closed(
                event, target_dtype=dtypes.bool)

        # TODO(jaana): The current sigmoid_cross_entropy_with_logits has
        # inconsistent behavior for logits = inf/-inf.
        event = math_ops.cast(event, self.logits.dtype)
        logits = self.logits

        # sigmoid_cross_entropy_with_logits doesn't broadcast shape,
        # so we do this here.

        def _broadcast(logits, event):
            return (array_ops.ones_like(event) * logits,
                    array_ops.ones_like(logits) * event)

        # First check static shape.
        if (event.get_shape().is_fully_defined()
                and logits.get_shape().is_fully_defined()):
            if event.get_shape() != logits.get_shape():
                logits, event = _broadcast(logits, event)
        else:
            logits, event = control_flow_ops.cond(
                distribution_util.same_dynamic_shape(logits, event), lambda:
                (logits, event), lambda: _broadcast(logits, event))
        return -nn.sigmoid_cross_entropy_with_logits(labels=event,
                                                     logits=logits)
예제 #4
0
  def _log_prob(self, k):
    k = ops.convert_to_tensor(k, name="k")
    if self.validate_args:
      k = distribution_util.embed_check_integer_casting_closed(
          k, target_dtype=dtypes.int32)

    if self.logits.get_shape()[:-1] == k.get_shape():
      logits = self.logits
    else:
      logits = self.logits * array_ops.ones_like(
          array_ops.expand_dims(k, -1), dtype=self.logits.dtype)
      logits_shape = array_ops.shape(logits)[:-1]
      k *= array_ops.ones(logits_shape, dtype=k.dtype)
      k.set_shape(tensor_shape.TensorShape(logits.get_shape()[:-1]))
      if k.dtype.is_integer:
        pass
      elif k.dtype.is_floating:
        # When `validate_args=True` we've already ensured int/float casting
        # is closed.
        return ops.cast(k, dtype=dtypes.int32)
      else:
        raise TypeError("`value` should have integer `dtype` or "
                        "`self.dtype` ({})".format(self.dtype.base_dtype))
    return -nn_ops.sparse_softmax_cross_entropy_with_logits(labels=k,
                                                            logits=logits)
예제 #5
0
 def testCorrectlyAssertsLargestPossibleInteger(self):
     with self.test_session():
         with self.assertRaisesOpError("Elements cannot exceed 32767."):
             x = array_ops.placeholder(dtype=dtypes.int32)
             x_checked = distribution_util.embed_check_integer_casting_closed(
                 x, target_dtype=dtypes.int16)
             x_checked.eval(
                 feed_dict={x: np.array([1, 2**15], dtype=np.int32)})
예제 #6
0
 def testCorrectlyAssertsNonnegative(self):
     with self.test_session():
         with self.assertRaisesOpError("Elements must be non-negative"):
             x = array_ops.placeholder(dtype=dtypes.float16)
             x_checked = distribution_util.embed_check_integer_casting_closed(
                 x, target_dtype=dtypes.int16)
             x_checked.eval(
                 feed_dict={x: np.array([1, -1], dtype=np.float16)})
예제 #7
0
 def testCorrectlyAssertsSmallestPossibleInteger(self):
     with self.test_session():
         with self.assertRaisesOpError(
                 "Elements cannot be smaller than 0."):
             x = array_ops.placeholder(dtype=dtypes.int32)
             x_checked = distribution_util.embed_check_integer_casting_closed(
                 x, target_dtype=dtypes.uint16, assert_nonnegative=False)
             x_checked.eval(
                 feed_dict={x: np.array([1, -1], dtype=np.int32)})
예제 #8
0
    def _log_prob(self, k):
        k = ops.convert_to_tensor(k, name="k")
        if self.validate_args:
            k = distribution_util.embed_check_integer_casting_closed(
                k, target_dtype=dtypes.int32)
        k, logits = _broadcast_cat_event_and_params(
            k, self.logits, base_dtype=self.dtype.base_dtype)

        return -nn_ops.sparse_softmax_cross_entropy_with_logits(labels=k,
                                                                logits=logits)
예제 #9
0
  def _log_prob(self, k):
    k = ops.convert_to_tensor(k, name="k")
    if self.validate_args:
      k = distribution_util.embed_check_integer_casting_closed(
          k, target_dtype=dtypes.int32)
    k, logits = _broadcast_cat_event_and_params(
        k, self.logits, base_dtype=self.dtype.base_dtype)

    return -nn_ops.sparse_softmax_cross_entropy_with_logits(labels=k,
                                                            logits=logits)
예제 #10
0
  def _cdf(self, k):
    k = ops.convert_to_tensor(k, name="k")
    if self.validate_args:
      k = distribution_util.embed_check_integer_casting_closed(
          k, target_dtype=dtypes.int32)

    # If there are multiple batch dimension, flatten them into one.
    batch_flattened_probs = array_ops.reshape(self._probs,
                                              [-1, self._event_size])
    batch_flattened_k = array_ops.reshape(k, [-1])

    # Form a tensor to sum over.
    # We don't need to cast k to integer since `sequence_mask` does this for us.
    mask_tensor = array_ops.sequence_mask(batch_flattened_k, self._event_size)
    to_sum_over = array_ops.where(mask_tensor,
                                  batch_flattened_probs,
                                  array_ops.zeros_like(batch_flattened_probs))
    batch_flat_cdf = math_ops.reduce_sum(to_sum_over, axis=-1)
    return array_ops.reshape(batch_flat_cdf, self._batch_shape())
예제 #11
0
    def _cdf(self, k):
        k = ops.convert_to_tensor(k, name="k")
        if self.validate_args:
            k = distribution_util.embed_check_integer_casting_closed(
                k, target_dtype=dtypes.int32)

        k, probs = _broadcast_cat_event_and_params(
            k, self.probs, base_dtype=self.dtype.base_dtype)

        # batch-flatten everything in order to use `sequence_mask()`.
        batch_flattened_probs = array_ops.reshape(probs,
                                                  (-1, self._event_size))
        batch_flattened_k = array_ops.reshape(k, [-1])

        to_sum_over = array_ops.where(
            array_ops.sequence_mask(batch_flattened_k, self._event_size),
            batch_flattened_probs, array_ops.zeros_like(batch_flattened_probs))
        batch_flattened_cdf = math_ops.reduce_sum(to_sum_over, axis=-1)
        # Reshape back to the shape of the argument.
        return array_ops.reshape(batch_flattened_cdf, array_ops.shape(k))
예제 #12
0
    def _cdf(self, k):
        k = ops.convert_to_tensor(k, name="k")
        if self.validate_args:
            k = distribution_util.embed_check_integer_casting_closed(
                k, target_dtype=dtypes.int32)

        # If there are multiple batch dimension, flatten them into one.
        batch_flattened_probs = array_ops.reshape(self._probs,
                                                  [-1, self._event_size])
        batch_flattened_k = array_ops.reshape(k, [-1])

        # Form a tensor to sum over.
        # We don't need to cast k to integer since `sequence_mask` does this for us.
        mask_tensor = array_ops.sequence_mask(batch_flattened_k,
                                              self._event_size)
        to_sum_over = array_ops.where(
            mask_tensor, batch_flattened_probs,
            array_ops.zeros_like(batch_flattened_probs))
        batch_flat_cdf = math_ops.reduce_sum(to_sum_over, axis=-1)
        return array_ops.reshape(batch_flat_cdf, self._batch_shape())
예제 #13
0
  def _cdf(self, k):
    k = ops.convert_to_tensor(k, name="k")
    if self.validate_args:
      k = distribution_util.embed_check_integer_casting_closed(
          k, target_dtype=dtypes.int32)

    k, probs = _broadcast_cat_event_and_params(
        k, self.probs, base_dtype=self.dtype.base_dtype)

    # batch-flatten everything in order to use `sequence_mask()`.
    batch_flattened_probs = array_ops.reshape(probs,
                                              (-1, self._event_size))
    batch_flattened_k = array_ops.reshape(k, [-1])

    to_sum_over = array_ops.where(
        array_ops.sequence_mask(batch_flattened_k, self._event_size),
        batch_flattened_probs,
        array_ops.zeros_like(batch_flattened_probs))
    batch_flattened_cdf = math_ops.reduce_sum(to_sum_over, axis=-1)
    # Reshape back to the shape of the argument.
    return array_ops.reshape(batch_flattened_cdf, array_ops.shape(k))