def test_array2d_nonspecific_shape(self):
     with tempfile.TemporaryDirectory() as tmp_dir:
         my_features = DEFAULT_FEATURES.copy()
         writer = ArrowWriter(features=my_features,
                              path=os.path.join(tmp_dir, "beta.arrow"))
         for key, record in generate_examples(
                 features=my_features,
                 num_examples=1,
         ):
             example = my_features.encode_example(record)
             writer.write(example)
         num_examples, num_bytes = writer.finalize()
         dataset = datasets.Dataset.from_file(
             os.path.join(tmp_dir, "beta.arrow"))
         dataset.set_format("numpy")
         row = dataset[0]
         first_shape = row["image"].shape
         second_shape = row["text"].shape
         self.assertTrue(
             first_shape is not None and second_shape is not None,
             "need atleast 2 different shapes")
         self.assertEqual(len(first_shape), len(second_shape),
                          "both shapes are supposed to be equal length")
         self.assertNotEqual(first_shape, second_shape,
                             "shapes must not be the same")
         del dataset
Exemple #2
0
def write(my_features, dummy_data, tmp_dir):
    writer = ArrowWriter(features=my_features,
                         path=os.path.join(tmp_dir, "beta.arrow"))
    for key, record in dummy_data:
        example = my_features.encode_example(record)
        writer.write(example)
    num_examples, num_bytes = writer.finalize()
    def save(self, path: str) -> None:
        # Make all the directories to the path
        os.makedirs(path, exist_ok=True)

        # Taken from Huggingface datasets.Dataset
        # Prepare output buffer and batched writer in memory or on file if we update
        # the table
        writer = ArrowWriter(
            features=self.features,
            path=os.path.join(path, "data.arrow"),
            writer_batch_size=1000,
        )

        # Loop over single examples or batches and write to buffer/file if examples
        # are to be updated
        for i, example in tqdm(enumerate(self)):
            writer.write(example)

        writer.finalize()

        # Write DatasetInfo
        self.info.write_to_directory(path)

        # Write split to file
        with open(os.path.join(path, "split.p"), "wb") as f:
            pickle.dump(self.split, f)
 def test_write_schema(self):
     fields = {"col_1": pa.string(), "col_2": pa.int64()}
     output = pa.BufferOutputStream()
     writer = ArrowWriter(stream=output, schema=pa.schema(fields))
     writer.write({"col_1": "foo", "col_2": 1})
     writer.write({"col_1": "bar", "col_2": 2})
     num_examples, num_bytes = writer.finalize()
     self.assertEqual(num_examples, 2)
     self.assertGreater(num_bytes, 0)
     self.assertEqual(writer._schema,
                      pa.schema(fields, metadata=writer._schema.metadata))
     self._check_output(output.getvalue())
 def test_compatability_with_string_values(self):
     with tempfile.TemporaryDirectory() as tmp_dir:
         my_features = DEFAULT_FEATURES.copy()
         my_features["image_id"] = datasets.Value("string")
         writer = ArrowWriter(features=my_features,
                              path=os.path.join(tmp_dir, "beta.arrow"))
         for key, record in generate_examples(features=my_features,
                                              num_examples=1):
             example = my_features.encode_example(record)
             writer.write(example)
         num_examples, num_bytes = writer.finalize()
         dataset = datasets.Dataset.from_file(
             os.path.join(tmp_dir, "beta.arrow"))
         self.assertTrue(isinstance(dataset[0]["image_id"], str),
                         "image id must be of type string")
 def test_extension_indexing(self):
     with tempfile.TemporaryDirectory() as tmp_dir:
         my_features = DEFAULT_FEATURES.copy()
         my_features["explicit_ext"] = Array2D((3, 3), dtype="float32")
         writer = ArrowWriter(features=my_features,
                              path=os.path.join(tmp_dir, "beta.arrow"))
         for key, record in generate_examples(features=my_features,
                                              num_examples=1):
             example = my_features.encode_example(record)
             writer.write(example)
         num_examples, num_bytes = writer.finalize()
         dataset = datasets.Dataset.from_file(
             os.path.join(tmp_dir, "beta.arrow"))
         dataset.set_format("numpy")
         data = dataset[0]["explicit_ext"]
         self.assertIsInstance(
             data, np.ndarray,
             "indexed extension must return numpy.ndarray")
    def test_write(self, array_feature, shape_1, shape_2):

        with tempfile.TemporaryDirectory() as tmp_dir:

            my_features = self.get_features(array_feature, shape_1, shape_2)
            writer = ArrowWriter(features=my_features,
                                 path=os.path.join(tmp_dir, "beta.arrow"))
            my_examples = [
                (0, self.get_dict_example_0(shape_1, shape_2)),
                (1, self.get_dict_example_1(shape_1, shape_2)),
            ]
            for key, record in my_examples:
                example = my_features.encode_example(record)
                writer.write(example)
            num_examples, num_bytes = writer.finalize()
            dataset = datasets.Dataset.from_file(
                os.path.join(tmp_dir, "beta.arrow"))
            self._check_getitem_output_type(dataset, shape_1, shape_2,
                                            my_examples[0][1]["matrix"])
 def test_multiple_extensions_same_row(self):
     with tempfile.TemporaryDirectory() as tmp_dir:
         my_features = DEFAULT_FEATURES.copy()
         writer = ArrowWriter(features=my_features,
                              path=os.path.join(tmp_dir, "beta.arrow"))
         for key, record in generate_examples(features=my_features,
                                              num_examples=1):
             example = my_features.encode_example(record)
             writer.write(example)
         num_examples, num_bytes = writer.finalize()
         dataset = datasets.Dataset.from_file(
             os.path.join(tmp_dir, "beta.arrow"))
         dataset.set_format("numpy")
         row = dataset[0]
         first_len = len(row["image"].shape)
         second_len = len(row["text"].shape)
         self.assertEqual(first_len, 2,
                          "use a sequence type if dim is  < 2")
         self.assertEqual(second_len, 2,
                          "use a sequence type if dim is  < 2")