def test_gcs_storage_cacheable_multi_threaded(self) -> None: dataset_id = "range-dataset" dataset_version = "0" num_threads = 20 configurations = create_gcs_configuration(access_server_port=15032) access_server_handler = test_util.AccessServerHandler( hostname="localhost", port=15032) access_server_handler.run_server_in_thread() gcs_cache_filepath = get_gcs_filepath( configurations=configurations, dataset_id=dataset_id, dataset_version=dataset_version, ) client = google_storage.Client() bucket = client.bucket(configurations.bucket) blob = bucket.blob(str(gcs_cache_filepath)) if blob.exists(): blob.delete() try: with thread.ThreadJoiner(10): for _ in range(num_threads): self.run_in_thread(lambda: worker( configurations, dataset_id, dataset_version)) finally: access_server_handler.stop_server()
def test_gcs_storage_cacheable_multi_threaded(self) -> None: dataset_id = "range-dataset" dataset_version = "0" num_threads = 20 configurations = create_s3_configuration(access_server_port=15032) access_server_handler = test_util.AccessServerHandler( hostname="localhost", port=15032) access_server_handler.run_server_in_thread() s3_cache_filepath = get_s3_filepath( configurations=configurations, dataset_id=dataset_id, dataset_version=dataset_version, ) client = boto3.client("s3") client.delete_object(Bucket=configurations.bucket, Key=str(s3_cache_filepath)) try: with thread.ThreadJoiner(10): for _ in range(num_threads): self.run_in_thread(lambda: worker( configurations, dataset_id, dataset_version)) finally: access_server_handler.stop_server()
def test_rw_coordinator(self) -> None: ip_address = "localhost" port = 10245 bucket = "my_bucket" cache_path = pathlib.Path("/tmp.mdb") num_threads = 5 shared_data = [0] access_server_handler = test_util.AccessServerHandler( hostname=ip_address, port=port) access_server_handler.run_server_in_thread() access_client = rw_coordinator.RwCoordinatorClient( url=f"ws://{ip_address}:{port}") try: with thread.ThreadJoiner(45): for i in range(num_threads): self.run_in_thread(lambda: read_and_sleep( access_client=access_client, sleep_time=i + 1, bucket=bucket, cache_path=cache_path, )) self.run_in_thread(lambda: write_and_sleep( shared_data=shared_data, access_client=access_client, sleep_time=i, bucket=bucket, cache_path=cache_path, )) finally: access_server_handler.stop_server() assert shared_data[0] == num_threads
def test_gcs_storage_cacheable_single_threaded() -> None: original_range_size = 120 updated_range_size = 55 dataset_id = "range-dataset" dataset_version = "0" configurations = create_gcs_configuration(access_server_port=15032) access_server_handler = test_util.AccessServerHandler(hostname="localhost", port=15032) access_server_handler.run_server_in_thread() gcs_cache_filepath = get_gcs_filepath( configurations=configurations, dataset_id=dataset_id, dataset_version=dataset_version, ) client = google_storage.Client() bucket = client.bucket(configurations.bucket) blob = bucket.blob(str(gcs_cache_filepath)) if blob.exists(): blob.delete() gcs_storage = storage.GCSStorage(configurations=configurations) @gcs_storage.cacheable(dataset_id, dataset_version) def make_dataref(range_size: int) -> dataref.LMDBDataRef: return tf.data.Dataset.range(range_size) # type: ignore original_data_stream = make_dataref( range_size=original_range_size).stream() assert original_data_stream.length == original_range_size data_generator = original_data_stream.iterator_fn() generator_length = 0 for idx, data in enumerate(data_generator): assert idx == data generator_length += 1 assert generator_length == original_range_size updated_data_stream = make_dataref(range_size=updated_range_size).stream() assert updated_data_stream.length == original_range_size access_server_handler.stop_server()