예제 #1
0
def generate_validator_constructor(ns, data_type):
    """
    Given a Stone data type, returns a string that can be used to construct
    the appropriate validation object in Python.
    """
    dt, nullable_dt = unwrap_nullable(data_type)
    if is_list_type(dt):
        v = generate_func_call(
            'bv.List',
            args=[generate_validator_constructor(ns, dt.data_type)],
            kwargs=[('min_items', dt.min_items), ('max_items', dt.max_items)],
        )
    elif is_map_type(dt):
        v = generate_func_call(
            'bv.Map',
            args=[
                generate_validator_constructor(ns, dt.key_data_type),
                generate_validator_constructor(ns, dt.value_data_type),
            ])
    elif is_numeric_type(dt):
        v = generate_func_call(
            'bv.{}'.format(dt.name),
            kwargs=[('min_value', dt.min_value), ('max_value', dt.max_value)],
        )
    elif is_string_type(dt):
        pattern = None
        if dt.pattern is not None:
            pattern = repr(dt.pattern)
        v = generate_func_call(
            'bv.String',
            kwargs=[('min_length', dt.min_length),
                    ('max_length', dt.max_length), ('pattern', pattern)],
        )
    elif is_timestamp_type(dt):
        v = generate_func_call(
            'bv.Timestamp',
            args=[repr(dt.format)],
        )
    elif is_user_defined_type(dt):
        v = fmt_class(dt.name) + '_validator'
        if ns.name != dt.namespace.name:
            v = '{}.{}'.format(dt.namespace.name, v)
    elif is_alias(dt):
        # Assume that the alias has already been declared elsewhere.
        name = fmt_class(dt.name) + '_validator'
        if ns.name != dt.namespace.name:
            name = '{}.{}'.format(dt.namespace.name, name)
        v = name
    elif is_boolean_type(dt) or is_bytes_type(dt) or is_void_type(dt):
        v = generate_func_call('bv.{}'.format(dt.name))
    else:
        raise AssertionError('Unsupported data type: %r' % dt)

    if nullable_dt:
        return generate_func_call('bv.Nullable', args=[v])
    else:
        return v
예제 #2
0
 def _generate_union_class_variant_creators(self, ns, data_type):
     """
     Each non-symbol, non-any variant has a corresponding class method that
     can be used to construct a union with that variant selected.
     """
     for field in data_type.fields:
         if not is_void_type(field.data_type):
             field_name = fmt_func(field.name)
             field_name_reserved_check = fmt_func(field.name, True)
             if is_nullable_type(field.data_type):
                 field_dt = field.data_type.data_type
             else:
                 field_dt = field.data_type
             self.emit('@classmethod')
             self.emit(
                 'def {}(cls, val):'.format(field_name_reserved_check))
             with self.indent():
                 self.emit('"""')
                 self.emit_wrapped_text(
                     'Create an instance of this class set to the ``%s`` '
                     'tag with value ``val``.' % field_name)
                 self.emit()
                 self.emit(':param {} val:'.format(
                     self._python_type_mapping(ns, field_dt)))
                 self.emit(':rtype: {}'.format(fmt_class(data_type.name)))
                 self.emit('"""')
                 self.emit("return cls('{}', val)".format(field_name))
             self.emit()
예제 #3
0
    def _generate_union_class_reflection_attributes(self, ns, data_type):
        """
        Adds a class attribute for each union member assigned to a validator.
        Also adds an attribute that is a map from tag names to validators.
        """
        class_name = fmt_class(data_type.name)

        for field in data_type.fields:
            field_name = fmt_var(field.name)
            validator_name = generate_validator_constructor(
                ns, field.data_type)
            self.emit('{}._{}_validator = {}'.format(class_name, field_name,
                                                     validator_name))

        with self.block('{}._tagmap ='.format(class_name)):
            for field in data_type.fields:
                var_name = fmt_var(field.name)
                validator_name = '{}._{}_validator'.format(
                    class_name, var_name)
                self.emit("'{}': {},".format(var_name, validator_name))

        if data_type.parent_type:
            self.emit('{0}._tagmap.update({1}._tagmap)'.format(
                class_name, class_name_for_data_type(data_type.parent_type,
                                                     ns)))

        self.emit()
예제 #4
0
 def _generate_union_class_variant_creators(self, ns, data_type):
     """
     Each non-symbol, non-any variant has a corresponding class method that
     can be used to construct a union with that variant selected.
     """
     for field in data_type.fields:
         if not is_void_type(field.data_type):
             field_name = fmt_func(field.name)
             field_name_reserved_check = fmt_func(field.name, True)
             if is_nullable_type(field.data_type):
                 field_dt = field.data_type.data_type
             else:
                 field_dt = field.data_type
             self.emit('@classmethod')
             self.emit('def {}(cls, val):'.format(field_name_reserved_check))
             with self.indent():
                 self.emit('"""')
                 self.emit_wrapped_text(
                     'Create an instance of this class set to the ``%s`` '
                     'tag with value ``val``.' % field_name)
                 self.emit()
                 self.emit(':param {} val:'.format(
                     self._python_type_mapping(ns, field_dt)))
                 self.emit(':rtype: {}'.format(
                     fmt_class(data_type.name)))
                 self.emit('"""')
                 self.emit("return cls('{}', val)".format(field_name))
             self.emit()
예제 #5
0
    def _generate_union_class_reflection_attributes(self, ns, data_type):
        """
        Adds a class attribute for each union member assigned to a validator.
        Also adds an attribute that is a map from tag names to validators.
        """
        class_name = fmt_class(data_type.name)

        for field in data_type.fields:
            field_name = fmt_var(field.name)
            validator_name = generate_validator_constructor(
                ns, field.data_type)
            self.emit('{}._{}_validator = {}'.format(
                class_name, field_name, validator_name))

        with self.block('{}._tagmap ='.format(class_name)):
            for field in data_type.fields:
                var_name = fmt_var(field.name)
                validator_name = '{}._{}_validator'.format(
                    class_name, var_name)
                self.emit("'{}': {},".format(var_name, validator_name))

        if data_type.parent_type:
            self.emit('{0}._tagmap.update({1}._tagmap)'.format(
                class_name,
                class_name_for_data_type(data_type.parent_type, ns)))

        self.emit()
예제 #6
0
    def _generate_union_class(self, ns, data_type):
        # type: (ApiNamespace, Union) -> None
        """Defines a Python class that represents a union in Stone."""
        self.emit(self._class_declaration_for_type(ns, data_type))
        with self.indent():
            self.emit('"""')
            if data_type.doc:
                self.emit_wrapped_text(
                    self.process_doc(data_type.doc, self._docf))
                self.emit()

            self.emit_wrapped_text(
                'This class acts as a tagged union. Only one of the ``is_*`` '
                'methods will return true. To get the associated value of a '
                'tag (if one exists), use the corresponding ``get_*`` method.')

            if data_type.has_documented_fields():
                self.emit()

            for field in data_type.fields:
                if not field.doc:
                    continue
                if is_void_type(field.data_type):
                    ivar_doc = ':ivar {}: {}'.format(
                        fmt_namespaced_var(ns.name, data_type.name,
                                           field.name),
                        self.process_doc(field.doc, self._docf))
                elif is_user_defined_type(field.data_type):
                    if data_type.namespace.name != ns.name:
                        formatted_var = fmt_namespaced_var(
                            ns.name, data_type.name, field.name)
                    else:
                        formatted_var = '{}.{}'.format(data_type.name,
                                                       fmt_var(field.name))
                    ivar_doc = ':ivar {} {}: {}'.format(
                        fmt_class(field.data_type.name), formatted_var,
                        self.process_doc(field.doc, self._docf))
                else:
                    ivar_doc = ':ivar {} {}: {}'.format(
                        self._python_type_mapping(ns, field.data_type),
                        fmt_namespaced_var(ns.name, data_type.name,
                                           field.name), field.doc)
                self.emit_wrapped_text(ivar_doc, subsequent_prefix='    ')
            self.emit('"""')
            self.emit()

            self._generate_union_class_vars(data_type)
            self._generate_union_class_variant_creators(ns, data_type)
            self._generate_union_class_is_set(data_type)
            self._generate_union_class_get_helpers(ns, data_type)
            self._generate_union_class_custom_annotations(ns, data_type)
        self.emit('{0}_validator = bv.Union({0})'.format(
            class_name_for_data_type(data_type)))
        self.emit()
예제 #7
0
 def _generate_union_class_symbol_creators(self, data_type):
     """
     Class attributes that represent a symbol are set after the union class
     definition.
     """
     class_name = fmt_class(data_type.name)
     lineno = self.lineno
     for field in data_type.fields:
         if is_void_type(field.data_type):
             field_name = fmt_func(field.name)
             self.emit("{0}.{1} = {0}('{1}')".format(class_name, field_name))
     if lineno != self.lineno:
         self.emit()
예제 #8
0
 def _generate_union_class_symbol_creators(self, data_type):
     """
     Class attributes that represent a symbol are set after the union class
     definition.
     """
     class_name = fmt_class(data_type.name)
     lineno = self.lineno
     for field in data_type.fields:
         if is_void_type(field.data_type):
             field_name = fmt_func(field.name)
             self.emit("{0}.{1} = {0}('{1}')".format(class_name, field_name))
     if lineno != self.lineno:
         self.emit()
예제 #9
0
    def _generate_union_class_reflection_attributes(self, ns, data_type):
        """
        Adds a class attribute for each union member assigned to a validator.
        Also adds an attribute that is a map from tag names to validators.
        """
        class_name = fmt_class(data_type.name)

        for field in data_type.fields:
            field_name = fmt_var(field.name)
            validator_name = generate_validator_constructor(
                ns, field.data_type)
            full_validator_name = '{}._{}_validator'.format(
                class_name, field_name)
            self.emit('{} = {}'.format(full_validator_name, validator_name))

            if field.redactor:
                self._generate_redactor(full_validator_name, field.redactor)

        # generate _all_fields_ for each omitted caller (and public)
        child_omitted_callers = data_type.get_all_omitted_callers()
        parent_omitted_callers = data_type.parent_type.get_all_omitted_callers() if \
            data_type.parent_type else set([])

        all_omitted_callers = child_omitted_callers | parent_omitted_callers
        if len(all_omitted_callers) != 0:
            self.emit('{}._permissioned_tagmaps = {}'.format(
                class_name, all_omitted_callers))
        for omitted_caller in all_omitted_callers | {None}:
            is_public = omitted_caller is None
            tagmap_name = '_tagmap' if is_public else '_{}_tagmap'.format(
                omitted_caller)
            caller_in_parent = data_type.parent_type and (
                is_public or omitted_caller in parent_omitted_callers)

            with self.block('{}.{} ='.format(class_name, tagmap_name)):
                for field in data_type.fields:
                    if field.omitted_caller != omitted_caller:
                        continue
                    var_name = fmt_var(field.name)
                    validator_name = '{}._{}_validator'.format(
                        class_name, var_name)
                    self.emit("'{}': {},".format(var_name, validator_name))

            if caller_in_parent:
                self.emit('{0}.{1}.update({2}.{1})'.format(
                    class_name, tagmap_name,
                    class_name_for_data_type(data_type.parent_type, ns)))

        self.emit()
예제 #10
0
    def _generate_union_class(self, ns, data_type):
        # type: (ApiNamespace, Union) -> None
        """Defines a Python class that represents a union in Stone."""
        self.emit(self._class_declaration_for_type(ns, data_type))
        with self.indent():
            self.emit('"""')
            if data_type.doc:
                self.emit_wrapped_text(
                    self.process_doc(data_type.doc, self._docf))
                self.emit()

            self.emit_wrapped_text(
                'This class acts as a tagged union. Only one of the ``is_*`` '
                'methods will return true. To get the associated value of a '
                'tag (if one exists), use the corresponding ``get_*`` method.')

            if data_type.has_documented_fields():
                self.emit()

            for field in data_type.fields:
                if not field.doc:
                    continue
                if is_void_type(field.data_type):
                    ivar_doc = ':ivar {}: {}'.format(
                        fmt_var(field.name),
                        self.process_doc(field.doc, self._docf))
                elif is_user_defined_type(field.data_type):
                    ivar_doc = ':ivar {} {}: {}'.format(
                        fmt_class(field.data_type.name),
                        fmt_var(field.name),
                        self.process_doc(field.doc, self._docf))
                else:
                    ivar_doc = ':ivar {} {}: {}'.format(
                        self._python_type_mapping(ns, field.data_type),
                        fmt_var(field.name), field.doc)
                self.emit_wrapped_text(ivar_doc, subsequent_prefix='    ')
            self.emit('"""')
            self.emit()

            self._generate_union_class_vars(data_type)
            self._generate_union_class_variant_creators(ns, data_type)
            self._generate_union_class_is_set(data_type)
            self._generate_union_class_get_helpers(ns, data_type)
            self._generate_union_class_repr(data_type)
        self.emit('{0}_validator = bv.Union({0})'.format(
            class_name_for_data_type(data_type)
        ))
        self.emit()
예제 #11
0
    def _generate_union_class_reflection_attributes(self, ns, data_type):
        """
        Adds a class attribute for each union member assigned to a validator.
        Also adds an attribute that is a map from tag names to validators.
        """
        class_name = fmt_class(data_type.name)

        for field in data_type.fields:
            field_name = fmt_var(field.name)
            validator_name = generate_validator_constructor(
                ns, field.data_type)
            full_validator_name = '{}._{}_validator'.format(class_name, field_name)
            self.emit('{} = {}'.format(full_validator_name, validator_name))

            if field.redactor:
                self._generate_redactor(full_validator_name, field.redactor)

        # generate _all_fields_ for each omitted caller (and public)
        child_omitted_callers = data_type.get_all_omitted_callers()
        parent_omitted_callers = data_type.parent_type.get_all_omitted_callers() if \
            data_type.parent_type else set([])

        all_omitted_callers = child_omitted_callers | parent_omitted_callers
        if len(all_omitted_callers) != 0:
            self.emit('{}._permissioned_tagmaps = {}'.format(class_name, all_omitted_callers))
        for omitted_caller in all_omitted_callers | {None}:
            is_public = omitted_caller is None
            tagmap_name = '_tagmap' if is_public else '_{}_tagmap'.format(omitted_caller)
            caller_in_parent = data_type.parent_type and (is_public or omitted_caller
                                                         in parent_omitted_callers)

            with self.block('{}.{} ='.format(class_name, tagmap_name)):
                for field in data_type.fields:
                    if field.omitted_caller != omitted_caller:
                        continue
                    var_name = fmt_var(field.name)
                    validator_name = '{}._{}_validator'.format(class_name, var_name)
                    self.emit("'{}': {},".format(var_name, validator_name))

            if caller_in_parent:
                self.emit('{0}.{1}.update({2}.{1})'.format(
                    class_name, tagmap_name,
                    class_name_for_data_type(data_type.parent_type, ns))
                )

        self.emit()
예제 #12
0
    def _generate_enumerated_subtypes_tag_mapping(self, ns, data_type):
        """
        Generates attributes needed for serializing and deserializing structs
        with enumerated subtypes. These assignments are made after all the
        Python class definitions to ensure that all references exist.
        """
        assert data_type.has_enumerated_subtypes()

        # Generate _tag_to_subtype_ attribute: Map from string type tag to
        # the validator of the referenced subtype. Used on deserialization
        # to look up the subtype for a given tag.
        tag_to_subtype_items = []
        for tags, subtype in data_type.get_all_subtypes_with_tags():
            tag_to_subtype_items.append("{}: {}".format(
                tags,
                generate_validator_constructor(ns, subtype)))

        self.generate_multiline_list(
            tag_to_subtype_items,
            before='{}._tag_to_subtype_ = '.format(data_type.name),
            delim=('{', '}'),
            compact=False)

        # Generate _pytype_to_tag_and_subtype_: Map from Python class to a
        # tuple of (type tag, subtype). Used on serialization to lookup how a
        # class should be encoded based on the root struct's enumerated
        # subtypes.
        items = []
        for tag, subtype in data_type.get_all_subtypes_with_tags():
            items.append("{0}: ({1}, {2})".format(
                fmt_class(subtype.name),
                tag,
                generate_validator_constructor(ns, subtype)))
        self.generate_multiline_list(
            items,
            before='{}._pytype_to_tag_and_subtype_ = '.format(data_type.name),
            delim=('{', '}'),
            compact=False)

        # Generate _is_catch_all_ attribute:
        self.emit('{}._is_catch_all_ = {!r}'.format(
            data_type.name, data_type.is_catch_all()))

        self.emit()
예제 #13
0
    def _generate_enumerated_subtypes_tag_mapping(self, ns, data_type):
        """
        Generates attributes needed for serializing and deserializing structs
        with enumerated subtypes. These assignments are made after all the
        Python class definitions to ensure that all references exist.
        """
        assert data_type.has_enumerated_subtypes()

        # Generate _tag_to_subtype_ attribute: Map from string type tag to
        # the validator of the referenced subtype. Used on deserialization
        # to look up the subtype for a given tag.
        tag_to_subtype_items = []
        for tags, subtype in data_type.get_all_subtypes_with_tags():
            tag_to_subtype_items.append("{}: {}".format(
                tags,
                generate_validator_constructor(ns, subtype)))

        self.generate_multiline_list(
            tag_to_subtype_items,
            before='{}._tag_to_subtype_ = '.format(data_type.name),
            delim=('{', '}'),
            compact=False)

        # Generate _pytype_to_tag_and_subtype_: Map from Python class to a
        # tuple of (type tag, subtype). Used on serialization to lookup how a
        # class should be encoded based on the root struct's enumerated
        # subtypes.
        items = []
        for tag, subtype in data_type.get_all_subtypes_with_tags():
            items.append("{0}: ({1}, {2})".format(
                fmt_class(subtype.name),
                tag,
                generate_validator_constructor(ns, subtype)))
        self.generate_multiline_list(
            items,
            before='{}._pytype_to_tag_and_subtype_ = '.format(data_type.name),
            delim=('{', '}'),
            compact=False)

        # Generate _is_catch_all_ attribute:
        self.emit('{}._is_catch_all_ = {!r}'.format(
            data_type.name, data_type.is_catch_all()))

        self.emit()
예제 #14
0
    def _generate_namespace_module(self, namespace):
        for data_type in namespace.linearize_data_types():
            if not is_struct_type(data_type):
                # Only handle user-defined structs (avoid unions and primitives)
                continue

            # Define a class for each struct
            class_def = 'class {}(object):'.format(fmt_class(data_type.name))
            self.emit(class_def)

            with self.indent():
                if data_type.doc:
                    self.emit('"""')
                    self.emit_wrapped_text(data_type.doc)
                    self.emit('"""')

                self.emit()

                # Define constructor to take each field
                args = ['self']
                for field in data_type.fields:
                    args.append(fmt_var(field.name))
                self.generate_multiline_list(args, 'def __init__', ':')

                with self.indent():
                    if data_type.fields:
                        self.emit()
                        # Body of init should assign all init vars
                        for field in data_type.fields:
                            if field.doc:
                                self.emit_wrapped_text(field.doc, '# ', '# ')
                            member_name = fmt_var(field.name)
                            self.emit('self.{0} = {0}'.format(member_name))
                    else:
                        self.emit('pass')
            self.emit()
예제 #15
0
    def _generate_namespace_module(self, namespace):
        for data_type in namespace.linearize_data_types():
            if not is_struct_type(data_type):
                # Only handle user-defined structs (avoid unions and primitives)
                continue

            # Define a class for each struct
            class_def = 'class {}(object):'.format(fmt_class(data_type.name))
            self.emit(class_def)

            with self.indent():
                if data_type.doc:
                    self.emit('"""')
                    self.emit_wrapped_text(data_type.doc)
                    self.emit('"""')

                self.emit()

                # Define constructor to take each field
                args = ['self']
                for field in data_type.fields:
                    args.append(fmt_var(field.name))
                self.generate_multiline_list(args, 'def __init__', ':')

                with self.indent():
                    if data_type.fields:
                        self.emit()
                        # Body of init should assign all init vars
                        for field in data_type.fields:
                            if field.doc:
                                self.emit_wrapped_text(field.doc, '# ', '# ')
                            member_name = fmt_var(field.name)
                            self.emit('self.{0} = {0}'.format(member_name))
                    else:
                        self.emit('pass')
            self.emit()
예제 #16
0
def generate_validator_constructor(ns, data_type):
    """
    Given a Stone data type, returns a string that can be used to construct
    the appropriate validation object in Python.
    """
    dt, nullable_dt = unwrap_nullable(data_type)
    if is_list_type(dt):
        v = generate_func_call(
            'bv.List',
            args=[
                generate_validator_constructor(ns, dt.data_type)],
            kwargs=[
                ('min_items', dt.min_items),
                ('max_items', dt.max_items)],
        )
    elif is_map_type(dt):
        v = generate_func_call(
            'bv.Map',
            args=[
                generate_validator_constructor(ns, dt.key_data_type),
                generate_validator_constructor(ns, dt.value_data_type),
            ]
        )
    elif is_numeric_type(dt):
        v = generate_func_call(
            'bv.{}'.format(dt.name),
            kwargs=[
                ('min_value', dt.min_value),
                ('max_value', dt.max_value)],
        )
    elif is_string_type(dt):
        pattern = None
        if dt.pattern is not None:
            pattern = repr(dt.pattern)
        v = generate_func_call(
            'bv.String',
            kwargs=[
                ('min_length', dt.min_length),
                ('max_length', dt.max_length),
                ('pattern', pattern)],
        )
    elif is_timestamp_type(dt):
        v = generate_func_call(
            'bv.Timestamp',
            args=[repr(dt.format)],
        )
    elif is_user_defined_type(dt):
        v = fmt_class(dt.name) + '_validator'
        if ns.name != dt.namespace.name:
            v = '{}.{}'.format(dt.namespace.name, v)
    elif is_alias(dt):
        # Assume that the alias has already been declared elsewhere.
        name = fmt_class(dt.name) + '_validator'
        if ns.name != dt.namespace.name:
            name = '{}.{}'.format(dt.namespace.name, name)
        v = name
    elif is_boolean_type(dt) or is_bytes_type(dt) or is_void_type(dt):
        v = generate_func_call('bv.{}'.format(dt.name))
    else:
        raise AssertionError('Unsupported data type: %r' % dt)

    if nullable_dt:
        return generate_func_call('bv.Nullable', args=[v])
    else:
        return v
예제 #17
0
    def _generate_route_helper(self, namespace, route, download_to_file=False):
        """Generate a Python method that corresponds to a route.

        :param namespace: Namespace that the route belongs to.
        :param stone.ir.ApiRoute route: IR node for the route.
        :param bool download_to_file: Whether a special version of the route
            that downloads the response body to a file should be generated.
            This can only be used for download-style routes.
        """
        arg_data_type = route.arg_data_type
        result_data_type = route.result_data_type

        request_binary_body = route.attrs.get('style') == 'upload'
        response_binary_body = route.attrs.get('style') == 'download'

        if download_to_file:
            assert response_binary_body, 'download_to_file can only be set ' \
                'for download-style routes.'
            self._generate_route_method_decl(namespace,
                                             route,
                                             arg_data_type,
                                             request_binary_body,
                                             method_name_suffix='_to_file',
                                             extra_args=['download_path'])
        else:
            self._generate_route_method_decl(namespace,
                                             route,
                                             arg_data_type,
                                             request_binary_body)

        with self.indent():
            extra_request_args = None
            extra_return_arg = None
            footer = None
            if request_binary_body:
                extra_request_args = [('f',
                                       'bytes',
                                       'Contents to upload.')]
            elif download_to_file:
                extra_request_args = [('download_path',
                                       'str',
                                       'Path on local machine to save file.')]
            if response_binary_body and not download_to_file:
                extra_return_arg = ':class:`requests.models.Response`'
                footer = DOCSTRING_CLOSE_RESPONSE

            if route.doc:
                func_docstring = self.process_doc(route.doc, self._docf)
            else:
                func_docstring = None

            self._generate_docstring_for_func(
                namespace,
                arg_data_type,
                result_data_type,
                route.error_data_type,
                overview=func_docstring,
                extra_request_args=extra_request_args,
                extra_return_arg=extra_return_arg,
                footer=footer,
            )

            self._maybe_generate_deprecation_warning(route)

            # Code to instantiate a class for the request data type
            if is_void_type(arg_data_type):
                self.emit('arg = None')
            elif is_struct_type(arg_data_type):
                self.generate_multiline_list(
                    [f.name for f in arg_data_type.all_fields],
                    before='arg = {}.{}'.format(
                        fmt_namespace(arg_data_type.namespace.name),
                        fmt_class(arg_data_type.name)),
                )
            elif not is_union_type(arg_data_type):
                raise AssertionError('Unhandled request type %r' %
                                     arg_data_type)

            # Code to make the request
            args = [
                '{}.{}'.format(fmt_namespace(namespace.name),
                               fmt_func(route.name, version=route.version)),
                "'{}'".format(namespace.name),
                'arg']
            if request_binary_body:
                args.append('f')
            else:
                args.append('None')
            self.generate_multiline_list(args, 'r = self.request', compact=False)

            if download_to_file:
                self.emit('self._save_body_to_file(download_path, r[1])')
                if is_void_type(result_data_type):
                    self.emit('return None')
                else:
                    self.emit('return r[0]')
            else:
                if is_void_type(result_data_type):
                    self.emit('return None')
                else:
                    self.emit('return r')
        self.emit()
예제 #18
0
    def _generate_route_helper(self, namespace, route, download_to_file=False):
        """Generate a Python method that corresponds to a route.

        :param namespace: Namespace that the route belongs to.
        :param stone.ir.ApiRoute route: IR node for the route.
        :param bool download_to_file: Whether a special version of the route
            that downloads the response body to a file should be generated.
            This can only be used for download-style routes.
        """
        arg_data_type = route.arg_data_type
        result_data_type = route.result_data_type

        request_binary_body = route.attrs.get('style') == 'upload'
        response_binary_body = route.attrs.get('style') == 'download'

        if download_to_file:
            assert response_binary_body, 'download_to_file can only be set ' \
                'for download-style routes.'
            self._generate_route_method_decl(namespace,
                                             route,
                                             arg_data_type,
                                             request_binary_body,
                                             method_name_suffix='_to_file',
                                             extra_args=['download_path'])
        else:
            self._generate_route_method_decl(namespace,
                                             route,
                                             arg_data_type,
                                             request_binary_body)

        with self.indent():
            extra_request_args = None
            extra_return_arg = None
            footer = None
            if request_binary_body:
                extra_request_args = [('f',
                                       'bytes',
                                       'Contents to upload.')]
            elif download_to_file:
                extra_request_args = [('download_path',
                                       'str',
                                       'Path on local machine to save file.')]
            if response_binary_body and not download_to_file:
                extra_return_arg = ':class:`requests.models.Response`'
                footer = DOCSTRING_CLOSE_RESPONSE

            if route.doc:
                func_docstring = self.process_doc(route.doc, self._docf)
            else:
                func_docstring = None

            self._generate_docstring_for_func(
                namespace,
                arg_data_type,
                result_data_type,
                route.error_data_type,
                overview=func_docstring,
                extra_request_args=extra_request_args,
                extra_return_arg=extra_return_arg,
                footer=footer,
            )

            self._maybe_generate_deprecation_warning(route)

            # Code to instantiate a class for the request data type
            if is_void_type(arg_data_type):
                self.emit('arg = None')
            elif is_struct_type(arg_data_type):
                self.generate_multiline_list(
                    [f.name for f in arg_data_type.all_fields],
                    before='arg = {}.{}'.format(
                        fmt_namespace(arg_data_type.namespace.name),
                        fmt_class(arg_data_type.name)),
                )
            elif not is_union_type(arg_data_type):
                raise AssertionError('Unhandled request type %r' %
                                     arg_data_type)

            # Code to make the request
            args = [
                '{}.{}'.format(fmt_namespace(namespace.name),
                               fmt_func(route.name, version=route.version)),
                "'{}'".format(namespace.name),
                'arg']
            if request_binary_body:
                args.append('f')
            else:
                args.append('None')
            self.generate_multiline_list(args, 'r = self.request', compact=False)

            if download_to_file:
                self.emit('self._save_body_to_file(download_path, r[1])')
                if is_void_type(result_data_type):
                    self.emit('return None')
                else:
                    self.emit('return r[0]')
            else:
                if is_void_type(result_data_type):
                    self.emit('return None')
                else:
                    self.emit('return r')
        self.emit()