def testDispatchCacheOrderingDeterminism(self):
    table_1 = type_dispatch.TypeDispatchTable()
    table_1.add_target(MockShape(1, None, None))
    table_1.add_target(MockShape(None, 2, None))
    table_1.add_target(MockShape(None, None, 3))

    table_2 = type_dispatch.TypeDispatchTable()
    table_2.add_target(MockShape(None, 2, None))
    table_2.add_target(MockShape(1, None, None))
    table_2.add_target(MockShape(None, None, 3))

    table_3 = type_dispatch.TypeDispatchTable()
    table_3.add_target(MockShape(None, None, 3))
    table_3.add_target(MockShape(1, None, None))
    table_3.add_target(MockShape(None, 2, None))

    # table_1, table_2, table_3 have the same targets
    self.assertEqual(set(table_1.all_targets()), set(table_2.all_targets()))
    self.assertEqual(set(table_2.all_targets()), set(table_3.all_targets()))

    # But they dispatch to the first target they find which does not have any
    # more specific viable target.
    shape = MockShape(1, 2, 3)
    self.assertEqual(table_1.dispatch(shape), MockShape(1, None, None))
    self.assertEqual(table_2.dispatch(shape), MockShape(None, 2, None))
    self.assertEqual(table_3.dispatch(shape), MockShape(None, None, 3))
  def testDeletion(self):
    table = type_dispatch.TypeDispatchTable()
    table.add_target(MockShape(None, None))
    table.add_target(MockShape(None, 1))
    table.add_target(MockShape(None, 2))

    self.assertEqual(
        table.all_targets(), [
            MockShape(None, None),
            MockShape(None, 1),
            MockShape(None, 2)
        ])

    table.delete(MockShape(None, 2))  # Should remove the target

    self.assertEqual(
        table.all_targets(), [
            MockShape(None, None),
            MockShape(None, 1),
        ])

    table.delete(MockShape(None, 2))  # Should have no effect

    self.assertEqual(
        table.all_targets(), [
            MockShape(None, None),
            MockShape(None, 1),
        ])
Esempio n. 3
0
  def __init__(self):
    # The set of functions that have been missed; entries are ExecutionContext.
    self._missed = set()
    # The primary cache, mapping FunctionCacheKey to a concrete function.
    self._primary = collections.OrderedDict()

    # Maps a FunctionCacheKey K to a FunctionCacheKey V such that it is safe
    # to dispatch K to the concrete function of V that exists in _primary.
    # Used to lookup posible concrete functions when K is not in _primary.
    self._dispatch_table = type_dispatch.TypeDispatchTable()

    # TODO(b/202430155): Incorporate relaxation logic inside FunctionCache.
    # A cache key lookup, mapping a cache key generated without shape info to a
    # flat list of `TypeSpec`s with relaxed shapes (one for each flattened
    # argument). Arguments that are not Tensors or `CompositeTensor`s contain a
    # `None` for the corresponding relaxed spec.
    self.arg_relaxed_specs = collections.OrderedDict()
    # The secondary cache, mapping a cache key generated without shape info to a
    # function.
    self.arg_relaxed = collections.OrderedDict()
    # All OrderedDicts require manual garbage collection.

    self._garbage_collectors = [
        _FunctionGarbageCollector(self._primary),
        _FunctionGarbageCollector(self.arg_relaxed),
        _FunctionGarbageCollector(self.arg_relaxed_specs)
    ]
Esempio n. 4
0
 def testGeneralizedNovel(self):
     table = type_dispatch.TypeDispatchTable()
     table.add_target(MockShape(None, 1, None))
     table.add_target(MockShape(None, 1, 2))
     self.assertEqual(
         table.try_generalizing_trace_type(MockShape(None, 2, 3)),
         MockShape(None, None, None))
Esempio n. 5
0
    def __init__(self):
        # The primary cache, mapping FunctionCacheKey to a concrete function.
        self._primary = collections.OrderedDict()

        # Maps a FunctionCacheKey K to a FunctionCacheKey V such that it is safe
        # to dispatch K to the concrete function of V that exists in _primary.
        # Used to lookup posible concrete functions when K is not in _primary.
        self._dispatch_table = type_dispatch.TypeDispatchTable()
  def testDispatchNoMatches(self):
    table = type_dispatch.TypeDispatchTable()
    table.add_target(MockShape(None, 1, None))
    table.add_target(MockShape(None, 1, 2))
    table.add_target(MockShape(None, 2, 2))

    self.assertIsNone(table.dispatch(MockShape(1, 2)))
    self.assertIsNone(table.dispatch(MockShape(1, 2, 3)))
    self.assertIsNone(table.dispatch(MockShape(1, 2, 3, 4)))
Esempio n. 7
0
 def testHorizontal(self):
     table = type_dispatch.TypeDispatchTable()
     table.add_target(MockShape(1, ))
     table.add_target(MockShape(1, 2))
     table.add_target(MockShape(1, 2, 3))
     self.assertEqual(list(table.targets),
                      [MockShape(1, ),
                       MockShape(1, 2),
                       MockShape(1, 2, 3)])
Esempio n. 8
0
 def testDuplicateNodes(self):
     table = type_dispatch.TypeDispatchTable()
     table.add_target(MockShape(None, None))
     table.add_target(MockShape(1, None))
     table.add_target(MockShape(None, 2))
     table.add_target(MockShape(None, None))
     self.assertEqual(
         list(table.targets),
         [MockShape(None, None),
          MockShape(1, None),
          MockShape(None, 2)])
Esempio n. 9
0
 def testVertical(self):
     table = type_dispatch.TypeDispatchTable()
     table.add_target(MockShape(None, None, None))
     table.add_target(MockShape(None, None, 1))
     table.add_target(MockShape(None, 1, 1))
     table.add_target(MockShape(1, 1, 1))
     self.assertEqual(list(table.targets), [
         MockShape(None, None, None),
         MockShape(None, None, 1),
         MockShape(None, 1, 1),
         MockShape(1, 1, 1)
     ])
  def testDispatchMoreSpecific(self):
    table = type_dispatch.TypeDispatchTable()
    table.add_target(MockShape(None, None, None))
    table.add_target(MockShape(None, 1, None))
    table.add_target(MockShape(None, 1, 2))
    table.add_target(MockShape(None, 2, 2))

    self.assertEqual(table.dispatch(MockShape(1, 1, 2)), MockShape(None, 1, 2))
    self.assertEqual(
        table.dispatch(MockShape(1, 1, 3)), MockShape(None, 1, None))
    self.assertEqual(
        table.dispatch(MockShape(1, 3, 3)), MockShape(None, None, None))
    self.assertEqual(table.dispatch(MockShape(1, 2, 2)), MockShape(None, 2, 2))
Esempio n. 11
0
    def testDispatchExactMatches(self):
        table = type_dispatch.TypeDispatchTable()
        table.add_target(MockShape(None, None, None))
        table.add_target(MockShape(None, 1, None))
        table.add_target(MockShape(None, 1, 2))
        table.add_target(MockShape(None, 2, 2))

        self.assertEqual(table.dispatch(MockShape(None, 1, 2)),
                         MockShape(None, 1, 2))
        self.assertEqual(table.dispatch(MockShape(None, 1, None)),
                         MockShape(None, 1, None))
        self.assertEqual(table.dispatch(MockShape(None, None, None)),
                         MockShape(None, None, None))
        self.assertEqual(table.dispatch(MockShape(None, 2, 2)),
                         MockShape(None, 2, 2))
  def testDispatchCachedAddUpdates(self):
    table = type_dispatch.TypeDispatchTable()

    table.add_target(MockShape(None, None, None))
    self.assertEqual(
        table.dispatch(MockShape(1, 1, 2)), MockShape(None, None, None))

    table.add_target(MockShape(None, 1, None))
    self.assertEqual(
        table.dispatch(MockShape(1, 1, 2)), MockShape(None, 1, None))

    table.add_target(MockShape(None, 1, 2))
    self.assertEqual(table.dispatch(MockShape(1, 1, 2)), MockShape(None, 1, 2))

    table.add_target(MockShape(1, 1, 2))
    self.assertEqual(table.dispatch(MockShape(1, 1, 2)), MockShape(1, 1, 2))
  def testContains(self):
    table = type_dispatch.TypeDispatchTable()
    table.add_target(MockShape(None, None, None))
    table.add_target(MockShape(None, 1))
    table.add_target(MockShape(1, 1))
    table.add_target(MockShape(None, 2, 1))

    self.assertTrue(table.contains(MockShape(None, None, None)))
    self.assertTrue(table.contains(MockShape(None, 1)))
    self.assertTrue(table.contains(MockShape(1, 1)))
    self.assertTrue(table.contains(MockShape(None, 2, 1)))

    self.assertFalse(table.contains(MockShape(None, None, 1)))
    self.assertFalse(table.contains(MockShape(1, None)))
    self.assertFalse(table.contains(MockShape(1, 2)))
    self.assertFalse(table.contains(MockShape(None, 2, None)))
Esempio n. 14
0
    def testContains(self):
        table = type_dispatch.TypeDispatchTable()
        table.add_target(MockShape(None, None, None))
        table.add_target(MockShape(None, 1))
        table.add_target(MockShape(1, 1))
        table.add_target(MockShape(None, 2, 1))

        self.assertIn(MockShape(None, None, None), table.targets)
        self.assertIn(MockShape(None, 1), table.targets)
        self.assertIn(MockShape(1, 1), table.targets)
        self.assertIn(MockShape(None, 2, 1), table.targets)

        self.assertNotIn(MockShape(None, None, 1), table.targets)
        self.assertNotIn(MockShape(1, None), table.targets)
        self.assertNotIn(MockShape(1, 2), table.targets)
        self.assertNotIn(MockShape(None, 2, None), table.targets)