コード例 #1
0
def dist_train(opt, node_list):
    # We can determine the init method automatically for Slurm.
    try:
        # Figure out the main host, and which rank we are.
        hostnames = subprocess.check_output(
            ['scontrol', 'show', 'hostnames', node_list])
        main_host = hostnames.split()[0].decode('utf-8')
        distributed_rank = int(os.environ['SLURM_PROCID'])
        if opt.get('model_parallel'):
            # -1 signals to multiprocessing_train to use all GPUs available.
            # (A value of None signals to multiprocessing_train to use the GPU
            # corresponding to the rank.
            device_id = -1
        else:
            device_id = int(os.environ['SLURM_LOCALID'])
        port = opt['port']
        logging.info(
            f'Initializing host {socket.gethostname()} as rank {distributed_rank}, '
            f'main is {main_host}')
        # Begin distributed training
        multiprocess_train(distributed_rank, opt, port, 0, device_id,
                           main_host)
    except subprocess.CalledProcessError as e:
        # scontrol failed
        raise e
    except FileNotFoundError:
        # Slurm is not installed
        raise RuntimeError('SLURM does not appear to be installed.')
コード例 #2
0
ファイル: distributed_train.py プロジェクト: shatu/ParlAI
def main():
    # double check we're using SLURM
    node_list = os.environ.get('SLURM_JOB_NODELIST')
    if node_list is None:
        raise RuntimeError(
            'Does not appear to be in a SLURM environment. '
            'You should not call this script directly; see launch_distributed.py'
        )

    parser = single_train.setup_args()
    parser.add_distributed_training_args()
    parser.add_argument('--port',
                        type=int,
                        default=61337,
                        help='TCP port number')
    opt = parser.parse_args(print_args=(os.environ['SLURM_PROCID'] == '0'))

    # We can determine the init method automatically for Slurm.
    try:
        # Figure out the main host, and which rank we are.
        hostnames = subprocess.check_output(
            ['scontrol', 'show', 'hostnames', node_list])
        main_host = hostnames.split()[0].decode('utf-8')
        distributed_rank = int(os.environ['SLURM_PROCID'])
        if opt.get('model_parallel'):
            # -1 signals to multiprocessing_train to use all GPUs available.
            # (A value of None signals to multiprocessing_train to use the GPU
            # corresponding to the rank.
            device_id = -1
        else:
            device_id = int(os.environ['SLURM_LOCALID'])
        port = opt['port']
        logging.info(
            f'Initializing host {socket.gethostname()} as rank {distributed_rank}, '
            f'main is {main_host}')
        # Begin distributed training
        multiprocess_train(distributed_rank, opt, port, 0, device_id,
                           main_host)
    except subprocess.CalledProcessError as e:
        # scontrol failed
        raise e
    except FileNotFoundError:
        # Slurm is not installed
        raise RuntimeError('SLURM does not appear to be installed.')