Esempio n. 1
0
def update_payload_factory(sqla_model: TableProtocol) -> InputObjectType:
    """UpdateAccountPayload"""
    from nebulo.gql.convert.table import table_factory

    relevant_type_name = Config.table_type_name_mapper(sqla_model)
    relevant_attr_name = Config.table_name_mapper(sqla_model)
    result_name = f"Update{relevant_type_name}Payload"

    attrs = {
        "clientMutationId":
        Field(String, resolve=default_resolver),
        "nodeId":
        ID,
        relevant_attr_name:
        Field(
            table_factory(sqla_model),
            resolve=default_resolver,
            description=
            f"The {relevant_type_name} that was created by this mutation.",
        ),
    }

    return UpdatePayloadType(
        result_name,
        attrs,
        description=f"The output of our update {relevant_type_name} mutation",
        sqla_model=sqla_model)
Esempio n. 2
0
def update_input_type_factory(sqla_model: TableProtocol) -> InputObjectType:
    """UpdateAccountInput!"""
    relevant_type_name = Config.table_type_name_mapper(sqla_model)
    result_name = f"Update{relevant_type_name}Input"

    input_object_name = Config.table_name_mapper(sqla_model)

    attrs = {
        "nodeId": NonNull(ID),
        "clientMutationId": String,
        input_object_name: NonNull(patch_type_factory(sqla_model)),
    }
    return UpdateInputType(result_name, attrs, description=f"All input for the create {relevant_type_name} mutation.")
Esempio n. 3
0
async def async_resolver(_, info: ResolveInfo, **kwargs) -> typing.Any:
    """Awaitable GraphQL Entrypoint resolver

    Expects:
        info.context['engine'] to contain an sqlalchemy.ext.asyncio.AsyncEngine
    """
    context = info.context
    engine = context["engine"]
    default_role = context["default_role"]
    jwt_claims = context["jwt_claims"]

    tree = parse_resolve_info(info)

    async with engine.begin() as trans:
        # Set claims for transaction
        if jwt_claims or default_role:
            claims_stmt = build_claims(jwt_claims, default_role)
            await trans.execute(claims_stmt)

        result: typing.Dict[str, typing.Any]

        if isinstance(tree.return_type, FunctionPayloadType):
            sql_function = tree.return_type.sql_function
            function_args = [
                val for key, val in tree.args["input"].items()
                if key != "clientMutationId"
            ]
            func_call = sql_function.to_executable(function_args)

            # Function returning table row
            if isinstance(sql_function.return_sqla_type, TableProtocol):
                # Unpack the table row to columns
                return_sqla_model = sql_function.return_sqla_type
                core_table = return_sqla_model.__table__
                func_alias = func_call.alias("named_alias")
                stmt = select([
                    literal_column(c.name).label(c.name) for c in core_table.c
                ]).select_from(func_alias)  # type: ignore
                stmt_alias = stmt.alias()
                node_id_stmt = select([
                    to_node_id_sql(return_sqla_model,
                                   stmt_alias).label("nodeId")
                ]).select_from(stmt_alias)  # type: ignore
                ((row, ), ) = await trans.execute(node_id_stmt)
                node_id = NodeIdStructure.from_dict(row)

                # Add nodeId to AST and query
                query_tree = next(
                    iter([x for x in tree.fields if x.name == "result"]), None)
                if query_tree is not None:
                    query_tree.args["nodeId"] = node_id
                    base_query = sql_builder(query_tree)
                    query = sql_finalize(query_tree.alias, base_query)
                    ((stmt_result, ), ) = await trans.execute(query)
                else:
                    stmt_result = {}
            else:
                stmt = select([func_call.label("result")])
                (stmt_result, ) = await trans.execute(stmt)

            maybe_mutation_id = tree.args["input"].get("clientMutationId")
            mutation_id_alias = next(
                iter([
                    x.alias for x in tree.fields
                    if x.name == "clientMutationId"
                ]),
                "clientMutationId",
            )
            result = {
                tree.alias: {
                    **stmt_result,
                    **{
                        mutation_id_alias: maybe_mutation_id
                    }
                }
            }

        elif isinstance(tree.return_type, MutationPayloadType):
            stmt = build_mutation(tree)
            ((row, ), ) = await trans.execute(stmt)
            node_id = NodeIdStructure.from_dict(row)

            maybe_mutation_id = tree.args["input"].get("clientMutationId")
            mutation_id_alias = next(
                iter([
                    x.alias for x in tree.fields
                    if x.name == "clientMutationId"
                ]),
                "clientMutationId",
            )
            node_id_alias = next(
                iter([x.alias for x in tree.fields if x.name == "nodeId"]),
                "nodeId")
            output_row_name: str = Config.table_name_mapper(
                tree.return_type.sqla_model)
            query_tree = next(
                iter([x for x in tree.fields if x.name == output_row_name]),
                None)
            sql_result = {}
            if query_tree:
                # Set the nodeid of the newly created record as an arg
                query_tree.args["nodeId"] = node_id
                base_query = sql_builder(query_tree)
                query = sql_finalize(query_tree.alias, base_query)
                ((sql_result, ), ) = await trans.execute(query)
            result = {
                tree.alias: {
                    **sql_result, mutation_id_alias: maybe_mutation_id
                },
                mutation_id_alias: maybe_mutation_id,
                node_id_alias: node_id,
            }

        elif isinstance(tree.return_type, (ObjectType, ScalarType)):
            base_query = sql_builder(tree)
            query = sql_finalize(tree.name, base_query)
            ((query_json_result, ), ) = await trans.execute(query)

            if isinstance(tree.return_type, ScalarType):
                # If its a scalar, unwrap the top level name
                result = flu(query_json_result.values()).first(None)
            else:
                result = query_json_result

        else:
            raise Exception("sql builder could not handle return type")

    # Stash result on context to enable dumb resolvers to not fail
    context["result"] = result
    return result