def _create_tpu_strategy(): FLAGS = flags.FLAGS # pylint: disable=invalid-name global _did_connect_to_cluster global _topology try: # Attempt to locally discover the TPU. This will fail for Cloud TPU, in # which case we fall back to the values passed as flags. resolver = tpu_cluster_resolver.TPUClusterResolver() did_automatically_resolve = True except ValueError: did_automatically_resolve = False # These flags will be defined by tpu_test_wrapper.py. resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "", zone=hasattr(FLAGS, "zone") and FLAGS.zone or None, project=hasattr(FLAGS, "project") and FLAGS.project or None, ) # Only connect once per process, rather than per test method. if not _did_connect_to_cluster: if getattr(FLAGS, "tpu", "") or did_automatically_resolve: remote.connect_to_cluster(resolver) _did_connect_to_cluster = True _topology = tpu_strategy_util.initialize_tpu_system(resolver) device_assignment = None if use_single_core: device_assignment = device_assignment_lib.DeviceAssignment( _topology, core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT) # Steps per run is only supported in TF 1.x if tf2.enabled(): strategy = tpu_lib.TPUStrategyV2( resolver, device_assignment, experimental_spmd_xla_partitioning=enable_spmd_xla_paritioning, **kwargs) else: strategy = tpu_lib.TPUStrategyV1(resolver, steps_per_run, device_assignment, **kwargs) if enable_packed_variable and enable_spmd_xla_paritioning: raise ValueError( "Packed Variable is not compatiable with SPMD mode") strategy._enable_packed_variable_in_eager_mode = enable_packed_variable # pylint: disable=protected-access return strategy
def testEnvironmentAndRpcDetectionForGoogleNamedPort(self): cluster_resolver = resolver.TPUClusterResolver( tpu='/bns/ab/cd/ef:port') self.assertEqual(cluster_resolver.environment, 'google') self.assertEqual(cluster_resolver.rpc_layer, None) self.assertEqual(cluster_resolver._tpu, compat.as_bytes('/bns/ab/cd/ef:port'))
def testRetrieveProjectAndZoneFromMetadata(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', 'health': 'HEALTHY' } } cluster_resolver = resolver.TPUClusterResolver( project=None, zone=None, tpu=['test-tpu-1'], credentials=None, service=self.mock_service_client(tpu_map=tpu_map), coordinator_name='coordinator') actual_cluster_spec = cluster_resolver.cluster_spec() expected_proto = """ job { name: 'coordinator' tasks { key: 0 value: '10.128.1.2:%s' } } job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ % cluster_resolver._coordinator_port self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
def get_tpu_cluster_resolver(): resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project, ) return resolver
def testGkeEnvironmentForPod(self): os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,' 'grpc://10.120.27.6:8470,' 'grpc://10.120.27.7:8470,' 'grpc://10.120.27.8:8470') self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ) self.assertTrue(resolver.TPUClusterResolver._in_gke()) self.assertEqual( compat.as_bytes('grpc://10.120.27.5:8470,' 'grpc://10.120.27.6:8470,' 'grpc://10.120.27.7:8470,' 'grpc://10.120.27.8:8470'), compat.as_bytes( resolver.TPUClusterResolver._gke_endpoints())) cluster_resolver = resolver.TPUClusterResolver() self.assertEqual( compat.as_bytes('grpc://10.120.27.5:8470'), compat.as_bytes(cluster_resolver.master())) actual_cluster_spec = cluster_resolver.cluster_spec() expected_proto = """ job { name: 'worker' tasks { key: 0 value: '10.120.27.5:8470' } tasks { key: 1 value: '10.120.27.6:8470' } tasks { key: 2 value: '10.120.27.7:8470' } tasks { key: 3 value: '10.120.27.8:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
def setUp(self): super(TPUEmbeddingCheckpointTest, self).setUp() self.resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project) remote.connect_to_cluster(self.resolver) tpu_strategy_util.initialize_tpu_system(self.resolver) self.strategy = tpu_strategy.TPUStrategy(self.resolver) self.num_rows = self.strategy.num_replicas_in_sync # These tests use two mid level API objects, initialized with different # values. These have the same sizes. with self.strategy.scope(): self.first_mid_level_contents = np.ones((self.num_rows, 4)) self.first_mid_level_optimizer = tpu_embedding_v2_utils.SGD( learning_rate=0.1) self.first_mid_level = self.build_mid_level( self.first_mid_level_contents, self.first_mid_level_optimizer) self.second_mid_level_contents = np.ones((self.num_rows, 4)) * 2 self.second_mid_level_optimizer = tpu_embedding_v2_utils.SGD( learning_rate=0.1) self.second_mid_level = self.build_mid_level( self.second_mid_level_contents, self.second_mid_level_optimizer, initialize_tpu_embedding=False) self.cpu_mid_level_optimizer = tpu_embedding_v2_utils.SGD( learning_rate=0.1) self.cpu_mid_level = self.build_mid_level( self.second_mid_level_contents, self.cpu_mid_level_optimizer)
def tpu_init_tf2(self, tpu_name=None): tpu_name = tpu_name or os.environ.get('TPU_NAME', None) tpu_cluster_resolver = resolver.TPUClusterResolver(tpu_name) service_addr = tpu_cluster_resolver.get_master() self.service_addr = service_addr.replace('grpc://', '').replace(':8470', ':8466') self.workers_list = get_workers_list(tpu_cluster_resolver) self.monitoring_level = 2 self.duration_ms = 1000 util = self.tpu_utilization(self.service_addr, self.duration_ms, self.monitoring_level) util = util.split('\n') mesh_type = {'v': 'v2', 'cores': 8} for stat in util: if 'TPU type' in stat: tpu = stat.replace('TPU type: TPU', '').strip() mesh_type['v'] = tpu elif 'Number of TPU cores' in stat: idx = stat.find('(') mesh_cores = stat[:idx].strip() mesh_type['cores'] = re.search(r'[0-9]', mesh_cores).group() self.mesh = f'{mesh_type["v"]}-{mesh_type["cores"]}' self.tpu_max_mem = _mesh_memory[self.mesh] self.profiler_ver = 'v2' self.tpu_profiler = self.tpu_util
def testSummaryWithCustomTrainingLoop(self): resolver = tpu_cluster_resolver.TPUClusterResolver('') tpu_strategy_util.initialize_tpu_system(resolver) strategy = tpu_strategy_lib.TPUStrategy(resolver) with strategy.scope(): model = distribute_strategy_test.get_model() model.compile('sgd', 'mse') writer = summary_ops_v2.create_file_writer_v2(self.summary_dir) @def_function.function def custom_function(dataset): def _custom_step(features, labels): del labels logits = model(features) with summary_ops_v2.always_record_summaries( ), writer.as_default(): summary_ops_v2.scalar('logits', logits, step=model.optimizer.iterations) return logits iterator = iter(dataset) output = strategy.unwrap( strategy.run(_custom_step, args=(next(iterator)))) return output dataset = strategy.experimental_distribute_dataset( distribute_strategy_test.get_dataset(strategy)) custom_function(dataset)
def _create_tpu_strategy(): global _did_connect_to_cluster # These flags will be defined by tpu_test_wrapper.py. resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "", zone=hasattr(FLAGS, "zone") and FLAGS.zone or None, project=hasattr(FLAGS, "project") and FLAGS.project or None, ) # Only connect once per process, rather than per test method. if hasattr(FLAGS, "tpu") and FLAGS.tpu and not _did_connect_to_cluster: remote.connect_to_cluster(resolver) _did_connect_to_cluster = True topology = tpu_strategy_util.initialize_tpu_system(resolver) device_assignment = None if use_single_core: device_assignment = device_assignment_lib.DeviceAssignment( topology, core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT) # Steps per run is only supported in TF 1.x if tf2.enabled(): return tpu_lib.TPUStrategy(resolver, device_assignment, **kwargs) else: return tpu_lib.TPUStrategyV1(resolver, steps_per_run, device_assignment, **kwargs)
def test_connect(self): # Log full diff on failure. self.maxDiff = None # pylint:disable=invalid-name self.assertCountEqual( EXPECTED_DEVICES_PRE_CONNECT, [device.name for device in config.list_logical_devices()]) resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project ) remote.connect_to_cluster(resolver) expected_devices = EXPECTED_DEVICES_PRE_CONNECT for task in range(FLAGS.num_tpu_devices // DEVICES_PER_TASK): expected_devices.extend([ template.format(task=task) for template in EXPECTED_NEW_DEVICES_AFTER_CONNECT_TEMPLATES ]) self.assertCountEqual( expected_devices, [device.name for device in config.list_logical_devices()]) tpu_strategy_util.initialize_tpu_system(resolver)
def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', 'health': 'HEALTHY' } } cluster_resolver = resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu=['test-tpu-1'], coordinator_name='coordinator', coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = cluster_resolver.cluster_spec() expected_proto = """ job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) self.assertEqual(cluster_resolver.master(), 'grpc://10.1.2.3:8470')
def testNewNetworkEndpointFormat(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'health': 'HEALTHY', 'networkEndpoints': [{ 'ipAddress': '10.2.3.4', 'port': 8470, }] } } cluster_resolver = resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu='test-tpu-1', coordinator_name='coordinator', coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = cluster_resolver.cluster_spec() expected_proto = """ job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) self.assertEqual('grpc://10.2.3.4:8470', cluster_resolver.master())
def main(unused_argv=None): logging.set_verbosity(logging.INFO) tf_version = versions.__version__ print('TensorFlow version %s detected' % tf_version) print('Welcome to the Cloud TPU Profiler v%s' % profiler_version.__version__) if LooseVersion(tf_version) < LooseVersion('1.14.0'): sys.exit('You must install tensorflow >= 1.14.0 to use this plugin.') if not FLAGS.service_addr and not FLAGS.tpu: sys.exit('You must specify either --service_addr or --tpu.') tpu_cluster_resolver = None if FLAGS.service_addr: if FLAGS.tpu: logging.warn('Both --service_addr and --tpu are set. Ignoring ' '--tpu and using --service_addr.') service_addr = FLAGS.service_addr else: try: tpu_cluster_resolver = (resolver.TPUClusterResolver( [FLAGS.tpu], zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)) service_addr = tpu_cluster_resolver.get_master() except (ValueError, TypeError): sys.exit( 'Failed to find TPU %s in zone %s project %s. You may use ' '--tpu_zone and --gcp_project to specify the zone and project of' ' your TPU.' % (FLAGS.tpu, FLAGS.tpu_zone, FLAGS.gcp_project)) service_addr = service_addr.replace('grpc://', '').replace(':8470', ':8466') workers_list = '' if FLAGS.workers_list is not None: workers_list = FLAGS.workers_list elif tpu_cluster_resolver is not None: workers_list = get_workers_list(tpu_cluster_resolver) # If profiling duration was not set by user or set to a non-positive value, # we set it to a default value of 1000ms. duration_ms = FLAGS.duration_ms if FLAGS.duration_ms > 0 else 1000 if FLAGS.monitoring_level > 0: print('Since monitoring level is provided, profile', service_addr, ' for ', FLAGS.duration_ms, ' ms and show metrics for ', FLAGS.num_queries, ' time(s).') monitoring_helper(service_addr, duration_ms, FLAGS.monitoring_level, FLAGS.display_timestamp, FLAGS.num_queries) else: if not FLAGS.logdir: sys.exit('You must specify either --logdir or --monitoring_level.') try: profiler_client.start_tracing(service_addr, os.path.expanduser(FLAGS.logdir), duration_ms, workers_list, FLAGS.include_dataset_ops, FLAGS.num_tracing_attempts) except errors.UnavailableError: sys.exit(0)
def __init__(self, container_strategy, tpu_cluster_resolver=None, steps_per_run=None, device_assignment=None): super(TPUExtended, self).__init__(container_strategy) if tpu_cluster_resolver is None: tpu_cluster_resolver = resolver_lib.TPUClusterResolver("") if steps_per_run is None: # TODO(frankchn): Warn when we are being used by DS/Keras and this is # not specified. steps_per_run = 1 self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) self._device_assignment = device_assignment # Device assignment is currently only supported for 1 core case. if self._device_assignment: assert isinstance(self._device_assignment, device_assignment_lib.DeviceAssignment) if self._device_assignment.num_replicas != 1: raise ValueError("Device assignment is only supported for a single " "core single replica case currently.") if self._device_assignment.num_cores_per_replica != 1: raise ValueError("Device assignment is only supported for a single " "core single replica case currently.") if not all(self._device_assignment.core_assignment[0][0] == [0, 0, 0]): raise ValueError("Device assignment is only supported for a single " "core single replica case currently.") # TODO(jhseu): Switch to DeviceAssignment to support pods and model # parallelism. self._device_index = { d.name: i for i, d in enumerate(self._tpu_metadata.devices) if "device:TPU:" in d.name } self._host_device = self.get_host_cpu_device(0) self._tpu_devices = tuple(sorted(self._device_index.keys())) # Only create variables for the number of replicas we're running. self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] self._device_map = values.ReplicaDeviceMap(self._tpu_devices) # For input: input_device_map = values.ReplicaDeviceMap(tuple( self.get_host_cpu_device(hid) for hid in range(self.num_hosts))) worker_devices = [ (self.get_host(hid), [self.get_host_cpu_device(hid)]) for hid in range(self.num_hosts) ] self._input_workers = input_lib.InputWorkers( input_device_map, worker_devices) # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run self._require_static_shapes = True
def testNoCallComputeMetadata(self): cluster_resolver = resolver.TPUClusterResolver(tpu='grpc://10.1.2.3:8470') self.assertEqual('grpc://10.1.2.3:8470', cluster_resolver.master()) self.assertEqual( server_lib.ClusterSpec({ 'worker': ['10.1.2.3:8470'] }).as_dict(), cluster_resolver.cluster_spec().as_dict())
def _get_strategy(self): self.resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project) remote.connect_to_cluster(self.resolver) tpu_strategy_util.initialize_tpu_system(self.resolver) strategy = tpu_strategy.TPUStrategy(self.resolver) self.num_replicas = strategy.num_replicas_in_sync return strategy
def get_tpu_strategy(): resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project, ) remote.connect_to_cluster(resolver) tpu_strategy_util.initialize_tpu_system(resolver) return tpu_lib.TPUStrategy(resolver)
def _get_strategy(self): self.resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project) if hasattr(self.resolver, '_cloud_tpu_client'): self.resolver._cloud_tpu_client.configure_tpu_version( version='nightly', restart_type='always') remote.connect_to_cluster(self.resolver) tpu_strategy_util.initialize_tpu_system(self.resolver) return tpu_strategy.TPUStrategy(self.resolver)
def testNumAcceleratorsRetryFailure(self, mock_list_devices, mock_eager_list_devices): cluster_resolver = resolver.TPUClusterResolver(tpu='') mock_list_devices.side_effect = errors.DeadlineExceededError( None, None, 'timeout') mock_eager_list_devices.side_effect = errors.DeadlineExceededError( None, None, 'timeout') with self.assertRaises(RuntimeError): cluster_resolver.num_accelerators()
def verifyShouldResolve(self, tpu, should_resolve): cluster_resolver = resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu=tpu, coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map={})) self.assertEqual(should_resolve, cluster_resolver._should_resolve(), "TPU: '%s'" % tpu)
def get_strategy(): resolver = tpu_cluster_resolver.TPUClusterResolver(tpu="grpc://" + os.environ["TPU_IP"]) remote.connect_to_cluster(resolver) topology = tpu_strategy_util.initialize_tpu_system(resolver) print("Device coordinates: ", topology.device_coordinates) device_assignment = tf.python.tpu.device_assignment.DeviceAssignment.build( topology, computation_shape=[1, 1, 1, 1], num_replicas=1) return tpu_strategy.TPUStrategy(resolver, device_assignment=device_assignment)
def testNoCallComputeMetadata(self): cluster_resolver = resolver.TPUClusterResolver(tpu='/bns/foo/bar') self.assertEqual('/bns/foo/bar', cluster_resolver.master()) if ops.executing_eagerly_outside_functions(): self.assertEqual( server_lib.ClusterSpec({ 'worker': ['/bns/foo/bar'] }).as_dict(), cluster_resolver.cluster_spec().as_dict()) else: self.assertEqual(None, cluster_resolver.cluster_spec())
def testGetMasterNoEntries(self): tpu_map = {} with self.assertRaises(ValueError): resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu=[], coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map))
def test_connect(self): self.assertCountEqual(EXPECTED_DEVICES_PRE_CONNECT, config.list_logical_devices()) resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project) remote.connect_to_cluster(resolver) self.assertCountEqual(EXPECTED_DEVICES_AFTER_CONNECT, config.list_logical_devices()) tpu_strategy_util.initialize_tpu_system(resolver)
def test_multiple_initialize_system(self): resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project, ) tpu_strategy_util.initialize_tpu_system(resolver) with test.mock.patch.object(logging, "warning") as mock_log: tpu_strategy_util.initialize_tpu_system(resolver) self.assertRegex(str(mock_log.call_args), "already been initialized")
def testNumAcceleratorsSuccess(self, mock_list_devices, mock_eager_list_devices): devices = [ LogicalDevice('/job:tpu_worker/task:0/device:TPU:0', 'TPU'), LogicalDevice('/job:tpu_worker/task:1/device:TPU:1', 'TPU'), LogicalDevice('/job:tpu_worker/task:2/device:TPU:0', 'TPU'), LogicalDevice('/job:tpu_worker/task:3/device:TPU:1', 'TPU'), LogicalDevice('/job:tpu_worker/task:0/device:TPU:4', 'TPU'), LogicalDevice('/job:tpu_worker/task:1/device:TPU:5', 'TPU'), LogicalDevice('/job:tpu_worker/task:2/device:TPU:4', 'TPU'), LogicalDevice('/job:tpu_worker/task:3/device:TPU:5', 'TPU'), ] device_list = [ session._DeviceAttributes(d.name, d.device_type, 1024, 0) for d in devices ] mock_eager_list_devices.return_value = devices mock_list_devices.return_value = device_list tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'state': 'READY', 'health': 'HEALTHY', 'networkEndpoints': [ { 'ipAddress': '10.2.3.4', 'port': 8470, }, { 'ipAddress': '10.2.3.5', 'port': 8470, }, { 'ipAddress': '10.2.3.6', 'port': 8470, }, { 'ipAddress': '10.2.3.7', 'port': 8470, }, ] } } cluster_resolver = resolver.TPUClusterResolver( project='test-project', zone='us-central1-c', tpu='test-tpu-1', service=self.mock_service_client(tpu_map=tpu_map)) self.assertEqual(cluster_resolver.num_accelerators(), {'TPU': 2})
def _create_tpu_strategy(): resolver = tpu_cluster_resolver.TPUClusterResolver("") topology = tpu_strategy_util.initialize_tpu_system(resolver) device_assignment = None if use_single_core: device_assignment = device_assignment_lib.DeviceAssignment( topology, core_assignment=device_assignment_lib. SINGLE_CORE_ASSIGNMENT) strategy = tpu_lib.TPUStrategy(resolver, steps_per_run=steps_per_run, device_assignment=device_assignment, **kwargs) return strategy
def testPodResolution(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'state': 'READY', 'health': 'HEALTHY', 'networkEndpoints': [ { 'ipAddress': '10.2.3.4', 'port': 8470, }, { 'ipAddress': '10.2.3.5', 'port': 8470, }, { 'ipAddress': '10.2.3.6', 'port': 8470, }, { 'ipAddress': '10.2.3.7', 'port': 8470, }, ] } } cluster_resolver = resolver.TPUClusterResolver( tpu='test-tpu-1', credentials=None, service=self.mock_service_client(tpu_map=tpu_map), coordinator_name='coordinator') actual_cluster_spec = cluster_resolver.cluster_spec() expected_proto = """ job { name: 'coordinator', tasks { key: 0 value: '10.128.1.2:%s'} } job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } tasks { key: 1 value: '10.2.3.5:8470' } tasks { key: 2 value: '10.2.3.6:8470' } tasks { key: 3 value: '10.2.3.7:8470' } } """ % cluster_resolver._coordinator_port self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) self.assertEqual(cluster_resolver.master(), 'grpc://10.2.3.4:8470')
def __init__(self, container_strategy, tpu_cluster_resolver=None, steps_per_run=None, num_cores=None): super(TPUExtended, self).__init__(container_strategy) if tpu_cluster_resolver is None: tpu_cluster_resolver = resolver_lib.TPUClusterResolver("") if steps_per_run is None: # TODO(frankchn): Warn when we are being used by DS/Keras and this is # not specified. steps_per_run = 1 self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata( self._tpu_cluster_resolver) # TODO(sourabhbajaj): Change this from num_cores to metadata_override self._num_cores_override = num_cores # TODO(jhseu): Switch to DeviceAssignment to support pods and model # parallelism. self._device_index = { d.name: i for i, d in enumerate(self._tpu_metadata.devices) if "device:TPU:" in d.name } self._host_device = self.get_host_cpu_device(0) self._tpu_devices = tuple(sorted(self._device_index.keys())) # Only create variables for the number of replicas we're running. self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync] self._device_map = values.ReplicaDeviceMap(self._tpu_devices) # For input: input_device_map = values.ReplicaDeviceMap( tuple( self.get_host_cpu_device(hid) for hid in range(self.num_hosts))) worker_devices = [(self.get_host(hid), [self.get_host_cpu_device(hid)]) for hid in range(self.num_hosts)] self._input_workers = values.InputWorkers(input_device_map, worker_devices) # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run self._require_static_shapes = True # Initialize the TPU devices. self._initialize_tpu()
def _create_tpu_strategy(): global _did_connect_to_cluster try: # Attempt to locally discover the TPU. This will fail for Cloud TPU, in # which case we fall back to the values passed as flags. resolver = tpu_cluster_resolver.TPUClusterResolver() did_automatically_resolve = True except ValueError: did_automatically_resolve = False # These flags will be defined by tpu_test_wrapper.py. resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=hasattr(FLAGS, "tpu") and FLAGS.tpu or "", zone=hasattr(FLAGS, "zone") and FLAGS.zone or None, project=hasattr(FLAGS, "project") and FLAGS.project or None, ) # Only connect once per process, rather than per test method. if getattr(FLAGS, "tpu", "") or did_automatically_resolve: if not _did_connect_to_cluster: remote.connect_to_cluster(resolver) _did_connect_to_cluster = True topology = tpu_strategy_util.initialize_tpu_system(resolver) device_assignment = None if use_single_core: device_assignment = device_assignment_lib.DeviceAssignment( topology, core_assignment=device_assignment_lib.SINGLE_CORE_ASSIGNMENT) # Steps per run is only supported in TF 1.x if tf2.enabled(): return tpu_lib.TPUStrategy(resolver, device_assignment, **kwargs) else: return tpu_lib.TPUStrategyV1(resolver, steps_per_run, device_assignment, **kwargs)