def dist_train(opt, node_list): # We can determine the init method automatically for Slurm. try: # Figure out the main host, and which rank we are. hostnames = subprocess.check_output( ['scontrol', 'show', 'hostnames', node_list]) main_host = hostnames.split()[0].decode('utf-8') distributed_rank = int(os.environ['SLURM_PROCID']) if opt.get('model_parallel'): # -1 signals to multiprocessing_train to use all GPUs available. # (A value of None signals to multiprocessing_train to use the GPU # corresponding to the rank. device_id = -1 else: device_id = int(os.environ['SLURM_LOCALID']) port = opt['port'] logging.info( f'Initializing host {socket.gethostname()} as rank {distributed_rank}, ' f'main is {main_host}') # Begin distributed training multiprocess_train(distributed_rank, opt, port, 0, device_id, main_host) except subprocess.CalledProcessError as e: # scontrol failed raise e except FileNotFoundError: # Slurm is not installed raise RuntimeError('SLURM does not appear to be installed.')
def main(): # double check we're using SLURM node_list = os.environ.get('SLURM_JOB_NODELIST') if node_list is None: raise RuntimeError( 'Does not appear to be in a SLURM environment. ' 'You should not call this script directly; see launch_distributed.py' ) parser = single_train.setup_args() parser.add_distributed_training_args() parser.add_argument('--port', type=int, default=61337, help='TCP port number') opt = parser.parse_args(print_args=(os.environ['SLURM_PROCID'] == '0')) # We can determine the init method automatically for Slurm. try: # Figure out the main host, and which rank we are. hostnames = subprocess.check_output( ['scontrol', 'show', 'hostnames', node_list]) main_host = hostnames.split()[0].decode('utf-8') distributed_rank = int(os.environ['SLURM_PROCID']) if opt.get('model_parallel'): # -1 signals to multiprocessing_train to use all GPUs available. # (A value of None signals to multiprocessing_train to use the GPU # corresponding to the rank. device_id = -1 else: device_id = int(os.environ['SLURM_LOCALID']) port = opt['port'] logging.info( f'Initializing host {socket.gethostname()} as rank {distributed_rank}, ' f'main is {main_host}') # Begin distributed training multiprocess_train(distributed_rank, opt, port, 0, device_id, main_host) except subprocess.CalledProcessError as e: # scontrol failed raise e except FileNotFoundError: # Slurm is not installed raise RuntimeError('SLURM does not appear to be installed.')