Example #1
0
    def parse(
        self,
        query: str,
        full_fragments: str = "",
        should_validate: bool = True,
        is_fragment: bool = False,
    ) -> ParsedQuery:
        query_document_ast = parse("".join([full_fragments, query]))
        document_ast = parse(query)
        if not is_fragment:
            operation = get_operation_ast(document_ast)
            if not operation.name:
                raise AnonymousQueryError()

        if should_validate:
            errors = validate(
                self.schema,
                query_document_ast,
                [rule for rule in specified_rules if rule is not NoUnusedFragmentsRule],
            )
            if errors:
                raise InvalidQueryError(errors)

        type_info = TypeInfo(self.schema)
        visitor = FieldToTypeMatcherVisitor(self.schema, type_info, query)
        visit(document_ast, TypeInfoVisitor(type_info, visitor))
        result = visitor.parsed
        return result
Example #2
0
def parse_result_recursive(
    schema: GraphQLSchema,
    document: DocumentNode,
    node: Node,
    result: Optional[Dict[str, Any]],
    initial_type: Optional[GraphQLType] = None,
    inside_list_level: int = 0,
    visit_fragment: bool = False,
    operation_name: Optional[str] = None,
) -> Any:

    if result is None:
        return None

    type_info = TypeInfo(schema, initial_type=initial_type)

    visited = visit(
        node,
        TypeInfoVisitor(
            type_info,
            ParseResultVisitor(
                schema,
                document,
                node,
                result,
                type_info=type_info,
                inside_list_level=inside_list_level,
                visit_fragment=visit_fragment,
                operation_name=operation_name,
            ),
        ),
        visitor_keys=RESULT_DOCUMENT_KEYS,
    )

    return visited
Example #3
0
 def _calculate_cost(ast_document, max_complexity=100) -> int:
     context = ValidationContext(schema=post_schema,
                                 ast=ast_document,
                                 type_info=TypeInfo(post_schema),
                                 on_error=on_error_stub)
     visitor = CostAnalysisVisitor(context=context,
                                   max_complexity=max_complexity)
     visit(ast_document, visitor)
     return visitor.total_complexity
Example #4
0
    def parse(self, query: str, should_validate: bool = True) -> ParsedQuery:
        document_ast = parse(query)
        operation = get_operation_ast(document_ast)

        if not operation.name:
            raise AnonymousQueryError()

        if should_validate:
            errors = validate(self.schema, document_ast)
            if errors:
                raise InvalidQueryError(errors)

        type_info = TypeInfo(self.schema)
        visitor = FieldToTypeMatcherVisitor(self.schema, type_info, query)
        visit(document_ast, TypeInfoVisitor(type_info, visitor))
        result = visitor.parsed
        return result
Example #5
0
def validate_depth(schema: GraphQLSchema, document: DocumentNode,
                   context_value: Any, errors: List[GraphQLError]):
    def on_error(error: GraphQLError) -> None:
        errors.append(error)
        raise ValidationAbortedError

    depth_analysis = context_value.get("depth_analysis")

    if depth_analysis:
        max_depth = depth_analysis.get("max_depth", 10)
        context = ValidationContext(schema=schema,
                                    ast=document,
                                    type_info=TypeInfo(schema),
                                    on_error=on_error)

        visit(document,
              DepthAnalysisVisitor(context=context, max_depth=max_depth))
Example #6
0
def validate_cost(schema: GraphQLSchema, document: DocumentNode,
                  context_value: Any, errors: List[GraphQLError]):
    def on_error(error: GraphQLError) -> None:
        errors.append(error)
        raise ValidationAbortedError

    cost_analysis = context_value.get("cost_analysis")

    if cost_analysis:
        max_complexity = cost_analysis.get("max_complexity", 500)
        context = ValidationContext(schema=schema,
                                    ast=document,
                                    type_info=TypeInfo(schema),
                                    on_error=on_error)

        visit(
            document,
            CostAnalysisVisitor(context=context,
                                max_complexity=max_complexity))

    return []
def _split_query_ast_one_level_recursive_normal_fields(
    query_node: SubQueryNode,
    selections: List[SelectionNode],
    type_info: TypeInfo,
    edge_to_stitch_fields: Dict[Tuple[str, str], Tuple[str, str]],
    name_assigner: IntermediateOutNameAssigner,
) -> Sequence[SelectionNode]:
    """One case of splitting query, selections contains a number of fields, no inline fragments.

    The input selections will be divided into three sets: property fields, intra-schema vertex
    fields, and cross-schema vertex fields.

    Each cross-schema vertex field will not be included in the output selections. The AST
    branch that each cross-schema vertex field leads to will be made into its own separate query
    AST. The parent and child property fields used in the stitch will be added to the parent and
    child ASTs, if not already present. @output directives will be added to these parent and
    child property fields, if not already present. @filter directives will not be added to
    child property fields in this step. This is because one may choose to rearrange and reroot
    the tree of SubQueryNodes to achieve an execution order with better performance. @filter
    directives should be added only once the tree's structure is fixed.

    _split_query_ast_one_level_recursive will be called recursive on each intra-schema vertex
    field.

    Args:
        query_node: Containing list of child query connections may be modified to
                    include new children.
        selections: Containing a number of property fields and vertex fields.
        type_info: Used to get information about the types of fields while traversing the query AST.
        edge_to_stitch_fields: Mapping (type name, vertex field name) to
                               (source field name, sink field name) used in the @stitch directive
                               for each cross schema edge.
        name_assigner: Used to generate and keep track of names of newly created @output directives.

    Returns:
        List of SelectionNodes to replace the list of selections in the SelectionSet one level
        above. All cross schema edges in the input list will be removed, and in their place,
        property fields added or modified. If no changes were made, the exact input list object
        will be returned.
    """
    parent_type = type_info.get_parent_type()
    if parent_type is None:
        raise AssertionError("parent_type cannot be None.")
    parent_type_name = parent_type.name

    made_changes = False

    # First, collect all property fields, but don't make any changes to them yet
    property_fields_map, vertex_fields = _split_selections_property_and_vertex(
        selections)

    # Second, process cross schema fields. This will modify our record of property fields, and
    # create child SubQueryNodes attached to the input SubQueryNode
    intra_schema_fields, cross_schema_fields = _split_vertex_fields_intra_and_cross_schema(
        vertex_fields, parent_type_name, edge_to_stitch_fields)
    for cross_schema_field in cross_schema_fields:
        type_info.enter(cross_schema_field)
        child_type = type_info.get_type()
        if child_type is not None:
            child_type_name = strip_non_null_and_list_from_type(
                child_type).name
        else:
            raise AssertionError(
                "The query may be invalid against the schema, causing TypeInfo to lose track "
                'of the types of fields. This occurs at the cross schema field "{}", while '
                'splitting the AST "{}"'.format(cross_schema_field,
                                                query_node.query_ast))
        stitch_data_key = (parent_type_name, cross_schema_field.name.value)
        parent_field_name, child_field_name = edge_to_stitch_fields[
            stitch_data_key]
        _process_cross_schema_field(
            query_node,
            cross_schema_field,
            property_fields_map,
            child_type_name,
            parent_field_name,
            child_field_name,
            name_assigner,
        )
        made_changes = True  # Cross schema edges are removed from the output, causing changes
        type_info.leave(cross_schema_field)

    # Third, process intra schema edges by recursing on them
    new_intra_schema_fields: List[SelectionNode] = []
    for intra_schema_field in intra_schema_fields:
        type_info.enter(intra_schema_field)
        new_intra_schema_field = _split_query_ast_one_level_recursive(
            query_node, intra_schema_field, type_info, edge_to_stitch_fields,
            name_assigner)
        if new_intra_schema_field is not intra_schema_field:
            made_changes = True
        new_intra_schema_fields.append(new_intra_schema_field)
        type_info.leave(intra_schema_field)

    # Return input, or make copy
    if made_changes:
        new_selections: Sequence[
            SelectionNode] = _get_selections_from_property_and_vertex_fields(
                property_fields_map, new_intra_schema_fields)
        return new_selections
    else:
        return selections
def _split_query_ast_one_level_recursive(
    query_node: SubQueryNode,
    ast: AstType,
    type_info: TypeInfo,
    edge_to_stitch_fields: Dict[Tuple[str, str], Tuple[str, str]],
    name_assigner: IntermediateOutNameAssigner,
) -> AstType:
    """Return an AST node with which to replace the input AST in the selections that contain it.

    This function examines the selections of the input AST, and recursively calls either
    _split_query_ast_one_level_recursive or _split_query_ast_one_level_recursive_normal_fields
    depending on whether the selections contains a single InlineFragmentNode or a number of normal
    fields.

    Args:
        query_node: SubQueryNode, whose list of child query connections may be modified to
                    include new children.
        ast: The AST that we are trying to split into child components.
             It is not modified by this function.
        type_info: Used to get information about the types of fields while traversing the query AST.
        edge_to_stitch_fields: Mapping (type name, vertex field name) to
                               (source field name, sink field name) used in the @stitch directive
                               for each cross schema edge.
        name_assigner: Object used to generate and keep track of names of newly created
                       @output directives.

    Returns:
        The AST with which to replace the input AST in the selections that contain it.
    """
    if ast.selection_set is None:
        raise AssertionError("AST's selection_set cannot be None.")
    type_info.enter(ast.selection_set)
    selections = ast.selection_set.selections

    type_coercion = try_get_inline_fragment(selections)
    if type_coercion is not None:
        # Case 1: type coercion
        type_info.enter(type_coercion)
        new_type_coercion = _split_query_ast_one_level_recursive(
            query_node, type_coercion, type_info, edge_to_stitch_fields,
            name_assigner)
        type_info.leave(type_coercion)

        if new_type_coercion is type_coercion:
            new_selections: Sequence[SelectionNode] = selections
        else:
            new_selections = [new_type_coercion]
    else:
        # Case 2: normal fields
        new_selections = _split_query_ast_one_level_recursive_normal_fields(
            query_node, selections, type_info, edge_to_stitch_fields,
            name_assigner)
    type_info.leave(ast.selection_set)

    # Return input, or make copy
    if new_selections is not selections:
        new_ast = copy(ast)
        new_ast.selection_set = SelectionSetNode(selections=new_selections)
        return new_ast
    else:
        return ast
def _split_query_one_level(
    query_node: SubQueryNode,
    merged_schema_descriptor: MergedSchemaDescriptor,
    edge_to_stitch_fields: Dict[Tuple[str, str], Tuple[str, str]],
    name_assigner: IntermediateOutNameAssigner,
) -> None:
    """Split the query node, creating children out of all branches across cross schema edges.

    The input query_node will be modified. Its query_ast will be replaced by a new AST with
    branches leading out of cross schema edges removed, and new property fields and @output
    directives added as necessary. Its child_query_connections will be modified by tacking
    on SubQueryNodes created from these cut-off branches.

    Args:
        query_node: Query to be split into its child components. Its query_ast
                    will be replaced (but the original AST will not be modified) and its
                    child_query_connections will be modified.
        merged_schema_descriptor: The schema that the query AST contained in the input
                                  query_node targets.
        edge_to_stitch_fields: Mapping (type name, vertex field name) to
                               (source field name, sink field name) used in the @stitch directive
                               for each cross schema edge.
        name_assigner: Object used to generate and keep track of names of newly created
                       @output directive.

    Raises:
        - GraphQLValidationError if the query AST contained in the input query_node is invalid,
          for example, having an @output directive on a cross schema edge
        - SchemaStructureError if the merged_schema_descriptor provided appears to be invalid
          or inconsistent
    """
    type_info = TypeInfo(merged_schema_descriptor.schema)

    operation_definition = get_only_query_definition(query_node.query_ast,
                                                     GraphQLValidationError)
    if not isinstance(operation_definition, OperationDefinitionNode):
        raise AssertionError(
            f"Expected operation_definition to be an OperationDefinitionNode, but it was of"
            f"type {type(operation_definition)}. This should be impossible.")

    type_info.enter(operation_definition)
    new_operation_definition = _split_query_ast_one_level_recursive(
        query_node, operation_definition, type_info, edge_to_stitch_fields,
        name_assigner)
    type_info.leave(operation_definition)

    if new_operation_definition is not operation_definition:
        query_node.query_ast = DocumentNode(
            definitions=[new_operation_definition])

    # Check resulting AST is valid
    validation_errors = validate(merged_schema_descriptor.schema,
                                 query_node.query_ast)
    if len(validation_errors) > 0:
        raise AssertionError(
            'The resulting split query "{}" is invalid, with the following error messages: {}'
            "".format(query_node.query_ast, validation_errors))

    # Set schema id, check for consistency
    visitor = TypeInfoVisitor(
        type_info,
        SchemaIdSetterVisitor(type_info, query_node,
                              merged_schema_descriptor.type_name_to_schema_id),
    )
    visit(query_node.query_ast, visitor)

    if query_node.schema_id is None:
        raise AssertionError(
            'Unreachable code reached. The schema id of query piece "{}" has not been '
            "determined.".format(query_node.query_ast))