def test_same_shape(self):
        self.assertTrue(
            tensor_utils.same_shape(tf.TensorShape(None),
                                    tf.TensorShape(None)))
        self.assertTrue(
            tensor_utils.same_shape(tf.TensorShape([None]),
                                    tf.TensorShape([None])))
        self.assertTrue(
            tensor_utils.same_shape(tf.TensorShape([1]), tf.TensorShape([1])))
        self.assertTrue(
            tensor_utils.same_shape(tf.TensorShape([None, 1]),
                                    tf.TensorShape([None, 1])))
        self.assertTrue(
            tensor_utils.same_shape(tf.TensorShape([1, 2, 3]),
                                    tf.TensorShape([1, 2, 3])))

        self.assertFalse(
            tensor_utils.same_shape(tf.TensorShape(None), tf.TensorShape([1])))
        self.assertFalse(
            tensor_utils.same_shape(tf.TensorShape([1]), tf.TensorShape(None)))
        self.assertFalse(
            tensor_utils.same_shape(tf.TensorShape([1]),
                                    tf.TensorShape([None])))
        self.assertFalse(
            tensor_utils.same_shape(tf.TensorShape([1]), tf.TensorShape([2])))
        self.assertFalse(
            tensor_utils.same_shape(tf.TensorShape([1, 2]),
                                    tf.TensorShape([2, 1])))
Exemple #2
0
 def __eq__(self, other):
     return (isinstance(other, TensorType) and self._dtype == other.dtype
             and tensor_utils.same_shape(self._shape, other.shape))