def create_graph(config: dict) -> Graph: clazz = config.get("class") if clazz == "ConcurrentGraph": return ConcurrentGraph() elif clazz == "DebugGraph": return DebugGraph(Path(config.get("base_dir"))) return Graph()
def add_tasks_to_graph(graph: Graph, tasks: list): added_tasks = {} def exist_all_dpes(at, t): ds = t["dependencies"] return all([d in at for d in ds]) def all_deps(at, t): ds = t["dependencies"] deps = [at[d] for d in ds] return deps while len(added_tasks) < len(tasks): for k, t in tasks.items(): if k in added_tasks: # 追加済 continue # 追加可能なら追加 if exist_all_dpes(added_tasks, t): gt = graph.append(t["instance"], all_deps(added_tasks, t)) added_tasks[k] = gt break else: # 依存を解決できないタスクがあった for k, t in tasks.items(): if k not in added_tasks: logger.error(f"{t['instance'].__class__.__name__} can't resolve dependencies") raise ValueError("build graph failed")
def test_hook(self): g = Graph() gt = g.append(DumbTask()) global hook_called hook_called = 0 def hook(ds): global hook_called hook_called += 1 gt.pre_run_hook = hook gt.post_run_hook = hook g.run() self.assertEqual(2, hook_called)
def test_error_handler(self): def value_error_handler(e, ds): print(str(e)) print(str(ds)) g = Graph() g.append(ErrorTask()) g.add_error_handler(ValueError, value_error_handler) g.run()
def test_auto_resolver_error(self): g = Graph() g.append(TaskA()) g.append(TaskB()) g.append(TaskC()) g.append(TaskD()) with self.assertRaises(ValueError): g.autoresolve_dependencies()
def test_dynamic_graph(self): class DynTaskA(Task): def __init__(self): self._output_datakeys = [] def output_datakeys(self): return self._output_datakeys def main(self, ds: DataSet): self._output_datakeys = ["DynTaskA"] return DataSet() class DynTaskB(Task): def input_datakeys(self): return ["DynTaskA"] def main(self, ds: DataSet): return DataSet() class DynTaskC(Task): def input_datakeys(self): return ["DynTaskX"] def main(self, ds: DataSet): return DataSet() # g = Graph() g.append(DynTaskA()) gtb = g.append(DynTaskB()) gtc = g.append(DynTaskC()) g.run() self.assertEqual(TaskStatus.COMPLETED, gtb.status) self.assertEqual(TaskStatus.INIT, gtc.status)
def test_auto_resolver(self): g = Graph() g.append(TaskA()) g.append(TaskB()) taskC = g.append(TaskC()) g.autoresolve_dependencies() self.assertEqual(2, len(taskC.dependencies))
def test_default_ds(self): class TestTask(Task): def main(self, ds: DataSet): if "default" not in ds.keys(): raise ValueError() return DataSet() # default_ds = DataSet().put("default", JsonData({})) g = Graph() g.append(TestTask()) g.run(default_ds)
def test_catalog_ds(self): class TestTask(Task): def main(self, ds: DataSet): if "catalog" not in ds.keys(): raise ValueError() return DataSet() catalog_ds = DataSet().put("catalog", JsonData({})) g = Graph(catalog_ds=catalog_ds) g.append(TestTask()) g.run()
return ["DataA-2"] def main(self, ds): logger.info("execute TaskC") return DataSet() if __name__ == "__main__": basicConfig(level=DEBUG) parser = argparse.ArgumentParser() parser.add_argument("-g", "--graph", type=str, default="normal", help="Graph {normal/thread/process}") args = parser.parse_args() if args.graph == "normal": graph = Graph() logger.info("use Graph") elif args.graph == "thread": graph = ConcurrentGraph(ThreadPoolExecutor()) logger.info("use ConcurrentGraph(ThreadPoolExecutor)") elif args.graph == "process": graph = ConcurrentGraph(ProcessPoolExecutor()) logger.info("use ConcurrentGraph(ProcessPoolExecutor)") else: logger.error("unknown graph,") sys.exit(-1) # Graphで処理する graph.append(TaskA()) graph.append(TaskB()) graph.append(TaskC())
rds = DataSet() rds.put("titanic", DataFrameData(df)) return rds # DAG定義 default_args = { "owner": "airflow", "depends_on_past": False, "start_date": datetime(2020, 6, 11), "email": ["*****@*****.**"], "email_on_failure": False, "email_on_retry": False, "retries": 1, "retry_delay": timedelta(minutes=5), } dag = DAG("graph_to_dag", default_args=default_args, schedule_interval=None) # Graph作成 graph = Graph() readcsv = graph.append(ReadCsv()) fill_age = graph.append(FillNaMedian("Age"), [readcsv]) sex_to_code = graph.append(SexToCode(), [readcsv]) embarked_to_code = graph.append(EmbarkedToCode(), [readcsv]) graph.append(Merge(), [readcsv, fill_age, sex_to_code, embarked_to_code]) # Graphのタスクを AirFlowに登録 af = AirFlow(graph) af.to_dag(dag)
def test_no_handler(self): g = Graph() g.append(ErrorTask()) with self.assertRaises(ValueError): g.run()
ds = DataSet() ds.put("titanic", DataFrameData(df)) return ds if __name__ == "__main__": basicConfig(level=DEBUG) # データセットの読み込み ds = DataSet() repo = LocalFileRepository( Path(os.path.dirname(__file__)) / Path("../titanic.csv")) titanic_data = DataFrameData.load(repo) ds.put("titanic", titanic_data) # print("## Original data") print(ds.get("titanic").content) # Graphで処理する # Age欠損埋め -> 性別のコード化 -> 乗船した港 のコード化 の順で処理 graph = Graph() fill_age = graph.append(FillNaMedian("Age")) sex_to_code = graph.append(SexToCode(), [fill_age]) graph.append(EmbarkedToCode(), [sex_to_code]) ds = graph.run(ds) print("## Processed data") print(ds.get("titanic").content)
return rds if __name__ == "__main__": basicConfig(level=DEBUG) parser = argparse.ArgumentParser() parser.add_argument("-g", "--graph", type=str, default="normal", help="Graph {normal/thread/process}") args = parser.parse_args() if args.graph == "normal": graph = Graph() logger.info("use Graph") elif args.graph == "thread": graph = ConcurrentGraph(ThreadPoolExecutor()) logger.info("use ConcurrentGraph(ThreadPoolExecutor)") elif args.graph == "process": graph = ConcurrentGraph(ProcessPoolExecutor()) logger.info("use ConcurrentGraph(ProcessPoolExecutor)") else: logger.error("unknown graph,") sys.exit(-1) # データセットの読み込み ds = DataSet() repo = LocalFileRepository( Path(os.path.dirname(__file__)) / Path("../titanic.csv"))