def _resource_creator_scope(self): with self._coordinator_creation_lock: if not self._container_strategy()._cluster_coordinator: # pylint: disable=protected-access cluster_coordinator.ClusterCoordinator( strategy=self._container_strategy()) # TODO(wxinyi): We should warn the user of the inefficiency of creating # `StaticHashTable` inside a `@tf.function`-wrapped `dataset_fn` to be # distributed with `distribute_datasets_from_function` and # `create_per_worker_dataset`. This is because the `dataset_fn` does not # use the same `default_graph` as `scope` to which the # `resource_creator_stack` belongs. Thus, `StaticHashTable` creation inside # `dataset_fn` is not intercepted. And since its resource creation under a # `tf.function` is lifted out, all workers will share the same resource on # the coordinator which incurs worker-coordinator communication overhead. def lookup_creator(next_creator, *args, **kwargs): if load_context.in_load_context: return (ps_values.RestoredDistributedTable( self._container_strategy(), lambda: next_creator(*args, **kwargs))) # pylint: disable=protected-access else: return ps_values.DistributedTable(self._container_strategy(), lambda: next_creator(*args, **kwargs)) # pylint: disable=protected-access def restored_lookup_creator(next_creator, *args, **kwargs): return (ps_values.RestoredDistributedTable( self._container_strategy(), lambda: next_creator(*args, **kwargs))) # pylint: disable=protected-access return [ops.resource_creator_scope("StaticHashTable", lookup_creator), ops.resource_creator_scope("RestoredStaticHashTable", restored_lookup_creator)]
def testResourceCreatorNestingError(self): def creator(next_creator, *a, **kwargs): return next_creator(*a, **kwargs) # Save the state so we can clean up at the end. graph = ops.get_default_graph() old_creator_stack = graph._resource_creator_stack["_DummyResource"] try: scope = ops.resource_creator_scope(creator, "_DummyResource") scope.__enter__() with ops.resource_creator_scope(creator, "_DummyResource"): with self.assertRaises(RuntimeError): scope.__exit__(None, None, None) finally: graph._resource_creator_stack["_DummyResource"] = old_creator_stack
def testResourceCreatorNesting(self): def resource_creator_fn_0(next_creator, *a, **kwargs): instance = next_creator(*a, **kwargs) instance._value = 1 return instance def resource_creator_fn_1(next_creator, *a, **kwargs): kwargs["handle_name"] = "forced_name1" return next_creator(*a, **kwargs) with ops.resource_creator_scope(["_DummyResource1"], resource_creator_fn_0): with ops.resource_creator_scope(["_DummyResource1"], resource_creator_fn_1): dummy_0 = _DummyResource1(handle_name="fake_name") self.assertEqual(dummy_0._handle_name, "forced_name1") self.assertEqual(dummy_0._value, 1)
def testResourceCreator(self): def resource_creator_fn(next_creator, *a, **kwargs): kwargs["handle_name"] = "forced_name" return next_creator(*a, **kwargs) # test that two resource classes use the same creator function with ops.resource_creator_scope(["_DummyResource", "_DummyResource1"], resource_creator_fn): dummy_0 = _DummyResource(handle_name="fake_name_0") dummy_1 = _DummyResource1(handle_name="fake_name_1") self.assertEqual(dummy_0._handle_name, "forced_name") self.assertEqual(dummy_1._handle_name, "forced_name")