예제 #1
0
def test_parses_multi_byte_characters():
    # type: () -> None
    result = parse(
        u"""
        # This comment has a \u0A0A multi-byte character.
        { field(arg: "Has a \u0A0A multi-byte character.") }
    """,
        no_location=True,
        no_source=True,
    )
    assert result == ast.Document(definitions=[
        ast.OperationDefinition(
            operation="query",
            name=None,
            variable_definitions=None,
            directives=[],
            selection_set=ast.SelectionSet(selections=[
                ast.Field(
                    alias=None,
                    name=ast.Name(value=u"field"),
                    arguments=[
                        ast.Argument(
                            name=ast.Name(value=u"arg"),
                            value=ast.StringValue(
                                value=u"Has a \u0a0a multi-byte character."),
                        )
                    ],
                    directives=[],
                    selection_set=None,
                )
            ]),
        )
    ])
예제 #2
0
def test_parses_simple_interface():
    # type: () -> None
    body = """
interface Hello {
  world: String
}
"""
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(
        definitions=[
            ast.InterfaceTypeDefinition(
                name=ast.Name(value="Hello", loc=loc(11, 16)),
                directives=[],
                fields=[
                    ast.FieldDefinition(
                        name=ast.Name(value="world", loc=loc(21, 26)),
                        arguments=[],
                        type=ast.NamedType(
                            name=ast.Name(value="String", loc=loc(28, 34)),
                            loc=loc(28, 34),
                        ),
                        directives=[],
                        loc=loc(21, 34),
                    )
                ],
                loc=loc(1, 36),
            )
        ],
        loc=loc(1, 37),
    )

    assert doc == expected
예제 #3
0
def test_parses_simple_type_inheriting_multiple_interfaces():
    # type: () -> None
    body = "type Hello implements Wo & rld { }"
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(
        definitions=[
            ast.ObjectTypeDefinition(
                name=ast.Name(value="Hello", loc=loc(5, 10)),
                interfaces=[
                    ast.NamedType(
                        name=ast.Name(value="Wo", loc=loc(22, 24)), loc=loc(22, 24)
                    ),
                    ast.NamedType(
                        name=ast.Name(value="rld", loc=loc(26, 29)), loc=loc(26, 29)
                    ),
                ],
                directives=[],
                fields=[],
                loc=loc(0, 33),
            )
        ],
        loc=loc(0, 33),
    )
    assert doc == expected
예제 #4
0
def test_parses_double_value_enum():
    # type: () -> None
    body = "enum Hello { WO, RLD }"
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(
        definitions=[
            ast.EnumTypeDefinition(
                name=ast.Name(value="Hello", loc=loc(5, 10)),
                directives=[],
                values=[
                    ast.EnumValueDefinition(
                        name=ast.Name(value="WO", loc=loc(13, 15)),
                        directives=[],
                        loc=loc(13, 15),
                    ),
                    ast.EnumValueDefinition(
                        name=ast.Name(value="RLD", loc=loc(17, 20)),
                        directives=[],
                        loc=loc(17, 20),
                    ),
                ],
                loc=loc(0, 22),
            )
        ],
        loc=loc(0, 22),
    )

    assert doc == expected
예제 #5
0
def test_parses_simple_interface():
    body = '''
interface Hello {
  world: String
}
'''
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(definitions=[
        ast.InterfaceTypeDefinition(
            name=ast.Name(value='Hello', loc=loc(11, 16)),
            directives=[],
            fields=[
                ast.FieldDefinition(name=ast.Name(value='world',
                                                  loc=loc(21, 26)),
                                    arguments=[],
                                    type=ast.NamedType(name=ast.Name(
                                        value='String', loc=loc(28, 34)),
                                                       loc=loc(28, 34)),
                                    directives=[],
                                    loc=loc(21, 34))
            ],
            loc=loc(1, 36))
    ],
                            loc=loc(1, 37))

    assert doc == expected
예제 #6
0
def test_parses_simple_type():
    # type: () -> None
    body = """
type Hello {
  world: String
}"""

    doc = parse(body)
    loc = create_loc_fn(body)

    expected = ast.Document(
        definitions=[
            ast.ObjectTypeDefinition(
                name=ast.Name(value="Hello", loc=loc(6, 11)),
                interfaces=[],
                directives=[],
                fields=[
                    ast.FieldDefinition(
                        name=ast.Name(value="world", loc=loc(16, 21)),
                        arguments=[],
                        type=ast.NamedType(
                            name=ast.Name(value="String", loc=loc(23, 29)),
                            loc=loc(23, 29),
                        ),
                        directives=[],
                        loc=loc(16, 29),
                    )
                ],
                loc=loc(1, 31),
            )
        ],
        loc=loc(1, 31),
    )
    assert doc == expected
예제 #7
0
def test_parses_simple_input_object():
    body = '''
input Hello {
  world: String
}'''
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(definitions=[
        ast.InputObjectTypeDefinition(
            name=ast.Name(value='Hello', loc=loc(7, 12)),
            directives=[],
            fields=[
                ast.InputValueDefinition(name=ast.Name(value='world',
                                                       loc=loc(17, 22)),
                                         type=ast.NamedType(name=ast.Name(
                                             value='String', loc=loc(24, 30)),
                                                            loc=loc(24, 30)),
                                         default_value=None,
                                         directives=[],
                                         loc=loc(17, 30))
            ],
            loc=loc(1, 32))
    ],
                            loc=loc(1, 32))
    assert doc == expected
예제 #8
0
def test_parses_simple_extension():
    body = '''
extend type Hello {
  world: String
}'''
    doc = parse(body)
    loc = create_loc_fn(body)

    expected = ast.Document(definitions=[
        ast.TypeExtensionDefinition(definition=ast.ObjectTypeDefinition(
            name=ast.Name(value='Hello', loc=loc(13, 18)),
            interfaces=[],
            directives=[],
            fields=[
                ast.FieldDefinition(name=ast.Name(value='world',
                                                  loc=loc(23, 28)),
                                    arguments=[],
                                    type=ast.NamedType(name=ast.Name(
                                        value='String', loc=loc(30, 36)),
                                                       loc=loc(30, 36)),
                                    directives=[],
                                    loc=loc(23, 36))
            ],
            loc=loc(8, 38)),
                                    loc=loc(1, 38))
    ],
                            loc=loc(1, 38))

    assert doc == expected
예제 #9
0
def test_parses_simple_field_with_arg_with_default_value():
    body = '''
type Hello {
  world(flag: Boolean = true): String
}'''
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(definitions=[
        ast.ObjectTypeDefinition(
            name=ast.Name(value='Hello', loc=loc(6, 11)),
            interfaces=[],
            fields=[
                ast.FieldDefinition(
                    name=ast.Name(value='world', loc=loc(16, 21)),
                    arguments=[
                        ast.InputValueDefinition(
                            name=ast.Name(value='flag', loc=loc(22, 26)),
                            type=ast.NamedType(name=ast.Name(value='Boolean',
                                                             loc=loc(28, 35)),
                                               loc=loc(28, 35)),
                            default_value=ast.BooleanValue(value=True,
                                                           loc=loc(38, 42)),
                            loc=loc(22, 42))
                    ],
                    type=ast.NamedType(name=ast.Name(value='String',
                                                     loc=loc(45, 51)),
                                       loc=loc(45, 51)),
                    loc=loc(16, 51))
            ],
            loc=loc(1, 53))
    ],
                            loc=loc(1, 53))

    assert doc == expected
예제 #10
0
def test_parses_simple_field_with_list_arg():
    body = '''
type Hello {
  world(things: [String]): String
}'''
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(definitions=[
        ast.ObjectTypeDefinition(
            name=ast.Name(value='Hello', loc=loc(6, 11)),
            interfaces=[],
            fields=[
                ast.FieldDefinition(
                    name=ast.Name(value='world', loc=loc(16, 21)),
                    arguments=[
                        ast.InputValueDefinition(
                            name=ast.Name(value='things', loc=loc(22, 28)),
                            type=ast.ListType(type=ast.NamedType(
                                name=ast.Name(value='String', loc=loc(31, 37)),
                                loc=loc(31, 37)),
                                              loc=loc(30, 38)),
                            default_value=None,
                            loc=loc(22, 38))
                    ],
                    type=ast.NamedType(name=ast.Name(value='String',
                                                     loc=loc(41, 47)),
                                       loc=loc(41, 47)),
                    loc=loc(16, 47))
            ],
            loc=loc(1, 49))
    ],
                            loc=loc(1, 49))
    assert doc == expected
예제 #11
0
def test_simple_non_null_type():
    body = '''
type Hello {
  world: String!
}'''

    doc = parse(body)
    loc = create_loc_fn(body)
    expected = ast.Document(definitions=[
        ast.ObjectTypeDefinition(
            name=ast.Name(value='Hello', loc=loc(6, 11)),
            interfaces=[],
            directives=[],
            fields=[
                ast.FieldDefinition(name=ast.Name(value='world',
                                                  loc=loc(16, 21)),
                                    arguments=[],
                                    type=ast.NonNullType(type=ast.NamedType(
                                        name=ast.Name(value='String',
                                                      loc=loc(23, 29)),
                                        loc=loc(23, 29)),
                                                         loc=loc(23, 30)),
                                    directives=[],
                                    loc=loc(16, 30))
            ],
            loc=loc(1, 32))
    ],
                            loc=loc(1, 32))
    assert doc == expected
예제 #12
0
def _get_basic_schema_ast(query_type):
    """Create a basic AST Document representing a nearly blank schema.

    The output AST contains a single query type, whose name is the input string. The query type
    is guaranteed to be the second entry of Document definitions, after the schema definition.
    The query type has no fields.

    Args:
        query_type: str, name of the query type for the schema

    Returns:
        Document, representing a nearly blank schema
    """
    blank_ast = ast_types.Document(definitions=[
        ast_types.SchemaDefinition(
            operation_types=[
                ast_types.OperationTypeDefinition(
                    operation='query',
                    type=ast_types.NamedType(name=ast_types.Name(
                        value=query_type)),
                )
            ],
            directives=[],
        ),
        ast_types.ObjectTypeDefinition(
            name=ast_types.Name(value=query_type),
            fields=[],
            interfaces=[],
            directives=[],
        ),
    ])
    return blank_ast
예제 #13
0
def test_parses_simple_input_object():
    # type: () -> None
    body = """
input Hello {
  world: String
}"""
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(
        definitions=[
            ast.InputObjectTypeDefinition(
                name=ast.Name(value="Hello", loc=loc(7, 12)),
                directives=[],
                fields=[
                    ast.InputValueDefinition(
                        name=ast.Name(value="world", loc=loc(17, 22)),
                        type=ast.NamedType(
                            name=ast.Name(value="String", loc=loc(24, 30)),
                            loc=loc(24, 30),
                        ),
                        default_value=None,
                        directives=[],
                        loc=loc(17, 30),
                    )
                ],
                loc=loc(1, 32),
            )
        ],
        loc=loc(1, 32),
    )
    assert doc == expected
예제 #14
0
def test_parse_creates_ast():
    source = Source("""{
  node(id: 4) {
    id,
    name
  }
}
""")
    result = parse(source)

    assert result == \
        ast.Document(
            loc=Loc(start=0, end=41, source=source),
            definitions=[ast.OperationDefinition(
                loc=Loc(start=0, end=40, source=source),
                operation='query',
                name=None,
                variable_definitions=None,
                directives=[],
                selection_set=ast.SelectionSet(
                    loc=Loc(start=0, end=40, source=source),
                    selections=[ast.Field(
                        loc=Loc(start=4, end=38, source=source),
                        alias=None,
                        name=ast.Name(
                            loc=Loc(start=4, end=8, source=source),
                            value='node'),
                        arguments=[ast.Argument(
                            name=ast.Name(loc=Loc(start=9, end=11, source=source),
                                          value='id'),
                            value=ast.IntValue(
                                   loc=Loc(start=13, end=14, source=source),
                                   value='4'),
                            loc=Loc(start=9, end=14, source=source))],
                        directives=[],
                        selection_set=ast.SelectionSet(
                            loc=Loc(start=16, end=38, source=source),
                            selections=[ast.Field(
                                loc=Loc(start=22, end=24, source=source),
                                alias=None,
                                name=ast.Name(
                                    loc=Loc(start=22, end=24, source=source),
                                    value='id'),
                                arguments=[],
                                directives=[],
                                selection_set=None),
                                ast.Field(
                                loc=Loc(start=30, end=34, source=source),
                                alias=None,
                                name=ast.Name(
                                    loc=Loc(start=30, end=34, source=source),
                                    value='name'),
                                arguments=[],
                                directives=[],
                                selection_set=None)]))]))])
예제 #15
0
 def _insert_selection_set_parent_nodes_from_deep_field_name(
         self, selection, deep_field_name):
     split_field_name = deep_field_name.split(LOOKUP_SEP, 1)
     name = to_camel_case(split_field_name[0])
     if len(split_field_name) == 1:
         return ast.Field(
             name=ast.Name(value=name),
             selection_set=ast.SelectionSet(
                 selections=selection.selection_set.selections))
     return ast.Field(
         name=ast.Name(value=name),
         selection_set=ast.SelectionSet(selections=[
             self._insert_selection_set_parent_nodes_from_deep_field_name(
                 selection, split_field_name[1])
         ]))
예제 #16
0
def _build_stitch_directive(source_field_name, sink_field_name):
    """Build a Directive node for the stitch directive."""
    return ast_types.Directive(
        name=ast_types.Name(value='stitch'),
        arguments=[
            ast_types.Argument(
                name=ast_types.Name(value='source_field'),
                value=ast_types.StringValue(value=source_field_name),
            ),
            ast_types.Argument(
                name=ast_types.Name(value='sink_field'),
                value=ast_types.StringValue(value=sink_field_name),
            ),
        ],
    )
예제 #17
0
def test_parses_simple_union():
    body = 'union Hello = World'
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(definitions=[
        ast.UnionTypeDefinition(name=ast.Name(value='Hello', loc=loc(6, 11)),
                                types=[
                                    ast.NamedType(name=ast.Name(value='World',
                                                                loc=loc(
                                                                    14, 19)),
                                                  loc=loc(14, 19))
                                ],
                                loc=loc(0, 19))
    ],
                            loc=loc(0, 19))
    assert doc == expected
예제 #18
0
def test_parses_single_value_enum():
    body = 'enum Hello { WORLD }'
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(definitions=[
        ast.EnumTypeDefinition(name=ast.Name(value='Hello', loc=loc(5, 10)),
                               values=[
                                   ast.EnumValueDefinition(name=ast.Name(
                                       value='WORLD', loc=loc(13, 18)),
                                                           loc=loc(13, 18))
                               ],
                               loc=loc(0, 20))
    ],
                            loc=loc(0, 20))

    assert doc == expected
예제 #19
0
    def _get_field_selection_set(
            self,
            field: GraphQLField,
            include_node: bool = True) -> Optional[graphql_ast.SelectionSet]:
        return_type = self.get_return_type(field.type)

        if isinstance(return_type, (GraphQLScalarType, GraphQLEnumType)):
            return None
        elif isinstance(return_type,
                        (GraphQLObjectType, GrapheneInterfaceType)):
            all_selections = []

            sub_fields = return_type.fields.items()

            if self.is_node_type(return_type):
                if include_node is False:
                    sub_fields = [('id', return_type.fields['id'])]
                else:
                    # disable full rendering of nested nodes to avoid recursion
                    include_node = False

            for name, sub_field in sub_fields:
                selection = graphql_ast.Field(
                    name=graphql_ast.Name(value=name),
                    selection_set=self._get_field_selection_set(
                        sub_field, include_node=include_node))
                all_selections.append(selection)

            return graphql_ast.SelectionSet(selections=all_selections)

        raise NotImplementedError
예제 #20
0
def test_parses_multi_byte_characters():
    result = parse(u'''
        # This comment has a \u0A0A multi-byte character.
        { field(arg: "Has a \u0A0A multi-byte character.") }
    ''',
                   no_location=True,
                   no_source=True)
    assert result == ast.Document(definitions=[
        ast.OperationDefinition(
            operation='query',
            name=None,
            variable_definitions=None,
            directives=[],
            selection_set=ast.SelectionSet(selections=[
                ast.Field(
                    alias=None,
                    name=ast.Name(value=u'field'),
                    arguments=[
                        ast.Argument(
                            name=ast.
                            Name(value=u'arg'),
                            value=ast
                            .StringValue(
                                value=u'Has a \u0a0a multi-byte character.'))
                    ],
                    directives=[],
                    selection_set=None)
            ]))
    ])
예제 #21
0
def test_parses_simple_field_with_arg_with_default_value():
    # type: () -> None
    body = """
type Hello {
  world(flag: Boolean = true): String
}"""
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(
        definitions=[
            ast.ObjectTypeDefinition(
                name=ast.Name(value="Hello", loc=loc(6, 11)),
                interfaces=[],
                directives=[],
                fields=[
                    ast.FieldDefinition(
                        name=ast.Name(value="world", loc=loc(16, 21)),
                        arguments=[
                            ast.InputValueDefinition(
                                name=ast.Name(value="flag", loc=loc(22, 26)),
                                type=ast.NamedType(
                                    name=ast.Name(value="Boolean",
                                                  loc=loc(28, 35)),
                                    loc=loc(28, 35),
                                ),
                                default_value=ast.BooleanValue(value=True,
                                                               loc=loc(38,
                                                                       42)),
                                directives=[],
                                loc=loc(22, 42),
                            )
                        ],
                        type=ast.NamedType(
                            name=ast.Name(value="String", loc=loc(45, 51)),
                            loc=loc(45, 51),
                        ),
                        directives=[],
                        loc=loc(16, 51),
                    )
                ],
                loc=loc(1, 53),
            )
        ],
        loc=loc(1, 53),
    )

    assert doc == expected
예제 #22
0
def _add_edge_field(source_type_node, sink_type_name, source_field_name,
                    sink_field_name, edge_name, direction):
    """Add one direction of the specified edge as a field of the source type.

    Args:
        source_type_node: (Interface/Object)TypeDefinition, where a new field representing
                          one direction of the edge will be added. It is modified by this
                          function
        sink_type_name: str, name of the type that the edge leads to
        source_field_name: str, name of the source side field that will be stitched
        sink_field_name: str, name of the sink side field that will be stitched
        edge_name: str, name of the edge that will be used to name the new field
        direction: str, either OUTBOUND_EDGE_DIRECTION or INBOUND_EDGE_DIRECTION ('out'
                   or 'in')

    Raises:
        - SchemaNameConflictError if the new cross-schema edge name causes a name conflict with
          existing fields, or fields created by previous cross-schema edges
    """
    type_fields = source_type_node.fields

    if direction not in (OUTBOUND_EDGE_DIRECTION, INBOUND_EDGE_DIRECTION):
        raise AssertionError(
            u'Input "direction" must be either "{}" or "{}".'.format(
                OUTBOUND_EDGE_DIRECTION, INBOUND_EDGE_DIRECTION))
    new_edge_field_name = direction + '_' + edge_name

    # Error if new edge causes a field name clash
    if any(field.name.value == new_edge_field_name for field in type_fields):
        raise SchemaNameConflictError(
            u'New field "{}" under type "{}" created by the {}bound field of edge named '
            u'"{}" clashes with an existing field of the same name. Consider changing the '
            u'name of your edge to avoid name conflicts.'.format(
                new_edge_field_name, source_type_node.name.value, direction,
                edge_name))

    new_edge_field_node = ast_types.FieldDefinition(
        name=ast_types.Name(value=new_edge_field_name),
        arguments=[],
        type=ast_types.ListType(type=ast_types.NamedType(
            name=ast_types.Name(value=sink_type_name), ), ),
        directives=[
            _build_stitch_directive(source_field_name, sink_field_name),
        ],
    )

    type_fields.append(new_edge_field_node)
예제 #23
0
def test_parses_simple_field_with_list_arg():
    # type: () -> None
    body = """
type Hello {
  world(things: [String]): String
}"""
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(
        definitions=[
            ast.ObjectTypeDefinition(
                name=ast.Name(value="Hello", loc=loc(6, 11)),
                interfaces=[],
                directives=[],
                fields=[
                    ast.FieldDefinition(
                        name=ast.Name(value="world", loc=loc(16, 21)),
                        arguments=[
                            ast.InputValueDefinition(
                                name=ast.Name(value="things", loc=loc(22, 28)),
                                type=ast.ListType(
                                    type=ast.NamedType(
                                        name=ast.Name(value="String",
                                                      loc=loc(31, 37)),
                                        loc=loc(31, 37),
                                    ),
                                    loc=loc(30, 38),
                                ),
                                default_value=None,
                                directives=[],
                                loc=loc(22, 38),
                            )
                        ],
                        type=ast.NamedType(
                            name=ast.Name(value="String", loc=loc(41, 47)),
                            loc=loc(41, 47),
                        ),
                        directives=[],
                        loc=loc(16, 47),
                    )
                ],
                loc=loc(1, 49),
            )
        ],
        loc=loc(1, 49),
    )
    assert doc == expected
예제 #24
0
def test_parses_union_with_two_types():
    body = 'union Hello = Wo | Rld'
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(definitions=[
        ast.UnionTypeDefinition(
            name=ast.Name(value='Hello', loc=loc(6, 11)),
            types=[
                ast.NamedType(name=ast.Name(value='Wo', loc=loc(14, 16)),
                              loc=loc(14, 16)),
                ast.NamedType(name=ast.Name(value='Rld', loc=loc(19, 22)),
                              loc=loc(19, 22))
            ],
            loc=loc(0, 22))
    ],
                            loc=loc(0, 22))
    assert doc == expected
예제 #25
0
def test_parses_simple_type_inheriting_multiple_interfaces():
    body = 'type Hello implements Wo, rld { }'
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(definitions=[
        ast.ObjectTypeDefinition(
            name=ast.Name(value='Hello', loc=loc(5, 10)),
            interfaces=[
                ast.NamedType(name=ast.Name(value='Wo', loc=loc(22, 24)),
                              loc=loc(22, 24)),
                ast.NamedType(name=ast.Name(value='rld', loc=loc(26, 29)),
                              loc=loc(26, 29))
            ],
            fields=[],
            loc=loc(0, 33))
    ],
                            loc=loc(0, 33))
    assert doc == expected
예제 #26
0
 def args(self, **args):
     for name, value in args.items():
         arg = self.field.args.get(name)
         arg_type_serializer = get_arg_serializer(arg.type)
         value = arg_type_serializer(value)
         self.ast_field.arguments.append(
             ast.Argument(name=ast.Name(value=name),
                          value=get_ast_value(value)))
     return self
def test_converts_input_objects():
    value = OrderedDict([('foo', 3), ('bar', 'HELLO')])

    assert ast_from_value(value) == ast.ObjectValue(fields=[
        ast.ObjectField(name=ast.Name('foo'), value=ast.IntValue('3')),
        ast.ObjectField(name=ast.Name('bar'), value=ast.StringValue('HELLO'))
    ])

    input_obj = GraphQLInputObjectType(
        'MyInputObj', {
            'foo': GraphQLInputObjectField(GraphQLFloat),
            'bar': GraphQLInputObjectField(my_enum)
        })

    assert ast_from_value(value, input_obj) == ast.ObjectValue(fields=[
        ast.ObjectField(name=ast.Name('foo'), value=ast.FloatValue('3.0')),
        ast.ObjectField(name=ast.Name('bar'), value=ast.EnumValue('HELLO'))
    ])
예제 #28
0
def test_parses_scalar():
    body = 'scalar Hello'
    loc = create_loc_fn(body)
    doc = parse(body)
    expected = ast.Document(definitions=[
        ast.ScalarTypeDefinition(name=ast.Name(value='Hello', loc=loc(7, 12)),
                                 loc=loc(0, 12))
    ],
                            loc=loc(0, 12))
    assert doc == expected
예제 #29
0
 def get_variable_type(self, return_type: GraphQLType):
     if isinstance(return_type, GraphQLNonNull):
         return graphql_ast.NonNullType(
             type=self.get_variable_type(return_type.of_type))
     elif isinstance(return_type, GraphQLList):
         return graphql_ast.ListType(
             type=self.get_variable_type(return_type.of_type))
     else:
         return graphql_ast.NamedType(name=graphql_ast.Name(
             value=return_type.name))
예제 #30
0
        def view_func():
            variable_values = self.get_variable_values()

            field_selection_set = self._get_field_selection_set(
                field, include_node=True)

            if hasattr(field.type, 'graphene_type') and issubclass(
                    field.type.graphene_type, Node):
                _type_name, _id = Node.from_global_id(variable_values['id'])
                node_type = schema.get_type(_type_name)
                inline_selection = graphql_ast.InlineFragment(
                    type_condition=graphql_ast.NamedType(name=graphql_ast.Name(
                        value=_type_name)),
                    selection_set=self._get_field_selection_set(
                        GraphQLField(node_type), include_node=True))
                field_selection_set.selections.append(inline_selection)

            document_ast = graphql_ast.Document([
                graphql_ast.OperationDefinition(
                    operation=operation,
                    variable_definitions=variable_definitions,
                    selection_set=graphql_ast.SelectionSet(selections=[
                        graphql_ast.Field(name=graphql_ast.Name(
                            value=field_name),
                                          arguments=arguments,
                                          selection_set=field_selection_set)
                    ]))
            ])

            execution_results = schema.execute(document_ast,
                                               variable_values=variable_values)

            # TODO custom encoder that positions data[field_name] at data
            result, status_code = encode_execution_results(
                [execution_results],
                is_batch=False,
                format_error=default_format_error,
                encode=json_encode)

            return Response(result,
                            status=status_code,
                            content_type='application/json')