Example #1
0
def dbnd_on_pre_init_context(ctx):
    from mlflow import get_tracking_uri, set_tracking_uri

    if not config.getboolean("mlflow_tracking", "databand_tracking"):
        return

    databand_url = config.get("core", "databand_url")
    if not databand_url:
        logger.info(
            "Although 'databand_tracking' was set in 'mlflow_tracking', "
            "dbnd will not use it since 'core.databand_url' was not set."
        )
        return

    duplicate_tracking_to = config.get("mlflow_tracking", "duplicate_tracking_to")

    if not duplicate_tracking_to:
        duplicate_tracking_to = get_tracking_uri()

        # check if dbnd store uri was already defined with MLFlow config
        if is_composite_uri(duplicate_tracking_to):
            raise DatabandConfigError(
                "Config conflict: MLFlow and DBND configs both define dbnd store uri"
            )

    composite_uri = build_composite_uri(databand_url, duplicate_tracking_to)

    global _original_mlflow_tracking_uri
    _original_mlflow_tracking_uri = get_tracking_uri()
    set_tracking_uri(composite_uri)
Example #2
0
 def set_azure_client(self):
     config.load_system_configs()
     self.credentials_filename = config.get("azure_tests",
                                            "credentials_file")
     self.storage_account = config.get("azure_tests", "storage_account")
     self.container_name = config.get("azure_tests", "container_name")
     self.client = AzureBlobStorageClient(**self._get_credentials_dict())
     global ATTEMPTED_CONTAINER_CREATE
     if not ATTEMPTED_CONTAINER_CREATE:
         self.client.put_string(b"", self.container_name, "create_marker")
         ATTEMPTED_CONTAINER_CREATE = True
     yield self.client
Example #3
0
def s3_path():

    return "s3://{}/{}/{}".format(
        str(config.get("aws_tests", "bucket_name")),
        datetime.datetime.today().strftime("%Y-%m-%d"),
        str(uuid.uuid4()),
    )
Example #4
0
 def test_read_(self):
     actual = read_from_config_files(
         [scenario_path("config_files", "test_config_reader.cfg")]
     )
     with config(actual):
         assert (
             config.get("test_config_reader", "test_config_reader") == "test_value"
         )
Example #5
0
 def set_gcs_client(self):
     if os.environ.get("GOOGLE_APPLICATION_CREDENTIALS") is None:
         with tempfile.NamedTemporaryFile(delete=False) as f:
             credentials_json = config.get("gcp_tests", "credentials_json")
             f.write(base64.b64decode(credentials_json))
             os.environ.setdefault("GOOGLE_APPLICATION_CREDENTIALS", f.name)
     credentials = self._get_credentials()
     self.client = gcs.GCSClient(credentials)
Example #6
0
 def test_word_spark_with_error(self):
     actual = WordCountThatFails(
         text=config.get("livy_tests", "text"),
         task_version=str(random.random()),
         override=conf_override,
     )
     with pytest.raises(DatabandRunError):
         actual.dbnd_run()
Example #7
0
    def test_setting_dbnd_config_from_valid_connection(self, af_session):
        self.set_dbnd_airflow_connection(
            af_session, json_for_connection=VALID_EXTRA_JSON_FOR_CONNECTION)

        is_config_configured = set_dbnd_config_from_airflow_connections()

        assert is_config_configured
        assert config.get("core", "databand_url") == DATABAND_URL
Example #8
0
 def test_word_count_pyspark(self):
     logging.info("Running %s", WordCountPySparkTask)
     actual = WordCountPySparkTask(
         text=config.get("livy_tests", "text"),
         task_version=str(random.random()),
         override=conf_override,
     )
     actual.dbnd_run()
     print(target(actual.counters.path, "part-00000").read())
Example #9
0
    def test_word_count_inline(self):
        from dbnd_test_scenarios.spark.spark_tasks_inline import word_count_inline

        assert_run_task(
            word_count_inline.t(
                text=config.get("livy_tests", "text"),
                task_version=str(random.random()),
                override=conf_override,
            ))
Example #10
0
 def test_override_config_values_from_context(self):
     with config({
             "task_from_config": {
                 "parameter_from_config": "from_context_override"
             }
     }):
         assert (config.get(
             "task_from_config",
             "parameter_from_config") == "from_context_override")
         task_from_config.dbnd_run(expected="from_context_override")
Example #11
0
from dbnd._core.settings import EnvConfig
from dbnd.testing.helpers_pytest import assert_run_task
from dbnd_aws.emr.emr_config import EmrConfig
from dbnd_spark.spark_config import SparkConfig
from dbnd_test_scenarios.spark.spark_tasks import (
    WordCountPySparkTask,
    WordCountTask,
    WordCountThatFails,
)
from targets import target


conf_override = {
    "task": {"task_env": CloudType.aws},
    EnvConfig.spark_config: SparkClusters.emr,
    EmrConfig.cluster: config.get("aws_tests", "cluster"),
    SparkConfig.jars: "",
    SparkConfig.main_jar: "",
}

TEXT_FILE = config.get("aws_tests", "text_file")


@pytest.mark.emr
class TestEmrSparkTasks(object):
    # add back java code test
    @pytest.mark.skip
    def test_word_count_spark(self):
        logging.info("Running %s", WordCountPySparkTask)
        actual = WordCountTask(
            text=TEXT_FILE, task_version=str(random.random()), override=conf_override
Example #12
0
 def test_read_environ_config(self):
     os.environ["DBND__TEST_SECTION__TEST_KEY"] = "TEST_VALUE"
     actual = read_environ_config()
     with config(actual):
         assert config.get("test_section", "test_key") == "TEST_VALUE"
Example #13
0
def hdfs_path():
    return "hdfs://{}_{}_{}".format(
        str(config.get("integration_tests", "hdfs_folder")),
        datetime.datetime.today().strftime("%Y-%m-%d"),
        str(uuid.uuid4()),
    )
Example #14
0
 def bucket_url(self, suffix):
     """
     Actually it's bucket + test folder name
     """
     bucket_name = config.get("gcp_tests", "bucket_name")
     return "gs://{}/{}/{}".format(bucket_name, TEST_FOLDER, suffix)
Example #15
0
 def _get_credentials(self):
     scope = config.get("gcp_tests", "scope")
     return google.auth.default(scopes=[scope])[0]