def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--bert_model",
        default='bert-base-uncased',
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        '--task',
        type=str,
        default=None,
        required=True,
        help="Task code in {hotpot_open, hotpot_distractor, squad, nq, ambigqa}"
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=378,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=1,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=5,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam. (def: 5e-5)")
    parser.add_argument("--num_train_epochs",
                        default=5.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )

    # RNN graph retriever-specific parameters
    parser.add_argument("--example_limit", default=None, type=int)

    parser.add_argument("--max_para_num", default=10, type=int)
    parser.add_argument(
        "--neg_chunk",
        default=8,
        type=int,
        help=
        "The chunk size of negative examples during training (to reduce GPU memory consumption with negative sampling)"
    )
    parser.add_argument(
        "--eval_chunk",
        default=100000,
        type=int,
        help=
        "The chunk size of evaluation examples (to reduce RAM consumption during evaluation)"
    )
    parser.add_argument(
        "--split_chunk",
        default=300,
        type=int,
        help=
        "The chunk size of BERT encoding during inference (to reduce GPU memory consumption)"
    )

    parser.add_argument('--train_file_path',
                        type=str,
                        default=None,
                        help="File path to the training data")
    parser.add_argument('--dev_file_path',
                        type=str,
                        default=None,
                        help="File path to the eval data")

    parser.add_argument('--beam', type=int, default=1, help="Beam size")
    parser.add_argument('--min_select_num',
                        type=int,
                        default=1,
                        help="Minimum number of selected paragraphs")
    parser.add_argument('--max_select_num',
                        type=int,
                        default=3,
                        help="Maximum number of selected paragraphs")
    parser.add_argument(
        "--use_redundant",
        action='store_true',
        help="Whether to use simulated seqs (only for training)")
    parser.add_argument(
        "--use_multiple_redundant",
        action='store_true',
        help="Whether to use multiple simulated seqs (only for training)")
    parser.add_argument(
        '--max_redundant_num',
        type=int,
        default=100000,
        help=
        "Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)"
    )
    parser.add_argument(
        "--no_links",
        action='store_true',
        help=
        "Whether to omit any links (or in other words, only use TF-IDF-based paragraphs)"
    )
    parser.add_argument("--pruning_by_links",
                        action='store_true',
                        help="Whether to do pruning by links (and top 1)")
    parser.add_argument(
        "--expand_links",
        action='store_true',
        help=
        "Whether to expand links with paragraphs in the same article (for NQ)")
    parser.add_argument(
        '--tfidf_limit',
        type=int,
        default=None,
        help=
        "Whether to limit the number of the initial TF-IDF pool (only for open-domain eval)"
    )

    parser.add_argument("--pred_file",
                        default=None,
                        type=str,
                        help="File name to write paragraph selection results")
    parser.add_argument("--tagme",
                        action='store_true',
                        help="Whether to use tagme at inference")
    parser.add_argument(
        '--topk',
        type=int,
        default=2,
        help="Whether to use how many paragraphs from the previous steps")

    parser.add_argument(
        "--model_suffix",
        default=None,
        type=str,
        help="Suffix to load a model file ('pytorch_model_' + suffix +'.bin')")

    parser.add_argument("--db_save_path",
                        default=None,
                        type=str,
                        help="File path to DB")

    args = parser.parse_args()

    cpu = torch.device('cpu')

    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    n_gpu = torch.cuda.device_count()

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if args.train_file_path is not None:
        do_train = True

        if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(args.output_dir))
        os.makedirs(args.output_dir, exist_ok=True)

    elif args.dev_file_path is not None:
        do_train = False

    else:
        raise ValueError(
            'One of train_file_path: {} or dev_file_path: {} must be non-None'.
            format(args.train_file_path, args.dev_file_path))

    processor = DataProcessor()

    # Configurations of the graph retriever
    graph_retriever_config = GraphRetrieverConfig(
        example_limit=args.example_limit,
        task=args.task,
        max_seq_length=args.max_seq_length,
        max_select_num=args.max_select_num,
        max_para_num=args.max_para_num,
        tfidf_limit=args.tfidf_limit,
        train_file_path=args.train_file_path,
        use_redundant=args.use_redundant,
        use_multiple_redundant=args.use_multiple_redundant,
        max_redundant_num=args.max_redundant_num,
        dev_file_path=args.dev_file_path,
        beam=args.beam,
        min_select_num=args.min_select_num,
        no_links=args.no_links,
        pruning_by_links=args.pruning_by_links,
        expand_links=args.expand_links,
        eval_chunk=args.eval_chunk,
        tagme=args.tagme,
        topk=args.topk,
        db_save_path=args.db_save_path)

    logger.info(graph_retriever_config)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    ##############################
    # Training                   #
    ##############################
    if do_train:
        model = BertForGraphRetriever.from_pretrained(
            args.bert_model,
            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
            'distributed_{}'.format(-1),
            graph_retriever_config=graph_retriever_config)

        model.to(device)

        if n_gpu > 1:
            print("Parallel Training.")
            model = torch.nn.DataParallel(model)

        global_step = 0
        nb_tr_steps = 0
        tr_loss = 0

        POSITIVE = 1.0
        NEGATIVE = 0.0

        # Load training examples
        train_examples = None
        num_train_steps = None
        train_examples = processor.get_train_examples(graph_retriever_config)
        train_features = convert_examples_to_features(train_examples,
                                                      args.max_seq_length,
                                                      args.max_para_num,
                                                      graph_retriever_config,
                                                      tokenizer,
                                                      train=True)
        # len(train_examples) and len(train_features) can be different, depedning on the redundant setting
        num_train_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

        # Prepare optimizer
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        t_total = num_train_steps

        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total,
                             max_grad_norm=1.0)

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)

        model.train()
        epc = 0
        for _ in range(int(args.num_train_epochs)):
            logger.info('Epoch ' + str(epc + 1))

            TOTAL_NUM = len(train_features)
            train_start_index = 0
            CHUNK_NUM = 4  # this doesn't matter for performance
            train_chunk = TOTAL_NUM // CHUNK_NUM
            chunk_index = 0

            random.shuffle(train_features)

            save_retry = False

            while train_start_index < TOTAL_NUM:
                train_end_index = min(train_start_index + train_chunk - 1,
                                      TOTAL_NUM - 1)
                chunk_len = train_end_index - train_start_index + 1

                train_features_ = train_features[
                    train_start_index:train_start_index + chunk_len]

                all_input_ids = torch.tensor(
                    [f.input_ids for f in train_features_], dtype=torch.long)
                all_input_masks = torch.tensor(
                    [f.input_masks for f in train_features_], dtype=torch.long)
                all_segment_ids = torch.tensor(
                    [f.segment_ids for f in train_features_], dtype=torch.long)
                all_output_masks = torch.tensor(
                    [f.output_masks for f in train_features_],
                    dtype=torch.float)
                all_num_paragraphs = torch.tensor(
                    [f.num_paragraphs for f in train_features_],
                    dtype=torch.long)
                all_num_steps = torch.tensor(
                    [f.num_steps for f in train_features_], dtype=torch.long)
                train_data = TensorDataset(all_input_ids, all_input_masks,
                                           all_segment_ids, all_output_masks,
                                           all_num_paragraphs, all_num_steps)

                train_sampler = RandomSampler(train_data)
                train_dataloader = DataLoader(train_data,
                                              sampler=train_sampler,
                                              batch_size=args.train_batch_size)

                tr_loss = 0
                nb_tr_examples, nb_tr_steps = 0, 0
                logger.info('Examples from ' + str(train_start_index) +
                            ' to ' + str(train_end_index))
                for step, batch in enumerate(
                        tqdm(train_dataloader, desc="Iteration")):
                    input_masks = batch[1]
                    batch_max_len = input_masks.sum(dim=2).max().item()

                    num_paragraphs = batch[4]
                    batch_max_para_num = num_paragraphs.max().item()

                    num_steps = batch[5]
                    batch_max_steps = num_steps.max().item()

                    output_masks_cpu = (
                        batch[3])[:, :batch_max_steps, :batch_max_para_num + 1]

                    batch = tuple(t.to(device) for t in batch)
                    input_ids, input_masks, segment_ids, output_masks, _, __ = batch
                    B = input_ids.size(0)

                    input_ids = input_ids[:, :batch_max_para_num, :
                                          batch_max_len]
                    input_masks = input_masks[:, :batch_max_para_num, :
                                              batch_max_len]
                    segment_ids = segment_ids[:, :batch_max_para_num, :
                                              batch_max_len]
                    output_masks = output_masks[:, :batch_max_steps, :
                                                batch_max_para_num +
                                                1]  # 1 for EOE

                    target = torch.FloatTensor(output_masks.size()).fill_(
                        NEGATIVE)  # (B, NUM_STEPS, |P|+1) <- 1 for EOE
                    for i in range(B):
                        output_masks[i, :num_steps[i], -1] = 1.0  # for EOE

                        for j in range(num_steps[i].item() - 1):
                            target[i, j, j].fill_(
                                POSITIVE
                            )  # positive paragraphs are stored in order of the right path

                        target[i, num_steps[i] - 1, -1].fill_(POSITIVE)  # EOE
                    target = target.to(device)

                    neg_start = batch_max_steps - 1
                    while neg_start < batch_max_para_num:
                        neg_end = min(neg_start + args.neg_chunk - 1,
                                      batch_max_para_num - 1)
                        neg_len = (neg_end - neg_start + 1)

                        input_ids_ = torch.cat(
                            (input_ids[:, :batch_max_steps - 1, :],
                             input_ids[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        input_masks_ = torch.cat(
                            (input_masks[:, :batch_max_steps - 1, :],
                             input_masks[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        segment_ids_ = torch.cat(
                            (segment_ids[:, :batch_max_steps - 1, :],
                             segment_ids[:, neg_start:neg_start + neg_len, :]),
                            dim=1)
                        output_masks_ = torch.cat(
                            (output_masks[:, :, :batch_max_steps - 1],
                             output_masks[:, :, neg_start:neg_start + neg_len],
                             output_masks[:, :, batch_max_para_num:
                                          batch_max_para_num + 1]),
                            dim=2)
                        target_ = torch.cat(
                            (target[:, :, :batch_max_steps - 1],
                             target[:, :, neg_start:neg_start + neg_len],
                             target[:, :,
                                    batch_max_para_num:batch_max_para_num +
                                    1]),
                            dim=2)

                        if neg_start != batch_max_steps - 1:
                            output_masks_[:, :, :batch_max_steps - 1] = 0.0
                            output_masks_[:, :, -1] = 0.0

                        loss = model(input_ids_, segment_ids_, input_masks_,
                                     output_masks_, target_, batch_max_steps)

                        if n_gpu > 1:
                            loss = loss.mean(
                            )  # mean() to average on multi-gpu.
                        if args.gradient_accumulation_steps > 1:
                            loss = loss / args.gradient_accumulation_steps

                        loss.backward()
                        tr_loss += loss.item()
                        neg_start = neg_end + 1

                    nb_tr_examples += B
                    nb_tr_steps += 1
                    if (step + 1) % args.gradient_accumulation_steps == 0:
                        # modify learning rate with special warm up BERT uses
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / t_total, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                        optimizer.step()
                        optimizer.zero_grad()
                        global_step += 1

                chunk_index += 1
                train_start_index = train_end_index + 1

                # Save the model at the half of the epoch
                if (chunk_index == CHUNK_NUM // 2 or save_retry):
                    status = save(model, args.output_dir, str(epc + 0.5))
                    save_retry = (not status)

                del all_input_ids
                del all_input_masks
                del all_segment_ids
                del all_output_masks
                del all_num_paragraphs
                del all_num_steps
                del train_data

            # Save the model at the end of the epoch
            save(model, args.output_dir, str(epc + 1))

            epc += 1

    if do_train:
        return

    ##############################
    # Evaluation                 #
    ##############################
    assert args.model_suffix is not None

    if graph_retriever_config.db_save_path is not None:
        import sys
        sys.path.append('../')
        from pipeline.tfidf_retriever import TfidfRetriever
        tfidf_retriever = TfidfRetriever(graph_retriever_config.db_save_path,
                                         None)
    else:
        tfidf_retriever = None

    model_state_dict = load(args.output_dir, args.model_suffix)

    model = BertForGraphRetriever.from_pretrained(
        args.bert_model,
        state_dict=model_state_dict,
        graph_retriever_config=graph_retriever_config)
    model.to(device)

    model.eval()

    if args.pred_file is not None:
        pred_output = []

    eval_examples = processor.get_dev_examples(graph_retriever_config)

    logger.info("***** Running evaluation *****")
    logger.info("  Num examples = %d", len(eval_examples))
    logger.info("  Batch size = %d", args.eval_batch_size)

    TOTAL_NUM = len(eval_examples)
    eval_start_index = 0

    while eval_start_index < TOTAL_NUM:
        eval_end_index = min(
            eval_start_index + graph_retriever_config.eval_chunk - 1,
            TOTAL_NUM - 1)
        chunk_len = eval_end_index - eval_start_index + 1

        eval_features = convert_examples_to_features(
            eval_examples[eval_start_index:eval_start_index + chunk_len],
            args.max_seq_length, args.max_para_num, graph_retriever_config,
            tokenizer)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_masks = torch.tensor([f.input_masks for f in eval_features],
                                       dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_output_masks = torch.tensor(
            [f.output_masks for f in eval_features], dtype=torch.float)
        all_num_paragraphs = torch.tensor(
            [f.num_paragraphs for f in eval_features], dtype=torch.long)
        all_num_steps = torch.tensor([f.num_steps for f in eval_features],
                                     dtype=torch.long)
        all_ex_indices = torch.tensor([f.ex_index for f in eval_features],
                                      dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_masks,
                                  all_segment_ids, all_output_masks,
                                  all_num_paragraphs, all_num_steps,
                                  all_ex_indices)

        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        for input_ids, input_masks, segment_ids, output_masks, num_paragraphs, num_steps, ex_indices in tqdm(
                eval_dataloader, desc="Evaluating"):
            batch_max_len = input_masks.sum(dim=2).max().item()
            batch_max_para_num = num_paragraphs.max().item()

            batch_max_steps = num_steps.max().item()

            input_ids = input_ids[:, :batch_max_para_num, :batch_max_len]
            input_masks = input_masks[:, :batch_max_para_num, :batch_max_len]
            segment_ids = segment_ids[:, :batch_max_para_num, :batch_max_len]
            output_masks = output_masks[:, :batch_max_para_num +
                                        2, :batch_max_para_num + 1]
            output_masks[:, 1:, -1] = 1.0  # Ignore EOE in the first step

            input_ids = input_ids.to(device)
            input_masks = input_masks.to(device)
            segment_ids = segment_ids.to(device)
            output_masks = output_masks.to(device)

            examples = [
                eval_examples[eval_start_index + ex_indices[i].item()]
                for i in range(input_ids.size(0))
            ]

            with torch.no_grad():
                pred, prob, topk_pred, topk_prob = model.beam_search(
                    input_ids,
                    segment_ids,
                    input_masks,
                    examples=examples,
                    tokenizer=tokenizer,
                    retriever=tfidf_retriever,
                    split_chunk=args.split_chunk)

            for i in range(len(pred)):
                e = examples[i]
                titles = [e.title_order[p] for p in pred[i]]

                # Output predictions to a file
                if args.pred_file is not None:
                    pred_output.append({})
                    pred_output[-1]['q_id'] = e.guid

                    pred_output[-1]['titles'] = titles
                    pred_output[-1]['probs'] = []
                    for prob_ in prob[i]:
                        entry = {'EOE': prob_[-1]}
                        for j in range(len(e.title_order)):
                            entry[e.title_order[j]] = prob_[j]
                        pred_output[-1]['probs'].append(entry)

                    topk_titles = [[e.title_order[p] for p in topk_pred[i][j]]
                                   for j in range(len(topk_pred[i]))]
                    pred_output[-1]['topk_titles'] = topk_titles

                    topk_probs = []
                    for k in range(len(topk_prob[i])):
                        topk_probs.append([])
                        for prob_ in topk_prob[i][k]:
                            entry = {'EOE': prob_[-1]}
                            for j in range(len(e.title_order)):
                                entry[e.title_order[j]] = prob_[j]
                            topk_probs[-1].append(entry)
                    pred_output[-1]['topk_probs'] = topk_probs

                    # Output the selected paragraphs
                    context = {}
                    for ts in topk_titles:
                        for t in ts:
                            context[t] = e.all_paras[t]
                    pred_output[-1]['context'] = context

        eval_start_index = eval_end_index + 1

        del eval_features
        del all_input_ids
        del all_input_masks
        del all_segment_ids
        del all_output_masks
        del all_num_paragraphs
        del all_num_steps
        del all_ex_indices
        del eval_data

    if args.pred_file is not None:
        json.dump(pred_output, open(args.pred_file, 'w'))
Exemple #2
0
def train(bundles, model1, device, mode, model2, batch_size, num_epoch,
          gradient_accumulation_steps, lr1, lr2, alpha):
    '''Train Sys1 and Sys2 models.
    
    Train models by task #1(tensors) and task #2(bundle). 
    
    Args:
        bundles (list): List of bundles.
        model1 (BertForMultiHopQuestionAnswering): System 1 model.
        device (torch.device): The device which models and data are on.
        mode (str): Defaults to 'tensors'. Task identifier('tensors' or 'bundle').
        model2 (CognitiveGNN): System 2 model.
        batch_size (int): Defaults to 4.
        num_epoch (int): Defaults to 1.
        gradient_accumulation_steps (int): Defaults to 1. 
        lr1 (float): Defaults to 1e-4. Learning rate for Sys1.
        lr2 (float): Defaults to 1e-4. Learning rate for Sys2.
        alpha (float): Defaults to 0.2. Balance factor for loss of two systems.
    
    Returns:
        ([type], [type]): Trained models.
    '''

    # Prepare optimizer for Sys1
    param_optimizer = list(model1.named_parameters())
    # hack to remove pooler, which is not used.
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    num_batch, dataloader = homebrew_data_loader(bundles,
                                                 mode=mode,
                                                 batch_size=batch_size)
    num_steps = num_batch * num_epoch
    global_step = 0
    opt1 = BertAdam(optimizer_grouped_parameters,
                    lr=lr1,
                    warmup=0.1,
                    t_total=num_steps)
    model1.to(device)
    model1.train()

    # Prepare optimizer for Sys2
    if mode == 'bundle':
        opt2 = Adam(model2.parameters(), lr=lr2)
        model2.to(device)
        model2.train()
        warmed = False  # warmup for jointly training

    for epoch in trange(num_epoch, desc='Epoch'):
        ans_mean, hop_mean = WindowMean(), WindowMean()
        opt1.zero_grad()
        if mode == 'bundle':
            final_mean = WindowMean()
            opt2.zero_grad()
        tqdm_obj = tqdm(dataloader, total=num_batch)

        for step, batch in enumerate(tqdm_obj):
            try:
                if mode == 'tensors':
                    batch = tuple(t.to(device) for t in batch)
                    hop_loss, ans_loss, pooled_output = model1(*batch)
                    hop_loss, ans_loss = hop_loss.mean(), ans_loss.mean()
                    pooled_output.detach()
                    loss = ans_loss + hop_loss
                elif mode == 'bundle':
                    hop_loss, ans_loss, final_loss = model2(
                        batch, model1, device)
                    hop_loss, ans_loss = hop_loss.mean(), ans_loss.mean()
                    loss = ans_loss + hop_loss + alpha * final_loss
                loss.backward()

                if (step + 1) % gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses. From BERT pytorch examples
                    lr_this_step = lr1 * warmup_linear(global_step / num_steps,
                                                       warmup=0.1)
                    for param_group in opt1.param_groups:
                        param_group['lr'] = lr_this_step
                    global_step += 1
                    if mode == 'bundle':
                        opt2.step()
                        opt2.zero_grad()
                        final_mean_loss = final_mean.update(final_loss.item())
                        tqdm_obj.set_description(
                            'ans_loss: {:.2f}, hop_loss: {:.2f}, final_loss: {:.2f}'
                            .format(ans_mean.update(ans_loss.item()),
                                    hop_mean.update(hop_loss.item()),
                                    final_mean_loss))
                        # During warming period, model1 is frozen and model2 is trained to normal weights
                        if final_mean_loss < 0.9 and step > 100:  # ugly manual hyperparam
                            warmed = True
                        if warmed:
                            opt1.step()
                        opt1.zero_grad()
                    else:
                        opt1.step()
                        opt1.zero_grad()
                        tqdm_obj.set_description(
                            'ans_loss: {:.2f}, hop_loss: {:.2f}'.format(
                                ans_mean.update(ans_loss.item()),
                                hop_mean.update(hop_loss.item())))
                    if step % 1000 == 0:
                        output_model_file = './models/bert-base-uncased.bin.tmp'
                        saved_dict = {'params1': model1.module.state_dict()}
                        saved_dict['params2'] = model2.state_dict()
                        torch.save(saved_dict, output_model_file)
            except Exception as err:
                traceback.print_exc()
                if mode == 'bundle':
                    print(batch._id)
    return (model1, model2)
Exemple #3
0
def train(training_data_file, valid_data_file, super_batch_size, tokenizer, mode, kw, p_key, model1, device, model2, model3, \
            batch_size, num_epoch, gradient_accumulation_steps, lr1, lr2, lambda_, valid_critic, early_stop):
    '''Train three models
    
    Train models through bundles
    
    Args:
        training_data_file (list) : training data json file, raw json file used to load data
        super_batch_size (int) : how many samples will be loaded into memory at once
        tokenizer : SentencePiece tokenizer used to obtain the token ids
        mode (str): mode of the passage format, coule be a list (processed) or a long string (unprocessed).
        kw (str) : the key word map to the passage in each data dictionary. Defaults to 'abstract'
        p_key (str) : the key word to search for specific passage. Default to 'title'
        model1 (nn.DataParallel) : local dependency encoder
        device (torch.device): The device which models and data are on.
        model2 (nn.Module): global coherence encoder
        model3 (nn.Module): attention decoder
        batch_size (int): Defaults to 4.
        num_epoch (int): Defaults to 1.
        gradient_accumulation_steps (int): Defaults to 1. 
        lr (float): Defaults to 1e-4. The Start learning rate.
        lambda_ (float): Defaults to 0.01. Balance factor for param nomalization.
        valid_critic (bool) : what critic to use when early stop evaluation. Default to 5 
        early_stop (int) : set the early stop boundary. Default to 5 

    '''

    # Prepare optimizer for Sys1
    param_optimizer_bert = list(model1.named_parameters())
    param_optimizer_others = list(model2.named_parameters()) + list(
        model3.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    # We tend to fix the embedding. Temeporarily we doesn't find the embedding layer
    optimizer_grouped_parameters_bert = [{
        'params': [
            p for n, p in param_optimizer_bert
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        lambda_
    }, {
        'params': [
            p for n, p in param_optimizer_bert
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    optimizer_grouped_parameters_others = [{
        'params': [
            p for n, p in param_optimizer_others
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        lambda_
    }, {
        'params': [
            p for n, p in param_optimizer_others
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    # We shall adda  module to count the num of parameters here
    critic = nn.NLLLoss(reduction='none')

    line_num = int(os.popen("wc -l " + training_data_file).read().split()[0])
    global_step = 0  # global step
    opt1 = BertAdam(optimizer_grouped_parameters_bert,
                    lr=lr1,
                    warmup=0.1,
                    t_total=line_num / batch_size * num_epoch)  # optimizer 1
    # opt = Adam(optimizer_grouped_parameter, lr=lr)
    opt2 = Adadelta(optimizer_grouped_parameters_others, lr=lr2, rho=0.95)
    model1.to(device)  #
    model1.train()  #
    model2.to(device)  #
    model2.train()  #
    model3.to(device)  #
    model3.train()  #
    warmed = True
    for epoch in trange(num_epoch, desc='Epoch'):

        smooth_mean = WindowMean()
        opt1.zero_grad()
        opt2.zero_grad()

        for superbatch, line_num in load_superbatch(training_data_file,
                                                    super_batch_size):
            bundles = []

            for data in superbatch:
                try:
                    bundles.append(
                        convert_passage_to_samples_bundle(
                            tokenizer, data, mode, kw, p_key))

                except:
                    print_exc()

            num_batch, dataloader = homebrew_data_loader(bundles,
                                                         batch_size=batch_size)

            tqdm_obj = tqdm(dataloader, total=num_batch)
            num_steps = line_num  #
            for step, batch in enumerate(tqdm_obj):
                try:
                    #batch[0] = batch[0].to(device)
                    #batch[1] = batch[1].to(device)
                    #batch[2] = batch[2].to(device)
                    batch = tuple(t for t in batch)
                    log_prob_loss, pointers_output, ground_truth = calculate_loss(
                        batch, model1, model2, model3, device, critic)
                    # here we need to add code to cal rouge-w and acc
                    rouge_ws = []
                    accs = []
                    ken_taus = []
                    pmrs = []
                    for pred, true in zip(pointers_output, ground_truth):
                        rouge_ws.append(rouge_w(pred, true))
                        accs.append(acc(pred, true))
                        ken_taus.append(kendall_tau(pred, true))
                        pmrs.append(pmr(pred, true))

                    log_prob_loss.backward()

                    # ******** In the following code we gonna edit it and made early stop ************

                    if (step + 1) % gradient_accumulation_steps == 0:
                        # modify learning rate with special warm up BERT uses. From BERT pytorch examples
                        lr_this_step = lr1 * warmup_linear(
                            global_step / num_steps, warmup=0.1)
                        for param_group in opt1.param_groups:
                            param_group['lr'] = lr_this_step
                        global_step += 1

                        opt2.step()
                        opt2.zero_grad()
                        smooth_mean_loss = smooth_mean.update(
                            log_prob_loss.item())
                        tqdm_obj.set_description(
                            '{}: {:.4f}, {}: {:.4f}, smooth_mean_loss: {:.4f}'.
                            format('accuracy', np.mean(accs), 'rough-w',
                                   np.mean(rouge_ws), smooth_mean_loss))
                        # During warming period, model1 is frozen and model2 is trained to normal weights
                        if smooth_mean_loss < 1.0 and step > 100:  # ugly manual hyperparam
                            warmed = True
                        if warmed:
                            opt1.step()
                        opt1.zero_grad()
                        if step % 1000 == 0:
                            output_model_file = './models/bert-base-cased.bin.tmp'
                            saved_dict = {
                                'params1': model1.module.state_dict()
                            }
                            saved_dict['params2'] = model2.state_dict()
                            saved_dict['params3'] = model3.state_dict()
                            torch.save(saved_dict, output_model_file)

                except Exception as err:
                    traceback.print_exc()
                    exit()
                    # if mode == 'list':
                    #     print(batch._id)

        if epoch < 5:
            best_score = 0
            continue

        with torch.no_grad():
            print('valid..............')

            valid_critic_dict = {
                'rouge-w': rouge_w,
                'acc': acc,
                'ken-tau': kendall_tau,
                'pmr': pmr
            }

            for superbatch, _ in load_superbatch(valid_data_file,
                                                 super_batch_size):
                bundles = []

                for data in superbatch:
                    try:
                        bundles.append(
                            convert_passage_to_samples_bundle(
                                tokenizer, data, mode, kw, p_key))
                    except:
                        print_exc()

                num_batch, valid_dataloader = homebrew_data_loader(
                    bundles, batch_size=1)

                valid_value = []
                for step, batch in enumerate(valid_dataloader):
                    try:
                        batch = tuple(t for idx, t in enumerate(batch))
                        pointers_output, ground_truth \
                            = dev_test(batch, model1, model2, model3, device)
                        valid_value.append(valid_critic_dict[valid_critic](
                            pointers_output, ground_truth))

                    except Exception as err:
                        traceback.print_exc()
                        # if mode == 'list':
                        #     print(batch._id)

                score = np.mean(valid_value)
            print('epc:{}, {} : {:.2f} best : {:.2f}\n'.format(
                epoch, valid_critic, score, best_score))

            if score > best_score:
                best_score = score
                best_iter = epoch

                print('Saving model to {}'.format(
                    output_model_file))  # save model structure
                saved_dict = {
                    'params1': model1.module.state_dict()
                }  # save parameters
                saved_dict['params2'] = model2.state_dict()  # save parameters
                saved_dict['params3'] = model3.state_dict()
                torch.save(saved_dict, output_model_file)  #

                # print('save best model at epc={}'.format(epc))
                # checkpoint = {'model': model.state_dict(),
                #             'args': args,
                #             'loss': best_score}
                # torch.save(checkpoint, '{}/{}.best.pt'.format(args.model_path, args.model))

            if early_stop and (epoch - best_iter) >= early_stop:
                print('early stop at epc {}'.format(epoch))
                break
Exemple #4
0
def train(bundles,
          model,
          device,
          batch_size=4,
          num_epoch=1,
          mode='tensors',
          model_cg=None,
          gradient_accumulation_steps=1,
          lr=1e-4):
    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    num_batch, dataloader = homebrew_data_loader(bundles,
                                                 mode=mode,
                                                 batch_size=batch_size)
    num_steps = num_batch * num_epoch
    global_step = 0
    opt = BertAdam(optimizer_grouped_parameters,
                   lr=lr,
                   warmup=0.1,
                   t_total=num_steps)
    if mode == 'bundle':
        opt_cg = Adam(model_cg.parameters(), lr=1e-4)  # TODO hyperparam
        model_cg.to(device)
        model_cg.train()
        warmed = False

    model.to(device)
    model.train()
    for epoch in trange(num_epoch, desc='Epoch'):
        ans_mean, hop_mean = WindowMean(), WindowMean()
        if mode == 'bundle':
            final_mean = WindowMean()
            opt_cg.zero_grad()
        opt.zero_grad()
        tqdm_obj = tqdm(dataloader, total=num_batch)
        for step, batch in enumerate(tqdm_obj):
            # torch.cuda.empty_cache()
            # gpu_tracker.track()
            try:
                if mode == 'tensors':
                    batch = tuple(t.to(device) for t in batch)
                    hop_loss, ans_loss, pooled_output = model(*batch)
                    hop_loss, ans_loss = hop_loss.mean(), ans_loss.mean()
                    pooled_output.detach()
                    loss = ans_loss + hop_loss
                elif mode == 'bundle':
                    hop_loss, ans_loss, final_loss = model_cg(
                        batch, model, device)
                    hop_loss, ans_loss = hop_loss.mean(), ans_loss.mean()
                    loss = ans_loss + hop_loss + 0.2 * final_loss
                # torch.cuda.empty_cache()
                # gpu_tracker.track()
                loss.backward()
                if (step + 1) % gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = lr * warmup_linear(global_step / num_steps,
                                                      warmup=0.1)
                    for param_group in opt.param_groups:
                        param_group['lr'] = lr_this_step
                    global_step += 1
                    if mode == 'bundle':
                        opt_cg.step()
                        opt_cg.zero_grad()
                        final_mean_loss = final_mean.update(final_loss.item())
                        tqdm_obj.set_description(
                            'ans_loss: {:.2f}, hop_loss: {:.2f}, final_loss: {:.2f}'
                            .format(ans_mean.update(ans_loss.item()),
                                    hop_mean.update(hop_loss.item()),
                                    final_mean_loss))
                        if final_mean_loss < 0.9 and step > 100:
                            warmed = True
                        if warmed:
                            opt.step()
                        opt.zero_grad()
                    else:
                        opt.step()
                        opt.zero_grad()
                        tqdm_obj.set_description(
                            'ans_loss: {:.2f}, hop_loss: {:.2f}'.format(
                                ans_mean.update(ans_loss.item()),
                                hop_mean.update(hop_loss.item())))
                    if step % 1000 == 0:
                        print('')
                        output_model_file = './models/bert-base-uncased.bin.tmp'
                        saved_dict = {'bert-params': model.module.state_dict()}
                        saved_dict['cg-params'] = model_cg.state_dict()
                        torch.save(saved_dict, output_model_file)
            except Exception as err:
                traceback.print_exc()
                if mode == 'bundle':
                    print(batch._id)
    return (model, model_cg)
Exemple #5
0
def train_model(args,dataset,model,device,D_model = None):
    logger = logging.getLogger("D-QA")
    param_optimizer = list(model.named_parameters())
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    random.shuffle(dataset[0])
    num_batch, dataloader = data_generator(args, dataset[0])
    num_steps = num_batch * args.epochs
    opt = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=0.1, t_total=num_steps)
    count_step = 0
    if args.mode == 'D-graph':
        opt_dg = Adam(D_model.parameters(), lr=args.learning_rate)
        D_model.to(device)
        D_model.train()
        warmed = False
    model.to(device)
    model.train()
    f1_s_max, em_s_max = -1, -1
    for epoch in trange(args.epochs,desc='Epoch'):
        ans_loss_mean, sup_loss_mean = WindowMean(),WindowMean()
        if args.mode == 'D-graph':
            node_loss_mean = WindowMean()
            opt_dg.zero_grad()
        opt.zero_grad()
        tqdm_obj = tqdm(dataloader, total=num_batch)
        for step, batch in enumerate(tqdm_obj):
            if args.mode == 'fine-tune':
                batch = tuple(t.to(device) for t in batch)
                sup_loss, ans_loss,_ = model(args,*batch)
                loss = ans_loss + sup_loss
            else:
                ans_loss, sup_loss, node_loss = D_model(args,batch, model, device)
                loss = ans_loss + sup_loss + 0.2 * node_loss

            loss.backward()
            if (step+1) % args.gradient_accumulation_steps == 0:
                lr_cur = args.learning_rate * warmup_linear(count_step / num_steps, warmup=0.1)
                for param_group in opt.param_groups:
                    param_group['lr'] = lr_cur
                count_step += 1
                if args.mode == 'D-graph':
                    opt_dg.step()
                    opt_dg.zero_grad()
                    node_mean_loss = node_loss_mean.update(node_loss.item())
                    ans_mean_loss = ans_loss_mean.update(ans_loss.item())
                    sup_mean_loss = sup_loss_mean.update(sup_loss.item())
                    logger.info('ans_loss: {:.2f}, sup_loss: {:.2f}, node_loss: {:.2f}'.format(ans_mean_loss,sup_mean_loss,node_mean_loss))
                    if node_mean_loss < 0.9 and step > 100:
                        warmed = True
                    if warmed:
                        opt.step()
                    opt.zero_grad()
                else:
                    opt.step()
                    opt.zero_grad()
                    ans_mean_loss = ans_loss_mean.update(ans_loss.item())
                    sup_mean_loss = sup_loss_mean.update(sup_loss.item())
                    logger.info('ans_loss: {:.2f}, sup_loss: {:.2f}'.format(ans_mean_loss, sup_mean_loss))
            else:
                if args.mode == 'D-graph':
                    node_loss_mean.update(node_loss.item())
                    ans_loss_mean.update(ans_loss.item())
                    sup_loss_mean.update(sup_loss.item())
                else:
                    ans_loss_mean.update(ans_loss.item())
                    sup_loss_mean.update(sup_loss.item())
            if args.mode == 'fine-tune':
                if step % 1000 == 0:
                    output_model_file = os.path.join(args.model_dir, 'bert-base-uncased.bin.tmp')
                    saved_dict = {'bert-params': model.module.state_dict()}
                    torch.save(saved_dict, output_model_file)
            else:
                if step % 1000 == 0:
                    metircs = evaluate_rd(args,dataset[1],model,D_model,device)
                    if metircs['joint_f1'] > f1_s_max:
                        output_model_file = './models/DQA_model.bin.tmp'
                        saved_dict = {'bert-params': model.module.state_dict()}
                        saved_dict['dg-params'] = D_model.state_dict()
                        torch.save(saved_dict, output_model_file)

    return (model,D_model)
Exemple #6
0
def fit(num_epoch=args['num_train_epochs']):
    global_step = 0
    model.train()
    for i_ in tqdm(range(int(num_epoch)), desc="Epoch"):
        print('当前阶段******************************', i_)
        tr_loss, tr_accuracy = 0, 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for index, batch in enumerate(tqdm(train_dataloader,
                                           desc="Iteration")):

            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch

            try:
                logits = model(input_ids, segment_ids, input_mask, label_ids)
                tmp_train_loss = loss_fct(logits.view(-1, num_labels),
                                          label_ids.squeeze())
                tmp_train_accuracy = accuracy(
                    logits.view(-1, num_labels).detach().cpu().numpy(),
                    label_ids.squeeze().detach().cpu().numpy())
                if n_gpu > 1:
                    tmp_train_loss = tmp_train_loss.mean(
                    )  # mean() to average on multi-gpu.

                if args["local_rank"] != -1:
                    tmp_train_loss = reduce_tensor(tmp_train_loss)
                    tmp_train_accuracy = reduce_tensor(
                        torch.tensor(tmp_train_accuracy).to(device))

                tmp_train_loss = tmp_train_loss / args[
                    'gradient_accumulation_steps']
                with amp.scale_loss(tmp_train_loss, optimizer) as scaled_loss:
                    scaled_loss.backward()

                # if args['fp16']:
                #     optimizer.backward(tmp_train_loss)
                # else:
                #     tmp_train_loss.backward()

                if (index + 1) % args['gradient_accumulation_steps'] == 0:
                    optimizer.step()
                    optimizer.zero_grad()

                tr_loss += tmp_train_loss.item()
                tr_accuracy += tmp_train_accuracy.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                global_step += 1
            except RuntimeError as e:
                if 'out of memory' in str(e):
                    print('| WARNING: ran out of memory')
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                else:
                    raise e

            # Tensorboard Logging
            eval_loss, eval_accuracy = 0, 0
            if global_step % 100 == 0:
                eval_loss, eval_accuracy = eval()

                logger.info('tr_loss:{} & tr_accuracy:{}'.format(
                    tr_loss / nb_tr_steps, tr_accuracy / nb_tr_examples))
                logger.info('eval_loss:{} & eval_accuracy:{}'.format(
                    eval_loss, eval_accuracy))
                info = {
                    'tr_loss': tr_loss / nb_tr_steps,
                    'tr_accuracy': tr_accuracy / nb_tr_examples
                }
                for tag, value in info.items():
                    loggers.scalar_summary(tag, value, global_step + 1)
                info = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy}
                for tag, value in info.items():
                    loggers.scalar_summary(tag, value, global_step + 1)

            # 将模型保存下来
            if global_step % 200 == 0:
                params.append(eval_accuracy)
                if eval_accuracy >= max(params):
                    if args["local_rank"] == -1:
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            model_path, "finetuned_pytorch_model.bin")
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                    elif args["local_rank"] == 0:
                        checkpoint = {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'amp': amp.state_dict()
                        }
                        output_model_file = os.path.join(
                            model_path, "amp_checkpoint.pt")
                        torch.save(checkpoint, output_model_file)
                    # model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
                    # output_model_file = os.path.join(model_path, "checkpoint.pt")
                    # torch.save({
                    #     'model': model_to_save.state_dict()
                    # }, output_model_file)

        if args["fp16"]:
            #             scheduler.batch_step()
            # modify learning rate with special warm up BERT uses
            lr_this_step = args['learning_rate'] * warmup_linear(
                global_step / t_total, args['warmup_proportion'])
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr_this_step
        else:
            scheduler.step()