Beispiel #1
0
def test_conf():
    # Prevent bspark.is_local() returning true
    orig_ctx_session_type = context.get("BIRGITTA_SPARK_SESSION_TYPE")
    context.set("BIRGITTA_SPARK_SESSION_TYPE", "NONLOCAL")
    session = bspark.session()
    assert session.conf.get("spark.sql.session.timeZone") == "UTC"
    context.set("BIRGITTA_SPARK_SESSION_TYPE", orig_ctx_session_type)
Beispiel #2
0
def get():
    source = context.get('BIRGITTA_DATAFRAMESOURCE')
    if source:
        return source
    source = derive_source()
    set(source)
    return source
Beispiel #3
0
def log_transform(var, line_no, line_str):
    """Helper function to log a transformation.
    localtest will insert log_transform() calls after assignments.
    The call contains the line_no and the assignement code.
    The purpose is to log the transformation, including the
    dataframe count if it is a dataframe. The purpose is to
    Verify that more than zero rows pass through each transform.

    Args:
        var (str): The variable being logged
        line_no (int): The line number of the original code line
        line_str (str): The original code line
    """
    report_file = context.get("BIRGITTA_TEST_COVERAGE")["cov_report_file"]
    test_case = context.get("BIRGITTA_TEST_COVERAGE")["test_case"]
    type_name = type(var).__name__
    metrics = {"var_type": type_name}
    if type_name == "DataFrame":
        metrics["count"] = var.count()
    log_entry(test_case, line_no, line_str, report_file, metrics)
Beispiel #4
0
def is_local():
    return context.get("BIRGITTA_SPARK_SESSION_TYPE") == "LOCAL"
Beispiel #5
0
def stored_in(t):
    """Returns true if storage type equals t"""
    storage_type = context.get("BIRGITTA_DATASET_STORAGE")
    return t == storage_type
Beispiel #6
0
def clone(client,
          src_project_key,
          dst_project_key,
          src_name,
          dst_name,
          dst_dataset_type,
          copy_data=True):
    """Utility function for cloning dataiku datasets including schema.
    """
    src_project = client.get_project(src_project_key)
    dst_project = client.get_project(dst_project_key)
    dataset_manage.delete_if_exists(dst_project, dst_name)

    src_dataset = src_project.get_dataset(src_name)
    # src_dataset_definition = src_dataset.get_definition()
    if dst_dataset_type == "HDFS":
        # dst_dataset_type = src_dataset_definition['type']
        # dst_dataset_params = src_dataset_definition['params']
        dst_dataset_params = {
            'metastoreSynchronizationEnabled': True,
            'hiveDatabase': '${hive_table_work}',  # Use work
            'hiveTableName': '${projectKey}_' + dst_name,
            'connection': 'hdfs_work',
            'path': '/${projectKey}/' + dst_name,
            'notReadyIfEmpty': False
            #         'filesSelectionRules': {'mode': 'ALL',
            #             'excludeRules': [],
            #             'includeRules': [],
            #             'explicitFiles': []
            #         }
        }
        dst_dataset_params['importProjectKey'] = dst_project_key
        dst_format_type = "parquet"
        # dst_format_type = src_dataset_definition['formatType']
        #     dst_format_params = src_dataset_definition['formatParams']
        #     dst_format_params = {'parquetLowerCaseIdentifiers': False,
        #       'representsNullFields': False,
        #       'parquetCompressionMethod': 'SNAPPY',
        #       'parquetFlavor': 'HIVE',
        #       'parquetBlockSizeMB': 128}
    elif dst_dataset_type == "S3":
        s3_bucket = context.get('BIRGITTA_S3_BUCKET')
        dst_dataset_params = {
            'bucket': s3_bucket,
            'connection': 'S3',
            'path': '/${projectKey}/' + dst_name,
            'notReadyIfEmpty': False,
            #       'filesSelectionRules': {'mode': 'ALL',
            #        'excludeRules': [],
            #        'includeRules': [],
            #        'explicitFiles': []}
        }
        dst_format_type = "avro"
    else:  # Inline
        dst_dataset_params = {
            'keepTrackOfChanges': False,
            'notReadyIfEmpty': False,
            'importSourceType': 'NONE',
            'importProjectKey': dst_project_key
        }
        dst_format_type = "json"

    dst_dataset = dst_project.create_dataset(
        dst_name,
        dst_dataset_type,
        params=dst_dataset_params,
        formatType=dst_format_type
        # formatParams=dst_format_params # Prefer no hard typing
    )

    dst_dataset.set_schema(src_dataset.get_schema())
Beispiel #7
0
def today():
    """Returns fixed value for today to enable consistent tests.
    """
    return context.get("TODAY")