Ejemplo n.º 1
0
    def on_Explain(self, explain):
        """
        Compile and print a compiled statement for debugging.
        """
        # pylint: disable=invalid-name
        pr = lambda *args: print(*args, file=self.outfile)
        pr("Parsed statement:")
        pr("  {}".format(explain.statement))
        pr()

        # Compile the select statement and print it uot.
        try:
            query = query_compile.compile(explain.statement,
                                          self.env_targets,
                                          self.env_postings,
                                          self.env_entries)
        except query_compile.CompilationError as exc:
            pr(str(exc).rstrip('.'))
            return

        pr("Compiled query:")
        pr("  {}".format(query))
        pr()
        pr("Targets:")
        for c_target in query.c_targets:
            pr("  '{}'{}: {}".format(
                c_target.name or '(invisible)',
                ' (aggregate)' if query_compile.is_aggregate(c_target.c_expr) else '',
                c_target.c_expr.dtype.__name__))
        pr()
Ejemplo n.º 2
0
    def assertIndexes(self, query, expected_simple_indexes,
                      expected_aggregate_indexes, expected_group_indexes,
                      expected_order_indexes):
        """Check the four lists of indexes for comparison.

        Args:
          query: An instance of EvalQuery, a compiled query statement.
          expected_simple_indexes: The expected visible non-aggregate indexes.
          expected_aggregate_indexes: The expected visible aggregate indexes.
          expected_group_indexes: The expected group_indexes.
          expected_order_indexes: The expected order_indexes.
        Raises:
          AssertionError: if the check fails.
        """
        # Compute the list of _visible_ aggregates and non-aggregates.
        simple_indexes = [
            index for index, c_target in enumerate(query.c_targets)
            if c_target.name and not qc.is_aggregate(c_target.expression)
        ]
        aggregate_indexes = [
            index for index, c_target in enumerate(query.c_targets)
            if c_target.name and qc.is_aggregate(c_target.expression)
        ]

        self.assertEqual(set(expected_simple_indexes), set(simple_indexes))

        self.assertEqual(set(expected_aggregate_indexes),
                         set(aggregate_indexes))

        self.assertEqual(
            set(expected_group_indexes)
            if expected_group_indexes is not None else None,
            set(query.group_indexes)
            if query.group_indexes is not None else None)

        self.assertEqual(
            set(expected_order_indexes)
            if expected_order_indexes is not None else None,
            set(query.order_indexes)
            if query.order_indexes is not None else None)
Ejemplo n.º 3
0
    def assertSelectInvariants(self, query):
        """Assert the invariants on the query.

        Args:
          query: An instance of EvalQuery, a compiled query statement.
        Raises:
          AssertionError: if the check fails.
        """
        # Check that the group references cover all the simple indexes.
        if query.group_indexes is not None:
            non_aggregate_indexes = [index
                                     for index, c_target in enumerate(query.c_targets)
                                     if not qc.is_aggregate(c_target.c_expr)]

            self.assertEqual(set(non_aggregate_indexes), set(query.group_indexes),
                             "Invalid indexes: {}".format(query))
Ejemplo n.º 4
0
    def test_get_columns_and_aggregates(self):
        # Simple column.
        c_query = qe.PositionColumn()
        columns, aggregates = qc.get_columns_and_aggregates(c_query)
        self.assertEqual((1, 0), (len(columns), len(aggregates)))
        self.assertFalse(qc.is_aggregate(c_query))

        # Multiple columns.
        c_query = qc.EvalAnd(qe.PositionColumn(), qe.DateColumn())
        columns, aggregates = qc.get_columns_and_aggregates(c_query)
        self.assertEqual((2, 0), (len(columns), len(aggregates)))
        self.assertFalse(qc.is_aggregate(c_query))

        # Simple aggregate.
        c_query = qe.SumPosition([qe.PositionColumn()])
        columns, aggregates = qc.get_columns_and_aggregates(c_query)
        self.assertEqual((0, 1), (len(columns), len(aggregates)))
        self.assertTrue(qc.is_aggregate(c_query))

        # Multiple aggregates.
        c_query = qc.EvalAnd(qe.First([qe.AccountColumn()]),
                             qe.Last([qe.AccountColumn()]))
        columns, aggregates = qc.get_columns_and_aggregates(c_query)
        self.assertEqual((0, 2), (len(columns), len(aggregates)))
        self.assertTrue(qc.is_aggregate(c_query))

        # Simple non-aggregate function.
        c_query = qe.Length([qe.AccountColumn()])
        columns, aggregates = qc.get_columns_and_aggregates(c_query)
        self.assertEqual((1, 0), (len(columns), len(aggregates)))
        self.assertFalse(qc.is_aggregate(c_query))

        # Mix of column and aggregates (this is used to detect this illegal case).
        c_query = qc.EvalAnd(qe.Length([qe.AccountColumn()]),
                             qe.SumPosition([qe.PositionColumn()]))
        columns, aggregates = qc.get_columns_and_aggregates(c_query)
        self.assertEqual((1, 1), (len(columns), len(aggregates)))
        self.assertTrue(qc.is_aggregate(c_query))