Пример #1
0
def get_tpu_strategy():
    resolver = get_tpu_cluster_resolver()
    remote.connect_to_cluster(resolver)
    tpu_strategy_util.initialize_tpu_system(resolver)
    return tpu_lib.TPUStrategy(resolver)
Пример #2
0
 def _get_strategy(self):
     self.resolver = tpu_cluster_resolver.TPUClusterResolver(
         tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
     remote.connect_to_cluster(self.resolver)
     tpu_strategy_util.initialize_tpu_system(self.resolver)
     return tpu_strategy.TPUStrategy(self.resolver)
Пример #3
0
 def testConnectToClusterInGraphModeWillFail(self):
   ops.disable_eager_execution()
   with self.assertRaises(ValueError):
     remote.connect_to_cluster(self._cluster_resolver)
   ops.enable_eager_execution()
Пример #4
0
 def testConnectToClusterWithLocalMaster(self):
   local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local')
   remote.connect_to_cluster(local_resolver)
Пример #5
0
 def testConnectToClusterTwiceOk(self):
   remote.connect_to_cluster(self._cluster_resolver)
   remote.connect_to_cluster(self._cluster_resolver)
Пример #6
0
 def test_cluster_resolver_available(self, enable_packed_var):
     resolver = get_tpu_cluster_resolver()
     remote.connect_to_cluster(resolver)
     tpu_strategy_util.initialize_tpu_system(resolver)
     strategy = tpu_lib.TPUStrategy(resolver)
     self.assertIs(strategy.cluster_resolver, resolver)
Пример #7
0
 def setUpClass(cls):
     super(VariablePartitioningTest, cls).setUpClass()
     cluster_def = multi_worker_test_base.create_in_process_cluster(
         num_workers=2, num_ps=2)
     cls.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def))
     remote.connect_to_cluster(cls.cluster_resolver.cluster_spec())