Ejemplo n.º 1
0
 def get_pair_mentions_features(self, m1, m2):
     ''' Features for pair of mentions (same speakers, speaker mentioned, string match)'''
     features_ = {"00_SameSpeaker": 1 if self.consider_speakers and m1.speaker == m2.speaker else 0,
                  "01_AntMatchMentionSpeaker": 1 if self.consider_speakers and m2.speaker_match_mention(m1) else 0,
                  "02_MentionMatchSpeaker": 1 if self.consider_speakers and m1.speaker_match_mention(m2) else 0,
                  "03_HeadsAgree": 1 if m1.heads_agree(m2) else 0,
                  "04_ExactStringMatch": 1 if m1.exact_match(m2) else 0,
                  "05_RelaxedStringMatch": 1 if m1.relaxed_match(m2) else 0,
                  "06_SentenceDistance": m2.utterances_sent - m1.utterances_sent,
                  "07_MentionDistance": m2.index - m1.index - 1,
                  "08_Overlapping": 1 if (m1.utterances_sent == m2.utterances_sent and m1.end > m2.start) else 0,
                  "09_M1Features": m1.features_,
                  "10_M2Features": m2.features_,
                  "11_DocGenre": self.genre_}
     pairwise_features = [np.array([features_["00_SameSpeaker"],
                                    features_["01_AntMatchMentionSpeaker"],
                                    features_["02_MentionMatchSpeaker"],
                                    features_["03_HeadsAgree"],
                                    features_["04_ExactStringMatch"],
                                    features_["05_RelaxedStringMatch"]]),
                          encode_distance(features_["06_SentenceDistance"]),
                          encode_distance(features_["07_MentionDistance"]),
                          np.array(features_["08_Overlapping"], ndmin=1),
                          m1.features,
                          m2.features,
                          self.genre]
     return (features_, np.concatenate(pairwise_features, axis=0))
Ejemplo n.º 2
0
 def set_mentions_features(self):
     '''
     Compute features for the extracted mentions
     '''
     doc_embedding = self.embed_extractor.get_document_embedding(self.utterances) if self.embed_extractor is not None else None
     for mention in self.mentions:
         one_hot_type = np.zeros((4,))
         one_hot_type[mention.mention_type] = 1
         features_ = {"01_MentionType": mention.mention_type,
                      "02_MentionLength": len(mention)-1,
                      "03_MentionNormLocation": (mention.index)/len(self.mentions),
                      "04_IsMentionNested": 1 if any((m is not mention
                                                       and m.utterances_sent == mention.utterances_sent
                                                       and m.start <= mention.start
                                                       and mention.end <= m.end)
                                                      for m in self.mentions) else 0}
         features = np.concatenate([one_hot_type,
                                    encode_distance(features_["02_MentionLength"]),
                                    np.array(features_["03_MentionNormLocation"], ndmin=1, copy=False),
                                    np.array(features_["04_IsMentionNested"], ndmin=1, copy=False)
                                   ], axis=0)
         (spans_embeddings_, words_embeddings_,
          spans_embeddings, words_embeddings) = self.embed_extractor.get_mention_embeddings(mention, doc_embedding)
         mention.features_ = features_
         mention.features = features
         mention.spans_embeddings = spans_embeddings
         mention.spans_embeddings_ = spans_embeddings_
         mention.words_embeddings = words_embeddings
         mention.words_embeddings_ = words_embeddings_
Ejemplo n.º 3
0
    def get_pair_mentions_features(self, m1, m2, doc_id, fileCsv, fileId,
                                   fileLabel, fileFeature, setID):
        ''' Features for pair of mentions (same speakers, speaker mentioned, string match)'''
        features_ = {
            "00_SameSpeaker":
            1 if self.consider_speakers and m1.speaker == m2.speaker else 0,
            "01_AntMatchMentionSpeaker":
            1
            if self.consider_speakers and m2.speaker_match_mention(m1) else 0,
            "02_MentionMatchSpeaker":
            1
            if self.consider_speakers and m1.speaker_match_mention(m2) else 0,
            "03_HeadsAgree":
            1 if m1.heads_agree(m2) else 0,
            "04_ExactStringMatch":
            1 if m1.exact_match(m2) else 0,
            "05_RelaxedStringMatch":
            1 if m1.relaxed_match(m2) else 0,
            "06_SentenceDistance":
            m2.utterances_sent - m1.utterances_sent,
            "07_MentionDistance":
            m2.index - m1.index - 1,
            "08_Overlapping":
            1 if (m1.utterances_sent == m2.utterances_sent
                  and m1.end > m2.start) else 0,
            "09_M1Features":
            m1.features_,
            "10_M2Features":
            m2.features_,
            "11_DocGenre":
            self.genre_
        }
        pairwise_features = [
            np.array([
                features_["00_SameSpeaker"],
                features_["01_AntMatchMentionSpeaker"],
                features_["02_MentionMatchSpeaker"],
                features_["03_HeadsAgree"], features_["04_ExactStringMatch"],
                features_["05_RelaxedStringMatch"]
            ]),
            encode_distance(features_["06_SentenceDistance"]),
            encode_distance(features_["07_MentionDistance"]),
            np.array(features_["08_Overlapping"], ndmin=1), m1.features,
            m2.features, self.genre
        ]
        folderIndexMention = "/home/hung/git/neuralcoref/unite/index/"
        #folderFeatureMention="/home/hung/git/neuralcoref/feature_mentions/"
        strFileM1 = "".join(
            [folderIndexMention,
             str(doc_id), "_",
             str(m1.index), ".txt"])
        strFileM2 = "".join(
            [folderIndexMention,
             str(doc_id), "_",
             str(m2.index), ".txt"])
        #strFileFeat="".join([folderFeatureMention,str(m1.index),"_",str(m2.index),".txt"])

        if not os.path.isfile(strFileM1):
            strWriteM1 = u"".join([
                str(doc_id), "\n",
                str(m1.index), "\n",
                str(m1.gold_label), "\n",
                str(m1.start), "\n",
                str(m1.end), "\n",
                str(m1).encode('utf8'), "\n"
            ])
            text_file = open(strFileM1, "w")
            text_file.write(strWriteM1)
            text_file.close()

        if not os.path.isfile(strFileM2):
            strWriteM2 = u"".join([
                str(doc_id), "\n",
                str(m2.index), "\n",
                str(m2.gold_label), "\n",
                str(m2.start), "\n",
                str(m2.end), "\n",
                str(m2).encode('utf8'), "\n"
            ])
            text_file = open(strFileM2, "w")
            text_file.write(strWriteM2)
            text_file.close()

        finalResult = np.concatenate(pairwise_features, axis=0)

        strId = "".join(
            [str(doc_id), "_",
             str(m1.index), "_",
             str(m2.index), ".txt"]).replace("\n", " ").strip()
        if strId not in setID:
            setID.add(strId)
            strPairLabel = "0"
            if m1.gold_label == m2.gold_label:
                strPairLabel = "1"

            #strWriteFeat=u"".join([str(m1.index),"\n",str(m1.gold_label),"\n",str(m2.index),"\n",str(m2.gold_label),"\n",strPairLabel,"\n",str(finalResult).replace("\n", " "),"\n"])
            strFeat = "".join([str(doc_id), ",", strId, ","])
            for i in range(len(finalResult)):
                strFeat = "".join([strFeat, str(finalResult[i]), ","])
            strFeat = "".join([strFeat, strPairLabel])
            #strFeat=str(finalResult).replace("\n", " ")

            fileId.write("".join([strId, "\n"]))
            fileLabel.write("".join([strPairLabel, "\n"]))
            fileFeature.write("".join([strFeat, "\n"]))
            fileCsv.write("".join([strFeat, "\n"]))

        #print(m1.gold_label+" "+m2.gold_label+" "+strPairLabel)
        #print('m1 '+' '+str(m1)+' '+str(m1.start)+' '+str(m1.end)+' '+str(m1.index)+' '+str(m1.gold_label)+' abc')
        #print('m2 '+' '+str(m2)+' '+str(m2.start)+' '+str(m2.end)+' '+str(m2.index)+' '+str(m2.gold_label)+' abc')
        #print('pair'+str(np.concatenate(pairwise_features, axis=0)))
        #abc = raw_input("What is your name? ")
        return (features_, finalResult)
Ejemplo n.º 4
0
    def __getitem__(self, mention_idx, debug=False):
        """
        Return:
            Definitions:
                P is the number of antecedent per mention (number of pairs for the mention)
                S = 250 is the size of the span vector (averaged word embeddings)
                W = 8 is the number of words in a mention (tuned embeddings)
                Fp = 70 is the number of features for a pair of mention
                Fs = 24 is the number of features of a single mention

            if there are some pairs:
                inputs = (spans, words, features, ant_spans, ant_words, ana_spans, ana_words, pairs_features)
                targets = (labels, costs, true_ants, false_ants)
            else:
                inputs = (spans, words, features)
                targets = (labels, costs, true_ants)

            inputs: Tuple of
                spans => (S,)   250
                words => (W,)   8
                features => (Fs,)   70
                + if there are potential antecedents (P > 0):
                    ant_spans => (P, S) or nothing if no pairs  (P,250)
                    ant_words => (P, W) or nothing if no pairs  (P,8)
                    ana_spans => (P, S) or nothing if no pairs
                    ana_words => (P, W) or nothing if no pairs
                    pair_features => (P, Fp) or nothing if no pairs

            targets: Tuple of
                labels => (P+1,)
                costs => (P+1,)
                true_ant => (P+1,)
                + if there are potential antecedents (P > 0):
                    false_ant => (P+1,)

        """
        features_raw, label, pairs_length, pairs_start_index, spans, words = self.mentions[
            mention_idx]
        pairs_start_index = np.asscalar(pairs_start_index)
        pairs_length = np.asscalar(pairs_length)

        # Build features array (float) from raw features (int)
        #Fs共24个特征
        ### SIZE_FS_COMPRESSED   size of the features for a mention as stored in numpy arrays分割成4,11,1,1,
        assert features_raw.shape[0] == SIZE_FS_COMPRESSED
        ##24*
        features = np.zeros((SIZE_FS, ))
        features[features_raw[0]] = 1
        ##features_raw[1]进行one-hot
        features[4:15] = encode_distance(features_raw[1])
        features[15] = features_raw[2].astype(float) / features_raw[3].astype(
            float)
        features[16] = features_raw[4]
        features[features_raw[5] + 17] = 1

        if pairs_length == 0:
            spans = torch.from_numpy(spans).float()
            words = torch.from_numpy(words)
            features = torch.from_numpy(features).float()
            inputs = (spans, words, features)
            if self.no_targets:
                return inputs
            true_ant = torch.zeros(1).long()  # zeros = indices of true ant
            costs = torch.from_numpy((1 - label) * self.costs['FN']).float()
            label = torch.from_numpy(label).float()
            targets = (label, costs, true_ant)
            if debug:
                print("inputs shapes: ", [a.size() for a in inputs])
                print("targets shapes: ", [a.size() for a in targets])
            return inputs, targets

        ##ant_spans, ant_words, ana_spans, ana_words, pairs_features
        start = pairs_start_index
        end = pairs_start_index + pairs_length
        pairs = self.pairs[start:end]
        assert len(pairs) == pairs_length
        assert len(
            pairs[0]
        ) == 3  # pair[i] = (pairs_ant_index, pairs_features, pairs_labels)
        pairs_ant_index, pairs_features_raw, pairs_labels = list(zip(*pairs))

        pairs_features_raw = np.stack(pairs_features_raw)
        pairs_labels = np.squeeze(np.stack(pairs_labels), axis=1)

        # Build pair features array (float) from raw features (int)
        assert pairs_features_raw[0, :].shape[0] == SIZE_FP_COMPRESSED
        pairs_features = np.zeros((len(pairs_ant_index), SIZE_FP))
        pairs_features[:, 0:6] = pairs_features_raw[:, 0:6]
        pairs_features[:, 6:17] = encode_distance(pairs_features_raw[:, 6])
        pairs_features[:, 17:28] = encode_distance(pairs_features_raw[:, 7])
        pairs_features[:, 28] = pairs_features_raw[:, 8]
        # prepare antecent features
        ant_features_raw = np.concatenate([
            self.mentions[np.asscalar(idx)][0][np.newaxis, :]
            for idx in pairs_ant_index
        ])
        ant_features = np.zeros((pairs_length, SIZE_FS - SIZE_GENRE))
        ant_features[:, ant_features_raw[:, 0]] = 1
        ant_features[:, 4:15] = encode_distance(ant_features_raw[:, 1])
        ant_features[:, 15] = ant_features_raw[:, 2].astype(
            float) / ant_features_raw[:, 3].astype(float)
        ant_features[:, 16] = ant_features_raw[:, 4]
        pairs_features[:, 29:46] = ant_features
        # Here we keep the genre
        ana_features = np.tile(features, (pairs_length, 1))
        pairs_features[:, 46:] = ana_features

        ant_spans = np.concatenate([
            self.mentions[np.asscalar(idx)][4][np.newaxis, :]
            for idx in pairs_ant_index
        ])
        ant_words = np.concatenate([
            self.mentions[np.asscalar(idx)][5][np.newaxis, :]
            for idx in pairs_ant_index
        ])
        ##np.tile(A,B)重复A B次,B可以是列表
        ana_spans = np.tile(spans, (pairs_length, 1))
        ana_words = np.tile(words, (pairs_length, 1))
        ##将ndarray转换成tensor
        ant_spans = torch.from_numpy(ant_spans).float()
        ant_words = torch.from_numpy(ant_words)
        ana_spans = torch.from_numpy(ana_spans).float()
        ana_words = torch.from_numpy(ana_words)
        pairs_features = torch.from_numpy(pairs_features).float()

        labels_stack = np.concatenate((pairs_labels, label), axis=0)
        assert labels_stack.shape == (pairs_length + 1, )
        labels = torch.from_numpy(labels_stack).float()

        spans = torch.from_numpy(spans).float()
        words = torch.from_numpy(words)
        features = torch.from_numpy(features).float()

        inputs = (spans, words, features, ant_spans, ant_words, ana_spans,
                  ana_words, pairs_features)

        if self.no_targets:
            return inputs

        if label == 0:
            costs = np.concatenate(
                (self.costs['WL'] * (1 - pairs_labels),
                 [self.costs['FN']]))  # Inverse labels: 1=>0, 0=>1
        else:
            costs = np.concatenate(
                (self.costs['FL'] * np.ones_like(pairs_labels), [0]))
        assert costs.shape == (pairs_length + 1, )
        costs = torch.from_numpy(costs).float()

        true_ants_unpad = np.flatnonzero(labels_stack)
        if len(true_ants_unpad) == 0:
            raise ValueError("Error: no True antecedent for mention")
        true_ants = np.pad(true_ants_unpad,
                           (0, len(pairs_labels) + 1 - len(true_ants_unpad)),
                           'edge')
        assert true_ants.shape == (pairs_length + 1, )
        true_ants = torch.from_numpy(true_ants).long()

        false_ants_unpad = np.flatnonzero(1 - labels_stack)
        assert len(false_ants_unpad) != 0
        false_ants = np.pad(false_ants_unpad,
                            (0, len(pairs_labels) + 1 - len(false_ants_unpad)),
                            'edge')
        assert false_ants.shape == (pairs_length + 1, )
        false_ants = torch.from_numpy(false_ants).long()

        targets = (labels, costs, true_ants, false_ants)
        if debug:
            print("Mention", mention_idx)
            print("inputs shapes: ", [a.size() for a in inputs])
            print("targets shapes: ", [a.size() for a in targets])
        return inputs, targets
Ejemplo n.º 5
0
    def __getitem__(self, mention_idx, debug=False):
        # add lru cache
        """
        Return:
            Definitions:
                P is the number of antecedent per mention (number of pairs for the mention)
                S = 250 is the size of the span vector (averaged word embeddings)
                W = 8 is the number of words in a mention (tuned embeddings)
                Fp = 70 is the number of features for a pair of mention
                Fs = 24 is the number of features of a single mention

            if there are some pairs:
                inputs = (spans, words, features, ant_spans, ant_words, ana_spans, ana_words, pairs_features)
                targets = (labels, costs, true_ants, false_ants)
            else:
                inputs = (spans, words, features)
                targets = (labels, costs, true_ants)

            inputs: Tuple of
                spans => (S,)
                words => (W,)
                features => (Fs,)
                + if there are potential antecedents (P > 0):
                    ant_spans => (P, S) or nothing if no pairs
                    ant_words => (P, W) or nothing if no pairs
                    ana_spans => (P, S) or nothing if no pairs
                    ana_words => (P, W) or nothing if no pairs
                    pair_features => (P, Fp) or nothing if no pairs

            targets: Tuple of
                labels => (P+1,)
                costs => (P+1,)
                true_ant => (P+1,)
                + if there are potential antecedents (P > 0):
                    false_ant => (P+1,)

        """
        #print("PRINTING ALL ARR AT MENTIONS IDX")
        #for key, arr in sorted(self.datas.items()) :
        #    print("<----------------->")
        #    print(key)
        #    #print((arr))
        #    print(type(arr))
        #    print(arr[mention_idx])
        #    #print(arr[mention_idx].shape)
        #    print("<---------------->")
        #nangool = input("ALL ARR PRINTED")
        #mentions_tuples = [arr[mention_idx,:] for key, arr in sorted(self.datas.items()) if key.startswith("mentions")]
        #print("+++++++++++++++++++++++++++++++++++++++++++")
        #for arr in mentions_idx_tuples :
        #    print(arr.shape)
        #print("+++++++++++++++++++++++++++++++++++++++++++")
        #print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
        #print(mentions_idx_tuples)
        #print("<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
        #mentions_tuples = list(zip(*(arr[mention_idx,:] for key, arr in sorted(self.datas.items()) if key.startswith(u"mentions"))))
        #print("`````````````````````````````````````````````````````````")
        #for tup in mentions_tuples :
        #    print(tup)
        #    print(tup.shape)
        #    print(tup.shape[0])
        #print("``````````````````````````````````````````````````````````")
        #mentions_tuples = list(zip(*mentions_idx_tuples))
        #print("//////////////////////////////////////////////////////")
        #print(mentions_tuples)
        #print(mentions_tuples[0])
        #print("///////////////////////////////////////////////////////")
        #assert [arr.shape[0] for arr in mentions_tuples[0]] == [6, 1, 1, 1, 250, 8] # Cf order of FEATURES_NAMES in conllparser.py
        #assert [arr.shape[0] for arr in mentions_tuples] == [6, 1, 1, 1, 250, 8] # Cf order of FEATURES_NAMES in conllparser.py
        #features_raw, label, pairs_length, pairs_start_index, spans, words = mentions_tuples#self.mentions[mention_idx]
        features_raw, label, pairs_length, pairs_start_index, spans, words = self.mentions[
            mention_idx]
        pairs_start_index = np.asscalar(pairs_start_index)
        pairs_length = np.asscalar(pairs_length)

        # Build features array (float) from raw features (int)
        assert features_raw.shape[0] == SIZE_FS_COMPRESSED
        features = np.zeros((SIZE_FS, ))
        features[features_raw[0]] = 1
        features[4:15] = encode_distance(features_raw[1])
        features[15] = features_raw[2].astype(float) / features_raw[3].astype(
            float)
        features[16] = features_raw[4]
        features[features_raw[5] + 17] = 1

        if pairs_length == 0:
            spans = torch.from_numpy(spans).float()
            words = torch.from_numpy(words)
            features = torch.from_numpy(features).float()
            inputs = (spans, words, features)
            if self.no_targets:
                return inputs
            true_ant = torch.zeros(1).long()  # zeros = indices of true ant
            costs = torch.from_numpy((1 - label) * self.costs['FN']).float()
            label = torch.from_numpy(label).float()
            targets = (label, costs, true_ant)
            if debug:
                print("inputs shapes: ", [a.size() for a in inputs])
                print("targets shapes: ", [a.size() for a in targets])
            return inputs, targets

        start = pairs_start_index
        end = pairs_start_index + pairs_length
        #pairs = list(zip(*(arr[start:end] for key, arr in sorted(self.datas.items()) if key.startswith(u"pairs"))))
        #assert [arr.shape[0] for arr in pairs[0]] == [1, 9, 1] # Cf order of FEATURES_NAMES in conllparser.py
        pairs = self.pairs[start:end]
        assert len(pairs) == pairs_length
        assert len(
            pairs[0]
        ) == 3  # pair[i] = (pairs_ant_index, pairs_features, pairs_labels)
        pairs_ant_index, pairs_features_raw, pairs_labels = list(zip(*pairs))

        pairs_features_raw = np.stack(pairs_features_raw)
        pairs_labels = np.squeeze(np.stack(pairs_labels), axis=1)

        # Build pair features array (float) from raw features (int)
        assert pairs_features_raw[0, :].shape[0] == SIZE_FP_COMPRESSED
        pairs_features = np.zeros((len(pairs_ant_index), SIZE_FP))
        pairs_features[:, 0:6] = pairs_features_raw[:, 0:6]
        pairs_features[:, 6:17] = encode_distance(pairs_features_raw[:, 6])
        pairs_features[:, 17:28] = encode_distance(pairs_features_raw[:, 7])
        pairs_features[:, 28] = pairs_features_raw[:, 8]
        # prepare antecent features
        ant_features_raw = np.concatenate([
            self.mentions[np.asscalar(idx)][0][np.newaxis, :]
            for idx in pairs_ant_index
        ])
        #ant_features_raw = np.concatenate([arr[]])
        ant_features = np.zeros((pairs_length, SIZE_FS - SIZE_GENRE))
        ant_features[:, ant_features_raw[:, 0]] = 1
        ant_features[:, 4:15] = encode_distance(ant_features_raw[:, 1])
        ant_features[:, 15] = ant_features_raw[:, 2].astype(
            float) / ant_features_raw[:, 3].astype(float)
        ant_features[:, 16] = ant_features_raw[:, 4]
        pairs_features[:, 29:46] = ant_features
        # Here we keep the genre
        ana_features = np.tile(features, (pairs_length, 1))
        pairs_features[:, 46:] = ana_features

        ant_spans = np.concatenate([
            self.mentions[np.asscalar(idx)][4][np.newaxis, :]
            for idx in pairs_ant_index
        ])
        ant_words = np.concatenate([
            self.mentions[np.asscalar(idx)][5][np.newaxis, :]
            for idx in pairs_ant_index
        ])
        ana_spans = np.tile(spans, (pairs_length, 1))
        ana_words = np.tile(words, (pairs_length, 1))
        ant_spans = torch.from_numpy(ant_spans).float()
        ant_words = torch.from_numpy(ant_words)
        ana_spans = torch.from_numpy(ana_spans).float()
        ana_words = torch.from_numpy(ana_words)
        pairs_features = torch.from_numpy(pairs_features).float()

        labels_stack = np.concatenate((pairs_labels, label), axis=0)
        assert labels_stack.shape == (pairs_length + 1, )
        labels = torch.from_numpy(labels_stack).float()

        spans = torch.from_numpy(spans).float()
        words = torch.from_numpy(words)
        features = torch.from_numpy(features).float()

        inputs = (spans, words, features, ant_spans, ant_words, ana_spans,
                  ana_words, pairs_features)
        del spans, words, features, ant_spans, ant_words, ana_spans, ana_words, pairs_features

        if self.no_targets:
            return inputs

        if label == 0:
            costs = np.concatenate(
                (self.costs['WL'] * (1 - pairs_labels),
                 [self.costs['FN']]))  # Inverse labels: 1=>0, 0=>1
        else:
            costs = np.concatenate(
                (self.costs['FL'] * np.ones_like(pairs_labels), [0]))
        assert costs.shape == (pairs_length + 1, )
        costs = torch.from_numpy(costs).float()

        true_ants_unpad = np.flatnonzero(labels_stack)
        if len(true_ants_unpad) == 0:
            raise ValueError("Error: no True antecedent for mention")
        true_ants = np.pad(true_ants_unpad,
                           (0, len(pairs_labels) + 1 - len(true_ants_unpad)),
                           'edge')
        assert true_ants.shape == (pairs_length + 1, )
        true_ants = torch.from_numpy(true_ants).long()

        false_ants_unpad = np.flatnonzero(1 - labels_stack)
        assert len(false_ants_unpad) != 0
        false_ants = np.pad(false_ants_unpad,
                            (0, len(pairs_labels) + 1 - len(false_ants_unpad)),
                            'edge')
        assert false_ants.shape == (pairs_length + 1, )
        false_ants = torch.from_numpy(false_ants).long()

        targets = (labels, costs, true_ants, false_ants)
        del labels, costs, true_ants, false_ants
        if debug:
            print("Mention", mention_idx)
            print("inputs shapes: ", [a.size() for a in inputs])
            print("targets shapes: ", [a.size() for a in targets])
        #lknvf = input("HEY I REACHED THE END OF THIS FUNCTION THAT SHOULDN'T BE HAPPENING")
        return inputs, targets
Ejemplo n.º 6
0
    def run_coref_on_mentions(self, mentions):
        '''
        Run the coreference model on a mentions list
        '''
        best_ant = {}
        best_score = {}
        n_ant = 0
        #print(mentions)
        #inp = np.empty((SIZE_SINGLE_IN, len(mentions)))
        #print("SHAPE OF INP SHAPE")
        #print(inp.shape)
        #yur = input()
        #for i, mention_idx in enumerate(mentions):
        #    mention = self.data[mention_idx]
        #    print(mention)
        #    frio = input("mention extraced from data")
        #    print(type(mention))
        #    shah = input("mention type")
        #    print()
        #    print("mention embedding", mention.embedding.shape)
        #    inp[:len(mention.embedding), i] = mention.embedding
        #    inp[:len(mention.embedding), i] = mention.features
        #    inp[:len(mention.embedding), i] = self.data.genre

        mention_idx_list = []
        mentions_spans = []
        mentions_words = []
        mentions_features = []
        pairs_ant_idx = []
        pairs_features = []
        pairs_labels = []
        mentions_labels = []
        mentions_pairs_start = []
        mentions_pairs_length = []
        mentions_location = []

        mentions_stories = []
        n_mentions = 0
        total_pairs = 0

        #print(mentions)
        #oven_sotry = input('MENTIONS PRINTED')
        #if debug: print("mentions", self.mentions, str([m.gold_label for m in self.mentions]))
        # create 2 for loops, one for single pairs and one for pairs

        for mention_idx, antecedents_idx in list(
                self.data.get_candidate_pairs(mentions, self.max_dist,
                                              self.max_dist_match)):
            n_mentions += 1
            doc_id = 1
            mention = self.data[mention_idx]

            # let's create the story
            story_embeds = []
            raw_utterances = self.get_utterances()
            for utt_index in range(mention.utterance_index):
                utt_dealt = raw_utterances[utt_index]
                for token in utt_dealt:
                    # since mention_words_idx works on Mention, we convert every token into a mention
                    token_word_idx = word_idx_finder(self.embed_extractor,
                                                     token.text)
                    #story_embeds.append(token_embed.tolist())
                    story_embeds.append(token_word_idx)
            final_utt_dealt = raw_utterances[mention.utterance_index]
            for token_index in range(mention.start):
                token_word_idx = word_idx_finder(
                    self.embed_extractor, final_utt_dealt[token_index].text)
                #story_embeds.append(token_embed.tolist())
                story_embeds.append(token_word_idx)

            mentions_stories.append(story_embeds)
            mention_idx_list.append(mention_idx)
            mentions_spans.append(mention.spans_embeddings)
            w_idx = mention_words_idx(self.embed_extractor, mention)

            if w_idx is None:
                print("error in", self.name, self.part,
                      mention.utterance_index)
            mentions_words.append(w_idx)
            mentions_features.append(
                self.get_single_mention_features_conll(mention))
            mentions_location.append([
                mention.start, mention.end, mention.utterance_index,
                mention_idx, doc_id
            ])
            ants = [self.data.mentions[ant_idx] for ant_idx in antecedents_idx]

            # Some display functions
            #tuy = input()
            #print("************************************************************************************************")
            #print("MENTION IDX,",mention_idx)
            #print("MENTION REFRED,",mention)
            #print("MENTION SPANS,",mention.spans_embeddings.shape)
            #print("MENTION FEATURES,",self.get_single_mention_features_conll(mention))
            #print("MENTION LOCATION,",[mention.start,mention.end,mention.utterance_index,mention_idx,doc_id])
            #print("ANTS ,",ants)
            #print("*************************************************************************************************")
            #hua = input()
            no_antecedent = not any(
                ant.gold_label == mention.gold_label
                for ant in ants) or mention.gold_label is None
            if antecedents_idx:
                pairs_ant_idx += [idx for idx in antecedents_idx]
                pairs_features += [
                    self.get_pair_mentions_features_conll(ant, mention)
                    for ant in ants
                ]
                ant_labels = [0 for ant in ants] if no_antecedent else [
                    1 if ant.gold_label == mention.gold_label else 0
                    for ant in ants
                ]
                pairs_labels += ant_labels
            mentions_labels.append(1 if no_antecedent else 0)
            mentions_pairs_start.append(total_pairs)
            total_pairs += len(ants)
            mentions_pairs_length.append(len(ants))

        out_dict = {
            FEATURES_NAMES[0]:
            mentions_features,
            FEATURES_NAMES[1]:
            mentions_labels,
            FEATURES_NAMES[2]:
            mentions_pairs_length,
            FEATURES_NAMES[3]:
            mentions_pairs_start,
            FEATURES_NAMES[4]:
            mentions_spans,
            FEATURES_NAMES[5]:
            mentions_words,
            #FEATURES_NAMES[6]: pairs_ant_idx if pairs_ant_idx else None,
            FEATURES_NAMES[6]:
            pairs_ant_idx if pairs_ant_idx else list(),
            #FEATURES_NAMES[7]: pairs_features if pairs_features else None,
            FEATURES_NAMES[7]:
            pairs_features if pairs_features else list(),
            #FEATURES_NAMES[8]: pairs_labels if pairs_labels else None,
            FEATURES_NAMES[8]:
            pairs_labels if pairs_labels else list(),
            FEATURES_NAMES[9]:
            mentions_stories
        }
        gathering_dict = dict((feat, None) for feat in FEATURES_NAMES)
        n_mentions_list = []
        pairs_ant_index = 0
        pairs_start_index = 0
        for n, p, arrays_dict in tqdm([(n_mentions, total_pairs, out_dict)]):
            #print(out_dict)
            #pizza_hut = input('OUT DICT PRINTED')
            #print(arrays_dict)
            #dominoes = input('ARRAYS DICT PRINTED')
            for f in FEATURES_NAMES:
                if gathering_dict[f] is None:
                    gathering_dict[f] = arrays_dict[f]
                else:
                    if f == FEATURES_NAMES[6]:
                        array = [a + pairs_ant_index for a in arrays_dict[f]]
                    elif f == FEATURES_NAMES[3]:
                        array = [a + pairs_start_index for a in arrays_dict[f]]
                    else:
                        array = arrays_dict[f]
                    gathering_dict[f] += array
            pairs_ant_index += n
            pairs_start_index += p
            n_mentions_list.append(n)

        mention_feature_dict = dict()
        pairs_feature_dict = dict()
        train_phase = True

        for feature in FEATURES_NAMES[:10]:
            print("Building numpy array for", feature, "length",
                  len(gathering_dict[feature]))
            if feature != "mentions_spans":
                #array = np.array(gathering_dict[feature])
                # check if we are dealing with length of memories
                if feature == "mentions_stories" or feature == "pairs_stories":
                    gathering_array = []
                    max_story_len = 200
                    for story in gathering_dict[feature]:
                        #print(len(story[0]))
                        #print(len(story[1]))
                        #random_pause = input()
                        if len(story) > 200:
                            final_story = story[-200:]
                        else:
                            number_to_append = max(0,
                                                   max_story_len - len(story))
                            #number_to_append = min(number_to_append,50)
                            final_story = story + number_to_append * [0]
                            #print(final_story)
                            #print(len(final_story))
                            #random_pause = input()
                        gathering_array.append(final_story)
                    array = np.array(gathering_array)
                    print(array.shape)
                else:
                    array = np.array(gathering_dict[feature])

                if array.ndim == 1:
                    print("expand_dims for feature, ", feature)
                    array = np.expand_dims(array, axis=1)
            else:
                array = np.stack(gathering_dict[feature])
            # check_numpy_array(feature, array, n_mentions_list)
            print("Saving numpy", feature, "size", array.shape)
            #array_save = input()
            if feature.startswith("mentions"):
                mention_feature_dict[feature] = array
            if feature.startswith("pairs"):
                pairs_feature_dict[feature] = array

        # zip it with pairs dict
        self.mentions = list(
            zip(*(arr for key, arr in sorted(mention_feature_dict.items()))))
        self.pairs = list(
            zip(*(arr for key, arr in sorted(pairs_feature_dict.items()))))
        #print("LEN OF PAIRS IS,",len(self.pairs))

        #jsk = input("PRINTING THE PAIRS")
        #print(self.pairs)
        #sdghr = input("ALL PAIRS PRINTED")
        #print("MENTION PAIRS LENGTH IS,",mention_feature_dict['mentions_pairs_length'])
        #victoria = input()

        for i in range(len(mention_feature_dict[FEATURES_NAMES[0]])):
            mention_idx = mention_idx_list[i]
            features_raw = mention_feature_dict['mentions_features'][i, :]
            #print("FEATUERES_RAW_PRINTED_is,",features_raw)
            label = mention_feature_dict['mentions_labels'][i, :]
            pairs_length = mention_feature_dict['mentions_pairs_length'][i, :]
            pairs_start_index = mention_feature_dict[
                'mentions_pairs_start_index'][i]
            mentions_stories = mention_feature_dict['mentions_stories'][i]

            spans = mention_feature_dict['mentions_spans'][i, :]
            words = mention_feature_dict['mentions_words'][i, :]

            pairs_start_index = np.asscalar(pairs_start_index)
            pairs_length = np.asscalar(pairs_length)

            # Build features array (float) from raw features (int)
            assert features_raw.shape[0] == SIZE_FS_COMPRESSED
            features = np.zeros((SIZE_FS, ))
            features[features_raw[0]] = 1
            features[4:15] = encode_distance(features_raw[1])
            features[15] = features_raw[2].astype(
                float) / features_raw[3].astype(float)
            features[16] = features_raw[4]
            features[features_raw[5] + 17] = 1

            #print("====================================<>============================================")
            #print("TYPE OF SPANS,",type(spans))
            #print("TYPE OF WORDS,",type(words))
            #print("TYPE OF FEATURES,",type(features))
            #print("====================================<>============================================")

            spans = spans[np.newaxis, :]
            print("PRINTING SHAPE OF WORDS")
            print(words.shape)
            words = words[np.newaxis, :]
            features = features[np.newaxis, :]
            mentions_stories = mentions_stories[np.newaxis, :]

            spans = torch.from_numpy(spans).float()
            words = torch.from_numpy(words)
            features = torch.from_numpy(features).float()
            mentions_stories = torch.from_numpy(mentions_stories)
            #print(mentions_stories.size())
            #print(words.size())
            #kake = input("size of mentions stories is ")

            # inputs for the single mentions

            #print("SINGLE SCORES COMPUTING")
            single_inputs = (spans, words, features, mentions_stories)
            score = self.coref_model.get_multiple_single_score(
                single_inputs).tolist()[0][0]

            #print("PRINTING SINGLE SCORE")
            #print(score)
            #sgbet = input("SINGLE SCORE PRINTED")
            self.mentions_single_scores[mention_idx] = score
            best_score[mention_idx] = score - 50 * (self.greedyness - 0.5)
            #print("SINGLE SCORES COMPUTED")

            if pairs_length == 0:
                continue

            start = pairs_start_index
            end = pairs_start_index + pairs_length
            pairs = self.pairs[start:end]
            #print("START IS,",start)
            #print("END IS,",end)
            #print("PAIRS LENGTH,",pairs_length)
            #print("LEN OF PAIRS IS,",len(pairs))
            assert len(pairs) == pairs_length
            assert len(
                pairs[0]
            ) == 3  # pair[i] = (pairs_ant_index, pairs_features, pairs_labels)
            pairs_ant_index, pairs_features_raw, pairs_labels = list(
                zip(*pairs))

            pairs_features_raw = np.stack(pairs_features_raw)
            pairs_labels = np.squeeze(np.stack(pairs_labels), axis=1)

            # Build pair features array (float) from raw features (int)
            assert pairs_features_raw[0, :].shape[0] == SIZE_FP_COMPRESSED
            pairs_features = np.zeros((len(pairs_ant_index), SIZE_FP))
            pairs_features[:, 0:6] = pairs_features_raw[:, 0:6]
            pairs_features[:, 6:17] = encode_distance(pairs_features_raw[:, 6])
            pairs_features[:, 17:28] = encode_distance(pairs_features_raw[:,
                                                                          7])
            pairs_features[:, 28] = pairs_features_raw[:, 8]
            # prepare antecent features

            # printing antecedent features
            #hsya = input("PRINTING DATA MENTIONS")
            #print(self.data.mentions)
            #uba = input("PRINTED DATA MENTIONS")
            ant_features_raw = np.concatenate([
                self.mentions[np.asscalar(idx)][0][np.newaxis, :]
                for idx in pairs_ant_index
            ])
            ant_features = np.zeros((pairs_length, SIZE_FS - SIZE_GENRE))
            ant_features[:, ant_features_raw[:, 0]] = 1
            ant_features[:, 4:15] = encode_distance(ant_features_raw[:, 1])
            ant_features[:, 15] = ant_features_raw[:, 2].astype(
                float) / ant_features_raw[:, 3].astype(float)
            ant_features[:, 16] = ant_features_raw[:, 4]
            pairs_features[:, 29:46] = ant_features
            # Here we keep the genre
            ana_features = np.tile(features, (pairs_length, 1))
            pairs_features[:, 46:] = ana_features

            ant_spans = np.concatenate([
                self.mentions[np.asscalar(idx)][4][np.newaxis, :]
                for idx in pairs_ant_index
            ])
            ant_words = np.concatenate([
                self.mentions[np.asscalar(idx)][6][np.newaxis, :]
                for idx in pairs_ant_index
            ])
            ana_spans = np.tile(spans, (pairs_length, 1))
            ana_words = np.tile(words, (pairs_length, 1))

            ant_spans = ant_spans[np.newaxis, :]
            ant_words = ant_words[np.newaxis, :]
            ana_spans = ana_spans[np.newaxis, :]
            ana_words = ana_words[np.newaxis, :]
            pairs_features = pairs_features[np.newaxis, :]

            ant_spans = torch.from_numpy(ant_spans).float()
            ant_words = torch.from_numpy(ant_words)
            ana_spans = torch.from_numpy(ana_spans).float()
            ana_words = torch.from_numpy(ana_words)
            pairs_features = torch.from_numpy(pairs_features).float()

            labels_stack = np.concatenate((pairs_labels, label), axis=0)
            assert labels_stack.shape == (pairs_length + 1, )
            labels = torch.from_numpy(labels_stack).float()

            # inputs for the pairs of mentions
            pairs_inputs = (spans, words, features, ant_spans, ant_words,
                            ana_spans, ana_words, pairs_features,
                            mentions_stories)

            #print("PAIRS INPUT CREATED")

            score = self.coref_model.get_multiple_pair_score(
                pairs_inputs).tolist()[0][:-1]
            #print("SCORES GOT IS ")
            #print(score)
            #hutys = input("SCORES PRINTED")
            for ant_idx, s in zip(pairs_ant_idx, score):
                self.mentions_pairs_scores[mention_idx][ant_idx] = s
                if s > best_score[mention_idx]:
                    best_score[mention_idx] = s
                    best_ant[mention_idx] = ant_idx
            if mention_idx in best_ant:
                n_ant += 1
                self._merge_coreference_clusters(best_ant[mention_idx],
                                                 mention_idx)

        #score = self.coref_model.get_multiple_single_score(inp.T)
        #score = self.coref_model(tuple(mentions_spans,mentions_words,mentions_features

        # find a way to access the mention_idx
        return (n_ant, best_ant)