Exemple #1
0
    def __init__(self, cfg, vocab, seq_encoder_output_size):
        """This function constructs `CNNEntModel` components and
        sets `CNNEntModel` parameters

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models
            vocab {Vocabulary} -- vocabulary
            seq_encoder_output_size {int} -- sequence encoder output size
        """

        super().__init__()
        self.vocab = vocab
        self.span_batch_size = cfg.span_batch_size
        self.ent_output_size = cfg.ent_output_size
        self.activation = nn.GELU()
        self.schedule_k = cfg.schedule_k
        self.device = cfg.device
        self.seq_encoder_output_size = seq_encoder_output_size
        self.pretrain_epoches = cfg.pretrain_epoches

        self.entity_span_extractor = CNNSpanExtractor(
            input_size=self.seq_encoder_output_size,
            num_filters=cfg.entity_cnn_output_channels,
            ngram_filter_sizes=cfg.entity_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.ent_output_size > 0:
            self.ent2hidden = BertLinear(
                input_size=self.entity_span_extractor.get_output_dims(),
                output_size=self.ent_output_size,
                activation=self.activation,
                dropout=cfg.dropout)
        else:
            self.ent_output_size = self.entity_span_extractor.get_output_dims()
            self.ent2hidden = lambda x: x

        self.entity_decoder = VanillaSoftmaxDecoder(
            hidden_size=self.ent_output_size,
            label_size=self.vocab.get_vocab_size('span2ent'))
Exemple #2
0
    def __init__(self, cfg, momentum=False):
        """This funciton constructs `PretrainedSpanEncoder` components

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models

        Keyword Arguments:
            momentum {bool} -- whether this encoder is momentum encoder (default: {False})
        """

        super().__init__()
        self.ent_output_size = cfg.ent_output_size
        self.span_batch_size = cfg.span_batch_size
        self.position_embedding_dims = cfg.position_embedding_dims
        self.att_size = cfg.att_size
        self.momentum = momentum
        self.activation = nn.GELU()
        self.device = cfg.device

        self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name,
                                        trainable=cfg.fine_tune,
                                        output_size=cfg.bert_output_size,
                                        activation=self.activation)

        self.entity_span_extractor = CNNSpanExtractor(
            input_size=self.bert_encoder.get_output_dims(),
            num_filters=cfg.entity_cnn_output_channels,
            ngram_filter_sizes=cfg.entity_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.ent_output_size > 0:
            self.ent2hidden = BertLinear(input_size=self.entity_span_extractor.get_output_dims(),
                                         output_size=self.ent_output_size,
                                         activation=self.activation,
                                         dropout=cfg.dropout)
        else:
            self.ent_output_size = self.entity_span_extractor.get_output_dims()
            self.ent2hidden = lambda x: x

        self.entity_span_mlp = BertLinear(input_size=self.ent_output_size,
                                          output_size=self.ent_output_size,
                                          activation=self.activation,
                                          dropout=cfg.dropout)
        self.entity_span_decoder = VanillaSoftmaxDecoder(hidden_size=self.ent_output_size,
                                                         label_size=6)

        self.global_position_embedding = nn.Embedding(150, 200)
        self.global_position_embedding.weight.data.normal_(mean=0.0, std=0.02)
        self.masked_token_mlp = BertLinear(input_size=self.bert_encoder.get_output_dims() + 200,
                                           output_size=self.bert_encoder.get_output_dims(),
                                           activation=self.activation,
                                           dropout=cfg.dropout)
        self.masked_token_decoder = nn.Linear(self.bert_encoder.get_output_dims(),
                                              28996,
                                              bias=False)
        self.masked_token_decoder.weight.data.normal_(mean=0.0, std=0.02)
        self.masked_token_decoder_bias = nn.Parameter(torch.zeros(28996))

        self.position_embedding = nn.Embedding(7, self.position_embedding_dims)
        self.position_embedding.weight.data.normal_(mean=0.0, std=0.02)
        self.attention_encoder = PosAwareAttEncoder(self.ent_output_size,
                                                    self.bert_encoder.get_output_dims(),
                                                    2 * self.position_embedding_dims,
                                                    self.att_size,
                                                    activation=self.activation,
                                                    dropout=cfg.dropout)

        self.mlp_head1 = BertLinear(self.ent_output_size,
                                    self.bert_encoder.get_output_dims(),
                                    activation=self.activation,
                                    dropout=cfg.dropout)
        self.mlp_head2 = BertLinear(self.bert_encoder.get_output_dims(),
                                    self.bert_encoder.get_output_dims(),
                                    activation=self.activation,
                                    dropout=cfg.dropout)

        self.masked_token_loss = nn.CrossEntropyLoss()
Exemple #3
0
class CNNEntModel(nn.Module):
    """This class predicts entities using CNN.
    """
    def __init__(self, cfg, vocab, seq_encoder_output_size):
        """This function constructs `CNNEntModel` components and
        sets `CNNEntModel` parameters

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models
            vocab {Vocabulary} -- vocabulary
            seq_encoder_output_size {int} -- sequence encoder output size
        """

        super().__init__()
        self.vocab = vocab
        self.span_batch_size = cfg.span_batch_size
        self.ent_output_size = cfg.ent_output_size
        self.activation = nn.GELU()
        self.schedule_k = cfg.schedule_k
        self.device = cfg.device
        self.seq_encoder_output_size = seq_encoder_output_size
        self.pretrain_epoches = cfg.pretrain_epoches

        self.entity_span_extractor = CNNSpanExtractor(
            input_size=self.seq_encoder_output_size,
            num_filters=cfg.entity_cnn_output_channels,
            ngram_filter_sizes=cfg.entity_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.ent_output_size > 0:
            self.ent2hidden = BertLinear(
                input_size=self.entity_span_extractor.get_output_dims(),
                output_size=self.ent_output_size,
                activation=self.activation,
                dropout=cfg.dropout)
        else:
            self.ent_output_size = self.entity_span_extractor.get_output_dims()
            self.ent2hidden = lambda x: x

        self.entity_decoder = VanillaSoftmaxDecoder(
            hidden_size=self.ent_output_size,
            label_size=self.vocab.get_vocab_size('span2ent'))

    def forward(self, batch_inputs):
        """This function propagetes forwardly

        Arguments:
            batch_inputs {dict} -- batch input data

        Returns:
            dict -- results: ent_loss, ent_pred
        """

        batch_seq_encoder_reprs = batch_inputs['seq_encoder_reprs']
        batch_ent_span_label_inputs = batch_inputs['entity_span_labels']
        batch_ent_span_preds = batch_inputs['ent_span_preds']
        seq_lens = batch_inputs['tokens_lens']

        results = {}

        if self.training and self.schedule_k > 0 and 'epoch' in batch_inputs:
            if batch_inputs['epoch'] > self.pretrain_epoches:
                schedule_p = self.schedule_k / (self.schedule_k + np.exp(
                    (batch_inputs['epoch'] - self.pretrain_epoches) /
                    self.schedule_k))
                ent_span_preds = [
                    gold if np.random.random() < schedule_p else pred
                    for gold, pred in zip(batch_ent_span_label_inputs,
                                          batch_ent_span_preds)
                ]
            else:
                ent_span_preds = [gold for gold in batch_ent_span_label_inputs]
            ent_span_preds = torch.stack(ent_span_preds)
        else:
            ent_span_preds = batch_ent_span_preds

        all_candi_ents, all_candi_ent_labels = self.generate_all_candi_ents(
            batch_inputs, ent_span_preds)

        batch_inputs['all_candi_ents'] = all_candi_ents
        batch_inputs['all_candi_ent_labels'] = all_candi_ent_labels

        if sum(len(candi_ents) for candi_ents in all_candi_ents) == 0:
            zero_loss = torch.Tensor([0])
            zero_loss.requires_grad = True
            if self.device > -1:
                zero_loss = zero_loss.cuda(device=self.device,
                                           non_blocking=True)
            batch_size = len(batch_inputs['tokens'])
            ent_labels = [[] for _ in range(batch_size)]
            results['ent_loss'] = zero_loss
            results['ent_preds'] = ent_labels
            return results

        batch_ent_spans_feature = self.cache_ent_spans_feature(
            all_candi_ents, seq_lens, batch_seq_encoder_reprs,
            self.entity_span_extractor, self.span_batch_size, self.device)

        batch_inputs['ent_spans_feature'] = batch_ent_spans_feature

        batch_ents = self.create_batch_ents(batch_inputs,
                                            batch_ent_spans_feature)

        entity_outputs = self.entity_decoder(
            batch_ents['ent_inputs'], batch_ents['all_candi_ent_labels'])

        results['ent_loss'] = entity_outputs['loss']
        results['ent_preds'] = entity_outputs['predict']

        return results

    def get_ent_span_feature_size(self):
        """This funtitoin returns entity span feature size
        
        Returns:
            int -- entity span feature size
        """
        return self.ent_output_size

    def create_batch_ents(self, batch_inputs, batch_ent_spans_feature):
        """This function creates batch entity inputs

        Arguments:
            batch_inputs {dict} -- batch inputs
            batch_ent_spans_feature {list} -- entity spans feature list

        Returns:
            dict -- batch entity inputs
        """

        batch_ents = defaultdict(list)

        for idx, _ in enumerate(batch_inputs['tokens_lens']):
            batch_ents['ent_inputs'].extend(
                batch_ent_spans_feature[idx][ent_span]
                for ent_span in batch_inputs['all_candi_ents'][idx])
            batch_ents['all_candi_ent_labels'].extend(
                batch_inputs['all_candi_ent_labels'][idx])

        batch_ents['ent_inputs'] = torch.stack(batch_ents['ent_inputs'])

        batch_ents['all_candi_ent_labels'] = torch.LongTensor(
            batch_ents['all_candi_ent_labels'])
        if self.device > -1:
            batch_ents['all_candi_ent_labels'] = batch_ents[
                'all_candi_ent_labels'].cuda(device=self.device,
                                             non_blocking=True)

        return batch_ents

    def cache_ent_spans_feature(self, batch_ent_spans, seq_lens,
                                batch_seq_encoder_reprs, ent_sapn_extractor,
                                ent_batch_size, device):
        """This function extracts all entity spans feature for caching

        Arguments:
            batch_ent_spans {list} -- batch entity spans
            seq_lens {list} -- batch sequence length
            batch_seq_encoder_reprs {list} -- batch sequence encoder reprentations
            ent_sapn_extractor {nn.Module} -- entity extractor model
            ent_batch_size {int} -- entity batch size
            device {int} -- device {int} -- device id: cpu: -1, gpu: >= 0 (default: {-1})

        Returns:
            list -- batch caching spans feature
        """

        assert len(batch_ent_spans) == len(
            batch_seq_encoder_reprs), "batch spans' size is not correct."

        all_spans = []
        all_seq_encoder_reprs = []
        for ent_spans, seq_encoder_reprs in zip(batch_ent_spans,
                                                batch_seq_encoder_reprs):
            all_spans.extend(ent_spans)
            all_seq_encoder_reprs.extend(
                [seq_encoder_reprs for _ in range(len(ent_spans))])

        ent_spans_feature = [{} for _ in range(len(seq_lens))]

        if len(all_spans) == 0:
            return ent_spans_feature

        if ent_batch_size > 0:
            all_spans_feature = []
            for idx in range(0, len(all_spans), ent_batch_size):
                batch_spans_tensor = torch.LongTensor(
                    all_spans[idx:idx + ent_batch_size]).unsqueeze(1)
                if self.device > -1:
                    batch_spans_tensor = batch_spans_tensor.cuda(
                        device=device, non_blocking=True)
                batch_seq_encoder_reprs = torch.stack(
                    all_seq_encoder_reprs[idx:idx + ent_batch_size])

                all_spans_feature.append(
                    ent_sapn_extractor(batch_seq_encoder_reprs,
                                       batch_spans_tensor).squeeze(1))
            all_spans_feature = torch.cat(all_spans_feature, dim=0)
        else:
            all_spans_tensor = torch.LongTensor(all_spans).unsqueeze(1)
            if self.device > -1:
                all_spans_tensor = all_spans_tensor.cuda(device=device,
                                                         non_blocking=True)
            all_seq_encoder_reprs = torch.stack(all_seq_encoder_reprs)
            all_spans_feature = ent_sapn_extractor(all_seq_encoder_reprs,
                                                   all_spans_tensor).squeeze(1)

        all_spans_feature = self.ent2hidden(all_spans_feature)

        idx = 0
        for i, ent_spans in enumerate(batch_ent_spans):
            for ent_span in ent_spans:
                ent_spans_feature[i][ent_span] = all_spans_feature[idx]
                idx += 1

        return ent_spans_feature

    def generate_all_candi_ents(self, batch_inputs, ent_span_labels):
        """This funtion generate all candidate entities
        
        Arguments:
            batch_inputs {dict} -- batch input data
            ent_span_labels {list} -- entity span labels list
        
        Returns:
            tuple -- all candidate entities, all candidate entities label
        """

        all_candi_ents = []
        all_candi_ent_labels = []
        for idx, seq_len in enumerate(batch_inputs['tokens_lens']):
            ent_span_label = [
                self.vocab.get_token_from_index(label.item(),
                                                'entity_span_labels')
                for label in ent_span_labels[idx][:seq_len]
            ]
            ent_span_label = [
                item if item == 'O' else item + '-ENT'
                for item in ent_span_label
            ]
            span2ent = get_entity_span(ent_span_label)
            candi_ents = set(span2ent.keys())
            if self.training:
                candi_ents.update(batch_inputs['span2ent'][idx].keys())

            candi_ents = list(candi_ents)
            candi_ent_labels = []
            for ent in candi_ents:
                if ent in batch_inputs['span2ent'][idx]:
                    candi_ent_labels.append(batch_inputs['span2ent'][idx][ent])
                else:
                    candi_ent_labels.append(
                        self.vocab.get_token_index('None', 'span2ent'))

            all_candi_ents.append(candi_ents)
            all_candi_ent_labels.append(candi_ent_labels)

        return all_candi_ents, all_candi_ent_labels
Exemple #4
0
    def __init__(self, cfg, vocab, input_size, reduction='mean'):
        """This function sets `EntConRelAttModel` parameters

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models
            vocab {dict} -- vocabulary
            input_size {int} -- input size

        Keyword Arguments:
            reduction {str} -- crossentropy loss recduction (default: {mean})
        """

        super().__init__()

        self.span_batch_size = cfg.span_batch_size
        self.ent_output_size = cfg.ent_output_size
        self.context_output_size = cfg.context_output_size
        self.output_size = cfg.ent_mention_output_size
        self.att_size = cfg.att_size
        self.position_embedding_dims = cfg.position_embedding_dims
        self.activation = nn.GELU()
        self.dropout = cfg.dropout
        self.device = cfg.device

        self.entity_span_extractor = CNNSpanExtractor(
            input_size=input_size,
            num_filters=cfg.entity_cnn_output_channels,
            ngram_filter_sizes=cfg.entity_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.ent_output_size > 0:
            self.ent2hidden = BertLinear(
                input_size=self.entity_span_extractor.get_output_dims(),
                output_size=self.ent_output_size,
                activation=self.activation,
                dropout=cfg.dropout)
        else:
            self.ent_output_size = self.entity_span_extractor.get_output_dims()
            self.ent2hidden = lambda x: x

        self.context_span_extractor = CNNSpanExtractor(
            input_size=input_size,
            num_filters=cfg.context_cnn_output_channels,
            ngram_filter_sizes=cfg.context_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.context_output_size > 0:
            self.context2hidden = BertLinear(
                input_size=self.context_span_extractor.get_output_dims(),
                output_size=self.context_output_size,
                activation=self.activation,
                dropout=self.dropout)
        else:
            self.context_output_size = self.context_span_extractor.get_output_dims(
            )
            self.context2hidden = lambda x: x

        assert self.ent_output_size == self.context_output_size

        self.position_embedding = nn.Embedding(7, self.position_embedding_dims)
        self.position_embedding.weight.data.normal_(mean=0.0, std=0.02)
        self.attention_encoder = PosAwareAttEncoder(
            self.ent_output_size,
            input_size,
            2 * self.position_embedding_dims,
            self.att_size,
            activation=self.activation,
            dropout=cfg.dropout)
        self.ent_mention_mlp = BertLinear(input_size=self.ent_output_size,
                                          output_size=self.ent_output_size,
                                          activation=self.activation,
                                          dropout=self.dropout)

        if self.output_size > 0:
            self.mlp = BertLinear(input_size=self.ent_output_size,
                                  output_size=self.output_size,
                                  activation=self.activation,
                                  dropout=self.dropout)
        else:
            self.output_size = self.ent_output_size
            self.mlp = lambda x: x

        self.relation_decoder = VanillaSoftmaxDecoder(
            hidden_size=self.output_size,
            label_size=vocab.get_vocab_size('span2rel'),
            reduction=reduction)

        self.subj_pos = torch.LongTensor([-1, 0, 1, 2, 3]) + 3
        self.obj_pos = torch.LongTensor([-3, -2, -1, 0, 1]) + 3
        self.context_zero_feat = torch.zeros(
            self.entity_span_extractor.get_output_dims())
        if self.device > -1:
            self.subj_pos = self.subj_pos.cuda(device=self.device,
                                               non_blocking=True)
            self.obj_pos = self.obj_pos.cuda(device=self.device,
                                             non_blocking=True)
            self.context_zero_feat = self.context_zero_feat.cuda(
                device=self.device, non_blocking=True)
Exemple #5
0
class PretrainedSpanEncoder(nn.Module):
    """PretrainedSpanEncoder encodes span into vector.
    """
    def __init__(self, cfg, momentum=False):
        """This funciton constructs `PretrainedSpanEncoder` components

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models

        Keyword Arguments:
            momentum {bool} -- whether this encoder is momentum encoder (default: {False})
        """

        super().__init__()
        self.ent_output_size = cfg.ent_output_size
        self.span_batch_size = cfg.span_batch_size
        self.position_embedding_dims = cfg.position_embedding_dims
        self.att_size = cfg.att_size
        self.momentum = momentum
        self.activation = nn.GELU()
        self.device = cfg.device

        self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name,
                                        trainable=cfg.fine_tune,
                                        output_size=cfg.bert_output_size,
                                        activation=self.activation)

        self.entity_span_extractor = CNNSpanExtractor(
            input_size=self.bert_encoder.get_output_dims(),
            num_filters=cfg.entity_cnn_output_channels,
            ngram_filter_sizes=cfg.entity_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.ent_output_size > 0:
            self.ent2hidden = BertLinear(input_size=self.entity_span_extractor.get_output_dims(),
                                         output_size=self.ent_output_size,
                                         activation=self.activation,
                                         dropout=cfg.dropout)
        else:
            self.ent_output_size = self.entity_span_extractor.get_output_dims()
            self.ent2hidden = lambda x: x

        self.entity_span_mlp = BertLinear(input_size=self.ent_output_size,
                                          output_size=self.ent_output_size,
                                          activation=self.activation,
                                          dropout=cfg.dropout)
        self.entity_span_decoder = VanillaSoftmaxDecoder(hidden_size=self.ent_output_size,
                                                         label_size=6)

        self.global_position_embedding = nn.Embedding(150, 200)
        self.global_position_embedding.weight.data.normal_(mean=0.0, std=0.02)
        self.masked_token_mlp = BertLinear(input_size=self.bert_encoder.get_output_dims() + 200,
                                           output_size=self.bert_encoder.get_output_dims(),
                                           activation=self.activation,
                                           dropout=cfg.dropout)
        self.masked_token_decoder = nn.Linear(self.bert_encoder.get_output_dims(),
                                              28996,
                                              bias=False)
        self.masked_token_decoder.weight.data.normal_(mean=0.0, std=0.02)
        self.masked_token_decoder_bias = nn.Parameter(torch.zeros(28996))

        self.position_embedding = nn.Embedding(7, self.position_embedding_dims)
        self.position_embedding.weight.data.normal_(mean=0.0, std=0.02)
        self.attention_encoder = PosAwareAttEncoder(self.ent_output_size,
                                                    self.bert_encoder.get_output_dims(),
                                                    2 * self.position_embedding_dims,
                                                    self.att_size,
                                                    activation=self.activation,
                                                    dropout=cfg.dropout)

        self.mlp_head1 = BertLinear(self.ent_output_size,
                                    self.bert_encoder.get_output_dims(),
                                    activation=self.activation,
                                    dropout=cfg.dropout)
        self.mlp_head2 = BertLinear(self.bert_encoder.get_output_dims(),
                                    self.bert_encoder.get_output_dims(),
                                    activation=self.activation,
                                    dropout=cfg.dropout)

        self.masked_token_loss = nn.CrossEntropyLoss()

    def forward(self, batch_inputs):
        """This function propagetes forwardly

        Arguments:
            batch_inputs {dict} -- batch inputs
        
        Returns:
            dict -- results
        """

        batch_seq_wordpiece_tokens_repr, batch_seq_cls_repr = self.bert_encoder(
            batch_inputs['wordpiece_tokens'])
        batch_seq_tokens_repr = batched_index_select(batch_seq_wordpiece_tokens_repr,
                                                     batch_inputs['wordpiece_tokens_index'])

        results = {}

        entity_feature = self.entity_span_extractor(batch_seq_tokens_repr,
                                                    batch_inputs['span_mention'])
        entity_feature = self.ent2hidden(entity_feature)

        subj_pos = torch.LongTensor([-1, 0, 1, 2, 3]) + 3
        obj_pos = torch.LongTensor([-3, -2, -1, 0, 1]) + 3

        if self.device > -1:
            subj_pos = subj_pos.cuda(device=self.device, non_blocking=True)
            obj_pos = obj_pos.cuda(device=self.device, non_blocking=True)

        subj_pos_emb = self.position_embedding(subj_pos)
        obj_pos_emb = self.position_embedding(obj_pos)
        pos_emb = torch.cat([subj_pos_emb, obj_pos_emb], dim=1).unsqueeze(0).repeat(
            batch_inputs['wordpiece_tokens_index'].size()[0], 1, 1)

        span_mention_attention_repr = self.attention_encoder(inputs=entity_feature,
                                                             query=batch_seq_cls_repr,
                                                             feature=pos_emb)
        results['span_mention_repr'] = self.mlp_head2(self.mlp_head1(span_mention_attention_repr))

        if self.momentum:
            return results

        zero_loss = torch.Tensor([0])
        zero_loss.requires_grad = True
        if self.device > -1:
            zero_loss = zero_loss.cuda(device=self.device, non_blocking=True)

        if sum([len(masked_index) for masked_index in batch_inputs['masked_index']]) == 0:
            results['masked_token_loss'] = zero_loss
        else:
            masked_wordpiece_tokens_repr = []
            all_masked_label = []
            for masked_index, masked_position, masked_label, seq_wordpiece_tokens_repr in zip(
                    batch_inputs['masked_index'], batch_inputs['masked_position'],
                    batch_inputs['masked_label'], batch_seq_wordpiece_tokens_repr):
                masked_index_tensor = torch.LongTensor(masked_index)
                masked_position_tensor = torch.LongTensor(masked_position)

                if self.device > -1:
                    masked_index_tensor = masked_index_tensor.cuda(device=self.device,
                                                                   non_blocking=True)
                    masked_position_tensor = masked_position_tensor.cuda(device=self.device,
                                                                         non_blocking=True)

                masked_wordpiece_tokens_repr.append(
                    torch.cat([
                        seq_wordpiece_tokens_repr[masked_index_tensor],
                        self.global_position_embedding(masked_position_tensor)
                    ],
                              dim=1))
                all_masked_label.extend(masked_label)

            masked_wordpiece_tokens_input = torch.cat(masked_wordpiece_tokens_repr, dim=0)
            masked_wordpiece_tokens_output = self.masked_token_decoder(
                self.masked_token_mlp(
                    masked_wordpiece_tokens_input)) + self.masked_token_decoder_bias

            all_masked_label_tensor = torch.LongTensor(all_masked_label)
            if self.device > -1:
                all_masked_label_tensor = all_masked_label_tensor.cuda(device=self.device,
                                                                       non_blocking=True)
            results['masked_token_loss'] = self.masked_token_loss(masked_wordpiece_tokens_output,
                                                                  all_masked_label_tensor)

        all_spans = []
        all_spans_label = []
        all_seq_tokens_reprs = []
        for spans, spans_label, seq_tokens_repr in zip(batch_inputs['spans'],
                                                       batch_inputs['spans_label'],
                                                       batch_seq_tokens_repr):
            all_spans.extend(spans)
            all_spans_label.extend(spans_label)
            all_seq_tokens_reprs.extend(seq_tokens_repr for _ in range(len(spans)))

        assert len(all_spans) == len(all_seq_tokens_reprs) and len(all_spans) == len(
            all_spans_label)

        if len(all_spans) == 0:
            results['span_loss'] = zero_loss
        else:
            if self.span_batch_size > 0:
                all_span_loss = []
                for idx in range(0, len(all_spans), self.span_batch_size):
                    batch_ents_tensor = torch.LongTensor(
                        all_spans[idx:idx + self.span_batch_size]).unsqueeze(1)
                    if self.device > -1:
                        batch_ents_tensor = batch_ents_tensor.cuda(device=self.device,
                                                                   non_blocking=True)

                    batch_seq_tokens_reprs = torch.stack(all_seq_tokens_reprs[idx:idx +
                                                                              self.span_batch_size])

                    batch_spans_feature = self.ent2hidden(
                        self.entity_span_extractor(batch_seq_tokens_reprs,
                                                   batch_ents_tensor).squeeze(1))

                    batch_spans_label = torch.LongTensor(all_spans_label[idx:idx +
                                                                         self.span_batch_size])
                    if self.device > -1:
                        batch_spans_label = batch_spans_label.cuda(device=self.device,
                                                                   non_blocking=True)

                    span_outputs = self.entity_span_decoder(
                        self.entity_span_mlp(batch_spans_feature), batch_spans_label)

                    all_span_loss.append(span_outputs['loss'])
                results['span_loss'] = sum(all_span_loss) / len(all_span_loss)
            else:
                all_spans_tensor = torch.LongTensor(all_spans).unsqueeze(1)
                if self.device > -1:
                    all_spans_tensor = all_spans_tensor.cuda(device=self.device, non_blocking=True)
                all_seq_tokens_reprs = torch.stack(all_seq_tokens_reprs)
                all_spans_feature = self.entity_span_extractor(all_seq_tokens_reprs,
                                                               all_spans_tensor).squeeze(1)

                all_spans_feature = self.ent2hidden(all_spans_feature)

                all_spans_label = torch.LongTensor(all_spans_label)
                if self.device > -1:
                    all_spans_label = all_spans_label.cuda(device=self.device, non_blocking=True)

                entity_typing_outputs = self.entity_span_decoder(
                    self.entity_span_mlp(all_spans_feature), all_spans_label)

                results['span_loss'] = entity_typing_outputs['loss']

        return results
Exemple #6
0
class EntConRelAttModel(nn.Module):
    """This class predicts relation between two candidate entity with attention,
    and extracts entity span representations.
    """
    def __init__(self, cfg, vocab, input_size, reduction='mean'):
        """This function sets `EntConRelAttModel` parameters

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models
            vocab {dict} -- vocabulary
            input_size {int} -- input size

        Keyword Arguments:
            reduction {str} -- crossentropy loss recduction (default: {mean})
        """

        super().__init__()

        self.span_batch_size = cfg.span_batch_size
        self.ent_output_size = cfg.ent_output_size
        self.context_output_size = cfg.context_output_size
        self.output_size = cfg.ent_mention_output_size
        self.att_size = cfg.att_size
        self.position_embedding_dims = cfg.position_embedding_dims
        self.activation = nn.GELU()
        self.dropout = cfg.dropout
        self.device = cfg.device

        self.entity_span_extractor = CNNSpanExtractor(
            input_size=input_size,
            num_filters=cfg.entity_cnn_output_channels,
            ngram_filter_sizes=cfg.entity_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.ent_output_size > 0:
            self.ent2hidden = BertLinear(
                input_size=self.entity_span_extractor.get_output_dims(),
                output_size=self.ent_output_size,
                activation=self.activation,
                dropout=cfg.dropout)
        else:
            self.ent_output_size = self.entity_span_extractor.get_output_dims()
            self.ent2hidden = lambda x: x

        self.context_span_extractor = CNNSpanExtractor(
            input_size=input_size,
            num_filters=cfg.context_cnn_output_channels,
            ngram_filter_sizes=cfg.context_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.context_output_size > 0:
            self.context2hidden = BertLinear(
                input_size=self.context_span_extractor.get_output_dims(),
                output_size=self.context_output_size,
                activation=self.activation,
                dropout=self.dropout)
        else:
            self.context_output_size = self.context_span_extractor.get_output_dims(
            )
            self.context2hidden = lambda x: x

        assert self.ent_output_size == self.context_output_size

        self.position_embedding = nn.Embedding(7, self.position_embedding_dims)
        self.position_embedding.weight.data.normal_(mean=0.0, std=0.02)
        self.attention_encoder = PosAwareAttEncoder(
            self.ent_output_size,
            input_size,
            2 * self.position_embedding_dims,
            self.att_size,
            activation=self.activation,
            dropout=cfg.dropout)
        self.ent_mention_mlp = BertLinear(input_size=self.ent_output_size,
                                          output_size=self.ent_output_size,
                                          activation=self.activation,
                                          dropout=self.dropout)

        if self.output_size > 0:
            self.mlp = BertLinear(input_size=self.ent_output_size,
                                  output_size=self.output_size,
                                  activation=self.activation,
                                  dropout=self.dropout)
        else:
            self.output_size = self.ent_output_size
            self.mlp = lambda x: x

        self.relation_decoder = VanillaSoftmaxDecoder(
            hidden_size=self.output_size,
            label_size=vocab.get_vocab_size('span2rel'),
            reduction=reduction)

        self.subj_pos = torch.LongTensor([-1, 0, 1, 2, 3]) + 3
        self.obj_pos = torch.LongTensor([-3, -2, -1, 0, 1]) + 3
        self.context_zero_feat = torch.zeros(
            self.entity_span_extractor.get_output_dims())
        if self.device > -1:
            self.subj_pos = self.subj_pos.cuda(device=self.device,
                                               non_blocking=True)
            self.obj_pos = self.obj_pos.cuda(device=self.device,
                                             non_blocking=True)
            self.context_zero_feat = self.context_zero_feat.cuda(
                device=self.device, non_blocking=True)

    def forward(self, batch_inputs):
        """This function propagetes forwardly

        Arguments:
            batch_inputs {dict} -- batch input data

        Returns:
            dict -- outputs: rel_inputs, all_candi_rel_labels
        """

        all_candi_rels = batch_inputs['all_candi_rels']
        seq_lens = batch_inputs['tokens_lens']
        batch_seq_encoder_reprs = batch_inputs['seq_encoder_reprs']

        batch_entity_spans, batch_context_spans = self.generate_all_spans(
            all_candi_rels, seq_lens)

        batch_entity_spans_feature = self.cache_spans_feature(
            batch_entity_spans, batch_seq_encoder_reprs,
            self.entity_span_extractor, self.span_batch_size, self.device)
        batch_context_spans_feature = self.cache_spans_feature(
            batch_context_spans, batch_seq_encoder_reprs,
            self.context_span_extractor, self.span_batch_size, self.device)

        batch_rels = self.create_batch_rels(batch_inputs,
                                            batch_entity_spans_feature,
                                            batch_context_spans_feature)

        relation_outputs = self.relation_decoder(
            batch_rels['rel_inputs'], batch_rels['all_candi_rel_labels'])

        results = {}
        results['rel_loss'] = relation_outputs['loss']
        results['rel_preds'] = relation_outputs['predict']

        return results

    def create_batch_rels(self, batch_inputs, batch_entity_spans_feature,
                          batch_context_spans_feature):
        """This function creates batch relation inputs

        Arguments:
            batch_inputs {dict} -- batch inputs
            batch_entity_spans_feature {list} -- entity span feature list
            batch_context_spans_feature {list} -- context spans feature list

        Returns:
            dict -- batch realtion inputs
        """

        batch_rels = defaultdict(list)

        for idx, seq_len in enumerate(batch_inputs['tokens_lens']):
            batch_rels['all_candi_rel_labels'].extend(
                batch_inputs['all_candi_rel_labels'][idx])
            for e1, e2 in batch_inputs['all_candi_rels'][idx]:
                L = (0, e1[0])
                E1 = (e1[0], e1[1])
                M = (e1[1], e2[0])
                E2 = (e2[0], e2[1])
                R = (e2[1], seq_len)

                if L[0] >= L[1]:
                    batch_rels['L'].append(self.context_zero_feat)
                else:
                    batch_rels['L'].append(batch_context_spans_feature[idx][L])

                if M[0] >= M[1]:
                    batch_rels['M'].append(self.context_zero_feat)
                else:
                    batch_rels['M'].append(batch_context_spans_feature[idx][M])

                if R[0] >= R[1]:
                    batch_rels['R'].append(self.context_zero_feat)
                else:
                    batch_rels['R'].append(batch_context_spans_feature[idx][R])

                batch_rels['E1'].append(batch_entity_spans_feature[idx][E1])
                batch_rels['E2'].append(batch_entity_spans_feature[idx][E2])

                batch_rels['cls'].append(batch_inputs['seq_cls_repr'][idx])

        batch_rels['E1'] = self.ent2hidden(torch.stack(
            batch_rels['E1'])).unsqueeze(1)
        batch_rels['E2'] = self.ent2hidden(torch.stack(
            batch_rels['E2'])).unsqueeze(1)

        batch_rels['L'] = self.ent2hidden(torch.stack(
            batch_rels['L'])).unsqueeze(1)
        batch_rels['M'] = self.ent2hidden(torch.stack(
            batch_rels['M'])).unsqueeze(1)
        batch_rels['R'] = self.ent2hidden(torch.stack(
            batch_rels['R'])).unsqueeze(1)

        batch_rels['cls'] = torch.stack(batch_rels['cls'])

        rel_feature = torch.cat([
            batch_rels['L'], batch_rels['E1'], batch_rels['M'],
            batch_rels['E2'], batch_rels['R']
        ],
                                dim=1)

        subj_pos_emb = self.position_embedding(self.subj_pos)
        obj_pos_emb = self.position_embedding(self.obj_pos)
        pos_emb = torch.cat([subj_pos_emb, obj_pos_emb],
                            dim=1).unsqueeze(0).repeat(rel_feature.size()[0],
                                                       1, 1)

        rel_feature = self.ent_mention_mlp(
            self.attention_encoder(inputs=rel_feature,
                                   query=batch_rels['cls'],
                                   feature=pos_emb))

        batch_rels['rel_inputs'] = self.mlp(rel_feature)

        batch_rels['all_candi_rel_labels'] = torch.LongTensor(
            batch_rels['all_candi_rel_labels'])
        if self.device > -1:
            batch_rels['all_candi_rel_labels'] = batch_rels[
                'all_candi_rel_labels'].cuda(device=self.device,
                                             non_blocking=True)

        return batch_rels

    def cache_spans_feature(self, batch_spans, batch_seq_encoder_reprs,
                            span_extractor, span_batch_size, device):
        """This function calculates spans feature for caching

        Arguments:
            spans {list} -- spans
            batch_seq_encoder_reprs {list} -- batch sequence encoder reprentations
            span_extractor {nn.Module} -- span extractor model
            span_batch_size {int} -- span batch size
            device {int} -- device {int} -- device id: cpu: -1, gpu: >= 0 (default: {-1})

        Returns:
            list -- batch caching spans feature
        """

        assert len(batch_spans) == len(
            batch_seq_encoder_reprs), "batch spans' size is not correct."

        all_spans = []
        all_seq_encoder_reprs = []
        for spans, seq_encoder_reprs in zip(batch_spans,
                                            batch_seq_encoder_reprs):
            all_spans.extend((span[0], span[1]) for span in spans)
            all_seq_encoder_reprs.extend(
                [seq_encoder_reprs for _ in range(len(spans))])

        batch_spans_feature = [{} for _ in range(len(batch_spans))]

        if len(all_spans) == 0:
            return batch_spans_feature

        if span_batch_size > 0:
            all_spans_feature = []
            for idx in range(0, len(all_spans), span_batch_size):
                batch_spans_tensor = torch.LongTensor(
                    all_spans[idx:idx + span_batch_size]).unsqueeze(1)
                if self.device > -1:
                    batch_spans_tensor = batch_spans_tensor.cuda(
                        device=device, non_blocking=True)
                batch_seq_encoder_reprs = torch.stack(
                    all_seq_encoder_reprs[idx:idx + span_batch_size])

                all_spans_feature.append(
                    span_extractor(batch_seq_encoder_reprs,
                                   batch_spans_tensor).squeeze(1))
            all_spans_feature = torch.cat(all_spans_feature, dim=0)
        else:
            all_spans_tensor = torch.LongTensor(all_spans).unsqueeze(1)
            if self.device > -1:
                all_spans_tensor = all_spans_tensor.cuda(device=device,
                                                         non_blocking=True)
            all_seq_encoder_reprs = torch.stack(all_seq_encoder_reprs)
            all_spans_feature = span_extractor(all_seq_encoder_reprs,
                                               all_spans_tensor).squeeze(1)

        idx = 0
        for i, spans in enumerate(batch_spans):
            for span in spans:
                batch_spans_feature[i][span] = all_spans_feature[idx]
                idx += 1

        return batch_spans_feature

    def generate_all_spans(self, all_candi_rels, seq_lens):
        """This function generates all entity and context spans

        Arguments:
            all_candi_rels {list} -- all candidate relation list
            seq_lens {list} -- batch sequence length

        Returns:
            list -- all entity and context spans
        """

        assert len(all_candi_rels) == len(
            seq_lens), "candidate relations' size is not correct."

        batch_entity_spans = []
        batch_context_spans = []
        for candi_rels, seq_len in zip(all_candi_rels, seq_lens):
            entity_spans = set()
            context_spans = set()
            for e1, e2 in candi_rels:
                if e1[0] < e1[1]:
                    entity_spans.add(e1)
                if e2[0] < e2[1]:
                    entity_spans.add(e2)

                L = (0, e1[0])
                M = (e1[1], e2[0])
                R = (e2[1], seq_len)

                # L, M, R can be empty
                for span in [L, M, R]:
                    if span[0] < span[1]:
                        context_spans.add(span)

            batch_entity_spans.append(list(entity_spans))
            batch_context_spans.append(list(context_spans))

        return batch_entity_spans, batch_context_spans
Exemple #7
0
    def __init__(self,
                 args,
                 vocab,
                 input_size,
                 ent_span_feature_size,
                 reduction='mean'):
        """Sets `ConRelModel` parameters

        Arguments:
            args {dict} -- config parameters for constructing multiple models
            vocab {dict} -- vocabulary
            input_size {int} -- input size
            ent_span_feature_size {int} -- entity span feature size

        Keyword Arguments:
            reduction {str} -- crossentropy loss recduction (default: {mean})
        """

        super().__init__()

        self.span_batch_size = args.span_batch_size
        self.context_output_size = args.context_output_size
        self.output_size = args.ent_mention_output_size
        self.activation = gelu
        self.dropout = args.dropout
        self.device = args.device

        self.context_span_extractor = CNNSpanExtractor(
            input_size=input_size,
            num_filters=args.context_cnn_output_channels,
            ngram_filter_sizes=args.context_cnn_kernel_sizes,
            dropout=args.dropout)

        if self.context_output_size > 0:
            self.context2hidden = BertLinear(
                input_size=self.context_span_extractor.get_output_dims(),
                output_size=self.context_output_size,
                activation=self.activation,
                dropout=self.dropout)
            self.real_feature_size = 2 * ent_span_feature_size + 3 * self.context_output_size
        else:
            self.context2hidden = lambda x: x
            self.real_feature_size = 2 * ent_span_feature_size + 3 * self.context_span_extractor.get_output_dims(
            )

        if self.output_size > 0:
            self.mlp = BertLinear(input_size=self.real_feature_size,
                                  output_size=self.output_size,
                                  activation=self.activation,
                                  dropout=self.dropout)
        else:
            self.output_size = self.real_feature_size
            self.mlp = lambda x: x

        self.relation_decoder = VanillaSoftmaxDecoder(
            hidden_size=self.output_size,
            label_size=vocab.get_vocab_size('span2rel'),
            reduction=reduction)

        self.context_zero_feat = torch.zeros(
            self.context_span_extractor.get_output_dims())
        if self.device > -1:
            self.context_zero_feat = self.context_zero_feat.cuda(
                device=self.device, non_blocking=True)
Exemple #8
0
class ConRelModel(nn.Module):
    """Predicts relation type between two candidate entity.
    """
    def __init__(self,
                 args,
                 vocab,
                 input_size,
                 ent_span_feature_size,
                 reduction='mean'):
        """Sets `ConRelModel` parameters

        Arguments:
            args {dict} -- config parameters for constructing multiple models
            vocab {dict} -- vocabulary
            input_size {int} -- input size
            ent_span_feature_size {int} -- entity span feature size

        Keyword Arguments:
            reduction {str} -- crossentropy loss recduction (default: {mean})
        """

        super().__init__()

        self.span_batch_size = args.span_batch_size
        self.context_output_size = args.context_output_size
        self.output_size = args.ent_mention_output_size
        self.activation = gelu
        self.dropout = args.dropout
        self.device = args.device

        self.context_span_extractor = CNNSpanExtractor(
            input_size=input_size,
            num_filters=args.context_cnn_output_channels,
            ngram_filter_sizes=args.context_cnn_kernel_sizes,
            dropout=args.dropout)

        if self.context_output_size > 0:
            self.context2hidden = BertLinear(
                input_size=self.context_span_extractor.get_output_dims(),
                output_size=self.context_output_size,
                activation=self.activation,
                dropout=self.dropout)
            self.real_feature_size = 2 * ent_span_feature_size + 3 * self.context_output_size
        else:
            self.context2hidden = lambda x: x
            self.real_feature_size = 2 * ent_span_feature_size + 3 * self.context_span_extractor.get_output_dims(
            )

        if self.output_size > 0:
            self.mlp = BertLinear(input_size=self.real_feature_size,
                                  output_size=self.output_size,
                                  activation=self.activation,
                                  dropout=self.dropout)
        else:
            self.output_size = self.real_feature_size
            self.mlp = lambda x: x

        self.relation_decoder = VanillaSoftmaxDecoder(
            hidden_size=self.output_size,
            label_size=vocab.get_vocab_size('span2rel'),
            reduction=reduction)

        self.context_zero_feat = torch.zeros(
            self.context_span_extractor.get_output_dims())
        if self.device > -1:
            self.context_zero_feat = self.context_zero_feat.cuda(
                device=self.device, non_blocking=True)

    def forward(self, batch_inputs):
        """Propagates forwardly

        Arguments:
            batch_inputs {dict} -- batch input data

        Returns:
            dict -- outputs: rel_inputs, all_candi_rel_labels
        """

        all_candi_rels = batch_inputs['all_candi_rels']
        seq_lens = batch_inputs['tokens_lens']
        batch_seq_encoder_reprs = batch_inputs['seq_encoder_reprs']

        batch_context_spans = self.generate_all_context_spans(
            all_candi_rels, seq_lens)

        batch_context_spans_feature = self.cache_context_spans_feature(
            batch_context_spans, batch_seq_encoder_reprs,
            self.context_span_extractor, self.span_batch_size, self.device)
        batch_rels = self.create_batch_rels(batch_inputs,
                                            batch_context_spans_feature)

        relation_outputs = self.relation_decoder(
            batch_rels['rel_inputs'], batch_rels['all_candi_rel_labels'])

        results = {}
        results['rel_loss'] = relation_outputs['loss']
        results['rel_preds'] = relation_outputs['predict']

        return results

    def create_batch_rels(self, batch_inputs, batch_context_spans_feature):
        """Creates batch relation inputs

        Arguments:
            batch_inputs {dict} -- batch inputs
            batch_context_spans_feature {list} -- context spans feature list

        Returns:
            dict -- batch relation inputs
        """

        batch_rels = defaultdict(list)

        for idx, seq_len in enumerate(batch_inputs['tokens_lens']):
            batch_rels['all_candi_rel_labels'].extend(
                batch_inputs['all_candi_rel_labels'][idx])
            for e1, e2 in batch_inputs['all_candi_rels'][idx]:
                L = (0, e1[0])
                E1 = (e1[0], e1[1])
                M = (e1[1], e2[0])
                E2 = (e2[0], e2[1])
                R = (e2[1], seq_len)

                if L[0] >= L[1]:
                    batch_rels['L'].append(self.context_zero_feat)
                else:
                    batch_rels['L'].append(batch_context_spans_feature[idx][L])

                if M[0] >= M[1]:
                    batch_rels['M'].append(self.context_zero_feat)
                else:
                    batch_rels['M'].append(batch_context_spans_feature[idx][M])

                if R[0] >= R[1]:
                    batch_rels['R'].append(self.context_zero_feat)
                else:
                    batch_rels['R'].append(batch_context_spans_feature[idx][R])

                batch_rels['E1'].append(
                    batch_inputs['ent_spans_feature'][idx][E1])
                batch_rels['E2'].append(
                    batch_inputs['ent_spans_feature'][idx][E2])

        batch_rels['E1'] = torch.stack(batch_rels['E1'])
        batch_rels['E2'] = torch.stack(batch_rels['E2'])

        batch_rels['L'] = self.context2hidden(torch.stack(batch_rels['L']))
        batch_rels['M'] = self.context2hidden(torch.stack(batch_rels['M']))
        batch_rels['R'] = self.context2hidden(torch.stack(batch_rels['R']))

        rel_feature = torch.cat([
            batch_rels['L'], batch_rels['E1'], batch_rels['M'],
            batch_rels['E2'], batch_rels['R']
        ],
                                dim=1)

        batch_rels['rel_inputs'] = self.mlp(rel_feature)

        batch_rels['all_candi_rel_labels'] = torch.LongTensor(
            batch_rels['all_candi_rel_labels'])
        if self.device > -1:
            batch_rels['all_candi_rel_labels'] = batch_rels[
                'all_candi_rel_labels'].cuda(device=self.device,
                                             non_blocking=True)

        return batch_rels

    def cache_context_spans_feature(self, context_spans,
                                    batch_seq_encoder_reprs,
                                    context_span_extractor, span_batch_size,
                                    device):
        """Calculates all context spans feature for caching

        Arguments:
            context_spans {list} -- context spans
            batch_seq_encoder_reprs {list} -- batch sequence encoder representations
            context_span_extractor {nn.Module} -- context span extractor model
            span_batch_size {int} -- span batch size
            device {int} -- device {int} -- device id: cpu: -1, gpu: >= 0 (default: {-1})

        Returns:
            list -- batch caching spans feature
        """

        assert len(context_spans) == len(
            batch_seq_encoder_reprs), "batch spans' size is not correct."

        all_spans = []
        all_seq_encoder_reprs = []
        for spans, seq_encoder_reprs in zip(context_spans,
                                            batch_seq_encoder_reprs):
            all_spans.extend((span[0], span[1]) for span in spans)
            all_seq_encoder_reprs.extend(
                [seq_encoder_reprs for _ in range(len(spans))])

        batch_spans_feature = [{} for _ in range(len(context_spans))]

        if len(all_spans) == 0:
            return batch_spans_feature

        if span_batch_size > 0:
            all_spans_feature = []
            for idx in range(0, len(all_spans), span_batch_size):
                batch_spans_tensor = torch.LongTensor(
                    all_spans[idx:idx + span_batch_size]).unsqueeze(1)
                if self.device > -1:
                    batch_spans_tensor = batch_spans_tensor.cuda(
                        device=device, non_blocking=True)
                batch_seq_encoder_reprs = torch.stack(
                    all_seq_encoder_reprs[idx:idx + span_batch_size])

                all_spans_feature.append(
                    context_span_extractor(batch_seq_encoder_reprs,
                                           batch_spans_tensor).squeeze(1))
            all_spans_feature = torch.cat(all_spans_feature, dim=0)
        else:
            all_spans_tensor = torch.LongTensor(all_spans).unsqueeze(1)
            if self.device > -1:
                all_spans_tensor = all_spans_tensor.cuda(device=device,
                                                         non_blocking=True)
            all_seq_encoder_reprs = torch.stack(all_seq_encoder_reprs)
            all_spans_feature = context_span_extractor(
                all_seq_encoder_reprs, all_spans_tensor).squeeze(1)

        idx = 0
        for i, spans in enumerate(context_spans):
            for span in spans:
                batch_spans_feature[i][span] = all_spans_feature[idx]
                idx += 1

        return batch_spans_feature

    def generate_all_context_spans(self, all_candi_rels, seq_lens):
        """Generates all context spans

        Arguments:
            all_candi_rels {list} -- all candidate relation list
            seq_lens {list} -- batch sequence length

        Returns:
            list -- all context spans
        """

        assert len(all_candi_rels) == len(
            seq_lens), "candidate relations' size is not correct."

        batch_context_spans = []
        for candi_rels, seq_len in zip(all_candi_rels, seq_lens):
            context_spans = set()
            for e1, e2 in candi_rels:
                L = (0, e1[0])
                M = (e1[1], e2[0])
                R = (e2[1], seq_len)

                # L, M, R can be empty
                for span in [L, M, R]:
                    if span[0] >= span[1]:
                        continue
                    context_spans.add(span)

            batch_context_spans.append(list(context_spans))

        return batch_context_spans
Exemple #9
0
    def __init__(self, cfg):
        """This function decides `JointREPretrainedModel` components

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models
        """

        super().__init__()
        self.ent_output_size = cfg.ent_output_size
        self.context_output_size = cfg.context_output_size
        self.output_size = cfg.ent_mention_output_size
        self.ent_batch_size = cfg.ent_batch_size
        self.permutation_batch_size = cfg.permutation_batch_size
        self.permutation_samples_num = cfg.permutation_samples_num
        self.confused_batch_size = cfg.confused_batch_size
        self.confused_samples_num = cfg.confused_samples_num
        self.activation = gelu
        self.device = cfg.device

        self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name,
                                        trainable=cfg.fine_tune,
                                        output_size=cfg.bert_output_size,
                                        activation=self.activation)

        self.entity_span_extractor = CNNSpanExtractor(
            input_size=self.bert_encoder.get_output_dims(),
            num_filters=cfg.entity_cnn_output_channels,
            ngram_filter_sizes=cfg.entity_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.ent_output_size > 0:
            self.ent2hidden = BertLinear(
                input_size=self.entity_span_extractor.get_output_dims(),
                output_size=self.ent_output_size,
                activation=self.activation,
                dropout=cfg.dropout)
        else:
            self.ent_output_size = self.entity_span_extractor.get_output_dims()
            self.ent2hidden = lambda x: x

        self.context_span_extractor = CNNSpanExtractor(
            input_size=self.bert_encoder.get_output_dims(),
            num_filters=cfg.context_cnn_output_channels,
            ngram_filter_sizes=cfg.context_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.context_output_size > 0:
            self.context2hidden = BertLinear(
                input_size=self.context_span_extractor.get_output_dims(),
                output_size=self.context_output_size,
                activation=self.activation,
                dropout=cfg.dropout)
        else:
            self.context_output_size = self.context_span_extractor.get_output_dims(
            )
            self.context2hidden = lambda x: x

        if self.output_size > 0:
            self.mlp = BertLinear(input_size=2 * self.ent_output_size +
                                  3 * self.context_output_size,
                                  output_size=self.output_size,
                                  activation=self.activation,
                                  dropout=cfg.dropout)
        else:
            self.output_size = 2 * self.ent_output_size + 3 * self.context_output_size
            self.mlp = lambda x: x

        self.entity_pretrained_decoder = VanillaSoftmaxDecoder(
            hidden_size=self.ent_output_size, label_size=18)

        self.masked_token_mlp = BertLinear(
            input_size=self.bert_encoder.get_output_dims(),
            output_size=self.bert_encoder.get_output_dims(),
            activation=self.activation)

        self.token_vocab_size = self.bert_encoder.bert_model.embeddings.word_embeddings.weight.size(
        )[0]
        self.masked_token_decoder = nn.Linear(
            self.bert_encoder.get_output_dims(),
            self.token_vocab_size,
            bias=False)
        self.masked_token_decoder.weight.data.normal_(mean=0.0, std=0.02)
        self.masked_token_decoder_bias = nn.Parameter(
            torch.zeros(self.token_vocab_size))

        clone_weights(self.masked_token_decoder,
                      self.bert_encoder.bert_model.embeddings.word_embeddings)

        self.masked_token_loss = nn.CrossEntropyLoss()

        self.permutation_decoder = VanillaSoftmaxDecoder(
            hidden_size=self.output_size, label_size=120)

        self.confused_context_decoder = nn.Linear(self.output_size, 1)
        self.confused_context_decoder.weight.data.normal_(mean=0.0, std=0.02)
        self.confused_context_decoder.bias.data.zero_()

        self.entity_mention_index_tensor = torch.LongTensor([2, 0, 3, 1, 4])
        if self.device > -1:
            self.entity_mention_index_tensor = self.entity_mention_index_tensor.cuda(
                device=self.device, non_blocking=True)
Exemple #10
0
class JointREPretrainedModel(nn.Module):
    """This class contains entity typing, masked token prediction, entity mention permutation prediction, confused entity mention context rank loss, four pretrained tasks in total.
    """
    def __init__(self, cfg):
        """This function decides `JointREPretrainedModel` components

        Arguments:
            cfg {dict} -- config parameters for constructing multiple models
        """

        super().__init__()
        self.ent_output_size = cfg.ent_output_size
        self.context_output_size = cfg.context_output_size
        self.output_size = cfg.ent_mention_output_size
        self.ent_batch_size = cfg.ent_batch_size
        self.permutation_batch_size = cfg.permutation_batch_size
        self.permutation_samples_num = cfg.permutation_samples_num
        self.confused_batch_size = cfg.confused_batch_size
        self.confused_samples_num = cfg.confused_samples_num
        self.activation = gelu
        self.device = cfg.device

        self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name,
                                        trainable=cfg.fine_tune,
                                        output_size=cfg.bert_output_size,
                                        activation=self.activation)

        self.entity_span_extractor = CNNSpanExtractor(
            input_size=self.bert_encoder.get_output_dims(),
            num_filters=cfg.entity_cnn_output_channels,
            ngram_filter_sizes=cfg.entity_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.ent_output_size > 0:
            self.ent2hidden = BertLinear(
                input_size=self.entity_span_extractor.get_output_dims(),
                output_size=self.ent_output_size,
                activation=self.activation,
                dropout=cfg.dropout)
        else:
            self.ent_output_size = self.entity_span_extractor.get_output_dims()
            self.ent2hidden = lambda x: x

        self.context_span_extractor = CNNSpanExtractor(
            input_size=self.bert_encoder.get_output_dims(),
            num_filters=cfg.context_cnn_output_channels,
            ngram_filter_sizes=cfg.context_cnn_kernel_sizes,
            dropout=cfg.dropout)

        if self.context_output_size > 0:
            self.context2hidden = BertLinear(
                input_size=self.context_span_extractor.get_output_dims(),
                output_size=self.context_output_size,
                activation=self.activation,
                dropout=cfg.dropout)
        else:
            self.context_output_size = self.context_span_extractor.get_output_dims(
            )
            self.context2hidden = lambda x: x

        if self.output_size > 0:
            self.mlp = BertLinear(input_size=2 * self.ent_output_size +
                                  3 * self.context_output_size,
                                  output_size=self.output_size,
                                  activation=self.activation,
                                  dropout=cfg.dropout)
        else:
            self.output_size = 2 * self.ent_output_size + 3 * self.context_output_size
            self.mlp = lambda x: x

        self.entity_pretrained_decoder = VanillaSoftmaxDecoder(
            hidden_size=self.ent_output_size, label_size=18)

        self.masked_token_mlp = BertLinear(
            input_size=self.bert_encoder.get_output_dims(),
            output_size=self.bert_encoder.get_output_dims(),
            activation=self.activation)

        self.token_vocab_size = self.bert_encoder.bert_model.embeddings.word_embeddings.weight.size(
        )[0]
        self.masked_token_decoder = nn.Linear(
            self.bert_encoder.get_output_dims(),
            self.token_vocab_size,
            bias=False)
        self.masked_token_decoder.weight.data.normal_(mean=0.0, std=0.02)
        self.masked_token_decoder_bias = nn.Parameter(
            torch.zeros(self.token_vocab_size))

        clone_weights(self.masked_token_decoder,
                      self.bert_encoder.bert_model.embeddings.word_embeddings)

        self.masked_token_loss = nn.CrossEntropyLoss()

        self.permutation_decoder = VanillaSoftmaxDecoder(
            hidden_size=self.output_size, label_size=120)

        self.confused_context_decoder = nn.Linear(self.output_size, 1)
        self.confused_context_decoder.weight.data.normal_(mean=0.0, std=0.02)
        self.confused_context_decoder.bias.data.zero_()

        self.entity_mention_index_tensor = torch.LongTensor([2, 0, 3, 1, 4])
        if self.device > -1:
            self.entity_mention_index_tensor = self.entity_mention_index_tensor.cuda(
                device=self.device, non_blocking=True)

    def forward(self, batch_inputs, pretrain_task=''):
        """This function propagates forwardly

        Arguments:
            batch_inputs {dict} -- batch inputs

        Keyword Arguments:
            pretrain_task {str} -- pretraining task (default: {''})
        
        Returns:
            dict -- results
        """

        if pretrain_task == 'masked_entity_typing':
            return self.masked_entity_typing(batch_inputs)
        elif pretrain_task == 'masked_entity_token_prediction':
            return self.masked_entity_token_prediction(batch_inputs)
        elif pretrain_task == 'entity_mention_permutation':
            return self.permutation_prediction(batch_inputs)
        elif pretrain_task == 'confused_context':
            return self.confused_context_prediction(batch_inputs)

    def seq_decoder(self, seq_inputs, seq_mask=None, seq_labels=None):
        results = {}
        seq_outpus = self.masked_token_decoder(
            seq_inputs) + self.masked_token_decoder_bias
        seq_log_probs = F.log_softmax(seq_outpus, dim=2)
        seq_preds = seq_log_probs.argmax(dim=2)
        results['predict'] = seq_preds

        if seq_labels is not None:
            if seq_mask is not None:
                active_loss = seq_mask.view(-1) == 1
                active_outputs = seq_outpus.view(
                    -1, self.token_vocab_size)[active_loss]
                active_labels = seq_labels.view(-1)[active_loss]
                no_pad_avg_loss = self.masked_token_loss(
                    active_outputs, active_labels)
                results['loss'] = no_pad_avg_loss
            else:
                avg_loss = self.masked_token_loss(
                    seq_outpus.view(-1, self.token_vocab_size),
                    seq_labels.view(-1))
                results['loss'] = avg_loss

        return results

    def masked_entity_typing(self, batch_inputs):
        """This function pretrains masked entity typing task.
        
        Arguments:
            batch_inputs {dict} -- batch inputs
        """

        seq_wordpiece_tokens_reprs, _ = self.bert_encoder(
            batch_inputs['tokens_id'])
        batch_inputs['seq_wordpiece_tokens_reprs'] = seq_wordpiece_tokens_reprs
        batch_inputs['seq_tokens_reprs'] = batched_index_select(
            seq_wordpiece_tokens_reprs, batch_inputs['tokens_index'])

        all_ents = []
        all_ents_labels = []
        all_seq_tokens_reprs = []
        for ent_spans, ent_labels, seq_tokens_reprs in zip(
                batch_inputs['ent_spans'], batch_inputs['ent_labels'],
                batch_inputs['seq_tokens_reprs']):
            all_ents.extend([span[0], span[1] - 1] for span in ent_spans)
            all_ents_labels.extend(ent_label for ent_label in ent_labels)
            all_seq_tokens_reprs.extend(seq_tokens_reprs
                                        for _ in range(len(ent_spans)))

        if self.ent_batch_size > 0:
            all_entity_typing_loss = []
            for idx in range(0, len(all_ents), self.ent_batch_size):
                batch_ents_tensor = torch.LongTensor(
                    all_ents[idx:idx + self.ent_batch_size]).unsqueeze(1)
                if self.device > -1:
                    batch_ents_tensor = batch_ents_tensor.cuda(
                        device=self.device, non_blocking=True)

                batch_seq_tokens_reprs = torch.stack(
                    all_seq_tokens_reprs[idx:idx + self.ent_batch_size])

                batch_ents_feature = self.ent2hidden(
                    self.entity_span_extractor(batch_seq_tokens_reprs,
                                               batch_ents_tensor).squeeze(1))

                batch_ents_labels = torch.LongTensor(
                    all_ents_labels[idx:idx + self.ent_batch_size])
                if self.device > -1:
                    batch_ents_labels = batch_ents_labels.cuda(
                        device=self.device, non_blocking=True)

                entity_typing_outputs = self.entity_pretrained_decoder(
                    batch_ents_feature, batch_ents_labels)

                all_entity_typing_loss.append(entity_typing_outputs['loss'])

            if len(all_entity_typing_loss) != 0:
                entity_typing_loss = sum(all_entity_typing_loss) / len(
                    all_entity_typing_loss)
            else:
                zero_loss = torch.Tensor([0])
                zero_loss.requires_grad = True
                if self.device > -1:
                    zero_loss = zero_loss.cuda(device=self.device,
                                               non_blocking=True)
                entity_typing_loss = zero_loss
        else:
            all_ents_tensor = torch.LongTensor(all_ents).unsqueeze(1)
            if self.device > -1:
                all_ents_tensor = all_ents_tensor.cuda(device=self.device,
                                                       non_blocking=True)
            all_seq_tokens_reprs = torch.stack(all_seq_tokens_reprs)
            all_ents_feature = self.entity_span_extractor(
                all_seq_tokens_reprs, all_ents_tensor).squeeze(1)

            all_ents_feature = self.ent2hidden(all_ents_feature)

            all_ents_labels = torch.LongTensor(all_ents_labels)
            if self.device > -1:
                all_ents_labels = all_ents_labels.cuda(device=self.device,
                                                       non_blocking=True)

            entity_typing_outputs = self.entity_pretrained_decoder(
                all_ents_feature, all_ents_labels)

            entity_typing_loss = entity_typing_outputs['loss']

        outputs = {}
        outputs['loss'] = entity_typing_loss

        return outputs

    def masked_entity_token_prediction(self, batch_inputs):
        """This function pretrains masked entity tokens prediction task.
        
        Arguments:
            batch_inputs {dict} -- batch inputs
        """

        masked_seq_wordpiece_tokens_reprs, _ = self.bert_encoder(
            batch_inputs['tokens_id'])
        masked_seq_wordpiece_tokens_reprs = self.masked_token_mlp(
            masked_seq_wordpiece_tokens_reprs)

        if batch_inputs['masked_index'].sum() != 0:
            masked_entity_token_outputs = self.seq_decoder(
                seq_inputs=masked_seq_wordpiece_tokens_reprs,
                seq_mask=batch_inputs['masked_index'],
                seq_labels=batch_inputs['tokens_label'])
            masked_entity_token_loss = masked_entity_token_outputs['loss']
        else:
            zero_loss = torch.Tensor([0])
            zero_loss.requires_grad = True
            if self.device > -1:
                zero_loss = zero_loss.cuda(device=self.device,
                                           non_blocking=True)
            masked_entity_token_loss = zero_loss

        outputs = {}
        outputs['loss'] = masked_entity_token_loss
        return outputs

    def permutation_prediction(self, batch_inputs):
        """This function pretrains entity mention permutaiton prediction task.
        
        Arguments:
            batch_inputs {dict} -- batch inputs
        """

        all_permutation_feature = self.get_entity_mention_feature(
            batch_inputs['tokens_id'], batch_inputs['tokens_index'],
            batch_inputs['ent_mention'], batch_inputs['tokens_index_lens'])
        permutation_outputs = self.permutation_decoder(
            all_permutation_feature, batch_inputs['ent_mention_label'])
        permutation_loss = permutation_outputs['loss']

        outputs = {}
        outputs['loss'] = permutation_loss
        return outputs

    def confused_context_prediction(self, batch_inputs):
        """This function pretrains confused context prediction task.
        
        Arguments:
            batch_inputs {dict} -- batch inputs
        """

        all_confused_context_feature = self.get_entity_mention_feature(
            batch_inputs['confused_tokens_id'],
            batch_inputs['confused_tokens_index'],
            batch_inputs['confused_ent_mention'],
            batch_inputs['confused_tokens_index_lens'])
        all_truth_context_feature = self.get_entity_mention_feature(
            batch_inputs['origin_tokens_id'],
            batch_inputs['origin_tokens_index'],
            batch_inputs['origin_ent_mention'],
            batch_inputs['origin_tokens_index_lens'])
        confused_context_score = self.confused_context_decoder(
            all_confused_context_feature)
        truth_context_score = self.confused_context_decoder(
            all_truth_context_feature)
        rank_loss = torch.mean(
            torch.relu(5.0 - torch.abs(confused_context_score -
                                       truth_context_score)))

        outputs = {}
        outputs['loss'] = rank_loss
        return outputs

    def get_entity_mention_feature(self, batch_wordpiece_tokens,
                                   batch_wordpiece_tokens_index,
                                   batch_entity_mentions, batch_seq_lens):
        """This function extracts entity mention feature using CNN.
        
        Arguments:
            batch_wordpiece_tokens {tensor} -- batch wordpiece tokens
            batch_wordpiece_tokens_index {tensor} -- batch wordpiece tokens index
            batch_entity_mentions {list} -- batch entity mentions
            batch_seq_lens {list} -- batch sequence length list
        
        Returns:
            tensor -- entity mention feature
        """

        batch_seq_reprs, _ = self.bert_encoder(batch_wordpiece_tokens)
        batch_seq_reprs = batched_index_select(batch_seq_reprs,
                                               batch_wordpiece_tokens_index)

        entity_spans = []
        context_spans = []
        for entity_mention, seq_len in zip(batch_entity_mentions,
                                           batch_seq_lens):
            entity_spans.append([[entity_mention[0][0], entity_mention[0][1]],
                                 [entity_mention[1][0], entity_mention[1][1]]])
            context_spans.append([[0, entity_mention[0][0]],
                                  [entity_mention[0][1], entity_mention[1][0]],
                                  [entity_mention[1][1], seq_len]])

        entity_spans_tensor = torch.LongTensor(entity_spans)
        if self.device > -1:
            entity_spans_tensor = entity_spans_tensor.cuda(device=self.device,
                                                           non_blocking=True)

        context_spans_tensor = torch.LongTensor(context_spans)
        if self.device > -1:
            context_spans_tensor = context_spans_tensor.cuda(
                device=self.device, non_blocking=True)

        entity_feature = self.entity_span_extractor(batch_seq_reprs,
                                                    entity_spans_tensor)
        context_feature = self.context_span_extractor(batch_seq_reprs,
                                                      context_spans_tensor)

        entity_feature = self.ent2hidden(entity_feature)
        context_feature = self.context2hidden(context_feature)

        entity_mention_feature = torch.cat([
            context_feature[:, 0, :], entity_feature[:, 0, :],
            context_feature[:, 1, :], entity_feature[:, 1, :],
            context_feature[:, 2, :]
        ],
                                           dim=-1).view(
                                               len(batch_wordpiece_tokens), -1)

        entity_mention_feature = self.mlp(entity_mention_feature)

        return entity_mention_feature