def convert_examples_to_features(examples, max_seq_length, max_para_num,
                                 graph_retriever_config, tokenizer):
    """Loads a data file into a list of `InputBatch`s."""

    max_para_num = graph_retriever_config.max_context_size
    graph_retriever_config.max_para_num = max(
        graph_retriever_config.max_para_num, max_para_num)

    max_steps = graph_retriever_config.max_select_num

    DUMMY = [0] * max_seq_length
    features = []

    for (ex_index, example) in enumerate(examples):
        tokens_q = tokenize_question(example.question, tokenizer)

        title2index = {}
        input_ids = []
        input_masks = []
        segment_ids = []

        titles_list = list(example.context.keys())
        for p in titles_list:

            if len(input_ids) == max_para_num:
                break

            if p in title2index:
                continue

            title2index[p] = len(title2index)
            example.title_order.append(p)
            p = example.context[p]

            input_ids_, input_masks_, segment_ids_ = tokenize_paragraph(
                p, tokens_q, max_seq_length, tokenizer)
            input_ids.append(input_ids_)
            input_masks.append(input_masks_)
            segment_ids.append(segment_ids_)

        num_paragraphs_no_links = len(input_ids)

        assert len(input_ids) <= max_para_num

        num_paragraphs = len(input_ids)

        output_masks = [([1.0] * len(input_ids) + [0.0] *
                         (max_para_num - len(input_ids) + 1))
                        for _ in range(max_para_num + 2)]

        assert len(example.context) == num_paragraphs_no_links
        for i in range(len(output_masks[0])):
            if i >= num_paragraphs_no_links:
                output_masks[0][i] = 0.0

        for i in range(len(input_ids)):
            output_masks[i + 1][i] = 0.0

        padding = [DUMMY] * (max_para_num - len(input_ids))
        input_ids += padding
        input_masks += padding
        segment_ids += padding

        features.append(
            InputFeatures(input_ids=input_ids,
                          input_masks=input_masks,
                          segment_ids=segment_ids,
                          output_masks=output_masks,
                          num_paragraphs=num_paragraphs,
                          num_steps=-1,
                          ex_index=ex_index))

    return features
    def beam_search(self, input_ids, token_type_ids, attention_mask, examples,
                    tokenizer, retriever, split_chunk):
        beam = self.graph_retriever_config.beam
        B = input_ids.size(0)
        N = self.graph_retriever_config.max_para_num

        pred = []
        prob = []

        topk_pred = []
        topk_prob = []

        eos_index = N

        init_paragraphs, state = self.encode(input_ids,
                                             token_type_ids,
                                             attention_mask,
                                             split_chunk=split_chunk)

        # Output matrix to be populated
        ps = torch.FloatTensor(N + 1, self.s.size(0)).zero_().to(
            self.s.device)  # (N+1, D)

        for i in range(B):
            init_context_len = len(examples[i].context)

            # Populating the output matrix by the initial encoding
            ps[:init_context_len, :].copy_(
                init_paragraphs[i, :init_context_len, :])
            ps[-1, :].copy_(init_paragraphs[i, -1, :])
            encoded_titles = set(examples[i].title_order)

            pred_ = [
                [[], [], 1.0] for _ in range(beam)
            ]  # [hist_1, topk_1, score_1], [hist_2, topk_2, score_2], ...
            prob_ = [[] for _ in range(beam)]

            state_ = state[i:i + 1]  # (1, 1, D)
            state_ = state_.expand(beam, 1, state_.size(2))  # -> (beam, 1, D)
            state_tmp = torch.FloatTensor(state_.size()).zero_().to(
                state_.device)

            for j in range(self.graph_retriever_config.max_select_num):
                if j > 0:
                    input = [p[0][-1] for p in pred_]
                    input = torch.LongTensor(input).to(ps.device)
                    input = ps[input].unsqueeze(1)  # (beam, 1, D)
                    state_ = torch.cat((state_, input),
                                       dim=2)  # (beam, 1, 2*D)
                    state_ = self.rw(state_)  # (beam, 1, D)
                    state_ = self.weight_norm(state_)

                # Opening new links from the previous predictions (pupulating the output matrix dynamically)
                if j > 0:
                    prev_title_size = len(examples[i].title_order)
                    new_titles = []
                    for b in range(beam):
                        prev_pred = pred_[b][0][-1]

                        if prev_pred == eos_index:
                            continue

                        prev_title = examples[i].title_order[prev_pred]

                        if prev_title not in examples[i].all_linked_paras_dic:

                            if retriever is None:
                                continue
                            else:
                                linked_paras_dic = retriever.get_hyperlinked_abstract_paragraphs(
                                    prev_title, examples[i].question)
                                examples[i].all_linked_paras_dic[
                                    prev_title] = {}
                                examples[i].all_linked_paras_dic[
                                    prev_title].update(linked_paras_dic)
                                examples[i].all_paras.update(linked_paras_dic)

                        for linked_title in examples[i].all_linked_paras_dic[
                                prev_title]:
                            if linked_title in encoded_titles or len(
                                    examples[i].title_order) == N:
                                continue

                            encoded_titles.add(linked_title)
                            new_titles.append(linked_title)
                            examples[i].title_order.append(linked_title)

                    if len(new_titles) > 0:

                        tokens_q = tokenize_question(examples[i].question,
                                                     tokenizer)
                        input_ids = []
                        input_masks = []
                        segment_ids = []
                        for linked_title in new_titles:
                            linked_para = examples[i].all_paras[linked_title]

                            input_ids_, input_masks_, segment_ids_ = tokenize_paragraph(
                                linked_para, tokens_q,
                                self.graph_retriever_config.max_seq_length,
                                tokenizer)
                            input_ids.append(input_ids_)
                            input_masks.append(input_masks_)
                            segment_ids.append(segment_ids_)

                        input_ids = torch.LongTensor([input_ids]).to(ps.device)
                        token_type_ids = torch.LongTensor([segment_ids
                                                           ]).to(ps.device)
                        attention_mask = torch.LongTensor([input_masks
                                                           ]).to(ps.device)

                        paragraphs, _ = self.encode(input_ids,
                                                    token_type_ids,
                                                    attention_mask,
                                                    split_chunk=split_chunk)
                        paragraphs = paragraphs.squeeze(0)
                        ps[prev_title_size:prev_title_size +
                           len(new_titles)].copy_(
                               paragraphs[:len(new_titles), :])

                        if retriever is not None and self.graph_retriever_config.expand_links:
                            expand_links(examples[i].all_paras,
                                         examples[i].all_linked_paras_dic,
                                         examples[i].all_paras)

                output = torch.bmm(state_,
                                   ps.unsqueeze(0).expand(
                                       beam, ps.size(0), ps.size(1)).transpose(
                                           1, 2))  # (beam, 1, N+1)
                output = output + self.bias
                output = torch.sigmoid(output)

                output = output.to(self.cpu)

                if j == 0:
                    output[:, :, len(examples[i].context):] = 0.0
                else:
                    if len(examples[i].title_order) < N:
                        output[:, :, len(examples[i].title_order):N] = 0.0
                    for b in range(beam):

                        # Omitting previous predictions
                        for k in range(len(pred_[b][0])):
                            output[b, :, pred_[b][0][k]] = 0.0

                        # Links & topK-based pruning
                        if self.graph_retriever_config.pruning_by_links:
                            if pred_[b][0][-1] == eos_index:
                                output[b, :, :eos_index] = 0.0
                                output[b, :, eos_index] = 1.0

                            elif examples[i].title_order[
                                    pred_[b][0]
                                [-1]] not in examples[i].all_linked_paras_dic:
                                for k in range(len(examples[i].title_order)):
                                    if k not in pred_[b][1]:
                                        output[b, :, k] = 0.0

                            else:
                                for k in range(len(examples[i].title_order)):
                                    if k not in pred_[b][1] and examples[
                                            i].title_order[k] not in examples[
                                                i].all_linked_paras_dic[
                                                    examples[i].title_order[
                                                        pred_[b][0][-1]]]:
                                        output[b, :, k] = 0.0

                # always >= M before EOS
                if j <= self.graph_retriever_config.min_select_num - 1:
                    output[:, :, -1] = 0.0

                score = [p[2] for p in pred_]
                score = torch.FloatTensor(score)
                score = score.unsqueeze(1).unsqueeze(2)  # (beam, 1, 1)
                score = output * score

                output = output.squeeze(1)  # (beam, N+1)
                score = score.squeeze(1)  # (beam, N+1)
                new_pred_ = []
                new_prob_ = []

                b = 0
                while b < beam:
                    s, p = torch.max(score.view(score.size(0) * score.size(1)),
                                     dim=0)
                    s = s.item()
                    p = p.item()
                    row = p // score.size(1)
                    col = p % score.size(1)

                    if j == 0:
                        score[:, col] = 0.0
                    else:
                        score[row, col] = 0.0

                    p = [[index for index in pred_[row][0]] + [col],
                         output[row].topk(k=2, dim=0)[1].tolist(), s]
                    new_pred_.append(p)

                    p = [[p_ for p_ in prb]
                         for prb in prob_[row]] + [output[row].tolist()]
                    new_prob_.append(p)

                    state_tmp[b].copy_(state_[row])
                    b += 1

                pred_ = new_pred_
                prob_ = new_prob_
                state_ = state_.clone()
                state_.copy_(state_tmp)

                if pred_[0][0][-1] == eos_index:
                    break

            topk_pred.append([])
            topk_prob.append([])
            for index__ in range(beam):

                pred_tmp = []
                for index in pred_[index__][0]:
                    if index == eos_index:
                        break
                    pred_tmp.append(index)

                if index__ == 0:
                    pred.append(pred_tmp)
                    prob.append(prob_[0])

                topk_pred[-1].append(pred_tmp)
                topk_prob[-1].append(prob_[index__])

        return pred, prob, topk_pred, topk_prob