예제 #1
0
class ElmoEmbeddings(nn.Module):
    """
    Embedding layer that returns deep contextualized word representations -> check: https://arxiv.org/abs/1802.05365
    """
    def __init__(self, word2ix, layers=3, dropout=0.):
        """
        :param word2ix: Dictionary that maps token words to indexes.
        :param layers: (int) Number of ELMo layers to extract. (3 is recommended)
        :param dropout: (float) Dropout to be applied to the ELMo model.
        """
        super(ElmoEmbeddings, self).__init__()
        self.elmo = Elmo(options_file, weight_file, layers, dropout=dropout)
        self.task_weigths = nn.Parameter(
            torch.tensor(np.random.random(3), dtype=torch.float32))
        self.task_scalar = nn.Parameter(
            torch.tensor(np.random.random(1), dtype=torch.float32))
        self.ix2word = {v: k for k, v in word2ix.items()}
        self.size = 1024

    def forward(self, inputs):
        """
        :param inputs: torch.tensor (Batch x maxlength) with vectorized and padded sequences.
        """
        character_ids = batch_to_ids(
            [[self.ix2word[ix.item()] for ix in sequence if ix != 0]
             for sequence in inputs])
        embeddings = self.elmo(character_ids.cuda())
        # We will scale each layer according to task specific soft-max-normalized weights
        softmax_norm_weigths = self.task_weigths.softmax(0)
        weighted_embeddings = softmax_norm_weigths[0] * embeddings[
            'elmo_representations'][0]
        for i in range(1, len(embeddings['elmo_representations'])):
            weighted_embeddings += softmax_norm_weigths[i] * embeddings[
                'elmo_representations'][i]
        # Finally we will scale the entire embeddings according to a task-specific weight.
        return weighted_embeddings * self.task_scalar[0], embeddings['mask']

    def freeze(self):
        """
        Function that freezes ELMo. 
        NOTE: this function does not freeze the task specific weights used to fine-tune the embeddings
        """
        for param in self.elmo.parameters():
            param.requires_grad = False

    def unfreeze(self):
        """
        Unfreezes ELMo. -> ELMo fine-tunning.
        """
        for param in self.elmo.parameters():
            param.requires_grad = True
class ElmoLayer(nn.Module):
    def __init__(self, char_table, conf):
        super(ElmoLayer, self).__init__()
        self.conf = conf
        lookup, length = char_table
        self.lookup = nn.Embedding(lookup.size(0), lookup.size(1))
        self.lookup.weight.data.copy_(lookup)
        self.lookup.weight.requires_grad = False
        self.elmo = Elmo(os.path.expanduser(self.conf.elmo_options),
                         os.path.expanduser(self.conf.elmo_weights),
                         num_output_representations=2,
                         do_layer_norm=False,
                         dropout=self.conf.embed_dropout)
        for p in self.elmo.parameters():
            p.requires_grad = False
        self.w = nn.Parameter(torch.Tensor([0.5, 0.5]))
        self.gamma = nn.Parameter(torch.ones(1))
        self.conv = nn.Conv1d(1024, self.conf.elmo_dim, 1)
        nn.init.xavier_uniform(self.conv.weight)
        self.conv.bias.data.fill_(0)

    def forward(self, input):
        charseq = self.lookup(input).long()
        res = self.elmo(charseq)['elmo_representations']
        w = F.softmax(self.w, dim=0)
        res = self.gamma * (w[0] * res[0] + w[1] * res[1])
        res = self.conv(res.transpose(1, 2)).transpose(1, 2)
        return res
class QuestionFocusModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        elmo_path = config['elmo']
        elmo_option_file = os.path.join(
            elmo_path, "elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json")
        elmo_weight_file = os.path.join(
            elmo_path, "elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5")
        self.elmo = Elmo(elmo_option_file, elmo_weight_file, 2)
        for p in self.elmo.parameters():
            p.requires_grad = False

        self.elmo_reduction = FeedForward(1024 * 2, 300)

        self.v_cls = nn.Parameter(torch.randn([300, 1]).float())
        self.k = FeedForward(300, 300)
        self.v = FeedForward(300, 300)
        self.softmax_simi = nn.Softmax(dim=1)

    def forward(self, tensor_input):
        elmo_embed = self.elmo(tensor_input)['elmo_representations']
        cat_embd = torch.cat(elmo_embed, 2)
        redu_embd = self.elmo_reduction(cat_embd)
        linear_key = self.k(redu_embd)
        linear_value = self.v(redu_embd)
        dotProducSimi = linear_key.matmul(self.v_cls)
        normedSimi = self.softmax_simi(dotProducSimi)
        attVector = linear_value.mul(normedSimi)
        weightedSum = torch.sum(attVector, dim=1)
        return weightedSum, dotProducSimi
class ElmoEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        elmo_path = config['elmo']
        elmo_option_file = os.path.join(
            elmo_path, "elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json")
        elmo_weight_file = os.path.join(
            elmo_path, "elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5")
        self.elmo = Elmo(elmo_option_file, elmo_weight_file, 2)
        for p in self.elmo.parameters():
            p.requires_grad = False

    def forward(self, x):
        elmo_embed = self.elmo(x)['elmo_representations']
        cat_embd = torch.cat(elmo_embed, 2)
        return cat_embd
예제 #5
0
class Model(nn.Module):
    def __init__(self, opt):

        super(Model, self).__init__()
        self.opt = opt
        self.use_gpu = self.opt.use_gpu

        if opt.emb_method == 'elmo':
            self.init_elmo()
        elif self.opt.emb_method == 'glove':
            self.init_glove()
        elif self.opt.emb_method == 'bert':
            self.init_bert()

        self.encoder = Encoder(opt.enc_method, self.word_dim, opt.hidden_size,
                               opt.out_size)
        self.cls = nn.Linear(opt.out_size, opt.num_labels)
        nn.init.uniform_(self.cls.weight, -0.1, 0.1)
        nn.init.uniform_(self.cls.bias, -0.1, 0.1)
        self.dropout = nn.Dropout(self.opt.dropout)

    def forward(self, x):
        if self.opt.emb_method == 'elmo':
            word_embs = self.get_elmo(x)
        elif self.opt.emb_method == 'glove':
            word_embs = self.get_glove(x)
        elif self.opt.emb_method == 'bert':
            word_embs = self.get_bert(x)

        x = self.encoder(word_embs)
        x = self.dropout(x)
        x = self.cls(x)  # batch_size * num_label
        return x

    def init_bert(self):
        '''
        initilize the Bert model
        '''
        self.bert_tokenizer = AutoTokenizer.from_pretrained(self.opt.bert_path)
        self.bert = AutoModel.from_pretrained(self.opt.bert_path)
        for param in self.bert.parameters():
            param.requires_grad = False
        self.word_dim = self.opt.bert_dim

    def init_elmo(self):
        '''
        initilize the ELMo model
        '''
        self.elmo = Elmo(self.opt.elmo_options_file, self.opt.elmo_weight_file,
                         1)
        for param in self.elmo.parameters():
            param.requires_grad = False
        self.word_dim = self.opt.elmo_dim

    def init_glove(self):
        '''
        load the GloVe model
        '''
        self.word2id = np.load(self.opt.word2id_file,
                               allow_pickle=True).tolist()
        self.glove = nn.Embedding(self.opt.vocab_size, self.opt.glove_dim)
        emb = torch.from_numpy(np.load(self.opt.glove_file, allow_pickle=True))
        if self.use_gpu:
            emb = emb.to(self.opt.device)
        self.glove.weight.data.copy_(emb)
        self.word_dim = self.opt.glove_dim

    def get_bert(self, sentence_lists):
        '''
        get the ELMo word embedding vectors for a sentences
        '''
        sentence_lists = [' '.join(x) for x in sentence_lists]
        ids = self.bert_tokenizer(sentence_lists,
                                  padding=True,
                                  return_tensors="pt")
        inputs = ids['input_ids']
        if self.opt.use_gpu:
            inputs = inputs.to(self.opt.device)

        embeddings = self.bert(inputs)
        return embeddings[0]

    def get_elmo(self, sentence_lists):
        '''
        get the ELMo word embedding vectors for a sentences
        '''
        character_ids = batch_to_ids(sentence_lists)
        if self.opt.use_gpu:
            character_ids = character_ids.to(self.opt.device)
        embeddings = self.elmo(character_ids)
        return embeddings['elmo_representations'][0]

    def get_glove(self, sentence_lists):
        '''
        get the glove word embedding vectors for a sentences
        '''
        max_len = max(map(lambda x: len(x), sentence_lists))
        sentence_lists = list(
            map(lambda x: list(map(lambda w: self.word2id.get(w, 0), x)),
                sentence_lists))
        sentence_lists = list(
            map(lambda x: x + [self.opt.vocab_size - 1] * (max_len - len(x)),
                sentence_lists))
        sentence_lists = torch.LongTensor(sentence_lists)
        if self.use_gpu:
            sentence_lists = sentence_lists.to(self.opt.device)
        embeddings = self.glove(sentence_lists)

        return embeddings
예제 #6
0
class ElmoGP(nn.Module):
    def __init__(self,
                 word_dim,
                 num_words,
                 char_dim,
                 num_chars,
                 pos_dim,
                 num_pos,
                 num_xpos,
                 num_filters,
                 kernel_size,
                 hidden_size,
                 num_layers,
                 num_labels,
                 arc_space,
                 type_space,
                 embed_word=None,
                 p_in=0.33,
                 p_out=0.33,
                 p_rnn=(0.33, 0.33),
                 biaffine=True,
                 pos=True,
                 char=True,
                 init_emb=False,
                 prelstm_args=None,
                 elmo=False,
                 lattice=None,
                 delex=False,
                 word_dictionary=None,
                 char_dictionary=None,
                 use_gpu=False):
        super(ElmoGP, self).__init__()

        self.word_embed = nn.Embedding(num_words,
                                       word_dim) if delex == False else None
        self.pos_embed = nn.Embedding(num_pos, pos_dim) if pos else None
        self.xpos_embed = nn.Embedding(num_xpos, pos_dim) if pos else None
        self.char_embed = nn.Embedding(
            num_words, char_dim) if delex == False and char else None
        self.word_dictionary = word_dictionary

        self.conv1d = nn.Conv1d(
            char_dim, num_filters, kernel_size, padding=kernel_size -
            1) if delex == False and char else None
        self.dropout_in = nn.Dropout2d(p=p_in)
        self.dropout_out = nn.Dropout2d(p=p_out)
        self.num_labels = num_labels
        self.pos = pos
        self.char = char
        self.elmo_use = elmo
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.lattice = lattice
        self.delex = delex
        self.init_lstm = os.path.exists(prelstm_args)
        self.use_gpu = use_gpu

        dim_enc = 0
        if delex == False:
            dim_enc = word_dim
        if self.pos:
            dim_enc += pos_dim
        if delex == False and self.char:
            dim_enc += num_filters
        if self.elmo_use:
            self.prelstm_args = None
            if prelstm_args.endswith('hdf5'):
                self.prelstm_args = {}
                self.prelstm_args['hidden_size'] = 512
                self.prelstm_args['num_layers'] = 1
                self.prelstm_args['type'] = 'allen'
            else:
                self.prelstm_args = json.load(open(prelstm_args))
                self.prelstm_args['type'] = 'inhouse'
            dim_enc += 2 * self.prelstm_args['hidden_size']
        if lattice:
            if lattice[1]:
                dim_enc += lattice[0].morph_value_dictionary.size()
            elif lattice[2] != -1:
                dim_enc += lattice[2]
        print('word feature size = %d' % (dim_enc))

        self.rnn = VarMaskedFastLSTM(dim_enc,
                                     hidden_size,
                                     num_layers=num_layers,
                                     batch_first=True,
                                     bidirectional=True,
                                     dropout=p_rnn)

        out_dim = hidden_size * 2
        self.arc_h = nn.Linear(out_dim, arc_space)
        self.arc_c = nn.Linear(out_dim, arc_space)
        self.attention = BiAttention(arc_space,
                                     arc_space,
                                     1,
                                     biaffine=biaffine)

        self.type_h = nn.Linear(out_dim, type_space)
        self.type_c = nn.Linear(out_dim, type_space)
        self.bilinear = BiLinear(type_space, type_space, self.num_labels)

        if init_emb:
            self.initialize_embed(embed_word, pos_dim, char_dim)

        if lattice:
            if lattice[2] != -1:
                self.morpho_specific_embed = nn.Embedding(
                    lattice[0].morph_value_dictionary.size(), lattice[2])
                if lattice[3] == 'Group':
                    self.morpho_group_scalar = nn.Embedding(
                        lattice[0].morph_class_dictionary.size(), 1)
                elif lattice[3] == 'Specific':
                    self.morpho_specific_scalar = nn.Embedding(
                        lattice[0].morph_value_dictionary.size(), 1)
        #if self.init_lstm:
        #  self.initialize_lstm(prelstm_args[0:prelstm_args.rfind('/')])
        if self.elmo_use:
            print('launching elmo layers from %s' % prelstm_args)
            if prelstm_args.endswith('hdf5'):
                options_file = prelstm_args[0:prelstm_args.rfind("/") +
                                            1] + "options.json"
                self.elmo = Elmo(
                    options_file, prelstm_args, 1, dropout=p_in
                )  # add scalar_mix_parameters=[1,1,1] to do unweighted average of different layers
            else:
                self.elmo = Elmo(prelstm_args, p_rnn, kernel_size, p_in,
                                 word_dictionary, char_dictionary)
            if not prelstm_args.endswith('hdf5'):
                self.scalar_parameters = ParameterList([
                    Parameter(torch.FloatTensor([0.0]))
                    for _ in range(self.num_layers + 1)
                ])
                self.gamma = Parameter(torch.FloatTensor([1.0]))
                for param in self.elmo.parameters():
                    param.requires_grad = False

    def initialize_embed(self, embed_word, pos_dim, char_dim):
        if self.delex == False:
            self.word_embed.weight.data = embed_word
        if self.pos:
            self.pos_embed.weight.data.uniform_(-(3.0 / pos_dim),
                                                (3.0 / pos_dim))
            self.xpos_embed.weight.data.uniform_(-(3.0 / pos_dim),
                                                 (3.0 / pos_dim))
        if self.delex == False and self.char:
            self.char_embed.weight.data.uniform_(-(3.0 / char_dim),
                                                 (3.0 / char_dim))

    def initialize_lstm(self, lstm_dir):
        # load pretrained lstm model
        print('initializing pre-trained lstm...')
        model_id = max([
            int(file.split('/')[-1].split('_')[-1].split('.')[0])
            for file in glob.glob(os.path.join(lstm_dir, 'model_*.pt'))
        ])
        model_id = str(model_id)
        model_path = os.path.join(lstm_dir,
                                  'model_epoch_' + str(model_id) + '.pt')
        pretrained_dict = torch.load(model_path,
                                     map_location=lambda storage, loc: storage)

        # initialize the rnn parts
        rnn_dict = self.rnn.state_dict()
        new_dict = {}
        for k, v in pretrained_dict.items():
            if 'rnn' in k:
                k = k[len('rnn.'):]
                assert (k in rnn_dict)
                new_dict[k] = v
        rnn_dict.update(new_dict)
        self.rnn.load_state_dict(new_dict)

    def _get_rnn_output(self,
                        input_word,
                        input_char,
                        input_pos,
                        input_xpos,
                        mask=None,
                        length=None,
                        hx=None,
                        input_morph=None,
                        vocab_expand=None):
        features = []
        if self.delex == False:
            # [batch, length, word_dim]
            word = self.word_embed(input_word)
            if vocab_expand != None and vocab_expand[0] != None:
                # expand word vocab
                oov_embed_dict, raw_words = vocab_expand
                for bi in range(input_word.size(0)):
                    for wi in range(input_word.size(1)):
                        if input_word[bi][wi] == 0:
                            assert (mask[bi][wi] == 1)
                            assert (wi < len(raw_words[bi]))
                            if raw_words[bi][wi] in oov_embed_dict:
                                word[bi][wi] = oov_embed_dict[raw_words[bi]
                                                              [wi]]
            # apply dropout on input
            word = self.dropout_in(word)
            features.append(word)

        if self.delex == False and self.char:
            # [batch, length, char_length, char_dim]
            char = self.char_embed(input_char)
            char_size = char.size()
            # first transform to [batch *length, char_length, char_dim]
            # then transpose to [batch * lengtscreen -S your_session_nameh, char_dim, char_length]
            char = char.view(char_size[0] * char_size[1], char_size[2],
                             char_size[3]).transpose(1, 2)
            # put into cnn [batch*length, char_filters, char_length]
            # then put into maxpooling [batch * length, char_filters]
            char, _ = self.conv1d(char).max(dim=2)
            # reshape to [batch, length, char_filters]
            char = torch.tanh(char).view(char_size[0], char_size[1], -1)
            # apply dropout on input
            char = self.dropout_in(char)
            # concatenate word and char [batch, length, word_dim+char_filter]
            features.append(char)

        if self.pos:
            # [batch, length, pos_dim]
            pos = self.pos_embed(input_pos)
            xpos = self.xpos_embed(input_xpos)
            # apply dropout on input
            pos = self.dropout_in(pos)
            xpos = self.dropout_in(xpos)
            pos = pos + xpos
            features.append(pos)

        if self.elmo_use:
            if self.prelstm_args['type'] == 'allen':
                sentences, elmo_embed = [], None
                max_v, max_id = length.max(0)
                for bi in range(input_word.size(0)):
                    sent = []
                    for wi in range(input_word.size(1)):
                        if input_word[bi][wi] == 0:
                            sent.append("<<unk>>")
                        elif input_word[bi][wi] == 2:
                            sent.append(".")
                        elif input_word[bi][wi] > 2:
                            sent.append(
                                self.word_dictionary.get_instance(
                                    input_word[bi][wi]).replace(' ', '_'))
                    if max_id.item() == bi:
                        num_pads = input_word.size(1) - len(sent)
                        for pi in range(num_pads):
                            sent.append(
                                "."
                            )  # adding to account for 1 extra pad per every instance
                    sentences.append(sent)
                character_ids = batch_to_ids(sentences)
                if self.use_gpu:
                    character_ids = character_ids.cuda()
                elmo_embed = self.elmo(
                    character_ids)["elmo_representations"][0]
                features.append(elmo_embed)
            else:
                hid, word = self.elmo(input_word, input_char, input_pos, mask,
                                      length, hx)
                hid = torch.stack(hid)
                hid = hid.view(-1, word.size(1),
                               self.prelstm_args['hidden_size'] * 2,
                               self.prelstm_args['num_layers'])
                elmo_embed = torch.cat([hid, word.unsqueeze(3)], dim=3)
                normed_weights = torch.nn.functional.softmax(torch.cat(
                    [parameter for parameter in self.scalar_parameters]),
                                                             dim=0)
                layer_sum = torch.matmul(
                    elmo_embed, normed_weights.unsqueeze(1)) * self.gamma
                layer_sum = layer_sum.squeeze(3)
                layer_sum = self.dropout_in(layer_sum)
                features.append(layer_sum)

        if self.lattice:
            if self.lattice[1]:
                features.append(input_morph)
            elif self.lattice[2] != -1:
                morph_specific, morph_group, morph_offsets = input_morph
                bsize, max_sent_len, max_morph_len = morph_specific.size()
                morph_master = []
                if self.lattice[3] == 'None':
                    for bi in range(bsize):
                        for si in range(max_sent_len):
                            morph_offset = morph_offsets[bi][si]
                            morph_feats = self.morpho_specific_embed(
                                morph_specific[bi][si]
                                [0:morph_offset.item()]).mean(0)
                            morph_master.append(morph_feats)
                elif self.lattice[3] == 'Specific':
                    for bi in range(bsize):
                        for si in range(max_sent_len):
                            morph_offset = morph_offsets[bi][si]
                            morph_feats = self.morpho_specific_embed(
                                morph_specific[bi][si][0:morph_offset.item()])
                            morph_scalars = self.morpho_specific_scalar(
                                morph_specific[bi][si]
                                [0:morph_offset.item()]).view(
                                    morph_offset.item())
                            morph_scalars = nn.Softmax(dim=0)(morph_scalars)
                            morph_master.append(
                                torch.matmul(morph_scalars, morph_feats))
                elif self.lattice[3] == 'Group':
                    for bi in range(bsize):
                        for si in range(max_sent_len):
                            morph_offset = morph_offsets[bi][si]
                            morph_feats = self.morpho_specific_embed(
                                morph_specific[bi][si][0:morph_offset.item()])
                            morph_scalars = self.morpho_group_scalar(
                                morph_group[bi][si]
                                [0:morph_offset.item()]).view(
                                    morph_offset.item())
                            morph_scalars = nn.Softmax(dim=0)(morph_scalars)
                            morph_master.append(
                                torch.matmul(morph_scalars, morph_feats))
                morph_master = torch.stack(morph_master).view(
                    bsize, max_sent_len, -1)
                morph_master = self.dropout_in(morph_master)
                features.append(morph_master)

        input = features[0]
        for i in range(len(features) - 1):
            input = torch.cat([input, features[i + 1]], dim=2)

        # output from rnn [batch, length, hidden_size]
        output, hn, _ = self.rnn(input, mask, hx=hx)

        # apply dropout for output
        # [batch, length, hidden_size] --> [batch, hidden_size, length] --> [batch, length, hidden_size]
        output = self.dropout_out(output.transpose(1, 2)).transpose(1, 2)

        # output size [batch, length, arc_space]
        arc_h = F.elu(self.arc_h(output))
        arc_c = F.elu(self.arc_c(output))

        # output size [batch, length, type_space]
        type_h = F.elu(self.type_h(output))
        type_c = F.elu(self.type_c(output))

        # apply dropout
        # [batch, length, dim] --> [batch, 2 * length, dim]
        arc = torch.cat([arc_h, arc_c], dim=1)
        type = torch.cat([type_h, type_c], dim=1)

        arc = self.dropout_out(arc.transpose(1, 2)).transpose(1, 2)
        arc_h, arc_c = arc.chunk(2, 1)

        type = self.dropout_out(type.transpose(1, 2)).transpose(1, 2)
        type_h, type_c = type.chunk(2, 1)
        type_h = type_h.contiguous()
        type_c = type_c.contiguous()

        return (arc_h, arc_c), (type_h, type_c), hn, mask, length

    def forward(self,
                input_word,
                input_char,
                input_pos,
                input_xpos,
                mask=None,
                length=None,
                hx=None,
                input_morph=None,
                vocab_expand=None):
        # output from rnn [batch, length, tag_space]
        arc, type, _, mask, length = self._get_rnn_output(
            input_word,
            input_char,
            input_pos,
            input_xpos,
            mask=mask,
            length=length,
            hx=hx,
            input_morph=input_morph,
            vocab_expand=vocab_expand)
        # [batch, length, length]
        out_arc = self.attention(arc[0], arc[1], mask_d=mask,
                                 mask_e=mask).squeeze(dim=1)
        return out_arc, type, mask, length

    def loss(self,
             input_word,
             input_char,
             input_pos,
             input_xpos,
             heads,
             types,
             mask=None,
             length=None,
             hx=None,
             input_morph=None):
        # out_arc shape [batch, length, length]
        out_arc, out_type, mask, length = self.forward(input_word,
                                                       input_char,
                                                       input_pos,
                                                       input_xpos,
                                                       mask=mask,
                                                       length=length,
                                                       hx=hx,
                                                       input_morph=input_morph)
        batch, max_len, _ = out_arc.size()

        if length is not None and heads.size(1) != mask.size(1):
            heads = heads[:, :max_len]
            types = types[:, :max_len]

        # out_type shape [batch, length, type_space]
        type_h, type_c = out_type

        # create batch index [batch]
        batch_index = torch.arange(0, batch).type_as(out_arc.data).long()
        # get vector for heads [batch, length, type_space]
        type_h = type_h[batch_index, heads.data.t()].transpose(0,
                                                               1).contiguous()
        # compute output for type [batch, length, num_labels]
        out_type = self.bilinear(type_h, type_c)

        # mask invalid position to -inf for log_softmax
        if mask is not None:
            minus_inf = -1e8
            minus_mask = (1 - mask) * minus_inf
            out_arc = out_arc + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(
                1)

        # loss_arc shape [batch, length, length]
        loss_arc = F.log_softmax(out_arc, dim=1)
        # loss_type shape [batch, length, num_labels]
        loss_type = F.log_softmax(out_type, dim=2)

        # mask invalid position to 0 for sum loss
        if mask is not None:
            loss_arc = loss_arc * mask.unsqueeze(2) * mask.unsqueeze(1)
            loss_type = loss_type * mask.unsqueeze(2)
            # number of valid positions which contribute to loss (remove the symbolic head for each sentence.
            num = mask.sum() - batch
        else:
            # number of valid positions which contribute to loss (remove the symbolic head for each sentence.
            num = float(max_len - 1) * batch

        # first create index matrix [length, batch]
        child_index = torch.arange(0, max_len).view(max_len,
                                                    1).expand(max_len, batch)
        child_index = child_index.type_as(out_arc.data).long()
        # [length-1, batch]
        loss_arc = loss_arc[batch_index, heads.data.t(), child_index][1:]
        loss_type = loss_type[batch_index, child_index, types.data.t()][1:]

        return -loss_arc.sum() / num, -loss_type.sum() / num

    def decode(self,
               input_word,
               input_char,
               input_pos,
               input_xpos,
               mask=None,
               length=None,
               hx=None,
               leading_symbolic=0,
               decode='mst',
               input_morph=None,
               vocab_expand=None):
        if decode == 'mst':
            return self.decode_mst(input_word, input_char, input_pos,
                                   input_xpos, mask, length, hx,
                                   leading_symbolic, input_morph, vocab_expand)
        return self.decode_greedy(input_word, input_char, input_pos,
                                  input_xpos, mask, length, hx,
                                  leading_symbolic, input_morph)

    def _decode_types(self, out_type, heads, leading_symbolic):
        # out_type shape [batch, length, type_space]
        type_h, type_c = out_type
        batch, max_len, _ = type_h.size()
        # create batch index [batch]
        batch_index = torch.arange(0, batch).type_as(type_h.data).long()
        # get vector for heads [batch, length, type_space],
        type_h = type_h[batch_index, heads.t()].transpose(0, 1).contiguous()
        # compute output for type [batch, length, num_labels]
        out_type = self.bilinear(type_h, type_c)
        # remove the first #leading_symbolic types.
        out_type = out_type[:, :, leading_symbolic:]
        # compute the prediction of types [batch, length]
        _, types = out_type.max(dim=2)
        return types + leading_symbolic

    def decode_greedy(self,
                      input_word,
                      input_char,
                      input_pos,
                      input_xpos,
                      mask=None,
                      length=None,
                      hx=None,
                      leading_symbolic=0,
                      input_morph=None):
        # out_arc shape [batch, length, length]
        out_arc, out_type, mask, length = self.forward(input_word,
                                                       input_char,
                                                       input_pos,
                                                       input_xpos,
                                                       mask=mask,
                                                       length=length,
                                                       hx=hx,
                                                       input_morph=input_morph)
        out_arc = out_arc.data
        batch, max_len, _ = out_arc.size()
        # set diagonal elements to -inf
        out_arc = out_arc + torch.diag(out_arc.new(max_len).fill_(-np.inf))
        # set invalid positions to -inf
        if mask is not None:
            # minus_mask = (1 - mask.data).byte().view(batch, max_len, 1)
            minus_mask = (1 - mask.data).byte().unsqueeze(2)
            out_arc.masked_fill_(minus_mask, -np.inf)

        # compute naive predictions.
        # predition shape = [batch, length]
        _, heads = out_arc.max(dim=1)

        types = self._decode_types(out_type, heads, leading_symbolic)

        return heads.cpu().numpy(), types.data.cpu().numpy()

    def decode_mst(self,
                   input_word,
                   input_char,
                   input_pos,
                   input_xpos,
                   mask=None,
                   length=None,
                   hx=None,
                   leading_symbolic=0,
                   input_morph=None,
                   vocab_expand=None):
        '''
    Args:
        input_word: Tensor
            the word input tensor with shape = [batch, length]
        input_char: Tensor
            the character input tensor with shape = [batch, length, char_length]
        input_pos: Tensor
            the pos input tensor with shape = [batch, length]
        mask: Tensor or None
            the mask tensor with shape = [batch, length]
        length: Tensor or None
            the length tensor with shape = [batch]
        hx: Tensor or None
            the initial states of RNN
        leading_symbolic: int
            number of symbolic labels leading in type alphabets (set it to 0 if you are not sure)

    Returns: (Tensor, Tensor)
            predicted heads and types.

    '''
        # out_arc shape [batch, length, length]
        out_arc, out_type, mask, length = self.forward(
            input_word,
            input_char,
            input_pos,
            input_xpos,
            mask=mask,
            length=length,
            hx=hx,
            input_morph=input_morph,
            vocab_expand=vocab_expand)

        # out_type shape [batch, length, type_space]
        type_h, type_c = out_type
        batch, max_len, type_space = type_h.size()

        # compute lengths
        if length is None:
            if mask is None:
                length = [max_len for _ in range(batch)]
            else:
                length = mask.data.sum(dim=1).long().cpu().numpy()
        type_h = type_h.unsqueeze(2).expand(batch, max_len, max_len,
                                            type_space).contiguous()
        type_c = type_c.unsqueeze(1).expand(batch, max_len, max_len,
                                            type_space).contiguous()
        # compute output for type [batch, length, length, num_labels]
        out_type = self.bilinear(type_h, type_c)

        # mask invalid position to -inf for log_softmax
        if mask is not None:
            minus_inf = -1e8
            minus_mask = (1 - mask) * minus_inf
            out_arc = out_arc + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(
                1)

        # loss_arc shape [batch, length, length]
        loss_arc = F.log_softmax(out_arc, dim=1)
        # loss_type shape [batch, length, length, num_labels]
        loss_type = F.log_softmax(out_type, dim=3).permute(0, 3, 1, 2)
        # [batch, num_labels, length, length]
        energy = torch.exp(loss_arc.unsqueeze(1) + loss_type)

        return self._decode_MST(energy.data.cpu().numpy(),
                                length,
                                leading_symbolic=leading_symbolic,
                                labeled=True)

    def _decode_MST(self, energies, lengths, leading_symbolic=0, labeled=True):
        """
    decode best parsing tree with MST algorithm.
    :param energies: energies: numpy 4D tensor
        energies of each edge. the shape is [batch_size, num_labels, n_steps, n_steps],
        where the dummy root is at index 0.
    :param masks: numpy 2D tensor
        masks in the shape [batch_size, 1].
    :param leading_symbolic: int
        number of symbolic dependency types leading in type alphabets)
    :return:
    """
        input_shape = energies.shape
        batch_size = input_shape[0]
        max_length = input_shape[2]

        pars = np.full((batch_size, max_length), -1, dtype=np.int32)
        types = np.full((batch_size, max_length), -1, dtype=np.int32)
        for i in range(batch_size):
            energy = energies[i]
            length = lengths[i]

            # calc real energy matrix shape = [length, length, num_labels - #symbolic] (remove the label for symbolic types).
            energy = energy[leading_symbolic:, :length, :length]
            label_id_matrix = energy.argmax(axis=0) + leading_symbolic
            energy = energy.max(axis=0)
            score_matrix = np.array(energy, copy=True)
            np.fill_diagonal(score_matrix, -1.)
            score_matrix[:, 0] = -1.
            head_pred = mst(score_matrix.T)
            pars[i][1:length] = head_pred[1:]
            for j in range(1, length):
                types[i][j] = label_id_matrix[head_pred[j]][j]
        return pars, types

    def __decode_MST(self,
                     energies,
                     lengths,
                     leading_symbolic=0,
                     labeled=True):
        """
    decode best parsing tree with MST algorithm.
    :param energies: energies: numpy 4D tensor
        energies of each edge. the shape is [batch_size, num_labels, n_steps, n_steps],
        where the summy root is at index 0.
    :param masks: numpy 2D tensor
        masks in the shape [batch_size, n_steps].
    :param leading_symbolic: int
        number of symbolic dependency types leading in type alphabets)
    :return:
    """
        def find_cycle(par):
            added = np.zeros([length], np.bool)
            added[0] = True
            cycle = set()
            findcycle = False
            for i in range(1, length):
                if findcycle:
                    break

                if added[i] or not curr_nodes[i]:
                    continue

                # init cycle
                tmp_cycle = set()
                tmp_cycle.add(i)
                added[i] = True
                findcycle = True
                l = i

                while par[l] not in tmp_cycle:
                    l = par[l]
                    if added[l]:
                        findcycle = False
                        break
                    added[l] = True
                    tmp_cycle.add(l)

                if findcycle:
                    lorg = l
                    cycle.add(lorg)
                    l = par[lorg]
                    while l != lorg:
                        cycle.add(l)
                        l = par[l]
                    break

            return findcycle, cycle

        def chuLiuEdmonds():
            par = np.zeros([length], dtype=np.int32)
            # create best graph
            par[0] = -1
            for i in range(1, length):
                # only interested at current nodes
                if curr_nodes[i]:
                    max_score = score_matrix[0, i]
                    par[i] = 0
                    for j in range(1, length):
                        if j == i or not curr_nodes[j]:
                            continue
                        new_score = score_matrix[j, i]
                        if new_score > max_score:
                            max_score = new_score
                            par[i] = j

            # find a cycle
            findcycle, cycle = find_cycle(par)
            # no cycles, get all edges and return them.
            if not findcycle:
                final_edges[0] = -1
                for i in range(1, length):
                    if not curr_nodes[i]:
                        continue
                    pr = oldI[par[i], i]
                    ch = oldO[par[i], i]
                    final_edges[ch] = pr
                return

            cyc_len = len(cycle)
            cyc_weight = 0.0
            cyc_nodes = np.zeros([cyc_len], dtype=np.int32)
            id = 0
            for cyc_node in cycle:
                cyc_nodes[id] = cyc_node
                id += 1
                cyc_weight += score_matrix[par[cyc_node], cyc_node]

            rep = cyc_nodes[0]
            for i in range(length):
                if not curr_nodes[i] or i in cycle:
                    continue

                max1 = float("-inf")
                wh1 = -1
                max2 = float("-inf")
                wh2 = -1

                for j in range(cyc_len):
                    j1 = cyc_nodes[j]
                    if score_matrix[j1, i] > max1:
                        max1 = score_matrix[j1, i]
                        wh1 = j1
                    scr = cyc_weight + score_matrix[i,
                                                    j1] - score_matrix[par[j1],
                                                                       j1]
                    if scr > max2:
                        max2 = scr
                        wh2 = j1

                score_matrix[rep, i] = max1
                oldI[rep, i] = oldI[wh1, i]
                oldO[rep, i] = oldO[wh1, i]
                score_matrix[i, rep] = max2
                oldO[i, rep] = oldO[i, wh2]
                oldI[i, rep] = oldI[i, wh2]

            rep_cons = []
            for i in range(cyc_len):
                rep_cons.append(set())
                cyc_node = cyc_nodes[i]
                for cc in reps[cyc_node]:
                    rep_cons[i].add(cc)

            for i in range(1, cyc_len):
                cyc_node = cyc_nodes[i]
                curr_nodes[cyc_node] = False
                for cc in reps[cyc_node]:
                    reps[rep].add(cc)

            chuLiuEdmonds()

            # check each node in cycle, if one of its representatives is a key in the final_edges, it is the one.
            found = False
            wh = -1
            for i in range(cyc_len):
                for repc in rep_cons[i]:
                    if repc in final_edges:
                        wh = cyc_nodes[i]
                        found = True
                        break
                if found:
                    break

            l = par[wh]
            while l != wh:
                ch = oldO[par[l], l]
                pr = oldI[par[l], l]
                final_edges[ch] = pr
                l = par[l]

        if labeled:
            assert energies.ndim == 4, 'dimension of energies is not equal to 4'
        else:
            assert energies.ndim == 3, 'dimension of energies is not equal to 3'
        input_shape = energies.shape
        batch_size = input_shape[0]
        max_length = input_shape[2]

        pars = np.zeros([batch_size, max_length], dtype=np.int32)
        types = np.zeros([batch_size, max_length],
                         dtype=np.int32) if labeled else None
        for i in range(batch_size):
            energy = energies[i]

            # calc the realy length of this instance
            length = lengths[i]

            # calc real energy matrix shape = [length, length, num_labels - #symbolic] (remove the label for symbolic types).
            if labeled:
                energy = energy[leading_symbolic:, :length, :length]
                # get best label for each edge.
                label_id_matrix = energy.argmax(axis=0) + leading_symbolic
                energy = energy.max(axis=0)
            else:
                energy = energy[:length, :length]
                label_id_matrix = None
            # get original score matrix
            orig_score_matrix = energy
            # initialize score matrix to original score matrix
            score_matrix = np.array(orig_score_matrix, copy=True)

            oldI = np.zeros([length, length], dtype=np.int32)
            oldO = np.zeros([length, length], dtype=np.int32)
            curr_nodes = np.zeros([length], dtype=np.bool)
            reps = []

            for s in range(length):
                orig_score_matrix[s, s] = 0.0
                score_matrix[s, s] = 0.0
                curr_nodes[s] = True
                reps.append(set())
                reps[s].add(s)
                for t in range(s + 1, length):
                    oldI[s, t] = s
                    oldO[s, t] = t

                    oldI[t, s] = t
                    oldO[t, s] = s

            final_edges = dict()
            chuLiuEdmonds()
            par = np.zeros([max_length], np.int32)
            if labeled:
                type = np.ones([max_length], np.int32)
                type[0] = 0
            else:
                type = None

            for ch, pr in final_edges.items():
                par[ch] = pr
                if labeled and ch != 0:
                    type[ch] = label_id_matrix[pr, ch]

            par[0] = 0
            pars[i] = par
            if labeled:
                types[i] = type
        return pars, types

    def compute_las(self,
                    words,
                    postags,
                    heads_pred,
                    types_pred,
                    heads,
                    types,
                    word_dictionary,
                    pos_dictionary,
                    lengths,
                    punct_set=None,
                    symbolic_root=False,
                    symbolic_end=False):
        batch_size, _ = words.shape
        ucorr = 0.
        lcorr = 0.
        total = 0.
        ucomplete_match = 0.
        lcomplete_match = 0.

        corr_root = 0.
        total_root = 0.
        start = 1 if symbolic_root else 0
        end = 1 if symbolic_end else 0

        for i in range(batch_size):
            ucm = 1.
            lcm = 1.
            for j in range(start, lengths[i] - end):
                word = word_dictionary.get_instance(words[i, j])
                pos = pos_dictionary.get_instance(postags[i, j])

                total += 1
                if heads[i, j] == heads_pred[i, j]:
                    ucorr += 1
                    if types[i, j] == types_pred[i, j]:
                        lcorr += 1
                    else:
                        lcm = 0
                else:
                    ucm = 0
                    lcm = 0

            ucomplete_match += ucm
            lcomplete_match += lcm

        return (ucorr, lcorr, total, ucomplete_match,
                lcomplete_match), batch_size
예제 #7
0
class EBMNLPTagger(pl.LightningModule):
    def __init__(self, hparams):
        """
        input:
            hparams: namespace with the following items:
                'data_dir' (str): Data Directory. default: './official/ebm_nlp_1_00'
                'bioelmo_dir' (str): BioELMo Directory. default: './models/bioelmo', help='BioELMo Directory')
                'max_length' (int): Max Length. default: 1024
                'lr' (float): Learning Rate. default: 1e-2
                'fine_tune_bioelmo' (bool): Whether to Fine Tune BioELMo. default: False
                'lr_bioelmo' (float): Learning Rate in BioELMo Fine-tuning. default: 1e-4
        """
        super().__init__()
        self.hparams = hparams
        self.itol = ID_TO_LABEL
        self.ltoi = {v: k for k, v in self.itol.items()}

        # Load Pretrained BioELMo
        DIR_ELMo = Path(str(self.hparams.bioelmo_dir))
        self.bioelmo = Elmo(DIR_ELMo / 'biomed_elmo_options.json',
                            DIR_ELMo / 'biomed_elmo_weights.hdf5',
                            1,
                            requires_grad=bool(self.hparams.fine_tune_bioelmo),
                            dropout=0)
        self.bioelmo_output_dim = self.bioelmo.get_output_dim()

        # ELMo Padding token (In ELMo token with ID 0 is used for padding)
        VOCAB_FILE_PATH = DIR_ELMo / 'vocab.txt'
        command = shlex.split(f"head -n 1 {VOCAB_FILE_PATH}")
        res = subprocess.Popen(command, stdout=subprocess.PIPE)
        self.bioelmo_pad_token = res.communicate()[0].decode('utf-8').strip()

        # Initialize Intermediate Affine Layer
        self.hidden_to_tag = nn.Linear(int(self.bioelmo_output_dim),
                                       len(self.itol))

        # Initialize CRF
        TRANSITIONS = conditional_random_field.allowed_transitions(
            constraint_type='BIO', labels=self.itol)
        self.crf = conditional_random_field.ConditionalRandomField(
            # set to 7 because here "tags" means ['O', 'B-P', 'I-P', 'B-I', 'I-I', 'B-O', 'I-O']
            # no need to include 'BOS' and 'EOS' in "tags"
            num_tags=len(self.itol),
            constraints=TRANSITIONS,
            include_start_end_transitions=False)
        self.crf.reset_parameters()

    def get_device(self):
        return self.crf.state_dict()['transitions'].device

    def _forward_crf(self, hidden, gold_tags_padded, crf_mask):
        """
        input:
            hidden (torch.tensor) (n_batch, seq_length, hidden_dim)
            gold_tags_padded (torch.tensor) (n_batch, seq_length)
            crf_mask (torch.bool) (n_batch, seq_length)
        output:
            result (dict)
                'log_likelihood' : torch.tensor
                'pred_tags_packed' : torch.nn.utils.rnn.PackedSequence
                'gold_tags_padded' : torch.tensor
        """
        result = {}

        if gold_tags_padded is not None:
            # Log likelihood
            log_prob = self.crf.forward(hidden, gold_tags_padded, crf_mask)

            # top k=1 tagging
            Y = [
                torch.tensor(result[0])
                for result in self.crf.viterbi_tags(logits=hidden,
                                                    mask=crf_mask)
            ]
            Y = rnn.pack_sequence(Y, enforce_sorted=False)

            result['log_likelihood'] = log_prob
            result['pred_tags_packed'] = Y
            result['gold_tags_padded'] = gold_tags_padded
            return result

        else:
            # top k=1 tagging
            Y = [
                torch.tensor(result[0])
                for result in self.crf.viterbi_tags(logits=hidden,
                                                    mask=crf_mask)
            ]
            Y = rnn.pack_sequence(Y, enforce_sorted=False)
            result['pred_tags_packed'] = Y
            return result

    def forward(self, tokens, gold_tags=None):
        """
        input:
            hidden (torch.tensor) (n_batch, seq_length, hidden_dim)
            gold_tags_padded (torch.tensor) (n_batch, seq_length)
            crf_mask (torch.bool) (n_batch, seq_length)
        output:
            result (dict)
                'log_likelihood' : torch.tensor
                'pred_tags_packed' : torch.nn.utils.rnn.PackedSequence
                'gold_tags_padded' : torch.tensor
        """
        # character_ids: torch.tensor(n_batch, len_max)
        character_ids = batch_to_ids(tokens)
        character_ids = character_ids[:, :self.hparams.max_length, :]
        character_ids = character_ids.to(self.get_device())

        # characted_ids -> BioELMo hidden state of the last layer & mask
        out = self.bioelmo(character_ids)

        # Turn on gradient tracking
        # Affine transformation (Hidden_dim -> N_tag)
        hidden = out['elmo_representations'][-1]
        hidden.requires_grad_()
        hidden = self.hidden_to_tag(hidden)

        crf_mask = out['mask'].to(torch.bool).to(self.get_device())

        if gold_tags is not None:
            gold_tags = [torch.tensor(seq) for seq in gold_tags]
            gold_tags_padded = rnn.pad_sequence(gold_tags,
                                                batch_first=True,
                                                padding_value=self.ltoi['O'])
            gold_tags_padded = gold_tags_padded[:, :self.hparams.max_length]
            gold_tags_padded = gold_tags_padded.to(self.get_device())
        else:
            gold_tags_padded = None

        result = self._forward_crf(hidden, gold_tags_padded, crf_mask)
        return result

    def step(self, batch, batch_nb, *optimizer_idx):
        tokens_nopad = batch['tokens']
        tags_nopad = batch['tags']

        # Negative Log Likelihood
        result = self.forward(tokens_nopad, tags_nopad)
        returns = {
            'loss': result['log_likelihood'] * (-1.0),
            'T': result['gold_tags_padded'],
            'Y': result['pred_tags_packed'],
            'I': batch['pmid']
        }
        return returns

    def unpack_pred_tags(self, Y_packed):
        """
        input:
            Y_packed: torch.nn.utils.rnn.PackedSequence
        output:
            Y: list(list(str))
                Predicted NER tagging sequence.
        """
        Y_padded, Y_len = rnn.pad_packed_sequence(Y_packed,
                                                  batch_first=True,
                                                  padding_value=-1)
        Y_padded = Y_padded.numpy().tolist()
        Y_len = Y_len.numpy().tolist()

        # Replace B- tag with I- tag because the original paper defines the NER task as identification of spans, not entities
        Y = [[self.itol[ix].replace('B-', 'I-') for ix in ids[:length]]
             for ids, length in zip(Y_padded, Y_len)]

        return Y

    def unpack_gold_and_pred_tags(self, T_padded, Y_packed):
        """
        input:
            T_padded: torch.tensor
            Y_packed: torch.nn.utils.rnn.PackedSequence
        output:
            T: list(list(str))
                Gold NER tagging sequence.
            Y: list(list(str))
                Predicted NER tagging sequence.
        """
        Y = self.unpack_pred_tags(Y_packed)
        Y_len = [len(seq) for seq in Y]

        T_padded = T_padded.numpy().tolist()

        # Replace B- tag with I- tag because the original paper defines the NER task as identification of spans, not entities
        T = [[self.itol[ix].replace('B-', 'I-') for ix in ids[:length]]
             for ids, length in zip(T_padded, Y_len)]

        return T, Y

    def gather_outputs(self, outputs):
        if len(outputs) > 1:
            loss = torch.mean(
                torch.tensor([output['loss'] for output in outputs]))
        else:
            loss = torch.mean(outputs[0]['loss'])

        I = []
        Y = []
        T = []

        for output in outputs:
            T_batch, Y_batch = self.unpack_gold_and_pred_tags(
                output['T'].cpu(), output['Y'].cpu())
            T += T_batch
            Y += Y_batch
            I += output['I'].cpu().numpy().tolist()

        returns = {'loss': loss, 'T': T, 'Y': Y, 'I': I}

        return returns

    def training_step(self, batch, batch_nb, *optimizer_idx):
        # Process on individual mini-batches
        """
        (batch) -> (dict or OrderedDict)
        # Caution: key for loss function must exactly be 'loss'.
        """
        return self.step(batch, batch_nb, *optimizer_idx)

    def training_step_end(self, outputs):
        """
        outputs(dict) -> loss(dict or OrderedDict)
        # Caution: key must exactly be 'loss'.
        """
        loss = torch.mean(outputs['loss'])

        progress_bar = {'train_loss': loss}
        returns = {
            'loss': loss,
            'T': outputs['T'],
            'Y': outputs['Y'],
            'I': outputs['I'],
            'progress_bar': progress_bar
        }
        return returns

    def training_epoch_end(self, outputs):
        """
        outputs(list of dict) -> loss(dict or OrderedDict)
        # Caution: key must exactly be 'loss'.
        """
        outs = self.gather_outputs(outputs)
        loss = outs['loss']
        I = outs['I']
        Y = outs['Y']
        T = outs['T']

        get_logger(self.hparams.version).info(
            f'========== Training Epoch {self.current_epoch} ==========')
        get_logger(self.hparams.version).info(f'Loss: {loss.item()}')
        get_logger(self.hparams.version).info(
            f'Entity-wise classification report\n{seq_classification_report(T, Y, 4)}'
        )
        get_logger(self.hparams.version).info(
            f'Token-wise classification report\n{span_classification_report(T, Y, 4)}'
        )

        progress_bar = {'train_loss': loss}
        returns = {'loss': loss, 'progress_bar': progress_bar}
        return returns

    def validation_step(self, batch, batch_nb):
        # Process on individual mini-batches
        """
        (batch) -> (dict or OrderedDict)
        """
        return self.step(batch, batch_nb)

    def validation_end(self, outputs):
        """
        For single dataloader:
            outputs(list of dict) -> (dict or OrderedDict)
        For multiple dataloaders:
            outputs(list of (list of dict)) -> (dict or OrderedDict)
        """
        outs = self.gather_outputs(outputs)
        loss = outs['loss']
        I = outs['I']
        Y = outs['Y']
        T = outs['T']

        get_logger(self.hparams.version).info(
            f'========== Validation Epoch {self.current_epoch} ==========')
        get_logger(self.hparams.version).info(f'Loss: {loss.item()}')
        get_logger(self.hparams.version).info(
            f'Entity-wise classification report\n{seq_classification_report(T, Y, 4)}'
        )
        get_logger(self.hparams.version).info(
            f'Token-wise classification report\n{span_classification_report(T, Y, 4)}'
        )

        progress_bar = {'val_loss': loss}
        returns = {'loss': loss, 'progress_bar': progress_bar}
        return returns

    def test_step(self, batch, batch_nb):
        # Process on individual mini-batches
        """
        (batch) -> (dict or OrderedDict)
        """
        return self.step(batch, batch_nb)

    def test_epoch_end(self, outputs):
        """
        For single dataloader:
            outputs(list of dict) -> (dict or OrderedDict)
        For multiple dataloaders:
            outputs(list of (list of dict)) -> (dict or OrderedDict)
        """
        outs = self.gather_outputs(outputs)
        loss = outs['loss']
        I = outs['I']
        Y = outs['Y']
        T = outs['T']

        get_logger(self.hparams.version).info(f'========== Test ==========')
        get_logger(self.hparams.version).info(f'Loss: {loss.item()}')
        get_logger(self.hparams.version).info(
            f'Entity-wise classification report\n{seq_classification_report(T, Y, 4)}'
        )
        get_logger(self.hparams.version).info(
            f'Token-wise classification report\n{span_classification_report(T, Y, 4)}'
        )

        progress_bar = {'test_loss': loss}
        returns = {'loss': loss, 'progress_bar': progress_bar}
        return returns

    def configure_optimizers(self):
        if self.hparams.fine_tune_bioelmo:
            optimizer_bioelmo_1 = optim.Adam(self.bioelmo.parameters(),
                                             lr=float(self.hparams.lr_bioelmo))
            optimizer_bioelmo_2 = optim.Adam(self.hidden_to_tag.parameters(),
                                             lr=float(self.hparams.lr_bioelmo))
            optimizer_crf = optim.Adam(self.crf.parameters(),
                                       lr=float(self.hparams.lr))
            return [optimizer_bioelmo_1, optimizer_bioelmo_2, optimizer_crf]
        else:
            optimizer = optim.Adam(self.parameters(),
                                   lr=float(self.hparams.lr))
            return optimizer

    def train_dataloader(self):
        ds_train_val = EBMNLPDataset(
            *path_finder(self.hparams.data_dir)['train'])

        ds_train, _ = train_test_split(ds_train_val,
                                       train_size=0.8,
                                       random_state=self.hparams.random_state)
        dl_train = EBMNLPDataLoader(ds_train,
                                    batch_size=self.hparams.batch_size,
                                    shuffle=True)
        return dl_train

    def val_dataloader(self):
        ds_train_val = EBMNLPDataset(
            *path_finder(self.hparams.data_dir)['train'])

        _, ds_val = train_test_split(ds_train_val,
                                     train_size=0.8,
                                     random_state=self.hparams.random_state)
        dl_val = EBMNLPDataLoader(ds_val,
                                  batch_size=self.hparams.batch_size,
                                  shuffle=False)
        return dl_val

    def test_dataloader(self):
        ds_test = EBMNLPDataset(*path_finder(self.hparams.data_dir)['test'])
        dl_test = EBMNLPDataLoader(ds_test,
                                   batch_size=self.hparams.batch_size,
                                   shuffle=False)
        return dl_test
예제 #8
0
class ContextualControllerELMo(ControllerBase):
    def __init__(
            self,
            hidden_size,
            dropout,
            pretrained_embeddings_dir,
            dataset_name,
            fc_hidden_size=150,
            freeze_pretrained=True,
            learning_rate=0.001,
            layer_learning_rate: Optional[Dict[str, float]] = None,
            max_segment_size=None,  # if None, process sentences independently
            max_span_size=10,
            model_name=None):
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.freeze_pretrained = freeze_pretrained
        self.fc_hidden_size = fc_hidden_size
        self.max_span_size = max_span_size
        self.max_segment_size = max_segment_size
        self.learning_rate = learning_rate
        self.layer_learning_rate = layer_learning_rate if layer_learning_rate is not None else {}

        self.pretrained_embeddings_dir = pretrained_embeddings_dir
        self.embedder = Elmo(
            options_file=os.path.join(pretrained_embeddings_dir,
                                      "options.json"),
            weight_file=os.path.join(pretrained_embeddings_dir,
                                     "slovenian-elmo-weights.hdf5"),
            dropout=(0.0 if freeze_pretrained else dropout),
            num_output_representations=1,
            requires_grad=(not freeze_pretrained)).to(DEVICE)
        embedding_size = self.embedder.get_output_dim()

        self.context_encoder = nn.LSTM(input_size=embedding_size,
                                       hidden_size=hidden_size,
                                       batch_first=True,
                                       bidirectional=True).to(DEVICE)
        self.scorer = NeuralCoreferencePairScorer(num_features=(2 *
                                                                hidden_size),
                                                  hidden_size=fc_hidden_size,
                                                  dropout=dropout).to(DEVICE)
        params_to_update = [{
            "params":
            self.scorer.parameters(),
            "lr":
            self.layer_learning_rate.get("lr_scorer", self.learning_rate)
        }, {
            "params":
            self.context_encoder.parameters(),
            "lr":
            self.layer_learning_rate.get("lr_context_encoder",
                                         self.learning_rate)
        }]
        if not freeze_pretrained:
            params_to_update.append({
                "params":
                self.embedder.parameters(),
                "lr":
                self.layer_learning_rate.get("lr_embedder", self.learning_rate)
            })

        self.optimizer = optim.Adam(params_to_update, lr=self.learning_rate)

        super().__init__(learning_rate=learning_rate,
                         dataset_name=dataset_name,
                         model_name=model_name)
        logging.info(
            f"Initialized contextual ELMo-based model with name {self.model_name}."
        )

    @property
    def model_base_dir(self):
        return "contextual_model_elmo"

    def train_mode(self):
        if not self.freeze_pretrained:
            self.embedder.train()
        self.context_encoder.train()
        self.scorer.train()

    def eval_mode(self):
        self.embedder.eval()
        self.context_encoder.eval()
        self.scorer.eval()

    def load_checkpoint(self):
        self.loaded_from_file = True
        self.context_encoder.load_state_dict(
            torch.load(os.path.join(self.path_model_dir, "context_encoder.th"),
                       map_location=DEVICE))
        self.scorer.load_state_dict(
            torch.load(os.path.join(self.path_model_dir, "scorer.th"),
                       map_location=DEVICE))

        path_to_embeddings = os.path.join(self.path_model_dir, "embeddings.th")
        if os.path.isfile(path_to_embeddings):
            logging.info(
                f"Loading fine-tuned ELMo weights from '{path_to_embeddings}'")
            self.embedder.load_state_dict(
                torch.load(path_to_embeddings, map_location=DEVICE))

    @staticmethod
    def from_pretrained(model_dir):
        controller_config_path = os.path.join(model_dir,
                                              "controller_config.json")
        with open(controller_config_path, "r", encoding="utf-8") as f_config:
            pre_config = json.load(f_config)

        instance = ContextualControllerELMo(**pre_config)
        instance.load_checkpoint()

        return instance

    def save_pretrained(self, model_dir):
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

        # Write controller config (used for instantiation)
        controller_config_path = os.path.join(model_dir,
                                              "controller_config.json")
        with open(controller_config_path, "w", encoding="utf-8") as f_config:
            json.dump(
                {
                    "hidden_size": self.hidden_size,
                    "dropout": self.dropout,
                    "pretrained_embeddings_dir":
                    self.pretrained_embeddings_dir,
                    "dataset_name": self.dataset_name,
                    "fc_hidden_size": self.fc_hidden_size,
                    "freeze_pretrained": self.freeze_pretrained,
                    "learning_rate": self.learning_rate,
                    "layer_learning_rate": self.layer_learning_rate,
                    "max_segment_size": self.max_segment_size,
                    "max_span_size": self.max_span_size,
                    "model_name": self.model_name
                },
                fp=f_config,
                indent=4)

        torch.save(self.context_encoder.state_dict(),
                   os.path.join(self.path_model_dir, "context_encoder.th"))
        torch.save(self.scorer.state_dict(),
                   os.path.join(self.path_model_dir, "scorer.th"))

        # Save fine-tuned ELMo embeddings only if they're not frozen
        if not self.freeze_pretrained:
            torch.save(self.embedder.state_dict(),
                       os.path.join(self.path_model_dir, "embeddings.th"))

    def save_checkpoint(self):
        logging.warning(
            "save_checkpoint() is deprecated. Use save_pretrained() instead")
        self.save_pretrained(self.path_model_dir)

    def _prepare_doc(self, curr_doc: Document) -> Dict:
        """ Returns a cache dictionary with preprocessed data. This should only be called once per document, since
        data inside same document does not get shuffled. """
        ret = {}

        # By default, each sentence is its own segment, meaning sentences are processed independently
        if self.max_segment_size is None:

            def get_position(t):
                return t.sentence_index, t.position_in_sentence

            _encoded_segments = batch_to_ids(curr_doc.raw_sentences())
        # Optionally, one can specify max_segment_size, in which case segments of tokens are processed independently
        else:

            def get_position(t):
                doc_position = t.position_in_document
                return doc_position // self.max_segment_size, doc_position % self.max_segment_size

            flattened_doc = list(chain(*curr_doc.raw_sentences()))
            num_segments = (len(flattened_doc) + self.max_segment_size -
                            1) // self.max_segment_size
            _encoded_segments = \
                batch_to_ids([flattened_doc[idx_seg * self.max_segment_size: (idx_seg + 1) * self.max_segment_size]
                              for idx_seg in range(num_segments)])

        encoded_segments = []
        # Convention: Add a PAD word ([0] * max_chars vector) at the end of each segment, for padding mentions
        for curr_sent in _encoded_segments:
            encoded_segments.append(
                torch.cat((curr_sent,
                           torch.zeros(
                               (1, ELMoCharacterMapper.max_word_length),
                               dtype=torch.long))))
        encoded_segments = torch.stack(encoded_segments)

        cluster_sets = []
        mention_to_cluster_id = {}
        for i, curr_cluster in enumerate(curr_doc.clusters):
            cluster_sets.append(set(curr_cluster))
            for mid in curr_cluster:
                mention_to_cluster_id[mid] = i

        all_candidate_data = []
        for idx_head, (head_id,
                       head_mention) in enumerate(curr_doc.mentions.items(),
                                                  1):
            gt_antecedent_ids = cluster_sets[mention_to_cluster_id[head_id]]

            # Note: no data for dummy antecedent (len(`features`) is one less than `candidates`)
            candidates, candidate_data = [None], []
            candidate_attention = []
            correct_antecedents = []

            curr_head_data = [[], []]
            num_head_words = 0
            for curr_token in head_mention.tokens:
                idx_segment, idx_inside_segment = get_position(curr_token)
                curr_head_data[0].append(idx_segment)
                curr_head_data[1].append(idx_inside_segment)
                num_head_words += 1

            if num_head_words > self.max_span_size:
                curr_head_data[0] = curr_head_data[0][:self.max_span_size]
                curr_head_data[1] = curr_head_data[1][:self.max_span_size]
            else:
                curr_head_data[0] += [curr_head_data[0][-1]
                                      ] * (self.max_span_size - num_head_words)
                curr_head_data[1] += [-1
                                      ] * (self.max_span_size - num_head_words)

            head_attention = torch.ones((1, self.max_span_size),
                                        dtype=torch.bool)
            head_attention[0, num_head_words:] = False

            for idx_candidate, (cand_id, cand_mention) in enumerate(
                    curr_doc.mentions.items(), start=1):
                if idx_candidate >= idx_head:
                    break

                candidates.append(cand_id)

                # Maps tokens to positions inside segments (idx_seg, idx_inside_seg) for efficient indexing later
                curr_candidate_data = [[], []]
                num_candidate_words = 0
                for curr_token in cand_mention.tokens:
                    idx_segment, idx_inside_segment = get_position(curr_token)
                    curr_candidate_data[0].append(idx_segment)
                    curr_candidate_data[1].append(idx_inside_segment)
                    num_candidate_words += 1

                if num_candidate_words > self.max_span_size:
                    curr_candidate_data[0] = curr_candidate_data[
                        0][:self.max_span_size]
                    curr_candidate_data[1] = curr_candidate_data[
                        1][:self.max_span_size]
                else:
                    # padding tokens index into the PAD token of the last segment
                    curr_candidate_data[0] += [curr_candidate_data[0][-1]] * (
                        self.max_span_size - num_candidate_words)
                    curr_candidate_data[1] += [-1] * (self.max_span_size -
                                                      num_candidate_words)

                candidate_data.append(curr_candidate_data)
                curr_attention = torch.ones((1, self.max_span_size),
                                            dtype=torch.bool)
                curr_attention[0, num_candidate_words:] = False
                candidate_attention.append(curr_attention)

                is_coreferent = cand_id in gt_antecedent_ids
                if is_coreferent:
                    correct_antecedents.append(idx_candidate)

            if len(correct_antecedents) == 0:
                correct_antecedents.append(0)

            candidate_attention = torch.cat(
                candidate_attention) if len(candidate_attention) > 0 else []
            all_candidate_data.append({
                "head_id":
                head_id,
                "head_data":
                torch.tensor([curr_head_data]),
                "head_attention":
                head_attention,
                "candidates":
                candidates,
                "candidate_data":
                torch.tensor(candidate_data),
                "candidate_attention":
                candidate_attention,
                "correct_antecedents":
                correct_antecedents
            })

        ret["preprocessed_segments"] = encoded_segments
        ret["steps"] = all_candidate_data

        return ret

    def _train_doc(self, curr_doc, eval_mode=False):
        """ Trains/evaluates (if `eval_mode` is True) model on specific document.
            Returns predictions, loss and number of examples evaluated. """

        if len(curr_doc.mentions) == 0:
            return {}, (0.0, 0)

        if not hasattr(curr_doc, "_cache_elmo"):
            curr_doc._cache_elmo = self._prepare_doc(curr_doc)
        cache = curr_doc._cache_elmo  # type: Dict

        encoded_segments = cache["preprocessed_segments"]
        if self.freeze_pretrained:
            with torch.no_grad():
                res = self.embedder(encoded_segments.to(DEVICE))
        else:
            res = self.embedder(encoded_segments.to(DEVICE))

        # Note: max_segment_size is either specified at instantiation or (the length of longest sentence + 1)
        embedded_segments = res["elmo_representations"][
            0]  # [num_segments, max_segment_size, embedding_size]
        (lstm_segments, _) = self.context_encoder(
            embedded_segments
        )  # [num_segments, max_segment_size, 2 * hidden_size]

        doc_loss, n_examples = 0.0, len(cache["steps"])
        preds = {}

        for curr_step in cache["steps"]:
            head_id = curr_step["head_id"]
            head_data = curr_step["head_data"]

            candidates = curr_step["candidates"]
            candidate_data = curr_step["candidate_data"]
            correct_antecedents = curr_step["correct_antecedents"]

            # Note: num_candidates includes dummy antecedent + actual candidates
            num_candidates = len(candidates)
            if num_candidates == 1:
                curr_pred = 0
            else:
                idx_segment = candidate_data[:, 0, :]
                idx_in_segment = candidate_data[:, 1, :]

                # [num_candidates, max_span_size, embedding_size]
                candidate_data = lstm_segments[idx_segment, idx_in_segment]
                # [1, head_size, embedding_size]
                head_data = lstm_segments[head_data[:, 0, :], head_data[:,
                                                                        1, :]]
                head_data = head_data.repeat((num_candidates - 1, 1, 1))

                candidate_scores = self.scorer(
                    candidate_data, head_data,
                    curr_step["candidate_attention"],
                    curr_step["head_attention"].repeat(
                        (num_candidates - 1, 1)))

                # [1, num_candidates]
                candidate_scores = torch.cat(
                    (torch.tensor([0.0], device=DEVICE),
                     candidate_scores.flatten())).unsqueeze(0)

                curr_pred = torch.argmax(candidate_scores)
                doc_loss += self.loss(
                    candidate_scores.repeat((len(correct_antecedents), 1)),
                    torch.tensor(correct_antecedents, device=DEVICE))

            # { antecedent: [mention(s)] } pair
            existing_refs = preds.get(candidates[int(curr_pred)], [])
            existing_refs.append(head_id)
            preds[candidates[int(curr_pred)]] = existing_refs

        if not eval_mode:
            doc_loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

        return preds, (float(doc_loss), n_examples)