forked from ericdill/depfinder
-
Notifications
You must be signed in to change notification settings - Fork 0
/
depfinder.py
239 lines (204 loc) · 8.23 KB
/
depfinder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# depfinder
# Copyright (C) 2015 Eric Dill
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import print_function, division, absolute_import
import ast
import os
from collections import defaultdict
from stdlib_list import stdlib_list
import sys
pyver = '%s.%s' % (sys.version_info.major, sys.version_info.minor)
builtin_modules = stdlib_list(pyver)
del pyver
del sys
class ImportCatcher(ast.NodeVisitor):
"""Find all imports in an Abstract Syntax Tree (AST).
Attributes
----------
required_modules : list
The list of imports that were found outside of try/except blocks
sketchy_modules : list
The list of imports that were found inside of try/except blocks
imports : list
The list of all ast.Import nodes in the AST
import_froms : list
The list of all ast.ImportFrom nodes in the AST
"""
def __init__(self):
self.required_modules = set()
self.sketchy_modules = set()
self.builtin_modules = set()
self.relative_modules = set()
self.imports = []
self.import_froms = []
self.trys = {}
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node.
Overridden from the ast.NodeVisitor base class so that I can add some
local state to keep track of whether or not my node visitor is inside
a try/except block. When a try block is encountered, the node is added
to the `trys` instance attribute and then the try block is recursed in
to. Once the recursion has exited, the node is removed from the `trys`
instance attribute
"""
for field, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
# add the node to the try/except block to signify that
# something potentially odd is going on in this import
if isinstance(item, ast.Try):
self.trys[item] = item
if isinstance(item, ast.AST):
self.visit(item)
# after the node has been recursed in to, remove the try node
if isinstance(item, ast.Try):
del self.trys[item]
elif isinstance(value, ast.AST):
# add the node to the try/except block to signify that
# something potentially odd is going on in this import
if isinstance(value, ast.Try):
self.trys[value] = item
self.visit(value)
# after the node has been recursed in to, remove the try node
if isinstance(value, ast.Try):
del self.trys[value]
def visit_Import(self, node):
"""Executes when an ast.Import node is encountered
an ast.Import node is something like 'import bar'
If ImportCatcher is inside of a try block then the import that has just
been encountered will be added to the `sketchy_modules` instance
attribute. Otherwise the module will be added to the `required_modules`
instance attribute
"""
self.imports.append(node)
mods = set([name.name.split('.')[0] for name in node.names])
for mod in mods:
self._add_import_node(mod)
def visit_ImportFrom(self, node):
"""Executes when an ast.ImportFrom node is encountered
an ast.ImportFrom node is something like 'from foo import bar'
If ImportCatcher is inside of a try block then the import that has just
been encountered will be added to the `sketchy_modules` instance
attribute. Otherwise the module will be added to the `required_modules`
instance attribute
"""
self.import_froms.append(node)
if node.module is None:
# this is a relative import like 'from . import bar'
# so do nothing
return
if node.level > 0:
# this is a relative import like 'from .foo import bar'
node_name = node.module.split('.')[0]
self.relative_modules.add(node_name)
return
# this is a non-relative import like 'from foo import bar'
node_name = node.module.split('.')[0]
self._add_import_node(node_name)
def _add_import_node(self, node_name):
# see if the module is a builtin
if node_name in builtin_modules:
self.builtin_modules.add(node_name)
return
# see if we are in a try block
if self.trys:
self.sketchy_modules.add(node_name)
return
# if none of the above cases are true, it is likely that this
# ImportFrom node occurs at the top level of the module
self.required_modules.add(node_name)
def describe(self):
"""Return the found imports
Returns
-------
dict :
'required': The modules that were encountered outside of a
try/except block
'questionable': The modules that were encountered inside of a
try/except block
'relative': The modules that were imported via relative import
syntax
'builtin' : The modules that are part of the standard library
"""
desc = {
'required': self.required_modules,
'relative': self.relative_modules,
'questionable': self.sketchy_modules,
'builtin': self.builtin_modules
}
desc = {k: v for k, v in desc.items() if v}
return desc
def __repr__(self):
return 'ImportCatcher: %s' % repr(self.describe())
def get_imported_libs(code):
"""Given a code snippet, return a list of the imported libraries
Parameters
----------
code : str
The code to parse and look for imports
Returns
-------
ImportCatcher
Examples
--------
>>> import depfinder
>>> depfinder.get_imported_libs('from foo import bar')
{'required': {'foo'}, 'questionable': set()}
>>> with open('depfinder.py') as f:
code = f.read()
imports = depfinder.get_imported_libs(code)
print(imports.describe())
"""
tree = ast.parse(code)
catcher = ImportCatcher()
catcher.visit(tree)
return catcher
def iterate_over_library(path_to_source_code):
"""Helper function to recurse into a library and find imports in .py files
Parameters
----------
path_to_source_code : str
Yields
-------
catchers : tuple
Yields tuples of (module_name, full_path_to_module, ImportCatcher)
"""
for parent, folders, files in os.walk(path_to_source_code):
for file in files:
if file.endswith('.py'):
print('.', end='')
full_file_path = os.path.join(parent, file)
with open(full_file_path, 'r') as f:
code = f.read()
catcher = get_imported_libs(code)
yield (file[:-3], full_file_path, catcher)
def simple_import_search(path_to_source_code):
"""Return all imported modules in all .py files in `path_to_source_code`
Parameters
----------
path_to_source_code : str
Returns
-------
dict
The list of all imported modules, sorted according to the keys listed
in the docstring of depfinder.ImportCatcher.describe()
"""
mods = defaultdict(set)
catchers = iterate_over_library(path_to_source_code)
for mod, path, catcher in catchers:
for k, v in catcher.describe().items():
mods[k].update(v)
mods = {k: sorted(list(v)) for k, v in mods.items() if v}
return mods