コード例 #1
0
class TextRNN(Classifier):
    """Implement TextRNN, contains LSTM,BiLSTM,GRU,BiGRU
    Reference: "Effective LSTMs for Target-Dependent Sentiment Classification"
               "Bidirectional LSTM-CRF Models for Sequence Tagging"
               "Generative and discriminative text classification
                with recurrent neural networks"
    """

    def __init__(self, dataset, config):
        super(TextRNN, self).__init__(dataset, config)
        self.rnn = RNN(
            config.embedding.dimension, config.TextRNN.hidden_dimension,
            num_layers=config.TextRNN.num_layers, batch_first=True, 
            bidirectional=config.TextRNN.bidirectional,
            rnn_type=config.TextRNN.rnn_type)
        hidden_dimension = config.TextRNN.hidden_dimension
        if config.TextRNN.bidirectional:
            hidden_dimension *= 2
        self.sum_attention = SumAttention(hidden_dimension,
                                          config.TextRNN.attention_dimension)
        self.linear = torch.nn.Linear(hidden_dimension, len(dataset.label_map))
        self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout)

    def get_parameter_optimizer_dict(self):
        params = super(TextRNN, self).get_parameter_optimizer_dict()
        params.append({'params': self.rnn.parameters()})
        params.append({'params': self.linear.parameters()})
        return params
    
    def update_lr(self, optimizer, epoch):
        if epoch > self.config.train.num_epochs_static_embedding:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = self.config.optimizer.learning_rate
        else:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = 0.0
               

    def forward(self, batch):
        if self.config.feature.feature_names[0] == "token":
            embedding = self.token_embedding(
                batch[cDataset.DOC_TOKEN].to(self.config.device))
            length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device)
        else:
            embedding = self.char_embedding(
                batch[cDataset.DOC_CHAR].to(self.config.device))
            length = batch[cDataset.DOC_CHAR_LEN].to(self.config.device)
        output, last_hidden = self.rnn(embedding, length)

        doc_embedding_type = self.config.TextRNN.doc_embedding_type
        if doc_embedding_type == DocEmbeddingType.AVG:
            doc_embedding = torch.sum(output, 1) / length.unsqueeze(1)
        elif doc_embedding_type == DocEmbeddingType.ATTENTION:
            doc_embedding = self.sum_attention(output)
        elif doc_embedding_type == DocEmbeddingType.LAST_HIDDEN:
            doc_embedding = last_hidden
        else:
            raise TypeError(
                "Unsupported rnn init type: %s. Supported rnn type is: %s" % (
                    doc_embedding_type, DocEmbeddingType.str()))
        
        return self.dropout(self.linear(doc_embedding))
コード例 #2
0
class HMCN(Classifier):
    """ Implement HMCN(Hierarchical Multi-Label Classification Networks)
        Reference: "Hierarchical Multi-Label Classification Networks"
    """

    def __init__(self, dataset, config):
        super(HMCN, self).__init__(dataset, config)
        self.hierarchical_depth = config.HMCN.hierarchical_depth
        self.hierarchical_class = dataset.hierarchy_classes
        self.global2local = config.HMCN.global2local
        self.rnn = RNN(
            config.embedding.dimension, config.TextRNN.hidden_dimension, 
            num_layers=config.TextRNN.num_layers, batch_first=True,
            bidirectional=config.TextRNN.bidirectional,
            rnn_type=config.TextRNN.rnn_type)
        hidden_dimension = config.TextRNN.hidden_dimension
        if config.TextRNN.bidirectional:
            hidden_dimension *= 2
        
        self.local_layers = torch.nn.ModuleList()
        self.global_layers = torch.nn.ModuleList()
        for i in range(1, len(self.hierarchical_depth)):
            self.global_layers.append(
                torch.nn.Sequential(
                    torch.nn.Linear(hidden_dimension + self.hierarchical_depth[i-1], self.hierarchical_depth[i]),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm1d(self.hierarchical_depth[i]),
                    torch.nn.Dropout(p=0.5)
                ))
            self.local_layers.append(
                torch.nn.Sequential(
                    torch.nn.Linear(self.hierarchical_depth[i], self.global2local[i]),
                    torch.nn.ReLU(),
                    torch.nn.BatchNorm1d(self.global2local[i]),
                    torch.nn.Linear(self.global2local[i], self.hierarchical_class[i-1])
                ))

        self.global_layers.apply(self._init_weight)
        self.local_layers.apply(self._init_weight)
        self.linear = torch.nn.Linear(self.hierarchical_depth[-1], len(dataset.label_map))
        self.linear.apply(self._init_weight)
        self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout)
        
    def _init_weight(self, m):
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.normal_(m.weight, std=0.1) 

    def get_parameter_optimizer_dict(self):
        params = super(HMCN, self).get_parameter_optimizer_dict() 
        params.append({'params': self.rnn.parameters()})
        params.append({'params': self.local_layers.parameters()})
        params.append({'params': self.global_layers.parameters()})
        params.append({'params': self.linear.parameters()})
        return params 

    def update_lr(self, optimizer, epoch):
        """ Update lr
        """
        if epoch > self.config.train.num_epochs_static_embedding:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = self.config.optimizer.learning_rate
        else:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = 0

    def forward(self, batch):
        if self.config.feature.feature_names[0] == "token":
            embedding = self.token_embedding(
                    batch[cDataset.DOC_TOKEN].to(self.config.device))
            length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device)
        else:
            embedding = self.char_embedding(
                    batch[cDataset.DOC_TOKEN].to(self.config.device))
            length = batch[cDataset.DOC_CHAR_LEN].to(self.config.device)
        
        output, last_hidden = self.rnn(embedding, length)
        doc_embedding = torch.sum(output, 1) / length.unsqueeze(1) 
        local_layer_outputs = []
        global_layer_activation = doc_embedding
        batch_size = doc_embedding.size()[0]
        for i, (local_layer, global_layer) in enumerate(zip(self.local_layers, self.global_layers)):
            local_layer_activation = global_layer(global_layer_activation)
            local_layer_outputs.append(local_layer(local_layer_activation))
            if i < len(self.global_layers)-1:
                global_layer_activation = torch.cat((local_layer_activation, doc_embedding), 1)
            else:
                global_layer_activation = local_layer_activation

        global_layer_output = self.linear(global_layer_activation)
        local_layer_output = torch.cat(local_layer_outputs, 1)
        return global_layer_output, local_layer_output, 0.5 * global_layer_output + 0.5 * local_layer_output   
コード例 #3
0
ファイル: textCRAN.py プロジェクト: AiXia520/HybridRCNN
class TextCRAN(Classifier):
    def __init__(self, dataset, config):
        super(TextCRAN, self).__init__(dataset, config)
        self.doc_embedding_type = config.TextCRAN.doc_embedding_type
        self.kernel_sizes = config.TextCRAN.kernel_sizes
        self.convs = torch.nn.ModuleList()
        for kernel_size in self.kernel_sizes:
            self.convs.append(
                torch.nn.Conv1d(config.embedding.dimension,
                                config.TextCRAN.num_kernels,
                                kernel_size,
                                padding=kernel_size - 1))

        self.top_k = self.config.TextCRAN.top_k_max_pooling
        hidden_size = len(config.TextCRAN.kernel_sizes) * \
                      config.TextCRAN.num_kernels * self.top_k

        self.rnn = RNN(config.embedding.dimension,
                       config.TextCRAN.hidden_dimension,
                       num_layers=config.TextCRAN.num_layers,
                       batch_first=True,
                       bidirectional=config.TextCRAN.bidirectional,
                       rnn_type=config.TextCRAN.rnn_type)

        hidden_dimension = config.TextCRAN.hidden_dimension
        if config.TextCRAN.bidirectional:
            hidden_dimension *= 2

        self.sum_attention = SumAttention(
            config.TextCRAN.attention_input_dimension,
            config.TextCRAN.attention_dimension, config.device)
        self.linear = torch.nn.Linear(
            config.TextCRAN.attention_input_dimension, len(dataset.label_map))
        self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout)

    def get_parameter_optimizer_dict(self):
        params = list()
        params.append({'params': self.token_embedding.parameters()})
        params.append({'params': self.char_embedding.parameters()})
        params.append({'params': self.convs.parameters()})
        params.append({'params': self.rnn.parameters()})
        params.append({'params': self.linear.parameters()})

        return params

    def update_lr(self, optimizer, epoch):
        """Update lr
        """
        if epoch > self.config.train.num_epochs_static_embedding:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = self.config.optimizer.learning_rate
        else:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = 0

    def forward(self, batch):
        if self.config.feature.feature_names[0] == "token":
            embedding = self.token_embedding(batch[cDataset.DOC_TOKEN].to(
                self.config.device))

            seq_length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device)
        else:
            embedding = self.char_embedding(batch[cDataset.DOC_CHAR].to(
                self.config.device))
            seq_length = batch[cDataset.DOC_CHAR_LEN].to(self.config.device)

        # CNN layer
        embedding2 = embedding.transpose(1, 2)
        pooled_outputs = []
        for i, conv in enumerate(self.convs):
            #convolution = torch.nn.ReLU(conv(embedding))
            convolution = torch.nn.functional.relu(
                conv(embedding2))  #[batch,100,n]

            pooled = torch.topk(convolution,
                                self.top_k)[0].view(convolution.size(0), -1)
            pooled_outputs.append(pooled)

        cnn = torch.cat(pooled_outputs, 1)  ##[batch,300]
        cnn = cnn.expand(embedding.size(1), cnn.size(0),
                         cnn.size(1))  # [n,batch,300]
        cnn = cnn.transpose(0, 1)  # [batch,n,300]

        # RNN layer
        rnn, output1 = self.rnn(embedding, seq_length)  #[batch,n, 300]

        output = torch.cat((cnn, rnn), 2)  ##[batch,n,600]

        if self.doc_embedding_type == DocEmbeddingType.ATTENTION:
            doc_embedding = self.sum_attention(output)

        out = self.dropout(self.linear(doc_embedding))

        return out
コード例 #4
0
ファイル: drnn.py プロジェクト: AiXia520/HybridRCNN
class DRNN(Classifier):
    def __init__(self, dataset, config):
        super(DRNN, self).__init__(dataset, config)
        self.rnn_type = config.DRNN.rnn_type
        self.forward_rnn = RNN(config.embedding.dimension,
                               config.DRNN.hidden_dimension,
                               batch_first=True,
                               rnn_type=config.DRNN.rnn_type)
        if config.DRNN.bidirectional:
            self.backward_rnn = RNN(config.embedding.dimension,
                                    config.DRNN.hidden_dimension,
                                    batch_first=True,
                                    rnn_type=config.DRNN.rnn_type)
        self.window_size = config.DRNN.window_size
        self.dropout = torch.nn.Dropout(p=config.DRNN.cell_hidden_dropout)
        self.hidden_dimension = config.DRNN.hidden_dimension
        if config.DRNN.bidirectional:
            self.hidden_dimension *= 2
        self.batch_norm = torch.nn.BatchNorm1d(self.hidden_dimension)

        self.mlp = torch.nn.Linear(self.hidden_dimension,
                                   self.hidden_dimension)
        self.linear = torch.nn.Linear(self.hidden_dimension,
                                      len(dataset.label_map))

    def get_parameter_optimizer_dict(self):
        params = super(DRNN, self).get_parameter_optimizer_dict()
        params.append({'params': self.forward_rnn.parameters()})
        if self.config.DRNN.bidirectional:
            params.append({'params': self.backward_rnn.parameters()})
        params.append({'params': self.batch_norm.parameters()})
        params.append({'params': self.mlp.parameters()})
        params.append({'params': self.linear.parameters()})
        return params

    def forward(self, batch):
        front_pad_embedding, _, mask = self.get_embedding(
            batch, [self.window_size - 1, 0], cDataset.VOCAB_PADDING_LEARNABLE)
        if self.config.DRNN.bidirectional:
            tail_pad_embedding, _, _ = self.get_embedding(
                batch, [0, self.window_size - 1],
                cDataset.VOCAB_PADDING_LEARNABLE)
        batch_size = front_pad_embedding.size(0)
        mask = mask.unsqueeze(2)

        front_slice_embedding_list = \
            [front_pad_embedding[:, i:i + self.window_size, :] for i in
             range(front_pad_embedding.size(1) - self.window_size + 1)]

        front_slice_embedding = torch.cat(front_slice_embedding_list, dim=0)

        state = None
        for i in range(front_slice_embedding.size(1)):
            _, state = self.forward_rnn(front_slice_embedding[:, i:i + 1, :],
                                        init_state=state,
                                        ori_state=True)
            if self.rnn_type == RNNType.LSTM:
                state[0] = self.dropout(state[0])
            else:
                state = self.dropout(state)
        front_state = state[0] if self.rnn_type == RNNType.LSTM else state
        front_state = front_state.transpose(0, 1)
        front_hidden = torch.cat(front_state.split(batch_size, dim=0), dim=1)
        front_hidden = front_hidden * mask

        hidden = front_hidden
        if self.config.DRNN.bidirectional:
            tail_slice_embedding_list = list()
            for i in range(tail_pad_embedding.size(1) - self.window_size + 1):
                slice_embedding = \
                    tail_pad_embedding[:, i:i + self.window_size, :]
                tail_slice_embedding_list.append(slice_embedding)
            tail_slice_embedding = torch.cat(tail_slice_embedding_list, dim=0)

            state = None
            for i in range(tail_slice_embedding.size(1), 0, -1):
                _, state = self.backward_rnn(tail_slice_embedding[:,
                                                                  i - 1:i, :],
                                             init_state=state,
                                             ori_state=True)
                if i != tail_slice_embedding.size(1) - 1:
                    if self.rnn_type == RNNType.LSTM:
                        state[0] = self.dropout(state[0])
                    else:
                        state = self.dropout(state)
            tail_state = state[0] if self.rnn_type == RNNType.LSTM else state
            tail_state = tail_state.transpose(0, 1)
            tail_hidden = torch.cat(tail_state.split(batch_size, dim=0), dim=1)
            tail_hidden = tail_hidden * mask
            hidden = torch.cat([hidden, tail_hidden], dim=2)

        hidden = hidden.transpose(1, 2).contiguous()

        batch_normed = self.batch_norm(hidden).transpose(1, 2)
        batch_normed = batch_normed * mask
        mlp_hidden = self.mlp(batch_normed)
        mlp_hidden = mlp_hidden * mask
        neg_mask = (mask - 1) * 65500.0
        mlp_hidden = mlp_hidden + neg_mask
        max_pooling = torch.nn.functional.max_pool1d(
            mlp_hidden.transpose(1, 2), mlp_hidden.size(1)).squeeze()
        return self.linear(self.dropout(max_pooling))
コード例 #5
0
ファイル: textGRA.py プロジェクト: AiXia520/HybridRCNN
class TextGRA(Classifier):
    def __init__(self, dataset, config):
        super(TextGRA, self).__init__(dataset, config)
        self.doc_embedding_type = config.TextGRA.doc_embedding_type
        self.kernel_sizes = config.TextGRA.kernel_sizes
        self.convs = torch.nn.ModuleList()
        for kernel_size in self.kernel_sizes:
            self.convs.append(
                torch.nn.Conv1d(config.embedding.dimension,
                                config.TextGRA.num_kernels,
                                kernel_size,
                                padding=kernel_size - 1))

        self.top_k = self.config.TextGRA.top_k_max_pooling
        hidden_size = len(config.TextGRA.kernel_sizes) * \
                      config.TextGRA.num_kernels * self.top_k

        self.rnn = RNN(config.embedding.dimension,
                       config.TextGRA.hidden_dimension,
                       num_layers=config.TextGRA.num_layers,
                       batch_first=True,
                       bidirectional=config.TextGRA.bidirectional,
                       rnn_type=config.TextGRA.rnn_type)
        self.rnn2 = RNN(600,
                        config.TextGRA.hidden_dimension,
                        num_layers=config.TextGRA.num_layers,
                        batch_first=True,
                        bidirectional=config.TextGRA.bidirectional,
                        rnn_type=config.TextGRA.rnn_type)
        hidden_dimension = config.TextGRA.hidden_dimension
        if config.TextGRA.bidirectional:
            hidden_dimension *= 2

        self.linear_u = torch.nn.Linear(300, 300)

        self.linear = torch.nn.Linear(hidden_dimension, len(dataset.label_map))
        self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout)

    def get_parameter_optimizer_dict(self):
        params = list()
        params.append({'params': self.token_embedding.parameters()})
        params.append({'params': self.char_embedding.parameters()})
        params.append({'params': self.convs.parameters()})
        params.append({'params': self.rnn.parameters()})
        params.append({'params': self.rnn2.parameters()})
        params.append({'params': self.linear_u.parameters()})
        params.append({'params': self.linear.parameters()})

        return params

    def update_lr(self, optimizer, epoch):
        """Update lr
        """
        if epoch > self.config.train.num_epochs_static_embedding:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = self.config.optimizer.learning_rate
        else:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = 0

    def forward(self, batch):
        if self.config.feature.feature_names[0] == "token":
            embedding = self.token_embedding(batch[cDataset.DOC_TOKEN].to(
                self.config.device))
            seq_length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device)
        else:
            embedding = self.char_embedding(batch[cDataset.DOC_CHAR].to(
                self.config.device))
            seq_length = batch[cDataset.DOC_CHAR_LEN].to(self.config.device)

        # RNN sentence vector
        S, output1 = self.rnn(embedding, seq_length)  #[batch,n, 300]
        embedding2 = embedding.transpose(1, 2)

        # CNN parase vector
        pooled_outputs = []
        for i, conv in enumerate(self.convs):
            #convolution = torch.nn.ReLU(conv(embedding))
            convolution = torch.nn.functional.relu(
                conv(embedding2))  #[batch,100,n]

            pooled = torch.topk(convolution,
                                self.top_k)[0].view(convolution.size(0), -1)
            pooled_outputs.append(pooled)

        P = torch.cat(pooled_outputs, 1)  ##[batch,300]

        P = P.expand(embedding.size(1), P.size(0), P.size(1))  #[batch,n,300]

        # attention scoring
        attention_score = self.linear_u(
            torch.tanh(torch.mul(S, P.transpose(0, 1))))  #[batch,n,300]
        # attention gate
        A, output2 = self.rnn(attention_score, seq_length)  # [batch,n, 450]
        # attention parase vector
        C = torch.mul(A, P.transpose(0, 1))  #[batch,n,450]
        # combine [x,c] input into RNN
        input = torch.cat((C, embedding), 2)  ##[batch,n,600]

        output, last_hidden = self.rnn2(input, seq_length)  #[batch,n,450]
        # average output
        if self.doc_embedding_type == DocEmbeddingType.AVG:
            doc_embedding = torch.sum(output, 1) / seq_length.unsqueeze(1)

        out = self.dropout(self.linear(doc_embedding))

        return out
コード例 #6
0
class TextCRVariant(Classifier):
    def __init__(self, dataset, config):
        super(TextCRVariant, self).__init__(dataset, config)

        self.label_semantic_emb = config.TextCRVariant.label_semantic_emb

        self.doc_embedding_type = config.TextCRVariant.doc_embedding_type
        self.kernel_sizes = config.TextCRVariant.kernel_sizes
        self.convs = torch.nn.ModuleList()
        for kernel_size in self.kernel_sizes:
            self.convs.append(
                torch.nn.Conv1d(384,
                                config.TextCRVariant.num_kernels,
                                kernel_size,
                                padding=kernel_size - 1))

        self.top_k = self.config.TextCRVariant.top_k_max_pooling
        hidden_size = len(config.TextCRVariant.kernel_sizes) * \
                      config.TextCRVariant.num_kernels * self.top_k

        self.rnn3 = torch.nn.GRU(256,
                                 256,
                                 num_layers=config.TextCRVariant.num_layers,
                                 batch_first=True)
        self.rnn1 = RNN(config.embedding.dimension,
                        config.TextCRVariant.hidden_dimension,
                        num_layers=config.TextCRVariant.num_layers,
                        batch_first=True,
                        bidirectional=config.TextCRVariant.bidirectional,
                        rnn_type=config.TextCRVariant.rnn_type)

        self.rnn2 = RNN(384,
                        config.TextCRVariant.hidden_dimension,
                        num_layers=config.TextCRVariant.num_layers,
                        batch_first=True,
                        bidirectional=config.TextCRVariant.bidirectional,
                        rnn_type=config.TextCRVariant.rnn_type)

        hidden_dimension = config.TextCRVariant.hidden_dimension
        if config.TextCRVariant.bidirectional:
            hidden_dimension *= 2

        self.sum_attention = SumAttention(
            hidden_dimension, config.TextCRVariant.attention_dimension,
            config.device)

        self.linear = torch.nn.Linear(hidden_size, len(dataset.label_map))
        self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout)

    def get_parameter_optimizer_dict(self):
        params = list()
        params.append({'params': self.token_embedding.parameters()})
        params.append({'params': self.char_embedding.parameters()})
        params.append({'params': self.convs.parameters()})
        params.append({'params': self.rnn1.parameters()})
        params.append({'params': self.rnn2.parameters()})
        params.append({'params': self.linear.parameters()})

        return params

    def update_lr(self, optimizer, epoch):
        """Update lr
        """
        if epoch > self.config.train.num_epochs_static_embedding:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = self.config.optimizer.learning_rate
        else:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = 0

    def forward(self, batch):
        if self.config.feature.feature_names[0] == "token":
            embedding = self.token_embedding(batch[cDataset.DOC_TOKEN].to(
                self.config.device))
            seq_length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device)
        else:
            embedding = self.char_embedding(batch[cDataset.DOC_CHAR].to(
                self.config.device))
            seq_length = batch[cDataset.DOC_CHAR_LEN].to(self.config.device)

        label_semantic_emb = self.label_semantic_emb.cuda()

        label_semantic_emb = label_semantic_emb.expand(
            (embedding.size(0), label_semantic_emb.size(0),
             label_semantic_emb.size(1)))  # [batch,L,256]

        output1, last_hidden = self.rnn3(label_semantic_emb)  # [batch,n,256]

        last_hidden = last_hidden.squeeze(dim=0)

        last_hidden = last_hidden.expand(
            (embedding.size(1), last_hidden.size(0),
             last_hidden.size(1)))  # [batch,n,256]

        last_hidden = last_hidden.transpose(0, 1)

        doc_embedding, _ = self.rnn1(embedding, seq_length)  #[batch,n,256]

        input = torch.cat((doc_embedding, last_hidden), 2)  ##[batch,512]

        doc_embedding = input.transpose(1, 2)
        pooled_outputs = []
        for _, conv in enumerate(self.convs):
            convolution = torch.nn.functional.relu(conv(doc_embedding))
            pooled = torch.topk(convolution,
                                self.top_k)[0].view(convolution.size(0), -1)
            pooled_outputs.append(pooled)

        doc_embedding = torch.cat(pooled_outputs, 1)

        out = self.dropout(self.linear(doc_embedding))

        return out
コード例 #7
0
class TextRCNN(Classifier):
    """TextRNN + TextCNN
    """
    def __init__(self, dataset, config):
        super(TextRCNN, self).__init__(dataset, config)
        self.rnn = RNN(config.embedding.dimension,
                       config.TextRCNN.hidden_dimension,
                       num_layers=config.TextRCNN.num_layers,
                       batch_first=True,
                       bidirectional=config.TextRCNN.bidirectional,
                       rnn_type=config.TextRCNN.rnn_type)

        hidden_dimension = config.TextRCNN.hidden_dimension
        if config.TextRCNN.bidirectional:
            hidden_dimension *= 2
        self.kernel_sizes = config.TextRCNN.kernel_sizes
        self.convs = torch.nn.ModuleList()
        for kernel_size in self.kernel_sizes:
            self.convs.append(
                torch.nn.Conv1d(hidden_dimension,
                                config.TextRCNN.num_kernels,
                                kernel_size,
                                padding=kernel_size - 1))

        self.top_k = self.config.TextRCNN.top_k_max_pooling
        hidden_size = len(config.TextRCNN.kernel_sizes) * \
                      config.TextRCNN.num_kernels * self.top_k

        self.linear = torch.nn.Linear(hidden_size, len(dataset.label_map))
        self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout)

    def get_parameter_optimizer_dict(self):
        params = list()
        params.append({'params': self.token_embedding.parameters()})
        params.append({'params': self.char_embedding.parameters()})
        params.append({'params': self.rnn.parameters()})
        params.append({'params': self.convs.parameters()})
        params.append({'params': self.linear.parameters()})
        return params

    def update_lr(self, optimizer, epoch):
        """
        """
        if epoch > self.config.train.num_epochs_static_embedding:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = self.config.optimizer.learning_rate
        else:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = 0

    def forward(self, batch):
        if self.config.feature.feature_names[0] == "token":
            embedding = self.token_embedding(batch[cDataset.DOC_TOKEN].to(
                self.config.device))
            seq_length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device)
        else:
            embedding = self.char_embedding(batch[cDataset.DOC_CHAR].to(
                self.config.device))
            seq_length = batch[cDataset.DOC_CHAR_LEN].to(self.config.device)
        embedding = self.token_similarity_attention(embedding)
        output, _ = self.rnn(embedding, seq_length)

        doc_embedding = output.transpose(1, 2)
        pooled_outputs = []
        for _, conv in enumerate(self.convs):
            convolution = F.relu(conv(doc_embedding))
            pooled = torch.topk(convolution,
                                self.top_k)[0].view(convolution.size(0), -1)
            pooled_outputs.append(pooled)

        doc_embedding = torch.cat(pooled_outputs, 1)

        return self.dropout(self.linear(doc_embedding))

    def token_similarity_attention(self, output):
        # output: (batch, sentence length, embedding dim)
        symptom_id_list = [
            6, 134, 15, 78, 2616, 257, 402, 281, 14848, 71, 82, 96, 352, 60,
            227, 204, 178, 175, 233, 192, 416, 91, 232, 317, 17513, 628, 1047
        ]
        symptom_embedding = self.token_embedding(
            torch.LongTensor(symptom_id_list).cuda())
        # symptom_embedding: torch.tensor(symptom_num, embedding dim)
        batch_symptom_embedding = torch.cat(
            [symptom_embedding.view(1, symptom_embedding.shape[0], -1)] *
            output.shape[0],
            dim=0)
        similarity = torch.sigmoid(
            torch.bmm(
                torch.nn.functional.normalize(output, dim=2),
                torch.nn.functional.normalize(batch_symptom_embedding.permute(
                    0, 2, 1),
                                              dim=2)))
        #similarity = torch.bmm(torch.nn.functional.normalize(output, dim=2), torch.nn.functional.normalize(batch_symptom_embedding.permute(0, 2, 1), dim=2))
        #similarity = torch.sigmoid(torch.max(similarity, dim=2)[0])
        similarity = torch.max(similarity, dim=2)[0]
        #similarity = torch.sigmoid(torch.sum(similarity, dim=2))
        # similarity: torch.tensor(batch, sentence_len)
        similarity = torch.cat([similarity.view(similarity.shape[0], -1, 1)] *
                               output.shape[2],
                               dim=2)
        # similarity: torch.tensor(batch, batch, sentence_len, embedding dim)
        #sentence_embedding = torch.sum(torch.mul(similarity, output), dim=1)
        # sentence_embedding: (batch, embedding)
        sentence_embedding = torch.mul(similarity, output)
        # sentence_embedding: (batch, sentence len, embedding)
        return sentence_embedding
コード例 #8
0
class ZAGRNN(Classifier):

    def __init__(self, dataset, config):

        assert config.label_embedding.dimension == config.ZAGRNN.gcn_in_features, \
            "label embedding dimension should be same as gcn input feature dimension"

        super(ZAGRNN, self).__init__(dataset, config)

        self.rnn = RNN(
            config.embedding.dimension, config.ZAGRNN.hidden_dimension,
            num_layers=config.ZAGRNN.num_layers, batch_first=True,
            bidirectional=config.ZAGRNN.bidirectional,
            rnn_type=config.ZAGRNN.rnn_type)

        self.label_wise_attention = LabelWiseAttention(
            feat_dim=config.ZAGRNN.hidden_dimension*2 if config.ZAGRNN.bidirectional else config.ZAGRNN.hidden_dimension,
            label_emb_dim=config.label_embedding.dimension,
            store_attention_score=config.ZAGRNN.store_attention_score)

        if config.ZAGRNN.use_gcn:
            self.gcn = torch.nn.ModuleList([
                GraphConvolution(
                    in_features=config.ZAGRNN.gcn_in_features,
                    out_features=config.ZAGRNN.gcn_hidden_features,
                    bias=True,
                    act=torch.relu_,
                    featureless=False,
                    dropout=config.ZAGRNN.gcn_dropout),
                GraphConvolution(
                    in_features=config.ZAGRNN.gcn_hidden_features,
                    out_features=config.ZAGRNN.gcn_out_features,
                    bias=True,
                    act=torch.relu_,
                    featureless=False,
                    dropout=config.ZAGRNN.gcn_dropout)
            ])

            self.doc_out_transform = torch.nn.Sequential(
                torch.nn.Linear(
                    in_features=config.ZAGRNN.hidden_dimension*2 if config.ZAGRNN.bidirectional else config.ZAGRNN.hidden_dimension,
                    out_features=config.ZAGRNN.gcn_in_features + config.ZAGRNN.gcn_out_features
                ),
                torch.nn.ReLU()
            )

    def get_parameter_optimizer_dict(self):
        params = list()
        params.append({'params': self.token_embedding.parameters()})
        params.append({'params': self.rnn.parameters()})
        params.append({'params': self.label_wise_attention.parameters()})
        if self.config.ZAGRNN.use_gcn:
            params.append({'params': self.gcn.parameters()})
            params.append({'params': self.doc_out_transform.parameters()})
        return params

    def update_lr(self, optimizer, epoch):
        if epoch > self.config.train.num_epochs_static_embedding:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = self.config.optimizer.learning_rate
        else:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = 0

    def forward(self, batch):
        if self.config.feature.feature_names[0] == "token":
            embedding = self.token_embedding(
                batch[cDataset.DOC_TOKEN].to(self.config.device))
            length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device)
        else:
            raise NotImplementedError
        doc_embedding, _ = self.rnn(embedding, length)

        label_repr = self.label_embedding(
            batch[cDataset.DOC_LABEL_ID].to(self.config.device))
        attentive_doc_embedding = self.label_wise_attention(doc_embedding, label_repr)

        if self.config.ZAGRNN.use_gcn:
            label_repr_gcn = label_repr
            for gcn_layer in self.gcn:
                label_repr_gcn = gcn_layer(label_repr_gcn, batch[cDataset.DOC_LABEL_RELATION].to(self.config.device))
            label_repr = torch.cat((label_repr, label_repr_gcn), dim=1)

            return torch.sum(self.doc_out_transform(attentive_doc_embedding) * label_repr, dim=-1)

        return torch.sum(attentive_doc_embedding * label_repr, dim=-1)
コード例 #9
0
ファイル: zsjlrnn.py プロジェクト: MemoriesJ/KAMG
class ZSJLRNN(Classifier):
    def __init__(self, dataset, config):

        assert config.label_embedding.dimension == config.ZSJLRNN.gcn_in_features, \
            "label embedding dimension should be same as gcn input feature dimension"
        assert len(config.data.label_relation_files) >= 2, \
            "this model should utilize at least 2 different graphs' adjacency"

        super(ZSJLRNN, self).__init__(dataset, config)

        self.rnn = RNN(config.embedding.dimension,
                       config.ZSJLRNN.hidden_dimension,
                       num_layers=config.ZSJLRNN.num_layers,
                       batch_first=True,
                       bidirectional=config.ZSJLRNN.bidirectional,
                       rnn_type=config.ZSJLRNN.rnn_type)

        self.label_wise_attention = LabelWiseAttention(
            feat_dim=config.ZSJLRNN.hidden_dimension * 2 if
            config.ZSJLRNN.bidirectional else config.ZSJLRNN.hidden_dimension,
            label_emb_dim=config.label_embedding.dimension,
            store_attention_score=config.ZSJLRNN.store_attention_score)

        self.multi_gcn = torch.nn.ModuleList([
            MultiGraphConvolution(
                n_adj=len(config.data.label_relation_files),
                in_features=config.ZSJLRNN.gcn_in_features,
                out_features=config.ZSJLRNN.gcn_hidden_features,
                bias=True,
                act=torch.relu_,
                featureless=False,
                dropout=config.ZSJLRNN.gcn_dropout),
            MultiGraphConvolution(
                n_adj=len(config.data.label_relation_files),
                in_features=config.ZSJLRNN.gcn_hidden_features,
                out_features=config.ZSJLRNN.gcn_out_features,
                bias=True,
                act=torch.relu_,
                featureless=False,
                dropout=config.ZSJLRNN.gcn_dropout)
        ])

        self.multi_gcn_fuse = Fusion(config)

        if config.fusion.fusion_type == FusionType.CONCATENATION:
            out_tmp = config.ZSJLRNN.gcn_in_features + config.fusion.out_features
        elif config.fusion.fusion_type == FusionType.ATTACH:
            out_tmp = config.ZSJLRNN.gcn_in_features + \
                      config.ZSJLRNN.gcn_out_features * len(config.data.label_relation_files)
        else:
            raise NotImplementedError
        self.doc_out_transform = torch.nn.Sequential(
            torch.nn.Linear(in_features=config.ZSJLRNN.hidden_dimension *
                            2 if config.ZSJLRNN.bidirectional else
                            config.ZSJLRNN.hidden_dimension,
                            out_features=out_tmp), torch.nn.ReLU())

    def get_parameter_optimizer_dict(self):
        params = list()
        params.append({'params': self.token_embedding.parameters()})
        params.append({'params': self.rnn.parameters()})
        params.append({'params': self.label_wise_attention.parameters()})
        params.append({'params': self.multi_gcn.parameters()})
        params.append({'params': self.multi_gcn_fuse.parameters()})
        params.append({'params': self.doc_out_transform.parameters()})
        return params

    def update_lr(self, optimizer, epoch):
        if epoch > self.config.train.num_epochs_static_embedding:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = self.config.optimizer.learning_rate
        else:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = 0

    def forward(self, batch):
        if self.config.feature.feature_names[0] == "token":
            embedding = self.token_embedding(batch[cDataset.DOC_TOKEN].to(
                self.config.device))
            length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device)
        else:
            raise NotImplementedError
        doc_embedding, _ = self.rnn(embedding, length)

        label_repr = self.label_embedding(batch[cDataset.DOC_LABEL_ID].to(
            self.config.device))
        attentive_doc_embedding = self.label_wise_attention(
            doc_embedding, label_repr)

        label_repr_gcn = torch.unsqueeze(label_repr, dim=0).repeat(
            len(self.config.data.label_relation_files), 1, 1)
        for gcn_layer in self.multi_gcn:
            label_repr_gcn = gcn_layer(
                label_repr_gcn,
                batch[cDataset.DOC_LABEL_RELATION].to(self.config.device))
        # do some other fusion operations
        label_repr_gcn = self.multi_gcn_fuse(
            torch.unbind(label_repr_gcn, dim=0))
        label_repr = torch.cat((label_repr, label_repr_gcn), dim=1)

        return torch.sum(self.doc_out_transform(attentive_doc_embedding) *
                         label_repr,
                         dim=-1)
コード例 #10
0
class TextRCNN(Classifier):
    """TextRNN + TextCNN
    """
    def __init__(self, dataset, config):
        super(TextRCNN, self).__init__(dataset, config)
        self.rnn = RNN(
            config.embedding.dimension, config.TextRCNN.hidden_dimension,
            num_layers=config.TextRCNN.num_layers,
            batch_first=True, bidirectional=config.TextRCNN.bidirectional,
            rnn_type=config.TextRCNN.rnn_type)

        hidden_dimension = config.TextRCNN.hidden_dimension
        if config.TextRCNN.bidirectional:
            hidden_dimension *= 2
        self.kernel_sizes = config.TextRCNN.kernel_sizes
        self.convs = torch.nn.ModuleList()
        for kernel_size in self.kernel_sizes:
            self.convs.append(torch.nn.Conv1d(
                hidden_dimension, config.TextRCNN.num_kernels,
                kernel_size, padding=kernel_size - 1))

        self.top_k = self.config.TextRCNN.top_k_max_pooling
        hidden_size = len(config.TextRCNN.kernel_sizes) * \
                      config.TextRCNN.num_kernels * self.top_k

        self.linear = torch.nn.Linear(hidden_size, len(dataset.label_map))
        self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout)

    def get_parameter_optimizer_dict(self):
        params = list()
        params.append({'params': self.token_embedding.parameters()})
        params.append({'params': self.char_embedding.parameters()})
        params.append({'params': self.rnn.parameters()})
        params.append({'params': self.convs.parameters()})
        params.append({'params': self.linear.parameters()})
        return params

    def update_lr(self, optimizer, epoch):
        """
        """
        if epoch > self.config.train.num_epochs_static_embedding:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = self.config.optimizer.learning_rate
        else:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = 0

    def forward(self, batch):
        if self.config.feature.feature_names[0] == "token":
            embedding = self.token_embedding(
                batch[cDataset.DOC_TOKEN].to(self.config.device))
            length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device)
        else:
            embedding = self.char_embedding(
                batch[cDataset.DOC_CHAR].to(self.config.device))
            length = batch[cDataset.DOC_CHAR_LEN].to(self.config.device)
        output, _ = self.rnn(embedding, length)

        doc_embedding = output.transpose(1, 2)
        pooled_outputs = []
        for _, conv in enumerate(self.convs):
            convolution = F.relu(conv(doc_embedding))
            pooled = torch.topk(convolution, self.top_k)[0].view(
                convolution.size(0), -1)
            pooled_outputs.append(pooled)

        doc_embedding = torch.cat(pooled_outputs, 1)

        return self.dropout(self.linear(doc_embedding))
コード例 #11
0
class DRNN(Classifier):
    def __init__(self, dataset, config):
        super(DRNN, self).__init__(dataset, config)
        self.rnn_type = config.DRNN.rnn_type
        self.forward_rnn = RNN(config.embedding.dimension,
                               config.DRNN.hidden_dimension,
                               batch_first=True,
                               rnn_type=config.DRNN.rnn_type)
        if config.DRNN.bidirectional:
            self.backward_rnn = RNN(config.embedding.dimension,
                                    config.DRNN.hidden_dimension,
                                    batch_first=True,
                                    rnn_type=config.DRNN.rnn_type)
        self.window_size = config.DRNN.window_size
        self.dropout = torch.nn.Dropout(p=config.DRNN.cell_hidden_dropout)
        self.hidden_dimension = config.DRNN.hidden_dimension
        if config.DRNN.bidirectional:
            self.hidden_dimension *= 2
        self.batch_norm = torch.nn.BatchNorm1d(self.hidden_dimension)

        self.mlp = torch.nn.Linear(self.hidden_dimension,
                                   self.hidden_dimension)
        self.linear = torch.nn.Linear(self.hidden_dimension,
                                      len(dataset.label_map))

    def get_parameter_optimizer_dict(self):
        params = super(DRNN, self).get_parameter_optimizer_dict()
        params.append({'params': self.forward_rnn.parameters()})
        if self.config.DRNN.bidirectional:
            params.append({'params': self.backward_rnn.parameters()})
        params.append({'params': self.batch_norm.parameters()})
        params.append({'params': self.mlp.parameters()})
        params.append({'params': self.linear.parameters()})
        return params

    def forward(self, batch):
        front_pad_embedding, _, mask = self.get_embedding(
            batch, [self.window_size - 1, 0], cDataset.VOCAB_PADDING_LEARNABLE)
        front_pad_embedding = self.token_similarity_attention(
            front_pad_embedding)
        if self.config.DRNN.bidirectional:
            tail_pad_embedding, _, _ = self.get_embedding(
                batch, [0, self.window_size - 1],
                cDataset.VOCAB_PADDING_LEARNABLE)
        batch_size = front_pad_embedding.size(0)
        mask = mask.unsqueeze(2)

        front_slice_embedding_list = \
            [front_pad_embedding[:, i:i + self.window_size, :] for i in
             range(front_pad_embedding.size(1) - self.window_size + 1)]

        front_slice_embedding = torch.cat(front_slice_embedding_list, dim=0)

        state = None
        for i in range(front_slice_embedding.size(1)):
            _, state = self.forward_rnn(front_slice_embedding[:, i:i + 1, :],
                                        init_state=state,
                                        ori_state=True)
            if self.rnn_type == RNNType.LSTM:
                state[0] = self.dropout(state[0])
            else:
                state = self.dropout(state)
        front_state = state[0] if self.rnn_type == RNNType.LSTM else state
        front_state = front_state.transpose(0, 1)
        front_hidden = torch.cat(front_state.split(batch_size, dim=0), dim=1)
        front_hidden = front_hidden * mask

        hidden = front_hidden
        if self.config.DRNN.bidirectional:
            tail_slice_embedding_list = list()
            for i in range(tail_pad_embedding.size(1) - self.window_size + 1):
                slice_embedding = \
                    tail_pad_embedding[:, i:i + self.window_size, :]
                tail_slice_embedding_list.append(slice_embedding)
            tail_slice_embedding = torch.cat(tail_slice_embedding_list, dim=0)

            state = None
            for i in range(tail_slice_embedding.size(1), 0, -1):
                _, state = self.backward_rnn(tail_slice_embedding[:,
                                                                  i - 1:i, :],
                                             init_state=state,
                                             ori_state=True)
                if i != tail_slice_embedding.size(1) - 1:
                    if self.rnn_type == RNNType.LSTM:
                        state[0] = self.dropout(state[0])
                    else:
                        state = self.dropout(state)
            tail_state = state[0] if self.rnn_type == RNNType.LSTM else state
            tail_state = tail_state.transpose(0, 1)
            tail_hidden = torch.cat(tail_state.split(batch_size, dim=0), dim=1)
            tail_hidden = tail_hidden * mask
            hidden = torch.cat([hidden, tail_hidden], dim=2)

        hidden = hidden.transpose(1, 2).contiguous()

        batch_normed = self.batch_norm(hidden).transpose(1, 2)
        batch_normed = batch_normed * mask
        mlp_hidden = self.mlp(batch_normed)
        mlp_hidden = mlp_hidden * mask
        neg_mask = (mask - 1) * 65500.0
        mlp_hidden = mlp_hidden + neg_mask
        max_pooling = torch.nn.functional.max_pool1d(
            mlp_hidden.transpose(1, 2), mlp_hidden.size(1)).squeeze()
        return self.linear(self.dropout(max_pooling))

    def token_similarity_attention(self, output):
        # output: (batch, sentence length, embedding dim)
        symptom_id_list = [
            6, 134, 15, 78, 2616, 257, 402, 281, 14848, 71, 82, 96, 352, 60,
            227, 204, 178, 175, 233, 192, 416, 91, 232, 317, 17513, 628, 1047
        ]
        symptom_embedding = self.token_embedding(
            torch.LongTensor(symptom_id_list).cuda())
        # symptom_embedding: torch.tensor(symptom_num, embedding dim)
        batch_symptom_embedding = torch.cat(
            [symptom_embedding.view(1, symptom_embedding.shape[0], -1)] *
            output.shape[0],
            dim=0)
        similarity = torch.sigmoid(
            torch.bmm(
                torch.nn.functional.normalize(output, dim=2),
                torch.nn.functional.normalize(batch_symptom_embedding.permute(
                    0, 2, 1),
                                              dim=2)))
        #similarity = torch.bmm(torch.nn.functional.normalize(output, dim=2), torch.nn.functional.normalize(batch_symptom_embedding.permute(0, 2, 1), dim=2))
        #similarity = torch.sigmoid(torch.max(similarity, dim=2)[0])
        similarity = torch.max(similarity, dim=2)[0]
        #similarity = torch.sigmoid(torch.sum(similarity, dim=2))
        # similarity: torch.tensor(batch, sentence_len)
        similarity = torch.cat([similarity.view(similarity.shape[0], -1, 1)] *
                               output.shape[2],
                               dim=2)
        # similarity: torch.tensor(batch, batch, sentence_len, embedding dim)
        #sentence_embedding = torch.sum(torch.mul(similarity, output), dim=1)
        # sentence_embedding: (batch, embedding)
        sentence_embedding = torch.mul(similarity, output)
        # sentence_embedding: (batch, sentence len, embedding)
        return sentence_embedding
コード例 #12
0
class TextRNN(Classifier):
    """Implement TextRNN, contains LSTM,BiLSTM,GRU,BiGRU
    Reference: "Effective LSTMs for Target-Dependent Sentiment Classification"
               "Bidirectional LSTM-CRF Models for Sequence Tagging"
               "Generative and discriminative text classification
                with recurrent neural networks"
    """
    def __init__(self, dataset, config):
        super(TextRNN, self).__init__(dataset, config)
        self.rnn = RNN(config.embedding.dimension,
                       config.TextRNN.hidden_dimension,
                       num_layers=config.TextRNN.num_layers,
                       batch_first=True,
                       bidirectional=config.TextRNN.bidirectional,
                       rnn_type=config.TextRNN.rnn_type)
        self.rnn2 = RNN(config.embedding.dimension,
                        config.TextRNN.hidden_dimension,
                        num_layers=config.TextRNN.num_layers,
                        batch_first=True,
                        bidirectional=config.TextRNN.bidirectional,
                        rnn_type=config.TextRNN.rnn_type)
        hidden_dimension = config.TextRNN.hidden_dimension
        if config.TextRNN.bidirectional:
            hidden_dimension *= 2
        self.sum_attention = SumAttention(hidden_dimension,
                                          config.TextRNN.attention_dimension,
                                          config.device)
        self.linear = torch.nn.Linear(hidden_dimension, len(dataset.label_map))
        self.dropout = torch.nn.Dropout(p=config.train.hidden_layer_dropout)

    def get_parameter_optimizer_dict(self):
        params = super(TextRNN, self).get_parameter_optimizer_dict()
        params.append({'params': self.rnn.parameters()})
        params.append({'params': self.linear.parameters()})
        return params

    def update_lr(self, optimizer, epoch):
        if epoch > self.config.train.num_epochs_static_embedding:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = self.config.optimizer.learning_rate
        else:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = 0.0

    def forward(self, batch):
        if self.config.feature.feature_names[0] == "token":
            embedding = self.token_embedding(batch[cDataset.DOC_TOKEN].to(
                self.config.device))
            length = batch[cDataset.DOC_TOKEN_LEN].to(self.config.device)
        else:
            embedding = self.char_embedding(batch[cDataset.DOC_CHAR].to(
                self.config.device))
            length = batch[cDataset.DOC_CHAR_LEN].to(self.config.device)
        #output1, _ = self.rnn2(embedding, length)
        #embedding = self.token_similarity_attention(embedding)
        output, last_hidden = self.rnn(embedding, length)

        doc_embedding_type = self.config.TextRNN.doc_embedding_type
        if doc_embedding_type == DocEmbeddingType.AVG:
            doc_embedding = torch.sum(output, 1) / length.unsqueeze(1)
        elif doc_embedding_type == DocEmbeddingType.ATTENTION:
            doc_embedding = self.sum_attention(output)
        elif doc_embedding_type == DocEmbeddingType.LAST_HIDDEN:
            doc_embedding = last_hidden
        else:
            raise TypeError(
                "Unsupported rnn init type: %s. Supported rnn type is: %s" %
                (doc_embedding_type, DocEmbeddingType.str()))

        return self.dropout(self.linear(doc_embedding))

    def token_similarity_attention(self, output):
        # output: (batch, sentence length, embedding dim)
        symptom_id_list = [
            6, 134, 15, 78, 2616, 257, 402, 281, 14848, 71, 82, 96, 352, 60,
            227, 204, 178, 175, 233, 192, 416, 91, 232, 317, 17513, 628, 1047
        ]
        symptom_embedding = self.token_embedding(
            torch.LongTensor(symptom_id_list).cuda())
        # symptom_embedding: torch.tensor(symptom_num, embedding dim)
        batch_symptom_embedding = torch.cat(
            [symptom_embedding.view(1, symptom_embedding.shape[0], -1)] *
            output.shape[0],
            dim=0)
        similarity = torch.sigmoid(
            torch.bmm(
                torch.nn.functional.normalize(output, dim=2),
                torch.nn.functional.normalize(batch_symptom_embedding.permute(
                    0, 2, 1),
                                              dim=2)))
        #similarity = torch.bmm(torch.nn.functional.normalize(output, dim=2), torch.nn.functional.normalize(batch_symptom_embedding.permute(0, 2, 1), dim=2))
        #similarity = torch.sigmoid(torch.max(similarity, dim=2)[0])
        similarity = torch.max(similarity, dim=2)[0]
        #similarity = torch.sigmoid(torch.sum(similarity, dim=2))
        # similarity: torch.tensor(batch, sentence_len)
        similarity = torch.cat([similarity.view(similarity.shape[0], -1, 1)] *
                               output.shape[2],
                               dim=2)
        # similarity: torch.tensor(batch, batch, sentence_len, embedding dim)
        #sentence_embedding = torch.sum(torch.mul(similarity, output), dim=1)
        # sentence_embedding: (batch, embedding)
        sentence_embedding = torch.mul(similarity, output)
        # sentence_embedding: (batch, sentence len, embedding)
        return sentence_embedding