示例#1
0
    def test_ret_dict(self):
        @task(result=(output(name="o_a").csv[List[str]], "o_b"))
        def t_f(a=5):
            return {"o_a": [str(a)], "o_b": ["2"]}

        t = assert_run_task(t_f.t(a=6))
        assert "6\n" == t.o_a.read()
示例#2
0
    def test_return_tuple(self, pandas_data_frame):
        my_target = self.target("file.parquet")
        my_target.write_df(pandas_data_frame)

        now = datetime.datetime.now()

        @task(
            result=(output(name="task_res1").parquet, output.parquet(name="task_res2"))
        )
        def t_f(a_str, b_datetime, c_timedelta, d_int):
            # type: (str, datetime.datetime, datetime.timedelta, int) -> (DataFrame, DataFrame)
            assert a_str == "strX"
            assert b_datetime == now
            assert c_timedelta == datetime.timedelta(seconds=1)
            assert d_int == 1
            return pandas_data_frame, pandas_data_frame

        args = ["strX", now, datetime.timedelta(seconds=1), 1]
        func_res1, func_res2 = t_f(*args)
        assert_frame_equal(pandas_data_frame, func_res1)
        assert_frame_equal(pandas_data_frame, func_res2)
        t = assert_run_task(t_f.t(*args))

        task_res1, task_res2 = t.result
        assert_frame_equal(pandas_data_frame, task_res1.read_df())
        assert_frame_equal(pandas_data_frame, task_res2.read_df())
        assert_frame_equal(pandas_data_frame, t.task_res1.read_df())
        assert_frame_equal(pandas_data_frame, t.task_res2.read_df())
示例#3
0
    def test_fails_on_same_names(self):
        with pytest.raises(
            DatabandBuildError, message="have same keys in result schema"
        ):

            @task(result=(output(name="same_name").csv[List[str]], "o_b"))
            def t_f(same_name=3, a=5):
                return {"o_a": [str(a)], "o_b": same_name}
示例#4
0
    def test_builder(self):
        target = output(name="ttt")
        assert target.parameter.name == "ttt"
        assert target.parameter.target_config != file.feather

        target = target.feather()

        assert target.parameter.target_config == file.feather
        assert target.parameter.name
示例#5
0
    def test_prepare_data_overwrite_previous_result(self):
        #### DOC START
        @task(result=output(
            default=DEFAULT_OUTPUT_HDFS_PATH).overwrite.csv[pd.DataFrame])
        def prepare_data(data=parameter[pd.DataFrame]):
            return data

        #### DOC END
        prepare_data.dbnd_run(data=data_repo.wines)
示例#6
0
 class RequiredTask(Task):
     t_output = output(default="/path/to/target/file")
            def run(self):
                assert self.a == 6
                target(self.output_file).mkdir_parent()
                with open(self.output_file, "w") as fp:
                    fp.write("test")
                return self.a

        TFCls_output.task(a=6).dbnd_run()

    def test_user_decorated_class_serializable(self):
        t = ClsAsTask()
        pickled = pickle.dumps(t)
        assert t.a == pickle.loads(pickled).a


@task(result=(output(name="datasets")[List[str]]))
class InlineCallClsDecoratedTask(object):
    def __init__(self, param_dict=parameter[Dict], param_str=parameter[str]):
        self.param_dict = param_dict
        self.param_str = param_str

    def run(self):
        assert self.param_dict
        assert self.param_str
        if self.param_str == "error":
            raise TError("Raising as requested")
        self.datasets = list(self.param_dict.keys())


@task(result=(output(name="datasets")[List[str]]))
class ParentCallClsDecoratedTask:
示例#8
0
        raise TaskValidationError("Salt level is too high!")

    mix = [
        "%s->%s  (%s salt)" % (dressing, v.rstrip(), salt_amount)
        for v in chopped_vegetables
    ]

    log_metric("dressed", len(mix))
    log_metric("dressing", dressing)
    log_metric("dict", {"a": datetime.datetime.utcnow(), "tuple": ("a", 1)})

    logging.info("Dressing result %s", mix)
    return [x + "\n" for x in mix]


@task(result=output().data_list_str)
def cut(vegetables):
    # type: (DataList[str]) -> DataList[str]
    # cuts vegetables
    chopped = []

    logging.info("Got {}. Start Chopping.".format(
        ",".join(vegetables)).replace("\n", ""))

    for line in vegetables:
        chopped.extend(list(line.rstrip()))

    shuffle(chopped)

    logging.info("Chopped vegetables:" + " ".join(chopped))
示例#9
0
    try:
        output = (subprocess.check_output(
            cmd, stderr=subprocess.STDOUT, shell=True,
            env=os.environ).decode("utf-8").strip())
    except subprocess.CalledProcessError as ex:
        logger.error(
            "Failed to run %s :\n\n\n -= Output =-\n%s\n\n\n -= See output above =-",
            cmd,
            ex.output.decode("utf-8", errors="ignore"),
        )
        raise ex
    return output


@task(
    result=output(default=DEFAULT_OUTPUT_HDFS_PATH).overwrite.csv[pd.DataFrame]
)
def overwrite_target_task(df=parameter[pd.DataFrame]):
    return df


@pipeline
def overwrite_target_pipeline():
    result = overwrite_target_task(pandas_data_frame())
    return result


class TestOverwritingTargetHdfs(object):
    @pytest.fixture(autouse=True, scope="module")
    def wait_for_namenode(self):
        # Wait for namenode to leave safe mode
示例#10
0
from dbnd._core.utils.timezone import utcnow
from targets import Target
from targets.types import Path

logger = logging.getLogger(__name__)


class DbSyncConfig(config.Config):
    """(Advanced) Databand's db sync behaviour"""

    _conf__task_family = "db_sync"

    export_root = parameter[Target]


@task(archive=output(output_ext=".tar.gz")[Path])
def export_db(
        archive,
        include_db=True,
        include_logs=True,
        task_version=utcnow().strftime("%Y%m%d_%H%M%S"),
):
    # type: (Path, bool, bool, str)-> None

    from dbnd._core.current import get_databand_context

    logger.info("Compressing files to %s..." % archive)
    with tarfile.open(str(archive), "w:gz") as tar:

        if include_db:
示例#11
0
    def test_build_parameter(self):
        actual = output(name="ttt")._p

        assert actual
        assert actual.name == "ttt"
示例#12
0
import pandas as pd

from dbnd import output, task
from targets.marshalling import register_marshaller
from targets.marshalling.pandas import _PandasMarshaller
from targets.target_config import register_file_extension

# 1. create file extension
excel_file_ext = register_file_extension("xlsx")


class DataFrameToExcel(_PandasMarshaller):
    def _pd_read(self, *args, **kwargs):
        return pd.read_excel(*args, **kwargs)

    def _pd_to(self, value, *args, **kwargs):
        return value.to_excel(*args, **kwargs)


# 2. register type to extension mapping
register_marshaller(pd.DataFrame, excel_file_ext, DataFrameToExcel())


@task(result=output(output_ext=excel_file_ext))
def dump_as_excel_table():
    # type: ()-> pd.DataFrame
    df = pd.DataFrame(data=list(zip(["Bob", "Jessica"], [968, 155])),
                      columns=["Names", "Births"])
    return df