def split_matplotlib_cells(nb): """ If a cell imports matplotlib, split the cell to keep the import statement separate from the code that uses matplotlib. This prevents a known bug in the Jupyter backend which causes the plot object to be represented as a string instead of a canvas when created in the cell where matplotlib is imported for the first time (https://github.com/jupyter/notebook/issues/3523). """ for i in range(len(nb['cells']) - 1, -1, -1): cell = nb['cells'][i] if cell['cell_type'] == 'code' and 'matplotlib' in cell['source']: code = iw.protect_ipython_magics(cell['source']) # split cells after matplotlib imports mapping = iw.delimit_statements(code) tree = ast.parse(code) visitor = iw.GetMatplotlibPyplot() visitor.visit(tree) if visitor.matplotlib_first: code = iw.deprotect_ipython_magics(code) lines = code.split('\n') lineno_end = mapping[visitor.matplotlib_first] split_code = '\n'.join(lines[lineno_end:]).lstrip('\n') if split_code: new_cell = nbformat.v4.new_code_cell(source=split_code) nb['cells'].insert(i + 1, new_cell) lines = lines[:lineno_end] nb['cells'][i]['source'] = '\n'.join(lines).rstrip('\n')
def test_delimit_statements(self): lines = [ 'a = 1 # NEWLINE becomes NL after a comment', 'print("""', '', '""")', '', 'b = 1 +\\', '3 + (', '4)', 'if True:', ' c = 1', ] source_code = '\n'.join(lines) linenos_exp = {1: 1, 2: 4, 6: 8, 9: 9, 10: 10} linenos_out = iw.delimit_statements(source_code) self.assertEqual(linenos_out, linenos_exp)
if cell['cell_type'] == 'code' and 'matplotlib' in cell['source']: cell['source'] = re.sub('^%matplotlib +notebook', '%matplotlib inline', cell['source'], flags=re.M) # if matplotlib is used in this script, split cell to keep the import # statement separate and avoid a known bug in the Jupyter backend which # causes the plot object to be represented as a string instead of a # canvas when created in the cell where matplotlib is imported for the # first time (https://github.com/jupyter/notebook/issues/3523) for i in range(len(nb['cells'])): cell = nb['cells'][i] if cell['cell_type'] == 'code' and 'matplotlib' in cell['source']: code = iw.protect_ipython_magics(cell['source']) # split cells after matplotlib imports mapping = iw.delimit_statements(code) tree = ast.parse(code) visitor = iw.GetMatplotlibPyplot() visitor.visit(tree) if visitor.matplotlib_first: code = iw.deprotect_ipython_magics(code) lines = code.split('\n') lineno_end = mapping[visitor.matplotlib_first] split_code = '\n'.join(lines[lineno_end:]).lstrip('\n') if split_code: new_cell = nbformat.v4.new_code_cell(source=split_code) nb['cells'].insert(i + 1, new_cell) lines = lines[:lineno_end] nb['cells'][i]['source'] = '\n'.join(lines).rstrip('\n') break