def create_examples(jsn, graph_retriever_config):

    task = graph_retriever_config.task

    examples = []
    '''
    Find the mximum size of the initial context (links are not included)
    '''
    graph_retriever_config.max_context_size = 0

    for data in jsn:

        guid = data['q_id']
        question = data['question']
        context = data['context']  # {context title: paragraph}
        all_linked_paras_dic = {}  # {context title: {linked title: paragraph}}
        '''
        Use TagMe-based context at test time.
        '''
        if graph_retriever_config.tagme:
            assert 'tagged_context' in data
            '''
            Reformat "tagged_context" if needed (c.f. the "context" case above)
            '''
            if type(data['tagged_context']) == list:
                tagged_context = {c[0]: c[1] for c in data['tagged_context']}
                data['tagged_context'] = tagged_context
            '''
            Append valid paragraphs from "tagged_context" to "context"
            '''
            for tagged_title in data['tagged_context']:
                tagged_text = data['tagged_context'][tagged_title]
                if tagged_title not in context and tagged_title is not None and tagged_title.strip(
                ) != '' and tagged_text is not None and tagged_text.strip(
                ) != '':
                    context[tagged_title] = tagged_text
        '''
        Clean "context" by removing invalid paragraphs
        '''
        removed_keys = []
        for title in context:
            if title is None or title.strip(
            ) == '' or context[title] is None or context[title].strip() == '':
                removed_keys.append(title)
        for key in removed_keys:
            context.pop(key)

        all_paras = {}
        for title in context:
            all_paras[title] = context[title]

        if graph_retriever_config.expand_links:
            expand_links(context, all_linked_paras_dic, all_paras)

        graph_retriever_config.max_context_size = max(
            graph_retriever_config.max_context_size, len(context))

        examples.append(
            InputExample(guid=guid,
                         q=question,
                         c=context,
                         para_dic=all_linked_paras_dic,
                         s_g=None,
                         r_g=None,
                         all_r_g=None,
                         all_paras=all_paras))

    return examples
    def beam_search(self, input_ids, token_type_ids, attention_mask, examples,
                    tokenizer, retriever, split_chunk):
        beam = self.graph_retriever_config.beam
        B = input_ids.size(0)
        N = self.graph_retriever_config.max_para_num

        pred = []
        prob = []

        topk_pred = []
        topk_prob = []

        eos_index = N

        init_paragraphs, state = self.encode(input_ids,
                                             token_type_ids,
                                             attention_mask,
                                             split_chunk=split_chunk)

        # Output matrix to be populated
        ps = torch.FloatTensor(N + 1, self.s.size(0)).zero_().to(
            self.s.device)  # (N+1, D)

        for i in range(B):
            init_context_len = len(examples[i].context)

            # Populating the output matrix by the initial encoding
            ps[:init_context_len, :].copy_(
                init_paragraphs[i, :init_context_len, :])
            ps[-1, :].copy_(init_paragraphs[i, -1, :])
            encoded_titles = set(examples[i].title_order)

            pred_ = [
                [[], [], 1.0] for _ in range(beam)
            ]  # [hist_1, topk_1, score_1], [hist_2, topk_2, score_2], ...
            prob_ = [[] for _ in range(beam)]

            state_ = state[i:i + 1]  # (1, 1, D)
            state_ = state_.expand(beam, 1, state_.size(2))  # -> (beam, 1, D)
            state_tmp = torch.FloatTensor(state_.size()).zero_().to(
                state_.device)

            for j in range(self.graph_retriever_config.max_select_num):
                if j > 0:
                    input = [p[0][-1] for p in pred_]
                    input = torch.LongTensor(input).to(ps.device)
                    input = ps[input].unsqueeze(1)  # (beam, 1, D)
                    state_ = torch.cat((state_, input),
                                       dim=2)  # (beam, 1, 2*D)
                    state_ = self.rw(state_)  # (beam, 1, D)
                    state_ = self.weight_norm(state_)

                # Opening new links from the previous predictions (pupulating the output matrix dynamically)
                if j > 0:
                    prev_title_size = len(examples[i].title_order)
                    new_titles = []
                    for b in range(beam):
                        prev_pred = pred_[b][0][-1]

                        if prev_pred == eos_index:
                            continue

                        prev_title = examples[i].title_order[prev_pred]

                        if prev_title not in examples[i].all_linked_paras_dic:

                            if retriever is None:
                                continue
                            else:
                                linked_paras_dic = retriever.get_hyperlinked_abstract_paragraphs(
                                    prev_title, examples[i].question)
                                examples[i].all_linked_paras_dic[
                                    prev_title] = {}
                                examples[i].all_linked_paras_dic[
                                    prev_title].update(linked_paras_dic)
                                examples[i].all_paras.update(linked_paras_dic)

                        for linked_title in examples[i].all_linked_paras_dic[
                                prev_title]:
                            if linked_title in encoded_titles or len(
                                    examples[i].title_order) == N:
                                continue

                            encoded_titles.add(linked_title)
                            new_titles.append(linked_title)
                            examples[i].title_order.append(linked_title)

                    if len(new_titles) > 0:

                        tokens_q = tokenize_question(examples[i].question,
                                                     tokenizer)
                        input_ids = []
                        input_masks = []
                        segment_ids = []
                        for linked_title in new_titles:
                            linked_para = examples[i].all_paras[linked_title]

                            input_ids_, input_masks_, segment_ids_ = tokenize_paragraph(
                                linked_para, tokens_q,
                                self.graph_retriever_config.max_seq_length,
                                tokenizer)
                            input_ids.append(input_ids_)
                            input_masks.append(input_masks_)
                            segment_ids.append(segment_ids_)

                        input_ids = torch.LongTensor([input_ids]).to(ps.device)
                        token_type_ids = torch.LongTensor([segment_ids
                                                           ]).to(ps.device)
                        attention_mask = torch.LongTensor([input_masks
                                                           ]).to(ps.device)

                        paragraphs, _ = self.encode(input_ids,
                                                    token_type_ids,
                                                    attention_mask,
                                                    split_chunk=split_chunk)
                        paragraphs = paragraphs.squeeze(0)
                        ps[prev_title_size:prev_title_size +
                           len(new_titles)].copy_(
                               paragraphs[:len(new_titles), :])

                        if retriever is not None and self.graph_retriever_config.expand_links:
                            expand_links(examples[i].all_paras,
                                         examples[i].all_linked_paras_dic,
                                         examples[i].all_paras)

                output = torch.bmm(state_,
                                   ps.unsqueeze(0).expand(
                                       beam, ps.size(0), ps.size(1)).transpose(
                                           1, 2))  # (beam, 1, N+1)
                output = output + self.bias
                output = torch.sigmoid(output)

                output = output.to(self.cpu)

                if j == 0:
                    output[:, :, len(examples[i].context):] = 0.0
                else:
                    if len(examples[i].title_order) < N:
                        output[:, :, len(examples[i].title_order):N] = 0.0
                    for b in range(beam):

                        # Omitting previous predictions
                        for k in range(len(pred_[b][0])):
                            output[b, :, pred_[b][0][k]] = 0.0

                        # Links & topK-based pruning
                        if self.graph_retriever_config.pruning_by_links:
                            if pred_[b][0][-1] == eos_index:
                                output[b, :, :eos_index] = 0.0
                                output[b, :, eos_index] = 1.0

                            elif examples[i].title_order[
                                    pred_[b][0]
                                [-1]] not in examples[i].all_linked_paras_dic:
                                for k in range(len(examples[i].title_order)):
                                    if k not in pred_[b][1]:
                                        output[b, :, k] = 0.0

                            else:
                                for k in range(len(examples[i].title_order)):
                                    if k not in pred_[b][1] and examples[
                                            i].title_order[k] not in examples[
                                                i].all_linked_paras_dic[
                                                    examples[i].title_order[
                                                        pred_[b][0][-1]]]:
                                        output[b, :, k] = 0.0

                # always >= M before EOS
                if j <= self.graph_retriever_config.min_select_num - 1:
                    output[:, :, -1] = 0.0

                score = [p[2] for p in pred_]
                score = torch.FloatTensor(score)
                score = score.unsqueeze(1).unsqueeze(2)  # (beam, 1, 1)
                score = output * score

                output = output.squeeze(1)  # (beam, N+1)
                score = score.squeeze(1)  # (beam, N+1)
                new_pred_ = []
                new_prob_ = []

                b = 0
                while b < beam:
                    s, p = torch.max(score.view(score.size(0) * score.size(1)),
                                     dim=0)
                    s = s.item()
                    p = p.item()
                    row = p // score.size(1)
                    col = p % score.size(1)

                    if j == 0:
                        score[:, col] = 0.0
                    else:
                        score[row, col] = 0.0

                    p = [[index for index in pred_[row][0]] + [col],
                         output[row].topk(k=2, dim=0)[1].tolist(), s]
                    new_pred_.append(p)

                    p = [[p_ for p_ in prb]
                         for prb in prob_[row]] + [output[row].tolist()]
                    new_prob_.append(p)

                    state_tmp[b].copy_(state_[row])
                    b += 1

                pred_ = new_pred_
                prob_ = new_prob_
                state_ = state_.clone()
                state_.copy_(state_tmp)

                if pred_[0][0][-1] == eos_index:
                    break

            topk_pred.append([])
            topk_prob.append([])
            for index__ in range(beam):

                pred_tmp = []
                for index in pred_[index__][0]:
                    if index == eos_index:
                        break
                    pred_tmp.append(index)

                if index__ == 0:
                    pred.append(pred_tmp)
                    prob.append(prob_[0])

                topk_pred[-1].append(pred_tmp)
                topk_prob[-1].append(prob_[index__])

        return pred, prob, topk_pred, topk_prob