示例#1
0
    def __init__(self,
                 database,
                 database_name,
                 repository_directory: str = ""):
        self.db = DB(database, database_name, assertions=True)
        self.scope_vars: Dict[str, Callable] = {
            "commands": self.help,
            "runs": self.runs,
            "issues": self.issues,
            "set_run": self.set_run,
            "set_issue": self.set_issue,
            "show": self.show,
            "trace": self.trace,
            "next": self.next_cursor_location,
            "n": self.next_cursor_location,
            "prev": self.prev_cursor_location,
            "p": self.prev_cursor_location,
            "expand": self.expand,
            "branch": self.branch,
            "list": self.list_source_code,
        }
        self.repository_directory = repository_directory or os.getcwd()

        self.current_issue_id: int = -1
        self.sources: Set[str] = set()
        self.sinks: Set[str] = set()
        # Tuples representing the trace of the current issue
        self.trace_tuples: List[TraceTuple] = []
        # Active trace frame of the current trace
        self.current_trace_frame_index: int = -1
        self.root_trace_frame_index: int = -1
        # The current issue id when 'trace' was last run
        self.trace_tuples_id: int = -1
 def setUp(self) -> None:
     self.db = DB("memory")
     self.interactive = Interactive("memory", "")
     self.interactive.db = self.db  # we need the tool to refer to the same db
     self.stdout = StringIO()
     self.stderr = StringIO()
     sys.stdout = self.stdout  # redirect output
     sys.stderr = self.stderr  # redirect output
示例#3
0
def cli(ctx: click.Context, repository: str, database_name: str,
        database_engine: DBType):
    ctx.obj = Context(
        repository=repository,
        database=DB(database_engine, database_name, assertions=True),
        parser_class=Parser,
    )
    logger.debug(f"Context: {ctx.obj}")
示例#4
0
    def _save(self, database: DB, cls, pk_gen: PrimaryKeyGenerator):
        # We sort keys because bulk insert uses executemany, but it can only
        # group together sequential items with the same keys. If we are scattered
        # then it does far more executemany calls, and it kills performance.
        with database.make_session() as session:
            items = sorted(
                cls.prepare(session, pk_gen, consume(self.saving[cls.__name__])),
                key=lambda k: list(k.keys()),
            )

        # bulk_insert_mappings should only be used for new objects.
        # To update an existing object, just modify its attribute(s)
        # and call session.commit()
        for group in split_every(self.BATCH_SIZE, items):
            with database.make_session() as session:
                session.bulk_insert_mappings(cls, group, render_nulls=True)
                session.commit()
示例#5
0
    def __init__(self, database, database_name, repository_directory: str = ""):
        self.db = DB(database, database_name, assertions=True)
        self.scope_vars: Dict[str, Union[Callable, TraceKind]] = {
            "precondition": TraceKind.PRECONDITION,
            "postcondition": TraceKind.POSTCONDITION,
            "commands": self.help,
            "state": self.state,
            "runs": self.runs,
            "issues": self.issues,
            "set_run": self.set_run,
            "set_issue": self.set_issue,
            "show": self.show,
            "trace": self.trace,
            "next": self.next_cursor_location,
            "n": self.next_cursor_location,
            "prev": self.prev_cursor_location,
            "p": self.prev_cursor_location,
            "expand": self.expand,
            "branch": self.branch,
            "list": self.list_source_code,
            "frames": self.frames,
            "set_frame": self.set_frame,
        }
        self.repository_directory = repository_directory or os.getcwd()

        self.current_run_id: int = -1

        # Trace exploration relies on either of these
        self.current_issue_id: int = -1
        self.current_frame_id: int = -1

        self.sources: Set[str] = set()
        self.sinks: Set[str] = set()
        # Tuples representing the trace of the current issue
        self.trace_tuples: List[TraceTuple] = []
        # Active trace frame of the current trace
        self.current_trace_frame_index: int = -1
示例#6
0
def analyze(
    ctx: Context,
    run_kind,
    branch,
    commit_hash,
    job_id,
    differential_id,
    previous_issue_handles,
    previous_input,
    linemap,
    store_unused_models,
    input_file,
):
    # Store all options in the right places
    summary_blob = {
        "run_kind": run_kind,
        "compress": lambda x: x,
        "repository": ctx.repository,
        "branch": branch,
        "commit_hash": commit_hash,
        "old_linemap_file": linemap,
        "store_unused_models": store_unused_models,
    }

    if job_id is None and differential_id is not None:
        job_id = "user_input_" + str(differential_id)
    summary_blob["job_id"] = job_id

    if previous_issue_handles:
        summary_blob["previous_issue_handles"] = AnalysisOutput.from_file(
            previous_issue_handles)
    elif previous_input:
        previous_input = AnalysisOutput.from_file(previous_input)

    # Construct pipeline
    input_files = (AnalysisOutput.from_file(input_file), previous_input)
    pipeline_steps = [
        Parser(),
        ModelGenerator(),
        TrimTraceGraph(),
        DatabaseSaver(
            DB(ctx.database_engine, ctx.database_name, assertions=True),
            PrimaryKeyGenerator(),
        ),
    ]
    pipeline = Pipeline(pipeline_steps)
    pipeline.run(input_files, summary_blob)
示例#7
0
    def save_all(self, database: DB, use_lock=False, dbname=""):
        saving_classes = [
            cls
            for cls in self.SAVING_CLASSES_ORDER
            if len(self.saving[cls.__name__]) is not 0
        ]

        item_counts = {
            cls.__name__: len(self.get_items_to_add(cls)) for cls in saving_classes
        }

        with database.make_session() as session:
            pk_gen = self.primary_key_generator.reserve(
                session, saving_classes, item_counts
            )

        for cls in saving_classes:
            log.info("Saving %s...", cls.__name__)
            self._save(database, cls, pk_gen)
示例#8
0
class Interactive:
    # @lint-ignore FBPYTHON2
    list_string = "list()"
    help_message = f"""
Commands =======================================================================

commands()      show this message
help(COMAMND)   more info about a command

runs()          list all completed static analysis runs
set_run(ID)     select a specific run for browsing issues
issues()        list all issues for the selected run
set_issue(ID)   select a specific issue for browsing a trace
show()          show info about selected issue
trace()         show a trace of the selected issue
prev()/p()      move backward within the trace
next()/n()      move forward within the trace
expand()        show alternative trace branches
branch(INDEX)   select a trace branch
{list_string}          show source code at the current trace frame
"""
    welcome_message = "Interactive issue exploration. Type 'commands()' for help."

    LEAF_NAMES = {"source", "sink", "leaf"}

    def __init__(self,
                 database,
                 database_name,
                 repository_directory: str = ""):
        self.db = DB(database, database_name, assertions=True)
        self.scope_vars: Dict[str, Callable] = {
            "commands": self.help,
            "runs": self.runs,
            "issues": self.issues,
            "set_run": self.set_run,
            "set_issue": self.set_issue,
            "show": self.show,
            "trace": self.trace,
            "next": self.next_cursor_location,
            "n": self.next_cursor_location,
            "prev": self.prev_cursor_location,
            "p": self.prev_cursor_location,
            "expand": self.expand,
            "branch": self.branch,
            "list": self.list_source_code,
        }
        self.repository_directory = repository_directory or os.getcwd()

        self.current_issue_id: int = -1
        self.sources: Set[str] = set()
        self.sinks: Set[str] = set()
        # Tuples representing the trace of the current issue
        self.trace_tuples: List[TraceTuple] = []
        # Active trace frame of the current trace
        self.current_trace_frame_index: int = -1
        self.root_trace_frame_index: int = -1
        # The current issue id when 'trace' was last run
        self.trace_tuples_id: int = -1

    def setup(self) -> Dict[str, Callable]:
        with self.db.make_session() as session:
            latest_run_id = (session.query(func.max(
                Run.id)).filter(Run.status == RunStatus.FINISHED).scalar())

        if latest_run_id.resolved() is None:
            self.warning(
                "No runs found. "
                f"Try running '{os.path.basename(sys.argv[0])} analyze' first."
            )

        self.current_run_id = latest_run_id
        print("=" * len(self.welcome_message))
        print(self.welcome_message)
        print("=" * len(self.welcome_message))
        return self.scope_vars

    def help(self):
        print(self.help_message)
        print(f"State {'=' * 74}\n")
        print(f"     Database: {self.db.dbtype}:{self.db.dbname}")
        print(f"  Current run: {self.current_run_id}")
        print(f"Current issue: {self.current_issue_id}")

    def runs(self, use_pager=None):
        pager = self._resolve_pager(use_pager)

        with self.db.make_session() as session:
            runs = session.query(Run).filter(
                Run.status == RunStatus.FINISHED).all()

        run_strings = [
            "\n".join([f"Run {run.id}", f"Date: {run.date}", "-" * 80])
            for run in runs
        ]
        run_output = "\n".join(run_strings)

        pager(run_output)
        print(f"Found {len(runs)} runs.")

    def set_run(self, run_id):
        with self.db.make_session() as session:
            selected_run = (session.query(Run).filter(
                Run.status == RunStatus.FINISHED).filter(
                    Run.id == run_id).scalar())

        if selected_run is None:
            self.warning(f"Run {run_id} doesn't exist or is not finished. "
                         "Type 'runs()' for available runs.")
            return

        self.current_run_id = selected_run.id
        print(f"Set run to {run_id}.")

    def set_issue(self, issue_id):
        with self.db.make_session() as session:
            selected_issue = (session.query(IssueInstance).filter(
                IssueInstance.id == issue_id).scalar())

            if selected_issue is None:
                self.warning(f"Issue {issue_id} doesn't exist. "
                             "Type 'issues()' for available issues.")
                return

            self.sources = set(
                self._get_leaves(session, selected_issue,
                                 SharedTextKind.SOURCE))
            self.sinks = set(
                self._get_leaves(session, selected_issue, SharedTextKind.SINK))

        self.current_issue_id = selected_issue.id
        self.current_trace_frame_index = 1  # first one after the source
        print(f"Set issue to {issue_id}.")
        self.show()

    def show(self):
        """ More details about the selected issue.
        """
        if not self._verify_issue_selected():
            return

        with self.db.make_session() as session:
            issue_instance, issue = self._get_current_issue(session)
            sources = self._get_leaves(session, issue_instance,
                                       SharedTextKind.SOURCE)
            sinks = self._get_leaves(session, issue_instance,
                                     SharedTextKind.SINK)

        page.display_page(
            self._create_issue_output_string(issue_instance, issue, sources,
                                             sinks))

    def issues(
        self,
        use_pager: bool = None,
        *,
        codes: Optional[List[int]] = None,
        callables: Optional[List[str]] = None,
        filenames: Optional[List[str]] = None,
    ):
        """Lists issues for the selected run.

        Parameters (all optional):
            use_pager: bool         use a unix style pager for output
            codes: list[int]        issue codes to filter on
            callables: list[str]    callables to filter on (supports wildcards)
            filenames: list[str]    filenames to filter on (supports wildcards)

        String filters support LIKE wildcards (%, _) from SQL:
            % matches anything (like .* in regex)
            _ matches 1 character (like . in regex)

        For example:
            callables=[
                "%occurs.anywhere%",
                "%at.end",
                "at.start%",
                "etc.",
            ])
        """
        pager = self._resolve_pager(use_pager)

        with self.db.make_session() as session:
            query = (session.query(IssueInstance, Issue).filter(
                IssueInstance.run_id == self.current_run_id).join(
                    Issue, IssueInstance.issue_id == Issue.id))

            # Process filters

            if codes is not None:
                if not isinstance(codes, list):
                    self.warning("'codes' should be a list.")
                    return
                query = query.filter(Issue.code.in_(codes))

            if callables is not None:
                if not isinstance(callables, list):
                    self.warning("'callables' should be a list.")
                    return
                query = query.filter(
                    or_(*[
                        Issue.callable.like(callable) for callable in callables
                    ]))

            if filenames is not None:
                if not isinstance(filenames, list):
                    self.warning("'filenames' should be a list.")
                    return
                query = query.filter(
                    or_(*[
                        Issue.filename.like(filename) for filename in filenames
                    ]))

            issues = query.options(joinedload(IssueInstance.message)).all()
            sources_list = [
                self._get_leaves(session, issue_instance,
                                 SharedTextKind.SOURCE)
                for issue_instance, _ in issues
            ]
            sinks_list = [
                self._get_leaves(session, issue_instance, SharedTextKind.SINK)
                for issue_instance, _ in issues
            ]

        issue_strings = [
            self._create_issue_output_string(issue_instance, issue, sources,
                                             sinks)
            for (
                issue_instance,
                issue), sources, sinks in zip(issues, sources_list, sinks_list)
        ]
        issue_output = f"\n{'-' * 80}\n".join(issue_strings)
        pager(issue_output)
        print(f"Found {len(issues)} issues with run_id {self.current_run_id}.")

    def trace(self):
        """Show a trace for the selected issue.

        The '-->' token points to the currently active trace frame within the
        trace.

        Trace output has 4 columns:
        - branches: the number of siblings a node has (including itself)
          [indicates that the trace branches into multiple paths]
        - callable: the name of the object that was called
        - port/condition: a description of the type of trace frame
          - source: where data originally comes from
          - root: the main callable through which the data propagates
          - sink: where data eventually flows to
        - location: the relative location of the trace frame's source code

        Example output:
             [branches] [callable]            [port]    [location]
             + 2        leaf                  source    module/main.py:26|4|8
         -->            module.main           root      module/helper.py:76|5|10
                        module.helper.process root      module/helper.py:76|5|10
             + 3        leaf                  sink      module/main.py:74|1|9
        """
        if not self._verify_issue_selected():
            return

        self._generate_trace()

        self._output_trace_tuples(self.trace_tuples)

    def _generate_trace(self):
        if self.trace_tuples_id == self.current_issue_id:
            return  # already generated

        with self.db.make_session() as session:
            issue_instance, issue = self._get_current_issue(session)

            postcondition_navigation = self._navigate_trace_frames(
                session,
                self._initial_trace_frames(session, issue_instance.id,
                                           TraceKind.POSTCONDITION),
            )
            precondition_navigation = self._navigate_trace_frames(
                session,
                self._initial_trace_frames(session, issue_instance.id,
                                           TraceKind.PRECONDITION),
            )

        self.trace_tuples = (
            self._create_trace_tuples(reversed(postcondition_navigation)) + [
                TraceTuple(trace_frame=TraceFrame(
                    callee=issue.callable,
                    callee_port="root",
                    filename=issue_instance.filename,
                    callee_location=issue_instance.location,
                ))
            ] + self._create_trace_tuples(precondition_navigation))
        self.trace_tuples_id = self.current_issue_id
        self.root_trace_frame_index = len(postcondition_navigation)
        self.current_trace_frame_index = self.root_trace_frame_index

    def next_cursor_location(self):
        """Move cursor to the next trace frame.
        """
        if not self._verify_issue_selected():
            return

        self._generate_trace()  # make sure self.trace_tuples exists
        self.current_trace_frame_index = min(
            self.current_trace_frame_index + 1,
            len(self.trace_tuples) - 1)
        self.trace()

    def prev_cursor_location(self):
        """Move cursor to the previous trace frame.
        """
        if not self._verify_issue_selected():
            return

        self._generate_trace()  # make sure self.trace_tuples exists
        self.current_trace_frame_index = max(
            self.current_trace_frame_index - 1, 0)
        self.trace()

    def expand(self):
        """Show and select branches for a branched trace.
        - [*] signifies the current branch that is selected

        Example output:

        Suppose we have the trace output:
             [branches] [callable]            [port]    [location]
         --> + 2        leaf                  source    module/main.py:26|4|8
                        module.main           root      module/helper.py:76|5|10
                        module.helper.process root      module/helper.py:76|5|10
             + 3        leaf                  sink      module/main.py:74|1|9

        Calling expand will result in the output:
        [*] leaf
                [0 hops: source]
                [module/main.py:26|4|8]
        [1] module.helper.preprocess
                [1 hops: source]
                [module/main.py:21|4|8]
        """
        if not self._verify_issue_selected(
        ) or not self._verify_multiple_branches():
            return

        current_trace_tuple = self.trace_tuples[self.current_trace_frame_index]
        filter_leaves = (self.sources if current_trace_tuple.trace_frame.kind
                         == TraceKind.POSTCONDITION else self.sinks)

        with self.db.make_session() as session:
            branches = self._get_trace_frame_branches(session)
            leaves_strings = [
                ", ".join([
                    leaf.contents for leaf in frame.leaves
                    if leaf.contents in filter_leaves
                ]) for frame in branches
            ]
            self._output_trace_expansion(branches, leaves_strings)

    def branch(self, selected_index: int) -> None:
        """Selects a branch when there are multiple possible traces to follow.

        The trace output that follows includes the new branch and its children
        frames.

        Parameters:
            selected_index: int    branch index from expand() output
        """
        if not self._verify_issue_selected(
        ) or not self._verify_multiple_branches():
            return

        with self.db.make_session() as session:
            branches = self._get_trace_frame_branches(session)

            if selected_index < 0 or selected_index >= len(branches):
                self.warning(
                    "Branch index out of bounds "
                    f"(expected 0-{len(branches) - 1} but got {selected_index})."
                )
                return

            new_navigation = self._navigate_trace_frames(
                session, branches, selected_index)

        new_trace_tuples = self._create_trace_tuples(new_navigation)

        if self._is_before_root():
            new_trace_tuples.reverse()
            self.trace_tuples = (
                new_trace_tuples +
                self.trace_tuples[self.current_trace_frame_index + 1:])

            # If length of prefix changes, it will change some indices
            trace_frame_index_delta = (len(new_navigation) -
                                       self.current_trace_frame_index - 1)
            self.current_trace_frame_index += trace_frame_index_delta
            self.root_trace_frame_index += trace_frame_index_delta
        else:
            self.trace_tuples = (
                self.trace_tuples[:self.current_trace_frame_index] +
                new_trace_tuples)

        self.trace()

    def list_source_code(self, context: int = 5) -> None:
        """Show source code around the current trace frame location.

        Parameters:
            context: int    number of lines to show above and below trace location
                            (default: 5)
        """
        if not self._verify_issue_selected():
            return

        self._generate_trace()

        current_trace_frame = self.trace_tuples[
            self.current_trace_frame_index].trace_frame

        filename = os.path.join(self.repository_directory,
                                current_trace_frame.filename)
        file_lines: List[str] = []

        try:
            # Use readlines instead of enumerate(file) because mock_open
            # doesn't support __iter__ until python 3.7.1.
            with open(filename, "r") as file:
                file_lines = file.readlines()
        except FileNotFoundError:
            self.warning(f"Couldn't open {filename}.")
            return

        self._output_file_lines(current_trace_frame, file_lines, context)

    def warning(self, message: str) -> None:
        print(message, file=sys.stderr)

    def _get_trace_frame_branches(self, session: Session) -> List[TraceFrame]:
        delta_from_parent = 1 if self._is_before_root() else -1
        parent_index = self.current_trace_frame_index + delta_from_parent

        if parent_index == self.root_trace_frame_index:
            kind = (TraceKind.POSTCONDITION
                    if self._is_before_root() else TraceKind.PRECONDITION)
            return self._initial_trace_frames(session, self.current_issue_id,
                                              kind)

        parent_trace_frame = self.trace_tuples[parent_index].trace_frame
        return self._next_trace_frames(session, parent_trace_frame)

    def _is_before_root(self) -> bool:
        return self.current_trace_frame_index < self.root_trace_frame_index

    def _current_branch_index(self, branches: List[TraceFrame]) -> int:
        selected_branch_id = int(
            self.trace_tuples[self.current_trace_frame_index].trace_frame.id)
        for i, branch in enumerate(branches):
            if selected_branch_id == int(branch.id):
                return i
        return -1

    def _output_file_lines(self, trace_frame: TraceFrame,
                           file_lines: List[str], context: int) -> None:
        print(f"{trace_frame.filename}:{trace_frame.callee_location}")
        center_line_number = trace_frame.callee_location.line_no
        line_number_width = len(str(center_line_number + context))

        for i in range(
                max(center_line_number - context, 1),
                min(center_line_number + context, len(file_lines)) + 1,
        ):
            line = file_lines[i - 1]

            prefix = " --> " if i == center_line_number else " " * 5
            prefix += f"{i:<{line_number_width}} "
            if sys.stdout.isatty():
                line = highlight(
                    line,
                    get_lexer_for_filename(trace_frame.filename),
                    TerminalFormatter(),
                )
            print(f"{prefix} {line}", end="")

    def _output_trace_expansion(self, trace_frames: List[TraceFrame],
                                leaves_strings: List[str]) -> None:
        for i, (frame, leaves) in enumerate(zip(trace_frames, leaves_strings)):
            prefix = ("[*]" if i == self._current_branch_index(trace_frames)
                      else f"[{i}]")
            print(f"{prefix} {frame.callee} : {frame.callee_port}")
            print(
                f"{' ' * 8}[{frame.leaf_assoc[0].trace_length} hops: {leaves}]"
            )
            print(f"{' ' * 8}[{frame.filename}:{frame.callee_location}]")

    def _output_trace_tuples(self, trace_tuples):
        expand = "+ "
        max_length_callable = max(
            max(
                len(trace_tuple.trace_frame.callee)
                for trace_tuple in trace_tuples),
            len("[callable]"),
        )
        max_length_condition = max(
            max(
                len(trace_tuple.trace_frame.callee_port)
                for trace_tuple in trace_tuples),
            len("[port]"),
        )
        max_length_branches = max(
            max(
                len(str(trace_tuple.branches)) + len(expand)
                for trace_tuple in trace_tuples),
            len("[branches]"),
        )

        print(  # table header
            f"{' ' * 5}"
            f"{'[branches]':{max_length_branches}}"
            f" {'[callable]':{max_length_callable}}"
            f" {'[port]':{max_length_condition}}"
            f" [location]")

        for i, trace_tuple in enumerate(trace_tuples):
            prefix = "-->" if i == self.current_trace_frame_index else " " * 3

            if trace_tuple.missing:
                output_string = (
                    f" {prefix}"
                    f" [Missing trace frame: {trace_tuple.trace_frame.callee}:"
                    f"{trace_tuple.trace_frame.callee_port}]")
            else:
                branches_string = (
                    f"{expand}"
                    f"{str(trace_tuple.branches):{max_length_branches - len(expand)}}"
                    if trace_tuple.branches > 1 else " " * max_length_branches)
                output_string = (
                    f" {prefix}"
                    f" {branches_string}"
                    f" {trace_tuple.trace_frame.callee:{max_length_callable}}"
                    f" {trace_tuple.trace_frame.callee_port:{max_length_condition}}"
                    f" {trace_tuple.trace_frame.filename}"
                    f":{trace_tuple.trace_frame.callee_location}")

            print(output_string)

    def _create_trace_tuples(self, navigation):
        return [
            TraceTuple(
                trace_frame=trace_frame,
                branches=branches,
                missing=trace_frame.caller is None,
            ) for trace_frame, branches in navigation
        ]

    def _initial_trace_frames(self, session, issue_instance_id, kind):
        return (session.query(TraceFrame).filter(
            TraceFrame.issue_instances.any(id=issue_instance_id)).filter(
                TraceFrame.kind == kind).join(TraceFrame.leaf_assoc).group_by(
                    TraceFrame.id).order_by(TraceFrameLeafAssoc.trace_length,
                                            TraceFrame.callee_location).all())

    def _navigate_trace_frames(self,
                               session: Session,
                               initial_trace_frames: List[TraceFrame],
                               index: int = 0) -> List[Tuple[TraceFrame, int]]:
        if not initial_trace_frames:
            return []

        trace_frames = [(initial_trace_frames[index],
                         len(initial_trace_frames))]
        while not self._is_leaf(trace_frames[-1]):
            trace_frame, branches = trace_frames[-1]
            next_nodes = self._next_trace_frames(session, trace_frame)

            if len(next_nodes) == 0:
                # Denote a missing frame by setting caller to None
                trace_frames.append((
                    TraceFrame(  # pyre-ignore: T41318465
                        callee=trace_frame.callee,
                        callee_port=trace_frame.callee_port,
                        caller=None,
                    ),
                    0,
                ))
                return trace_frames

            trace_frames.append((next_nodes[0], len(next_nodes)))
        return trace_frames

    def _is_leaf(self, node: Tuple[TraceFrame, int]) -> bool:
        trace_frame, branches = node
        return trace_frame.callee_port in self.LEAF_NAMES

    def _next_trace_frames(self, session, trace_frame):
        results = (
            session.query(TraceFrame).filter(
                TraceFrame.run_id == self.current_run_id).filter(
                    TraceFrame.caller !=
                    TraceFrame.callee)  # skip recursive calls for now
            .filter(TraceFrame.caller == trace_frame.callee).filter(
                TraceFrame.caller_port == trace_frame.callee_port).join(
                    TraceFrame.leaf_assoc).group_by(TraceFrame.id).order_by(
                        TraceFrameLeafAssoc.trace_length,
                        TraceFrame.callee_location).all())
        filter_leaves = (self.sources if trace_frame.kind
                         == TraceKind.POSTCONDITION else self.sinks)
        filtered_results = [
            frame for frame in results if filter_leaves.intersection(
                {leaf.contents
                 for leaf in frame.leaves})
        ]
        return filtered_results

    def _create_issue_output_string(self, issue_instance, issue, sources,
                                    sinks):
        sources_output = f"\n{' ' * 10}".join(sources)
        sinks_output = f"\n{' ' * 10}".join(sinks)
        return "\n".join([
            f"Issue {issue_instance.id}",
            f"    Code: {issue.code}",
            f" Message: {issue_instance.message.contents}",
            f"Callable: {issue.callable}",
            f" Sources: {sources_output if sources_output else 'No sources'}",
            f"   Sinks: {sinks_output if sinks_output else 'No sinks'}",
            (f"Location: {issue_instance.filename}"
             f":{SourceLocation.to_string(issue_instance.location)}"),
        ])

    def _resolve_pager(self, use_pager):
        use_pager = sys.stdout.isatty() if use_pager is None else use_pager
        return page.page if use_pager else page.display_page

    def _get_current_issue(self, session):
        return (session.query(
            IssueInstance,
            Issue).filter(IssueInstance.id == self.current_issue_id).join(
                Issue, IssueInstance.issue_id == Issue.id).options(
                    joinedload(IssueInstance.message)).first())

    def _get_leaves(self, session: Session, issue_instance: IssueInstance,
                    kind: SharedTextKind) -> List[str]:
        return [
            leaf for leaf, in session.query(distinct(SharedText.contents)).
            join(SharedText.shared_text_issue_instance).filter(
                SharedText.issue_instances.any(id=issue_instance.id)).filter(
                    SharedText.kind == kind).all()
        ]

    def _verify_issue_selected(self) -> bool:
        if self.current_issue_id == -1:
            self.warning("Use 'set_issue(ID)' to select an issue first.")
            return False
        return True

    def _verify_multiple_branches(self) -> bool:
        self._generate_trace()  # make sure self.trace_tuples exists
        current_trace_tuple = self.trace_tuples[self.current_trace_frame_index]
        if current_trace_tuple.branches < 2:
            self.warning("This trace frame has no alternate branches to take.")
            return False
        return True
class InteractiveTest(TestCase):
    def setUp(self) -> None:
        self.db = DB("memory")
        self.interactive = Interactive("memory", "")
        self.interactive.db = self.db  # we need the tool to refer to the same db
        self.stdout = StringIO()
        self.stderr = StringIO()
        sys.stdout = self.stdout  # redirect output
        sys.stderr = self.stderr  # redirect output

    def tearDown(self) -> None:
        sys.stdout = sys.__stdout__  # reset redirect
        sys.stderr = sys.__stderr__  # reset redirect

    def _clear_stdout(self):
        self.stdout = StringIO()
        sys.stdout = self.stdout

    def _add_to_session(self, session, data):
        if not isinstance(data, list):
            session.add(data)
            return

        for row in data:
            session.add(row)

    def _generic_issue(
        self, id: int = 1, callable: str = "call1", filename: str = "file.py"
    ) -> Issue:
        return Issue(  # pyre-ignore: T41318465
            id=id,
            handle=str(id),
            first_seen=datetime.now(),
            code=1000 + id - 1,
            callable=callable,
            filename=filename,
        )

    def _generic_issue_instance(
        self, id: int = 1, run_id: int = 1, issue_id: int = 1
    ) -> IssueInstance:
        return IssueInstance(  # pyre-ignore: T41318465
            id=id,
            run_id=run_id,
            message_id=1,
            filename="file.py",
            location=SourceLocation(1, 2, 3),
            issue_id=issue_id,
        )

    def testListIssuesBasic(self):
        issues = [
            self._generic_issue(id=1, callable="module.function1"),
            self._generic_issue(id=2, callable="module.function2"),
        ]

        message = SharedText(id=1, contents="message1")
        run = Run(id=1, date=datetime.now())

        issue_instance = self._generic_issue_instance()

        with self.db.make_session() as session:
            self._add_to_session(session, issues)
            session.add(message)
            session.add(run)
            session.add(issue_instance)
            session.commit()

        self.interactive.setup()
        self.interactive.issues()
        output = self.stdout.getvalue().strip()

        self.assertIn("Issue 1", output)
        self.assertIn("Code: 1000", output)
        self.assertIn("Message: message1", output)
        self.assertIn("Callable: module.function1", output)
        self.assertIn("Location: file.py:1|2|3", output)
        self.assertNotIn("module.function2", output)

    def testListIssuesFromLatestRun(self):
        issue = self._generic_issue()

        message = SharedText(id=1, contents="message1")
        runs = [
            Run(id=1, date=datetime.now(), status=RunStatus.FINISHED),
            Run(id=2, date=datetime.now(), status=RunStatus.FINISHED),
        ]

        issue_instances = [
            self._generic_issue_instance(id=1, run_id=1),
            self._generic_issue_instance(id=2, run_id=2),
        ]

        with self.db.make_session() as session:
            session.add(issue)
            session.add(message)
            self._add_to_session(session, runs)
            self._add_to_session(session, issue_instances)
            session.commit()

        self.interactive.setup()
        self.interactive.issues()
        output = self.stdout.getvalue().strip()

        self.assertNotIn("Issue 1", output)
        self.assertIn("Issue 2", output)

    def _list_issues_filter_setup(self):
        run = Run(id=1, date=datetime.now(), status=RunStatus.FINISHED)
        issues = [
            self._generic_issue(
                id=1, callable="module.sub.function1", filename="module/sub.py"
            ),
            self._generic_issue(
                id=2, callable="module.sub.function2", filename="module/sub.py"
            ),
            self._generic_issue(
                id=3, callable="module.function3", filename="module/__init__.py"
            ),
        ]
        issue_instances = [
            self._generic_issue_instance(id=1, issue_id=1),
            self._generic_issue_instance(id=2, issue_id=2),
            self._generic_issue_instance(id=3, issue_id=3),
        ]

        with self.db.make_session() as session:
            session.add(run)
            self._add_to_session(session, issues)
            self._add_to_session(session, issue_instances)
            session.commit()

    def testListIssuesFilterCodes(self):
        self._list_issues_filter_setup()

        self.interactive.setup()
        self.interactive.issues(codes=1000)
        stderr = self.stderr.getvalue().strip()
        self.assertIn("'codes' should be a list", stderr)

        self.interactive.issues(codes=[1000])
        output = self.stdout.getvalue().strip()
        self.assertIn("Issue 1", output)
        self.assertNotIn("Issue 2", output)
        self.assertNotIn("Issue 3", output)

        self._clear_stdout()
        self.interactive.issues(codes=[1001, 1002])
        output = self.stdout.getvalue().strip()
        self.assertNotIn("Issue 1", output)
        self.assertIn("Issue 2", output)
        self.assertIn("Issue 3", output)

    def testListIssuesFilterCallables(self):
        self._list_issues_filter_setup()

        self.interactive.setup()
        self.interactive.issues(callables="function3")
        stderr = self.stderr.getvalue().strip()
        self.assertIn("'callables' should be a list", stderr)

        self.interactive.issues(callables=["%sub%"])
        output = self.stdout.getvalue().strip()
        self.assertIn("Issue 1", output)
        self.assertIn("Issue 2", output)
        self.assertNotIn("Issue 3", output)

        self._clear_stdout()
        self.interactive.issues(callables=["%function3"])
        output = self.stdout.getvalue().strip()
        self.assertNotIn("Issue 1", output)
        self.assertNotIn("Issue 2", output)
        self.assertIn("Issue 3", output)

    def testListIssuesFilterFilenames(self):
        self._list_issues_filter_setup()

        self.interactive.setup()
        self.interactive.issues(filenames="hello.py")
        stderr = self.stderr.getvalue().strip()
        self.assertIn("'filenames' should be a list", stderr)

        self.interactive.issues(filenames=["module/s%"])
        output = self.stdout.getvalue().strip()
        self.assertIn("Issue 1", output)
        self.assertIn("Issue 2", output)
        self.assertNotIn("Issue 3", output)

        self._clear_stdout()
        self.interactive.issues(filenames=["%__init__.py"])
        output = self.stdout.getvalue().strip()
        self.assertNotIn("Issue 1", output)
        self.assertNotIn("Issue 2", output)
        self.assertIn("Issue 3", output)

    def testNoRunsFound(self):
        self.interactive.setup()
        stderr = self.stderr.getvalue().strip()
        self.assertIn("No runs found.", stderr)

    def testListRuns(self):
        runs = [
            Run(id=1, date=datetime.now(), status=RunStatus.FINISHED),
            Run(id=2, date=datetime.now(), status=RunStatus.INCOMPLETE),
            Run(id=3, date=datetime.now(), status=RunStatus.FINISHED),
        ]

        with self.db.make_session() as session:
            self._add_to_session(session, runs)
            session.commit()

        self.interactive.setup()
        self.interactive.runs()
        output = self.stdout.getvalue().strip()

        self.assertIn("Run 1", output)
        self.assertNotIn("Run 2", output)
        self.assertIn("Run 3", output)

    def testSetRun(self):
        runs = [
            Run(id=1, date=datetime.now(), status=RunStatus.FINISHED),
            Run(id=2, date=datetime.now(), status=RunStatus.FINISHED),
        ]
        issue = self._generic_issue()
        issue_instances = [
            self._generic_issue_instance(id=1, run_id=1),
            self._generic_issue_instance(id=2, run_id=2),
        ]

        with self.db.make_session() as session:
            self._add_to_session(session, runs)
            self._add_to_session(session, issue_instances)
            session.add(issue)
            session.commit()

        self.interactive.setup()
        self.interactive.set_run(1)
        self.interactive.issues()
        output = self.stdout.getvalue().strip()

        self.assertIn("Issue 1", output)
        self.assertNotIn("Issue 2", output)

    def testSetRunNonExistent(self):
        runs = [
            Run(id=1, date=datetime.now(), status=RunStatus.FINISHED),
            Run(id=2, date=datetime.now(), status=RunStatus.INCOMPLETE),
        ]

        with self.db.make_session() as session:
            self._add_to_session(session, runs)
            session.commit()

        self.interactive.setup()
        self.interactive.set_run(2)
        self.interactive.set_run(3)
        stderr = self.stderr.getvalue().strip()

        self.assertIn("Run 2 doesn't exist", stderr)
        self.assertIn("Run 3 doesn't exist", stderr)

    def testSetIssue(self):
        run = Run(id=1, date=datetime.now(), status=RunStatus.FINISHED)
        issue = self._generic_issue()
        issue_instances = [
            self._generic_issue_instance(id=1, run_id=1),
            self._generic_issue_instance(id=2, run_id=2),
            self._generic_issue_instance(id=3, run_id=3),
        ]

        with self.db.make_session() as session:
            session.add(run)
            session.add(issue)
            self._add_to_session(session, issue_instances)
            session.commit()

        self.interactive.setup()

        self.interactive.set_issue(2)
        self.interactive.show()
        stdout = self.stdout.getvalue().strip()
        self.assertNotIn("Issue 1", stdout)
        self.assertIn("Issue 2", stdout)
        self.assertNotIn("Issue 3", stdout)

        self.interactive.set_issue(1)
        self.interactive.show()
        stdout = self.stdout.getvalue().strip()
        self.assertIn("Issue 1", stdout)
        self.assertNotIn("Issue 3", stdout)

    def testSetIssueNonExistent(self):
        run = Run(id=1, date=datetime.now(), status=RunStatus.FINISHED)

        with self.db.make_session() as session:
            session.add(run)
            session.commit()

        self.interactive.setup()
        self.interactive.set_issue(1)
        stderr = self.stderr.getvalue().strip()

        self.assertIn("Issue 1 doesn't exist", stderr)

    def testGetSources(self):
        issue_instance = self._generic_issue_instance()
        sources = [
            SharedText(id=1, contents="source1", kind=SharedTextKind.SOURCE),
            SharedText(id=2, contents="source2", kind=SharedTextKind.SOURCE),
            SharedText(id=3, contents="source3", kind=SharedTextKind.SOURCE),
        ]
        assocs = [
            IssueInstanceSharedTextAssoc(shared_text_id=1, issue_instance_id=1),
            IssueInstanceSharedTextAssoc(shared_text_id=2, issue_instance_id=1),
        ]

        with self.db.make_session() as session:
            session.add(issue_instance)
            self._add_to_session(session, sources)
            self._add_to_session(session, assocs)
            session.commit()

            sources = self.interactive._get_leaves(
                session, issue_instance, SharedTextKind.SOURCE
            )

        self.assertEqual(len(sources), 2)
        self.assertIn("source1", sources)
        self.assertIn("source2", sources)

    def testGetSinks(self):
        return
        issue_instance = self._generic_issue_instance()
        sinks = [
            SharedText(id=1, contents="sink1", kind=SharedTextKind.SINK),
            SharedText(id=2, contents="sink2", kind=SharedTextKind.SINK),
            SharedText(id=3, contents="sink3", kind=SharedTextKind.SINK),
        ]
        assocs = [
            IssueInstanceSharedTextAssoc(shared_text_id=1, issue_instance_id=1),
            IssueInstanceSharedTextAssoc(shared_text_id=2, issue_instance_id=1),
        ]

        with self.db.make_session() as session:
            session.add(issue_instance)
            self._add_to_session(session, sinks)
            self._add_to_session(session, assocs)
            session.commit()

            sinks = self.interactive._get_leaves(
                session, issue_instance, SharedTextKind.SINK
            )

        self.assertEqual(len(sinks), 2)
        self.assertIn("sink1", sinks)
        self.assertIn("sink2", sinks)

    def _basic_trace_frames(self):
        return [
            TraceFrame(
                id=1,
                kind=TraceKind.PRECONDITION,
                caller="call1",
                caller_port="root",
                callee="call2",
                callee_port="param0",
                callee_location=SourceLocation(1, 1),
                filename="file.py",
                run_id=1,
            ),
            TraceFrame(
                id=2,
                kind=TraceKind.PRECONDITION,
                caller="call2",
                caller_port="param0",
                callee="leaf",
                callee_port="sink",
                callee_location=SourceLocation(1, 2),
                filename="file.py",
                run_id=1,
            ),
        ]

    def testNextTraceFrames(self):
        run = Run(id=1, date=datetime.now(), status=RunStatus.FINISHED)
        trace_frames = self._basic_trace_frames()
        sink = SharedText(id=1, contents="sink1", kind=SharedTextKind.SINK)
        assoc = TraceFrameLeafAssoc(trace_frame_id=2, leaf_id=1, trace_length=1)
        with self.db.make_session() as session:
            self._add_to_session(session, trace_frames)
            session.add(run)
            session.add(sink)
            session.add(assoc)
            session.commit()

            self.interactive.setup()
            self.interactive.sinks = {"sink1"}
            next_frames = self.interactive._next_trace_frames(session, trace_frames[0])
            self.assertEqual(len(next_frames), 1)
            self.assertEqual(int(next_frames[0].id), int(trace_frames[1].id))

    def testNextTraceFramesMultipleRuns(self):
        runs = [
            Run(id=1, date=datetime.now(), status=RunStatus.FINISHED),
            Run(id=2, date=datetime.now(), status=RunStatus.FINISHED),
        ]
        trace_frames_run1 = self._basic_trace_frames()
        trace_frames_run2 = self._basic_trace_frames()
        trace_frames_run2[0].id = 3
        trace_frames_run2[0].run_id = 2
        trace_frames_run2[1].id = 4
        trace_frames_run2[1].run_id = 2

        sink = SharedText(id=1, contents="sink1", kind=SharedTextKind.SINK)
        assocs = [
            TraceFrameLeafAssoc(trace_frame_id=2, leaf_id=1, trace_length=0),
            TraceFrameLeafAssoc(trace_frame_id=4, leaf_id=1, trace_length=0),
        ]
        with self.db.make_session() as session:
            self._add_to_session(session, trace_frames_run1)
            self._add_to_session(session, trace_frames_run2)
            self._add_to_session(session, runs)
            self._add_to_session(session, assocs)
            session.add(sink)
            session.commit()

            self.interactive.setup()
            self.interactive.sinks = {"sink1"}
            next_frames = self.interactive._next_trace_frames(
                session, trace_frames_run2[0]
            )
            self.assertEqual(len(next_frames), 1)
            self.assertEqual(int(next_frames[0].id), int(trace_frames_run2[1].id))

    def testNavigateTraceFrames(self):
        run = Run(id=1, date=datetime.now(), status=RunStatus.FINISHED)
        trace_frames = self._basic_trace_frames()
        sink = SharedText(id=1, contents="sink1", kind=SharedTextKind.SINK)
        assoc = TraceFrameLeafAssoc(trace_frame_id=2, leaf_id=1, trace_length=1)
        with self.db.make_session() as session:
            self._add_to_session(session, trace_frames)
            session.add(run)
            session.add(sink)
            session.add(assoc)
            session.commit()

            self.interactive.setup()
            self.interactive.sinks = {"sink1"}
            result = self.interactive._navigate_trace_frames(session, [trace_frames[0]])
            self.assertEqual(len(result), 2)
            self.assertEqual(int(result[0][0].id), 1)
            self.assertEqual(int(result[1][0].id), 2)

    def testCreateTraceTuples(self):
        # reverse order
        postcondition_traces = [
            (
                TraceFrame(
                    callee="call3",
                    callee_port="result",
                    filename="file3.py",
                    callee_location=SourceLocation(1, 1, 3),
                    caller="main",
                    caller_port="root",
                ),
                1,
            ),
            (
                TraceFrame(
                    callee="call2",
                    callee_port="result",
                    caller="dummy caller",
                    filename="file2.py",
                    callee_location=SourceLocation(1, 1, 2),
                ),
                2,
            ),
            (
                TraceFrame(
                    callee="leaf",
                    callee_port="source",
                    caller="dummy caller",
                    filename="file1.py",
                    callee_location=SourceLocation(1, 1, 1),
                ),
                3,
            ),
        ]
        trace_tuples = self.interactive._create_trace_tuples(postcondition_traces)
        self.assertEqual(len(trace_tuples), 3)
        self.assertEqual(
            trace_tuples,
            [
                TraceTuple(postcondition_traces[0][0], 1),
                TraceTuple(postcondition_traces[1][0], 2),
                TraceTuple(postcondition_traces[2][0], 3),
            ],
        )

    def testOutputTraceTuples(self):
        trace_tuples = [
            TraceTuple(
                trace_frame=TraceFrame(
                    callee="leaf",
                    callee_port="source",
                    filename="file1.py",
                    callee_location=SourceLocation(1, 1, 1),
                )
            ),
            TraceTuple(
                trace_frame=TraceFrame(
                    callee="call2",
                    callee_port="result",
                    filename="file2.py",
                    callee_location=SourceLocation(1, 1, 2),
                )
            ),
            TraceTuple(
                trace_frame=TraceFrame(
                    callee="call3",
                    callee_port="result",
                    filename="file3.py",
                    callee_location=SourceLocation(1, 1, 3),
                )
            ),
            TraceTuple(
                trace_frame=TraceFrame(
                    callee="main",
                    callee_port="root",
                    filename="file4.py",
                    callee_location=SourceLocation(1, 1, 4),
                )
            ),
            TraceTuple(
                trace_frame=TraceFrame(
                    callee="call4",
                    callee_port="param0",
                    filename="file4.py",
                    callee_location=SourceLocation(1, 1, 4),
                )
            ),
            TraceTuple(
                trace_frame=TraceFrame(
                    callee="call5",
                    callee_port="param1",
                    filename="file5.py",
                    callee_location=SourceLocation(1, 1, 5),
                )
            ),
            TraceTuple(
                trace_frame=TraceFrame(
                    callee="leaf",
                    callee_port="sink",
                    filename="file6.py",
                    callee_location=SourceLocation(1, 1, 6),
                )
            ),
        ]
        self.interactive.current_trace_frame_index = 1
        self.interactive._output_trace_tuples(trace_tuples)
        output = self.stdout.getvalue()
        self.assertEqual(
            output.split("\n"),
            [
                "     [branches] [callable] [port] [location]",
                "                leaf       source file1.py:1|1|1",
                " -->            call2      result file2.py:1|1|2",
                "                call3      result file3.py:1|1|3",
                "                main       root   file4.py:1|1|4",
                "                call4      param0 file4.py:1|1|4",
                "                call5      param1 file5.py:1|1|5",
                "                leaf       sink   file6.py:1|1|6",
                "",
            ],
        )

        self._clear_stdout()
        self.interactive.current_trace_frame_index = 4
        self.interactive._output_trace_tuples(trace_tuples)
        output = self.stdout.getvalue()
        self.assertEqual(
            output.split("\n"),
            [
                "     [branches] [callable] [port] [location]",
                "                leaf       source file1.py:1|1|1",
                "                call2      result file2.py:1|1|2",
                "                call3      result file3.py:1|1|3",
                "                main       root   file4.py:1|1|4",
                " -->            call4      param0 file4.py:1|1|4",
                "                call5      param1 file5.py:1|1|5",
                "                leaf       sink   file6.py:1|1|6",
                "",
            ],
        )

    def testTrace(self):
        run = Run(id=1, date=datetime.now(), status=RunStatus.FINISHED)
        issue = self._generic_issue()
        issue_instance = self._generic_issue_instance()
        trace_frames = [
            TraceFrame(
                id=1,
                kind=TraceKind.POSTCONDITION,
                caller="call1",
                caller_port="root",
                callee="leaf",
                callee_port="source",
                callee_location=SourceLocation(1, 1, 1),
                filename="file.py",
                run_id=1,
            ),
            TraceFrame(
                id=2,
                kind=TraceKind.PRECONDITION,
                caller="call1",
                caller_port="root",
                callee="leaf",
                callee_port="sink",
                callee_location=SourceLocation(1, 1, 2),
                filename="file.py",
                run_id=1,
            ),
        ]
        assocs = [
            IssueInstanceTraceFrameAssoc(trace_frame_id=1, issue_instance_id=1),
            IssueInstanceTraceFrameAssoc(trace_frame_id=2, issue_instance_id=1),
            TraceFrameLeafAssoc(trace_frame_id=1, leaf_id=1),
            TraceFrameLeafAssoc(trace_frame_id=2, leaf_id=1),
        ]

        with self.db.make_session() as session:
            session.add(run)
            session.add(issue)
            session.add(issue_instance)
            self._add_to_session(session, trace_frames)
            self._add_to_session(session, assocs)
            session.commit()

        self.interactive.setup()
        self.interactive.trace()
        stderr = self.stderr.getvalue().strip()
        self.assertIn("Use 'set_issue(ID)' to select an issue first.", stderr)

        self.interactive.set_issue(1)
        self.interactive.trace()
        output = self.stdout.getvalue().strip()
        self.assertIn("                leaf       source file.py:1|1|1", output)
        self.assertIn(" -->            call1      root   file.py:1|2|3", output)
        self.assertIn("                leaf       sink   file.py:1|1|2", output)

    def testTraceMissingFrames(self):
        run = Run(id=1, date=datetime.now(), status=RunStatus.FINISHED)
        issue = self._generic_issue()
        issue_instance = self._generic_issue_instance()
        trace_frames = [
            TraceFrame(
                id=1,
                kind=TraceKind.POSTCONDITION,
                caller="call1",
                caller_port="root",
                callee="leaf",
                callee_port="source",
                callee_location=SourceLocation(1, 1, 1),
                filename="file.py",
                run_id=1,
            ),
            TraceFrame(
                id=2,
                kind=TraceKind.PRECONDITION,
                caller="call1",
                caller_port="root",
                callee="call2",
                callee_port="param0",
                callee_location=SourceLocation(1, 1, 1),
                filename="file.py",
                run_id=1,
            ),
        ]
        assocs = [
            IssueInstanceTraceFrameAssoc(trace_frame_id=1, issue_instance_id=1),
            IssueInstanceTraceFrameAssoc(trace_frame_id=2, issue_instance_id=1),
            TraceFrameLeafAssoc(trace_frame_id=1, leaf_id=1),
            TraceFrameLeafAssoc(trace_frame_id=2, leaf_id=1),
        ]

        with self.db.make_session() as session:
            session.add(run)
            session.add(issue)
            session.add(issue_instance)
            self._add_to_session(session, trace_frames)
            self._add_to_session(session, assocs)
            session.commit()

        self.interactive.setup()
        self.interactive.set_issue(1)
        self.interactive.trace()
        stdout = self.stdout.getvalue().strip()
        self.assertIn("Missing trace frame: call2:param0", stdout)

    def testTraceCursorLocation(self):
        run = Run(id=1, date=datetime.now(), status=RunStatus.FINISHED)
        issue = self._generic_issue()
        issue_instance = self._generic_issue_instance()
        trace_frames = [
            TraceFrame(
                id=1,
                kind=TraceKind.POSTCONDITION,
                caller="call1",
                caller_port="root",
                callee="leaf",
                callee_port="source",
                callee_location=SourceLocation(1, 1),
                filename="file.py",
                run_id=1,
            ),
            TraceFrame(
                id=2,
                kind=TraceKind.PRECONDITION,
                caller="call1",
                caller_port="root",
                callee="leaf",
                callee_port="sink",
                callee_location=SourceLocation(1, 2),
                filename="file.py",
                run_id=1,
            ),
        ]
        assocs = [
            IssueInstanceTraceFrameAssoc(trace_frame_id=1, issue_instance_id=1),
            IssueInstanceTraceFrameAssoc(trace_frame_id=2, issue_instance_id=1),
            TraceFrameLeafAssoc(trace_frame_id=1, leaf_id=1),
            TraceFrameLeafAssoc(trace_frame_id=2, leaf_id=1),
        ]
        with self.db.make_session() as session:
            session.add(run)
            session.add(issue)
            session.add(issue_instance)
            self._add_to_session(session, trace_frames)
            self._add_to_session(session, assocs)
            session.commit()

        self.interactive.setup()
        self.interactive.set_issue(1)
        self.assertEqual(self.interactive.current_trace_frame_index, 1)
        self.interactive.next_cursor_location()
        self.assertEqual(self.interactive.current_trace_frame_index, 2)
        self.interactive.next_cursor_location()
        self.assertEqual(self.interactive.current_trace_frame_index, 2)
        self.interactive.prev_cursor_location()
        self.assertEqual(self.interactive.current_trace_frame_index, 1)
        self.interactive.prev_cursor_location()
        self.assertEqual(self.interactive.current_trace_frame_index, 0)
        self.interactive.prev_cursor_location()
        self.assertEqual(self.interactive.current_trace_frame_index, 0)

    def testTraceNoSinks(self):
        run = Run(id=1, date=datetime.now(), status=RunStatus.FINISHED)
        issue = self._generic_issue()
        issue_instance = self._generic_issue_instance()
        trace_frame = TraceFrame(
            id=1,
            kind=TraceKind.POSTCONDITION,
            caller="call1",
            caller_port="root",
            callee="leaf",
            callee_port="source",
            callee_location=SourceLocation(1, 1),
            filename="file.py",
            run_id=1,
        )
        source = SharedText(id=1, contents="source1", kind=SharedTextKind.SOURCE)
        assocs = [
            IssueInstanceTraceFrameAssoc(trace_frame_id=1, issue_instance_id=1),
            TraceFrameLeafAssoc(trace_frame_id=1, leaf_id=1),
        ]
        with self.db.make_session() as session:
            session.add(run)
            session.add(issue)
            session.add(issue_instance)
            session.add(trace_frame)
            session.add(source)
            self._add_to_session(session, assocs)
            session.commit()

        self.interactive.setup()
        self.interactive.sources = {"source1"}
        self.interactive.set_issue(1)
        self._clear_stdout()
        self.interactive.trace()
        self.assertEqual(
            self.stdout.getvalue().split("\n"),
            [
                "     [branches] [callable] [port] [location]",
                "                leaf       source file.py:1|1|1",
                " -->            call1      root   file.py:1|2|3",
                "",
            ],
        )

    def _set_up_branched_trace(self):
        run = Run(id=1, date=datetime.now(), status=RunStatus.FINISHED)
        issue = self._generic_issue()
        issue_instance = self._generic_issue_instance()
        messages = [
            SharedText(id=1, contents="source1", kind=SharedTextKind.SOURCE),
            SharedText(id=2, contents="sink1", kind=SharedTextKind.SINK),
        ]
        trace_frames = []
        assocs = [
            IssueInstanceSharedTextAssoc(issue_instance_id=1, shared_text_id=1),
            IssueInstanceSharedTextAssoc(issue_instance_id=1, shared_text_id=2),
        ]
        for i in range(6):
            trace_frames.append(
                TraceFrame(
                    id=i + 1,
                    caller="call1",
                    caller_port="root",
                    filename="file.py",
                    callee_location=SourceLocation(i, i, i),
                    run_id=1,
                )
            )
            if i < 2:  # 2 postconditions
                trace_frames[i].kind = TraceKind.POSTCONDITION
                trace_frames[i].callee = "leaf"
                trace_frames[i].callee_port = "source"
                assocs.append(
                    TraceFrameLeafAssoc(trace_frame_id=i + 1, leaf_id=1, trace_length=0)
                )
                assocs.append(
                    IssueInstanceTraceFrameAssoc(
                        trace_frame_id=i + 1, issue_instance_id=1
                    )
                )
            elif i < 4:
                trace_frames[i].kind = TraceKind.PRECONDITION
                trace_frames[i].callee = "call2"
                trace_frames[i].callee_port = "param2"
                assocs.append(
                    TraceFrameLeafAssoc(trace_frame_id=i + 1, leaf_id=2, trace_length=1)
                )
                assocs.append(
                    IssueInstanceTraceFrameAssoc(
                        trace_frame_id=i + 1, issue_instance_id=1
                    )
                )
            else:
                trace_frames[i].kind = TraceKind.PRECONDITION
                trace_frames[i].caller = "call2"
                trace_frames[i].caller_port = "param2"
                trace_frames[i].callee = "leaf"
                trace_frames[i].callee_port = "sink"
                assocs.append(
                    TraceFrameLeafAssoc(trace_frame_id=i + 1, leaf_id=2, trace_length=0)
                )

        with self.db.make_session() as session:
            session.add(run)
            session.add(issue)
            session.add(issue_instance)
            self._add_to_session(session, messages)
            self._add_to_session(session, trace_frames)
            self._add_to_session(session, assocs)
            session.commit()

    def testTraceBranchNumber(self):
        self._set_up_branched_trace()

        self.interactive.setup()
        self.interactive.set_issue(1)

        self.assertEqual(self.interactive.sources, {"source1"})
        self.assertEqual(self.interactive.sinks, {"sink1"})

        self.interactive.trace()
        output = self.stdout.getvalue().strip()
        self.assertIn("     + 2        leaf       source file.py:0|0|0", output)
        self.assertIn(" -->            call1      root   file.py:1|2|3", output)
        self.assertIn("     + 2        call2      param2 file.py:2|2|2", output)
        self.assertIn("     + 2        leaf       sink   file.py:4|4|4", output)

    def testExpand(self):
        self._set_up_branched_trace()

        self.interactive.setup()
        self.interactive.set_issue(1)
        # Parent at root
        self.interactive.prev_cursor_location()
        self.interactive.expand()
        output = self.stdout.getvalue().strip()
        self.assertIn(
            "[*] leaf : source\n        [0 hops: source1]\n        [file.py:0|0|0]",
            output,
        )
        self.assertIn(
            "[1] leaf : source\n        [0 hops: source1]\n        [file.py:1|1|1]",
            output,
        )

        self._clear_stdout()
        # Move to call2:param2
        self.interactive.next_cursor_location()
        self.interactive.next_cursor_location()
        self.interactive.expand()
        output = self.stdout.getvalue().strip()
        self.assertIn(
            "[*] call2 : param2\n        [1 hops: sink1]\n        [file.py:2|2|2]",
            output,
        )
        self.assertIn(
            "[1] call2 : param2\n        [1 hops: sink1]\n        [file.py:3|3|3]",
            output,
        )

        self._clear_stdout()
        # Move to leaf:sink
        self.interactive.next_cursor_location()
        self.interactive.expand()
        output = self.stdout.getvalue().strip()
        self.assertIn(
            "[*] leaf : sink\n        [0 hops: sink1]\n        [file.py:4|4|4]", output
        )
        self.assertIn(
            "[1] leaf : sink\n        [0 hops: sink1]\n        [file.py:5|5|5]", output
        )

    def testGetTraceFrameBranches(self):
        self._set_up_branched_trace()

        self.interactive.setup()
        self.interactive.set_issue(1)
        # Parent at root
        self.interactive.prev_cursor_location()

        with self.db.make_session() as session:
            branches = self.interactive._get_trace_frame_branches(session)
            self.assertEqual(len(branches), 2)
            self.assertEqual(int(branches[0].id), 1)
            self.assertEqual(int(branches[1].id), 2)

            # Parent is no longer root
            self.interactive.next_cursor_location()
            self.interactive.next_cursor_location()
            self.interactive.next_cursor_location()

            branches = self.interactive._get_trace_frame_branches(session)
            self.assertEqual(len(branches), 2)
            self.assertEqual(int(branches[0].id), 5)
            self.assertEqual(int(branches[1].id), 6)

    def testBranch(self):
        self._set_up_branched_trace()

        self.interactive.setup()
        self.interactive.set_issue(1)
        self.interactive.prev_cursor_location()

        # We are testing for the source location, which differs between branches
        self._clear_stdout()
        self.interactive.branch(1)  # location 0|0|0 -> 1|1|1
        output = self.stdout.getvalue().strip()
        self.assertIn(" --> + 2        leaf       source file.py:1|1|1", output)

        self._clear_stdout()
        self.interactive.branch(0)  # location 1|1|1 -> 0|0|0
        output = self.stdout.getvalue().strip()
        self.assertIn(" --> + 2        leaf       source file.py:0|0|0", output)

        self.interactive.next_cursor_location()
        self.interactive.next_cursor_location()

        self._clear_stdout()
        self.interactive.branch(1)  # location 2|2|2 -> 3|3|3
        output = self.stdout.getvalue().strip()
        self.assertIn(" --> + 2        call2      param2 file.py:3|3|3", output)

        self.interactive.next_cursor_location()

        self._clear_stdout()
        self.interactive.branch(1)  # location 4|4|4 -> 5|5|5
        output = self.stdout.getvalue().strip()
        self.assertIn("     + 2        call2      param2 file.py:3|3|3", output)
        self.assertIn(" --> + 2        leaf       sink   file.py:5|5|5", output)

        self.interactive.branch(2)  # location 4|4|4 -> 5|5|5
        stderr = self.stderr.getvalue().strip()
        self.assertIn("out of bounds", stderr)

    def testBranchPrefixLengthChanges(self):
        run = Run(id=1, date=datetime.now(), status=RunStatus.FINISHED)
        issue = self._generic_issue()
        issue_instance = self._generic_issue_instance()
        messages = [
            SharedText(id=1, contents="source1", kind=SharedTextKind.SOURCE),
            SharedText(id=2, contents="sink1", kind=SharedTextKind.SINK),
        ]
        trace_frames = [
            TraceFrame(
                id=1,
                kind=TraceKind.POSTCONDITION,
                caller="call1",
                caller_port="root",
                callee="leaf",
                callee_port="source",
                callee_location=SourceLocation(1, 1),
                filename="file.py",
                run_id=1,
            ),
            TraceFrame(
                id=2,
                kind=TraceKind.POSTCONDITION,
                caller="call1",
                caller_port="root",
                callee="prev_call",
                callee_port="result",
                callee_location=SourceLocation(1, 1),
                filename="file.py",
                run_id=1,
            ),
            TraceFrame(
                id=3,
                kind=TraceKind.POSTCONDITION,
                caller="prev_call",
                caller_port="result",
                callee="leaf",
                callee_port="source",
                callee_location=SourceLocation(1, 1),
                filename="file.py",
                run_id=1,
            ),
            TraceFrame(
                id=4,
                kind=TraceKind.PRECONDITION,
                caller="call1",
                caller_port="root",
                callee="leaf",
                callee_port="sink",
                callee_location=SourceLocation(1, 2),
                filename="file.py",
                run_id=1,
            ),
        ]
        assocs = [
            IssueInstanceSharedTextAssoc(issue_instance_id=1, shared_text_id=1),
            IssueInstanceSharedTextAssoc(issue_instance_id=1, shared_text_id=2),
            IssueInstanceTraceFrameAssoc(issue_instance_id=1, trace_frame_id=1),
            IssueInstanceTraceFrameAssoc(issue_instance_id=1, trace_frame_id=2),
            IssueInstanceTraceFrameAssoc(issue_instance_id=1, trace_frame_id=4),
            TraceFrameLeafAssoc(trace_frame_id=1, leaf_id=1, trace_length=0),
            TraceFrameLeafAssoc(trace_frame_id=2, leaf_id=1, trace_length=1),
            TraceFrameLeafAssoc(trace_frame_id=3, leaf_id=1, trace_length=0),
            TraceFrameLeafAssoc(trace_frame_id=4, leaf_id=2, trace_length=0),
        ]
        with self.db.make_session() as session:
            session.add(run)
            session.add(issue)
            session.add(issue_instance)
            self._add_to_session(session, messages)
            self._add_to_session(session, trace_frames)
            self._add_to_session(session, assocs)
            session.commit()

        self.interactive.setup()
        self.interactive.set_issue(1)

        self._clear_stdout()
        self.interactive.prev_cursor_location()
        self.assertEqual(
            self.stdout.getvalue().split("\n"),
            [
                "     [branches] [callable] [port] [location]",
                " --> + 2        leaf       source file.py:1|1|1",
                "                call1      root   file.py:1|2|3",
                "                leaf       sink   file.py:1|2|2",
                "",
            ],
        )

        self._clear_stdout()
        self.interactive.branch(1)
        self.assertEqual(
            self.stdout.getvalue().split("\n"),
            [
                "     [branches] [callable] [port] [location]",
                "                leaf       source file.py:1|1|1",
                " --> + 2        prev_call  result file.py:1|1|1",
                "                call1      root   file.py:1|2|3",
                "                leaf       sink   file.py:1|2|2",
                "",
            ],
        )

        self._clear_stdout()
        self.interactive.expand()
        output = self.stdout.getvalue().strip()
        self.assertIn("[*] prev_call : result", output)
        self.assertIn("        [1 hops: source1]", output)

    def testCurrentBranchIndex(self):
        trace_frames = [TraceFrame(id=1), TraceFrame(id=2), TraceFrame(id=3)]

        self.interactive.current_trace_frame_index = 0
        self.interactive.trace_tuples = [TraceTuple(trace_frame=TraceFrame(id=1))]

        self.assertEqual(0, self.interactive._current_branch_index(trace_frames))
        self.interactive.trace_tuples[0].trace_frame.id = 2
        self.assertEqual(1, self.interactive._current_branch_index(trace_frames))
        self.interactive.trace_tuples[0].trace_frame.id = 3
        self.assertEqual(2, self.interactive._current_branch_index(trace_frames))

        self.interactive.trace_tuples[0].trace_frame.id = 4
        self.assertEqual(-1, self.interactive._current_branch_index(trace_frames))

    def testVerifyIssueSelected(self):
        self.interactive.current_issue_id = -1
        self.assertFalse(self.interactive._verify_issue_selected())

        self.interactive.current_issue_id = 1
        self.assertTrue(self.interactive._verify_issue_selected())

    def testVerifyMultipleBranches(self):
        # Leads to no-op on _generate_trace
        self.interactive.trace_tuples_id = 1
        self.interactive.current_issue_id = 1

        self.interactive.current_trace_frame_index = 0
        self.interactive.trace_tuples = [
            TraceTuple(trace_frame=TraceFrame(id=1), branches=1),
            TraceTuple(trace_frame=TraceFrame(id=2), branches=2),
        ]
        self.assertFalse(self.interactive._verify_multiple_branches())

        self.interactive.current_trace_frame_index = 1
        self.assertTrue(self.interactive._verify_multiple_branches())

    def testCreateIssueOutputStringNoSourcesNoSinks(self):
        issue = Issue(code=1000, callable="module.function1")
        issue_instance = IssueInstance(
            id=1,
            message=SharedText(contents="leaf"),
            filename="module.py",
            location=SourceLocation(1, 2, 3),
        )
        sources = []
        sinks = ["sink1", "sink2"]
        result = self.interactive._create_issue_output_string(
            issue_instance, issue, sources, sinks
        )
        self.assertIn("Sources: No sources", result)
        self.assertIn("Sinks: sink1", result)

        sources = ["source1", "source2"]
        sinks = []
        result = self.interactive._create_issue_output_string(
            issue_instance, issue, sources, sinks
        )
        self.assertIn("Sources: source1", result)
        self.assertIn("Sinks: No sinks", result)

    def testListSourceCode(self):
        mock_data = """if this_is_true:
    print("This was true")
else:
    print("This was false")
        """
        self.interactive.setup()
        self.interactive.current_issue_id = 1
        self.interactive.trace_tuples_id = 1

        self.interactive.current_trace_frame_index = 0
        self.interactive.trace_tuples = [
            TraceTuple(
                trace_frame=TraceFrame(
                    filename="file.py", callee_location=SourceLocation(2, 1, 1)
                )
            )
        ]
        with patch("builtins.open", mock_open(read_data=mock_data)) as mock_file:
            self._clear_stdout()
            self.interactive.list_source_code(2)
            mock_file.assert_called_once_with(f"{os.getcwd()}/file.py", "r")
            output = self.stdout.getvalue()
            self.assertEqual(
                output.split("\n"),
                [
                    "file.py:2|1|1",
                    "     1  if this_is_true:",
                    ' --> 2      print("This was true")',
                    "     3  else:",
                    '     4      print("This was false")',
                    "",
                ],
            )

            mock_file.reset_mock()
            self._clear_stdout()
            self.interactive.list_source_code(1)
            mock_file.assert_called_once_with(f"{os.getcwd()}/file.py", "r")
            output = self.stdout.getvalue()
            self.assertEqual(
                output.split("\n"),
                [
                    "file.py:2|1|1",
                    "     1  if this_is_true:",
                    ' --> 2      print("This was true")',
                    "     3  else:",
                    "",
                ],
            )

    def testListSourceCodeFileNotFound(self):
        self.interactive.setup()
        self.interactive.current_issue_id = 1
        self.interactive.trace_tuples_id = 1

        self.interactive.current_trace_frame_index = 0
        self.interactive.trace_tuples = [
            TraceTuple(
                trace_frame=TraceFrame(
                    filename="file.py", callee_location=SourceLocation(2, 1, 1)
                )
            )
        ]
        with patch("builtins.open", mock_open(read_data="not read")) as mock_file:
            mock_file.side_effect = FileNotFoundError()
            self.interactive.list_source_code()
            self.assertIn("Couldn't open", self.stderr.getvalue())
            self.assertNotIn("file.py", self.stdout.getvalue())

    def mock_pager(self, output_string):
        self.pager_calls += 1

    def testPager(self):
        run = Run(id=1, date=datetime.now(), status=RunStatus.FINISHED)
        issue = self._generic_issue()
        issue_instance = self._generic_issue_instance()

        with self.db.make_session() as session:
            session.add(run)
            session.add(issue)
            session.add(issue_instance)
            session.commit()

        # Default is no pager in tests
        self.pager_calls = 0
        with patch("IPython.core.page.page", self.mock_pager):
            self.interactive.setup()
            self.interactive.issues()
            self.interactive.runs()
        self.assertEqual(self.pager_calls, 0)

        self.pager_calls = 0
        with patch("IPython.core.page.page", self.mock_pager):
            self.interactive.setup()
            self.interactive.issues(use_pager=True)
            self.interactive.runs(use_pager=True)
        self.assertEqual(self.pager_calls, 2)