def __init__(self, cfg, vocab): """This function constructs `BertEmbedModel` components and sets `BertEmbedModel` parameters Arguments: cfg {dict} -- config parameters for constructing multiple models vocab {Vocabulary} -- vocabulary """ super().__init__() self.activation = nn.GELU() self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name, trainable=cfg.fine_tune, output_size=cfg.bert_output_size, activation=self.activation, dropout=cfg.bert_dropout) self.encoder_output_size = self.bert_encoder.get_output_dims()
class BertEmbedModel(nn.Module): """This class acts as an embeddding layer with bert model """ def __init__(self, cfg, vocab): """This function constructs `BertEmbedModel` components and sets `BertEmbedModel` parameters Arguments: cfg {dict} -- config parameters for constructing multiple models vocab {Vocabulary} -- vocabulary """ super().__init__() self.activation = nn.GELU() self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name, trainable=cfg.fine_tune, output_size=cfg.bert_output_size, activation=self.activation, dropout=cfg.bert_dropout) self.encoder_output_size = self.bert_encoder.get_output_dims() def forward(self, batch_inputs): """This function propagetes forwardly Arguments: batch_inputs {dict} -- batch input data """ if 'wordpiece_segment_ids' in batch_inputs: batch_seq_bert_encoder_repr, batch_cls_repr = self.bert_encoder( batch_inputs['wordpiece_tokens'], batch_inputs['wordpiece_segment_ids']) else: batch_seq_bert_encoder_repr, batch_cls_repr = self.bert_encoder( batch_inputs['wordpiece_tokens']) batch_seq_tokens_encoder_repr = batched_index_select( batch_seq_bert_encoder_repr, batch_inputs['wordpiece_tokens_index']) batch_inputs['seq_encoder_reprs'] = batch_seq_tokens_encoder_repr batch_inputs['seq_cls_repr'] = batch_cls_repr def get_hidden_size(self): """This function returns embedding dimensions Returns: int -- embedding dimensitons """ return self.encoder_output_size
def __init__(self, cfg, momentum=False): """This funciton constructs `PretrainedSpanEncoder` components Arguments: cfg {dict} -- config parameters for constructing multiple models Keyword Arguments: momentum {bool} -- whether this encoder is momentum encoder (default: {False}) """ super().__init__() self.ent_output_size = cfg.ent_output_size self.span_batch_size = cfg.span_batch_size self.position_embedding_dims = cfg.position_embedding_dims self.att_size = cfg.att_size self.momentum = momentum self.activation = nn.GELU() self.device = cfg.device self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name, trainable=cfg.fine_tune, output_size=cfg.bert_output_size, activation=self.activation) self.entity_span_extractor = CNNSpanExtractor( input_size=self.bert_encoder.get_output_dims(), num_filters=cfg.entity_cnn_output_channels, ngram_filter_sizes=cfg.entity_cnn_kernel_sizes, dropout=cfg.dropout) if self.ent_output_size > 0: self.ent2hidden = BertLinear(input_size=self.entity_span_extractor.get_output_dims(), output_size=self.ent_output_size, activation=self.activation, dropout=cfg.dropout) else: self.ent_output_size = self.entity_span_extractor.get_output_dims() self.ent2hidden = lambda x: x self.entity_span_mlp = BertLinear(input_size=self.ent_output_size, output_size=self.ent_output_size, activation=self.activation, dropout=cfg.dropout) self.entity_span_decoder = VanillaSoftmaxDecoder(hidden_size=self.ent_output_size, label_size=6) self.global_position_embedding = nn.Embedding(150, 200) self.global_position_embedding.weight.data.normal_(mean=0.0, std=0.02) self.masked_token_mlp = BertLinear(input_size=self.bert_encoder.get_output_dims() + 200, output_size=self.bert_encoder.get_output_dims(), activation=self.activation, dropout=cfg.dropout) self.masked_token_decoder = nn.Linear(self.bert_encoder.get_output_dims(), 28996, bias=False) self.masked_token_decoder.weight.data.normal_(mean=0.0, std=0.02) self.masked_token_decoder_bias = nn.Parameter(torch.zeros(28996)) self.position_embedding = nn.Embedding(7, self.position_embedding_dims) self.position_embedding.weight.data.normal_(mean=0.0, std=0.02) self.attention_encoder = PosAwareAttEncoder(self.ent_output_size, self.bert_encoder.get_output_dims(), 2 * self.position_embedding_dims, self.att_size, activation=self.activation, dropout=cfg.dropout) self.mlp_head1 = BertLinear(self.ent_output_size, self.bert_encoder.get_output_dims(), activation=self.activation, dropout=cfg.dropout) self.mlp_head2 = BertLinear(self.bert_encoder.get_output_dims(), self.bert_encoder.get_output_dims(), activation=self.activation, dropout=cfg.dropout) self.masked_token_loss = nn.CrossEntropyLoss()
class PretrainedSpanEncoder(nn.Module): """PretrainedSpanEncoder encodes span into vector. """ def __init__(self, cfg, momentum=False): """This funciton constructs `PretrainedSpanEncoder` components Arguments: cfg {dict} -- config parameters for constructing multiple models Keyword Arguments: momentum {bool} -- whether this encoder is momentum encoder (default: {False}) """ super().__init__() self.ent_output_size = cfg.ent_output_size self.span_batch_size = cfg.span_batch_size self.position_embedding_dims = cfg.position_embedding_dims self.att_size = cfg.att_size self.momentum = momentum self.activation = nn.GELU() self.device = cfg.device self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name, trainable=cfg.fine_tune, output_size=cfg.bert_output_size, activation=self.activation) self.entity_span_extractor = CNNSpanExtractor( input_size=self.bert_encoder.get_output_dims(), num_filters=cfg.entity_cnn_output_channels, ngram_filter_sizes=cfg.entity_cnn_kernel_sizes, dropout=cfg.dropout) if self.ent_output_size > 0: self.ent2hidden = BertLinear(input_size=self.entity_span_extractor.get_output_dims(), output_size=self.ent_output_size, activation=self.activation, dropout=cfg.dropout) else: self.ent_output_size = self.entity_span_extractor.get_output_dims() self.ent2hidden = lambda x: x self.entity_span_mlp = BertLinear(input_size=self.ent_output_size, output_size=self.ent_output_size, activation=self.activation, dropout=cfg.dropout) self.entity_span_decoder = VanillaSoftmaxDecoder(hidden_size=self.ent_output_size, label_size=6) self.global_position_embedding = nn.Embedding(150, 200) self.global_position_embedding.weight.data.normal_(mean=0.0, std=0.02) self.masked_token_mlp = BertLinear(input_size=self.bert_encoder.get_output_dims() + 200, output_size=self.bert_encoder.get_output_dims(), activation=self.activation, dropout=cfg.dropout) self.masked_token_decoder = nn.Linear(self.bert_encoder.get_output_dims(), 28996, bias=False) self.masked_token_decoder.weight.data.normal_(mean=0.0, std=0.02) self.masked_token_decoder_bias = nn.Parameter(torch.zeros(28996)) self.position_embedding = nn.Embedding(7, self.position_embedding_dims) self.position_embedding.weight.data.normal_(mean=0.0, std=0.02) self.attention_encoder = PosAwareAttEncoder(self.ent_output_size, self.bert_encoder.get_output_dims(), 2 * self.position_embedding_dims, self.att_size, activation=self.activation, dropout=cfg.dropout) self.mlp_head1 = BertLinear(self.ent_output_size, self.bert_encoder.get_output_dims(), activation=self.activation, dropout=cfg.dropout) self.mlp_head2 = BertLinear(self.bert_encoder.get_output_dims(), self.bert_encoder.get_output_dims(), activation=self.activation, dropout=cfg.dropout) self.masked_token_loss = nn.CrossEntropyLoss() def forward(self, batch_inputs): """This function propagetes forwardly Arguments: batch_inputs {dict} -- batch inputs Returns: dict -- results """ batch_seq_wordpiece_tokens_repr, batch_seq_cls_repr = self.bert_encoder( batch_inputs['wordpiece_tokens']) batch_seq_tokens_repr = batched_index_select(batch_seq_wordpiece_tokens_repr, batch_inputs['wordpiece_tokens_index']) results = {} entity_feature = self.entity_span_extractor(batch_seq_tokens_repr, batch_inputs['span_mention']) entity_feature = self.ent2hidden(entity_feature) subj_pos = torch.LongTensor([-1, 0, 1, 2, 3]) + 3 obj_pos = torch.LongTensor([-3, -2, -1, 0, 1]) + 3 if self.device > -1: subj_pos = subj_pos.cuda(device=self.device, non_blocking=True) obj_pos = obj_pos.cuda(device=self.device, non_blocking=True) subj_pos_emb = self.position_embedding(subj_pos) obj_pos_emb = self.position_embedding(obj_pos) pos_emb = torch.cat([subj_pos_emb, obj_pos_emb], dim=1).unsqueeze(0).repeat( batch_inputs['wordpiece_tokens_index'].size()[0], 1, 1) span_mention_attention_repr = self.attention_encoder(inputs=entity_feature, query=batch_seq_cls_repr, feature=pos_emb) results['span_mention_repr'] = self.mlp_head2(self.mlp_head1(span_mention_attention_repr)) if self.momentum: return results zero_loss = torch.Tensor([0]) zero_loss.requires_grad = True if self.device > -1: zero_loss = zero_loss.cuda(device=self.device, non_blocking=True) if sum([len(masked_index) for masked_index in batch_inputs['masked_index']]) == 0: results['masked_token_loss'] = zero_loss else: masked_wordpiece_tokens_repr = [] all_masked_label = [] for masked_index, masked_position, masked_label, seq_wordpiece_tokens_repr in zip( batch_inputs['masked_index'], batch_inputs['masked_position'], batch_inputs['masked_label'], batch_seq_wordpiece_tokens_repr): masked_index_tensor = torch.LongTensor(masked_index) masked_position_tensor = torch.LongTensor(masked_position) if self.device > -1: masked_index_tensor = masked_index_tensor.cuda(device=self.device, non_blocking=True) masked_position_tensor = masked_position_tensor.cuda(device=self.device, non_blocking=True) masked_wordpiece_tokens_repr.append( torch.cat([ seq_wordpiece_tokens_repr[masked_index_tensor], self.global_position_embedding(masked_position_tensor) ], dim=1)) all_masked_label.extend(masked_label) masked_wordpiece_tokens_input = torch.cat(masked_wordpiece_tokens_repr, dim=0) masked_wordpiece_tokens_output = self.masked_token_decoder( self.masked_token_mlp( masked_wordpiece_tokens_input)) + self.masked_token_decoder_bias all_masked_label_tensor = torch.LongTensor(all_masked_label) if self.device > -1: all_masked_label_tensor = all_masked_label_tensor.cuda(device=self.device, non_blocking=True) results['masked_token_loss'] = self.masked_token_loss(masked_wordpiece_tokens_output, all_masked_label_tensor) all_spans = [] all_spans_label = [] all_seq_tokens_reprs = [] for spans, spans_label, seq_tokens_repr in zip(batch_inputs['spans'], batch_inputs['spans_label'], batch_seq_tokens_repr): all_spans.extend(spans) all_spans_label.extend(spans_label) all_seq_tokens_reprs.extend(seq_tokens_repr for _ in range(len(spans))) assert len(all_spans) == len(all_seq_tokens_reprs) and len(all_spans) == len( all_spans_label) if len(all_spans) == 0: results['span_loss'] = zero_loss else: if self.span_batch_size > 0: all_span_loss = [] for idx in range(0, len(all_spans), self.span_batch_size): batch_ents_tensor = torch.LongTensor( all_spans[idx:idx + self.span_batch_size]).unsqueeze(1) if self.device > -1: batch_ents_tensor = batch_ents_tensor.cuda(device=self.device, non_blocking=True) batch_seq_tokens_reprs = torch.stack(all_seq_tokens_reprs[idx:idx + self.span_batch_size]) batch_spans_feature = self.ent2hidden( self.entity_span_extractor(batch_seq_tokens_reprs, batch_ents_tensor).squeeze(1)) batch_spans_label = torch.LongTensor(all_spans_label[idx:idx + self.span_batch_size]) if self.device > -1: batch_spans_label = batch_spans_label.cuda(device=self.device, non_blocking=True) span_outputs = self.entity_span_decoder( self.entity_span_mlp(batch_spans_feature), batch_spans_label) all_span_loss.append(span_outputs['loss']) results['span_loss'] = sum(all_span_loss) / len(all_span_loss) else: all_spans_tensor = torch.LongTensor(all_spans).unsqueeze(1) if self.device > -1: all_spans_tensor = all_spans_tensor.cuda(device=self.device, non_blocking=True) all_seq_tokens_reprs = torch.stack(all_seq_tokens_reprs) all_spans_feature = self.entity_span_extractor(all_seq_tokens_reprs, all_spans_tensor).squeeze(1) all_spans_feature = self.ent2hidden(all_spans_feature) all_spans_label = torch.LongTensor(all_spans_label) if self.device > -1: all_spans_label = all_spans_label.cuda(device=self.device, non_blocking=True) entity_typing_outputs = self.entity_span_decoder( self.entity_span_mlp(all_spans_feature), all_spans_label) results['span_loss'] = entity_typing_outputs['loss'] return results
def __init__(self, cfg): """This function decides `JointREPretrainedModel` components Arguments: cfg {dict} -- config parameters for constructing multiple models """ super().__init__() self.ent_output_size = cfg.ent_output_size self.context_output_size = cfg.context_output_size self.output_size = cfg.ent_mention_output_size self.ent_batch_size = cfg.ent_batch_size self.permutation_batch_size = cfg.permutation_batch_size self.permutation_samples_num = cfg.permutation_samples_num self.confused_batch_size = cfg.confused_batch_size self.confused_samples_num = cfg.confused_samples_num self.activation = gelu self.device = cfg.device self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name, trainable=cfg.fine_tune, output_size=cfg.bert_output_size, activation=self.activation) self.entity_span_extractor = CNNSpanExtractor( input_size=self.bert_encoder.get_output_dims(), num_filters=cfg.entity_cnn_output_channels, ngram_filter_sizes=cfg.entity_cnn_kernel_sizes, dropout=cfg.dropout) if self.ent_output_size > 0: self.ent2hidden = BertLinear( input_size=self.entity_span_extractor.get_output_dims(), output_size=self.ent_output_size, activation=self.activation, dropout=cfg.dropout) else: self.ent_output_size = self.entity_span_extractor.get_output_dims() self.ent2hidden = lambda x: x self.context_span_extractor = CNNSpanExtractor( input_size=self.bert_encoder.get_output_dims(), num_filters=cfg.context_cnn_output_channels, ngram_filter_sizes=cfg.context_cnn_kernel_sizes, dropout=cfg.dropout) if self.context_output_size > 0: self.context2hidden = BertLinear( input_size=self.context_span_extractor.get_output_dims(), output_size=self.context_output_size, activation=self.activation, dropout=cfg.dropout) else: self.context_output_size = self.context_span_extractor.get_output_dims( ) self.context2hidden = lambda x: x if self.output_size > 0: self.mlp = BertLinear(input_size=2 * self.ent_output_size + 3 * self.context_output_size, output_size=self.output_size, activation=self.activation, dropout=cfg.dropout) else: self.output_size = 2 * self.ent_output_size + 3 * self.context_output_size self.mlp = lambda x: x self.entity_pretrained_decoder = VanillaSoftmaxDecoder( hidden_size=self.ent_output_size, label_size=18) self.masked_token_mlp = BertLinear( input_size=self.bert_encoder.get_output_dims(), output_size=self.bert_encoder.get_output_dims(), activation=self.activation) self.token_vocab_size = self.bert_encoder.bert_model.embeddings.word_embeddings.weight.size( )[0] self.masked_token_decoder = nn.Linear( self.bert_encoder.get_output_dims(), self.token_vocab_size, bias=False) self.masked_token_decoder.weight.data.normal_(mean=0.0, std=0.02) self.masked_token_decoder_bias = nn.Parameter( torch.zeros(self.token_vocab_size)) clone_weights(self.masked_token_decoder, self.bert_encoder.bert_model.embeddings.word_embeddings) self.masked_token_loss = nn.CrossEntropyLoss() self.permutation_decoder = VanillaSoftmaxDecoder( hidden_size=self.output_size, label_size=120) self.confused_context_decoder = nn.Linear(self.output_size, 1) self.confused_context_decoder.weight.data.normal_(mean=0.0, std=0.02) self.confused_context_decoder.bias.data.zero_() self.entity_mention_index_tensor = torch.LongTensor([2, 0, 3, 1, 4]) if self.device > -1: self.entity_mention_index_tensor = self.entity_mention_index_tensor.cuda( device=self.device, non_blocking=True)
class JointREPretrainedModel(nn.Module): """This class contains entity typing, masked token prediction, entity mention permutation prediction, confused entity mention context rank loss, four pretrained tasks in total. """ def __init__(self, cfg): """This function decides `JointREPretrainedModel` components Arguments: cfg {dict} -- config parameters for constructing multiple models """ super().__init__() self.ent_output_size = cfg.ent_output_size self.context_output_size = cfg.context_output_size self.output_size = cfg.ent_mention_output_size self.ent_batch_size = cfg.ent_batch_size self.permutation_batch_size = cfg.permutation_batch_size self.permutation_samples_num = cfg.permutation_samples_num self.confused_batch_size = cfg.confused_batch_size self.confused_samples_num = cfg.confused_samples_num self.activation = gelu self.device = cfg.device self.bert_encoder = BertEncoder(bert_model_name=cfg.bert_model_name, trainable=cfg.fine_tune, output_size=cfg.bert_output_size, activation=self.activation) self.entity_span_extractor = CNNSpanExtractor( input_size=self.bert_encoder.get_output_dims(), num_filters=cfg.entity_cnn_output_channels, ngram_filter_sizes=cfg.entity_cnn_kernel_sizes, dropout=cfg.dropout) if self.ent_output_size > 0: self.ent2hidden = BertLinear( input_size=self.entity_span_extractor.get_output_dims(), output_size=self.ent_output_size, activation=self.activation, dropout=cfg.dropout) else: self.ent_output_size = self.entity_span_extractor.get_output_dims() self.ent2hidden = lambda x: x self.context_span_extractor = CNNSpanExtractor( input_size=self.bert_encoder.get_output_dims(), num_filters=cfg.context_cnn_output_channels, ngram_filter_sizes=cfg.context_cnn_kernel_sizes, dropout=cfg.dropout) if self.context_output_size > 0: self.context2hidden = BertLinear( input_size=self.context_span_extractor.get_output_dims(), output_size=self.context_output_size, activation=self.activation, dropout=cfg.dropout) else: self.context_output_size = self.context_span_extractor.get_output_dims( ) self.context2hidden = lambda x: x if self.output_size > 0: self.mlp = BertLinear(input_size=2 * self.ent_output_size + 3 * self.context_output_size, output_size=self.output_size, activation=self.activation, dropout=cfg.dropout) else: self.output_size = 2 * self.ent_output_size + 3 * self.context_output_size self.mlp = lambda x: x self.entity_pretrained_decoder = VanillaSoftmaxDecoder( hidden_size=self.ent_output_size, label_size=18) self.masked_token_mlp = BertLinear( input_size=self.bert_encoder.get_output_dims(), output_size=self.bert_encoder.get_output_dims(), activation=self.activation) self.token_vocab_size = self.bert_encoder.bert_model.embeddings.word_embeddings.weight.size( )[0] self.masked_token_decoder = nn.Linear( self.bert_encoder.get_output_dims(), self.token_vocab_size, bias=False) self.masked_token_decoder.weight.data.normal_(mean=0.0, std=0.02) self.masked_token_decoder_bias = nn.Parameter( torch.zeros(self.token_vocab_size)) clone_weights(self.masked_token_decoder, self.bert_encoder.bert_model.embeddings.word_embeddings) self.masked_token_loss = nn.CrossEntropyLoss() self.permutation_decoder = VanillaSoftmaxDecoder( hidden_size=self.output_size, label_size=120) self.confused_context_decoder = nn.Linear(self.output_size, 1) self.confused_context_decoder.weight.data.normal_(mean=0.0, std=0.02) self.confused_context_decoder.bias.data.zero_() self.entity_mention_index_tensor = torch.LongTensor([2, 0, 3, 1, 4]) if self.device > -1: self.entity_mention_index_tensor = self.entity_mention_index_tensor.cuda( device=self.device, non_blocking=True) def forward(self, batch_inputs, pretrain_task=''): """This function propagates forwardly Arguments: batch_inputs {dict} -- batch inputs Keyword Arguments: pretrain_task {str} -- pretraining task (default: {''}) Returns: dict -- results """ if pretrain_task == 'masked_entity_typing': return self.masked_entity_typing(batch_inputs) elif pretrain_task == 'masked_entity_token_prediction': return self.masked_entity_token_prediction(batch_inputs) elif pretrain_task == 'entity_mention_permutation': return self.permutation_prediction(batch_inputs) elif pretrain_task == 'confused_context': return self.confused_context_prediction(batch_inputs) def seq_decoder(self, seq_inputs, seq_mask=None, seq_labels=None): results = {} seq_outpus = self.masked_token_decoder( seq_inputs) + self.masked_token_decoder_bias seq_log_probs = F.log_softmax(seq_outpus, dim=2) seq_preds = seq_log_probs.argmax(dim=2) results['predict'] = seq_preds if seq_labels is not None: if seq_mask is not None: active_loss = seq_mask.view(-1) == 1 active_outputs = seq_outpus.view( -1, self.token_vocab_size)[active_loss] active_labels = seq_labels.view(-1)[active_loss] no_pad_avg_loss = self.masked_token_loss( active_outputs, active_labels) results['loss'] = no_pad_avg_loss else: avg_loss = self.masked_token_loss( seq_outpus.view(-1, self.token_vocab_size), seq_labels.view(-1)) results['loss'] = avg_loss return results def masked_entity_typing(self, batch_inputs): """This function pretrains masked entity typing task. Arguments: batch_inputs {dict} -- batch inputs """ seq_wordpiece_tokens_reprs, _ = self.bert_encoder( batch_inputs['tokens_id']) batch_inputs['seq_wordpiece_tokens_reprs'] = seq_wordpiece_tokens_reprs batch_inputs['seq_tokens_reprs'] = batched_index_select( seq_wordpiece_tokens_reprs, batch_inputs['tokens_index']) all_ents = [] all_ents_labels = [] all_seq_tokens_reprs = [] for ent_spans, ent_labels, seq_tokens_reprs in zip( batch_inputs['ent_spans'], batch_inputs['ent_labels'], batch_inputs['seq_tokens_reprs']): all_ents.extend([span[0], span[1] - 1] for span in ent_spans) all_ents_labels.extend(ent_label for ent_label in ent_labels) all_seq_tokens_reprs.extend(seq_tokens_reprs for _ in range(len(ent_spans))) if self.ent_batch_size > 0: all_entity_typing_loss = [] for idx in range(0, len(all_ents), self.ent_batch_size): batch_ents_tensor = torch.LongTensor( all_ents[idx:idx + self.ent_batch_size]).unsqueeze(1) if self.device > -1: batch_ents_tensor = batch_ents_tensor.cuda( device=self.device, non_blocking=True) batch_seq_tokens_reprs = torch.stack( all_seq_tokens_reprs[idx:idx + self.ent_batch_size]) batch_ents_feature = self.ent2hidden( self.entity_span_extractor(batch_seq_tokens_reprs, batch_ents_tensor).squeeze(1)) batch_ents_labels = torch.LongTensor( all_ents_labels[idx:idx + self.ent_batch_size]) if self.device > -1: batch_ents_labels = batch_ents_labels.cuda( device=self.device, non_blocking=True) entity_typing_outputs = self.entity_pretrained_decoder( batch_ents_feature, batch_ents_labels) all_entity_typing_loss.append(entity_typing_outputs['loss']) if len(all_entity_typing_loss) != 0: entity_typing_loss = sum(all_entity_typing_loss) / len( all_entity_typing_loss) else: zero_loss = torch.Tensor([0]) zero_loss.requires_grad = True if self.device > -1: zero_loss = zero_loss.cuda(device=self.device, non_blocking=True) entity_typing_loss = zero_loss else: all_ents_tensor = torch.LongTensor(all_ents).unsqueeze(1) if self.device > -1: all_ents_tensor = all_ents_tensor.cuda(device=self.device, non_blocking=True) all_seq_tokens_reprs = torch.stack(all_seq_tokens_reprs) all_ents_feature = self.entity_span_extractor( all_seq_tokens_reprs, all_ents_tensor).squeeze(1) all_ents_feature = self.ent2hidden(all_ents_feature) all_ents_labels = torch.LongTensor(all_ents_labels) if self.device > -1: all_ents_labels = all_ents_labels.cuda(device=self.device, non_blocking=True) entity_typing_outputs = self.entity_pretrained_decoder( all_ents_feature, all_ents_labels) entity_typing_loss = entity_typing_outputs['loss'] outputs = {} outputs['loss'] = entity_typing_loss return outputs def masked_entity_token_prediction(self, batch_inputs): """This function pretrains masked entity tokens prediction task. Arguments: batch_inputs {dict} -- batch inputs """ masked_seq_wordpiece_tokens_reprs, _ = self.bert_encoder( batch_inputs['tokens_id']) masked_seq_wordpiece_tokens_reprs = self.masked_token_mlp( masked_seq_wordpiece_tokens_reprs) if batch_inputs['masked_index'].sum() != 0: masked_entity_token_outputs = self.seq_decoder( seq_inputs=masked_seq_wordpiece_tokens_reprs, seq_mask=batch_inputs['masked_index'], seq_labels=batch_inputs['tokens_label']) masked_entity_token_loss = masked_entity_token_outputs['loss'] else: zero_loss = torch.Tensor([0]) zero_loss.requires_grad = True if self.device > -1: zero_loss = zero_loss.cuda(device=self.device, non_blocking=True) masked_entity_token_loss = zero_loss outputs = {} outputs['loss'] = masked_entity_token_loss return outputs def permutation_prediction(self, batch_inputs): """This function pretrains entity mention permutaiton prediction task. Arguments: batch_inputs {dict} -- batch inputs """ all_permutation_feature = self.get_entity_mention_feature( batch_inputs['tokens_id'], batch_inputs['tokens_index'], batch_inputs['ent_mention'], batch_inputs['tokens_index_lens']) permutation_outputs = self.permutation_decoder( all_permutation_feature, batch_inputs['ent_mention_label']) permutation_loss = permutation_outputs['loss'] outputs = {} outputs['loss'] = permutation_loss return outputs def confused_context_prediction(self, batch_inputs): """This function pretrains confused context prediction task. Arguments: batch_inputs {dict} -- batch inputs """ all_confused_context_feature = self.get_entity_mention_feature( batch_inputs['confused_tokens_id'], batch_inputs['confused_tokens_index'], batch_inputs['confused_ent_mention'], batch_inputs['confused_tokens_index_lens']) all_truth_context_feature = self.get_entity_mention_feature( batch_inputs['origin_tokens_id'], batch_inputs['origin_tokens_index'], batch_inputs['origin_ent_mention'], batch_inputs['origin_tokens_index_lens']) confused_context_score = self.confused_context_decoder( all_confused_context_feature) truth_context_score = self.confused_context_decoder( all_truth_context_feature) rank_loss = torch.mean( torch.relu(5.0 - torch.abs(confused_context_score - truth_context_score))) outputs = {} outputs['loss'] = rank_loss return outputs def get_entity_mention_feature(self, batch_wordpiece_tokens, batch_wordpiece_tokens_index, batch_entity_mentions, batch_seq_lens): """This function extracts entity mention feature using CNN. Arguments: batch_wordpiece_tokens {tensor} -- batch wordpiece tokens batch_wordpiece_tokens_index {tensor} -- batch wordpiece tokens index batch_entity_mentions {list} -- batch entity mentions batch_seq_lens {list} -- batch sequence length list Returns: tensor -- entity mention feature """ batch_seq_reprs, _ = self.bert_encoder(batch_wordpiece_tokens) batch_seq_reprs = batched_index_select(batch_seq_reprs, batch_wordpiece_tokens_index) entity_spans = [] context_spans = [] for entity_mention, seq_len in zip(batch_entity_mentions, batch_seq_lens): entity_spans.append([[entity_mention[0][0], entity_mention[0][1]], [entity_mention[1][0], entity_mention[1][1]]]) context_spans.append([[0, entity_mention[0][0]], [entity_mention[0][1], entity_mention[1][0]], [entity_mention[1][1], seq_len]]) entity_spans_tensor = torch.LongTensor(entity_spans) if self.device > -1: entity_spans_tensor = entity_spans_tensor.cuda(device=self.device, non_blocking=True) context_spans_tensor = torch.LongTensor(context_spans) if self.device > -1: context_spans_tensor = context_spans_tensor.cuda( device=self.device, non_blocking=True) entity_feature = self.entity_span_extractor(batch_seq_reprs, entity_spans_tensor) context_feature = self.context_span_extractor(batch_seq_reprs, context_spans_tensor) entity_feature = self.ent2hidden(entity_feature) context_feature = self.context2hidden(context_feature) entity_mention_feature = torch.cat([ context_feature[:, 0, :], entity_feature[:, 0, :], context_feature[:, 1, :], entity_feature[:, 1, :], context_feature[:, 2, :] ], dim=-1).view( len(batch_wordpiece_tokens), -1) entity_mention_feature = self.mlp(entity_mention_feature) return entity_mention_feature