コード例 #1
0
    def get_post_call_scope(self):
        old_scope = tracer().cur_frame_original_scope
        if isinstance(self.stmt_node, ast.ClassDef):
            # classes need a new scope before the ClassDef has finished executing,
            # so we make it immediately
            pending_ns = Namespace.make_child_namespace(
                old_scope, self.stmt_node.name)
            tracer().pending_class_namespaces.append(pending_ns)
            return pending_ns

        if isinstance(self.stmt_node, (ast.FunctionDef, ast.AsyncFunctionDef)):
            func_name = self.stmt_node.name
        else:
            func_name = None
        func_sym = nbs().statement_to_func_cell.get(id(self.stmt_node), None)
        if func_sym is None:
            # TODO: brittle; assumes any user-defined and traceable function will always be present; is this safe?
            return old_scope
        if not func_sym.is_function:
            msg = "got non-function symbol %s for name %s" % (
                func_sym.full_path,
                func_name,
            )
            if nbs().is_develop:
                raise TypeError(msg)
            else:
                logger.warning(msg)
                return old_scope
        if not self.finished:
            func_sym.create_symbols_for_call_args()
        return func_sym.call_scope
コード例 #2
0
def test_updated_namespace_after_subscript_dep_removed():
    cells = {
        0: 'x = 5',
        1: 'd = {x: 5}',
        2: 'logging.info(d[5])',
        3: 'x = 9',
    }
    for idx, cell in cells.items():
        run_cell(cell, idx)
    response = nbs().check_and_link_multiple_cells(cells)
    assert response['stale_cells'] == [2]
    assert response['fresh_cells'] == [1]
    cells[1] = 'd = {5: 5}'
    run_cell(cells[1], 1)
    response = nbs().check_and_link_multiple_cells(cells)
    assert response['stale_cells'] == []
    assert response['fresh_cells'] == [2]
    run_cell(cells[2], 2)
    response = nbs().check_and_link_multiple_cells(cells)
    assert response['stale_cells'] == []
    assert response['fresh_cells'] == []
    run_cell(cells[0], 0)
    response = nbs().check_and_link_multiple_cells(cells)
    assert response['stale_cells'] == []
    assert response['fresh_cells'] == [], 'got %s' % response['fresh_cells']
コード例 #3
0
 def _handle_aliases(self):
     old_aliases = nbs().aliases.get(self.cached_obj_id, None)
     if old_aliases is not None:
         old_aliases.discard(self)
         if len(old_aliases) == 0:
             del nbs().aliases[self.cached_obj_id]
     nbs().aliases[self.obj_id].add(self)
コード例 #4
0
 def should_preserve_timestamp(self, prev_obj: Optional[Any]) -> bool:
     if nbs().mut_settings.exec_mode == ExecutionMode.REACTIVE:
         # always bump timestamps for reactive mode
         return False
     if nbs().mut_settings.exec_schedule == ExecutionSchedule.DAG_BASED:
         # always bump timestamps for dag schedule
         return False
     if prev_obj is None:
         return False
     if (
         nbs().blocked_reactive_timestamps_by_symbol.get(self, -1)
         == self.timestamp.cell_num
     ):
         return False
     if not self._cached_out_of_sync or self.obj_id == self.cached_obj_id:
         return True
     if self.obj is None or prev_obj is DataSymbol.NULL:
         return self.obj is None and prev_obj is DataSymbol.NULL
     obj_type = type(self.obj)
     prev_type = type(prev_obj)
     if obj_type != prev_type:
         return False
     obj_size_ubound = sizing.sizeof(self.obj)
     if obj_size_ubound > sizing.MAX_SIZE:
         return False
     cached_obj_size_ubound = sizing.sizeof(prev_obj)
     if cached_obj_size_ubound > sizing.MAX_SIZE:
         return False
     return (obj_size_ubound == cached_obj_size_ubound) and self.obj == prev_obj
コード例 #5
0
 def update_obj_ref(self, obj, refresh_cached=True):
     logger.info("%s update obj ref to %s", self, obj)
     self._tombstone = False
     self._cached_out_of_sync = True
     if (
         nbs().settings.mark_typecheck_failures_unsafe
         and self.cached_obj_type != type(obj)
     ):
         for cell in self.cells_where_live:
             cell.invalidate_typecheck_result()
     self.cells_where_shallow_live.clear()
     self.cells_where_deep_live.clear()
     self.obj = obj
     if self.cached_obj_id is not None and self.cached_obj_id != self.obj_id:
         new_ns = nbs().namespaces.get(self.obj_id, None)
         # don't overwrite existing namespace for this obj
         old_ns = nbs().namespaces.get(self.cached_obj_id, None)
         if (
             old_ns is not None
             and (
                 new_ns is None or not new_ns.max_descendent_timestamp.is_initialized
             )
             and old_ns.full_namespace_path == self.full_namespace_path
         ):
             if new_ns is None:
                 logger.info("create fresh copy of namespace %s", old_ns)
                 new_ns = old_ns.fresh_copy(obj)
             else:
                 new_ns.scope_name = old_ns.scope_name
                 new_ns.parent_scope = old_ns.parent_scope
             old_ns.transfer_symbols_to(new_ns)
         self._handle_aliases()
     if refresh_cached:
         self._refresh_cached_obj()
コード例 #6
0
    def _make_lval_data_symbols_old(self):
        symbol_edges = get_symbol_edges(self.stmt_node)
        should_overwrite = not isinstance(self.stmt_node, ast.AugAssign)
        is_function_def = isinstance(self.stmt_node,
                                     (ast.FunctionDef, ast.AsyncFunctionDef))
        is_class_def = isinstance(self.stmt_node, ast.ClassDef)
        is_import = isinstance(self.stmt_node, (ast.Import, ast.ImportFrom))
        if is_function_def or is_class_def:
            assert len(symbol_edges) == 1
            # assert not lval_symbol_refs.issubset(rval_symbol_refs)

        for target, dep_node in symbol_edges:
            rval_deps = resolve_rval_symbols(dep_node)
            logger.info("create edges from %s to %s", rval_deps, target)
            if is_class_def:
                assert self.class_scope is not None
                class_ref = self.frame.f_locals[self.stmt_node.name]
                self.class_scope.obj = class_ref
                nbs().namespaces[id(class_ref)] = self.class_scope
            try:
                (
                    scope,
                    name,
                    obj,
                    is_subscript,
                    excluded_deps,
                ) = tracer().resolve_store_data_for_target(target, self.frame)
                scope.upsert_data_symbol_for_name(
                    name,
                    obj,
                    rval_deps - excluded_deps,
                    self.stmt_node,
                    overwrite=should_overwrite,
                    is_subscript=is_subscript,
                    is_function_def=is_function_def,
                    is_import=is_import,
                    class_scope=self.class_scope,
                    propagate=not isinstance(self.stmt_node, ast.For),
                )
                if isinstance(
                        self.stmt_node,
                    (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef),
                ):
                    self._handle_reactive_store(self.stmt_node)
            except KeyError as ke:
                # e.g., slices aren't implemented yet
                # put logging behind flag to avoid noise to user
                if nbs().is_develop:
                    logger.warning(
                        "keyerror for %s",
                        ast.dump(target)
                        if isinstance(target, ast.AST) else target,
                    )
                # if nbs().is_test:
                #     raise ke
            except Exception as e:
                logger.warning("exception while handling store: %s", e)
                if nbs().is_test:
                    raise e
コード例 #7
0
 def finished_execution_hook(self):
     if self.finished:
         return
     # print('finishing stmt', self.stmt_node)
     tracer().seen_stmts.add(self.stmt_id)
     self.handle_dependencies()
     tracer().after_stmt_reset_hook()
     nbs()._namespace_gc()
コード例 #8
0
def run_cell(cell, cell_id=None, **kwargs):
    """Mocks the `change active cell` portion of the comm protocol"""
    if cell_id is not None:
        nbs().handle({
            'type': 'change_active_cell',
            'active_cell_id': cell_id
        })
    run_cell_(cell, **kwargs)
コード例 #9
0
def test_update_list_elem():
    force_subscript_symbol_creation = True
    cells = {
        0: ("""
            class Foo:
                def __init__(self):
                    self.counter = 0
                    self.dummy = 0
                    
                def inc(self):
                    self.counter += 1
            """),
        1: ("""
            lst = []
            for i in range(5):
                x = Foo()
                lst.append(x)
            """),
        2: ("""
            for foo in lst:
                foo.inc()
            """),
        3:
        "logging.info(lst)",
    }

    run_all_cells(cells)

    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == set(), "got %s" % response.fresh_cells

    cells[4] = "x.inc()"
    run_cell(cells[4], 4)

    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set(), "got %s" % response.stale_cells
    assert response.fresh_cells == {2, 3}, "got %s" % response.fresh_cells

    cells[5] = "foo.inc()"
    run_cell(cells[5], 5)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set(), "got %s" % response.stale_cells
    assert response.fresh_cells == {2, 3, 4}, "got %s" % response.fresh_cells

    if force_subscript_symbol_creation:
        cells[6] = "lst[-1]"
        run_cell(cells[6], 6)
        response = nbs().check_and_link_multiple_cells()
        assert response.stale_cells == set()
        assert response.fresh_cells == {2, 3, 4}

    run_cell(cells[4], 4)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == set(
        [2, 3, 5] + ([6] if force_subscript_symbol_creation else []))
コード例 #10
0
 def get_ref_count(self):
     if self.obj is None or self.obj is DataSymbol.NULL:
         return -1
     total = sys.getrefcount(self.obj) - 1
     total -= len(nbs().aliases[self.obj_id])
     ns = nbs().namespaces.get(self.obj_id, None)
     if ns is not None and ns.obj is not None and ns.obj is not DataSymbol.NULL:
         total -= 1
     return total
コード例 #11
0
def test_update_list_elem():
    force_subscript_symbol_creation = True
    cells = {
        0: """
class Foo:
    def __init__(self):
        self.counter = 0
        self.dummy = 0
        
    def inc(self):
        self.counter += 1""",

        1: """
lst = []
for i in range(5):
    x = Foo()
    lst.append(x)""",

        2: """
for foo in lst:
    foo.inc()""",

        3: 'logging.info(lst)',
    }

    for idx, cell in cells.items():
        run_cell(cell, idx)

    response = nbs().check_and_link_multiple_cells(cells)
    assert response['stale_cells'] == []
    assert response['fresh_cells'] == []

    cells[4] = 'x.inc()'
    run_cell(cells[4], 4)

    response = nbs().check_and_link_multiple_cells(cells)
    assert response['stale_cells'] == [], 'got %s' % response['stale_cells']
    assert response['fresh_cells'] == [2, 3]

    cells[5] = 'foo.inc()'
    run_cell(cells[5], 5)
    response = nbs().check_and_link_multiple_cells(cells)
    assert response['stale_cells'] == []
    assert response['fresh_cells'] == [2, 3, 4]

    if force_subscript_symbol_creation:
        cells[6] = 'lst[-1]'
        run_cell(cells[6], 6)
        response = nbs().check_and_link_multiple_cells(cells)
        assert response['stale_cells'] == []
        assert response['fresh_cells'] == [2, 3, 4]

    run_cell(cells[4], 4)
    response = nbs().check_and_link_multiple_cells(cells)
    assert response['stale_cells'] == []
    assert response['fresh_cells'] == [2, 3, 5] + ([6] if force_subscript_symbol_creation else [])
コード例 #12
0
def override_settings(**kwargs):
    old_settings = nbs().settings
    new_settings = old_settings._asdict()
    new_settings.update(kwargs)
    new_settings = NotebookSafetySettings(**new_settings)
    try:
        nbs().settings = new_settings
        yield
    finally:
        nbs().settings = old_settings
コード例 #13
0
 def visit(self, node: ast.AST):
     try:
         ret = super().visit(node)
         cells().current_cell().to_ast(
             override=cast(ast.Module, self.orig_to_copy_mapping[id(node)])
         )
         return ret
     except Exception as e:
         nbs().set_exception_raised_during_execution(e)
         traceback.print_exc()
         raise e
コード例 #14
0
def test_int_change_to_str_triggers_typecheck():
    run_cell("a = 1", 1)
    assert not get_cell_ids_needing_typecheck()
    run_cell("b = 2", 2)
    assert not get_cell_ids_needing_typecheck()
    run_cell("logging.info(a + b)", 3)
    assert not get_cell_ids_needing_typecheck()
    run_cell('b = "b"', 4)
    assert get_cell_ids_needing_typecheck() == {3}
    nbs().check_and_link_multiple_cells()
    assert not get_cell_ids_needing_typecheck()
    assert cells().from_id(3)._cached_typecheck_result is False
コード例 #15
0
    def handle_dependencies(self):
        if not nbs().dependency_tracking_enabled:
            return
        for mutated_obj_id, mutation_event, mutation_arg_dsyms, mutation_arg_objs in tracer().mutations:
            logger.info("mutation %s %s %s %s", mutated_obj_id, mutation_event, mutation_arg_dsyms, mutation_arg_objs)
            update_usage_info(mutation_arg_dsyms)
            if mutation_event == MutationEvent.arg_mutate:
                for mutated_sym in mutation_arg_dsyms:
                    if mutated_sym is None:
                        continue
                    # TODO: happens when module mutates args
                    #  should we add module as a dep in this case?
                    mutated_sym.update_deps(set(), overwrite=False, mutated=True)
                continue

            # NOTE: this next block is necessary to ensure that we add the argument as a namespace child
            # of the mutated symbol. This helps to avoid propagating through to dependency children that are
            # themselves namespace children.
            if mutation_event == MutationEvent.list_append and len(mutation_arg_objs) == 1:
                namespace_scope = nbs().namespaces.get(mutated_obj_id, None)
                mutated_sym = nbs().get_first_full_symbol(mutated_obj_id)
                if mutated_sym is not None:
                    mutated_obj = mutated_sym.get_obj()
                    mutation_arg_obj = next(iter(mutation_arg_objs))
                    # TODO: replace int check w/ more general "immutable" check
                    if mutation_arg_obj is not None:
                        if namespace_scope is None:
                            namespace_scope = NamespaceScope(
                                mutated_obj,
                                mutated_sym.name,
                                parent_scope=mutated_sym.containing_scope
                            )
                        logger.info("upsert %s to %s", len(mutated_obj) - 1, namespace_scope)
                        namespace_scope.upsert_data_symbol_for_name(
                            len(mutated_obj) - 1,
                            mutation_arg_obj,
                            set(),
                            self.stmt_node,
                            overwrite=False,
                            is_subscript=True,
                            propagate=False
                        )
            # TODO: add mechanism for skipping namespace children in case of list append
            update_usage_info(nbs().aliases[mutated_obj_id])
            for mutated_sym in nbs().aliases[mutated_obj_id]:
                mutated_sym.update_deps(mutation_arg_dsyms, overwrite=False, mutated=True)
        if self._contains_lval():
            self._make_lval_data_symbols()
        elif isinstance(self.stmt_node, ast.Delete):
            self._handle_delete()
        else:
            # make sure usage timestamps get bumped
            resolve_rval_symbols(self.stmt_node)
コード例 #16
0
 def init_metadata(self, parent):
     """
     Don't actually change the metadata; we just want to get the cell id
     out of the execution request.
     """
     metadata = parent.get("metadata", {})
     cell_id = metadata.get("cellId", None)
     if cell_id is not None:
         nbs().set_active_cell(cell_id)
     tags = tuple(metadata.get("tags", ()))
     nbs().set_tags(tags)
     return super().init_metadata(parent)
コード例 #17
0
def trace_messages(line_: str):
    line = line_.split()
    usage = 'Usage: %safety trace_messages [enable|disable]'
    if len(line) != 1:
        print(usage)
        return
    setting = line[0].lower()
    if setting == 'on' or setting.startswith('enable'):
        nbs().trace_messages_enabled = True
    elif setting == 'off' or setting.startswith('disable'):
        nbs().trace_messages_enabled = False
    else:
        print(usage)
コード例 #18
0
 def visit(self, node: 'ast.AST'):
     try:
         mapper = StatementMapper(nbs().statement_cache[nbs().cell_counter()], nbs().ast_node_by_id)
         orig_to_copy_mapping = mapper(node)
         # very important that the eavesdropper does not create new ast nodes for ast.stmt (but just
         # modifies existing ones), since StatementInserter relies on being able to map these
         node = AstEavesdropper(orig_to_copy_mapping).visit(node)
         node = StatementInserter(orig_to_copy_mapping).visit(node)
     except Exception as e:
         nbs().set_ast_transformer_raised(e)
         traceback.print_exc()
         raise e
     return node
コード例 #19
0
 def _propagate_staleness_to_deps(self, dsym: DataSymbol, skip_seen_check=False):
     if not skip_seen_check and dsym in self.seen:
         return
     self.seen.add(dsym)
     if dsym not in nbs().updated_symbols:
         if dsym.should_mark_stale(self.updated_sym):
             dsym.fresher_ancestors.add(self.updated_sym)
             dsym.required_cell_num = nbs().cell_counter()
             self._propagate_staleness_to_namespace_parents(dsym, skip_seen_check=True)
             self._propagate_staleness_to_namespace_children(dsym, skip_seen_check=True)
     for child in self._non_class_to_instance_children(dsym):
         logger.warning('propagate %s %s to %s', dsym, dsym.obj_id, child)
         self._propagate_staleness_to_deps(child)
コード例 #20
0
def trace_messages(line_: str) -> None:
    line = line_.split()
    usage = "Usage: %safety trace_messages [enable|disable]"
    if len(line) != 1:
        logger.warning(usage)
        return
    setting = line[0].lower()
    if setting == "on" or setting.startswith("enable"):
        nbs().trace_messages_enabled = True
    elif setting == "off" or setting.startswith("disable"):
        nbs().trace_messages_enabled = False
    else:
        logger.warning(usage)
コード例 #21
0
def set_highlights(cmd: str, rest: str) -> None:
    usage = "Usage: %safety [hls|nohls]"
    if cmd == "hls":
        nbs().mut_settings.highlights_enabled = True
    elif cmd == "nohls":
        nbs().mut_settings.highlights_enabled = False
    else:
        rest = rest.lower()
        if rest == "on" or rest.startswith("enable"):
            nbs().mut_settings.highlights_enabled = True
        elif rest == "off" or rest.startswith("disable"):
            nbs().mut_settings.highlights_enabled = False
        else:
            logger.warning(usage)
コード例 #22
0
def set_highlights(cmd: str, rest: str):
    usage = 'Usage: %safety [hls|nohls]'
    if cmd == 'hls':
        nbs().mut_settings.highlights_enabled = True
    elif cmd == 'nohls':
        nbs().mut_settings.highlights_enabled = False
    else:
        rest = rest.lower()
        if rest == 'on' or rest.startswith('enable'):
            nbs().mut_settings.highlights_enabled = True
        elif rest == 'off' or rest.startswith('disable'):
            nbs().mut_settings.highlights_enabled = False
        else:
            print(usage)
コード例 #23
0
def test_dict_clear():
    cells = {
        0: "d = {0: 0}",
        1: "logging.info(d[0])",
        2: "logging.info(d)",
    }
    run_all_cells(cells)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == set()
    run_cell("d.clear()", 3)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == {2}
コード例 #24
0
def test_list_clear():
    cells = {
        0: "lst = [0]",
        1: "logging.info(lst[0])",
        2: "logging.info(lst)",
    }
    run_all_cells(cells)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == set()
    run_cell("lst.clear()", 3)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == {2}
コード例 #25
0
def test_increment_by_same_amount():
    cells = {
        0: "x = 2",
        1: "y = x + 1",
        2: "logging.info(y)",
    }
    run_all_cells(cells)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == set()
    run_cell("x = 3", 0)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == {2}
    assert response.fresh_cells == {1}
コード例 #26
0
def test_liveness_skipped_for_simple_assignment_involving_aliases():
    cells = {
        0: "lst = [1, 2, 3]",
        1: "lst2 = lst",
        2: "lst.append(4)",
    }
    run_all_cells(cells)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == set()
    run_cell("lst = [1, 2, 3, 4]", 3)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == {1, 2}, "got %s" % response.fresh_cells
コード例 #27
0
def test_adhoc_pandas_series_update():
    cells = {
        0: "import pandas as pd",
        1: "df = pd.DataFrame({1: [2, 3], 4: [5, 7]})",
        2: 'df["foo"] = [8, 9]',
        3: "df.foo.dropna(inplace=True)",
    }
    run_all_cells(cells)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == set()
    run_cell('df["foo"] = [8, 9]', 4)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == {3}
コード例 #28
0
def test_list_extend():
    cells = {
        0: "lst = [0, 1]",
        1: "x = lst[1] + 1",
        2: "logging.info(x)",
        3: "lst.extend([2, 3, 4])",
    }
    run_all_cells(cells)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == set()
    run_cell("lst[1] += 42", 4)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == {2}
    assert response.fresh_cells == {1}
コード例 #29
0
def test_equal_dict_update_does_induce_fresh_cell():
    cells = {
        0: 'x = {"foo": 42, "bar": 43}',
        1: 'y = dict(set(x.items()) | set({"baz": 44}.items()))',
        2: "logging.info(y)",
        3: "y = dict(y.items())",
    }
    run_all_cells(cells)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == {2}, "got %s" % response.fresh_cells
    run_cell('y = {"foo": 99}', 4)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == {2, 3}
コード例 #30
0
def test_equal_list_update_does_induce_fresh_cell():
    cells = {
        0: 'x = ["f"] + ["o"] * 10',
        1: 'y = x + list("bar")',
        2: "logging.info(y)",
        3: 'y = list("".join(y))',
    }
    run_all_cells(cells)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == {2}
    run_cell('y = ("f",)', 4)
    response = nbs().check_and_link_multiple_cells()
    assert response.stale_cells == set()
    assert response.fresh_cells == {2, 3}