def _get_unit_checks( self, bound_args: inspect.BoundArguments ) -> Dict[str, Dict[str, Any]]: """ Review :attr:`checks` and function bound arguments to build a complete 'checks' dictionary. If a check key is omitted from the argument checks, then a default value is assumed (see `check units`_) Parameters ---------- bound_args: :class:`inspect.BoundArguments` arguments passed into the function being wrapped .. code-block:: python bound_args = inspect.signature(f).bind(*args, **kwargs) Returns ------- Dict[str, Dict[str, Any]] A complete 'checks' dictionary for checking function input arguments and return. """ # initialize validation dictionary out_checks = {} # Iterate through function bound arguments + return and build `out_checks`: # # artificially add "return" to parameters things_to_check = bound_args.signature.parameters.copy() things_to_check["checks_on_return"] = inspect.Parameter( "checks_on_return", inspect.Parameter.POSITIONAL_ONLY, annotation=bound_args.signature.return_annotation, ) for param in things_to_check.values(): # variable arguments are NOT checked # e.g. in foo(x, y, *args, d=None, **kwargs) variable arguments # *args and **kwargs will NOT be checked # if param.kind in ( inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL, ): continue # grab the checks dictionary for the desired parameter try: param_checks = self.checks[param.name] except KeyError: param_checks = None # -- Determine target units `_units` -- # target units can be defined in one of three ways (in # preferential order): # 1. direct keyword pass-through # i.e. CheckUnits(x=u.cm) # CheckUnits(x=[u.cm, u.s]) # 2. keyword pass-through via dictionary definition # i.e. CheckUnits(x={'units': u.cm}) # CheckUnits(x={'units': [u.cm, u.s]}) # 3. function annotations # # * if option (3) is used simultaneously with option (1) or (2), then # checks defined by (3) must be consistent with checks from (1) or (2) # to avoid raising an error. # * if None is included in the units list, then None values are allowed # _none_shall_pass = False _units = None _units_are_from_anno = False if param_checks is not None: # checks for argument were defined with decorator try: _units = param_checks["units"] except TypeError: # if checks is NOT None and is NOT a dictionary, then assume # only units were specified # e.g. CheckUnits(x=u.cm) # _units = param_checks except KeyError: # if checks does NOT have 'units' but is still a dictionary, # then other check conditions may have been specified and the # user is relying on function annotations to define desired # units _units = None # If no units have been specified by decorator checks, then look for # function annotations. # # Reconcile units specified by decorator checks and function annotations _units_anno = None if param.annotation is not inspect.Parameter.empty: # unit annotations defined _units_anno = param.annotation if _units is None and _units_anno is None and param_checks is None: # no checks specified and no unit annotations defined continue elif _units is None and _units_anno is None: # checks specified, but NO unit checks msg = f"No astropy.units specified for " if param.name == "checks_on_return": msg += f"return value " else: msg += f"argument {param.name} " msg += f"of function {self.f.__name__}()." raise ValueError(msg) elif _units is None: _units = _units_anno _units_are_from_anno = True _units_anno = None # Ensure `_units` is an iterable if not isinstance(_units, collections.abc.Iterable): _units = [_units] if not isinstance(_units_anno, collections.abc.Iterable): _units_anno = [_units_anno] # Is None allowed? if None in _units or param.default is None: _none_shall_pass = True # Remove Nones if None in _units: _units = [t for t in _units if t is not None] if None in _units_anno: _units_anno = [t for t in _units_anno if t is not None] # ensure all _units are astropy.units.Unit or physical types & # define 'units' for unit checks & # define 'none_shall_pass' check _units = self._condition_target_units( _units, from_annotations=_units_are_from_anno ) _units_anno = self._condition_target_units( _units_anno, from_annotations=True ) if not all(_u in _units for _u in _units_anno): raise ValueError( f"For argument '{param.name}', " f"annotation units ({_units_anno}) are not included in the units " f"specified by decorator arguments ({_units}). Use either " f"decorator arguments or function annotations to defined unit " f"types, or make sure annotation specifications match decorator " f"argument specifications." ) if len(_units) == 0 and len(_units_anno) == 0 and param_checks is None: # annotations did not specify units continue elif len(_units) == 0 and len(_units_anno) == 0: # checks specified, but NO unit checks msg = f"No astropy.units specified for " if param.name == "checks_on_return": msg += f"return value " else: msg += f"argument {param.name} " msg += f"of function {self.f.__name__}()." raise ValueError(msg) out_checks[param.name] = { "units": _units, "none_shall_pass": _none_shall_pass, } # -- Determine target equivalencies -- # Unit equivalences can be defined by: # 1. keyword pass-through via dictionary definition # e.g. CheckUnits(x={'units': u.C, # 'equivalencies': u.temperature}) # # initialize equivalencies try: _equivs = param_checks["equivalencies"] except (KeyError, TypeError): _equivs = self.__check_defaults["equivalencies"] # ensure equivalences are properly formatted if _equivs is None or _equivs == [None]: _equivs = None elif isinstance(_equivs, Equivalency): pass elif isinstance(_equivs, (list, tuple)): # flatten list to non-list elements if isinstance(_equivs, tuple): _equivs = [_equivs] else: _equivs = self._flatten_equivalencies_list(_equivs) # ensure passed equivalencies list is structured properly # [(), ...] # or [Equivalency(), ...] # # * All equivalencies must be a list of 2, 3, or 4 element tuples # structured like... # (from_unit, to_unit, forward_func, backward_func) # if all(isinstance(el, Equivalency) for el in _equivs): _equivs = reduce(lambda x, y: x + y, _equivs) else: _equivs = self._normalize_equivalencies(_equivs) out_checks[param.name]["equivalencies"] = _equivs # -- Determine if equivalent units pass -- try: peu = param_checks.get( "pass_equivalent_units", self.__check_defaults["pass_equivalent_units"], ) except (AttributeError, TypeError): peu = self.__check_defaults["pass_equivalent_units"] out_checks[param.name]["pass_equivalent_units"] = peu # Does `self.checks` indicate arguments not used by f? missing_params = [ param for param in set(self.checks.keys()) - set(out_checks.keys()) ] if len(missing_params) > 0: params_str = ", ".join(missing_params) warnings.warn( PlasmaPyWarning( f"Expected to unit check parameters {params_str} but they " f"are missing from the call to {self.f.__name__}" ) ) return out_checks
_name = 'nan_policy' _type = "{'propagate', 'omit', 'raise'}" _desc = ("""Defines how to handle input NaNs. - ``propagate``: if a NaN is present in the axis slice (e.g. row) along which the statistic is computed, the corresponding entry of the output will be NaN. - ``omit``: NaNs will be omitted when performing the calculation. If insufficient data remains in the axis slice along which the statistic is computed, the corresponding entry of the output will be NaN. - ``raise``: if a NaN is present, a ``ValueError`` will be raised.""".split( '\n')) _nan_policy_parameter_doc = Parameter(_name, _type, _desc) _nan_policy_parameter = inspect.Parameter(_name, inspect.Parameter.KEYWORD_ONLY, default='propagate') _name = 'keepdims' _type = "bool, default: False" _desc = ("""If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, the result will broadcast correctly against the input array.""".split('\n')) _keepdims_parameter_doc = Parameter(_name, _type, _desc) _keepdims_parameter = inspect.Parameter(_name, inspect.Parameter.KEYWORD_ONLY, default=False) _standard_note_addition = ( """\nBeginning in SciPy 1.9, ``np.matrix`` inputs (not recommended for new code) are converted to ``np.ndarray`` before the calculation is performed. In
def _get_value_checks( self, bound_args: inspect.BoundArguments ) -> Dict[str, Dict[str, bool]]: """ Review :attr:`checks` and function bound arguments to build a complete 'checks' dictionary. If a check key is omitted from the argument checks, then a default value is assumed (see `check values`_). Parameters ---------- bound_args: :class:`inspect.BoundArguments` arguments passed into the function being wrapped .. code-block:: python bound_args = inspect.signature(f).bind(*args, **kwargs) Returns ------- Dict[str, Dict[str, bool]] A complete 'checks' dictionary for checking function input arguments and return. """ # initialize validation dictionary out_checks = {} # Iterate through function bound arguments + return and build `out_checks: # # artificially add "return" to parameters things_to_check = bound_args.signature.parameters.copy() things_to_check["checks_on_return"] = inspect.Parameter( "checks_on_return", inspect.Parameter.POSITIONAL_ONLY, annotation=bound_args.signature.return_annotation, ) for param in things_to_check.values(): # variable arguments are NOT checked # e.g. in foo(x, y, *args, d=None, **kwargs) variable arguments # *args and **kwargs will NOT be checked # if param.kind in ( inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL, ): continue # grab the checks dictionary for the desired parameter try: param_in_checks = self.checks[param.name] except KeyError: # checks for parameter not specified continue # build `out_checks` # read checks and/or apply defaults values out_checks[param.name] = {} for v_name, v_default in self.__check_defaults.items(): try: out_checks[param.name][v_name] = param_in_checks.get( v_name, v_default ) except AttributeError: # for the case that checks are defined for an argument, # but is NOT a dictionary # (e.g. CheckValues(x=u.cm) ... this scenario could happen # during subclassing) out_checks[param.name][v_name] = v_default # Does `self.checks` indicate arguments not used by f? missing_params = [ param for param in set(self.checks.keys()) - set(out_checks.keys()) ] if len(missing_params) > 0: params_str = ", ".join(missing_params) warnings.warn( PlasmaPyWarning( f"Expected to value check parameters {params_str} but they " f"are missing from the call to {self.f.__name__}" ) ) return out_checks
if not isinstance(file, file_types): if PY2: csvfile = open(file, 'wb') else: csvfile = open(file, 'w', newline='') autoclose = True else: csvfile = file autoclose = False try: writer = csv.writer(csvfile, **fmtparams) for row in reader: if nonstringiter(row): writer.writerow(row) else: writer.writerow([row]) finally: if autoclose: csvfile.close() with contextlib.suppress(AttributeError): # inspect.Signature() is new in 3.3 BaseQuery.__init__.__signature__ = inspect.Signature([ inspect.Parameter('self', inspect.Parameter.POSITIONAL_ONLY), inspect.Parameter('columns', inspect.Parameter.POSITIONAL_ONLY), inspect.Parameter('where', inspect.Parameter.VAR_KEYWORD), ])
def angular_freq_to_hz(fn): """ A decorator that adds to a function the ability to convert the function's return from angular frequency (rad/s) to frequency (Hz). A kwarg `to_hz` is added to the function's signature, with a default value of `False`. The keyword is also added to the function's docstring under the **"Other Parameters"** heading. Parameters ---------- fn : function The function to be decorated Raises ------ ValueError If `fn` has already defined a kwarg `to_hz` Returns ------- callable The decorated function Notes ----- * If `angular_freq_to_hz` is used with decorator :func:`~plasmapy.utils.decorators.validate_quantities`, then `angular_freq_to_hz` should be used inside :func:`~plasmapy.utils.decorators.validate_quantities` but special consideration is needed for setup. The following is an example of an appropriate setup:: import astropy.units as u from plasmapy.utils.decorators.converter import angular_freq_to_hz from plasmapy.utils.decorators.validators import validate_quantities @validate_quantities(validations_on_return={'units': [u.rad / u.s, u.Hz]}) @angular_freq_to_hz def foo(x: u.rad / u.s) -> u.rad / u.s return x Adding `u.Hz` to the allowed units allows the converted quantity to pass the validations. Examples -------- >>> import astropy.units as u >>> from plasmapy.utils.decorators.converter import angular_freq_to_hz >>> >>> @angular_freq_to_hz ... def foo(x): ... return x >>> >>> foo(5 * u.rad / u.s, to_hz=True) <Quantity 0.79577472 Hz> >>> >>> foo(-1 * u.rad / u.s, to_hz=True) <Quantity -0.15915494 Hz> Decoration also works with methods >>> class Foo: ... def __init__(self, x): ... self.x = x ... ... @angular_freq_to_hz ... def bar(self): ... return self.x >>> >>> foo = Foo(0.5 * u.rad / u.s) >>> foo.bar(to_hz=True) <Quantity 0.07957747 Hz> """ # raise exception if fn uses the 'to_hz' kwarg sig = inspect.signature(fn) if "to_hz" in sig.parameters: raise ValueError( f"Wrapped function '{fn.__name__}' can not use keyword 'to_hz'." f" Keyword reserved for decorator functionality.") # make new signature for fn new_params = sig.parameters.copy() new_params["to_hz"] = inspect.Parameter( "to_hz", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=False) new_sig = inspect.Signature(parameters=new_params.values(), return_annotation=sig.return_annotation) fn.__signature__ = new_sig @preserve_signature @functools.wraps(fn) def wrapper(*args, to_hz=False, **kwargs): _result = fn(*args, **kwargs) if to_hz: return _result.to(u.Hz, equivalencies=[(u.cy / u.s, u.Hz)]) return _result added_doc_bit = """ Other Parameters ---------------- to_hz: bool Set `True` to to convert function output from angular frequency to Hz """ if wrapper.__doc__ is not None: wrapper.__doc__ += added_doc_bit else: wrapper.__doc__ = added_doc_bit return wrapper
def decomposition_decorator(f: Callable) -> Callable: sig = inspect.signature(f) out_annotation = f.__annotations__.get("out") # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this fn = f if out_annotation and getattr(out_annotation, "__origin__", None) is tuple: out_names = sig.return_annotation._fields # If out is a tuple, we need to register a function that unpacks all the out # elements as this is what native_functions.yaml expects @wraps(f) def _fn(*args, **kwargs): out_kwargs = tuple(kwargs.pop(o, None) for o in out_names) # Either all of the out kwargs are set or none of them is_none = out_kwargs[0] is None assert all((o is None) == is_none for o in out_kwargs) return f(*args, **kwargs, out=None if is_none else out_kwargs) out_params = [ inspect.Parameter( o, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=t, ) for o, t in zip(out_names, out_annotation.__args__) ] # Drop the out parameter and concatenate the new kwargs in the signature params = chain( (v for k, v in sig.parameters.items() if k != "out"), out_params) _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] parameters=params, return_annotation=sig. return_annotation # type: ignore[arg-type] ) # Drop the out parameter and concatenate the new kwargs in the annotations _fn.__annotations__ = { k: v for k, v in f.__annotations__.items() if k != "out" } for o in out_params: _fn.__annotations__[o.name] = o.annotation fn = _fn nonlocal registry if registry is None: registry = decomposition_table def add_op_to_table(aten_op): overloads = [] if isinstance(aten_op, torch._ops.OpOverload): overloads.append(aten_op) else: assert isinstance(aten_op, torch._ops.OpOverloadPacket) for ol in aten_op.overloads(): overloads.append(getattr(aten_op, ol)) for op_overload in overloads: if op_overload in registry: raise RuntimeError( f"duplicate registrations for {op_overload}") registry[op_overload] = fn # TODO: factor this logic into OpOverload or Library API name = op_overload._schema.name if op_overload._schema.overload_name: name += "." + op_overload._schema.overload_name if (not disable_meta # TorchScript dumps a bunch of extra nonsense overloads # which don't have corresponding dispatcher entries, we need # to filter those out and torch._C._dispatch_has_kernel(name) # Don't register a python meta kernel to any operator that has # should already work with meta tensors today. # We can check that by seeing if the "computed table" for the operator # has a registration to Meta; # either through a direct registration, or an indirect one through # an alias dispatch key (e.g. CompositeImplicitAutograd) and not torch._C. _dispatch_has_computed_kernel_for_dispatch_key( name, "Meta")): if any(a.alias_info is not None and not a.alias_info.is_write for a in op_overload._schema.arguments): raise RuntimeError(f""" Attempting to register a python meta kernel for a view operator: {str(op_overload)}. We shouldn't do this, because the output will report as not having aliased storages. All view ops have meta kernels in C++ today, so we should use those instead. If you're registering an operator through the `@register_decomposition` decorator, Please set `disable_meta=True`. """) meta_lib.impl(op_overload, fn) # To handle allowing multiple aten_ops at once tree_map(add_op_to_table, aten_op) return fn
def match(self, hint: object, injectable: Injectable, container: Container): sub_hint = getattr(hint, '__args__', [None])[0] param = inspect.Parameter(name='_', kind=1, annotation=sub_hint) return get_candidates(param.annotation, container=container)
def decorator_maker(tested_scene_construct): if ( SCENE_PARAMETER_NAME not in inspect.getfullargspec(tested_scene_construct).args ): raise Exception( f"Invalid graphical test function test function : must have '{SCENE_PARAMETER_NAME}'as one of the parameters.", ) # Exclude "scene" from the argument list of the signature. old_sig = inspect.signature( functools.partial(tested_scene_construct, scene=None), ) if "__module_test__" not in tested_scene_construct.__globals__: raise Exception( "There is no module test name indicated for the graphical unit test. You have to declare __module_test__ in the test file.", ) module_name = tested_scene_construct.__globals__.get("__module_test__") test_name = tested_scene_construct.__name__[len("test_") :] @functools.wraps(tested_scene_construct) # The "request" parameter is meant to be used as a fixture by pytest. See below. def wrapper(*args, request: FixtureRequest, tmp_path, **kwargs): # Wraps the test_function to a construct method, to "freeze" the eventual additional arguments (parametrizations fixtures). construct = functools.partial(tested_scene_construct, *args, **kwargs) # Kwargs contains the eventual parametrization arguments. # This modifies the test_name so that it is defined by the parametrization # arguments too. # Example: if "length" is parametrized from 0 to 20, the kwargs # will be once with {"length" : 1}, etc. test_name_with_param = test_name + "_".join( f"_{str(tup[0])}[{str(tup[1])}]" for tup in kwargs.items() ) config_tests = _config_test(last_frame) config_tests["text_dir"] = tmp_path config_tests["tex_dir"] = tmp_path if last_frame: config_tests["frame_rate"] = 1 config_tests["dry_run"] = True setting_test = request.config.getoption("--set_test") try: test_file_path = tested_scene_construct.__globals__["__file__"] except Exception: test_file_path = None real_test = _make_test_comparing_frames( file_path=_control_data_path( test_file_path, module_name, test_name_with_param, setting_test, ), base_scene=base_scene, construct=construct, renderer_class=renderer_class, is_set_test_data_test=setting_test, last_frame=last_frame, show_diff=request.config.getoption("--show_diff"), size_frame=(config_tests["pixel_height"], config_tests["pixel_width"]), ) # Isolate the config used for the test, to avoid modifying the global config during the test run. with tempconfig({**config_tests, **custom_config}): real_test() parameters = list(old_sig.parameters.values()) # Adds "request" param into the signature of the wrapper, to use the associated pytest fixture. # This fixture is needed to have access to flags value and pytest's config. See above. if "request" not in old_sig.parameters: parameters += [inspect.Parameter("request", inspect.Parameter.KEYWORD_ONLY)] if "tmp_path" not in old_sig.parameters: parameters += [ inspect.Parameter("tmp_path", inspect.Parameter.KEYWORD_ONLY), ] new_sig = old_sig.replace(parameters=parameters) wrapper.__signature__ = new_sig # Reach a bit into pytest internals to hoist the marks from our wrapped # function. setattr(wrapper, "pytestmark", []) new_marks = getattr(tested_scene_construct, "pytestmark", []) wrapper.pytestmark = new_marks return wrapper
def post_init(cls: Type[U]) -> Type[U]: """ Class decorator to automatically support __post_init__() on classes This is useful for @attr.s decorated classes, because __attr_post_init__() doesn't support additional arguments. This decorators wraps the class __init__ in a new function that accept merged arguments, and dispatch them to __init__ and then __post_init__() """ if not isinstance(cls, type): raise TypeError("Can only decorate classes") if not hasattr(cls, "__post_init__"): raise TypeError("The class must have a __post_init__() method") # Ignore the first argument which is the "self" argument sig = init_sig = _sig_without(inspect.signature(cls.__init__), 0) previous = [(cls, "__init__", sig)] for parent in reversed(cls.__mro__): if hasattr(parent, "__post_init__"): post_sig = _sig_without( inspect.signature(getattr(parent, "__post_init__")), 0 ) try: sig = _sig_merge(sig, post_sig) except Exception as err: # find the incompatibility for parent, method, psig in previous: try: _sig_merge(psig, post_sig) except Exception: break else: raise TypeError( "__post_init__ signature is incompatible with the class" ) from err raise TypeError( f"__post_init__() is incompatible with {parent.__qualname__}{method}()" ) from err # No exception previous.append((parent, "__post_init__", post_sig)) # handles type annotations and defaults # inspired by the dataclasses modules params = list(sig.parameters.values()) localns = ( { f"__type_{p.name}": p.annotation for p in params if p.annotation is not inspect.Parameter.empty } | { f"__default_{p.name}": p.default for p in params if p.default is not inspect.Parameter.empty } | cls.__dict__ ) for i, p in enumerate(params): if p.default is not inspect.Parameter.empty: p = p.replace(default=Variable(f"__default_{p.name}")) if p.annotation is not inspect.Parameter.empty: p = p.replace(annotation=f"__type_{p.name}") params[i] = p new_sig = inspect.Signature(params) # Build the new __init__ source code self_ = "self" if "self" not in sig.parameters else "__post_init_self" init_lines = [ f"def __init__({self_}, {_sig_to_def(new_sig)}) -> None:", f"__original_init({self_}, {_sig_to_call(init_sig)})", ] for parent, method, psig in previous[1:]: if hasattr(parent, "__post_init__"): if parent is not cls: init_lines.append( f"super({parent.__qualname__}, {self_}).{method}({_sig_to_call(psig)})" ) else: init_lines.append(f"{self_}.{method}({_sig_to_call(psig)})") init_src = "\n ".join(init_lines) # Build the factory function source code local_vars = ", ".join(localns.keys()) factory_src = ( f"def __make_init__(__original_init, {local_vars}):\n" f" {init_src}\n" " return __init__" ) # Create new __init__ with the factory globalns = inspect.getmodule(cls).__dict__ ns: dict[str, Any] = {} exec(factory_src, globalns, ns) init = ns["__make_init__"](cls.__init__, **localns) self_param = inspect.Parameter(self_, inspect.Parameter.POSITIONAL_ONLY) init.__signature__ = inspect.Signature( parameters=[self_param] + list(sig.parameters.values()), return_annotation=None ) setattr(cls, "__init__", init) return cls
def test_conv_str_choices_valid(): """Calling str type with str_choices and valid value.""" param = inspect.Parameter('foo', inspect.Parameter.POSITIONAL_ONLY) converted = argparser.type_conv(param, str, 'val1', str_choices=['val1', 'val2']) assert converted == 'val1'
def test_conv_str_choices_invalid(): """Calling str type with str_choices and invalid value.""" param = inspect.Parameter('foo', inspect.Parameter.POSITIONAL_ONLY) with pytest.raises(cmdexc.ArgumentTypeError, match='foo: Invalid value ' 'val3 - expected one of: val1, val2'): argparser.type_conv(param, str, 'val3', str_choices=['val1', 'val2'])
def test_multitype_conv_invalid_type(): """Test using an invalid type with a multitype converter.""" param = inspect.Parameter('foo', inspect.Parameter.POSITIONAL_ONLY) with pytest.raises(ValueError, match="foo: Unknown type None!"): argparser.multitype_conv(param, [None], '')
def _get_sig(cls): return inspect.Signature([inspect.Parameter(n, inspect.Parameter.POSITIONAL_OR_KEYWORD) for n in cls.field_names])
def _fields_from_attrs(kind: ParameterKind, attrs: Tuple[str, ...]): return {x: inspect.Parameter(x, kind) for x in attrs}
def _create_test_entry_function(cls, module, devices): # Creates a test entry function from the template class, and places it in # the same module as the class. # We enforce 'Test' prefix in OpTest implementations so that they look like # unittest.TestCase implementations. OTOH generated entry function must # have a prefix 'test_' in order for it to be found in pytest test # collection. if not cls.__name__.startswith('Test'): raise TypeError( 'OpTest class name must start with \'Test\'. Actual: {!r}'.format( cls.__name__)) func_name = 'test_{}'.format(cls.__name__[len('Test'):]) @pytest.mark.parametrize_device(devices) def entry_func(device, *args, **kwargs): backend_config = _make_backend_config(device.name) # Forward test obj = cls() try: obj.setup(*args, **kwargs) obj.run_test_forward(backend_config) finally: obj.teardown() # If this is a NumpyOpTest instance, skip backward/double-backward # tests if the forward test succeeds with acceptable errors. if isinstance(obj, NumpyOpTest): if obj.is_forward_successful_with_accept_errors: return # success with expected errors # Backward test obj = cls() try: obj.setup(*args, **kwargs) obj.run_test_backward(backend_config) finally: obj.teardown() # Double-backward test obj = cls() try: obj.setup(*args, **kwargs) obj.run_test_double_backward(backend_config) finally: obj.teardown() entry_func.__name__ = func_name # Set the signature of the entry function sig = inspect.signature(cls.setup) params = list(sig.parameters.values()) params = params[1:] # Remove `self` argument device_param = inspect.Parameter('device', inspect.Parameter.POSITIONAL_OR_KEYWORD) params = [device_param] + params # Prepend `device` argument entry_func.__signature__ = inspect.Signature(params) # Set the pytest mark try: pytestmark = cls.pytestmark entry_func.pytestmark += pytestmark except AttributeError: pass # Place the entry function in the module of the class setattr(module, func_name, entry_func)
def add( self, instruction: Union[str, Instruction], qubits: Union[int, Iterable[int]], schedule: Union[Schedule, ScheduleBlock, Callable[..., Union[Schedule, ScheduleBlock]]], arguments: Optional[List[str]] = None, ) -> None: """Add a new known instruction for the given qubits and its mapping to a pulse schedule. Args: instruction: The name of the instruction to add. qubits: The qubits which the instruction applies to. schedule: The Schedule that implements the given instruction. arguments: List of parameter names to create a parameter-bound schedule from the associated gate instruction. If :py:meth:`get` is called with arguments rather than keyword arguments, this parameter list is used to map the input arguments to parameter objects stored in the target schedule. Raises: PulseError: If the qubits are provided as an empty iterable. """ instruction = _get_instruction_string(instruction) # validation of target qubit qubits = _to_tuple(qubits) if qubits == (): raise PulseError( f"Cannot add definition {instruction} with no target qubits.") # generate signature if isinstance(schedule, (Schedule, ScheduleBlock)): ordered_names = sorted( list({par.name for par in schedule.parameters})) if arguments: if set(arguments) != set(ordered_names): raise PulseError( "Arguments does not match with schedule parameters. " f"{set(arguments)} != {schedule.parameters}.") ordered_names = arguments parameters = list() for argname in ordered_names: param_signature = inspect.Parameter( name=argname, annotation=ParameterValueType, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, ) parameters.append(param_signature) signature = inspect.Signature(parameters=parameters, return_annotation=type(schedule)) elif isinstance(schedule, ParameterizedSchedule): # TODO remove this warnings.warn( "ParameterizedSchedule has been deprecated. " "Define Schedule with Parameter objects.", DeprecationWarning, ) parameters = list() for argname in schedule.parameters: param_signature = inspect.Parameter( name=argname, annotation=ParameterValueType, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, ) parameters.append(param_signature) signature = inspect.Signature(parameters=parameters, return_annotation=Schedule) elif callable(schedule): if arguments: warnings.warn( "Arguments are overridden by the callback function signature. " "Input `arguments` are ignored.", UserWarning, ) signature = inspect.signature(schedule) else: raise PulseError( "Supplied schedule must be one of the Schedule, ScheduleBlock or a " "callable that outputs a schedule.") self._map[instruction][qubits] = Generator(schedule, signature) self._qubit_instructions[qubits].add(instruction)
def the_decorator(func): if Version(current_library_version) >= version: return func new_signature = inspect.signature(func) new_arg_names = [name for name in new_signature.parameters] if previous_arg_order is None: old_nargs = len(new_signature.parameters) old_arg_names = new_arg_names[:old_nargs] else: old_nargs = len(previous_arg_order) old_arg_names = previous_arg_order # These arguments are still required as argument only keywords func_args = inspect.getfullargspec(func).args new_nargs = len(func_args) old_parameters = [] all_params = {**new_signature.parameters} for key in old_arg_names: param = all_params.pop(key) if key in func_args: kind = param.kind else: kind = POSITIONAL_OR_KEYWORD old_parameters.append( inspect.Parameter(key, kind=kind, default=param.default, annotation=param.annotation)) for key, param in all_params.items(): old_parameters.append( inspect.Parameter(key, kind=param.kind, default=param.default, annotation=param.annotation)) old_signature = new_signature.replace(parameters=old_parameters) @wraps(func) def wrapper(*args, **kwargs): if len(args) > old_nargs: # The warning should be issued here too! raise TypeError('{name}() takes {old_nargs} positional ' 'arguments but {len_args} were given' ''.format(name=func.__name__, old_nargs=old_nargs, len_args=len(args))) if len(args) > new_nargs: for key, value in zip(old_arg_names[new_nargs:len(args)], args[new_nargs:]): if key in kwargs: calling_function = inspect.stack()[1] s = SyntaxError( "In version {version} of {library_name}, the " "argument ('{key}') has " "will become a keyword only argument. You " "specified it as both a positional argument and " "a keyword argument." "".format(version=version, library_name=library_name, key=key)) s.lineno = calling_function.lineno s.filename = calling_function.filename # Even the normal syntax errors suck at telling you # the position of your error whe a statement spans # multiple lines # s.offset = # is it worth it to find where? raise s kwargs[key] = value warn("In version {version} of {library_name}, the " "argument(s): '{old_pos_args}' will become keyword-only " "argument(s). To suppress this warning, specify all " "listed argument(s) with keywords." "".format( version=version, library_name=library_name, old_pos_args=old_arg_names[new_nargs:len(args)]), FutureWarning, stacklevel=2) args = args[:new_nargs] return func(*args, **kwargs) if keep_old_signature: wrapper.__signature__ = old_signature # only add a docstring if they had one already if wrapper.__doc__ is None: return wrapper warnings_string = """ Warns ----- FutureWarning In release {version} of {module}, the argument(s): `{args}` will become keyword-only arguments. To avoid this warning, provide all the above arguments as keyword arguments. """.format(version=version, module=library_name, funcname=func.__name__, args=', '.join(old_arg_names[new_nargs:])) wrapper.__doc__ = merge_docstrings(wrapper, warnings_string) return wrapper
def replace_parameter( param: inspect.Parameter, converter: Any, callback: Callable[..., Any], original: Parameter, mapping: Dict[str, inspect.Parameter], ) -> inspect.Parameter: try: # If it's a supported annotation (i.e. a transformer) just let it pass as-is. app_commands.transformers.get_supported_annotation(converter) except TypeError: # Fallback to see if the behaviour needs changing origin = getattr(converter, '__origin__', None) args = getattr(converter, '__args__', []) if isinstance(converter, Range): r = converter param = param.replace( annotation=app_commands.Range[r.annotation, r.min, r.max]) # type: ignore elif isinstance(converter, Greedy): # Greedy is "optional" in ext.commands # However, in here, it probably makes sense to make it required. # I'm unsure how to allow the user to choose right now. inner = converter.converter if inner is discord.Attachment: raise TypeError( 'discord.Attachment with Greedy is not supported in hybrid commands' ) param = param.replace( annotation=make_greedy_transformer(inner, original)) elif is_flag(converter): callback.__hybrid_command_flag__ = (param.name, converter) descriptions = {} renames = {} for flag in converter.__commands_flags__.values(): name = flag.attribute flag_param = inspect.Parameter( name=name, kind=param.kind, default=flag.default if flag.default is not MISSING else inspect.Parameter.empty, annotation=flag.annotation, ) pseudo = replace_parameter(flag_param, flag.annotation, callback, original, mapping) if name in mapping: raise TypeError( f'{name!r} flag would shadow a pre-existing parameter') if flag.description is not MISSING: descriptions[name] = flag.description if flag.name != flag.attribute: renames[name] = flag.name mapping[name] = pseudo # Manually call the decorators if descriptions: app_commands.describe(**descriptions)(callback) if renames: app_commands.rename(**renames)(callback) elif is_converter(converter): param = param.replace( annotation=make_converter_transformer(converter)) elif origin is Union: if len(args) == 2 and args[-1] is _NoneType: # Special case Optional[X] where X is a single type that can optionally be a converter inner = args[0] is_inner_tranformer = is_transformer(inner) if is_converter(inner) and not is_inner_tranformer: param = param.replace(annotation=Optional[ make_converter_transformer(inner)]) # type: ignore else: raise elif origin: # Unsupported typing.X annotation e.g. typing.Dict, typing.Tuple, typing.List, etc. raise elif callable(converter) and not inspect.isclass(converter): param_count = required_pos_arguments(converter) if param_count != 1: raise param = param.replace( annotation=make_callable_transformer(converter)) return param
def RaiseMissingArguement(): raise commands.MissingRequiredArgument( inspect.Parameter("startdate", inspect.Parameter.POSITIONAL_ONLY))
def decorator(obj): """Outer wrapper. The outer wrapper is used to create the decorating wrapper. @param obj: function being wrapped @type obj: object """ def wrapper(*__args, **__kw): """Replacement function. @param __args: args passed to the decorated function @type __args: list @param __kwargs: kwargs passed to the decorated function @type __kwargs: dict @return: the value returned by the decorated function @rtype: any """ name = obj.__full_name__ for old_arg, new_arg in arg_pairs.items(): if old_arg in __kw: if new_arg not in [True, False, None]: if new_arg in __kw: warning(u"%(new_arg)s argument of %(name)s " "replaces %(old_arg)s; cannot use both." % locals()) else: # If the value is positionally given this will # cause a TypeError, which is intentional warning(u"%(old_arg)s argument of %(name)s " "is deprecated; use %(new_arg)s instead." % locals()) __kw[new_arg] = __kw[old_arg] elif new_arg is not False: debug( u"%(old_arg)s argument of %(name)s is " "deprecated." % locals(), _logger) del __kw[old_arg] return obj(*__args, **__kw) if not __debug__: return obj wrapper.__doc__ = obj.__doc__ wrapper.__name__ = obj.__name__ wrapper.__module__ = obj.__module__ wrapper.__signature__ = signature(obj) if wrapper.__signature__: # Build a new signature with deprecated args added. params = collections.OrderedDict() for param in wrapper.__signature__.parameters.values(): params[param.name] = param.replace() for old_arg, new_arg in arg_pairs.items(): params[old_arg] = inspect.Parameter( old_arg, kind=inspect._POSITIONAL_OR_KEYWORD, default='[deprecated name of ' + new_arg + ']' if new_arg not in [True, False, None] else NotImplemented) wrapper.__signature__ = inspect.Signature() wrapper.__signature__._parameters = params if not hasattr(obj, '__full_name__'): add_decorated_full_name(obj) wrapper.__full_name__ = obj.__full_name__ return wrapper
output = np.ones(output_shape) * np.nan return output # Standard docstring / signature entries for `axis` and `nan_policy` _name = 'axis' _type = "int or None, default: 0" _desc = ( """If an int, the axis of the input along which to compute the statistic. The statistic of each axis-slice (e.g. row) of the input will appear in a corresponding element of the output. If ``None``, the input will be raveled before computing the statistic.""". split('\n')) _axis_parameter_doc = Parameter(_name, _type, _desc) _axis_parameter = inspect.Parameter(_name, inspect.Parameter.KEYWORD_ONLY, default=0) _name = 'nan_policy' _type = "{'propagate', 'omit', 'raise'}" _desc = ("""Defines how to handle input NaNs. - ``propagate``: if a NaN is present in the axis slice (e.g. row) along which the statistic is computed, the corresponding entry of the output will be NaN. - ``omit``: NaNs will be omitted when performing the calculation. If insufficient data remains in the axis slice along which the statistic is computed, the corresponding entry of the output will be NaN. - ``raise``: if a NaN is present, a ``ValueError`` will be raised.""".split( '\n'))
else: raise TypeError return deviation or 0, expected or 0 def call_predicate(self, item): _, diff = item # Unpack item (discarding key). try: deviation, _ = self._get_deviation_expected(diff) except TypeError: return False # <- EXIT! return self.lower <= deviation <= self.upper with contextlib.suppress(AttributeError): # inspect.Signature() is new in 3.3 AcceptedTolerance.__init__.__signature__ = inspect.Signature([ inspect.Parameter('self', inspect.Parameter.POSITIONAL_ONLY), inspect.Parameter('tolerance', inspect.Parameter.POSITIONAL_ONLY), inspect.Parameter('msg', inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None), ]) class AcceptedPercent(AcceptedTolerance): """AcceptedPercent(tolerance, /, msg=None) AcceptedPercent(lower, upper, msg=None) Context manager that accepts Deviations within a given percent tolerance without triggering a test failure. See documentation for full details. """ def call_predicate(self, item):
def decorator(func): import pytest _, result_dir = _image_directories(func) old_sig = inspect.signature(func) if not {"fig_test", "fig_ref"}.issubset(old_sig.parameters): raise ValueError("The decorated function must have at least the " "parameters 'fig_ref' and 'fig_test', but your " f"function has the signature {old_sig}") @pytest.mark.parametrize("ext", extensions) def wrapper(*args, ext, request, **kwargs): if 'ext' in old_sig.parameters: kwargs['ext'] = ext if 'request' in old_sig.parameters: kwargs['request'] = request file_name = "".join(c for c in request.node.name if c in ALLOWED_CHARS) try: fig_test = plt.figure("test") fig_ref = plt.figure("reference") # Keep track of number of open figures, to make sure test # doesn't create any new ones n_figs = len(plt.get_fignums()) func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs) if len(plt.get_fignums()) > n_figs: raise RuntimeError('Number of open figures changed during ' 'test. Make sure you are plotting to ' 'fig_test or fig_ref, or if this is ' 'deliberate explicitly close the ' 'new figure(s) inside the test.') test_image_path = result_dir / (file_name + "." + ext) ref_image_path = result_dir / (file_name + "-expected." + ext) fig_test.savefig(test_image_path) fig_ref.savefig(ref_image_path) _raise_on_image_difference(ref_image_path, test_image_path, tol=tol) finally: plt.close(fig_test) plt.close(fig_ref) parameters = [ param for param in old_sig.parameters.values() if param.name not in {"fig_test", "fig_ref"} ] if 'ext' not in old_sig.parameters: parameters += [inspect.Parameter("ext", KEYWORD_ONLY)] if 'request' not in old_sig.parameters: parameters += [inspect.Parameter("request", KEYWORD_ONLY)] new_sig = old_sig.replace(parameters=parameters) wrapper.__signature__ = new_sig # reach a bit into pytest internals to hoist the marks from # our wrapped function new_marks = getattr(func, "pytestmark", []) + wrapper.pytestmark wrapper.pytestmark = new_marks return wrapper
def decomposition_decorator(f: Callable) -> Callable: sig = inspect.signature(f) out_annotation = f.__annotations__.get("out") # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this fn = f if out_annotation and getattr(out_annotation, "__origin__", None) is tuple: out_names = sig.return_annotation._fields # If out is a tuple, we need to register a function that unpacks all the out # elements as this is what native_functions.yaml expects @wraps(f) def _fn(*args, **kwargs): out_kwargs = tuple(kwargs.pop(o, None) for o in out_names) # Either all of the out kwargs are set or none of them is_none = out_kwargs[0] is None assert all((o is None) == is_none for o in out_kwargs) return f(*args, **kwargs, out=None if is_none else out_kwargs) out_params = [ inspect.Parameter( o, kind=inspect.Parameter.KEYWORD_ONLY, default=None, annotation=t, ) for o, t in zip(out_names, out_annotation.__args__) ] # Drop the out parameter and concatenate the new kwargs in the signature params = chain( (v for k, v in sig.parameters.items() if k != "out"), out_params ) _fn.__signature__ = inspect.Signature( # type: ignore[attr-defined] parameters=params, return_annotation=sig.return_annotation # type: ignore[arg-type] ) # Drop the out parameter and concatenate the new kwargs in the annotations _fn.__annotations__ = { k: v for k, v in f.__annotations__.items() if k != "out" } for o in out_params: _fn.__annotations__[o.name] = o.annotation fn = _fn nonlocal registry if registry is None: registry = decomposition_table def add_op_to_table(aten_op): overloads = [] if isinstance(aten_op, torch._ops.OpOverload): overloads.append(aten_op) else: assert isinstance(aten_op, torch._ops.OpOverloadPacket) for ol in aten_op.overloads(): overloads.append(getattr(aten_op, ol)) for op_overload in overloads: if op_overload in registry: raise RuntimeError(f"duplicate registrations for {op_overload}") registry[op_overload] = fn # TODO: factor this logic into OpOverload or Library API name = op_overload._schema.name if op_overload._schema.overload_name: name += "." + op_overload._schema.overload_name if ( not disable_meta # TorchScript dumps a bunch of extra nonsense overloads # which don't have corresponding dispatcher entries, we need # to filter those out and torch._C._dispatch_has_kernel(name) # Don't register a meta kernel to any operator that has # a CompositeImplicitAutograd kernel in core. # Otherwise we won't be able to run autograd for that operator with the meta backend. and "CompositeImplicitAutograd" not in torch._C._dispatch_dump(name) and not torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta") ): meta_lib.impl(op_overload, fn) # To handle allowing multiple aten_ops at once tree_map(add_op_to_table, aten_op) return fn