Esempio n. 1
0
def test_torch_formatter_sets_default_dtypes(cast_schema, arrow_table):
    import torch

    from datasets.formatting import TorchFormatter

    if cast_schema:
        arrow_table = arrow_table.cast(pa.schema(cast_schema))
    arrow_table_dict = arrow_table.to_pydict()
    list_int = arrow_table_dict["col_int"]
    list_float = arrow_table_dict["col_float"]
    formatter = TorchFormatter()

    row = formatter.format_row(arrow_table)
    torch.testing.assert_allclose(row["col_int"],
                                  torch.tensor(list_int, dtype=torch.int64)[0])
    torch.testing.assert_allclose(
        row["col_float"],
        torch.tensor(list_float, dtype=torch.float32)[0])

    col = formatter.format_column(arrow_table)
    torch.testing.assert_allclose(col, torch.tensor(list_int,
                                                    dtype=torch.int64))

    batch = formatter.format_batch(arrow_table)
    torch.testing.assert_allclose(batch["col_int"],
                                  torch.tensor(list_int, dtype=torch.int64))
    torch.testing.assert_allclose(
        batch["col_float"], torch.tensor(list_float, dtype=torch.float32))
Esempio n. 2
0
    def test_torch_formatter_np_array_kwargs(self):
        import torch

        from datasets.formatting import TorchFormatter

        pa_table = self._create_dummy_table().drop(["b"])
        formatter = TorchFormatter(dtype=torch.float16)
        row = formatter.format_row(pa_table)
        self.assertEqual(row["c"].dtype, torch.float16)
        col = formatter.format_column(pa_table)
        self.assertEqual(col.dtype, torch.float16)
        batch = formatter.format_batch(pa_table)
        self.assertEqual(batch["a"].dtype, torch.float16)
        self.assertEqual(batch["c"].dtype, torch.float16)
Esempio n. 3
0
    def test_torch_formatter(self):
        import torch

        from datasets.formatting import TorchFormatter

        pa_table = self._create_dummy_table().drop(["b"])
        formatter = TorchFormatter()
        row = formatter.format_row(pa_table)
        torch.testing.assert_allclose(row["a"], torch.tensor(_COL_A, dtype=torch.int64)[0])
        torch.testing.assert_allclose(row["c"], torch.tensor(_COL_C, dtype=torch.float32)[0])
        col = formatter.format_column(pa_table)
        torch.testing.assert_allclose(col, torch.tensor(_COL_A, dtype=torch.int64))
        batch = formatter.format_batch(pa_table)
        torch.testing.assert_allclose(batch["a"], torch.tensor(_COL_A, dtype=torch.int64))
        torch.testing.assert_allclose(batch["c"], torch.tensor(_COL_C, dtype=torch.float32))