def update_payload_factory(sqla_model: TableProtocol) -> InputObjectType: """UpdateAccountPayload""" from nebulo.gql.convert.table import table_factory relevant_type_name = Config.table_type_name_mapper(sqla_model) relevant_attr_name = Config.table_name_mapper(sqla_model) result_name = f"Update{relevant_type_name}Payload" attrs = { "clientMutationId": Field(String, resolve=default_resolver), "nodeId": ID, relevant_attr_name: Field( table_factory(sqla_model), resolve=default_resolver, description= f"The {relevant_type_name} that was created by this mutation.", ), } return UpdatePayloadType( result_name, attrs, description=f"The output of our update {relevant_type_name} mutation", sqla_model=sqla_model)
def build_attrs(): attrs = {} # Override id to relay standard attrs["nodeId"] = Field(NonNull(ID), resolve=default_resolver) for column in get_columns(sqla_model): if not Config.exclude_read(column): key = Config.column_name_mapper(column) attrs[key] = convert_column(column) for relationship in get_relationships(sqla_model): direction = relationship.direction to_sqla_model = relationship.mapper.class_ relationship_is_nullable = is_nullable(relationship) # Name of the attribute on the model attr_key = Config.relationship_name_mapper(relationship) # If this model has 1 counterpart, do not use a list if direction == interfaces.MANYTOONE: _type = table_factory(to_sqla_model) _type = NonNull(_type) if not relationship_is_nullable else _type attrs[attr_key] = Field(_type, resolve=default_resolver) # Otherwise, set it up as a connection elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): connection_field = connection_field_factory( to_sqla_model, resolver=default_resolver, not_null=relationship_is_nullable ) attrs[attr_key] = connection_field return attrs
def patch_type_factory(sqla_model: TableProtocol) -> InputObjectType: """AccountPatch""" relevant_type_name = Config.table_type_name_mapper(sqla_model) result_name = f"{relevant_type_name}Patch" attrs = {} for column in get_columns(sqla_model): if not Config.exclude_update(column): field_key = Config.column_name_mapper(column) column_field = convert_column_to_input(column) # TODO Unwrap not null here attrs[field_key] = column_field return TableInputType(result_name, attrs, description=f"An input for mutations affecting {relevant_type_name}.")
def update_input_type_factory(sqla_model: TableProtocol) -> InputObjectType: """UpdateAccountInput!""" relevant_type_name = Config.table_type_name_mapper(sqla_model) result_name = f"Update{relevant_type_name}Input" input_object_name = Config.table_name_mapper(sqla_model) attrs = { "nodeId": NonNull(ID), "clientMutationId": String, input_object_name: NonNull(patch_type_factory(sqla_model)), } return UpdateInputType(result_name, attrs, description=f"All input for the create {relevant_type_name} mutation.")
def input_type_factory(sqla_model: TableProtocol) -> TableInputType: """AccountInput""" relevant_type_name = Config.table_type_name_mapper(sqla_model) result_name = f"{relevant_type_name}Input" attrs = {} for column in get_columns(sqla_model): if not Config.exclude_create(column): field_key = Config.column_name_mapper(column) attrs[field_key] = convert_column_to_input(column) return TableInputType( result_name, attrs, description=f"An input for mutations affecting {relevant_type_name}.")
def mutable_function_entrypoint_factory( sql_function: SQLFunction, resolver: typing.Callable, jwt_secret: typing.Optional[str] = None) -> typing.Dict[str, Field]: """authenticate""" # TODO(OR): need seperate mapper function_name = Config.function_name_mapper(sql_function) if sql_function.is_immutable: raise Exception( f"SQLFunction {sql_function.name} is immutable, use immutable_function_entrypoint" ) args = {"input": NonNull(function_input_type_factory(sql_function))} if jwt_secret is not None: payload = jwt_function_payload_factory(sql_function, jwt_secret) else: payload = function_payload_factory(sql_function) return { function_name: Field(payload, args=args, resolve=resolver, description=f"Call the function {function_name}.") }
def immutable_function_entrypoint_factory( sql_function: SQLFunction, resolver: typing.Callable) -> typing.Dict[str, Field]: """authenticate""" # TODO(OR): need seperate mapper function_name = Config.function_name_mapper(sql_function) if not sql_function.is_immutable: raise Exception( f"SQLFunction {sql_function.name} is not immutable, use mutable_function_entrypoint" ) gql_args = {(arg_name if arg_name else f"param{ix}"): Argument(NonNull(convert_type(arg_sqla_type))) for ix, (arg_name, arg_sqla_type) in enumerate( zip(sql_function.arg_names, sql_function.arg_sqla_types))} return_type = convert_type(sql_function.return_sqla_type) return_type.sql_function = sql_function return_field = Field(return_type, args=gql_args, resolve=resolver, description="") return {function_name: return_field}
def jwt_function_payload_factory(sql_function: SQLFunction, jwt_secret: str) -> FunctionPayloadType: """CreateAccountPayload""" function_name = Config.function_type_name_mapper(sql_function) result_name = f"{function_name}Payload" function_return_type = ScalarType( "JWT", serialize=lambda result: jwt.encode({k: v for k, v in result.items()}, jwt_secret, algorithm="HS256").decode("utf-8"), ) attrs = { "clientMutationId": Field(String, resolve=default_resolver), "result": Field( function_return_type, description=f"The {result_name} that was created by this mutation.", resolve=default_resolver, ), } payload = FunctionPayloadType( result_name, attrs, description=f"The output of our create {function_name} mutation") payload.sql_function = sql_function return payload
def function_payload_factory(sql_function: SQLFunction) -> FunctionPayloadType: """CreateAccountPayload""" function_name = Config.function_type_name_mapper(sql_function) result_name = f"{function_name}Payload" # TODO(OR): handle functions with no return function_return_type = convert_type(sql_function.return_sqla_type) function_return_type.sql_function = sql_function attrs = { "clientMutationId": Field(String, resolve=default_resolver), "result": Field( function_return_type, description=f"The {result_name} that was created by this mutation.", resolve=default_resolver, ), } payload = FunctionPayloadType( result_name, attrs, description=f"The output of our create {function_name} mutation") payload.sql_function = sql_function return payload
def field_name_to_relationship(sqla_model: TableProtocol, gql_field_name: str) -> RelationshipProperty: for relationship in get_relationships(sqla_model): if Config.relationship_name_mapper(relationship) == gql_field_name: return relationship raise Exception(f"No relationship corresponding to field {gql_field_name}")
def condition_factory(sqla_model: TableProtocol) -> InputObjectType: result_name = f"{Config.table_name_mapper(sqla_model)}Condition" attrs = {} for column in get_columns(sqla_model): field_key = Config.column_name_mapper(column) attrs[field_key] = convert_column_to_input(column) return InputObjectType(result_name, attrs, description="")
def table_field_factory(sqla_model: TableProtocol, resolver) -> Field: relevant_type_name = Config.table_type_name_mapper(sqla_model) node = table_factory(sqla_model) return Field( node, args={"nodeId": Argument(NonNull(ID))}, resolve=resolver, description=f"Reads a single {relevant_type_name} using its globally unique ID", )
def edge_factory(sqla_model: TableProtocol) -> EdgeType: from .table import table_factory name = Config.table_type_name_mapper(sqla_model) + "Edge" def build_attrs(): return {"cursor": Field(Cursor), "node": Field(table_factory(sqla_model))} edge = EdgeType(name=name, fields=build_attrs, description="", sqla_model=sqla_model) return edge
def table_factory(sqla_model: TableProtocol) -> TableType: """ Reflects a SQLAlchemy table into a graphql-core GraphQLObjectType Parameters ---------- sqla_model A SQLAlchemy ORM Table """ from .connection import connection_field_factory name = Config.table_type_name_mapper(sqla_model) def build_attrs(): attrs = {} # Override id to relay standard attrs["nodeId"] = Field(NonNull(ID), resolve=default_resolver) for column in get_columns(sqla_model): if not Config.exclude_read(column): key = Config.column_name_mapper(column) attrs[key] = convert_column(column) for relationship in get_relationships(sqla_model): direction = relationship.direction to_sqla_model = relationship.mapper.class_ relationship_is_nullable = is_nullable(relationship) # Name of the attribute on the model attr_key = Config.relationship_name_mapper(relationship) # If this model has 1 counterpart, do not use a list if direction == interfaces.MANYTOONE: _type = table_factory(to_sqla_model) _type = NonNull(_type) if not relationship_is_nullable else _type attrs[attr_key] = Field(_type, resolve=default_resolver) # Otherwise, set it up as a connection elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY): connection_field = connection_field_factory( to_sqla_model, resolver=default_resolver, not_null=relationship_is_nullable ) attrs[attr_key] = connection_field return attrs return_type = TableType( name=name, fields=build_attrs, interfaces=[NodeInterface], description="", sqla_model=sqla_model ) return return_type
def create_entrypoint_factory(sqla_model: TableProtocol, resolver) -> Field: """createAccount""" relevant_type_name = Config.table_type_name_mapper(sqla_model) name = f"create{relevant_type_name}" args = {"input": NonNull(create_input_type_factory(sqla_model))} payload = create_payload_factory(sqla_model) return { name: Field(payload, args=args, resolve=resolver, description=f"Creates a single {relevant_type_name}.") }
def connection_factory(sqla_model: TableProtocol) -> ConnectionType: name = Config.table_type_name_mapper(sqla_model) + "Connection" def build_attrs(): edge = edge_factory(sqla_model) return { "edges": Field(NonNull(List(NonNull(edge))), resolve=default_resolver), "pageInfo": Field(NonNull(PageInfo), resolve=default_resolver), "totalCount": Field(NonNull(Int), resolve=default_resolver), } return_type = ConnectionType(name=name, fields=build_attrs, description="", sqla_model=sqla_model) return return_type
def update_entrypoint_factory(sqla_model: TableProtocol, resolver) -> t.Dict[str, Field]: """updateAccount""" relevant_type_name = Config.table_type_name_mapper(sqla_model) name = f"update{relevant_type_name}" args = {"input": NonNull(update_input_type_factory(sqla_model))} payload = update_payload_factory(sqla_model) return { name: Field( payload, args=args, resolve=resolver, description=f"Updates a single {relevant_type_name} using its globally unique id and a patch.", ) }
def delete_payload_factory(sqla_model: TableProtocol) -> InputObjectType: """DeleteAccountPayload""" relevant_type_name = Config.table_type_name_mapper(sqla_model) result_name = f"Delete{relevant_type_name}Payload" attrs = { "clientMutationId": Field(String, resolve=default_resolver), "nodeId": ID } return DeletePayloadType( result_name, attrs, description=f"The output of our delete {relevant_type_name} mutation", sqla_model=sqla_model)
def connection_field_factory(sqla_model: TableProtocol, resolver, not_null=False) -> Field: relevant_type_name = Config.table_type_name_mapper(sqla_model) connection = connection_factory(sqla_model) condition = condition_factory(sqla_model) args = { "first": Argument(Int, description="", out_name=None), "last": Argument(Int), "before": Argument(Cursor), "after": Argument(Cursor), "condition": Argument(condition), } return Field( NonNull(connection) if not_null else connection, args=args, resolve=resolver, description=f"Reads and enables pagination through a set of {relevant_type_name}", )
def function_input_type_factory( sql_function: SQLFunction) -> FunctionInputType: """AuthenticateInput!""" function_name = Config.function_type_name_mapper(sql_function) result_name = f"{function_name}Input" function_args = { (arg_name if arg_name else f"param{ix}"): NonNull(convert_input_type(arg_sqla_type)) for ix, (arg_name, arg_sqla_type) in enumerate( zip(sql_function.arg_names, sql_function.arg_sqla_types)) } attrs = {"clientMutationId": String, **function_args} return FunctionInputType( result_name, attrs, description=f"All input for the {function_name} mutation.")
def field_name_to_column(sqla_model: TableProtocol, gql_field_name: str) -> Column: for column in get_columns(sqla_model): if Config.column_name_mapper(column) == gql_field_name: return column raise KeyError(f"No column corresponding to field {gql_field_name}")
def sqla_models_to_graphql_schema( sqla_models, sql_functions: typing.Optional[typing.List[SQLFunction]] = None, jwt_identifier: typing.Optional[str] = None, jwt_secret: typing.Optional[str] = None, ) -> Schema: """Creates a GraphQL Schema from SQLA Models **Parameters** * **sqla_models**: _List[Type[SQLAModel]]_ = List of SQLAlchemy models to include in the GraphQL schema * **jwt_identifier**: _str_ = qualified path of SQL composite type to use encode as a JWT e.g. 'public.jwt' * **jwt_secret**: _str_ = Secret key used to encrypt JWT contents * **sql_functions** = **NOT PUBLIC API** """ if sql_functions is None: sql_functions = [] query_fields = {} mutation_fields = {} # Tables for sqla_model in sqla_models: if not Config.exclude_read_one(sqla_model): # e.g. account(nodeId: NodeID) single_name = snake_to_camel(get_table_name(sqla_model), upper=False) query_fields[single_name] = table_field_factory( sqla_model, resolver) if not Config.exclude_read_all(sqla_model): # e.g. allAccounts(first: Int, last: Int ....) connection_name = "all" + snake_to_camel( to_plural(get_table_name(sqla_model)), upper=True) query_fields[connection_name] = connection_field_factory( sqla_model, resolver) if not Config.exclude_create(sqla_model): # e.g. createAccount(input: CreateAccountInput) mutation_fields.update( create_entrypoint_factory(sqla_model, resolver=resolver)) if not Config.exclude_update(sqla_model): # e.g. updateAccount(input: UpdateAccountInput) mutation_fields.update( update_entrypoint_factory(sqla_model, resolver=resolver)) if not Config.exclude_delete(sqla_model): # e.g. deleteAccount(input: DeleteAccountInput) mutation_fields.update( delete_entrypoint_factory(sqla_model, resolver=resolver)) # Functions for sql_function in sql_functions: if is_jwt_function(sql_function, jwt_identifier): mutation_fields.update( mutable_function_entrypoint_factory(sql_function=sql_function, resolver=resolver, jwt_secret=jwt_secret)) else: # Immutable functions are queries if sql_function.is_immutable: query_fields.update( immutable_function_entrypoint_factory( sql_function=sql_function, resolver=resolver)) # Mutable functions are mutations else: mutation_fields.update( mutable_function_entrypoint_factory( sql_function=sql_function, resolver=resolver)) schema_kwargs = { "query": ObjectType(name="Query", fields=query_fields), "mutation": ObjectType(name="Mutation", fields=mutation_fields), } return Schema(**{k: v for k, v in schema_kwargs.items() if v.fields})
def enum_factory(sqla_enum: typing.Type[postgresql.base.ENUM]) -> EnumType: name = Config.enum_name_mapper(sqla_enum) return EnumType(name=name, values={val: val for val in sqla_enum.enums})
async def async_resolver(_, info: ResolveInfo, **kwargs) -> typing.Any: """Awaitable GraphQL Entrypoint resolver Expects: info.context['engine'] to contain an sqlalchemy.ext.asyncio.AsyncEngine """ context = info.context engine = context["engine"] default_role = context["default_role"] jwt_claims = context["jwt_claims"] tree = parse_resolve_info(info) async with engine.begin() as trans: # Set claims for transaction if jwt_claims or default_role: claims_stmt = build_claims(jwt_claims, default_role) await trans.execute(claims_stmt) result: typing.Dict[str, typing.Any] if isinstance(tree.return_type, FunctionPayloadType): sql_function = tree.return_type.sql_function function_args = [ val for key, val in tree.args["input"].items() if key != "clientMutationId" ] func_call = sql_function.to_executable(function_args) # Function returning table row if isinstance(sql_function.return_sqla_type, TableProtocol): # Unpack the table row to columns return_sqla_model = sql_function.return_sqla_type core_table = return_sqla_model.__table__ func_alias = func_call.alias("named_alias") stmt = select([ literal_column(c.name).label(c.name) for c in core_table.c ]).select_from(func_alias) # type: ignore stmt_alias = stmt.alias() node_id_stmt = select([ to_node_id_sql(return_sqla_model, stmt_alias).label("nodeId") ]).select_from(stmt_alias) # type: ignore ((row, ), ) = await trans.execute(node_id_stmt) node_id = NodeIdStructure.from_dict(row) # Add nodeId to AST and query query_tree = next( iter([x for x in tree.fields if x.name == "result"]), None) if query_tree is not None: query_tree.args["nodeId"] = node_id base_query = sql_builder(query_tree) query = sql_finalize(query_tree.alias, base_query) ((stmt_result, ), ) = await trans.execute(query) else: stmt_result = {} else: stmt = select([func_call.label("result")]) (stmt_result, ) = await trans.execute(stmt) maybe_mutation_id = tree.args["input"].get("clientMutationId") mutation_id_alias = next( iter([ x.alias for x in tree.fields if x.name == "clientMutationId" ]), "clientMutationId", ) result = { tree.alias: { **stmt_result, **{ mutation_id_alias: maybe_mutation_id } } } elif isinstance(tree.return_type, MutationPayloadType): stmt = build_mutation(tree) ((row, ), ) = await trans.execute(stmt) node_id = NodeIdStructure.from_dict(row) maybe_mutation_id = tree.args["input"].get("clientMutationId") mutation_id_alias = next( iter([ x.alias for x in tree.fields if x.name == "clientMutationId" ]), "clientMutationId", ) node_id_alias = next( iter([x.alias for x in tree.fields if x.name == "nodeId"]), "nodeId") output_row_name: str = Config.table_name_mapper( tree.return_type.sqla_model) query_tree = next( iter([x for x in tree.fields if x.name == output_row_name]), None) sql_result = {} if query_tree: # Set the nodeid of the newly created record as an arg query_tree.args["nodeId"] = node_id base_query = sql_builder(query_tree) query = sql_finalize(query_tree.alias, base_query) ((sql_result, ), ) = await trans.execute(query) result = { tree.alias: { **sql_result, mutation_id_alias: maybe_mutation_id }, mutation_id_alias: maybe_mutation_id, node_id_alias: node_id, } elif isinstance(tree.return_type, (ObjectType, ScalarType)): base_query = sql_builder(tree) query = sql_finalize(tree.name, base_query) ((query_json_result, ), ) = await trans.execute(query) if isinstance(tree.return_type, ScalarType): # If its a scalar, unwrap the top level name result = flu(query_json_result.values()).first(None) else: result = query_json_result else: raise Exception("sql builder could not handle return type") # Stash result on context to enable dumb resolvers to not fail context["result"] = result return result