def _param_ast(param: ServiceParameter) -> ast.arg: """Create the ast node for the parameter of a function. This can be typed. """ if param.name.startswith("/"): annotation = ast.Dict(keys=None, values=None) annotation = ast.Subscript( value=ast.Attribute(value=ast.Name(id="typing"), attr="Dict"), slice=ast.Index(value=ast.Tuple( elts=[ast.Name( id="str"), ast.Name(id="str")])), ) return ast.arg( arg=snakeit(param.extra_data["(placeholderParam)"]["paramName"]), annotation=annotation, ) annotation = None if param.type == "string": if param.multiple: annotation = ast.Name(id="OptionalListStr") else: annotation = ast.Name(id="str") elif param.type == "number": annotation = ast.Name(id="int") elif param.type == "boolean": annotation = ast.Name(id="bool") elif param.type == "file": annotation = ast.Name(id="typing.BinaryIO") return ast.arg(arg=snakeit(param.name), annotation=annotation)
def _get_resource_methods(self, service_domain, path, data): params = list(service_domain.path_parameters) params.extend(self._get_parameters(data)) method_name = "_%s" % snakeit(data["(methodName)"]) method_name = method_name.replace("with", "by") type_name = _get_item_type(data) name_prefix = "" if service_domain.parent: name_prefix = snakeit(service_domain.context_name) + "_" for endpoint_path, endpoint_data in data.items(): if endpoint_path == "post": input_type = data["type"][type_name]["resourceUpdateType"] method = ServiceMethod( name=name_prefix + "update" + method_name, path=service_domain.path + path, path_params=list(params), type="update", method="post", input_type=service_domain.resource_type + input_type, returns=_get_return_type(endpoint_data, service_domain.resource_type), ) yield self._add_metadata(method, endpoint_data, data) elif endpoint_path == "get": method = ServiceMethod( name=name_prefix + "get" + method_name, path=service_domain.path + path, path_params=list(params), type="get", method="get", returns=_get_return_type(endpoint_data, service_domain.resource_type), ) yield self._add_metadata(method, endpoint_data, data) elif endpoint_path == "delete": method = ServiceMethod( name=name_prefix + "delete" + method_name, path=service_domain.path + path, path_params=list(params), type="delete", method="delete", returns=_get_return_type(endpoint_data, service_domain.resource_type), ) yield self._add_metadata(method, endpoint_data, data) elif endpoint_path.startswith("/"): yield self._get_action_method(service_domain, endpoint_path, endpoint_data, data)
def _get_action_method(self, service_domain, path, data, parent_data): if "get" in data and "post" in data: if data["post"].get("responses"): endpoint_data = data["post"] method = "post" else: endpoint_data = data["get"] method = "get" elif "post" in data: endpoint_data = data["post"] method = "post" elif "get" in data: endpoint_data = data["get"] method = "get" else: return None method_name = create_method_name(path) if service_domain.parent: method_name = snakeit( service_domain.context_name) + "_" + method_name method = ServiceMethod( name=method_name, path=service_domain.path + path, path_params=[], query_params=[], type="action", method=method, input_type=_get_input_type(endpoint_data), returns=_get_return_type(endpoint_data, service_domain.resource_type), ) return self._add_metadata(method, endpoint_data, data)
def _get_domain_methods(self, service_domain, method, method_data, parent_data): method_name = "" if service_domain.parent: method_name = snakeit(service_domain.context_name) + "_" if method == "get": method = ServiceMethod( name=method_name + "query", path=service_domain.path, path_params=list(service_domain.path_parameters), query_params=[], type="query", method=method, returns=_get_return_type(method_data, service_domain.resource_querytype), ) return self._add_metadata(method, method_data, parent_data) elif method == "post": method = ServiceMethod( name=method_name + "create", path=service_domain.path, path_params=list(service_domain.path_parameters), query_params=[], type="create", method=method, input_type=service_domain.resource_draft, returns=_get_return_type(method_data, service_domain.resource_type), ) return self._add_metadata(method, method_data, parent_data)
def _create_endpoint_fstring( value: str) -> typing.Union[ast.Str, ast.JoinedStr]: """Create the value for the `endpoint=..` parameter in the client. If the parameter doesn't contain a var a regular string is returned, otherwise an f-string is created and returned. Note that this method assumes that the parameters in the endpoint are kept 'intact'. E.g. the raml specifices `/{container}/{key}` so we assume that those variables are available in the scope when we create the f-string. """ parts = [] last = 0 value = value.lstrip("/") for m in re.finditer("{[^}]+}", value): parts.append(ast.Constant(value=value[last:m.start()])) identifier = snakeit(value[m.start() + 1:m.end() - 1]) parts.append( ast.FormattedValue(value=ast.Name(identifier), conversion=-1, format_spec=None)) last = m.end() if last != len(value): parts.append(ast.Constant(value=value[last:len(value)])) # If no values are in the f-string we can just generate a regular string if len(parts) == 1: return ast.Str(s=parts[0].value, kind=None) return ast.JoinedStr(values=parts)
def _create_data_type(name: str, data: Dict[str, Any]): obj = DataType(name=name, type=data.get("type")) # Copy all annotations for key, value in data.items(): if key.startswith("("): obj.annotations[key[1:-1]] = value obj.package_name = "_" + snakeit(obj.annotations.get('package', 'base')) # Copy enum properties obj.enum = data.get("enum", []) obj.discriminator = data.get("discriminator") obj.discriminator_value = data.get("discriminatorValue") # Iterate object properties properties = data.get("properties") or {} for key, value in properties.items(): if isinstance(value, dict): property_type = _determine_type(value) items = value.get("items") if items: items = [v.strip() for v in value.get("items").split("|")] else: property_type = value items = [] optional = key.endswith("?") many = property_type.endswith("[]") if many: property_type = property_type[:-2] property_types = [v.strip() for v in property_type.split("|")] property_objects = [UnresolvedType(name=p) for p in property_types] prop = Property( name=key.rstrip("?"), optional=optional, types=property_objects, many=many, items=items, ) obj.properties.append(prop) return obj
def attribute_name(self) -> Optional[str]: name = snakeit(self.name) if not name or not name.isidentifier(): return None return name
def add_service(self, service: ServiceDomain): module_name = snakeit(service.context_name) node = self.create_class(service, module_name) if node: self._services[service.context_name] = service self._service_nodes[module_name].append(node)
def _create_schema_field(param: ServiceParameter): """Generate a field assignment for a marshmallow schema""" keywords = [] imports = [] methods = [] code_name = snakeit(param.name) if code_name != param.name: keywords.append( ast.keyword(arg="data_key", value=ast.Str(s=param.name, kind=None))) if not param.required: keywords.append( ast.keyword(arg="required", value=ast.Constant(value=False, kind=None))) if param.name.startswith("/"): placeholder = param.extra_data["(placeholderParam)"] code_name = snakeit(placeholder["paramName"]) imports.append(("marshmallow", "fields")) imports.append(("marshmallow", )) serialize_func = ast.Call( func=ast.Attribute(value=ast.Name(id="fields"), attr="Dict"), args=[], keywords=[], ) # TODO: can break if there is a suffix key_name = placeholder["template"].replace( "<%s>" % placeholder["placeholder"], "") code = ast.parse( textwrap.dedent( """ @marshmallow.post_dump def _%(target_name)s_post_dump(self, data, **kwrags): values = data.pop('%(target_name)s') if not values: return data for key, val in values.items(): data[f"%(key_name)s{key}"] = val return data @marshmallow.pre_load def _%(target_name)s_post_load(self, data, **kwrags): items = {} for key in list(data.keys()): if key.startswith("%(key_name)s"): items[key[%(key_len)d:]] = data[key] del data[key] data["%(target_name)s"] = items return data """ % { "target_name": code_name, "key_name": key_name, "key_len": len(key_name), })) methods.extend(code.body) elif param.type == "string": imports.append(("commercetools.helpers", "OptionalList")) imports.append(("marshmallow", "fields")) serialize_func = ast.Call( func=ast.Name(id="OptionalList"), args=[ ast.Call( func=ast.Attribute(value=ast.Name(id="fields"), attr="String"), args=[], keywords=[], ) ], keywords=keywords, ) elif param.type == "number": imports.append(("marshmallow", "fields")) serialize_func = ast.Call( func=ast.Attribute(value=ast.Name(id="fields"), attr="Int"), args=[], keywords=keywords, ) elif param.type == "boolean": keywords.append( ast.keyword(arg="missing", value=ast.Constant(value=False, kind=None))) imports.append(("marshmallow", "fields")) serialize_func = ast.Call( func=ast.Attribute(value=ast.Name(id="fields"), attr="Bool"), args=[], keywords=keywords, ) elif param.type == "file": return None, [] else: raise NotImplementedError(param) node = ast.Assign(targets=[ast.Name(id=code_name)], value=serialize_func, simple=1) return node, methods, imports
def _generate_init_file(services, modules): """Generate the __init__.py file which contains the ServicsMixin for the client. This is mostly to automate the addition of new services. """ nodes = [] nodes.append( ast.Import(names=[ast.alias(name="typing", asname=None)], level=0)) nodes.append( ast.ImportFrom( module="cached_property", names=[ast.alias(name="cached_property", asname=None)], level=0, )) # Collect all submodules submodules = {} for service in services.values(): module_name = snakeit(service.context_name) service_name = service.context_name + "Service" info = modules[module_name] key = ".%s" % info["name"] submodules[key] = { "module_name": info["name"], "class_name": service.context_name + "Service", "var_name": snakeit(service.context_name), } # Add manual generated files (TODO) submodules[".project"] = { "module_name": "project", "class_name": "ProjectService", "var_name": "project", } # Generate TYPE_CHECKING import statements (these will be sorted by isort). if_node = ast.If( test=ast.Attribute(value=ast.Name(id="typing"), attr="TYPE_CHECKING"), body=[], orelse=[], ) nodes.append(if_node) for name, service in submodules.items(): node = ast.ImportFrom( module=name, names=[ast.alias(name=service["class_name"], asname=None)], level=0, ) if_node.body.append(node) module_varnames = sorted(submodules.values(), key=operator.itemgetter("var_name")) # Return the class + properties class_node = ast.ClassDef(name="ServicesMixin", bases=[], keywords=[], decorator_list=[], body=[]) for name, service in submodules.items(): node = ast.FunctionDef( name=service["module_name"], args=ast.arguments( args=[ast.arg(arg="self", annotation=None)], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[], ), body=[], decorator_list=[ast.Name(id="cached_property")], returns=ast.Str(s=service["class_name"], kind=None), ) node.body.append( ast.ImportFrom( module=name, names=[ast.alias(name=service["class_name"], asname=None)], level=0, )) node.body.append( ast.Return( ast.Call( func=ast.Name(id=service["class_name"]), args=[ast.Name(id="self")], keywords=[], ))) class_node.body.append(node) nodes.append(class_node) return ast.Module(body=nodes)
def make_serialize_query_params(self, method: ServiceMethod, node, module_name: str): """Code to serialize optional parameters to the `params` dict passed to the client post/get call. This method might also optionally generate a marshmallow schema where it uses the various traits as base classes. """ query_params = { param.name: param for param in method.query_params if param.type != "file" } # TODO: This should be fixed in the raml specifications since version is # part of the body and not part of the query parmaters for update calls if method.type == "update" and "version" in query_params: del query_params["version"] # If this method doesn't accept parameters we just exit early with a # `params = {}` line. if not query_params: line = ast.Assign(targets=[ast.Name(id="params")], value=ast.Dict(keys=[], values=[])) node.body.append(line) return bases = [] for trait in method.traits: if trait.params: bases.append(ast.Name(id="traits.%sSchema" % trait.class_name)) # Generate a custom schema if required if method.extra_params or len(bases) != 1: if method.type != "action": schema_name = f"_{method.context_name}{method.type.title()}Schema" else: schema_name = f"_{method.context_name}{method.name.title()}Schema" if not bases: self.add_import_statement(module_name, "marshmallow", "fields") self.add_import_statement(module_name, "marshmallow") bases = [ ast.Name(id="marshmallow.Schema"), ast.Name(id="RemoveEmptyValuesMixin"), ] schema_node = ast.ClassDef(name=schema_name, bases=bases, keywords=[], decorator_list=[], body=[]) # Marshmallow field definitions schema_methods = [] for param in method.extra_params: # We skip files since we post the value in the request body if param.type == "file": continue field_node, methods, imports = _create_schema_field(param) if field_node: schema_node.body.append(field_node) schema_methods.extend(methods) for import_ in imports: self.add_import_statement(module_name, *import_) schema_node.body.extend(schema_methods) if not schema_node.body: schema_node.body.append(ast.Pass()) self.add_schema(method.context_name, schema_node) else: schema_name = bases[0].id # params = self._serialize_params({}, schema) input_params = {} for key, param in query_params.items(): if key.startswith("/"): key = snakeit( param.extra_data["(placeholderParam)"]["paramName"]) input_params[key] = snakeit(key) line = ast.Assign( targets=[ast.Name(id="params")], value=ast.Call( func=ast.Attribute(value=ast.Name(id="self"), attr="_serialize_params"), args=[ ast.Dict( keys=[ ast.Str(s=val, kind="") for val in input_params.keys() ], values=[ ast.Name(id=val) for val in input_params.values() ], ), ast.Name(id=schema_name), ], keywords=[], ), ) node.body.append(line)
def add_schema(self, context_name: str, schema_node: ast.ClassDef): module_name = snakeit(context_name) self._schema_nodes[module_name][schema_node.name] = schema_node