def request_from_graphql_ast(ast, root, context, variables, field, fragments): if isinstance(ast, ast_types.Field): key = field_key(ast) else: key = None if field is None: args = {} else: args = get_argument_values(field.args, getattr(ast, "arguments", []), variables=variables) selections = _graphql_selections(ast, root, context=context, variables=variables, fragments=fragments) return Request( key=key, field=field, args=args, selections=selections, join_selections=(), context=context, )
def args_selection_search(self, selection_set, variables, parent_type, arg_name, arg_value): """Recursively search through feild/fragment selection set fields.""" for field in selection_set.selections: if isinstance(field, ast.FragmentSpread): if field.name.value in self.visited_fragments: continue frag_def = self.fragment_defs[field.name.value] frag_type = type_from_ast(self.schema, frag_def.type_condition) if self.args_selection_search(frag_def.selection_set, variables, frag_type, arg_name, arg_value): return True self.visited_fragments.add(frag_def.name) continue field_def = get_field_def(self.schema, parent_type, field.name.value) if field_def is None: continue arg_vals = get_argument_values(field_def.args, field.arguments, variables) if arg_vals.get(arg_name) == arg_value: return True if field.selection_set is None: continue if self.args_selection_search(field.selection_set, variables, get_named_type(field_def.type), arg_name, arg_value): return True return False
def _should_include_node(node, variables): for directive in node.directives: name = directive.name.value if name == "skip": args = get_argument_values(GraphQLSkipDirective.args, directive.arguments, variables) if args.get("if") is True: return False elif name == "include": args = get_argument_values(GraphQLIncludeDirective.args, directive.arguments, variables) if args.get("if") is False: return False else: raise Exception("Unknown directive: {}".format(name)) return True
def _visitor_selector( type_: VisitableSchemaType, method_name: str) -> List["SchemaDirectiveVisitor"]: visitors: List["SchemaDirectiveVisitor"] = [] directive_nodes = type_.ast_node.directives if type_.ast_node else None if directive_nodes is None: return visitors for directive_node in directive_nodes: directive_name = directive_node.name.value if directive_name not in directive_visitors: continue visitor_class = directive_visitors[directive_name] # Avoid creating visitor objects if visitor_class does not override # the visitor method named by method_name. if not visitor_class.implements_visitor_method(method_name): continue decl = declared_directives[directive_name] args: Dict[str, Any] = {} if decl: # If this directive was explicitly declared, use the declared # argument types (and any default values) to check, coerce, and/or # supply default values for the given arguments. args = get_argument_values(decl, directive_node) else: # If this directive was not explicitly declared, just convert the # argument nodes to their corresponding values. for arg in directive_node.arguments: args[arg.name.value] = value_from_ast_untyped( arg.value) # As foretold in comments near the top of the visit_schema_directives # method, this is where instances of the SchemaDirectiveVisitor class # get created and assigned names. While subclasses could override the # constructor method, the constructor is marked as protected, so # these are the only arguments that will ever be passed. visitors.append( visitor_class(directive_name, args, type_, schema, context)) for visitor in visitors: created_visitors[visitor.name].append(visitor) return visitors
def resolve_field_value_or_error( self, field_def: GraphQLField, field_nodes: Sequence[FieldNode], resolve_fn: GraphQLFieldResolver, source: Any, info: GraphQLResolveInfo, ) -> Union[Exception, Any]: try: is_introspection = is_introspection_type(info.parent_type) camelcase = getattr(info.schema, "camelcase", False) arguments = get_argument_values(field_def, field_nodes[0], self.variable_values) if camelcase and not is_introspection: arguments = to_snake(arguments=arguments) result = resolve_fn(source, info, **arguments) return result except GraphQLError as e: return e except Exception as e: return e
def resolve_field_value_or_error(self, field_def, field_nodes, resolve_fn, source, info): """Resolve field to a value or an error. Isolates the "ReturnOrAbrupt" behavior to not de-opt the resolve_field() method. Returns the result of resolveFn or the abrupt-return Error object. For internal use only. """ try: # Build a dictionary of arguments from the field.arguments AST, using the # variables scope to fulfill any variable references. args = get_argument_values(field_def, field_nodes[0], self.variable_values) # Note that contrary to the JavaScript implementation, we pass the context # value as part of the resolve info. result = resolve_fn(source, info, **args) if self.is_awaitable(result): # noinspection PyShadowingNames async def await_result(): try: return await result except GraphQLError as error: return error # except Exception as error: # return GraphQLError(str(error), original_error=error) # Yes, this is commented out code. It's been intentionally # _not_ removed to show what has changed from the original # implementation. return await_result() return result except GraphQLError as error: return error
def compute_node_cost(self, node, type_definition, parent_multiplier=None, parent_complexity=None): if not parent_multiplier: parent_multiplier = [] if not node.selection_set: return 0 fields = {} if isinstance(type_definition, GraphQLObjectType) or isinstance( type_definition, GraphQLInterfaceType): fields = type_definition.fields total_cost = 0 fragment_costs = [] variables = {} # TODO get variables from operation selections = node.selection_set.selections for selection in selections: self.operation_multipliers = [*parent_multiplier] node_cost = 0 if selection.kind == 'field': # Calculate cost for FieldNode field: GraphQLField = fields.get(selection.name.value) if not field: break field_type = get_named_type(field.type) field_args = get_argument_values(field, selection, variables) use_field_type_complexity = False cost_is_computed = False if field.ast_node and field.ast_node.directives: directive_args: Union[ Tuple[int, List, bool], None] = self.get_args_from_directives( directives=field.ast_node.directives, field_args=field_args) override_complexity = directive_args[-1] if not override_complexity and isinstance( field_type, GraphQLObjectType): use_field_type_complexity = True parent_complexity, _, _ = self.get_args_from_directives( directives=field_type.ast_node.directives, field_args=field_args) node_cost = self.compute_cost(directive_args) if directive_args: cost_is_computed = True if field_type and field_type.ast_node and \ field_type.ast_node.directives and \ isinstance(field_type, GraphQLObjectType) and \ (not cost_is_computed or use_field_type_complexity): directive_args = self.get_args_from_directives( directives=field_type.ast_node.directives, field_args=field_args) node_cost = self.compute_cost(directive_args) child_cost = self.compute_node_cost( node=selection, type_definition=field_type, parent_multiplier=self.operation_multipliers, parent_complexity=parent_complexity) or 0 node_cost += child_cost elif selection.kind == 'fragment_spread': fragment = self.context.get_fragment(selection.name.value) fragment_type = fragment and self.context.schema.get_type( fragment.type_condition.name.value) fragment_node_cost = self.compute_node_cost(fragment, fragment_type, self.operation_multipliers) \ if fragment \ else 0 fragment_costs.append(fragment_node_cost) node_cost = 0 elif selection.kind == 'inline_fragment': inline_fragment_type = self.context.schema.get_type(selection.type_condition.name.value) \ if selection.type_condition and selection.type_condition.name \ else type_definition fragment_node_cost = self.compute_node_cost(selection, inline_fragment_type, self.operation_multipliers) \ if selection \ else 0 fragment_costs.append(fragment_node_cost) node_cost = 0 else: node_cost = self.compute_node_cost( node=selection, type_definition=type_definition) total_cost += max(node_cost, 0) if fragment_costs: return total_cost + max(fragment_costs) return total_cost
def compute_node_cost(self, node: CostAwareNode, type_def, parent_multipliers=None): if parent_multipliers is None: parent_multipliers = [] if isinstance(node, FragmentSpread) or not node.selection_set: return 0 fields: GraphQLFieldMap = {} if isinstance(type_def, (GraphQLObjectType, GraphQLInterfaceType)): fields = type_def.fields total = 0 for child_node in node.selection_set.selections: self.operation_multipliers = parent_multipliers[:] node_cost = self.default_cost if isinstance(child_node, Field): field = fields.get(child_node.name.value) if not field: continue field_type = get_named_type(field.type) try: field_args: Dict[str, Any] = get_argument_values( field.args, child_node.arguments, self.variables, ) except Exception as e: report_error(self.context, e) field_args = {} if not self.cost_map: return 0 cost_map_args = ( self.get_args_from_cost_map(child_node, type_def.name, field_args) if type_def and type_def.name else None ) if cost_map_args is not None: try: node_cost = self.compute_cost(**cost_map_args) except (TypeError, ValueError) as e: report_error(self.context, e) child_cost = self.compute_node_cost( child_node, field_type, self.operation_multipliers ) node_cost += child_cost if isinstance(child_node, FragmentSpread): fragment = self.context.get_fragment(child_node.name.value) if fragment: fragment_type = self.context.get_schema().get_type( fragment.type_condition.name.value ) node_cost = self.compute_node_cost(fragment, fragment_type) if isinstance(child_node, InlineFragment): inline_fragment_type = type_def if child_node.type_condition and child_node.type_condition.name: inline_fragment_type = self.context.get_schema().get_type( child_node.type_condition.name.value ) node_cost = self.compute_node_cost(child_node, inline_fragment_type) total += node_cost return total
def __init__( self, field_node: FieldNode, field_def: ObjectType, schema: Schema, parent: typing.Optional["ASTNode"], variable_values, parent_type, fragments: typing.Dict[str, FragmentDefinitionNode], ): self.name = field_node.name.value # A connection/edge/etc class field_def = get_field_def(schema, parent_type, self.name) _args = get_argument_values(type_def=field_def, node=field_node, variable_values=variable_values) selection_set = field_node.selection_set field_type = field_to_type(field_def) self.alias = (field_node.alias.value if field_node.alias else None) or field_node.name.value self.return_type = field_type self.parent: typing.Optional[ASTNode] = parent self.parent_type = parent_type self.args: typing.Dict[str, typing.Any] = _args self.path: typing.List[str] = parent.path + [ self.name ] if parent is not None else ["root"] def from_selection_set(selection_set): for selection_ast in selection_set.selections: # Handle fragments if isinstance(selection_ast, FragmentSpreadNode): fragment_name = selection_ast.name.value fragment = fragments[fragment_name] fragment_selection_set = fragment.selection_set yield from from_selection_set(fragment_selection_set) elif isinstance(selection_ast, InlineFragmentNode): yield from from_selection_set(selection_ast.selection_set) else: selection_name = selection_ast.name.value if selection_name.startswith("__"): # Reserved "introspection" field handled by framework continue selection_field = field_type.fields[selection_name] yield ASTNode( field_node=selection_ast, field_def=selection_field, schema=schema, parent=self, variable_values=variable_values, parent_type=field_type, fragments=fragments, ) sub_fields = list( from_selection_set(selection_set)) if selection_set else [] self.fields = sub_fields