Ejemplo n.º 1
0
def _apply_type_to_mapped_statement(
    api: SemanticAnalyzerPluginInterface,
    stmt: AssignmentStmt,
    lvalue: NameExpr,
    left_hand_explicit_type: Optional[Union[Instance, UnionType]],
    python_type_for_type: Union[Instance, UnionType],
) -> None:
    """Apply the Mapped[<type>] annotation and right hand object to a
    declarative assignment statement.

    This converts a Python declarative class statement such as::

        class User(Base):
            # ...

            attrname = Column(Integer)

    To one that describes the final Python behavior to Mypy::

        class User(Base):
            # ...

            attrname : Mapped[Optional[int]] = <meaningless temp node>

    """
    descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"]

    left_node = lvalue.node

    inst = Instance(descriptor.node, [python_type_for_type])

    if left_hand_explicit_type is not None:
        left_node.type = Instance(descriptor.node, [left_hand_explicit_type])
    else:
        lvalue.is_inferred_def = False
        left_node.type = inst

    # so to have it skip the right side totally, we can do this:
    # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))

    # however, if we instead manufacture a new node that uses the old
    # one, then we can still get type checking for the call itself,
    # e.g. the Column, relationship() call, etc.

    # rewrite the node as:
    # <attr> : Mapped[<typ>] =
    # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
    # the original right-hand side is maintained so it gets type checked
    # internally
    api.add_symbol_table_node("_sa_Mapped", descriptor)
    column_descriptor = nodes.NameExpr("_sa_Mapped")
    column_descriptor.fullname = "sqlalchemy.orm.Mapped"
    mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")
    orig_call_expr = stmt.rvalue
    stmt.rvalue = CallExpr(
        mm,
        [orig_call_expr],
        [nodes.ARG_POS],
        ["arg1"],
    )
Ejemplo n.º 2
0
def _scan_declarative_decorator_stmt(
    cls: ClassDef,
    api: SemanticAnalyzerPluginInterface,
    stmt: Decorator,
    cls_metadata: DeclClassApplied,
):
    """Extract mapping information from a @declared_attr in a declarative
    class.

    E.g.::

        @reg.mapped
        class MyClass:
            # ...

            @declared_attr
            def updated_at(cls) -> Column[DateTime]:
                return Column(DateTime)

    Will resolve in mypy as::

        @reg.mapped
        class MyClass:
            # ...

            updated_at: Mapped[Optional[datetime.datetime]]

    """
    for dec in stmt.decorators:
        if names._type_id_for_named_node(dec) is names.DECLARED_ATTR:
            break
    else:
        return

    dec_index = cls.defs.body.index(stmt)

    left_hand_explicit_type = None

    if stmt.func.type is not None:
        func_type = stmt.func.type.ret_type
        if isinstance(func_type, UnboundType):
            type_id = names._type_id_for_unbound_type(func_type, cls, api)
        else:
            # this does not seem to occur unless the type argument is
            # incorrect
            return

        if (
            type_id
            in {
                names.MAPPED,
                names.RELATIONSHIP,
                names.COMPOSITE_PROPERTY,
                names.MAPPER_PROPERTY,
                names.SYNONYM_PROPERTY,
                names.COLUMN_PROPERTY,
            }
            and func_type.args
        ):
            left_hand_explicit_type = func_type.args[0]
        elif type_id is names.COLUMN and func_type.args:
            typeengine_arg = func_type.args[0]
            if isinstance(typeengine_arg, UnboundType):
                sym = api.lookup(typeengine_arg.name, typeengine_arg)
                if sym is not None and names._mro_has_id(
                    sym.node.mro, names.TYPEENGINE
                ):

                    left_hand_explicit_type = UnionType(
                        [
                            _extract_python_type_from_typeengine(sym.node),
                            NoneType(),
                        ]
                    )
                else:
                    util.fail(
                        api,
                        "Column type should be a TypeEngine "
                        "subclass not '{}'".format(sym.node.fullname),
                        func_type,
                    )

    if left_hand_explicit_type is None:
        # no type on the decorated function.  our option here is to
        # dig into the function body and get the return type, but they
        # should just have an annotation.
        msg = (
            "Can't infer type from @declared_attr on function '{}';  "
            "please specify a return type from this function that is "
            "one of: Mapped[<python type>], relationship[<target class>], "
            "Column[<TypeEngine>], MapperProperty[<python type>]"
        )
        util.fail(api, msg.format(stmt.var.name), stmt)

        left_hand_explicit_type = AnyType(TypeOfAny.special_form)

    descriptor = api.modules["sqlalchemy.orm.attributes"].names["Mapped"]

    left_node = NameExpr(stmt.var.name)
    left_node.node = stmt.var

    # totally feeling around in the dark here as I don't totally understand
    # the significance of UnboundType.  It seems to be something that is
    # not going to do what's expected when it is applied as the type of
    # an AssignmentStatement.  So do a feeling-around-in-the-dark version
    # of converting it to the regular Instance/TypeInfo/UnionType structures
    # we see everywhere else.
    if isinstance(left_hand_explicit_type, UnboundType):
        left_hand_explicit_type = util._unbound_to_instance(
            api, left_hand_explicit_type
        )

    left_node.node.type = Instance(descriptor.node, [left_hand_explicit_type])

    # this will ignore the rvalue entirely
    # rvalue = TempNode(AnyType(TypeOfAny.special_form))

    # rewrite the node as:
    # <attr> : Mapped[<typ>] =
    # _sa_Mapped._empty_constructor(lambda: <function body>)
    # the function body is maintained so it gets type checked internally
    api.add_symbol_table_node("_sa_Mapped", descriptor)
    column_descriptor = nodes.NameExpr("_sa_Mapped")
    column_descriptor.fullname = "sqlalchemy.orm.Mapped"
    mm = nodes.MemberExpr(column_descriptor, "_empty_constructor")

    arg = nodes.LambdaExpr(stmt.func.arguments, stmt.func.body)
    rvalue = CallExpr(
        mm,
        [arg],
        [nodes.ARG_POS],
        ["arg1"],
    )

    new_stmt = AssignmentStmt([left_node], rvalue)
    new_stmt.type = left_node.node.type

    cls_metadata.mapped_attr_names.append(
        (left_node.name, left_hand_explicit_type)
    )
    cls.defs.body[dec_index] = new_stmt