Exemple #1
0
    def test_healthy_tpuvm_cluster(self):
        # Using TPUVM flavor of metadata.
        mock.patch.object(ClusterResolver, 'get_instance_metadata',
                          mock_request_tpuvm_metadata).start()
        noop_compute_service = build_mock_compute_service({}, {})
        self.mock_discovery.side_effect = build_mock_services_fn(
            noop_compute_service)

        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'v2-nightly',
                'accelerator_type':
                'v3-32',
                'api_version':
                'V2_ALPHA1',
                'network_endpoints': [{
                    'ipAddress': f'10.1.0.{index}',
                    'port': '8470',
                } for index in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        cluster = cr.get_cluster()

        expected_client_workers = [
            ClientWorker(internal_ip=f'10.1.0.{index}',
                         machine_type='v3-32',
                         zone='fake-zone',
                         hostname=f'{TPUVM_HOSTNAME_PREFIX}{index}')
            for index in range(4)
        ]
        expected_service_workers = [
            ServiceWorker(internal_ip=f'10.1.0.{ip}',
                          port='8470',
                          machine_type='v3-32',
                          zone='fake-zone',
                          runtime_version='v2-nightly',
                          tpu='fake-pod') for ip in range(4)
        ]
        expected = Cluster(expected_client_workers,
                           expected_service_workers,
                           client_master_ip='10.1.0.0')
        self.assertEqual(expected, cluster)
        mock.patch.object(ClusterResolver, 'get_instance_metadata',
                          mock_request_metadata).start()
Exemple #2
0
        type=str,
        help='The python command to launch training including model parameters.'
    )

    FLAGS = parser.parse_args()
    tpuvm_mode = False
    accel_type = ClusterResolver.get_instance_metadata(
        'instance/attributes/accelerator-type')
    if re.match(r'v[0-9]+-[0-9]+', accel_type):
        # Only TPUVM will carry the accelerator-type metadata
        tpuvm_mode = True

    if (FLAGS.docker_container or FLAGS.docker_image
            or FLAGS.docker_run_flag) and FLAGS.conda_env:
        raise ValueError('Docker Setup arguments and Conda Setup'
                         ' arguments are mutually exclusive.')

    # Resolve VM and TPU clusters.
    cluster_resolver = ClusterResolver(FLAGS.tpu,
                                       vms=FLAGS.vm,
                                       tpuvm_mode=tpuvm_mode)
    cluster = cluster_resolver.get_cluster()
    executor = DistributedExecutor(cluster,
                                   docker_container=FLAGS.docker_container,
                                   docker_image=FLAGS.docker_image,
                                   docker_run_flags=FLAGS.docker_run_flag,
                                   conda_env=FLAGS.conda_env,
                                   env_vars=FLAGS.env,
                                   tpuvm_mode=tpuvm_mode)
    executor.run(FLAGS.positional)
Exemple #3
0
    def test_healthy_cluster(self):
        list_instances_map = {
            'fake-ig': {
                'kind':
                'compute#instanceGroupsListInstances',
                'items': [
                    gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING')
                    for c in 'abcd'
                ],
            },
        }
        instance_resp_map = {
            'fake-ig-' + c:
            gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16',
                                         '10.0.0.' + ip, 'RUNNING')
            for c, ip in zip('abcd', '0123')
        }
        compute_service = build_mock_compute_service(instance_resp_map,
                                                     list_instances_map)
        self.mock_discovery.side_effect = build_mock_services_fn(
            compute_service)

        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-32',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                } for ip in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        cluster = cr.get_cluster()

        expected_client_workers = [
            ClientWorker(internal_ip='10.0.0.' + ip,
                         machine_type='n1-standard-16',
                         zone='fake-zone',
                         hostname='fake-ig-' + c)
            for c, ip in zip('abcd', '0123')
        ]
        expected_service_workers = [
            ServiceWorker(internal_ip=f'10.0.0.{ip}',
                          port='8470',
                          machine_type='v3-32',
                          zone='fake-zone',
                          runtime_version='pytorch-nightly',
                          tpu='fake-pod') for ip in range(4)
        ]
        expected = Cluster(expected_client_workers, expected_service_workers)
        self.assertEqual(expected, cluster)