Пример #1
0
    def __init__(self,
                 cluster,
                 docker_container=None,
                 docker_image=None,
                 docker_run_flags=None,
                 conda_env=None,
                 env_vars=None):
        self._cluster = cluster
        self._initialize()
        client_master_ip = ClusterResolver.get_instance_metadata(
            'instance/network-interfaces/0/ip')
        self._client_master = next(
            filter(lambda cw: cw.get_internal_ip() == client_master_ip,
                   self._cluster.get_client_workers()))
        self.logger = self._get_logger()
        self.docker_container = docker_container or self.DEFAULT_CONTAINER_NAME
        self.docker_image = docker_image
        self.docker_run_flags = list(
            docker_run_flags) if docker_run_flags else []
        self.conda_env = conda_env
        self.env_vars = list(env_vars) if env_vars else []

        for env_var in self.env_vars:
            if re.match('\w*=\w*', env_var) is None:
                raise ValueError(
                    ('Environment variable to distribute ({}) should follow '
                     'the form: X=Y').format(env_var))
            for dist_var in self.DIST_ENV_VARS:
                if re.match('{}=.*'.format(dist_var), env_var):
                    raise ValueError((
                        '{} should not be in the training command provided as they'
                        ' will interfere with the values set for distributed'
                        ' training'.format(dist_var)))
Пример #2
0
        help='Name of the conda environment if running with conda.')

    parser.add_argument('--env',
                        action='append',
                        type=str,
                        help='List of environment variables to distribute.')
    parser.add_argument(
        'positional',
        nargs='+',
        type=str,
        help='The python command to launch training including model parameters.'
    )

    FLAGS = parser.parse_args()
    tpuvm_mode = False
    accel_type = ClusterResolver.get_instance_metadata(
        'instance/attributes/accelerator-type')
    if re.match(r'v[0-9]+-[0-9]+', accel_type):
        # Only TPUVM will carry the accelerator-type metadata
        tpuvm_mode = True

    if (FLAGS.docker_container or FLAGS.docker_image
            or FLAGS.docker_run_flag) and FLAGS.conda_env:
        raise ValueError('Docker Setup arguments and Conda Setup'
                         ' arguments are mutually exclusive.')

    # Resolve VM and TPU clusters.
    cluster_resolver = ClusterResolver(FLAGS.tpu,
                                       vms=FLAGS.vm,
                                       tpuvm_mode=tpuvm_mode)
    cluster = cluster_resolver.get_cluster()
    executor = DistributedExecutor(cluster,