def set_figtitle(figtitle, subtitle='', forcefignum=True, incanvas=True, size=None, fontfamily=None, fontweight=None, fig=None): r""" Args: figtitle (?): subtitle (str): (default = '') forcefignum (bool): (default = True) incanvas (bool): (default = True) fontfamily (None): (default = None) fontweight (None): (default = None) size (None): (default = None) fig (None): (default = None) CommandLine: python -m .custom_figure set_figtitle --show Example: >>> # DISABLE_DOCTEST >>> autompl() >>> fig = figure(fnum=1, doclf=True) >>> result = set_figtitle(figtitle='figtitle', fig=fig) >>> # xdoc: +REQUIRES(--show) >>> show_if_requested() """ from matplotlib import pyplot as plt if figtitle is None: figtitle = '' if fig is None: fig = plt.gcf() figtitle = ub.ensure_unicode(figtitle) subtitle = ub.ensure_unicode(subtitle) if incanvas: if subtitle != '': subtitle = '\n' + subtitle prop = { 'family': fontfamily, 'weight': fontweight, 'size': size, } prop = {k: v for k, v in prop.items() if v is not None} sup = fig.suptitle(figtitle + subtitle) if prop: fontproperties = sup.get_fontproperties().copy() for key, val in prop.items(): getattr(fontproperties, 'set_' + key)(val) sup.set_fontproperties(fontproperties) # fontproperties = mpl.font_manager.FontProperties(**prop) else: fig.suptitle('') # Set title in the window window_figtitle = ('fig(%d) ' % fig.number) + figtitle window_figtitle = window_figtitle.replace('\n', ' ') fig.canvas.set_window_title(window_figtitle)
def parse(ImportVisitor, source=None, modpath=None, modname=None, module=None): if module is not None: if source is None: source = inspect.getsource(module) if modpath is None: modname = module.__name__ if modname is None: modpath = module.__file__ if modpath is not None: if isdir(modpath): modpath = join(modpath, '__init__.py') if modname is None: modname = ub.modpath_to_modname(modpath) if modpath is not None: if source is None: if not modpath.endswith(('.py', '>')): raise NotAPythonFile( 'can only parse python files, not {}'.format(modpath)) source = open(modpath, 'r').read() if source is None: raise ValueError('unable to derive source code') source = ub.ensure_unicode(source) pt = ast.parse(source) visitor = ImportVisitor(modpath, modname, module, pt=pt) visitor.visit(pt) return visitor
def preprocess_research(input_str): """ test of an em --- dash test of an em — dash """ import utool as ut import ubelt as ub inside = ut.named_field('ref', '.*?') input_str = re.sub(r'\\emph{' + inside + '}', ut.bref_field('ref'), input_str) # input_str = input_str.decode('utf-8') input_str = ub.ensure_unicode(input_str) pause = re.escape(' <break time="300ms"/> ') # pause = ', ' emdash = u'\u2014' # # print('input_str = %r' % (input_str,)) # print('emdash = %r' % (emdash,)) # print('emdash = %s' % (emdash,)) input_str = re.sub('\s?' + re.escape('---') + '\s?', pause, input_str) input_str = re.sub('\s?' + emdash + '\s?', pause, input_str) # print('input_str = %r' % (input_str,)) input_str = re.sub('\\\\cite{[^}]*}', '', input_str) input_str = re.sub('et al.', 'et all', input_str) # Let rob say et al. input_str = re.sub(' i\.e\.', ' i e ' + pause, input_str) # Let rob say et al. input_str = re.sub(r'\\r', '', input_str) # input_str = re.sub(r'\\n', '', input_str) # input_str = re.sub('\\\\', '', input_str) # #input_str = re.sub('[a-z]?[a-z]', 'et all', input_str) # Let rob say et al. input_str = re.sub('\\.[^a-zA-Z0-1]+', '.\n', input_str) # Split the document at periods input_str = re.sub('\r\n', '\n', input_str) input_str = re.sub('^ *$\n', '', input_str) input_str = re.sub('\n\n*', '\n', input_str) return input_str
def find_pyfunc_above_row(line_list, row, orclass=False): """ originally part of the vim plugin CommandLine: python ~/local/vim/rc/pyvim_funcs.py find_pyfunc_above_row Example: >>> import ubelt as ub >>> import six >>> func = find_pyfunc_above_row >>> fpath = six.get_function_globals(func)['__file__'].replace('.pyc', '.py') >>> line_list = ub.readfrom(fpath, aslines=True) >>> row = six.get_function_code(func).co_firstlineno + 1 >>> funcname, searchlines, func_pos, foundline = find_pyfunc_above_row(line_list, row) >>> print(funcname) find_pyfunc_above_row """ import ubelt as ub searchlines = [] # for debugging funcname = None # Janky way to find function name func_sentinal = 'def ' method_sentinal = ' def ' class_sentinal = 'class ' for ix in range(200): func_pos = row - ix searchline = line_list[func_pos] searchline = ub.ensure_unicode(searchline) cleanline = searchline.strip(' ') searchlines.append(cleanline) if searchline.startswith(func_sentinal): # and cleanline.endswith(':'): # Found a valid function name funcname = parse_callname(searchline, func_sentinal) if funcname is not None: break if orclass and searchline.startswith(class_sentinal): # Found a valid class name (as funcname) funcname = parse_callname(searchline, class_sentinal) if funcname is not None: break if searchline.startswith(method_sentinal): # and cleanline.endswith(':'): # Found a valid function name funcname = parse_callname(searchline, method_sentinal) if funcname is not None: classline, classpos = find_pyclass_above_row(line_list, func_pos) classname = parse_callname(classline, class_sentinal) if classname is not None: funcname = '.'.join([classname, funcname]) break else: funcname = None foundline = searchline return funcname, searchlines, func_pos, foundline
def is_paragraph_end(line_): # Hack, par_marker_list should be an argument import ubelt as ub striped_line = ub.ensure_unicode(line_.strip()) isblank = striped_line == '' if isblank: return True par_marker_list = [ #'\\noindent', '\\begin{equation}', '\\end{equation}', '% ---', ] return any(striped_line.startswith(marker) for marker in par_marker_list)
def parse(ImportVisitor, source=None, modpath=None, modname=None, module=None): if module is not None: if source is None: source = inspect.getsource(module) if modpath is None: modname = module.__file__ if modname is None: modname = module.__name__ if modpath is not None: if modpath.endswith('.pyc'): modpath = modpath.replace('.pyc', '.py') # python 2 hack if isdir(modpath): modpath = join(modpath, '__init__.py') if modname is None: modname = ub.modpath_to_modname(modpath) if modpath is not None: if source is None: if not modpath.endswith(('.py', '>')): raise NotAPythonFile( 'can only parse python files, not {}'.format(modpath)) source = open(modpath, 'r').read() if source is None: raise ValueError('unable to derive source code') source = ub.ensure_unicode(source) if six.PY2: try: pt = ast.parse(source) except SyntaxError as ex: if 'encoding declaration in Unicode string' in ex.args[0]: pt = ast.parse(source.encode()) else: raise else: pt = ast.parse(source) visitor = ImportVisitor(modpath, modname, module, pt=pt) visitor.visit(pt) return visitor
def hash_code(sourcecode): r""" Hashes source code text, but tries to normalize things like whitespace and comments, so very minor changes wont change the hash. Args: source (str): uft8 text of source code Returns: str: hashid: 128 character (512 byte) hash of the normalized input Notes: The return value of this function is based on the AST parse tree, which might change between different version of Python. However, within the same version of Python, the results should be consistent. CommandLine: xdoctest -m /home/joncrall/code/torch_liberator/torch_liberator/export/exporter.py hash_code Example: >>> hashid1 = (hash_code('x = 1')[0:8]) >>> hashid2 = (hash_code('x=1 # comments and spaces dont matter')[0:8]) >>> hashid3 = (hash_code('\nx=1')[0:8]) >>> assert ub.allsame([hashid1, hashid2, hashid3]) >>> hashid4 = hash_code('x=2')[0:8] >>> assert hashid1 != hashid4 """ # Strip docstrings before making a parse tree sourcecode = ub.ensure_unicode(sourcecode) stripped = remove_comments_and_docstrings(sourcecode) # Also remove pytorch_export version info (not sure if correct?) stripped = re.sub('__pt_export_version__ = .*', '', stripped) parse_tree = ast.parse(stripped) # hashing the parse tree will normalize for a lot possible small changes ast_dump = ast.dump(parse_tree) hasher = hashlib.sha512() hasher.update(ast_dump.encode('utf8')) hashid = hasher.hexdigest() return hashid
def text_between_lines(lnum1, lnum2, col1=0, col2=sys.maxsize - 1): import vim # lines = vim.eval('getline({}, {})'.format(lnum1, lnum2)) lines = vim.current.buffer[lnum1 - 1:lnum2] lines = [ub.ensure_unicode(line) for line in lines] try: if len(lines) == 0: pass elif len(lines) == 1: lines[0] = lines[0][col1:col2 + 1] else: # lines[0] = lines[0][col1:] # lines[-1] = lines[-1][:col2 + 1] for i in range(len(lines)): lines[i] = lines[i][col1:col2 + 1] text = '\n'.join(lines) except Exception: print(ub.repr2(lines)) raise return text
def hash_code(sourcecode): r""" Hashes source code text, but tries to normalize things like whitespace and comments, so very minor changes wont change the hash. Args: source (str): uft8 text of source code Returns: str: hashid: 128 character (512 byte) hash of the normalized input Example: >>> print(hash_code('x = 1')[0:8]) 93d321be >>> print(hash_code('x=1 # comments and spaces dont matter')[0:8]) 93d321be >>> print(hash_code('\nx=1')[0:8]) 93d321be >>> print(hash_code('x=2')[0:8]) 6949c223 """ # Strip docstrings before making a parse tree sourcecode = ub.ensure_unicode(sourcecode) stripped = remove_comments_and_docstrings(sourcecode) # Also remove pytorch_export version info (not sure if correct?) stripped = re.sub('__pt_export_version__ = .*', '', stripped) parse_tree = ast.parse(stripped) # hashing the parse tree will normalize for a lot possible small changes ast_dump = ast.dump(parse_tree) hasher = hashlib.sha512() hasher.update(ast_dump.encode('utf8')) hashid = hasher.hexdigest() return hashid
def difftext(text1, text2, context_lines=0, ignore_whitespace=False, colored=False): r""" Uses difflib to return a difference string between two similar texts Args: text1 (str): old text text2 (str): new text context_lines (int): number of lines of unchanged context ignore_whitespace (bool): colored (bool): if true highlight the diff Returns: str: formatted difference text message References: http://www.java2s.com/Code/Python/Utility/IntelligentdiffbetweentextfilesTimPeters.htm Example: >>> # build test data >>> text1 = 'one\ntwo\nthree' >>> text2 = 'one\ntwo\nfive' >>> # execute function >>> result = difftext(text1, text2) >>> # verify results >>> print(result) - three + five Example: >>> # build test data >>> text1 = 'one\ntwo\nthree\n3.1\n3.14\n3.1415\npi\n3.4\n3.5\n4' >>> text2 = 'one\ntwo\nfive\n3.1\n3.14\n3.1415\npi\n3.4\n4' >>> # execute function >>> context_lines = 1 >>> result = difftext(text1, text2, context_lines, colored=True) >>> # verify results >>> print(result) """ import ubelt as ub import difflib text1 = ub.ensure_unicode(text1) text2 = ub.ensure_unicode(text2) text1_lines = text1.splitlines() text2_lines = text2.splitlines() if ignore_whitespace: text1_lines = [t.rstrip() for t in text1_lines] text2_lines = [t.rstrip() for t in text2_lines] ndiff_kw = dict(linejunk=difflib.IS_LINE_JUNK, charjunk=difflib.IS_CHARACTER_JUNK) else: ndiff_kw = {} all_diff_lines = list(difflib.ndiff(text1_lines, text2_lines, **ndiff_kw)) if context_lines is None: diff_lines = all_diff_lines else: # boolean for every line if it is marked or not ismarked_list = [ len(line) > 0 and line[0] in '+-?' for line in all_diff_lines ] # flag lines that are within context_lines away from a diff line isvalid_list = ismarked_list[:] for i in range(1, context_lines + 1): isvalid_list[:-i] = list( map(any, zip(isvalid_list[:-i], ismarked_list[i:]))) isvalid_list[i:] = list( map(any, zip(isvalid_list[i:], ismarked_list[:-i]))) USE_BREAK_LINE = True if USE_BREAK_LINE: # insert a visual break when there is a break in context diff_lines = [] prev = False visual_break = '\n <... FILTERED CONTEXT ...> \n' #print(isvalid_list) for line, valid in zip(all_diff_lines, isvalid_list): if valid: diff_lines.append(line) elif prev: if False: diff_lines.append(visual_break) prev = valid else: diff_lines = list(ub.compress(all_diff_lines, isvalid_list)) text = '\n'.join(diff_lines) if colored: text = ub.highlight_code(text, lexer_name='diff') return text
def parse_import_names(sourcecode, top_level=True, fpath=None, branch=False): """ Finds all function names in a file without importing it Args: sourcecode (str): Returns: list: func_names References: https://stackoverflow.com/questions/20445733/how-to-tell-which-modules-have-been-imported-in-some-source-code Example: >>> from vimtk import pyinspect >>> fpath = pyinspect.__file__.replace('.pyc', '.py') >>> sourcecode = ub.readfrom(fpath) >>> func_names = parse_import_names(sourcecode) >>> result = ('func_names = %s' % (ub.repr2(func_names),)) >>> print(result) """ import_names = [] if six.PY2: sourcecode = ub.ensure_unicode(sourcecode) encoded = sourcecode.encode('utf8') pt = ast.parse(encoded) else: pt = ast.parse(sourcecode) modules = [] class ImportVisitor(ast.NodeVisitor): def _parse_alias_list(self, aliases): for alias in aliases: if alias.asname is not None: import_names.append(alias.asname) else: if '.' not in alias.name: import_names.append(alias.name) def visit_Import(self, node): self._parse_alias_list(node.names) self.generic_visit(node) for alias in node.names: modules.append(alias.name) def visit_ImportFrom(self, node): self._parse_alias_list(node.names) self.generic_visit(node) for alias in node.names: prefix = '' if node.level: if fpath is not None: modparts = ub.split_modpath( os.path.abspath(fpath))[1].replace('\\', '/').split('/') parts = modparts[:-node.level] # parts = os.path.split(ub.split_modpath(os.path.abspath(fpath))[1])[:-node.level] prefix = '.'.join(parts) + '.' # prefix = '.'.join(os.path.split(fpath)[-node.level:]) + '.' else: prefix = '.' * node.level # modules.append(node.level * '.' + node.module + '.' + alias.name) # modules.append(prefix + node.module + '.' + alias.name) modules.append(prefix + node.module) def visit_FunctionDef(self, node): # Ignore modules imported in functions if not top_level: self.generic_visit(node) # ast.NodeVisitor.generic_visit(self, node) def visit_ClassDef(self, node): if not top_level: self.generic_visit(node) # ast.NodeVisitor.generic_visit(self, node) def visit_If(self, node): if not branch: # TODO: determine how to figure out if a name is in all branches if not _node_is_main_if(node): # Ignore the main statement self.generic_visit(node) try: ImportVisitor().visit(pt) except Exception: pass return import_names, modules
def parse_function_names(sourcecode, top_level=True, ignore_condition=1): """ Finds all function names in a file without importing it Args: sourcecode (str): Returns: list: func_names Example: >>> from vimtk import pyinspect >>> fpath = pyinspect.__file__.replace('.pyc', '.py') >>> sourcecode = ub.readfrom(fpath) >>> func_names = parse_function_names(sourcecode) >>> result = ('func_names = %s' % (ub.repr2(func_names),)) >>> print(result) """ func_names = [] if six.PY2: sourcecode = ub.ensure_unicode(sourcecode) encoded = sourcecode.encode('utf8') pt = ast.parse(encoded) else: pt = ast.parse(sourcecode) class FuncVisitor(ast.NodeVisitor): def __init__(self): super(FuncVisitor, self).__init__() self.condition_names = None self.condition_id = -9001 self.in_condition_chain = False def visit_If(self, node): if ignore_condition: return # if ignore_conditional: # return # Ignore the main statement # print('----') # print('node.test = {!r}'.format(node.test)) # print('node.orelse = {!r}'.format(node.orelse)) if _node_is_main_if(node): return # if isinstance(node.orelse, ast.If): # # THIS IS AN ELIF # self.condition_id += 1 # self.in_condition_chain = True # ast.NodeVisitor.generic_visit(self, node) # self.in_condition_chain = False # pass # # TODO: where does else get parsed exactly? # Reset the set of conditionals # self.condition_id = 0 # self.condition_names = ub.ddict(list) # self.in_condition_chain = True ast.NodeVisitor.generic_visit(self, node) # self.in_condition_chain = False # if False: # # IF THIS WAS AN ELSE: # if self.condition_names is not None: # # anything defined in all conditions is kosher # from six.moves import reduce # common_names = reduce(set.intersection, # map(set, self.condition_names.values())) # self.func_names.extend(common_names) # self.condition_names = None def visit_FunctionDef(self, node): # if self.in_condition_chain and self.condition_names is not None: # # dont immediately add things in conditions. Wait until we can # # ensure which definitions are common in all conditions. # self.condition_names[self.condition_id].append(node.name) # else: func_names.append(node.name) if not top_level: ast.NodeVisitor.generic_visit(self, node) def visit_ClassDef(self, node): if not top_level: ast.NodeVisitor.generic_visit(self, node) try: FuncVisitor().visit(pt) except Exception: raise return func_names
def remove_comments_and_docstrings(source): r""" Args: source (str): uft8 text of source code Returns: str: out: the source with comments and docstrings removed. References: https://stackoverflow.com/questions/1769332/remove-comments-docstrings Example: >>> source = ub.codeblock( ''' def foo(): 'The spaces before this docstring are tokenize.INDENT' test = [ 'The spaces before this string do not get a token' ] ''') >>> out = remove_comments_and_docstrings(source) >>> want = ub.codeblock( ''' def foo(): test = [ 'The spaces before this string do not get a token' ]''').splitlines() >>> got = [o.rstrip() for o in out.splitlines()] >>> assert got == want """ source = ub.ensure_unicode(source) io_obj = io.StringIO(source) out = '' prev_toktype = tokenize.INDENT last_lineno = -1 last_col = 0 for tok in tokenize.generate_tokens(io_obj.readline): token_type = tok[0] token_string = tok[1] start_line, start_col = tok[2] end_line, end_col = tok[3] # ltext = tok[4] # The following two conditionals preserve indentation. # This is necessary because we're not using tokenize.untokenize() # (because it spits out code with copious amounts of oddly-placed # whitespace). if start_line > last_lineno: last_col = 0 if start_col > last_col: out += (' ' * (start_col - last_col)) # Remove comments: if token_type == tokenize.COMMENT: pass # This series of conditionals removes docstrings: elif token_type == tokenize.STRING: if prev_toktype != tokenize.INDENT: # This is likely a docstring; double-check we're not inside an # operator: if prev_toktype != tokenize.NEWLINE: # Note regarding NEWLINE vs NL: The tokenize module # differentiates between newlines that start a new statement # and newlines inside of operators such as parens, brackes, # and curly braces. Newlines inside of operators are # NEWLINE and newlines that start new code are NL. # Catch whole-module docstrings: if start_col > 0: # Unlabelled indentation means we're inside an operator out += token_string # Note regarding the INDENT token: The tokenize module does # not label indentation inside of an operator (parens, # brackets, and curly braces) as actual indentation. else: out += token_string prev_toktype = token_type last_col = end_col last_lineno = end_line return out
def none_or_unicode(text): return None if text is None else ub.ensure_unicode(text)
def sedfile(fpath, regexpr, repl, dry=False, verbose=1): r""" Execute a search and replace on a particular file TODO: - [ ] Store "SedResult" class, with lazy execution Example: >>> from xdev.search_replace import * # NOQA >>> from xdev.search_replace import _create_test_filesystem >>> fpath = _create_test_filesystem()['contents'][1] >>> changed_lines1 = sedfile(fpath, 'a', 'x', dry=True, verbose=1) >>> changed_lines2 = sedfile(fpath, 'a', 'x', dry=False, verbose=0) >>> assert changed_lines2 == changed_lines1 >>> changed_lines3 = sedfile(fpath, 'a', 'x', dry=False, verbose=0) >>> assert changed_lines3 != changed_lines2 """ import xdev mode_text = ['(real-run)', '(dry-run)'][dry] pattern = Pattern.coerce(regexpr, hint='regex') path, name = split(fpath) new_file_lines = [] try: with open(fpath, 'r') as file: file_lines = file.readlines() # Search each line for the desired regexpr new_file_lines = [pattern.sub(repl, line) for line in file_lines] except UnicodeDecodeError as ex: # Add the file name into the exception new_last_arg = ex.args[-1] + ' in fpath={!r}'.format(fpath) new_args = ex.args[:-1] + (new_last_arg, ) raise UnicodeDecodeError(*new_args) from ex except Exception: raise # This does not preserve exception type # raise Exception('Failed to sedfile fpath = {!r}'.format(fpath)) from ex changed_lines = [(newline, line) for newline, line in zip(new_file_lines, file_lines) if newline != line] nChanged = len(changed_lines) if nChanged > 0: try: rel_fpath = relpath(fpath, os.getcwd()) except ValueError: # windows issues rel_fpath = abspath(fpath) if verbose: print(' * {} changed {} lines in {!r} '.format( mode_text, nChanged, rel_fpath)) print(' * --------------------') new_file = ''.join(new_file_lines) old_file = ub.ensure_unicode(''.join( list(map(ub.ensure_unicode, file_lines)))) if verbose: print(xdev.difftext(old_file, new_file, colored=True)) if not dry: if verbose: print(' ! WRITING CHANGES') with open(fpath, 'w') as file: file.write(new_file) return changed_lines return []
def 文本_转unicode(str): return ub.ensure_unicode(str)
def get_func_sourcecode(func, strip_def=False, strip_ret=False, strip_docstr=False, strip_comments=False, remove_linenums=None, strip_decor=False): """ wrapper around inspect.getsource but takes into account utool decorators strip flags are very hacky as of now Args: func (function): strip_def (bool): strip_ret (bool): (default = False) strip_docstr (bool): (default = False) strip_comments (bool): (default = False) remove_linenums (None): (default = None) Example: >>> # build test data >>> func = get_func_sourcecode >>> strip_def = True >>> strip_ret = True >>> sourcecode = get_func_sourcecode(func, strip_def) >>> print('sourcecode = {}'.format(sourcecode)) """ inspect.linecache.clearcache() # HACK: fix inspect bug sourcefile = inspect.getsourcefile(func) if hasattr(func, '_utinfo'): # DEPRICATE func2 = func._utinfo['orig_func'] sourcecode = get_func_sourcecode(func2) elif sourcefile is not None and (sourcefile != '<string>'): try_limit = 2 for num_tries in range(try_limit): try: #print(func) sourcecode = inspect.getsource(func) if not isinstance(sourcecode, six.text_type): sourcecode = sourcecode.decode('utf-8') #print(sourcecode) except (IndexError, OSError, SyntaxError): print('WARNING: Error getting source') inspect.linecache.clearcache() if num_tries + 1 != try_limit: tries_left = try_limit - num_tries - 1 print('Attempting %d more time(s)' % (tries_left)) else: raise else: sourcecode = None if strip_def: # hacky # TODO: use redbaron or something like that for a more robust appraoch sourcecode = textwrap.dedent(sourcecode) regex_decor = '^@.' + REGEX_NONGREEDY regex_defline = '^def [^:]*\\):\n' patern = '(' + regex_decor + ')?' + regex_defline RE_FLAGS = re.MULTILINE | re.DOTALL RE_KWARGS = {'flags': RE_FLAGS} nodef_source = re.sub(patern, '', sourcecode, **RE_KWARGS) sourcecode = textwrap.dedent(nodef_source) #print(sourcecode) pass if strip_ret: r""" \s is a whitespace char """ return_ = named_field('return', 'return .*$') prereturn = named_field('prereturn', r'^\s*') return_bref = bref_field('return') prereturn_bref = bref_field('prereturn') regex = prereturn + return_ repl = prereturn_bref + 'pass # ' + return_bref sourcecode_ = re.sub(regex, repl, sourcecode, flags=re.MULTILINE) sourcecode = sourcecode_ pass if strip_docstr or strip_comments: # pip install pyminifier # References: http://code.activestate.com/recipes/576704/ #from pyminifier import minification, token_utils def remove_docstrings_or_comments(source): """ TODO: commit clean version to pyminifier """ import tokenize from six.moves import StringIO io_obj = StringIO(source) out = '' prev_toktype = tokenize.INDENT last_lineno = -1 last_col = 0 for tok in tokenize.generate_tokens(io_obj.readline): token_type = tok[0] token_string = tok[1] start_line, start_col = tok[2] end_line, end_col = tok[3] if start_line > last_lineno: last_col = 0 if start_col > last_col: out += (' ' * (start_col - last_col)) # Remove comments: if strip_comments and token_type == tokenize.COMMENT: pass elif strip_docstr and token_type == tokenize.STRING: if prev_toktype != tokenize.INDENT: # This is likely a docstring; double-check we're not inside an operator: if prev_toktype != tokenize.NEWLINE: if start_col > 0: out += token_string else: out += token_string prev_toktype = token_type last_col = end_col last_lineno = end_line return out sourcecode = remove_docstrings_or_comments(sourcecode) #sourcecode = minification.remove_comments_and_docstrings(sourcecode) #tokens = token_utils.listified_tokenizer(sourcecode) #minification.remove_comments(tokens) #minification.remove_docstrings(tokens) #token_utils.untokenize(tokens) if strip_decor: try: import redbaron red = redbaron.RedBaron(ub.codeblock(sourcecode)) except Exception: hack_text = ub.ensure_unicode(ub.codeblock(sourcecode)).encode( 'ascii', 'replace') red = redbaron.RedBaron(hack_text) pass if len(red) == 1: redfunc = red[0] if redfunc.type == 'def': # Remove decorators del redfunc.decorators[:] sourcecode = redfunc.dumps() if remove_linenums is not None: source_lines = sourcecode.strip('\n').split('\n') delete_items_by_index(source_lines, remove_linenums) sourcecode = '\n'.join(source_lines) return sourcecode
def source_closure(model_class): """ Hacky way to pull just the minimum amount of code needed to define a model_class. Args: model_class (type): class used to define the model_class Returns: str: closed_sourcecode: text defining a new python module. Example: >>> from torchvision import models >>> model_class = models.AlexNet >>> text = source_closure(model_class) >>> assert not undefined_names(text) >>> print(hash_code(text)) 18a043fc0563bcf8f97b2ee76d... >>> model_class = models.DenseNet >>> text = source_closure(model_class) >>> assert not undefined_names(text) >>> print(hash_code(text)) d52175ef0d52ec5ca155bdb1037... >>> model_class = models.resnet50 >>> text = source_closure(model_class) >>> assert not undefined_names(text) >>> print(hash_code(text)) ad683af44142b58c85b6c2314... >>> model_class = models.Inception3 >>> text = source_closure(model_class) >>> assert not undefined_names(text) >>> print(hash_code(text)) bd7c67c37e292ffad6beb8532324d3... """ module_name = model_class.__module__ module = sys.modules[module_name] sourcecode = inspect.getsource(model_class) sourcecode = ub.ensure_unicode(sourcecode) names = undefined_names(sourcecode) # try: # module_source = ub.readfrom(module.__file__) # except OSError: module_source = inspect.getsource(module) module_source = ub.ensure_unicode(module_source) pt = ast.parse(module_source) visitor = ImportVisitor(module.__file__) try: visitor.visit(pt) except Exception: pass def closure_(obj, name): # TODO: handle assignments if name in visitor.import_lines: # Check and see if the name was imported from elsewhere return 'import', visitor.import_lines[name] elif name in visitor.assignments: type_, value = visitor.assignments[name] if type_ == 'node': # TODO, need to handle non-simple expressions return type_, '{} = {}'.format(name, value.value.id) else: # when value is a dict we need to be sure it is # extracted in the same order as we see it return type_, '{} = {}'.format(name, ub.repr2(value)) elif isinstance(obj, types.FunctionType): if obj.__module__ == module_name: sourcecode = inspect.getsource(obj) return 'code', sourcecode elif isinstance(obj, type): if obj.__module__ == module_name: sourcecode = inspect.getsource(obj) return 'code', sourcecode raise NotImplementedError(str(obj) + ' ' + str(name)) import_lines = [] lines = [sourcecode] while names: # Make sure we process names in the same order for hashability names = sorted(set(names)) for name in names: obj = getattr(module, name) type_, text = closure_(obj, name) if type_ == 'import': import_lines.append(text) else: lines.append(text) if text is None: raise NotImplementedError(str(obj) + ' ' + str(name)) break import_lines = sorted(import_lines) closed_sourcecode = ('\n'.join(import_lines) + '\n\n\n' + '\n\n'.join(lines[::-1])) names = sorted(undefined_names(closed_sourcecode)) return closed_sourcecode