Ejemplo n.º 1
0
 def visit_Dict(self, node: ast.Dict):
     traced_keys: List[Optional[ast.expr]] = []
     traced_values: List[ast.expr] = []
     for k, v in zip(node.keys, node.values):
         is_dict_unpack = (k is None)
         if is_dict_unpack:
             traced_keys.append(None)
         else:
             with fast.location_of(k):
                 traced_keys.append(fast.Call(
                     func=self._emitter_ast(),
                     args=[TraceEvent.dict_key.to_ast(), self._get_copy_id_ast(k)],
                     keywords=fast.kwargs(
                         ret=self.visit(k),
                         value_node_id=self._get_copy_id_ast(v),
                         dict_node_id=self._get_copy_id_ast(node),
                     )
                 ))
         with fast.location_of(v):
             if is_dict_unpack:
                 key_node_id_ast = fast.NameConstant(None)
             else:
                 key_node_id_ast = self._get_copy_id_ast(k)
             traced_values.append(fast.Call(
                 func=self._emitter_ast(),
                 args=[TraceEvent.dict_value.to_ast(), self._get_copy_id_ast(v)],
                 keywords=fast.kwargs(
                     ret=self.visit(v),
                     key_node_id=key_node_id_ast,
                     dict_node_id=self._get_copy_id_ast(node),
                 )
             ))
     node.keys = traced_keys
     node.values = traced_values
     return self.visit_literal(node, should_inner_visit=False)
Ejemplo n.º 2
0
    def visit_Call(self, node: ast.Call):
        orig_node_id = id(node)

        with self.attrsub_context(node):
            if isinstance(node.func, ast.Attribute):
                node.func = self.visit_Attribute(node.func, call_context=True)
            elif isinstance(node.func, ast.Subscript):
                node.func = self.visit_Subscript(node.func, call_context=True)
            else:
                node.func = self.visit(node.func)

        # TODO: need a way to rewrite ast of subscript args,
        #  and to process these separately from outer rewrite

        node.args = self._get_replacement_args(node.args, False)
        node.keywords = self._get_replacement_args(node.keywords, True)

        # in order to ensure that the args are processed with appropriate active scope,
        # we need to make sure not to use the active namespace scope on args (in the case
        # of a function call on an ast.Attribute).
        #
        # We do so by emitting an "enter argument list", whose handler pushes the current active
        # scope while we process each argument. The "end argument list" event will then restore
        # the active scope.
        #
        # This effectively rewrites function calls as follows:
        # f(a, b, ..., c) -> trace(f, 'enter argument list')(a, b, ..., c)
        with fast.location_of(node.func):
            node.func = fast.Call(
                func=self._emitter_ast(),
                args=[TraceEvent.before_call.to_ast(), self._get_copy_id_ast(orig_node_id)],
                keywords=fast.kwargs(
                    ret=node.func,
                    call_node_id=self._get_copy_id_ast(orig_node_id),
                ),
            )

        # f(a, b, ..., c) -> trace(f(a, b, ..., c), 'exit argument list')
        with fast.location_of(node):
            node = fast.Call(
                func=self._emitter_ast(),
                args=[TraceEvent.after_call.to_ast(), self._get_copy_id_ast(orig_node_id)],
                keywords=fast.kwargs(
                    ret=node,
                    call_node_id=self._get_copy_id_ast(orig_node_id),
                ),
            )

        return self._maybe_wrap_symbol_in_before_after_tracing(node, call_context=True, orig_node_id=orig_node_id)
Ejemplo n.º 3
0
 def visit_List_or_Tuple(self, node: Union[ast.List, ast.Tuple]):
     traced_elts: List[ast.expr] = []
     is_load = isinstance(getattr(node, 'ctx', ast.Load()), ast.Load)
     saw_starred = False
     for i, elt in enumerate(node.elts):
         if isinstance(elt, ast.Starred):
             # TODO: trace starred elts too
             saw_starred = True
             traced_elts.append(elt)
             continue
         elif not is_load:
             traced_elts.append(self.visit(elt))
             continue
         with fast.location_of(elt):
             traced_elts.append(fast.Call(
                 func=self._emitter_ast(),
                 args=[
                     TraceEvent.list_elt.to_ast() if isinstance(node, ast.List) else TraceEvent.tuple_elt.to_ast(),
                     self._get_copy_id_ast(elt),
                 ],
                 keywords=fast.kwargs(
                     ret=self.visit(elt),
                     index=fast.NameConstant(None) if saw_starred else fast.Num(i),
                     container_node_id=self._get_copy_id_ast(node),
                 )
             ))
     node.elts = traced_elts
     return self.visit_literal(node, should_inner_visit=False)
Ejemplo n.º 4
0
 def _get_replacement_args(self, args, keywords: bool):
     replacement_args = []
     for arg in args:
         is_starred = isinstance(arg, ast.Starred)
         is_kwstarred = keywords and arg.arg is None
         if keywords or is_starred:
             maybe_kwarg = getattr(arg, 'value')
         else:
             maybe_kwarg = arg
         with fast.location_of(maybe_kwarg):
             with self.attrsub_context(None):
                 visited_maybe_kwarg = self.visit(maybe_kwarg)
             with self.attrsub_context(None):
                 new_arg_value = cast(ast.expr, fast.Call(
                     func=self._emitter_ast(),
                     args=[TraceEvent.argument.to_ast(), self._get_copy_id_ast(maybe_kwarg)],
                     keywords=fast.kwargs(
                         ret=visited_maybe_kwarg,
                         is_starred=fast.NameConstant(is_starred),
                         is_kwstarred=fast.NameConstant(is_kwstarred)
                     ),
                 ))
         if keywords or is_starred:
             setattr(arg, 'value', new_arg_value)
         else:
             arg = new_arg_value
         replacement_args.append(arg)
     return replacement_args
Ejemplo n.º 5
0
    def visit_Attribute_or_Subscript(
        self,
        node: Union[ast.Attribute, ast.Subscript],
        attr_or_sub: ast.expr,
        call_context: bool = False
    ):
        orig_node_id = id(node)
        with fast.location_of(node.value):
            extra_args: List[ast.keyword] = []
            if isinstance(node.value, ast.Name):
                extra_args = fast.kwargs(obj_name=fast.Str(node.value.id))

            with self.attrsub_context(node):
                node.value = fast.Call(
                    func=self._emitter_ast(),
                    args=[
                        TraceEvent.subscript.to_ast() if isinstance(node, ast.Subscript) else TraceEvent.attribute.to_ast(),
                        self._get_copy_id_ast(node.value)
                    ],
                    keywords=fast.kwargs(
                        ret=self.visit(node.value),
                        attr_or_subscript=attr_or_sub,
                        ctx=fast.Str(node.ctx.__class__.__name__),
                        call_context=fast.NameConstant(call_context),
                        top_level_node_id=self._get_copy_id_ast(self._top_level_node_for_symbol)
                    ) + extra_args
                )
        # end fast.location_of(node.value)

        return self._maybe_wrap_symbol_in_before_after_tracing(node, orig_node_id=orig_node_id)
Ejemplo n.º 6
0
 def visit_Subscript(self, node: ast.Subscript, call_context=False):
     with fast.location_of(node.slice if hasattr(node.slice, 'lineno') else node.value):
         attr_or_sub = _maybe_convert_ast_subscript(node.slice)
         if isinstance(attr_or_sub, (ast.Slice, ast.ExtSlice)):
             elts = attr_or_sub.elts if isinstance(attr_or_sub, ast.Tuple) else attr_or_sub.dims  # type: ignore
             elts = [_maybe_convert_ast_subscript(elt) for elt in elts]
             attr_or_sub = fast.Tuple(elts)
     return self.visit_Attribute_or_Subscript(node, cast(ast.expr, attr_or_sub), call_context=call_context)
Ejemplo n.º 7
0
    def _maybe_wrap_symbol_in_before_after_tracing(
        self, node, call_context=False, orig_node_id=None, begin_kwargs=None, end_kwargs=None
    ):
        if self._inside_attrsub_load_chain:
            return node
        orig_node = node
        orig_node_id = orig_node_id or id(orig_node)
        begin_kwargs = begin_kwargs or {}
        end_kwargs = end_kwargs or {}

        ctx = getattr(orig_node, 'ctx', ast.Load())
        is_load = isinstance(ctx, ast.Load)

        with fast.location_of(node):
            begin_kwargs['ret'] = self._get_copy_id_ast(orig_node_id)
            if is_load:
                end_ret = orig_node
            elif isinstance(orig_node, (ast.Attribute, ast.Subscript)):
                end_ret = orig_node.value
            else:
                raise TypeError('Unsupported node type for before / after symbol tracing: %s', type(orig_node))
            end_kwargs['ret'] = end_ret
            end_kwargs['ctx'] = fast.Str(ctx.__class__.__name__)
            end_kwargs['call_context'] = fast.NameConstant(call_context)
            node = fast.Call(
                func=self._emitter_ast(),
                args=[
                    TraceEvent.after_complex_symbol.to_ast(),
                    fast.Call(
                        # this will return the node id
                        func=self._emitter_ast(),
                        args=[TraceEvent.before_complex_symbol.to_ast(), self._get_copy_id_ast(orig_node_id)],
                        keywords=fast.kwargs(**begin_kwargs),
                    )
                ],
                keywords=fast.kwargs(**end_kwargs),
            )
            if not is_load:
                if isinstance(orig_node, ast.Attribute):
                    node = fast.Attribute(
                        value=node,
                        attr=orig_node.attr,
                    )
                elif isinstance(orig_node, ast.Subscript):
                    node = fast.Subscript(
                        value=node,
                        slice=orig_node.slice,
                    )
                else:
                    logger.error(
                        'Symbol tracing stores unsupported for node %s with type %s', orig_node, type(orig_node)
                    )
                    assert False
                node.ctx = ast.Store()
        # end location_of(node)
        return node
Ejemplo n.º 8
0
 def _make_tuple_event_for(self, node: ast.AST, event: TraceEvent, orig_node_id=None, **kwargs):
     with fast.location_of(node):
         tuple_node = fast.Tuple([fast.Call(
             func=self._emitter_ast(),
             args=[event.to_ast(), self._get_copy_id_ast(orig_node_id or node)],
             keywords=[] if len(kwargs) == 0 else fast.kwargs(**kwargs),
         ), node], ast.Load())
         slc: Union[ast.Constant, ast.Num, ast.Index] = fast.Num(1)
         if sys.version_info < (3, 9):
             slc = fast.Index(slc)
         return fast.Subscript(tuple_node, slc, ast.Load())
Ejemplo n.º 9
0
 def _get_parsed_append_stmt(self,
                             stmt: 'ast.stmt',
                             ret_expr: 'ast.expr' = None) -> 'ast.stmt':
     with fast.location_of(stmt):
         ret = cast(
             ast.Expr,
             fast.parse(self._append_stmt_template.format(
                 stmt_id=id(stmt))).body[0])
         if ret_expr is not None:
             ret_value = cast(ast.Call, ret.value)
             ret_value.keywords = fast.kwargs(ret=ret_expr)
     ret.lineno = getattr(stmt, 'end_lineno', ret.lineno)
     return ret
Ejemplo n.º 10
0
 def visit_Return(self, node: ast.Return):
     with fast.location_of(node):
         node.value = fast.Call(
             func=self._emitter_ast(),
             args=[TraceEvent.after_return.to_ast(), self._get_copy_id_ast(node.value)],
             keywords=fast.kwargs(
                 ret=self._make_tuple_event_for(
                     self.visit(node.value),
                     TraceEvent.before_return,
                     orig_node_id=id(node.value),
                 ),
             ),
         )
     return node
Ejemplo n.º 11
0
 def visit_literal(self, node: Union[ast.Dict, ast.List, ast.Tuple], should_inner_visit=True):
     # TODO: what about set literals?
     maybe_visited: ast.AST = node
     if should_inner_visit:
         maybe_visited = self.generic_visit(node)
     if not isinstance(getattr(node, 'ctx', ast.Load()), ast.Load):
         return maybe_visited
     with fast.location_of(node):
         subscripted_node = self._make_tuple_event_for(
             maybe_visited, TraceEvent.before_literal, orig_node_id=id(node)
         )
         return fast.Call(
             func=self._emitter_ast(),
             args=[TraceEvent.after_literal.to_ast(), self._get_copy_id_ast(node)],
             keywords=fast.kwargs(ret=subscripted_node),
         )
Ejemplo n.º 12
0
 def visit_Assign(self, node: ast.Assign):
     new_targets = []
     for target in node.targets:
         new_targets.append(self.visit(target))
     node.targets = new_targets
     orig_value_id = id(node.value)
     with fast.location_of(node.value):
         node.value = self._make_tuple_event_for(
             self.visit(node.value), TraceEvent.before_assign_rhs, orig_node_id=orig_value_id
         )
         node.value = fast.Call(
             func=self._emitter_ast(),
             args=[TraceEvent.after_assign_rhs.to_ast(), self._get_copy_id_ast(orig_value_id)],
             keywords=fast.kwargs(ret=node.value),
         )
     return node
Ejemplo n.º 13
0
 def visit_Attribute(self, node: ast.Attribute, call_context=False):
     with fast.location_of(node.value):
         attr_node = cast(ast.Attribute, node)
         attr_or_sub = fast.Str(attr_node.attr)
     return self.visit_Attribute_or_Subscript(node, attr_or_sub, call_context=call_context)
Ejemplo n.º 14
0
 def _get_parsed_prepend_stmt(self, stmt: 'ast.stmt') -> 'ast.stmt':
     with fast.location_of(stmt):
         return fast.parse(
             self._prepend_stmt_template.format(stmt_id=id(stmt))).body[0]