def test_search(self): g = hl.Graph() a = hl.Node(uid="a", name="a", op="a") b = hl.Node(uid="b", name="b", op="b") c = hl.Node(uid="c", name="c", op="c") d = hl.Node(uid="d", name="d", op="d") g.add_node(a) g.add_node(b) g.add_node(c) g.add_node(d) g.add_edge(a, b) g.add_edge(b, c) g.add_edge(b, d) pattern = ge.GEParser("a > b").parse() match, following = g.search(pattern) self.assertCountEqual(match, [a, b]) self.assertCountEqual(following, [c, d]) pattern = ge.GEParser("b > (c | d)").parse() match, following = g.search(pattern) self.assertCountEqual(match, [b, c, d]) self.assertEqual(following, []) pattern = ge.GEParser("c|d").parse() match, following = g.search(pattern) self.assertCountEqual(match, [c, d]) self.assertEqual(following, [])
def test_parsing(self): p = ge.GEParser("Conv") self.assertTrue(isinstance(p.parse(), ge.NodePattern)) p = ge.GEParser("Conv | Conv[1x1] ") self.assertTrue(isinstance(p.parse(), ge.ParallelPattern)) p = ge.GEParser("Conv | (Conv[1x1] > Conv)") self.assertTrue(isinstance(p.parse(), ge.ParallelPattern)) p = ge.GEParser("(Conv | (Conv[1x1] > Conv))") self.assertTrue(isinstance(p.parse(), ge.ParallelPattern))
def test_parallel(self): p = ge.GEParser("Conv|Conv[1x1]") self.assertTrue(isinstance(p.parallel(), ge.ParallelPattern)) p = ge.GEParser("Conv | Conv[1x1]") self.assertTrue(isinstance(p.parallel(), ge.ParallelPattern)) p = ge.GEParser("Conv | (Conv[1x1] | Conv)") self.assertTrue(isinstance(p.parallel(), ge.ParallelPattern)) p = ge.GEParser("Conv | Conv[1x1] | Conv") self.assertTrue(isinstance(p.parallel(), ge.ParallelPattern)) self.assertEqual(p.index, 23) p = ge.GEParser("(Conv | Conv[1x1])") self.assertTrue(isinstance(p.expression(), ge.ParallelPattern))
def test_serial(self): p = ge.GEParser("Conv>Conv") self.assertTrue(isinstance(p.serial(), ge.SerialPattern)) p = ge.GEParser("Conv > Conv[1x1]") self.assertTrue(isinstance(p.serial(), ge.SerialPattern)) p = ge.GEParser("Conv > (Conv[1x1] > Conv)") self.assertTrue(isinstance(p.serial(), ge.SerialPattern)) p = ge.GEParser("Conv > Conv[1x1] > Conv") self.assertTrue(isinstance(p.serial(), ge.SerialPattern)) self.assertEqual(p.index, 23) p = ge.GEParser("(Conv > Conv[1x1])") self.assertTrue(isinstance(p.expression(), ge.SerialPattern))
def test_basics(self): p = ge.GEParser(" (hello )") self.assertTrue(p.token("(") and p.re(r"\w+") and p.token(")")) p = ge.GEParser("[1x1]") self.assertTrue(p.condition() == "1x1" and p.index == 5) p = ge.GEParser(" [ 1x1 ] ") self.assertTrue(p.condition() == "1x1" and p.index == 9) p = ge.GEParser("[1x1") self.assertTrue(not p.condition() and p.index == 0) p = ge.GEParser("Conv[1x1]") self.assertTrue(isinstance(p.op(), ge.NodePattern)) p = ge.GEParser("Conv[1x1]") self.assertTrue(isinstance(p.expression(), ge.NodePattern)) p = ge.GEParser("(Conv[1x1])") self.assertTrue(isinstance(p.expression(), ge.NodePattern))
def test_combinations(self): p = ge.GEParser("Conv | (Conv[1x1] > Conv)") self.assertTrue(isinstance(p.parallel(), ge.ParallelPattern)) p = ge.GEParser("Conv > (Conv [1x1] | Conv)") self.assertTrue(isinstance(p.serial(), ge.SerialPattern))
def test_basics(self): g = hl.Graph() a = hl.Node(uid="a", name="a", op="a") b = hl.Node(uid="b", name="b", op="b") c = hl.Node(uid="c", name="c", op="c") d = hl.Node(uid="d", name="d", op="d") e = hl.Node(uid="e", name="e", op="e") g.add_node(a) g.add_node(b) g.add_node(c) g.add_node(d) g.add_node(e) g.add_edge(a, b) g.add_edge(b, c) g.add_edge(b, d) g.add_edge(c, e) g.add_edge(d, e) rule = ge.GEParser("a > b").parse() self.assertIsInstance(rule, ge.SerialPattern) match, following = rule.match(g, a) self.assertTrue(match) self.assertCountEqual(following, [c, d]) match, following = rule.match(g, b) self.assertFalse(match) rule = ge.GEParser("b > c").parse() self.assertIsInstance(rule, ge.SerialPattern) match, following = rule.match(g, b) self.assertFalse(match) rule = ge.GEParser("c | d").parse() self.assertIsInstance(rule, ge.ParallelPattern) match, following = rule.match(g, [c, d]) self.assertTrue(match) self.assertEqual(following, e) match, following = rule.match(g, [c]) self.assertTrue(match) self.assertEqual(following, e) match, following = rule.match(g, d) self.assertTrue(match) self.assertEqual(following, e) match, following = rule.match(g, b) self.assertFalse(match) rule = ge.GEParser("a > b > (c | d)").parse() self.assertIsInstance(rule, ge.SerialPattern) match, following = rule.match(g, a) self.assertTrue(match, following) rule = ge.GEParser("(a > b) > (c | d)").parse() self.assertIsInstance(rule, ge.SerialPattern) match, following = rule.match(g, a) self.assertTrue(match) rule = ge.GEParser("a > b > (c | d) > e").parse() self.assertIsInstance(rule, ge.SerialPattern) match, following = rule.match(g, a) self.assertTrue(match) rule = ge.GEParser("(c | d) > e").parse() self.assertIsInstance(rule, ge.SerialPattern) match, following = rule.match(g, [c, d]) self.assertTrue(match)