def testCreate(self):
     """Tests the creation of a dataset with valid parameters."""
     parquet._RawParquetDataset(filenames=self._test_filenames,
                                value_paths=["DocId"],
                                value_dtypes=(tf.int64, ),
                                parent_index_paths=["DocId"],
                                path_index=[0],
                                batch_size=1)
 def testMultipleColumns(self):
     """Tests that the dataset supports multiple columns."""
     pq_ds = parquet._RawParquetDataset(filenames=self._test_filenames,
                                        value_paths=[
                                            "DocId", "Name.Language.Code",
                                            "Name.Language.Country"
                                        ],
                                        value_dtypes=(
                                            tf.int64,
                                            tf.string,
                                            tf.string,
                                        ),
                                        parent_index_paths=[
                                            "DocId", "Name.Language.Code",
                                            "Name.Language.Code",
                                            "Name.Language.Code",
                                            "Name.Language.Country"
                                        ],
                                        path_index=[0, 0, 1, 2, 2],
                                        batch_size=1)
     self.assertDatasetProduces(pq_ds,
                                expected_output=[
                                    (1, [0], [10], [0, 0, 0], [0, 0, 2],
                                     [0, 1,
                                      2], [b"en-us", b"en",
                                           b"en-gb"], [0,
                                                       2], [b"us", b"gb"]),
                                    (1, [0], [20], [0], [], [], [], [], [])
                                ])
 def testMultipleColumns_TwoRowGroupsAndEqualBatchSize(self):
     """Tests that the dataset supports multiple columns."""
     pq_ds = parquet._RawParquetDataset(
         filenames=self._rowgroup_test_filenames,
         value_paths=[
             "DocId", "Name.Language.Code", "Name.Language.Country"
         ],
         value_dtypes=(
             tf.int64,
             tf.string,
             tf.string,
         ),
         parent_index_paths=[
             "DocId", "Name.Language.Code", "Name.Language.Code",
             "Name.Language.Code", "Name.Language.Country"
         ],
         path_index=[0, 0, 1, 2, 2],
         batch_size=2)
     self.assertDatasetProduces(pq_ds,
                                expected_output=[
                                    (2, [0, 1], [10, 20], [0, 0, 0,
                                                           1], [0, 0, 2],
                                     [0, 1,
                                      2], [b"en-us", b"en",
                                           b"en-gb"], [0,
                                                       2], [b"us", b"gb"]),
                                    (2, [0, 1], [30,
                                                 40], [0, 0, 0,
                                                       1], [0, 0,
                                                            2], [0, 1, 2],
                                     [b"en-us2", b"en2",
                                      b"en-gb2"], [0, 2], [b"us2", b"gb2"])
                                ])
    def testTwoRowGroupsAndLargerBatchSize(self):
        """Tests batch size > row group size, with two row groups.

    Input:
    Rowgroup0:
      Document
        DocId: 10
      Document
        DocId: 20
    RowGroup1:
      Document
        DocId: 30
      Document
        DocId: 40
    """
        pq_ds = parquet._RawParquetDataset(
            filenames=self._rowgroup_test_filenames,
            value_paths=["DocId"],
            value_dtypes=(tf.int64, ),
            parent_index_paths=["DocId"],
            path_index=[0],
            batch_size=3)
        self.assertDatasetProduces(pq_ds,
                                   expected_output=[(3, [0, 1,
                                                         2], [10, 20, 30]),
                                                    (1, [0], [40])])
    def testTwoRowGroupsAndLargerBatchSizeFirstValueIsNone(self):
        """Tests batch size > row group size, where the first value is None.

    This tests that the buffer is still used properly, even if the value
    cached is None.
    Input:
    Rowgroup0:
      Document
        Links
      Document
        Links
          Backward: 10
          Backward: 30
    RowGroup1:
      Document
        Links
      Document
        Links
          Backward: 100
          Backward: 300
    """
        pq_ds = parquet._RawParquetDataset(
            filenames=self._rowgroup_test_filenames,
            value_paths=["Links.Backward"],
            value_dtypes=(tf.int64, ),
            parent_index_paths=["Links.Backward", "Links.Backward"],
            path_index=[0, 1],
            batch_size=3)
        self.assertDatasetProduces(pq_ds,
                                   expected_output=[(3, [0, 1,
                                                         2], [1, 1], [10, 30]),
                                                    (1, [0], [0,
                                                              0], [100, 300])])
 def testBatchEvenlyDivisible_ReadsTwoMessages(self):
     """Tests batch size that evenly divides the total number of messages."""
     pq_ds = parquet._RawParquetDataset(filenames=self._test_filenames,
                                        value_paths=["DocId"],
                                        value_dtypes=(tf.int64, ),
                                        parent_index_paths=["DocId"],
                                        path_index=[0],
                                        batch_size=2)
     self.assertDatasetProduces(pq_ds,
                                expected_output=[(2, [0, 1], [10, 20])])
 def testGetNext_HandlesDouble(self):
     """Tests that double is translated from parquet file to tensor."""
     pq_ds = parquet._RawParquetDataset(
         filenames=self._datatype_test_filenames,
         value_paths=["double"],
         value_dtypes=(tf.double, ),
         parent_index_paths=["double"],
         path_index=[0],
         batch_size=1)
     self.assertDatasetProduces(pq_ds, expected_output=[(1, [0], [40.0])])
 def testGetNext_HandlesString(self):
     """Tests that string is translated from parquet file to tensor."""
     pq_ds = parquet._RawParquetDataset(
         filenames=self._datatype_test_filenames,
         value_paths=["byte_array"],
         value_dtypes=(tf.string, ),
         parent_index_paths=["byte_array"],
         path_index=[0],
         batch_size=1)
     self.assertDatasetProduces(pq_ds,
                                expected_output=[(1, [0], [b"fifty"])])
    def testBatchLargerThanTotal_ReadsTwoMessages(self):
        """Tests batch size that is larger than the total number of messages.

    Since the batch size is larger, the output would be less than batch size.
    """
        pq_ds = parquet._RawParquetDataset(filenames=self._test_filenames,
                                           value_paths=["DocId"],
                                           value_dtypes=(tf.int64, ),
                                           parent_index_paths=["DocId"],
                                           path_index=[0],
                                           batch_size=5)
        self.assertDatasetProduces(pq_ds,
                                   expected_output=[(2, [0, 1], [10, 20])])
 def testDatasetShuffle(self):
     """Tests that the dataset supports shuffling of the dataset."""
     pq_ds = parquet._RawParquetDataset(filenames=self._test_filenames,
                                        value_paths=["DocId"],
                                        value_dtypes=(tf.int64, ),
                                        parent_index_paths=["DocId"],
                                        path_index=[0],
                                        batch_size=1)
     pq_ds = pq_ds.shuffle(10)
     self.assertDatasetProduces(pq_ds,
                                expected_output=[(1, [0], [10]),
                                                 (1, [0], [20])],
                                assert_items_equal=True)
    def testTwoRowGroupsAndDefaultBatchSizeContainsNones(self):
        """Tests default batch size with two row groups with None values.

    Input:
    RowGroup0:
      Document
        Name
          Language
            Code: 'en-us'
          Language
            Code: 'en'
        Name
        Name
          Language
            Code: 'en-gb'
      Document
        Name
    RowGroup1:
      Document
        Name
          Language
            Code: 'en-us2'
          Language
            Code: 'en2'
        Name
        Name
          Language
            Code: 'en-gb2'
      Document
        Name
    """
        pq_ds = parquet._RawParquetDataset(
            filenames=self._rowgroup_test_filenames,
            value_paths=["Name.Language.Code"],
            value_dtypes=(tf.string, ),
            parent_index_paths=[
                "Name.Language.Code", "Name.Language.Code",
                "Name.Language.Code"
            ],
            path_index=[0, 1, 2],
            batch_size=1)
        self.assertDatasetProduces(pq_ds,
                                   expected_output=[
                                       (1, [0, 0, 0], [0, 0, 2], [0, 1, 2],
                                        [b"en-us", b"en", b"en-gb"]),
                                       (1, [0], [], [], []),
                                       (1, [0, 0, 0], [0, 0, 2], [0, 1, 2],
                                        [b"en-us2", b"en2", b"en-gb2"]),
                                       (1, [0], [], [], [])
                                   ])
 def testMultipleFilesLargeBatchSize(self):
     """Tests that the dataset supports multiple files and large batch size."""
     # TODO(andylou) We don't want this behavior in the future.
     # We eventually want the dataset to grab batches across files.
     pq_ds = parquet._RawParquetDataset(filenames=[
         self._test_filenames[0], self._rowgroup_test_filenames[0]
     ],
                                        value_paths=["DocId"],
                                        value_dtypes=(tf.int64, ),
                                        parent_index_paths=["DocId"],
                                        path_index=[0],
                                        batch_size=5)
     self.assertDatasetProduces(pq_ds,
                                expected_output=[(2, [0, 1], [10, 20]),
                                                 (4, [0, 1, 2,
                                                      3], [10, 20, 30,
                                                           40])])
 def testMultipleFiles(self):
     """Tests that the dataset supports multiple files."""
     pq_ds = parquet._RawParquetDataset(filenames=[
         self._test_filenames[0], self._rowgroup_test_filenames[0]
     ],
                                        value_paths=["DocId"],
                                        value_dtypes=(tf.int64, ),
                                        parent_index_paths=["DocId"],
                                        path_index=[0],
                                        batch_size=1)
     self.assertDatasetProduces(pq_ds,
                                expected_output=[(1, [0], [10]),
                                                 (1, [0], [20]),
                                                 (1, [0], [10]),
                                                 (1, [0], [20]),
                                                 (1, [0], [30]),
                                                 (1, [0], [40])])
    def testBatchLargerThanTotalContainsNones_ReadsTwoMessages(self):
        """Tests batch size that is larger than the total number of messages.

    And that the messages contains None values.
    """
        pq_ds = parquet._RawParquetDataset(
            filenames=self._test_filenames,
            value_paths=["Name.Language.Country"],
            value_dtypes=(tf.string, ),
            parent_index_paths=[
                "Name.Language.Country", "Name.Language.Country",
                "Name.Language.Country"
            ],
            path_index=[0, 1, 2],
            batch_size=5)
        self.assertDatasetProduces(pq_ds,
                                   expected_output=[(2, [0, 0, 0,
                                                         1], [0, 0, 2], [0, 2],
                                                     [b"us", b"gb"])])
    def testFirstParentIndexOnly(self):
        """Tests only reqruest one parent index.

    Even when the path contains more than one. i.e. Name.Language.Code should
    have 4 parent indices (including the root).
    """
        pq_ds = parquet._RawParquetDataset(
            filenames=self._test_filenames,
            value_paths=["Name.Language.Code"],
            value_dtypes=(tf.string, ),
            parent_index_paths=["Name.Language.Code"],
            path_index=[0],
            batch_size=1)
        self.assertDatasetProduces(pq_ds,
                                   expected_output=[
                                       (1, [0, 0,
                                            0], [b"en-us", b"en", b"en-gb"]),
                                       (1, [0], [])
                                   ])
 def testInvalidParentIndexPaths(self):
     """Tests that wrong parent_index_paths order will throw an error."""
     with self.assertRaisesRegex(
             tf.errors.InvalidArgumentError,
             "parent_index_paths is not aligned with value_paths"):
         pq_ds = parquet._RawParquetDataset(
             filenames=self._test_filenames,
             value_paths=[
                 "DocId", "Name.Language.Code", "Name.Language.Country"
             ],
             value_dtypes=(
                 tf.int64,
                 tf.string,
                 tf.string,
             ),
             parent_index_paths=[
                 "Name.Language.Code", "Name.Language.Code",
                 "Name.Language.Code", "DocId", "Name.Language.Country"
             ],
             path_index=[0, 1, 2, 0, 2],
             batch_size=1)
         get_next = self._getNext(pq_ds, True)
         self.evaluate(get_next())
    def testTwoRowGroupsAndEqualBatchSizeLargeFirstMessage(self):
        """Tests batch size == row group size, where the first message is large.

    Input:
    Rowgroup0:
      Document
        Links
          Forward: 20
          Forward: 40
          Forward: 60
      Document
        Links
          Forward: 80
    RowGroup1:
      Document
        Links
          Forward: 200
          Forward: 400
          Forward: 600
      Document
        Links
          Forward: 800
    """
        pq_ds = parquet._RawParquetDataset(
            filenames=self._rowgroup_test_filenames,
            value_paths=["Links.Forward"],
            value_dtypes=(tf.int64, ),
            parent_index_paths=["Links.Forward", "Links.Forward"],
            path_index=[0, 1],
            batch_size=2)
        self.assertDatasetProduces(pq_ds,
                                   expected_output=[
                                       (2, [0, 1], [0, 0, 0,
                                                    1], [20, 40, 60, 80]),
                                       (2, [0, 1], [0, 0, 0,
                                                    1], [200, 400, 600, 800])
                                   ])