示例#1
0
 def testShapeGetters(self):
     with self.cached_session():
         chain = Chain([
             SoftmaxCentered(validate_args=True),
             SoftmaxCentered(validate_args=True),
         ])
         x = tensor_shape.TensorShape([1])
         y = tensor_shape.TensorShape([2 + 1])
         self.assertAllEqual(y, chain.forward_event_shape(x))
         self.assertAllEqual(
             y.as_list(),
             chain.forward_event_shape_tensor(x.as_list()).eval())
         self.assertAllEqual(x, chain.inverse_event_shape(y))
         self.assertAllEqual(
             x.as_list(),
             chain.inverse_event_shape_tensor(y.as_list()).eval())
示例#2
0
 def testShapeGetters(self):
   with self.test_session():
     chain = Chain([
         SoftmaxCentered(validate_args=True),
         SoftmaxCentered(validate_args=True),
     ])
     x = tensor_shape.TensorShape([1])
     y = tensor_shape.TensorShape([2 + 1])
     self.assertAllEqual(y, chain.forward_event_shape(x))
     self.assertAllEqual(
         y.as_list(),
         chain.forward_event_shape_tensor(x.as_list()).eval())
     self.assertAllEqual(x, chain.inverse_event_shape(y))
     self.assertAllEqual(
         x.as_list(),
         chain.inverse_event_shape_tensor(y.as_list()).eval())
示例#3
0
 def testShapeGetters(self):
     with self.test_session():
         bijector = Chain([
             SoftmaxCentered(event_ndims=1, validate_args=True),
             SoftmaxCentered(event_ndims=0, validate_args=True)
         ])
         x = tensor_shape.TensorShape([])
         y = tensor_shape.TensorShape([2 + 1])
         self.assertAllEqual(y, bijector.forward_event_shape(x))
         self.assertAllEqual(
             y.as_list(),
             bijector.forward_event_shape_tensor(x.as_list()).eval())
         self.assertAllEqual(x, bijector.inverse_event_shape(y))
         self.assertAllEqual(
             x.as_list(),
             bijector.inverse_event_shape_tensor(y.as_list()).eval())
示例#4
0
 def testShapeGetters(self):
   with self.test_session():
     bijector = Chain([
         SoftmaxCentered(
             event_ndims=1, validate_args=True),
         SoftmaxCentered(
             event_ndims=0, validate_args=True)
     ])
     x = tensor_shape.TensorShape([])
     y = tensor_shape.TensorShape([2 + 1])
     self.assertAllEqual(y, bijector.forward_event_shape(x))
     self.assertAllEqual(
         y.as_list(),
         bijector.forward_event_shape_tensor(x.as_list()).eval())
     self.assertAllEqual(x, bijector.inverse_event_shape(y))
     self.assertAllEqual(
         x.as_list(),
         bijector.inverse_event_shape_tensor(y.as_list()).eval())