def __getitem__(self, idx):
        example = self.data.iloc[idx]
        example_id = torch.LongTensor([example['e_id']
                                       ])  ### for ddp validation alignment
        ####
        query_encode, ctx_encode_list, norm_answer = example[
            'ques_encode'], example['ctx_encode_list'], example['norm_answer']
        ctx_encode = context_random_selection(
            context_tuple_list=ctx_encode_list,
            neg_num=self.neg_doc_num,
            shuffle=self.shuffle)
        if norm_answer.strip() in ['yes', 'no', 'noanswer'
                                   ]:  ## yes: 1, no/noanswer: 2, span = 0
            answer_type_label = np.array(
                [1]) if norm_answer.strip() == 'yes' else np.array([2])
        else:
            answer_type_label = np.array([0])
        span_flag = example['span_flag']
        doc_infor, sent_infor, seq_infor, answer_infor = context_merge_longer(
            query_encode_ids=query_encode,
            context_tuple_list=ctx_encode,
            span_flag=span_flag)
        doc_labels, doc_ans_labels, doc_num, doc_len_array, doc_start_position, doc_end_position, doc_head_idx, doc_tail_idx = doc_infor
        sent_labels, sent_ans_labels, sent_num, sent_len_array, sent_start_position, sent_end_position, sent2doc_map_array, abs_sentIndoc_array, doc_sent_nums = sent_infor
        concat_ctx_array, token_num, global_attn_marker, token2sentID_map, answer_mask_idxs = seq_infor
        answer_pos_start, answer_pos_end, _ = answer_infor
        ####
        doc_num = doc_num
        concat_sent_num = sent_num
        concat_len = token_num
        ####
        cat_doc_encodes = concat_ctx_array.tolist()
        cat_doc_attention_mask = [1] * concat_len
        cat_doc_global_attn_mask = global_attn_marker.tolist()
        ctx_marker_mask = answer_mask_idxs.tolist()
        ctx_token2sent_map = token2sentID_map.tolist()
        assert concat_len == len(cat_doc_encodes)

        if concat_len < self.max_token_num:
            token_pad_num = self.max_token_num - concat_len
            cat_doc_encodes = cat_doc_encodes + [self.pad_token_id
                                                 ] * token_pad_num
            cat_doc_attention_mask = cat_doc_attention_mask + [
                0
            ] * token_pad_num
            ctx_token2sent_map = ctx_token2sent_map + [0] * token_pad_num
            cat_doc_global_attn_mask = cat_doc_global_attn_mask + [
                0
            ] * token_pad_num
            ctx_marker_mask = ctx_marker_mask + [0] * token_pad_num
        cat_doc_encodes = torch.LongTensor(cat_doc_encodes)
        cat_doc_attention_mask = torch.LongTensor(cat_doc_attention_mask)
        ctx_token2sent_map = torch.LongTensor(ctx_token2sent_map)
        cat_doc_global_attn_mask = torch.LongTensor(cat_doc_global_attn_mask)
        ctx_marker_mask = torch.BoolTensor(ctx_marker_mask)
        ################################################################################################################
        doc_start_idxes = doc_start_position.tolist()
        doc_end_idxes = doc_end_position.tolist()
        doc_lens = doc_len_array.tolist()
        doc_labels = doc_ans_labels.tolist()
        doc_sent_nums = doc_sent_nums.tolist()
        if doc_num < self.max_doc_num:
            doc_pad_num = self.max_doc_num - doc_num
            doc_start_idxes = doc_start_idxes + [0] * doc_pad_num
            doc_end_idxes = doc_end_idxes + [0] * doc_pad_num
            doc_lens = doc_lens + [0] * doc_pad_num
            doc_labels = doc_labels + [0] * doc_pad_num
            doc_sent_nums = doc_sent_nums + [0] * doc_pad_num
        doc_start_idxes = torch.LongTensor(doc_start_idxes)
        doc_end_idxes = torch.LongTensor(doc_end_idxes)
        doc_lens = torch.LongTensor(doc_lens)
        doc_labels = torch.LongTensor(doc_labels)
        ################################################################################################################
        sent_start_idxes = sent_start_position.tolist()
        sent_end_idxes = sent_end_position.tolist()
        ctx_sent_lens = sent_len_array.tolist()
        supp_sent_labels = sent_ans_labels.tolist()
        ctx_sent2doc_map = sent2doc_map_array.tolist()
        ctx_sentIndoc_idx = abs_sentIndoc_array.tolist()
        if concat_sent_num < self.max_sent_num:
            sent_pad_num = self.max_sent_num - concat_sent_num
            sent_start_idxes = sent_start_idxes + [0] * sent_pad_num
            sent_end_idxes = sent_end_idxes + [0] * sent_pad_num
            ctx_sent_lens = ctx_sent_lens + [0] * sent_pad_num
            supp_sent_labels = supp_sent_labels + [0] * sent_pad_num
            ctx_sent2doc_map = ctx_sent2doc_map + [0] * sent_pad_num
            ctx_sentIndoc_idx = ctx_sentIndoc_idx + [0] * sent_pad_num
        sent_start_idxes = torch.LongTensor(sent_start_idxes)
        sent_end_idxes = torch.LongTensor(sent_end_idxes)
        ctx_sent_lens = torch.LongTensor(ctx_sent_lens)
        supp_sent_labels = torch.LongTensor(supp_sent_labels)
        ctx_sent2doc_map = torch.LongTensor(ctx_sent2doc_map)
        ctx_sentIndoc_idx = torch.LongTensor(ctx_sentIndoc_idx)
        ################################################################################################################
        answer_start_idx = answer_pos_start
        answer_end_idx = answer_pos_end
        answer_start_idx = torch.LongTensor(answer_start_idx)
        answer_end_idx = torch.LongTensor(answer_end_idx)
        ################################################################################################################
        yes_no_label = answer_type_label
        yes_no_label = torch.LongTensor(yes_no_label)
        ################################################################################################################
        head_doc_idx = doc_head_idx
        tail_doc_idx = doc_tail_idx
        head_doc_idx = torch.LongTensor(head_doc_idx)
        tail_doc_idx = torch.LongTensor(tail_doc_idx)
        ################################################################################################################
        ss_attn_mask, sd_attn_mask = mask_generation(
            sent_num_docs=doc_sent_nums, max_sent_num=self.max_sent_num)
        ################################################################################################################
        ################################################################################################################
        assert concat_len <= self.max_token_num
        ##++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        return cat_doc_encodes, cat_doc_attention_mask, cat_doc_global_attn_mask, doc_start_idxes, sent_start_idxes, \
               answer_start_idx, answer_end_idx, doc_lens, doc_labels, ctx_sent_lens, supp_sent_labels, yes_no_label, head_doc_idx, \
               tail_doc_idx, ss_attn_mask, sd_attn_mask, ctx_sent2doc_map, ctx_sentIndoc_idx, ctx_token2sent_map, ctx_marker_mask, \
               doc_end_idxes, sent_end_idxes, concat_len, concat_sent_num, example_id
    def __getitem__(self, idx):
        example = self.data.iloc[idx]
        example_id = torch.LongTensor([example['e_id']
                                       ])  ### for ddp validation alignment
        ####
        query_encode, ctx_encode = example['ques_encode'], example[
            'ctx_encode_list']
        doc_infor, sent_infor, seq_infor = test_context_merge_longer(
            query_encode_ids=query_encode, context_tuple_list=ctx_encode)
        doc_num, doc_len_array, doc_start_position, doc_end_position = doc_infor
        sent_num, sent_len_array, sent_start_position, sent_end_position, sent2doc_map_array, abs_sentIndoc_array, doc_sent_nums = sent_infor
        concat_ctx_array, token_num, global_attn_marker, token2sentID_map, answer_mask_idxs = seq_infor
        ####
        doc_num = doc_num
        concat_sent_num = sent_num
        concat_len = token_num
        ####
        cat_doc_encodes = concat_ctx_array.tolist()
        cat_doc_attention_mask = [1] * concat_len
        cat_doc_global_attn_mask = global_attn_marker.tolist()
        ctx_marker_mask = answer_mask_idxs.tolist()
        ctx_token2sent_map = token2sentID_map.tolist()
        assert concat_len == len(cat_doc_encodes)

        if concat_len < self.max_token_num:
            token_pad_num = self.max_token_num - concat_len
            cat_doc_encodes = cat_doc_encodes + [self.pad_token_id
                                                 ] * token_pad_num
            cat_doc_attention_mask = cat_doc_attention_mask + [
                0
            ] * token_pad_num
            ctx_token2sent_map = ctx_token2sent_map + [0] * token_pad_num
            cat_doc_global_attn_mask = cat_doc_global_attn_mask + [
                0
            ] * token_pad_num
            ctx_marker_mask = ctx_marker_mask + [0] * token_pad_num
        cat_doc_encodes = torch.LongTensor(cat_doc_encodes)
        cat_doc_attention_mask = torch.LongTensor(cat_doc_attention_mask)
        ctx_token2sent_map = torch.LongTensor(ctx_token2sent_map)
        cat_doc_global_attn_mask = torch.LongTensor(cat_doc_global_attn_mask)
        ctx_marker_mask = torch.BoolTensor(ctx_marker_mask)
        ################################################################################################################
        doc_start_idxes = doc_start_position.tolist()
        doc_end_idxes = doc_end_position.tolist()
        doc_lens = doc_len_array.tolist()
        doc_sent_nums = doc_sent_nums.tolist()
        if doc_num < self.max_doc_num:
            doc_pad_num = self.max_doc_num - doc_num
            doc_start_idxes = doc_start_idxes + [0] * doc_pad_num
            doc_end_idxes = doc_end_idxes + [0] * doc_pad_num
            doc_lens = doc_lens + [0] * doc_pad_num
            doc_sent_nums = doc_sent_nums + [0] * doc_pad_num
        doc_start_idxes = torch.LongTensor(doc_start_idxes)
        doc_end_idxes = torch.LongTensor(doc_end_idxes)
        doc_lens = torch.LongTensor(doc_lens)
        ################################################################################################################
        sent_start_idxes = sent_start_position.tolist()
        sent_end_idxes = sent_end_position.tolist()
        ctx_sent_lens = sent_len_array.tolist()
        ctx_sent2doc_map = sent2doc_map_array.tolist()
        ctx_sentIndoc_idx = abs_sentIndoc_array.tolist()
        if concat_sent_num < self.max_sent_num:
            sent_pad_num = self.max_sent_num - concat_sent_num
            sent_start_idxes = sent_start_idxes + [0] * sent_pad_num
            sent_end_idxes = sent_end_idxes + [0] * sent_pad_num
            ctx_sent_lens = ctx_sent_lens + [0] * sent_pad_num
            ctx_sent2doc_map = ctx_sent2doc_map + [0] * sent_pad_num
            ctx_sentIndoc_idx = ctx_sentIndoc_idx + [0] * sent_pad_num
        sent_start_idxes = torch.LongTensor(sent_start_idxes)
        sent_end_idxes = torch.LongTensor(sent_end_idxes)
        ctx_sent_lens = torch.LongTensor(ctx_sent_lens)
        ctx_sent2doc_map = torch.LongTensor(ctx_sent2doc_map)
        ctx_sentIndoc_idx = torch.LongTensor(ctx_sentIndoc_idx)
        ################################################################################################################
        ss_attn_mask, sd_attn_mask = mask_generation(
            sent_num_docs=doc_sent_nums, max_sent_num=self.max_sent_num)
        ################################################################################################################
        assert concat_len <= self.max_token_num
        ################################################################################################################
        return cat_doc_encodes, cat_doc_attention_mask, cat_doc_global_attn_mask, doc_start_idxes, sent_start_idxes, \
               doc_lens, ctx_sent_lens, ss_attn_mask, sd_attn_mask, ctx_sent2doc_map, ctx_sentIndoc_idx, ctx_token2sent_map, ctx_marker_mask, \
               doc_end_idxes, sent_end_idxes, concat_len, concat_sent_num, example_id
Beispiel #3
0
    def __getitem__(self, idx):
        example = self.data.iloc[idx]
        example_id = torch.LongTensor([example['e_id']])  ### for ddp validation alignment
        ####
        doc_num = example['doc_num']
        concat_sent_num = example['sent_num']
        concat_len = example['ctx_len']
        ####
        cat_doc_encodes = example['ctx_encode']
        cat_doc_attention_mask = [1] * concat_len
        cat_doc_global_attn_mask = example['global_attn']
        ctx_marker_mask = example['ans_mask']

        ctx_token2sent_map = example['token2sent']
        assert concat_len == len(cat_doc_encodes)
        if concat_len < self.max_token_num:
            token_pad_num = self.max_token_num - concat_len
            cat_doc_encodes = cat_doc_encodes + [self.pad_token_id] * token_pad_num
            cat_doc_global_attn_mask = cat_doc_global_attn_mask + [0] * token_pad_num
            ctx_marker_mask = ctx_marker_mask + [0] * token_pad_num
            cat_doc_attention_mask = cat_doc_attention_mask + [0] * token_pad_num
            ctx_token2sent_map = ctx_token2sent_map + [0] * token_pad_num
        cat_doc_encodes = torch.LongTensor(cat_doc_encodes)
        cat_doc_attention_mask = torch.LongTensor(cat_doc_attention_mask)
        ctx_token2sent_map = torch.LongTensor(ctx_token2sent_map)
        cat_doc_global_attn_mask = torch.LongTensor(cat_doc_global_attn_mask)
        ctx_marker_mask = torch.BoolTensor(ctx_marker_mask)
        ################################################################################################################
        doc_start_idxes = example['doc_start']
        doc_end_idxes = example['doc_end']
        doc_lens = example['doc_len']
        doc_sent_nums = example['doc_sent_num']
        if doc_num < self.max_doc_num:
            doc_pad_num = self.max_doc_num - doc_num
            doc_start_idxes = doc_start_idxes + [0] * doc_pad_num
            doc_end_idxes = doc_end_idxes + [0] * doc_pad_num
            doc_lens = doc_lens + [0] * doc_pad_num
            doc_sent_nums = doc_sent_nums + [0] * doc_pad_num
        doc_start_idxes = torch.LongTensor(doc_start_idxes)
        doc_end_idxes = torch.LongTensor(doc_end_idxes)
        doc_lens = torch.LongTensor(doc_lens)
        ################################################################################################################
        sent_start_idxes = example['sent_start']
        sent_end_idxes = example['sent_end']
        ctx_sent_lens = example['sent_len']
        ctx_sent2doc_map = example['sent2doc']
        ctx_sentIndoc_idx = example['sentIndoc']
        if concat_sent_num < self.max_sent_num:
            sent_pad_num = self.max_sent_num - concat_sent_num
            sent_start_idxes = sent_start_idxes + [0] * sent_pad_num
            sent_end_idxes = sent_end_idxes + [0] * sent_pad_num
            ctx_sent_lens = ctx_sent_lens + [0] * sent_pad_num
            ctx_sent2doc_map = ctx_sent2doc_map + [0] * sent_pad_num
            ctx_sentIndoc_idx = ctx_sentIndoc_idx + [0] * sent_pad_num
        sent_start_idxes = torch.LongTensor(sent_start_idxes)
        sent_end_idxes = torch.LongTensor(sent_end_idxes)
        ctx_sent_lens = torch.LongTensor(ctx_sent_lens)
        ctx_sent2doc_map = torch.LongTensor(ctx_sent2doc_map)
        ctx_sentIndoc_idx = torch.LongTensor(ctx_sentIndoc_idx)
        ################################################################################################################
        ss_attn_mask, sd_attn_mask = mask_generation(sent_num_docs=doc_sent_nums, max_sent_num=self.max_sent_num)
        ################################################################################################################
        ################################################################################################################
        assert concat_len <= self.max_token_num
        ################################################################################################################
        return cat_doc_encodes, cat_doc_attention_mask, cat_doc_global_attn_mask, doc_start_idxes, sent_start_idxes, \
               doc_lens, ctx_sent_lens, ss_attn_mask, sd_attn_mask, ctx_sent2doc_map, ctx_sentIndoc_idx, ctx_token2sent_map, ctx_marker_mask, \
               doc_end_idxes, sent_end_idxes, concat_len, concat_sent_num, example_id