def test_source_split(self):
        source = RangeSource(0, 100)
        expected_splits = list(source.split(30))

        worker = sdk_harness.SdkWorker(
            None, data_plane.GrpcClientDataChannelFactory())
        worker.register(
            beam_fn_api_pb2.RegisterRequest(process_bundle_descriptor=[
                beam_fn_api_pb2.ProcessBundleDescriptor(primitive_transform=[
                    beam_fn_api_pb2.PrimitiveTransform(
                        function_spec=sdk_harness.serialize_and_pack_py_fn(
                            SourceBundle(1.0, source, None, None),
                            sdk_harness.PYTHON_SOURCE_URN,
                            id="src"))
                ])
            ]))
        split_response = worker.initial_source_split(
            beam_fn_api_pb2.InitialSourceSplitRequest(
                desired_bundle_size_bytes=30, source_reference="src"))

        self.assertEqual(expected_splits, [
            sdk_harness.unpack_and_deserialize_py_fn(s.source)
            for s in split_response.splits
        ])

        self.assertEqual([s.weight for s in expected_splits],
                         [s.relative_size for s in split_response.splits])
  def test_source_split(self):
    source = RangeSource(0, 100)
    expected_splits = list(source.split(30))

    worker = sdk_harness.SdkWorker(
        None, data_plane.GrpcClientDataChannelFactory())
    worker.register(
        beam_fn_api_pb2.RegisterRequest(
            process_bundle_descriptor=[beam_fn_api_pb2.ProcessBundleDescriptor(
                primitive_transform=[beam_fn_api_pb2.PrimitiveTransform(
                    function_spec=sdk_harness.serialize_and_pack_py_fn(
                        SourceBundle(1.0, source, None, None),
                        sdk_harness.PYTHON_SOURCE_URN,
                        id="src"))])]))
    split_response = worker.initial_source_split(
        beam_fn_api_pb2.InitialSourceSplitRequest(
            desired_bundle_size_bytes=30,
            source_reference="src"))

    self.assertEqual(
        expected_splits,
        [sdk_harness.unpack_and_deserialize_py_fn(s.source)
         for s in split_response.splits])

    self.assertEqual(
        [s.weight for s in expected_splits],
        [s.relative_size for s in split_response.splits])
Exemple #3
0
 def setUp(self):
     self.initial_range_start = 0
     self.initial_range_stop = 4
     self.initial_range_source = RangeSource(self.initial_range_start,
                                             self.initial_range_stop)
     self.sdf_restriction_provider = (
         iobase._SDFBoundedSourceRestrictionProvider(desired_chunk_size=2))
    def test_source_split_via_instruction(self):

        source = RangeSource(0, 100)
        expected_splits = list(source.split(30))

        test_controller = BeamFnControlServicer([
            beam_fn_api_pb2.InstructionRequest(
                instruction_id="register_request",
                register=beam_fn_api_pb2.RegisterRequest(
                    process_bundle_descriptor=[
                        beam_fn_api_pb2.ProcessBundleDescriptor(
                            primitive_transform=[
                                beam_fn_api_pb2.PrimitiveTransform(
                                    function_spec=sdk_harness.
                                    serialize_and_pack_py_fn(
                                        SourceBundle(1.0, source, None, None),
                                        sdk_harness.PYTHON_SOURCE_URN,
                                        id="src"))
                            ])
                    ])),
            beam_fn_api_pb2.InstructionRequest(
                instruction_id="split_request",
                initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest(
                    desired_bundle_size_bytes=30, source_reference="src"))
        ])

        server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        beam_fn_api_pb2.add_BeamFnControlServicer_to_server(
            test_controller, server)
        test_port = server.add_insecure_port("[::]:0")
        server.start()

        channel = grpc.insecure_channel("localhost:%s" % test_port)
        harness = sdk_harness.SdkHarness(channel)
        harness.run()

        split_response = test_controller.responses[
            "split_request"].initial_source_split

        self.assertEqual(expected_splits, [
            sdk_harness.unpack_and_deserialize_py_fn(s.source)
            for s in split_response.splits
        ])

        self.assertEqual([s.weight for s in expected_splits],
                         [s.relative_size for s in split_response.splits])
  def test_source_split_via_instruction(self):

    source = RangeSource(0, 100)
    expected_splits = list(source.split(30))

    test_controller = BeamFnControlServicer([
        beam_fn_api_pb2.InstructionRequest(
            instruction_id="register_request",
            register=beam_fn_api_pb2.RegisterRequest(
                process_bundle_descriptor=[
                    beam_fn_api_pb2.ProcessBundleDescriptor(
                        primitive_transform=[beam_fn_api_pb2.PrimitiveTransform(
                            function_spec=sdk_harness.serialize_and_pack_py_fn(
                                SourceBundle(1.0, source, None, None),
                                sdk_harness.PYTHON_SOURCE_URN,
                                id="src"))])])),
        beam_fn_api_pb2.InstructionRequest(
            instruction_id="split_request",
            initial_source_split=beam_fn_api_pb2.InitialSourceSplitRequest(
                desired_bundle_size_bytes=30,
                source_reference="src"))
        ])

    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
    beam_fn_api_pb2.add_BeamFnControlServicer_to_server(test_controller, server)
    test_port = server.add_insecure_port("[::]:0")
    server.start()

    channel = grpc.insecure_channel("localhost:%s" % test_port)
    harness = sdk_harness.SdkHarness(channel)
    harness.run()

    split_response = test_controller.responses[
        "split_request"].initial_source_split

    self.assertEqual(
        expected_splits,
        [sdk_harness.unpack_and_deserialize_py_fn(s.source)
         for s in split_response.splits])

    self.assertEqual(
        [s.weight for s in expected_splits],
        [s.relative_size for s in split_response.splits])
Exemple #6
0
 def test_try_split_with_any_exception(self):
     source_bundle = SourceBundle(
         range_trackers.OffsetRangeTracker.OFFSET_INFINITY,
         RangeSource(0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY),
         0, range_trackers.OffsetRangeTracker.OFFSET_INFINITY)
     self.sdf_restriction_tracker = (
         iobase._SDFBoundedSourceRestrictionTracker(
             iobase._SDFBoundedSourceRestriction(source_bundle)))
     self.sdf_restriction_tracker.try_claim(0)
     self.assertIsNone(self.sdf_restriction_tracker.try_split(0.5))
Exemple #7
0
 def setUp(self):
     self.initial_start_pos = 0
     self.initial_stop_pos = 4
     source_bundle = SourceBundle(
         self.initial_stop_pos - self.initial_start_pos,
         RangeSource(self.initial_start_pos, self.initial_stop_pos),
         self.initial_start_pos, self.initial_stop_pos)
     self.sdf_restriction_tracker = (
         iobase._SDFBoundedSourceRestrictionTracker(
             iobase._SDFBoundedSourceRestriction(source_bundle)))
Exemple #8
0
 def test_create_tracker(self):
     expected_start = 1
     expected_stop = 3
     source_bundle = SourceBundle(expected_stop - expected_start,
                                  RangeSource(1, 3), expected_start,
                                  expected_stop)
     restriction_tracker = (self.sdf_restriction_provider.create_tracker(
         iobase._SDFBoundedSourceRestriction(source_bundle)))
     self.assertTrue(
         isinstance(restriction_tracker,
                    iobase._SDFBoundedSourceRestrictionTracker))
     self.assertEqual(expected_start, restriction_tracker.start_pos())
     self.assertEqual(expected_stop, restriction_tracker.stop_pos())
Exemple #9
0
 def test_initialization(self):
     with self.assertRaises(ValueError):
         iobase.ThreadsafeRestrictionTracker(RangeSource(0, 1))
Exemple #10
0
 def test_sdf_wrap_range_source(self):
     self._run_sdf_wrapper_pipeline(RangeSource(0, 4), [0, 1, 2, 3])
Exemple #11
0
    def test_sdf_wrapper_overrides_read(self, sdf_wrapper_mock_expand):
        def _fake_wrapper_expand(pbegin):
            return (pbegin | beam.Create(['fake']))

        sdf_wrapper_mock_expand.side_effect = _fake_wrapper_expand
        self._run_sdf_wrapper_pipeline(RangeSource(0, 4), ['fake'])