def test_superclasses(self): src = textwrap.dedent(""" class object: pass class A(): pass class B(): pass class C(A): pass class D(A,B): pass class E(C,D,A): pass """) ast = visitors.LookupClasses(self.Parse(src)) data = ast.Visit(pytd_visitors.ExtractSuperClasses()) six.assertCountEqual(self, ["object"], [t.name for t in data[ast.Lookup("A")]]) six.assertCountEqual(self, ["object"], [t.name for t in data[ast.Lookup("B")]]) six.assertCountEqual(self, ["A"], [t.name for t in data[ast.Lookup("C")]]) six.assertCountEqual(self, ["A", "B"], [t.name for t in data[ast.Lookup("D")]]) six.assertCountEqual(self, ["C", "D", "A"], [t.name for t in data[ast.Lookup("E")]])
def GetAllSubClasses(ast): """Compute a class->subclasses mapping. Args: ast: Parsed PYTD. Returns: A dictionary, mapping instances of pytd.Type (types) to lists of pytd.Class (the derived classes). """ hierarchy = ast.Visit(pytd_visitors.ExtractSuperClasses()) hierarchy = {cls: list(superclasses) for cls, superclasses in hierarchy.items()} return utils.invert_dict(hierarchy)