def get_tpu_strategy(): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) tpu_strategy_util.initialize_tpu_system(resolver) return tpu_lib.TPUStrategy(resolver)
def _get_strategy(self): self.resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project) remote.connect_to_cluster(self.resolver) tpu_strategy_util.initialize_tpu_system(self.resolver) return tpu_strategy.TPUStrategy(self.resolver)
def testConnectToClusterInGraphModeWillFail(self): ops.disable_eager_execution() with self.assertRaises(ValueError): remote.connect_to_cluster(self._cluster_resolver) ops.enable_eager_execution()
def testConnectToClusterWithLocalMaster(self): local_resolver = SimpleClusterResolver(ClusterSpec({}), master='local') remote.connect_to_cluster(local_resolver)
def testConnectToClusterTwiceOk(self): remote.connect_to_cluster(self._cluster_resolver) remote.connect_to_cluster(self._cluster_resolver)
def test_cluster_resolver_available(self, enable_packed_var): resolver = get_tpu_cluster_resolver() remote.connect_to_cluster(resolver) tpu_strategy_util.initialize_tpu_system(resolver) strategy = tpu_lib.TPUStrategy(resolver) self.assertIs(strategy.cluster_resolver, resolver)
def setUpClass(cls): super(VariablePartitioningTest, cls).setUpClass() cluster_def = multi_worker_test_base.create_in_process_cluster( num_workers=2, num_ps=2) cls.cluster_resolver = SimpleClusterResolver(ClusterSpec(cluster_def)) remote.connect_to_cluster(cls.cluster_resolver.cluster_spec())