Beispiel #1
0
def rewrite_refs(expr, callback):
    """Rewrite class references in EdgeQL expression."""

    tree = qlparser.parse_fragment(expr)

    def _cb(node):
        if isinstance(node, qlast.ObjectRef):
            name = sn.Name(name=node.name, module=node.module)
            upd = callback(name)
            if name != upd:
                node.name = upd.name
                node.module = upd.module

    ast.find_children(tree, _cb)

    return qlcodegen.generate_source(tree, pretty=False)
Beispiel #2
0
def is_aggregated_expr(ir):
    def flt(n):
        if isinstance(n, irast.FunctionCall):
            return n.func.aggregate
        elif isinstance(n, irast.Stmt):
            # Make sure we don't dip into subqueries
            raise ast.SkipNode()

    return bool(set(ast.find_children(ir, flt)))
Beispiel #3
0
def get_source_references(ir):
    result = set()

    flt = lambda n: isinstance(n, irast.Set) and n.expr is None
    ir_sets = ast.find_children(ir, flt)
    for ir_set in ir_sets:
        result.add(ir_set.scls)

    return result
Beispiel #4
0
def get_terminal_references(ir):
    result = set()
    parents = set()

    flt = lambda n: isinstance(n, irast.Set) and n.expr is None
    ir_sets = ast.find_children(ir, flt)
    for ir_set in ir_sets:
        result.add(ir_set)
        if ir_set.rptr:
            parents.add(ir_set.rptr.source)

    return result - parents
Beispiel #5
0
def is_const(ir):
    flt = lambda n: isinstance(n, irast.Set) and n.expr is None
    ir_sets = ast.find_children(ir, flt)
    variables = get_variables(ir)
    return not ir_sets and not variables
Beispiel #6
0
def get_variables(ir):
    result = set()
    flt = lambda n: isinstance(n, irast.Parameter)
    result.update(ast.find_children(ir, flt))
    return result
Beispiel #7
0
    def _edgeql_ref_to_pg_constr(cls, subject, tree, schema, link_bias):
        sql_tree = compiler.compile_ir_to_sql_tree(tree,
                                                   schema=schema,
                                                   singleton_mode=True)

        if isinstance(sql_tree, pg_ast.SelectStmt):
            # XXX: use ast pattern matcher for this
            sql_expr = sql_tree.from_clause[0].relation\
                .query.target_list[0].val
        else:
            sql_expr = sql_tree

        if isinstance(tree, irast.Statement):
            tree = tree.expr

        if isinstance(tree.expr, irast.SelectStmt):
            tree = tree.expr.result

        is_multicol = isinstance(tree.expr, irast.Tuple)

        # Determine if the sequence of references are all simple refs, not
        # expressions.  This influences the type of Postgres constraint used.
        #
        is_trivial = (isinstance(sql_expr, pg_ast.ColumnRef)
                      or (isinstance(sql_expr, pg_ast.ImplicitRowExpr) and all(
                          isinstance(el, pg_ast.ColumnRef)
                          for el in sql_expr.args)))

        # Find all field references
        #
        flt = lambda n: isinstance(n, pg_ast.ColumnRef) and len(n.name) == 1
        refs = set(ast.find_children(sql_expr, flt))

        if isinstance(subject, s_scalars.ScalarType):
            # Domain constraint, replace <scalar_name> with VALUE

            subject_pg_name = common.edgedb_name_to_pg_name(subject.name)

            for ref in refs:
                if ref.name != [subject_pg_name]:
                    raise ValueError(
                        f'unexpected node reference in '
                        f'ScalarType constraint: {".".join(ref.name)}')

                # work around the immutability check
                object.__setattr__(ref, 'name', ['VALUE'])

        plain_expr = codegen.SQLSourceGenerator.to_source(sql_expr)

        if is_multicol:
            chunks = []

            for elem in sql_expr.args:
                chunks.append(codegen.SQLSourceGenerator.to_source(elem))
        else:
            chunks = [plain_expr]

        if isinstance(sql_expr, pg_ast.ColumnRef):
            refs.add(sql_expr)

        for ref in refs:
            ref.name.insert(0, 'NEW')
        new_expr = codegen.SQLSourceGenerator.to_source(sql_expr)

        for ref in refs:
            ref.name[0] = 'OLD'
        old_expr = codegen.SQLSourceGenerator.to_source(sql_expr)

        exprdata = dict(plain=plain_expr,
                        plain_chunks=chunks,
                        new=new_expr,
                        old=old_expr)

        return dict(exprdata=exprdata,
                    is_multicol=is_multicol,
                    is_trivial=is_trivial)
Beispiel #8
0
    def _normalize_ptr_default(self, expr, source, ptr, ptrdecl):
        module_aliases = {None: source.name.module}

        ir, _, expr_text = edgeql.utils.normalize_tree(
            expr,
            self._schema,
            modaliases=module_aliases,
            anchors={qlast.Source: source})

        self_set = ast.find_children(
            ir,
            lambda n: getattr(n, 'anchor', None) == qlast.Source,
            terminate_early=True)

        try:
            expr_type = ir_utils.infer_type(ir, self._schema)
        except edgeql.EdgeQLError as e:
            raise s_err.SchemaError(
                'could not determine the result type of the default '
                'expression on {!s}.{!s}'.format(source.name, ptr.shortname),
                context=expr.context) from e

        ptr.default = expr_text
        ptr.normalize_defaults()

        if ptr.is_pure_computable():
            # Pure computable without explicit target.
            # Fixup pointer target and target property.
            ptr.target = expr_type

            if isinstance(ptr, s_links.Link):
                if not isinstance(expr_type, s_objtypes.ObjectType):
                    raise s_err.SchemaDefinitionError(
                        f'invalid link target, expected object type, got '
                        f'{expr_type.__class__.__name__}',
                        context=ptrdecl.expr.context)
            else:
                if not isinstance(expr_type,
                                  (s_scalars.ScalarType, s_types.Collection)):
                    raise s_err.SchemaDefinitionError(
                        f'invalid property target, expected primitive type, '
                        f'got {expr_type.__class__.__name__}',
                        context=ptrdecl.expr.context)

            if isinstance(ptr, s_links.Link):
                pname = s_name.Name('std::target')
                tgt_prop = ptr.pointers[pname]
                tgt_prop.target = expr_type

            cardinality = self._get_literal_attribute(ptrdecl, 'cardinality')
            if cardinality is not None:
                raise s_err.SchemaError(
                    'computable links must not define explicit cardinality',
                    context=expr.context)

            singletons = set()
            if self_set is not None:
                singletons.add(self_set.path_id)

            cardinality = \
                ir_inference.infer_cardinality(ir, singletons, self._schema)

            if cardinality == qlast.Cardinality.MANY:
                ptr.cardinality = s_pointers.PointerCardinality.ManyToMany
            else:
                ptr.cardinality = s_pointers.PointerCardinality.ManyToOne

        if (not isinstance(expr_type, s_types.Type) or
            (ptr.target is not None and not expr_type.issubclass(ptr.target))):
            raise s_err.SchemaError(
                'default value query must yield a single result of '
                'type {!r}'.format(ptr.target.name),
                context=expr.context)

        if not isinstance(ptr.target, s_scalars.ScalarType):
            many_mapping = (s_pointers.PointerCardinality.ManyToOne,
                            s_pointers.PointerCardinality.ManyToMany)
            if ptr.cardinality not in many_mapping:
                raise s_err.SchemaError(
                    'type links with query defaults '
                    'must have either a "*1" or "**" cardinality',
                    context=expr.context)