def test_computed_table_with_autograd(self): global_m = C._dispatch_library("IMPL", "_", "AutogradCPU") result = self.commute( "foo", [ # m.def("foo(Tensor x) -> Tensor") lambda m: m.def_("foo(Tensor x) -> Tensor"), # m.impl("foo", torch::kAutograd, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "Autograd"), ]) state, table = result.state, result.table self.assertExpectedInline( state, '''\ name: test::foo schema: test::foo(Tensor x) -> (Tensor) debug: registered at /dev/null:0 alias analysis kind: FROM_SCHEMA Autograd[alias]: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. extracted_table = extract_dispatch_table_with_keys( table, dispatch_keys_to_check) self.assertExpectedInline( extracted_table, '''\ AutogradOther: impl_t_t [autograd kernel] AutogradCPU: impl_t_t [autograd kernel] AutogradCUDA: impl_t_t [autograd kernel] AutogradXLA: impl_t_t [autograd kernel] ''')
def test_computed_table_with_cpu_catchall(self): global_m = C._dispatch_library("IMPL", "_", "AutogradCPU") result = self.commute( "foo", [ # m.def("foo", [](const Tensor & x) { return x }) lambda m: m.def_name_t_t("foo"), # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "CPU"), ]) state, table = result.state, result.table self.assertExpectedInline( state, '''\ name: test::foo schema: test::foo(Tensor _0) -> (Tensor _0) debug: registered at /dev/null:0 alias analysis kind: CONSERVATIVE CPU: impl_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] catchall: default_def_name_t_t :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. extracted_table = extract_dispatch_table_with_keys( table, dispatch_keys_to_check) self.assertExpectedInline( extracted_table, '''\ CPU: impl_t_t [kernel] CUDA: default_def_name_t_t [catch all] XLA: default_def_name_t_t [catch all] AutogradOther: default_def_name_t_t [catch all] AutogradCPU: fallthrough registered in pytorch framework [backend fallback] AutogradCUDA: default_def_name_t_t [catch all] AutogradXLA: default_def_name_t_t [catch all] ''')
def test_computed_table_with_cpu_math(self): global_m = C._dispatch_library("IMPL", "_", "AutogradCPU") result = self.commute("foo", [ # m.def("foo(Tensor x) -> Tensor") lambda m: m.def_("foo(Tensor x) -> Tensor"), # m.impl("foo", torch::kCPU, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "CPU", debug="fn_cpu"), # m.impl("foo", torch::kCompositeImplicitAutograd, [](const Tensor & x) { return x }) lambda m: m.impl_t_t("foo", "CompositeImplicitAutograd", debug="fn_math"), ]) state, table = result.state, result.table self.assertExpectedInline(state, '''\ name: test::foo schema: test::foo(Tensor x) -> (Tensor) debug: registered at /dev/null:0 alias analysis kind: FROM_SCHEMA CPU: fn_cpu :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] CompositeImplicitAutograd[alias]: fn_math :: (Tensor _0) -> (Tensor _0) [ boxed unboxed ] ''') # computed dispatch table is too big, so we only check on a few entries we're interested in. extracted_table = extract_dispatch_table_with_keys(table, dispatch_keys_to_check) self.assertExpectedInline(extracted_table, '''\ Undefined: fn_math [math kernel] CPU: fn_cpu [kernel] CUDA: fn_math [math kernel] XLA: fn_math [math kernel] AutogradOther: fn_math [math kernel] AutogradCPU: fallthrough registered in pytorch framework [backend fallback] AutogradCUDA: fn_math [math kernel] AutogradXLA: fn_math [math kernel] ''')
def test_multiple_fallback(self): global_m = C._dispatch_library("IMPL", "_", "XLA") global_m.fallback_fallthrough(), try: global_m.fallback_fallthrough(), except RuntimeError as e: self.assertExpectedInline( str(e), '''Tried to register multiple backend fallbacks for the same dispatch key XLA; previous registration registered at /dev/null:0, new registration registered at /dev/null:0''' # noqa ) else: self.assertTrue(False)
def run_ops(self, name, ops, ctor_order=None, dtor_order=None, results=None, expect_raises=False): """ Given a list of operator registrations, run the registrations in the order specified by ctor_order, and then run the deregistrations in dtor_order. If results is specified, intermediate results are checked for consistency with results stored in results (and stored in results if this is the first time we've seen them). Results are expected to be equivalent modulo commutativity and inverses (thus, results is keyed on a frozenset of in effect registrations from ops). Results stores namedtuple Result[state, table, provenance], where state is a string that contains non-derived kernel registered or error message if it doesn't pass; table is a string that contains computed dispatch table entries; provenance is a string that describes how exactly we got this string. If expect_raises is True, it is not an error to raise an exception. Instead, we'll store the exception string (instead of the dispatcher state) in results. In principle we should flag these differently, but it's very obvious when you get an error in one case but not another. """ # By allocating every test into a fresh namespace, this makes it less # likely that a bug in the testing framework will result in tests # interfering with each other self.__class__.namespace_index += 1 if results is None: results = {} if ctor_order is None: ctor_order = list(range(len(ops))) if dtor_order is None: dtor_order = list(reversed(ctor_order)) # Refs which retain the c10::Module object so we can explicitly control # when each deregistration happens (deregistration occurs when the # object gets deallocated). refs = [None] * len(ops) # Keep track of the set "in effect" registrations active_ops = set() # double underscore to make it less likely we conflict with something # else test_namespace = "__test{}__".format(self.namespace_index) def check_invariants(actual_provenance): C._dispatch_check_invariants(name) # Normalize the test namespace so that expected outputs are stable actual_state = C._dispatch_dump("{}::{}".format( test_namespace, name)).replace(test_namespace, "test") actual_table = C._dispatch_dump_table("{}::{}".format( test_namespace, name)).replace(test_namespace, "test") expected_state, expected_table, expected_provenance = results.setdefault( frozenset(active_ops), Result(actual_state, actual_table, actual_provenance)) self.assertMultiLineEqual( expected_state, actual_state, "expected from {}; actual from {}".format( expected_provenance, actual_provenance)) self.assertMultiLineEqual( expected_table, actual_table, "expected from {}; actual from {}".format( expected_provenance, actual_provenance)) results.setdefault(frozenset(), Result("", "", "hardcoded initial state")) check_invariants("initial state") # In the order specified by ctor_order, run registrations set_to_report = frozenset(range(len(ops))) for i, op_ix in enumerate(ctor_order): # It would be better to DEF here, but because we manage # lifetime of multiple registrations with multiple Library # references (refs), we can't deal with the strict checking # from DEF. refs[op_ix] = C._dispatch_library("FRAGMENT", test_namespace, "") active_ops.add(op_ix) try: ops[op_ix](refs[op_ix]) check_invariants("running ctors {}".format(ctor_order[:i + 1])) except RuntimeError as e: if not expect_raises: raise actual = str(e).replace(test_namespace, "test") actual = actual.split("\nException raised from ")[0] expected, _, expected_provenance = results.setdefault( frozenset(active_ops), Result( actual, "", "error after running ctors {}".format( ctor_order[:i + 1]))) self.assertMultiLineEqual(expected, actual, expected_provenance) set_to_report = frozenset(active_ops) active_ops.remove(op_ix) # NB: this finally test asserts that if a registrations fails, # the dispatcher is left in the same state *that it was before*! check_invariants( "running ctors {} and then failing to run ctor {} " "(did this failure leave the dispatcher in a wedged state? " "it shouldn't!)".format(ctor_order[:i], op_ix)) break last_ctor = i if expect_raises and len(active_ops) == len(ops): # Destroy references first, as some test frameworks (like pytest) # will retain references in the exception raised by assertTrue! EW! refs = None self.assertTrue( False, "expected exception to be raised, but nothing was raised " "(after running ctors {})".format(ctor_order)) # In the order specified by dtor_order, run deregistrations for i, op_ix in enumerate(dtor_order): # Trigger a destruction refs[op_ix] = None # discard not remove, since we may not have actually deregistered # anything if there was an error raised if expect_raises: active_ops.discard(op_ix) else: active_ops.remove(op_ix) check_invariants("running ctors {}, then running dtors {}".format( ctor_order[:last_ctor + 1], dtor_order[:i + 1])) return results[set_to_report][0]
def __init__(self): C._dispatch_check_invariants(self.name) # type: ignore[attr-defined] self.ref = C._dispatch_library("FRAGMENT", self.namespace, "") # type: ignore[attr-defined] self.ref.def_("foo(Tensor x) -> Tensor")