示例#1
0
文件: checks.py 项目: zhyhou/PlasmaPy
    def _get_unit_checks(
        self, bound_args: inspect.BoundArguments
    ) -> Dict[str, Dict[str, Any]]:
        """
        Review :attr:`checks` and function bound arguments to build a complete 'checks'
        dictionary.  If a check key is omitted from the argument checks, then a default
        value is assumed (see `check units`_)

        Parameters
        ----------
        bound_args: :class:`inspect.BoundArguments`
            arguments passed into the function being wrapped

            .. code-block:: python

                bound_args = inspect.signature(f).bind(*args, **kwargs)

        Returns
        -------
        Dict[str, Dict[str, Any]]
            A complete 'checks' dictionary for checking function input arguments
            and return.
        """
        # initialize validation dictionary
        out_checks = {}

        # Iterate through function bound arguments + return and build `out_checks`:
        #
        # artificially add "return" to parameters
        things_to_check = bound_args.signature.parameters.copy()
        things_to_check["checks_on_return"] = inspect.Parameter(
            "checks_on_return",
            inspect.Parameter.POSITIONAL_ONLY,
            annotation=bound_args.signature.return_annotation,
        )
        for param in things_to_check.values():
            # variable arguments are NOT checked
            # e.g. in foo(x, y, *args, d=None, **kwargs) variable arguments
            #      *args and **kwargs will NOT be checked
            #
            if param.kind in (
                inspect.Parameter.VAR_KEYWORD,
                inspect.Parameter.VAR_POSITIONAL,
            ):
                continue

            # grab the checks dictionary for the desired parameter
            try:
                param_checks = self.checks[param.name]
            except KeyError:
                param_checks = None

            # -- Determine target units `_units` --
            # target units can be defined in one of three ways (in
            # preferential order):
            #   1. direct keyword pass-through
            #      i.e. CheckUnits(x=u.cm)
            #           CheckUnits(x=[u.cm, u.s])
            #   2. keyword pass-through via dictionary definition
            #      i.e. CheckUnits(x={'units': u.cm})
            #           CheckUnits(x={'units': [u.cm, u.s]})
            #   3. function annotations
            #
            # * if option (3) is used simultaneously with option (1) or (2), then
            #   checks defined by (3) must be consistent with checks from (1) or (2)
            #   to avoid raising an error.
            # * if None is included in the units list, then None values are allowed
            #
            _none_shall_pass = False
            _units = None
            _units_are_from_anno = False
            if param_checks is not None:
                # checks for argument were defined with decorator
                try:
                    _units = param_checks["units"]
                except TypeError:
                    # if checks is NOT None and is NOT a dictionary, then assume
                    # only units were specified
                    #   e.g. CheckUnits(x=u.cm)
                    #
                    _units = param_checks
                except KeyError:
                    # if checks does NOT have 'units' but is still a dictionary,
                    # then other check conditions may have been specified and the
                    # user is relying on function annotations to define desired
                    # units
                    _units = None

            # If no units have been specified by decorator checks, then look for
            # function annotations.
            #
            # Reconcile units specified by decorator checks and function annotations
            _units_anno = None
            if param.annotation is not inspect.Parameter.empty:
                # unit annotations defined
                _units_anno = param.annotation

            if _units is None and _units_anno is None and param_checks is None:
                # no checks specified and no unit annotations defined
                continue
            elif _units is None and _units_anno is None:
                # checks specified, but NO unit checks
                msg = f"No astropy.units specified for "
                if param.name == "checks_on_return":
                    msg += f"return value "
                else:
                    msg += f"argument {param.name} "
                msg += f"of function {self.f.__name__}()."
                raise ValueError(msg)
            elif _units is None:
                _units = _units_anno
                _units_are_from_anno = True
                _units_anno = None

            # Ensure `_units` is an iterable
            if not isinstance(_units, collections.abc.Iterable):
                _units = [_units]
            if not isinstance(_units_anno, collections.abc.Iterable):
                _units_anno = [_units_anno]

            # Is None allowed?
            if None in _units or param.default is None:
                _none_shall_pass = True

            # Remove Nones
            if None in _units:
                _units = [t for t in _units if t is not None]
            if None in _units_anno:
                _units_anno = [t for t in _units_anno if t is not None]

            # ensure all _units are astropy.units.Unit or physical types &
            # define 'units' for unit checks &
            # define 'none_shall_pass' check
            _units = self._condition_target_units(
                _units, from_annotations=_units_are_from_anno
            )
            _units_anno = self._condition_target_units(
                _units_anno, from_annotations=True
            )
            if not all(_u in _units for _u in _units_anno):
                raise ValueError(
                    f"For argument '{param.name}', "
                    f"annotation units ({_units_anno}) are not included in the units "
                    f"specified by decorator arguments ({_units}).  Use either "
                    f"decorator arguments or function annotations to defined unit "
                    f"types, or make sure annotation specifications match decorator "
                    f"argument specifications."
                )
            if len(_units) == 0 and len(_units_anno) == 0 and param_checks is None:
                # annotations did not specify units
                continue
            elif len(_units) == 0 and len(_units_anno) == 0:
                # checks specified, but NO unit checks
                msg = f"No astropy.units specified for "
                if param.name == "checks_on_return":
                    msg += f"return value "
                else:
                    msg += f"argument {param.name} "
                msg += f"of function {self.f.__name__}()."
                raise ValueError(msg)

            out_checks[param.name] = {
                "units": _units,
                "none_shall_pass": _none_shall_pass,
            }

            # -- Determine target equivalencies --
            # Unit equivalences can be defined by:
            # 1. keyword pass-through via dictionary definition
            #    e.g. CheckUnits(x={'units': u.C,
            #                       'equivalencies': u.temperature})
            #
            # initialize equivalencies
            try:
                _equivs = param_checks["equivalencies"]
            except (KeyError, TypeError):
                _equivs = self.__check_defaults["equivalencies"]

            # ensure equivalences are properly formatted
            if _equivs is None or _equivs == [None]:
                _equivs = None
            elif isinstance(_equivs, Equivalency):
                pass
            elif isinstance(_equivs, (list, tuple)):

                # flatten list to non-list elements
                if isinstance(_equivs, tuple):
                    _equivs = [_equivs]
                else:
                    _equivs = self._flatten_equivalencies_list(_equivs)

                # ensure passed equivalencies list is structured properly
                #   [(), ...]
                #   or [Equivalency(), ...]
                #
                # * All equivalencies must be a list of 2, 3, or 4 element tuples
                #   structured like...
                #     (from_unit, to_unit, forward_func, backward_func)
                #
                if all(isinstance(el, Equivalency) for el in _equivs):
                    _equivs = reduce(lambda x, y: x + y, _equivs)
                else:
                    _equivs = self._normalize_equivalencies(_equivs)

            out_checks[param.name]["equivalencies"] = _equivs

            # -- Determine if equivalent units pass --
            try:
                peu = param_checks.get(
                    "pass_equivalent_units",
                    self.__check_defaults["pass_equivalent_units"],
                )
            except (AttributeError, TypeError):
                peu = self.__check_defaults["pass_equivalent_units"]

            out_checks[param.name]["pass_equivalent_units"] = peu

        # Does `self.checks` indicate arguments not used by f?
        missing_params = [
            param for param in set(self.checks.keys()) - set(out_checks.keys())
        ]
        if len(missing_params) > 0:
            params_str = ", ".join(missing_params)
            warnings.warn(
                PlasmaPyWarning(
                    f"Expected to unit check parameters {params_str} but they "
                    f"are missing from the call to {self.f.__name__}"
                )
            )

        return out_checks
示例#2
0
_name = 'nan_policy'
_type = "{'propagate', 'omit', 'raise'}"
_desc = ("""Defines how to handle input NaNs.

- ``propagate``: if a NaN is present in the axis slice (e.g. row) along
  which the  statistic is computed, the corresponding entry of the output
  will be NaN.
- ``omit``: NaNs will be omitted when performing the calculation.
  If insufficient data remains in the axis slice along which the
  statistic is computed, the corresponding entry of the output will be
  NaN.
- ``raise``: if a NaN is present, a ``ValueError`` will be raised.""".split(
    '\n'))
_nan_policy_parameter_doc = Parameter(_name, _type, _desc)
_nan_policy_parameter = inspect.Parameter(_name,
                                          inspect.Parameter.KEYWORD_ONLY,
                                          default='propagate')

_name = 'keepdims'
_type = "bool, default: False"
_desc = ("""If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array.""".split('\n'))
_keepdims_parameter_doc = Parameter(_name, _type, _desc)
_keepdims_parameter = inspect.Parameter(_name,
                                        inspect.Parameter.KEYWORD_ONLY,
                                        default=False)

_standard_note_addition = (
    """\nBeginning in SciPy 1.9, ``np.matrix`` inputs (not recommended for new
code) are converted to ``np.ndarray`` before the calculation is performed. In
示例#3
0
文件: checks.py 项目: zhyhou/PlasmaPy
    def _get_value_checks(
        self, bound_args: inspect.BoundArguments
    ) -> Dict[str, Dict[str, bool]]:
        """
        Review :attr:`checks` and function bound arguments to build a complete 'checks'
        dictionary.  If a check key is omitted from the argument checks, then a default
        value is assumed (see `check values`_).

        Parameters
        ----------
        bound_args: :class:`inspect.BoundArguments`
            arguments passed into the function being wrapped

            .. code-block:: python

                bound_args = inspect.signature(f).bind(*args, **kwargs)

        Returns
        -------
        Dict[str, Dict[str, bool]]
            A complete 'checks' dictionary for checking function input arguments
            and return.
        """
        # initialize validation dictionary
        out_checks = {}

        # Iterate through function bound arguments + return and build `out_checks:
        #
        # artificially add "return" to parameters
        things_to_check = bound_args.signature.parameters.copy()
        things_to_check["checks_on_return"] = inspect.Parameter(
            "checks_on_return",
            inspect.Parameter.POSITIONAL_ONLY,
            annotation=bound_args.signature.return_annotation,
        )
        for param in things_to_check.values():
            # variable arguments are NOT checked
            # e.g. in foo(x, y, *args, d=None, **kwargs) variable arguments
            #      *args and **kwargs will NOT be checked
            #
            if param.kind in (
                inspect.Parameter.VAR_KEYWORD,
                inspect.Parameter.VAR_POSITIONAL,
            ):
                continue

            # grab the checks dictionary for the desired parameter
            try:
                param_in_checks = self.checks[param.name]
            except KeyError:
                # checks for parameter not specified
                continue

            # build `out_checks`
            # read checks and/or apply defaults values
            out_checks[param.name] = {}
            for v_name, v_default in self.__check_defaults.items():
                try:
                    out_checks[param.name][v_name] = param_in_checks.get(
                        v_name, v_default
                    )
                except AttributeError:
                    # for the case that checks are defined for an argument,
                    # but is NOT a dictionary
                    # (e.g. CheckValues(x=u.cm) ... this scenario could happen
                    # during subclassing)
                    out_checks[param.name][v_name] = v_default

        # Does `self.checks` indicate arguments not used by f?
        missing_params = [
            param for param in set(self.checks.keys()) - set(out_checks.keys())
        ]
        if len(missing_params) > 0:
            params_str = ", ".join(missing_params)
            warnings.warn(
                PlasmaPyWarning(
                    f"Expected to value check parameters {params_str} but they "
                    f"are missing from the call to {self.f.__name__}"
                )
            )

        return out_checks
示例#4
0
        if not isinstance(file, file_types):
            if PY2:
                csvfile = open(file, 'wb')
            else:
                csvfile = open(file, 'w', newline='')
            autoclose = True
        else:
            csvfile = file
            autoclose = False

        try:
            writer = csv.writer(csvfile, **fmtparams)

            for row in reader:
                if nonstringiter(row):
                    writer.writerow(row)
                else:
                    writer.writerow([row])
        finally:
            if autoclose:
                csvfile.close()


with contextlib.suppress(AttributeError):  # inspect.Signature() is new in 3.3
    BaseQuery.__init__.__signature__ = inspect.Signature([
        inspect.Parameter('self', inspect.Parameter.POSITIONAL_ONLY),
        inspect.Parameter('columns', inspect.Parameter.POSITIONAL_ONLY),
        inspect.Parameter('where', inspect.Parameter.VAR_KEYWORD),
    ])
示例#5
0
def angular_freq_to_hz(fn):
    """
    A decorator that adds to a function the ability to convert the function's return from
    angular frequency (rad/s) to frequency (Hz).

    A kwarg `to_hz` is added to the function's signature, with a default value of `False`.
    The keyword is also added to the function's docstring under the **"Other Parameters"**
    heading.

    Parameters
    ----------
    fn : function
        The function to be decorated

    Raises
    ------
    ValueError
        If `fn` has already defined a kwarg `to_hz`

    Returns
    -------
    callable
        The decorated function

    Notes
    -----
    * If `angular_freq_to_hz` is used with decorator
      :func:`~plasmapy.utils.decorators.validate_quantities`, then
      `angular_freq_to_hz` should be used inside
      :func:`~plasmapy.utils.decorators.validate_quantities` but special
      consideration is needed for setup.  The following is an example of an
      appropriate setup::

        import astropy.units as u
        from plasmapy.utils.decorators.converter import angular_freq_to_hz
        from plasmapy.utils.decorators.validators import validate_quantities

        @validate_quantities(validations_on_return={'units': [u.rad / u.s, u.Hz]})
        @angular_freq_to_hz
        def foo(x: u.rad / u.s) -> u.rad / u.s
            return x

      Adding `u.Hz` to the allowed units allows the converted quantity to pass
      the validations.

    Examples
    --------

        >>> import astropy.units as u
        >>> from plasmapy.utils.decorators.converter import angular_freq_to_hz
        >>>
        >>> @angular_freq_to_hz
        ... def foo(x):
        ...     return x
        >>>
        >>> foo(5 * u.rad / u.s, to_hz=True)
        <Quantity 0.79577472 Hz>
        >>>
        >>> foo(-1 * u.rad / u.s, to_hz=True)
        <Quantity -0.15915494 Hz>

    Decoration also works with methods

        >>> class Foo:
        ...     def __init__(self, x):
        ...         self.x = x
        ...
        ...     @angular_freq_to_hz
        ...     def bar(self):
        ...         return self.x
        >>>
        >>> foo = Foo(0.5 * u.rad / u.s)
        >>> foo.bar(to_hz=True)
        <Quantity 0.07957747 Hz>

    """
    # raise exception if fn uses the 'to_hz' kwarg
    sig = inspect.signature(fn)
    if "to_hz" in sig.parameters:
        raise ValueError(
            f"Wrapped function '{fn.__name__}' can not use keyword 'to_hz'."
            f" Keyword reserved for decorator functionality.")

    # make new signature for fn
    new_params = sig.parameters.copy()
    new_params["to_hz"] = inspect.Parameter(
        "to_hz", inspect.Parameter.POSITIONAL_OR_KEYWORD, default=False)
    new_sig = inspect.Signature(parameters=new_params.values(),
                                return_annotation=sig.return_annotation)
    fn.__signature__ = new_sig

    @preserve_signature
    @functools.wraps(fn)
    def wrapper(*args, to_hz=False, **kwargs):
        _result = fn(*args, **kwargs)
        if to_hz:
            return _result.to(u.Hz, equivalencies=[(u.cy / u.s, u.Hz)])
        return _result

    added_doc_bit = """
    Other Parameters
    ----------------
    to_hz: bool
        Set `True` to to convert function output from angular frequency to Hz
    """
    if wrapper.__doc__ is not None:
        wrapper.__doc__ += added_doc_bit
    else:
        wrapper.__doc__ = added_doc_bit

    return wrapper
示例#6
0
    def decomposition_decorator(f: Callable) -> Callable:
        sig = inspect.signature(f)
        out_annotation = f.__annotations__.get("out")
        # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
        fn = f
        if out_annotation and getattr(out_annotation, "__origin__",
                                      None) is tuple:
            out_names = sig.return_annotation._fields
            # If out is a tuple, we need to register a function that unpacks all the out
            # elements as this is what native_functions.yaml expects

            @wraps(f)
            def _fn(*args, **kwargs):
                out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
                # Either all of the out kwargs are set or none of them
                is_none = out_kwargs[0] is None
                assert all((o is None) == is_none for o in out_kwargs)
                return f(*args, **kwargs, out=None if is_none else out_kwargs)

            out_params = [
                inspect.Parameter(
                    o,
                    kind=inspect.Parameter.KEYWORD_ONLY,
                    default=None,
                    annotation=t,
                ) for o, t in zip(out_names, out_annotation.__args__)
            ]
            # Drop the out parameter and concatenate the new kwargs in the signature
            params = chain(
                (v for k, v in sig.parameters.items() if k != "out"),
                out_params)
            _fn.__signature__ = inspect.Signature(  # type: ignore[attr-defined]
                parameters=params,
                return_annotation=sig.
                return_annotation  # type: ignore[arg-type]
            )
            # Drop the out parameter and concatenate the new kwargs in the annotations
            _fn.__annotations__ = {
                k: v
                for k, v in f.__annotations__.items() if k != "out"
            }
            for o in out_params:
                _fn.__annotations__[o.name] = o.annotation

            fn = _fn

        nonlocal registry
        if registry is None:
            registry = decomposition_table

        def add_op_to_table(aten_op):
            overloads = []
            if isinstance(aten_op, torch._ops.OpOverload):
                overloads.append(aten_op)
            else:
                assert isinstance(aten_op, torch._ops.OpOverloadPacket)
                for ol in aten_op.overloads():
                    overloads.append(getattr(aten_op, ol))
            for op_overload in overloads:
                if op_overload in registry:
                    raise RuntimeError(
                        f"duplicate registrations for {op_overload}")
                registry[op_overload] = fn
                # TODO: factor this logic into OpOverload or Library API
                name = op_overload._schema.name
                if op_overload._schema.overload_name:
                    name += "." + op_overload._schema.overload_name
                if (not disable_meta
                        # TorchScript dumps a bunch of extra nonsense overloads
                        # which don't have corresponding dispatcher entries, we need
                        # to filter those out
                        and torch._C._dispatch_has_kernel(name)
                        # Don't register a python meta kernel to any operator that has
                        # should already work with meta tensors today.
                        # We can check that by seeing if the "computed table" for the operator
                        # has a registration to Meta;
                        # either through a direct registration, or an indirect one through
                        # an alias dispatch key (e.g. CompositeImplicitAutograd)
                        and not torch._C.
                        _dispatch_has_computed_kernel_for_dispatch_key(
                            name, "Meta")):
                    if any(a.alias_info is not None
                           and not a.alias_info.is_write
                           for a in op_overload._schema.arguments):
                        raise RuntimeError(f"""
Attempting to register a python meta kernel for a view operator: {str(op_overload)}.
We shouldn't do this, because the output will report as not having aliased storages.
All view ops have meta kernels in C++ today, so we should use those instead.

If you're registering an operator through the `@register_decomposition` decorator,
Please set `disable_meta=True`.
                        """)
                    meta_lib.impl(op_overload, fn)

        # To handle allowing multiple aten_ops at once
        tree_map(add_op_to_table, aten_op)
        return fn
示例#7
0
 def match(self, hint: object, injectable: Injectable,
           container: Container):
     sub_hint = getattr(hint, '__args__', [None])[0]
     param = inspect.Parameter(name='_', kind=1, annotation=sub_hint)
     return get_candidates(param.annotation, container=container)
示例#8
0
    def decorator_maker(tested_scene_construct):
        if (
            SCENE_PARAMETER_NAME
            not in inspect.getfullargspec(tested_scene_construct).args
        ):
            raise Exception(
                f"Invalid graphical test function test function : must have '{SCENE_PARAMETER_NAME}'as one of the parameters.",
            )

        # Exclude "scene" from the argument list of the signature.
        old_sig = inspect.signature(
            functools.partial(tested_scene_construct, scene=None),
        )

        if "__module_test__" not in tested_scene_construct.__globals__:
            raise Exception(
                "There is no module test name indicated for the graphical unit test. You have to declare __module_test__ in the test file.",
            )
        module_name = tested_scene_construct.__globals__.get("__module_test__")
        test_name = tested_scene_construct.__name__[len("test_") :]

        @functools.wraps(tested_scene_construct)
        # The "request" parameter is meant to be used as a fixture by pytest. See below.
        def wrapper(*args, request: FixtureRequest, tmp_path, **kwargs):
            # Wraps the test_function to a construct method, to "freeze" the eventual additional arguments (parametrizations fixtures).
            construct = functools.partial(tested_scene_construct, *args, **kwargs)

            # Kwargs contains the eventual parametrization arguments.
            # This modifies the test_name so that it is defined by the parametrization
            # arguments too.
            # Example: if "length" is parametrized from 0 to 20, the kwargs
            # will be once with {"length" : 1}, etc.
            test_name_with_param = test_name + "_".join(
                f"_{str(tup[0])}[{str(tup[1])}]" for tup in kwargs.items()
            )

            config_tests = _config_test(last_frame)

            config_tests["text_dir"] = tmp_path
            config_tests["tex_dir"] = tmp_path

            if last_frame:
                config_tests["frame_rate"] = 1
                config_tests["dry_run"] = True

            setting_test = request.config.getoption("--set_test")
            try:
                test_file_path = tested_scene_construct.__globals__["__file__"]
            except Exception:
                test_file_path = None
            real_test = _make_test_comparing_frames(
                file_path=_control_data_path(
                    test_file_path,
                    module_name,
                    test_name_with_param,
                    setting_test,
                ),
                base_scene=base_scene,
                construct=construct,
                renderer_class=renderer_class,
                is_set_test_data_test=setting_test,
                last_frame=last_frame,
                show_diff=request.config.getoption("--show_diff"),
                size_frame=(config_tests["pixel_height"], config_tests["pixel_width"]),
            )

            # Isolate the config used for the test, to avoid modifying the global config during the test run.
            with tempconfig({**config_tests, **custom_config}):
                real_test()

        parameters = list(old_sig.parameters.values())
        # Adds "request" param into the signature of the wrapper, to use the associated pytest fixture.
        # This fixture is needed to have access to flags value and pytest's config. See above.
        if "request" not in old_sig.parameters:
            parameters += [inspect.Parameter("request", inspect.Parameter.KEYWORD_ONLY)]
        if "tmp_path" not in old_sig.parameters:
            parameters += [
                inspect.Parameter("tmp_path", inspect.Parameter.KEYWORD_ONLY),
            ]
        new_sig = old_sig.replace(parameters=parameters)
        wrapper.__signature__ = new_sig

        # Reach a bit into pytest internals to hoist the marks from our wrapped
        # function.
        setattr(wrapper, "pytestmark", [])
        new_marks = getattr(tested_scene_construct, "pytestmark", [])
        wrapper.pytestmark = new_marks
        return wrapper
示例#9
0
def post_init(cls: Type[U]) -> Type[U]:
    """
    Class decorator to automatically support __post_init__() on classes

    This is useful for @attr.s decorated classes, because __attr_post_init__() doesn't
    support additional arguments.

    This decorators wraps the class __init__ in a new function that accept merged arguments,
    and dispatch them to __init__ and then __post_init__()
    """
    if not isinstance(cls, type):
        raise TypeError("Can only decorate classes")
    if not hasattr(cls, "__post_init__"):
        raise TypeError("The class must have a __post_init__() method")
    # Ignore the first argument which is the "self" argument
    sig = init_sig = _sig_without(inspect.signature(cls.__init__), 0)
    previous = [(cls, "__init__", sig)]
    for parent in reversed(cls.__mro__):
        if hasattr(parent, "__post_init__"):
            post_sig = _sig_without(
                inspect.signature(getattr(parent, "__post_init__")), 0
            )
            try:
                sig = _sig_merge(sig, post_sig)
            except Exception as err:
                # find the incompatibility
                for parent, method, psig in previous:
                    try:
                        _sig_merge(psig, post_sig)
                    except Exception:
                        break
                else:
                    raise TypeError(
                        "__post_init__ signature is incompatible with the class"
                    ) from err
                raise TypeError(
                    f"__post_init__() is incompatible with {parent.__qualname__}{method}()"
                ) from err
            # No exception
            previous.append((parent, "__post_init__", post_sig))
    # handles type annotations and defaults
    # inspired by the dataclasses modules
    params = list(sig.parameters.values())
    localns = (
        {
            f"__type_{p.name}": p.annotation
            for p in params
            if p.annotation is not inspect.Parameter.empty
        }
        | {
            f"__default_{p.name}": p.default
            for p in params
            if p.default is not inspect.Parameter.empty
        }
        | cls.__dict__
    )
    for i, p in enumerate(params):
        if p.default is not inspect.Parameter.empty:
            p = p.replace(default=Variable(f"__default_{p.name}"))
        if p.annotation is not inspect.Parameter.empty:
            p = p.replace(annotation=f"__type_{p.name}")
        params[i] = p
    new_sig = inspect.Signature(params)
    # Build the new __init__ source code
    self_ = "self" if "self" not in sig.parameters else "__post_init_self"
    init_lines = [
        f"def __init__({self_}, {_sig_to_def(new_sig)}) -> None:",
        f"__original_init({self_}, {_sig_to_call(init_sig)})",
    ]
    for parent, method, psig in previous[1:]:
        if hasattr(parent, "__post_init__"):
            if parent is not cls:
                init_lines.append(
                    f"super({parent.__qualname__}, {self_}).{method}({_sig_to_call(psig)})"
                )
            else:
                init_lines.append(f"{self_}.{method}({_sig_to_call(psig)})")
    init_src = "\n  ".join(init_lines)
    # Build the factory function source code
    local_vars = ", ".join(localns.keys())
    factory_src = (
        f"def __make_init__(__original_init, {local_vars}):\n"
        f" {init_src}\n"
        " return __init__"
    )
    # Create new __init__ with the factory
    globalns = inspect.getmodule(cls).__dict__
    ns: dict[str, Any] = {}
    exec(factory_src, globalns, ns)
    init = ns["__make_init__"](cls.__init__, **localns)
    self_param = inspect.Parameter(self_, inspect.Parameter.POSITIONAL_ONLY)
    init.__signature__ = inspect.Signature(
        parameters=[self_param] + list(sig.parameters.values()), return_annotation=None
    )
    setattr(cls, "__init__", init)
    return cls
示例#10
0
def test_conv_str_choices_valid():
    """Calling str type with str_choices and valid value."""
    param = inspect.Parameter('foo', inspect.Parameter.POSITIONAL_ONLY)
    converted = argparser.type_conv(param, str, 'val1',
                                    str_choices=['val1', 'val2'])
    assert converted == 'val1'
示例#11
0
def test_conv_str_choices_invalid():
    """Calling str type with str_choices and invalid value."""
    param = inspect.Parameter('foo', inspect.Parameter.POSITIONAL_ONLY)
    with pytest.raises(cmdexc.ArgumentTypeError, match='foo: Invalid value '
                       'val3 - expected one of: val1, val2'):
        argparser.type_conv(param, str, 'val3', str_choices=['val1', 'val2'])
示例#12
0
def test_multitype_conv_invalid_type():
    """Test using an invalid type with a multitype converter."""
    param = inspect.Parameter('foo', inspect.Parameter.POSITIONAL_ONLY)
    with pytest.raises(ValueError, match="foo: Unknown type None!"):
        argparser.multitype_conv(param, [None], '')
示例#13
0
 def _get_sig(cls):
     return inspect.Signature([inspect.Parameter(n, inspect.Parameter.POSITIONAL_OR_KEYWORD)
                               for n in cls.field_names])
示例#14
0
 def _fields_from_attrs(kind: ParameterKind, attrs: Tuple[str, ...]):
     return {x: inspect.Parameter(x, kind) for x in attrs}
示例#15
0
文件: op_utils.py 项目: zu3st/chainer
def _create_test_entry_function(cls, module, devices):
    # Creates a test entry function from the template class, and places it in
    # the same module as the class.

    # We enforce 'Test' prefix in OpTest implementations so that they look like
    # unittest.TestCase implementations. OTOH generated entry function must
    # have a prefix 'test_' in order for it to be found in pytest test
    # collection.
    if not cls.__name__.startswith('Test'):
        raise TypeError(
            'OpTest class name must start with \'Test\'. Actual: {!r}'.format(
                cls.__name__))

    func_name = 'test_{}'.format(cls.__name__[len('Test'):])

    @pytest.mark.parametrize_device(devices)
    def entry_func(device, *args, **kwargs):
        backend_config = _make_backend_config(device.name)

        # Forward test
        obj = cls()
        try:
            obj.setup(*args, **kwargs)
            obj.run_test_forward(backend_config)
        finally:
            obj.teardown()

        # If this is a NumpyOpTest instance, skip backward/double-backward
        # tests if the forward test succeeds with acceptable errors.
        if isinstance(obj, NumpyOpTest):
            if obj.is_forward_successful_with_accept_errors:
                return  # success with expected errors

        # Backward test
        obj = cls()
        try:
            obj.setup(*args, **kwargs)
            obj.run_test_backward(backend_config)
        finally:
            obj.teardown()

        # Double-backward test
        obj = cls()
        try:
            obj.setup(*args, **kwargs)
            obj.run_test_double_backward(backend_config)
        finally:
            obj.teardown()

    entry_func.__name__ = func_name

    # Set the signature of the entry function
    sig = inspect.signature(cls.setup)
    params = list(sig.parameters.values())
    params = params[1:]  # Remove `self` argument
    device_param = inspect.Parameter('device',
                                     inspect.Parameter.POSITIONAL_OR_KEYWORD)
    params = [device_param] + params  # Prepend `device` argument
    entry_func.__signature__ = inspect.Signature(params)

    # Set the pytest mark
    try:
        pytestmark = cls.pytestmark
        entry_func.pytestmark += pytestmark
    except AttributeError:
        pass

    # Place the entry function in the module of the class
    setattr(module, func_name, entry_func)
示例#16
0
    def add(
        self,
        instruction: Union[str, Instruction],
        qubits: Union[int, Iterable[int]],
        schedule: Union[Schedule, ScheduleBlock,
                        Callable[..., Union[Schedule, ScheduleBlock]]],
        arguments: Optional[List[str]] = None,
    ) -> None:
        """Add a new known instruction for the given qubits and its mapping to a pulse schedule.

        Args:
            instruction: The name of the instruction to add.
            qubits: The qubits which the instruction applies to.
            schedule: The Schedule that implements the given instruction.
            arguments: List of parameter names to create a parameter-bound schedule from the
                associated gate instruction. If :py:meth:`get` is called with arguments rather
                than keyword arguments, this parameter list is used to map the input arguments to
                parameter objects stored in the target schedule.

        Raises:
            PulseError: If the qubits are provided as an empty iterable.
        """
        instruction = _get_instruction_string(instruction)

        # validation of target qubit
        qubits = _to_tuple(qubits)
        if qubits == ():
            raise PulseError(
                f"Cannot add definition {instruction} with no target qubits.")

        # generate signature
        if isinstance(schedule, (Schedule, ScheduleBlock)):
            ordered_names = sorted(
                list({par.name
                      for par in schedule.parameters}))
            if arguments:
                if set(arguments) != set(ordered_names):
                    raise PulseError(
                        "Arguments does not match with schedule parameters. "
                        f"{set(arguments)} != {schedule.parameters}.")
                ordered_names = arguments

            parameters = list()
            for argname in ordered_names:
                param_signature = inspect.Parameter(
                    name=argname,
                    annotation=ParameterValueType,
                    kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                )
                parameters.append(param_signature)
            signature = inspect.Signature(parameters=parameters,
                                          return_annotation=type(schedule))

        elif isinstance(schedule, ParameterizedSchedule):
            # TODO remove this
            warnings.warn(
                "ParameterizedSchedule has been deprecated. "
                "Define Schedule with Parameter objects.",
                DeprecationWarning,
            )

            parameters = list()
            for argname in schedule.parameters:
                param_signature = inspect.Parameter(
                    name=argname,
                    annotation=ParameterValueType,
                    kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
                )
                parameters.append(param_signature)
            signature = inspect.Signature(parameters=parameters,
                                          return_annotation=Schedule)

        elif callable(schedule):
            if arguments:
                warnings.warn(
                    "Arguments are overridden by the callback function signature. "
                    "Input `arguments` are ignored.",
                    UserWarning,
                )
            signature = inspect.signature(schedule)

        else:
            raise PulseError(
                "Supplied schedule must be one of the Schedule, ScheduleBlock or a "
                "callable that outputs a schedule.")

        self._map[instruction][qubits] = Generator(schedule, signature)
        self._qubit_instructions[qubits].add(instruction)
示例#17
0
    def the_decorator(func):
        if Version(current_library_version) >= version:
            return func
        new_signature = inspect.signature(func)

        new_arg_names = [name for name in new_signature.parameters]

        if previous_arg_order is None:
            old_nargs = len(new_signature.parameters)
            old_arg_names = new_arg_names[:old_nargs]
        else:
            old_nargs = len(previous_arg_order)
            old_arg_names = previous_arg_order
        # These arguments are still required as argument only keywords
        func_args = inspect.getfullargspec(func).args
        new_nargs = len(func_args)

        old_parameters = []
        all_params = {**new_signature.parameters}
        for key in old_arg_names:
            param = all_params.pop(key)
            if key in func_args:
                kind = param.kind
            else:
                kind = POSITIONAL_OR_KEYWORD

            old_parameters.append(
                inspect.Parameter(key,
                                  kind=kind,
                                  default=param.default,
                                  annotation=param.annotation))
        for key, param in all_params.items():
            old_parameters.append(
                inspect.Parameter(key,
                                  kind=param.kind,
                                  default=param.default,
                                  annotation=param.annotation))

        old_signature = new_signature.replace(parameters=old_parameters)

        @wraps(func)
        def wrapper(*args, **kwargs):

            if len(args) > old_nargs:
                # The warning should be issued here too!
                raise TypeError('{name}() takes {old_nargs} positional '
                                'arguments but {len_args} were given'
                                ''.format(name=func.__name__,
                                          old_nargs=old_nargs,
                                          len_args=len(args)))

            if len(args) > new_nargs:
                for key, value in zip(old_arg_names[new_nargs:len(args)],
                                      args[new_nargs:]):
                    if key in kwargs:
                        calling_function = inspect.stack()[1]

                        s = SyntaxError(
                            "In version {version} of {library_name}, the "
                            "argument ('{key}') has "
                            "will become a keyword only argument. You "
                            "specified it as both a positional argument and "
                            "a keyword argument."
                            "".format(version=version,
                                      library_name=library_name,
                                      key=key))
                        s.lineno = calling_function.lineno
                        s.filename = calling_function.filename
                        # Even the normal syntax errors suck at telling you
                        # the position of your error whe a statement spans
                        # multiple lines
                        # s.offset =  # is it worth it to find where?
                        raise s
                    kwargs[key] = value

                warn("In version {version} of {library_name}, the "
                     "argument(s): '{old_pos_args}' will become keyword-only "
                     "argument(s). To suppress this warning, specify all "
                     "listed argument(s) with keywords."
                     "".format(
                         version=version,
                         library_name=library_name,
                         old_pos_args=old_arg_names[new_nargs:len(args)]),
                     FutureWarning,
                     stacklevel=2)

                args = args[:new_nargs]

            return func(*args, **kwargs)

        if keep_old_signature:
            wrapper.__signature__ = old_signature

        # only add a docstring if they had one already
        if wrapper.__doc__ is None:
            return wrapper

        warnings_string = """
Warns
-----
FutureWarning
  In release {version} of {module}, the argument(s):

    `{args}`

  will become keyword-only arguments. To avoid this warning,
  provide all the above arguments as keyword arguments.

""".format(version=version,
           module=library_name,
           funcname=func.__name__,
           args=', '.join(old_arg_names[new_nargs:]))

        wrapper.__doc__ = merge_docstrings(wrapper, warnings_string)
        return wrapper
示例#18
0
def replace_parameter(
    param: inspect.Parameter,
    converter: Any,
    callback: Callable[..., Any],
    original: Parameter,
    mapping: Dict[str, inspect.Parameter],
) -> inspect.Parameter:
    try:
        # If it's a supported annotation (i.e. a transformer) just let it pass as-is.
        app_commands.transformers.get_supported_annotation(converter)
    except TypeError:
        # Fallback to see if the behaviour needs changing
        origin = getattr(converter, '__origin__', None)
        args = getattr(converter, '__args__', [])
        if isinstance(converter, Range):
            r = converter
            param = param.replace(
                annotation=app_commands.Range[r.annotation, r.min,
                                              r.max])  # type: ignore
        elif isinstance(converter, Greedy):
            # Greedy is "optional" in ext.commands
            # However, in here, it probably makes sense to make it required.
            # I'm unsure how to allow the user to choose right now.
            inner = converter.converter
            if inner is discord.Attachment:
                raise TypeError(
                    'discord.Attachment with Greedy is not supported in hybrid commands'
                )

            param = param.replace(
                annotation=make_greedy_transformer(inner, original))
        elif is_flag(converter):
            callback.__hybrid_command_flag__ = (param.name, converter)
            descriptions = {}
            renames = {}
            for flag in converter.__commands_flags__.values():
                name = flag.attribute
                flag_param = inspect.Parameter(
                    name=name,
                    kind=param.kind,
                    default=flag.default if flag.default is not MISSING else
                    inspect.Parameter.empty,
                    annotation=flag.annotation,
                )
                pseudo = replace_parameter(flag_param, flag.annotation,
                                           callback, original, mapping)
                if name in mapping:
                    raise TypeError(
                        f'{name!r} flag would shadow a pre-existing parameter')
                if flag.description is not MISSING:
                    descriptions[name] = flag.description
                if flag.name != flag.attribute:
                    renames[name] = flag.name

                mapping[name] = pseudo

            # Manually call the decorators
            if descriptions:
                app_commands.describe(**descriptions)(callback)
            if renames:
                app_commands.rename(**renames)(callback)

        elif is_converter(converter):
            param = param.replace(
                annotation=make_converter_transformer(converter))
        elif origin is Union:
            if len(args) == 2 and args[-1] is _NoneType:
                # Special case Optional[X] where X is a single type that can optionally be a converter
                inner = args[0]
                is_inner_tranformer = is_transformer(inner)
                if is_converter(inner) and not is_inner_tranformer:
                    param = param.replace(annotation=Optional[
                        make_converter_transformer(inner)])  # type: ignore
            else:
                raise
        elif origin:
            # Unsupported typing.X annotation e.g. typing.Dict, typing.Tuple, typing.List, etc.
            raise
        elif callable(converter) and not inspect.isclass(converter):
            param_count = required_pos_arguments(converter)
            if param_count != 1:
                raise
            param = param.replace(
                annotation=make_callable_transformer(converter))

    return param
示例#19
0
文件: loa.py 项目: Mercy1/Cog-testing
def RaiseMissingArguement():
    raise commands.MissingRequiredArgument(
        inspect.Parameter("startdate", inspect.Parameter.POSITIONAL_ONLY))
示例#20
0
    def decorator(obj):
        """Outer wrapper.

        The outer wrapper is used to create the decorating wrapper.

        @param obj: function being wrapped
        @type obj: object
        """
        def wrapper(*__args, **__kw):
            """Replacement function.

            @param __args: args passed to the decorated function
            @type __args: list
            @param __kwargs: kwargs passed to the decorated function
            @type __kwargs: dict
            @return: the value returned by the decorated function
            @rtype: any
            """
            name = obj.__full_name__
            for old_arg, new_arg in arg_pairs.items():
                if old_arg in __kw:
                    if new_arg not in [True, False, None]:
                        if new_arg in __kw:
                            warning(u"%(new_arg)s argument of %(name)s "
                                    "replaces %(old_arg)s; cannot use both." %
                                    locals())
                        else:
                            # If the value is positionally given this will
                            # cause a TypeError, which is intentional
                            warning(u"%(old_arg)s argument of %(name)s "
                                    "is deprecated; use %(new_arg)s instead." %
                                    locals())
                            __kw[new_arg] = __kw[old_arg]
                    elif new_arg is not False:
                        debug(
                            u"%(old_arg)s argument of %(name)s is "
                            "deprecated." % locals(), _logger)
                    del __kw[old_arg]
            return obj(*__args, **__kw)

        if not __debug__:
            return obj

        wrapper.__doc__ = obj.__doc__
        wrapper.__name__ = obj.__name__
        wrapper.__module__ = obj.__module__
        wrapper.__signature__ = signature(obj)
        if wrapper.__signature__:
            # Build a new signature with deprecated args added.
            params = collections.OrderedDict()
            for param in wrapper.__signature__.parameters.values():
                params[param.name] = param.replace()
            for old_arg, new_arg in arg_pairs.items():
                params[old_arg] = inspect.Parameter(
                    old_arg,
                    kind=inspect._POSITIONAL_OR_KEYWORD,
                    default='[deprecated name of ' + new_arg + ']'
                    if new_arg not in [True, False, None] else NotImplemented)
            wrapper.__signature__ = inspect.Signature()
            wrapper.__signature__._parameters = params
        if not hasattr(obj, '__full_name__'):
            add_decorated_full_name(obj)
        wrapper.__full_name__ = obj.__full_name__
        return wrapper
示例#21
0
    output = np.ones(output_shape) * np.nan
    return output


# Standard docstring / signature entries for `axis` and `nan_policy`
_name = 'axis'
_type = "int or None, default: 0"
_desc = (
    """If an int, the axis of the input along which to compute the statistic.
The statistic of each axis-slice (e.g. row) of the input will appear in a
corresponding element of the output.
If ``None``, the input will be raveled before computing the statistic.""".
    split('\n'))
_axis_parameter_doc = Parameter(_name, _type, _desc)
_axis_parameter = inspect.Parameter(_name,
                                    inspect.Parameter.KEYWORD_ONLY,
                                    default=0)

_name = 'nan_policy'
_type = "{'propagate', 'omit', 'raise'}"
_desc = ("""Defines how to handle input NaNs.

- ``propagate``: if a NaN is present in the axis slice (e.g. row) along
  which the  statistic is computed, the corresponding entry of the output
  will be NaN.
- ``omit``: NaNs will be omitted when performing the calculation.
  If insufficient data remains in the axis slice along which the
  statistic is computed, the corresponding entry of the output will be
  NaN.
- ``raise``: if a NaN is present, a ``ValueError`` will be raised.""".split(
    '\n'))
            else:
                raise TypeError

        return deviation or 0, expected or 0

    def call_predicate(self, item):
        _, diff = item  # Unpack item (discarding key).
        try:
            deviation, _ = self._get_deviation_expected(diff)
        except TypeError:
            return False  # <- EXIT!
        return self.lower <= deviation <= self.upper

with contextlib.suppress(AttributeError):  # inspect.Signature() is new in 3.3
    AcceptedTolerance.__init__.__signature__ = inspect.Signature([
        inspect.Parameter('self', inspect.Parameter.POSITIONAL_ONLY),
        inspect.Parameter('tolerance', inspect.Parameter.POSITIONAL_ONLY),
        inspect.Parameter('msg', inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None),
    ])


class AcceptedPercent(AcceptedTolerance):
    """AcceptedPercent(tolerance, /, msg=None)
    AcceptedPercent(lower, upper, msg=None)

    Context manager that accepts Deviations within a given percent
    tolerance without triggering a test failure.

    See documentation for full details.
    """
    def call_predicate(self, item):
示例#23
0
    def decorator(func):
        import pytest

        _, result_dir = _image_directories(func)
        old_sig = inspect.signature(func)

        if not {"fig_test", "fig_ref"}.issubset(old_sig.parameters):
            raise ValueError("The decorated function must have at least the "
                             "parameters 'fig_ref' and 'fig_test', but your "
                             f"function has the signature {old_sig}")

        @pytest.mark.parametrize("ext", extensions)
        def wrapper(*args, ext, request, **kwargs):
            if 'ext' in old_sig.parameters:
                kwargs['ext'] = ext
            if 'request' in old_sig.parameters:
                kwargs['request'] = request

            file_name = "".join(c for c in request.node.name
                                if c in ALLOWED_CHARS)
            try:
                fig_test = plt.figure("test")
                fig_ref = plt.figure("reference")
                # Keep track of number of open figures, to make sure test
                # doesn't create any new ones
                n_figs = len(plt.get_fignums())
                func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs)
                if len(plt.get_fignums()) > n_figs:
                    raise RuntimeError('Number of open figures changed during '
                                       'test. Make sure you are plotting to '
                                       'fig_test or fig_ref, or if this is '
                                       'deliberate explicitly close the '
                                       'new figure(s) inside the test.')
                test_image_path = result_dir / (file_name + "." + ext)
                ref_image_path = result_dir / (file_name + "-expected." + ext)
                fig_test.savefig(test_image_path)
                fig_ref.savefig(ref_image_path)
                _raise_on_image_difference(ref_image_path,
                                           test_image_path,
                                           tol=tol)
            finally:
                plt.close(fig_test)
                plt.close(fig_ref)

        parameters = [
            param for param in old_sig.parameters.values()
            if param.name not in {"fig_test", "fig_ref"}
        ]
        if 'ext' not in old_sig.parameters:
            parameters += [inspect.Parameter("ext", KEYWORD_ONLY)]
        if 'request' not in old_sig.parameters:
            parameters += [inspect.Parameter("request", KEYWORD_ONLY)]
        new_sig = old_sig.replace(parameters=parameters)
        wrapper.__signature__ = new_sig

        # reach a bit into pytest internals to hoist the marks from
        # our wrapped function
        new_marks = getattr(func, "pytestmark", []) + wrapper.pytestmark
        wrapper.pytestmark = new_marks

        return wrapper
示例#24
0
    def decomposition_decorator(f: Callable) -> Callable:
        sig = inspect.signature(f)
        out_annotation = f.__annotations__.get("out")
        # Hack to detect when out is a Tuple. There seems to be no pretty way of doing this
        fn = f
        if out_annotation and getattr(out_annotation, "__origin__", None) is tuple:
            out_names = sig.return_annotation._fields
            # If out is a tuple, we need to register a function that unpacks all the out
            # elements as this is what native_functions.yaml expects

            @wraps(f)
            def _fn(*args, **kwargs):
                out_kwargs = tuple(kwargs.pop(o, None) for o in out_names)
                # Either all of the out kwargs are set or none of them
                is_none = out_kwargs[0] is None
                assert all((o is None) == is_none for o in out_kwargs)
                return f(*args, **kwargs, out=None if is_none else out_kwargs)

            out_params = [
                inspect.Parameter(
                    o,
                    kind=inspect.Parameter.KEYWORD_ONLY,
                    default=None,
                    annotation=t,
                )
                for o, t in zip(out_names, out_annotation.__args__)
            ]
            # Drop the out parameter and concatenate the new kwargs in the signature
            params = chain(
                (v for k, v in sig.parameters.items() if k != "out"), out_params
            )
            _fn.__signature__ = inspect.Signature(  # type: ignore[attr-defined]
                parameters=params, return_annotation=sig.return_annotation  # type: ignore[arg-type]
            )
            # Drop the out parameter and concatenate the new kwargs in the annotations
            _fn.__annotations__ = {
                k: v for k, v in f.__annotations__.items() if k != "out"
            }
            for o in out_params:
                _fn.__annotations__[o.name] = o.annotation

            fn = _fn

        nonlocal registry
        if registry is None:
            registry = decomposition_table

        def add_op_to_table(aten_op):
            overloads = []
            if isinstance(aten_op, torch._ops.OpOverload):
                overloads.append(aten_op)
            else:
                assert isinstance(aten_op, torch._ops.OpOverloadPacket)
                for ol in aten_op.overloads():
                    overloads.append(getattr(aten_op, ol))
            for op_overload in overloads:
                if op_overload in registry:
                    raise RuntimeError(f"duplicate registrations for {op_overload}")
                registry[op_overload] = fn
                # TODO: factor this logic into OpOverload or Library API
                name = op_overload._schema.name
                if op_overload._schema.overload_name:
                    name += "." + op_overload._schema.overload_name
                if (
                    not disable_meta
                    # TorchScript dumps a bunch of extra nonsense overloads
                    # which don't have corresponding dispatcher entries, we need
                    # to filter those out
                    and torch._C._dispatch_has_kernel(name)
                    # Don't register a meta kernel to any operator that has
                    # a CompositeImplicitAutograd kernel in core.
                    # Otherwise we won't be able to run autograd for that operator with the meta backend.
                    and "CompositeImplicitAutograd" not in torch._C._dispatch_dump(name)
                    and not torch._C._dispatch_has_kernel_for_dispatch_key(name, "Meta")
                ):
                    meta_lib.impl(op_overload, fn)

        # To handle allowing multiple aten_ops at once
        tree_map(add_op_to_table, aten_op)
        return fn