def test_data_set_with_enums(): json = '{"src":"http://foo.com","type":"FILE","status":"DONE","path":"/mount/path"}' dataset = DownloadableContent.from_json(json) assert dataset == DownloadableContent(src="http://foo.com", path="/mount/path", type=FetchedType.FILE, status=FetcherStatus.DONE)
def update(self, data_set: DownloadableContent): data_set.status = self.status data_set.type = self.type if not self.status.success: data_set.message = self.message data_set.dst = None
def test_data_set_optional_fields_just_src(): json = '{"src":"http://foo.com","path":"/mount/path"}' with pytest.warns(None) as record: dataset = DownloadableContent.from_json(json) assert dataset == DownloadableContent(src="http://foo.com", path="/mount/path") assert not record.list
def on_content_locked(content: DownloadableContent, lock: RWLock): def _on_done_and_unlock(content: DownloadableContent): on_done(content) self._download_dispatcher.cleanup(content, event) lock.release() try: content.size_info = self._size_estimator(content.src) except Exception as e: msg = f"Failed to estimate the size of content {content.src}: {str(e)}" logger.exception(f"{msg}") FetcherResult(FetcherStatus.FAILED, None, msg).update(content) on_done(content) lock.release() return # This node will be killed if I die zk_node_path = self._get_node_path(event.client_id, event.action_id, content) self._zk.create(zk_node_path, DownloadManager.INITIAL_DATA, ephemeral=True, makepath=True) self.__handle_node_state(zk_node_path, _on_done_and_unlock, content) content.size_info = self._size_estimator(content.src) self._download_dispatcher.dispatch_fetch(content, event, zk_node_path)
def test_data_set_with_enums_serialize(): dataset = DownloadableContent(src="http://foo.com", path="/mount/path", type=FetchedType.FILE, status=FetcherStatus.DONE) jsons = dataset.to_json() dict = json.loads(jsons) assert dict["type"] == "FILE" assert dict["status"] == "DONE"
def get_fetcher_benchmark_event(template_event: BenchmarkEvent, dataset_src: str, model_src: str): doc = BenchmarkDoc({"var": "val"}, "var = val", "") datasets = [] if not dataset_src else [ DownloadableContent(src=get_salted_src(dataset_src), path="/mount/path") ] models = [] if not model_src else [ DownloadableContent(src=get_salted_src(model_src), path="/mount/path") ] fetch_payload = FetcherPayload(toml=doc, datasets=datasets, models=models) return dataclasses.replace(template_event, payload=fetch_payload)
def test_collect_status(fetch_statuses, expected_status): assert expected_status == FetcherEventHandler._collect_status([ DownloadableContent(src="some/path", status=fetch_status, path="/mount/path") for fetch_status in fetch_statuses ])
def test_fetch( repeat: int, download_manager, fetcher_service_config: FetcherServiceConfig, benchmark_event_dummy_payload: BenchmarkEvent, ): data_sets_with_events = [ DataSetWithEvent( DownloadableContent( src=EXISTING_DATASET, dst= f"s3://{fetcher_service_config.s3_download_bucket}/it/test.file", md5=None, path="/mount/path", ), threading.Event(), ) ] def on_done_test(content: DownloadableContent, completed: threading.Event): assert content.src assert content.type == FetchedType.FILE assert content.dst assert content.status == FetcherStatus.DONE completed.set() for data_sets_with_event in data_sets_with_events: download_manager.fetch( data_sets_with_event.content, benchmark_event_dummy_payload, lambda d: on_done_test(d, data_sets_with_event.event), ) for _, event in data_sets_with_events: event.wait(WAIT_TIMEOUT)
def fetcher_event(descriptor_as_adict) -> FetcherBenchmarkEvent: return FetcherBenchmarkEvent( action_id=ACTION_ID, message_id="MESSAGE_ID", client_id="CLIENT_ID", client_version="CLIENT_VERSION", client_username="******", authenticated=False, tstamp=42, visited=[], type="PRODUCER_TOPIC", payload=FetcherPayload( toml=BenchmarkDoc(contents=descriptor_as_adict.to_dict(), doc="", sha1="SHA"), scripts=SCRIPTS, datasets=[ DownloadableContent( src="http://someserver.com/somedata.zip", dst=DATASET_S3_URI, path="/mount/path", id=DATASET_ID, size_info=ContentSizeInfo(total_size=42, file_count=1, max_size=42), type=FetchedType.FILE, ) ], ), )
def fetched_data_sources(base_data_sources): sources = [] for source in base_data_sources: sources.append( DownloadableContent(src=source["src"], path="/mount/path", md5=source["md5"], dst=source["puller_uri"]) ) return sources
def test_kubernetes_client( k8s_dispatcher: KubernetesDispatcher, benchmark_event_dummy_payload: BenchmarkEvent, k8s_test_client: KubernetesTestUtilsClient, fetcher_job_config: FetcherJobConfig, size_info: ContentSizeInfo, ): data_set = DownloadableContent(src=SOMEDATA_BIG, path="/mount/path", dst=S3_DST, md5=None, size_info=size_info) k8s_dispatcher.dispatch_fetch(data_set, benchmark_event_dummy_payload, "/data/sets/fake") _wait_for_k8s_objects_exist(benchmark_event_dummy_payload, fetcher_job_config, k8s_test_client, size_info)
def generate_fetched_models(descriptor_data) -> List[DownloadableContent]: data_sources = descriptor_data.get("server", {}).get("models", []) if data_sources: return [ DownloadableContent( src=source["src"], md5="md5", path=source["path"], dst=PULLER_S3_URI + str(inx), type=FetchedType.DIRECTORY, ) # Fake different destinations for inx, source in enumerate(data_sources) ] else: return []
def generate_fetched_data_sources( descriptor_data, http_puller_uri: bool = False) -> List[DownloadableContent]: data_sources = descriptor_data.get("data", {}).get("sources", []) puller_uri = PULLER_HTTP_URI if http_puller_uri else PULLER_S3_URI if data_sources: return [ DownloadableContent(src=source["src"], md5="md5", path=source["path"], dst=puller_uri + str(inx), type=FetchedType.DIRECTORY) # Fake different destinations for inx, source in enumerate(data_sources) ] else: return []
def test_cancel(download_manager, fetcher_service_config: FetcherServiceConfig, benchmark_event_dummy_payload: BenchmarkEvent): data_set = DownloadableContent( src=VERY_LARGE_DATASET, dst=f"s3://{fetcher_service_config.s3_download_bucket}/it/test.file", md5=None, path="/mount/path", ) completed = threading.Event() def on_done_test(content: DownloadableContent): assert content.src assert content.status == FetcherStatus.CANCELED assert not content.dst completed.set() download_manager.fetch(data_set, benchmark_event_dummy_payload, on_done_test) download_manager.cancel(benchmark_event_dummy_payload.client_id, benchmark_event_dummy_payload.action_id) assert completed.wait(WAIT_TIMEOUT)
doc="IyBCZW5jaG1hcYS90Zi1pbWFnZW5ldC8iCg==", sha1="be60cb85620fa041c1bfabd9a9b1c8c1d6be1c78", contents=EXPECTED_FETCHER_CONTENTS, verified=True, descriptor_filename="example_descriptor2.toml", ) EXPECTED_FETCHER_VISITED = [ VisitedService(svc="baictl-client", tstamp="@@TSTAMP@@", version="0.1.0-481dad2"), VisitedService(svc="bai-bff", tstamp=1556814924121, version="0.0.2"), ] EXPECTED_FETCHER_DATASETS = [ DownloadableContent("s3://bucket/imagenet/train", path="~/data/tf-imagenet/"), DownloadableContent("s3://bucket/imagenet/validation", path="~/data/tf-imagenet/"), ] EXPECTED_FETCHER_SCRIPTS = [ FileSystemObject(dst="s3://script-exchange/foo.tar") ] EXPECTED_FETCHER_MODELS = [ DownloadableContent("s3://bucket/model/inception", path="/models/inception", md5="5d41402abc4b2a76b9719d911017c592"), DownloadableContent("s3://bucket/models/mnist", path="/models/mnist"), ]
import pytest from bai_kafka_utils.events import ( BenchmarkEvent, StatusMessageBenchmarkEvent, StatusMessageBenchmarkEventPayload, VisitedService, DownloadableContent, FetcherPayload, FetcherBenchmarkEvent, Status, ) FETCHER_PAYLOAD = FetcherPayload( datasets=[DownloadableContent(src="SRC", path="/mount/path")], toml=None) FETCHER_EVENT = FetcherBenchmarkEvent( action_id="OTHER_ACTION_ID", parent_action_id="PARENT_ACTION_ID", message_id="OTHER_MESSAGE_ID", client_id="OTHER_CLIENT_ID", client_version="0.1.0-481dad2", client_username="******", authenticated=False, tstamp=1556814924121, visited=[ VisitedService(svc="some", tstamp=1556814924121, version="1.0", node=None)
def datasets(): return [DownloadableContent(src="src1"), DownloadableContent(src="src2")]
def test_data_set_dont_fail_unknown_fields(): json = '{"src":"http://foo.com","foo":"bar","path":"/mount/path"}' dataset = DownloadableContent.from_json(json) assert not hasattr(dataset, "foo")
def models(): return [ DownloadableContent(src="model1"), DownloadableContent(src="model2") ]
def test_data_set_optional_missing_src(): json = '{"dst":"http://foo.com", "md5":"42"}' with pytest.raises(KeyError): DownloadableContent.from_json(json)
def some_data_set() -> DownloadableContent: return DownloadableContent("http://imagenet.org/bigdata.zip", size_info=SOME_SIZE_INFO, path="/mount/path")
def dataset(): return DownloadableContent("http://foo.com")
def some_data_set(): return DownloadableContent(src="http://something.com/dataset.zip", path="/mount/path", dst="s3://something/dataset.zip")
def _set_failed(content: DownloadableContent, message: str): content.message = message content.status = FetcherStatus.FAILED content.dst = None
import pytest from bai_kafka_utils.events import DownloadableContent from fetcher_dispatcher.content_pull import get_content_dst S3_BUCKET = "datasets_bucket" @pytest.mark.parametrize( "data_set, expected", [ # Simplest case ( DownloadableContent(src="http://some-server.org/datasets/plenty/bigfile.zip", path="/mount/path"), f"s3://{S3_BUCKET}/data_sets/390c2fe19f6061e4520964a1a968cede/datasets/plenty/bigfile.zip", ), # Same with query args - we ignore them ( DownloadableContent(src="http://some-server.org/datasets/plenty/bigfile.zip?foo=bar", path="/mount/path"), f"s3://{S3_BUCKET}/data_sets/5fddff4d49df672934851f436de903f3/datasets/plenty/bigfile.zip", ), # md5 matters ( DownloadableContent(src="http://some-server.org/datasets/plenty/bigfile.zip", md5="42", path="/mount/path"), f"s3://{S3_BUCKET}/data_sets/390c2fe19f6061e4520964a1a968cede/42/datasets/plenty/bigfile.zip", ), # Hardly possible, but who knows? ( DownloadableContent(src="http://some-server.org", path="/mount/path"), f"s3://{S3_BUCKET}/data_sets/a05fe609e976847b1543a2f3cd25d22c", ),
from pytest import fixture, mark from bai_k8s_utils.service_labels import ServiceLabels from bai_kafka_utils.events import DownloadableContent, BenchmarkEvent, ContentSizeInfo from fetcher_dispatcher import kubernetes_dispatcher, SERVICE_NAME from fetcher_dispatcher.args import FetcherJobConfig, FetcherVolumeConfig from fetcher_dispatcher.kubernetes_dispatcher import KubernetesDispatcher MB = 1024 * 1024 CLIENT_ID = "CLIENT_ID" ACTION_ID = "ACTION_ID" DATA_SET = DownloadableContent(src="http://some.com/src", dst="s3://bucket/dst/", path="/mount/path") DATA_SET_WITH_MD5 = dataclasses.replace(DATA_SET, md5="42") BENCHMARK_EVENT = BenchmarkEvent( action_id=ACTION_ID, message_id="DONTCARE", client_id=CLIENT_ID, client_version="DONTCARE", client_username="******", authenticated=False, tstamp=42, visited=[], type="BAI_APP_BFF", payload="DONTCARE", )