Exemple #1
0
class ELMO(nn.Module):
    def __init__(self, projection_dim=None, dropout=0.5):
        super(ELMO, self).__init__()
        self.dropout = dropout
        # funky... should be able to not force this
        device = 0 if torch.cuda.is_available() else -1
        self.elmo = ElmoEmbedder(cuda_device=device)
        self.layer_weights = nn.Parameter(torch.ones(3))
        if projection_dim:
            self.projection = nn.Linear(1024, projection_dim)
        else:
            self.projection = None

    def forward(self, batch_sentences):
        # Embed words
        embeddings = self.elmo.batch_to_embeddings(batch_sentences)[0]

        # Apply learned weights to combine 3 layers
        norm_weights = F.softmax(self.layer_weights, dim=-1).view(1, 3, 1, 1)
        norm_weights = norm_weights.expand_as(embeddings)
        embeddings = (norm_weights * embeddings).sum(1)

        if self.dropout:
            embeddings = F.dropout(embeddings, self.dropout, self.training)

        if self.projection is not None:
            flat = embeddings.view(-1, embeddings.size(-1))
            projected = F.relu(self.projection(flat))
            embeddings = projected.view(
                embeddings.size(0), embeddings.size(1), -1
            )
        return embeddings
Exemple #2
0
class WordEmbeddings():
    """
        ELMo
        https://allennlp.org/elmo

    """
    def __init__(
            self,
            options_file='https://exawizardsallenlp.blob.core.windows.net/data/options.json',
            weight_file='/content/drive/My Drive/SIFRank_ja/auxiliary_data/weights.hdf5',
            cuda_device=0):
        self.cuda_device = cuda_device
        self.elmo = ElmoEmbedder(options_file,
                                 weight_file,
                                 cuda_device=self.cuda_device)

    def get_tokenized_words_embeddings(self, sents_tokened):
        """
        @see EmbeddingDistributor
        :param tokenized_sents: list of tokenized words string (sentences/phrases)
        :return: ndarray with shape (len(sents), dimension of embeddings)
        """

        elmo_embedding, elmo_mask = self.elmo.batch_to_embeddings(
            sents_tokened)
        if (self.cuda_device > -1):
            return elmo_embedding.cpu(), elmo_mask.cpu()
        else:
            return elmo_embedding, elmo_mask
Exemple #3
0
class WordEmbeddings():
    """
        ELMo
        https://allennlp.org/elmo

    """
    def __init__(
            self,
            options_file="../auxiliary_data/elmo_2x4096_512_2048cnn_2xhighway_options.json",
            weight_file="../auxiliary_data/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5",
            cuda_device=0):
        self.cuda_device = cuda_device
        self.elmo = ElmoEmbedder(options_file,
                                 weight_file,
                                 cuda_device=self.cuda_device)

    def get_tokenized_words_embeddings(self, sents_tokened):
        """
        @see EmbeddingDistributor
        :param tokenized_sents: list of tokenized words string (sentences/phrases)
        :return: ndarray with shape (len(sents), dimension of embeddings)
        """

        elmo_embedding, elmo_mask = self.elmo.batch_to_embeddings(
            [sents_tokened])
        if (self.cuda_device > -1):
            return elmo_embedding.cpu(), elmo_mask.cpu()
        else:
            return elmo_embedding, elmo_mask
def get_elmo(sentences, batch_size, tokens):
    '''
        Returns numpy array of reduced elmo representations. Old function
    '''

    x = []
    options_file = "/data/models/pytorch/elmo/options/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json"
    weight_file = "/data/models/pytorch/elmo/weights/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5"
    # Convert x data to embeddings
    embeddings = ElmoEmbedder(options_file, weight_file, cuda_device=0)
    sentences = [
        sentences[j:j + batch_size]
        for j in range(0, len(sentences), batch_size)
    ]
    tokens = [
        tokens[j:j + batch_size] for j in range(0, len(tokens), batch_size)
    ]
    for toks, words in zip(tokens, sentences):
        toks = torch.tensor(toks,
                            dtype=torch.long,
                            device=torch.device('cuda:0')).unsqueeze(1)
        raw_embeds, _ = embeddings.batch_to_embeddings(words)
        raw_embeds = torch.cat((raw_embeds[:, 0, :, :], raw_embeds[:, 1, :, :],
                                raw_embeds[:, 2, :, :]),
                               dim=2)
        root_embeds = choose_tokens(batch=raw_embeds, lengths=toks)
        x.append(root_embeds.detach().cpu().numpy())
        return np.concatenate(x, axis=0)
Exemple #5
0
def main():
    # Init ELMO model
    elmo_emb = ElmoEmbedder(weight_file='res/elmo/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5',
                            options_file='res/elmo/elmo_2x1024_128_2048cnn_1xhighway_options.json')

    # "Warm up" ELMo embedder (https://github.com/allenai/allennlp/blob/master/tutorials/how_to/elmo.md)
    warmup_data, _, _, _ = util.load_dataset(splits=['train'], base_path='../data/multiwoz/delex/')
    warmup_data = [dg.to_dict() for dg in warmup_data['train'].iter_dialogs()][:500]

    print('Warming up ELMo embedder on train dialogs')
    for d in tqdm(warmup_data):
        utts = []
        for t in d['turns']:
            utts.append(t['transcript'])
        _ = elmo_emb.batch_to_embeddings(utts)

    base_path = '../data/multiwoz/delex/'
    splits = ['train', 'test', 'dev']
    #splits = ['dev']

    # Load dialogs
    print('Creating elmo embeddings for annotated data')
    utterance_featurizer = ElmoFeaturizer(elmo_emb, 'utterance')
    sys_act_featurizer = ElmoFeaturizer(elmo_emb, 'act')

    elmo = Elmo(utterance_featurizer, sys_act_featurizer)

    dia_data, ontology = util.generate_dataset_elmo(elmo, splits=splits, base_path=base_path)

    # Save dataset
    for split in splits:
        pickle.dump(dia_data[split], open('{}_elmo_full.pkl'.format(base_path + split), 'wb'))
        # Workaround for s2v featurization
        dia_data[split] = [dg.to_dict() for dg in dia_data[split].iter_dialogs()]

    ## Create s2v embedding
    s2v = ontology.values
    if DELEX:
        s2v = util.delexicalize(s2v)
    s2v = util.fix_s2v(s2v, dia_data, splits=splits)

    slot_featurizer = ElmoFeaturizer(elmo_emb, "slot")
    value_featurizer = ElmoFeaturizer(elmo_emb, "value")

    s2v = util.featurize_s2v(s2v, slot_featurizer, value_featurizer, elmo=True, elmo_pool=False)

    # Save s2v
    pickle.dump(s2v, open('{}_elmo_full.pkl'.format(base_path + 's2v'), 'wb'))
Exemple #6
0
class ElmoData(DataGenerator):
    def __init__(self, device, cache_dir, state):
        # tokenize sents
        self.tokenizer = MosesTokenizer()
        self.preprocess = lambda sent: self.tokenizer.tokenize(sent.lower(),
                                                               escape=False)
        self.elmo = ElmoEmbedder(
            options_file=os.path.join(
                cache_dir, 'elmo_2x4096_512_2048cnn_2xhighway_options.json'),
            weight_file=os.path.join(
                cache_dir, 'elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5'),
            cuda_device=0 if device.type == 'cuda' else -1)
        self.device = device
        self.state = RandomState(state)
        self.name = 'ELMo'
        self.is_unk = lambda tok_id: False

    def load_file(self, filename):
        data = pd.read_csv(filename)
        questions = [self.preprocess(sent) for sent in data.question_text]
        labels = data.target.tolist()
        return questions, labels

    def prepare_batch(self, batch):
        batch, labels = batch

        # assumes tokenized batch
        lengths = torch.tensor([len(sent) for sent in batch])
        lengths, perm_index = lengths.sort(0, descending=True)

        # get the elmo embeddings
        full_embedding_list, _ = self.elmo.batch_to_embeddings(batch)

        # average over embedding layers for every token
        # sequence[-1,:,:] for the last layer only
        #last_layer_list = [torch.tensor(sequence.sum(axis = 0), device = self.device)
        padded_embeddings = torch.stack(
            [sequence.float().mean(0) for sequence in full_embedding_list])
        #padded_embeddings = pad_sequence(last_layer_list, batch_first = True)

        # resort sentences and labels
        labels = torch.tensor(labels, device=self.device)
        labels = labels[perm_index]
        padded_embeddings = padded_embeddings[perm_index]

        return padded_embeddings, labels, lengths
Exemple #7
0
class Model(BaseModel):
    def __init__(self, vocab, config):
        word2id = vocab.word2idx
        super(Model, self).__init__()
        vocab_num = len(word2id)
        self.word2id = word2id
        self.config = config
        self.char_dict = preprocess.get_char_dict(
            'data/char_vocab.english.txt')
        self.genres = {
            g: i
            for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])
        }
        self.device = torch.device("cuda:" + config.cuda)

        self.emb = nn.Embedding(vocab_num, 350)

        emb1 = EmbedLoader().load_with_vocab(config.glove,
                                             vocab,
                                             normalize=False)
        emb2 = EmbedLoader().load_with_vocab(config.turian,
                                             vocab,
                                             normalize=False)
        pre_emb = np.concatenate((emb1, emb2), axis=1)
        pre_emb /= (np.linalg.norm(pre_emb, axis=1, keepdims=True) + 1e-12)

        if pre_emb is not None:
            self.emb.weight = nn.Parameter(torch.from_numpy(pre_emb).float())
            for param in self.emb.parameters():
                param.requires_grad = False
        self.emb_dropout = nn.Dropout(inplace=True)

        if config.use_elmo:
            self.elmo = ElmoEmbedder(
                options_file=
                'data/elmo/elmo_2x4096_512_2048cnn_2xhighway_options.json',
                weight_file=
                'data/elmo/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5',
                cuda_device=int(config.cuda))
            print("elmo load over.")
            self.elmo_args = torch.randn((3),
                                         requires_grad=True).to(self.device)

        self.char_emb = nn.Embedding(len(self.char_dict), config.char_emb_size)
        self.conv1 = nn.Conv1d(config.char_emb_size, 50, 3)
        self.conv2 = nn.Conv1d(config.char_emb_size, 50, 4)
        self.conv3 = nn.Conv1d(config.char_emb_size, 50, 5)

        self.feature_emb = nn.Embedding(config.span_width, config.feature_size)
        self.feature_emb_dropout = nn.Dropout(p=0.2, inplace=True)

        self.mention_distance_emb = nn.Embedding(10, config.feature_size)
        self.distance_drop = nn.Dropout(p=0.2, inplace=True)

        self.genre_emb = nn.Embedding(7, config.feature_size)
        self.speaker_emb = nn.Embedding(2, config.feature_size)

        self.bilstm = VarLSTM(input_size=350 + 150 * config.use_CNN +
                              config.use_elmo * 1024,
                              hidden_size=200,
                              bidirectional=True,
                              batch_first=True,
                              hidden_dropout=0.2)
        # self.bilstm = nn.LSTM(input_size=500, hidden_size=200, bidirectional=True, batch_first=True)
        self.h0 = nn.init.orthogonal_(torch.empty(2, 1, 200)).to(self.device)
        self.c0 = nn.init.orthogonal_(torch.empty(2, 1, 200)).to(self.device)
        self.bilstm_drop = nn.Dropout(p=0.2, inplace=True)

        self.atten = ffnn(input_size=400,
                          hidden_size=config.atten_hidden_size,
                          output_size=1)
        self.mention_score = ffnn(input_size=1320,
                                  hidden_size=config.mention_hidden_size,
                                  output_size=1)
        self.sa = ffnn(input_size=3980 + 40 * config.use_metadata,
                       hidden_size=config.sa_hidden_size,
                       output_size=1)
        self.mention_start_np = None
        self.mention_end_np = None

    def _reorder_lstm(self, word_emb, seq_lens):
        sort_ind = sorted(range(len(seq_lens)),
                          key=lambda i: seq_lens[i],
                          reverse=True)
        seq_lens_re = [seq_lens[i] for i in sort_ind]
        emb_seq = self.reorder_sequence(word_emb, sort_ind, batch_first=True)
        packed_seq = nn.utils.rnn.pack_padded_sequence(emb_seq,
                                                       seq_lens_re,
                                                       batch_first=True)

        h0 = self.h0.repeat(1, len(seq_lens), 1)
        c0 = self.c0.repeat(1, len(seq_lens), 1)
        packed_out, final_states = self.bilstm(packed_seq, (h0, c0))

        lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_out,
                                                       batch_first=True)
        back_map = {ind: i for i, ind in enumerate(sort_ind)}
        reorder_ind = [back_map[i] for i in range(len(seq_lens_re))]
        lstm_out = self.reorder_sequence(lstm_out,
                                         reorder_ind,
                                         batch_first=True)
        return lstm_out

    def reorder_sequence(self, sequence_emb, order, batch_first=True):
        """
        sequence_emb: [T, B, D] if not batch_first
        order: list of sequence length
        """
        batch_dim = 0 if batch_first else 1
        assert len(order) == sequence_emb.size()[batch_dim]

        order = torch.LongTensor(order)
        order = order.to(sequence_emb).long()

        sorted_ = sequence_emb.index_select(index=order, dim=batch_dim)

        del order
        return sorted_

    def flat_lstm(self, lstm_out, seq_lens):
        batch = lstm_out.shape[0]
        seq = lstm_out.shape[1]
        dim = lstm_out.shape[2]
        l = [
            j + i * seq for i, seq_len in enumerate(seq_lens)
            for j in range(seq_len)
        ]
        flatted = torch.index_select(lstm_out.view(batch * seq, dim), 0,
                                     torch.LongTensor(l).to(self.device))
        return flatted

    def potential_mention_index(self, word_index, max_sent_len):
        # get mention index [3,2]:the first sentence is 3 and secend 2
        # [0,0,0,1,1] --> [[0, 0], [0, 1], [1, 1], [1, 2], [2, 2], [3, 3], [3, 4], [4, 4]] (max =2)
        potential_mention = []
        for i in range(len(word_index)):
            for j in range(i, i + max_sent_len):
                if (j < len(word_index) and word_index[i] == word_index[j]):
                    potential_mention.append([i, j])
        return potential_mention

    def get_mention_start_end(self, seq_lens):
        # 序列长度转换成mention
        # [3,2] --> [0,0,0,1,1]
        word_index = [0] * sum(seq_lens)
        sent_index = 0
        index = 0
        for length in seq_lens:
            for l in range(length):
                word_index[index] = sent_index
                index += 1
            sent_index += 1

        # [0,0,0,1,1]-->[[0,0],[0,1],[0,2]....]
        mention_id = self.potential_mention_index(word_index,
                                                  self.config.span_width)
        mention_start = np.array(mention_id, dtype=int)[:, 0]
        mention_end = np.array(mention_id, dtype=int)[:, 1]
        return mention_start, mention_end

    def get_mention_emb(self, flatten_lstm, mention_start, mention_end):
        mention_start_tensor = torch.from_numpy(mention_start).to(self.device)
        mention_end_tensor = torch.from_numpy(mention_end).to(self.device)
        emb_start = flatten_lstm.index_select(
            dim=0, index=mention_start_tensor)  # [mention_num,embed]
        emb_end = flatten_lstm.index_select(
            dim=0, index=mention_end_tensor)  # [mention_num,embed]
        return emb_start, emb_end

    def get_mask(self, mention_start, mention_end):
        # big mask for attention
        mention_num = mention_start.shape[0]
        mask = np.zeros(
            (mention_num, self.config.span_width))  # [mention_num,span_width]
        for i in range(mention_num):
            start = mention_start[i]
            end = mention_end[i]
            # 实际上是宽度
            for j in range(end - start + 1):
                mask[i][j] = 1
        mask = torch.from_numpy(mask)  # [mention_num,max_mention]
        # 0-->-inf  1-->0
        log_mask = torch.log(mask)
        return log_mask

    def get_mention_index(self, mention_start, max_mention):
        # TODO 后面可能要改
        assert len(mention_start.shape) == 1
        mention_start_tensor = torch.from_numpy(mention_start)
        num_mention = mention_start_tensor.shape[0]
        mention_index = mention_start_tensor.expand(
            max_mention, num_mention).transpose(0,
                                                1)  # [num_mention,max_mention]
        assert mention_index.shape[0] == num_mention
        assert mention_index.shape[1] == max_mention
        range_add = torch.arange(0, max_mention).expand(
            num_mention, max_mention).long()  # [num_mention,max_mention]
        mention_index = mention_index + range_add
        mention_index = torch.min(
            mention_index,
            torch.LongTensor([mention_start[-1]
                              ]).expand(num_mention, max_mention))
        return mention_index.to(self.device)

    def sort_mention(self, mention_start, mention_end, candidate_mention_emb,
                     candidate_mention_score, seq_lens):
        # 排序记录,高分段在前面
        mention_score, mention_ids = torch.sort(candidate_mention_score,
                                                descending=True)
        preserve_mention_num = int(self.config.mention_ratio * sum(seq_lens))
        mention_ids = mention_ids[0:preserve_mention_num]
        mention_score = mention_score[0:preserve_mention_num]

        mention_start_tensor = torch.from_numpy(mention_start).to(
            self.device).index_select(dim=0,
                                      index=mention_ids)  # [lamda*word_num]
        mention_end_tensor = torch.from_numpy(mention_end).to(
            self.device).index_select(dim=0,
                                      index=mention_ids)  # [lamda*word_num]
        mention_emb = candidate_mention_emb.index_select(
            index=mention_ids, dim=0)  # [lamda*word_num,emb]
        assert mention_score.shape[0] == preserve_mention_num
        assert mention_start_tensor.shape[0] == preserve_mention_num
        assert mention_end_tensor.shape[0] == preserve_mention_num
        assert mention_emb.shape[0] == preserve_mention_num
        # TODO 不交叉没做处理

        # 对start进行再排序,实际位置在前面
        # TODO 这里只考虑了start没有考虑end
        mention_start_tensor, temp_index = torch.sort(mention_start_tensor)
        mention_end_tensor = mention_end_tensor.index_select(dim=0,
                                                             index=temp_index)
        mention_emb = mention_emb.index_select(dim=0, index=temp_index)
        mention_score = mention_score.index_select(dim=0, index=temp_index)
        return mention_start_tensor, mention_end_tensor, mention_score, mention_emb

    def get_antecedents(self, mention_starts, max_antecedents):
        num_mention = mention_starts.shape[0]
        max_antecedents = min(max_antecedents, num_mention)
        # mention和它是第几个mention之间的对应关系
        antecedents = np.zeros((num_mention, max_antecedents),
                               dtype=int)  # [num_mention,max_an]
        # 记录长度
        antecedents_len = [0] * num_mention
        for i in range(num_mention):
            ante_count = 0
            for j in range(max(0, i - max_antecedents), i):
                antecedents[i, ante_count] = j
                ante_count += 1
            # 补位操作
            for j in range(ante_count, max_antecedents):
                antecedents[i, j] = 0
            antecedents_len[i] = ante_count
        assert antecedents.shape[1] == max_antecedents
        return antecedents, antecedents_len

    def get_antecedents_score(self, span_represent, mention_score, antecedents,
                              antecedents_len, mention_speakers_ids, genre):
        num_mention = mention_score.shape[0]
        max_antecedent = antecedents.shape[1]

        pair_emb = self.get_pair_emb(span_represent, antecedents,
                                     mention_speakers_ids,
                                     genre)  # [span_num,max_ant,emb]
        antecedent_scores = self.sa(pair_emb)
        mask01 = self.sequence_mask(antecedents_len, max_antecedent)
        maskinf = torch.log(mask01).to(self.device)
        assert maskinf.shape[1] <= max_antecedent
        assert antecedent_scores.shape[0] == num_mention
        antecedent_scores = antecedent_scores + maskinf
        antecedents = torch.from_numpy(antecedents).to(self.device)
        mention_scoreij = mention_score.unsqueeze(1) + torch.gather(
            mention_score.unsqueeze(0).expand(num_mention, num_mention),
            dim=1,
            index=antecedents)
        antecedent_scores += mention_scoreij

        antecedent_scores = torch.cat([
            torch.zeros([mention_score.shape[0], 1]).to(self.device),
            antecedent_scores
        ], 1)  # [num_mentions, max_ant + 1]
        return antecedent_scores

    ##############################
    def distance_bin(self, mention_distance):
        bins = torch.zeros(mention_distance.size()).byte().to(self.device)
        rg = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 7], [8, 15], [16, 31],
              [32, 63], [64, 300]]
        for t, k in enumerate(rg):
            i, j = k[0], k[1]
            b = torch.LongTensor([i]).unsqueeze(-1).expand(
                mention_distance.size()).to(self.device)
            m1 = torch.ge(mention_distance, b)
            e = torch.LongTensor([j]).unsqueeze(-1).expand(
                mention_distance.size()).to(self.device)
            m2 = torch.le(mention_distance, e)
            bins = bins + (t + 1) * (m1 & m2)
        return bins.long()

    def get_distance_emb(self, antecedents_tensor):
        num_mention = antecedents_tensor.shape[0]
        max_ant = antecedents_tensor.shape[1]

        assert max_ant <= self.config.max_antecedents
        source = torch.arange(0, num_mention).expand(
            max_ant,
            num_mention).transpose(0,
                                   1).to(self.device)  # [num_mention,max_ant]
        mention_distance = source - antecedents_tensor
        mention_distance_bin = self.distance_bin(mention_distance)
        distance_emb = self.mention_distance_emb(mention_distance_bin)
        distance_emb = self.distance_drop(distance_emb)
        return distance_emb

    def get_pair_emb(self, span_emb, antecedents, mention_speakers_ids, genre):
        emb_dim = span_emb.shape[1]
        num_span = span_emb.shape[0]
        max_ant = antecedents.shape[1]
        assert span_emb.shape[0] == antecedents.shape[0]
        antecedents = torch.from_numpy(antecedents).to(self.device)

        # [num_span,max_ant,emb]
        antecedent_emb = torch.gather(
            span_emb.unsqueeze(0).expand(num_span, num_span, emb_dim),
            dim=1,
            index=antecedents.unsqueeze(2).expand(num_span, max_ant, emb_dim))
        # [num_span,max_ant,emb]
        target_emb_tiled = span_emb.expand((max_ant, num_span, emb_dim))
        target_emb_tiled = target_emb_tiled.transpose(0, 1)

        similarity_emb = antecedent_emb * target_emb_tiled

        pair_emb_list = [target_emb_tiled, antecedent_emb, similarity_emb]

        # get speakers and genre
        if self.config.use_metadata:
            antecedent_speaker_ids = mention_speakers_ids.unsqueeze(0).expand(
                num_span, num_span).gather(dim=1, index=antecedents)
            same_speaker = torch.eq(
                mention_speakers_ids.unsqueeze(1).expand(num_span, max_ant),
                antecedent_speaker_ids)  # [num_mention,max_ant]
            speaker_embedding = self.speaker_emb(same_speaker.long().to(
                self.device))  # [mention_num.max_ant,emb]
            genre_embedding = self.genre_emb(
                torch.LongTensor([genre]).expand(num_span, max_ant).to(
                    self.device))  # [mention_num,max_ant,emb]
            pair_emb_list.append(speaker_embedding)
            pair_emb_list.append(genre_embedding)

        # get distance emb
        if self.config.use_distance:
            distance_emb = self.get_distance_emb(antecedents)
            pair_emb_list.append(distance_emb)

        pair_emb = torch.cat(pair_emb_list, 2)
        return pair_emb

    def sequence_mask(self, len_list, max_len):
        x = np.zeros((len(len_list), max_len))
        for i in range(len(len_list)):
            l = len_list[i]
            for j in range(l):
                x[i][j] = 1
        return torch.from_numpy(x).float()

    def logsumexp(self, value, dim=None, keepdim=False):
        """Numerically stable implementation of the operation

        value.exp().sum(dim, keepdim).log()
        """
        # TODO: torch.max(value, dim=None) threw an error at time of writing
        if dim is not None:
            m, _ = torch.max(value, dim=dim, keepdim=True)
            value0 = value - m
            if keepdim is False:
                m = m.squeeze(dim)
            return m + torch.log(
                torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim))
        else:
            m = torch.max(value)
            sum_exp = torch.sum(torch.exp(value - m))

            return m + torch.log(sum_exp)

    def softmax_loss(self, antecedent_scores, antecedent_labels):
        antecedent_labels = torch.from_numpy(antecedent_labels * 1).to(
            self.device)
        gold_scores = antecedent_scores + torch.log(
            antecedent_labels.float())  # [num_mentions, max_ant + 1]
        marginalized_gold_scores = self.logsumexp(gold_scores,
                                                  1)  # [num_mentions]
        log_norm = self.logsumexp(antecedent_scores, 1)  # [num_mentions]
        return torch.sum(
            log_norm -
            marginalized_gold_scores)  # [num_mentions]reduce_logsumexp

    def get_predicted_antecedents(self, antecedents, antecedent_scores):
        predicted_antecedents = []
        for i, index in enumerate(
                np.argmax(antecedent_scores.detach(), axis=1) - 1):
            if index < 0:
                predicted_antecedents.append(-1)
            else:
                predicted_antecedents.append(antecedents[i, index])
        return predicted_antecedents

    def get_predicted_clusters(self, mention_starts, mention_ends,
                               predicted_antecedents):
        mention_to_predicted = {}
        predicted_clusters = []
        for i, predicted_index in enumerate(predicted_antecedents):
            if predicted_index < 0:
                continue
            assert i > predicted_index
            predicted_antecedent = (int(mention_starts[predicted_index]),
                                    int(mention_ends[predicted_index]))
            if predicted_antecedent in mention_to_predicted:
                predicted_cluster = mention_to_predicted[predicted_antecedent]
            else:
                predicted_cluster = len(predicted_clusters)
                predicted_clusters.append([predicted_antecedent])
                mention_to_predicted[predicted_antecedent] = predicted_cluster

            mention = (int(mention_starts[i]), int(mention_ends[i]))
            predicted_clusters[predicted_cluster].append(mention)
            mention_to_predicted[mention] = predicted_cluster

        predicted_clusters = [tuple(pc) for pc in predicted_clusters]
        mention_to_predicted = {
            m: predicted_clusters[i]
            for m, i in mention_to_predicted.items()
        }

        return predicted_clusters, mention_to_predicted

    def evaluate_coref(self, mention_starts, mention_ends,
                       predicted_antecedents, gold_clusters, evaluator):
        gold_clusters = [tuple(tuple(m) for m in gc) for gc in gold_clusters]
        mention_to_gold = {}
        for gc in gold_clusters:
            for mention in gc:
                mention_to_gold[mention] = gc
        predicted_clusters, mention_to_predicted = self.get_predicted_clusters(
            mention_starts, mention_ends, predicted_antecedents)
        evaluator.update(predicted_clusters, gold_clusters,
                         mention_to_predicted, mention_to_gold)
        return predicted_clusters

    def forward(self, words1, words2, words3, words4, chars, seq_len):
        """
        实际输入都是tensor
        :param sentences: 句子,被fastNLP转化成了numpy,
        :param doc_np: 被fastNLP转化成了Tensor
        :param speaker_ids_np: 被fastNLP转化成了Tensor
        :param genre: 被fastNLP转化成了Tensor
        :param char_index: 被fastNLP转化成了Tensor
        :param seq_len: 被fastNLP转化成了Tensor
        :return:
        """

        sentences = words3
        doc_np = words4
        speaker_ids_np = words2
        genre = words1
        char_index = chars

        # change for fastNLP
        sentences = sentences[0].tolist()
        doc_tensor = doc_np[0]
        speakers_tensor = speaker_ids_np[0]
        genre = genre[0].item()
        char_index = char_index[0]
        seq_len = seq_len[0].cpu().numpy()

        # 类型

        # doc_tensor = torch.from_numpy(doc_np).to(self.device)
        # speakers_tensor = torch.from_numpy(speaker_ids_np).to(self.device)
        mention_emb_list = []

        word_emb = self.emb(doc_tensor)
        word_emb_list = [word_emb]
        if self.config.use_CNN:
            # [batch, length, char_length, char_dim]
            char = self.char_emb(char_index)
            char_size = char.size()
            # first transform to [batch *length, char_length, char_dim]
            # then transpose to [batch * length, char_dim, char_length]
            char = char.view(char_size[0] * char_size[1], char_size[2],
                             char_size[3]).transpose(1, 2)

            # put into cnn [batch*length, char_filters, char_length]
            # then put into maxpooling [batch * length, char_filters]
            char_over_cnn, _ = self.conv1(char).max(dim=2)
            # reshape to [batch, length, char_filters]
            char_over_cnn = torch.tanh(char_over_cnn).view(
                char_size[0], char_size[1], -1)
            word_emb_list.append(char_over_cnn)

            char_over_cnn, _ = self.conv2(char).max(dim=2)
            char_over_cnn = torch.tanh(char_over_cnn).view(
                char_size[0], char_size[1], -1)
            word_emb_list.append(char_over_cnn)

            char_over_cnn, _ = self.conv3(char).max(dim=2)
            char_over_cnn = torch.tanh(char_over_cnn).view(
                char_size[0], char_size[1], -1)
            word_emb_list.append(char_over_cnn)

        # word_emb = torch.cat(word_emb_list, dim=2)

        # use elmo or not
        if self.config.use_elmo:
            # 如果确实被截断了
            if doc_tensor.shape[0] == 50 and len(sentences) > 50:
                sentences = sentences[0:50]
            elmo_embedding, elmo_mask = self.elmo.batch_to_embeddings(
                sentences)
            elmo_embedding = elmo_embedding.to(
                self.device
            )  # [sentence_num,max_sent_len,3,1024]--[sentence_num,max_sent,1024]
            elmo_embedding = elmo_embedding[:, 0, :, :] * self.elmo_args[0] + elmo_embedding[:, 1, :, :] * \
                             self.elmo_args[1] + elmo_embedding[:, 2, :, :] * self.elmo_args[2]
            word_emb_list.append(elmo_embedding)
        # print(word_emb_list[0].shape)
        # print(word_emb_list[1].shape)
        # print(word_emb_list[2].shape)
        # print(word_emb_list[3].shape)
        # print(word_emb_list[4].shape)

        word_emb = torch.cat(word_emb_list, dim=2)

        word_emb = self.emb_dropout(word_emb)
        # word_emb_elmo = self.emb_dropout(word_emb_elmo)
        lstm_out = self._reorder_lstm(word_emb, seq_len)
        flatten_lstm = self.flat_lstm(lstm_out, seq_len)  # [word_num,emb]
        flatten_lstm = self.bilstm_drop(flatten_lstm)
        # TODO 没有按照论文写
        flatten_word_emb = self.flat_lstm(word_emb, seq_len)  # [word_num,emb]

        mention_start, mention_end = self.get_mention_start_end(
            seq_len)  # [mention_num]
        self.mention_start_np = mention_start  # [mention_num] np
        self.mention_end_np = mention_end
        mention_num = mention_start.shape[0]
        emb_start, emb_end = self.get_mention_emb(
            flatten_lstm, mention_start, mention_end)  # [mention_num,emb]

        # list
        mention_emb_list.append(emb_start)
        mention_emb_list.append(emb_end)

        if self.config.use_width:
            mention_width_index = mention_end - mention_start
            mention_width_tensor = torch.from_numpy(mention_width_index).to(
                self.device)  # [mention_num]
            mention_width_emb = self.feature_emb(mention_width_tensor)
            mention_width_emb = self.feature_emb_dropout(mention_width_emb)
            mention_emb_list.append(mention_width_emb)

        if self.config.model_heads:
            mention_index = self.get_mention_index(
                mention_start,
                self.config.span_width)  # [mention_num,max_mention]
            log_mask_tensor = self.get_mask(
                mention_start, mention_end).float().to(
                    self.device)  # [mention_num,max_mention]
            alpha = self.atten(flatten_lstm).to(self.device)  # [word_num]

            # 得到attention
            mention_head_score = torch.gather(
                alpha.expand(mention_num, -1), 1, mention_index).float().to(
                    self.device)  # [mention_num,max_mention]
            mention_attention = F.softmax(mention_head_score + log_mask_tensor,
                                          dim=1)  # [mention_num,max_mention]

            # TODO flatte lstm
            word_num = flatten_lstm.shape[0]
            lstm_emb = flatten_lstm.shape[1]
            emb_num = flatten_word_emb.shape[1]

            # [num_mentions, max_mention_width, emb]
            mention_text_emb = torch.gather(
                flatten_word_emb.unsqueeze(1).expand(word_num,
                                                     self.config.span_width,
                                                     emb_num), 0,
                mention_index.unsqueeze(2).expand(mention_num,
                                                  self.config.span_width,
                                                  emb_num))
            # [mention_num,emb]
            mention_head_emb = torch.sum(mention_attention.unsqueeze(2).expand(
                mention_num, self.config.span_width, emb_num) *
                                         mention_text_emb,
                                         dim=1)
            mention_emb_list.append(mention_head_emb)

        candidate_mention_emb = torch.cat(mention_emb_list,
                                          1)  # [candidate_mention_num,emb]
        candidate_mention_score = self.mention_score(
            candidate_mention_emb)  # [candidate_mention_num]

        antecedent_scores, antecedents, mention_start_tensor, mention_end_tensor = (
            None, None, None, None)
        mention_start_tensor, mention_end_tensor, mention_score, mention_emb = \
            self.sort_mention(mention_start, mention_end, candidate_mention_emb, candidate_mention_score, seq_len)
        mention_speakers_ids = speakers_tensor.index_select(
            dim=0, index=mention_start_tensor)  # num_mention

        antecedents, antecedents_len = self.get_antecedents(
            mention_start_tensor, self.config.max_antecedents)
        antecedent_scores = self.get_antecedents_score(
            mention_emb, mention_score, antecedents, antecedents_len,
            mention_speakers_ids, genre)

        ans = {
            "candidate_mention_score": candidate_mention_score,
            "antecedent_scores": antecedent_scores,
            "antecedents": antecedents,
            "mention_start_tensor": mention_start_tensor,
            "mention_end_tensor": mention_end_tensor
        }

        return ans

    def predict(self, words1, words2, words3, words4, chars, seq_len):
        """
        实际输入都是tensor
        :param sentences: 句子,被fastNLP转化成了numpy,
        :param doc_np: 被fastNLP转化成了Tensor
        :param speaker_ids_np: 被fastNLP转化成了Tensor
        :param genre: 被fastNLP转化成了Tensor
        :param char_index: 被fastNLP转化成了Tensor
        :param seq_len: 被fastNLP转化成了Tensor
        :return:
        """

        sentences = words1
        doc_np = words2
        speaker_ids_np = words3
        genre = words4
        char_index = chars

        # def predict(self, sentences, doc_np, speaker_ids_np, genre, char_index, seq_len):
        ans = self(sentences, doc_np, speaker_ids_np, genre, char_index,
                   seq_len)
        predicted_antecedents = self.get_predicted_antecedents(
            ans["antecedents"], ans["antecedent_scores"].cpu())
        predicted_clusters, mention_to_predicted = self.get_predicted_clusters(
            ans["mention_start_tensor"].cpu(), ans["mention_end_tensor"].cpu(),
            predicted_antecedents)

        return {
            'predicted': predicted_clusters,
            "mention_to_predicted": mention_to_predicted
        }
Exemple #8
0
class NCETModel(nn.Module):
    def __init__(self, config):
        super(NCETModel, self).__init__()
        self.use_elmo = config.use_elmo
        self.embedding_dim = config.embedding_dim
        self.hidden_dim = config.hidden_dim
        self.vocab_size = config.vocab_size
        self.label_size = config.label_size
        self.embed_fropout_rate = config.embed_dropout_rate
        self.dropout_rate = config.dropout_rate
        self.batch_size = config.batch_size
        self.max_word_length = config.max_word_length
        self.max_sent_length = config.max_sent_length
        self.use_state_graph = config.use_state_graph

        self.elmo = ElmoEmbedder(options_file=config.path_elmo_options,
                                 weight_file=config.path_elmo_weights,
                                 cuda_device=-1 if not config.gpu_id else 0)
        self.alpha = nn.Parameter(torch.randn(3, 1))
        if config.path_word_vector:
            self.embedding = nn.Embedding.from_pretrained(
                torch.FloatTensor(config.embeddings))
        else:
            self.embedding = nn.Embedding(self.vocab_size,
                                          self.embedding_dim,
                                          padding_idx=0)
        self.embed_drop = nn.Dropout(self.dropout_rate)
        self.word_lstm = LSTMEncoder(self.embedding_dim + 1, self.hidden_dim,
                                     self.batch_size, self.max_word_length)

        # entity state
        self.state_lstm = LSTMEncoder(self.hidden_dim * 4, self.hidden_dim,
                                      self.batch_size, self.max_sent_length)
        self.state_rnn = CustomRNN(cell_class=StateGraphGRUCell,
                                   input_dim=self.hidden_dim * 4,
                                   hidden_dim=self.hidden_dim,
                                   batch_first=True,
                                   bidirectional=True)

        self.state_mlp = nn.Sequential(
            nn.Dropout(self.dropout_rate),
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            nn.LayerNorm(self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.label_size),
        )
        self.crf = CRF(self.label_size)

        # entity location
        self.location_lstm = LSTMEncoder(self.hidden_dim * 4, self.hidden_dim,
                                         self.batch_size, self.max_sent_length)
        self.location_mlp = nn.Sequential(
            nn.Dropout(self.dropout_rate),
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),
            nn.LayerNorm(self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, 1),
        )

    def forward(self, words, words_idxs, verbs, words_length, sents_length,
                verbs_idxs_sents, entity_idxs_sents, entity_to_idx,
                idx_to_entity, location_candidate_idxs_sents,
                location_candidate_to_idx, idx_to_location_candidate):
        ''' '''
        device = self.embedding.weight.device
        words_idxs = torch.LongTensor(words_idxs).to(
            device)  # [max_word_length, 1]
        verbs = torch.FloatTensor(verbs).to(device)  # [max_word_length]

        # embedding
        if self.use_elmo:
            with torch.no_grad():
                words = [[word[0] for word in words]]
                word_embeddings_elmo, word_mask = self.elmo.batch_to_embeddings(
                    words)
                word_embeddings_elmo = word_embeddings_elmo.permute(
                    0, 2, 3, 1) @ self.alpha
                word_embeddings_elmo = word_embeddings_elmo.view(-1, 1024)
                # word_embeddings_elmo = word_embeddings[:, -1, :, :].view(-1, 256)
            # elmo only
            word_embeddings = word_embeddings_elmo
            # fasttext + elmo
            # word_embeddings_fasttext = self.embedding(words_idxs)
            # word_embeddings = torch.cat([word_embeddings_elmo, word_embeddings_fasttext], dim=-1)
        else:
            word_embeddings = self.embedding(words_idxs)
        word_embeddings = self.embed_drop(word_embeddings)
        word_verb_embeddings = torch.cat(
            [word_embeddings, verbs.unsqueeze(dim=-1)],
            dim=-1).unsqueeze(dim=0)

        # lstm word encoding
        word_lstm_encoding = self.word_lstm(
            word_verb_embeddings,
            [words_length]).squeeze(dim=0)  # [max_length, 2*hidden_dim]

        #-------------------------------------#
        #       entity state tracking         #
        #-------------------------------------#

        # lstm entity encoding
        verbs_sents = []
        for idxs in verbs_idxs_sents:
            if not idxs:
                verbs = torch.zeros(self.hidden_dim * 2).to(device)
            else:
                idxs = torch.stack(idxs).view(1, -1).to(device)
                verbs = torch.mean(F.embedding(idxs, word_lstm_encoding),
                                   dim=1).view(-1)
            verbs_sents.append(verbs)
        start_encoding = torch.zeros(verbs.size()).to(device)
        end_encoding = torch.zeros(verbs.size()).to(device)
        verbs_sents = [start_encoding] + verbs_sents + [end_encoding]
        verbs_sents = torch.stack(verbs_sents)  # [sents_length, hidden_dim*2]

        # rnn entity encoding
        if self.use_state_graph:
            entity_states_input = [None for i in range(len(entity_to_idx))
                                   ]  # placeholder
            for entity, entity_idxs in entity_idxs_sents.items():
                entity_sents = []
                for i, idxs in enumerate(entity_idxs):
                    if not idxs:
                        entity_encoding = torch.zeros(self.hidden_dim *
                                                      4).to(device)
                    else:
                        entity_encoding = word_lstm_encoding[idxs[0].item(
                        ):idxs[1].item() + 1]
                        entity_encoding = torch.mean(entity_encoding,
                                                     dim=0).view(-1)
                        entity_encoding = torch.cat(
                            [entity_encoding, verbs_sents[i]],
                            dim=-1)  # add verb
                    entity_sents.append(entity_encoding)
                start_encoding = torch.zeros(entity_encoding.size()).to(device)
                end_encoding = torch.zeros(entity_encoding.size()).to(device)
                entity_sents = [start_encoding] + entity_sents + [end_encoding]
                entity_sents = torch.stack(entity_sents).unsqueeze(
                    dim=0)  # [batch_size, seq_len, hidden_dim*4]
                entity_states_input[
                    entity_to_idx[entity]] = entity_sents.unsqueeze(dim=-2)
            # concat all entities
            entity_states_input = torch.cat(
                entity_states_input,
                dim=-2)  # [batch_size, seq_len, entity_size, hidden_dim*4]
            entity_states_rnn_encoding, _ = self.state_rnn(
                entity_states_input
            )  # [batch_size, seq_len, entity_size, hidden_dim*2]
            states_logit = self.state_mlp(entity_states_rnn_encoding)
            entity_states_logit = {}
            for entity in entity_to_idx:
                entity_states_logit[
                    entity] = states_logit[:, :,
                                           entity_to_idx[entity], :].squeeze(
                                               dim=-2)

        else:
            entity_states_logit = {}
            for entity, entity_idxs in entity_idxs_sents.items():
                entity_sents = []
                for i, idxs in enumerate(entity_idxs):
                    if not idxs:
                        entity_encoding = torch.zeros(self.hidden_dim *
                                                      4).to(device)
                    else:
                        entity_encoding = word_lstm_encoding[idxs[0].item(
                        ):idxs[1].item() + 1]
                        entity_encoding = torch.mean(entity_encoding,
                                                     dim=0).view(-1)
                        entity_encoding = torch.cat(
                            [entity_encoding, verbs_sents[i]], dim=-1)
                    entity_sents.append(entity_encoding)
                start_encoding = torch.zeros(entity_encoding.size()).to(device)
                end_encoding = torch.zeros(entity_encoding.size()).to(device)
                entity_sents = [start_encoding] + entity_sents + [end_encoding]
                entity_sents = torch.stack(entity_sents).unsqueeze(dim=0)
                states_lstm_encoding = self.state_lstm(
                    entity_sents, [sents_length])[:, :sents_length, :]
                states_logit = self.state_mlp(states_lstm_encoding)
                entity_states_logit[entity] = states_logit

        #----------------------------------------#
        #       entity location tracking         #
        #----------------------------------------#

        # lstm location encoding
        entity_locations_logit = {}
        for entity, entity_idxs in entity_idxs_sents.items():
            entity_sents = []
            for i, idxs in enumerate(entity_idxs):
                if not idxs:
                    entity_encoding = torch.zeros(self.hidden_dim *
                                                  2).to(device)
                else:
                    entity_encoding = word_lstm_encoding[idxs[0].item(
                    ):idxs[1].item() + 1]
                    entity_encoding = torch.mean(entity_encoding,
                                                 dim=0).view(-1)
                entity_sents.append(entity_encoding)
            start_encoding = torch.zeros(entity_encoding.size()).to(device)
            end_encoding = torch.zeros(entity_encoding.size()).to(device)
            entity_sents = [start_encoding] + entity_sents + [end_encoding]
            entity_sents = torch.stack(entity_sents).unsqueeze(
                dim=0)  # [batch_size, sents_len, hidden_dim * 2]

            location_logits = [None] * len(location_candidate_to_idx)
            for location_candidate, location_candidate_idxs in location_candidate_idxs_sents.items(
            ):
                location_sents = []
                for i, idxs in enumerate(location_candidate_idxs):
                    if not idxs:
                        location_encoding = torch.zeros(self.hidden_dim *
                                                        2).to(device)
                    else:
                        location_encoding = word_lstm_encoding[idxs[0].item(
                        ):idxs[1].item() + 1]
                        location_encoding = torch.mean(location_encoding,
                                                       dim=0).view(-1)
                    location_sents.append(location_encoding)
                start_encoding = torch.zeros(
                    location_encoding.size()).to(device)
                end_encoding = torch.zeros(location_encoding.size()).to(device)
                location_sents = [start_encoding
                                  ] + location_sents + [end_encoding]
                location_sents = torch.stack(location_sents).unsqueeze(
                    dim=0)  # [batch_size, sents_len, hidde_dim * 2]
                # concate entity & locaation
                entity_location_sents = torch.cat(
                    [entity_sents, location_sents],
                    dim=-1)  # [batch_size, sents_len, hidden_dim * 4]
                location_lstm_encoding = self.location_lstm(
                    entity_location_sents, [sents_length])[:, :sents_length, :]
                location_logit = self.location_mlp(location_lstm_encoding)
                location_logits[location_candidate_to_idx[
                    location_candidate]] = location_logit

            location_logits = torch.cat(location_logits, dim=-1)
            entity_locations_logit[entity] = location_logits

        return entity_states_logit, entity_locations_logit
Exemple #9
0
class ELMOCRFSegModel(LSTMCRFSegModel):
    def __init__(self, args, word_vocab):
        super().__init__(args, word_vocab)

        # import ElmoEmbedder here so that the cuda_visible_divices can work
        from allennlp.commands.elmo import ElmoEmbedder
        self.elmo = ElmoEmbedder(cuda_device=0 if args.gpu is not None else -1)

    def _setup_placeholders(self):
        self.placeholders = {
            'input_words': tf.placeholder(tf.int32, shape=[None, None]),
            'input_length': tf.placeholder(tf.int32, shape=[None]),
            'elmo_vectors': tf.placeholder(tf.float32,
                                           shape=[None, 3, None, 1024]),
            'seg_labels': tf.placeholder(tf.float32, shape=[None, None]),
            'dropout_keep_prob': tf.placeholder(tf.float32)
        }

    def _embed(self):
        with tf.device('/cpu:0'):
            if self.word_vocab.embeddings is None:
                word_emb_init = tf.random_normal_initializer()
            else:
                word_emb_init = tf.constant_initializer(
                    self.word_vocab.embeddings)
            self.word_embeddings = tf.get_variable(
                'word_embeddings',
                shape=(self.word_vocab.size(), self.word_vocab.embed_dim),
                initializer=word_emb_init,
                trainable=False)
            self.embedded_words = tf.nn.embedding_lookup(
                self.word_embeddings, self.placeholders['input_words'])
        self.elmo_weights = tf.nn.softmax(
            tf.get_variable('elmo_weights', [3],
                            dtype=tf.float32,
                            trainable=True))
        self.scale_para = tf.get_variable('scale_para', [1],
                                          dtype=tf.float32,
                                          trainable=True)
        self.elmo_vectors = self.scale_para * (
            self.elmo_weights[0] *
            self.placeholders['elmo_vectors'][:, 0, :, :] +
            self.elmo_weights[1] *
            self.placeholders['elmo_vectors'][:, 1, :, :] +
            self.elmo_weights[2] *
            self.placeholders['elmo_vectors'][:, 2, :, :])
        self.embedded_inputs = tf.concat(
            [self.embedded_words, self.elmo_vectors], -1)
        self.embedded_inputs = tf.nn.dropout(
            self.embedded_inputs, self.placeholders['dropout_keep_prob'])

    def _compute_loss(self):
        self.loss = tf.reduce_mean(-self.log_likelyhood, 0)
        if self.weight_decay > 0:
            with tf.variable_scope('l2_loss'):
                l2_loss = tf.add_n([
                    tf.nn.l2_loss(v) for v in tf.trainable_variables()
                    if 'bias' not in v.name
                ])
            self.loss += self.weight_decay * l2_loss

    def _train_epoch(self, train_batches, print_every_n_batch):
        total_loss, total_batch_num = 0, 0
        for bitx, batch in enumerate(train_batches):
            feed_dict = {
                self.placeholders['input_words']: batch['word_ids'],
                self.placeholders['input_length']: batch['length'],
                self.placeholders['seg_labels']: batch['seg_labels']
            }
            elmo_vectors, mask = self.elmo.batch_to_embeddings(
                [sample['words'] for sample in batch['raw_data']])
            feed_dict[self.placeholders['elmo_vectors']] = np.asarray(
                elmo_vectors.data)
            dkp = self.placeholders['dropout_keep_prob']
            feed_dict[dkp] = self.dropout_keep_prob

            _, loss, grad_norm = self.sess.run(
                [self.train_op, self.loss, self.grad_norm], feed_dict)

            if (bitx != 0 and print_every_n_batch > 0
                    and bitx % print_every_n_batch == 0):
                self.logger.info('bitx: {}, loss: {}, grad: {}'.format(
                    bitx, loss, grad_norm))
            total_loss += loss
            total_batch_num += 1
        return total_loss / total_batch_num

    def segment(self, batch):
        feed_dict = {
            self.placeholders['input_words']: batch['word_ids'],
            self.placeholders['input_length']: batch['length']
        }
        elmo_vectors, mask = self.elmo.batch_to_embeddings(
            [sample['words'] for sample in batch['raw_data']])
        feed_dict[self.placeholders['elmo_vectors']] = np.asarray(
            elmo_vectors.data)
        feed_dict[self.placeholders['dropout_keep_prob']] = 1.0

        scores, trans_params = self.sess.run([self.scores, self.trans_params],
                                             feed_dict)

        batch_pred_segs = []
        for sample_idx in range(len(batch['raw_data'])):
            length = batch['length'][sample_idx]
            viterbi_seq, viterbi_score = tc.crf.viterbi_decode(
                scores[sample_idx][:length], trans_params)
            pred_segs = []
            for word_idx, label in enumerate(viterbi_seq):
                if label == 1:
                    pred_segs.append(word_idx)
            batch_pred_segs.append(pred_segs)
        return batch_pred_segs
Exemple #10
0
# coding: utf-8 or # -*- coding: utf-8 -*-
"""
NAACL2018 一种新的embedding方法--原理与实验 Deep contextualized word representations (ELMo)
https://cstsunfu.github.io/2018/06/ELMo/
"""
from allennlp.commands.elmo import ElmoEmbedder
elmo = ElmoEmbedder(
    options_file=
    '/Users/tony/myfiles/spark/share/python-projects/deep_trading/dataset/elmo_embedder/elmo_options.json',
    weight_file=
    '/Users/tony/myfiles/spark/share/python-projects/deep_trading/dataset/elmo_embedder/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5',
    cuda_device=-1)
context_tokens = [['I', 'love', 'you', '.'],
                  ['Sorry', ',', 'I', 'don', "'t", 'love', 'you', '.']]
elmo_embedding, elmo_mask = elmo.batch_to_embeddings(context_tokens)
print(elmo_embedding)
print(elmo_mask)
Exemple #11
0
class MLPRegression(Module):
    def __init__(self,
                 embed_params,
                 attention_type,
                 all_attributes,
                 output_size,
                 layers,
                 hand_feat_dim,
                 device="cpu",
                 embedding_dim=1024,
                 turn_on_hand_feats=False,
                 turn_on_embeddings=False):
        '''
            Super class for training
        '''
        super(MLPRegression, self).__init__()

        # Set model constants and embeddings
        self.device = device
        self.layers = layers
        self.embedding_dim = embedding_dim
        self.output_size = output_size
        self.attention_type = attention_type
        self.all_attributes = all_attributes
        self.is_hand_feats_on = turn_on_hand_feats
        self.is_embeds_on = turn_on_embeddings

        # Initialise embeddings
        if self.is_embeds_on:
            self._init_embeddings(embed_params)
        else:
            self.reduced_embedding_dim = 0
        if self.is_hand_feats_on:
            self.hand_feat_dim = hand_feat_dim
        else:
            self.hand_feat_dim = 0

        # Initialise regression layers and parameters
        self._init_regression()

        # Initialise attention parameters
        self._init_attention()

    def _init_embeddings(self, embedding_params):
        '''
            Initialise embeddings
        '''
        if type(embedding_params[0]) is str:
            self.vocab = None
            options_file = embedding_params[0]
            weight_file = embedding_params[1]
            self.embeddings = ElmoEmbedder(options_file,
                                           weight_file,
                                           cuda_device=0)
            # self.embeddings = Elmo(options_file, weight_file, 3, dropout=0)
            self.reduced_embedding_dim = 256

            # ELMO tuning parameters
            self.embed_linmap_argpred_lower = Linear(
                self.embedding_dim, self.reduced_embedding_dim)
            self.embed_linmap_argpred_mid = Linear(self.embedding_dim,
                                                   self.reduced_embedding_dim,
                                                   bias=False)
            self.embed_linmap_argpred_top = Linear(self.embedding_dim,
                                                   self.reduced_embedding_dim,
                                                   bias=False)

        else:
            # GloVe embeddings
            glove_embeds = embedding_params[0]
            self.vocab = embedding_params[1]
            self.num_embeddings = len(self.vocab)
            self.embeddings = torch.nn.Embedding(self.num_embeddings,
                                                 self.embedding_dim,
                                                 max_norm=None,
                                                 norm_type=2,
                                                 scale_grad_by_freq=False,
                                                 sparse=False)
            self.reduced_embedding_dim = 300

            self.embeddings.weight.data.copy_(
                torch.from_numpy(glove_embeds.values))
            self.embeddings.weight.requires_grad = False
            self.vocab_hash = {w: i for i, w in enumerate(self.vocab)}
            # self.embed_linmap = Linear(self.embedding_dim, self.reduced_embedding_dim)

    def _init_regression(self):
        '''
            Define the linear maps
        '''

        # Output regression parameters
        self.linmaps = ModuleDict(
            {prot: ModuleList([])
             for prot in self.all_attributes.keys()})

        for prot in self.all_attributes.keys():
            last_size = self.reduced_embedding_dim
            # Handle varying size of dimension depending on representation
            if self.attention_type[prot]['repr'] == "root":
                if self.attention_type[prot]['context'] != "none":
                    last_size *= 2
            else:
                if self.attention_type[prot]['context'] == "none":
                    last_size *= 2
                else:
                    last_size *= 3
            # self.layer_norm[prot] = torch.nn.LayerNorm(last_size)
            last_size += self.hand_feat_dim
            for out_size in self.layers:
                linmap = Linear(last_size, out_size)
                self.linmaps[prot].append(linmap)
                last_size = out_size
            final_linmap = Linear(last_size, self.output_size)
            self.linmaps[prot].append(final_linmap)

        # Dropout layer
        self.dropout = Dropout()

    def _regression_nonlinearity(self, x):
        return F.relu(x)

    def _init_attention(self):
        '''
            Initialises the attention map vector/matrix

            Takes attention_type-Span, Sentence, Span-param, Sentence-param
            as a parameter to decide the size of the attention matrix
        '''

        self.att_map_repr = ModuleDict({})
        self.att_map_W = ModuleDict({})
        self.att_map_V = ModuleDict({})
        self.att_map_context = ModuleDict({})
        for prot in self.attention_type.keys():
            # Token representation
            if self.attention_type[prot]['repr'] == "span":
                repr_dim = 2 * self.reduced_embedding_dim
                self.att_map_repr[prot] = Linear(self.reduced_embedding_dim,
                                                 1,
                                                 bias=False)
                self.att_map_W[prot] = Linear(self.reduced_embedding_dim,
                                              self.reduced_embedding_dim)
                self.att_map_V[prot] = Linear(self.reduced_embedding_dim,
                                              1,
                                              bias=False)
            elif self.attention_type[prot]['repr'] == "param":
                repr_dim = 2 * self.reduced_embedding_dim
                self.att_map_repr[prot] = Linear(self.reduced_embedding_dim,
                                                 self.reduced_embedding_dim,
                                                 bias=False)
                self.att_map_W[prot] = Linear(2 * self.reduced_embedding_dim,
                                              self.reduced_embedding_dim)
                self.att_map_V[prot] = Linear(self.reduced_embedding_dim,
                                              1,
                                              bias=False)
            else:
                repr_dim = self.reduced_embedding_dim

            # Context representation
            # There is no attention for argument davidsonian
            if self.attention_type[prot]['context'] == 'param':
                self.att_map_context[prot] = Linear(repr_dim,
                                                    self.reduced_embedding_dim,
                                                    bias=False)
            elif self.attention_type[prot][
                    'context'] == 'david' and prot == 'arg':
                self.att_map_context[prot] = Linear(repr_dim,
                                                    self.reduced_embedding_dim,
                                                    bias=False)

    def _choose_tokens(self, batch, lengths):
        '''
            Extracts tokens from a batch at specified position(lengths)
            batch - batch_size x max_sent_length x embed_dim
            lengths - batch_size x max_span_length x embed_dim
        '''
        idx = (lengths).unsqueeze(2).expand(-1, -1, batch.shape[2])
        return batch.gather(1, idx).squeeze()

    def _get_inputs(self, words):
        '''
           Return ELMO embeddings as root, span or param span
        '''
        if not self.vocab:
            raw_embeds, masks = self.embeddings.batch_to_embeddings(words)
            # raw_ = self.embeddings(batch_to_ids(words).to(self.device))
            # raw_embeds, masks = torch.cat([x.unsqueeze(1) for x in raw_['elmo_representations']], dim=1), raw_['mask']
            masks = masks.unsqueeze(2).repeat(
                1, 1, self.reduced_embedding_dim).byte()
            embedded_inputs = (self.embed_linmap_argpred_lower(
                raw_embeds[:, 0, :, :].squeeze()) +
                               self.embed_linmap_argpred_mid(
                                   raw_embeds[:, 1, :, :].squeeze()) +
                               self.embed_linmap_argpred_top(
                                   raw_embeds[:, 2, :, :].squeeze()))
            masked_embedded_inputs = embedded_inputs * masks.float()
            return masked_embedded_inputs, masks
        else:
            # Glove embeddings
            indices = [[self.vocab_hash[word] for word in sent]
                       for sent in words]
            indices = torch.tensor(indices,
                                   dtype=torch.long,
                                   device=self.device)
            embeddings = self.embeddings(indices)
            masks = (embeddings != 0)[:, :, :self.reduced_embedding_dim].byte()
            # reduced_embeddings = self.embed_linmap(embeddings) * masks.float()
            return embeddings, masks

    def _get_representation(self,
                            prot,
                            embeddings,
                            roots,
                            spans,
                            context=False):
        '''
            returns the representation required from arguments passed by
            running attention based on arguments passed
        '''

        # Get token(pred/arg) representation
        rep_type = self.attention_type[prot]['repr']

        roots_rep_raw = self._choose_tokens(embeddings, roots)
        if len(roots_rep_raw.shape) == 1:
            roots_rep_raw = roots_rep_raw.unsqueeze(0)

        if rep_type == "root":
            token_rep = roots_rep_raw
        else:
            masks_spans = (spans == -1)
            spans[spans == -1] = 0
            spans_rep_raw = self._choose_tokens(embeddings, spans)

            if len(spans_rep_raw.shape) == 1:
                spans_rep_raw = spans_rep_raw.unsqueeze(0).unsqueeze(1)
            elif len(spans_rep_raw.shape) == 2:
                if spans.shape[0] == 1:
                    spans_rep_raw = spans_rep_raw.unsqueeze(0)
                elif spans.shape[1] == 1:
                    spans_rep_raw = spans_rep_raw.unsqueeze(1)

            if rep_type == "span":
                att_raw = self.att_map_repr[prot](spans_rep_raw).squeeze()
                # additive attention
                # att_raw_w = torch.relu(self.att_map_W[prot](for_att))
                # att_raw = self.att_map_V[prot](att_raw_w).squeeze()
            elif rep_type == "param":
                # att_param = torch.relu(self.att_map_repr[prot](roots_rep_raw)).unsqueeze(2)
                # att_raw = torch.matmul(spans_rep_raw, att_param).squeeze()
                # additive attention
                for_att = torch.cat(
                    (spans_rep_raw, roots_rep_raw.unsqueeze(1).repeat(
                        1, spans_rep_raw.shape[1], 1)),
                    dim=2)
                att_raw_w = torch.relu(self.att_map_W[prot](for_att))
                att_raw = self.att_map_V[prot](att_raw_w).squeeze()

            att_raw = att_raw.masked_fill(masks_spans, -1e9)
            att = F.softmax(att_raw, dim=1)
            att = self.dropout(att)
            pure_token_rep = torch.matmul(
                att.unsqueeze(2).permute(0, 2, 1), spans_rep_raw).squeeze()
            if not context:
                token_rep = torch.cat((roots_rep_raw, pure_token_rep), dim=1)
            else:
                token_rep = pure_token_rep

        return token_rep

    def _run_attention(self, prot, embeddings, roots, spans, context_roots,
                       context_spans, masks):
        '''
            Various attention mechanisms implemented
        '''

        # Get the required representation for pred/arg
        token_rep = self._get_representation(prot=prot,
                                             embeddings=embeddings,
                                             roots=roots,
                                             spans=spans)

        # Get the required representation for context of pred/arg
        context_type = self.attention_type[prot]['context']

        if context_type == "none":
            context_rep = None

        elif context_type == "param":
            # Sentence level attention
            att_param = torch.relu(
                self.att_map_context[prot](token_rep)).unsqueeze(1)
            att_raw = torch.matmul(embeddings, att_param.permute(0, 2, 1))
            att_raw = att_raw.masked_fill(masks[:, :, 0:1] == 0, -1e9)
            att = F.softmax(att_raw, dim=1)
            att = self.dropout(att)
            context_rep = torch.matmul(att.permute(0, 2, 1),
                                       embeddings).squeeze()

        elif context_type == "david":
            if prot == "arg":
                prot_context = 'pred'
                context_roots = torch.tensor(context_roots,
                                             dtype=torch.long,
                                             device=self.device).unsqueeze(1)
                max_span = max([len(a) for a in context_spans])
                context_spans = torch.tensor([
                    a + [-1 for i in range(max_span - len(a))]
                    for a in context_spans
                ],
                                             dtype=torch.long,
                                             device=self.device)
                context_rep = self._get_representation(context=True,
                                                       prot=prot_context,
                                                       embeddings=embeddings,
                                                       roots=context_roots,
                                                       spans=context_spans)
            else:
                prot_context = 'arg'
                context_rep = None
                for i, ctx_root in enumerate(context_roots):
                    ctx_root = torch.tensor(ctx_root,
                                            dtype=torch.long,
                                            device=self.device).unsqueeze(1)
                    max_span = max([len(a) for a in context_spans[i]])
                    ctx_span = torch.tensor([
                        a + [-1 for i in range(max_span - len(a))]
                        for a in context_spans[i]
                    ],
                                            dtype=torch.long,
                                            device=self.device)
                    sentence = embeddings[i, :, :].unsqueeze(0).repeat(
                        len(ctx_span), 1, 1)
                    ctx_reps = self._get_representation(context=True,
                                                        prot=prot_context,
                                                        embeddings=sentence,
                                                        roots=ctx_root,
                                                        spans=ctx_span)

                    if len(ctx_reps.shape) == 1:
                        ctx_reps = ctx_reps.unsqueeze(0)
                    # Attention over arguments
                    att_nd_param = torch.relu(self.att_map_context[prot](
                        token_rep[i, :].unsqueeze(0)))
                    att_raw = torch.matmul(att_nd_param,
                                           ctx_reps.permute(1, 0))
                    att = F.softmax(att_raw, dim=1)
                    ctx_rep_final = torch.matmul(att, ctx_reps)
                    if i:
                        context_rep = torch.cat((context_rep, ctx_rep_final),
                                                dim=0).squeeze()
                    else:
                        context_rep = ctx_rep_final

        if context_rep is not None:
            inputs_for_regression = torch.cat((token_rep, context_rep), dim=1)
        else:
            inputs_for_regression = token_rep

        return inputs_for_regression

    def _run_regression(self, prot, x):
        '''
            Run regression to get 3 attribute vector
        '''
        for i, lin_map in enumerate(self.linmaps[prot]):
            if i:
                x = self._regression_nonlinearity(x)
                x = self.dropout(x)

            x = lin_map(x)

        return torch.sigmoid(x)

    def forward(self, prot, words, roots, spans, context_roots, context_spans,
                hand_feats):
        """
            Forward propagation of activations
        """

        if self.is_embeds_on:
            inputs_for_attention, masks = self._get_inputs(words)
            inputs_for_regression = self._run_attention(
                prot=prot,
                embeddings=inputs_for_attention,
                roots=roots,
                spans=spans,
                context_roots=context_roots,
                context_spans=context_spans,
                masks=masks)
            if self.is_hand_feats_on:
                inputs_for_regression = torch.cat(
                    (inputs_for_regression, hand_feats), dim=1)
        elif self.is_hand_feats_on:
            inputs_for_regression = hand_feats
        else:
            sys.exit('You need some word representation!!')

        outputs = self._run_regression(prot=prot, x=inputs_for_regression)
        return outputs
Exemple #12
0
class ELMOCRFSegModel(LSTMCRFSegModel):
    def __init__(self, args, word_vocab):
        super().__init__(args, word_vocab)

        # import ElmoEmbedder here so that the cuda_visible_divices can work
        from allennlp.commands.elmo import ElmoEmbedder

        if os.path.exists(
                os.path.join(args.base_data_dir, args.elmo_dir,
                             args.weights_file)):
            weights_file = os.path.join(args.base_data_dir, args.elmo_dir,
                                        args.weights_file)
            options_file = os.path.join(args.base_data_dir, args.elmo_dir,
                                        args.options_file)
            self.elmo = ElmoEmbedder(options_file=options_file,
                                     weight_file=weights_file,
                                     cuda_device=int(args.gpu))
        else:
            print(
                "Elmo vectors NOT found at designated path. Downloading from online..."
            )
            self.elmo = ElmoEmbedder(cuda_device=int(args.gpu))

    def _setup_placeholders(self):
        self.placeholders = {
            'input_words': tf.placeholder(tf.int32, shape=[None, None]),
            'input_length': tf.placeholder(tf.int32, shape=[None]),
            'elmo_vectors': tf.placeholder(tf.float32,
                                           shape=[None, 3, None, 1024]),
            'seg_labels': tf.placeholder(tf.float32, shape=[None, None]),
            'dropout_keep_prob': tf.placeholder(tf.float32)
        }

    def _embed(self):
        with tf.device('/cpu:0'):
            word_emb_init = tf.constant_initializer(self.word_vocab.embeddings) if self.word_vocab.embeddings is not None \
                else tf.random_normal_initializer()
            self.word_embeddings = tf.get_variable(
                'word_embeddings',
                shape=(self.word_vocab.size(), self.word_vocab.embed_dim),
                initializer=word_emb_init,
                trainable=False)
            self.embedded_words = tf.nn.embedding_lookup(
                self.word_embeddings, self.placeholders['input_words'])
        self.elmo_weights = tf.nn.softmax(
            tf.get_variable('elmo_weights', [3],
                            dtype=tf.float32,
                            trainable=True))
        self.scale_para = tf.get_variable('scale_para', [1],
                                          dtype=tf.float32,
                                          trainable=True)
        self.elmo_vectors = self.scale_para * (
            self.elmo_weights[0] *
            self.placeholders['elmo_vectors'][:, 0, :, :] +
            self.elmo_weights[1] *
            self.placeholders['elmo_vectors'][:, 1, :, :] +
            self.elmo_weights[2] *
            self.placeholders['elmo_vectors'][:, 2, :, :])
        self.embedded_inputs = tf.concat(
            [self.embedded_words, self.elmo_vectors], -1)
        self.embedded_inputs = tf.nn.dropout(
            self.embedded_inputs, self.placeholders['dropout_keep_prob'])

    def _compute_loss(self):
        self.loss = tf.reduce_mean(-self.log_likelyhood, 0)
        if self.weight_decay > 0:
            with tf.variable_scope('l2_loss'):
                l2_loss = tf.add_n([
                    tf.nn.l2_loss(v) for v in tf.trainable_variables()
                    if 'bias' not in v.name
                ])
            self.loss += self.weight_decay * l2_loss

    def _train_epoch(self, train_batches, print_every_n_batch):
        total_loss, total_batch_num = 0, 0
        for bitx, batch in enumerate(train_batches):
            feed_dict = {
                self.placeholders['input_words']: batch['word_ids'],
                self.placeholders['input_length']: batch['length'],
                self.placeholders['seg_labels']: batch['seg_labels']
            }
            elmo_vectors, mask = self.elmo.batch_to_embeddings(
                [sample['words'] for sample in batch['raw_data']])
            feed_dict[self.placeholders['elmo_vectors']] = np.asarray(
                elmo_vectors.cpu().data)
            feed_dict[self.placeholders[
                'dropout_keep_prob']] = self.dropout_keep_prob

            _, loss, grad_norm = self.sess.run(
                [self.train_op, self.loss, self.grad_norm], feed_dict)

            if bitx != 0 and print_every_n_batch > 0 and bitx % print_every_n_batch == 0:
                self.logger.info('bitx: {}, loss: {}, grad: {}'.format(
                    bitx, loss, grad_norm))
            total_loss += loss
            total_batch_num += 1
        return total_loss / total_batch_num

    def segment(self, batch):
        feed_dict = {
            self.placeholders['input_words']: batch['word_ids'],
            self.placeholders['input_length']: batch['length']
        }
        elmo_vectors, mask = self.elmo.batch_to_embeddings(
            [sample['words'] for sample in batch['raw_data']])
        feed_dict[self.placeholders['elmo_vectors']] = np.asarray(
            elmo_vectors.data.cpu())
        feed_dict[self.placeholders['dropout_keep_prob']] = 1.0

        scores, trans_params = self.sess.run([self.scores, self.trans_params],
                                             feed_dict)

        batch_pred_segs = []
        # log_likes = []
        for sample_idx in range(len(batch['raw_data'])):
            length = batch['length'][sample_idx]
            viterbi_seq, viterbi_score = tc.crf.viterbi_decode(
                scores[sample_idx][:length], trans_params)

            # with tf.Graph().as_default(), tf.Session() as session:
            #     length_tensor = tf.expand_dims(c2t(length), axis=0)
            #     viterbi_seq_tensor = tf.expand_dims(c2t(viterbi_seq), axis=0)
            #     scores_tensor = c2t(scores)
            #     trans_params_tensor = c2t(trans_params)
            #     log_likelihood, tparams = tc.crf.crf_log_likelihood(scores_tensor, viterbi_seq_tensor, length_tensor, trans_params_tensor)
            #     log_like_numpy = session.run(log_likelihood)
            # log_likes.append(log_like_numpy)

            # tf.get_default_graph().finalize()
            pred_segs = []
            for word_idx, label in enumerate(viterbi_seq):
                if label == 1:
                    pred_segs.append(word_idx)
            batch_pred_segs.append(pred_segs)
        return batch_pred_segs  # , log_likes
Exemple #13
0
class TextEmbedding:
    def __init__(self):
        # self.w2v_embdding_size = 100
        # self.w2v = Word2Vec.load("./w2v/w2v_model")
        # self.vocabulary = set(open("./w2v/text8_vocabulary.txt").read().split("\n"))
        # self.default_word = "a"

        self.bert_tokenizer = None
        self.bert = None

        # self.gpt2_tokenizer = None
        # self.gpt2 = None
        #
        # self.transformer_tokenizer = None
        # self.transformer = None
        #
        # self.elmo = None
        #
        # self.bert_map = {}

    def Get_Word2Vec_Representation(self, examples):
        for example in examples:
            representation = []
            for word in example.fgt_channels[0].split(" "):
                if (word in self.vocabulary):
                    representation.append(self.w2v[word])
                else:
                    representation.append(self.w2v[self.default_word])
            while (len(representation) < pb.fgt_maxlength):
                representation.append(np.zeros(self.w2v_embdding_size))
            example.word2vec_mat = representation[0:pb.fgt_maxlength]

    def Get_RNN_Representation(self, examples):
        for example in examples:
            representation = []
            for word in example.fgt_channels[0].split(" "):
                if (word in self.vocabulary):
                    representation.append(self.w2v[word])
                else:
                    representation.append(self.w2v[self.default_word])
            while (len(representation) < pb.fgt_maxlength):
                representation.append(np.zeros(self.w2v_embdding_size))
            example.rnn_mat = representation[0:pb.fgt_maxlength]

    def Get_Char_Representation(self, examples):
        for example in examples:
            representation = []
            for char in example.fgt_channels[0]:
                if (char in self.vocabulary):
                    representation.append(self.w2v[char])
                else:
                    _rep = [
                        0.0 for _ in range(len(self.w2v[self.default_word]))
                    ]
                    if (char in pb.label_histogram_x):
                        _rep[pb.label_histogram_x.index(char)] = 1
                    representation.append(_rep)
            while (len(representation) < pb.fgt_maxlength * pb.word_maxlength):
                representation.append(np.zeros(self.w2v_embdding_size))
            example.char_mat = representation[0:pb.fgt_maxlength *
                                              pb.word_maxlength]

    def Get_Bert_Representation(self, examples_train, examples_test):

        train_rep_file = "./data/" + pb.dataset + "_train_" + "bert"
        test_rep_file = "./data/" + pb.dataset + "_test_" + "bert"

        if (os.path.exists(train_rep_file) == True
                and os.path.exists(test_rep_file) == True):
            with open(train_rep_file, 'rb') as file:
                examples_train_rep = pickle.load(file)
                for i, example in enumerate(examples_train):
                    example.bert_mat = examples_train_rep[i]
            with open(test_rep_file, 'rb') as file:
                examples_test_rep = pickle.load(file)
                for i, example in enumerate(examples_test):
                    example.bert_mat = examples_test_rep[i]
        else:
            examples = []
            for example in examples_train:
                examples.append(example)
            for example in examples_test:
                examples.append(example)

            for i, example in enumerate(examples):

                if (self.bert_tokenizer == None):
                    self.bert_tokenizer = BertTokenizer.from_pretrained(
                        'bert-base-uncased')

                text = "[CLS] " + example.fgt_channels[0] + " [SEP]"
                text = text.replace("  ", " ")
                tokenized_text = self.bert_tokenizer.tokenize(text)

                indexed_tokens = self.bert_tokenizer.convert_tokens_to_ids(
                    tokenized_text)
                segments_ids = [0 for _ in tokenized_text]

                tokens_tensor = torch.tensor([indexed_tokens])
                segments_tensors = torch.tensor([segments_ids])

                if (self.bert == None):
                    self.bert = BertModel.from_pretrained('bert-base-uncased')
                    self.bert.eval()

                with torch.no_grad():
                    representation, sum = [], 0

                    encoded_layers, _ = self.bert(tokens_tensor,
                                                  segments_tensors)
                    a, b = encoded_layers[0].numpy(
                    ).shape[1], encoded_layers[0].numpy().shape[2]
                    representation = np.zeros((a, b))

                    for layer in encoded_layers:
                        for words in layer.numpy():
                            representation += words
                            sum += 1
                    if (sum > 0):
                        representation = representation * 1.0 / sum

                    representation = list(representation)
                    while (len(representation) < pb.fgt_maxlength):
                        representation.append(np.zeros(b))

                    example.bert_mat = representation[0:pb.fgt_maxlength]

                print("{:.2%}".format(i * 1.0 / len(examples)))

    def _Get_Bert_Representation(self):

        count = 0
        bert_map = {}
        for root, dirs, files in os.walk("./data/test"):
            for file in files:
                file_path = os.path.join(root, file)
                print(file_path)

                file = open(file_path, "r")
                while True:
                    line = file.readline()
                    if not line:
                        break
                    line = line[:len(line) - 1]
                    line = line.split(" ")
                    line = line[:len(line) - 1]
                    line = " ".join(line)

                    if (line in bert_map.keys()):
                        continue

                    if (self.bert_tokenizer == None):
                        self.bert_tokenizer = BertTokenizer.from_pretrained(
                            'bert-base-uncased')

                    text = "[CLS] " + line + " [SEP]"
                    text = text.replace("  ", " ")
                    tokenized_text = self.bert_tokenizer.tokenize(text)

                    indexed_tokens = self.bert_tokenizer.convert_tokens_to_ids(
                        tokenized_text)
                    segments_ids = [0 for _ in tokenized_text]

                    tokens_tensor = torch.tensor([indexed_tokens])
                    segments_tensors = torch.tensor([segments_ids])

                    if (self.bert == None):
                        self.bert = BertModel.from_pretrained(
                            'bert-base-uncased')
                        self.bert.eval()

                    with torch.no_grad():
                        representation, sum = [], 0

                        encoded_layers, _ = self.bert(tokens_tensor,
                                                      segments_tensors)

                        Len = len(encoded_layers[-1].numpy()[0])
                        representation = np.zeros(768)
                        for i in range(1, Len - 1):
                            representation += encoded_layers[-1].numpy()[0][i]
                            sum += 1
                        representation = representation * 1.0 / sum

                        bert_map[line] = representation

                        count += 1
                        if (count % 100 == 0):
                            print(count)

        with open("./bert_map", 'wb') as file:
            pickle.dump(bert_map, file)

    def _Get_Word_Bert_Representation(self, word):

        if (word not in self.bert_map.keys()):
            if (self.bert_tokenizer == None):
                self.bert_tokenizer = BertTokenizer.from_pretrained(
                    'bert-base-uncased')

            text = "[CLS] " + word + " [SEP]"
            text = text.replace("  ", " ")
            tokenized_text = self.bert_tokenizer.tokenize(text)

            indexed_tokens = self.bert_tokenizer.convert_tokens_to_ids(
                tokenized_text)
            segments_ids = [0 for _ in tokenized_text]

            tokens_tensor = torch.tensor([indexed_tokens])
            segments_tensors = torch.tensor([segments_ids])

            if (self.bert == None):
                self.bert = BertModel.from_pretrained('bert-base-uncased')
                self.bert.eval()

            with torch.no_grad():
                representation, sum = [], 0

                encoded_layers, _ = self.bert(tokens_tensor, segments_tensors)

                Len = len(encoded_layers[-1].numpy()[0])
                representation = np.zeros(768)
                for i in range(1, Len - 1):
                    representation += encoded_layers[-1].numpy()[0][i]
                    sum += 1
                representation = representation * 1.0 / sum

                self.bert_map[word] = representation

        return self.bert_map[word]

    def Get_BOW_Representation(self, examples):
        volist = list(pb.vocabulary)
        for example in examples:
            x = [0.0 for _ in volist]
            for word in example.fgt_channels[0].split(" "):
                if (word in volist):
                    index = volist.index(word)
                    x[index] += 1
            example.bow_vec = x

    def Get_GPT2_Representation(self, examples):
        for i, example in enumerate(examples):

            # example.gpt2_mat = np.zeros((pb.fgt_maxlength,768))
            # continue

            if (self.gpt2_tokenizer == None):
                self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

            text = example.fgt_channels[0]
            indexed_tokens = self.gpt2_tokenizer.encode(text)
            tokens_tensor = torch.tensor([indexed_tokens])

            if (self.gpt2 == None):
                self.gpt2 = GPT2Model.from_pretrained('gpt2')
                self.gpt2.eval()

            with torch.no_grad():
                hidden_states, past = self.gpt2(tokens_tensor)  # (1, 5, 768)
                shape = np.array(hidden_states).shape

                representation, sum = [], 0

                a, b = shape[1], shape[2]
                representation = np.zeros((a, b))

                for layer in hidden_states:
                    for words in layer.numpy():
                        representation += words
                        sum += 1
                if (sum > 0):
                    representation = representation * 1.0 / sum

                representation = list(representation)
                while (len(representation) < pb.fgt_maxlength):
                    representation.append(np.zeros(b))

                example.gpt2_mat = representation[0:pb.fgt_maxlength]

            print("{:.2%}".format(i * 1.0 / len(examples)))

    def Get_Transformer_Representation(self, examples_train, examples_test):

        train_rep_file = "./data/" + pb.dataset + "_train_" + "transformerXL"
        test_rep_file = "./data/" + pb.dataset + "_test_" + "transformerXL"

        if (os.path.exists(train_rep_file) == True
                and os.path.exists(test_rep_file) == True):
            with open(train_rep_file, 'rb') as file:
                examples_train_rep = pickle.load(file)
                for i, example in enumerate(examples_train):
                    example.transformerXL_mat = examples_train_rep[i]
            with open(test_rep_file, 'rb') as file:
                examples_test_rep = pickle.load(file)
                for i, example in enumerate(examples_test):
                    example.transformerXL_mat = examples_test_rep[i]
        else:
            examples = []
            for example in examples_train:
                examples.append(example)
            for example in examples_test:
                examples.append(example)

            for i, example in enumerate(examples):

                # example.transformerXL_mat = np.zeros((pb.fgt_maxlength,20))
                # continue

                if (self.transformer_tokenizer == None):
                    self.transformer_tokenizer = TransfoXLTokenizer.from_pretrained(
                        'transfo-xl-wt103')

                text = example.fgt_channels[0]
                tokenized_text = self.transformer_tokenizer.tokenize(text)

                indexed_tokens = self.transformer_tokenizer.convert_tokens_to_ids(
                    tokenized_text)

                tokens_tensor = torch.tensor([indexed_tokens])

                if (self.transformer == None):
                    self.transformer = TransfoXLModel.from_pretrained(
                        'transfo-xl-wt103')
                    self.transformer.eval()

                with torch.no_grad():
                    hidden_states, _ = self.transformer(
                        tokens_tensor)  # (1, 3, 1024)
                    shape = np.array(hidden_states).shape
                    # print(shape)

                    representation, sum = [], 0

                    a, b = shape[1], shape[2]
                    representation = np.zeros((a, b))

                    for layer in hidden_states:
                        for words in layer.numpy():
                            representation += words
                            sum += 1
                    if (sum > 0):
                        representation = representation * 1.0 / sum

                    representation = list(representation)
                    while (len(representation) < pb.fgt_maxlength):
                        representation.append(np.zeros(b))

                    example.transformerXL_mat = representation[0:pb.
                                                               fgt_maxlength]

                print("{:.2%}".format(i * 1.0 / len(examples)))

    def Get_ELMo_Representation(self, examples):
        for i, example in enumerate(examples):

            # example.elmo_mat = np.zeros((pb.fgt_maxlength,*))
            # continue

            if (self.elmo == None):
                options_file = "./sources/elmo_2x1024_128_2048cnn_1xhighway_options.json"
                weight_file = "./sources/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"
                self.elmo = ElmoEmbedder(options_file, weight_file)

            text = example.fgt_channels[0]

            context_tokens = [text.split(" ")]
            elmo_embedding, _ = self.elmo.batch_to_embeddings(context_tokens)

            shape = np.array(elmo_embedding[0]).shape
            # print(shape)

            representation, sum = [], 0

            a, b = shape[1], shape[2]
            representation = np.zeros((a, b))

            for layer in elmo_embedding:
                for words in layer.numpy():
                    representation += words
                    sum += 1
            if (sum > 0):
                representation = representation * 1.0 / sum

            representation = list(representation)
            while (len(representation) < pb.fgt_maxlength):
                representation.append(np.zeros(b))

            example.elmo_mat = representation[0:pb.fgt_maxlength]

            print("{:.2%}".format(i * 1.0 / len(examples)))

    def Get_Glove_Representation(self, examples):
        glove_vocabulary = {}
        Len = 300

        for line in open("./sources/glove_" + str(Len) +
                         "d.dat").read().split("\n"):
            eles = line.split(" ")
            if (len(eles) == Len + 1):
                word = eles[0]
                vector = [float(ele) for ele in eles[1:]]
                glove_vocabulary[word] = vector

        for example in examples:
            representation = []
            for word in example.fgt_channels[0].split(" "):
                if (word in glove_vocabulary.keys()):
                    representation.append(glove_vocabulary[word])
                else:
                    representation.append(glove_vocabulary[self.default_word])
            while (len(representation) < pb.fgt_maxlength):
                representation.append(np.zeros(Len))
            example.glove_mat = representation[0:pb.fgt_maxlength]
            # print(np.array(example.glove_mat).shape)

    def _Get_Glove_Representation(self, examples):
        glove_vocabulary = {}
        Len = 300

        for line in open("./sources/glove_" + str(Len) +
                         "d.dat").read().split("\n"):
            eles = line.split(" ")
            if (len(eles) == Len + 1):
                word = eles[0]
                vector = [float(ele) for ele in eles[1:]]
                glove_vocabulary[word] = vector

        for example in examples:
            representation = []
            for word in example.fgt_channels[0].split(" "):
                if (word in glove_vocabulary.keys()):
                    representation.append(glove_vocabulary[word])
                else:
                    representation.append(glove_vocabulary[self.default_word])
            while (len(representation) < pb.fgt_maxlength):
                representation.append(np.zeros(Len))
            example.glove_mat = representation[0:pb.fgt_maxlength]
            # print(np.array(example.glove_mat).shape)

    def Get_Tag_Representation(self, examples_train, examples_test):
        examples = []
        for example in examples_train:
            examples.append(example)
        for example in examples_test:
            examples.append(example)

        tag_vocabulary = set()

        for i, example in enumerate(examples):
            tags = pb.Get_POS(example.fgt_channels[0])
            example.tags = tags
            for tag in tags:
                tag_vocabulary.add(tag)

        tag_vocabulary = sorted(list(tag_vocabulary))

        for example in examples:
            representation = []
            for tag in example.tags:
                x = np.zeros(len(tag_vocabulary))
                index = tag_vocabulary.index(tag)
                x[index] = 1.0
                representation.append(x)
            while (len(representation) < pb.fgt_maxlength):
                representation.append(np.zeros(len(tag_vocabulary)))
            example.tag_mat = representation[0:pb.fgt_maxlength]

    def Get_Bert_Representation_Tmp(self, sentences):

        sentences_bert = []

        for i in range(len(sentences)):
            print(sentences[i])

            if (self.bert_tokenizer == None):
                self.bert_tokenizer = BertTokenizer.from_pretrained(
                    'bert-base-uncased')

            text = "[CLS] " + sentences[i] + " [SEP]"
            text = text.replace("  ", " ")
            tokenized_text = self.bert_tokenizer.tokenize(text)

            indexed_tokens = self.bert_tokenizer.convert_tokens_to_ids(
                tokenized_text)
            segments_ids = [0 for _ in tokenized_text]

            tokens_tensor = torch.tensor([indexed_tokens])
            segments_tensors = torch.tensor([segments_ids])

            if (self.bert == None):
                self.bert = BertModel.from_pretrained('bert-base-uncased')
                self.bert.eval()

            with torch.no_grad():
                representation, sum = [], 0

                encoded_layers, _ = self.bert(tokens_tensor, segments_tensors)
                a, b = encoded_layers[0].numpy(
                ).shape[1], encoded_layers[0].numpy().shape[2]
                representation = np.zeros((a, b))

                for layer in encoded_layers:
                    for words in layer.numpy():
                        representation += words
                        sum += 1
                if (sum > 0):
                    representation = representation * 1.0 / sum

                representation = list(representation)
                while (len(representation) < pb.fgt_maxlength):
                    representation.append(np.zeros(b))

                example.bert_mat = representation[0:pb.fgt_maxlength]

            print("{:.2%}".format(i * 1.0 / len(examples)))

    def Get_Glove_Representation_Tmp(self, examples):
        glove_vocabulary = {}
        Len = 300

        word_glove_map = {}

        for line in open("./sources/glove_" + str(Len) +
                         "d.dat").read().split("\n"):
            eles = line.split(" ")
            if (len(eles) == Len + 1):
                word = eles[0]
                vector = [float(ele) for ele in eles[1:]]
                glove_vocabulary[word] = vector

        for sentence in examples:
            for word in sentence:
                if (word in glove_vocabulary.keys()):
                    word_glove_map[word] = glove_vocabulary[word]
                else:
                    word_glove_map[word] = glove_vocabulary[self.default_word]

        return word_glove_map
class ElmoEmbeddings(nn.Module):
    def __init__(self, device_id, dropout, batch_size=128):
        super(ElmoEmbeddings, self).__init__()
        self.batch_size = batch_size
        self.elmo = ElmoEmbedder(cuda_device=device_id)
        self.dropout = nn.Dropout(dropout)

    def forward(self, indices, raw_text, starts_sentence):
        assert len(raw_text) == len(starts_sentence)
        all_sentences = []
        for edus, edu_starts_sentence in zip(raw_text, starts_sentence):
            ends_sentence = edu_starts_sentence[1:] + [True]
            sentences, sentence = [], []
            for edu_words, end_of_sentence in zip(edus, ends_sentence):
                sentence.extend(edu_words)
                if end_of_sentence:
                    sentences.append(sentence)
                    sentence = []

            all_sentences.extend(sentences)

        # Run ELMo Embedder
        sentence_embeddings = []
        for min_batch in self.batch_iter(all_sentences, self.batch_size):
            sentence_embeddings.extend(self._forward(min_batch))

        # Sentence embeddings -> EDU embeddings
        sentence_idx = 0
        batch_edu_embeddings = []
        for edus, edu_starts_sentence in zip(raw_text, starts_sentence):
            ends_sentence = edu_starts_sentence[1:] + [True]
            edu_offset = 0
            edu_embeddings = []
            for edu_words, end_of_sentence in zip(edus, ends_sentence):
                edu_length = len(edu_words)
                edu_embedding = sentence_embeddings[sentence_idx][
                    edu_offset:edu_offset + edu_length]
                edu_embeddings.append(edu_embedding)

                edu_offset += edu_length
                if end_of_sentence:
                    sentence_idx += 1
                    edu_offset = 0

            # edu_embeddings: Num_edus, Num_words, embedding_size
            edu_embeddings = pad_sequence(edu_embeddings,
                                          batch_first=True,
                                          padding_value=0)
            max_num_words = indices.size(2)
            diff = max_num_words - edu_embeddings.size(1)
            edu_embeddings = torch.nn.functional.pad(edu_embeddings,
                                                     (0, 0, 0, diff))
            batch_edu_embeddings.append(edu_embeddings)

        embeddings = pad_sequence(batch_edu_embeddings,
                                  batch_first=True,
                                  padding_value=0)

        B, E, W, _ = embeddings.size()
        _B, _E, _W = indices.size()
        assert B == _B
        assert E == _E
        assert W == _W
        return self.dropout(embeddings)

    def forward_for_sentences(self, sentences):
        vectors = []
        max_length = max([len(sentence) for sentence in sentences])
        for min_batch in self.batch_iter(sentences, self.batch_size):
            embeddings = self._forward(min_batch)
            diff = max_length - embeddings.size(1)
            embeddings = torch.nn.functional.pad(embeddings, (0, 0, 0, diff))
            vectors.append(embeddings)

        vectors = torch.cat(vectors, dim=0)
        return vectors

    def batch_iter(self, iterable, batch_size=1):
        l = len(iterable)
        for offset in range(0, l, batch_size):
            yield iterable[offset:min(offset + batch_size, l)]

    def _forward(self, raw_text):
        elmo_vectors, _ = self.elmo.batch_to_embeddings(raw_text)
        B, _, L, E = elmo_vectors.size()
        elmo_vectors = elmo_vectors.transpose(1, 2)  # Bx3xLxE -> BxLx3xE
        elmo_vectors = elmo_vectors.contiguous().view(B, L,
                                                      -1)  # BxLx3xE -> BxLx3*E
        return elmo_vectors

    def get_embed_size(self):
        return 1024 * 3