示例#1
0
    def __init__(self, cfg, vocab):
        """This function constructs `PipelineEntModel` components and
        sets `PipelineEntModel` parameters

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models
            vocab {Vocabulary} -- vocabulary
        """

        super().__init__()
        self.vocab = vocab
        self.lstm_layers = cfg.lstm_layers

        if cfg.embedding_model == 'word_char':
            self.embedding_model = WordCharEmbedModel(cfg, vocab)
        else:
            self.embedding_model = BertEmbedModel(cfg, vocab)

        self.encoder_output_size = self.embedding_model.get_hidden_size()

        if self.lstm_layers > 0:
            self.seq_encoder = BiLSTMEncoder(
                input_size=self.encoder_output_size,
                hidden_size=cfg.lstm_hidden_unit_dims,
                num_layers=cfg.lstm_layers,
                dropout=cfg.dropout)
            self.encoder_output_size = self.seq_encoder.get_output_dims()

        self.ent_span_decoder = SeqSoftmaxDecoder(
            hidden_size=self.encoder_output_size,
            label_size=self.vocab.get_vocab_size('entity_span_labels'))

        self.cnn_ent_model = CNNEntModel(cfg, vocab, self.encoder_output_size)
示例#2
0
class PipelineEntModel(nn.Module):
    """This class utilizes pipeline method to handle
    entity recognition task, firstly detecting entity
    spans with sequence labeling then using CNN for entity spans typing
    """
    def __init__(self, cfg, vocab):
        """This function constructs `PipelineEntModel` components and
        sets `PipelineEntModel` parameters

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models
            vocab {Vocabulary} -- vocabulary
        """

        super().__init__()
        self.vocab = vocab
        self.lstm_layers = cfg.lstm_layers

        if cfg.embedding_model == 'word_char':
            self.embedding_model = WordCharEmbedModel(cfg, vocab)
        else:
            self.embedding_model = BertEmbedModel(cfg, vocab)

        self.encoder_output_size = self.embedding_model.get_hidden_size()

        if self.lstm_layers > 0:
            self.seq_encoder = BiLSTMEncoder(
                input_size=self.encoder_output_size,
                hidden_size=cfg.lstm_hidden_unit_dims,
                num_layers=cfg.lstm_layers,
                dropout=cfg.dropout)
            self.encoder_output_size = self.seq_encoder.get_output_dims()

        self.ent_span_decoder = SeqSoftmaxDecoder(
            hidden_size=self.encoder_output_size,
            label_size=self.vocab.get_vocab_size('entity_span_labels'))

        self.cnn_ent_model = CNNEntModel(cfg, vocab, self.encoder_output_size)

    def forward(self, batch_inputs):
        """This function propagetes forwardly

        Arguments:
            batch_inputs {dict} -- batch input data

        Returns:
            dict -- results: ent_loss, ent_pred
        """

        results = {}

        batch_seq_entity_span_labels = batch_inputs['entity_span_labels']
        batch_seq_tokens_lens = batch_inputs['tokens_lens']
        batch_seq_tokens_mask = batch_inputs['tokens_mask']

        self.embedding_model(batch_inputs)
        batch_seq_tokens_encoder_repr = batch_inputs['seq_encoder_reprs']

        if self.lstm_layers > 0:
            batch_seq_encoder_repr = self.seq_encoder(
                batch_seq_tokens_encoder_repr,
                batch_seq_tokens_lens).contiguous()
        else:
            batch_seq_encoder_repr = batch_seq_tokens_encoder_repr

        batch_inputs['seq_encoder_reprs'] = batch_seq_encoder_repr

        entity_span_ouputs = self.ent_span_decoder(
            batch_seq_encoder_repr, batch_seq_tokens_mask,
            batch_seq_entity_span_labels)

        batch_inputs['ent_span_preds'] = entity_span_ouputs['predict']

        results['ent_span_loss'] = entity_span_ouputs['loss']
        results['sequence_label_preds'] = entity_span_ouputs['predict']

        ent_model_outputs = self.cnn_ent_model(batch_inputs)
        ent_preds = self.get_ent_preds(batch_inputs, ent_model_outputs)

        results['all_ent_span_preds'] = batch_inputs['all_candi_ents']
        results['ent_loss'] = entity_span_ouputs['loss'] + ent_model_outputs[
            'ent_loss']
        results['all_ent_preds'] = ent_preds

        return results

    def get_ent_preds(self, batch_inputs, ent_model_outputs):
        """This funtion gets entity predictions from entity model outputs
        
        Arguments:
            batch_inputs {dict} -- batch input data
            ent_model_outputs {dict} -- entity model outputs
        
        Returns:
            list -- entity predictions
        """

        ent_preds = []
        candi_ent_cnt = 0
        for ents in batch_inputs['all_candi_ents']:
            cur_ents_num = len(ents)
            ent_pred = {}
            for ent, pred in zip(
                    ents, ent_model_outputs['ent_preds']
                [candi_ent_cnt:candi_ent_cnt + cur_ents_num]):
                ent_pred_label = self.vocab.get_token_from_index(
                    pred.item(), 'span2ent')
                if ent_pred_label != 'None':
                    ent_pred[ent] = ent_pred_label
            ent_preds.append(ent_pred)
            candi_ent_cnt += cur_ents_num

        return ent_preds

    def get_hidden_size(self):
        """This function returns sentence encoder representation tensor size
        
        Returns:
            int -- sequence encoder output size
        """

        return self.encoder_output_size

    def get_ent_span_feature_size(self):
        """This funtitoin returns entity span feature size
        
        Returns:
            int -- entity span feature size
        """
        return self.cnn_ent_model.get_ent_span_feature_size()
示例#3
0
class JointEntModel(nn.Module):
    """This class regrads entity recognition task as a sequence labeling task, and utilizes bilstm model to handle it
    """
    def __init__(self, cfg, vocab):
        """This function constructs `JointEntModel` components and
        sets `JointEntModel` parameters

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models
            vocab {Vocabulary} -- vocabulary
        """

        super().__init__()
        self.vocab = vocab
        self.lstm_layers = cfg.lstm_layers

        if cfg.embedding_model == 'word_char':
            self.embedding_model = WordCharEmbedModel(cfg, vocab)
        else:
            self.embedding_model = BertEmbedModel(cfg, vocab)

        self.encoder_output_size = self.embedding_model.get_hidden_size()

        if self.lstm_layers > 0:
            self.seq_encoder = BiLSTMEncoder(
                input_size=self.encoder_output_size,
                hidden_size=cfg.lstm_hidden_unit_dims,
                num_layers=cfg.lstm_layers,
                dropout=cfg.dropout)
            self.encoder_output_size = self.seq_encoder.get_output_dims()

        self.ent_decoder = SeqSoftmaxDecoder(
            hidden_size=self.encoder_output_size,
            label_size=self.vocab.get_vocab_size('entity_labels'))

    def forward(self, batch_inputs):
        """This function propagates forwardly

        Arguments:
            batch_inputs {dict} -- batch input data

        Returns:
            dict -- results: ent_loss, ent_pred
        """

        results = {}

        batch_seq_entity_labels = batch_inputs['entity_labels']
        batch_seq_tokens_lens = batch_inputs['tokens_lens']
        batch_seq_tokens_mask = batch_inputs['tokens_mask']

        self.embedding_model(batch_inputs)
        batch_seq_tokens_encoder_repr = batch_inputs['seq_encoder_reprs']

        if self.lstm_layers > 0:
            batch_seq_encoder_repr = self.seq_encoder(
                batch_seq_tokens_encoder_repr,
                batch_seq_tokens_lens).contiguous()
        else:
            batch_seq_encoder_repr = batch_seq_tokens_encoder_repr

        batch_inputs['seq_encoder_reprs'] = batch_seq_encoder_repr

        ent_outputs = self.ent_decoder(batch_seq_encoder_repr,
                                       batch_seq_tokens_mask,
                                       batch_seq_entity_labels)

        batch_inputs['ent_label_preds'] = ent_outputs['predict']

        ent_preds = self.get_ent_preds(batch_inputs)

        results['sequence_label_preds'] = ent_outputs['predict']
        results['ent_loss'] = ent_outputs['loss']
        results['all_ent_preds'] = ent_preds

        return results

    def get_ent_preds(self, batch_inputs):
        """This function gets entity predictions from entity decoder outputs
        
        Arguments:
            batch_inputs {dict} -- batch input data
        
        Returns:
            list -- entity predictions
        """

        ent_preds = []
        for idx, seq_len in enumerate(batch_inputs['tokens_lens']):
            ent_span_label = [
                self.vocab.get_token_from_index(label.item(), 'entity_labels')
                for label in batch_inputs['ent_label_preds'][idx][:seq_len]
            ]
            span2ent = get_entity_span(ent_span_label)
            ent_preds.append(span2ent)

        return ent_preds

    def get_hidden_size(self):
        """This function returns sentence encoder representation tensor size
        
        Returns:
            int -- sequence encoder output size
        """

        return self.encoder_output_size