def test_torch_formatter_sets_default_dtypes(cast_schema, arrow_table): import torch from datasets.formatting import TorchFormatter if cast_schema: arrow_table = arrow_table.cast(pa.schema(cast_schema)) arrow_table_dict = arrow_table.to_pydict() list_int = arrow_table_dict["col_int"] list_float = arrow_table_dict["col_float"] formatter = TorchFormatter() row = formatter.format_row(arrow_table) torch.testing.assert_allclose(row["col_int"], torch.tensor(list_int, dtype=torch.int64)[0]) torch.testing.assert_allclose( row["col_float"], torch.tensor(list_float, dtype=torch.float32)[0]) col = formatter.format_column(arrow_table) torch.testing.assert_allclose(col, torch.tensor(list_int, dtype=torch.int64)) batch = formatter.format_batch(arrow_table) torch.testing.assert_allclose(batch["col_int"], torch.tensor(list_int, dtype=torch.int64)) torch.testing.assert_allclose( batch["col_float"], torch.tensor(list_float, dtype=torch.float32))
def test_torch_formatter_np_array_kwargs(self): import torch from datasets.formatting import TorchFormatter pa_table = self._create_dummy_table().drop(["b"]) formatter = TorchFormatter(dtype=torch.float16) row = formatter.format_row(pa_table) self.assertEqual(row["c"].dtype, torch.float16) col = formatter.format_column(pa_table) self.assertEqual(col.dtype, torch.float16) batch = formatter.format_batch(pa_table) self.assertEqual(batch["a"].dtype, torch.float16) self.assertEqual(batch["c"].dtype, torch.float16)
def test_torch_formatter(self): import torch from datasets.formatting import TorchFormatter pa_table = self._create_dummy_table().drop(["b"]) formatter = TorchFormatter() row = formatter.format_row(pa_table) torch.testing.assert_allclose(row["a"], torch.tensor(_COL_A, dtype=torch.int64)[0]) torch.testing.assert_allclose(row["c"], torch.tensor(_COL_C, dtype=torch.float32)[0]) col = formatter.format_column(pa_table) torch.testing.assert_allclose(col, torch.tensor(_COL_A, dtype=torch.int64)) batch = formatter.format_batch(pa_table) torch.testing.assert_allclose(batch["a"], torch.tensor(_COL_A, dtype=torch.int64)) torch.testing.assert_allclose(batch["c"], torch.tensor(_COL_C, dtype=torch.float32))