Esempio n. 1
0
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                named_arg_list = normalize_function(
                    func,
                    args,
                    kwargs,
                    normalize_to_only_use_kwargs=True
                ).kwargs
                schema_info_value_test = torch._C._SchemaInfo(func._schema)
                schema_info_values_test = torch._C._SchemaInfo(func._schema)
                self.test_self.assertFalse(schema_info_value_test.may_alias(
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
                self.test_self.assertFalse(schema_info_values_test.may_alias(
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
                for i in named_arg_list:
                    schema_info_value_test.add_argument_value(i, named_arg_list[i])
                schema_info_values_test.add_argument_values(named_arg_list)
                self.test_self.assertTrue(schema_info_value_test.may_alias(
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))
                self.test_self.assertTrue(schema_info_values_test.may_alias(
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0),
                    torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1)))

                return func(*args, **kwargs)
Esempio n. 2
0
    def normalized_arguments(
            self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None,
            kwarg_types : Optional[Dict[str, Any]] = None,
            normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]:
        """
        Returns normalized arguments to Python targets. This means that
        `args/kwargs` will be matched up to the module/functional's
        signature and return exclusively kwargs in positional order
        if `normalize_to_only_use_kwargs` is true.
        Also populates default values. Does not support positional-only
        parameters or varargs parameters.

        Supports module calls.

        May require `arg_types` and `kwarg_types` in order to disambiguate overloads.

        Args:
            root (torch.nn.Module): Module upon which to resolve module targets.
            arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args
            kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs
            normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs.

        Returns:

            Returns NamedTuple ArgsKwargsPair, or `None` if not successful.
        """
        if self.op == 'call_function':
            assert callable(self.target)
            return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types)  # type: ignore[arg-type]
        elif self.op == 'call_module':
            assert isinstance(self.target, str)
            return normalize_module(root, self.target, self.args, self.kwargs)  # type: ignore[arg-type]

        return None
Esempio n. 3
0
def index_put_(fake_mode, func, *args, **kwargs):
    with in_kernel_invocation_manager(fake_mode):
        out = func(*args, **kwargs)

    _, new_kwargs = normalize_function(
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
    )

    return new_kwargs["input"]
Esempio n. 4
0
def non_kwarg_to(fake_mode, func, *args, **kwargs):
    _, new_kwargs = normalize_function(
        func, args, kwargs, normalize_to_only_use_kwargs=True
    )
    input_device = new_kwargs["device"]
    out_device = input_device if input_device else new_kwargs["input"].device
    new_kwargs["device"] = torch.device("meta")
    r = func(*args, **new_kwargs)
    return fake_mode.fake_tensor_converter(fake_mode, r, out_device)
Esempio n. 5
0
def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
    _, new_kwargs = normalize_function(
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
    )

    out_device = new_kwargs["input"].device
    with in_kernel_invocation_manager(fake_mode):
        out = func(*args, **kwargs)

    return FakeTensor(fake_mode, out, out_device)
Esempio n. 6
0
def to_copy(fake_mode, func, *args, **kwargs):
    _, new_kwargs = normalize_function(
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
    )

    input_device = new_kwargs.pop("device", None)
    out_device = input_device if input_device else new_kwargs["input"].device
    with no_dispatch():
        input = new_kwargs.pop("input").to("meta")
        return FakeTensor(fake_mode, aten._to_copy(input, **new_kwargs), out_device)
Esempio n. 7
0
 def call_function(
         self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any],
         arg_types: Optional[Tuple[Any, ...]] = None, kwarg_types : Optional[Dict[str, Any]] = None):
     assert callable(target)
     new_args_and_kwargs = normalize_function(target, args, kwargs, arg_types, kwarg_types,  # type: ignore[arg-type]
                                              self.normalize_to_only_use_kwargs)
     if new_args_and_kwargs:
         new_args, new_kwargs = new_args_and_kwargs
         return self.tracer.create_proxy('call_function', target, new_args, new_kwargs)
     else:
         return super().call_function(target, args, kwargs)
Esempio n. 8
0
def index_tensor(fake_mode, func, *args, **kwargs):
    # dynamic shape op if indices are bool/uint8
    check_no_bool_index_tensors(func, *args, **kwargs)

    _, new_kwargs = normalize_function(
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
    )

    out_device = new_kwargs["input"].device
    with in_kernel_invocation_manager(fake_mode):
        out = func(*args, **kwargs)

    return FakeTensor(fake_mode, out, out_device)
Esempio n. 9
0
 def call_function(self,
                   target: Target,
                   args: Tuple[Argument, ...],
                   kwargs: Dict[str, Any],
                   arg_types: Optional[Tuple[Any, ...]] = None,
                   kwarg_types: Optional[Dict[str, Any]] = None):
     assert callable(target)
     new_kwargs = normalize_function(target, args, kwargs, arg_types,
                                     kwarg_types)  # type: ignore
     if new_kwargs:
         return self.tracer.create_proxy('call_function', target, (),
                                         new_kwargs)
     else:
         return super().call_function(target, args, kwargs)
Esempio n. 10
0
def contructors(fake_mode, func, *args, **kwargs):
    assert func not in _non_kwarg_device_constructors
    _, new_kwargs = normalize_function(
        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
    )
    if func in _like_tensor_constructors:
        default_device = new_kwargs["input"].device
        # TODO: file issue
        args = (new_kwargs.pop("input"),)
    else:
        # cpu is default device if none is specified
        default_device = torch.device("cpu")
        args = ()
    out_device = new_kwargs.pop("device", None)
    out_device = out_device if out_device is not None else default_device
    new_kwargs["device"] = torch.device("meta")
    r = func(*args, **new_kwargs)
    return FakeTensor(fake_mode, r, out_device)
Esempio n. 11
0
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        def has_mutated(before, after):
            return not torch.equal(before, after) if isinstance(
                before, torch.Tensor) and isinstance(after,
                                                     torch.Tensor) else False

        def has_aliased(lhs, rhs):
            return torch._C._is_alias_of(lhs, rhs) if isinstance(
                lhs, torch.Tensor) and isinstance(rhs, torch.Tensor) else False

        def is_mutable(arg_alias_pairs):
            for arg in arg_alias_pairs:
                if arg.alias_info is not None and arg.alias_info.is_write:
                    return True
            return False

        def is_aliasing(output_alias_info, arg_alias_pairs):
            if output_alias_info is None:
                return False
            for arg in arg_alias_pairs:
                if arg.alias_info is not None and bool(
                        len(output_alias_info.after_set
                            & arg.alias_info.after_set)):
                    return True
            return False

        def are_args_aliasing(lhs, rhs):
            for lhs_value in lhs:
                for rhs_value in rhs:
                    if (has_aliased(lhs_value, rhs_value)):
                        return True
            return False

        def standardize_name(name):
            return name if name != "self" else "input"

        def unwrap(e):
            if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
                try:
                    return e.elem
                except AttributeError as t:
                    return e
            else:
                return e

        self.ops.append(func._schema.name)
        arguments = normalize_function(
            func, args, kwargs, normalize_to_only_use_kwargs=True).kwargs

        cloned_arguments = dict(
            zip(arguments.keys(), clone_inputs(arguments.values())))
        out = func(*args, **kwargs)

        # Construct an aliasing map between op arguments for verifying aliasing pairs
        # between op arguments and op outputs. This is used to allow cases where two aliasing arguments
        # cause a non-mutable/non-aliasing argument to mutate or alias.
        arg_alias_pairs_map = {
            arg.name: [arg]
            for arg in func._schema.arguments
        }
        for i_arg, j_arg in combinations(func._schema.arguments, 2):
            i_values = tree_map(
                unwrap,
                tree_flatten(arguments.get(standardize_name(i_arg.name)))[0])
            j_values = tree_map(
                unwrap,
                tree_flatten(arguments.get(standardize_name(j_arg.name)))[0])
            if are_args_aliasing(i_values, j_values):
                arg_alias_pairs_map[i_arg.name].append(j_arg)
                arg_alias_pairs_map[j_arg.name].append(i_arg)

        for arg in func._schema.arguments:
            name = standardize_name(arg.name)
            if arguments.get(name) is not None:
                arg_alias_pairs = arg_alias_pairs_map[arg.name]
                before = tree_flatten(cloned_arguments.get(name))[0]
                after = tree_flatten(arguments.get(name))[0]
                u_values = tree_map(unwrap, after)
                u_out = tree_map(unwrap, out)
                u_out = u_out if isinstance(u_out, tuple) else (u_out, )
                if any([has_mutated(i, j) for i, j in zip(before, after)
                        ]) and not is_mutable(arg_alias_pairs):
                    raise RuntimeError(
                        f"Argument {name} is not defined as mutable but was mutated"
                    )
                for v in u_values:
                    for j in range(len(u_out)):
                        if has_aliased(v, u_out[j]) and not is_aliasing(
                                func._schema.returns[j].alias_info,
                                arg_alias_pairs):
                            raise RuntimeError(
                                f'Argument {name} is not defined to alias output but was aliasing'
                            )

        return out
Esempio n. 12
0
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        def has_mutated(before, after, md):
            if type(before) == torch.Tensor and type(after) == torch.Tensor:
                return not (torch.equal(before, after)
                            and md[0] == after.stride()
                            and md[1] == after.storage()._cdata)
            return False

        def has_aliased(lhs, rhs):
            try:
                return torch._C._overlaps(lhs, rhs)
            except Exception as exception:
                if str(exception).startswith("Cannot inspect value of type "):
                    return False
                else:
                    raise exception

        def standardize_name(name):
            return name if name != "self" else "input"

        def unwrap(e):
            if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
                try:
                    return e.elem
                except AttributeError as t:
                    return e
            return e

        def parse_metadata(e):
            if isinstance(e, torch.Tensor):
                if not type(e) == torch.Tensor:
                    try:
                        current = e.elem
                        return (deepcopy(current.stride()),
                                current.storage()._cdata)
                    except AttributeError as t:
                        return None
                else:
                    return (deepcopy(e.stride()), e.storage()._cdata)
            return None

        self.ops.append(func._schema.name)

        # Clone and process arguments and outputs
        pre_arguments = normalize_function(
            func, args, kwargs, normalize_to_only_use_kwargs=True).kwargs

        c_p_args = dict(
            zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))
        cloned_arguments = {
            name: tree_map(unwrap, c_p_args.get(name))
            for name in c_p_args
        }
        cloned_metadata = {
            name: tree_map(parse_metadata,
                           tree_flatten(pre_arguments.get(name))[0])
            for name in pre_arguments
        }

        out = func(*args, **kwargs)
        arguments = {
            name: tree_map(unwrap, pre_arguments.get(name))
            for name in pre_arguments
        }
        tuple_out = out if isinstance(out, tuple) else (out, )
        tuple_out = tree_map(unwrap, tuple_out)

        schema_info = SchemaInfo(func._schema)
        schema_info.add_argument_values(pre_arguments)

        # Process arguments with outputs
        for i in range(len(func._schema.arguments)):
            arg = func._schema.arguments[i]
            name = standardize_name(arg.name)
            if arguments.get(name) is not None:
                before = cloned_arguments.get(name)
                md = cloned_metadata.get(name)
                after = arguments.get(name)
                for j in range(len(tuple_out)):
                    if has_aliased(tuple_out[j], after):
                        if not schema_info.may_contain_alias(
                                SchemaArgument(SchemaArgType.output, j),
                                SchemaArgument(SchemaArgType.input, i)):
                            raise RuntimeError(
                                f'Argument {name} is not defined to alias output but was aliasing'
                            )
                        else:
                            self.aliasing.append(
                                Aliasing(func._schema.name, name,
                                         f"output_{j}"))
                if any(
                        has_mutated(a, b, c) for a, b, c in zip(
                            tree_flatten(before)[0],
                            tree_flatten(after)[0], md)):
                    if not schema_info.is_mutable(
                            SchemaArgument(SchemaArgType.input, i)):
                        raise RuntimeError(
                            f"Argument {name} is not defined as mutable but was mutated"
                        )
                    else:
                        self.mutated.append(Mutation(func._schema.name, name))

        # Aliasing between outputs
        for i, j in combinations(range(len(func._schema.returns)), 2):
            if has_aliased(tuple_out[i], tuple_out[j]):
                if not schema_info.may_contain_alias(
                        SchemaArgument(SchemaArgType.output, i),
                        SchemaArgument(SchemaArgType.output, j)):
                    raise RuntimeError(
                        f'Outputs {i} and {j} alias unexpectedly')

        return out
Esempio n. 13
0
def torch_dispatch_impl(cls_or_mode_instance, func, types, args, kwargs,
                        run_function):
    kwargs = kwargs if kwargs else {}
    in_fake_mode = isinstance(cls_or_mode_instance, FakeTensorMode)
    converter = cls_or_mode_instance.fake_tensor_converter if in_fake_mode else FakeTensorConverter(
    )

    # This classes virtualizes .device() calls, need to short-circuit
    # it instead of calling device again or we would keep on recurring
    if func == torch.ops.prim.device.default:
        assert len(args) == 1 and isinstance(args[0], FakeTensor)
        return args[0].fake_device

    def wrap(e, device=None):
        if isinstance(e, torch.Tensor) and not isinstance(e, FakeTensor):
            return converter(e, device)
        else:
            return e

    # if we are in the dispatch mode, we will enter this function even if the inputs
    # are not FakeTensors. For now, throw if any non-Fake Tensor inputs
    # and just support constructors. TODO: extend more broadly
    if isinstance(cls_or_mode_instance, FakeTensorMode):
        conversion_made = False

        def check_non_fake_tensor(x):
            nonlocal conversion_made
            conversion_made = conversion_made or (isinstance(
                x, torch.Tensor) and not isinstance(x, FakeTensor))

        tree_map(check_non_fake_tensor, args)
        tree_map(check_non_fake_tensor, kwargs)

        if conversion_made:
            raise Exception(
                "Invoking operators with non-Fake Tensor inputs in FakeTensorMode is not yet supported. "
                f"Please convert all Tensors to FakeTensors first. Found in {func}"
            )

    # _to_copy fails when run with FakeTensors to cuda device
    # TODO: debug
    if func == torch.ops.aten._to_copy.default:
        _, new_kwargs = normalize_function(func,
                                           args=args,
                                           kwargs=kwargs,
                                           normalize_to_only_use_kwargs=True)
        out_device = new_kwargs.pop("device", new_kwargs["input"].device)
        with no_dispatch():
            input = new_kwargs.pop("input").to("meta")
            return FakeTensor(torch.ops.aten._to_copy(input, **new_kwargs),
                              out_device)

    if _is_tensor_constructor(func):
        assert func not in _non_kwarg_device_constructors
        _, new_kwargs = normalize_function(func,
                                           args=args,
                                           kwargs=kwargs,
                                           normalize_to_only_use_kwargs=True)
        # cpu is default device if none is specified
        out_device = new_kwargs.pop("device", torch.device("cpu"))
        new_kwargs["device"] = torch.device("meta")
        r = run_function(func, types, (), new_kwargs)
        return FakeTensor(r, out_device)

    r = run_function(func, types, args, kwargs)

    # TODO: handle non-kwarg devices
    assert func not in _device_not_kwarg_ops, f"NYI: {func}"

    # if device is specified, use that
    if kwargs.get("device", None):
        return tree_map(partial(wrap, device=kwargs["device"]), r)

    # operators which copy size from another tensor do not
    # also take device from the size tensor
    # other size_as operators are not builtin operators
    if func == aten.resize_as_.default:
        _, new_kwargs = normalize_function(func,
                                           args=args,
                                           kwargs=kwargs,
                                           normalize_to_only_use_kwargs=True)
        # device of the input is returned
        return tree_map(partial(wrap, device=new_kwargs["input"].device), r)

    common_device = FakeTensor._find_common_device(func, args, kwargs)

    return tree_map(partial(wrap, device=common_device), r)
Esempio n. 14
0
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        def has_mutated(before, after, md):
            if type(before) == torch.Tensor and type(after) == torch.Tensor:
                return not (torch.equal(before, after)
                            and md[0] == after.stride()
                            and md[1] == after.storage()._cdata)
            return False

        def has_aliased(lhs, rhs):
            if type(lhs) == torch.Tensor and type(rhs) == torch.Tensor:
                return torch._C._is_alias_of(lhs, rhs)
            return False

        def is_mutable(arg_alias_pairs):
            for arg in arg_alias_pairs:
                if arg.alias_info is not None and arg.alias_info.is_write:
                    return True
            return False

        def is_aliasing(output, arg_alias_pairs):
            for arg in arg_alias_pairs:
                if arg.alias_info is not None:
                    if '*' in arg.alias_info.after_set:
                        same_types = output.type == arg.type
                        elems_same_types = (
                            isinstance(output.type, torch._C.ListType)
                            and output.type.getElementType() == arg.type)
                        if same_types or elems_same_types:
                            return True
                    elif output.alias_info is not None:
                        share_aliasing_sets = bool(
                            len(output.alias_info.after_set
                                & arg.alias_info.after_set))
                        if share_aliasing_sets:
                            return True
            return False

        def are_args_aliasing(lhs, rhs):
            for lhs_value in lhs:
                for rhs_value in rhs:
                    if (has_aliased(lhs_value, rhs_value)):
                        return True
            return False

        def standardize_name(name):
            return name if name != "self" else "input"

        def unwrap(e):
            if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
                try:
                    return e.elem
                except AttributeError as t:
                    return e
            return e

        def parse_metadata(e):
            if isinstance(e, torch.Tensor):
                if not type(e) == torch.Tensor:
                    try:
                        current = e.elem
                        return (deepcopy(current.stride()),
                                current.storage()._cdata)
                    except AttributeError as t:
                        return None
                else:
                    return (deepcopy(e.stride()), e.storage()._cdata)
            return None

        self.ops.append(func._schema.name)

        # Clone and process arguments and outputs
        pre_arguments = normalize_function(
            func, args, kwargs, normalize_to_only_use_kwargs=True).kwargs

        c_p_args = dict(
            zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))
        cloned_arguments = {
            name: tree_map(unwrap,
                           tree_flatten(c_p_args.get(name))[0])
            for name in c_p_args
        }
        cloned_metadata = {
            name: tree_map(parse_metadata,
                           tree_flatten(pre_arguments.get(name))[0])
            for name in pre_arguments
        }

        out = func(*args, **kwargs)

        arguments = {
            name: tree_map(unwrap,
                           tree_flatten(pre_arguments.get(name))[0])
            for name in pre_arguments
        }
        tuple_out = out if isinstance(out, tuple) else (out, )
        u_out = [tree_map(unwrap, tree_flatten(i)[0]) for i in tuple_out]

        # Construct an aliasing map between op arguments for verifying aliasing pairs
        # between op arguments and op outputs. This is used to allow cases where two aliasing arguments
        # cause a non-mutable/non-aliasing argument to mutate or alias.
        arg_alias_pairs_map = {
            standardize_name(arg.name): [arg]
            for arg in func._schema.arguments
        }

        # Construct an aliasing set for each output for verifying aliasing pairs
        # between op outputs
        out_alias_pairs_map = [set() for arg in func._schema.returns]

        # Aliasing between arguments
        for i_arg, j_arg in combinations(func._schema.arguments, 2):
            i_values = arguments.get(standardize_name(i_arg.name))
            j_values = arguments.get(standardize_name(j_arg.name))
            if are_args_aliasing(i_values, j_values):
                arg_alias_pairs_map[standardize_name(i_arg.name)].append(j_arg)
                arg_alias_pairs_map[standardize_name(j_arg.name)].append(i_arg)

        # Process arguments with outputs
        for arg in func._schema.arguments:
            name = standardize_name(arg.name)
            if arguments.get(name) is not None:
                arg_alias_pairs = arg_alias_pairs_map[name]
                before = cloned_arguments.get(name)
                md = cloned_metadata.get(name)
                after = arguments.get(name)
                for v in after:
                    for i in range(len(u_out)):
                        for j in range(len(u_out[i])):
                            if has_aliased(v, u_out[i][j]):
                                if not is_aliasing(func._schema.returns[i],
                                                   arg_alias_pairs):
                                    raise RuntimeError(
                                        f'Argument {name} is not defined to alias output but was aliasing'
                                    )
                                else:
                                    self.aliasing.append(
                                        Aliasing(func._schema.name, name,
                                                 f"output_{i}"))
                                    out_alias_pairs_map[i].add(name)
                                    break
                if any(
                        has_mutated(i, j, k)
                        for i, j, k in zip(before, after, md)):
                    if not is_mutable(arg_alias_pairs):
                        raise RuntimeError(
                            f"Argument {name} is not defined as mutable but was mutated"
                        )
                    else:
                        self.mutated.append(Mutation(func._schema.name, name))

        # Aliasing between outputs
        for i, j in combinations(range(len(func._schema.returns)), 2):
            if are_args_aliasing(u_out[i], u_out[j]):
                share_aliasing_inputs = bool(
                    len(out_alias_pairs_map[i] & out_alias_pairs_map[j]))
                if (not share_aliasing_inputs):
                    raise RuntimeError(
                        f'Outputs {i} and {j} alias unexpectedly')

        return out