Exemple #1
0
def update_payload_factory(sqla_model: TableProtocol) -> InputObjectType:
    """UpdateAccountPayload"""
    from nebulo.gql.convert.table import table_factory

    relevant_type_name = Config.table_type_name_mapper(sqla_model)
    relevant_attr_name = Config.table_name_mapper(sqla_model)
    result_name = f"Update{relevant_type_name}Payload"

    attrs = {
        "clientMutationId":
        Field(String, resolve=default_resolver),
        "nodeId":
        ID,
        relevant_attr_name:
        Field(
            table_factory(sqla_model),
            resolve=default_resolver,
            description=
            f"The {relevant_type_name} that was created by this mutation.",
        ),
    }

    return UpdatePayloadType(
        result_name,
        attrs,
        description=f"The output of our update {relevant_type_name} mutation",
        sqla_model=sqla_model)
Exemple #2
0
    def build_attrs():
        attrs = {}

        # Override id to relay standard
        attrs["nodeId"] = Field(NonNull(ID), resolve=default_resolver)

        for column in get_columns(sqla_model):
            if not Config.exclude_read(column):
                key = Config.column_name_mapper(column)
                attrs[key] = convert_column(column)

        for relationship in get_relationships(sqla_model):
            direction = relationship.direction
            to_sqla_model = relationship.mapper.class_
            relationship_is_nullable = is_nullable(relationship)

            # Name of the attribute on the model
            attr_key = Config.relationship_name_mapper(relationship)

            # If this model has 1 counterpart, do not use a list
            if direction == interfaces.MANYTOONE:
                _type = table_factory(to_sqla_model)
                _type = NonNull(_type) if not relationship_is_nullable else _type
                attrs[attr_key] = Field(_type, resolve=default_resolver)

            # Otherwise, set it up as a connection
            elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
                connection_field = connection_field_factory(
                    to_sqla_model, resolver=default_resolver, not_null=relationship_is_nullable
                )
                attrs[attr_key] = connection_field

        return attrs
Exemple #3
0
def patch_type_factory(sqla_model: TableProtocol) -> InputObjectType:
    """AccountPatch"""
    relevant_type_name = Config.table_type_name_mapper(sqla_model)
    result_name = f"{relevant_type_name}Patch"

    attrs = {}
    for column in get_columns(sqla_model):
        if not Config.exclude_update(column):
            field_key = Config.column_name_mapper(column)
            column_field = convert_column_to_input(column)
            # TODO Unwrap not null here
            attrs[field_key] = column_field
    return TableInputType(result_name, attrs, description=f"An input for mutations affecting {relevant_type_name}.")
Exemple #4
0
def update_input_type_factory(sqla_model: TableProtocol) -> InputObjectType:
    """UpdateAccountInput!"""
    relevant_type_name = Config.table_type_name_mapper(sqla_model)
    result_name = f"Update{relevant_type_name}Input"

    input_object_name = Config.table_name_mapper(sqla_model)

    attrs = {
        "nodeId": NonNull(ID),
        "clientMutationId": String,
        input_object_name: NonNull(patch_type_factory(sqla_model)),
    }
    return UpdateInputType(result_name, attrs, description=f"All input for the create {relevant_type_name} mutation.")
Exemple #5
0
def input_type_factory(sqla_model: TableProtocol) -> TableInputType:
    """AccountInput"""
    relevant_type_name = Config.table_type_name_mapper(sqla_model)
    result_name = f"{relevant_type_name}Input"

    attrs = {}
    for column in get_columns(sqla_model):
        if not Config.exclude_create(column):
            field_key = Config.column_name_mapper(column)
            attrs[field_key] = convert_column_to_input(column)
    return TableInputType(
        result_name,
        attrs,
        description=f"An input for mutations affecting {relevant_type_name}.")
Exemple #6
0
def mutable_function_entrypoint_factory(
        sql_function: SQLFunction,
        resolver: typing.Callable,
        jwt_secret: typing.Optional[str] = None) -> typing.Dict[str, Field]:
    """authenticate"""
    # TODO(OR): need seperate mapper
    function_name = Config.function_name_mapper(sql_function)

    if sql_function.is_immutable:
        raise Exception(
            f"SQLFunction {sql_function.name} is immutable, use immutable_function_entrypoint"
        )

    args = {"input": NonNull(function_input_type_factory(sql_function))}

    if jwt_secret is not None:
        payload = jwt_function_payload_factory(sql_function, jwt_secret)
    else:
        payload = function_payload_factory(sql_function)

    return {
        function_name:
        Field(payload,
              args=args,
              resolve=resolver,
              description=f"Call the function {function_name}.")
    }
Exemple #7
0
def immutable_function_entrypoint_factory(
        sql_function: SQLFunction,
        resolver: typing.Callable) -> typing.Dict[str, Field]:
    """authenticate"""
    # TODO(OR): need seperate mapper
    function_name = Config.function_name_mapper(sql_function)

    if not sql_function.is_immutable:
        raise Exception(
            f"SQLFunction {sql_function.name} is not immutable, use mutable_function_entrypoint"
        )

    gql_args = {(arg_name if arg_name else f"param{ix}"):
                Argument(NonNull(convert_type(arg_sqla_type)))
                for ix, (arg_name, arg_sqla_type) in enumerate(
                    zip(sql_function.arg_names, sql_function.arg_sqla_types))}

    return_type = convert_type(sql_function.return_sqla_type)
    return_type.sql_function = sql_function
    return_field = Field(return_type,
                         args=gql_args,
                         resolve=resolver,
                         description="")

    return {function_name: return_field}
Exemple #8
0
def jwt_function_payload_factory(sql_function: SQLFunction,
                                 jwt_secret: str) -> FunctionPayloadType:
    """CreateAccountPayload"""
    function_name = Config.function_type_name_mapper(sql_function)
    result_name = f"{function_name}Payload"

    function_return_type = ScalarType(
        "JWT",
        serialize=lambda result: jwt.encode({k: v
                                             for k, v in result.items()},
                                            jwt_secret,
                                            algorithm="HS256").decode("utf-8"),
    )

    attrs = {
        "clientMutationId":
        Field(String, resolve=default_resolver),
        "result":
        Field(
            function_return_type,
            description=f"The {result_name} that was created by this mutation.",
            resolve=default_resolver,
        ),
    }

    payload = FunctionPayloadType(
        result_name,
        attrs,
        description=f"The output of our create {function_name} mutation")
    payload.sql_function = sql_function
    return payload
Exemple #9
0
def function_payload_factory(sql_function: SQLFunction) -> FunctionPayloadType:
    """CreateAccountPayload"""
    function_name = Config.function_type_name_mapper(sql_function)
    result_name = f"{function_name}Payload"

    # TODO(OR): handle functions with no return
    function_return_type = convert_type(sql_function.return_sqla_type)
    function_return_type.sql_function = sql_function

    attrs = {
        "clientMutationId":
        Field(String, resolve=default_resolver),
        "result":
        Field(
            function_return_type,
            description=f"The {result_name} that was created by this mutation.",
            resolve=default_resolver,
        ),
    }

    payload = FunctionPayloadType(
        result_name,
        attrs,
        description=f"The output of our create {function_name} mutation")
    payload.sql_function = sql_function
    return payload
Exemple #10
0
def field_name_to_relationship(sqla_model: TableProtocol,
                               gql_field_name: str) -> RelationshipProperty:
    for relationship in get_relationships(sqla_model):

        if Config.relationship_name_mapper(relationship) == gql_field_name:
            return relationship
    raise Exception(f"No relationship corresponding to field {gql_field_name}")
Exemple #11
0
def condition_factory(sqla_model: TableProtocol) -> InputObjectType:
    result_name = f"{Config.table_name_mapper(sqla_model)}Condition"

    attrs = {}
    for column in get_columns(sqla_model):
        field_key = Config.column_name_mapper(column)
        attrs[field_key] = convert_column_to_input(column)
    return InputObjectType(result_name, attrs, description="")
Exemple #12
0
def table_field_factory(sqla_model: TableProtocol, resolver) -> Field:
    relevant_type_name = Config.table_type_name_mapper(sqla_model)
    node = table_factory(sqla_model)
    return Field(
        node,
        args={"nodeId": Argument(NonNull(ID))},
        resolve=resolver,
        description=f"Reads a single {relevant_type_name} using its globally unique ID",
    )
Exemple #13
0
def edge_factory(sqla_model: TableProtocol) -> EdgeType:
    from .table import table_factory

    name = Config.table_type_name_mapper(sqla_model) + "Edge"

    def build_attrs():
        return {"cursor": Field(Cursor), "node": Field(table_factory(sqla_model))}

    edge = EdgeType(name=name, fields=build_attrs, description="", sqla_model=sqla_model)
    return edge
Exemple #14
0
def table_factory(sqla_model: TableProtocol) -> TableType:
    """
    Reflects a SQLAlchemy table into a graphql-core GraphQLObjectType

    Parameters
    ----------
    sqla_model
        A SQLAlchemy ORM Table

    """
    from .connection import connection_field_factory

    name = Config.table_type_name_mapper(sqla_model)

    def build_attrs():
        attrs = {}

        # Override id to relay standard
        attrs["nodeId"] = Field(NonNull(ID), resolve=default_resolver)

        for column in get_columns(sqla_model):
            if not Config.exclude_read(column):
                key = Config.column_name_mapper(column)
                attrs[key] = convert_column(column)

        for relationship in get_relationships(sqla_model):
            direction = relationship.direction
            to_sqla_model = relationship.mapper.class_
            relationship_is_nullable = is_nullable(relationship)

            # Name of the attribute on the model
            attr_key = Config.relationship_name_mapper(relationship)

            # If this model has 1 counterpart, do not use a list
            if direction == interfaces.MANYTOONE:
                _type = table_factory(to_sqla_model)
                _type = NonNull(_type) if not relationship_is_nullable else _type
                attrs[attr_key] = Field(_type, resolve=default_resolver)

            # Otherwise, set it up as a connection
            elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
                connection_field = connection_field_factory(
                    to_sqla_model, resolver=default_resolver, not_null=relationship_is_nullable
                )
                attrs[attr_key] = connection_field

        return attrs

    return_type = TableType(
        name=name, fields=build_attrs, interfaces=[NodeInterface], description="", sqla_model=sqla_model
    )

    return return_type
Exemple #15
0
def create_entrypoint_factory(sqla_model: TableProtocol, resolver) -> Field:
    """createAccount"""
    relevant_type_name = Config.table_type_name_mapper(sqla_model)
    name = f"create{relevant_type_name}"
    args = {"input": NonNull(create_input_type_factory(sqla_model))}
    payload = create_payload_factory(sqla_model)
    return {
        name:
        Field(payload,
              args=args,
              resolve=resolver,
              description=f"Creates a single {relevant_type_name}.")
    }
Exemple #16
0
def connection_factory(sqla_model: TableProtocol) -> ConnectionType:
    name = Config.table_type_name_mapper(sqla_model) + "Connection"

    def build_attrs():
        edge = edge_factory(sqla_model)
        return {
            "edges": Field(NonNull(List(NonNull(edge))), resolve=default_resolver),
            "pageInfo": Field(NonNull(PageInfo), resolve=default_resolver),
            "totalCount": Field(NonNull(Int), resolve=default_resolver),
        }

    return_type = ConnectionType(name=name, fields=build_attrs, description="", sqla_model=sqla_model)
    return return_type
Exemple #17
0
def update_entrypoint_factory(sqla_model: TableProtocol, resolver) -> t.Dict[str, Field]:
    """updateAccount"""
    relevant_type_name = Config.table_type_name_mapper(sqla_model)
    name = f"update{relevant_type_name}"
    args = {"input": NonNull(update_input_type_factory(sqla_model))}
    payload = update_payload_factory(sqla_model)
    return {
        name: Field(
            payload,
            args=args,
            resolve=resolver,
            description=f"Updates a single {relevant_type_name} using its globally unique id and a patch.",
        )
    }
Exemple #18
0
def delete_payload_factory(sqla_model: TableProtocol) -> InputObjectType:
    """DeleteAccountPayload"""

    relevant_type_name = Config.table_type_name_mapper(sqla_model)
    result_name = f"Delete{relevant_type_name}Payload"

    attrs = {
        "clientMutationId": Field(String, resolve=default_resolver),
        "nodeId": ID
    }

    return DeletePayloadType(
        result_name,
        attrs,
        description=f"The output of our delete {relevant_type_name} mutation",
        sqla_model=sqla_model)
Exemple #19
0
def connection_field_factory(sqla_model: TableProtocol, resolver, not_null=False) -> Field:
    relevant_type_name = Config.table_type_name_mapper(sqla_model)
    connection = connection_factory(sqla_model)
    condition = condition_factory(sqla_model)
    args = {
        "first": Argument(Int, description="", out_name=None),
        "last": Argument(Int),
        "before": Argument(Cursor),
        "after": Argument(Cursor),
        "condition": Argument(condition),
    }
    return Field(
        NonNull(connection) if not_null else connection,
        args=args,
        resolve=resolver,
        description=f"Reads and enables pagination through a set of {relevant_type_name}",
    )
Exemple #20
0
def function_input_type_factory(
        sql_function: SQLFunction) -> FunctionInputType:
    """AuthenticateInput!"""
    function_name = Config.function_type_name_mapper(sql_function)
    result_name = f"{function_name}Input"

    function_args = {
        (arg_name if arg_name else f"param{ix}"):
        NonNull(convert_input_type(arg_sqla_type))
        for ix, (arg_name, arg_sqla_type) in enumerate(
            zip(sql_function.arg_names, sql_function.arg_sqla_types))
    }

    attrs = {"clientMutationId": String, **function_args}
    return FunctionInputType(
        result_name,
        attrs,
        description=f"All input for the {function_name} mutation.")
Exemple #21
0
def field_name_to_column(sqla_model: TableProtocol,
                         gql_field_name: str) -> Column:
    for column in get_columns(sqla_model):
        if Config.column_name_mapper(column) == gql_field_name:
            return column
    raise KeyError(f"No column corresponding to field {gql_field_name}")
Exemple #22
0
def sqla_models_to_graphql_schema(
    sqla_models,
    sql_functions: typing.Optional[typing.List[SQLFunction]] = None,
    jwt_identifier: typing.Optional[str] = None,
    jwt_secret: typing.Optional[str] = None,
) -> Schema:
    """Creates a GraphQL Schema from SQLA Models

    **Parameters**

    * **sqla_models**: _List[Type[SQLAModel]]_ = List of SQLAlchemy models to include in the GraphQL schema
    * **jwt_identifier**: _str_ = qualified path of SQL composite type to use encode as a JWT e.g. 'public.jwt'
    * **jwt_secret**: _str_ = Secret key used to encrypt JWT contents
    * **sql_functions** = **NOT PUBLIC API**
    """

    if sql_functions is None:
        sql_functions = []

    query_fields = {}
    mutation_fields = {}

    # Tables
    for sqla_model in sqla_models:

        if not Config.exclude_read_one(sqla_model):
            # e.g. account(nodeId: NodeID)
            single_name = snake_to_camel(get_table_name(sqla_model),
                                         upper=False)
            query_fields[single_name] = table_field_factory(
                sqla_model, resolver)

        if not Config.exclude_read_all(sqla_model):
            # e.g. allAccounts(first: Int, last: Int ....)
            connection_name = "all" + snake_to_camel(
                to_plural(get_table_name(sqla_model)), upper=True)
            query_fields[connection_name] = connection_field_factory(
                sqla_model, resolver)

        if not Config.exclude_create(sqla_model):
            # e.g. createAccount(input: CreateAccountInput)
            mutation_fields.update(
                create_entrypoint_factory(sqla_model, resolver=resolver))

        if not Config.exclude_update(sqla_model):
            # e.g. updateAccount(input: UpdateAccountInput)
            mutation_fields.update(
                update_entrypoint_factory(sqla_model, resolver=resolver))

        if not Config.exclude_delete(sqla_model):
            # e.g. deleteAccount(input: DeleteAccountInput)
            mutation_fields.update(
                delete_entrypoint_factory(sqla_model, resolver=resolver))
    # Functions
    for sql_function in sql_functions:
        if is_jwt_function(sql_function, jwt_identifier):
            mutation_fields.update(
                mutable_function_entrypoint_factory(sql_function=sql_function,
                                                    resolver=resolver,
                                                    jwt_secret=jwt_secret))
        else:

            # Immutable functions are queries
            if sql_function.is_immutable:
                query_fields.update(
                    immutable_function_entrypoint_factory(
                        sql_function=sql_function, resolver=resolver))

            # Mutable functions are mutations
            else:
                mutation_fields.update(
                    mutable_function_entrypoint_factory(
                        sql_function=sql_function, resolver=resolver))

    schema_kwargs = {
        "query": ObjectType(name="Query", fields=query_fields),
        "mutation": ObjectType(name="Mutation", fields=mutation_fields),
    }
    return Schema(**{k: v for k, v in schema_kwargs.items() if v.fields})
Exemple #23
0
def enum_factory(sqla_enum: typing.Type[postgresql.base.ENUM]) -> EnumType:
    name = Config.enum_name_mapper(sqla_enum)
    return EnumType(name=name, values={val: val for val in sqla_enum.enums})
Exemple #24
0
async def async_resolver(_, info: ResolveInfo, **kwargs) -> typing.Any:
    """Awaitable GraphQL Entrypoint resolver

    Expects:
        info.context['engine'] to contain an sqlalchemy.ext.asyncio.AsyncEngine
    """
    context = info.context
    engine = context["engine"]
    default_role = context["default_role"]
    jwt_claims = context["jwt_claims"]

    tree = parse_resolve_info(info)

    async with engine.begin() as trans:
        # Set claims for transaction
        if jwt_claims or default_role:
            claims_stmt = build_claims(jwt_claims, default_role)
            await trans.execute(claims_stmt)

        result: typing.Dict[str, typing.Any]

        if isinstance(tree.return_type, FunctionPayloadType):
            sql_function = tree.return_type.sql_function
            function_args = [
                val for key, val in tree.args["input"].items()
                if key != "clientMutationId"
            ]
            func_call = sql_function.to_executable(function_args)

            # Function returning table row
            if isinstance(sql_function.return_sqla_type, TableProtocol):
                # Unpack the table row to columns
                return_sqla_model = sql_function.return_sqla_type
                core_table = return_sqla_model.__table__
                func_alias = func_call.alias("named_alias")
                stmt = select([
                    literal_column(c.name).label(c.name) for c in core_table.c
                ]).select_from(func_alias)  # type: ignore
                stmt_alias = stmt.alias()
                node_id_stmt = select([
                    to_node_id_sql(return_sqla_model,
                                   stmt_alias).label("nodeId")
                ]).select_from(stmt_alias)  # type: ignore
                ((row, ), ) = await trans.execute(node_id_stmt)
                node_id = NodeIdStructure.from_dict(row)

                # Add nodeId to AST and query
                query_tree = next(
                    iter([x for x in tree.fields if x.name == "result"]), None)
                if query_tree is not None:
                    query_tree.args["nodeId"] = node_id
                    base_query = sql_builder(query_tree)
                    query = sql_finalize(query_tree.alias, base_query)
                    ((stmt_result, ), ) = await trans.execute(query)
                else:
                    stmt_result = {}
            else:
                stmt = select([func_call.label("result")])
                (stmt_result, ) = await trans.execute(stmt)

            maybe_mutation_id = tree.args["input"].get("clientMutationId")
            mutation_id_alias = next(
                iter([
                    x.alias for x in tree.fields
                    if x.name == "clientMutationId"
                ]),
                "clientMutationId",
            )
            result = {
                tree.alias: {
                    **stmt_result,
                    **{
                        mutation_id_alias: maybe_mutation_id
                    }
                }
            }

        elif isinstance(tree.return_type, MutationPayloadType):
            stmt = build_mutation(tree)
            ((row, ), ) = await trans.execute(stmt)
            node_id = NodeIdStructure.from_dict(row)

            maybe_mutation_id = tree.args["input"].get("clientMutationId")
            mutation_id_alias = next(
                iter([
                    x.alias for x in tree.fields
                    if x.name == "clientMutationId"
                ]),
                "clientMutationId",
            )
            node_id_alias = next(
                iter([x.alias for x in tree.fields if x.name == "nodeId"]),
                "nodeId")
            output_row_name: str = Config.table_name_mapper(
                tree.return_type.sqla_model)
            query_tree = next(
                iter([x for x in tree.fields if x.name == output_row_name]),
                None)
            sql_result = {}
            if query_tree:
                # Set the nodeid of the newly created record as an arg
                query_tree.args["nodeId"] = node_id
                base_query = sql_builder(query_tree)
                query = sql_finalize(query_tree.alias, base_query)
                ((sql_result, ), ) = await trans.execute(query)
            result = {
                tree.alias: {
                    **sql_result, mutation_id_alias: maybe_mutation_id
                },
                mutation_id_alias: maybe_mutation_id,
                node_id_alias: node_id,
            }

        elif isinstance(tree.return_type, (ObjectType, ScalarType)):
            base_query = sql_builder(tree)
            query = sql_finalize(tree.name, base_query)
            ((query_json_result, ), ) = await trans.execute(query)

            if isinstance(tree.return_type, ScalarType):
                # If its a scalar, unwrap the top level name
                result = flu(query_json_result.values()).first(None)
            else:
                result = query_json_result

        else:
            raise Exception("sql builder could not handle return type")

    # Stash result on context to enable dumb resolvers to not fail
    context["result"] = result
    return result