コード例 #1
0
def _param_ast(param: ServiceParameter) -> ast.arg:
    """Create the ast node for the parameter of a function.

    This can be typed.

    """
    if param.name.startswith("/"):
        annotation = ast.Dict(keys=None, values=None)

        annotation = ast.Subscript(
            value=ast.Attribute(value=ast.Name(id="typing"), attr="Dict"),
            slice=ast.Index(value=ast.Tuple(
                elts=[ast.Name(
                    id="str"), ast.Name(id="str")])),
        )

        return ast.arg(
            arg=snakeit(param.extra_data["(placeholderParam)"]["paramName"]),
            annotation=annotation,
        )

    annotation = None
    if param.type == "string":
        if param.multiple:
            annotation = ast.Name(id="OptionalListStr")
        else:
            annotation = ast.Name(id="str")
    elif param.type == "number":
        annotation = ast.Name(id="int")
    elif param.type == "boolean":
        annotation = ast.Name(id="bool")
    elif param.type == "file":
        annotation = ast.Name(id="typing.BinaryIO")
    return ast.arg(arg=snakeit(param.name), annotation=annotation)
コード例 #2
0
    def _get_resource_methods(self, service_domain, path, data):
        params = list(service_domain.path_parameters)
        params.extend(self._get_parameters(data))
        method_name = "_%s" % snakeit(data["(methodName)"])
        method_name = method_name.replace("with", "by")
        type_name = _get_item_type(data)

        name_prefix = ""
        if service_domain.parent:
            name_prefix = snakeit(service_domain.context_name) + "_"

        for endpoint_path, endpoint_data in data.items():
            if endpoint_path == "post":
                input_type = data["type"][type_name]["resourceUpdateType"]
                method = ServiceMethod(
                    name=name_prefix + "update" + method_name,
                    path=service_domain.path + path,
                    path_params=list(params),
                    type="update",
                    method="post",
                    input_type=service_domain.resource_type + input_type,
                    returns=_get_return_type(endpoint_data,
                                             service_domain.resource_type),
                )
                yield self._add_metadata(method, endpoint_data, data)

            elif endpoint_path == "get":
                method = ServiceMethod(
                    name=name_prefix + "get" + method_name,
                    path=service_domain.path + path,
                    path_params=list(params),
                    type="get",
                    method="get",
                    returns=_get_return_type(endpoint_data,
                                             service_domain.resource_type),
                )
                yield self._add_metadata(method, endpoint_data, data)

            elif endpoint_path == "delete":
                method = ServiceMethod(
                    name=name_prefix + "delete" + method_name,
                    path=service_domain.path + path,
                    path_params=list(params),
                    type="delete",
                    method="delete",
                    returns=_get_return_type(endpoint_data,
                                             service_domain.resource_type),
                )
                yield self._add_metadata(method, endpoint_data, data)

            elif endpoint_path.startswith("/"):
                yield self._get_action_method(service_domain, endpoint_path,
                                              endpoint_data, data)
コード例 #3
0
    def _get_action_method(self, service_domain, path, data, parent_data):
        if "get" in data and "post" in data:
            if data["post"].get("responses"):
                endpoint_data = data["post"]
                method = "post"
            else:
                endpoint_data = data["get"]
                method = "get"
        elif "post" in data:
            endpoint_data = data["post"]
            method = "post"
        elif "get" in data:
            endpoint_data = data["get"]
            method = "get"
        else:
            return None

        method_name = create_method_name(path)
        if service_domain.parent:
            method_name = snakeit(
                service_domain.context_name) + "_" + method_name

        method = ServiceMethod(
            name=method_name,
            path=service_domain.path + path,
            path_params=[],
            query_params=[],
            type="action",
            method=method,
            input_type=_get_input_type(endpoint_data),
            returns=_get_return_type(endpoint_data,
                                     service_domain.resource_type),
        )
        return self._add_metadata(method, endpoint_data, data)
コード例 #4
0
    def _get_domain_methods(self, service_domain, method, method_data,
                            parent_data):
        method_name = ""
        if service_domain.parent:
            method_name = snakeit(service_domain.context_name) + "_"

        if method == "get":
            method = ServiceMethod(
                name=method_name + "query",
                path=service_domain.path,
                path_params=list(service_domain.path_parameters),
                query_params=[],
                type="query",
                method=method,
                returns=_get_return_type(method_data,
                                         service_domain.resource_querytype),
            )
            return self._add_metadata(method, method_data, parent_data)

        elif method == "post":
            method = ServiceMethod(
                name=method_name + "create",
                path=service_domain.path,
                path_params=list(service_domain.path_parameters),
                query_params=[],
                type="create",
                method=method,
                input_type=service_domain.resource_draft,
                returns=_get_return_type(method_data,
                                         service_domain.resource_type),
            )
            return self._add_metadata(method, method_data, parent_data)
コード例 #5
0
def _create_endpoint_fstring(
        value: str) -> typing.Union[ast.Str, ast.JoinedStr]:
    """Create the value for the `endpoint=..` parameter in the client.

    If the parameter doesn't contain a var a regular string is returned,
    otherwise an f-string is created and returned.

    Note that this method assumes that the parameters in the endpoint are kept
    'intact'. E.g. the raml specifices `/{container}/{key}` so we assume that
    those variables are available in the scope when we create the f-string.

    """
    parts = []
    last = 0
    value = value.lstrip("/")
    for m in re.finditer("{[^}]+}", value):
        parts.append(ast.Constant(value=value[last:m.start()]))

        identifier = snakeit(value[m.start() + 1:m.end() - 1])
        parts.append(
            ast.FormattedValue(value=ast.Name(identifier),
                               conversion=-1,
                               format_spec=None))
        last = m.end()
    if last != len(value):
        parts.append(ast.Constant(value=value[last:len(value)]))

    # If no values are in the f-string we can just generate a regular string
    if len(parts) == 1:
        return ast.Str(s=parts[0].value, kind=None)

    return ast.JoinedStr(values=parts)
コード例 #6
0
def _create_data_type(name: str, data: Dict[str, Any]):
    obj = DataType(name=name, type=data.get("type"))

    # Copy all annotations
    for key, value in data.items():
        if key.startswith("("):
            obj.annotations[key[1:-1]] = value

    obj.package_name = "_" + snakeit(obj.annotations.get('package', 'base'))

    # Copy enum properties
    obj.enum = data.get("enum", [])
    obj.discriminator = data.get("discriminator")
    obj.discriminator_value = data.get("discriminatorValue")

    # Iterate object properties
    properties = data.get("properties") or {}
    for key, value in properties.items():
        if isinstance(value, dict):
            property_type = _determine_type(value)
            items = value.get("items")
            if items:
                items = [v.strip() for v in value.get("items").split("|")]
        else:
            property_type = value
            items = []

        optional = key.endswith("?")
        many = property_type.endswith("[]")
        if many:
            property_type = property_type[:-2]

        property_types = [v.strip() for v in property_type.split("|")]
        property_objects = [UnresolvedType(name=p) for p in property_types]
        prop = Property(
            name=key.rstrip("?"),
            optional=optional,
            types=property_objects,
            many=many,
            items=items,
        )
        obj.properties.append(prop)

    return obj
コード例 #7
0
 def attribute_name(self) -> Optional[str]:
     name = snakeit(self.name)
     if not name or not name.isidentifier():
         return None
     return name
コード例 #8
0
 def add_service(self, service: ServiceDomain):
     module_name = snakeit(service.context_name)
     node = self.create_class(service, module_name)
     if node:
         self._services[service.context_name] = service
         self._service_nodes[module_name].append(node)
コード例 #9
0
def _create_schema_field(param: ServiceParameter):
    """Generate a field assignment for a marshmallow schema"""
    keywords = []
    imports = []
    methods = []
    code_name = snakeit(param.name)
    if code_name != param.name:
        keywords.append(
            ast.keyword(arg="data_key", value=ast.Str(s=param.name,
                                                      kind=None)))
    if not param.required:
        keywords.append(
            ast.keyword(arg="required",
                        value=ast.Constant(value=False, kind=None)))

    if param.name.startswith("/"):
        placeholder = param.extra_data["(placeholderParam)"]
        code_name = snakeit(placeholder["paramName"])
        imports.append(("marshmallow", "fields"))
        imports.append(("marshmallow", ))
        serialize_func = ast.Call(
            func=ast.Attribute(value=ast.Name(id="fields"), attr="Dict"),
            args=[],
            keywords=[],
        )

        # TODO: can break if there is a suffix
        key_name = placeholder["template"].replace(
            "<%s>" % placeholder["placeholder"], "")
        code = ast.parse(
            textwrap.dedent(
                """
            @marshmallow.post_dump
            def _%(target_name)s_post_dump(self, data, **kwrags):
                values = data.pop('%(target_name)s')
                if not values:
                    return data
                for key, val in values.items():
                    data[f"%(key_name)s{key}"] = val
                return data

            @marshmallow.pre_load
            def _%(target_name)s_post_load(self, data, **kwrags):
                items = {}
                for key in list(data.keys()):
                    if key.startswith("%(key_name)s"):
                        items[key[%(key_len)d:]] = data[key]
                        del data[key]
                data["%(target_name)s"] = items
                return data
        """ % {
                    "target_name": code_name,
                    "key_name": key_name,
                    "key_len": len(key_name),
                }))
        methods.extend(code.body)

    elif param.type == "string":
        imports.append(("commercetools.helpers", "OptionalList"))
        imports.append(("marshmallow", "fields"))
        serialize_func = ast.Call(
            func=ast.Name(id="OptionalList"),
            args=[
                ast.Call(
                    func=ast.Attribute(value=ast.Name(id="fields"),
                                       attr="String"),
                    args=[],
                    keywords=[],
                )
            ],
            keywords=keywords,
        )
    elif param.type == "number":
        imports.append(("marshmallow", "fields"))
        serialize_func = ast.Call(
            func=ast.Attribute(value=ast.Name(id="fields"), attr="Int"),
            args=[],
            keywords=keywords,
        )

    elif param.type == "boolean":
        keywords.append(
            ast.keyword(arg="missing",
                        value=ast.Constant(value=False, kind=None)))
        imports.append(("marshmallow", "fields"))
        serialize_func = ast.Call(
            func=ast.Attribute(value=ast.Name(id="fields"), attr="Bool"),
            args=[],
            keywords=keywords,
        )
    elif param.type == "file":
        return None, []
    else:
        raise NotImplementedError(param)

    node = ast.Assign(targets=[ast.Name(id=code_name)],
                      value=serialize_func,
                      simple=1)
    return node, methods, imports
コード例 #10
0
def _generate_init_file(services, modules):
    """Generate the __init__.py file which contains the ServicsMixin for
    the client.

    This is mostly to automate the addition of new services.

    """
    nodes = []

    nodes.append(
        ast.Import(names=[ast.alias(name="typing", asname=None)], level=0))
    nodes.append(
        ast.ImportFrom(
            module="cached_property",
            names=[ast.alias(name="cached_property", asname=None)],
            level=0,
        ))

    # Collect all submodules
    submodules = {}
    for service in services.values():
        module_name = snakeit(service.context_name)
        service_name = service.context_name + "Service"
        info = modules[module_name]

        key = ".%s" % info["name"]
        submodules[key] = {
            "module_name": info["name"],
            "class_name": service.context_name + "Service",
            "var_name": snakeit(service.context_name),
        }

    # Add manual generated files (TODO)
    submodules[".project"] = {
        "module_name": "project",
        "class_name": "ProjectService",
        "var_name": "project",
    }

    # Generate TYPE_CHECKING import statements (these will be sorted by isort).
    if_node = ast.If(
        test=ast.Attribute(value=ast.Name(id="typing"), attr="TYPE_CHECKING"),
        body=[],
        orelse=[],
    )
    nodes.append(if_node)
    for name, service in submodules.items():
        node = ast.ImportFrom(
            module=name,
            names=[ast.alias(name=service["class_name"], asname=None)],
            level=0,
        )
        if_node.body.append(node)

    module_varnames = sorted(submodules.values(),
                             key=operator.itemgetter("var_name"))

    # Return the class + properties
    class_node = ast.ClassDef(name="ServicesMixin",
                              bases=[],
                              keywords=[],
                              decorator_list=[],
                              body=[])
    for name, service in submodules.items():
        node = ast.FunctionDef(
            name=service["module_name"],
            args=ast.arguments(
                args=[ast.arg(arg="self", annotation=None)],
                vararg=None,
                kwonlyargs=[],
                kw_defaults=[],
                kwarg=None,
                defaults=[],
            ),
            body=[],
            decorator_list=[ast.Name(id="cached_property")],
            returns=ast.Str(s=service["class_name"], kind=None),
        )
        node.body.append(
            ast.ImportFrom(
                module=name,
                names=[ast.alias(name=service["class_name"], asname=None)],
                level=0,
            ))
        node.body.append(
            ast.Return(
                ast.Call(
                    func=ast.Name(id=service["class_name"]),
                    args=[ast.Name(id="self")],
                    keywords=[],
                )))
        class_node.body.append(node)

    nodes.append(class_node)

    return ast.Module(body=nodes)
コード例 #11
0
    def make_serialize_query_params(self, method: ServiceMethod, node,
                                    module_name: str):
        """Code to serialize optional parameters to the `params` dict passed to
        the client post/get call.

        This method might also optionally generate a marshmallow schema where it
        uses the various traits as base classes.

        """
        query_params = {
            param.name: param
            for param in method.query_params if param.type != "file"
        }

        # TODO: This should be fixed in the raml specifications since version is
        # part of the body and not part of the query parmaters for update calls
        if method.type == "update" and "version" in query_params:
            del query_params["version"]

        # If this method doesn't accept parameters we just exit early with a
        # `params = {}` line.
        if not query_params:
            line = ast.Assign(targets=[ast.Name(id="params")],
                              value=ast.Dict(keys=[], values=[]))
            node.body.append(line)
            return

        bases = []
        for trait in method.traits:
            if trait.params:
                bases.append(ast.Name(id="traits.%sSchema" % trait.class_name))

        # Generate a custom schema if required
        if method.extra_params or len(bases) != 1:

            if method.type != "action":
                schema_name = f"_{method.context_name}{method.type.title()}Schema"
            else:
                schema_name = f"_{method.context_name}{method.name.title()}Schema"

            if not bases:
                self.add_import_statement(module_name, "marshmallow", "fields")
                self.add_import_statement(module_name, "marshmallow")

                bases = [
                    ast.Name(id="marshmallow.Schema"),
                    ast.Name(id="RemoveEmptyValuesMixin"),
                ]

            schema_node = ast.ClassDef(name=schema_name,
                                       bases=bases,
                                       keywords=[],
                                       decorator_list=[],
                                       body=[])

            # Marshmallow field definitions
            schema_methods = []
            for param in method.extra_params:

                # We skip files since we post the value in the request body
                if param.type == "file":
                    continue

                field_node, methods, imports = _create_schema_field(param)
                if field_node:
                    schema_node.body.append(field_node)
                    schema_methods.extend(methods)
                    for import_ in imports:
                        self.add_import_statement(module_name, *import_)

            schema_node.body.extend(schema_methods)
            if not schema_node.body:
                schema_node.body.append(ast.Pass())

            self.add_schema(method.context_name, schema_node)
        else:
            schema_name = bases[0].id

        # params = self._serialize_params({}, schema)
        input_params = {}
        for key, param in query_params.items():
            if key.startswith("/"):
                key = snakeit(
                    param.extra_data["(placeholderParam)"]["paramName"])
            input_params[key] = snakeit(key)

        line = ast.Assign(
            targets=[ast.Name(id="params")],
            value=ast.Call(
                func=ast.Attribute(value=ast.Name(id="self"),
                                   attr="_serialize_params"),
                args=[
                    ast.Dict(
                        keys=[
                            ast.Str(s=val, kind="")
                            for val in input_params.keys()
                        ],
                        values=[
                            ast.Name(id=val) for val in input_params.values()
                        ],
                    ),
                    ast.Name(id=schema_name),
                ],
                keywords=[],
            ),
        )
        node.body.append(line)
コード例 #12
0
 def add_schema(self, context_name: str, schema_node: ast.ClassDef):
     module_name = snakeit(context_name)
     self._schema_nodes[module_name][schema_node.name] = schema_node