Exemplo n.º 1
0
def main_cli():
    parser = argparse.ArgumentParser('CEDR model training and validation')

    model_init_utils.add_model_init_basic_args(parser, True)

    parser.add_argument('--datafiles', metavar='data files', help='data files: docs & queries',
                        type=argparse.FileType('rt'), nargs='+', required=True)

    parser.add_argument('--qrels', metavar='QREL file', help='QREL file',
                        type=argparse.FileType('rt'), required=True)

    parser.add_argument('--train_pairs', metavar='paired train data', help='paired train data',
                        type=argparse.FileType('rt'), required=True)

    parser.add_argument('--valid_run', metavar='validation file', help='validation file',
                        type=argparse.FileType('rt'), required=True)

    parser.add_argument('--model_out_dir',
                        metavar='model out dir', help='an output directory for the trained model',
                        required=True)

    parser.add_argument('--epoch_qty', metavar='# of epochs', help='# of epochs',
                        type=int, default=10)

    parser.add_argument('--no_cuda', action='store_true')

    parser.add_argument('--warmup_pct', metavar='warm-up fraction',
                        default=None, type=float,
                        help='use a warm-up/cool-down learning-reate schedule')

    parser.add_argument('--device_qty', type=int, metavar='# of device for multi-GPU training',
                        default=1, help='# of GPUs for multi-GPU training')

    parser.add_argument('--batch_sync_qty', metavar='# of batches before model sync',
                        type=int, default=4, help='Model syncronization frequency for multi-GPU trainig in the # of batche')

    parser.add_argument('--master_port', type=int, metavar='pytorch master port',
                        default=None, help='pytorch master port for multi-GPU training')

    parser.add_argument('--print_grads', action='store_true',
                        help='print gradient norms of parameters')

    parser.add_argument('--save_snapshots', action='store_true',
                        help='save model after each epoch')

    parser.add_argument('--seed', metavar='random seed', help='random seed',
                        type=int, default=42)

    parser.add_argument('--loss_margin', metavar='loss margin', help='Margin in the margin loss',
                        type=float, default=1)

    parser.add_argument('--init_lr', metavar='init learn. rate',
                        type=float, default=0.001, help='Initial learning rate for BERT-unrelated parameters')

    parser.add_argument('--init_bert_lr', metavar='init BERT learn. rate',
                        type=float, default=0.00005, help='Initial learning rate for BERT parameters')

    parser.add_argument('--epoch_lr_decay', metavar='epoch LR decay',
                        type=float, default=1.0, help='Per-epoch learning rate decay')

    parser.add_argument('--weight_decay', metavar='weight decay',
                        type=float, default=0.0, help='optimizer weight decay')

    parser.add_argument('--batch_size', metavar='batch size',
                        type=int, default=32, help='batch size')

    parser.add_argument('--batch_size_val', metavar='val batch size',
                        type=int, default=32, help='validation batch size')

    parser.add_argument('--backprop_batch_size', metavar='backprop batch size',
                        type=int, default=12,
                        help='batch size for each backprop step')

    parser.add_argument('--batches_per_train_epoch', metavar='# of rand. batches per epoch',
                        type=int, default=0,
                        help='# of random batches per epoch: 0 tells to use all data')

    parser.add_argument('--max_query_val', metavar='max # of val queries',
                        type=int, default=0,
                        help='max # of validation queries: 0 tells to use all data')

    parser.add_argument('--no_shuffle_train', action='store_true',
                        help='disabling shuffling of training data')

    parser.add_argument('--use_external_eval', action='store_true',
                        help='use external eval tools: gdeval or trec_eval')

    parser.add_argument('--eval_metric', choices=METRIC_LIST, default=METRIC_LIST[0],
                        help='Metric list: ' +  ','.join(METRIC_LIST), 
                        metavar='eval metric')

    parser.add_argument('--loss_func', choices=LOSS_FUNC_LIST,
                        default=PairwiseSoftmaxLoss.name(),
                        help='Loss functions: ' + ','.join(LOSS_FUNC_LIST))

    args = parser.parse_args()

    utils.set_all_seeds(args.seed)

    loss_name = args.loss_func
    if loss_name == PairwiseSoftmaxLoss.name():
        loss_obj = PairwiseSoftmaxLoss()
    elif loss_name == MarginRankingLossWrapper.name():
        loss_obj = MarginRankingLossWrapper(margin = args.loss_margin)
    else:
        print('Unsupported loss: ' + loss_name)
        sys.exit(1)

    # If we have the complete model, we just load it,
    # otherwise we first create a model and load *SOME* of its weights.
    # For example, if we start from an original BERT model, which has
    # no extra heads, it we will load only the respective weights and
    # initialize the weights of the head randomly.
    if args.init_model is not None:
        print('Loading a complete model from:', args.init_model.name)
        model = torch.load(args.init_model.name, map_location='cpu')
    elif args.init_model_weights is not None:
        model = model_init_utils.create_model_from_args(args)
        print('Loading model weights from:', args.init_model_weights.name)
        model.load_state_dict(torch.load(args.init_model_weights.name, map_location='cpu'), strict=False)
    else:
        print('Creating the model from scratch!')
        model = model_init_utils.create_model_from_args(args)

    os.makedirs(args.model_out_dir, exist_ok=True)
    print(model)
    model.set_grad_checkpoint_param(args.grad_checkpoint_param)

    dataset = data.read_datafiles(args.datafiles)
    qrelf = args.qrels.name
    qrels = readQrelsDict(qrelf)
    train_pairs_all = data.read_pairs_dict(args.train_pairs)
    valid_run = readRunDict(args.valid_run.name)
    max_query_val = args.max_query_val
    query_ids = list(valid_run.keys())
    if max_query_val > 0:
        query_ids = query_ids[0:max_query_val]
        valid_run = {k: valid_run[k] for k in query_ids}

    print('# of eval. queries:', len(query_ids), ' in the file', args.valid_run.name)


    device_qty = args.device_qty
    master_port = args.master_port
    if device_qty > 1:
        if master_port is None:
            print('Specify a master port for distributed training!')
            sys.exit(1)

    processes = []

    is_distr_train = device_qty > 1

    qids = []

    if is_distr_train:
        qids = list(train_pairs_all.keys())

    # We must go in the reverse direction, b/c
    # rank == 0 trainer is in the same process and
    # we call the function do_train in the same process,
    # i.e., this call is blocking processing and
    # prevents other processes from starting.
    for rank in range(device_qty - 1, -1, -1):
        if is_distr_train:
            device_name = f'cuda:{rank}'
        else:
            device_name = args.device_name
            if args.no_cuda:
                device_name = DEVICE_CPU

        # When we have only a single GPP, the main process is its own master
        is_master_proc = rank == 0

        train_params = TrainParams(init_lr=args.init_lr, init_bert_lr=args.init_bert_lr,
                                    warmup_pct=args.warmup_pct, batch_sync_qty=args.batch_sync_qty,
                                    epoch_lr_decay=args.epoch_lr_decay, weight_decay=args.weight_decay,
                                    backprop_batch_size=args.backprop_batch_size,
                                    batches_per_train_epoch=args.batches_per_train_epoch,
                                    save_snapshots=args.save_snapshots,
                                    batch_size=args.batch_size, batch_size_val=args.batch_size_val,
                                    max_query_len=args.max_query_len, max_doc_len=args.max_doc_len,
                                    epoch_qty=args.epoch_qty, device_name=device_name,
                                    use_external_eval=args.use_external_eval, eval_metric=args.eval_metric.lower(),
                                    print_grads=args.print_grads,
                                    shuffle_train=not args.no_shuffle_train)

        train_pair_qty = len(train_pairs_all)
        if is_distr_train or train_pair_qty < device_qty:
            tpart_qty = int((train_pair_qty + device_qty - 1) / device_qty)
            train_start = rank * tpart_qty
            train_end = min(train_start + tpart_qty, len(qids))
            train_pairs = { k : train_pairs_all[k] for k in qids[train_start : train_end] }
        else:
            train_pairs = train_pairs_all
        print('Process rank %d device %s using %d training pairs out of %d' %
              (rank, device_name, len(train_pairs), train_pair_qty))

        param_dict = {
            'device_qty' : device_qty, 'master_port' : master_port,
             'rank' : rank, 'is_master_proc' : is_master_proc,
             'dataset' : dataset,
             'qrels' : qrels, 'qrel_file_name' : qrelf,
             'train_pairs' : train_pairs, 'valid_run' : valid_run,
             'model_out_dir' : args.model_out_dir,
             'model' : model, 'loss_obj' : loss_obj, 'train_params' : train_params
        }

        if is_distr_train and not is_master_proc:
            p = Process(target=do_train, kwargs=param_dict)
            p.start()
            processes.append(p)
        else:
            do_train(**param_dict)

    for p in processes:
        utils.join_and_check_stat(p)

    if device_qty > 1:
        dist.destroy_process_group()
Exemplo n.º 2
0
def main_cli():
    parser = argparse.ArgumentParser('CEDR model training and validation')

    model_init_utils.add_model_init_basic_args(parser, True)

    parser.add_argument('--datafiles',
                        metavar='data files',
                        help='data files: docs & queries',
                        type=argparse.FileType('rt'),
                        nargs='+',
                        required=True)

    parser.add_argument('--qrels',
                        metavar='QREL file',
                        help='QREL file',
                        type=argparse.FileType('rt'),
                        required=True)

    parser.add_argument('--train_pairs',
                        metavar='paired train data',
                        help='paired train data',
                        type=argparse.FileType('rt'),
                        required=True)

    parser.add_argument('--valid_run',
                        metavar='validation file',
                        help='validation file',
                        type=argparse.FileType('rt'),
                        required=True)

    parser.add_argument('--model_out_dir',
                        metavar='model out dir',
                        help='an output directory for the trained model',
                        required=True)

    parser.add_argument('--epoch_qty',
                        metavar='# of epochs',
                        help='# of epochs',
                        type=int,
                        default=10)

    parser.add_argument('--no_cuda', action='store_true', help='Use no CUDA')

    parser.add_argument('--valid_type',
                        default=VALID_ALWAYS,
                        choices=[VALID_ALWAYS, VALID_LAST, VALID_NONE],
                        help='validation type')

    parser.add_argument('--warmup_pct',
                        metavar='warm-up fraction',
                        default=None,
                        type=float,
                        help='use a warm-up/cool-down learning-reate schedule')

    parser.add_argument('--device_qty',
                        type=int,
                        metavar='# of device for multi-GPU training',
                        default=1,
                        help='# of GPUs for multi-GPU training')

    parser.add_argument(
        '--batch_sync_qty',
        metavar='# of batches before model sync',
        type=int,
        default=4,
        help=
        'model syncronization frequency for multi-GPU trainig in the # of batche'
    )

    parser.add_argument('--master_port',
                        type=int,
                        metavar='pytorch master port',
                        default=None,
                        help='pytorch master port for multi-GPU training')

    parser.add_argument('--print_grads',
                        action='store_true',
                        help='print gradient norms of parameters')

    parser.add_argument('--save_epoch_snapshots',
                        action='store_true',
                        help='save model after each epoch')

    parser.add_argument(
        '--save_last_snapshot_every_k_batch',
        metavar='debug: save latest snapshot every k batch',
        type=int,
        default=None,
        help='debug option: save latest snapshot every k batch')

    parser.add_argument('--seed',
                        metavar='random seed',
                        help='random seed',
                        type=int,
                        default=42)

    parser.add_argument('--optim',
                        metavar='optimizer',
                        choices=[OPT_SGD, OPT_ADAMW],
                        default=OPT_ADAMW,
                        help='Optimizer')

    parser.add_argument('--loss_margin',
                        metavar='loss margin',
                        help='Margin in the margin loss',
                        type=float,
                        default=1)

    parser.add_argument(
        '--init_lr',
        metavar='init learn. rate',
        type=float,
        default=0.001,
        help='initial learning rate for BERT-unrelated parameters')

    parser.add_argument('--momentum',
                        metavar='SGD momentum',
                        type=float,
                        default=0.9,
                        help='SGD momentum')

    parser.add_argument('--init_bert_lr',
                        metavar='init BERT learn. rate',
                        type=float,
                        default=0.00005,
                        help='initial learning rate for BERT parameters')

    parser.add_argument('--epoch_lr_decay',
                        metavar='epoch LR decay',
                        type=float,
                        default=1.0,
                        help='per-epoch learning rate decay')

    parser.add_argument('--weight_decay',
                        metavar='weight decay',
                        type=float,
                        default=0.0,
                        help='optimizer weight decay')

    parser.add_argument('--batch_size',
                        metavar='batch size',
                        type=int,
                        default=32,
                        help='batch size')

    parser.add_argument('--batch_size_val',
                        metavar='val batch size',
                        type=int,
                        default=32,
                        help='validation batch size')

    parser.add_argument('--backprop_batch_size',
                        metavar='backprop batch size',
                        type=int,
                        default=1,
                        help='batch size for each backprop step')

    parser.add_argument(
        '--batches_per_train_epoch',
        metavar='# of rand. batches per epoch',
        type=int,
        default=None,
        help='# of random batches per epoch: 0 tells to use all data')

    parser.add_argument(
        '--max_query_val',
        metavar='max # of val queries',
        type=int,
        default=0,
        help='max # of validation queries: 0 tells to use all data')

    parser.add_argument('--no_shuffle_train',
                        action='store_true',
                        help='disabling shuffling of training data')

    parser.add_argument('--use_external_eval',
                        action='store_true',
                        help='use external eval tools: gdeval or trec_eval')

    parser.add_argument('--eval_metric',
                        choices=METRIC_LIST,
                        default=METRIC_LIST[0],
                        help='Metric list: ' + ','.join(METRIC_LIST),
                        metavar='eval metric')

    parser.add_argument('--loss_func',
                        choices=LOSS_FUNC_LIST,
                        default=PairwiseSoftmaxLoss.name(),
                        help='Loss functions: ' + ','.join(LOSS_FUNC_LIST))

    parser.add_argument(
        '--json_conf',
        metavar='JSON config',
        type=str,
        default=None,
        help=
        'a JSON config (simple-dictionary): keys are the same as args, takes precedence over command line args'
    )

    parser.add_argument(
        '--valid_run_dir',
        metavar='',
        type=str,
        default=None,
        help='directory to store predictions on validation set')
    parser.add_argument('--valid_checkpoints',
                        metavar='',
                        type=str,
                        default=None,
                        help='validation checkpoints (in # of batches)')

    args = parser.parse_args()

    print(args)
    utils.sync_out_streams()

    all_arg_names = vars(args).keys()

    if args.json_conf is not None:
        conf_file = args.json_conf
        print(f'Reading configuration variables from {conf_file}')
        add_conf = utils.read_json(conf_file)
        for arg_name, arg_val in add_conf.items():
            if arg_name not in all_arg_names:
                print(f'Invalid option in the configuration file: {arg_name}')
                sys.exit(1)
            arg_default = getattr(args, arg_name)
            exp_type = type(arg_default)
            if arg_default is not None and type(arg_val) != exp_type:
                print(
                    f'Invalid type in the configuration file: {arg_name} expected type: '
                    + str(type(exp_type)) + f' default {arg_default}')
                sys.exit(1)
            print(f'Using {arg_name} from the config')
            setattr(args, arg_name, arg_val)

    # This hack copies max query and document length parameters to the model space parameters
    # maybe some other approach is more elegant, but this one should at least work
    setattr(args, f'{MODEL_PARAM_PREF}max_query_len', args.max_query_len)
    setattr(args, f'{MODEL_PARAM_PREF}max_doc_len', args.max_doc_len)

    if args.save_last_snapshot_every_k_batch is not None and args.save_last_snapshot_every_k_batch < 2:
        print('--save_last_snapshot_every_k_batch should be > 1')
        sys.exit(1)

    utils.set_all_seeds(args.seed)

    loss_name = args.loss_func
    if loss_name == PairwiseSoftmaxLoss.name():
        loss_obj = PairwiseSoftmaxLoss()
    elif loss_name == MarginRankingLossWrapper.name():
        loss_obj = MarginRankingLossWrapper(margin=args.loss_margin)
    else:
        print('Unsupported loss: ' + loss_name)
        sys.exit(1)

    # If we have the complete model, we just load it,
    # otherwise we first create a model and load *SOME* of its weights.
    # For example, if we start from an original BERT model, which has
    # no extra heads, it we will load only the respective weights and
    # initialize the weights of the head randomly.
    if args.init_model is not None:
        print('Loading a complete model from:', args.init_model.name)
        model = torch.load(args.init_model.name, map_location='cpu')
    elif args.init_model_weights is not None:
        model = model_init_utils.create_model_from_args(args)
        print('Loading model weights from:', args.init_model_weights.name)
        model.load_state_dict(torch.load(args.init_model_weights.name,
                                         map_location='cpu'),
                              strict=False)
    else:
        print('Creating the model from scratch!')
        model = model_init_utils.create_model_from_args(args)

    os.makedirs(args.model_out_dir, exist_ok=True)
    print(model)
    utils.sync_out_streams()
    model.set_grad_checkpoint_param(args.grad_checkpoint_param)

    dataset = data.read_datafiles(args.datafiles)
    qrelf = args.qrels.name
    qrels = read_qrels_dict(qrelf)
    train_pairs_all = data.read_pairs_dict(args.train_pairs)
    valid_run = read_run_dict(args.valid_run.name)
    max_query_val = args.max_query_val
    query_ids = list(valid_run.keys())
    if max_query_val > 0:
        query_ids = query_ids[0:max_query_val]
        valid_run = {k: valid_run[k] for k in query_ids}

    print('# of eval. queries:', len(query_ids), ' in the file',
          args.valid_run.name)

    device_qty = args.device_qty
    master_port = args.master_port
    if device_qty > 1:
        if master_port is None:
            print('Specify a master port for distributed training!')
            sys.exit(1)

    processes = []

    is_distr_train = device_qty > 1

    qids = []

    if is_distr_train:
        qids = list(train_pairs_all.keys())

    sync_barrier = Barrier(device_qty)

    # We must go in the reverse direction, b/c
    # rank == 0 trainer is in the same process and
    # we call the function do_train in the same process,
    # i.e., this call is blocking processing and
    # prevents other processes from starting.
    for rank in range(device_qty - 1, -1, -1):
        if is_distr_train:
            device_name = f'cuda:{rank}'
        else:
            device_name = args.device_name
            if args.no_cuda:
                device_name = DEVICE_CPU

        # When we have only a single GPP, the main process is its own master
        is_master_proc = rank == 0

        train_params = TrainParams(
            init_lr=args.init_lr,
            init_bert_lr=args.init_bert_lr,
            momentum=args.momentum,
            warmup_pct=args.warmup_pct,
            batch_sync_qty=args.batch_sync_qty,
            epoch_lr_decay=args.epoch_lr_decay,
            weight_decay=args.weight_decay,
            backprop_batch_size=args.backprop_batch_size,
            batches_per_train_epoch=args.batches_per_train_epoch,
            save_epoch_snapshots=args.save_epoch_snapshots,
            save_last_snapshot_every_k_batch=args.
            save_last_snapshot_every_k_batch,
            batch_size=args.batch_size,
            batch_size_val=args.batch_size_val,
            max_query_len=args.max_query_len,
            max_doc_len=args.max_doc_len,
            epoch_qty=args.epoch_qty,
            device_name=device_name,
            use_external_eval=args.use_external_eval,
            eval_metric=args.eval_metric.lower(),
            print_grads=args.print_grads,
            shuffle_train=not args.no_shuffle_train,
            valid_type=args.valid_type,
            optim=args.optim)

        train_pair_qty = len(train_pairs_all)
        if is_distr_train or train_pair_qty < device_qty:
            tpart_qty = int((train_pair_qty + device_qty - 1) / device_qty)
            train_start = rank * tpart_qty
            train_end = min(train_start + tpart_qty, len(qids))
            train_pairs = {
                k: train_pairs_all[k]
                for k in qids[train_start:train_end]
            }
        else:
            train_pairs = train_pairs_all
        print('Process rank %d device %s using %d training pairs out of %d' %
              (rank, device_name, len(train_pairs), train_pair_qty))

        valid_checkpoints = [] if args.valid_checkpoints is None \
                            else list(map(int, args.valid_checkpoints.split(',')))
        param_dict = {
            'sync_barrier': sync_barrier,
            'device_qty': device_qty,
            'master_port': master_port,
            'rank': rank,
            'is_master_proc': is_master_proc,
            'dataset': dataset,
            'qrels': qrels,
            'qrel_file_name': qrelf,
            'train_pairs': train_pairs,
            'valid_run': valid_run,
            'valid_run_dir': args.valid_run_dir,
            'valid_checkpoints': valid_checkpoints,
            'model_out_dir': args.model_out_dir,
            'model': model,
            'loss_obj': loss_obj,
            'train_params': train_params
        }

        if is_distr_train and not is_master_proc:
            p = Process(target=do_train, kwargs=param_dict)
            p.start()
            processes.append(p)
        else:
            do_train(**param_dict)

    for p in processes:
        utils.join_and_check_stat(p)

    if device_qty > 1:
        dist.destroy_process_group()