Exemplo n.º 1
0
  def test_split_filtered(self, mock_client):
    # filtering 2 documents: 2 <= 'x' < 4
    filtered_mongo_source = self._create_source(
        filter={'x': {
            '$gte': 2, '$lt': 4
        }}, bucket_auto=self.bucket_auto)

    mock_client.return_value = _MockMongoClient(self._docs)
    for size_mb, (bucket_auto_count, split_vector_count) in [(1, (2, 5)),
                                                             (2, (1, 3)),
                                                             (10, (1, 1))]:
      size = size_mb * 1024 * 1024
      splits = list(
          filtered_mongo_source.split(
              start_position=None, stop_position=None,
              desired_bundle_size=size))

      if self.bucket_auto:
        self.assertEqual(len(splits), bucket_auto_count)
      else:
        # Note: splitVector mode does not respect filter
        self.assertEqual(len(splits), split_vector_count)
      reference_info = (
          filtered_mongo_source, self._docs[2]['_id'], self._docs[4]['_id'])
      sources_info = ([
          (split.source, split.start_position, split.stop_position)
          for split in splits
      ])
      source_test_utils.assert_sources_equal_reference_source(
          reference_info, sources_info)
Exemplo n.º 2
0
  def test_split_filtered_empty(self, mock_client):
    # filtering doesn't match any documents
    filtered_mongo_source = self._create_source(
        filter={'x': {
            '$lt': 0
        }}, bucket_auto=self.bucket_auto)

    mock_client.return_value = _MockMongoClient(self._docs)
    for size_mb, (bucket_auto_count, split_vector_count) in [(1, (1, 5)),
                                                             (2, (1, 3)),
                                                             (10, (1, 1))]:
      size = size_mb * 1024 * 1024
      splits = list(
          filtered_mongo_source.split(
              start_position=None, stop_position=None,
              desired_bundle_size=size))

      if self.bucket_auto:
        # Note: if filter matches no docs - one split covers entire range
        self.assertEqual(len(splits), bucket_auto_count)
      else:
        # Note: splitVector mode does not respect filter
        self.assertEqual(len(splits), split_vector_count)
      reference_info = (
          filtered_mongo_source,
          # range to match no documents:
          _ObjectIdHelper.increment_id(self._docs[-1]['_id'], 1),
          _ObjectIdHelper.increment_id(self._docs[-1]['_id'], 2),
      )
      sources_info = ([
          (split.source, split.start_position, split.stop_position)
          for split in splits
      ])
      source_test_utils.assert_sources_equal_reference_source(
          reference_info, sources_info)
Exemplo n.º 3
0
  def test_split(self):
    for size in [1, 3, 10]:
      splits = list(self.source.split(desired_bundle_size=size))

      reference_info = (self.source, None, None)
      sources_info = ([(split.source, split.start_position, split.stop_position)
                       for split in splits])
      source_test_utils.assert_sources_equal_reference_source(
          reference_info, sources_info)
Exemplo n.º 4
0
 def test_synthetic_source_split_uneven(self):
     source = synthetic_pipeline.SyntheticSource(
         input_spec(1000, 1, 1, 'zipf', 3, 10))
     splits = source.split(100)
     sources_info = [(split.source, split.start_position,
                      split.stop_position) for split in splits]
     self.assertEqual(10, len(sources_info))
     source_test_utils.assert_sources_equal_reference_source(
         (source, None, None), sources_info)
 def testSyntheticSourceSplitUneven(self):
   source = synthetic_pipeline.SyntheticSource(
       input_spec(1000, 1, 1, 'zipf', 3, 10))
   splits = source.split(100)
   sources_info = [(split.source, split.start_position, split.stop_position)
                   for split in splits]
   self.assertEquals(10, len(sources_info))
   source_test_utils.assert_sources_equal_reference_source(
       (source, None, None), sources_info)
  def test_source_equals_reference_source(self):
    data = self._create_data(100)
    reference_source = self._create_source(data)
    sources_info = [(split.source, split.start_position, split.stop_position)
                    for split in reference_source.split(desired_bundle_size=50)]
    if len(sources_info) < 2:
      raise ValueError('Test is too trivial since splitting only generated %d'
                       'bundles. Please adjust the test so that at least '
                       'two splits get generated.' % len(sources_info))

    source_test_utils.assert_sources_equal_reference_source(
        (reference_source, None, None), sources_info)
Exemplo n.º 7
0
 def check_read_with_initial_splits(self, values, coder, num_splits):
     """A test that splits the given source into `num_splits` and verifies that
 the data read from original source is equal to the union of the data read
 from the split sources.
 """
     source = Create._create_source_from_iterable(values, coder)
     desired_bundle_size = source._total_size / num_splits
     splits = source.split(desired_bundle_size)
     splits_info = [(split.source, split.start_position,
                     split.stop_position) for split in splits]
     source_test_utils.assert_sources_equal_reference_source(
         (source, None, None), splits_info)
Exemplo n.º 8
0
    def test_read_after_splitting(self):
        file_name, expected_data = write_data(10)
        assert len(expected_data) == 10
        source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                            coders.StrUtf8Coder())
        splits = list(source.split(desired_bundle_size=33))

        reference_source_info = (source, None, None)
        sources_info = ([(split.source, split.start_position,
                          split.stop_position) for split in splits])
        source_test_utils.assert_sources_equal_reference_source(
            reference_source_info, sources_info)
Exemplo n.º 9
0
  def test_source_equals_reference_source(self):
    data = self._create_data(100)
    reference_source = self._create_source(data)
    sources_info = [(split.source, split.start_position, split.stop_position)
                    for split in reference_source.split(desired_bundle_size=50)]
    if len(sources_info) < 2:
      raise ValueError('Test is too trivial since splitting only generated %d'
                       'bundles. Please adjust the test so that at least '
                       'two splits get generated.' % len(sources_info))

    source_test_utils.assert_sources_equal_reference_source(
        (reference_source, None, None), sources_info)
Exemplo n.º 10
0
  def test_read_after_splitting(self):
    file_name, expected_data = write_data(10)
    assert len(expected_data) == 10
    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
                        coders.StrUtf8Coder())
    splits = list(source.split(desired_bundle_size=33))

    reference_source_info = (source, None, None)
    sources_info = ([
        (split.source, split.start_position, split.stop_position) for
        split in splits])
    source_test_utils.assert_sources_equal_reference_source(
        reference_source_info, sources_info)
Exemplo n.º 11
0
 def check_read_with_initial_splits(self, values, coder, num_splits):
   """A test that splits the given source into `num_splits` and verifies that
   the data read from original source is equal to the union of the data read
   from the split sources.
   """
   source = Create._create_source_from_iterable(values, coder)
   desired_bundle_size = source._total_size / num_splits
   splits = source.split(desired_bundle_size)
   splits_info = [
       (split.source, split.start_position, split.stop_position)
       for split in splits]
   source_test_utils.assert_sources_equal_reference_source(
       (source, None, None), splits_info)
Exemplo n.º 12
0
  def test_split(self, mock_client):
    mock_client.return_value = _MockMongoClient(self._docs)
    for size in [i * 1024 * 1024 for i in (1, 2, 10)]:
      splits = list(
          self.mongo_source.split(start_position=None,
                                  stop_position=None,
                                  desired_bundle_size=size))

      reference_info = (self.mongo_source, None, None)
      sources_info = ([(split.source, split.start_position, split.stop_position)
                       for split in splits])
      source_test_utils.assert_sources_equal_reference_source(
          reference_info, sources_info)
    def test_split(self, mock_client):
        mock_client.return_value = _MockMongoClient(self._docs)
        for size_mb, expected_split_count in [(0.5, 5), (1, 5), (2, 3),
                                              (10, 1)]:
            size = size_mb * 1024 * 1024
            splits = list(
                self.mongo_source.split(start_position=None,
                                        stop_position=None,
                                        desired_bundle_size=size))

            self.assertEqual(len(splits), expected_split_count)
            reference_info = (self.mongo_source, None, None)
            sources_info = ([(split.source, split.start_position,
                              split.stop_position) for split in splits])
            source_test_utils.assert_sources_equal_reference_source(
                reference_info, sources_info)
Exemplo n.º 14
0
  def _run_parquet_test(self, pattern, columns, desired_bundle_size,
                        perform_splitting, expected_result):
    source = _create_parquet_source(pattern, columns=columns)
    if perform_splitting:
      assert desired_bundle_size
      sources_info = [
          (split.source, split.start_position, split.stop_position)
          for split in source.split(desired_bundle_size=desired_bundle_size)
      ]
      if len(sources_info) < 2:
        raise ValueError('Test is trivial. Please adjust it so that at least '
                         'two splits get generated')

      source_test_utils.assert_sources_equal_reference_source(
          (source, None, None), sources_info)
    else:
      read_records = source_test_utils.read_from_source(source, None, None)
      self.assertCountEqual(expected_result, read_records)
Exemplo n.º 15
0
    def test_split(self, mock_client):
        # desired bundle size is 1 times of avg doc size, each bundle contains 1
        # documents
        mock_client.return_value.__enter__.return_value.__getitem__.return_value \
          .__getitem__.return_value.find.return_value = [{'x': 1}, {'x': 2},
                                                         {'x': 3}, {'x': 4},
                                                         {'x': 5}]
        for size in [10, 20, 100]:
            splits = list(
                self.mongo_source.split(start_position=0,
                                        stop_position=5,
                                        desired_bundle_size=size))

            reference_info = (self.mongo_source, None, None)
            sources_info = ([(split.source, split.start_position,
                              split.stop_position) for split in splits])
            source_test_utils.assert_sources_equal_reference_source(
                reference_info, sources_info)
Exemplo n.º 16
0
  def test_read_gzip_large_after_splitting(self):
    _, lines = write_data(10000)
    file_name = self._create_temp_file()
    with gzip.GzipFile(file_name, 'wb') as f:
      f.write('\n'.join(lines))

    source = TextSource(file_name, 0, CompressionTypes.GZIP, True,
                        coders.StrUtf8Coder())
    splits = [split for split in source.split(desired_bundle_size=1000)]

    if len(splits) > 1:
      raise ValueError('FileBasedSource generated more than one initial split '
                       'for a compressed file.')

    reference_source_info = (source, None, None)
    sources_info = ([
        (split.source, split.start_position, split.stop_position) for
        split in splits])
    source_test_utils.assert_sources_equal_reference_source(
        reference_source_info, sources_info)
Exemplo n.º 17
0
    def test_read_gzip_large_after_splitting(self):
        _, lines = write_data(10000)
        file_name = self._create_temp_file()
        with gzip.GzipFile(file_name, 'wb') as f:
            f.write('\n'.join(lines))

        source = TextSource(file_name, 0, CompressionTypes.GZIP, True,
                            coders.StrUtf8Coder())
        splits = [split for split in source.split(desired_bundle_size=1000)]

        if len(splits) > 1:
            raise ValueError(
                'FileBasedSource generated more than one initial split '
                'for a compressed file.')

        reference_source_info = (source, None, None)
        sources_info = ([(split.source, split.start_position,
                          split.stop_position) for split in splits])
        source_test_utils.assert_sources_equal_reference_source(
            reference_source_info, sources_info)
Exemplo n.º 18
0
    def _run_avro_test(self, pattern, desired_bundle_size, perform_splitting,
                       expected_result):
        source = _create_avro_source(pattern, use_fastavro=self.use_fastavro)

        read_records = []
        if perform_splitting:
            assert desired_bundle_size
            splits = [
                split for split in source.split(
                    desired_bundle_size=desired_bundle_size)
            ]
            if len(splits) < 2:
                raise ValueError(
                    'Test is trivial. Please adjust it so that at least '
                    'two splits get generated')

            sources_info = [(split.source, split.start_position,
                             split.stop_position) for split in splits]
            source_test_utils.assert_sources_equal_reference_source(
                (source, None, None), sources_info)
        else:
            read_records = source_test_utils.read_from_source(
                source, None, None)
            self.assertItemsEqual(expected_result, read_records)
Exemplo n.º 19
0
  def _run_avro_test(self, pattern, desired_bundle_size, perform_splitting,
                     expected_result):
    source = AvroSource(pattern)

    read_records = []
    if perform_splitting:
      assert desired_bundle_size
      splits = [
          split
          for split in source.split(desired_bundle_size=desired_bundle_size)
      ]
      if len(splits) < 2:
        raise ValueError('Test is trivial. Please adjust it so that at least '
                         'two splits get generated')

      sources_info = [
          (split.source, split.start_position, split.stop_position)
          for split in splits
      ]
      source_test_utils.assert_sources_equal_reference_source(
          (source, None, None), sources_info)
    else:
      read_records = source_test_utils.read_from_source(source, None, None)
      self.assertItemsEqual(expected_result, read_records)