def test_array2d_nonspecific_shape(self): with tempfile.TemporaryDirectory() as tmp_dir: my_features = DEFAULT_FEATURES.copy() writer = ArrowWriter(features=my_features, path=os.path.join(tmp_dir, "beta.arrow")) for key, record in generate_examples( features=my_features, num_examples=1, ): example = my_features.encode_example(record) writer.write(example) num_examples, num_bytes = writer.finalize() dataset = datasets.Dataset.from_file( os.path.join(tmp_dir, "beta.arrow")) dataset.set_format("numpy") row = dataset[0] first_shape = row["image"].shape second_shape = row["text"].shape self.assertTrue( first_shape is not None and second_shape is not None, "need atleast 2 different shapes") self.assertEqual(len(first_shape), len(second_shape), "both shapes are supposed to be equal length") self.assertNotEqual(first_shape, second_shape, "shapes must not be the same") del dataset
def write(my_features, dummy_data, tmp_dir): writer = ArrowWriter(features=my_features, path=os.path.join(tmp_dir, "beta.arrow")) for key, record in dummy_data: example = my_features.encode_example(record) writer.write(example) num_examples, num_bytes = writer.finalize()
def save(self, path: str) -> None: # Make all the directories to the path os.makedirs(path, exist_ok=True) # Taken from Huggingface datasets.Dataset # Prepare output buffer and batched writer in memory or on file if we update # the table writer = ArrowWriter( features=self.features, path=os.path.join(path, "data.arrow"), writer_batch_size=1000, ) # Loop over single examples or batches and write to buffer/file if examples # are to be updated for i, example in tqdm(enumerate(self)): writer.write(example) writer.finalize() # Write DatasetInfo self.info.write_to_directory(path) # Write split to file with open(os.path.join(path, "split.p"), "wb") as f: pickle.dump(self.split, f)
def test_write_schema(self): fields = {"col_1": pa.string(), "col_2": pa.int64()} output = pa.BufferOutputStream() writer = ArrowWriter(stream=output, schema=pa.schema(fields)) writer.write({"col_1": "foo", "col_2": 1}) writer.write({"col_1": "bar", "col_2": 2}) num_examples, num_bytes = writer.finalize() self.assertEqual(num_examples, 2) self.assertGreater(num_bytes, 0) self.assertEqual(writer._schema, pa.schema(fields, metadata=writer._schema.metadata)) self._check_output(output.getvalue())
def test_compatability_with_string_values(self): with tempfile.TemporaryDirectory() as tmp_dir: my_features = DEFAULT_FEATURES.copy() my_features["image_id"] = datasets.Value("string") writer = ArrowWriter(features=my_features, path=os.path.join(tmp_dir, "beta.arrow")) for key, record in generate_examples(features=my_features, num_examples=1): example = my_features.encode_example(record) writer.write(example) num_examples, num_bytes = writer.finalize() dataset = datasets.Dataset.from_file( os.path.join(tmp_dir, "beta.arrow")) self.assertTrue(isinstance(dataset[0]["image_id"], str), "image id must be of type string")
def test_extension_indexing(self): with tempfile.TemporaryDirectory() as tmp_dir: my_features = DEFAULT_FEATURES.copy() my_features["explicit_ext"] = Array2D((3, 3), dtype="float32") writer = ArrowWriter(features=my_features, path=os.path.join(tmp_dir, "beta.arrow")) for key, record in generate_examples(features=my_features, num_examples=1): example = my_features.encode_example(record) writer.write(example) num_examples, num_bytes = writer.finalize() dataset = datasets.Dataset.from_file( os.path.join(tmp_dir, "beta.arrow")) dataset.set_format("numpy") data = dataset[0]["explicit_ext"] self.assertIsInstance( data, np.ndarray, "indexed extension must return numpy.ndarray")
def test_write(self, array_feature, shape_1, shape_2): with tempfile.TemporaryDirectory() as tmp_dir: my_features = self.get_features(array_feature, shape_1, shape_2) writer = ArrowWriter(features=my_features, path=os.path.join(tmp_dir, "beta.arrow")) my_examples = [ (0, self.get_dict_example_0(shape_1, shape_2)), (1, self.get_dict_example_1(shape_1, shape_2)), ] for key, record in my_examples: example = my_features.encode_example(record) writer.write(example) num_examples, num_bytes = writer.finalize() dataset = datasets.Dataset.from_file( os.path.join(tmp_dir, "beta.arrow")) self._check_getitem_output_type(dataset, shape_1, shape_2, my_examples[0][1]["matrix"])
def test_multiple_extensions_same_row(self): with tempfile.TemporaryDirectory() as tmp_dir: my_features = DEFAULT_FEATURES.copy() writer = ArrowWriter(features=my_features, path=os.path.join(tmp_dir, "beta.arrow")) for key, record in generate_examples(features=my_features, num_examples=1): example = my_features.encode_example(record) writer.write(example) num_examples, num_bytes = writer.finalize() dataset = datasets.Dataset.from_file( os.path.join(tmp_dir, "beta.arrow")) dataset.set_format("numpy") row = dataset[0] first_len = len(row["image"].shape) second_len = len(row["text"].shape) self.assertEqual(first_len, 2, "use a sequence type if dim is < 2") self.assertEqual(second_len, 2, "use a sequence type if dim is < 2")