def get_post_call_scope(self): old_scope = tracer().cur_frame_original_scope if isinstance(self.stmt_node, ast.ClassDef): # classes need a new scope before the ClassDef has finished executing, # so we make it immediately pending_ns = Namespace.make_child_namespace( old_scope, self.stmt_node.name) tracer().pending_class_namespaces.append(pending_ns) return pending_ns if isinstance(self.stmt_node, (ast.FunctionDef, ast.AsyncFunctionDef)): func_name = self.stmt_node.name else: func_name = None func_sym = nbs().statement_to_func_cell.get(id(self.stmt_node), None) if func_sym is None: # TODO: brittle; assumes any user-defined and traceable function will always be present; is this safe? return old_scope if not func_sym.is_function: msg = "got non-function symbol %s for name %s" % ( func_sym.full_path, func_name, ) if nbs().is_develop: raise TypeError(msg) else: logger.warning(msg) return old_scope if not self.finished: func_sym.create_symbols_for_call_args() return func_sym.call_scope
def test_updated_namespace_after_subscript_dep_removed(): cells = { 0: 'x = 5', 1: 'd = {x: 5}', 2: 'logging.info(d[5])', 3: 'x = 9', } for idx, cell in cells.items(): run_cell(cell, idx) response = nbs().check_and_link_multiple_cells(cells) assert response['stale_cells'] == [2] assert response['fresh_cells'] == [1] cells[1] = 'd = {5: 5}' run_cell(cells[1], 1) response = nbs().check_and_link_multiple_cells(cells) assert response['stale_cells'] == [] assert response['fresh_cells'] == [2] run_cell(cells[2], 2) response = nbs().check_and_link_multiple_cells(cells) assert response['stale_cells'] == [] assert response['fresh_cells'] == [] run_cell(cells[0], 0) response = nbs().check_and_link_multiple_cells(cells) assert response['stale_cells'] == [] assert response['fresh_cells'] == [], 'got %s' % response['fresh_cells']
def _handle_aliases(self): old_aliases = nbs().aliases.get(self.cached_obj_id, None) if old_aliases is not None: old_aliases.discard(self) if len(old_aliases) == 0: del nbs().aliases[self.cached_obj_id] nbs().aliases[self.obj_id].add(self)
def should_preserve_timestamp(self, prev_obj: Optional[Any]) -> bool: if nbs().mut_settings.exec_mode == ExecutionMode.REACTIVE: # always bump timestamps for reactive mode return False if nbs().mut_settings.exec_schedule == ExecutionSchedule.DAG_BASED: # always bump timestamps for dag schedule return False if prev_obj is None: return False if ( nbs().blocked_reactive_timestamps_by_symbol.get(self, -1) == self.timestamp.cell_num ): return False if not self._cached_out_of_sync or self.obj_id == self.cached_obj_id: return True if self.obj is None or prev_obj is DataSymbol.NULL: return self.obj is None and prev_obj is DataSymbol.NULL obj_type = type(self.obj) prev_type = type(prev_obj) if obj_type != prev_type: return False obj_size_ubound = sizing.sizeof(self.obj) if obj_size_ubound > sizing.MAX_SIZE: return False cached_obj_size_ubound = sizing.sizeof(prev_obj) if cached_obj_size_ubound > sizing.MAX_SIZE: return False return (obj_size_ubound == cached_obj_size_ubound) and self.obj == prev_obj
def update_obj_ref(self, obj, refresh_cached=True): logger.info("%s update obj ref to %s", self, obj) self._tombstone = False self._cached_out_of_sync = True if ( nbs().settings.mark_typecheck_failures_unsafe and self.cached_obj_type != type(obj) ): for cell in self.cells_where_live: cell.invalidate_typecheck_result() self.cells_where_shallow_live.clear() self.cells_where_deep_live.clear() self.obj = obj if self.cached_obj_id is not None and self.cached_obj_id != self.obj_id: new_ns = nbs().namespaces.get(self.obj_id, None) # don't overwrite existing namespace for this obj old_ns = nbs().namespaces.get(self.cached_obj_id, None) if ( old_ns is not None and ( new_ns is None or not new_ns.max_descendent_timestamp.is_initialized ) and old_ns.full_namespace_path == self.full_namespace_path ): if new_ns is None: logger.info("create fresh copy of namespace %s", old_ns) new_ns = old_ns.fresh_copy(obj) else: new_ns.scope_name = old_ns.scope_name new_ns.parent_scope = old_ns.parent_scope old_ns.transfer_symbols_to(new_ns) self._handle_aliases() if refresh_cached: self._refresh_cached_obj()
def _make_lval_data_symbols_old(self): symbol_edges = get_symbol_edges(self.stmt_node) should_overwrite = not isinstance(self.stmt_node, ast.AugAssign) is_function_def = isinstance(self.stmt_node, (ast.FunctionDef, ast.AsyncFunctionDef)) is_class_def = isinstance(self.stmt_node, ast.ClassDef) is_import = isinstance(self.stmt_node, (ast.Import, ast.ImportFrom)) if is_function_def or is_class_def: assert len(symbol_edges) == 1 # assert not lval_symbol_refs.issubset(rval_symbol_refs) for target, dep_node in symbol_edges: rval_deps = resolve_rval_symbols(dep_node) logger.info("create edges from %s to %s", rval_deps, target) if is_class_def: assert self.class_scope is not None class_ref = self.frame.f_locals[self.stmt_node.name] self.class_scope.obj = class_ref nbs().namespaces[id(class_ref)] = self.class_scope try: ( scope, name, obj, is_subscript, excluded_deps, ) = tracer().resolve_store_data_for_target(target, self.frame) scope.upsert_data_symbol_for_name( name, obj, rval_deps - excluded_deps, self.stmt_node, overwrite=should_overwrite, is_subscript=is_subscript, is_function_def=is_function_def, is_import=is_import, class_scope=self.class_scope, propagate=not isinstance(self.stmt_node, ast.For), ) if isinstance( self.stmt_node, (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef), ): self._handle_reactive_store(self.stmt_node) except KeyError as ke: # e.g., slices aren't implemented yet # put logging behind flag to avoid noise to user if nbs().is_develop: logger.warning( "keyerror for %s", ast.dump(target) if isinstance(target, ast.AST) else target, ) # if nbs().is_test: # raise ke except Exception as e: logger.warning("exception while handling store: %s", e) if nbs().is_test: raise e
def finished_execution_hook(self): if self.finished: return # print('finishing stmt', self.stmt_node) tracer().seen_stmts.add(self.stmt_id) self.handle_dependencies() tracer().after_stmt_reset_hook() nbs()._namespace_gc()
def run_cell(cell, cell_id=None, **kwargs): """Mocks the `change active cell` portion of the comm protocol""" if cell_id is not None: nbs().handle({ 'type': 'change_active_cell', 'active_cell_id': cell_id }) run_cell_(cell, **kwargs)
def test_update_list_elem(): force_subscript_symbol_creation = True cells = { 0: (""" class Foo: def __init__(self): self.counter = 0 self.dummy = 0 def inc(self): self.counter += 1 """), 1: (""" lst = [] for i in range(5): x = Foo() lst.append(x) """), 2: (""" for foo in lst: foo.inc() """), 3: "logging.info(lst)", } run_all_cells(cells) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == set(), "got %s" % response.fresh_cells cells[4] = "x.inc()" run_cell(cells[4], 4) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set(), "got %s" % response.stale_cells assert response.fresh_cells == {2, 3}, "got %s" % response.fresh_cells cells[5] = "foo.inc()" run_cell(cells[5], 5) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set(), "got %s" % response.stale_cells assert response.fresh_cells == {2, 3, 4}, "got %s" % response.fresh_cells if force_subscript_symbol_creation: cells[6] = "lst[-1]" run_cell(cells[6], 6) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == {2, 3, 4} run_cell(cells[4], 4) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == set( [2, 3, 5] + ([6] if force_subscript_symbol_creation else []))
def get_ref_count(self): if self.obj is None or self.obj is DataSymbol.NULL: return -1 total = sys.getrefcount(self.obj) - 1 total -= len(nbs().aliases[self.obj_id]) ns = nbs().namespaces.get(self.obj_id, None) if ns is not None and ns.obj is not None and ns.obj is not DataSymbol.NULL: total -= 1 return total
def test_update_list_elem(): force_subscript_symbol_creation = True cells = { 0: """ class Foo: def __init__(self): self.counter = 0 self.dummy = 0 def inc(self): self.counter += 1""", 1: """ lst = [] for i in range(5): x = Foo() lst.append(x)""", 2: """ for foo in lst: foo.inc()""", 3: 'logging.info(lst)', } for idx, cell in cells.items(): run_cell(cell, idx) response = nbs().check_and_link_multiple_cells(cells) assert response['stale_cells'] == [] assert response['fresh_cells'] == [] cells[4] = 'x.inc()' run_cell(cells[4], 4) response = nbs().check_and_link_multiple_cells(cells) assert response['stale_cells'] == [], 'got %s' % response['stale_cells'] assert response['fresh_cells'] == [2, 3] cells[5] = 'foo.inc()' run_cell(cells[5], 5) response = nbs().check_and_link_multiple_cells(cells) assert response['stale_cells'] == [] assert response['fresh_cells'] == [2, 3, 4] if force_subscript_symbol_creation: cells[6] = 'lst[-1]' run_cell(cells[6], 6) response = nbs().check_and_link_multiple_cells(cells) assert response['stale_cells'] == [] assert response['fresh_cells'] == [2, 3, 4] run_cell(cells[4], 4) response = nbs().check_and_link_multiple_cells(cells) assert response['stale_cells'] == [] assert response['fresh_cells'] == [2, 3, 5] + ([6] if force_subscript_symbol_creation else [])
def override_settings(**kwargs): old_settings = nbs().settings new_settings = old_settings._asdict() new_settings.update(kwargs) new_settings = NotebookSafetySettings(**new_settings) try: nbs().settings = new_settings yield finally: nbs().settings = old_settings
def visit(self, node: ast.AST): try: ret = super().visit(node) cells().current_cell().to_ast( override=cast(ast.Module, self.orig_to_copy_mapping[id(node)]) ) return ret except Exception as e: nbs().set_exception_raised_during_execution(e) traceback.print_exc() raise e
def test_int_change_to_str_triggers_typecheck(): run_cell("a = 1", 1) assert not get_cell_ids_needing_typecheck() run_cell("b = 2", 2) assert not get_cell_ids_needing_typecheck() run_cell("logging.info(a + b)", 3) assert not get_cell_ids_needing_typecheck() run_cell('b = "b"', 4) assert get_cell_ids_needing_typecheck() == {3} nbs().check_and_link_multiple_cells() assert not get_cell_ids_needing_typecheck() assert cells().from_id(3)._cached_typecheck_result is False
def handle_dependencies(self): if not nbs().dependency_tracking_enabled: return for mutated_obj_id, mutation_event, mutation_arg_dsyms, mutation_arg_objs in tracer().mutations: logger.info("mutation %s %s %s %s", mutated_obj_id, mutation_event, mutation_arg_dsyms, mutation_arg_objs) update_usage_info(mutation_arg_dsyms) if mutation_event == MutationEvent.arg_mutate: for mutated_sym in mutation_arg_dsyms: if mutated_sym is None: continue # TODO: happens when module mutates args # should we add module as a dep in this case? mutated_sym.update_deps(set(), overwrite=False, mutated=True) continue # NOTE: this next block is necessary to ensure that we add the argument as a namespace child # of the mutated symbol. This helps to avoid propagating through to dependency children that are # themselves namespace children. if mutation_event == MutationEvent.list_append and len(mutation_arg_objs) == 1: namespace_scope = nbs().namespaces.get(mutated_obj_id, None) mutated_sym = nbs().get_first_full_symbol(mutated_obj_id) if mutated_sym is not None: mutated_obj = mutated_sym.get_obj() mutation_arg_obj = next(iter(mutation_arg_objs)) # TODO: replace int check w/ more general "immutable" check if mutation_arg_obj is not None: if namespace_scope is None: namespace_scope = NamespaceScope( mutated_obj, mutated_sym.name, parent_scope=mutated_sym.containing_scope ) logger.info("upsert %s to %s", len(mutated_obj) - 1, namespace_scope) namespace_scope.upsert_data_symbol_for_name( len(mutated_obj) - 1, mutation_arg_obj, set(), self.stmt_node, overwrite=False, is_subscript=True, propagate=False ) # TODO: add mechanism for skipping namespace children in case of list append update_usage_info(nbs().aliases[mutated_obj_id]) for mutated_sym in nbs().aliases[mutated_obj_id]: mutated_sym.update_deps(mutation_arg_dsyms, overwrite=False, mutated=True) if self._contains_lval(): self._make_lval_data_symbols() elif isinstance(self.stmt_node, ast.Delete): self._handle_delete() else: # make sure usage timestamps get bumped resolve_rval_symbols(self.stmt_node)
def init_metadata(self, parent): """ Don't actually change the metadata; we just want to get the cell id out of the execution request. """ metadata = parent.get("metadata", {}) cell_id = metadata.get("cellId", None) if cell_id is not None: nbs().set_active_cell(cell_id) tags = tuple(metadata.get("tags", ())) nbs().set_tags(tags) return super().init_metadata(parent)
def trace_messages(line_: str): line = line_.split() usage = 'Usage: %safety trace_messages [enable|disable]' if len(line) != 1: print(usage) return setting = line[0].lower() if setting == 'on' or setting.startswith('enable'): nbs().trace_messages_enabled = True elif setting == 'off' or setting.startswith('disable'): nbs().trace_messages_enabled = False else: print(usage)
def visit(self, node: 'ast.AST'): try: mapper = StatementMapper(nbs().statement_cache[nbs().cell_counter()], nbs().ast_node_by_id) orig_to_copy_mapping = mapper(node) # very important that the eavesdropper does not create new ast nodes for ast.stmt (but just # modifies existing ones), since StatementInserter relies on being able to map these node = AstEavesdropper(orig_to_copy_mapping).visit(node) node = StatementInserter(orig_to_copy_mapping).visit(node) except Exception as e: nbs().set_ast_transformer_raised(e) traceback.print_exc() raise e return node
def _propagate_staleness_to_deps(self, dsym: DataSymbol, skip_seen_check=False): if not skip_seen_check and dsym in self.seen: return self.seen.add(dsym) if dsym not in nbs().updated_symbols: if dsym.should_mark_stale(self.updated_sym): dsym.fresher_ancestors.add(self.updated_sym) dsym.required_cell_num = nbs().cell_counter() self._propagate_staleness_to_namespace_parents(dsym, skip_seen_check=True) self._propagate_staleness_to_namespace_children(dsym, skip_seen_check=True) for child in self._non_class_to_instance_children(dsym): logger.warning('propagate %s %s to %s', dsym, dsym.obj_id, child) self._propagate_staleness_to_deps(child)
def trace_messages(line_: str) -> None: line = line_.split() usage = "Usage: %safety trace_messages [enable|disable]" if len(line) != 1: logger.warning(usage) return setting = line[0].lower() if setting == "on" or setting.startswith("enable"): nbs().trace_messages_enabled = True elif setting == "off" or setting.startswith("disable"): nbs().trace_messages_enabled = False else: logger.warning(usage)
def set_highlights(cmd: str, rest: str) -> None: usage = "Usage: %safety [hls|nohls]" if cmd == "hls": nbs().mut_settings.highlights_enabled = True elif cmd == "nohls": nbs().mut_settings.highlights_enabled = False else: rest = rest.lower() if rest == "on" or rest.startswith("enable"): nbs().mut_settings.highlights_enabled = True elif rest == "off" or rest.startswith("disable"): nbs().mut_settings.highlights_enabled = False else: logger.warning(usage)
def set_highlights(cmd: str, rest: str): usage = 'Usage: %safety [hls|nohls]' if cmd == 'hls': nbs().mut_settings.highlights_enabled = True elif cmd == 'nohls': nbs().mut_settings.highlights_enabled = False else: rest = rest.lower() if rest == 'on' or rest.startswith('enable'): nbs().mut_settings.highlights_enabled = True elif rest == 'off' or rest.startswith('disable'): nbs().mut_settings.highlights_enabled = False else: print(usage)
def test_dict_clear(): cells = { 0: "d = {0: 0}", 1: "logging.info(d[0])", 2: "logging.info(d)", } run_all_cells(cells) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == set() run_cell("d.clear()", 3) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == {2}
def test_list_clear(): cells = { 0: "lst = [0]", 1: "logging.info(lst[0])", 2: "logging.info(lst)", } run_all_cells(cells) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == set() run_cell("lst.clear()", 3) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == {2}
def test_increment_by_same_amount(): cells = { 0: "x = 2", 1: "y = x + 1", 2: "logging.info(y)", } run_all_cells(cells) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == set() run_cell("x = 3", 0) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == {2} assert response.fresh_cells == {1}
def test_liveness_skipped_for_simple_assignment_involving_aliases(): cells = { 0: "lst = [1, 2, 3]", 1: "lst2 = lst", 2: "lst.append(4)", } run_all_cells(cells) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == set() run_cell("lst = [1, 2, 3, 4]", 3) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == {1, 2}, "got %s" % response.fresh_cells
def test_adhoc_pandas_series_update(): cells = { 0: "import pandas as pd", 1: "df = pd.DataFrame({1: [2, 3], 4: [5, 7]})", 2: 'df["foo"] = [8, 9]', 3: "df.foo.dropna(inplace=True)", } run_all_cells(cells) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == set() run_cell('df["foo"] = [8, 9]', 4) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == {3}
def test_list_extend(): cells = { 0: "lst = [0, 1]", 1: "x = lst[1] + 1", 2: "logging.info(x)", 3: "lst.extend([2, 3, 4])", } run_all_cells(cells) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == set() run_cell("lst[1] += 42", 4) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == {2} assert response.fresh_cells == {1}
def test_equal_dict_update_does_induce_fresh_cell(): cells = { 0: 'x = {"foo": 42, "bar": 43}', 1: 'y = dict(set(x.items()) | set({"baz": 44}.items()))', 2: "logging.info(y)", 3: "y = dict(y.items())", } run_all_cells(cells) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == {2}, "got %s" % response.fresh_cells run_cell('y = {"foo": 99}', 4) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == {2, 3}
def test_equal_list_update_does_induce_fresh_cell(): cells = { 0: 'x = ["f"] + ["o"] * 10', 1: 'y = x + list("bar")', 2: "logging.info(y)", 3: 'y = list("".join(y))', } run_all_cells(cells) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == {2} run_cell('y = ("f",)', 4) response = nbs().check_and_link_multiple_cells() assert response.stale_cells == set() assert response.fresh_cells == {2, 3}