Example #1
0
def get_dml_range(ir_stmt: irast.MutatingStmt, dml_stmt: pgast.DML, *,
                  ctx: context.CompilerContextLevel) -> pgast.CommonTableExpr:
    """Create a range CTE for the given DML statement.

    :param ir_stmt:
        IR of the statement.
    :param dml_stmt:
        SQL DML node instance.

    :return:
        A CommonTableExpr node representing the range affected
        by the DML statement.
    """
    target_ir_set = ir_stmt.subject
    ir_qual_expr = ir_stmt.where

    with ctx.newscope() as scopectx, scopectx.newrel() as subctx:
        subctx.expr_exposed = False
        range_stmt = subctx.rel

        # init_stmt() has associated all top-level paths with
        # the main query, which is at the very bottom.
        # Hoist that scope to the modification range statement
        # instead.
        for path_id, stmt in ctx.path_scope.items():
            if stmt is ctx.rel or path_id == ir_stmt.subject.path_id:
                scopectx.path_scope[path_id] = range_stmt

        if ir_stmt.parent_stmt is not None:
            iterator_set = ir_stmt.parent_stmt.iterator_stmt
        else:
            iterator_set = None

        if iterator_set is not None:
            scopectx.path_scope[iterator_set.path_id] = range_stmt
            relctx.update_scope(iterator_set, range_stmt, ctx=subctx)
            iterator_rvar = clauses.compile_iterator_expr(range_stmt,
                                                          iterator_set,
                                                          ctx=subctx)
            relctx.include_rvar(range_stmt,
                                iterator_rvar,
                                path_id=iterator_set.path_id,
                                ctx=subctx)

        dispatch.compile(target_ir_set, ctx=subctx)

        pathctx.get_path_identity_output(range_stmt,
                                         target_ir_set.path_id,
                                         env=subctx.env)

        if ir_qual_expr is not None:
            range_stmt.where_clause = astutils.extend_binop(
                range_stmt.where_clause,
                clauses.compile_filter_clause(ir_qual_expr, ctx=subctx))

        range_cte = pgast.CommonTableExpr(query=range_stmt,
                                          name=ctx.env.aliases.get('range'))

        return range_cte
Example #2
0
def cte_for_query(
        rel: pgast.Query, *,
        env: context.Environment) -> pgast.CommonTableExpr:
    return pgast.CommonTableExpr(
        query=rel,
        alias=pgast.Alias(
            aliasname=env.aliases.get(rel.name)
        )
    )
Example #3
0
def process_linkprop_update(ir_stmt: irast.MutatingStmt, ir_expr: irast.Base,
                            wrapper: pgast.Query,
                            dml_cte: pgast.CommonTableExpr, *,
                            ctx: context.CompilerContextLevel) -> None:
    """Perform link property updates to a link relation.

    :param ir_stmt:
        IR of the statement.
    :param ir_expr:
        IR of the UPDATE body element.
    :param wrapper:
        Top-level SQL query.
    :param dml_cte:
        CTE representing the SQL UPDATE to the main relation of the Object.
    """
    toplevel = ctx.toplevel_stmt

    rptr = ir_expr.rptr
    ptrcls = rptr.ptrcls

    target_tab = dbobj.range_for_ptrcls(ptrcls,
                                        '>',
                                        include_overlays=False,
                                        env=ctx.env)

    dml_cte_rvar = pgast.RangeVar(
        relation=dml_cte,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('m')))

    cond = astutils.new_binop(
        pathctx.get_rvar_path_identity_var(dml_cte_rvar,
                                           ir_stmt.subject.path_id,
                                           env=ctx.env),
        dbobj.get_column(target_tab, 'source', nullable=False),
        op='=',
    )

    targets = []
    for prop_el in ir_expr.shape:
        ptrname = prop_el.rptr.ptrcls.get_shortname(ctx.env.schema)
        with ctx.new() as input_rel_ctx:
            input_rel_ctx.expr_exposed = False
            input_rel = dispatch.compile(prop_el.expr, ctx=input_rel_ctx)
            targets.append(pgast.UpdateTarget(name=ptrname.name,
                                              val=input_rel))

    updstmt = pgast.UpdateStmt(relation=target_tab,
                               where_clause=cond,
                               targets=targets,
                               from_clause=[dml_cte_rvar])

    updcte = pgast.CommonTableExpr(
        query=updstmt,
        name=ctx.env.aliases.get(ptrcls.get_shortname(ctx.env.schema).name))

    toplevel.ctes.append(updcte)
Example #4
0
def process_link_values(
        ir_stmt, ir_expr, target_tab, tab_cols, col_data,
        dml_rvar, sources, props_only, target_is_scalar, iterator_cte, *,
        ctx=context.CompilerContext) -> \
        typing.Tuple[pgast.CommonTableExpr, typing.List[str]]:
    """Unpack data from an update expression into a series of selects.

    :param ir_expr:
        IR of the INSERT/UPDATE body element.
    :param target_tab:
        The link table being updated.
    :param tab_cols:
        A sequence of columns in the table being updated.
    :param col_data:
        Expressions used to populate well-known columns of the link
        table such as std::source and std::__type__.
    :param sources:
        A list of relations which must be joined into the data query
        to resolve expressions in *col_data*.
    :param props_only:
        Whether this link update only touches link properties.
    :param target_is_scalar:
        Whether the link target is an ScalarType.
    :param iterator_cte:
        CTE representing the iterator range in the FOR clause of the
        EdgeQL DML statement.
    """
    with ctx.newscope() as newscope, newscope.newrel() as subrelctx:
        row_query = subrelctx.rel

        relctx.include_rvar(row_query,
                            dml_rvar,
                            path_id=ir_stmt.subject.path_id,
                            ctx=subrelctx)
        subrelctx.path_scope[ir_stmt.subject.path_id] = row_query

        if iterator_cte is not None:
            iterator_rvar = dbobj.rvar_for_rel(iterator_cte,
                                               lateral=True,
                                               env=subrelctx.env)
            relctx.include_rvar(row_query,
                                iterator_rvar,
                                path_id=iterator_cte.query.path_id,
                                ctx=subrelctx)

        with subrelctx.newscope() as sctx, sctx.subrel() as input_rel_ctx:
            input_rel = input_rel_ctx.rel
            if iterator_cte is not None:
                input_rel_ctx.path_scope[iterator_cte.query.path_id] = \
                    row_query
            input_rel_ctx.expr_exposed = False
            input_rel_ctx.shape_format = context.ShapeFormat.FLAT
            input_rel_ctx.volatility_ref = pathctx.get_path_identity_var(
                row_query, ir_stmt.subject.path_id, env=input_rel_ctx.env)
            dispatch.compile(ir_expr, ctx=input_rel_ctx)

    input_stmt = input_rel

    input_rvar = pgast.RangeSubselect(
        subquery=input_rel,
        lateral=True,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('val')))

    source_data = {}

    if input_stmt.op is not None:
        # UNION
        input_stmt = input_stmt.rarg

    path_id = ir_expr.path_id

    output = pathctx.get_path_value_output(input_stmt, path_id, env=ctx.env)

    if isinstance(output, pgast.TupleVar):
        for element in output.elements:
            name = element.path_id.rptr_name()
            if name is None:
                name = element.path_id[-1].name
            colname = common.edgedb_name_to_pg_name(name)
            source_data.setdefault(colname,
                                   dbobj.get_column(input_rvar, element.name))
    else:
        if target_is_scalar:
            target_ref = pathctx.get_rvar_path_value_var(input_rvar,
                                                         path_id,
                                                         env=ctx.env)
        else:
            target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                            path_id,
                                                            env=ctx.env)

        source_data['std::target'] = target_ref

    if not target_is_scalar and 'std::target' not in source_data:
        target_ref = pathctx.get_rvar_path_identity_var(input_rvar,
                                                        path_id,
                                                        env=ctx.env)
        source_data['std::target'] = target_ref

    specified_cols = []
    for col in tab_cols:
        expr = col_data.get(col)
        if expr is None:
            expr = source_data.get(col)

        if expr is not None:
            row_query.target_list.append(pgast.ResTarget(val=expr, name=col))
            specified_cols.append(col)

    row_query.from_clause += list(sources) + [input_rvar]

    link_rows = pgast.CommonTableExpr(query=row_query,
                                      name=ctx.env.aliases.get(hint='r'))

    return link_rows, specified_cols
Example #5
0
def process_link_update(
        ir_stmt: irast.MutatingStmt, ir_expr: irast.Base, props_only: bool,
        wrapper: pgast.Query, dml_cte: pgast.CommonTableExpr,
        iterator_cte: pgast.CommonTableExpr, *,
        ctx: context.CompilerContextLevel) -> typing.Optional[pgast.Query]:
    """Perform updates to a link relation as part of a DML statement.

    :param ir_stmt:
        IR of the statement.
    :param ir_expr:
        IR of the INSERT/UPDATE body element.
    :param props_only:
        Whether this link update only touches link properties.
    :param wrapper:
        Top-level SQL query.
    :param dml_cte:
        CTE representing the SQL INSERT or UPDATE to the main
        relation of the Object.
    :param iterator_cte:
        CTE representing the iterator range in the FOR clause of the
        EdgeQL DML statement.
    """
    toplevel = ctx.toplevel_stmt

    edgedb_ptr_tab = pgast.RangeVar(
        relation=pgast.Relation(schemaname='edgedb', name='pointer'),
        alias=pgast.Alias(aliasname=ctx.env.aliases.get(hint='ptr')))

    ltab_alias = edgedb_ptr_tab.alias.aliasname

    rptr = ir_expr.rptr
    ptrcls = rptr.ptrcls
    target_is_scalar = isinstance(ptrcls.target, s_scalars.ScalarType)

    path_id = rptr.source.path_id.extend(ptrcls, rptr.direction,
                                         rptr.target.scls)

    # The links in the dml class shape have been derived,
    # but we must use the correct specialized link class for the
    # base material type.
    mptrcls = ptrcls.material_type()

    # Lookup link class id by link name.
    lname_to_id = pgast.CommonTableExpr(query=pgast.SelectStmt(
        from_clause=[edgedb_ptr_tab],
        target_list=[
            pgast.ResTarget(val=pgast.ColumnRef(name=[ltab_alias, 'id']))
        ],
        where_clause=astutils.new_binop(
            lexpr=pgast.ColumnRef(name=[ltab_alias, 'name']),
            rexpr=pgast.Constant(val=mptrcls.name),
            op=ast.ops.EQ)),
                                        name=ctx.env.aliases.get(hint='lid'))

    lname_to_id_rvar = pgast.RangeVar(relation=lname_to_id)
    toplevel.ctes.append(lname_to_id)

    target_rvar = dbobj.range_for_ptrcls(mptrcls,
                                         '>',
                                         include_overlays=False,
                                         env=ctx.env)
    target_alias = target_rvar.alias.aliasname

    target_tab_name = (target_rvar.relation.schemaname,
                       target_rvar.relation.name)

    tab_cols = dbobj.cols_for_pointer(mptrcls, env=ctx.env)

    dml_cte_rvar = pgast.RangeVar(
        relation=dml_cte,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('m')))

    col_data = {
        'ptr_item_id':
        pgast.ColumnRef(name=[lname_to_id.name, 'id']),
        'std::source':
        pathctx.get_rvar_path_identity_var(dml_cte_rvar,
                                           ir_stmt.subject.path_id,
                                           env=ctx.env)
    }

    # Drop all previous link records for this source.
    delcte = pgast.CommonTableExpr(query=pgast.DeleteStmt(
        relation=target_rvar,
        where_clause=astutils.new_binop(
            lexpr=col_data['std::source'],
            op=ast.ops.EQ,
            rexpr=pgast.ColumnRef(name=[target_alias, 'std::source'])),
        using_clause=[dml_cte_rvar],
        returning_list=[
            pgast.ResTarget(val=pgast.ColumnRef(
                name=[target_alias, pgast.Star()]))
        ]),
                                   name=ctx.env.aliases.get(hint='d'))

    pathctx.put_path_value_rvar(delcte.query,
                                path_id.ptr_path(),
                                target_rvar,
                                env=ctx.env)

    # Record the effect of this removal in the relation overlay
    # context to ensure that the RETURNING clause potentially
    # referencing this link yields the expected results.
    overlays = ctx.env.rel_overlays[ptrcls.shortname]
    overlays.append(('except', delcte))
    toplevel.ctes.append(delcte)

    # Turn the IR of the expression on the right side of :=
    # into a subquery returning records for the link table.
    data_cte, specified_cols = process_link_values(ir_stmt,
                                                   ir_expr,
                                                   target_tab_name,
                                                   tab_cols,
                                                   col_data,
                                                   dml_cte_rvar,
                                                   [lname_to_id_rvar],
                                                   props_only,
                                                   target_is_scalar,
                                                   iterator_cte,
                                                   ctx=ctx)

    toplevel.ctes.append(data_cte)

    data_select = pgast.SelectStmt(
        target_list=[
            pgast.ResTarget(val=pgast.ColumnRef(
                name=[data_cte.name, pgast.Star()]))
        ],
        from_clause=[pgast.RangeVar(relation=data_cte)])

    # Inserting rows into the link table may produce cardinality
    # constraint violations, since the INSERT into the link table
    # is executed in the snapshot where the above DELETE from
    # the link table is not visible.  Hence, we need to use
    # the ON CONFLICT clause to resolve this.
    conflict_cols = ['std::source', 'std::target', 'ptr_item_id']
    conflict_inference = []
    conflict_exc_row = []

    for col in conflict_cols:
        conflict_inference.append(pgast.ColumnRef(name=[col]))
        conflict_exc_row.append(pgast.ColumnRef(name=['excluded', col]))

    conflict_data = pgast.SelectStmt(
        target_list=[
            pgast.ResTarget(val=pgast.ColumnRef(
                name=[data_cte.name, pgast.Star()]))
        ],
        from_clause=[pgast.RangeVar(relation=data_cte)],
        where_clause=astutils.new_binop(
            lexpr=pgast.ImplicitRowExpr(args=conflict_inference),
            rexpr=pgast.ImplicitRowExpr(args=conflict_exc_row),
            op='='))

    cols = [pgast.ColumnRef(name=[col]) for col in specified_cols]
    updcte = pgast.CommonTableExpr(
        name=ctx.env.aliases.get(hint='i'),
        query=pgast.InsertStmt(
            relation=target_rvar,
            select_stmt=data_select,
            cols=cols,
            on_conflict=pgast.OnConflictClause(
                action='update',
                infer=pgast.InferClause(index_elems=conflict_inference),
                target_list=[
                    pgast.MultiAssignRef(columns=cols, source=conflict_data)
                ]),
            returning_list=[
                pgast.ResTarget(val=pgast.ColumnRef(name=[pgast.Star()]))
            ]))

    pathctx.put_path_value_rvar(updcte.query,
                                path_id.ptr_path(),
                                target_rvar,
                                env=ctx.env)

    # Record the effect of this insertion in the relation overlay
    # context to ensure that the RETURNING clause potentially
    # referencing this link yields the expected results.
    overlays = ctx.env.rel_overlays[ptrcls.shortname]
    overlays.append(('union', updcte))

    toplevel.ctes.append(updcte)

    return data_cte
Example #6
0
def init_dml_stmt(
        ir_stmt: irast.MutatingStmt, dml_stmt: pgast.DML, *,
        ctx: context.CompilerContextLevel,
        parent_ctx: context.CompilerContextLevel) \
        -> typing.Tuple[pgast.Query, pgast.CommonTableExpr,
                        pgast.CommonTableExpr]:
    """Prepare the common structure of the query representing a DML stmt.

    :param ir_stmt:
        IR of the statement.
    :param dml_stmt:
        SQL DML node instance.

    :return:
        A (*wrapper*, *dml_cte*, *range_cte*) tuple, where *wrapper* the
        the wrapping SQL statement, *dml_cte* is the CTE representing the
        SQL DML operation in the main relation of the Object, and
        *range_cte* is the CTE for the subset affected by the statement.
        *range_cte* is None for INSERT statmenets.
    """
    wrapper = ctx.rel

    clauses.init_stmt(ir_stmt, ctx, parent_ctx)

    target_ir_set = ir_stmt.subject

    dml_stmt.relation = dbobj.range_for_set(ir_stmt.subject,
                                            include_overlays=False,
                                            env=ctx.env)
    pathctx.put_path_value_rvar(dml_stmt,
                                target_ir_set.path_id,
                                dml_stmt.relation,
                                env=ctx.env)
    pathctx.put_path_source_rvar(dml_stmt,
                                 target_ir_set.path_id,
                                 dml_stmt.relation,
                                 env=ctx.env)
    dml_stmt.path_scope.add(target_ir_set.path_id)

    dml_cte = pgast.CommonTableExpr(query=dml_stmt,
                                    name=ctx.env.aliases.get(hint='m'))

    if isinstance(ir_stmt, (irast.UpdateStmt, irast.DeleteStmt)):
        # UPDATE and DELETE operate over a range, so generate
        # the corresponding CTE and connect it to the DML query.
        range_cte = get_dml_range(ir_stmt, dml_stmt, ctx=ctx)

        range_rvar = pgast.RangeVar(
            relation=range_cte,
            alias=pgast.Alias(aliasname=ctx.env.aliases.get(hint='range')))

        relctx.pull_path_namespace(target=dml_stmt, source=range_rvar, ctx=ctx)

        # Auxillary relations are always joined via the WHERE
        # clause due to the structure of the UPDATE/DELETE SQL statments.
        id_col = common.edgedb_name_to_pg_name('std::id')
        dml_stmt.where_clause = astutils.new_binop(
            lexpr=pgast.ColumnRef(
                name=[dml_stmt.relation.alias.aliasname, id_col]),
            op=ast.ops.EQ,
            rexpr=pathctx.get_rvar_path_identity_var(range_rvar,
                                                     target_ir_set.path_id,
                                                     env=ctx.env))

        # UPDATE has "FROM", while DELETE has "USING".
        if hasattr(dml_stmt, 'from_clause'):
            dml_stmt.from_clause.append(range_rvar)
        else:
            dml_stmt.using_clause.append(range_rvar)

    else:
        range_cte = None

    # Due to the fact that DML statements are structured
    # as a flat list of CTEs instead of nested range vars,
    # the top level path scope must be empty.  The necessary
    # range vars will be injected explicitly in all rels that
    # need them.
    ctx.path_scope.clear()

    pathctx.put_path_value_rvar(dml_stmt,
                                ir_stmt.subject.path_id,
                                dml_stmt.relation,
                                env=ctx.env)

    pathctx.put_path_source_rvar(dml_stmt,
                                 ir_stmt.subject.path_id,
                                 dml_stmt.relation,
                                 env=ctx.env)

    dml_rvar = pgast.RangeVar(
        relation=dml_cte,
        alias=pgast.Alias(aliasname=parent_ctx.env.aliases.get('d')))

    relctx.include_rvar(wrapper, dml_rvar, ir_stmt.subject.path_id, ctx=ctx)

    pathctx.put_path_bond(wrapper, ir_stmt.subject.path_id)

    return wrapper, dml_cte, dml_rvar, range_cte
Example #7
0
def process_insert_body(ir_stmt: irast.MutatingStmt, wrapper: pgast.Query,
                        insert_cte: pgast.CommonTableExpr,
                        insert_rvar: pgast.BaseRangeVar, *,
                        ctx: context.CompilerContextLevel) -> None:
    """Generate SQL DML CTEs from an InsertStmt IR.

    :param ir_stmt:
        IR of the statement.
    :param wrapper:
        Top-level SQL query.
    :param insert_cte:
        CTE representing the SQL INSERT to the main relation of the Object.
    """
    cols = [pgast.ColumnRef(name=['std::__type__'])]
    select = pgast.SelectStmt(target_list=[])
    values = select.target_list

    # The main INSERT query of this statement will always be
    # present to insert at least the std::id and std::__type__
    # links.
    insert_stmt = insert_cte.query

    insert_stmt.cols = cols
    insert_stmt.select_stmt = select

    if ir_stmt.parent_stmt is not None:
        iterator_set = ir_stmt.parent_stmt.iterator_stmt
    else:
        iterator_set = None

    if iterator_set is not None:
        with ctx.substmt() as ictx:
            ictx.path_scope = ictx.path_scope.new_child()
            ictx.path_scope[iterator_set.path_id] = ictx.rel
            clauses.compile_iterator_expr(ictx.rel, iterator_set, ctx=ictx)
            ictx.rel.path_id = iterator_set.path_id
            pathctx.put_path_bond(ictx.rel, iterator_set.path_id)
            iterator_cte = pgast.CommonTableExpr(
                query=ictx.rel, name=ctx.env.aliases.get('iter'))
            ictx.toplevel_stmt.ctes.append(iterator_cte)
        iterator_rvar = dbobj.rvar_for_rel(iterator_cte, env=ctx.env)
        relctx.include_rvar(select,
                            iterator_rvar,
                            path_id=ictx.rel.path_id,
                            ctx=ctx)
        iterator_id = pathctx.get_path_identity_var(select,
                                                    iterator_set.path_id,
                                                    env=ctx.env)
    else:
        iterator_cte = None
        iterator_id = None

    values.append(
        pgast.ResTarget(val=pgast.SelectStmt(
            target_list=[pgast.ResTarget(val=pgast.ColumnRef(name=['id']))],
            from_clause=[
                pgast.RangeVar(relation=pgast.Relation(name='objecttype',
                                                       schemaname='edgedb'))
            ],
            where_clause=astutils.new_binop(
                op=ast.ops.EQ,
                lexpr=pgast.ColumnRef(name=['name']),
                rexpr=pgast.Constant(val=ir_stmt.subject.scls.shortname)))))

    external_inserts = []
    tuple_elements = []
    parent_link_props = []

    with ctx.newrel() as subctx:
        subctx.rel = select
        subctx.rel_hierarchy[select] = insert_stmt

        subctx.expr_exposed = False
        subctx.shape_format = context.ShapeFormat.FLAT

        if iterator_cte is not None:
            subctx.path_scope = ctx.path_scope.new_child()
            subctx.path_scope[iterator_cte.query.path_id] = select

        # Process the Insert IR and separate links that go
        # into the main table from links that are inserted into
        # a separate link table.
        for shape_el in ir_stmt.subject.shape:
            rptr = shape_el.rptr
            ptrcls = rptr.ptrcls.material_type()

            if (ptrcls.is_link_property()
                    and rptr.source.path_id != ir_stmt.subject.path_id):
                parent_link_props.append(shape_el)
                continue

            ptr_info = pg_types.get_pointer_storage_info(
                ptrcls,
                schema=subctx.env.schema,
                resolve_type=True,
                link_bias=False)

            props_only = False

            # First, process all local link inserts.
            if ptr_info.table_type == 'ObjectType':
                props_only = True
                field = pgast.ColumnRef(name=[ptr_info.column_name])
                cols.append(field)

                insvalue = insert_value_for_shape_element(insert_stmt,
                                                          wrapper,
                                                          ir_stmt,
                                                          shape_el,
                                                          iterator_id,
                                                          ptr_info=ptr_info,
                                                          ctx=subctx)

                tuple_el = astutils.tuple_element_for_shape_el(shape_el, field)
                tuple_elements.append(tuple_el)
                values.append(pgast.ResTarget(val=insvalue))

            ptr_info = pg_types.get_pointer_storage_info(ptrcls,
                                                         resolve_type=False,
                                                         link_bias=True)

            if ptr_info and ptr_info.table_type == 'link':
                external_inserts.append((shape_el, props_only))

        if iterator_cte is not None:
            cols.append(pgast.ColumnRef(name=['__edb_token']))

            values.append(pgast.ResTarget(val=iterator_id))

            pathctx.put_path_identity_var(insert_stmt,
                                          iterator_set.path_id,
                                          cols[-1],
                                          force=True,
                                          env=subctx.env)

            pathctx.put_path_bond(insert_stmt, iterator_set.path_id)

    toplevel = ctx.toplevel_stmt
    toplevel.ctes.append(insert_cte)

    # Process necessary updates to the link tables.
    for shape_el, props_only in external_inserts:
        process_link_update(ir_stmt,
                            shape_el,
                            props_only,
                            wrapper,
                            insert_cte,
                            iterator_cte,
                            ctx=ctx)

    if parent_link_props:
        prop_elements = []

        with ctx.newscope() as scopectx:
            scopectx.rel = wrapper

            for shape_el in parent_link_props:
                rptr = shape_el.rptr
                scopectx.path_scope[rptr.source.path_id] = wrapper
                pathctx.put_path_rvar_if_not_exists(wrapper,
                                                    rptr.source.path_id,
                                                    insert_rvar,
                                                    aspect='value',
                                                    env=scopectx.env)
                dispatch.compile(shape_el, ctx=scopectx)
                tuple_el = astutils.tuple_element_for_shape_el(shape_el, None)
                prop_elements.append(tuple_el)

        valtuple = pgast.TupleVar(elements=prop_elements, named=True)
        pathctx.put_path_value_var(wrapper,
                                   ir_stmt.subject.path_id,
                                   valtuple,
                                   force=True,
                                   env=ctx.env)
Example #8
0
def compile_GroupStmt(stmt: irast.GroupStmt, *,
                      ctx: context.CompilerContextLevel) -> pgast.Query:

    parent_ctx = ctx
    with parent_ctx.substmt() as ctx:
        clauses.init_stmt(stmt, ctx=ctx, parent_ctx=parent_ctx)

        group_path_id = stmt.group_path_id

        # Process the GROUP .. BY part into a subquery.
        with ctx.subrel() as gctx:
            gctx.expr_exposed = False
            gquery = gctx.rel
            pathctx.put_path_bond(gquery, group_path_id)
            if stmt.path_scope:
                ctx.path_scope.update(
                    {path_id: gquery
                     for path_id in stmt.path_scope.paths})
            relctx.update_scope(stmt.subject, gquery, ctx=gctx)
            stmt.subject.path_scope = None
            clauses.compile_output(stmt.subject, ctx=gctx)
            subj_rvar = pathctx.get_path_rvar(gquery,
                                              stmt.subject.path_id,
                                              aspect='value',
                                              env=gctx.env)
            relctx.ensure_bond_for_expr(stmt.subject,
                                        subj_rvar.query,
                                        ctx=gctx)

            group_paths = set()

            part_clause = []

            for expr in stmt.groupby:
                with gctx.new() as subctx:
                    partexpr = dispatch.compile(expr, ctx=subctx)

                part_clause.append(partexpr)
                group_paths.add(expr.path_id)

            # Since we will be computing arbitrary expressions
            # based on the grouped sets, it is more efficient
            # to compute the "group bond" as a small unique
            # value than it is to use GROUP BY and aggregate
            # actual id values into an array.
            #
            # To achieve this we use the first_value() window
            # function while using the GROUP BY clause as
            # a partition clause.  We use the id of the first
            # object in each partition if GROUP BY input is
            # a ObjectType, otherwise we generate the id using
            # row_number().
            if isinstance(stmt.subject.scls, s_objtypes.ObjectType):
                first_val = pathctx.get_path_identity_var(gquery,
                                                          stmt.subject.path_id,
                                                          env=ctx.env)
            else:
                with ctx.subrel() as subctx:
                    wrapper = subctx.rel

                    gquery_rvar = dbobj.rvar_for_rel(gquery, env=ctx.env)
                    wrapper.from_clause = [gquery_rvar]
                    relctx.pull_path_namespace(target=wrapper,
                                               source=gquery_rvar,
                                               ctx=subctx)

                    new_part_clause = []

                    for i, expr in enumerate(part_clause):
                        path_id = stmt.groupby[i].path_id
                        pathctx.put_path_value_var(gquery,
                                                   path_id,
                                                   expr,
                                                   force=True,
                                                   env=ctx.env)
                        output_ref = pathctx.get_path_value_output(gquery,
                                                                   path_id,
                                                                   env=ctx.env)
                        new_part_clause.append(
                            dbobj.get_column(gquery_rvar, output_ref))

                    part_clause = new_part_clause

                    first_val = pathctx.get_rvar_path_identity_var(
                        gquery_rvar, stmt.subject.path_id, env=ctx.env)

                    gquery = wrapper
                    pathctx.put_path_bond(gquery, group_path_id)

            group_id = pgast.FuncCall(
                name=('first_value', ),
                args=[first_val],
                over=pgast.WindowDef(partition_clause=part_clause))

            pathctx.put_path_identity_var(gquery,
                                          group_path_id,
                                          group_id,
                                          env=ctx.env)

            pathctx.put_path_value_var(gquery,
                                       group_path_id,
                                       group_id,
                                       env=ctx.env)

        group_cte = pgast.CommonTableExpr(query=gquery,
                                          name=ctx.env.aliases.get('g'))

        group_cte_rvar = dbobj.rvar_for_rel(group_cte, env=ctx.env)

        # Generate another subquery contaning distinct values of
        # path expressions in BY.
        with ctx.subrel() as gvctx:
            gvquery = gvctx.rel
            relctx.include_rvar(gvquery,
                                group_cte_rvar,
                                path_id=group_path_id,
                                ctx=gvctx)

            pathctx.put_path_bond(gvquery, group_path_id)

            for group_set in stmt.groupby:
                dispatch.visit(group_set, ctx=gvctx)
                path_id = group_set.path_id
                if path_id.is_objtype_path():
                    pathctx.put_path_bond(gvquery, path_id)

            gvquery.distinct_clause = [
                pathctx.get_path_identity_var(gvquery,
                                              group_path_id,
                                              env=ctx.env)
            ]

            for path_id, aspect in list(gvquery.path_rvar_map):
                if path_id not in group_paths and path_id != group_path_id:
                    gvquery.path_rvar_map.pop((path_id, aspect))

            for path_id, aspect in list(gquery.path_rvar_map):
                if path_id in group_paths:
                    gquery.path_rvar_map.pop((path_id, aspect))
                    gquery.path_namespace.pop((path_id, aspect), None)
                    gquery.path_outputs.pop((path_id, aspect), None)

        groupval_cte = pgast.CommonTableExpr(query=gvquery,
                                             name=ctx.env.aliases.get('gv'))

        groupval_cte_rvar = dbobj.rvar_for_rel(groupval_cte, env=ctx.env)

        o_stmt = stmt.result.expr

        # process the result expression;
        with ctx.subrel() as selctx:
            selquery = selctx.rel
            outer_id = stmt.result.path_id
            inner_id = o_stmt.result.path_id

            relctx.include_specific_rvar(selquery,
                                         groupval_cte_rvar,
                                         group_path_id,
                                         aspect='identity',
                                         ctx=ctx)

            for path_id in group_paths:
                selctx.path_scope[path_id] = selquery
                pathctx.put_path_rvar(selquery,
                                      path_id,
                                      groupval_cte_rvar,
                                      aspect='value',
                                      env=ctx.env)

            selctx.group_by_rels = selctx.group_by_rels.copy()
            selctx.group_by_rels[group_path_id, stmt.subject.path_id] = \
                group_cte

            selquery.view_path_id_map = {outer_id: inner_id}

            selquery.ctes.append(group_cte)

            sortoutputs = []

            selquery.ctes.append(groupval_cte)

            clauses.compile_output(o_stmt.result, ctx=selctx)

            # The WHERE clause
            selquery.where_clause = astutils.extend_binop(
                selquery.where_clause,
                clauses.compile_filter_clause(o_stmt.where, ctx=selctx))

            for ir_sortexpr in o_stmt.orderby:
                alias = ctx.env.aliases.get('s')
                sexpr = dispatch.compile(ir_sortexpr.expr, ctx=selctx)
                selquery.target_list.append(
                    pgast.ResTarget(val=sexpr, name=alias))
                sortoutputs.append(alias)

        if not gvquery.target_list:
            # No values were pulled from the group-values rel,
            # we must remove the DISTINCT clause to prevent
            # a syntax error.
            gvquery.distinct_clause[:] = []

        query = ctx.rel
        result_rvar = dbobj.rvar_for_rel(selquery, lateral=True, env=ctx.env)
        relctx.include_rvar(query, result_rvar, path_id=outer_id, ctx=ctx)

        for rt in selquery.target_list:
            if rt.name is None:
                rt.name = ctx.env.aliases.get('v')
            if rt.name not in sortoutputs:
                query.target_list.append(
                    pgast.ResTarget(val=dbobj.get_column(result_rvar, rt.name),
                                    name=rt.name))

        for i, expr in enumerate(o_stmt.orderby):
            sort_ref = dbobj.get_column(result_rvar, sortoutputs[i])
            sortexpr = pgast.SortBy(node=sort_ref,
                                    dir=expr.direction,
                                    nulls=expr.nones_order)
            query.sort_clause.append(sortexpr)

        # The OFFSET clause
        if o_stmt.offset:
            with ctx.new() as ctx1:
                ctx1.clause = 'offsetlimit'
                ctx1.expr_exposed = False
                query.limit_offset = dispatch.compile(o_stmt.offset, ctx=ctx1)

        # The LIMIT clause
        if o_stmt.limit:
            with ctx.new() as ctx1:
                ctx1.clause = 'offsetlimit'
                ctx1.expr_exposed = False
                query.limit_count = dispatch.compile(o_stmt.limit, ctx=ctx1)

        clauses.fini_stmt(query, ctx, parent_ctx)

    return query