def test_get_node(): base_path = "./test_dfgs" file_path = f"{base_path}/linear_dfg.json" test_sched = Schedule() test_sched.load_dfg(file_path) sched_node = test_sched.get_schedule_node(0) assert sched_node.op_name == 'source' with pytest.raises(KeyError): test_invalid_node = test_sched.get_schedule_node(1000)
def test_data_node(): base_path = "./test_dfgs" file_path = f"{base_path}/linear_dfg.json" test_sched = Schedule() test_sched.load_dfg(file_path) sched_node = test_sched.get_schedule_node(2) non_data_node = test_sched.get_schedule_node(11) assert sched_node.is_data_node() assert not non_data_node.is_data_node()
def test_node_depth(): Component.reset_ids() base_path = "./test_dfgs" file_path = f"{base_path}/logistic_dfg.json" with open('config.json') as config_file: data = json.load(config_file) new_arch = TablaTemplate(data) test_sched = Schedule() test_sched.load_dfg(file_path) test_sched.schedule_graph(new_arch) sched_node = test_sched.get_schedule_node(15) # This is the sigmoid operation node assert sched_node.depth == 5