Esempio n. 1
0
def test_tf_formatter_sets_default_dtypes(cast_schema, arrow_table):
    import tensorflow as tf

    from datasets.formatting import TFFormatter

    if cast_schema:
        arrow_table = arrow_table.cast(pa.schema(cast_schema))
    arrow_table_dict = arrow_table.to_pydict()
    list_int = arrow_table_dict["col_int"]
    list_float = arrow_table_dict["col_float"]
    formatter = TFFormatter()

    row = formatter.format_row(arrow_table)
    tf.debugging.assert_equal(row["col_int"],
                              tf.ragged.constant(list_int, dtype=tf.int64)[0])
    tf.debugging.assert_equal(
        row["col_float"],
        tf.ragged.constant(list_float, dtype=tf.float32)[0])

    col = formatter.format_column(arrow_table)
    tf.debugging.assert_equal(col, tf.ragged.constant(list_int,
                                                      dtype=tf.int64))

    batch = formatter.format_batch(arrow_table)
    tf.debugging.assert_equal(batch["col_int"],
                              tf.ragged.constant(list_int, dtype=tf.int64))
    tf.debugging.assert_equal(batch["col_float"],
                              tf.ragged.constant(list_float, dtype=tf.float32))
Esempio n. 2
0
    def test_tf_formatter(self):
        import tensorflow as tf

        from datasets.formatting import TFFormatter

        pa_table = self._create_dummy_table()
        formatter = TFFormatter()
        row = formatter.format_row(pa_table)
        tf.debugging.assert_equal(
            row["a"],
            tf.convert_to_tensor(_COL_A, dtype=tf.int64)[0])
        tf.debugging.assert_equal(
            row["b"],
            tf.convert_to_tensor(_COL_B, dtype=tf.string)[0])
        tf.debugging.assert_equal(
            row["c"],
            tf.convert_to_tensor(_COL_C, dtype=tf.float32)[0])
        col = formatter.format_column(pa_table)
        tf.debugging.assert_equal(col,
                                  tf.ragged.constant(_COL_A, dtype=tf.int64))
        batch = formatter.format_batch(pa_table)
        tf.debugging.assert_equal(batch["a"],
                                  tf.convert_to_tensor(_COL_A, dtype=tf.int64))
        tf.debugging.assert_equal(
            batch["b"], tf.convert_to_tensor(_COL_B, dtype=tf.string))
        self.assertIsInstance(batch["c"], tf.Tensor)
        self.assertEqual(batch["c"].dtype, tf.float32)
        tf.debugging.assert_equal(
            batch["c"].shape.as_list(),
            tf.convert_to_tensor(_COL_C, dtype=tf.float32).shape.as_list())
        tf.debugging.assert_equal(
            tf.convert_to_tensor(batch["c"]),
            tf.convert_to_tensor(_COL_C, dtype=tf.float32))
Esempio n. 3
0
    def test_tf_formatter_np_array_kwargs(self):
        import tensorflow as tf

        from datasets.formatting import TFFormatter

        pa_table = self._create_dummy_table().drop(["b"])
        formatter = TFFormatter(dtype=tf.float16)
        row = formatter.format_row(pa_table)
        self.assertEqual(row["c"].dtype, tf.float16)
        col = formatter.format_column(pa_table)
        self.assertEqual(col.dtype, tf.float16)
        batch = formatter.format_batch(pa_table)
        self.assertEqual(batch["a"].dtype, tf.float16)
        self.assertEqual(batch["c"].dtype, tf.float16)