예제 #1
0
async def execute(server, query, variables):
    query_unit = await compile(server, query)
    if query_unit.capabilities & ~ALLOWED_CAPABILITIES:
        raise query_unit.capabilities.make_error(
            ALLOWED_CAPABILITIES,
            errors.UnsupportedCapabilityError,
        )

    args = []
    if query_unit.in_type_args:
        for param in query_unit.in_type_args:
            if variables is None or param.name not in variables:
                raise errors.QueryError(
                    f'no value for the ${param.name} query parameter')
            else:
                value = variables[param.name]
                if value is None and param.required:
                    raise errors.QueryError(
                        f'parameter ${param.name} is required')
                args.append(value)

    pgcon = await server.acquire_pgcon(edbdef.EDGEDB_SYSTEM_DB)
    try:
        data = await pgcon.parse_execute_json(
            query_unit.sql[0],
            query_unit.sql_hash,
            1,
            True,
            args,
        )
    finally:
        server.release_pgcon(edbdef.EDGEDB_SYSTEM_DB, pgcon)

    if data is None:
        raise errors.InternalServerError(
            f'no data received for a JSON query {query_unit.sql[0]!r}')

    return data
예제 #2
0
    def _cmd_tree_from_ast(
        cls,
        schema: s_schema.Schema,
        astnode: qlast.DDLOperation,
        context: sd.CommandContext,
    ) -> sd.Command:
        cmd = super()._cmd_tree_from_ast(schema, astnode, context)

        if isinstance(cmd, sd.CommandGroup):
            for subcmd in cmd.get_subcommands():
                if isinstance(subcmd, cls):
                    create_cmd: sd.Command = subcmd
                    break
            else:
                raise errors.InternalServerError(
                    'scalar alias definition did not return CreateScalarType')
        else:
            create_cmd = cmd

        bases = create_cmd.get_attribute_value('bases')
        is_enum = False
        if len(bases) == 1 and isinstance(bases._ids[0], AnonymousEnumTypeRef):
            # type ignore below because this class elements is set
            # directly on __dict__
            elements = bases._ids[0].elements  # type: ignore
            create_cmd.set_attribute_value('enum_values', elements)
            create_cmd.set_attribute_value('is_final', True)
            is_enum = True

        for sub in create_cmd.get_subcommands(type=sd.AlterObjectProperty):
            if sub.property == 'default':
                if is_enum:
                    raise errors.UnsupportedFeatureError(
                        f'enumerated types do not support defaults')
                else:
                    sub.new_value = [sub.new_value]
        assert isinstance(cmd, (CreateScalarType, sd.CommandGroup))
        return cmd
예제 #3
0
파일: stmtctx.py 프로젝트: joe2hpimn/edgedb
def pend_pointer_cardinality_inference(
        *,
        ptrcls: s_pointers.Pointer,
        specified_card: typing.Optional[qltypes.Cardinality] = None,
        from_parent: bool=False,
        source_ctx: typing.Optional[parsing.ParserContext] = None,
        ctx: context.ContextLevel) -> None:

    existing = ctx.pending_cardinality.get(ptrcls)
    if existing is not None:
        if (existing.specified_cardinality != specified_card
                or existing.from_parent != from_parent):
            raise errors.InternalServerError(
                f'cardinality inference for {ptrcls.get_name(ctx.env.schema)} '
                f'is scheduled multiple times with different context'
            )
    else:
        ctx.pending_cardinality[ptrcls] = context.PendingCardinality(
            specified_cardinality=specified_card,
            source_ctx=source_ctx,
            from_parent=from_parent,
            callbacks=[],
        )
예제 #4
0
def resolve_special_anchor(anchor: qlast.SpecialAnchor, *,
                           ctx: context.ContextLevel) -> irast.Set:

    # '__source__' and '__subject__` can only appear as the
    # starting path label syntactically and must be pre-populated
    # by the compile() caller.

    if isinstance(anchor, qlast.SpecialAnchor):
        token = anchor.name
    else:
        raise errors.InternalServerError(
            f'unexpected special anchor kind: {anchor!r}')

    anchors = ctx.anchors
    path_tip = anchors.get(token)

    if path_tip is None:
        raise errors.InvalidReferenceError(
            f'{token} cannot be used in this expression',
            context=anchor.context,
        )

    return path_tip
예제 #5
0
def expression_set(expr: irast.Expr,
                   path_id: Optional[irast.PathId] = None,
                   *,
                   type_override: Optional[s_types.Type] = None,
                   ctx: context.ContextLevel) -> irast.Set:

    if isinstance(expr, irast.Set):  # pragma: no cover
        raise errors.InternalServerError(f'{expr!r} is already a Set')

    if type_override is not None:
        stype = type_override
    else:
        stype = inference.infer_type(expr, ctx.env)

    if path_id is None:
        path_id = getattr(expr, 'path_id', None)
        if path_id is None:
            path_id = pathctx.get_expression_path_id(stype, ctx=ctx)

    return new_set(path_id=path_id,
                   stype=stype,
                   expr=expr,
                   context=expr.context,
                   ctx=ctx)
예제 #6
0
def preload(
    allow_rebuild: bool = True,
    paralellize: bool = False,
    parsers: Optional[List[qlparser.EdgeQLParserBase]] = None,
) -> None:
    if parsers is None:
        parsers = [
            qlparser.EdgeQLBlockParser(),
            qlparser.EdgeQLExpressionParser(),
            qlparser.EdgeSDLParser(),
        ]

    if not paralellize:
        try:
            for parser in parsers:
                parser.get_parser_spec(allow_rebuild)
        except parsing.ParserSpecIncompatibleError as e:
            raise errors.InternalServerError(e.args[0]) from None
    else:
        parsers_to_rebuild = []

        for parser in parsers:
            try:
                parser.get_parser_spec(allow_rebuild=False)
            except parsing.ParserSpecIncompatibleError:
                parsers_to_rebuild.append(parser)

        if len(parsers_to_rebuild) == 0:
            pass
        elif len(parsers_to_rebuild) == 1:
            parsers_to_rebuild[0].get_parser_spec(allow_rebuild=True)
        else:
            with multiprocessing.Pool(len(parsers_to_rebuild)) as pool:
                pool.map(_load_parser, parsers_to_rebuild)

            preload(parsers=parsers, allow_rebuild=False)
예제 #7
0
    def _describe_type(self, t, view_shapes, view_shapes_metadata,
                       protocol_version,
                       follow_links: bool = True):
        # The encoding format is documented in edb/api/types.txt.

        buf = self.buffer

        if isinstance(t, s_types.Tuple):
            subtypes = [self._describe_type(st, view_shapes,
                                            view_shapes_metadata,
                                            protocol_version)
                        for st in t.get_subtypes(self.schema)]

            if t.is_named(self.schema):
                element_names = list(t.get_element_names(self.schema))
                assert len(element_names) == len(subtypes)

                type_id = self._get_collection_type_id(
                    t.schema_name, subtypes, element_names)

                if type_id in self.uuid_to_pos:
                    return type_id

                buf.append(CTYPE_NAMEDTUPLE)
                buf.append(type_id.bytes)
                buf.append(_uint16_packer(len(subtypes)))
                for el_name, el_type in zip(element_names, subtypes):
                    el_name_bytes = el_name.encode('utf-8')
                    buf.append(_uint32_packer(len(el_name_bytes)))
                    buf.append(el_name_bytes)
                    buf.append(_uint16_packer(self.uuid_to_pos[el_type]))

            else:
                type_id = self._get_collection_type_id(t.schema_name, subtypes)

                if type_id in self.uuid_to_pos:
                    return type_id

                buf.append(CTYPE_TUPLE)
                buf.append(type_id.bytes)
                buf.append(_uint16_packer(len(subtypes)))
                for el_type in subtypes:
                    buf.append(_uint16_packer(self.uuid_to_pos[el_type]))

            self._register_type_id(type_id)
            return type_id

        elif isinstance(t, s_types.Array):
            subtypes = [self._describe_type(st, view_shapes,
                                            view_shapes_metadata,
                                            protocol_version)
                        for st in t.get_subtypes(self.schema)]

            assert len(subtypes) == 1
            type_id = self._get_collection_type_id(t.schema_name, subtypes)

            if type_id in self.uuid_to_pos:
                return type_id

            buf.append(CTYPE_ARRAY)
            buf.append(type_id.bytes)
            buf.append(_uint16_packer(self.uuid_to_pos[subtypes[0]]))
            # Number of dimensions (currently always 1)
            buf.append(_uint16_packer(1))
            # Dimension cardinality (currently always unbound)
            buf.append(_int32_packer(-1))

            self._register_type_id(type_id)
            return type_id

        elif isinstance(t, s_types.Collection):
            raise errors.SchemaError(f'unsupported collection type {t!r}')

        elif isinstance(t, s_objtypes.ObjectType):
            # This is a view
            self.schema, mt = t.material_type(self.schema)
            base_type_id = mt.id

            subtypes = []
            element_names = []
            link_props = []
            links = []
            cardinalities = []

            metadata = view_shapes_metadata.get(t)
            implicit_id = metadata is not None and metadata.has_implicit_id

            for ptr in view_shapes.get(t, ()):
                if ptr.singular(self.schema):
                    if isinstance(ptr, s_links.Link) and not follow_links:
                        subtype_id = self._describe_type(
                            self.schema.get('std::uuid'), view_shapes,
                            view_shapes_metadata, protocol_version,
                        )
                    else:
                        subtype_id = self._describe_type(
                            ptr.get_target(self.schema), view_shapes,
                            view_shapes_metadata, protocol_version)
                else:
                    if isinstance(ptr, s_links.Link) and not follow_links:
                        raise errors.InternalServerError(
                            'cannot describe multi links when '
                            'follow_links=False'
                        )
                    else:
                        subtype_id = self._describe_set(
                            ptr.get_target(self.schema), view_shapes,
                            view_shapes_metadata, protocol_version)
                subtypes.append(subtype_id)
                element_names.append(ptr.get_shortname(self.schema).name)
                link_props.append(False)
                links.append(not ptr.is_property(self.schema))
                cardinalities.append(
                    cardinality_from_ptr(ptr, self.schema).value)

            t_rptr = t.get_rptr(self.schema)
            if t_rptr is not None and (rptr_ptrs := view_shapes.get(t_rptr)):
                # There are link properties in the mix
                for ptr in rptr_ptrs:
                    if ptr.singular(self.schema):
                        subtype_id = self._describe_type(
                            ptr.get_target(self.schema), view_shapes,
                            view_shapes_metadata, protocol_version)
                    else:
                        subtype_id = self._describe_set(
                            ptr.get_target(self.schema), view_shapes,
                            view_shapes_metadata, protocol_version)
                    subtypes.append(subtype_id)
                    element_names.append(
                        ptr.get_shortname(self.schema).name)
                    link_props.append(True)
                    links.append(False)
                    cardinalities.append(
                        cardinality_from_ptr(ptr, self.schema).value)

            type_id = self._get_object_type_id(
                base_type_id, subtypes, element_names,
                links_props=link_props, links=links,
                has_implicit_fields=implicit_id)

            if type_id in self.uuid_to_pos:
                return type_id

            buf.append(CTYPE_SHAPE)
            buf.append(type_id.bytes)

            assert len(subtypes) == len(element_names)
            buf.append(_uint16_packer(len(subtypes)))

            zipped_parts = list(zip(element_names, subtypes, link_props, links,
                                    cardinalities))
            for el_name, el_type, el_lp, el_l, el_c in zipped_parts:
                flags = 0
                if el_lp:
                    flags |= self.EDGE_POINTER_IS_LINKPROP
                if (implicit_id and el_name == 'id') or el_name == '__tid__':
                    if el_type != UUID_TYPE_ID:
                        raise errors.InternalServerError(
                            f"{el_name!r} is expected to be a 'std::uuid' "
                            f"singleton")
                    flags |= self.EDGE_POINTER_IS_IMPLICIT
                elif el_name == '__tname__':
                    if el_type != STR_TYPE_ID:
                        raise errors.InternalServerError(
                            f"{el_name!r} is expected to be a 'std::str' "
                            f"singleton")
                    flags |= self.EDGE_POINTER_IS_IMPLICIT
                if el_l:
                    flags |= self.EDGE_POINTER_IS_LINK

                if protocol_version >= (0, 11):
                    buf.append(_uint32_packer(flags))
                    buf.append(_uint8_packer(el_c))
                else:
                    buf.append(_uint8_packer(flags))

                el_name_bytes = el_name.encode('utf-8')
                buf.append(_uint32_packer(len(el_name_bytes)))
                buf.append(el_name_bytes)
                buf.append(_uint16_packer(self.uuid_to_pos[el_type]))

            self._register_type_id(type_id)
            return type_id
예제 #8
0
def range_for_ptrref(
    ptrref: irast.BasePointerRef,
    *,
    dml_source: Optional[irast.MutatingStmt] = None,
    for_mutation: bool = False,
    only_self: bool = False,
    ctx: context.CompilerContextLevel,
) -> pgast.PathRangeVar:
    """"Return a Range subclass corresponding to a given ptr step.

    The return value may potentially be a UNION of all tables
    corresponding to a set of specialized links computed from the given
    `ptrref` taking source inheritance into account.
    """
    tgt_col = pg_types.get_ptrref_storage_info(ptrref,
                                               resolve_type=False,
                                               link_bias=True).column_name

    cols = ['source', tgt_col]

    set_ops = []

    if ptrref.union_components:
        refs = ptrref.union_components
        if only_self and len(refs) > 1:
            raise errors.InternalServerError('unexpected union link')
    else:
        refs = {ptrref}
        assert isinstance(ptrref, irast.PointerRef), \
            "expected regular PointerRef"
        overlays = get_ptr_rel_overlays(ptrref, dml_source=dml_source, ctx=ctx)

    for src_ptrref in refs:
        assert isinstance(src_ptrref, irast.PointerRef), \
            "expected regular PointerRef"
        table = table_from_ptrref(
            src_ptrref,
            include_descendants=not ptrref.union_is_concrete,
            for_mutation=for_mutation,
            ctx=ctx,
        )

        qry = pgast.SelectStmt()
        qry.from_clause.append(table)

        # Make sure all property references are pulled up properly
        for colname in cols:
            selexpr = pgast.ColumnRef(name=[table.alias.aliasname, colname])
            qry.target_list.append(pgast.ResTarget(val=selexpr, name=colname))

        set_ops.append(('union', qry))

        overlays = get_ptr_rel_overlays(src_ptrref,
                                        dml_source=dml_source,
                                        ctx=ctx)
        if overlays and not for_mutation:
            for op, cte in overlays:
                rvar = pgast.RelRangeVar(
                    relation=cte,
                    alias=pgast.Alias(aliasname=ctx.env.aliases.get(cte.name)))

                qry = pgast.SelectStmt(
                    target_list=[
                        pgast.ResTarget(val=pgast.ColumnRef(name=[col]))
                        for col in cols
                    ],
                    from_clause=[rvar],
                )
                set_ops.append((op, qry))

    return range_from_queryset(set_ops, ptrref.shortname, ctx=ctx)
예제 #9
0
                bt_id = self._describe_type(
                    base_type, view_shapes, view_shapes_metadata,
                    protocol_version)

                buf.append(CTYPE_SCALAR)
                buf.append(type_id.bytes)
                buf.append(_uint16_packer(self.uuid_to_pos[bt_id]))

                if self.inline_typenames:
                    self._add_annotation(mt)

            self._register_type_id(type_id)
            return type_id

        else:
            raise errors.InternalServerError(
                f'cannot describe type {t.get_name(self.schema)}')

    def _add_annotation(self, t: s_types.Type):
        self.anno_buffer.append(CTYPE_ANNO_TYPENAME)

        self.anno_buffer.append(t.id.bytes)

        tn = t.get_displayname(self.schema)

        tn_bytes = tn.encode('utf-8')
        self.anno_buffer.append(_uint32_packer(len(tn_bytes)))
        self.anno_buffer.append(tn_bytes)

    @classmethod
    def describe(
        cls, schema, typ, view_shapes, view_shapes_metadata,
예제 #10
0
def static_interpret_backend_error(fields):
    err_details = get_error_details(fields)
    # handle some generic errors if possible
    err = get_generic_exception_from_err_details(err_details)
    if err is not None:
        return err

    if err_details.code == PGErrorCode.NotNullViolationError:
        if err_details.schema_name and err_details.table_name:
            return SchemaRequired

        else:
            return errors.InternalServerError(err_details.message)

    elif err_details.code in constraint_errors:
        source = pointer = None

        for errtype, ere in constraint_res.items():
            m = ere.match(err_details.message)
            if m:
                error_type = errtype
                break
        else:
            return errors.InternalServerError(err_details.message)

        if error_type == 'cardinality':
            return errors.CardinalityViolationError('cardinality violation',
                                                    source=source,
                                                    pointer=pointer)

        elif error_type == 'link_target':
            if err_details.detail_json:
                srcname = err_details.detail_json.get('source')
                ptrname = err_details.detail_json.get('pointer')
                target = err_details.detail_json.get('target')
                expected = err_details.detail_json.get('expected')

                if srcname and ptrname:
                    srcname = sn.Name(srcname)
                    ptrname = sn.Name(ptrname)
                    lname = '{}.{}'.format(srcname, ptrname.name)
                else:
                    lname = ''

                msg = (f'invalid target for link {lname!r}: {target!r} '
                       f'(expecting {expected!r})')

            else:
                msg = 'invalid target for link'

            return errors.UnknownLinkError(msg)

        elif error_type == 'link_target_del':
            return errors.ConstraintViolationError(err_details.message,
                                                   details=err_details.detail)

        elif error_type == 'constraint':
            if err_details.constraint_name is None:
                return errors.InternalServerError(err_details.message)

            constraint_id, _, _ = err_details.constraint_name.rpartition(';')

            try:
                constraint_id = uuid.UUID(constraint_id)
            except ValueError:
                return errors.InternalServerError(err_details.message)

            return SchemaRequired

        elif error_type == 'id':
            return errors.ConstraintViolationError(
                'unique link constraint violation')

    elif err_details.code in SCHEMA_CODES:
        return SchemaRequired

    elif err_details.code == PGErrorCode.InvalidParameterValue:
        return errors.InvalidValueError(
            err_details.message,
            details=err_details.detail if err_details.detail else None)

    elif err_details.code == PGErrorCode.DivisionByZeroError:
        return errors.DivisionByZeroError(err_details.message)

    elif err_details.code == PGErrorCode.ReadOnlySQLTransactionError:
        return errors.TransactionError(
            'cannot execute query in a read-only transaction')

    elif err_details.code == PGErrorCode.TransactionSerializationFailure:
        return errors.TransactionSerializationError(err_details.message)

    elif err_details.code == PGErrorCode.TransactionDeadlockDetected:
        return errors.TransactionDeadlockError(err_details.message)

    return errors.InternalServerError(err_details.message)
예제 #11
0
def interpret_backend_error(schema, fields):
    err_details = get_error_details(fields)
    # all generic errors are static and have been handled by this point

    if err_details.code == PGErrorCode.NotNullViolationError:
        source_name = pointer_name = None

        if err_details.schema_name and err_details.table_name:
            tabname = (err_details.schema_name, err_details.table_name)

            source = common.get_object_from_backend_name(
                schema, s_objtypes.ObjectType, tabname)
            source_name = source.get_displayname(schema)

            if err_details.column_name:
                pointer_name = err_details.column_name

        if pointer_name is not None:
            pname = f'{source_name}.{pointer_name}'

            return errors.MissingRequiredError(
                f'missing value for required property {pname}')

        else:
            return errors.InternalServerError(err_details.message)

    elif err_details.code in constraint_errors:
        error_type = None

        for errtype, ere in constraint_res.items():
            m = ere.match(err_details.message)
            if m:
                error_type = errtype
                break
        # no need for else clause since it would have been handled by
        # the static version

        # so far 'constraint' is the only expected error_type here,
        # but in the future that might change, so we leave the if
        if error_type == 'constraint':
            # similarly, if we're here it's because we have a constraint_id
            constraint_id, _, _ = err_details.constraint_name.rpartition(';')
            constraint_id = uuid.UUID(constraint_id)

            constraint = schema.get_by_id(constraint_id)

            return errors.ConstraintViolationError(
                constraint.format_error_message(schema))

    elif err_details.code == PGErrorCode.InvalidTextRepresentation:
        return errors.InvalidValueError(
            translate_pgtype(schema, err_details.message))

    elif err_details.code == PGErrorCode.NumericValueOutOfRange:
        return errors.NumericOutOfRangeError(
            translate_pgtype(schema, err_details.message))

    elif err_details.code in {
            PGErrorCode.InvalidDatetimeFormatError, PGErrorCode.DatetimeError
    }:
        return errors.InvalidValueError(
            translate_pgtype(schema, err_details.message))

    return errors.InternalServerError(err_details.message)
예제 #12
0
def interpret_backend_error(schema, fields):
    # See https://www.postgresql.org/docs/current/protocol-error-fields.html
    # for the full list of PostgreSQL error message fields.
    message = fields.get('M')

    try:
        code = PGError(fields['C'])
    except ValueError:
        return errors.InternalServerError(message)

    schema_name = fields.get('s')
    table_name = fields.get('t')
    column_name = fields.get('c')
    detail = fields.get('D')
    constraint_name = fields.get('n')

    if code == PGError.NotNullViolationError:
        source_name = pointer_name = None

        if schema_name and table_name:
            tabname = (schema_name, table_name)

            source = common.get_object_from_backend_name(
                schema, s_objtypes.ObjectType, tabname)
            source_name = source.get_displayname(schema)

            if column_name:
                pointer_name = column_name

        if pointer_name is not None:
            pname = f'{source_name}.{pointer_name}'

            return errors.MissingRequiredError(
                f'missing value for required property {pname}')

        else:
            return errors.InternalServerError(message)

    elif code in constraint_errors:
        source = pointer = None

        for type, ere in constraint_res.items():
            m = ere.match(message)
            if m:
                error_type = type
                break
        else:
            return errors.InternalServerError(message)

        if error_type == 'cardinality':
            return errors.CardinalityViolationError('cardinality violation',
                                                    source=source,
                                                    pointer=pointer)

        elif error_type == 'link_target':
            if detail:
                try:
                    detail = json.loads(detail)
                except ValueError:
                    detail = None

            if detail is not None:
                srcname = detail.get('source')
                ptrname = detail.get('pointer')
                target = detail.get('target')
                expected = detail.get('expected')

                if srcname and ptrname:
                    srcname = sn.Name(srcname)
                    ptrname = sn.Name(ptrname)
                    lname = '{}.{}'.format(srcname, ptrname.name)
                else:
                    lname = ''

                msg = (f'invalid target for link {lname!r}: {target!r} '
                       f'(expecting {expected!r})')

            else:
                msg = 'invalid target for link'

            return errors.UnknownLinkError(msg)

        elif error_type == 'link_target_del':
            return errors.ConstraintViolationError(message, details=detail)

        elif error_type == 'constraint':
            if constraint_name is None:
                return errors.InternalServerError(message)

            constraint_id, _, _ = constraint_name.rpartition(';')

            try:
                constraint_id = uuid.UUID(constraint_id)
            except ValueError:
                return errors.InternalServerError(message)

            constraint = schema.get_by_id(constraint_id)

            return errors.ConstraintViolationError(
                constraint.format_error_message(schema))

        elif error_type == 'id':
            return errors.ConstraintViolationError(
                'unique link constraint violation')

    elif code == PGError.NumericValueOutOfRange:
        return errors.NumericOutOfRangeError(message)

    return errors.InternalServerError(message)
예제 #13
0
def interpret_backend_error(schema, fields):
    err_details = get_error_details(fields)
    hint = None
    details = None
    if err_details.detail_json:
        hint = err_details.detail_json.get('hint')

    # all generic errors are static and have been handled by this point

    if err_details.code == PGErrorCode.NotNullViolationError:
        colname = err_details.column_name
        if colname:
            if colname.startswith('??'):
                ptr_id, *_ = colname[2:].partition('_')
            else:
                ptr_id = colname
            pointer = common.get_object_from_backend_name(
                schema, s_pointers.Pointer, ptr_id)
            pname = pointer.get_verbosename(schema, with_parent=True)
        else:
            pname = None

        if pname is not None:
            if err_details.detail_json:
                object_id = err_details.detail_json.get('object_id')
                if object_id is not None:
                    details = f'Failing object id is {str(object_id)!r}.'

            return errors.MissingRequiredError(
                f'missing value for required {pname}',
                details=details,
                hint=hint,
            )
        else:
            return errors.InternalServerError(err_details.message)

    elif err_details.code in constraint_errors:
        error_type = None
        match = None

        for errtype, ere in constraint_res.items():
            m = ere.match(err_details.message)
            if m:
                error_type = errtype
                match = m
                break
        # no need for else clause since it would have been handled by
        # the static version

        if error_type == 'constraint':
            # similarly, if we're here it's because we have a constraint_id
            constraint_id, _, _ = err_details.constraint_name.rpartition(';')
            constraint_id = uuidgen.UUID(constraint_id)

            constraint = schema.get_by_id(constraint_id)

            return errors.ConstraintViolationError(
                constraint.format_error_message(schema))
        elif error_type == 'newconstraint':
            # If we're here, it means that we already validated that
            # schema_name, table_name and column_name all exist.
            tabname = (err_details.schema_name, err_details.table_name)
            source = common.get_object_from_backend_name(
                schema, s_objtypes.ObjectType, tabname)
            source_name = source.get_displayname(schema)
            pointer = common.get_object_from_backend_name(
                schema, s_pointers.Pointer, err_details.column_name)
            pointer_name = pointer.get_shortname(schema).name

            return errors.ConstraintViolationError(
                f'Existing {source_name}.{pointer_name} '
                f'values violate the new constraint')
        elif error_type == 'scalar':
            domain_name = match.group(1)
            stype_name = types.base_type_name_map_r.get(domain_name)
            if stype_name:
                msg = f'invalid value for scalar type {str(stype_name)!r}'
            else:
                msg = translate_pgtype(schema, err_details.message)
            return errors.InvalidValueError(msg)

    elif err_details.code == PGErrorCode.InvalidTextRepresentation:
        return errors.InvalidValueError(
            translate_pgtype(schema, err_details.message))

    elif err_details.code == PGErrorCode.NumericValueOutOfRange:
        return errors.NumericOutOfRangeError(
            translate_pgtype(schema, err_details.message))

    elif err_details.code in {
            PGErrorCode.InvalidDatetimeFormatError, PGErrorCode.DatetimeError
    }:
        return errors.InvalidValueError(translate_pgtype(
            schema, err_details.message),
                                        hint=hint)

    return errors.InternalServerError(err_details.message)
예제 #14
0
def process_link_update(
    *,
    ir_stmt: irast.MutatingStmt,
    ir_set: irast.Set,
    props_only: bool,
    is_insert: bool,
    shape_op: qlast.ShapeOp = qlast.ShapeOp.ASSIGN,
    source_typeref: irast.TypeRef,
    wrapper: pgast.Query,
    dml_cte: pgast.CommonTableExpr,
    iterator_cte: Optional[pgast.CommonTableExpr],
    ctx: context.CompilerContextLevel,
) -> pgast.CommonTableExpr:
    """Perform updates to a link relation as part of a DML statement.

    :param ir_stmt:
        IR of the statement.
    :param ir_set:
        IR of the INSERT/UPDATE body element.
    :param props_only:
        Whether this link update only touches link properties.
    :param wrapper:
        Top-level SQL query.
    :param dml_cte:
        CTE representing the SQL INSERT or UPDATE to the main
        relation of the Object.
    :param iterator_cte:
        CTE representing the iterator range in the FOR clause of the
        EdgeQL DML statement.
    """
    toplevel = ctx.toplevel_stmt

    rptr = ir_set.rptr
    ptrref = rptr.ptrref
    assert isinstance(ptrref, irast.PointerRef)
    target_is_scalar = irtyputils.is_scalar(ir_set.typeref)
    path_id = ir_set.path_id

    # The links in the dml class shape have been derived,
    # but we must use the correct specialized link class for the
    # base material type.
    if ptrref.material_ptr is not None:
        mptrref = ptrref.material_ptr
    else:
        mptrref = ptrref

    if mptrref.out_source.id != source_typeref.id:
        for descendant in mptrref.descendants:
            if descendant.out_source.id == source_typeref.id:
                mptrref = descendant
                break
        else:
            raise errors.InternalServerError(
                'missing PointerRef descriptor for source typeref')

    assert isinstance(mptrref, irast.PointerRef)

    target_rvar = relctx.range_for_ptrref(mptrref,
                                          for_mutation=True,
                                          only_self=True,
                                          ctx=ctx)
    assert isinstance(target_rvar, pgast.RelRangeVar)
    assert isinstance(target_rvar.relation, pgast.Relation)
    target_alias = target_rvar.alias.aliasname

    target_tab_name = (target_rvar.relation.schemaname,
                       target_rvar.relation.name)

    dml_cte_rvar = pgast.RelRangeVar(
        relation=dml_cte,
        alias=pgast.Alias(aliasname=ctx.env.aliases.get('m')))

    col_data = {
        'ptr_item_id':
        pgast.TypeCast(arg=pgast.StringConstant(val=str(mptrref.id)),
                       type_name=pgast.TypeName(name=('uuid', ))),
        'source':
        pathctx.get_rvar_path_identity_var(dml_cte_rvar,
                                           ir_stmt.subject.path_id,
                                           env=ctx.env)
    }

    # Turn the IR of the expression on the right side of :=
    # into a subquery returning records for the link table.
    data_cte, specified_cols = process_link_values(
        ir_stmt=ir_stmt,
        ir_expr=ir_set,
        target_tab=target_tab_name,
        col_data=col_data,
        dml_rvar=dml_cte_rvar,
        sources=[],
        props_only=props_only,
        target_is_scalar=target_is_scalar,
        iterator_cte=iterator_cte,
        ctx=ctx,
    )

    toplevel.ctes.append(data_cte)

    delqry: Optional[pgast.DeleteStmt]

    data_select = pgast.SelectStmt(
        target_list=[
            pgast.ResTarget(val=pgast.ColumnRef(
                name=[data_cte.name, pgast.Star()]), ),
        ],
        from_clause=[
            pgast.RelRangeVar(relation=data_cte),
        ],
    )

    if not is_insert and shape_op is not qlast.ShapeOp.APPEND:
        if shape_op is qlast.ShapeOp.SUBTRACT:
            data_rvar = relctx.rvar_for_rel(data_select, ctx=ctx)

            # Drop requested link records.
            delqry = pgast.DeleteStmt(
                relation=target_rvar,
                where_clause=astutils.new_binop(
                    lexpr=astutils.new_binop(
                        lexpr=col_data['source'],
                        op='=',
                        rexpr=pgast.ColumnRef(name=[target_alias, 'source'], ),
                    ),
                    op='AND',
                    rexpr=astutils.new_binop(
                        lexpr=pgast.ColumnRef(name=[target_alias, 'target'], ),
                        op='=',
                        rexpr=pgast.ColumnRef(
                            name=[data_rvar.alias.aliasname, 'target'], ),
                    ),
                ),
                using_clause=[
                    dml_cte_rvar,
                    data_rvar,
                ],
                returning_list=[
                    pgast.ResTarget(val=pgast.ColumnRef(
                        name=[target_alias, pgast.Star()], ), )
                ])
        else:
            # Drop all previous link records for this source.
            delqry = pgast.DeleteStmt(
                relation=target_rvar,
                where_clause=astutils.new_binop(
                    lexpr=col_data['source'],
                    op='=',
                    rexpr=pgast.ColumnRef(name=[target_alias, 'source'], ),
                ),
                using_clause=[dml_cte_rvar],
                returning_list=[
                    pgast.ResTarget(val=pgast.ColumnRef(
                        name=[target_alias, pgast.Star()], ), )
                ])

        delcte = pgast.CommonTableExpr(
            name=ctx.env.aliases.get(hint='d'),
            query=delqry,
        )

        pathctx.put_path_value_rvar(delcte.query,
                                    path_id.ptr_path(),
                                    target_rvar,
                                    env=ctx.env)

        # Record the effect of this removal in the relation overlay
        # context to ensure that references to the link in the result
        # of this DML statement yield the expected results.
        dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx)
        relctx.add_ptr_rel_overlay(ptrref,
                                   'except',
                                   delcte,
                                   dml_stmts=dml_stack,
                                   ctx=ctx)
        toplevel.ctes.append(delcte)
    else:
        delqry = None

    if shape_op is qlast.ShapeOp.SUBTRACT:
        return data_cte

    cols = [pgast.ColumnRef(name=[col]) for col in specified_cols]
    conflict_cols = ['source', 'target', 'ptr_item_id']

    if is_insert:
        conflict_clause = None
    elif len(cols) == len(conflict_cols) and delqry is not None:
        # There are no link properties, so we can optimize the
        # link replacement operation by omitting the overlapping
        # link rows from deletion.
        filter_select = pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(val=pgast.ColumnRef(name=['source']), ),
                pgast.ResTarget(val=pgast.ColumnRef(name=['target']), ),
            ],
            from_clause=[pgast.RelRangeVar(relation=data_cte)],
        )

        delqry.where_clause = astutils.extend_binop(
            delqry.where_clause,
            astutils.new_binop(
                lexpr=pgast.ImplicitRowExpr(args=[
                    pgast.ColumnRef(name=['source']),
                    pgast.ColumnRef(name=['target']),
                ], ),
                rexpr=pgast.SubLink(
                    type=pgast.SubLinkType.ALL,
                    expr=filter_select,
                ),
                op='!=',
            ))

        conflict_clause = pgast.OnConflictClause(
            action='nothing',
            infer=pgast.InferClause(index_elems=[
                pgast.ColumnRef(name=[col]) for col in conflict_cols
            ]),
        )
    else:
        # Inserting rows into the link table may produce cardinality
        # constraint violations, since the INSERT into the link table
        # is executed in the snapshot where the above DELETE from
        # the link table is not visible.  Hence, we need to use
        # the ON CONFLICT clause to resolve this.
        conflict_inference = []
        conflict_exc_row = []

        for col in conflict_cols:
            conflict_inference.append(pgast.ColumnRef(name=[col]))
            conflict_exc_row.append(pgast.ColumnRef(name=['excluded', col]))

        conflict_data = pgast.SelectStmt(
            target_list=[
                pgast.ResTarget(val=pgast.ColumnRef(
                    name=[data_cte.name, pgast.Star()]))
            ],
            from_clause=[pgast.RelRangeVar(relation=data_cte)],
            where_clause=astutils.new_binop(
                lexpr=pgast.ImplicitRowExpr(args=conflict_inference),
                rexpr=pgast.ImplicitRowExpr(args=conflict_exc_row),
                op='='))

        conflict_clause = pgast.OnConflictClause(
            action='update',
            infer=pgast.InferClause(index_elems=conflict_inference),
            target_list=[
                pgast.MultiAssignRef(columns=cols, source=conflict_data)
            ])

    updcte = pgast.CommonTableExpr(
        name=ctx.env.aliases.get(hint='i'),
        query=pgast.InsertStmt(
            relation=target_rvar,
            select_stmt=data_select,
            cols=cols,
            on_conflict=conflict_clause,
            returning_list=[
                pgast.ResTarget(val=pgast.ColumnRef(name=[pgast.Star()]))
            ]))

    pathctx.put_path_value_rvar(updcte.query,
                                path_id.ptr_path(),
                                target_rvar,
                                env=ctx.env)

    # Record the effect of this insertion in the relation overlay
    # context to ensure that references to the link in the result
    # of this DML statement yield the expected results.
    dml_stack = get_dml_stmt_stack(ir_stmt, ctx=ctx)
    relctx.add_ptr_rel_overlay(ptrref,
                               'union',
                               updcte,
                               dml_stmts=dml_stack,
                               ctx=ctx)
    toplevel.ctes.append(updcte)

    return data_cte
예제 #15
0
def compile_DescribeStmt(
        ql: qlast.DescribeStmt, *, ctx: context.ContextLevel) -> irast.Set:
    with ctx.subquery() as ictx:
        stmt = irast.SelectStmt()
        init_stmt(stmt, ql, ctx=ictx, parent_ctx=ctx)

        if not ql.object:
            if ql.language is qltypes.DescribeLanguage.DDL:
                # DESCRIBE SCHEMA
                text = s_ddl.ddl_text_from_schema(
                    ctx.env.schema,
                )
            else:
                raise errors.QueryError(
                    f'cannot describe full schema as {ql.language}')
        else:
            modules = []
            items: List[str] = []
            referenced_classes: List[s_obj.ObjectMeta] = []

            objref = ql.object
            itemclass = objref.itemclass

            if itemclass is qltypes.SchemaObjectClass.MODULE:
                modules.append(objref.name)
            else:
                itemtype: Optional[Type[s_obj.Object]] = None
                found = False

                name: str
                if objref.module:
                    name = s_name.Name(module=objref.module, name=objref.name)
                else:
                    name = objref.name

                if itemclass is not None:
                    itemtype = (
                        s_obj.ObjectMeta.get_schema_metaclass_for_ql_class(
                            itemclass)
                    )

                if (itemclass is None or
                        itemclass is qltypes.SchemaObjectClass.FUNCTION):

                    try:
                        funcs: Tuple[s_func.Function, ...] = (
                            ictx.env.schema.get_functions(
                                name,
                                module_aliases=ictx.modaliases)
                        )
                    except errors.InvalidReferenceError:
                        pass
                    else:
                        for func in funcs:
                            items.append(func.get_name(ictx.env.schema))
                        found = True

                if not found:
                    obj = schemactx.get_schema_object(
                        objref,
                        item_type=itemtype,
                        ctx=ictx,
                    )

                    items.append(obj.get_name(ictx.env.schema))

            verbose = ql.options.get_flag('VERBOSE')

            if ql.language is qltypes.DescribeLanguage.DDL:
                method = s_ddl.ddl_text_from_schema
            elif ql.language is qltypes.DescribeLanguage.SDL:
                method = s_ddl.sdl_text_from_schema
            elif ql.language is qltypes.DescribeLanguage.TEXT:
                method = s_ddl.descriptive_text_from_schema
                if not verbose.val:
                    referenced_classes = [s_links.Link, s_lprops.Property]
            else:
                raise errors.InternalServerError(
                    f'cannot handle describe language {ql.language}'
                )

            text = method(
                ctx.env.schema,
                included_modules=modules,
                included_items=items,
                included_ref_classes=referenced_classes,
                include_module_ddl=False,
                include_std_ddl=True,
            )

        ct = typegen.type_to_typeref(
            ctx.env.get_track_schema_type('std::str'),
            env=ctx.env,
        )

        stmt.result = setgen.ensure_set(
            irast.StringConstant(value=text, typeref=ct),
            ctx=ictx,
        )

        result = fini_stmt(stmt, ql, ctx=ictx, parent_ctx=ctx)

    return result
예제 #16
0
    def _cmd_tree_from_ast(
        cls,
        schema: s_schema.Schema,
        astnode: qlast.DDLOperation,
        context: sd.CommandContext,
    ) -> sd.Command:
        cmd = super()._cmd_tree_from_ast(schema, astnode, context)

        if isinstance(cmd, sd.CommandGroup):
            for subcmd in cmd.get_subcommands():
                if isinstance(subcmd, cls):
                    create_cmd: sd.Command = subcmd
                    break
            else:
                raise errors.InternalServerError(
                    'scalar alias definition did not return CreateScalarType')
        else:
            create_cmd = cmd

        if isinstance(astnode, qlast.CreateScalarType):
            bases = [
                s_utils.ast_to_type_shell(
                    b,
                    metaclass=ScalarType,
                    modaliases=context.modaliases,
                    schema=schema,
                ) for b in (astnode.bases or [])
            ]
            is_enum = any(
                isinstance(br, AnonymousEnumTypeShell) for br in bases)

            # We don't support FINAL, but old dumps and migrations specify
            # it on enum CREATE SCALAR TYPEs, so we need to permit it in those
            # cases.
            if not is_enum and astnode.final:
                raise errors.UnsupportedFeatureError(
                    f'FINAL is not supported',
                    context=astnode.context,
                )

            if is_enum:
                # This is an enumerated type.
                if len(bases) > 1:
                    assert isinstance(astnode, qlast.BasesMixin)
                    raise errors.SchemaError(
                        f'invalid scalar type definition, enumeration must be'
                        f' the only supertype specified',
                        context=astnode.bases[0].context,
                    )
                if create_cmd.has_attribute_value('default'):
                    raise errors.UnsupportedFeatureError(
                        f'enumerated types do not support defaults',
                        context=(create_cmd.get_attribute_source_context(
                            'default')),
                    )

                shell = bases[0]
                assert isinstance(shell, AnonymousEnumTypeShell)
                if len(set(shell.elements)) != len(shell.elements):
                    raise errors.SchemaDefinitionError(
                        f'enums cannot contain duplicate values',
                        context=astnode.bases[0].context,
                    )
                create_cmd.set_attribute_value('enum_values', shell.elements)
                create_cmd.set_attribute_value(
                    'bases',
                    so.ObjectCollectionShell(
                        [
                            s_utils.ast_objref_to_object_shell(
                                s_utils.name_to_ast_ref(
                                    s_name.QualName('std', 'anyenum'), ),
                                schema=schema,
                                metaclass=ScalarType,
                                modaliases={},
                            )
                        ],
                        collection_type=so.ObjectList,
                    ))

        return cmd
예제 #17
0
def range_for_ptrref(
    ptrref: irast.BasePointerRef,
    *,
    dml_source: Optional[irast.MutatingStmt] = None,
    for_mutation: bool = False,
    only_self: bool = False,
    ctx: context.CompilerContextLevel,
) -> pgast.PathRangeVar:
    """"Return a Range subclass corresponding to a given ptr step.

    The return value may potentially be a UNION of all tables
    corresponding to a set of specialized links computed from the given
    `ptrref` taking source inheritance into account.
    """

    output_cols = ('source', 'target')

    set_ops = []

    if ptrref.union_components:
        refs = ptrref.union_components
        if only_self and len(refs) > 1:
            raise errors.InternalServerError('unexpected union link')
    else:
        refs = {ptrref}
        assert isinstance(ptrref, irast.PointerRef), \
            "expected regular PointerRef"
        overlays = get_ptr_rel_overlays(ptrref, dml_source=dml_source, ctx=ctx)

    for src_ptrref in refs:
        assert isinstance(src_ptrref, irast.PointerRef), \
            "expected regular PointerRef"

        # Most references to inline links are dispatched to a separate
        # code path (_new_inline_pointer_rvar) by new_pointer_rvar,
        # but when we have union pointers, some might be inline.  We
        # always use the link table if it exists (because this range
        # needs to contain any link properties, for one reason.)
        ptr_info = pg_types.get_ptrref_storage_info(
            src_ptrref,
            resolve_type=False,
            link_bias=True,
        )
        if not ptr_info:
            assert ptrref.union_components
            ptr_info = pg_types.get_ptrref_storage_info(
                src_ptrref,
                resolve_type=False,
                link_bias=False,
            )

        cols = [
            'source' if ptr_info.table_type == 'link' else 'id',
            ptr_info.column_name,
        ]

        table = table_from_ptrref(
            src_ptrref,
            ptr_info,
            include_descendants=not ptrref.union_is_concrete,
            for_mutation=for_mutation,
            ctx=ctx,
        )

        qry = pgast.SelectStmt()
        qry.from_clause.append(table)

        # Make sure all property references are pulled up properly
        for colname, output_colname in zip(cols, output_cols):
            selexpr = pgast.ColumnRef(name=[table.alias.aliasname, colname])
            qry.target_list.append(
                pgast.ResTarget(val=selexpr, name=output_colname))

        set_ops.append(('union', qry))

        overlays = get_ptr_rel_overlays(src_ptrref,
                                        dml_source=dml_source,
                                        ctx=ctx)
        if overlays and not for_mutation:
            for op, cte in overlays:
                rvar = pgast.RelRangeVar(
                    relation=cte,
                    alias=pgast.Alias(aliasname=ctx.env.aliases.get(cte.name)))

                qry = pgast.SelectStmt(
                    target_list=[
                        pgast.ResTarget(val=pgast.ColumnRef(name=[col]))
                        for col in cols
                    ],
                    from_clause=[rvar],
                )
                set_ops.append((op, qry))

    return range_from_queryset(set_ops, ptrref.shortname, ctx=ctx)
예제 #18
0
def compile_DescribeStmt(ql: qlast.DescribeStmt, *,
                         ctx: context.ContextLevel) -> irast.Set:
    with ctx.subquery() as ictx:
        stmt = irast.SelectStmt()
        init_stmt(stmt, ql, ctx=ictx, parent_ctx=ctx)

        if ql.object == qlast.DescribeGlobal.Schema:
            if ql.language is qltypes.DescribeLanguage.DDL:
                # DESCRIBE SCHEMA
                text = s_ddl.ddl_text_from_schema(ctx.env.schema, )
            else:
                raise errors.QueryError(
                    f'cannot describe full schema as {ql.language}')

            ct = typegen.type_to_typeref(
                ctx.env.get_track_schema_type('std::str'),
                env=ctx.env,
            )

            stmt.result = setgen.ensure_set(
                irast.StringConstant(value=text, typeref=ct),
                ctx=ictx,
            )

        elif ql.object == qlast.DescribeGlobal.SystemConfig:
            if ql.language is qltypes.DescribeLanguage.DDL:
                function_call = dispatch.compile(qlast.FunctionCall(
                    func=('cfg', '_describe_system_config_as_ddl'), ),
                                                 ctx=ictx)
                assert isinstance(function_call, irast.Set), function_call
                stmt.result = function_call
            else:
                raise errors.QueryError(
                    f'cannot describe config as {ql.language}')
        elif ql.object == qlast.DescribeGlobal.Roles:
            if ql.language is qltypes.DescribeLanguage.DDL:
                function_call = dispatch.compile(qlast.FunctionCall(
                    func=('sys', '_describe_roles_as_ddl'), ),
                                                 ctx=ictx)
                assert isinstance(function_call, irast.Set), function_call
                stmt.result = function_call
            else:
                raise errors.QueryError(
                    f'cannot describe roles as {ql.language}')
        else:
            assert isinstance(ql.object, qlast.ObjectRef), ql.object
            modules = []
            items: DefaultDict[str, List[str]] = defaultdict(list)
            referenced_classes: List[s_obj.ObjectMeta] = []

            objref = ql.object
            itemclass = objref.itemclass

            if itemclass is qltypes.SchemaObjectClass.MODULE:
                modules.append(objref.name)
            else:
                itemtype: Optional[Type[s_obj.Object]] = None

                name: str
                if objref.module:
                    name = s_name.Name(module=objref.module, name=objref.name)
                else:
                    name = objref.name

                if itemclass is not None:
                    if itemclass is qltypes.SchemaObjectClass.ALIAS:
                        # Look for underlying derived type.
                        itemtype = s_types.Type
                    else:
                        itemtype = (
                            s_obj.ObjectMeta.get_schema_metaclass_for_ql_class(
                                itemclass))

                last_exc = None
                # Search in the current namespace AND in std. We do
                # this to avoid masking a `std` object/function by one
                # in a default module.
                search_ns = [ictx.modaliases]
                # Only check 'std' separately if the current
                # modaliases don't already include it.
                if ictx.modaliases.get(None, 'std') != 'std':
                    search_ns.append({None: 'std'})

                # Search in the current namespace AND in std.
                for aliases in search_ns:
                    # Use the specific modaliases instead of the
                    # context ones.
                    with ictx.subquery() as newctx:
                        newctx.modaliases = aliases
                        # Get the default module name
                        modname = aliases[None]
                        # Is the current item a function
                        is_function = (itemclass is
                                       qltypes.SchemaObjectClass.FUNCTION)

                        # We need to check functions if we're looking for them
                        # specifically or if this is a broad search. They are
                        # handled separately because they allow multiple
                        # matches for the same name.
                        if (itemclass is None or is_function):
                            try:
                                funcs: Tuple[s_func.Function, ...] = (
                                    newctx.env.schema.get_functions(
                                        name, module_aliases=aliases))
                            except errors.InvalidReferenceError:
                                pass
                            else:
                                for func in funcs:
                                    items[f'function_{modname}'].append(
                                        func.get_name(newctx.env.schema))

                        # Also find an object matching the name as long as
                        # it's not a function we're looking for specifically.
                        if not is_function:
                            try:
                                if itemclass is not \
                                        qltypes.SchemaObjectClass.ALIAS:
                                    condition = None
                                    label = None
                                else:
                                    condition = (lambda obj: obj.
                                                 get_alias_is_persistent(
                                                     ctx.env.schema))
                                    label = 'alias'
                                obj = schemactx.get_schema_object(
                                    objref,
                                    item_type=itemtype,
                                    condition=condition,
                                    label=label,
                                    ctx=newctx,
                                )
                                items[f'other_{modname}'].append(
                                    obj.get_name(newctx.env.schema))
                            except errors.InvalidReferenceError as exc:
                                # Record the exception to be possibly
                                # raised if no matches are found
                                last_exc = exc

                # If we already have some results, suppress the exception,
                # otherwise raise the recorded exception.
                if not items and last_exc:
                    raise last_exc

            verbose = ql.options.get_flag('VERBOSE')

            if ql.language is qltypes.DescribeLanguage.DDL:
                method = s_ddl.ddl_text_from_schema
            elif ql.language is qltypes.DescribeLanguage.SDL:
                method = s_ddl.sdl_text_from_schema
            elif ql.language is qltypes.DescribeLanguage.TEXT:
                method = s_ddl.descriptive_text_from_schema
                if not verbose.val:
                    referenced_classes = [s_links.Link, s_lprops.Property]
            else:
                raise errors.InternalServerError(
                    f'cannot handle describe language {ql.language}')

            # Based on the items found generate main text and a
            # potential comment about masked items.
            defmod = ictx.modaliases.get(None, 'std')
            default_items = []
            masked_items = set()
            for objtype in ['function', 'other']:
                defkey = f'{objtype}_{defmod}'
                mskkey = f'{objtype}_std'

                default_items += items.get(defkey, [])
                if defkey in items and mskkey in items:
                    # We have a match in default module and some masked.
                    masked_items.update(items.get(mskkey, []))
                else:
                    default_items += items.get(mskkey, [])

            # Throw out anything in the masked set that's already in
            # the default.
            masked_items.difference_update(default_items)

            text = method(
                ctx.env.schema,
                included_modules=modules,
                included_items=default_items,
                included_ref_classes=referenced_classes,
                include_module_ddl=False,
                include_std_ddl=True,
            )
            if masked_items:
                text += ('\n\n'
                         '# The following builtins are masked by the above:'
                         '\n\n')
                masked = method(
                    ctx.env.schema,
                    included_modules=modules,
                    included_items=masked_items,
                    included_ref_classes=referenced_classes,
                    include_module_ddl=False,
                    include_std_ddl=True,
                )
                masked = textwrap.indent(masked, '# ')
                text += masked

            ct = typegen.type_to_typeref(
                ctx.env.get_track_schema_type('std::str'),
                env=ctx.env,
            )

            stmt.result = setgen.ensure_set(
                irast.StringConstant(value=text, typeref=ct),
                ctx=ictx,
            )

        result = fini_stmt(stmt, ql, ctx=ictx, parent_ctx=ctx)

    return result
예제 #19
0
def _process_view(*,
                  stype: s_objtypes.ObjectType,
                  path_id: irast.PathId,
                  path_id_namespace: Optional[irast.WeakNamespace] = None,
                  elements: List[qlast.ShapeElement],
                  view_rptr: Optional[context.ViewRPtr] = None,
                  view_name: Optional[sn.SchemaName] = None,
                  is_insert: bool = False,
                  is_update: bool = False,
                  ctx: context.ContextLevel) -> s_objtypes.ObjectType:

    if (view_name is None and ctx.env.schema_view_mode
            and view_rptr is not None):
        # Make sure persistent schema expression aliases have properly formed
        # names as opposed to the usual mangled form of the ephemeral
        # aliases.  This is needed for introspection readability, as well
        # as helps in maintaining proper type names for schema
        # representations that require alphanumeric names, such as
        # GraphQL.
        #
        # We use the name of the source together with the name
        # of the inbound link to form the name, so in e.g.
        #    CREATE ALIAS V := (SELECT Foo { bar: { baz: { ... } })
        # The name of the innermost alias would be "__V__bar__baz".
        source_name = view_rptr.source.get_name(ctx.env.schema).name
        if not source_name.startswith('__'):
            source_name = f'__{source_name}'
        if view_rptr.ptrcls_name is not None:
            ptr_name = view_rptr.ptrcls_name.name
        elif view_rptr.ptrcls is not None:
            ptr_name = view_rptr.ptrcls.get_shortname(ctx.env.schema).name
        else:
            raise errors.InternalServerError(
                '_process_view in schema mode received view_rptr with '
                'neither ptrcls_name, not ptrcls')

        name = f'{source_name}__{ptr_name}'
        view_name = sn.Name(
            module=ctx.derived_target_module or '__derived__',
            name=name,
        )

    view_scls = schemactx.derive_view(stype,
                                      is_insert=is_insert,
                                      is_update=is_update,
                                      derived_name=view_name,
                                      ctx=ctx)
    assert isinstance(view_scls, s_objtypes.ObjectType)
    is_mutation = is_insert or is_update
    is_defining_shape = ctx.expr_exposed or is_mutation

    if view_rptr is not None and view_rptr.ptrcls is None:
        derive_ptrcls(view_rptr,
                      target_scls=view_scls,
                      transparent=True,
                      ctx=ctx)

    pointers = []

    for shape_el in elements:
        with ctx.newscope(fenced=True) as scopectx:
            pointers.append(
                _normalize_view_ptr_expr(shape_el,
                                         view_scls,
                                         path_id=path_id,
                                         path_id_namespace=path_id_namespace,
                                         is_insert=is_insert,
                                         is_update=is_update,
                                         view_rptr=view_rptr,
                                         ctx=scopectx))

    if is_insert:
        assert isinstance(stype, s_objtypes.ObjectType)
        explicit_ptrs = {
            ptrcls.get_shortname(ctx.env.schema).name
            for ptrcls in pointers
        }

        scls_pointers = stype.get_pointers(ctx.env.schema)
        for pn, ptrcls in scls_pointers.items(ctx.env.schema):
            if (pn in explicit_ptrs
                    or ptrcls.is_pure_computable(ctx.env.schema)):
                continue

            if not ptrcls.get_default(ctx.env.schema):
                if ptrcls.get_required(ctx.env.schema):
                    if ptrcls.is_property(ctx.env.schema):
                        # If the target is a sequence, there's no need
                        # for an explicit value.
                        if ptrcls.get_target(ctx.env.schema).issubclass(
                                ctx.env.schema,
                                ctx.env.schema.get('std::sequence')):
                            continue

                        what = 'property'
                    else:
                        what = 'link'
                    raise errors.MissingRequiredError(
                        f'missing value for required {what} '
                        f'{stype.get_displayname(ctx.env.schema)}.'
                        f'{ptrcls.get_displayname(ctx.env.schema)}')
                else:
                    continue

            ptrcls_sn = ptrcls.get_shortname(ctx.env.schema)
            default_ql = qlast.ShapeElement(expr=qlast.Path(steps=[
                qlast.Ptr(ptr=qlast.ObjectRef(name=ptrcls_sn.name,
                                              module=ptrcls_sn.module))
            ]))

            with ctx.newscope(fenced=True) as scopectx:
                pointers.append(
                    _normalize_view_ptr_expr(
                        default_ql,
                        view_scls,
                        path_id=path_id,
                        path_id_namespace=path_id_namespace,
                        is_insert=is_insert,
                        is_update=is_update,
                        view_rptr=view_rptr,
                        ctx=scopectx))

    for ptrcls in pointers:
        source: Union[s_types.Type, s_pointers.PointerLike]

        if ptrcls.is_link_property(ctx.env.schema):
            assert view_rptr is not None and view_rptr.ptrcls is not None
            source = view_rptr.ptrcls
        else:
            source = view_scls

        if is_defining_shape:
            ctx.env.view_shapes[source].append(ptrcls)

    if (view_rptr is not None and view_rptr.ptrcls is not None
            and view_scls is not stype):
        ctx.env.schema = view_scls.set_field_value(ctx.env.schema, 'rptr',
                                                   view_rptr.ptrcls)

    return view_scls
예제 #20
0
파일: scalars.py 프로젝트: backwardn/edgedb
    def _cmd_tree_from_ast(
        cls,
        schema: s_schema.Schema,
        astnode: qlast.DDLOperation,
        context: sd.CommandContext,
    ) -> sd.Command:
        cmd = super()._cmd_tree_from_ast(schema, astnode, context)

        if isinstance(cmd, sd.CommandGroup):
            for subcmd in cmd.get_subcommands():
                if isinstance(subcmd, cls):
                    create_cmd: sd.Command = subcmd
                    break
            else:
                raise errors.InternalServerError(
                    'scalar alias definition did not return CreateScalarType'
                )
        else:
            create_cmd = cmd

        if isinstance(astnode, qlast.CreateScalarType):
            bases = [
                s_utils.ast_to_type_shell(
                    b,
                    modaliases=context.modaliases,
                    schema=schema,
                )
                for b in astnode.bases
            ]

            if any(isinstance(br, AnonymousEnumTypeShell) for br in bases):
                # This is an enumerated type.
                if len(bases) > 1:
                    assert isinstance(astnode, qlast.BasesMixin)
                    raise errors.SchemaError(
                        f'invalid scalar type definition, enumeration must be'
                        f' the only supertype specified',
                        context=astnode.bases[0].context,
                    )
                deflt = create_cmd.get_attribute_set_cmd('default')
                if deflt is not None:
                    raise errors.UnsupportedFeatureError(
                        f'enumerated types do not support defaults',
                        context=deflt.source_context,
                    )

                shell = bases[0]
                assert isinstance(shell, AnonymousEnumTypeShell)
                create_cmd.set_attribute_value('enum_values', shell.elements)
                create_cmd.set_attribute_value('is_final', True)
                create_cmd.set_attribute_value('bases', [
                    s_utils.ast_objref_to_object_shell(
                        s_utils.name_to_ast_ref(
                            s_name.Name('std::anyenum'),
                        ),
                        schema=schema,
                        metaclass=ScalarType,
                        modaliases={},
                    )
                ])

        return cmd
예제 #21
0
파일: viewgen.py 프로젝트: sbdchd/edgedb
def _process_view(
    *,
    stype: s_objtypes.ObjectType,
    path_id: irast.PathId,
    path_id_namespace: Optional[irast.WeakNamespace] = None,
    elements: List[qlast.ShapeElement],
    view_rptr: Optional[context.ViewRPtr] = None,
    view_name: Optional[sn.QualName] = None,
    is_insert: bool = False,
    is_update: bool = False,
    is_delete: bool = False,
    parser_context: pctx.ParserContext,
    ctx: context.ContextLevel,
) -> s_objtypes.ObjectType:

    if (view_name is None and ctx.env.options.schema_view_mode
            and view_rptr is not None):
        # Make sure persistent schema expression aliases have properly formed
        # names as opposed to the usual mangled form of the ephemeral
        # aliases.  This is needed for introspection readability, as well
        # as helps in maintaining proper type names for schema
        # representations that require alphanumeric names, such as
        # GraphQL.
        #
        # We use the name of the source together with the name
        # of the inbound link to form the name, so in e.g.
        #    CREATE ALIAS V := (SELECT Foo { bar: { baz: { ... } })
        # The name of the innermost alias would be "__V__bar__baz".
        source_name = view_rptr.source.get_name(ctx.env.schema).name
        if not source_name.startswith('__'):
            source_name = f'__{source_name}'
        if view_rptr.ptrcls_name is not None:
            ptr_name = view_rptr.ptrcls_name.name
        elif view_rptr.ptrcls is not None:
            ptr_name = view_rptr.ptrcls.get_shortname(ctx.env.schema).name
        else:
            raise errors.InternalServerError(
                '_process_view in schema mode received view_rptr with '
                'neither ptrcls_name, not ptrcls'
            )

        name = f'{source_name}__{ptr_name}'
        view_name = sn.QualName(
            module=ctx.derived_target_module or '__derived__',
            name=name,
        )

    view_scls = schemactx.derive_view(
        stype,
        is_insert=is_insert,
        is_update=is_update,
        is_delete=is_delete,
        derived_name=view_name,
        ctx=ctx,
    )
    assert isinstance(view_scls, s_objtypes.ObjectType), view_scls
    is_mutation = is_insert or is_update
    is_defining_shape = ctx.expr_exposed or is_mutation

    if view_rptr is not None and view_rptr.ptrcls is None:
        derive_ptrcls(
            view_rptr, target_scls=view_scls,
            transparent=True, ctx=ctx)

    pointers = []

    for shape_el in elements:
        with ctx.newscope(fenced=True) as scopectx:
            pointer = _normalize_view_ptr_expr(
                shape_el, view_scls, path_id=path_id,
                path_id_namespace=path_id_namespace,
                is_insert=is_insert, is_update=is_update,
                view_rptr=view_rptr,
                ctx=scopectx)

            if pointer in pointers:
                schema = ctx.env.schema
                vnp = pointer.get_verbosename(schema, with_parent=True)

                raise errors.QueryError(
                    f'duplicate definition of {vnp}',
                    context=shape_el.context)

            pointers.append(pointer)

    if is_insert:
        explicit_ptrs = {
            ptrcls.get_local_name(ctx.env.schema)
            for ptrcls in pointers
        }
        scls_pointers = stype.get_pointers(ctx.env.schema)
        for pn, ptrcls in scls_pointers.items(ctx.env.schema):
            if (pn in explicit_ptrs or
                    ptrcls.is_pure_computable(ctx.env.schema)):
                continue

            default_expr = ptrcls.get_default(ctx.env.schema)
            if not default_expr:
                if (
                    ptrcls.get_required(ctx.env.schema)
                    and pn != sn.UnqualName('__type__')
                ):
                    if ptrcls.is_property(ctx.env.schema):
                        # If the target is a sequence, there's no need
                        # for an explicit value.
                        ptrcls_target = ptrcls.get_target(ctx.env.schema)
                        assert ptrcls_target is not None
                        if ptrcls_target.issubclass(
                                ctx.env.schema,
                                ctx.env.schema.get(
                                    'std::sequence',
                                    type=s_objects.SubclassableObject)):
                            continue
                    vn = ptrcls.get_verbosename(
                        ctx.env.schema, with_parent=True)
                    raise errors.MissingRequiredError(
                        f'missing value for required {vn}')
                else:
                    continue

            ptrcls_sn = ptrcls.get_shortname(ctx.env.schema)
            default_ql = qlast.ShapeElement(
                expr=qlast.Path(
                    steps=[
                        qlast.Ptr(
                            ptr=qlast.ObjectRef(
                                name=ptrcls_sn.name,
                                module=ptrcls_sn.module,
                            ),
                        ),
                    ],
                ),
                compexpr=qlast.DetachedExpr(
                    expr=default_expr.qlast,
                ),
            )

            with ctx.newscope(fenced=True) as scopectx:
                pointers.append(
                    _normalize_view_ptr_expr(
                        default_ql,
                        view_scls,
                        path_id=path_id,
                        path_id_namespace=path_id_namespace,
                        is_insert=is_insert,
                        is_update=is_update,
                        from_default=True,
                        view_rptr=view_rptr,
                        ctx=scopectx,
                    ),
                )

    elif (
        stype.get_name(ctx.env.schema).module == 'schema'
        and ctx.env.options.apply_query_rewrites
    ):
        explicit_ptrs = {
            ptrcls.get_local_name(ctx.env.schema)
            for ptrcls in pointers
        }
        scls_pointers = stype.get_pointers(ctx.env.schema)
        for pn, ptrcls in scls_pointers.items(ctx.env.schema):
            if (
                pn in explicit_ptrs
                or ptrcls.is_pure_computable(ctx.env.schema)
            ):
                continue

            schema_deflt = ptrcls.get_schema_reflection_default(ctx.env.schema)
            if schema_deflt is None:
                continue

            with ctx.newscope(fenced=True) as scopectx:
                ptr_ref = s_utils.name_to_ast_ref(pn)
                implicit_ql = qlast.ShapeElement(
                    expr=qlast.Path(steps=[qlast.Ptr(ptr=ptr_ref)]),
                    compexpr=qlast.BinOp(
                        left=qlast.Path(
                            partial=True,
                            steps=[
                                qlast.Ptr(
                                    ptr=ptr_ref,
                                    direction=(
                                        s_pointers.PointerDirection.Outbound
                                    ),
                                )
                            ],
                        ),
                        right=qlparser.parse_fragment(schema_deflt),
                        op='??',
                    ),
                )

                # Note: we only need to record the schema default
                # as a computable, but not include it in the type
                # shape, so we ignore the return value.
                _normalize_view_ptr_expr(
                    implicit_ql,
                    view_scls,
                    path_id=path_id,
                    path_id_namespace=path_id_namespace,
                    is_insert=is_insert,
                    is_update=is_update,
                    view_rptr=view_rptr,
                    ctx=scopectx,
                )

    for ptrcls in pointers:
        source: Union[s_types.Type, s_pointers.PointerLike]

        if ptrcls.is_link_property(ctx.env.schema):
            assert view_rptr is not None and view_rptr.ptrcls is not None
            source = view_rptr.ptrcls
        else:
            source = view_scls

        if is_defining_shape:
            cinfo = ctx.source_map.get(ptrcls)
            if cinfo is not None:
                shape_op = cinfo.shape_op
            else:
                shape_op = qlast.ShapeOp.ASSIGN

            ctx.env.view_shapes[source].append((ptrcls, shape_op))

    if (view_rptr is not None and view_rptr.ptrcls is not None and
            view_scls != stype):
        ctx.env.schema = view_scls.set_field_value(
            ctx.env.schema, 'rptr', view_rptr.ptrcls)

    return view_scls
예제 #22
0
def computable_ptr_set(
    rptr: irast.Pointer,
    *,
    unnest_fence: bool = False,
    same_computable_scope: bool = False,
    srcctx: Optional[parsing.ParserContext] = None,
    ctx: context.ContextLevel,
) -> irast.Set:
    """Return ir.Set for a pointer defined as a computable."""
    ptrcls = typegen.ptrcls_from_ptrref(rptr.ptrref, ctx=ctx)
    source_set = rptr.source
    source_scls = get_set_type(source_set, ctx=ctx)
    # process_view() may generate computable pointer expressions
    # in the form "self.linkname".  To prevent infinite recursion,
    # self must resolve to the parent type of the view NOT the view
    # type itself.  Similarly, when resolving computable link properties
    # make sure that we use the parent of derived ptrcls.
    if source_scls.is_view(ctx.env.schema):
        source_set_stype = source_scls.peel_view(ctx.env.schema)
        source_set = new_set_from_set(source_set,
                                      stype=source_set_stype,
                                      preserve_scope_ns=True,
                                      ctx=ctx)
        source_set.shape = []
        if source_set.rptr is not None:
            source_rptrref = source_set.rptr.ptrref
            if source_rptrref.base_ptr is not None:
                source_rptrref = source_rptrref.base_ptr
            source_set.rptr = irast.Pointer(
                source=source_set.rptr.source,
                target=source_set,
                ptrref=source_rptrref,
                direction=source_set.rptr.direction,
            )

    qlctx: Optional[context.ContextLevel]
    inner_source_path_id: Optional[irast.PathId]

    try:
        comp_info = ctx.source_map[ptrcls]
        qlexpr = comp_info.qlexpr
        assert isinstance(comp_info.context, context.ContextLevel)
        qlctx = comp_info.context
        inner_source_path_id = comp_info.path_id
        path_id_ns = comp_info.path_id_ns
    except KeyError:
        comp_expr = ptrcls.get_expr(ctx.env.schema)
        schema_qlexpr: Optional[qlast.Expr] = None
        if comp_expr is None and ctx.env.options.apply_query_rewrites:
            schema_deflt = ptrcls.get_schema_reflection_default(ctx.env.schema)
            if schema_deflt is not None:
                assert isinstance(ptrcls, s_pointers.Pointer)
                ptrcls_n = ptrcls.get_shortname(ctx.env.schema).name
                schema_qlexpr = qlast.BinOp(
                    left=qlast.Path(steps=[
                        qlast.Source(),
                        qlast.Ptr(
                            ptr=qlast.ObjectRef(name=ptrcls_n),
                            direction=s_pointers.PointerDirection.Outbound,
                            type=('property' if ptrcls.is_link_property(
                                ctx.env.schema) else None))
                    ], ),
                    right=qlparser.parse_fragment(schema_deflt),
                    op='??',
                )

        if schema_qlexpr is None:
            if comp_expr is None:
                ptrcls_sn = ptrcls.get_shortname(ctx.env.schema)
                raise errors.InternalServerError(
                    f'{ptrcls_sn!r} is not a computable pointer')

            comp_qlexpr = qlparser.parse(comp_expr.text)
            assert isinstance(comp_qlexpr, qlast.Expr), 'expected qlast.Expr'
            schema_qlexpr = comp_qlexpr

        # NOTE: Validation of the expression type is not the concern
        # of this function. For any non-object pointer target type,
        # the default expression must be assignment-cast into that
        # type.
        target_scls = ptrcls.get_target(ctx.env.schema)
        assert target_scls is not None
        if not target_scls.is_object_type():
            schema_qlexpr = qlast.TypeCast(
                type=typegen.type_to_ql_typeref(target_scls, ctx=ctx),
                expr=schema_qlexpr,
            )
        qlexpr = astutils.ensure_qlstmt(schema_qlexpr)
        qlctx = None
        inner_source_path_id = None
        path_id_ns = None

    newctx: Callable[[], ContextManager[context.ContextLevel]]

    if qlctx is None:
        # Schema-level computable, completely detached context
        newctx = ctx.detached
    else:
        newctx = _get_computable_ctx(rptr=rptr,
                                     source=source_set,
                                     source_scls=source_scls,
                                     inner_source_path_id=inner_source_path_id,
                                     path_id_ns=path_id_ns,
                                     same_scope=same_computable_scope,
                                     qlctx=qlctx,
                                     ctx=ctx)

    if ptrcls.is_link_property(ctx.env.schema):
        source_path_id = rptr.source.path_id.ptr_path()
    else:
        src_path = rptr.target.path_id.src_path()
        assert src_path is not None
        source_path_id = src_path

    result_path_id = pathctx.extend_path_id(
        source_path_id,
        ptrcls=ptrcls,
        ns=ctx.path_id_namespace,
        ctx=ctx,
    )

    result_stype = ptrcls.get_target(ctx.env.schema)
    base_object = ctx.env.schema.get('std::BaseObject', type=s_types.Type)
    with newctx() as subctx:
        subctx.disable_shadowing.add(ptrcls)
        if result_stype != base_object:
            subctx.view_scls = result_stype
        subctx.view_rptr = context.ViewRPtr(source_scls,
                                            ptrcls=ptrcls,
                                            rptr=rptr)  # type: ignore
        subctx.anchors[qlast.Source().name] = source_set
        subctx.empty_result_type_hint = ptrcls.get_target(ctx.env.schema)
        subctx.partial_path_prefix = source_set
        # On a mutation, make the expr_exposed. This corresponds with
        # a similar check on is_mutation in _normalize_view_ptr_expr.
        if (source_scls.get_expr_type(ctx.env.schema) !=
                s_types.ExprType.Select):
            subctx.expr_exposed = True

        if isinstance(qlexpr, qlast.Statement):
            subctx.stmt_metadata[qlexpr] = context.StatementMetadata(
                is_unnest_fence=unnest_fence,
                iterator_target=True,
            )

        comp_ir_set = ensure_set(dispatch.compile(qlexpr, ctx=subctx),
                                 ctx=subctx)

    comp_ir_set = new_set_from_set(comp_ir_set,
                                   path_id=result_path_id,
                                   rptr=rptr,
                                   context=srcctx,
                                   ctx=ctx)

    rptr.target = comp_ir_set

    return comp_ir_set
예제 #23
0
    def apply(self, spec: spec.Spec,
              storage: typing.Mapping) -> typing.Mapping:

        setting = self.get_setting(spec)
        allow_missing = (self.opcode is OpCode.CONFIG_REM
                         or self.opcode is OpCode.CONFIG_RESET)

        value = self.coerce_value(setting, allow_missing=allow_missing)

        if self.opcode is OpCode.CONFIG_SET:
            if issubclass(setting.type, types.ConfigType):
                raise errors.InternalServerError(
                    f'unexpected CONFIGURE SET on a non-primitive '
                    f'configuration parameter: {self.setting_name}')

            storage = storage.set(self.setting_name, value)

        elif self.opcode is OpCode.CONFIG_RESET:
            if issubclass(setting.type, types.ConfigType):
                raise errors.InternalServerError(
                    f'unexpected CONFIGURE RESET on a non-primitive '
                    f'configuration parameter: {self.setting_name}')

            try:
                storage = storage.delete(self.setting_name)
            except KeyError:
                pass

        elif self.opcode is OpCode.CONFIG_ADD:
            if not issubclass(setting.type, types.ConfigType):
                raise errors.InternalServerError(
                    f'unexpected CONFIGURE SET += on a primitive '
                    f'configuration parameter: {self.setting_name}')

            exist_value = storage.get(self.setting_name, setting.default)
            if value in exist_value:
                props = []
                for f in dataclasses.fields(setting.type):
                    if f.compare:
                        props.append(f.name)

                if len(props) > 1:
                    props = f' ({", ".join(props)}) violate'
                else:
                    props = f'.{props[0]} violates'

                raise errors.ConstraintViolationError(
                    f'{setting.type.__name__}{props} '
                    f'exclusivity constriant')

            new_value = exist_value | {value}
            storage = storage.set(self.setting_name, new_value)

        elif self.opcode is OpCode.CONFIG_REM:
            if not issubclass(setting.type, types.ConfigType):
                raise errors.InternalServerError(
                    f'unexpected CONFIGURE SET -= on a primitive '
                    f'configuration parameter: {self.setting_name}')

            exist_value = storage.get(self.setting_name, setting.default)
            new_value = exist_value - {value}
            storage = storage.set(self.setting_name, new_value)

        return storage
예제 #24
0
파일: func.py 프로젝트: fantix/edgedb
def compile_operator(
        qlexpr: qlast.Base, op_name: str, qlargs: List[qlast.Base], *,
        ctx: context.ContextLevel) -> irast.Set:

    env = ctx.env
    schema = env.schema
    opers = schema.get_operators(op_name, module_aliases=ctx.modaliases)

    if opers is None:
        raise errors.QueryError(
            f'no operator matches the given name and argument types',
            context=qlexpr.context)

    fq_op_name = next(iter(opers)).get_shortname(ctx.env.schema)
    conditional_args = CONDITIONAL_OPS.get(fq_op_name)

    arg_ctxs = {}
    args = []
    for ai, qlarg in enumerate(qlargs):
        with ctx.newscope(fenced=True) as fencectx:
            fencectx.path_log = []
            # We put on a SET OF fence preemptively in case this is
            # a SET OF arg, which we don't know yet due to polymorphic
            # matching.  We will remove it if necessary in `finalize_args()`.
            if conditional_args and ai in conditional_args:
                fencectx.in_conditional = qlexpr.context

            arg_ir = setgen.ensure_set(
                dispatch.compile(qlarg, ctx=fencectx),
                ctx=fencectx)

            arg_ir = setgen.scoped_set(
                setgen.ensure_stmt(arg_ir, ctx=fencectx),
                ctx=fencectx)

            arg_ctxs[arg_ir] = fencectx

        arg_type = inference.infer_type(arg_ir, ctx.env)
        if arg_type is None:
            raise errors.QueryError(
                f'could not resolve the type of operand '
                f'#{ai} of {op_name}',
                context=qlarg.context)

        args.append((arg_type, arg_ir))

    # Check if the operator is a derived operator, and if so,
    # find the origins.
    origin_op = opers[0].get_derivative_of(env.schema)
    derivative_op: Optional[s_oper.Operator]
    if origin_op is not None:
        # If this is a derived operator, there should be
        # exactly one form of it.  This is enforced at the DDL
        # level, but check again to be sure.
        if len(opers) > 1:
            raise errors.InternalServerError(
                f'more than one derived operator of the same name: {op_name}',
                context=qlarg.context)

        derivative_op = opers[0]
        opers = schema.get_operators(origin_op)
        if not opers:
            raise errors.InternalServerError(
                f'cannot find the origin operator for {op_name}',
                context=qlarg.context)
        actual_typemods = [
            param.get_typemod(schema)
            for param in derivative_op.get_params(schema).objects(schema)
        ]
    else:
        derivative_op = None
        actual_typemods = []

    matched = None
    # Some 2-operand operators are special when their operands are
    # arrays or tuples.
    if len(args) == 2:
        coll_opers = None
        # If both of the args are arrays or tuples, potentially
        # compile the operator for them differently than for other
        # combinations.
        if args[0][0].is_tuple(env.schema) and args[1][0].is_tuple(env.schema):
            # Out of the candidate operators, find the ones that
            # correspond to tuples.
            coll_opers = [
                op for op in opers
                if all(
                    param.get_type(schema).is_tuple(schema)
                    for param in op.get_params(schema).objects(schema)
                )
            ]

        elif args[0][0].is_array() and args[1][0].is_array():
            # Out of the candidate operators, find the ones that
            # correspond to arrays.
            coll_opers = [
                op for op in opers
                if all(
                    param.get_type(schema).is_array()
                    for param in op.get_params(schema).objects(schema)
                )
            ]

        # Proceed only if we have a special case of collection operators.
        if coll_opers:
            # Then check if they are recursive (i.e. validation must be
            # done recursively for the subtypes). We rely on the fact that
            # it is forbidden to define an operator that has both
            # recursive and non-recursive versions.
            if not coll_opers[0].get_recursive(schema):
                # The operator is non-recursive, so regular processing
                # is needed.
                matched = polyres.find_callable(
                    coll_opers, args=args, kwargs={}, ctx=ctx)

            else:
                # The recursive operators are usually defined as
                # being polymorphic on all parameters, and so this has
                # a side-effect of forcing both operands to be of
                # the same type (via casting) before the operator is
                # applied.  This might seem suboptmial, since there might
                # be a more specific operator for the types of the
                # elements, but the current version of Postgres
                # actually requires tuples and arrays to be of the
                # same type in comparison, so this behavior is actually
                # what we want.
                matched = polyres.find_callable(
                    coll_opers,
                    args=args,
                    kwargs={},
                    ctx=ctx,
                )

                # Now that we have an operator, we need to validate that it
                # can be applied to the tuple or array elements.
                submatched = validate_recursive_operator(
                    opers, args[0], args[1], ctx=ctx)

                if len(submatched) != 1:
                    # This is an error. We want the error message to
                    # reflect whether no matches were found or too
                    # many, so we preserve the submatches found for
                    # this purpose.
                    matched = submatched

    # No special handling match was necessary, find a normal match.
    if matched is None:
        matched = polyres.find_callable(opers, args=args, kwargs={}, ctx=ctx)

    in_polymorphic_func = (
        ctx.env.options.func_params is not None and
        ctx.env.options.func_params.has_polymorphic(env.schema)
    )

    in_abstract_constraint = (
        in_polymorphic_func and
        ctx.env.options.schema_object_context is s_constr.Constraint
    )

    if not in_polymorphic_func:
        matched = [call for call in matched
                   if not call.func.get_abstract(env.schema)]

    if len(matched) == 1:
        matched_call = matched[0]
    else:
        if len(args) == 2:
            ltype = schemactx.get_material_type(args[0][0], ctx=ctx)
            rtype = schemactx.get_material_type(args[1][0], ctx=ctx)

            types = (
                f'{ltype.get_displayname(env.schema)!r} and '
                f'{rtype.get_displayname(env.schema)!r}')
        else:
            types = ', '.join(
                repr(
                    schemactx.get_material_type(
                        a[0], ctx=ctx).get_displayname(env.schema)
                ) for a in args
            )

        if not matched:
            hint = ('Consider using an explicit type cast or a conversion '
                    'function.')

            if op_name == 'std::IF':
                hint = (f"The IF and ELSE result clauses must be of "
                        f"compatible types, while the condition clause must "
                        f"be 'std::bool'. {hint}")
            elif op_name == '+':
                str_t = cast(s_scalars.ScalarType,
                             env.schema.get('std::str'))
                bytes_t = cast(s_scalars.ScalarType,
                               env.schema.get('std::bytes'))
                if (
                    (ltype.issubclass(env.schema, str_t) and
                        rtype.issubclass(env.schema, str_t)) or
                    (ltype.issubclass(env.schema, bytes_t) and
                        rtype.issubclass(env.schema, bytes_t)) or
                    (ltype.is_array() and rtype.is_array())
                ):
                    hint = 'Consider using the "++" operator for concatenation'

            raise errors.QueryError(
                f'operator {str(op_name)!r} cannot be applied to '
                f'operands of type {types}',
                hint=hint,
                context=qlexpr.context)
        elif len(matched) > 1:
            if in_abstract_constraint:
                matched_call = matched[0]
            else:
                detail = ', '.join(
                    f'`{m.func.get_verbosename(ctx.env.schema)}`'
                    for m in matched
                )
                raise errors.QueryError(
                    f'operator {str(op_name)!r} is ambiguous for '
                    f'operands of type {types}',
                    hint=f'Possible variants: {detail}.',
                    context=qlexpr.context)

    oper = matched_call.func
    assert isinstance(oper, s_oper.Operator)
    env.add_schema_ref(oper, expr=qlexpr)
    oper_name = oper.get_shortname(env.schema)
    str_oper_name = str(oper_name)

    matched_params = oper.get_params(env.schema)
    rtype = matched_call.return_type

    is_polymorphic = (
        any(p.get_type(env.schema).is_polymorphic(env.schema)
            for p in matched_params.objects(env.schema)) and
        rtype.is_polymorphic(env.schema)
    )

    final_args, params_typemods = finalize_args(
        matched_call,
        arg_ctxs=arg_ctxs,
        actual_typemods=actual_typemods,
        is_polymorphic=is_polymorphic,
        ctx=ctx,
    )

    if str_oper_name in {'std::UNION', 'std::IF'} and rtype.is_object_type():
        # Special case for the UNION and IF operators, instead of common
        # parent type, we return a union type.
        if str_oper_name == 'std::UNION':
            larg, rarg = (a.expr for a in final_args)
        else:
            larg, _, rarg = (a.expr for a in final_args)

        left_type = schemactx.get_material_type(
            setgen.get_set_type(larg, ctx=ctx),
            ctx=ctx,
        )
        right_type = schemactx.get_material_type(
            setgen.get_set_type(rarg, ctx=ctx),
            ctx=ctx,
        )

        if left_type.issubclass(env.schema, right_type):
            rtype = right_type
        elif right_type.issubclass(env.schema, left_type):
            rtype = left_type
        else:
            assert isinstance(left_type, s_types.InheritingType)
            assert isinstance(right_type, s_types.InheritingType)
            rtype = schemactx.get_union_type([left_type, right_type], ctx=ctx)

    from_op = oper.get_from_operator(env.schema)
    sql_operator = None
    if (from_op is not None and oper.get_code(env.schema) is None and
            oper.get_from_function(env.schema) is None and
            not in_polymorphic_func):
        sql_operator = tuple(from_op)

    origin_name: Optional[sn.QualName]
    origin_module_id: Optional[uuid.UUID]
    if derivative_op is not None:
        origin_name = oper_name
        origin_module_id = env.schema.get_global(
            s_mod.Module, origin_name.module).id
        oper_name = derivative_op.get_shortname(env.schema)
    else:
        origin_name = None
        origin_module_id = None

    node = irast.OperatorCall(
        args=final_args,
        func_shortname=oper_name,
        func_polymorphic=is_polymorphic,
        origin_name=origin_name,
        origin_module_id=origin_module_id,
        func_sql_function=oper.get_from_function(env.schema),
        sql_operator=sql_operator,
        force_return_cast=oper.get_force_return_cast(env.schema),
        volatility=oper.get_volatility(env.schema),
        operator_kind=oper.get_operator_kind(env.schema),
        params_typemods=params_typemods,
        context=qlexpr.context,
        typeref=typegen.type_to_typeref(rtype, env=env),
        typemod=oper.get_return_typemod(env.schema),
    )

    return setgen.ensure_set(node, typehint=rtype, ctx=ctx)
예제 #25
0
    def _describe_type(self, t, view_shapes, view_shapes_metadata,
                       follow_links: bool = True):
        # The encoding format is documented in edb/api/types.txt.

        buf = self.buffer

        if isinstance(t, s_types.Tuple):
            subtypes = [self._describe_type(st, view_shapes,
                                            view_shapes_metadata)
                        for st in t.get_subtypes(self.schema)]

            if t.named:
                element_names = list(t.get_element_names(self.schema))
                assert len(element_names) == len(subtypes)

                type_id = self._get_collection_type_id(
                    t.schema_name, subtypes, element_names)

                if type_id in self.uuid_to_pos:
                    return type_id

                buf.append(CTYPE_NAMEDTUPLE)
                buf.append(type_id.bytes)
                buf.append(_uint16_packer(len(subtypes)))
                for el_name, el_type in zip(element_names, subtypes):
                    el_name_bytes = el_name.encode('utf-8')
                    buf.append(_uint32_packer(len(el_name_bytes)))
                    buf.append(el_name_bytes)
                    buf.append(_uint16_packer(self.uuid_to_pos[el_type]))

            else:
                type_id = self._get_collection_type_id(t.schema_name, subtypes)

                if type_id in self.uuid_to_pos:
                    return type_id

                buf.append(CTYPE_TUPLE)
                buf.append(type_id.bytes)
                buf.append(_uint16_packer(len(subtypes)))
                for el_type in subtypes:
                    buf.append(_uint16_packer(self.uuid_to_pos[el_type]))

            self._register_type_id(type_id)
            return type_id

        elif isinstance(t, s_types.Array):
            subtypes = [self._describe_type(st, view_shapes,
                                            view_shapes_metadata)
                        for st in t.get_subtypes(self.schema)]

            assert len(subtypes) == 1
            type_id = self._get_collection_type_id(t.schema_name, subtypes)

            if type_id in self.uuid_to_pos:
                return type_id

            buf.append(CTYPE_ARRAY)
            buf.append(type_id.bytes)
            buf.append(_uint16_packer(self.uuid_to_pos[subtypes[0]]))
            # Number of dimensions (currently always 1)
            buf.append(_uint16_packer(1))
            # Dimension cardinality (currently always unbound)
            buf.append(_int32_packer(-1))

            self._register_type_id(type_id)
            return type_id

        elif isinstance(t, s_types.Collection):
            raise errors.SchemaError(f'unsupported collection type {t!r}')

        elif view_shapes.get(t):
            # This is a view
            mt = t.material_type(self.schema)
            base_type_id = mt.id

            subtypes = []
            element_names = []
            link_props = []
            links = []

            metadata = view_shapes_metadata.get(t)
            implicit_id = metadata is not None and metadata.has_implicit_id

            for ptr in view_shapes[t]:
                if ptr.singular(self.schema):
                    if isinstance(ptr, s_links.Link) and not follow_links:
                        subtype_id = self._describe_type(
                            self.schema.get('std::uuid'), view_shapes,
                            view_shapes_metadata,
                        )
                    else:
                        subtype_id = self._describe_type(
                            ptr.get_target(self.schema), view_shapes,
                            view_shapes_metadata)
                else:
                    if isinstance(ptr, s_links.Link) and not follow_links:
                        raise errors.InternalServerError(
                            'cannot describe multi links when '
                            'follow_links=False'
                        )
                    else:
                        subtype_id = self._describe_set(
                            ptr.get_target(self.schema), view_shapes,
                            view_shapes_metadata)
                subtypes.append(subtype_id)
                element_names.append(ptr.get_shortname(self.schema).name)
                link_props.append(False)
                links.append(not ptr.is_property(self.schema))

            t_rptr = t.get_rptr(self.schema)
            if t_rptr is not None:
                # There are link properties in the mix
                for ptr in view_shapes[t_rptr]:
                    if ptr.singular(self.schema):
                        subtype_id = self._describe_type(
                            ptr.get_target(self.schema), view_shapes,
                            view_shapes_metadata)
                    else:
                        subtype_id = self._describe_set(
                            ptr.get_target(self.schema), view_shapes,
                            view_shapes_metadata)
                    subtypes.append(subtype_id)
                    element_names.append(
                        ptr.get_shortname(self.schema).name)
                    link_props.append(True)
                    links.append(False)

            type_id = self._get_object_type_id(
                base_type_id, subtypes, element_names,
                links_props=link_props, links=links,
                has_implicit_fields=implicit_id)

            if type_id in self.uuid_to_pos:
                return type_id

            buf.append(CTYPE_SHAPE)
            buf.append(type_id.bytes)

            assert len(subtypes) == len(element_names)
            buf.append(_uint16_packer(len(subtypes)))

            for el_name, el_type, el_lp, el_l in zip(element_names,
                                                     subtypes, link_props,
                                                     links):
                flags = 0
                if el_lp:
                    flags |= self.EDGE_POINTER_IS_LINKPROP
                if (implicit_id and el_name == 'id') or el_name == '__tid__':
                    if el_type != UUID_TYPE_ID:
                        raise errors.InternalServerError(
                            f"{el_name!r} is expected to be a 'std::uuid' "
                            f"singleton")
                    flags |= self.EDGE_POINTER_IS_IMPLICIT
                if el_l:
                    flags |= self.EDGE_POINTER_IS_LINK
                buf.append(_uint8_packer(flags))

                el_name_bytes = el_name.encode('utf-8')
                buf.append(_uint32_packer(len(el_name_bytes)))
                buf.append(el_name_bytes)
                buf.append(_uint16_packer(self.uuid_to_pos[el_type]))

            self._register_type_id(type_id)
            return type_id

        elif t.is_scalar():
            # This is a scalar type

            mt = t.material_type(self.schema)
            type_id = mt.id
            if type_id in self.uuid_to_pos:
                # already described
                return type_id

            base_type = mt.get_topmost_concrete_base(self.schema)
            enum_values = mt.get_enum_values(self.schema)

            if enum_values:
                buf.append(CTYPE_ENUM)
                buf.append(type_id.bytes)
                buf.append(_uint16_packer(len(enum_values)))
                for enum_val in enum_values:
                    enum_val_bytes = enum_val.encode('utf-8')
                    buf.append(_uint32_packer(len(enum_val_bytes)))
                    buf.append(enum_val_bytes)

            elif mt is base_type:
                buf.append(CTYPE_BASE_SCALAR)
                buf.append(type_id.bytes)

            else:
                bt_id = self._describe_type(
                    base_type, view_shapes, view_shapes_metadata)

                buf.append(CTYPE_SCALAR)
                buf.append(type_id.bytes)
                buf.append(_uint16_packer(self.uuid_to_pos[bt_id]))

            self._register_type_id(type_id)
            return type_id

        else:
            raise errors.InternalServerError(
                f'cannot describe type {t.get_name(self.schema)}')
예제 #26
0
def static_interpret_backend_error(fields):
    err_details = get_error_details(fields)
    # handle some generic errors if possible
    err = get_generic_exception_from_err_details(err_details)
    if err is not None:
        return err

    if err_details.code == PGErrorCode.NotNullViolationError:
        if err_details.table_name or err_details.column_name:
            return SchemaRequired

        else:
            return errors.InternalServerError(err_details.message)

    elif err_details.code in constraint_errors:
        source = pointer = None

        for errtype, ere in constraint_res.items():
            m = ere.match(err_details.message)
            if m:
                error_type = errtype
                break
        else:
            return errors.InternalServerError(err_details.message)

        if error_type == 'cardinality':
            return errors.CardinalityViolationError('cardinality violation',
                                                    source=source,
                                                    pointer=pointer)

        elif error_type == 'link_target':
            if err_details.detail_json:
                srcname = err_details.detail_json.get('source')
                ptrname = err_details.detail_json.get('pointer')
                target = err_details.detail_json.get('target')
                expected = err_details.detail_json.get('expected')

                if srcname and ptrname:
                    srcname = sn.QualName.from_string(srcname)
                    ptrname = sn.QualName.from_string(ptrname)
                    lname = '{}.{}'.format(srcname, ptrname.name)
                else:
                    lname = ''

                msg = (f'invalid target for link {lname!r}: {target!r} '
                       f'(expecting {expected!r})')

            else:
                msg = 'invalid target for link'

            return errors.UnknownLinkError(msg)

        elif error_type == 'link_target_del':
            return errors.ConstraintViolationError(err_details.message,
                                                   details=err_details.detail)

        elif error_type == 'constraint':
            if err_details.constraint_name is None:
                return errors.InternalServerError(err_details.message)

            constraint_id, _, _ = err_details.constraint_name.rpartition(';')

            try:
                constraint_id = uuidgen.UUID(constraint_id)
            except ValueError:
                return errors.InternalServerError(err_details.message)

            return SchemaRequired

        elif error_type == 'newconstraint':
            # We can reconstruct what went wrong from the schema_name,
            # table_name, and column_name. But we don't expect
            # constraint_name to be present (because the constraint is
            # not yet present in the schema?).
            if (err_details.schema_name and err_details.table_name
                    and err_details.column_name):
                return SchemaRequired

            else:
                return errors.InternalServerError(err_details.message)

        elif error_type == 'scalar':
            return SchemaRequired

        elif error_type == 'id':
            return errors.ConstraintViolationError(
                'unique link constraint violation')

    elif err_details.code in SCHEMA_CODES:
        if err_details.code == PGErrorCode.InvalidDatetimeFormatError:
            hint = None
            if err_details.detail_json:
                hint = err_details.detail_json.get('hint')

            if err_details.message.startswith('missing required time zone'):
                return errors.InvalidValueError(err_details.message, hint=hint)
            elif err_details.message.startswith('unexpected time zone'):
                return errors.InvalidValueError(err_details.message, hint=hint)

        return SchemaRequired

    elif err_details.code == PGErrorCode.InvalidParameterValue:
        return errors.InvalidValueError(
            err_details.message,
            details=err_details.detail if err_details.detail else None)

    elif err_details.code == PGErrorCode.WrongObjectType:
        return errors.InvalidValueError(
            err_details.message,
            details=err_details.detail if err_details.detail else None)

    elif err_details.code == PGErrorCode.DivisionByZeroError:
        return errors.DivisionByZeroError(err_details.message)

    elif err_details.code == PGErrorCode.ReadOnlySQLTransactionError:
        return errors.TransactionError(
            'cannot execute query in a read-only transaction')

    elif err_details.code == PGErrorCode.TransactionSerializationFailure:
        return errors.TransactionSerializationError(err_details.message)

    elif err_details.code == PGErrorCode.TransactionDeadlockDetected:
        return errors.TransactionDeadlockError(err_details.message)

    elif err_details.code == PGErrorCode.InvalidCatalogNameError:
        return errors.AuthenticationError(err_details.message)

    elif err_details.code == PGErrorCode.ObjectInUse:
        return errors.ExecutionError(err_details.message)

    return errors.InternalServerError(err_details.message)
예제 #27
0
def interpret_backend_error(schema, fields):
    # See https://www.postgresql.org/docs/current/protocol-error-fields.html
    # for the full list of PostgreSQL error message fields.
    message = fields.get('M')

    detail = fields.get('D')
    detail_json = None
    if detail and detail.startswith('{'):
        detail_json = json.loads(detail)
        detail = None

    if detail_json:
        errcode = detail_json.get('code')
        if errcode:
            try:
                errcls = type(
                    errors.EdgeDBError).get_error_class_from_code(errcode)
            except LookupError:
                pass
            else:
                err = errcls(message)
                err.set_linecol(detail_json.get('line', -1),
                                detail_json.get('column', -1))
                return err

    try:
        code = PGError(fields['C'])
    except ValueError:
        return errors.InternalServerError(message)

    schema_name = fields.get('s')
    table_name = fields.get('t')
    column_name = fields.get('c')
    constraint_name = fields.get('n')

    if code == PGError.NotNullViolationError:
        source_name = pointer_name = None

        if schema_name and table_name:
            tabname = (schema_name, table_name)

            source = common.get_object_from_backend_name(
                schema, s_objtypes.ObjectType, tabname)
            source_name = source.get_displayname(schema)

            if column_name:
                pointer_name = column_name

        if pointer_name is not None:
            pname = f'{source_name}.{pointer_name}'

            return errors.MissingRequiredError(
                f'missing value for required property {pname}')

        else:
            return errors.InternalServerError(message)

    elif code in constraint_errors:
        source = pointer = None

        for errtype, ere in constraint_res.items():
            m = ere.match(message)
            if m:
                error_type = errtype
                break
        else:
            return errors.InternalServerError(message)

        if error_type == 'cardinality':
            return errors.CardinalityViolationError('cardinality violation',
                                                    source=source,
                                                    pointer=pointer)

        elif error_type == 'link_target':
            if detail_json:
                srcname = detail_json.get('source')
                ptrname = detail_json.get('pointer')
                target = detail_json.get('target')
                expected = detail_json.get('expected')

                if srcname and ptrname:
                    srcname = sn.Name(srcname)
                    ptrname = sn.Name(ptrname)
                    lname = '{}.{}'.format(srcname, ptrname.name)
                else:
                    lname = ''

                msg = (f'invalid target for link {lname!r}: {target!r} '
                       f'(expecting {expected!r})')

            else:
                msg = 'invalid target for link'

            return errors.UnknownLinkError(msg)

        elif error_type == 'link_target_del':
            return errors.ConstraintViolationError(message, details=detail)

        elif error_type == 'constraint':
            if constraint_name is None:
                return errors.InternalServerError(message)

            constraint_id, _, _ = constraint_name.rpartition(';')

            try:
                constraint_id = uuid.UUID(constraint_id)
            except ValueError:
                return errors.InternalServerError(message)

            constraint = schema.get_by_id(constraint_id)

            return errors.ConstraintViolationError(
                constraint.format_error_message(schema))

        elif error_type == 'id':
            return errors.ConstraintViolationError(
                'unique link constraint violation')

    elif code == PGError.InvalidParameterValue:
        return errors.InvalidValueError(message,
                                        details=detail if detail else None)

    elif code == PGError.InvalidTextRepresentation:
        return errors.InvalidValueError(translate_pgtype(schema, message))

    elif code == PGError.NumericValueOutOfRange:
        return errors.NumericOutOfRangeError(translate_pgtype(schema, message))

    elif code == PGError.DivisionByZeroError:
        return errors.DivisionByZeroError(message)

    elif code == PGError.ReadOnlySQLTransactionError:
        return errors.TransactionError(
            'cannot execute query in a read-only transaction')

    elif code in {PGError.InvalidDatetimeFormatError, PGError.DatetimeError}:
        return errors.InvalidValueError(translate_pgtype(schema, message))

    elif code == PGError.TransactionSerializationFailure:
        return errors.TransactionSerializationError(message)

    elif code == PGError.TransactionDeadlockDetected:
        return errors.TransactionDeadlockError(message)

    return errors.InternalServerError(message)
예제 #28
0
def type_op_ast_to_type_shell(
    node: qlast.TypeOp,
    *,
    metaclass: Type[s_types.TypeT],
    module: Optional[str] = None,
    modaliases: Mapping[Optional[str], str],
    schema: s_schema.Schema,
) -> s_types.TypeExprShell[s_types.TypeT]:

    from . import types as s_types

    if node.op != '|':
        raise errors.UnsupportedFeatureError(
            f'unsupported type expression operator: {node.op}',
            context=node.context,
        )

    if module is None:
        module = modaliases.get(None)

    if module is None:
        raise errors.InternalServerError(
            'cannot determine module for derived compound type',
            context=node.context,
        )

    left = ast_to_type_shell(
        node.left,
        metaclass=metaclass,
        module=module,
        modaliases=modaliases,
        schema=schema,
    )
    right = ast_to_type_shell(
        node.right,
        metaclass=metaclass,
        module=module,
        modaliases=modaliases,
        schema=schema,
    )

    if isinstance(left, s_types.UnionTypeShell):
        if isinstance(right, s_types.UnionTypeShell):
            return s_types.UnionTypeShell(
                components=left.components + right.components,
                module=module,
                schemaclass=metaclass,
            )
        else:
            return s_types.UnionTypeShell(
                components=left.components + (right, ),
                module=module,
                schemaclass=metaclass,
            )
    else:
        if isinstance(right, s_types.UnionTypeShell):
            return s_types.UnionTypeShell(
                components=(left, ) + right.components,
                schemaclass=metaclass,
                module=module,
            )
        else:
            return s_types.UnionTypeShell(
                components=(left, right),
                module=module,
                schemaclass=metaclass,
            )
예제 #29
0
    def compile_expr_field(
        self,
        schema: s_schema.Schema,
        context: sd.CommandContext,
        field: so.Field[Any],
        value: s_expr.Expression,
    ) -> s_expr.Expression:

        referrer_ctx = self.get_referrer_context(context)
        if referrer_ctx is not None:
            # Concrete constraint
            if field.name == 'expr':
                # Concrete constraints cannot redefine the base check
                # expressions, and so the only way we should get here
                # is through field inheritance, so check that the
                # value is compiled and move on.
                if not value.is_compiled():
                    mcls = self.get_schema_metaclass()
                    dn = mcls.get_schema_class_displayname()
                    raise errors.InternalServerError(
                        f'uncompiled expression in the {field.name!r} field of'
                        f' {dn} {self.classname!r}')
                return value

            elif field.name in {'subjectexpr', 'finalexpr'}:
                anchors = {'__subject__': referrer_ctx.op.scls}
                return s_expr.Expression.compiled(
                    value,
                    schema=schema,
                    options=qlcompiler.CompilerOptions(
                        modaliases=context.modaliases,
                        anchors=anchors,
                        allow_generic_type_output=True,
                        schema_object_context=self.get_schema_metaclass(),
                    ),
                )

            else:
                return super().compile_expr_field(schema, context, field,
                                                  value)

        elif field.name in ('expr', 'subjectexpr'):
            # Abstract constraint.
            params = self._get_params(schema, context)

            param_anchors = s_func.get_params_symtable(
                params,
                schema,
                inlined_defaults=False,
            )

            return s_expr.Expression.compiled(
                value,
                schema=schema,
                options=qlcompiler.CompilerOptions(
                    modaliases=context.modaliases,
                    anchors=param_anchors,
                    func_params=params,
                    allow_generic_type_output=True,
                    schema_object_context=self.get_schema_metaclass(),
                ),
            )
        else:
            return super().compile_expr_field(schema, context, field, value)
예제 #30
0
def compile_ir_to_sql_tree(
    ir_expr: irast.Base,
    *,
    output_format: Optional[OutputFormat] = None,
    ignore_shapes: bool = False,
    explicit_top_cast: Optional[irast.TypeRef] = None,
    singleton_mode: bool = False,
    use_named_params: bool = False,
    expected_cardinality_one: bool = False,
    external_rvars: Optional[
        Mapping[Tuple[irast.PathId, str], pgast.PathRangeVar]
    ] = None,
) -> pgast.Base:
    try:
        # Transform to sql tree
        query_params = []
        type_rewrites = {}

        if isinstance(ir_expr, irast.Statement):
            scope_tree = ir_expr.scope_tree
            query_params = list(ir_expr.params)
            type_rewrites = ir_expr.type_rewrites
            ir_expr = ir_expr.expr
        elif isinstance(ir_expr, irast.ConfigCommand):
            assert ir_expr.scope_tree
            scope_tree = ir_expr.scope_tree
        else:
            scope_tree = irast.new_scope_tree()

        scope_tree_nodes = {
            node.unique_id: node for node in scope_tree.descendants
            if node.unique_id is not None
        }

        env = context.Environment(
            output_format=output_format,
            expected_cardinality_one=expected_cardinality_one,
            use_named_params=use_named_params,
            query_params=query_params,
            type_rewrites=type_rewrites,
            ignore_object_shapes=ignore_shapes,
            explicit_top_cast=explicit_top_cast,
            singleton_mode=singleton_mode,
            scope_tree_nodes=scope_tree_nodes,
            external_rvars=external_rvars,
        )

        ctx = context.CompilerContextLevel(
            None,
            context.ContextSwitchMode.TRANSPARENT,
            env=env,
            scope_tree=scope_tree,
        )

        _ = context.CompilerContext(initial=ctx)
        ctx.singleton_mode = singleton_mode
        ctx.expr_exposed = True
        qtree = dispatch.compile(ir_expr, ctx=ctx)

    except Exception as e:  # pragma: no cover
        try:
            args = [e.args[0]]
        except (AttributeError, IndexError):
            args = []
        raise errors.InternalServerError(*args) from e

    return qtree