Пример #1
0
 def test_before_and_after_alias_set_larger_than_1_raises_exception(
         self) -> None:
     with self.assertRaisesRegex(
             AssertionError,
             r"before alias set and after alias set cannot be larger than 1 at the same time",
     ):
         Annotation.parse("a|b -> c|d")
def generate_out_args_from_schema(
    func: FunctionSchema,
) -> Tuple[List[Return], List[Argument]]:
    # More of a sanity check - our existing restrictions on schemas should enforce that
    # mutable schema kinds never return their mutable arguments.
    assert not any(
        r.annotation is not None and r.annotation.is_write for r in func.returns
    )

    tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()]
    assert len(tensorlike_rets) > 0

    used_annotations = concatMap(
        lambda a: [] if a.annotation is None else a.annotation.alias_set,
        func.arguments.flat_all,
    )
    valid_annotations = [
        x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations
    ]

    all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)

    new_out_args: List[Argument] = []
    # The end result of new_returns is that:
    # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
    # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
    new_returns: List[Return] = []
    for (i, r) in enumerate(func.returns):
        if r.type.is_tensor_like():
            new_out = Argument(
                name="out" if len(func.returns) == 1 else f"out{i}",
                type=r.type,
                default=None,
                annotation=Annotation.parse(f"{valid_annotations[i]}!"),
            )
            new_out_args.append(new_out)
            if all_rets_are_tensors:
                # The convention for out= schemas is that they only return their out arguments
                # if the return is a plain Tensor (or if it's a tuple of plain Tensors)
                new_ret = Return(
                    name=None, type=new_out.type, annotation=new_out.annotation
                )
                new_returns.append(new_ret)
        else:
            new_returns.append(r)
    return new_returns, new_out_args
Пример #3
0
 def test_single_alias_is_write_to_alias_set(self) -> None:
     a = Annotation.parse("a! -> a|b")
     self.assertEqual(a.alias_set, tuple("a"))
     self.assertTrue(a.is_write)
     self.assertEqual(a.alias_set_after, ("a", "b"))
Пример #4
0
 def test_alias_set_is_write_raises_exception(self) -> None:
     with self.assertRaisesRegex(AssertionError,
                                 r"alias set larger than 1 is not mutable"):
         Annotation.parse("a|b!")
Пример #5
0
 def test_alias_set(self) -> None:
     a = Annotation.parse("a|b")
     self.assertEqual(a.alias_set, ("a", "b"))
Пример #6
0
 def test_single_alias_is_write_to_wildcard(self) -> None:
     a = Annotation.parse("a! -> *")
     self.assertEqual(a.alias_set, tuple("a"))
     self.assertTrue(a.is_write)
     self.assertEqual(a.alias_set_after, tuple("*"))
Пример #7
0
 def test_single_alias_no_write(self) -> None:
     a = Annotation.parse("a")
     self.assertEqual(a.alias_set, tuple("a"))
     self.assertFalse(a.is_write)
     self.assertEqual(a.alias_set_after, tuple())
Пример #8
0
def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
    # Generating an out= schema from a mutable schema.
    assert func.kind() == SchemaKind.mutable
    # The new out= schema has:
    # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments
    #   (if the argument is a tensor then we also return it for method chaining,
    #   otherwise we return nothing)
    # - an "out" overload name
    #
    # Note that:
    # (1) This also means that we can *only* generate an out= variant from a mutable schema
    #     if the mutable schema has at least one tensor-like non-aliasing return.
    # (2) The generated out= variant still has mutable positional arguments,
    #     but if necessary we could probably add another out= variant that also
    #     functionalizes the mutable arguments (a functional_out variant)

    # More of a sanity check - our existing restrictions on schemas should enforce that
    # mutable schema kinds never return their mutable arguments.
    assert not any(r.annotation is not None and r.annotation.is_write
                   for r in func.returns)

    tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()]
    assert len(tensorlike_rets) > 0

    used_annotations = concatMap(
        lambda a: [] if a.annotation is None else a.annotation.alias_set,
        func.arguments.flat_all,
    )
    valid_annotations = [
        x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations
    ]

    all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor)
                               for r in func.returns)

    new_out_args: List[Argument] = []
    # The end result of new_returns is that:
    # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
    # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
    new_returns: List[Return] = []
    for (i, r) in enumerate(func.returns):
        if r.type.is_tensor_like():
            new_out = Argument(
                name=f"out{i}",
                type=r.type,
                default=None,
                annotation=Annotation.parse(f"{valid_annotations[i]}!"),
            )
            new_out_args.append(new_out)
            if all_rets_are_tensors:
                # The convention for out= schemas is that they only return their out arguments
                # if the return is a plain Tensor (or if it's a tuple of plain Tensors)
                new_ret = Return(name=None,
                                 type=new_out.type,
                                 annotation=new_out.annotation)
                new_returns.append(new_ret)
        else:
            new_returns.append(r)

    return FunctionSchema(
        name=func.name.remove_inplace().with_overload(
            "out" if not func.name.overload_name else
            f"{func.name.overload_name}_out"),
        arguments=func.arguments.with_out_args(new_out_args),
        returns=tuple(new_returns),
    )