def __init__(self, cluster, docker_container=None, docker_image=None, docker_run_flags=None, conda_env=None, env_vars=None): self._cluster = cluster self._initialize() client_master_ip = ClusterResolver.get_instance_metadata( 'instance/network-interfaces/0/ip') self._client_master = next( filter(lambda cw: cw.get_internal_ip() == client_master_ip, self._cluster.get_client_workers())) self.logger = self._get_logger() self.docker_container = docker_container or self.DEFAULT_CONTAINER_NAME self.docker_image = docker_image self.docker_run_flags = list( docker_run_flags) if docker_run_flags else [] self.conda_env = conda_env self.env_vars = list(env_vars) if env_vars else [] for env_var in self.env_vars: if re.match('\w*=\w*', env_var) is None: raise ValueError( ('Environment variable to distribute ({}) should follow ' 'the form: X=Y').format(env_var)) for dist_var in self.DIST_ENV_VARS: if re.match('{}=.*'.format(dist_var), env_var): raise ValueError(( '{} should not be in the training command provided as they' ' will interfere with the values set for distributed' ' training'.format(dist_var)))
help='Name of the conda environment if running with conda.') parser.add_argument('--env', action='append', type=str, help='List of environment variables to distribute.') parser.add_argument( 'positional', nargs='+', type=str, help='The python command to launch training including model parameters.' ) FLAGS = parser.parse_args() tpuvm_mode = False accel_type = ClusterResolver.get_instance_metadata( 'instance/attributes/accelerator-type') if re.match(r'v[0-9]+-[0-9]+', accel_type): # Only TPUVM will carry the accelerator-type metadata tpuvm_mode = True if (FLAGS.docker_container or FLAGS.docker_image or FLAGS.docker_run_flag) and FLAGS.conda_env: raise ValueError('Docker Setup arguments and Conda Setup' ' arguments are mutually exclusive.') # Resolve VM and TPU clusters. cluster_resolver = ClusterResolver(FLAGS.tpu, vms=FLAGS.vm, tpuvm_mode=tpuvm_mode) cluster = cluster_resolver.get_cluster() executor = DistributedExecutor(cluster,