def test_tf_formatter_sets_default_dtypes(cast_schema, arrow_table): import tensorflow as tf from datasets.formatting import TFFormatter if cast_schema: arrow_table = arrow_table.cast(pa.schema(cast_schema)) arrow_table_dict = arrow_table.to_pydict() list_int = arrow_table_dict["col_int"] list_float = arrow_table_dict["col_float"] formatter = TFFormatter() row = formatter.format_row(arrow_table) tf.debugging.assert_equal(row["col_int"], tf.ragged.constant(list_int, dtype=tf.int64)[0]) tf.debugging.assert_equal( row["col_float"], tf.ragged.constant(list_float, dtype=tf.float32)[0]) col = formatter.format_column(arrow_table) tf.debugging.assert_equal(col, tf.ragged.constant(list_int, dtype=tf.int64)) batch = formatter.format_batch(arrow_table) tf.debugging.assert_equal(batch["col_int"], tf.ragged.constant(list_int, dtype=tf.int64)) tf.debugging.assert_equal(batch["col_float"], tf.ragged.constant(list_float, dtype=tf.float32))
def test_tf_formatter(self): import tensorflow as tf from datasets.formatting import TFFormatter pa_table = self._create_dummy_table() formatter = TFFormatter() row = formatter.format_row(pa_table) tf.debugging.assert_equal( row["a"], tf.convert_to_tensor(_COL_A, dtype=tf.int64)[0]) tf.debugging.assert_equal( row["b"], tf.convert_to_tensor(_COL_B, dtype=tf.string)[0]) tf.debugging.assert_equal( row["c"], tf.convert_to_tensor(_COL_C, dtype=tf.float32)[0]) col = formatter.format_column(pa_table) tf.debugging.assert_equal(col, tf.ragged.constant(_COL_A, dtype=tf.int64)) batch = formatter.format_batch(pa_table) tf.debugging.assert_equal(batch["a"], tf.convert_to_tensor(_COL_A, dtype=tf.int64)) tf.debugging.assert_equal( batch["b"], tf.convert_to_tensor(_COL_B, dtype=tf.string)) self.assertIsInstance(batch["c"], tf.Tensor) self.assertEqual(batch["c"].dtype, tf.float32) tf.debugging.assert_equal( batch["c"].shape.as_list(), tf.convert_to_tensor(_COL_C, dtype=tf.float32).shape.as_list()) tf.debugging.assert_equal( tf.convert_to_tensor(batch["c"]), tf.convert_to_tensor(_COL_C, dtype=tf.float32))
def test_tf_formatter_np_array_kwargs(self): import tensorflow as tf from datasets.formatting import TFFormatter pa_table = self._create_dummy_table().drop(["b"]) formatter = TFFormatter(dtype=tf.float16) row = formatter.format_row(pa_table) self.assertEqual(row["c"].dtype, tf.float16) col = formatter.format_column(pa_table) self.assertEqual(col.dtype, tf.float16) batch = formatter.format_batch(pa_table) self.assertEqual(batch["a"].dtype, tf.float16) self.assertEqual(batch["c"].dtype, tf.float16)