Esempio n. 1
0
 def testDuplicate(self):
     myreg = registry.Registry('testbar')
     myreg.register(bar, 'Bar')
     with self.assertRaisesRegexp(
             KeyError, r'Registering two testbar with name \'Bar\'! '
             r'\(Previous registration was in [^ ]+ .*.py:[0-9]+\)'):
         myreg.register(bar, 'Bar')
Esempio n. 2
0
 def testRegistryBasics(self, candidate):
     myreg = registry.Registry('testRegistry')
     with self.assertRaises(LookupError):
         myreg.lookup('testKey')
     myreg.register(candidate)
     self.assertEqual(myreg.lookup(candidate.__name__), candidate)
     myreg.register(candidate, 'testKey')
     self.assertEqual(myreg.lookup('testKey'), candidate)
     self.assertEqual(sorted(myreg.list()),
                      sorted(['testKey', candidate.__name__]))
Esempio n. 3
0
 def testDuplicate(self):
     myreg = registry.Registry('testbar')
     myreg.register(bar, 'Bar')
     with self.assertRaises(KeyError):
         myreg.register(bar, 'Bar')
Esempio n. 4
0
 def testRegisterFunction(self):
     myreg = registry.Registry('testbar')
     with self.assertRaises(LookupError):
         myreg.lookup('Bar')
     myreg.register(bar, 'Bar')
     assert myreg.lookup('Bar') == bar
Esempio n. 5
0
 def testRegisterClass(self):
     myreg = registry.Registry('testfoo')
     with self.assertRaises(LookupError):
         myreg.lookup('Foo')
     myreg.register(RegistryTest.Foo, 'Foo')
     assert myreg.lookup('Foo') == RegistryTest.Foo
Esempio n. 6
0
                              expand_composites=True):
            if self.ops_which_must_run:
                updated_ops_which_must_run = []
                if r.graph.building_function:
                    updated_ops_which_must_run = self.ops_which_must_run
                else:
                    updated_ops_which_must_run = [
                        o for o in self.ops_which_must_run if
                        o._control_flow_context is r.op._control_flow_context
                    ]
                r.op._add_control_inputs(updated_ops_which_must_run)

        self.collective_manager_ids_used = collective_manager_scopes_used


_acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers")


def register_acd_resource_resolver(f):
    """Register a function for resolving resources touched by an op.

  `f` is called for every Operation added in the ACD context with the op's
  original resource reads and writes. `f` is expected to update the sets of
  resource reads and writes in-place and return True if it updated either of the
  sets, False otherwise.

  Example:
  @register_acd_resource_resolver
  def ResolveIdentity(op, resource_reads, resource_writes):
    # op: The `Operation` being processed by ACD currently.
    # resource_reads: An `ObjectIdentitySet` of read-only resources.