def get_args_kwargs(cls, loader: yaml.Loader, node: yaml.Node, signature: Signature = free_signature) -> tp.Optional[BoundArguments]: if isinstance(node, yaml.ScalarNode): try: # This sometimes fails because of ruamel/yaml/resolver.py line 370 node.tag = loader.resolver.resolve(yaml.ScalarNode, node.value, [True, False]) except IndexError: node.tag = loader.DEFAULT_SCALAR_TAG val = loader.construct_object(node) if val is None: val = loader.yaml_constructors[node.tag](loader, node) return None if val is None else signature.bind(val) elif isinstance(node, yaml.SequenceNode): bound = signature.bind(*node.value) args = [] subnodes = node.value elif isinstance(node, yaml.MappingNode): # Construct the keys kwargs = {loader.construct_object(key, deep=True): val for key, val in node.value} args = kwargs.setdefault('__args', yaml.SequenceNode('tag:yaml.org,2002:seq', [])) args_is_seq = isinstance(args, yaml.SequenceNode) and args.tag == 'tag:yaml.org,2002:seq' if args_is_seq: kwargs['__args'] = args.value # Extract nodes in order (nodes are not iterable, so only "flattens" __args) subnodes = list(flatten(kwargs.values())) __args = kwargs.pop('__args') bound = signature.bind_partial(*(__args if args_is_seq else ()), **kwargs) else: raise ValueError(f'Invalid node type, {node}') # Experimental cls.fix_signature(bound) # Construct nodes in yaml order subnode_values = {n: loader.construct_object(n, deep=True) for n in subnodes} for key, val in bound.arguments.items(): bound.arguments[key] = ( signature.parameters[key].kind == Parameter.VAR_POSITIONAL and (subnode_values[n] for n in val) or signature.parameters[key].kind == Parameter.VAR_KEYWORD and {name: subnode_values[n] for name, n in val.items()} or subnode_values[val] ) if args and args in subnode_values: return bound.signature.bind(*subnode_values[args], **bound.kwargs) return bound
def __init__(self, *args, **kwargs): params = [Parameter(field, Parameter.POSITIONAL_OR_KEYWORD) for field in self._fields] sig = Signature(params) bound_values = sig.bind(*args, **kwargs) for name, value in bound_values.arguments.items(): setattr(self, name, value)
def _args_kwargs_to_normalized_kwargs(self, sig : inspect.Signature, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Optional[Dict[str, Any]]: """ Given a call target, args, and kwargs, return the arguments normalized into a single kwargs dict, or None if the type signature is not supported by this normalization. Args: target (inspect.Signature): Signature object for the target args (Tuple): Arguments that appear at the callsite for `target` kwargs (Dict): Keyword arugments that appear at the callsite for `target` Returns: Optional[Dict]: Normalized kwargs for `target`, or `None` if this target is not supported """ # Don't currently support positional-only # or varargs (*args, **kwargs) signatures supported_parameter_types = { inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): return None bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() new_kwargs : Dict[str, Any] = {} for param in sig.parameters: new_kwargs[param] = bound_args.arguments[param] return new_kwargs
def _get_arg_by_name(call: nodes.CallExpr, name: str, sig: inspect.Signature) -> Optional[nodes.Expression]: """ Get value of argument from a call. :return: The argument value, or ``None`` if it cannot be found. .. warning:: This probably doesn't yet work for calls with ``*args`` and/or ``*kwargs``. """ args = [] kwargs = {} for arg_name, arg_value in zip(call.arg_names, call.args): if arg_name is None: args.append(arg_value) else: kwargs[arg_name] = arg_value try: bound_args = sig.bind(*args, **kwargs) except TypeError: return None try: return bound_args.arguments[name] except KeyError: return None
class MethodProxy: def __init__(self, interface, method, bridge, name_converter): self.interface = interface self.method = method self.bridge = bridge params = [Parameter(name_converter.parameter_name(name), Parameter.POSITIONAL_OR_KEYWORD, annotation=TypeWrapper(tp, name_converter)) for tp, name in method.arguments] params.insert(0, Parameter('self', Parameter.POSITIONAL_ONLY, annotation=TypeWrapper(self.interface, name_converter))) return_type = TypeWrapper(method.return_type, name_converter) self.__signature__ = Signature(parameters=params, return_annotation=return_type) def __call__(self, this, *args, **kwargs): bound_values = self.__signature__.bind(this, *args, **kwargs) return_value = self.bridge.call_method(self.method, this, bound_values.arguments) return return_value def __get__(self, instance, cls): if instance: def bound_method_proxy(*args, **kwargs): return self(instance, *args, **kwargs) bound_method_proxy.__signature__ = self.__signature__ return bound_method_proxy else: return self
def freeze_arguments(signature: inspect.Signature, *args, **kwargs): # TODO : review usage... boundargs = signature.bind(*args, **kwargs) boundargs.apply_defaults( ) # in case some arguments were omitted during the bind return
class LibFunction(object): def __init__(self, name, c_name, sig, c_func): self.sig = sig self.name = name self.c_name = c_name self.c_func = c_func sig_str = '{}({}) -> {}'.format(name, sig.args_py_str(), sig.rets_py_str()) c_sig_str = '{}({})'.format(c_name, sig.args_c_str()) self.__doc__ = sig_str + '\n\nOriginal C Function: ' + c_sig_str if sys.version_info >= (3,3): self._assign_signature() def _assign_signature(self): from inspect import Parameter, Signature params = [Parameter(h.c_argname, Parameter.POSITIONAL_OR_KEYWORD) for h in self.sig.in_handlers] self.__signature__ = Signature(params) def __get__(self, instance, owner): if instance is None: return self else: return LibMethod(instance, self) def __call__(self, *args, **kwds): return self._call(args, kwds) def _call(self, args, kwds, niceobj=None): if sys.version_info >= (3,3): bound_args = self.__signature__.bind(*args, **kwds) args = bound_args.args elif kwds: raise TypeError('Keyword args in LibFunctions are not supported for Python versions ' 'before 3.3') check_num_args(self.name, len(args), self.sig.num_inargs, self.sig.variadic) if len(args) != self.sig.num_inargs: raise TypeError("{}() takes {} arguments ({} given)" "".format(self.name, self.sig.num_inargs, len(args))) c_args = self.sig.make_c_args(args) # TODO: Add a custom object for logging these messages, to allow both filtering # and to avoid this text formatting unless it's needed log.info('Calling {}({})'.format(self.name, ', '.join(repr(arg) for arg in args))) retval = self.c_func(*c_args) ret_handler_args = { 'niceobj': niceobj, 'funcname': self.name, 'funcargs': c_args, } return self.sig.extract_outputs(c_args, retval, ret_handler_args)
def func(*args, **kwargs): parms = [ Parameter('x', Parameter.POSITIONAL_OR_KEYWORD), Parameter('y', Parameter.POSITIONAL_OR_KEYWORD, default=24), Parameter('z', Parameter.POSITIONAL_OR_KEYWORD, default=None) ] sig = Signature(parms) bound_values = sig.bind(*args, **kwargs) for name, value in bound_values.arguments.items(): print(name, value)
def bind_args(sig: Signature, call: ast.Call) -> BoundArguments: """Binds the arguments of a function call to that function's signature. @raise TypeError: If the arguments do not match the signature. """ kwargs = { kw.arg: kw.value for kw in call.keywords # When keywords are passed using '**kwargs', the 'arg' field will # be None. We don't currently support keywords passed that way. if kw.arg is not None } return sig.bind(*call.args, **kwargs)
def _get_bound_arguments(sig: Signature, *args, **kwargs) -> OrderedDict: """ Bind the arguments based on a method signature :param sig: A method signature :param args: positional arguments :param kwargs: keyword arguments :return: An OrderedDict containing a mapping from variable names to the values they will have if the method is called with the passed arguments """ bound = sig.bind(*args, **kwargs) bound.apply_defaults() return cast(OrderedDict, bound.arguments)
def _extract_labels(sig: inspect.Signature, labels: _LabelSpec, *args, **kwargs): bound = sig.bind(*args, **kwargs) bound.apply_defaults() result = {} if isinstance(labels, collections.abc.Mapping): for label, spec in labels.items(): if callable(spec): result[label] = spec(bound) else: result[label] = bound.arguments[spec] else: result = {label: bound.arguments[label] for label in labels} return result
def __init__(self, *l, **k): sig = Signature([ Parameter(name, Parameter.POSITIONAL_OR_KEYWORD) for name in self._fields ] + [ Parameter(name, Parameter.POSITIONAL_OR_KEYWORD, default=self._defaults.get(name, None)) for name in self._options ]) self._args = sig.bind(*l, **k) self._args.apply_defaults() for key, val in self._args.arguments.items(): setattr(self, key, val)
def __init__(self, input_infos: List[ModelInputInfo], fwd_signature: Signature, module_ref_for_device: torch.nn.Module = None): self._module_ref_for_device = module_ref_for_device arg_iis_list = [ii for ii in input_infos if ii.keyword is None] kwarg_iis_list = [(ii.keyword, ii) for ii in input_infos if ii.keyword is not None] kwarg_iis = OrderedDict() arg_iis = tuple(arg_iis_list) for kw, ii in kwarg_iis_list: kwarg_iis[kw] = ii bound_params = fwd_signature.bind(*arg_iis, **kwarg_iis) self._fwd_params_to_input_infos_odict = bound_params.arguments self._fwd_signature = fwd_signature # type: Signature
def is_signature_compatible_with_types(signature: inspect.Signature, *args, **kwargs) -> bool: """Determines if functions matching signature accept `args` and `kwargs`. Args: signature: An instance of `inspect.Signature` to verify agains the arguments. *args: Zero or more positional arguments, all of which must be instances of computation_types.Type or something convertible to it by computation_types.to_type(). **kwargs: Zero or more keyword arguments, all of which must be instances of computation_types.Type or something convertible to it by computation_types.to_type(). Returns: `True` or `False`, depending on the outcome of the test. Raises: TypeError: if the arguments are of the wrong computation_types. """ try: bound_args = signature.bind(*args, **kwargs) except TypeError: return False # If we have no defaults then `bind` will have raised `TypeError` if the # signature was not compatible with *args and **kwargs. if all(p.default is inspect.Parameter.empty for p in signature.parameters.values()): return True # Otherwise we need to check the defaults against the types that were given to # ensure they are compatible. for p in signature.parameters.values(): if p.default is inspect.Parameter.empty or p.default is None: # No default value or optional. continue arg_value = bound_args.arguments.get(p.name, p.default) if arg_value is p.default: continue arg_type = computation_types.to_type(arg_value) default_type = type_conversions.infer_type(p.default) if not arg_type.is_assignable_from(default_type): return False return True
def bind(self, signature: inspect.Signature, params: Optional[Union[list, dict]]) -> inspect.BoundArguments: """ Binds parameters to method. :param signature: method to bind parameters to :param params: parameters to be bound :raises: ValidationError is parameters binding failed :returns: bound parameters """ method_args = params if isinstance(params, (list, tuple)) else () method_kwargs = params if isinstance(params, dict) else {} try: return signature.bind(*method_args, **method_kwargs) except TypeError as e: raise ValidationError(str(e)) from e
def __init__(self, *l, **k): params = [] for i, name in enumerate(self._fields): if name[0] == "*": self._fields = self._fields[:] self._fields[i] = name[1:] params.append(Parameter(name[1:], Parameter.VAR_POSITIONAL)) else: params.append(Parameter(name, Parameter.POSITIONAL_OR_KEYWORD)) for name in self._options: params.append( Parameter(name, Parameter.POSITIONAL_OR_KEYWORD, default=None)) params.append(Parameter("src", Parameter.KEYWORD_ONLY, default=None)) sig = Signature(params) args = self._args = sig.bind(*l, **k) args.apply_defaults() for key, val in args.arguments.items(): setattr(self, key, val) self._at = set()
def _args_kwargs_to_normalized_args_kwargs( sig: inspect.Signature, args: Tuple[Any, ...], kwargs: Dict[str, Any], normalize_to_only_use_kwargs: bool) -> Optional[ArgsKwargsPair]: """ Given a call target, args, and kwargs, return the arguments normalized into an ArgsKwargsPair, or None if the type signature is not supported by this normalization. Args: target (inspect.Signature): Signature object for the target args (Tuple): Arguments that appear at the callsite for `target` kwargs (Dict): Keyword arguments that appear at the callsite for `target` normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. Returns: Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if this target is not supported. """ # Don't currently support positional-only # or varargs (*args, **kwargs) signatures supported_parameter_types = { inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY } if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): return None bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() new_kwargs: Dict[str, Any] = {} new_args: List[Any] = [] for i, param in enumerate(sig.parameters): if not normalize_to_only_use_kwargs and i < len(args): new_args.append(bound_args.arguments[param]) else: new_kwargs[param] = bound_args.arguments[param] return ArgsKwargsPair(tuple(new_args), new_kwargs)
def _bound_and_values( signature: inspect.Signature, args: typing.Tuple, kwargs: Kwargs, ) -> typing.Tuple[typing.Tuple, Kwargs, Kwargs]: # The signature lets us regularize the call and apply any defaults bound_arguments = signature.bind(*args, **kwargs) bound_arguments.apply_defaults() # Extract the *args and **kwargs, if any. # These are never used in the matching, just passed to the underlying function bound_args = () bound_kwargs: Kwargs = {} values = bound_arguments.arguments.copy() for parameter in signature.parameters.values(): if parameter.kind is inspect.Parameter.VAR_POSITIONAL: bound_args = values.pop(parameter.name) if parameter.kind is inspect.Parameter.VAR_KEYWORD: bound_kwargs = values.pop(parameter.name) return bound_args, bound_kwargs, values
def bind_args_kwargs(sig: inspect.Signature, *args: typing.Any, **kwargs: typing.Any) -> typing.List[BoundParameter]: """Bind *args and **kwargs to signature and get Bound Parameters. :param sig: source signature :type sig: inspect.Signature :param args: not keyworded arguments :type args: typing.Any :param kwargs: keyworded arguments :type kwargs: typing.Any :return: Iterator for bound parameters with all information about it :rtype: typing.List[BoundParameter] .. versionadded:: 3.3.0 .. versionchanged:: 5.3.1 return list """ result: typing.List[BoundParameter] = [] bound: typing.MutableMapping[str, inspect.Parameter] = sig.bind(*args, **kwargs).arguments for param in sig.parameters.values(): result.append(BoundParameter(parameter=param, value=bound.get(param.name, param.default))) return result
def _log_wrapper(func: typing.Callable, type_: str, sig: inspect.Signature, args: typing.Tuple[typing.Any, ...], kwargs: typing.Dict[str, typing.Any], list_: bool) -> typing.ContextManager[typing.List[str]]: log = logging.getLogger(func.__module__) if 'log_opts' in sig.parameters and not kwargs.get('log_opts', None): kwargs['log_opts'] = _log_opts() if 'log' in sig.parameters and not kwargs.get('log', None): kwargs['log'] = log bargs = sig.bind(*args, **kwargs) logged_func_enabled = log.isEnabledFor(const.LogLevel.LOGGED_FUNC) results = [] if logged_func_enabled: log.logged_func(f'{type_} {func.__name__} [BEGIN]: {bargs!r}') start = arrow.get() try: yield results, bargs.arguments except Exception as e: if logged_func_enabled: log.logged_func( f'{type_} {func.__name__} [{e.__class__.__name__}]: {e!s}') raise else: if logged_func_enabled: if len(results) == 1: results = results[0] else: results = tuple(results) end_msg = f'returns: {results!r}' interval = arrow.get() - start log.logged_func( f'{type_} {func.__name__} [END] ({interval}): {end_msg}')
def to_bound_arguments( self, signature: inspect.Signature, partial: bool = False, ) -> inspect.BoundArguments: """ Generates an instance of :class:inspect.BoundArguments` for a given :class:`inspect.Signature`. Does not raise if invalid or incomplete arguments are provided, as the underlying implementation uses :meth:`inspect.Signature.bind_partial`. :param signature: an instance of :class:`inspect.Signature` to which :paramref:`.CallArguments.args` and :paramref:`.CallArguments.kwargs` will be bound. :param partial: does not raise if invalid or incomplete arguments are provided, as the underlying implementation uses :meth:`inspect.Signature.bind_partial` :returns: an instance of :class:`inspect.BoundArguments` to which :paramref:`.CallArguments.args` and :paramref:`.CallArguments.kwargs` are bound. """ return signature.bind_partial(*self.args, **self.kwargs) \ if partial \ else signature.bind(*self.args, **self.kwargs)
def _ctor_bind_args(self, args, kw): sig = Signature([Parameter("doinit", Parameter.POSITIONAL_OR_KEYWORD)]) return sig.bind(*args, **kw).arguments
class _UnifiAPICall: # pylint: disable=too-many-instance-attributes, too-many-arguments # pylint: disable=too-few-public-methods, protected-access "A representation of a single API call in a specific site" def __init__(self, doc, endpoint, path_arg_name=None, path_arg_optional=True, json_args=None, json_body_name=None, json_fix=None, rest_command=None, method=None, need_login=True): self._endpoint = endpoint self._path_arg_name = path_arg_name self._json_args = json_args self._json_body_name = json_body_name self._rest = rest_command self._need_login = need_login if not isinstance(json_fix, (list, tuple, type(None))): json_fix = [json_fix] self._fixes = json_fix self.__doc__ = doc args = [Parameter('self', POSITIONAL_ONLY)] if path_arg_name: args.append( Parameter( path_arg_name, POSITIONAL_OR_KEYWORD, default=None if path_arg_optional else Parameter.empty)) if json_args: json_args.sort(key=lambda x: isinstance(x, tuple)) for arg_name in json_args: if isinstance(arg_name, tuple): arg_name, default = arg_name else: default = Parameter.empty args.append( Parameter(arg_name, POSITIONAL_OR_KEYWORD, default=default)) if json_body_name: args.append( Parameter(json_body_name, POSITIONAL_OR_KEYWORD if path_arg_optional else POSITIONAL_OR_KEYWORD, default=None)) self.call_sig = Signature(args) if method is None: if json_args or json_body_name or rest_command: method = "POST" else: method = "GET" self._method = method def _build_url(self, client, path_arg): if not client.site: raise UnifiAPIError("No site specified for site-specific call") return "https://{host}:{port}/proxy/network/api/s/{site}/{endpoint}{path}".format( host=client.host, port=client.port, site=client.site, endpoint=self._endpoint, path="/" + path_arg if path_arg else "") def __call__(self, *args, **kwargs): bound = self.call_sig.bind(*args, **kwargs) bound.apply_defaults() # The first parameter is the 'self' of the API class to which it is attached client = bound.arguments["self"] path_arg = bound.arguments[ self._path_arg_name] if self._path_arg_name else None rest_dict = bound.arguments[ self._json_body_name] if self._json_body_name else {} if self._rest: rest_dict["cmd"] = self._rest if self._json_args: for arg_name in self._json_args: if isinstance(arg_name, tuple): arg_name, _ = arg_name if arg_name not in bound.arguments: raise TypeError("Argument {} is required".format(arg_name)) if bound.arguments[arg_name] is not None: rest_dict[arg_name] = bound.arguments[arg_name] if self._fixes: for fix in self._fixes: rest_dict = fix(rest_dict) url = self._build_url(client, path_arg) return client._execute(url, self._method, rest_dict, need_login=self._need_login)
def _prepare_call_args_from_dict( signature: Signature, params: t.Dict[str, t.Any]) -> t.Tuple[list, t.Dict[str, t.Any]]: """ Creates populated ``args`` and ``kwargs`` from params bound to signature. Processes every parameter in ``Signature.parameters`` and takes corresponding values from ``params``. NOTE: Just cannot simply call ``signature.bind(**params)`` because it cannot take VAR_POSITIONAL from dict by a key name. But we need this features since there's no way to pass ``*args`` to rpc call with dict params. Also it cannot do the same unpack thing with VAR_KEYWORD passed as a key in ``params``. Links: - https://www.python.org/dev/peps/pep-0457/#id14 :param signature: a signature of a callable :param params: a dict with callable arguments :return: a tuple with args and kwargs """ sentinel = object() params, args, kwargs = deepcopy(params), [], {} for name, parameter in signature.parameters.items(): parameter: Parameter = parameter value = params.pop(name, sentinel) if value is sentinel: value = parameter.default # Positional-only parameters don't accept default values according to PEP if parameter.kind is Parameter.POSITIONAL_ONLY: if value is Parameter.empty: raise TypeError(f'You must specify `{name}` argument') args.append(value) elif parameter.kind is Parameter.POSITIONAL_OR_KEYWORD: if value is Parameter.empty: # should not raise here # raise TypeError(f'You must specify `{name}` argument') # Example: # def srem(self, key, member, *members): # User may want to specify only ``members`` arg and it's ok for this signature # If something is incorrect Signature.bind should perform a final check continue args.append(value) elif parameter.kind is Parameter.VAR_POSITIONAL: if value is Parameter.empty: # user may not pass *args continue if not isinstance(value, list): raise TypeError(f'`{name}` must be a list') args += value elif parameter.kind is Parameter.KEYWORD_ONLY: if value is Parameter.empty: continue kwargs[name] = value elif parameter.kind is Parameter.VAR_KEYWORD: if value is Parameter.empty: continue if not isinstance(value, dict): raise TypeError( f'Keyword arguments passed in the variable `{name}` must be a dict' ) kwargs.update(value) else: raise TypeError( f'Unknown type `{parameter.kind.name}` for parameter {name}' ) # let Signature.bind do the rest ba: BoundArguments = signature.bind(*args, **kwargs, **params) ba.apply_defaults() return ba.args, ba.kwargs
def _get_bound_args(cls, args, kwargs): return Signature.bind(cls.__signature__, *args, **kwargs).arguments.items()
print(foo_params) # 创建一个函数参数列表,列表内的元素由类Parameter的实例组成 # Parameter实例化时,依次接受参数名、参数类型、默认值和参数注解 # 默认值和参数类型默认为空,这里的空值不是None,而是Parameter.empty,代表没有值 parms = [Parameter('x', Parameter.POSITIONAL_OR_KEYWORD), Parameter('y', Parameter.POSITIONAL_OR_KEYWORD), Parameter('z', Parameter.KEYWORD_ONLY, default=9)] # 使用Signature类实例化出一个函数签名实例 sig = Signature(parms) if __name__ == '__main__': bound_args_01 = sig.bind(1, 2, z=3) # <BoundArguments (x=1, y=2, z=3)> bound_args_02 = sig.bind(1, 2) # <BoundArguments (x=1, y=2)> # 引发异常 try: bound_args_03 = sig.bind(1) except TypeError as ex: print(ex) # missing a required argument: 'y' # 获取函数参数的内容 for name, value in bound_args_01.arguments.items(): print(name, value) # x 1 # y 2 # z 3
def __init__(self, *args, **kwargs): sig = Signature([Parameter(f, Parameter.POSITIONAL_OR_KEYWORD) for f in self.fields]) for k, v in sig.bind(*args, **kwargs).arguments.items(): setattr(self, k, v)
def __init__(self,*args,**kwargs): sig = Signature([Parameter(f,Parameter.POSITIONAL_OR_KEYWORD) for f in self.fields]) for k,v in sig.bind(*args,**kwargs).arguments.items(): setattr(self,k,v)
def _get_key(self, signature: Signature, *args, **kwargs) -> str: bound_arguments = signature.bind(*args, **kwargs) bound_arguments.apply_defaults() key = self._get_parsed_expression(self.key_expr, bound_arguments.arguments) return key
class PureTensorMethod: """A function taking specific tensor arguments. """ def __init__(self, assignment: Assignment, input_formats: Dict[str, Format], output_format: Format): if assignment.is_mutating(): raise ValueError( f'{assignment} mutates its target and is so is not a pure function' ) variable_orders = assignment.expression.variable_orders() # Ensure that all parameters are defined for variable_name in variable_orders.keys(): if variable_name not in input_formats: raise ValueError( f'Variable {variable_name} in {assignment} not listed in parameters' ) # Ensure that no extraneous parameters are defined for parameter_name in input_formats.keys(): if parameter_name not in variable_orders: raise ValueError( f'Parameter {parameter_name} not in {assignment} variables' ) # Verify that parameters have the correct order for parameter_name, format in input_formats.items(): if format.order != variable_orders[parameter_name]: raise ValueError( f'Parameter {parameter_name} has order {format.order}, but this variable in the ' f'assignment has order {variable_orders[parameter_name]}') if output_format.order != assignment.target.order: raise ValueError( f'Output parameter has order {output_format.order}, but the output variable in the ' f'assignment has order {assignment.target.order}') # Store validated attributes self.assignment = assignment self.input_formats = input_formats self.output_format = output_format # Create Python signature of the function self.signature = Signature([ Parameter(parameter_name, Parameter.POSITIONAL_OR_KEYWORD) for parameter_name in input_formats.keys() ]) # Compile taco function all_formats = { self.assignment.target.name: output_format, **input_formats } format_strings = frozenset( (parameter_name, format_to_taco_format(format)) for parameter_name, format in all_formats.items() if format.order != 0) # Taco does not like formats for scalars self.parameter_order, self.cffi_lib = taco_kernel( assignment.deparse(), format_strings) def __call__(self, *args, **kwargs): # Handle arguments like normal Python function bound_arguments = self.signature.bind(*args, **kwargs).arguments # Validate tensor arguments for name, argument, format in zip(bound_arguments.keys(), bound_arguments.values(), self.input_formats.values()): if argument.order != format.order: raise ValueError( f'Argument {name} must have order {format.order} not {argument.order}' ) if tuple(argument.modes) != tuple(format.modes): raise ValueError( f'Argument {name} must have modes ' f'{tuple(mode.name for mode in format.modes)} not ' f'{tuple(mode.name for mode in argument.modes)}') if tuple(argument.mode_ordering) != tuple(format.ordering): raise ValueError( f'Argument {name} must have mode ordering ' f'{format.ordering} not {argument.mode_ordering}') # Validate dimensions index_participants = self.assignment.expression.index_participants() index_sizes = {} for index, participants in index_participants.items(): # Extract the size of dimension referenced by this index on each tensor that uses it; record the variable # name and dimension for a better error actual_sizes = [(variable, dimension, bound_arguments[variable].dimensions[dimension]) for variable, dimension in participants] reference_size = actual_sizes[0][2] index_sizes[index] = reference_size for variable, dimension, size in actual_sizes[1:]: if size != reference_size: expected = ', '.join( f'{variable}.dimensions[{dimension}] == {size}' for variable, dimension, size in actual_sizes) raise ValueError( f'{self.assignment} expected all these dimensions of these tensors to be the same ' f'because they share the index {index}: {expected}') # Determine output dimensions output_dimensions = tuple(index_sizes[index] for index in self.assignment.target.indexes) cffi_output = allocate_taco_structure( tuple(mode.c_int for mode in self.output_format.modes), output_dimensions, self.output_format.ordering) output = Tensor(cffi_output) all_arguments = { self.assignment.target.name: output, **bound_arguments } cffi_args = [ all_arguments[name].cffi_tensor for name in self.parameter_order ] return_value = self.cffi_lib.evaluate(*cffi_args) take_ownership_of_arrays(cffi_output) if return_value != 0: raise RuntimeError( f'Taco function failed with error code {return_value}') return output
def _get_bargs(sig: inspect.Signature, args: typing.Tuple[typing.Any, ...], kwargs: typing.Dict[str, typing.Any]) -> inspect.BoundArguments: bargs = sig.bind(*args, **kwargs) bargs.apply_defaults() return bargs
def _prepare_call_args_from_list( signature: Signature, params: list) -> t.Tuple[list, t.Dict[str, t.Any]]: ba: BoundArguments = signature.bind(*params) ba.apply_defaults() return ba.args, ba.kwargs