Ejemplo n.º 1
0
        def _fn(*args, **kwargs):
            bound = sig.bind(*args, **kwargs)
            type_promoting_args = tuple(
                bound.arguments[x] for x in
                self.type_promoting_arg_names  # type: ignore[union-attr]
                if x in bound.arguments.keys())

            flattened_type_promoting_args = tree_flatten(
                type_promoting_args)[0]
            compute_dtype, result_dtype = utils.elementwise_dtypes(
                *flattened_type_promoting_args,
                type_promotion_kind=self.type_promotion_kind,
            )

            promoted_args = {
                x: _maybe_convert_to_dtype(bound.arguments[x], compute_dtype)
                for x in
                self.type_promoting_arg_names  # type: ignore[union-attr]
                if x in bound.arguments.keys()
            }
            bound.arguments.update(promoted_args)

            result = fn(**bound.arguments)

            # FIXME?: assumes result is a single tensor
            assert isinstance(result, TensorLike)
            return _maybe_convert_to_dtype(result, result_dtype)
Ejemplo n.º 2
0
    def inner(*args, **kwargs):
        flat_args = [
            x for x in tree_flatten((args, kwargs))[0] if isinstance(x, Tensor)
        ]
        computation_dtype, result_dtype = utils.elementwise_dtypes(
            *flat_args, type_promotion_kind=type_promotion
        )

        # TODO: pretty sure this is not quite right
        def increase_prec(x):
            if isinstance(x, Tensor):
                return x.to(computation_dtype)
            else:
                return x

        def decrease_prec(x):
            if isinstance(x, Tensor):
                return x.to(result_dtype)
            else:
                return x

        r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
        return tree_map(decrease_prec, r)
Ejemplo n.º 3
0
def meta_angle(self):
    _, result_dtype = elementwise_dtypes(
        self, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
    )
    return self.new_empty(self.size(), dtype=result_dtype)