def testDispatchCacheOrderingDeterminism(self): table_1 = type_dispatch.TypeDispatchTable() table_1.add_target(MockShape(1, None, None)) table_1.add_target(MockShape(None, 2, None)) table_1.add_target(MockShape(None, None, 3)) table_2 = type_dispatch.TypeDispatchTable() table_2.add_target(MockShape(None, 2, None)) table_2.add_target(MockShape(1, None, None)) table_2.add_target(MockShape(None, None, 3)) table_3 = type_dispatch.TypeDispatchTable() table_3.add_target(MockShape(None, None, 3)) table_3.add_target(MockShape(1, None, None)) table_3.add_target(MockShape(None, 2, None)) # table_1, table_2, table_3 have the same targets self.assertEqual(set(table_1.all_targets()), set(table_2.all_targets())) self.assertEqual(set(table_2.all_targets()), set(table_3.all_targets())) # But they dispatch to the first target they find which does not have any # more specific viable target. shape = MockShape(1, 2, 3) self.assertEqual(table_1.dispatch(shape), MockShape(1, None, None)) self.assertEqual(table_2.dispatch(shape), MockShape(None, 2, None)) self.assertEqual(table_3.dispatch(shape), MockShape(None, None, 3))
def testDeletion(self): table = type_dispatch.TypeDispatchTable() table.add_target(MockShape(None, None)) table.add_target(MockShape(None, 1)) table.add_target(MockShape(None, 2)) self.assertEqual( table.all_targets(), [ MockShape(None, None), MockShape(None, 1), MockShape(None, 2) ]) table.delete(MockShape(None, 2)) # Should remove the target self.assertEqual( table.all_targets(), [ MockShape(None, None), MockShape(None, 1), ]) table.delete(MockShape(None, 2)) # Should have no effect self.assertEqual( table.all_targets(), [ MockShape(None, None), MockShape(None, 1), ])
def __init__(self): # The set of functions that have been missed; entries are ExecutionContext. self._missed = set() # The primary cache, mapping FunctionCacheKey to a concrete function. self._primary = collections.OrderedDict() # Maps a FunctionCacheKey K to a FunctionCacheKey V such that it is safe # to dispatch K to the concrete function of V that exists in _primary. # Used to lookup posible concrete functions when K is not in _primary. self._dispatch_table = type_dispatch.TypeDispatchTable() # TODO(b/202430155): Incorporate relaxation logic inside FunctionCache. # A cache key lookup, mapping a cache key generated without shape info to a # flat list of `TypeSpec`s with relaxed shapes (one for each flattened # argument). Arguments that are not Tensors or `CompositeTensor`s contain a # `None` for the corresponding relaxed spec. self.arg_relaxed_specs = collections.OrderedDict() # The secondary cache, mapping a cache key generated without shape info to a # function. self.arg_relaxed = collections.OrderedDict() # All OrderedDicts require manual garbage collection. self._garbage_collectors = [ _FunctionGarbageCollector(self._primary), _FunctionGarbageCollector(self.arg_relaxed), _FunctionGarbageCollector(self.arg_relaxed_specs) ]
def testGeneralizedNovel(self): table = type_dispatch.TypeDispatchTable() table.add_target(MockShape(None, 1, None)) table.add_target(MockShape(None, 1, 2)) self.assertEqual( table.try_generalizing_trace_type(MockShape(None, 2, 3)), MockShape(None, None, None))
def __init__(self): # The primary cache, mapping FunctionCacheKey to a concrete function. self._primary = collections.OrderedDict() # Maps a FunctionCacheKey K to a FunctionCacheKey V such that it is safe # to dispatch K to the concrete function of V that exists in _primary. # Used to lookup posible concrete functions when K is not in _primary. self._dispatch_table = type_dispatch.TypeDispatchTable()
def testDispatchNoMatches(self): table = type_dispatch.TypeDispatchTable() table.add_target(MockShape(None, 1, None)) table.add_target(MockShape(None, 1, 2)) table.add_target(MockShape(None, 2, 2)) self.assertIsNone(table.dispatch(MockShape(1, 2))) self.assertIsNone(table.dispatch(MockShape(1, 2, 3))) self.assertIsNone(table.dispatch(MockShape(1, 2, 3, 4)))
def testHorizontal(self): table = type_dispatch.TypeDispatchTable() table.add_target(MockShape(1, )) table.add_target(MockShape(1, 2)) table.add_target(MockShape(1, 2, 3)) self.assertEqual(list(table.targets), [MockShape(1, ), MockShape(1, 2), MockShape(1, 2, 3)])
def testDuplicateNodes(self): table = type_dispatch.TypeDispatchTable() table.add_target(MockShape(None, None)) table.add_target(MockShape(1, None)) table.add_target(MockShape(None, 2)) table.add_target(MockShape(None, None)) self.assertEqual( list(table.targets), [MockShape(None, None), MockShape(1, None), MockShape(None, 2)])
def testVertical(self): table = type_dispatch.TypeDispatchTable() table.add_target(MockShape(None, None, None)) table.add_target(MockShape(None, None, 1)) table.add_target(MockShape(None, 1, 1)) table.add_target(MockShape(1, 1, 1)) self.assertEqual(list(table.targets), [ MockShape(None, None, None), MockShape(None, None, 1), MockShape(None, 1, 1), MockShape(1, 1, 1) ])
def testDispatchMoreSpecific(self): table = type_dispatch.TypeDispatchTable() table.add_target(MockShape(None, None, None)) table.add_target(MockShape(None, 1, None)) table.add_target(MockShape(None, 1, 2)) table.add_target(MockShape(None, 2, 2)) self.assertEqual(table.dispatch(MockShape(1, 1, 2)), MockShape(None, 1, 2)) self.assertEqual( table.dispatch(MockShape(1, 1, 3)), MockShape(None, 1, None)) self.assertEqual( table.dispatch(MockShape(1, 3, 3)), MockShape(None, None, None)) self.assertEqual(table.dispatch(MockShape(1, 2, 2)), MockShape(None, 2, 2))
def testDispatchExactMatches(self): table = type_dispatch.TypeDispatchTable() table.add_target(MockShape(None, None, None)) table.add_target(MockShape(None, 1, None)) table.add_target(MockShape(None, 1, 2)) table.add_target(MockShape(None, 2, 2)) self.assertEqual(table.dispatch(MockShape(None, 1, 2)), MockShape(None, 1, 2)) self.assertEqual(table.dispatch(MockShape(None, 1, None)), MockShape(None, 1, None)) self.assertEqual(table.dispatch(MockShape(None, None, None)), MockShape(None, None, None)) self.assertEqual(table.dispatch(MockShape(None, 2, 2)), MockShape(None, 2, 2))
def testDispatchCachedAddUpdates(self): table = type_dispatch.TypeDispatchTable() table.add_target(MockShape(None, None, None)) self.assertEqual( table.dispatch(MockShape(1, 1, 2)), MockShape(None, None, None)) table.add_target(MockShape(None, 1, None)) self.assertEqual( table.dispatch(MockShape(1, 1, 2)), MockShape(None, 1, None)) table.add_target(MockShape(None, 1, 2)) self.assertEqual(table.dispatch(MockShape(1, 1, 2)), MockShape(None, 1, 2)) table.add_target(MockShape(1, 1, 2)) self.assertEqual(table.dispatch(MockShape(1, 1, 2)), MockShape(1, 1, 2))
def testContains(self): table = type_dispatch.TypeDispatchTable() table.add_target(MockShape(None, None, None)) table.add_target(MockShape(None, 1)) table.add_target(MockShape(1, 1)) table.add_target(MockShape(None, 2, 1)) self.assertTrue(table.contains(MockShape(None, None, None))) self.assertTrue(table.contains(MockShape(None, 1))) self.assertTrue(table.contains(MockShape(1, 1))) self.assertTrue(table.contains(MockShape(None, 2, 1))) self.assertFalse(table.contains(MockShape(None, None, 1))) self.assertFalse(table.contains(MockShape(1, None))) self.assertFalse(table.contains(MockShape(1, 2))) self.assertFalse(table.contains(MockShape(None, 2, None)))
def testContains(self): table = type_dispatch.TypeDispatchTable() table.add_target(MockShape(None, None, None)) table.add_target(MockShape(None, 1)) table.add_target(MockShape(1, 1)) table.add_target(MockShape(None, 2, 1)) self.assertIn(MockShape(None, None, None), table.targets) self.assertIn(MockShape(None, 1), table.targets) self.assertIn(MockShape(1, 1), table.targets) self.assertIn(MockShape(None, 2, 1), table.targets) self.assertNotIn(MockShape(None, None, 1), table.targets) self.assertNotIn(MockShape(1, None), table.targets) self.assertNotIn(MockShape(1, 2), table.targets) self.assertNotIn(MockShape(None, 2, None), table.targets)