def test_reflect_fkey_comment(engine):
    engine.execute(SQL_UP)
    tables, _ = reflect_sqla_models(engine, schema="public")

    table = [x for x in tables if get_table_name(x) == "address"][0]

    constraint = [x for x in get_constraints(table) if x.name == "fk_person"][0]
    assert get_comment(constraint) == "@name Person Addresses"
Example #2
0
def to_pkey_clause(field: ASTNode, pkey_eq: typing.List[str]) -> typing.List[BinaryExpression]:
    local_table = field.return_type.sqla_model
    local_table_name = get_table_name(field.return_type.sqla_model)
    pkey_cols = get_primary_key_columns(local_table)

    res = []
    for col, val in zip(pkey_cols, pkey_eq):
        res.append(literal_column(f"{local_table_name}.{col.name}") == val)
    return res
Example #3
0
def to_conditions_clause(field: ASTNode) -> typing.List[BinaryExpression]:
    return_sqla_model = field.return_type.sqla_model
    local_table_name = get_table_name(return_sqla_model)
    args = field.args

    conditions = args.get("condition")

    if conditions is None:
        return [True]

    res = []
    for field_name, val in conditions.items():
        column = field_name_to_column(return_sqla_model, field_name)
        res.append(column == val)
    return res
Example #4
0
def to_join_clause(field: ASTNode, parent_block_name: str) -> typing.List[BinaryExpression]:
    parent_field = field.parent
    assert parent_field is not None
    relation_from_parent = field_name_to_relationship(parent_field.return_type.sqla_model, field.name)
    local_table_name = get_table_name(field.return_type.sqla_model)

    join_clause: typing.List[BinaryExpression] = []
    for parent_col, local_col in relation_from_parent.local_remote_pairs:
        parent_col_name = parent_col.name
        local_col_name = local_col.name
        join_clause.append(
            literal_column(f"{parent_block_name}.{parent_col_name}")
            == literal_column(f"{local_table_name}.{local_col_name}")
        )
    return join_clause
Example #5
0
def to_cursor_sql(sqla_model, query_elem: Alias):
    table_name = get_table_name(sqla_model)

    pkey_cols = get_primary_key_columns(sqla_model)

    # Columns selected from query element
    vals = []
    for col in pkey_cols:
        col_name = str(col.name)
        vals.extend([literal_string(col_name), query_elem.c[col_name]])

    return func.jsonb_build_object(
        literal_string("table_name"),
        literal_string(table_name),
        literal_string("values"),
        func.jsonb_build_object(*vals),
    )
Example #6
0
    def relationship_name_mapper(relationship: RelationshipProperty) -> str:
        # Union of Mapper or ORM instance
        referred_cls = relationship.argument
        if hasattr(referred_cls, "class_"):
            referred_cls = referred_cls.class_
        elif callable(referred_cls):
            referred_cls = referred_cls()

        referred_name = get_table_name(referred_cls)
        cardinal_name = to_plural(
            referred_name) if relationship.uselist else referred_name
        camel_name = snake_to_camel(cardinal_name, upper=False)
        relationship_name = (camel_name + "By" + "And".join(
            snake_to_camel(local_col.name, upper=True) + "To" +
            snake_to_camel(remote_col.name, upper=True)
            for local_col, remote_col in relationship.local_remote_pairs))
        return relationship_name
Example #7
0
def populate_constraint_comment(engine: Engine, constraint: Constraint) -> None:
    """Adds SQL comments on a constraint to the SQLAlchemy constraint's
    Constraint.info['comment'] dictionary
    """

    schema: Optional[str] = constraint.table.schema
    table_name: str = get_table_name(constraint.table)
    constraint_name: Optional[str] = constraint.name

    if schema is None or constraint_name is None:
        return

    comment_map: CommentMap = reflect_all_constraint_comments(engine=engine, schema=schema)
    comment: Optional[str] = comment_map.get(schema, {}).get(table_name, {}).get(constraint_name)

    # constraint.info is "Optional[Mapping[str, Any]]"
    if not hasattr(constraint, "info"):
        constraint.info = {"comment": comment}
    else:
        constraint.info["comment"] = comment  # type: ignore

    return
Example #8
0
def connection_block(field: ASTNode,
                     parent_name: typing.Optional[str]) -> Alias:
    return_type = field.return_type
    sqla_model = return_type.sqla_model

    block_name = slugify_path(field.path)
    if parent_name is None:
        join_conditions = [True]
    else:
        join_conditions = to_join_clause(field, parent_name)

    filter_conditions = to_conditions_clause(field)
    limit = to_limit(field)
    has_total = check_has_total(field)

    is_page_after = "after" in field.args
    is_page_before = "before" in field.args

    totalCount_alias = field.get_subfield_alias(["totalCount"])

    edges_alias = field.get_subfield_alias(["edges"])
    node_alias = field.get_subfield_alias(["edges", "node"])
    cursor_alias = field.get_subfield_alias(["edges", "cursor"])

    pageInfo_alias = field.get_subfield_alias(["pageInfo"])
    hasNextPage_alias = field.get_subfield_alias(["pageInfo", "hasNextPage"])
    hasPreviousPage_alias = field.get_subfield_alias(
        ["pageInfo", "hasPreviousPage"])
    startCursor_alias = field.get_subfield_alias(["pageInfo", "startCursor"])
    endCursor_alias = field.get_subfield_alias(["pageInfo", "endCursor"])

    # Apply Filters
    core_model = sqla_model.__table__
    core_model_ref = (
        select(core_model.c).select_from(core_model).where(
            and_(
                # Join clause
                *join_conditions,
                # Conditions
                *filter_conditions,
            ))).alias(block_name)

    new_edge_node_selects = []
    new_relation_selects = []

    for subfield in get_edge_node_fields(field):
        # Does anything other than NodeID go here?
        if subfield.return_type == ID:
            # elem = select([to_node_id_sql(sqla_model, core_model_ref)]).label(subfield.alias)
            elem = to_node_id_sql(sqla_model,
                                  core_model_ref).label(subfield.alias)
            new_edge_node_selects.append(elem)
        elif isinstance(subfield.return_type,
                        (ScalarType, CompositeType, EnumType)):
            col_name = field_name_to_column(sqla_model, subfield.name).name
            elem = core_model_ref.c[col_name].label(subfield.alias)
            new_edge_node_selects.append(elem)
        else:
            elem = build_relationship(subfield, block_name)
            new_relation_selects.append(elem)

    # Setup Pagination
    args = field.args
    after_cursor = args.get("after", None)
    before_cursor = args.get("before", None)
    first = args.get("first", None)
    last = args.get("last", None)

    if first is not None and last is not None:
        raise ValueError('only one of "first" and "last" may be provided')

    pkey_cols = get_primary_key_columns(sqla_model)

    if after_cursor or before_cursor:
        local_table_name = get_table_name(field.return_type.sqla_model)
        cursor_table_name = before_cursor.table_name if before_cursor else after_cursor.table_name
        cursor_values = before_cursor.values if before_cursor else after_cursor.values

        if after_cursor is not None and before_cursor is not None:
            raise ValueError(
                'only one of "before" and "after" may be provided')

        if after_cursor is not None and last is not None:
            raise ValueError(
                '"after" is not compatible with "last". Use "first"')

        if before_cursor is not None and first is not None:
            raise ValueError(
                '"before" is not compatible with "first". Use "last"')

        if cursor_table_name != local_table_name:
            raise ValueError("Invalid cursor for entity type")

        pagination_clause = tuple_(
            *[core_model_ref.c[col.name] for col in pkey_cols]).op(
                ">" if after_cursor is not None else "<")(
                    tuple_(*[cursor_values[col.name] for col in pkey_cols]))
    else:
        pagination_clause = True

    order_clause = [
        asc(core_model_ref.c[col.name])
        for col in get_primary_key_columns(sqla_model)
    ]
    reverse_order_clause = [
        desc(core_model_ref.c[col.name])
        for col in get_primary_key_columns(sqla_model)
    ]

    total_block = (select([func.count(ONE).label("total_count")]).select_from(
        core_model_ref.alias()).where(has_total)).alias(block_name + "_total")

    node_id_sql = to_node_id_sql(sqla_model, core_model_ref)
    cursor_sql = to_cursor_sql(sqla_model, core_model_ref)

    # Select the right stuff
    p1_block = (
        select([
            *new_edge_node_selects,
            *new_relation_selects,
            # For internal Use
            node_id_sql.label("_nodeId"),
            cursor_sql.label("_cursor"),
            # For internal Use
            func.row_number().over().label("_row_num"),
        ]).select_from(core_model_ref).where(pagination_clause).order_by(
            *(reverse_order_clause if
              (is_page_before or last is not None) else order_clause),
            *order_clause).limit(cast(limit + 1,
                                      Integer()))).alias(block_name + "_p1")

    # Drop maybe extra row
    p2_block = (select([
        *p1_block.c,
        (func.max(p1_block.c._row_num).over() > limit).label("_has_next_page")
    ]).select_from(p1_block).limit(limit)).alias(block_name + "_p2")

    ordering = (desc(literal_column("_row_num")) if
                (is_page_before or last is not None) else asc(
                    literal_column("_row_num")))

    p3_block = (select(p2_block.c).select_from(p2_block).order_by(ordering)
                ).alias(block_name + "_p3")

    final = (select([
        func.jsonb_build_object(
            literal_string(totalCount_alias),
            func.coalesce(func.min(total_block.c.total_count), ZERO)
            if has_total else None,
            literal_string(pageInfo_alias),
            func.jsonb_build_object(
                literal_string(hasNextPage_alias),
                func.coalesce(
                    func.array_agg(p3_block.c._has_next_page)[ONE], FALSE),
                literal_string(hasPreviousPage_alias),
                TRUE if is_page_after else FALSE,
                literal_string(startCursor_alias),
                func.array_agg(p3_block.c._nodeId)[ONE],
                literal_string(endCursor_alias),
                func.array_agg(p3_block.c._nodeId)[func.array_upper(
                    func.array_agg(p3_block.c._nodeId), ONE)],
            ),
            literal_string(edges_alias),
            func.coalesce(
                func.jsonb_agg(
                    func.jsonb_build_object(
                        literal_string(cursor_alias),
                        p3_block.c._nodeId,
                        literal_string(node_alias),
                        func.cast(
                            func.row_to_json(literal_column(p3_block.name)),
                            JSONB()),
                    )),
                func.cast(literal("[]"), JSONB()),
            ),
        ).label("ret_json")
    ]).select_from(p3_block).select_from(
        total_block if has_total else select([1]).alias())).alias()

    return final
Example #9
0
def sqla_models_to_graphql_schema(
    sqla_models,
    sql_functions: typing.Optional[typing.List[SQLFunction]] = None,
    jwt_identifier: typing.Optional[str] = None,
    jwt_secret: typing.Optional[str] = None,
) -> Schema:
    """Creates a GraphQL Schema from SQLA Models

    **Parameters**

    * **sqla_models**: _List[Type[SQLAModel]]_ = List of SQLAlchemy models to include in the GraphQL schema
    * **jwt_identifier**: _str_ = qualified path of SQL composite type to use encode as a JWT e.g. 'public.jwt'
    * **jwt_secret**: _str_ = Secret key used to encrypt JWT contents
    * **sql_functions** = **NOT PUBLIC API**
    """

    if sql_functions is None:
        sql_functions = []

    query_fields = {}
    mutation_fields = {}

    # Tables
    for sqla_model in sqla_models:

        if not Config.exclude_read_one(sqla_model):
            # e.g. account(nodeId: NodeID)
            single_name = snake_to_camel(get_table_name(sqla_model),
                                         upper=False)
            query_fields[single_name] = table_field_factory(
                sqla_model, resolver)

        if not Config.exclude_read_all(sqla_model):
            # e.g. allAccounts(first: Int, last: Int ....)
            connection_name = "all" + snake_to_camel(
                to_plural(get_table_name(sqla_model)), upper=True)
            query_fields[connection_name] = connection_field_factory(
                sqla_model, resolver)

        if not Config.exclude_create(sqla_model):
            # e.g. createAccount(input: CreateAccountInput)
            mutation_fields.update(
                create_entrypoint_factory(sqla_model, resolver=resolver))

        if not Config.exclude_update(sqla_model):
            # e.g. updateAccount(input: UpdateAccountInput)
            mutation_fields.update(
                update_entrypoint_factory(sqla_model, resolver=resolver))

        if not Config.exclude_delete(sqla_model):
            # e.g. deleteAccount(input: DeleteAccountInput)
            mutation_fields.update(
                delete_entrypoint_factory(sqla_model, resolver=resolver))
    # Functions
    for sql_function in sql_functions:
        if is_jwt_function(sql_function, jwt_identifier):
            mutation_fields.update(
                mutable_function_entrypoint_factory(sql_function=sql_function,
                                                    resolver=resolver,
                                                    jwt_secret=jwt_secret))
        else:

            # Immutable functions are queries
            if sql_function.is_immutable:
                query_fields.update(
                    immutable_function_entrypoint_factory(
                        sql_function=sql_function, resolver=resolver))

            # Mutable functions are mutations
            else:
                mutation_fields.update(
                    mutable_function_entrypoint_factory(
                        sql_function=sql_function, resolver=resolver))

    schema_kwargs = {
        "query": ObjectType(name="Query", fields=query_fields),
        "mutation": ObjectType(name="Mutation", fields=mutation_fields),
    }
    return Schema(**{k: v for k, v in schema_kwargs.items() if v.fields})
Example #10
0
 def table_type_name_mapper(sqla_table: TableProtocol) -> str:
     table_name = get_table_name(sqla_table)
     return snake_to_camel(table_name)