Exemple #1
0
    def test_healthy_pod_service_cluster(self):
        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-32',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                } for ip in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        service_workers = cr.get_service_workers()

        expected = [
            ServiceWorker(internal_ip=f'10.0.0.{ip}',
                          port='8470',
                          machine_type='v3-32',
                          zone='fake-zone',
                          runtime_version='pytorch-nightly',
                          tpu='fake-pod') for ip in range(4)
        ]
        self.assertCountEqual(expected, service_workers)
Exemple #2
0
    def test_healthy_vm_list_client_cluster(self):
        # Arrange
        list_instances_map = {}
        instance_resp_map = {
            'fake-ig-' + c:
            gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16',
                                         '10.0.0.' + ip, 'RUNNING')
            for c, ip in zip('abcd', '0123')
        }
        compute_service = build_mock_compute_service(instance_resp_map,
                                                     list_instances_map)
        self.mock_discovery.side_effect = build_mock_services_fn(
            compute_service)

        # Act
        vms = ['fake-ig-a', 'fake-ig-b', 'fake-ig-c', 'fake-ig-d']
        cr = ClusterResolver(['fake-tpu'], vms=vms)
        vm_cluster = cr.get_client_workers()

        # Assert
        expected = [
            ClientWorker(internal_ip='10.0.0.' + ip,
                         machine_type='n1-standard-16',
                         zone='fake-zone',
                         hostname='fake-ig-' + c)
            for c, ip in zip('abcd', '0123')
        ]
        self.assertCountEqual(expected, vm_cluster)
Exemple #3
0
    def test_healthy_tpuvm_cluster(self):
        # Using TPUVM flavor of metadata.
        mock.patch.object(ClusterResolver, 'get_instance_metadata',
                          mock_request_tpuvm_metadata).start()
        noop_compute_service = build_mock_compute_service({}, {})
        self.mock_discovery.side_effect = build_mock_services_fn(
            noop_compute_service)

        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'v2-nightly',
                'accelerator_type':
                'v3-32',
                'api_version':
                'V2_ALPHA1',
                'network_endpoints': [{
                    'ipAddress': f'10.1.0.{index}',
                    'port': '8470',
                } for index in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        cluster = cr.get_cluster()

        expected_client_workers = [
            ClientWorker(internal_ip=f'10.1.0.{index}',
                         machine_type='v3-32',
                         zone='fake-zone',
                         hostname=f'{TPUVM_HOSTNAME_PREFIX}{index}')
            for index in range(4)
        ]
        expected_service_workers = [
            ServiceWorker(internal_ip=f'10.1.0.{ip}',
                          port='8470',
                          machine_type='v3-32',
                          zone='fake-zone',
                          runtime_version='v2-nightly',
                          tpu='fake-pod') for ip in range(4)
        ]
        expected = Cluster(expected_client_workers,
                           expected_service_workers,
                           client_master_ip='10.1.0.0')
        self.assertEqual(expected, cluster)
        mock.patch.object(ClusterResolver, 'get_instance_metadata',
                          mock_request_metadata).start()
Exemple #4
0
    def test_unknown_health_pod_service_cluster(self):
        noop_compute_service = build_mock_compute_service({}, {})
        self.mock_discovery.side_effect = build_mock_services_fn(
            noop_compute_service)
        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-32',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                } for ip in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        self.assertRaisesRegex(RuntimeError,
                               'TPU fake-pod is not HEALTHY yet.*',
                               cr.get_service_workers)
Exemple #5
0
    def test_non_ready_sea_service_cluster(self):
        noop_compute_service = build_mock_compute_service({}, {})
        self.mock_discovery.side_effect = build_mock_services_fn(
            noop_compute_service)

        tpu_map = {
            f'fake-tpu-{ip}': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-8',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                }],
            }
            for ip in range(3)
        }
        tpu_map['fake-tpu-3'] = {
            'state': 'CREATING',
            'runtime_version': 'pytorch-nightly',
            'accelerator_type': 'v3-8',
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        self.assertRaisesRegex(RuntimeError,
                               'TPU fake-tpu-3 is not READY yet.*',
                               cr.get_service_workers)
Exemple #6
0
    def test_unhealthy_pod_service_cluster(self):
        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'UNHEALTHY_TENSORFLOW',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-128',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                } for ip in range(16)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        self.assertRaisesRegex(RuntimeError,
                               'TPU fake-pod is not HEALTHY yet.*',
                               cr.get_service_workers)
Exemple #7
0
    def __init__(self,
                 cluster,
                 docker_container=None,
                 docker_image=None,
                 docker_run_flags=None,
                 conda_env=None,
                 env_vars=None):
        self._cluster = cluster
        self._initialize()
        client_master_ip = ClusterResolver.get_instance_metadata(
            'instance/network-interfaces/0/ip')
        self._client_master = next(
            filter(lambda cw: cw.get_internal_ip() == client_master_ip,
                   self._cluster.get_client_workers()))
        self.logger = self._get_logger()
        self.docker_container = docker_container or self.DEFAULT_CONTAINER_NAME
        self.docker_image = docker_image
        self.docker_run_flags = list(
            docker_run_flags) if docker_run_flags else []
        self.conda_env = conda_env
        self.env_vars = list(env_vars) if env_vars else []

        for env_var in self.env_vars:
            if re.match('\w*=\w*', env_var) is None:
                raise ValueError(
                    ('Environment variable to distribute ({}) should follow '
                     'the form: X=Y').format(env_var))
            for dist_var in self.DIST_ENV_VARS:
                if re.match('{}=.*'.format(dist_var), env_var):
                    raise ValueError((
                        '{} should not be in the training command provided as they'
                        ' will interfere with the values set for distributed'
                        ' training'.format(dist_var)))
Exemple #8
0
    def test_healthy_sea_service_cluster(self):
        noop_compute_service = build_mock_compute_service({}, {})
        self.mock_discovery.side_effect = build_mock_services_fn(
            noop_compute_service)
        tpu_map = {
            f'fake-tpu-{ip}': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-8',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                }],
            }
            for ip in range(256)
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        service_workers = cr.get_service_workers()

        expected = [
            ServiceWorker(internal_ip=f'10.0.0.{ip}',
                          port='8470',
                          machine_type='v3-8',
                          zone='fake-zone',
                          runtime_version='pytorch-nightly',
                          tpu=f'fake-tpu-{ip}') for ip in range(256)
        ]
        self.assertCountEqual(expected, service_workers)
Exemple #9
0
    def test_bad_cluster(self):
        list_instances_map = {
            'fake-ig': {
                'kind':
                'compute#instanceGroupsListInstances',
                'items': [
                    gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING')
                    for c in 'abc'
                ],
            },
        }
        instance_resp_map = {
            'fake-ig-' + c:
            gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16',
                                         '10.0.0.' + ip, 'RUNNING')
            for c, ip in zip('abcd', '0123')
        }
        compute_service = build_mock_compute_service(instance_resp_map,
                                                     list_instances_map)
        self.mock_discovery.side_effect = build_mock_services_fn(
            compute_service)

        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-32',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                } for ip in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        self.assertRaisesRegex(
            RuntimeError,
            'The client_workers and service_workers must have a 1:1 mapping',
            cr.get_cluster)
Exemple #10
0
    def test_unhealthy_client_cluster(self):
        # Arrange
        list_instances_map = {
            'fake-ig': {
                'kind':
                'compute#instanceGroupsListInstances',
                'items': [
                    gen_fake_ig_list_instances_entry('fake-ig-a', 'RUNNING'),
                    gen_fake_ig_list_instances_entry('fake-ig-b',
                                                     'PROVISIONING'),
                    gen_fake_ig_list_instances_entry('fake-ig-c', 'RUNNING'),
                    gen_fake_ig_list_instances_entry('fake-ig-d', 'RUNNING'),
                ],
            },
        }
        instance_resp_map = {
            'fake-ig-a':
            gen_fake_instances_get_entry('fake-ig-a', 'n1-standard-16',
                                         '10.0.0.0', 'RUNNING'),
            'fake-ig-b':
            gen_fake_instances_get_entry('fake-ig-b', 'n1-standard-16',
                                         '10.0.0.1', 'PROVISIONING'),
            'fake-ig-c':
            gen_fake_instances_get_entry('fake-ig-c', 'n1-standard-16',
                                         '10.0.0.2', 'RUNNING'),
            'fake-ig-d':
            gen_fake_instances_get_entry('fake-ig-d', 'n1-standard-16',
                                         '10.0.0.3', 'RUNNING'),
        }
        compute_service = build_mock_compute_service(instance_resp_map,
                                                     list_instances_map)
        self.mock_discovery.side_effect = build_mock_services_fn(
            compute_service)

        # Act
        cr = ClusterResolver(['fake-tpu'])

        # Assert
        self.assertRaisesRegex(RuntimeError,
                               'Instance fake-ig-b is not running yet.*',
                               cr.get_client_workers)
Exemple #11
0
    def test_empty_instance_group_client_cluster(self):
        list_instances_map = {
            'fake-ig': {
                'kind': 'compute#instanceGroupsListInstances',
                'items': [],
            },
        }
        instance_resp_map = {
            'fake-ig-a':
            gen_fake_instances_get_entry('fake-ig-a', 'n1-standard-16',
                                         '10.0.0.0', 'RUNNING'),
        }
        compute_service = build_mock_compute_service(instance_resp_map,
                                                     list_instances_map)
        self.mock_discovery.side_effect = build_mock_services_fn(
            compute_service)

        # Act
        cr = ClusterResolver(['fake-tpu'])

        # Assert
        self.assertRaisesRegex(RuntimeError,
                               '.*vms is empty in instance group.*',
                               cr.get_client_workers)
Exemple #12
0
        help='Name of the conda environment if running with conda.')

    parser.add_argument('--env',
                        action='append',
                        type=str,
                        help='List of environment variables to distribute.')
    parser.add_argument(
        'positional',
        nargs='+',
        type=str,
        help='The python command to launch training including model parameters.'
    )

    FLAGS = parser.parse_args()
    tpuvm_mode = False
    accel_type = ClusterResolver.get_instance_metadata(
        'instance/attributes/accelerator-type')
    if re.match(r'v[0-9]+-[0-9]+', accel_type):
        # Only TPUVM will carry the accelerator-type metadata
        tpuvm_mode = True

    if (FLAGS.docker_container or FLAGS.docker_image
            or FLAGS.docker_run_flag) and FLAGS.conda_env:
        raise ValueError('Docker Setup arguments and Conda Setup'
                         ' arguments are mutually exclusive.')

    # Resolve VM and TPU clusters.
    cluster_resolver = ClusterResolver(FLAGS.tpu,
                                       vms=FLAGS.vm,
                                       tpuvm_mode=tpuvm_mode)
    cluster = cluster_resolver.get_cluster()
    executor = DistributedExecutor(cluster,
Exemple #13
0
    def test_healthy_cluster(self):
        list_instances_map = {
            'fake-ig': {
                'kind':
                'compute#instanceGroupsListInstances',
                'items': [
                    gen_fake_ig_list_instances_entry('fake-ig-' + c, 'RUNNING')
                    for c in 'abcd'
                ],
            },
        }
        instance_resp_map = {
            'fake-ig-' + c:
            gen_fake_instances_get_entry('fake-ig-' + c, 'n1-standard-16',
                                         '10.0.0.' + ip, 'RUNNING')
            for c, ip in zip('abcd', '0123')
        }
        compute_service = build_mock_compute_service(instance_resp_map,
                                                     list_instances_map)
        self.mock_discovery.side_effect = build_mock_services_fn(
            compute_service)

        tpu_map = {
            'fake-pod': {
                'state':
                'READY',
                'health':
                'HEALTHY',
                'runtime_version':
                'pytorch-nightly',
                'accelerator_type':
                'v3-32',
                'network_endpoints': [{
                    'ipAddress': f'10.0.0.{ip}',
                    'port': '8470'
                } for ip in range(4)],
            }
        }
        self.mock_ctc.side_effect = build_mock_cloud_tpu_client_library(
            tpu_map)

        tpus = list(tpu_map.keys())
        cr = ClusterResolver(tpus)
        cluster = cr.get_cluster()

        expected_client_workers = [
            ClientWorker(internal_ip='10.0.0.' + ip,
                         machine_type='n1-standard-16',
                         zone='fake-zone',
                         hostname='fake-ig-' + c)
            for c, ip in zip('abcd', '0123')
        ]
        expected_service_workers = [
            ServiceWorker(internal_ip=f'10.0.0.{ip}',
                          port='8470',
                          machine_type='v3-32',
                          zone='fake-zone',
                          runtime_version='pytorch-nightly',
                          tpu='fake-pod') for ip in range(4)
        ]
        expected = Cluster(expected_client_workers, expected_service_workers)
        self.assertEqual(expected, cluster)
Exemple #14
0
        help='Name of the conda environment if running with conda.')

    parser.add_argument('--env',
                        action='append',
                        type=str,
                        help='List of environment variables to distribute.')
    parser.add_argument(
        'positional',
        nargs='+',
        type=str,
        help='The python command to launch training including model parameters.'
    )

    FLAGS = parser.parse_args()

    if (FLAGS.docker_container or FLAGS.docker_image
            or FLAGS.docker_run_flag) and FLAGS.conda_env:
        raise ValueError('Docker Setup arguments and Conda Setup'
                         ' arguments are mutually exclusive.')

    # Resolve VM and TPU clusters.
    cluster_resolver = ClusterResolver(FLAGS.tpu, vms=FLAGS.vm)
    cluster = cluster_resolver.get_cluster()
    executor = DistributedExecutor(cluster,
                                   docker_container=FLAGS.docker_container,
                                   docker_image=FLAGS.docker_image,
                                   docker_run_flags=FLAGS.docker_run_flag,
                                   conda_env=FLAGS.conda_env,
                                   env_vars=FLAGS.env)
    executor.run(FLAGS.positional)
Exemple #15
0
                        help='Port that XRT local service will be start on.')
    parser.add_argument(
        'positional',
        nargs='+',
        type=str,
        help='The python command to launch training including model parameters.'
    )

    FLAGS = parser.parse_args()

    if (FLAGS.docker_container or FLAGS.docker_image
            or FLAGS.docker_run_flag) and FLAGS.conda_env:
        raise ValueError('Docker Setup arguments and Conda Setup'
                         ' arguments are mutually exclusive.')

    # Resolve VM and TPU clusters.
    cluster_resolver = ClusterResolver(FLAGS.tpu, vms=FLAGS.vm)
    cluster = cluster_resolver.get_cluster()
    tpuvm_mode = cluster_resolver.get_tpuvm_mode()
    executor = DistributedExecutor(
        cluster,
        docker_container=FLAGS.docker_container,
        docker_image=FLAGS.docker_image,
        docker_run_flags=FLAGS.docker_run_flag,
        conda_env=FLAGS.conda_env,
        env_vars=FLAGS.env,
        restart_server=FLAGS.restart_tpuvm_pod_server,
        tpuvm_mode=tpuvm_mode,
        tpuvm_server_port=FLAGS.tpuvm_server_port)
    executor.run(FLAGS.positional)