예제 #1
0
def create_graph(config: dict) -> Graph:
    clazz = config.get("class")
    if clazz == "ConcurrentGraph":
        return ConcurrentGraph()
    elif clazz == "DebugGraph":
        return DebugGraph(Path(config.get("base_dir")))
    return Graph()
예제 #2
0
def add_tasks_to_graph(graph: Graph, tasks: list):
    added_tasks = {}

    def exist_all_dpes(at, t):
        ds = t["dependencies"]
        return all([d in at for d in ds])

    def all_deps(at, t):
        ds = t["dependencies"]
        deps = [at[d] for d in ds]
        return deps

    while len(added_tasks) < len(tasks):
        for k, t in tasks.items():
            if k in added_tasks:
                # 追加済
                continue

            # 追加可能なら追加
            if exist_all_dpes(added_tasks, t):
                gt = graph.append(t["instance"], all_deps(added_tasks, t))
                added_tasks[k] = gt
                break
        else:
            # 依存を解決できないタスクがあった
            for k, t in tasks.items():
                if k not in added_tasks:
                    logger.error(f"{t['instance'].__class__.__name__} can't resolve dependencies")
            raise ValueError("build graph failed")
예제 #3
0
    def test_hook(self):
        g = Graph()
        gt = g.append(DumbTask())

        global hook_called
        hook_called = 0

        def hook(ds):
            global hook_called
            hook_called += 1

        gt.pre_run_hook = hook
        gt.post_run_hook = hook
        g.run()

        self.assertEqual(2, hook_called)
예제 #4
0
    def test_error_handler(self):
        def value_error_handler(e, ds):
            print(str(e))
            print(str(ds))

        g = Graph()
        g.append(ErrorTask())
        g.add_error_handler(ValueError, value_error_handler)
        g.run()
예제 #5
0
    def test_auto_resolver_error(self):
        g = Graph()
        g.append(TaskA())
        g.append(TaskB())
        g.append(TaskC())
        g.append(TaskD())

        with self.assertRaises(ValueError):
            g.autoresolve_dependencies()
예제 #6
0
    def test_dynamic_graph(self):
        class DynTaskA(Task):
            def __init__(self):
                self._output_datakeys = []

            def output_datakeys(self):
                return self._output_datakeys

            def main(self, ds: DataSet):
                self._output_datakeys = ["DynTaskA"]
                return DataSet()

        class DynTaskB(Task):
            def input_datakeys(self):
                return ["DynTaskA"]

            def main(self, ds: DataSet):
                return DataSet()

        class DynTaskC(Task):
            def input_datakeys(self):
                return ["DynTaskX"]

            def main(self, ds: DataSet):
                return DataSet()

        #
        g = Graph()
        g.append(DynTaskA())
        gtb = g.append(DynTaskB())
        gtc = g.append(DynTaskC())
        g.run()

        self.assertEqual(TaskStatus.COMPLETED, gtb.status)
        self.assertEqual(TaskStatus.INIT, gtc.status)
예제 #7
0
    def test_auto_resolver(self):
        g = Graph()
        g.append(TaskA())
        g.append(TaskB())
        taskC = g.append(TaskC())
        g.autoresolve_dependencies()

        self.assertEqual(2, len(taskC.dependencies))
예제 #8
0
    def test_default_ds(self):
        class TestTask(Task):
            def main(self, ds: DataSet):
                if "default" not in ds.keys():
                    raise ValueError()
                return DataSet()

        #
        default_ds = DataSet().put("default", JsonData({}))
        g = Graph()
        g.append(TestTask())
        g.run(default_ds)
예제 #9
0
    def test_catalog_ds(self):
        class TestTask(Task):
            def main(self, ds: DataSet):
                if "catalog" not in ds.keys():
                    raise ValueError()
                return DataSet()

        catalog_ds = DataSet().put("catalog", JsonData({}))
        g = Graph(catalog_ds=catalog_ds)
        g.append(TestTask())

        g.run()
예제 #10
0
        return ["DataA-2"]

    def main(self, ds):
        logger.info("execute TaskC")
        return DataSet()


if __name__ == "__main__":
    basicConfig(level=DEBUG)

    parser = argparse.ArgumentParser()
    parser.add_argument("-g", "--graph", type=str, default="normal", help="Graph {normal/thread/process}")
    args = parser.parse_args()

    if args.graph == "normal":
        graph = Graph()
        logger.info("use Graph")
    elif args.graph == "thread":
        graph = ConcurrentGraph(ThreadPoolExecutor())
        logger.info("use ConcurrentGraph(ThreadPoolExecutor)")
    elif args.graph == "process":
        graph = ConcurrentGraph(ProcessPoolExecutor())
        logger.info("use ConcurrentGraph(ProcessPoolExecutor)")
    else:
        logger.error("unknown graph,")
        sys.exit(-1)

    # Graphで処理する
    graph.append(TaskA())
    graph.append(TaskB())
    graph.append(TaskC())
예제 #11
0
        rds = DataSet()
        rds.put("titanic", DataFrameData(df))
        return rds


# DAG定義
default_args = {
    "owner": "airflow",
    "depends_on_past": False,
    "start_date": datetime(2020, 6, 11),
    "email": ["*****@*****.**"],
    "email_on_failure": False,
    "email_on_retry": False,
    "retries": 1,
    "retry_delay": timedelta(minutes=5),
}
dag = DAG("graph_to_dag", default_args=default_args, schedule_interval=None)

# Graph作成
graph = Graph()
readcsv = graph.append(ReadCsv())
fill_age = graph.append(FillNaMedian("Age"), [readcsv])
sex_to_code = graph.append(SexToCode(), [readcsv])
embarked_to_code = graph.append(EmbarkedToCode(), [readcsv])
graph.append(Merge(), [readcsv, fill_age, sex_to_code, embarked_to_code])

# Graphのタスクを AirFlowに登録
af = AirFlow(graph)
af.to_dag(dag)
예제 #12
0
    def test_no_handler(self):
        g = Graph()
        g.append(ErrorTask())

        with self.assertRaises(ValueError):
            g.run()
예제 #13
0
        ds = DataSet()
        ds.put("titanic", DataFrameData(df))
        return ds


if __name__ == "__main__":
    basicConfig(level=DEBUG)

    # データセットの読み込み
    ds = DataSet()
    repo = LocalFileRepository(
        Path(os.path.dirname(__file__)) / Path("../titanic.csv"))
    titanic_data = DataFrameData.load(repo)
    ds.put("titanic", titanic_data)

    #
    print("## Original data")
    print(ds.get("titanic").content)

    # Graphで処理する
    # Age欠損埋め -> 性別のコード化 -> 乗船した港 のコード化 の順で処理
    graph = Graph()
    fill_age = graph.append(FillNaMedian("Age"))
    sex_to_code = graph.append(SexToCode(), [fill_age])
    graph.append(EmbarkedToCode(), [sex_to_code])
    ds = graph.run(ds)

    print("## Processed data")
    print(ds.get("titanic").content)
예제 #14
0
        return rds


if __name__ == "__main__":
    basicConfig(level=DEBUG)

    parser = argparse.ArgumentParser()
    parser.add_argument("-g",
                        "--graph",
                        type=str,
                        default="normal",
                        help="Graph {normal/thread/process}")
    args = parser.parse_args()

    if args.graph == "normal":
        graph = Graph()
        logger.info("use Graph")
    elif args.graph == "thread":
        graph = ConcurrentGraph(ThreadPoolExecutor())
        logger.info("use ConcurrentGraph(ThreadPoolExecutor)")
    elif args.graph == "process":
        graph = ConcurrentGraph(ProcessPoolExecutor())
        logger.info("use ConcurrentGraph(ProcessPoolExecutor)")
    else:
        logger.error("unknown graph,")
        sys.exit(-1)

    # データセットの読み込み
    ds = DataSet()
    repo = LocalFileRepository(
        Path(os.path.dirname(__file__)) / Path("../titanic.csv"))