def _log_prob(self, event): if self.validate_args: event = distribution_util.embed_check_integer_casting_closed( event, target_dtype=dtypes.bool) # TODO(jaana): The current sigmoid_cross_entropy_with_logits has # inconsistent behavior for logits = inf/-inf. event = math_ops.cast(event, self.logits.dtype) logits = self.logits # sigmoid_cross_entropy_with_logits doesn't broadcast shape, # so we do this here. def _broadcast(logits, event): return (array_ops.ones_like(event) * logits, array_ops.ones_like(logits) * event) # First check static shape. if (event.get_shape().is_fully_defined() and logits.get_shape().is_fully_defined()): if event.get_shape() != logits.get_shape(): logits, event = _broadcast(logits, event) else: logits, event = control_flow_ops.cond( distribution_util.same_dynamic_shape(logits, event), lambda: (logits, event), lambda: _broadcast(logits, event)) return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
def testSameDynamicShape(self): with self.test_session(): scalar = constant_op.constant(2.0) scalar1 = array_ops.placeholder(dtype=dtypes.float32) vector = [0.3, 0.4, 0.5] vector1 = array_ops.placeholder(dtype=dtypes.float32, shape=[None]) vector2 = array_ops.placeholder(dtype=dtypes.float32, shape=[None]) multidimensional = [[0.3, 0.4], [0.2, 0.6]] multidimensional1 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None]) multidimensional2 = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None]) # Scalar self.assertTrue( distribution_util.same_dynamic_shape(scalar, scalar1).eval( {scalar1: 2.0})) # Vector self.assertTrue( distribution_util.same_dynamic_shape(vector, vector1).eval( {vector1: [2.0, 3.0, 4.0]})) self.assertTrue( distribution_util.same_dynamic_shape(vector1, vector2).eval({ vector1: [2.0, 3.0, 4.0], vector2: [2.0, 3.5, 6.0] })) # Multidimensional self.assertTrue( distribution_util.same_dynamic_shape(multidimensional, multidimensional1).eval({ multidimensional1: [[2.0, 3.0], [3.0, 4.0]] })) self.assertTrue( distribution_util.same_dynamic_shape(multidimensional1, multidimensional2).eval({ multidimensional1: [[2.0, 3.0], [3.0, 4.0]], multidimensional2: [[1.0, 3.5], [6.3, 2.3]] })) # Scalar, X self.assertFalse( distribution_util.same_dynamic_shape(scalar, vector1).eval( {vector1: [2.0, 3.0, 4.0]})) self.assertFalse( distribution_util.same_dynamic_shape(scalar1, vector1).eval({ scalar1: 2.0, vector1: [2.0, 3.0, 4.0] })) self.assertFalse( distribution_util.same_dynamic_shape(scalar, multidimensional1).eval({ multidimensional1: [[2.0, 3.0], [3.0, 4.0]] })) self.assertFalse( distribution_util.same_dynamic_shape(scalar1, multidimensional1).eval({ scalar1: 2.0, multidimensional1: [[2.0, 3.0], [3.0, 4.0]] })) # Vector, X self.assertFalse( distribution_util.same_dynamic_shape(vector, vector1).eval( {vector1: [2.0, 3.0]})) self.assertFalse( distribution_util.same_dynamic_shape(vector1, vector2).eval({ vector1: [2.0, 3.0, 4.0], vector2: [6.0] })) self.assertFalse( distribution_util.same_dynamic_shape(vector, multidimensional1).eval({ multidimensional1: [[2.0, 3.0], [3.0, 4.0]] })) self.assertFalse( distribution_util.same_dynamic_shape(vector1, multidimensional1).eval({ vector1: [2.0, 3.0, 4.0], multidimensional1: [[2.0, 3.0], [3.0, 4.0]] })) # Multidimensional, X self.assertFalse( distribution_util.same_dynamic_shape(multidimensional, multidimensional1).eval({ multidimensional1: [[1.0, 3.5, 5.0], [6.3, 2.3, 7.1]] })) self.assertFalse( distribution_util.same_dynamic_shape(multidimensional1, multidimensional2).eval({ multidimensional1: [[2.0, 3.0], [3.0, 4.0]], multidimensional2: [[1.0, 3.5, 5.0], [6.3, 2.3, 7.1]] }))