Beispiel #1
0
    def __init__(self, config):
        super(NNCRF, self).__init__()
        self.config = config
        self.label_size = config.label_size
        self.device = config.device
        self.use_char = config.use_char_rnn
        self.context_emb = config.context_emb
        '''
        task specific:
        '''
        self.is_base = config.is_base

        self.label2idx = config.label2idx
        self.labels = config.idx2labels
        self.start_idx = self.label2idx[START]
        self.end_idx = self.label2idx[STOP]
        self.pad_idx = self.label2idx[PAD]

        self.input_size = config.embedding_dim

        if config.is_base:
            if self.context_emb != ContextEmb.none:
                self.input_size += config.context_emb_size
            if self.use_char:
                self.char_feature = CharBiLSTM(config)
                self.input_size += config.charlstm_hidden_dim

            vocab_size = len(config.word2idx)
            self.word_embedding = nn.Embedding.from_pretrained(
                torch.FloatTensor(config.word_embedding),
                freeze=False).to(self.device)
            self.word_drop = nn.Dropout(config.dropout).to(self.device)
        else:
            self.input_size = 350

        print("[Model Info] Input size to LSTM: {}".format(self.input_size))
        print("[Model Info] LSTM Hidden Size: {}".format(config.hidden_dim))

        self.lstm = nn.LSTM(self.input_size,
                            config.hidden_dim // 2,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True).to(self.device)

        self.drop_lstm = nn.Dropout(config.dropout).to(self.device)

        final_hidden_dim = config.hidden_dim

        print("[Model Info] Final Hidden Size: {}".format(final_hidden_dim))
        self.hidden2tag = nn.Linear(final_hidden_dim,
                                    self.label_size).to(self.device)

        init_transition = torch.randn(self.label_size,
                                      self.label_size).to(self.device)
        init_transition[:, self.start_idx] = -10000.0
        init_transition[self.end_idx, :] = -10000.0
        init_transition[:, self.pad_idx] = -10000.0
        init_transition[self.pad_idx, :] = -10000.0

        self.transition = nn.Parameter(init_transition)
    def __init__(self, config: Config):
        super(BertCharEncoder, self).__init__()
        self.config = config
        self.device = config.device
        self.use_char = config.use_char_rnn
        self.context_emb = config.context_emb
        self.input_size = 0
        self.input_size += config.bert_embedding_size
        if self.context_emb != ContextEmb.none:
            self.input_size += config.context_emb_size

        if self.use_char:
            self.char_feature = CharBiLSTM(config)
            self.input_size += config.charlstm_hidden_dim
        self.bert = AutoModel.from_pretrained(config.bert_path).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(config.bert_path)

        self.word_drop = nn.Dropout(config.dropout).to(self.device)
        if config.rnn == 'lstm':
            self.rnn = nn.LSTM(self.input_size,
                               config.hidden_dim,
                               num_layers=1,
                               batch_first=True,
                               bidirectional=True).to(self.device)
        else:
            self.rnn = nn.GRU(self.input_size,
                              config.hidden_dim,
                              num_layers=1,
                              batch_first=True,
                              bidirectional=True).to(self.device)
Beispiel #3
0
    def __init__(self, config, encoder=None):
        super(SoftEncoder, self).__init__()
        self.config = config
        self.device = config.device
        self.use_char = config.use_char_rnn
        self.context_emb = config.context_emb
        self.input_size = config.embedding_dim

        if self.context_emb != ContextEmb.none:
            self.input_size += config.context_emb_size
        if self.use_char:
            self.char_feature = CharBiLSTM(config)
            self.input_size += config.charlstm_hidden_dim

        self.word_embedding = nn.Embedding.from_pretrained(
            torch.FloatTensor(config.word_embedding),
            freeze=False).to(self.device)
        self.word_drop = nn.Dropout(config.dropout).to(self.device)
        self.lstm = nn.LSTM(self.input_size,
                            config.hidden_dim // 2,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True).to(self.device)

        if encoder is not None:
            if self.use_char:
                self.char_feature = encoder.char_feature
            self.word_embedding = encoder.word_embedding
            self.word_drop = encoder.word_drop
            self.lstm = encoder.lstm
Beispiel #4
0
    def __init__(self, data):
        super(BiLSTM, self).__init__()
        print ("build batched bilstm...")
        self.gpu = data.HP_gpu
        self.use_gloss = data.HP_use_gloss
        self.use_entity = data.HP_use_entity
        self.use_gaz = data.HP_use_gaz
        self.batch_size = data.HP_batch_size
        self.gloss_hidden_dim = 0
        self.embedding_dim = data.word_emb_dim
        self.gloss_hidden_dim = data.gloss_hidden_dim
        self.gloss_drop = data.HP_dropout
        self.drop = nn.Dropout(data.HP_dropout)
        self.word_embeddings = nn.Embedding(data.word_alphabet.size(), self.embedding_dim)
        if data.pretrain_word_embedding is not None:
            self.word_embeddings.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding))
        else:
            self.word_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(data.word_alphabet.size(), self.embedding_dim)))
        if self.use_entity:
            self.entity_embeddings = nn.Embedding(data.entity_alphabet.size(), data.entity_emb_dim)
            self.entity_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(data.entity_alphabet.size(), data.entity_emb_dim)))
        if self.use_gloss:
            self.gloss_hidden_dim = data.gloss_hidden_dim
            self.gloss_embedding_dim = data.gloss_emb_dim
            if data.gloss_features == "CNN":
                self.gloss_feature = CNN(data,input_dim=data.gloss_emb_dim,hidden_dim=self.gloss_hidden_dim,dropout=self.gloss_drop)
                # self.gloss_feature = CharCNN(data)#data.gloss_alphabet.size(), self.gloss_embedding_dim, self.gloss_hidden_dim, data.HP_dropout, self.gpu)
            elif data.gloss_features == "LSTM":
                self.gloss_feature = CharBiLSTM(data.gloss_alphabet.size(), self.gloss_embedding_dim, self.gloss_hidden_dim, data.HP_dropout, self.gpu)
            else:
                print ("Error gloss feature selection, please check parameter data.gloss_features (either CNN or LSTM).")
                exit(0)
        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.bilstm_flag = data.HP_bilstm
        self.lstm_layer = data.HP_lstm_layer
        self.droplstm = nn.Dropout(data.HP_dropout)
        if self.bilstm_flag:
            lstm_hidden_dim = data.HP_lstm_hidden_dim // 2
        else:
            lstm_hidden_dim = data.HP_lstm_hidden_dim
        lstm_input_dim = self.embedding_dim + self.gloss_hidden_dim
        self.forward_lstm = LatticeLSTM(lstm_input_dim, lstm_hidden_dim, data.gaz_dropout, data.gaz_alphabet.size(), data.gaz_emb_dim, data.pretrain_gaz_embedding, left2right=True, fix_word_emb=data.HP_fix_gaz_emb, gpu=self.gpu)
        if self.bilstm_flag:
            self.backward_lstm = LatticeLSTM(lstm_input_dim, lstm_hidden_dim, data.gaz_dropout, data.gaz_alphabet.size(), data.gaz_emb_dim, data.pretrain_gaz_embedding, left2right=False, fix_word_emb=data.HP_fix_gaz_emb, gpu=self.gpu)
        # self.lstm = nn.LSTM(lstm_input_dim, lstm_hidden_dim, num_layers=self.lstm_layer, batch_first=True, bidirectional=self.bilstm_flag)

        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(data.HP_lstm_hidden_dim, data.label_alphabet_size)

        if self.gpu:
            self.drop = self.drop.cuda()
            self.droplstm = self.droplstm.cuda()
            self.word_embeddings = self.word_embeddings.cuda()
            if self.use_entity:
                self.entity_embeddings = self.entity_embeddings.cuda()
            self.forward_lstm = self.forward_lstm.cuda()
            if self.bilstm_flag:
                self.backward_lstm = self.backward_lstm.cuda()
            self.hidden2tag = self.hidden2tag.cuda()
Beispiel #5
0
    def __init__(self, config, print_info: bool = True):
        super(BiLSTMEncoder, self).__init__()

        self.label_size = config.label_size
        self.device = config.device
        self.use_char = config.use_char_rnn
        self.context_emb = config.context_emb
        self.num_layers = config.num_layers
        self.label2idx = config.label2idx
        self.labels = config.idx2labels

        self.input_size = config.embedding_dim
        if self.context_emb != ContextEmb.none:
            self.input_size += config.context_emb_size
        if self.use_char:
            self.char_feature = CharBiLSTM(config, print_info=print_info)
            self.input_size += config.charlstm_hidden_dim

        self.word_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(config.word_embedding), freeze=False).to(self.device)
        self.word_drop = nn.Dropout(config.dropout).to(self.device)

        if print_info:
            print("[Model Info] Input size to LSTM: {}".format(self.input_size))
            print("[Model Info] LSTM Hidden Size: {}".format(config.hidden_dim))

        self.lstm = nn.LSTM(self.input_size, config.hidden_dim // 2, num_layers=self.num_layers, batch_first=True, bidirectional=True).to(self.device)

        self.drop_lstm = nn.Dropout(config.dropout).to(self.device)

        final_hidden_dim = config.hidden_dim

        if print_info:
            print("[Model Info] Final Hidden Size: {}".format(final_hidden_dim))

        self.hidden2tag = nn.Linear(final_hidden_dim, self.label_size).to(self.device)
Beispiel #6
0
class AttenLSTM(nn.Module):
    def __init__(self, vocab, num_classes, char_alphabet):
        super(AttenLSTM, self).__init__()
        self.embed_size = opt.word_emb_size
        self.embedding = vocab.init_embed_layer()
        self.hidden_size = opt.hidden_size
        self.char_hidden_dim = 10
        self.char_embedding_dim = 20

        self.char_feature = CharBiLSTM(len(char_alphabet), None,
                                       self.char_embedding_dim,
                                       self.char_hidden_dim, opt.dropout,
                                       opt.gpu)
        self.input_size = self.embed_size + self.char_hidden_dim

        self.W = nn.Linear(self.input_size, 1, bias=False)

        self.hidden = nn.Linear(self.input_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, num_classes)
        self.dropout = nn.Dropout(opt.dropout)

    def forward(self, input, char_inputs):
        """
		inputs: (unpacked_padded_output: batch_size x seq_len x hidden_size, lengths: batch_size)
		"""

        entity_words, entity_lengths, entity_seq_recover = input
        entity_words = autograd.Variable(entity_words)
        entity_words_embeds = self.embedding(entity_words)
        batch_size, max_len, _ = entity_words_embeds.size()

        char_inputs, char_seq_lengths, char_seq_recover = char_inputs
        char_features = self.char_feature.get_last_hiddens(
            char_inputs,
            char_seq_lengths.cpu().numpy())
        char_features = char_features[char_seq_recover]
        char_features = char_features.view(batch_size, max_len, -1)

        input_embeds = torch.cat((entity_words_embeds, char_features), 2)

        flat_input = input_embeds.contiguous().view(-1, self.input_size)
        logits = self.W(flat_input).view(batch_size, max_len)
        alphas = functional.softmax(logits, dim=1)

        # computing mask
        idxes = torch.arange(0, max_len,
                             out=torch.LongTensor(max_len)).unsqueeze(0).cuda(
                                 opt.gpu)
        mask = autograd.Variable((idxes < entity_lengths.unsqueeze(1)).float())

        alphas = alphas * mask
        alphas = alphas / torch.sum(alphas, 1).view(-1, 1)
        atten_input = torch.bmm(alphas.unsqueeze(1), input_embeds).squeeze(1)
        atten_input = self.dropout(atten_input)

        hidden = self.hidden(atten_input)
        output = self.out(hidden)
        return output
Beispiel #7
0
    def __init__(self, vocab, num_classes, char_alphabet):
        super(AttenLSTM, self).__init__()
        self.embed_size = opt.word_emb_size
        self.embedding = vocab.init_embed_layer()
        self.hidden_size = opt.hidden_size
        self.char_hidden_dim = 10
        self.char_embedding_dim = 20

        self.char_feature = CharBiLSTM(len(char_alphabet), None,
                                       self.char_embedding_dim,
                                       self.char_hidden_dim, opt.dropout,
                                       opt.gpu)
        self.input_size = self.embed_size + self.char_hidden_dim

        self.W = nn.Linear(self.input_size, 1, bias=False)

        self.hidden = nn.Linear(self.input_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, num_classes)
        self.dropout = nn.Dropout(opt.dropout)
Beispiel #8
0
    def __init__(self, vocab, num_classes, char_alphabet):
        super(BiLSTM_Attn, self).__init__()

        self.embed_size = opt.word_emb_size
        self.embedding = vocab.init_embed_layer()
        self.hidden_size = opt.hidden_size
        self.char_hidden_dim = 10
        self.char_embedding_dim = 20
        self.char_feature = CharBiLSTM(len(char_alphabet), None, self.char_embedding_dim, self.char_hidden_dim,
									opt.dropout, opt.gpu)

        self.input_size = self.embed_size + self.char_hidden_dim
        self.lstm_hidden = self.hidden_size // 2

        self.lstm = nn.LSTM(self.input_size, self.lstm_hidden, num_layers=1, batch_first=True, bidirectional=True)

        self.attn = DotAttentionLayer(self.hidden_size)

        self.hidden = nn.Linear(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, num_classes)
        self.dropout = nn.Dropout(opt.dropout)
Beispiel #9
0
class BiLSTM_Attn(nn.Module):
    def __init__(self, vocab, num_classes, char_alphabet):
        super(BiLSTM_Attn, self).__init__()

        self.embed_size = opt.word_emb_size
        self.embedding = vocab.init_embed_layer()
        self.hidden_size = opt.hidden_size
        self.char_hidden_dim = 10
        self.char_embedding_dim = 20
        self.char_feature = CharBiLSTM(len(char_alphabet), None, self.char_embedding_dim, self.char_hidden_dim,
									opt.dropout, opt.gpu)

        self.input_size = self.embed_size + self.char_hidden_dim
        self.lstm_hidden = self.hidden_size // 2

        self.lstm = nn.LSTM(self.input_size, self.lstm_hidden, num_layers=1, batch_first=True, bidirectional=True)

        self.attn = DotAttentionLayer(self.hidden_size)

        self.hidden = nn.Linear(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, num_classes)
        self.dropout = nn.Dropout(opt.dropout)

    def forward(self, input, char_inputs):
        entity_words, entity_lengths, entity_seq_recover = input

        entity_words_embeds = self.embedding(entity_words)
        batch_size, max_len, _ = entity_words_embeds.size()

        char_inputs, char_seq_lengths, char_seq_recover = char_inputs
        char_features = self.char_feature.get_last_hiddens(char_inputs, char_seq_lengths)
        char_features = char_features[char_seq_recover]
        char_features = char_features.view(batch_size, max_len, -1)

        input_embeds = torch.cat((entity_words_embeds, char_features), 2)

        packed_words = pack_padded_sequence(input_embeds, entity_lengths.cpu().numpy(), True)
        hidden = None
        lstm_out, hidden = self.lstm(packed_words, hidden)
        lstm_out, _ = pad_packed_sequence(lstm_out)

        output = self.attn(lstm_out.transpose(1, 0), entity_lengths)

        output = self.dropout(output)

        output = self.hidden(output)
        output = self.out(output)
        return output
Beispiel #10
0
    def __init__(self, config, pretrained_dep_model: nn.Module = None):
        super(NNCRF, self).__init__()

        self.label_size = config.label_size
        self.device = config.device
        self.use_char = config.use_char_rnn
        self.dep_model = config.dep_model
        self.context_emb = config.context_emb
        self.interaction_func = config.interaction_func

        self.label2idx = config.label2idx
        self.labels = config.idx2labels
        self.start_idx = self.label2idx[START]
        self.end_idx = self.label2idx[STOP]
        self.pad_idx = self.label2idx[PAD]

        self.input_size = config.embedding_dim

        if self.use_char:
            self.char_feature = CharBiLSTM(config)
            self.input_size += config.charlstm_hidden_dim

        vocab_size = len(config.word2idx)
        self.word_embedding = nn.Embedding.from_pretrained(
            torch.FloatTensor(config.word_embedding),
            freeze=False).to(self.device)
        self.word_drop = nn.Dropout(config.dropout).to(self.device)

        if self.dep_model == DepModelType.dglstm and self.interaction_func == InteractionFunction.mlp:
            self.mlp_layers = nn.ModuleList()
            for i in range(config.num_lstm_layer - 1):
                self.mlp_layers.append(
                    nn.Linear(config.hidden_dim,
                              2 * config.hidden_dim).to(self.device))
            self.mlp_head_linears = nn.ModuleList()
            for i in range(config.num_lstm_layer - 1):
                self.mlp_head_linears.append(
                    nn.Linear(config.hidden_dim,
                              2 * config.hidden_dim).to(self.device))
        """
            Input size to LSTM description
        """
        self.charlstm_dim = config.charlstm_hidden_dim
        if self.dep_model == DepModelType.dglstm:
            self.input_size += config.embedding_dim + config.dep_emb_size
            if self.use_char:
                self.input_size += config.charlstm_hidden_dim

        if self.context_emb != ContextEmb.none:
            self.input_size += config.context_emb_size

        print("[Model Info] Input size to LSTM: {}".format(self.input_size))
        print("[Model Info] LSTM Hidden Size: {}".format(config.hidden_dim))

        num_layers = 1
        if config.num_lstm_layer > 1 and self.dep_model != DepModelType.dglstm:
            num_layers = config.num_lstm_layer
        if config.num_lstm_layer > 0:
            self.lstm = nn.LSTM(self.input_size,
                                config.hidden_dim // 2,
                                num_layers=num_layers,
                                batch_first=True,
                                bidirectional=True).to(self.device)

        self.num_lstm_layer = config.num_lstm_layer
        self.lstm_hidden_dim = config.hidden_dim
        self.embedding_dim = config.embedding_dim
        if config.num_lstm_layer > 1 and self.dep_model == DepModelType.dglstm:
            self.add_lstms = nn.ModuleList()
            if self.interaction_func == InteractionFunction.concatenation or \
                    self.interaction_func == InteractionFunction.mlp:
                hidden_size = 2 * config.hidden_dim
            elif self.interaction_func == InteractionFunction.addition:
                hidden_size = config.hidden_dim

            print(
                "[Model Info] Building {} more LSTMs, with size: {} x {} (without dep label highway connection)"
                .format(config.num_lstm_layer - 1, hidden_size,
                        config.hidden_dim))
            for i in range(config.num_lstm_layer - 1):
                self.add_lstms.append(
                    nn.LSTM(hidden_size,
                            config.hidden_dim // 2,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True).to(self.device))

        self.drop_lstm = nn.Dropout(config.dropout).to(self.device)

        final_hidden_dim = config.hidden_dim if self.num_lstm_layer > 0 else self.input_size
        """
        Model description
        """
        print("[Model Info] Dep Method: {}, hidden size: {}".format(
            self.dep_model.name, config.dep_hidden_dim))
        if self.dep_model != DepModelType.none:
            print("Initializing the dependency label embedding")
            self.pretrain_dep = config.pretrain_dep
            if config.pretrain_dep:
                self.dep_label_embedding = pretrained_dep_model.to(self.device)
                self.dep_label_embedding.train()
                if config.freeze:
                    for param in self.dep_label_embedding.parameters():
                        param.requires_grad = False
            else:
                self.dep_label_embedding = nn.Embedding(
                    len(config.deplabel2idx),
                    config.dep_emb_size).to(self.device)
                self.dep_label_embedding.weight.requires_grad = True if config.freeze else False
            if self.dep_model == DepModelType.dggcn:
                self.gcn = DepLabeledGCN(
                    config, config.hidden_dim)  ### lstm hidden size
                final_hidden_dim = config.dep_hidden_dim

        print("[Model Info] Final Hidden Size: {}".format(final_hidden_dim))
        self.hidden2tag = nn.Linear(final_hidden_dim,
                                    self.label_size).to(self.device)

        init_transition = torch.randn(self.label_size,
                                      self.label_size).to(self.device)
        init_transition[:, self.start_idx] = -10000.0
        init_transition[self.end_idx, :] = -10000.0
        init_transition[:, self.pad_idx] = -10000.0
        init_transition[self.pad_idx, :] = -10000.0

        self.transition = nn.Parameter(init_transition)
Beispiel #11
0
class NNCRF(nn.Module):
    def __init__(self, config, pretrained_dep_model: nn.Module = None):
        super(NNCRF, self).__init__()

        self.label_size = config.label_size
        self.device = config.device
        self.use_char = config.use_char_rnn
        self.dep_model = config.dep_model
        self.context_emb = config.context_emb
        self.interaction_func = config.interaction_func

        self.label2idx = config.label2idx
        self.labels = config.idx2labels
        self.start_idx = self.label2idx[START]
        self.end_idx = self.label2idx[STOP]
        self.pad_idx = self.label2idx[PAD]

        self.input_size = config.embedding_dim

        if self.use_char:
            self.char_feature = CharBiLSTM(config)
            self.input_size += config.charlstm_hidden_dim

        vocab_size = len(config.word2idx)
        self.word_embedding = nn.Embedding.from_pretrained(
            torch.FloatTensor(config.word_embedding),
            freeze=False).to(self.device)
        self.word_drop = nn.Dropout(config.dropout).to(self.device)

        if self.dep_model == DepModelType.dglstm and self.interaction_func == InteractionFunction.mlp:
            self.mlp_layers = nn.ModuleList()
            for i in range(config.num_lstm_layer - 1):
                self.mlp_layers.append(
                    nn.Linear(config.hidden_dim,
                              2 * config.hidden_dim).to(self.device))
            self.mlp_head_linears = nn.ModuleList()
            for i in range(config.num_lstm_layer - 1):
                self.mlp_head_linears.append(
                    nn.Linear(config.hidden_dim,
                              2 * config.hidden_dim).to(self.device))
        """
            Input size to LSTM description
        """
        self.charlstm_dim = config.charlstm_hidden_dim
        if self.dep_model == DepModelType.dglstm:
            self.input_size += config.embedding_dim + config.dep_emb_size
            if self.use_char:
                self.input_size += config.charlstm_hidden_dim

        if self.context_emb != ContextEmb.none:
            self.input_size += config.context_emb_size

        print("[Model Info] Input size to LSTM: {}".format(self.input_size))
        print("[Model Info] LSTM Hidden Size: {}".format(config.hidden_dim))

        num_layers = 1
        if config.num_lstm_layer > 1 and self.dep_model != DepModelType.dglstm:
            num_layers = config.num_lstm_layer
        if config.num_lstm_layer > 0:
            self.lstm = nn.LSTM(self.input_size,
                                config.hidden_dim // 2,
                                num_layers=num_layers,
                                batch_first=True,
                                bidirectional=True).to(self.device)

        self.num_lstm_layer = config.num_lstm_layer
        self.lstm_hidden_dim = config.hidden_dim
        self.embedding_dim = config.embedding_dim
        if config.num_lstm_layer > 1 and self.dep_model == DepModelType.dglstm:
            self.add_lstms = nn.ModuleList()
            if self.interaction_func == InteractionFunction.concatenation or \
                    self.interaction_func == InteractionFunction.mlp:
                hidden_size = 2 * config.hidden_dim
            elif self.interaction_func == InteractionFunction.addition:
                hidden_size = config.hidden_dim

            print(
                "[Model Info] Building {} more LSTMs, with size: {} x {} (without dep label highway connection)"
                .format(config.num_lstm_layer - 1, hidden_size,
                        config.hidden_dim))
            for i in range(config.num_lstm_layer - 1):
                self.add_lstms.append(
                    nn.LSTM(hidden_size,
                            config.hidden_dim // 2,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True).to(self.device))

        self.drop_lstm = nn.Dropout(config.dropout).to(self.device)

        final_hidden_dim = config.hidden_dim if self.num_lstm_layer > 0 else self.input_size
        """
        Model description
        """
        print("[Model Info] Dep Method: {}, hidden size: {}".format(
            self.dep_model.name, config.dep_hidden_dim))
        if self.dep_model != DepModelType.none:
            print("Initializing the dependency label embedding")
            self.pretrain_dep = config.pretrain_dep
            if config.pretrain_dep:
                self.dep_label_embedding = pretrained_dep_model.to(self.device)
                self.dep_label_embedding.train()
                if config.freeze:
                    for param in self.dep_label_embedding.parameters():
                        param.requires_grad = False
            else:
                self.dep_label_embedding = nn.Embedding(
                    len(config.deplabel2idx),
                    config.dep_emb_size).to(self.device)
                self.dep_label_embedding.weight.requires_grad = True if config.freeze else False
            if self.dep_model == DepModelType.dggcn:
                self.gcn = DepLabeledGCN(
                    config, config.hidden_dim)  ### lstm hidden size
                final_hidden_dim = config.dep_hidden_dim

        print("[Model Info] Final Hidden Size: {}".format(final_hidden_dim))
        self.hidden2tag = nn.Linear(final_hidden_dim,
                                    self.label_size).to(self.device)

        init_transition = torch.randn(self.label_size,
                                      self.label_size).to(self.device)
        init_transition[:, self.start_idx] = -10000.0
        init_transition[self.end_idx, :] = -10000.0
        init_transition[:, self.pad_idx] = -10000.0
        init_transition[self.pad_idx, :] = -10000.0

        self.transition = nn.Parameter(init_transition)

    def neural_scoring(self,
                       word_seq_tensor,
                       word_seq_lens,
                       batch_context_emb,
                       char_inputs,
                       char_seq_lens,
                       adj_matrixs,
                       adjs_in,
                       adjs_out,
                       graphs,
                       dep_label_adj,
                       dep_head_tensor,
                       dep_label_tensor,
                       trees=None):
        """
        :param word_seq_tensor: (batch_size, sent_len)   NOTE: The word seq actually is already ordered before come here.
        :param word_seq_lens: (batch_size, 1)
        :param chars: (batch_size * sent_len * word_length)
        :param char_seq_lens: numpy (batch_size * sent_len , 1)
        :param dep_label_tensor: (batch_size, max_sent_len)
        :return: emission scores (batch_size, sent_len, hidden_dim)
        """
        batch_size = word_seq_tensor.size(0)
        sent_len = word_seq_tensor.size(1)

        word_emb = self.word_embedding(word_seq_tensor)
        if self.use_char:
            if self.dep_model == DepModelType.dglstm:
                char_features = self.char_feature.get_last_hiddens(
                    char_inputs, char_seq_lens)
                word_emb = torch.cat((word_emb, char_features), 2)
        if self.dep_model == DepModelType.dglstm:
            size = self.embedding_dim if not self.use_char else (
                self.embedding_dim + self.charlstm_dim)
            dep_head_emb = torch.gather(
                word_emb, 1,
                dep_head_tensor.view(batch_size, sent_len,
                                     1).expand(batch_size, sent_len, size))

        if self.context_emb != ContextEmb.none:
            word_emb = torch.cat((word_emb, batch_context_emb.to(self.device)),
                                 2)

        if self.use_char:
            if self.dep_model != DepModelType.dglstm:
                char_features = self.char_feature.get_last_hiddens(
                    char_inputs, char_seq_lens)
                word_emb = torch.cat((word_emb, char_features), 2)
        """
          Word Representation
        """
        if self.dep_model == DepModelType.dglstm:
            if self.pretrain_dep:
                adj_matrixs = adj_matrixs.to(self.device)
                dep_emb = self.dep_label_embedding.inference(
                    adj_matrixs, dep_label_tensor)
            else:
                dep_emb = self.dep_label_embedding(dep_label_tensor)
            word_emb = torch.cat((word_emb, dep_head_emb, dep_emb), 2)

        word_rep = self.word_drop(word_emb)

        sorted_seq_len, permIdx = word_seq_lens.sort(0, descending=True)
        _, recover_idx = permIdx.sort(0, descending=False)
        sorted_seq_tensor = word_rep[permIdx]

        if self.num_lstm_layer > 0:
            packed_words = pack_padded_sequence(sorted_seq_tensor,
                                                sorted_seq_len, True)
            lstm_out, _ = self.lstm(packed_words, None)
            lstm_out, _ = pad_packed_sequence(
                lstm_out, batch_first=True
            )  ## CARE: make sure here is batch_first, otherwise need to transpose.
            feature_out = self.drop_lstm(lstm_out)
        else:
            feature_out = sorted_seq_tensor
        """
        Higher order interactions
        """
        if self.num_lstm_layer > 1 and (self.dep_model == DepModelType.dglstm):
            for l in range(self.num_lstm_layer - 1):
                dep_head_emb = torch.gather(
                    feature_out, 1, dep_head_tensor[permIdx].view(
                        batch_size, sent_len,
                        1).expand(batch_size, sent_len, self.lstm_hidden_dim))
                if self.interaction_func == InteractionFunction.concatenation:
                    feature_out = torch.cat((feature_out, dep_head_emb), 2)
                elif self.interaction_func == InteractionFunction.addition:
                    feature_out = feature_out + dep_head_emb
                elif self.interaction_func == InteractionFunction.mlp:
                    feature_out = F.relu(self.mlp_layers[l](feature_out) +
                                         self.mlp_head_linears[l]
                                         (dep_head_emb))

                packed_words = pack_padded_sequence(feature_out,
                                                    sorted_seq_len, True)
                lstm_out, _ = self.add_lstms[l](packed_words, None)
                lstm_out, _ = pad_packed_sequence(
                    lstm_out, batch_first=True
                )  ## CARE: make sure here is batch_first, otherwise need to transpose.
                feature_out = self.drop_lstm(lstm_out)
        """
        Model forward if we have GCN
        """
        if self.dep_model == DepModelType.dggcn:
            feature_out = self.gcn(feature_out, sorted_seq_len,
                                   adj_matrixs[permIdx],
                                   dep_label_adj[permIdx])

        outputs = self.hidden2tag(feature_out)

        return outputs[recover_idx]

    def calculate_all_scores(self, features):
        batch_size = features.size(0)
        seq_len = features.size(1)
        scores = self.transition.view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
                    features.view(batch_size, seq_len, 1, self.label_size).expand(batch_size,seq_len,self.label_size, self.label_size)
        return scores

    def forward_unlabeled(self, all_scores, word_seq_lens, masks):
        batch_size = all_scores.size(0)
        seq_len = all_scores.size(1)
        alpha = torch.zeros(batch_size, seq_len,
                            self.label_size).to(self.device)

        alpha[:,
              0, :] = all_scores[:, 0, self.
                                 start_idx, :]  ## the first position of all labels = (the transition from start - > all labels) + current emission.

        for word_idx in range(1, seq_len):
            ## batch_size, self.label_size, self.label_size
            before_log_sum_exp = alpha[:, word_idx - 1, :].view(
                batch_size, self.label_size,
                1).expand(batch_size, self.label_size,
                          self.label_size) + all_scores[:, word_idx, :, :]
            alpha[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp)

        ### batch_size x label_size
        last_alpha = torch.gather(
            alpha, 1,
            word_seq_lens.view(batch_size, 1, 1).expand(
                batch_size, 1, self.label_size) - 1).view(
                    batch_size, self.label_size)
        last_alpha += self.transition[:, self.end_idx].view(
            1, self.label_size).expand(batch_size, self.label_size)
        last_alpha = log_sum_exp_pytorch(
            last_alpha.view(batch_size, self.label_size, 1)).view(batch_size)

        return torch.sum(last_alpha)

    def forward_labeled(self, all_scores, word_seq_lens, tags, masks):
        '''
        :param all_scores: (batch, seq_len, label_size, label_size)
        :param word_seq_lens: (batch, seq_len)
        :param tags: (batch, seq_len)
        :param masks: batch, seq_len
        :return: sum of score for the gold sequences
        '''
        batchSize = all_scores.shape[0]
        sentLength = all_scores.shape[1]

        ## all the scores to current labels: batch, seq_len, all_from_label?
        currentTagScores = torch.gather(
            all_scores, 3,
            tags.view(batchSize, sentLength, 1,
                      1).expand(batchSize, sentLength, self.label_size,
                                1)).view(batchSize, -1, self.label_size)
        if sentLength != 1:
            tagTransScoresMiddle = torch.gather(
                currentTagScores[:, 1:, :], 2,
                tags[:, :sentLength - 1].view(batchSize, sentLength - 1,
                                              1)).view(batchSize, -1)
        tagTransScoresBegin = currentTagScores[:, 0, self.start_idx]
        endTagIds = torch.gather(tags, 1, word_seq_lens.view(batchSize, 1) - 1)
        tagTransScoresEnd = torch.gather(
            self.transition[:, self.end_idx].view(1, self.label_size).expand(
                batchSize, self.label_size), 1, endTagIds).view(batchSize)
        score = torch.sum(tagTransScoresBegin) + torch.sum(tagTransScoresEnd)
        if sentLength != 1:
            score += torch.sum(tagTransScoresMiddle.masked_select(masks[:,
                                                                        1:]))
        return score

    def neg_log_obj(self,
                    words,
                    word_seq_lens,
                    batch_context_emb,
                    chars,
                    char_seq_lens,
                    adj_matrixs,
                    adjs_in,
                    adjs_out,
                    graphs,
                    dep_label_adj,
                    batch_dep_heads,
                    tags,
                    batch_dep_label,
                    trees=None):
        features = self.neural_scoring(words, word_seq_lens, batch_context_emb,
                                       chars, char_seq_lens, adj_matrixs,
                                       adjs_in, adjs_out, graphs,
                                       dep_label_adj, batch_dep_heads,
                                       batch_dep_label, trees)

        all_scores = self.calculate_all_scores(features)

        batch_size = words.size(0)
        sent_len = words.size(1)

        maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view(
            1, sent_len).expand(batch_size, sent_len).to(self.device)
        mask = torch.le(
            maskTemp,
            word_seq_lens.view(batch_size, 1).expand(batch_size,
                                                     sent_len)).to(self.device)

        unlabed_score = self.forward_unlabeled(all_scores, word_seq_lens, mask)
        labeled_score = self.forward_labeled(all_scores, word_seq_lens, tags,
                                             mask)
        return unlabed_score - labeled_score

    def viterbiDecode(self, all_scores, word_seq_lens):
        batchSize = all_scores.shape[0]
        sentLength = all_scores.shape[1]
        # sent_len =
        scoresRecord = torch.zeros([batchSize, sentLength,
                                    self.label_size]).to(self.device)
        idxRecord = torch.zeros([batchSize, sentLength, self.label_size],
                                dtype=torch.int64).to(self.device)
        mask = torch.ones_like(word_seq_lens,
                               dtype=torch.int64).to(self.device)
        startIds = torch.full((batchSize, self.label_size),
                              self.start_idx,
                              dtype=torch.int64).to(self.device)
        decodeIdx = torch.LongTensor(batchSize, sentLength).to(self.device)

        scores = all_scores
        # scoresRecord[:, 0, :] = self.getInitAlphaWithBatchSize(batchSize).view(batchSize, self.label_size)
        scoresRecord[:,
                     0, :] = scores[:, 0, self.
                                    start_idx, :]  ## represent the best current score from the start, is the best
        idxRecord[:, 0, :] = startIds
        for wordIdx in range(1, sentLength):
            ### scoresIdx: batch x from_label x to_label at current index.
            scoresIdx = scoresRecord[:, wordIdx - 1, :].view(
                batchSize, self.label_size,
                1).expand(batchSize, self.label_size,
                          self.label_size) + scores[:, wordIdx, :, :]
            idxRecord[:, wordIdx, :] = torch.argmax(
                scoresIdx, 1)  ## the best previous label idx to crrent labels
            scoresRecord[:, wordIdx, :] = torch.gather(
                scoresIdx, 1, idxRecord[:, wordIdx, :].view(
                    batchSize, 1,
                    self.label_size)).view(batchSize, self.label_size)

        lastScores = torch.gather(
            scoresRecord, 1,
            word_seq_lens.view(batchSize, 1, 1).expand(
                batchSize, 1, self.label_size) - 1).view(
                    batchSize, self.label_size)  ##select position
        lastScores += self.transition[:, self.end_idx].view(
            1, self.label_size).expand(batchSize, self.label_size)
        decodeIdx[:, 0] = torch.argmax(lastScores, 1)
        bestScores = torch.gather(lastScores, 1,
                                  decodeIdx[:, 0].view(batchSize, 1))

        for distance2Last in range(sentLength - 1):
            lastNIdxRecord = torch.gather(
                idxRecord, 1,
                torch.where(word_seq_lens - distance2Last - 1 > 0,
                            word_seq_lens - distance2Last - 1,
                            mask).view(batchSize, 1, 1).expand(
                                batchSize, 1,
                                self.label_size)).view(batchSize,
                                                       self.label_size)
            decodeIdx[:, distance2Last + 1] = torch.gather(
                lastNIdxRecord, 1,
                decodeIdx[:, distance2Last].view(batchSize, 1)).view(batchSize)

        return bestScores, decodeIdx

    def decode(self, batchInput):
        wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, adj_matrixs, adjs_in, adjs_out, graphs, dep_label_adj, batch_dep_heads, trees, tagSeqTensor, batch_dep_label = batchInput
        features = self.neural_scoring(wordSeqTensor, wordSeqLengths,
                                       batch_context_emb, charSeqTensor,
                                       charSeqLengths, adj_matrixs, adjs_in,
                                       adjs_out, graphs, dep_label_adj,
                                       batch_dep_heads, batch_dep_label, trees)
        all_scores = self.calculate_all_scores(features)
        bestScores, decodeIdx = self.viterbiDecode(all_scores, wordSeqLengths)
        return bestScores, decodeIdx
Beispiel #12
0
    def __init__(self, data):
        super(WordRep, self).__init__()
        print("build word representation...")
        self.gpu = data.HP_gpu
        self.use_char = data.use_char
        self.use_trans = data.use_trans
        self.batch_size = data.HP_batch_size
        self.char_hidden_dim = 0
        self.char_all_feature = False
        self.w = nn.Linear(data.word_emb_dim, data.HP_trans_hidden_dim)

        if self.use_trans:
            self.trans_hidden_dim = data.HP_trans_hidden_dim
            self.trans_embedding_dim = data.trans_emb_dim
            self.trans_feature = TransBiLSTM(data.translation_alphabet.size(),
                                             self.trans_embedding_dim,
                                             self.trans_hidden_dim,
                                             data.HP_dropout,
                                             data.pretrain_trans_embedding,
                                             self.gpu)

        if self.use_char:
            self.char_hidden_dim = data.HP_char_hidden_dim
            self.char_embedding_dim = data.char_emb_dim
            if data.char_seq_feature == "CNN":
                self.char_feature = CharCNN(data.char_alphabet.size(),
                                            self.char_embedding_dim,
                                            self.char_hidden_dim,
                                            data.HP_dropout, self.gpu)
            elif data.char_seq_feature == "LSTM":
                self.char_feature = CharBiLSTM(data.char_alphabet.size(),
                                               self.char_embedding_dim,
                                               self.char_hidden_dim,
                                               data.HP_dropout,
                                               data.pretrain_char_embedding,
                                               self.gpu)
            elif data.char_seq_feature == "GRU":
                self.char_feature = CharBiGRU(data.char_alphabet.size(),
                                              self.char_embedding_dim,
                                              self.char_hidden_dim,
                                              data.HP_dropout, self.gpu)
            elif data.char_seq_feature == "ALL":
                self.char_all_feature = True
                self.char_feature = CharCNN(data.char_alphabet.size(),
                                            self.char_embedding_dim,
                                            self.char_hidden_dim,
                                            data.HP_dropout, self.gpu)
                self.char_feature_extra = CharBiLSTM(data.char_alphabet.size(),
                                                     self.char_embedding_dim,
                                                     self.char_hidden_dim,
                                                     data.HP_dropout, self.gpu)
            else:
                print(
                    "Error char feature selection, please check parameter data.char_seq_feature (CNN/LSTM/GRU/ALL)."
                )
                exit(0)
        self.embedding_dim = data.word_emb_dim
        self.drop = nn.Dropout(data.HP_dropout)
        self.word_embedding = nn.Embedding(data.word_alphabet.size(),
                                           self.embedding_dim)
        if data.pretrain_word_embedding is not None:
            self.word_embedding.weight.data.copy_(
                torch.from_numpy(data.pretrain_word_embedding))
        else:
            self.word_embedding.weight.data.copy_(
                torch.from_numpy(
                    self.random_embedding(data.word_alphabet.size(),
                                          self.embedding_dim)))

        self.feature_num = data.feature_num
        self.feature_embedding_dims = data.feature_emb_dims
        self.feature_embeddings = nn.ModuleList()
        for idx in range(self.feature_num):
            self.feature_embeddings.append(
                nn.Embedding(data.feature_alphabets[idx].size(),
                             self.feature_embedding_dims[idx]))
        for idx in range(self.feature_num):
            if data.pretrain_feature_embeddings[idx] is not None:
                self.feature_embeddings[idx].weight.data.copy_(
                    torch.from_numpy(data.pretrain_feature_embeddings[idx]))
            else:
                self.feature_embeddings[idx].weight.data.copy_(
                    torch.from_numpy(
                        self.random_embedding(
                            data.feature_alphabets[idx].size(),
                            self.feature_embedding_dims[idx])))

        if self.gpu:
            self.drop = self.drop.cuda()
            self.word_embedding = self.word_embedding.cuda()
            for idx in range(self.feature_num):
                self.feature_embeddings[idx] = self.feature_embeddings[
                    idx].cuda()
Beispiel #13
0
class WordRep(nn.Module):
    def __init__(self, data):
        super(WordRep, self).__init__()
        print("build word representation...")
        self.gpu = data.HP_gpu
        self.use_char = data.use_char
        self.use_trans = data.use_trans
        self.batch_size = data.HP_batch_size
        self.char_hidden_dim = 0
        self.char_all_feature = False
        self.w = nn.Linear(data.word_emb_dim, data.HP_trans_hidden_dim)

        if self.use_trans:
            self.trans_hidden_dim = data.HP_trans_hidden_dim
            self.trans_embedding_dim = data.trans_emb_dim
            self.trans_feature = TransBiLSTM(data.translation_alphabet.size(),
                                             self.trans_embedding_dim,
                                             self.trans_hidden_dim,
                                             data.HP_dropout,
                                             data.pretrain_trans_embedding,
                                             self.gpu)

        if self.use_char:
            self.char_hidden_dim = data.HP_char_hidden_dim
            self.char_embedding_dim = data.char_emb_dim
            if data.char_seq_feature == "CNN":
                self.char_feature = CharCNN(data.char_alphabet.size(),
                                            self.char_embedding_dim,
                                            self.char_hidden_dim,
                                            data.HP_dropout, self.gpu)
            elif data.char_seq_feature == "LSTM":
                self.char_feature = CharBiLSTM(data.char_alphabet.size(),
                                               self.char_embedding_dim,
                                               self.char_hidden_dim,
                                               data.HP_dropout,
                                               data.pretrain_char_embedding,
                                               self.gpu)
            elif data.char_seq_feature == "GRU":
                self.char_feature = CharBiGRU(data.char_alphabet.size(),
                                              self.char_embedding_dim,
                                              self.char_hidden_dim,
                                              data.HP_dropout, self.gpu)
            elif data.char_seq_feature == "ALL":
                self.char_all_feature = True
                self.char_feature = CharCNN(data.char_alphabet.size(),
                                            self.char_embedding_dim,
                                            self.char_hidden_dim,
                                            data.HP_dropout, self.gpu)
                self.char_feature_extra = CharBiLSTM(data.char_alphabet.size(),
                                                     self.char_embedding_dim,
                                                     self.char_hidden_dim,
                                                     data.HP_dropout, self.gpu)
            else:
                print(
                    "Error char feature selection, please check parameter data.char_seq_feature (CNN/LSTM/GRU/ALL)."
                )
                exit(0)
        self.embedding_dim = data.word_emb_dim
        self.drop = nn.Dropout(data.HP_dropout)
        self.word_embedding = nn.Embedding(data.word_alphabet.size(),
                                           self.embedding_dim)
        if data.pretrain_word_embedding is not None:
            self.word_embedding.weight.data.copy_(
                torch.from_numpy(data.pretrain_word_embedding))
        else:
            self.word_embedding.weight.data.copy_(
                torch.from_numpy(
                    self.random_embedding(data.word_alphabet.size(),
                                          self.embedding_dim)))

        self.feature_num = data.feature_num
        self.feature_embedding_dims = data.feature_emb_dims
        self.feature_embeddings = nn.ModuleList()
        for idx in range(self.feature_num):
            self.feature_embeddings.append(
                nn.Embedding(data.feature_alphabets[idx].size(),
                             self.feature_embedding_dims[idx]))
        for idx in range(self.feature_num):
            if data.pretrain_feature_embeddings[idx] is not None:
                self.feature_embeddings[idx].weight.data.copy_(
                    torch.from_numpy(data.pretrain_feature_embeddings[idx]))
            else:
                self.feature_embeddings[idx].weight.data.copy_(
                    torch.from_numpy(
                        self.random_embedding(
                            data.feature_alphabets[idx].size(),
                            self.feature_embedding_dims[idx])))

        if self.gpu:
            self.drop = self.drop.cuda()
            self.word_embedding = self.word_embedding.cuda()
            for idx in range(self.feature_num):
                self.feature_embeddings[idx] = self.feature_embeddings[
                    idx].cuda()

    def random_embedding(self, vocab_size, embedding_dim):
        pretrain_emb = np.empty([vocab_size, embedding_dim])
        scale = np.sqrt(3.0 / embedding_dim)
        for index in range(vocab_size):
            pretrain_emb[index, :] = np.random.uniform(-scale, scale,
                                                       [1, embedding_dim])
        return pretrain_emb

    def forward(self, word_inputs, feature_inputs, word_seq_lengths,
                char_inputs, char_seq_lengths, char_seq_recover, trans_inputs,
                trans_seq_length, trans_seq_recover):
        """
            input:
                word_inputs: (batch_size, sent_len)
                features: list [(batch_size, sent_len), (batch_len, sent_len),...]
                word_seq_lengths: list of batch_size, (batch_size,1)
                char_inputs: (batch_size*sent_len, word_length)
                char_seq_lengths: list of whole batch_size for char, (batch_size*sent_len, 1)
                char_seq_recover: variable which records the char order information, used to recover char order
            output: 
                Variable(batch_size, sent_len, hidden_dim)
        """
        batch_size = word_inputs.size(0)
        sent_len = word_inputs.size(1)
        word_embs = self.word_embedding(word_inputs)
        word_list = [word_embs]

        for idx in range(self.feature_num):
            word_list.append(self.feature_embeddings[idx](feature_inputs[idx]))

        if self.use_char:
            # calculate char lstm last hidden
            char_features, _ = self.char_feature.get_last_hiddens(
                char_inputs,
                char_seq_lengths.cpu().numpy())
            char_features = char_features[char_seq_recover]
            char_features = char_features.view(batch_size, sent_len, -1)
            # concat word and char together
            word_list.append(char_features)
            # word_embs = torch.cat([word_embs, char_features], 2)
            if self.char_all_feature:
                char_features_extra, _ = self.char_feature_extra.get_last_hiddens(
                    char_inputs,
                    char_seq_lengths.cpu().numpy())
                char_features_extra = char_features_extra[char_seq_recover]
                char_features_extra = char_features_extra.view(
                    batch_size, sent_len, -1)
                # concat word and char together
                word_list.append(char_features_extra)

        if self.use_trans:
            trans_features, trans_rnn_length = self.trans_feature.get_last_hiddens(
                trans_inputs,
                trans_seq_length.cpu().numpy())

            trans_features_wc = trans_features
            if self.gpu:
                trans_features_wc.cuda()
            trans_features_wc = trans_features_wc[trans_seq_recover]
            trans_inputs = trans_inputs[trans_seq_recover]
            word_embs_temp = word_embs.view(batch_size * sent_len, -1)
            for index, line in enumerate(trans_inputs):
                if line[0].data.cpu().numpy()[0] == 0:
                    trans_features_wc[index] = self.w(word_embs_temp[index])

            trans_features_wc_temp = trans_features_wc
            trans_features_wc = trans_features_wc.view(batch_size, sent_len,
                                                       -1)

            word_list.append(trans_features_wc)

        word_embs = torch.cat(word_list, 2)
        word_represent = self.drop(word_embs)
        return word_represent, self.w(word_embs_temp), trans_features_wc_temp
Beispiel #14
0
    def __init__(self, config: Config, print_info: bool = True):
        super(BiLSTMEncoder, self).__init__()

        self.label_size = config.label_size
        self.fined_label_size = config.fined_label_size if config.use_fined_labels or config.latent_base else 0
        self.device = config.device
        self.use_char = config.use_char_rnn
        self.context_emb = config.context_emb

        self.label2idx = config.label2idx
        self.labels = config.idx2labels

        self.use_fined_labels = config.use_fined_labels
        self.use_neg_labels = config.use_neg_labels
        self.use_boundary = config.use_boundary
        self.fined_label2idx = config.fined_label2idx
        self.fined_labels = config.idx2fined_labels

        self.latent_base = config.latent_base
        self.latent_base_labels = config.latent_labels

        self.input_size = config.embedding_dim
        if self.context_emb != ContextEmb.none:
            self.input_size += config.context_emb_size
        if self.use_char:
            self.char_feature = CharBiLSTM(config, print_info=print_info)
            self.input_size += config.charlstm_hidden_dim

        self.word_embedding = nn.Embedding.from_pretrained(
            torch.FloatTensor(config.word_embedding),
            freeze=False).to(self.device)
        self.word_drop = nn.Dropout(config.dropout).to(self.device)

        if print_info:
            print("[Model Info] Input size to LSTM: {}".format(
                self.input_size))
            print("[Model Info] LSTM Hidden Size: {}".format(
                config.hidden_dim))

        self.lstm = nn.LSTM(self.input_size,
                            config.hidden_dim // 2,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True).to(self.device)

        self.drop_lstm = nn.Dropout(config.dropout).to(self.device)

        final_hidden_dim = config.hidden_dim

        if print_info:
            print(
                "[Model Info] Final Hidden Size: {}".format(final_hidden_dim))

        tag_size = self.fined_label_size if self.use_fined_labels or self.latent_base else self.label_size
        self.hidden2tag = nn.Linear(final_hidden_dim, tag_size).to(self.device)
        if self.use_fined_labels or self.latent_base:
            """
            The idea here is to build many mapping weights (where each with size: self.label_size x self.fined_label_size)
            for each type of the hyperedge. 
            (NOTE: we excluded "O" labels for connecting hyperedges because there would be too many and computation cost is
            extremely hight. For example. O connect to the combination of (B-Per-not, I-per-not, B-org-not....))
            
            First step:
               Build a dictionary `self.coarse_label2comb` where key is the coarse label, value is the list of possible combinations 
               of hyperedges for this label. For example, if coarse label = "B-PER"
               self.coarse_label2comb['B-PER'] = [(), ('B-ORG-NOT'), ('B-LOC-NOT'), ('B-'), ('B-org-not', 'B-loc-not')
                                                , ('B-org-not', 'B-'), ('B-loc-not', 'B-), (B-org-not, B-loc-not, B-)]
               (Note that we use index for the values, but I use string here for better illustration)
               Thus, each coarse label (excluding 'START', 'END', 'PAD' and 'O') has 8 elements in the list.
               We have a maximum of 8 combinations in the end.
               
               Because we exclude these labels: 'START', 'END', 'PAD' and 'O'. So the combination for these labels 
               is only a list of an empty tuple [()].
            Second step: (Where amazing happens) (Check out the `init_dense_label_mapping_weight` function)
                   
            """
            auxilary_labels = set(
                [config.START_TAG, config.STOP_TAG, config.PAD,
                 "O"])  ##excluded O labels
            self.coarse_label2comb = {}
            self.max_num_combinations = 0
            start = config.start_num
            for coarse_label in self.label2idx:
                if coarse_label not in auxilary_labels:
                    combs = []
                    valid_indexs = self.find_other_fined_idx(
                        coarse_label, is_latent_base=self.latent_base)
                    ## this commented code is used to test the equivalence with previous implementation. (Also need to remove O in auxilary labels)
                    if config.use_hypergraph and coarse_label != config.O:
                        if config.heuristic:
                            combs = self.find_heuristic_combination(
                                coarse_label=coarse_label,
                                is_latent_base=self.latent_base)
                        else:
                            for num in range(start, len(valid_indexs) + 1):
                                combs += list(
                                    combinations(
                                        self.find_other_fined_idx(
                                            coarse_label), num))
                    else:
                        combs += list(
                            combinations(
                                self.find_other_fined_idx(coarse_label),
                                len(valid_indexs)))
                    self.coarse_label2comb[coarse_label] = combs
                    self.max_num_combinations = max(self.max_num_combinations,
                                                    len(combs))
                else:
                    self.coarse_label2comb[coarse_label] = [
                        ()
                    ]  ## only itself for "start", "end", and "pad" labels and "O"???
            print(
                colored(
                    f"[Model Info] Maximum number of combination: {self.max_num_combinations}",
                    'red'))
            """
            Second step
            """
            label_mapping_weight, self.weight_mask = self.init_dense_label_mapping_weight(
            )
            row = label_mapping_weight.shape[
                0]  ##should be equal to `self.max_num_combinations` x `self.label_size`
            self.fined2labels = nn.Linear(self.fined_label_size,
                                          row,
                                          bias=False).to(self.device)
            self.inference_method = IF[config.inference_method]

            self.fined2labels.weight.data.copy_(
                torch.from_numpy(label_mapping_weight))
            self.fined2labels.weight.requires_grad = False  # not updating the weight.
            self.fined2labels.zero_grad()
Beispiel #15
0
class NNCRF(nn.Module):
    def __init__(self, config):
        super(NNCRF, self).__init__()
        self.config = config
        self.label_size = config.label_size
        self.device = config.device
        self.use_char = config.use_char_rnn
        self.context_emb = config.context_emb
        '''
        task specific:
        '''
        self.is_base = config.is_base

        self.label2idx = config.label2idx
        self.labels = config.idx2labels
        self.start_idx = self.label2idx[START]
        self.end_idx = self.label2idx[STOP]
        self.pad_idx = self.label2idx[PAD]

        self.input_size = config.embedding_dim

        if config.is_base:
            if self.context_emb != ContextEmb.none:
                self.input_size += config.context_emb_size
            if self.use_char:
                self.char_feature = CharBiLSTM(config)
                self.input_size += config.charlstm_hidden_dim

            vocab_size = len(config.word2idx)
            self.word_embedding = nn.Embedding.from_pretrained(
                torch.FloatTensor(config.word_embedding),
                freeze=False).to(self.device)
            self.word_drop = nn.Dropout(config.dropout).to(self.device)
        else:
            self.input_size = 350

        print("[Model Info] Input size to LSTM: {}".format(self.input_size))
        print("[Model Info] LSTM Hidden Size: {}".format(config.hidden_dim))

        self.lstm = nn.LSTM(self.input_size,
                            config.hidden_dim // 2,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True).to(self.device)

        self.drop_lstm = nn.Dropout(config.dropout).to(self.device)

        final_hidden_dim = config.hidden_dim

        print("[Model Info] Final Hidden Size: {}".format(final_hidden_dim))
        self.hidden2tag = nn.Linear(final_hidden_dim,
                                    self.label_size).to(self.device)

        init_transition = torch.randn(self.label_size,
                                      self.label_size).to(self.device)
        init_transition[:, self.start_idx] = -10000.0
        init_transition[self.end_idx, :] = -10000.0
        init_transition[:, self.pad_idx] = -10000.0
        init_transition[self.pad_idx, :] = -10000.0

        self.transition = nn.Parameter(init_transition)

    def neural_scoring(self, word_seq_tensor, word_seq_lens, batch_context_emb,
                       char_inputs, char_seq_lens, base_hiddens):
        """
        :param word_seq_tensor: (batch_size, sent_len)   NOTE: The word seq actually is already ordered before come here.
        :param word_seq_lens: (batch_size, 1)
        :param chars: (batch_size * sent_len * word_length)
        :param char_seq_lens: numpy (batch_size * sent_len , 1)
        :param dep_label_tensor: (batch_size, max_sent_len)
        :return: emission scores (batch_size, sent_len, hidden_dim)
        """
        if self.is_base:
            batch_size = word_seq_tensor.size(0)
            sent_len = word_seq_tensor.size(1)

            word_emb = self.word_embedding(word_seq_tensor)
            if self.context_emb != ContextEmb.none:
                word_emb = torch.cat(
                    [word_emb, batch_context_emb.to(self.device)], 2)
            if self.use_char:
                char_features = self.char_feature.get_last_hiddens(
                    char_inputs, char_seq_lens)
                word_emb = torch.cat([word_emb, char_features], 2)

            word_rep = self.word_drop(word_emb)
        else:
            word_rep = base_hiddens

        packed_words = pack_padded_sequence(word_rep, word_seq_lens, True)
        lstm_out, _ = self.lstm(packed_words, None)
        lstm_out, _ = pad_packed_sequence(
            lstm_out, batch_first=True
        )  ## CARE: make sure here is batch_first, otherwise need to transpose.
        feature_out = self.drop_lstm(lstm_out)
        ### TODO: dropout this lstm output or not, because ABB code do dropout.

        outputs = self.hidden2tag(feature_out)
        # return outputs, feature_out
        return outputs, torch.cat([word_rep, feature_out], 2)

    def calculate_all_scores(self, features):
        batch_size = features.size(0)
        seq_len = features.size(1)
        scores = self.transition.view(1, 1, self.label_size, self.label_size).expand(batch_size, seq_len, self.label_size, self.label_size) + \
                    features.view(batch_size, seq_len, 1, self.label_size).expand(batch_size,seq_len,self.label_size, self.label_size)
        return scores

    def forward_unlabeled(self, all_scores, word_seq_lens, masks,
                          mask_datasets):
        batch_size = all_scores.size(0)
        seq_len = all_scores.size(1)
        alpha = torch.zeros(batch_size, seq_len,
                            self.label_size).to(self.device)

        alpha[:,
              0, :] = all_scores[:, 0, self.
                                 start_idx, :]  ## the first position of all labels = (the transition from start - > all labels) + current emission.

        for word_idx in range(1, seq_len):
            ## batch_size, self.label_size, self.label_size
            before_log_sum_exp = alpha[:, word_idx - 1, :].view(
                batch_size, self.label_size,
                1).expand(batch_size, self.label_size,
                          self.label_size) + all_scores[:, word_idx, :, :]
            alpha[:, word_idx, :] = log_sum_exp_pytorch(before_log_sum_exp)

        ### batch_size x label_size
        last_alpha = torch.gather(
            alpha, 1,
            word_seq_lens.view(batch_size, 1, 1).expand(
                batch_size, 1, self.label_size) - 1).view(
                    batch_size, self.label_size)
        last_alpha += self.transition[:, self.end_idx].view(
            1, self.label_size).expand(batch_size, self.label_size)
        last_alpha = log_sum_exp_pytorch(
            last_alpha.view(batch_size, self.label_size, 1)).view(batch_size)

        return torch.sum(last_alpha * mask_datasets)

    def forward_labeled(self, all_scores, word_seq_lens, tags, masks,
                        mask_dataset):
        '''
        :param all_scores: (batch, seq_len, label_size, label_size)
        :param word_seq_lens: (batch, seq_len)
        :param tags: (batch, seq_len)
        :param masks: batch, seq_len
        :return: sum of score for the gold sequences
        '''
        batchSize = all_scores.shape[0]
        sentLength = all_scores.shape[1]

        ## all the scores to current labels: batch, seq_len, all_from_label?
        currentTagScores = torch.gather(
            all_scores, 3,
            tags.view(batchSize, sentLength, 1,
                      1).expand(batchSize, sentLength, self.label_size,
                                1)).view(batchSize, -1, self.label_size)
        if sentLength != 1:
            tagTransScoresMiddle = torch.gather(
                currentTagScores[:, 1:, :], 2,
                tags[:, :sentLength - 1].view(batchSize, sentLength - 1,
                                              1)).view(batchSize, -1)
        tagTransScoresBegin = currentTagScores[:, 0, self.start_idx]
        endTagIds = torch.gather(tags, 1, word_seq_lens.view(batchSize, 1) - 1)
        tagTransScoresEnd = torch.gather(
            self.transition[:, self.end_idx].view(1, self.label_size).expand(
                batchSize, self.label_size), 1, endTagIds).view(batchSize)
        score = torch.sum(tagTransScoresBegin * mask_dataset) + torch.sum(
            tagTransScoresEnd * mask_dataset)
        if sentLength != 1:
            score += torch.sum(
                (tagTransScoresMiddle *
                 mask_dataset.view(batchSize, 1)).masked_select(masks[:, 1:]))
        return score

    def neg_log_obj(self,
                    words,
                    word_seq_lens,
                    batch_context_emb,
                    chars,
                    char_seq_lens,
                    tags,
                    mask_dataset,
                    base_hiddens=None):
        words = words.to(self.config.device)
        word_seq_lens = word_seq_lens.to(self.config.device)
        chars = chars.to(self.config.device)
        char_seq_lens = char_seq_lens.to(self.config.device)
        tags = tags.to(self.config.device)
        mask_dataset = mask_dataset.to(self.config.device)

        features, feature_out = self.neural_scoring(words, word_seq_lens,
                                                    batch_context_emb, chars,
                                                    char_seq_lens,
                                                    base_hiddens)

        all_scores = self.calculate_all_scores(features)

        batch_size = words.size(0)
        sent_len = words.size(1)

        maskTemp = torch.arange(1, sent_len + 1, dtype=torch.long).view(
            1, sent_len).expand(batch_size, sent_len).to(self.device)
        mask = torch.le(
            maskTemp,
            word_seq_lens.view(batch_size, 1).expand(batch_size,
                                                     sent_len)).to(self.device)

        unlabed_score = self.forward_unlabeled(all_scores, word_seq_lens, mask,
                                               mask_dataset)
        labeled_score = self.forward_labeled(all_scores, word_seq_lens, tags,
                                             mask, mask_dataset)
        return unlabed_score - labeled_score, feature_out

    def viterbiDecode(self, all_scores, word_seq_lens):
        batchSize = all_scores.shape[0]
        sentLength = all_scores.shape[1]
        # sent_len =
        scoresRecord = torch.zeros([batchSize, sentLength,
                                    self.label_size]).to(self.device)
        idxRecord = torch.zeros([batchSize, sentLength, self.label_size],
                                dtype=torch.int64).to(self.device)
        mask = torch.ones_like(word_seq_lens,
                               dtype=torch.int64).to(self.device)
        startIds = torch.full((batchSize, self.label_size),
                              self.start_idx,
                              dtype=torch.int64).to(self.device)
        decodeIdx = torch.LongTensor(batchSize, sentLength).to(self.device)

        scores = all_scores
        # scoresRecord[:, 0, :] = self.getInitAlphaWithBatchSize(batchSize).view(batchSize, self.label_size)
        scoresRecord[:,
                     0, :] = scores[:, 0, self.
                                    start_idx, :]  ## represent the best current score from the start, is the best
        idxRecord[:, 0, :] = startIds
        for wordIdx in range(1, sentLength):
            ### scoresIdx: batch x from_label x to_label at current index.
            scoresIdx = scoresRecord[:, wordIdx - 1, :].view(
                batchSize, self.label_size,
                1).expand(batchSize, self.label_size,
                          self.label_size) + scores[:, wordIdx, :, :]
            idxRecord[:, wordIdx, :] = torch.argmax(
                scoresIdx, 1)  ## the best previous label idx to crrent labels
            scoresRecord[:, wordIdx, :] = torch.gather(
                scoresIdx, 1, idxRecord[:, wordIdx, :].view(
                    batchSize, 1,
                    self.label_size)).view(batchSize, self.label_size)

        lastScores = torch.gather(
            scoresRecord, 1,
            word_seq_lens.view(batchSize, 1, 1).expand(
                batchSize, 1, self.label_size) - 1).view(
                    batchSize, self.label_size)  ##select position
        lastScores += self.transition[:, self.end_idx].view(
            1, self.label_size).expand(batchSize, self.label_size)
        decodeIdx[:, 0] = torch.argmax(lastScores, 1)
        bestScores = torch.gather(lastScores, 1,
                                  decodeIdx[:, 0].view(batchSize, 1))

        for distance2Last in range(sentLength - 1):
            lastNIdxRecord = torch.gather(
                idxRecord, 1,
                torch.where(word_seq_lens - distance2Last - 1 > 0,
                            word_seq_lens - distance2Last - 1,
                            mask).view(batchSize, 1, 1).expand(
                                batchSize, 1,
                                self.label_size)).view(batchSize,
                                                       self.label_size)
            decodeIdx[:, distance2Last + 1] = torch.gather(
                lastNIdxRecord, 1,
                decodeIdx[:, distance2Last].view(batchSize, 1)).view(batchSize)

        return bestScores, decodeIdx

    def decode(self, batchInput, hidden_base):
        wordSeqTensor, wordSeqLengths, batch_context_emb, charSeqTensor, charSeqLengths, prefix_label, conll_label, notes_label, mask_base, mask_conll, mask_ontonotes = batchInput
        wordSeqTensor = wordSeqTensor.to(self.config.device)
        wordSeqLengths = wordSeqLengths.to(self.config.device)
        charSeqTensor = charSeqTensor.to(self.config.device)
        charSeqLengths = charSeqLengths.to(self.config.device)
        # prefix_label = prefix_label.to(self.config.device)
        # conll_label = conll_label.to(self.config.device)
        # notes_label = notes_label.to(self.config.device)
        # mask_base = mask_base.to(self.config.device)
        # mask_conll = mask_conll.to(self.config.device)
        # mask_ontonotes = mask_ontonotes.to(self.config.device)
        features, _ = self.neural_scoring(wordSeqTensor, wordSeqLengths,
                                          batch_context_emb, charSeqTensor,
                                          charSeqLengths, hidden_base)
        all_scores = self.calculate_all_scores(features)
        bestScores, decodeIdx = self.viterbiDecode(all_scores, wordSeqLengths)
        return bestScores, decodeIdx
Beispiel #16
0
    def __init__(self, data):
        super(BiLSTM, self).__init__()
        print( "build batched bilstm...")
        self.use_bigram = data.use_bigram
        self.gpu = data.HP_gpu
        self.use_char = data.HP_use_char
        self.use_gaz = data.HP_use_gaz
        self.batch_size = data.HP_batch_size
        self.char_hidden_dim = 0
        if self.use_char:
            self.char_hidden_dim = data.HP_char_hidden_dim
            self.char_embedding_dim = data.char_emb_dim
            if data.char_features == "CNN":
                self.char_feature = CharCNN(data.char_alphabet.size(), self.char_embedding_dim, self.char_hidden_dim, data.HP_dropout, self.gpu)
            elif data.char_features == "LSTM":
                self.char_feature = CharBiLSTM(data.char_alphabet.size(), self.char_embedding_dim, self.char_hidden_dim, data.HP_dropout, self.gpu)
            else:
                print( "Error char feature selection, please check parameter data.char_features (either CNN or LSTM).")
                exit(0)
        self.embedding_dim = data.word_emb_dim
        self.hidden_dim = data.HP_hidden_dim
        self.drop = nn.Dropout(data.HP_dropout)
        self.droplstm = nn.Dropout(data.HP_dropout)
        self.word_embeddings = nn.Embedding(data.word_alphabet.size(), self.embedding_dim)
        self.biword_embeddings = nn.Embedding(data.biword_alphabet.size(), data.biword_emb_dim)
        self.bilstm_flag = data.HP_bilstm
        # self.bilstm_flag = False
        self.lstm_layer = data.HP_lstm_layer
        if data.pretrain_word_embedding is not None:
            self.word_embeddings.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding))
        else:
            self.word_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(data.word_alphabet.size(), self.embedding_dim)))
            
        if data.pretrain_biword_embedding is not None:
            self.biword_embeddings.weight.data.copy_(torch.from_numpy(data.pretrain_biword_embedding))
        else:
            self.biword_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(data.biword_alphabet.size(), data.biword_emb_dim)))
        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        
        if self.bilstm_flag:
            lstm_hidden = data.HP_hidden_dim // 2
        else:
            lstm_hidden = data.HP_hidden_dim
        lstm_input = self.embedding_dim + self.char_hidden_dim
        if self.use_bigram:
            lstm_input += data.biword_emb_dim
        print("********************use_lattice",self.use_gaz)
        if self.use_gaz:
            self.forward_lstm = LatticeLSTM(lstm_input, lstm_hidden, data.gaz_dropout, data.gaz_alphabet.size(), data.gaz_emb_dim, data.pretrain_gaz_embedding, True, data.HP_fix_gaz_emb, self.gpu)
            if self.bilstm_flag:
                self.backward_lstm = LatticeLSTM(lstm_input, lstm_hidden, data.gaz_dropout, data.gaz_alphabet.size(), data.gaz_emb_dim, data.pretrain_gaz_embedding, False, data.HP_fix_gaz_emb, self.gpu)
        else:
            self.lstm = nn.LSTM(lstm_input, lstm_hidden, num_layers=self.lstm_layer, batch_first=True, bidirectional=self.bilstm_flag)

        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(data.HP_hidden_dim, data.label_alphabet_size)
        self.hidden2tag_ner = nn.Linear(data.HP_hidden_dim, data.label_alphabet_size_ner)
        self.hidden2tag_general = nn.Linear(data.HP_hidden_dim, data.label_alphabet_size_general)

        if self.gpu:
            self.drop = self.drop.cuda()
            self.droplstm = self.droplstm.cuda()
            self.word_embeddings = self.word_embeddings.cuda()
            self.biword_embeddings = self.biword_embeddings.cuda()
            if self.use_gaz:
                self.forward_lstm = self.forward_lstm.cuda()
                if self.bilstm_flag:
                    self.backward_lstm = self.backward_lstm.cuda()
            else:
                self.lstm = self.lstm.cuda()
            self.hidden2tag = self.hidden2tag.cuda()
            self.hidden2tag_ner = self.hidden2tag_ner.cuda()
            self.hidden2tag_general = self.hidden2tag_general.cuda()
Beispiel #17
0
class BiLSTM(nn.Module):
    def __init__(self, data):
        super(BiLSTM, self).__init__()
        print( "build batched bilstm...")
        self.use_bigram = data.use_bigram
        self.gpu = data.HP_gpu
        self.use_char = data.HP_use_char
        self.use_gaz = data.HP_use_gaz
        self.batch_size = data.HP_batch_size
        self.char_hidden_dim = 0
        if self.use_char:
            self.char_hidden_dim = data.HP_char_hidden_dim
            self.char_embedding_dim = data.char_emb_dim
            if data.char_features == "CNN":
                self.char_feature = CharCNN(data.char_alphabet.size(), self.char_embedding_dim, self.char_hidden_dim, data.HP_dropout, self.gpu)
            elif data.char_features == "LSTM":
                self.char_feature = CharBiLSTM(data.char_alphabet.size(), self.char_embedding_dim, self.char_hidden_dim, data.HP_dropout, self.gpu)
            else:
                print( "Error char feature selection, please check parameter data.char_features (either CNN or LSTM).")
                exit(0)
        self.embedding_dim = data.word_emb_dim
        self.hidden_dim = data.HP_hidden_dim
        self.drop = nn.Dropout(data.HP_dropout)
        self.droplstm = nn.Dropout(data.HP_dropout)
        self.word_embeddings = nn.Embedding(data.word_alphabet.size(), self.embedding_dim)
        self.biword_embeddings = nn.Embedding(data.biword_alphabet.size(), data.biword_emb_dim)
        self.bilstm_flag = data.HP_bilstm
        # self.bilstm_flag = False
        self.lstm_layer = data.HP_lstm_layer
        if data.pretrain_word_embedding is not None:
            self.word_embeddings.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding))
        else:
            self.word_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(data.word_alphabet.size(), self.embedding_dim)))
            
        if data.pretrain_biword_embedding is not None:
            self.biword_embeddings.weight.data.copy_(torch.from_numpy(data.pretrain_biword_embedding))
        else:
            self.biword_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(data.biword_alphabet.size(), data.biword_emb_dim)))
        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        
        if self.bilstm_flag:
            lstm_hidden = data.HP_hidden_dim // 2
        else:
            lstm_hidden = data.HP_hidden_dim
        lstm_input = self.embedding_dim + self.char_hidden_dim
        if self.use_bigram:
            lstm_input += data.biword_emb_dim
        print("********************use_lattice",self.use_gaz)
        if self.use_gaz:
            self.forward_lstm = LatticeLSTM(lstm_input, lstm_hidden, data.gaz_dropout, data.gaz_alphabet.size(), data.gaz_emb_dim, data.pretrain_gaz_embedding, True, data.HP_fix_gaz_emb, self.gpu)
            if self.bilstm_flag:
                self.backward_lstm = LatticeLSTM(lstm_input, lstm_hidden, data.gaz_dropout, data.gaz_alphabet.size(), data.gaz_emb_dim, data.pretrain_gaz_embedding, False, data.HP_fix_gaz_emb, self.gpu)
        else:
            self.lstm = nn.LSTM(lstm_input, lstm_hidden, num_layers=self.lstm_layer, batch_first=True, bidirectional=self.bilstm_flag)

        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(data.HP_hidden_dim, data.label_alphabet_size)
        self.hidden2tag_ner = nn.Linear(data.HP_hidden_dim, data.label_alphabet_size_ner)
        self.hidden2tag_general = nn.Linear(data.HP_hidden_dim, data.label_alphabet_size_general)

        if self.gpu:
            self.drop = self.drop.cuda()
            self.droplstm = self.droplstm.cuda()
            self.word_embeddings = self.word_embeddings.cuda()
            self.biword_embeddings = self.biword_embeddings.cuda()
            if self.use_gaz:
                self.forward_lstm = self.forward_lstm.cuda()
                if self.bilstm_flag:
                    self.backward_lstm = self.backward_lstm.cuda()
            else:
                self.lstm = self.lstm.cuda()
            self.hidden2tag = self.hidden2tag.cuda()
            self.hidden2tag_ner = self.hidden2tag_ner.cuda()
            self.hidden2tag_general = self.hidden2tag_general.cuda()


    def random_embedding(self, vocab_size, embedding_dim):
        pretrain_emb = np.empty([vocab_size, embedding_dim])
        scale = np.sqrt(3.0 / embedding_dim)
        for index in range(vocab_size):
            pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, embedding_dim])
        return pretrain_emb


    def get_lstm_features(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover):
        """
            input:
                word_inputs: (batch_size, sent_len)
                gaz_list:
                word_seq_lengths: list of batch_size, (batch_size,1)
                char_inputs: (batch_size*sent_len, word_length)
                char_seq_lengths: list of whole batch_size for char, (batch_size*sent_len, 1)
                char_seq_recover: variable which records the char order information, used to recover char order
            output: 
                Variable(sent_len, batch_size, hidden_dim)
        """
        batch_size = word_inputs.size(0)
        sent_len = word_inputs.size(1)
        word_embs =  self.word_embeddings(word_inputs)
        if self.use_bigram:
            biword_embs = self.biword_embeddings(biword_inputs)
            word_embs = torch.cat([word_embs, biword_embs],2)
        if self.use_char:
            ## calculate char lstm last hidden
            char_features = self.char_feature.get_last_hiddens(char_inputs, char_seq_lengths.cpu().numpy())
            char_features = char_features[char_seq_recover]
            char_features = char_features.view(batch_size,sent_len,-1)
            ## concat word and char together
            word_embs = torch.cat([word_embs, char_features], 2)
        word_embs = self.drop(word_embs)
        # packed_words = pack_padded_sequence(word_embs, word_seq_lengths.cpu().numpy(), True)
        hidden = None
        if self.use_gaz:
            lstm_out, hidden = self.forward_lstm(word_embs, gaz_list, hidden)
            if self.bilstm_flag:
                backward_hidden = None 
                backward_lstm_out, backward_hidden = self.backward_lstm(word_embs, gaz_list, backward_hidden)
                lstm_out = torch.cat([lstm_out, backward_lstm_out],2)
        else:
            lstm_out, hidden = self.lstm(word_embs, hidden)
        # lstm_out, _ = pad_packed_sequence(lstm_out)
        lstm_out = self.droplstm(lstm_out)
        return lstm_out



    def get_output_score(self, gaz_list,  word_inputs, biword_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover):
        lstm_out = self.get_lstm_features(gaz_list, word_inputs,biword_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover)
        ## lstm_out (batch_size, sent_len, hidden_dim)
        outputs = self.hidden2tag(lstm_out)
        return outputs
    
    def get_output_score_ner(self, gaz_list,  word_inputs, biword_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover):
        lstm_out = self.get_lstm_features(gaz_list, word_inputs,biword_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover)
        ## lstm_out (batch_size, sent_len, hidden_dim)
        outputs = self.hidden2tag_ner(lstm_out)
        return outputs

    def get_output_score_general(self, gaz_list,  word_inputs, biword_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover):
        lstm_out = self.get_lstm_features(gaz_list, word_inputs,biword_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover)
        ## lstm_out (batch_size, sent_len, hidden_dim)
        outputs = self.hidden2tag_general(lstm_out)
        return outputs
    def __init__(self, data):
        super(BiLSTM, self).__init__()
        print ("build batched bilstm...")
        self.use_bigram = data.use_bigram
        self.gpu = data.HP_gpu
        self.use_char = data.HP_use_char
        self.use_gaz = data.HP_use_gaz
        self.batch_size = data.HP_batch_size
        self.char_hidden_dim = 0
        if self.use_char:
            self.char_hidden_dim = data.HP_char_hidden_dim
            self.char_embedding_dim = data.char_emb_dim
            if data.char_features == "CNN":
                self.char_feature = CharCNN(data.char_alphabet.size(), self.char_embedding_dim, self.char_hidden_dim, data.HP_dropout, self.gpu)
            elif data.char_features == "LSTM":
                self.char_feature = CharBiLSTM(data.char_alphabet.size(), self.char_embedding_dim, self.char_hidden_dim, data.HP_dropout, self.gpu)
            else:
                print ("Error char feature selection, please check parameter data.char_features (either CNN or LSTM).")
                exit(0)
        self.embedding_dim = data.word_emb_dim
        self.hidden_dim = data.HP_hidden_dim
        self.drop = nn.Dropout(data.HP_dropout)
        self.droplstm = nn.Dropout(data.HP_dropout)
        self.word_embeddings = nn.Embedding(data.word_alphabet.size(), self.embedding_dim)
        self.biword_embeddings = nn.Embedding(data.biword_alphabet.size(), data.biword_emb_dim)
        self.bilstm_flag = data.HP_bilstm
        # self.bilstm_flag = False
        self.lstm_layer = data.HP_lstm_layer
        if data.pretrain_word_embedding is not None:
            self.word_embeddings.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding))
        else:
            self.word_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(data.word_alphabet.size(), self.embedding_dim)))
            
        if data.pretrain_biword_embedding is not None:
            self.biword_embeddings.weight.data.copy_(torch.from_numpy(data.pretrain_biword_embedding))
        else:
            self.biword_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(data.biword_alphabet.size(), data.biword_emb_dim)))
        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        
        if self.bilstm_flag:
            lstm_hidden = data.HP_hidden_dim // 2
        else:
            lstm_hidden = data.HP_hidden_dim

        lstm_input = self.embedding_dim + self.char_hidden_dim
        if self.use_bigram:
            lstm_input += data.biword_emb_dim

        self.forward_lstm = LatticeLSTM(lstm_input, lstm_hidden, data.gaz_dropout, data.gaz_alphabet.size(), data.gaz_emb_dim, data.pretrain_gaz_embedding, True, data.HP_fix_gaz_emb, self.gpu)

        if self.bilstm_flag:
            self.backward_lstm = LatticeLSTM(lstm_input, lstm_hidden, data.gaz_dropout, data.gaz_alphabet.size(), data.gaz_emb_dim, data.pretrain_gaz_embedding, False, data.HP_fix_gaz_emb, self.gpu)
        # self.lstm = nn.LSTM(lstm_input, lstm_hidden, num_layers=self.lstm_layer, batch_first=True, bidirectional=self.bilstm_flag)

        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(data.HP_hidden_dim, data.label_alphabet_size)

        if self.gpu:
            self.drop = self.drop.cuda()
            self.droplstm = self.droplstm.cuda()
            self.word_embeddings = self.word_embeddings.cuda()
            self.biword_embeddings = self.biword_embeddings.cuda()
            self.forward_lstm = self.forward_lstm.cuda()
            if self.bilstm_flag:
                self.backward_lstm = self.backward_lstm.cuda()
            self.hidden2tag = self.hidden2tag.cuda()
class BiLSTM(nn.Module):
    def __init__(self, data):
        super(BiLSTM, self).__init__()
        print ("build batched bilstm...")
        self.use_bigram = data.use_bigram
        self.gpu = data.HP_gpu
        self.use_char = data.HP_use_char
        self.use_gaz = data.HP_use_gaz
        self.batch_size = data.HP_batch_size
        self.char_hidden_dim = 0
        if self.use_char:
            self.char_hidden_dim = data.HP_char_hidden_dim
            self.char_embedding_dim = data.char_emb_dim
            if data.char_features == "CNN":
                self.char_feature = CharCNN(data.char_alphabet.size(), self.char_embedding_dim, self.char_hidden_dim, data.HP_dropout, self.gpu)
            elif data.char_features == "LSTM":
                self.char_feature = CharBiLSTM(data.char_alphabet.size(), self.char_embedding_dim, self.char_hidden_dim, data.HP_dropout, self.gpu)
            else:
                print ("Error char feature selection, please check parameter data.char_features (either CNN or LSTM).")
                exit(0)
        self.embedding_dim = data.word_emb_dim
        self.hidden_dim = data.HP_hidden_dim
        self.drop = nn.Dropout(data.HP_dropout)
        self.droplstm = nn.Dropout(data.HP_dropout)
        self.word_embeddings = nn.Embedding(data.word_alphabet.size(), self.embedding_dim)
        self.biword_embeddings = nn.Embedding(data.biword_alphabet.size(), data.biword_emb_dim)
        self.bilstm_flag = data.HP_bilstm
        # self.bilstm_flag = False
        self.lstm_layer = data.HP_lstm_layer
        if data.pretrain_word_embedding is not None:
            self.word_embeddings.weight.data.copy_(torch.from_numpy(data.pretrain_word_embedding))
        else:
            self.word_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(data.word_alphabet.size(), self.embedding_dim)))
            
        if data.pretrain_biword_embedding is not None:
            self.biword_embeddings.weight.data.copy_(torch.from_numpy(data.pretrain_biword_embedding))
        else:
            self.biword_embeddings.weight.data.copy_(torch.from_numpy(self.random_embedding(data.biword_alphabet.size(), data.biword_emb_dim)))
        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        
        if self.bilstm_flag:
            lstm_hidden = data.HP_hidden_dim // 2
        else:
            lstm_hidden = data.HP_hidden_dim

        lstm_input = self.embedding_dim + self.char_hidden_dim
        if self.use_bigram:
            lstm_input += data.biword_emb_dim

        self.forward_lstm = LatticeLSTM(lstm_input, lstm_hidden, data.gaz_dropout, data.gaz_alphabet.size(), data.gaz_emb_dim, data.pretrain_gaz_embedding, True, data.HP_fix_gaz_emb, self.gpu)

        if self.bilstm_flag:
            self.backward_lstm = LatticeLSTM(lstm_input, lstm_hidden, data.gaz_dropout, data.gaz_alphabet.size(), data.gaz_emb_dim, data.pretrain_gaz_embedding, False, data.HP_fix_gaz_emb, self.gpu)
        # self.lstm = nn.LSTM(lstm_input, lstm_hidden, num_layers=self.lstm_layer, batch_first=True, bidirectional=self.bilstm_flag)

        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(data.HP_hidden_dim, data.label_alphabet_size)

        if self.gpu:
            self.drop = self.drop.cuda()
            self.droplstm = self.droplstm.cuda()
            self.word_embeddings = self.word_embeddings.cuda()
            self.biword_embeddings = self.biword_embeddings.cuda()
            self.forward_lstm = self.forward_lstm.cuda()
            if self.bilstm_flag:
                self.backward_lstm = self.backward_lstm.cuda()
            self.hidden2tag = self.hidden2tag.cuda()


    def random_embedding(self, vocab_size, embedding_dim):
        pretrain_emb = np.empty([vocab_size, embedding_dim])
        scale = np.sqrt(3.0 / embedding_dim)
        for index in range(vocab_size):
            pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, embedding_dim])
        return pretrain_emb


    def get_lstm_features(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover):
        """
            input:
                word_inputs: (batch_size, sent_len)
                gaz_list:
                word_seq_lengths: list of batch_size, (batch_size,1)
                char_inputs: (batch_size*sent_len, word_length)
                char_seq_lengths: list of whole batch_size for char, (batch_size*sent_len, 1)
                char_seq_recover: variable which records the char order information, used to recover char order
            output: 
                Variable(sent_len, batch_size, hidden_dim)
        """
        batch_size = word_inputs.size(0)
        sent_len = word_inputs.size(1)
        word_embs =  self.word_embeddings(word_inputs)
        if self.use_bigram:
            biword_embs = self.biword_embeddings(biword_inputs)
            word_embs = torch.cat([word_embs, biword_embs],2)
        if self.use_char:
            ## calculate char lstm last hidden
            char_features = self.char_feature.get_last_hiddens(char_inputs, char_seq_lengths.cpu().numpy())
            char_features = char_features[char_seq_recover]
            char_features = char_features.view(batch_size,sent_len,-1)
            ## concat word and char together
            word_embs = torch.cat([word_embs, char_features], 2)
        word_embs = self.drop(word_embs)
        # packed_words = pack_padded_sequence(word_embs, word_seq_lengths.cpu().numpy(), True)
        hidden = None
        lstm_out, hidden = self.forward_lstm(word_embs, gaz_list, hidden)
        if self.bilstm_flag:
            backward_hidden = None 
            backward_lstm_out, backward_hidden = self.backward_lstm(word_embs, gaz_list, backward_hidden)
            lstm_out = torch.cat([lstm_out, backward_lstm_out],2)
        # lstm_out, _ = pad_packed_sequence(lstm_out)
        lstm_out = self.droplstm(lstm_out)
        return lstm_out



    def get_output_score(self, gaz_list,  word_inputs, biword_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover):
        lstm_out = self.get_lstm_features(gaz_list, word_inputs,biword_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover)
        ## lstm_out (batch_size, sent_len, hidden_dim)
        outputs = self.hidden2tag(lstm_out)
        return outputs
    

    def neg_log_likelihood_loss(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover, batch_label, mask):
        ## mask is not used
        batch_size = word_inputs.size(0)
        seq_len = word_inputs.size(1)
        total_word = batch_size * seq_len
        loss_function = nn.NLLLoss(ignore_index=0, size_average=False)
        outs = self.get_output_score(gaz_list, word_inputs, biword_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover)
        # outs (batch_size, seq_len, label_vocab)
        outs = outs.view(total_word, -1)
        score = F.log_softmax(outs, 1)
        loss = loss_function(score, batch_label.view(total_word))
        _, tag_seq  = torch.max(score, 1)
        tag_seq = tag_seq.view(batch_size, seq_len)
        return loss, tag_seq


    def forward(self, gaz_list, word_inputs, biword_inputs, word_seq_lengths,  char_inputs, char_seq_lengths, char_seq_recover, mask):
        
        batch_size = word_inputs.size(0)
        seq_len = word_inputs.size(1)
        total_word = batch_size * seq_len
        outs = self.get_output_score(gaz_list,  word_inputs, biword_inputs, word_seq_lengths, char_inputs, char_seq_lengths, char_seq_recover)
        outs = outs.view(total_word, -1)
        _, tag_seq  = torch.max(outs, 1)
        tag_seq = tag_seq.view(batch_size, seq_len)
        ## filter padded position with zero
        decode_seq = mask.long() * tag_seq
        return decode_seq