示例#1
0
# construct cognitive graph in training data
from utils import judge_question_type


def find_fact_content(bundle, title, sen_num):
    for x in bundle['context']:
        if x[0] == title:
            return x[1][sen_num]


test = copy.deepcopy(train_set)
for bundle in tqdm(test):
    entities = set([title for title, sen_num in bundle['supporting_facts']])
    bundle['Q_edge'] = fuzzy_find(entities, bundle['question'])
    question_type = judge_question_type(bundle['question'])
    for fact in bundle['supporting_facts']:
        try:
            title, sen_num = fact
            pool = set()
            for i in range(sen_num + 1):
                name = 'edges:###{}###{}'.format(i, title)
                tmp = set([
                    x.decode().split('###')[0] for x in db.lrange(name, 0, -1)
                ])
                pool |= tmp
            pool &= entities
            stripped = [re.sub(r' \(.*?\)$', '', x)
                        for x in pool] + ['yes', 'no']
            if bundle['answer'] not in stripped:
                if fuzz.ratio(re.sub(r'\(.*?\)$', '', title),
示例#2
0
def convert_question_to_samples_bundle(tokenizer, data: 'Json refined', neg=2):
    context = dict(data['context'])
    gold_sentences_set = dict([
        ((para, sen), edges) for para, sen, edges in data['supporting_facts']
    ])
    e2i = {}
    i2e = []
    for entity, sens in context.items():
        assert not entity in e2i
        e2i[entity] = len(i2e)
        i2e.append(entity)
    prev = [[]] * len(i2e)

    ids, hop_start_weights, hop_end_weights, ans_start_weights, ans_end_weights, segment_ids, sep_positions, additional_nodes = [], [], [], [], [], [], [], []
    tokenized_question = ['[CLS]'] + tokenizer.tokenize(
        data['question']) + ['[SEP]']
    for title_x, sen, edges in data[
            'supporting_facts']:  # TODO: match previous sentence
        for title_y, matched, l, r in edges:
            if title_y not in e2i:  # answer
                assert data['answer'] == title_y
                e2i[title_y] = len(i2e)
                i2e.append(title_y)
                prev.append([])
            if title_x != title_y:
                y = e2i[title_y]
                prev[y] = prev[y] + tokenizer.tokenize(
                    context[title_x][sen]) + ['[SEP]']
    question_type = judge_question_type(data['question'])

    # fix by answer:
    if data['answer'] == 'yes' or data['answer'] == 'no':
        question_type = 1
        answer_entity = data['answer']
    else:
        # find answer entity
        answer_entity = fuzzy_retrieve(data['answer'], e2i, 'distractor', 80)
        if answer_entity is None:
            raise ValueError('Cannot find answer: {}'.format(data['answer']))

    if question_type == 0:
        answer_id = e2i[answer_entity]
    elif len(data['Q_edge']) != 2:
        if question_type == 1:
            raise ValueError(
                'There must be 2 entities in "Q_edge" for type 1 question.')
        elif question_type == 2:
            # print('Convert type 2 question to 0.\n Question:{}'.format(data['question']))
            question_type = 0
            answer_id = e2i[answer_entity]
    else:
        answer_id = [e2i[data['Q_edge'][0][0]], e2i[data['Q_edge'][1][0]]]
        if question_type == 1:
            answer_id.append(int(data['answer'] == 'yes'))
        elif question_type == 2:
            if data['answer'] == data['Q_edge'][0][1]:
                answer_id.append(0)
            elif data['answer'] == data['Q_edge'][1][1]:
                answer_id.append(1)
            else:
                score = (fuzz.partial_ratio(data['answer'],
                                            data['Q_edge'][0][1]),
                         fuzz.partial_ratio(data['answer'],
                                            data['Q_edge'][1][1]))
                if score[0] < 50 and score[1] < 50:
                    raise ValueError(
                        'There is no exact match in selecting question. answer: {}'
                        .format(data['answer']))
                else:
                    # print('Resolve type 1 or 2 question: {}\n answer: {}'.format(data['question'], data['answer']))
                    answer_id.append(0 if score[0] > score[1] else 1)
        else:
            pass

    for entity, sens in context.items():
        num_hop, num_ans = 0, 0
        tokenized_all = tokenized_question + prev[e2i[entity]]
        if len(tokenized_all) > 512:
            tokenized_all = tokenized_all[:512]
            print('PREV TOO LONG, id: {}'.format(data['_id']))
        segment_id = [0] * len(tokenized_all)
        sep_position = []
        hop_start_weight = [0] * len(tokenized_all)
        hop_end_weight = [0] * len(tokenized_all)
        ans_start_weight = [0] * len(tokenized_all)
        ans_end_weight = [0] * len(tokenized_all)

        for sen_num, sen in enumerate(sens):
            tokenized_sen = tokenizer.tokenize(sen) + ['[SEP]']
            if len(tokenized_all) + len(tokenized_sen) > 512 or sen_num > 15:
                break
            # if sen_num > 10:
            #     raise ValueError('Too many sentences in context: {}'.format(sens))
            tokenized_all += tokenized_sen
            segment_id += [sen_num + 1] * len(tokenized_sen)
            sep_position.append(len(tokenized_all) - 1)
            hs_weight = [0] * len(tokenized_sen)
            he_weight = [0] * len(tokenized_sen)
            as_weight = [0] * len(tokenized_sen)
            ae_weight = [0] * len(tokenized_sen)
            if (entity, sen_num) in gold_sentences_set:
                tmp = gold_sentences_set[(entity, sen_num)]
                intervals = find_start_end_after_tokenized(
                    tokenizer, tokenized_sen,
                    [matched for _, matched, _, _ in tmp])
                for j, (l, r) in enumerate(intervals):
                    if tmp[j][0] == answer_entity or question_type > 0:
                        as_weight[l] = ae_weight[r] = 1
                        num_ans += 1
                    else:
                        hs_weight[l] = he_weight[r] = 1
                        num_hop += 1
            hop_start_weight += hs_weight
            hop_end_weight += he_weight
            ans_start_weight += as_weight
            ans_end_weight += ae_weight

        assert len(tokenized_all) <= 512
        # for i in range(len(start_weight)):
        #     start_weight[i] /= max(num_spans, 1)
        #     end_weight[i] /= max(num_spans, 1)
        if 1 not in hop_start_weight:
            hop_start_weight[0] = 0.1
        if 1 not in ans_start_weight:
            ans_start_weight[0] = 0.1

        ids.append(tokenizer.convert_tokens_to_ids(tokenized_all))
        sep_positions.append(sep_position)
        segment_ids.append(segment_id)
        hop_start_weights.append(hop_start_weight)
        hop_end_weights.append(hop_end_weight)
        ans_start_weights.append(ans_start_weight)
        ans_end_weights.append(ans_end_weight)

    n = len(context)
    edges_in_bundle = []
    if question_type == 0:
        # find all edges and prepare forbidden set(containing answer) for negative sampling
        forbidden = set([])
        for para, sen, edges in data['supporting_facts']:
            for x, matched, l, r in edges:
                edges_in_bundle.append((e2i[para], e2i[x]))
                if x == answer_entity:
                    forbidden.add((para, sen))
        if answer_entity not in context and answer_entity in e2i:
            n += 1
            tokenized_all = tokenized_question + prev[e2i[answer_entity]]
            if len(tokenized_all) > 512:
                tokenized_all = tokenized_all[:512]
                print('ANSWER TOO LONG! id: {}'.format(data['_id']))
            additional_nodes.append(
                tokenizer.convert_tokens_to_ids(tokenized_all))
        for i in range(neg):
            # build negative node n+i
            father_para = random.choice(list(context.keys()))
            father_sen = random.randrange(len(context[father_para]))
            if (father_para, father_sen) in forbidden:
                father_para = random.choice(list(context.keys()))
                father_sen = random.randrange(len(context[father_para]))
            if (father_para, father_sen) in forbidden:
                neg -= 1
                continue
            tokenized_all = tokenized_question + tokenizer.tokenize(
                context[father_para][father_sen]) + ['[SEP]']
            if len(tokenized_all) > 512:
                tokenized_all = tokenized_all[:512]
                print('NEG TOO LONG! id: {}'.format(data['_id']))
            additional_nodes.append(
                tokenizer.convert_tokens_to_ids(tokenized_all))
            edges_in_bundle.append((e2i[father_para], n + i))
        n += neg
    assert n == len(additional_nodes) + len(context)

    adj = torch.eye(n) * 2
    for x, y in edges_in_bundle:
        adj[x, y] = 1
    adj /= torch.sum(adj, dim=0, keepdim=True)

    _id = data['_id']
    ret = Bundle()
    for field in FIELDS:
        setattr(ret, field, eval(field))
    return ret
示例#3
0
def cognitive_graph_propagate(tokenizer,
                              data: 'Json eval(Context as pool)',
                              model1,
                              model2,
                              device,
                              setting: 'distractor / fullwiki' = 'fullwiki',
                              max_new_nodes=5):
    """Answer the question in ``data'' by trained CogQA model.
    
    Args:
        tokenizer (Tokenizer): Word-Piece tokenizer.
        data (Json): Unrefined.
        model1 (nn.Module): System 1 model.
        model2 (nn.Module): System 2 model.
        device (torch.device): Selected device.
        setting (string, optional): 'distractor / fullwiki'. Defaults to 'fullwiki'.
        max_new_nodes (int, optional): Maximum number of new nodes in cognitive graph. Defaults to 5.
    
    Returns:
        tuple: (gold_ret, ans_ret, graph_ret, ans_nodes_ret)
    """
    context = dict(data['context'])
    e2i = dict([(entity, id) for id, entity in enumerate(context.keys())])
    n = len(context)
    i2e = [''] * n
    for k, v in e2i.items():
        i2e[v] = k
    prev = [[] for i in range(n)]  # elements: (title, sen_num)
    queue = range(n)
    semantics = [None] * n
    input_masks = [None] * n

    tokenized_question = ['[CLS]'] + tokenizer.tokenize(
        data['question']) + ['[SEP]']

    def construct_infer_batch(queue):
        """Construct next batch (frontier nodes to visit).
        
        Args:
            queue (list): A queue containing frontier nodes.
        
        Returns:
            tuple: A batch of inputs
        """
        ids, sep_positions, segment_ids, tokenized_alls, B_starts = [], [], [], [], []
        max_length, max_seps, num_samples = 0, 0, len(queue)
        for x in queue:
            tokenized_all = copy.copy(tokenized_question)
            for title, sen_num in prev[x]:
                tokenized_all += tokenizer.tokenize(
                    context[title][sen_num]) + ['[SEP]']
            if len(tokenized_all) > 512:
                tokenized_all = tokenized_all[:512]
                print('PREV TOO LONG, id: {}'.format(data['_id']))
            segment_id = [0] * len(tokenized_all)
            sep_position = []
            B_starts.append(len(tokenized_all))
            for sen_num, sen in enumerate(context[i2e[x]]):
                tokenized_sen = tokenizer.tokenize(sen) + ['[SEP]']
                if len(tokenized_all) + len(
                        tokenized_sen) > 512 or sen_num > 15:
                    break
                tokenized_all += tokenized_sen
                segment_id += [sen_num + 1] * len(tokenized_sen)
                sep_position.append(len(tokenized_all) - 1)
            max_length = max(max_length, len(tokenized_all))
            max_seps = max(max_seps, len(sep_position))
            tokenized_alls.append(tokenized_all)
            ids.append(tokenizer.convert_tokens_to_ids(tokenized_all))
            sep_positions.append(sep_position)
            segment_ids.append(segment_id)

        ids_tensor = torch.zeros((num_samples, max_length),
                                 dtype=torch.long,
                                 device=device)
        sep_positions_tensor = torch.zeros((num_samples, max_seps),
                                           dtype=torch.long,
                                           device=device)
        segment_ids_tensor = torch.zeros((num_samples, max_length),
                                         dtype=torch.long,
                                         device=device)
        input_mask = torch.zeros((num_samples, max_length),
                                 dtype=torch.long,
                                 device=device)
        B_starts = torch.tensor(B_starts, dtype=torch.long, device=device)
        for i in range(num_samples):
            length = len(ids[i])
            ids_tensor[i, :length] = torch.tensor(ids[i], dtype=torch.long)
            sep_positions_tensor[i, :len(sep_positions[i])] = torch.tensor(
                sep_positions[i], dtype=torch.long)
            segment_ids_tensor[i, :length] = torch.tensor(segment_ids[i],
                                                          dtype=torch.long)
            input_mask[i, :length] = 1
        return ids_tensor, segment_ids_tensor, input_mask, sep_positions_tensor, tokenized_alls, B_starts

    gold_ret, ans_nodes = set([]), set([])
    allow_limit = [0, 0]
    while len(queue) > 0:
        # visit all nodes in the frontier queue
        ids, segment_ids, input_mask, sep_positions, tokenized_alls, B_starts = construct_infer_batch(
            queue)
        hop_preds, ans_preds, semantics_preds, no_ans_logits = model1(
            ids, segment_ids, input_mask, sep_positions, None, None, None,
            None, B_starts, allow_limit)
        new_queue = []
        assert len(queue) == input_mask.shape[0]
        for i, x in enumerate(queue):
            input_masks[x] = input_mask[i]
            semantics[x] = semantics_preds[i]
            # for hop spans
            for k in range(hop_preds.size()[1]):
                l, r, j = hop_preds[i, k]
                j = j.item()
                if l == 0:
                    break
                gold_ret.add((i2e[x], j))  # supporting facts
                orig_text = context[i2e[x]][j]
                pred_slice = tokenized_alls[i][l:r + 1]
                l, r = find_start_end_before_tokenized(orig_text,
                                                       [pred_slice])[0]
                if l == r == 0:
                    continue
                recovered_matched = orig_text[l:r]
                pool = context if setting == 'distractor' else (i2e[x], j)
                matched = fuzzy_retrieve(recovered_matched, pool, setting)
                if matched is not None:
                    if setting == 'fullwiki' and matched not in e2i and n < 10 + max_new_nodes:
                        context_new = get_context_fullwiki(matched)
                        if len(context_new) > 0:  # cannot resovle redirection
                            # create new nodes in the cognitive graph
                            context[matched] = context_new
                            prev.append([])
                            semantics.append(None)
                            input_masks.append(None)
                            e2i[matched] = n
                            i2e.append(matched)
                            n += 1
                    if matched in e2i and e2i[matched] != x:
                        y = e2i[matched]
                        if y not in new_queue and (i2e[x], j) not in prev[y]:
                            # new edge means new clues! update the successor as frontier nodes.
                            new_queue.append(y)
                            prev[y].append(((i2e[x], j)))
            # for ans spans
            for k in range(ans_preds.size()[1]):
                l, r, j = ans_preds[i, k]
                j = j.item()
                if l == 0:
                    break
                gold_ret.add((i2e[x], j))
                orig_text = context[i2e[x]][j]
                pred_slice = tokenized_alls[i][l:r + 1]
                l, r = find_start_end_before_tokenized(orig_text,
                                                       [pred_slice])[0]
                if l == r == 0:
                    continue
                recovered_matched = orig_text[l:r]
                matched = fuzzy_retrieve(recovered_matched,
                                         context,
                                         'distractor',
                                         threshold=70)
                if matched is not None:
                    y = e2i[matched]
                    ans_nodes.add(y)
                    if (i2e[x], j) not in prev[y]:
                        prev[y].append(((i2e[x], j)))
                elif n < 10 + max_new_nodes:
                    context[recovered_matched] = []
                    e2i[recovered_matched] = n
                    i2e.append(recovered_matched)
                    new_queue.append(n)
                    ans_nodes.add(n)
                    prev.append([(i2e[x], j)])
                    semantics.append(None)
                    input_masks.append(None)
                    n += 1
        if len(new_queue) == 0 and len(ans_nodes) == 0 and allow_limit[
                1] < 0.1:  # must find one answer
            # ``allow'' is an offset of negative threshold.
            # If no ans span is valid, make the minimal gap between negative threshold and probability of ans spans -0.1, and try again.
            prob, pos_in_queue = torch.min(no_ans_logits, dim=0)
            new_queue.append(queue[pos_in_queue])
            allow_limit[1] = prob.item() + 0.1
        queue = new_queue

    question_type = judge_question_type(data['question'])

    if n == 0:
        return set([]), 'yes', [], []
    if n == 1 and question_type > 0:
        ans_ret = 'yes' if question_type == 1 else i2e[0]
        return [(i2e[0], 0)], ans_ret, [], []
    # GCN || CompareNets
    seq_len = np.max([x.shape[0] for x in semantics])
    for idx in range(len(semantics)):
        if semantics[idx].shape[0] < seq_len:
            semantics[idx] = torch.cat(
                (semantics[idx],
                 torch.zeros(seq_len - semantics[idx].shape[0],
                             semantics[idx].shape[1]).to(device)),
                dim=0)
    seq_len = np.max([x.shape[0] for x in input_masks])
    for idx in range(len(input_masks)):
        input_masks[idx] = torch.cat(
            (input_masks[idx],
             torch.zeros(seq_len - input_masks[idx].shape[0],
                         dtype=torch.long).to(device)),
            dim=0)

    semantics = torch.stack(semantics)
    input_masks = torch.stack(input_masks)
    input_masks = input_masks.unsqueeze(1).unsqueeze(2)
    input_masks = (1.0 - input_masks) * -10000.0
    input_masks = input_masks.to(dtype=torch.float32)  # fp16 compatibility
    if question_type == 0:
        adj = torch.eye(n, device=device) * 2
        for x in range(n):
            for title, sen_num in prev[x]:
                adj[e2i[title], x] = 1
        adj /= torch.sum(adj, dim=0, keepdim=True)
        pred = model2.gcn(adj, semantics, input_masks)
        for x in range(n):
            if x not in ans_nodes:
                pred[x] -= 10000.
        ans_ret = i2e[torch.argmax(pred).item()]
    else:
        # Take the most golden paragraphs as x,y
        gold_num = torch.zeros(n)
        for title, sen_num in gold_ret:
            gold_num[e2i[title]] += 1
        x, y = gold_num.topk(2)[1].tolist()
        diff_sem = semantics[x][0] - semantics[y][0]
        classifier = model2.both_net if question_type == 1 else model2.select_net
        pred = int(torch.sigmoid(classifier(diff_sem)).item() > 0.5)
        ans_ret = ['no', 'yes'
                   ][pred] if question_type == 1 else [i2e[x], i2e[y]][pred]

    ans_ret = re.sub(r' \(.*?\)$', '', ans_ret)

    graph_ret = []
    for x in range(n):
        for title, sen_num in prev[x]:
            graph_ret.append('({}, {}) --> {}'.format(title, sen_num, i2e[x]))

    ans_nodes_ret = [i2e[x] for x in ans_nodes]
    return gold_ret, ans_ret, graph_ret, ans_nodes_ret
示例#4
0
def cognitive_graph_propagate(tokenizer,
                              data: 'Json eval(Context as pool)',
                              model,
                              model_cg,
                              device,
                              setting: 'distractor / fullwiki' = 'distractor',
                              max_new_nodes=5):
    context = dict(data['context'])
    e2i = dict([(entity, id) for id, entity in enumerate(context.keys())])
    n = len(context)
    i2e = [''] * n
    for k, v in e2i.items():
        i2e[v] = k
    prev = [[] for i in range(n)]  # elements: (title, sen_num)
    queue = range(n)
    semantics = [None] * n

    tokenized_question = ['[CLS]'] + tokenizer.tokenize(
        data['question']) + ['[SEP]']

    def construct_infer_batch(queue):
        ids, sep_positions, segment_ids, tokenized_alls, B_starts = [], [], [], [], []
        max_length, max_seps, num_samples = 0, 0, len(queue)
        for x in queue:
            tokenized_all = copy.copy(tokenized_question)
            for title, sen_num in prev[x]:
                tokenized_all += tokenizer.tokenize(
                    context[title][sen_num]) + ['[SEP]']
            if len(tokenized_all) > 512:
                tokenized_all = tokenized_all[:512]
                print('PREV TOO LONG, id: {}'.format(data['_id']))
            segment_id = [0] * len(tokenized_all)
            sep_position = []
            B_starts.append(len(tokenized_all))
            for sen_num, sen in enumerate(context[i2e[x]]):
                tokenized_sen = tokenizer.tokenize(sen) + ['[SEP]']
                if len(tokenized_all) + len(
                        tokenized_sen) > 512 or sen_num > 15:
                    break
                tokenized_all += tokenized_sen
                segment_id += [sen_num + 1] * len(tokenized_sen)
                sep_position.append(len(tokenized_all) - 1)
            max_length = max(max_length, len(tokenized_all))
            max_seps = max(max_seps, len(sep_position))
            tokenized_alls.append(tokenized_all)
            ids.append(tokenizer.convert_tokens_to_ids(tokenized_all))
            sep_positions.append(sep_position)
            segment_ids.append(segment_id)

        ids_tensor = torch.zeros((num_samples, max_length),
                                 dtype=torch.long,
                                 device=device)
        sep_positions_tensor = torch.zeros((num_samples, max_seps),
                                           dtype=torch.long,
                                           device=device)
        segment_ids_tensor = torch.zeros((num_samples, max_length),
                                         dtype=torch.long,
                                         device=device)
        input_mask = torch.zeros((num_samples, max_length),
                                 dtype=torch.long,
                                 device=device)
        B_starts = torch.tensor(B_starts, dtype=torch.long, device=device)
        for i in range(num_samples):
            length = len(ids[i])
            ids_tensor[i, :length] = torch.tensor(ids[i], dtype=torch.long)
            sep_positions_tensor[i, :len(sep_positions[i])] = torch.tensor(
                sep_positions[i], dtype=torch.long)
            segment_ids_tensor[i, :length] = torch.tensor(segment_ids[i],
                                                          dtype=torch.long)
            input_mask[i, :length] = 1
        return ids_tensor, segment_ids_tensor, input_mask, sep_positions_tensor, tokenized_alls, B_starts

    gold_ret, ans_nodes = set([]), set([])
    allow_limit = [0, 0]
    while len(queue) > 0:
        ids, segment_ids, input_mask, sep_positions, tokenized_alls, B_starts = construct_infer_batch(
            queue)
        # pdb.set_trace()
        hop_preds, ans_preds, semantics_preds, no_ans_logits = model(
            ids, segment_ids, input_mask, sep_positions, None, None, None,
            None, B_starts, allow_limit)
        new_queue = []
        for i, x in enumerate(queue):
            semantics[x] = semantics_preds[i]
            # for hop spans
            for k in range(hop_preds.size()[1]):
                l, r, j = hop_preds[i, k]
                j = j.item()
                if l == 0:
                    break
                gold_ret.add((i2e[x], j))
                orig_text = context[i2e[x]][j]
                pred_slice = tokenized_alls[i][l:r + 1]
                l, r = find_start_end_before_tokenized(orig_text,
                                                       [pred_slice])[0]
                if l == r == 0:
                    continue
                recovered_matched = orig_text[l:r]
                pool = context if setting == 'distractor' else (i2e[x], j)
                matched = fuzzy_retrieve(recovered_matched, pool, setting)
                if matched is not None:
                    if setting == 'fullwiki' and matched not in e2i and n < 10 + max_new_nodes:
                        context_new = get_context_fullwiki(matched)
                        if len(context_new) > 0:  # cannot resovle redirection
                            context[matched] = context_new
                            prev.append([])
                            semantics.append(None)
                            e2i[matched] = n
                            i2e.append(matched)
                            n += 1
                    if matched in e2i and e2i[matched] != x:
                        y = e2i[matched]
                        if y not in new_queue and (i2e[x], j) not in prev[y]:
                            new_queue.append(y)
                            prev[y].append(((i2e[x], j)))
            # for ans spans
            for k in range(ans_preds.size()[1]):
                l, r, j = ans_preds[i, k]
                j = j.item()
                if l == 0:
                    break
                gold_ret.add((i2e[x], j))
                orig_text = context[i2e[x]][j]
                pred_slice = tokenized_alls[i][l:r + 1]
                l, r = find_start_end_before_tokenized(orig_text,
                                                       [pred_slice])[0]
                if l == r == 0:
                    continue
                recovered_matched = orig_text[l:r]
                # pool = context if setting == 'distractor' else (i2e[x], j)
                matched = fuzzy_retrieve(recovered_matched,
                                         context,
                                         'distractor',
                                         threshold=70)
                if matched is not None:
                    y = e2i[matched]
                    ans_nodes.add(y)
                    if (i2e[x], j) not in prev[y]:
                        prev[y].append(((i2e[x], j)))
                elif n < 10 + max_new_nodes:
                    context[recovered_matched] = []
                    e2i[recovered_matched] = n
                    i2e.append(recovered_matched)
                    new_queue.append(n)
                    ans_nodes.add(n)
                    prev.append([(i2e[x], j)])
                    semantics.append(None)
                    n += 1
        if len(new_queue) == 0 and len(ans_nodes) == 0 and allow_limit[
                1] < 0.1:  # must find one answer
            prob, pos_in_queue = torch.min(no_ans_logits, dim=0)
            new_queue.append(queue[pos_in_queue])
            allow_limit[1] = prob.item() + 0.1
        queue = new_queue

    question_type = judge_question_type(data['question'])

    if n == 0:
        return set([]), 'yes', [], []
    if n == 1 and question_type > 0:
        ans_ret = 'yes' if question_type == 1 else i2e[0]
        return [(i2e[0], 0)], ans_ret, [], []
    # GCN || CompareNets
    semantics = torch.stack(semantics)
    if question_type == 0:
        adj = torch.eye(n, device=device) * 2
        for x in range(n):
            for title, sen_num in prev[x]:
                adj[e2i[title], x] = 1
        adj /= torch.sum(adj, dim=0, keepdim=True)
        pred = model_cg.gcn(adj, semantics)
        for x in range(n):
            if x not in ans_nodes:
                pred[x] -= 10000.
        ans_ret = i2e[torch.argmax(pred).item()]
    else:
        # Take the most golden paragraphs as x,y
        gold_num = torch.zeros(n)
        for title, sen_num in gold_ret:
            gold_num[e2i[title]] += 1
        x, y = gold_num.topk(2)[1].tolist()
        diff_sem = semantics[x] - semantics[y]
        classifier = model_cg.both_net if question_type == 1 else model_cg.select_net
        pred = int(classifier(diff_sem).item() > 0.5)
        ans_ret = ['no', 'yes'
                   ][pred] if question_type == 1 else [i2e[x], i2e[y]][pred]

    ans_ret = re.sub(r' \(.*?\)$', '', ans_ret)

    graph_ret = []
    for x in range(n):
        for title, sen_num in prev[x]:
            graph_ret.append('({}, {}) --> {}'.format(title, sen_num, i2e[x]))

    ans_nodes_ret = [i2e[x] for x in ans_nodes]
    return gold_ret, ans_ret, graph_ret, ans_nodes_ret