def test_healthy_tpuvm_cluster(self): # Using TPUVM flavor of metadata. mock.patch.object(ClusterResolver, 'get_instance_metadata', mock_request_tpuvm_metadata).start() noop_compute_service = build_mock_compute_service({}, {}) self.mock_discovery.side_effect = build_mock_services_fn( noop_compute_service) tpu_map = { 'fake-pod': { 'state': 'READY', 'health': 'HEALTHY', 'runtime_version': 'v2-nightly', 'accelerator_type': 'v3-32', 'api_version': 'V2_ALPHA1', 'network_endpoints': [{ 'ipAddress': f'10.1.0.{index}', 'port': '8470', } for index in range(4)], } } self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library( tpu_map) tpus = list(tpu_map.keys()) cr = ClusterResolver(tpus) cluster = cr.get_cluster() expected_client_workers = [ ClientWorker(internal_ip=f'10.1.0.{index}', machine_type='v3-32', zone='fake-zone', hostname=f'{TPUVM_HOSTNAME_PREFIX}{index}') for index in range(4) ] expected_service_workers = [ ServiceWorker(internal_ip=f'10.1.0.{ip}', port='8470', machine_type='v3-32', zone='fake-zone', runtime_version='v2-nightly', tpu='fake-pod') for ip in range(4) ] expected = Cluster(expected_client_workers, expected_service_workers, client_master_ip='10.1.0.0') self.assertEqual(expected, cluster) mock.patch.object(ClusterResolver, 'get_instance_metadata', mock_request_metadata).start()
type=str, help='The python command to launch training including model parameters.' ) FLAGS = parser.parse_args() tpuvm_mode = False accel_type = ClusterResolver.get_instance_metadata( 'instance/attributes/accelerator-type') if re.match(r'v[0-9]+-[0-9]+', accel_type): # Only TPUVM will carry the accelerator-type metadata tpuvm_mode = True if (FLAGS.docker_container or FLAGS.docker_image or FLAGS.docker_run_flag) and FLAGS.conda_env: raise ValueError('Docker Setup arguments and Conda Setup' ' arguments are mutually exclusive.') # Resolve VM and TPU clusters. cluster_resolver = ClusterResolver(FLAGS.tpu, vms=FLAGS.vm, tpuvm_mode=tpuvm_mode) cluster = cluster_resolver.get_cluster() executor = DistributedExecutor(cluster, docker_container=FLAGS.docker_container, docker_image=FLAGS.docker_image, docker_run_flags=FLAGS.docker_run_flag, conda_env=FLAGS.conda_env, env_vars=FLAGS.env, tpuvm_mode=tpuvm_mode) executor.run(FLAGS.positional)
def test_healthy_cluster(self): list_instances_map = { 'fake-ig': { 'kind': 'compute#instanceGroupsListInstances', 'items': [ gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING') for c in 'abcd' ], }, } instance_resp_map = { 'fake-ig-' + c: gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16', '10.0.0.' + ip, 'RUNNING') for c, ip in zip('abcd', '0123') } compute_service = build_mock_compute_service(instance_resp_map, list_instances_map) self.mock_discovery.side_effect = build_mock_services_fn( compute_service) tpu_map = { 'fake-pod': { 'state': 'READY', 'health': 'HEALTHY', 'runtime_version': 'pytorch-nightly', 'accelerator_type': 'v3-32', 'network_endpoints': [{ 'ipAddress': f'10.0.0.{ip}', 'port': '8470' } for ip in range(4)], } } self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library( tpu_map) tpus = list(tpu_map.keys()) cr = ClusterResolver(tpus) cluster = cr.get_cluster() expected_client_workers = [ ClientWorker(internal_ip='10.0.0.' + ip, machine_type='n1-standard-16', zone='fake-zone', hostname='fake-ig-' + c) for c, ip in zip('abcd', '0123') ] expected_service_workers = [ ServiceWorker(internal_ip=f'10.0.0.{ip}', port='8470', machine_type='v3-32', zone='fake-zone', runtime_version='pytorch-nightly', tpu='fake-pod') for ip in range(4) ] expected = Cluster(expected_client_workers, expected_service_workers) self.assertEqual(expected, cluster)