예제 #1
0
 def test_prod_immutable_output_example(self):
     with dbnd_config({
             FetchIds.task_enabled_in_prod: True,
             FetchData.task_enabled_in_prod: True
     }):
         task = ProductionIdsAndData(
             task_env=get_databand_context().env.clone(production=True))
         assert_run_task(task)
예제 #2
0
 def test_tracking_after_flush(self, fake_api_request):
     ctx = get_databand_context()
     async_store = TrackingStoreThroughChannel.build_with_async_web_channel(
         ctx)
     async_store.heartbeat(get_uuid())
     async_store.flush()
     async_store.heartbeat(get_uuid())
     async_store.flush()
예제 #3
0
 def test_flush_without_worker(self, fake_api_request):
     ctx = get_databand_context()
     async_store = TrackingStoreThroughChannel.build_with_async_web_channel(
         ctx)
     assert not async_store.channel._background_worker.is_alive
     assert async_store.is_ready()
     async_store.flush()
     assert async_store.is_ready()
     async_store.flush()
예제 #4
0
    def test_thread_not_started_immideately(self, fake_api_request):
        ctx = get_databand_context()
        async_store = TrackingStoreThroughChannel.build_with_async_web_channel(
            ctx)
        assert async_store.is_ready()
        assert not async_store.channel._background_worker.is_alive

        async_store.heartbeat(get_uuid())
        assert async_store.channel._background_worker.is_alive
예제 #5
0
def greetings_pipeline_subrun(num_of_greetings):
    # this is task, so num_of_greetings is going to be evaluated
    assert isinstance(num_of_greetings, int)

    dc = get_databand_context()  # type: DatabandContext
    # if you are running in "submission" mode, you should cancel that for the new run.
    dc.settings.run.submit_driver = False

    greetings_pipeline.task(num_of_greetings=num_of_greetings).dbnd_run()
    return "OK"
예제 #6
0
def get_tracking_service_config_from_dbnd() -> TrackingServiceConfig:
    from dbnd import get_databand_context

    conf = get_databand_context().settings.core
    config = TrackingServiceConfig(
        url=conf.databand_url,
        access_token=conf.databand_access_token,
        user=conf.dbnd_user,
        password=conf.dbnd_password,
    )
    return config
예제 #7
0
def custom_logging_mod():
    base_config = get_databand_context(
    ).settings.log.get_dbnd_logging_config_base()

    base_config["handlers"]["my_file"] = {
        "class": "logging.FileHandler",
        "formatter": "formatter",
        "filename": "/tmp/my_custom_file",
        "encoding": "utf-8",
    }
    base_config["root"]["handlers"].append("my_file")

    return base_config
예제 #8
0
def _get_output_dir(use_cached):
    dc = get_databand_context()
    working_dir = dc.env.dbnd_local_root__build
    if use_cached:
        working_dir = working_dir.folder("bdist_zip_build_%s" %
                                         dc.current_context_uid)
    else:
        import random

        random_number = random.randrange(0, 1000)
        working_dir = working_dir.folder(
            "bdist_zip_build_%s_%d" % (dc.current_context_uid, random_number))
    return working_dir
예제 #9
0
def get_dbnd_store(store_uri=None, artifact_uri=None):
    dbnd_store_url, duplicate_tracking_to = parse_composite_uri(store_uri)

    logger.info("MLFlow DBND Tracking Store url: {}".format(dbnd_store_url))
    logger.info(
        "MLFlow DBND Tracking Store duplication to: {}".format(duplicate_tracking_to)
    )

    duplication_store = None
    if duplicate_tracking_to is not None:
        # avoid cyclic imports during `_tracking_store_registry.register_entrypoints()`
        from mlflow.tracking import _get_store

        duplication_store = _get_store(duplicate_tracking_to, artifact_uri)

    dbnd_store = get_databand_context().tracking_store

    return DatabandStore(dbnd_store, duplication_store)
예제 #10
0
 def build_databand_client(cls):
     api_client = get_databand_context().databand_api_client
     return DatabandClient(api_client, verbose=True)