예제 #1
0
    def test_source_split(self):
        source = RangeSource(0, 100)
        expected_splits = list(source.split(30))

        worker = sdk_harness.SdkWorker(
            None, data_plane.GrpcClientDataChannelFactory())
        worker.register(
            beam_fn_api_pb2.RegisterRequest(process_bundle_descriptor=[
                beam_fn_api_pb2.ProcessBundleDescriptor(primitive_transform=[
                    beam_fn_api_pb2.PrimitiveTransform(
                        function_spec=sdk_harness.serialize_and_pack_py_fn(
                            SourceBundle(1.0, source, None, None),
                            sdk_harness.PYTHON_SOURCE_URN,
                            id="src"))
                ])
            ]))
        split_response = worker.initial_source_split(
            beam_fn_api_pb2.InitialSourceSplitRequest(
                desired_bundle_size_bytes=30, source_reference="src"))

        self.assertEqual(expected_splits, [
            sdk_harness.unpack_and_deserialize_py_fn(s.source)
            for s in split_response.splits
        ])

        self.assertEqual([s.weight for s in expected_splits],
                         [s.relative_size for s in split_response.splits])
예제 #2
0
 def test_try_split_with_any_exception(self):
     source_bundle = SourceBundle(
         range_trackers.OffsetRangeTracker.OFFSET_INFINITY,
         RangeSource(0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY),
         0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY)
     self.sdf_restriction_tracker = (
         iobase._SDFBoundedSourceRestrictionTracker(
             iobase._SDFBoundedSourceRestriction(source_bundle)))
     self.sdf_restriction_tracker.try_claim(0)
     self.assertIsNone(self.sdf_restriction_tracker.try_split(0.5))
예제 #3
0
 def setUp(self):
     self.initial_start_pos = 0
     self.initial_stop_pos = 4
     source_bundle = SourceBundle(
         self.initial_stop_pos - self.initial_start_pos,
         RangeSource(self.initial_start_pos, self.initial_stop_pos),
         self.initial_start_pos, self.initial_stop_pos)
     self.sdf_restriction_tracker = (
         iobase._SDFBoundedSourceRestrictionTracker(
             iobase._SDFBoundedSourceRestriction(source_bundle)))
예제 #4
0
    def split(self, desired_bundle_size, start_position=0, stop_position=None):
        """
        Implements method `apache_beam.io.iobase.BoundedSource.split`

        `BillboardSource` is unsplittable, so only a single source is returned.
        """
        stop_position = range_trackers.OffsetRangeTracker.OFFSET_INFINITY

        yield SourceBundle(weight=1,
                           source=self,
                           start_position=start_position,
                           stop_position=stop_position)
예제 #5
0
 def test_create_tracker(self):
     expected_start = 1
     expected_stop = 3
     source_bundle = SourceBundle(expected_stop - expected_start,
                                  RangeSource(1, 3), expected_start,
                                  expected_stop)
     restriction_tracker = (self.sdf_restriction_provider.create_tracker(
         iobase._SDFBoundedSourceRestriction(source_bundle)))
     self.assertTrue(
         isinstance(restriction_tracker,
                    iobase._SDFBoundedSourceRestrictionTracker))
     self.assertEqual(expected_start, restriction_tracker.start_pos())
     self.assertEqual(expected_stop, restriction_tracker.stop_pos())
예제 #6
0
    def split(self,
              desired_bundle_size=None,
              start_position=None,
              stop_position=None):
        """ Splits the source into a set of bundles, using the row_set if it is set.

    *** At this point, only splitting an entire table into samples based on the sample row keys is supported ***

    :param desired_bundle_size: the desired size (in bytes) of the bundles returned.
    :param start_position: if specified, the position must be used as the starting position of the first bundle.
    :param stop_position: if specified, the position must be used as the ending position of the last bundle.
    Returns:
    	an iterator of objects of type 'SourceBundle' that gives information about the generated bundles.
    """

        if desired_bundle_size is not None or start_position is not None or stop_position is not None:
            raise NotImplementedError

        # TODO: Use the desired bundle size to split accordingly
        # TODO: Allow users to provide their own row sets

        sample_row_keys = list(self.get_sample_row_keys())
        bundles = []
        if len(sample_row_keys) > 0 and sample_row_keys[0] != b'':
            bundles.append(
                SourceBundle(sample_row_keys[0].offset_bytes, self, b'',
                             sample_row_keys[0].row_key))
        for i in range(1, len(sample_row_keys)):
            pos_start = sample_row_keys[i - 1].offset_bytes
            pos_stop = sample_row_keys[i].offset_bytes
            bundles.append(
                SourceBundle(pos_stop - pos_start, self,
                             sample_row_keys[i - 1].row_key,
                             sample_row_keys[i].row_key))

        # Shuffle is needed to allow reading from different locations of the table for better efficiency
        shuffle(bundles)
        return bundles
예제 #7
0
    def split(self, desired_bundle_size, start_offset=None, stop_offset=None):
        if start_offset is None:
            start_offset = self._start_offset
        if stop_offset is None:
            stop_offset = self._stop_offset

        if self._splittable:
            splits = OffsetRange(start_offset,
                                 stop_offset).split(desired_bundle_size,
                                                    self._min_bundle_size)
            for split in splits:
                yield SourceBundle(
                    split.stop - split.start,
                    _SingleFileSource(
                        # Copying this so that each sub-source gets a fresh instance.
                        pickler.loads(pickler.dumps(self._file_based_source)),
                        self._file_name,
                        split.start,
                        split.stop,
                        min_bundle_size=self._min_bundle_size,
                        splittable=self._splittable),
                    split.start,
                    split.stop)
        else:
            # Returning a single sub-source with end offset set to OFFSET_INFINITY (so
            # that all data of the source gets read) since this source is
            # unsplittable. Choosing size of the file as end offset will be wrong for
            # certain unsplittable source, e.g., compressed sources.
            yield SourceBundle(
                stop_offset - start_offset,
                _SingleFileSource(self._file_based_source,
                                  self._file_name,
                                  start_offset,
                                  OffsetRangeTracker.OFFSET_INFINITY,
                                  min_bundle_size=self._min_bundle_size,
                                  splittable=self._splittable), start_offset,
                OffsetRangeTracker.OFFSET_INFINITY)
예제 #8
0
    def test_source_split_via_instruction(self):

        source = RangeSource(0, 100)
        expected_splits = list(source.split(30))

        test_controller = BeamFnControlServicer([
            beam_fn_api_pb2.InstructionRequest(
                instruction_id="register_request",
                register=beam_fn_api_pb2.RegisterRequest(
                    process_bundle_descriptor=[
                        beam_fn_api_pb2.ProcessBundleDescriptor(
                            primitive_transform=[
                                beam_fn_api_pb2.PrimitiveTransform(
                                    function_spec=sdk_harness.
                                    serialize_and_pack_py_fn(
                                        SourceBundle(1.0, source, None, None),
                                        sdk_harness.PYTHON_SOURCE_URN,
                                        id="src"))
                            ])
                    ])),
            beam_fn_api_pb2.InstructionRequest(
                instruction_id="split_request",
                initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest(
                    desired_bundle_size_bytes=30, source_reference="src"))
        ])

        server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        beam_fn_api_pb2.add_BeamFnControlServicer_to_server(
            test_controller, server)
        test_port = server.add_insecure_port("[::]:0")
        server.start()

        channel = grpc.insecure_channel("localhost:%s" % test_port)
        harness = sdk_harness.SdkHarness(channel)
        harness.run()

        split_response = test_controller.responses[
            "split_request"].initial_source_split

        self.assertEqual(expected_splits, [
            sdk_harness.unpack_and_deserialize_py_fn(s.source)
            for s in split_response.splits
        ])

        self.assertEqual([s.weight for s in expected_splits],
                         [s.relative_size for s in split_response.splits])
예제 #9
0
    def split(self,
              desired_bundle_size,
              start_position=None,
              stop_position=None):
        if self.split_result is None:
            bq = bigquery_tools.BigQueryWrapper()

            if self.query is not None:
                self._setup_temporary_dataset(bq)
                self.table_reference = self._execute_query(bq)

            schema, metadata_list = self._export_files(bq)
            self.split_result = [
                TextSource(metadata.path, 0,
                           CompressionTypes.UNCOMPRESSED, True,
                           self.coder(schema)) for metadata in metadata_list
            ]
            if self.query is not None:
                bq.clean_up_temporary_dataset(self.project.get())

        for source in self.split_result:
            yield SourceBundle(0, source, None, None)