def get_post_call_scope(self): old_scope = tracer().cur_frame_original_scope if isinstance(self.stmt_node, ast.ClassDef): # classes need a new scope before the ClassDef has finished executing, # so we make it immediately pending_ns = Namespace.make_child_namespace( old_scope, self.stmt_node.name) tracer().pending_class_namespaces.append(pending_ns) return pending_ns if isinstance(self.stmt_node, (ast.FunctionDef, ast.AsyncFunctionDef)): func_name = self.stmt_node.name else: func_name = None func_sym = nbs().statement_to_func_cell.get(id(self.stmt_node), None) if func_sym is None: # TODO: brittle; assumes any user-defined and traceable function will always be present; is this safe? return old_scope if not func_sym.is_function: msg = "got non-function symbol %s for name %s" % ( func_sym.full_path, func_name, ) if nbs().is_develop: raise TypeError(msg) else: logger.warning(msg) return old_scope if not self.finished: func_sym.create_symbols_for_call_args() return func_sym.call_scope
def _match_call_args_with_definition_args(self): assert self.is_function and isinstance( self.stmt_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda) ) caller_node = self._get_calling_ast_node() if caller_node is None or not isinstance(caller_node, ast.Call): return [] def_args = self.stmt_node.args.args if len(self.stmt_node.args.defaults) > 0: def_args = def_args[: -len(self.stmt_node.args.defaults)] if len(def_args) > 0 and def_args[0].arg == "self": # FIXME: this is bad and I should feel bad def_args = def_args[1:] for def_arg, call_arg in zip(def_args, caller_node.args): if isinstance(call_arg, ast.Starred): # give up # TODO: handle this case break yield def_arg.arg, tracer().resolve_loaded_symbols(call_arg) seen_keys = set() for keyword in caller_node.keywords: key, value = keyword.arg, keyword.value if value is None: continue seen_keys.add(key) yield key, tracer().resolve_loaded_symbols(value) for key, value in zip( self.stmt_node.args.args[-len(self.stmt_node.args.defaults) :], self.stmt_node.args.defaults, ): if key.arg in seen_keys: continue yield key.arg, tracer().resolve_loaded_symbols(value)
def finished_execution_hook(self): if self.finished: return # print('finishing stmt', self.stmt_node) tracer().seen_stmts.add(self.stmt_id) self.handle_dependencies() tracer().after_stmt_reset_hook() nbs()._namespace_gc()
def _append_atom( self, node: Union[ast.Name, ast.Attribute, ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef], val: str, **kwargs, ) -> None: self.symbol_chain.append( Atom( val, is_reactive=id(node) in tracer().reactive_node_ids, is_blocking=id(node) in tracer().blocking_node_ids, **kwargs, ))
def _handle_starred_store_target(self, target: ast.Starred, inner_deps: List[Optional[DataSymbol]]): try: scope, name, obj, is_subscript, _ = tracer( ).resolve_store_data_for_target(target, self.frame) except KeyError as e: # e.g., slices aren't implemented yet # use suppressed log level to avoid noise to user logger.info("Exception: %s", e) return ns = nbs().namespaces.get(id(obj), None) if ns is None: ns = Namespace(obj, str(name), scope) for i, inner_dep in enumerate(inner_deps): deps = set() if inner_dep is None else {inner_dep} ns.upsert_data_symbol_for_name(i, inner_dep.obj, deps, self.stmt_node, is_subscript=True) scope.upsert_data_symbol_for_name( name, obj, set(), self.stmt_node, is_subscript=is_subscript, ) self._handle_reactive_store(target.value)
def get_post_call_scope(self): old_scope = tracer().cur_frame_original_scope if isinstance(self.stmt_node, ast.ClassDef): # classes need a new scope before the ClassDef has finished executing, # so we make it immediately return old_scope.make_child_scope(self.stmt_node.name, obj_id=-1) if not isinstance(self.stmt_node, (ast.FunctionDef, ast.AsyncFunctionDef)): # TODO: probably the right thing is to check is whether a lambda appears somewhere inside the ast node # if not isinstance(self.ast_node, ast.Lambda): # raise TypeError('unexpected type for ast node %s' % self.ast_node) return old_scope func_name = self.stmt_node.name func_cell = nbs().statement_to_func_cell.get(id(self.stmt_node), None) if func_cell is None: # TODO: brittle; assumes any user-defined and traceable function will always be present; is this safe? return old_scope if not func_cell.is_function: if nbs().is_develop: raise TypeError('got non-function symbol %s for name %s' % (func_cell.full_path, func_name)) else: # TODO: log an error to a file return old_scope if not self.finished: func_cell.create_symbols_for_call_args() return func_cell.call_scope
def _handle_reactive_store(target: ast.AST) -> None: try: symbol_ref = SymbolRef(target) reactive_seen = False blocking_seen = False for resolved in symbol_ref.gen_resolved_symbols( tracer().cur_frame_original_scope, only_yield_final_symbol=False, yield_all_intermediate_symbols=True, inherit_reactivity=False, yield_in_reverse=True, ): if resolved.is_blocking: blocking_seen = True if resolved.is_reactive and not blocking_seen: nbs().updated_deep_reactive_symbols.add(resolved.dsym) reactive_seen = True if reactive_seen and not blocking_seen: nbs().updated_reactive_symbols.add(resolved.dsym) if blocking_seen and resolved.dsym not in nbs( ).updated_symbols: nbs().blocked_reactive_timestamps_by_symbol[ resolved.dsym] = nbs().cell_counter() except TypeError: return
def _make_lval_data_symbols_old(self): symbol_edges = get_symbol_edges(self.stmt_node) should_overwrite = not isinstance(self.stmt_node, ast.AugAssign) is_function_def = isinstance(self.stmt_node, (ast.FunctionDef, ast.AsyncFunctionDef)) is_class_def = isinstance(self.stmt_node, ast.ClassDef) is_import = isinstance(self.stmt_node, (ast.Import, ast.ImportFrom)) if is_function_def or is_class_def: assert len(symbol_edges) == 1 # assert not lval_symbol_refs.issubset(rval_symbol_refs) for target, dep_node in symbol_edges: rval_deps = resolve_rval_symbols(dep_node) logger.info('create edges from %s to %s', rval_deps, target) if is_class_def: assert self.class_scope is not None class_ref = self.frame.f_locals[self.stmt_node.name] class_obj_id = id(class_ref) self.class_scope.obj_id = class_obj_id nbs().namespaces[class_obj_id] = self.class_scope try: scope, name, obj, is_subscript = tracer().resolve_store_or_del_data_for_target(target, self.frame, ctx=ast.Store()) scope.upsert_data_symbol_for_name( name, obj, rval_deps, self.stmt_node, overwrite=should_overwrite, is_subscript=is_subscript, is_function_def=is_function_def, is_import=is_import, class_scope=self.class_scope, propagate=not isinstance(self.stmt_node, ast.For) ) except KeyError: logger.warning('keyerror for %s', target) except Exception as e: logger.warning('exception while handling store: %s', e) pass
def visit_Subscript(self, node: ast.Subscript): if isinstance(node.value, ast.Call): self.visit(node.value) symbols = tracer().resolve_loaded_symbols(node) with self._push_symbols(): # add slice to RHS to avoid propagating to it self.visit(node.slice) symbols.extend(self.symbols) if len(symbols) > 0: self.symbols.extend(symbols) return # TODO: this path lacks coverage try: slice = resolve_slice_to_constant(node) if slice is None or isinstance(slice, ast.Name): return with self._push_symbols(): self.visit(node.value) symbols = self.symbols if len(symbols) != 1 or symbols[0] is None: return ns = self._get_attr_or_subscript_namespace(node) if ns is None: return dsym = ns.lookup_data_symbol_by_name_this_indentation(slice, is_subscript=True) if dsym is None and isinstance(slice, int) and slice < 0: try: dsym = ns.lookup_data_symbol_by_name_this_indentation(len(ns) + slice, is_subscript=True) except TypeError: dsym = None if dsym is not None: self.symbols.append(dsym) except Exception as e: logger.warning("Exception occurred while resolving node %s: %s", ast.dump(node), e)
def _make_lval_data_symbols_old(self): symbol_edges = get_symbol_edges(self.stmt_node) should_overwrite = not isinstance(self.stmt_node, ast.AugAssign) is_function_def = isinstance(self.stmt_node, (ast.FunctionDef, ast.AsyncFunctionDef)) is_class_def = isinstance(self.stmt_node, ast.ClassDef) is_import = isinstance(self.stmt_node, (ast.Import, ast.ImportFrom)) if is_function_def or is_class_def: assert len(symbol_edges) == 1 # assert not lval_symbol_refs.issubset(rval_symbol_refs) for target, dep_node in symbol_edges: rval_deps = resolve_rval_symbols(dep_node) logger.info("create edges from %s to %s", rval_deps, target) if is_class_def: assert self.class_scope is not None class_ref = self.frame.f_locals[self.stmt_node.name] self.class_scope.obj = class_ref nbs().namespaces[id(class_ref)] = self.class_scope try: ( scope, name, obj, is_subscript, excluded_deps, ) = tracer().resolve_store_data_for_target(target, self.frame) scope.upsert_data_symbol_for_name( name, obj, rval_deps - excluded_deps, self.stmt_node, overwrite=should_overwrite, is_subscript=is_subscript, is_function_def=is_function_def, is_import=is_import, class_scope=self.class_scope, propagate=not isinstance(self.stmt_node, ast.For), ) if isinstance( self.stmt_node, (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef), ): self._handle_reactive_store(self.stmt_node) except KeyError as ke: # e.g., slices aren't implemented yet # put logging behind flag to avoid noise to user if nbs().is_develop: logger.warning( "keyerror for %s", ast.dump(target) if isinstance(target, ast.AST) else target, ) # if nbs().is_test: # raise ke except Exception as e: logger.warning("exception while handling store: %s", e) if nbs().is_test: raise e
def visit_List_or_Tuple(self, node: Union[ast.List, ast.Tuple]): resolved = tracer().resolve_loaded_symbols(node) if not resolved: # if id(node) not in tracer().node_id_to_loaded_literal_scope: # only descend if tracer failed to create literal symbol self.generic_visit(node.elts) else: self.symbols.extend(resolved)
def visit_Name(self, node: ast.Name) -> None: ref = SymbolRef(node) if self._in_kill_context: self.dead.add(ref) elif not self._skip_simple_names and ref not in self.dead: if id(node) in tracer().reactive_node_ids: ref.chain[0].is_reactive = True self.live.add(LiveSymbolRef(ref, self._module_stmt_counter))
def visit_Dict(self, node: ast.Dict): resolved = tracer().resolve_loaded_symbols(node) if not resolved: # if id(node) not in tracer().node_id_to_loaded_literal_scope: # only descend if tracer failed to create literal symbol self.generic_visit(node.keys) self.generic_visit(node.values) else: self.symbols.extend(resolved)
def _handle_delete(self): assert isinstance(self.stmt_node, ast.Delete) for target in self.stmt_node.targets: try: scope, name, _, is_subscript = tracer().resolve_store_or_del_data_for_target(target, self.frame, ctx=ast.Del()) scope.delete_data_symbol_for_name(name, is_subscript=is_subscript) except KeyError as e: # this will happen if, e.g., a __delitem__ triggered a call # logger.info("got key error while trying to handle %s: %s", ast.dump(self.stmt_node), e) logger.info("got key error: %s", e)
def upsert_data_symbol_for_name( self, name: SupportedIndexType, obj: Any, deps: Iterable[DataSymbol], stmt_node: ast.AST, overwrite: bool = True, is_subscript: bool = False, is_function_def: bool = False, is_import: bool = False, is_anonymous: bool = False, class_scope: Optional["Scope"] = None, symbol_type: Optional[DataSymbolType] = None, propagate: bool = True, implicit: bool = False, ) -> DataSymbol: symbol_type = symbol_type or self._resolve_symbol_type( overwrite=overwrite, is_subscript=is_subscript, is_function_def=is_function_def, is_import=is_import, is_anonymous=is_anonymous, class_scope=class_scope, ) deps = set(deps) # make a copy since we mutate it (see below fixme) dsym, prev_dsym, prev_obj = self._upsert_data_symbol_for_name_inner( name, obj, deps, # FIXME: this updates deps, which is super super hacky symbol_type, stmt_node, implicit=implicit, ) dsym.update_deps( deps, prev_obj=prev_obj, overwrite=overwrite, propagate=propagate, refresh=not implicit, ) tracer().this_stmt_updated_symbols.add(dsym) return dsym
def __call__( self, new_deps: Set["DataSymbol"], mutated: bool, propagate_to_namespace_descendents: bool, refresh: bool, ) -> None: # in most cases, mutated implies that we should propagate to namespace descendents, since we # do not know how the mutation affects the namespace members. The exception is for specific # known events such as 'list.append()' or 'list.extend()' since we know these do not update # the namespace members. logger.warning( "updated sym %s (containing scope %s) with children %s", self.updated_sym, self.updated_sym.containing_scope, self.updated_sym.children, ) directly_updated_symbols = (nbs().aliases[self.updated_sym.obj_id] if mutated else {self.updated_sym}) directly_updated_symbols |= self._maybe_get_adhoc_pandas_updated_syms() self._collect_updated_symbols_and_refresh_namespaces( directly_updated_symbols, propagate_to_namespace_descendents) logger.warning( "for symbol %s: mutated=%s; updated_symbols=%s", self.updated_sym, mutated, directly_updated_symbols, ) updated_symbols_with_ancestors = set(self.seen) logger.warning( "all updated symbols for symbol %s: %s", self.updated_sym, updated_symbols_with_ancestors, ) tracer().this_stmt_updated_symbols |= self.seen if refresh: for updated_sym in directly_updated_symbols: if not updated_sym.is_stale and updated_sym is not self.updated_sym: updated_sym.refresh() self.seen |= new_deps # don't propagate to stuff on RHS for dsym in updated_symbols_with_ancestors: self._propagate_staleness_to_deps(dsym, skip_seen_check=True)
def _handle_assign_target(self, target: ast.AST, value: ast.AST): if isinstance(target, (ast.List, ast.Tuple)): rhs_namespace = nbs().namespaces.get(tracer().saved_assign_rhs_obj_id, None) if rhs_namespace is None: self._handle_assign_target_tuple_unpack_from_deps(target, resolve_rval_symbols(value)) else: self._handle_assign_target_tuple_unpack_from_namespace(target, rhs_namespace) else: self._handle_assign_target_for_deps( target, resolve_rval_symbols(value), maybe_fixup_literal_namespace=True )
def _get_calling_ast_node(self) -> Optional[ast.AST]: if isinstance(self.stmt_node, (ast.FunctionDef, ast.AsyncFunctionDef)): if self.name in ("__getitem__", "__setitem__", "__delitem__"): # TODO: handle case where we're looking for a subscript for the calling node return None for decorator in self.stmt_node.decorator_list: if isinstance(decorator, ast.Name) and decorator.id == "property": # TODO: handle case where we're looking for an attribute for the calling node return None lexical_call_stack = tracer().lexical_call_stack if len(lexical_call_stack) == 0: return None prev_node_id_in_cur_frame_lexical = lexical_call_stack.get_field( "prev_node_id_in_cur_frame_lexical" ) caller_ast_node = tracer().ast_node_by_id.get( prev_node_id_in_cur_frame_lexical, None ) if caller_ast_node is None or not isinstance(caller_ast_node, ast.Call): return None return caller_ast_node
def handle_dependencies(self): if not nbs().dependency_tracking_enabled: return for mutated_obj_id, mutation_event, mutation_arg_dsyms, mutation_arg_objs in tracer().mutations: logger.info("mutation %s %s %s %s", mutated_obj_id, mutation_event, mutation_arg_dsyms, mutation_arg_objs) update_usage_info(mutation_arg_dsyms) if mutation_event == MutationEvent.arg_mutate: for mutated_sym in mutation_arg_dsyms: if mutated_sym is None: continue # TODO: happens when module mutates args # should we add module as a dep in this case? mutated_sym.update_deps(set(), overwrite=False, mutated=True) continue # NOTE: this next block is necessary to ensure that we add the argument as a namespace child # of the mutated symbol. This helps to avoid propagating through to dependency children that are # themselves namespace children. if mutation_event == MutationEvent.list_append and len(mutation_arg_objs) == 1: namespace_scope = nbs().namespaces.get(mutated_obj_id, None) mutated_sym = nbs().get_first_full_symbol(mutated_obj_id) if mutated_sym is not None: mutated_obj = mutated_sym.get_obj() mutation_arg_obj = next(iter(mutation_arg_objs)) # TODO: replace int check w/ more general "immutable" check if mutation_arg_obj is not None: if namespace_scope is None: namespace_scope = NamespaceScope( mutated_obj, mutated_sym.name, parent_scope=mutated_sym.containing_scope ) logger.info("upsert %s to %s", len(mutated_obj) - 1, namespace_scope) namespace_scope.upsert_data_symbol_for_name( len(mutated_obj) - 1, mutation_arg_obj, set(), self.stmt_node, overwrite=False, is_subscript=True, propagate=False ) # TODO: add mechanism for skipping namespace children in case of list append update_usage_info(nbs().aliases[mutated_obj_id]) for mutated_sym in nbs().aliases[mutated_obj_id]: mutated_sym.update_deps(mutation_arg_dsyms, overwrite=False, mutated=True) if self._contains_lval(): self._make_lval_data_symbols() elif isinstance(self.stmt_node, ast.Delete): self._handle_delete() else: # make sure usage timestamps get bumped resolve_rval_symbols(self.stmt_node)
def make_child_scope(self, scope_name) -> "Scope": symtab = tracer().cur_cell_symtab if self.is_global else self.symtab child_symtab = None if symtab is not None: try: sym = symtab.lookup(scope_name) if sym.is_namespace(): child_symtab = sym.get_namespace() except KeyError: pass except ValueError: pass return Scope(scope_name, parent_scope=self, symtab=child_symtab)
def visit_Attribute(self, node: ast.Attribute): if isinstance(node.value, ast.Call): self.visit(node.value) symbols = tracer().resolve_loaded_symbols(node) if len(symbols) > 0: self.symbols.extend(symbols) return # TODO: this path lacks coverage try: ns = self._get_attr_or_subscript_namespace(node) if ns is None: return dsym = ns.lookup_data_symbol_by_name_this_indentation(node.attr, is_subscript=False) if dsym is not None: self.symbols.append(dsym) except Exception as e: logger.warning("Exception occurred while resolving node %s: %s", ast.dump(node), e)
def _handle_assign_target_for_deps( self, target: ast.AST, deps: Set[DataSymbol], maybe_fixup_literal_namespace=False, ) -> None: # logger.error("upsert %s into %s", deps, tracer()._partial_resolve_ref(target)) try: ( scope, name, obj, is_subscript, excluded_deps, ) = tracer().resolve_store_data_for_target(target, self.frame) except KeyError: # e.g., slices aren't implemented yet # use suppressed log level to avoid noise to user if nbs().is_develop: logger.warning( "keyerror for %s", ast.dump(target) if isinstance(target, ast.AST) else target, ) # if nbs().is_test: # raise ke return upserted = scope.upsert_data_symbol_for_name( name, obj, deps - excluded_deps, self.stmt_node, is_subscript=is_subscript, ) self._handle_reactive_store(target) logger.info( "sym %s upserted to scope %s has parents %s", upserted, scope, upserted.parents, ) if maybe_fixup_literal_namespace: namespace_for_upsert = nbs().namespaces.get(id(obj), None) if namespace_for_upsert is not None and namespace_for_upsert.is_anonymous: namespace_for_upsert.scope_name = str(name) namespace_for_upsert.parent_scope = scope
def _propagate_staleness_to_deps(self, dsym: "DataSymbol", skip_seen_check: bool = False) -> None: if not skip_seen_check and dsym in self.seen: return self.seen.add(dsym) if (dsym not in nbs().updated_symbols and dsym not in tracer().this_stmt_updated_symbols): if dsym.should_mark_stale(self.updated_sym): dsym.fresher_ancestors.add(self.updated_sym) dsym.fresher_ancestor_timestamps.add( self.updated_sym.timestamp) dsym.required_timestamp = Timestamp.current() self._propagate_staleness_to_namespace_parents( dsym, skip_seen_check=True) self._propagate_staleness_to_namespace_children( dsym, skip_seen_check=True) for child in self._non_class_to_instance_children(dsym): logger.warning("propagate %s %s to %s", dsym, dsym.obj_id, child) self._propagate_staleness_to_deps(child)
def _handle_store_target(self, target: ast.AST, value: ast.AST, skip_namespace_check: bool = False): if isinstance(target, (ast.List, ast.Tuple)): rhs_namespace = ( None if skip_namespace_check # next branch will always return None if skip_namespace_check is true, # but we skip it anyway just for the sake of explicitness else nbs().namespaces.get(id(tracer().saved_assign_rhs_obj), None)) if rhs_namespace is None: self._handle_store_target_tuple_unpack_from_deps( target, resolve_rval_symbols(value)) else: self._handle_store_target_tuple_unpack_from_namespace( target, rhs_namespace) else: self._handle_assign_target_for_deps( target, resolve_rval_symbols(value), maybe_fixup_literal_namespace=True)
def _handle_assign_target_for_deps( self, target: ast.AST, deps: Set[DataSymbol], maybe_fixup_literal_namespace=False, ) -> None: # logger.error("upsert %s into %s", deps, tracer()._partial_resolve_ref(target)) try: scope, name, obj, is_subscript = tracer().resolve_store_or_del_data_for_target(target, self.frame) except KeyError as e: # e.g., slices aren't implemented yet # use suppressed log level to avoid noise to user logger.info("Exception: %s", e) return upserted = scope.upsert_data_symbol_for_name( name, obj, deps, self.stmt_node, is_subscript=is_subscript, ) logger.info("sym %s upserted to scope %s has parents %s", upserted, scope, upserted.parents) if maybe_fixup_literal_namespace: namespace_for_upsert = nbs().namespaces.get(id(obj), None) if namespace_for_upsert is not None and namespace_for_upsert.scope_name == NamespaceScope.ANONYMOUS: namespace_for_upsert.scope_name = str(name) namespace_for_upsert.parent_scope = scope
def current(cls) -> "Timestamp": # TODO: shouldn't have to go through nbs() singleton to get the cell counter, # but the dependency structure prevents us from importing from nbsafety.data_model.code_cell return cls(nbs().cell_counter(), tracer().module_stmt_counter())
def visit_Call(self, node): if isinstance(node.func, (ast.Attribute, ast.Subscript)): self.visit(node.func) self.symbols.extend(tracer().resolve_loaded_symbols(node.func)) self.symbols.extend(tracer().resolve_loaded_symbols(node)) self.generic_visit([node.args, node.keywords])
def visit_Name(self, node: ast.Name): self.symbols.extend(tracer().resolve_loaded_symbols(node))
def visit_arg(self, node: ast.arg): self.symbols.extend(tracer().resolve_loaded_symbols(node.arg))
def visit_Starred(self, node: ast.Starred): self.symbols.extend(tracer().resolve_loaded_symbols(node))