Пример #1
0
    def _prepare_doc(self, curr_doc: Document) -> Dict:
        """ Returns a cache dictionary with preprocessed data. This should only be called once per document, since
        data inside same document does not get shuffled. """
        ret = {}

        preprocessed_sents, max_len = [], 0
        for curr_sent in curr_doc.raw_sentences():
            # TODO: uncased/cased option
            curr_processed_sent = list(map(lambda s: s.lower().strip(), curr_sent)) + ["<PAD>"]
            preprocessed_sents.append(curr_processed_sent)
            if len(curr_processed_sent) > max_len:
                max_len = len(curr_processed_sent)

        for i in range(len(preprocessed_sents)):
            preprocessed_sents[i].extend(["<PAD>"] * (max_len - len(preprocessed_sents[i])))

        cluster_sets = []
        mention_to_cluster_id = {}
        for i, curr_cluster in enumerate(curr_doc.clusters):
            cluster_sets.append(set(curr_cluster))
            for mid in curr_cluster:
                mention_to_cluster_id[mid] = i

        all_candidate_data = []
        for idx_head, (head_id, head_mention) in enumerate(curr_doc.mentions.items(), start=1):
            gt_antecedent_ids = cluster_sets[mention_to_cluster_id[head_id]]

            # Note: no data for dummy antecedent (len(`features`) is one less than `candidates`)
            candidates, candidate_data = [None], []
            starts, ends = [], []
            candidate_attention = []
            correct_antecedents = []

            curr_head_data = [[], []]
            for curr_token in head_mention.tokens:
                curr_head_data[0].append(curr_token.sentence_index)
                curr_head_data[1].append(curr_token.position_in_sentence)

            num_tokens = len(head_mention.tokens)
            if num_tokens > self.max_span_size:
                curr_head_data[0] = curr_head_data[0][:self.max_span_size]
                curr_head_data[1] = curr_head_data[1][:self.max_span_size]
            else:
                curr_head_data[0] += [head_mention.tokens[0].sentence_index] * (self.max_span_size - num_tokens)
                curr_head_data[1] += [-1] * (self.max_span_size - num_tokens)

            head_start = 0
            head_end = num_tokens
            head_attention = torch.ones((1, self.max_span_size), dtype=torch.bool)
            head_attention[0, num_tokens:] = False

            for idx_candidate, (cand_id, cand_mention) in enumerate(curr_doc.mentions.items(), start=1):
                if idx_candidate >= idx_head:
                    break

                candidates.append(cand_id)

                # Maps tokens to positions inside document (idx_sent, idx_inside_sent) for efficient indexing later
                curr_candidate_data = [[], []]
                for curr_token in cand_mention.tokens:
                    curr_candidate_data[0].append(curr_token.sentence_index)
                    curr_candidate_data[1].append(curr_token.position_in_sentence)

                num_tokens = len(cand_mention.tokens)
                if num_tokens > self.max_span_size:
                    curr_candidate_data[0] = curr_candidate_data[0][:self.max_span_size]
                    curr_candidate_data[1] = curr_candidate_data[1][:self.max_span_size]
                else:
                    curr_candidate_data[0] += [cand_mention.tokens[0].sentence_index] * (self.max_span_size - num_tokens)
                    curr_candidate_data[1] += [-1] * (self.max_span_size - num_tokens)

                candidate_data.append(curr_candidate_data)
                starts.append(0)
                ends.append(num_tokens)

                curr_attention = torch.ones((1, self.max_span_size), dtype=torch.bool)
                curr_attention[0, num_tokens:] = False
                candidate_attention.append(curr_attention)

                is_coreferent = cand_id in gt_antecedent_ids
                if is_coreferent:
                    correct_antecedents.append(idx_candidate)

            if len(correct_antecedents) == 0:
                correct_antecedents.append(0)

            candidate_attention = torch.cat(candidate_attention) if len(candidate_attention) > 0 else []

            all_candidate_data.append({
                "head_id": head_id,
                "head_data": torch.tensor([curr_head_data]),
                "head_attention": head_attention,
                "head_start": head_start,
                "head_end": head_end,
                "candidates": candidates,
                "candidate_data": torch.tensor(candidate_data),
                "candidate_attention": candidate_attention,
                "correct_antecedents": correct_antecedents
            })

        ret["preprocessed_sents"] = preprocessed_sents
        ret["steps"] = all_candidate_data

        return ret
Пример #2
0
    def _prepare_doc(self, curr_doc: Document) -> Dict:
        """ Returns a cache dictionary with preprocessed data. This should only be called once per document, since
        data inside same document does not get shuffled. """
        ret = {}

        # By default, each sentence is its own segment, meaning sentences are processed independently
        if self.max_segment_size is None:

            def get_position(t):
                return t.sentence_index, t.position_in_sentence

            _encoded_segments = batch_to_ids(curr_doc.raw_sentences())
        # Optionally, one can specify max_segment_size, in which case segments of tokens are processed independently
        else:

            def get_position(t):
                doc_position = t.position_in_document
                return doc_position // self.max_segment_size, doc_position % self.max_segment_size

            flattened_doc = list(chain(*curr_doc.raw_sentences()))
            num_segments = (len(flattened_doc) + self.max_segment_size -
                            1) // self.max_segment_size
            _encoded_segments = \
                batch_to_ids([flattened_doc[idx_seg * self.max_segment_size: (idx_seg + 1) * self.max_segment_size]
                              for idx_seg in range(num_segments)])

        encoded_segments = []
        # Convention: Add a PAD word ([0] * max_chars vector) at the end of each segment, for padding mentions
        for curr_sent in _encoded_segments:
            encoded_segments.append(
                torch.cat((curr_sent,
                           torch.zeros(
                               (1, ELMoCharacterMapper.max_word_length),
                               dtype=torch.long))))
        encoded_segments = torch.stack(encoded_segments)

        cluster_sets = []
        mention_to_cluster_id = {}
        for i, curr_cluster in enumerate(curr_doc.clusters):
            cluster_sets.append(set(curr_cluster))
            for mid in curr_cluster:
                mention_to_cluster_id[mid] = i

        all_candidate_data = []
        for idx_head, (head_id,
                       head_mention) in enumerate(curr_doc.mentions.items(),
                                                  1):
            gt_antecedent_ids = cluster_sets[mention_to_cluster_id[head_id]]

            # Note: no data for dummy antecedent (len(`features`) is one less than `candidates`)
            candidates, candidate_data = [None], []
            candidate_attention = []
            correct_antecedents = []

            curr_head_data = [[], []]
            num_head_words = 0
            for curr_token in head_mention.tokens:
                idx_segment, idx_inside_segment = get_position(curr_token)
                curr_head_data[0].append(idx_segment)
                curr_head_data[1].append(idx_inside_segment)
                num_head_words += 1

            if num_head_words > self.max_span_size:
                curr_head_data[0] = curr_head_data[0][:self.max_span_size]
                curr_head_data[1] = curr_head_data[1][:self.max_span_size]
            else:
                curr_head_data[0] += [curr_head_data[0][-1]
                                      ] * (self.max_span_size - num_head_words)
                curr_head_data[1] += [-1
                                      ] * (self.max_span_size - num_head_words)

            head_attention = torch.ones((1, self.max_span_size),
                                        dtype=torch.bool)
            head_attention[0, num_head_words:] = False

            for idx_candidate, (cand_id, cand_mention) in enumerate(
                    curr_doc.mentions.items(), start=1):
                if idx_candidate >= idx_head:
                    break

                candidates.append(cand_id)

                # Maps tokens to positions inside segments (idx_seg, idx_inside_seg) for efficient indexing later
                curr_candidate_data = [[], []]
                num_candidate_words = 0
                for curr_token in cand_mention.tokens:
                    idx_segment, idx_inside_segment = get_position(curr_token)
                    curr_candidate_data[0].append(idx_segment)
                    curr_candidate_data[1].append(idx_inside_segment)
                    num_candidate_words += 1

                if num_candidate_words > self.max_span_size:
                    curr_candidate_data[0] = curr_candidate_data[
                        0][:self.max_span_size]
                    curr_candidate_data[1] = curr_candidate_data[
                        1][:self.max_span_size]
                else:
                    # padding tokens index into the PAD token of the last segment
                    curr_candidate_data[0] += [curr_candidate_data[0][-1]] * (
                        self.max_span_size - num_candidate_words)
                    curr_candidate_data[1] += [-1] * (self.max_span_size -
                                                      num_candidate_words)

                candidate_data.append(curr_candidate_data)
                curr_attention = torch.ones((1, self.max_span_size),
                                            dtype=torch.bool)
                curr_attention[0, num_candidate_words:] = False
                candidate_attention.append(curr_attention)

                is_coreferent = cand_id in gt_antecedent_ids
                if is_coreferent:
                    correct_antecedents.append(idx_candidate)

            if len(correct_antecedents) == 0:
                correct_antecedents.append(0)

            candidate_attention = torch.cat(
                candidate_attention) if len(candidate_attention) > 0 else []
            all_candidate_data.append({
                "head_id":
                head_id,
                "head_data":
                torch.tensor([curr_head_data]),
                "head_attention":
                head_attention,
                "candidates":
                candidates,
                "candidate_data":
                torch.tensor(candidate_data),
                "candidate_attention":
                candidate_attention,
                "correct_antecedents":
                correct_antecedents
            })

        ret["preprocessed_segments"] = encoded_segments
        ret["steps"] = all_candidate_data

        return ret