Beispiel #1
0
def main():
    parser = argparse.ArgumentParser()
    BERT_DIR = "./model/uncased_L-12_H-768_A-12/"
    ## Required parameters
    parser.add_argument("--bert_config_file", default=BERT_DIR+"bert_config.json", \
                        type=str, help="The config json file corresponding to the pre-trained BERT model. "
                             "This specifies the model architecture.")
    parser.add_argument("--vocab_file", default=BERT_DIR+"vocab.txt", type=str, \
                        help="The vocabulary file that the BERT model was trained on.")
    parser.add_argument("--output_dir", default="out", type=str, \
                        help="The output directory where the model checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--train_file", type=str, \
                        help="SQuAD json for training. E.g., train-v1.1.json", \
                        default="")
    parser.add_argument("--predict_file", type=str,
                        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json", \
                        default="")
    parser.add_argument("--init_checkpoint", type=str,
                        help="Initial checkpoint (usually from a pre-trained BERT model).", \
                        default=BERT_DIR+"pytorch_model.bin")
    parser.add_argument(
        "--do_lower_case",
        default=True,
        action='store_true',
        help="Whether to lower case the input text. Should be True for uncased "
        "models and False for cased models.")
    parser.add_argument(
        "--max_seq_length",
        default=300,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=128,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=10.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
        "of training.")
    parser.add_argument("--save_checkpoints_steps",
                        default=1000,
                        type=int,
                        help="How often to save the model checkpoint.")
    parser.add_argument("--iterations_per_loop",
                        default=1000,
                        type=int,
                        help="How many steps to make in each estimator call.")
    parser.add_argument(
        "--n_best_size",
        default=3,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json "
        "output file.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")

    parser.add_argument(
        "--verbose_logging",
        default=False,
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        "--accumulate_gradients",
        type=int,
        default=1,
        help=
        "Number of steps to accumulate gradient on (divide the batch_size and accumulate)"
    )
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument('--eval_period', type=int, default=2000)

    parser.add_argument('--max_n_answers', type=int, default=5)
    parser.add_argument('--merge_query', type=int, default=-1)
    parser.add_argument('--reduce_layers', type=int, default=-1)
    parser.add_argument('--reduce_layers_to_tune', type=int, default=-1)

    parser.add_argument('--only_comp', action="store_true", default=False)

    parser.add_argument('--train_subqueries_file', type=str, default="")  #500
    parser.add_argument('--predict_subqueries_file', type=str,
                        default="")  #500
    parser.add_argument('--prefix', type=str, default="")  #500

    parser.add_argument('--model', type=str, default="qa")  #500
    parser.add_argument('--pooling', type=str, default="max")
    parser.add_argument('--debug', action="store_true", default=False)
    parser.add_argument('--output_dropout_prob', type=float, default=0)
    parser.add_argument('--wait_step', type=int, default=30)
    parser.add_argument('--with_key', action="store_true", default=False)
    parser.add_argument('--add_noise', action="store_true", default=False)

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device %s n_gpu %d distributed training %r", device, n_gpu,
                bool(args.local_rank != -1))

    if args.accumulate_gradients < 1:
        raise ValueError(
            "Invalid accumulate_gradients parameter: {}, should be >= 1".
            format(args.accumulate_gradients))

    args.train_batch_size = int(args.train_batch_size /
                                args.accumulate_gradients)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
        if not args.predict_file:
            raise ValueError(
                "If `do_train` is True, then `predict_file` must be specified."
            )

    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )

    bert_config = BertConfig.from_json_file(args.bert_config_file)

    if args.do_train and args.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (args.max_seq_length, bert_config.max_position_embeddings))

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        logger.info("Output directory () already exists and is not empty.")
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)

    tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file,
                                           do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_steps = None

    eval_dataloader, eval_examples, eval_features, _ = get_dataloader(
        logger=logger,
        args=args,
        input_file=args.predict_file,
        subqueries_file=args.predict_subqueries_file,
        is_training=False,
        batch_size=args.predict_batch_size,
        num_epochs=1,
        tokenizer=tokenizer)
    if args.do_train:
        train_dataloader, train_examples, _, num_train_steps = get_dataloader(
                logger=logger, args=args, \
                input_file=args.train_file, \
                subqueries_file=args.train_subqueries_file, \
                is_training=True,
                batch_size=args.train_batch_size,
                num_epochs=args.num_train_epochs,
                tokenizer=tokenizer)

    #a = input()
    if args.model == 'qa':
        model = BertForQuestionAnswering(bert_config, 4)
        metric_name = "F1"
    elif args.model == 'classifier':
        if args.reduce_layers != -1:
            bert_config.num_hidden_layers = args.reduce_layers
        model = BertClassifier(bert_config, 2, args.pooling)
        metric_name = "F1"
    elif args.model == "span-predictor":
        if args.reduce_layers != -1:
            bert_config.num_hidden_layers = args.reduce_layers
        if args.with_key:
            Model = BertForQuestionAnsweringWithKeyword
        else:
            Model = BertForQuestionAnswering
        model = Model(bert_config, 2)
        metric_name = "Accuracy"
    else:
        raise NotImplementedError()

    if args.init_checkpoint is not None and args.do_predict and \
                len(args.init_checkpoint.split(','))>1:
        assert args.model == "qa"
        model = [model]
        for i, checkpoint in enumerate(args.init_checkpoint.split(',')):
            if i > 0:
                model.append(BertForQuestionAnswering(bert_config, 4))
            print("Loading from", checkpoint)
            state_dict = torch.load(checkpoint, map_location='cpu')
            filter = lambda x: x[7:] if x.startswith('module.') else x
            state_dict = {filter(k): v for (k, v) in state_dict.items()}
            model[-1].load_state_dict(state_dict)
            model[-1].to(device)

    else:
        if args.init_checkpoint is not None:
            print("Loading from", args.init_checkpoint)
            state_dict = torch.load(args.init_checkpoint, map_location='cpu')
            if args.reduce_layers != -1:
                state_dict = {k:v for k, v in state_dict.items() \
                    if not '.'.join(k.split('.')[:3]) in \
                    ['encoder.layer.{}'.format(i) for i in range(args.reduce_layers, 12)]}
            if args.do_predict:
                filter = lambda x: x[7:] if x.startswith('module.') else x
                state_dict = {filter(k): v for (k, v) in state_dict.items()}
                model.load_state_dict(state_dict)
            else:
                model.bert.load_state_dict(state_dict)
                if args.reduce_layers_to_tune != -1:
                    model.bert.embeddings.required_grad = False
                    n_layers = 12 if args.reduce_layers == -1 else args.reduce_layers
                    for i in range(n_layers - args.reduce_layers_to_tune):
                        model.bert.encoder.layer[i].require_grad = False

        model.to(device)

        if args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[args.local_rank],
                output_device=args.local_rank)
        elif n_gpu > 1:
            model = torch.nn.DataParallel(model)

    if args.do_train:
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_parameters = [{
            'params':
            [p for n, p in model.named_parameters() if n not in no_decay],
            'weight_decay_rate':
            0.01
        }, {
            'params':
            [p for n, p in model.named_parameters() if n in no_decay],
            'weight_decay_rate':
            0.0
        }]

        optimizer = BERTAdam(optimizer_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_steps)

        global_step = 0

        best_f1 = 0
        wait_step = 0
        model.train()
        global_step = 0
        stop_training = False

        for epoch in range(int(args.num_train_epochs)):
            for step, batch in tqdm(enumerate(train_dataloader)):
                global_step += 1
                batch = [t.to(device) for t in batch]
                loss = model(batch, global_step)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss.backward()
                if global_step % args.gradient_accumulation_steps == 0:
                    optimizer.step()  # We have accumulated enought gradients
                    model.zero_grad()
                if global_step % args.eval_period == 0:
                    model.eval()
                    f1 =  predict(args, model, eval_dataloader, eval_examples, eval_features, \
                                  device, write_prediction=False)
                    logger.info("%s: %.3f on epoch=%d" %
                                (metric_name, f1 * 100.0, epoch))
                    if best_f1 < f1:
                        logger.info("Saving model with best %s: %.3f -> %.3f on epoch=%d" % \
                                (metric_name, best_f1*100.0, f1*100.0, epoch))
                        model_state_dict = {
                            k: v.cpu()
                            for (k, v) in model.state_dict().items()
                        }
                        torch.save(
                            model_state_dict,
                            os.path.join(args.output_dir, "best-model.pt"))
                        model = model.cuda()
                        best_f1 = f1
                        wait_step = 0
                        stop_training = False
                    else:
                        wait_step += 1
                        if best_f1 > 0.1 and wait_step == args.wait_step:
                            stop_training = True
                    model.train()
            if stop_training:
                break

    elif args.do_predict:
        if type(model) == list:
            model = [m.eval() for m in model]
        else:
            model.eval()
        f1 = predict(args, model, eval_dataloader, eval_examples,
                     eval_features, device)
        logger.info("Final %s score: %.3f%%" % (metric_name, f1 * 100.0))
Beispiel #2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir",
                        default='./data',
                        type=str,
                        help="Data dir containing model dir, etc.")
    parser.add_argument(
        "--tfidf_file",
        default=
        'wiki_first_paras-tfidf-ngram=2-hash=16777216-tokenizer=spacy.npz',
        type=str,
        help="td-idf .npz file placed inside the data_dir.")
    parser.add_argument(
        "--wiki_jsonl",
        default='wiki_firstpara_sents.jsonl',
        type=str,
        help="Processed wikipedia .jsonl placed inside data_dir.")
    parser.add_argument("--qdmr_jsonl",
                        default='./data/qdmr_data/qdmrs_hotpotqa_gold.jsonl',
                        type=str,
                        help="Path to processed qdmr .jsonl file.")
    parser.add_argument("--predict_batch_size",
                        default=128,
                        type=int,
                        help="Batch size for predictions in eval mode.")
    parser.add_argument("--tasks",
                        default='break_rc,ques_ir,break_ir',
                        type=str,
                        help="The IR, RC tasks to perform.")
    parser.add_argument("--suffix",
                        default='gold',
                        type=str,
                        help="Suffix to add to the output files.")
    parser.add_argument("--debug",
                        action='store_true',
                        help="If on, only keep a small number of qdmrs.")
    parser.add_argument(
        "--input_results_file",
        default='',
        type=str,
        help="File containing results of the task to be reused.")
    args = parser.parse_args()

    # we use an already finetuned single-hop RC ensemble by
    # Min et al (https://github.com/shmsw25/DecompRC/tree/master/DecompRC)
    rc_args = {
        'bert_config_file':
        'data/onehop_rc/uncased_L-12_H-768_A-12/bert_config.json',
        'do_lower_case': True,
        'doc_stride': 128,
        'init_checkpoint':
        f'{args.data_dir}/onehop_rc/uncased_L-12_H-768_A-12/model1.pt,{args.data_dir}/onehop_rc/uncased_L-12_H-768_A-12/model2.pt,{args.data_dir}/onehop_rc/uncased_L-12_H-768_A-12/model3.pt',
        'iterations_per_loop': 1000,
        'local_rank': -1,
        'max_answer_length': 30,
        'max_n_answers': 5,
        'max_query_length': 64,
        'max_seq_length': 300,
        'model': 'qa',
        'n_best_size': 4,
        'no_cuda': False,
        'output_dropout_prob': 0,
        'pooling': 'max',
        'seed': 42,
        'verbose_logging': False,
        'vocab_file': 'data/onehop_rc/uncased_L-12_H-768_A-12/vocab.txt',
        'with_key': False
    }
    rc_args = SimpleNamespace(**rc_args)

    # load hotpotQA
    logging.info(f'loading datasets from {args.data_dir}/hotpot_data/ ...')
    data = read_file(f'{args.data_dir}/hotpot_data/hotpot_train_v1.json')
    #data += read_file(f'{args.data_dir}/hotpot_data/hotpot_dev_distractor_v1.json')
    data += read_file(
        f'{args.data_dir}/hotpot_data/hotpot_dev_fullwiki_v1.json')
    for d in data:
        d['gold_titles'] = {x[0] for x in d['supporting_facts']}
    hotpot = {d['_id']: d for d in data}

    # load qdmr data processed using prepare_break.jsonl
    qdmr_path = args.qdmr_jsonl
    logging.info(f'loading processed qdmr data from {qdmr_path} ...')
    qdmrs = read_file(qdmr_path)

    # load spacy
    nlp = en_core_web_sm.load()  # spacy
    tokenize = lambda s: [x.text for x in nlp.tokenizer(s)]

    # load IR
    logging.info('loading IR ...')
    ranker = IR(tfidf_path=f'{args.data_dir}/{args.tfidf_file}')

    # load wikipedia
    wiki_path = f'{args.data_dir}/{args.wiki_jsonl}'
    logging.info(f'loading wikipedia from {wiki_path} ...')
    with jsonlines.open(wiki_path, 'r') as reader:
        wiki = {d['title']: d['para'] for d in tqdm(reader.iter())}

    # prepare and load the RC for inference
    device = torch.device("cuda")
    n_gpu = torch.cuda.device_count()
    logging.info(f'{n_gpu} cuda devices available.')
    logging.info('loading 1-hop RC ensemble ...')

    random.seed(rc_args.seed)
    np.random.seed(rc_args.seed)
    torch.manual_seed(rc_args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(rc_args.seed)

    tokenizer = tokenization.FullTokenizer(vocab_file=rc_args.vocab_file,
                                           do_lower_case=rc_args.do_lower_case)

    bert_config = BertConfig.from_json_file(rc_args.bert_config_file)
    model = BertForQuestionAnswering(bert_config, 4)

    if rc_args.init_checkpoint is not None:
        model = [model]
        for i, checkpoint in enumerate(rc_args.init_checkpoint.split(',')):
            if i > 0:
                model.append(BertForQuestionAnswering(bert_config, 4))
            logging.info(f"Loading from {checkpoint}")
            state_dict = torch.load(checkpoint, map_location='cpu')
            filter = lambda x: x[7:] if x.startswith('module.') else x
            state_dict = {filter(k): v for (k, v) in state_dict.items()}
            model[-1].load_state_dict(state_dict)
            model[-1].to(device)

    if type(model) == list:
        model = [m.eval() for m in model]
    else:
        model.eval()

    # 1hop RC wrapper
    simpQA = partial(_simpQA,
                     args=args,
                     rc_args=rc_args,
                     tokenizer=tokenizer,
                     model=model,
                     device=device)

    if args.input_results_file:
        logging.info(
            f'Reading the supplied results file {args.input_results_file} ...')
        all_results = read_file(args.input_results_file)
    else:
        all_results = {}
        for i_d, d in enumerate(qdmrs):
            [_, data_split, _id] = d['question_id'].split('_')
            #assert data_split == 'dev'
            assert _id in hotpot
            all_results[_id] = d
            assert d['steps'] and d['op_types'], print(
                'QDMRs must be pre-processed and non-empty.')

    if args.debug:
        all_results = {
            key: val
            for key, val in all_results.items() if random.random() < 0.01
        }
        logging.info(f'\nTruncating to only {len(all_results)} samples!!!\n')

    tasks = [x.strip() for x in args.tasks.split(',')]

    if 'break_rc' in tasks:
        logging.info(f'Running BREAK IR+RC on {len(all_results)} samples ...')
        max_n_parts = max([len(v['steps']) for v in all_results.values()])

        for i_p in range(max_n_parts):
            logging.info(f'Processing qdmr step #{i_p} ...')
            # process the i_p'th part of all samples
            articles = []  # hotpot articles corresponding to queries to the RC
            for _id, v in tqdm(all_results.items()):
                parts = v['steps']
                if i_p >= (len(parts) - int(
                        v['op_types'][-1] in ['COMPARISON', 'INTERSECTION'])):
                    # the last discrete comparison, intersection step is processed later
                    continue
                rc_outputs = v['rc_outputs'] if 'rc_outputs' in v else {}
                nbest_outputs = v[
                    'nbest_outputs'] if 'nbest_outputs' in v else {}
                l_top = v['titles'] if 'titles' in v else []
                part = parts[i_p]
                # replace placeholders with the respective RC outputs of previous parts
                for j in range(i_p):
                    ph = '#' + str(j + 1)  # 1...i_p
                    if ph in part:
                        part = part.replace(ph, rc_outputs[ph])
                # get top 10 titles from IR
                top_titles = ranker.closest_docs(part, k=10)[0]
                l_top.append(top_titles)
                v.update({
                    'titles': l_top,
                    'rc_outputs': rc_outputs,
                    'nbest_outputs': nbest_outputs
                })
                context = []
                # use all retrieved para for the sample instead of just current 10 & sort them acc to similarity wrt part
                set_l_top = set(sum(l_top, []))
                scores = ranker.rank_titles(part, set_l_top)
                sorted_l_top = sorted(scores.keys(),
                                      key=lambda title: scores[title],
                                      reverse=True)
                for title in sorted_l_top:
                    context.append([title, wiki[title]['sents']
                                    ])  # get para from wiki
                if not sorted_l_top:
                    # rare case of no valid titles
                    context = [['Random Title 1', 'Random Text 1'],
                               ['Random Title 2', 'Random Text 2']]
                d, article = hotpot[_id], {}
                article['question'], article[
                    'context'] = part + ' ?', context  # appending '?' to part query
                article.update({
                    k: d[k]
                    for k in ['_id', 'type', 'answer']
                })  # '_id', 'type', 'context', 'question', 'answer'
                articles.append(article)
            if not articles:
                continue
            # querying the 1-hop RC
            all_nbest_out = simpQA([to_squad(article)
                                    for article in articles])[1]
            for _id, v in all_results.items():
                if _id not in all_nbest_out:
                    continue
                nbest_i_p = all_nbest_out[_id]
                op = v['op_types'][i_p]
                nbest_id = v['nbest_outputs']
                # handle filter steps
                if 'FILTER' in op:
                    ref_ph = op.split('_')[1]
                    nbest_ref = Counter(nbest_id[ref_ph])
                    # accumulating the logits of nbest of the part and the ref part
                    nbest_ref.update(nbest_i_p)
                    nbest_i_p = dict(nbest_ref)
                rc_out = max(nbest_i_p.keys(), key=lambda key: nbest_i_p[key])
                v['rc_outputs'][f'#{i_p+1}'] = rc_out
                v['nbest_outputs'][f'#{i_p+1}'] = nbest_i_p

        # discrete processing of the last comparison step
        logging.info(
            f'Discrete processing of the last comparison/intersection steps ...'
        )
        for _id, v in all_results.items():
            if v['op_types'][-1] not in ['COMPARISON', 'INTERSECTION']:
                continue
            question, answer, gold_titles = hotpot[_id]['question'], hotpot[
                _id]['answer'], hotpot[_id]['gold_titles']
            parts, rc_outputs = v['steps'], v['rc_outputs']
            if v['op_types'][-1] == 'COMPARISON':
                ents, rc_outs = [], []
                for i_p, part in enumerate(parts[:-1]):
                    # get named entity in the part
                    part_without_phs = part
                    for x in ['#' + str(j) for j in range(1, 8)]:
                        part_without_phs = part_without_phs.replace(x, '')
                    ent = get_ent(part_without_phs, nlp, only_longest=True)
                    ent = '' if ent is None else ent
                    ents.append(ent)
                    rc_outs.append(
                        normalize_answer(rc_outputs['#' + str(i_p + 1)]))
                if 'same as' in parts[-1]:
                    pred_ans = 'yes' if rc_outs[-2] == rc_outs[-1] else 'no'
                else:
                    pred_ans = ents[compare(parts[-1], rc_outs[-2],
                                            rc_outs[-1])]
                    v['rc_outputs'][f'#{len(parts)}'] = pred_ans
            elif v['op_types'][-1] == 'INTERSECTION':
                part = parts[-1]
                phs = [
                    '#' + str(j) for j in range(1, 10) if '#' + str(j) in part
                ]
                phs = list(set(phs))
                # accumulate logits of the parts and take the argmax
                nbest_id = v['nbest_outputs']
                nbest = Counter(nbest_id[phs[0]])
                # accumulate logits
                for ph in phs[1:]:
                    if ph in nbest_id:
                        nbest.update(nbest_id[ph])
                nbest = dict(nbest)
                pred_ans = max(nbest.keys(), key=lambda key: nbest[key])
                v['rc_outputs'][f'#{len(parts)}'] = pred_ans
                v['nbest_outputs'][f'#{len(parts)}'] = nbest

        for v in all_results.values():
            assert len(v['rc_outputs']) == len(v['steps'])

    if 'break_ir' in tasks:
        # this can only be run after break_rc task & requires all_results dict
        logging.info(
            f'Forming context using the titles used by Break RC for {len(all_results)} samples ...'
        )
        # prepare hotpot-like data for Bert RC
        new_hotpot = []
        for _id, v in tqdm(all_results.items()):
            d = hotpot[_id]
            d_new = deepcopy(d)
            used_titles = sum(v['titles'], [])
            # sort wrt similarity to ques
            scores = ranker.rank_titles(d['question'], set(used_titles))
            titles = sorted(scores.keys(),
                            key=lambda title: scores[title],
                            reverse=True)
            context = []
            for title in titles:
                context.append([title, wiki[title]['sents']])
            d_new['context'] = context
            if 'gold_titles' in d_new:
                del d_new['gold_titles']
            new_hotpot.append(d_new)
        out_break_ir_file = f'{args.data_dir}/hotpot_data/hotpot_after_break_ir_{args.suffix}.json'
        logging.info(
            f'Writing hotpot version with the Break IR context to {out_break_ir_file} ...'
        )
        write_file(new_hotpot, out_break_ir_file)

        # store the retrieved titles
        for d in new_hotpot:
            all_results[d['_id']]['titles_found_by_break_rc'] = list(
                set([x[0] for x in d['context']]))

    if 'ques_ir' in tasks:
        # this can only be run after break_rc task & requires all_results dict formed
        # to determine the number of titles to be retrieved for each sample
        logging.info(
            f'Running baseline IR using the whole question for {len(all_results)} samples ...'
        )
        # prepare hotpot-like data for Bert RC
        new_hotpot = []
        for _id in tqdm(all_results.keys()):
            d = hotpot[_id]
            d_new = deepcopy(d)
            # for fair comparison with Break RC retrieve the same number of titles
            n_titles = len(sum(all_results[_id]['titles'], []))
            titles = ranker.closest_docs(d['question'], k=n_titles)[0]
            context = []
            for title in titles:
                context.append([title, wiki[title]['sents']])
            d_new['context'] = context
            if 'gold_titles' in d_new:
                del d_new['gold_titles']
            new_hotpot.append(d_new)
        out_ques_ir_file = f'{args.data_dir}/hotpot_data/hotpot_after_ques_ir_{args.suffix}.json'
        logging.info(
            f'Writing hotpot version with the baseline IR context to {out_ques_ir_file} ...'
        )
        write_file(new_hotpot, out_ques_ir_file)

        # store the retrieved titles
        for d in new_hotpot:
            all_results[d['_id']]['titles_found_using_whole_ques'] = list(
                set([x[0] for x in d['context']]))

    # save the Break RC outputs
    out_break_rc_file = f'{args.data_dir}/predictions/break_rc_results_{args.suffix}.json'
    logging.info(f'Writing the break RC results to {out_break_rc_file}...')
    os.makedirs(dirname(out_break_rc_file), exist_ok=True)
    write_file(all_results, out_break_rc_file)