예제 #1
0
    def test_simple_read_tf(self):
        """Just a bunch of read and compares of all values to the expected values for their types
        and shapes"""
        reader_tensors = tf_tensors(self.reader)._asdict()

        for schema_field in TestSchema.fields.values():
            self.assertEqual(reader_tensors[schema_field.name].dtype,
                             _numpy_to_tf_dtypes(schema_field.numpy_dtype))
            self.assertEqual(len(reader_tensors[schema_field.name].shape),
                             len(schema_field.shape))

        # Read a bunch of entries from the dataset and compare the data to reference
        with tf.Session() as sess:
            for _ in range(10):
                sess.run(reader_tensors)

        self.reader.stop()
        self.reader.join()
예제 #2
0
 def test_unknown_type(self):
     with self.assertRaises(ValueError):
         _numpy_to_tf_dtypes(np.uint64)
예제 #3
0
 def test_uint16_promotion_to_int32(self):
     self.assertEqual(_numpy_to_tf_dtypes(np.uint16), tf.int32)
예제 #4
0
 def test_decimal_conversion(self):
     self.assertEqual(_numpy_to_tf_dtypes(Decimal), tf.string)
예제 #5
0
def test_unknown_type():
    with pytest.raises(ValueError):
        _numpy_to_tf_dtypes(np.uint64)
예제 #6
0
def test_uint16_promotion_to_int32():
    assert _numpy_to_tf_dtypes(np.uint16) == tf.int32
예제 #7
0
def test_decimal_conversion():
    assert _numpy_to_tf_dtypes(Decimal) == tf.string