예제 #1
0
  def _log_prob(self, event):
    if self.validate_args:
      event = distribution_util.embed_check_integer_casting_closed(
          event, target_dtype=dtypes.bool)

    # TODO(jaana): The current sigmoid_cross_entropy_with_logits has
    # inconsistent behavior for logits = inf/-inf.
    event = math_ops.cast(event, self.logits.dtype)
    logits = self.logits
    # sigmoid_cross_entropy_with_logits doesn't broadcast shape,
    # so we do this here.

    def _broadcast(logits, event):
      return (array_ops.ones_like(event) * logits,
              array_ops.ones_like(logits) * event)

    # First check static shape.
    if (event.get_shape().is_fully_defined() and
        logits.get_shape().is_fully_defined()):
      if event.get_shape() != logits.get_shape():
        logits, event = _broadcast(logits, event)
    else:
      logits, event = control_flow_ops.cond(
          distribution_util.same_dynamic_shape(logits, event),
          lambda: (logits, event),
          lambda: _broadcast(logits, event))
    return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
예제 #2
0
    def _log_prob(self, event):
        if self.validate_args:
            event = distribution_util.embed_check_integer_casting_closed(
                event, target_dtype=dtypes.bool)

        # TODO(jaana): The current sigmoid_cross_entropy_with_logits has
        # inconsistent behavior for logits = inf/-inf.
        event = math_ops.cast(event, self.logits.dtype)
        logits = self.logits

        # sigmoid_cross_entropy_with_logits doesn't broadcast shape,
        # so we do this here.

        def _broadcast(logits, event):
            return (array_ops.ones_like(event) * logits,
                    array_ops.ones_like(logits) * event)

        # First check static shape.
        if (event.get_shape().is_fully_defined()
                and logits.get_shape().is_fully_defined()):
            if event.get_shape() != logits.get_shape():
                logits, event = _broadcast(logits, event)
        else:
            logits, event = control_flow_ops.cond(
                distribution_util.same_dynamic_shape(logits, event), lambda:
                (logits, event), lambda: _broadcast(logits, event))
        return -nn.sigmoid_cross_entropy_with_logits(labels=event,
                                                     logits=logits)
예제 #3
0
    def testSameDynamicShape(self):
        with self.test_session():
            scalar = constant_op.constant(2.0)
            scalar1 = array_ops.placeholder(dtype=dtypes.float32)

            vector = [0.3, 0.4, 0.5]
            vector1 = array_ops.placeholder(dtype=dtypes.float32, shape=[None])
            vector2 = array_ops.placeholder(dtype=dtypes.float32, shape=[None])

            multidimensional = [[0.3, 0.4], [0.2, 0.6]]
            multidimensional1 = array_ops.placeholder(dtype=dtypes.float32,
                                                      shape=[None, None])
            multidimensional2 = array_ops.placeholder(dtype=dtypes.float32,
                                                      shape=[None, None])

            # Scalar
            self.assertTrue(
                distribution_util.same_dynamic_shape(scalar, scalar1).eval(
                    {scalar1: 2.0}))

            # Vector

            self.assertTrue(
                distribution_util.same_dynamic_shape(vector, vector1).eval(
                    {vector1: [2.0, 3.0, 4.0]}))
            self.assertTrue(
                distribution_util.same_dynamic_shape(vector1, vector2).eval({
                    vector1: [2.0, 3.0, 4.0],
                    vector2: [2.0, 3.5, 6.0]
                }))

            # Multidimensional
            self.assertTrue(
                distribution_util.same_dynamic_shape(multidimensional,
                                                     multidimensional1).eval({
                                                         multidimensional1:
                                                         [[2.0, 3.0],
                                                          [3.0, 4.0]]
                                                     }))
            self.assertTrue(
                distribution_util.same_dynamic_shape(multidimensional1,
                                                     multidimensional2).eval({
                                                         multidimensional1:
                                                         [[2.0, 3.0],
                                                          [3.0, 4.0]],
                                                         multidimensional2:
                                                         [[1.0, 3.5],
                                                          [6.3, 2.3]]
                                                     }))

            # Scalar, X
            self.assertFalse(
                distribution_util.same_dynamic_shape(scalar, vector1).eval(
                    {vector1: [2.0, 3.0, 4.0]}))
            self.assertFalse(
                distribution_util.same_dynamic_shape(scalar1, vector1).eval({
                    scalar1:
                    2.0,
                    vector1: [2.0, 3.0, 4.0]
                }))
            self.assertFalse(
                distribution_util.same_dynamic_shape(scalar,
                                                     multidimensional1).eval({
                                                         multidimensional1:
                                                         [[2.0, 3.0],
                                                          [3.0, 4.0]]
                                                     }))
            self.assertFalse(
                distribution_util.same_dynamic_shape(scalar1,
                                                     multidimensional1).eval({
                                                         scalar1:
                                                         2.0,
                                                         multidimensional1:
                                                         [[2.0, 3.0],
                                                          [3.0, 4.0]]
                                                     }))

            # Vector, X
            self.assertFalse(
                distribution_util.same_dynamic_shape(vector, vector1).eval(
                    {vector1: [2.0, 3.0]}))
            self.assertFalse(
                distribution_util.same_dynamic_shape(vector1, vector2).eval({
                    vector1: [2.0, 3.0, 4.0],
                    vector2: [6.0]
                }))
            self.assertFalse(
                distribution_util.same_dynamic_shape(vector,
                                                     multidimensional1).eval({
                                                         multidimensional1:
                                                         [[2.0, 3.0],
                                                          [3.0, 4.0]]
                                                     }))
            self.assertFalse(
                distribution_util.same_dynamic_shape(vector1,
                                                     multidimensional1).eval({
                                                         vector1:
                                                         [2.0, 3.0, 4.0],
                                                         multidimensional1:
                                                         [[2.0, 3.0],
                                                          [3.0, 4.0]]
                                                     }))

            # Multidimensional, X
            self.assertFalse(
                distribution_util.same_dynamic_shape(multidimensional,
                                                     multidimensional1).eval({
                                                         multidimensional1:
                                                         [[1.0, 3.5, 5.0],
                                                          [6.3, 2.3, 7.1]]
                                                     }))
            self.assertFalse(
                distribution_util.same_dynamic_shape(multidimensional1,
                                                     multidimensional2).eval({
                                                         multidimensional1:
                                                         [[2.0, 3.0],
                                                          [3.0, 4.0]],
                                                         multidimensional2:
                                                         [[1.0, 3.5, 5.0],
                                                          [6.3, 2.3, 7.1]]
                                                     }))