class BertForMultiLabelSequenceClassification(BertPreTrainedModel): """BERT model for classification. This module is composed of the BERT model with a linear layer on top of the pooled output. """ def __init__(self, config, num_labels=4): super(BertForMultiLabelSequenceClassification, self).__init__(config) self.num_labels = num_labels self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, position_ids=None, head_mask=None, output_all_encoded_layers=False) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) return logits def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True
class BertClassifier(BertPreTrainedModel): """ BERT multi-label classifier """ def __init__(self, config): super(BertClassifier, self).__init__(config) self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None): """ Forward pass of the BERT classifier :param input_ids: the input IDs (bs, seq len) :param token_type_ids: (not used) a tensor of zeros indicating which sequence in sequence pairs (bs, seq len) :param attention_mask: tensor of one if not pad token, zero otherwise (bs, seq len) :return: logits corresponding to each output class (bs, ) """ _, pooled_output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) pooled_output = self.dropout(pooled_output) return self.classifier(pooled_output) def freeze_bert_encoder(self): """ Prevents further backpropagation (used when testing) """ for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): """ Re-enables backpropagation (used when training) """ for param in self.bert.parameters(): param.requires_grad = True def save(self, path: str): print('save model parameters to [%s]' % path, file=sys.stderr) # Only save the model and not the entire pretrained Bert model_to_save = self.module if hasattr(self, 'module') else self torch.save(model_to_save.state_dict(), path) @staticmethod def load(model_path: str, bert_pretrained_path: str, num_labels: int): """ Load a fine-tuned model from a file. @param model_path (str): path to model """ state_dict = torch.load(model_path) model = BertClassifier.from_pretrained(bert_pretrained_path, state_dict=state_dict, num_labels=num_labels) return model
class Teacher(nn.Module): def __init__(self, pretrained_model, freeze_bert=True, lstm_dim=-1): super(Teacher, self).__init__() self.output_dim = len(punctuation_dict) self.config = BertConfig.from_pretrained(pretrained_model, ) self.bert_layer = BertModel(self.config) # Freeze bert layers # if freeze_bert: for p in self.bert_layer.parameters(): p.requires_grad = False bert_dim = self.config.hidden_size if lstm_dim == -1: hidden_size = bert_dim else: hidden_size = lstm_dim self.lstm = nn.LSTM(input_size=bert_dim, hidden_size=hidden_size, num_layers=1, bidirectional=True) def forward(self, input_ids, attention_mask): # if len(x.shape) == 1: # x = x.view(1, x.shape[0]) # add dummy batch for single sample # (B, N, E) -> (B, N, E) out = self.bert_layer(input_ids, attention_mask=attention_mask) x = out.last_hidden_state # (B, N, E) -> (N, B, E) x = torch.transpose(x, 0, 1) x, (_, _) = self.lstm(x) # (N, B, E) -> (B, N, E) x = torch.transpose(x, 0, 1) x = self.linear(x) return x, hs[0], hs[6], hs[12]
class BertForLinearSequenceToSequenceProbing(ProteinBertAbstractModel): """Bert head for token-level prediction tasks (secondary structure, binding sites)""" def __init__(self, config): super().__init__(config) self.bert = BertModel(config) self.classify = LinearSequenceToSequenceClassificationHead( config.hidden_size, config.num_labels, ignore_index=-1, dropout=0.5) for param in self.bert.parameters(): param.requires_grad = False self.init_weights() def forward(self, input_ids, input_mask=None, targets=None): outputs = self.bert(input_ids, input_mask=input_mask) sequence_output, pooled_output = outputs[:2] outputs = self.classify(sequence_output, targets) + outputs[2:] return outputs
class RCNNModel(BertPreTrainedModel): def __init__(self, config, num_class): super(RCNNModel, self).__init__(config) self.bert = BertModel(config) for param in self.bert.parameters(): param.requires_grad = True self.lstm = nn.LSTM(768, 256, 2, bidirectional=True, batch_first=True, dropout=0.1) self.maxpool = nn.MaxPool1d(512) self.fc = nn.Linear(512 + 768, num_class) def forward(self, x, masks): encoder_out, text_cls = self.bert(input_ids=x, attention_mask=masks) out, _ = self.lstm(encoder_out) out = torch.cat((encoder_out, out), 2) out = F.relu(out) out = out.permute(0, 2, 1) # print(out.size()) out = self.maxpool(out) out = out.squeeze() out = self.fc(out) return out
class BertForMultiLabelSequenceClassification(BertPreTrainedModel): def __init__(self, config, label_emb): super(BertForMultiLabelSequenceClassification, self).__init__(config, label_emb=None) self.num_labels = config.num_labels self.hidden_size = config.hidden_size self.label_emb = label_emb self.bert = BertModel(config) self.dropout = nn.Dropout(0.1) self.self_attn = SelfAttention(self.hidden_size, self.num_labels) self.label_attn = LabelAttention(self.hidden_size, self.num_labels, self.label_emb) self.linear = MLinear(self.hidden_size, self.num_labels) self.init_weights() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, ): sequence, _ = self.bert(input_ids, attention_mask) sequence = self.dropout(sequence) # [batch, sequence, hidden_size] masks = attention_mask != 0 # [batch, sequence] masks = torch.unsqueeze(masks, 1) # [batch, 1, sequence] self_attn = self.self_attn(sequence, masks) label_attn = self.label_attn(sequence, masks) return self.linear(self_attn, label_attn) def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True
class BertForMultiLabelSequenceClassification(BertPreTrainedModel): """BERT model for classification. This module is composed of the BERT model with a linear layer on top of the pooled output. """ def __init__(self, config): super(BertForMultiLabelSequenceClassification, self).__init__(config) self.bert = BertModel(config) self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) self.classifier = torch.nn.Linear(config.hidden_size + 128, config.num_labels) self.init_weights() def forward(self, input_ids, node_vec, tfidf, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None, labels=None): outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, position_ids=None, head_mask=None)[0] outputs = torch.sum(outputs * tfidf.unsqueeze(2), 1) outputs = torch.cat((outputs, node_vec), 1) outputs = self.dropout(outputs) logits = self.classifier(outputs) return logits def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True
class BaseModel(BertPreTrainedModel): def __init__(self, config, num_class): super(BaseModel, self).__init__(config) self.bert = BertModel(config) for param in self.bert.parameters(): param.requires_grad = True self.dropout = nn.Dropout(0.5) self.fc1 = nn.Linear(768, num_class) self.init_weights() def forward(self, x, masks): encoder_out, text_cls = self.bert(input_ids=x, attention_mask=masks) x = self.dropout(text_cls) x = self.fc1(x) return x
class DPCNNModel(BertPreTrainedModel): def __init__(self, config, num_class): super(DPCNNModel, self).__init__(config) self.bert = BertModel(config) for param in self.bert.parameters(): param.requires_grad = True # self.fc = nn.Linear(config.hidden_size, config.num_classes) self.conv_region = nn.Conv2d(1, 250, (3, 768), stride=1) self.conv = nn.Conv2d(250, 250, (3, 1), stride=1) self.max_pool = nn.MaxPool2d(kernel_size=(3, 1), stride=2) self.padding1 = nn.ZeroPad2d((0, 0, 1, 1)) # top bottom self.padding2 = nn.ZeroPad2d((0, 0, 0, 1)) # bottom self.relu = nn.ReLU() self.fc = nn.Linear(250, num_class) def forward(self, x, masks): encoder_out, text_cls = self.bert(input_ids=x, attention_mask=masks) x = encoder_out.unsqueeze(1) # [batch_size, 1, seq_len, embed] x = self.conv_region(x) # [batch_size, 250, seq_len-3+1, 1] x = self.padding1(x) # [batch_size, 250, seq_len, 1] x = self.relu(x) x = self.conv(x) # [batch_size, 250, seq_len-3+1, 1] x = self.padding1(x) # [batch_size, 250, seq_len, 1] x = self.relu(x) x = self.conv(x) # [batch_size, 250, seq_len-3+1, 1] while x.size()[2] > 2: x = self._block(x) x = x.squeeze() # [batch_size, num_filters(250)] x = self.fc(x) return x def _block(self, x): x = self.padding2(x) px = self.max_pool(x) x = self.padding1(px) x = F.relu(x) x = self.conv(x) x = self.padding1(x) x = F.relu(x) x = self.conv(x) x = x + px # short cut return x
def add_enc_adapters(bert_model: BertModel, config: AdapterConfig) -> BertModel: # Replace specific layer with adapter-added layer bert_encoder = bert_model.encoder for i in range(len(bert_model.encoder.layer)): bert_encoder.layer[i].attention.output = adapt_bert_self_output( config)(bert_encoder.layer[i].attention.output) bert_encoder.layer[i].output = adapt_bert_output(config)( bert_encoder.layer[i].output) # Freeze all parameters for param in bert_model.parameters(): param.requires_grad = False # Unfreeze trainable parts — layer norms and adapters for name, sub_module in bert_model.named_modules(): if isinstance(sub_module, (Adapter_func, BertLayerNorm)): for param_name, param in sub_module.named_parameters(): param.requires_grad = True return bert_model
class BertABSATagger(BertPreTrainedModel): def __init__(self, bert_config): """ :param bert_config: configuration for bert model """ super(BertABSATagger, self).__init__(bert_config) self.num_labels = bert_config.num_labels self.tagger_config = TaggerConfig() self.tagger_config.absa_type = bert_config.absa_type.lower() if bert_config.tfm_mode == 'finetune': # initialized with pre-trained BERT and perform finetuning print("Fine-tuning the pre-trained BERT...") self.bert = BertModel(bert_config) else: raise Exception("Invalid transformer mode %s!!!" % bert_config.tfm_mode) self.bert_dropout = nn.Dropout(bert_config.hidden_dropout_prob) # fix the parameters in BERT and regard it as feature extractor if bert_config.fix_tfm: # fix the parameters of the (pre-trained or randomly initialized) transformers during fine-tuning for p in self.bert.parameters(): p.requires_grad = False self.tagger = None if self.tagger_config.absa_type == 'linear': # hidden size at the penultimate layer penultimate_hidden_size = bert_config.hidden_size else: self.tagger_dropout = nn.Dropout( self.tagger_config.hidden_dropout_prob) if self.tagger_config.absa_type == 'lstm': self.tagger = LSTM( input_size=bert_config.hidden_size, hidden_size=self.tagger_config.hidden_size, bidirectional=self.tagger_config.bidirectional) elif self.tagger_config.absa_type == 'gru': self.tagger = GRU( input_size=bert_config.hidden_size, hidden_size=self.tagger_config.hidden_size, bidirectional=self.tagger_config.bidirectional) elif self.tagger_config.absa_type == 'tfm': # transformer encoder layer self.tagger = nn.TransformerEncoderLayer( d_model=bert_config.hidden_size, nhead=12, dim_feedforward=4 * bert_config.hidden_size, dropout=0.1) elif self.tagger_config.absa_type == 'san': # vanilla self attention networks self.tagger = SAN(d_model=bert_config.hidden_size, nhead=12, dropout=0.1) elif self.tagger_config.absa_type == 'crf': self.tagger = CRF(num_tags=self.num_labels) else: raise Exception('Unimplemented downstream tagger %s...' % self.tagger_config.absa_type) penultimate_hidden_size = self.tagger_config.hidden_size self.classifier = nn.Linear(penultimate_hidden_size, bert_config.num_labels) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, position_ids=None, head_mask=None): outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, head_mask=head_mask) # the hidden states of the last Bert Layer, shape: (bsz, seq_len, hsz) tagger_input = outputs[0] tagger_input = self.bert_dropout(tagger_input) #print("tagger_input.shape:", tagger_input.shape) if self.tagger is None or self.tagger_config.absa_type == 'crf': # regard classifier as the tagger logits = self.classifier(tagger_input) else: if self.tagger_config.absa_type == 'lstm': # customized LSTM classifier_input, _ = self.tagger(tagger_input) elif self.tagger_config.absa_type == 'gru': # customized GRU classifier_input, _ = self.tagger(tagger_input) elif self.tagger_config.absa_type == 'san' or self.tagger_config.absa_type == 'tfm': # vanilla self-attention networks or transformer # adapt the input format for the transformer or self attention networks tagger_input = tagger_input.transpose(0, 1) classifier_input = self.tagger(tagger_input) classifier_input = classifier_input.transpose(0, 1) else: raise Exception("Unimplemented downstream tagger %s..." % self.tagger_config.absa_type) classifier_input = self.tagger_dropout(classifier_input) logits = self.classifier(classifier_input) outputs = (logits, ) + outputs[2:] if labels is not None: if self.tagger_config.absa_type != 'crf': loss_fct = CrossEntropyLoss() if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels)[active_loss] active_labels = labels.view(-1)[active_loss] loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss, ) + outputs else: log_likelihood = self.tagger(inputs=logits, tags=labels, mask=attention_mask) loss = -log_likelihood outputs = (loss, ) + outputs return outputs
def main(): parser = argparse.ArgumentParser( description='Train the individual Transformer model') parser.add_argument('--dataset_folder', type=str, default='datasets') parser.add_argument('--dataset_name', type=str, default='zara1') parser.add_argument('--obs', type=int, default=8) parser.add_argument('--preds', type=int, default=12) parser.add_argument('--emb_size', type=int, default=1024) parser.add_argument('--heads', type=int, default=8) parser.add_argument('--layers', type=int, default=6) parser.add_argument('--dropout', type=float, default=0.1) parser.add_argument('--cpu', action='store_true') parser.add_argument('--output_folder', type=str, default='Output') parser.add_argument('--val_size', type=int, default=50) parser.add_argument('--gpu_device', type=str, default="0") parser.add_argument('--verbose', action='store_true') parser.add_argument('--max_epoch', type=int, default=100) parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--validation_epoch_start', type=int, default=30) parser.add_argument('--resume_train', action='store_true') parser.add_argument('--delim', type=str, default='\t') parser.add_argument('--name', type=str, default="zara1") args = parser.parse_args() model_name = args.name try: os.mkdir('models') except: pass try: os.mkdir('output') except: pass try: os.mkdir('output/BERT') except: pass try: os.mkdir(f'models/BERT') except: pass try: os.mkdir(f'output/BERT/{args.name}') except: pass try: os.mkdir(f'models/BERT/{args.name}') except: pass log = SummaryWriter('logs/BERT_%s' % model_name) log.add_scalar('eval/mad', 0, 0) log.add_scalar('eval/fad', 0, 0) try: os.mkdir(args.name) except: pass device = torch.device("cuda") if args.cpu or not torch.cuda.is_available(): device = torch.device("cpu") args.verbose = True ## creation of the dataloaders for train and validation train_dataset, _ = baselineUtils.create_dataset(args.dataset_folder, args.dataset_name, 0, args.obs, args.preds, delim=args.delim, train=True, verbose=args.verbose) val_dataset, _ = baselineUtils.create_dataset(args.dataset_folder, args.dataset_name, 0, args.obs, args.preds, delim=args.delim, train=False, verbose=args.verbose) test_dataset, _ = baselineUtils.create_dataset(args.dataset_folder, args.dataset_name, 0, args.obs, args.preds, delim=args.delim, train=False, eval=True, verbose=args.verbose) from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, AdamW config = BertConfig(vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act='relu', hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12) model = BertModel(config).to(device) from individual_TF import LinearEmbedding as NewEmbed, Generator as GeneratorTS a = NewEmbed(3, 768).to(device) model.set_input_embeddings(a) generator = GeneratorTS(768, 2).to(device) #model.set_output_embeddings(GeneratorTS(1024,2)) tr_dl = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0) val_dl = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0) test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) #optim = SGD(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01) #sched=torch.optim.lr_scheduler.StepLR(optim,0.0005) optim = NoamOpt( 768, 0.1, len(tr_dl), torch.optim.Adam(list(a.parameters()) + list(model.parameters()) + list(generator.parameters()), lr=0, betas=(0.9, 0.98), eps=1e-9)) #optim=Adagrad(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01,lr_decay=0.001) epoch = 0 mean = train_dataset[:]['src'][:, :, 2:4].mean((0, 1)) * 0 std = train_dataset[:]['src'][:, :, 2:4].std((0, 1)) * 0 + 1 while epoch < args.max_epoch: epoch_loss = 0 model.train() for id_b, batch in enumerate(tr_dl): optim.optimizer.zero_grad() r = 0 rot_mat = np.array([[np.cos(r), np.sin(r)], [-np.sin(r), np.cos(r)]]) inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device) inp = torch.matmul(inp, torch.from_numpy(rot_mat).float().to(device)) trg_masked = torch.zeros((inp.shape[0], args.preds, 2)).to(device) inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device) trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1], 1).to(device) inp_cat = torch.cat((inp, trg_masked), 1) cls_cat = torch.cat((inp_cls, trg_cls), 1) net_input = torch.cat((inp_cat, cls_cat), 2) position = torch.arange(0, net_input.shape[1]).repeat( inp.shape[0], 1).long().to(device) token = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) attention_mask = torch.ones( (inp.shape[0], net_input.shape[1])).long().to(device) out = model(input_ids=net_input, position_ids=position, token_type_ids=token, attention_mask=attention_mask) pred = generator(out[0]) loss = F.pairwise_distance( pred[:, :].contiguous().view(-1, 2), torch.matmul( torch.cat( (batch['src'][:, :, 2:4], batch['trg'][:, :, 2:4]), 1).contiguous().view(-1, 2).to(device), torch.from_numpy(rot_mat).float().to(device))).mean() loss.backward() optim.step() print("epoch %03i/%03i frame %04i / %04i loss: %7.4f" % (epoch, args.max_epoch, id_b, len(tr_dl), loss.item())) epoch_loss += loss.item() #sched.step() log.add_scalar('Loss/train', epoch_loss / len(tr_dl), epoch) with torch.no_grad(): model.eval() gt = [] pr = [] val_loss = 0 for batch in val_dl: inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device) trg_masked = torch.zeros( (inp.shape[0], args.preds, 2)).to(device) inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device) trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1], 1).to(device) inp_cat = torch.cat((inp, trg_masked), 1) cls_cat = torch.cat((inp_cls, trg_cls), 1) net_input = torch.cat((inp_cat, cls_cat), 2) position = torch.arange(0, net_input.shape[1]).repeat( inp.shape[0], 1).long().to(device) token = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) attention_mask = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) out = model(input_ids=net_input, position_ids=position, token_type_ids=token, attention_mask=attention_mask) pred = generator(out[0]) loss = F.pairwise_distance( pred[:, :].contiguous().view(-1, 2), torch.cat( (batch['src'][:, :, 2:4], batch['trg'][:, :, 2:4]), 1).contiguous().view(-1, 2).to(device)).mean() val_loss += loss.item() gt_b = batch['trg'][:, :, 0:2] preds_tr_b = pred[:, args.obs:].cumsum(1).to( 'cpu').detach() + batch['src'][:, -1:, 0:2] gt.append(gt_b) pr.append(preds_tr_b) gt = np.concatenate(gt, 0) pr = np.concatenate(pr, 0) mad, fad, errs = baselineUtils.distance_metrics(gt, pr) log.add_scalar('validation/loss', val_loss / len(val_dl), epoch) log.add_scalar('validation/mad', mad, epoch) log.add_scalar('validation/fad', fad, epoch) model.eval() gt = [] pr = [] for batch in test_dl: inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device) trg_masked = torch.zeros( (inp.shape[0], args.preds, 2)).to(device) inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device) trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1], 1).to(device) inp_cat = torch.cat((inp, trg_masked), 1) cls_cat = torch.cat((inp_cls, trg_cls), 1) net_input = torch.cat((inp_cat, cls_cat), 2) position = torch.arange(0, net_input.shape[1]).repeat( inp.shape[0], 1).long().to(device) token = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) attention_mask = torch.zeros( (inp.shape[0], net_input.shape[1])).long().to(device) out = model(input_ids=net_input, position_ids=position, token_type_ids=token, attention_mask=attention_mask) pred = generator(out[0]) gt_b = batch['trg'][:, :, 0:2] preds_tr_b = pred[:, args.obs:].cumsum(1).to( 'cpu').detach() + batch['src'][:, -1:, 0:2] gt.append(gt_b) pr.append(preds_tr_b) gt = np.concatenate(gt, 0) pr = np.concatenate(pr, 0) mad, fad, errs = baselineUtils.distance_metrics(gt, pr) torch.save(model.state_dict(), "models/BERT/%s/ep_%03i.pth" % (args.name, epoch)) torch.save(generator.state_dict(), "models/BERT/%s/gen_%03i.pth" % (args.name, epoch)) torch.save(a.state_dict(), "models/BERT/%s/emb_%03i.pth" % (args.name, epoch)) log.add_scalar('eval/mad', mad, epoch) log.add_scalar('eval/fad', fad, epoch) epoch += 1 ab = 1
class BertEncoder(MetaModule): """BERT model as presented in Google's paper and using Hugging Face's code References: https://arxiv.org/abs/1810.04805 """ class Config(BaseConfig): model_name: Union[str, Path] = 'bert-base-multilingual-cased' """Pre-trained BERT model to use.""" use_mismatch_features: bool = False """Use Alibaba's mismatch features.""" use_predictor_features: bool = False """Use features originally proposed in the Predictor model.""" interleave_input: bool = False """Concatenate SOURCE and TARGET without internal padding (111222000 instead of 111002220)""" freeze: bool = False """Freeze BERT during training.""" use_mlp: bool = True """Apply a linear layer on top of BERT.""" hidden_size: int = 100 """Size of the linear layer on top of BERT.""" scalar_mix_dropout: confloat(ge=0.0, le=1.0) = 0.1 scalar_mix_layer_norm: bool = True @validator('model_name', pre=True) def fix_relative_path(cls, v): if ( v not in BERT_PRETRAINED_MODEL_ARCHIVE_LIST and v not in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST ): v = Path(v) if not v.is_absolute(): v = Path.cwd().joinpath(v) return v @validator('use_mismatch_features', 'use_predictor_features', pre=True) def no_implementation(cls, v): if v: raise NotImplementedError('Not yet implemented') return False def __init__( self, vocabs: Dict[str, Vocabulary], config: Config, pre_load_model: bool = True ): super().__init__(config=config) if pre_load_model: self.bert = BertModel.from_pretrained( self.config.model_name, output_hidden_states=True ) else: bert_config = BertConfig.from_pretrained( self.config.model_name, output_hidden_states=True ) self.bert = BertModel(bert_config) self.vocabs = { const.TARGET: vocabs[const.TARGET], const.SOURCE: vocabs[const.SOURCE], } self.mlp = None if self.config.use_mlp: self.mlp = nn.Sequential( nn.Linear(self.bert.config.hidden_size, self.config.hidden_size), nn.Tanh(), ) output_size = self.config.hidden_size else: output_size = self.bert.config.hidden_size self.scalar_mix = ScalarMixWithDropout( mixture_size=self.bert.config.num_hidden_layers + 1, # +1 for embeddings do_layer_norm=self.config.scalar_mix_layer_norm, dropout=self.config.scalar_mix_dropout, ) self._sizes = { const.TARGET: output_size, const.TARGET_LOGITS: output_size, const.TARGET_SENTENCE: self.bert.config.hidden_size, const.SOURCE: output_size, } self.output_embeddings = self.bert.embeddings.word_embeddings if self.config.freeze: for param in self.bert.parameters(): param.requires_grad = False def load_state_dict( self, state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]], strict: bool = True, ): try: keys = super().load_state_dict(state_dict, strict) except RuntimeError as e: if "position_ids" in str(e): # FIXME: hack to get around Transformers 3.1 breaking changes # https://github.com/huggingface/transformers/issues/6882 self.bert.embeddings._non_persistent_buffers_set.add('position_ids') keys = super().load_state_dict(state_dict, strict) self.bert.embeddings._non_persistent_buffers_set.discard('position_ids') else: raise e return keys @classmethod def input_data_encoders(cls, config: Config): return { const.SOURCE: TransformersTextEncoder( tokenizer_name=config.model_name, is_source=True ), const.TARGET: TransformersTextEncoder(tokenizer_name=config.model_name), } def size(self, field=None): if field: return self._sizes[field] return self._sizes def forward( self, batch_inputs, *args, include_target_logits=False, include_source_logits=False ): # BERT gets it's input as a concatenation of both embeddings # or as an interleave of inputs if self.config.interleave_input: merge_input_fn = self.interleave_input else: merge_input_fn = self.concat_input input_ids, token_type_ids, attention_mask = merge_input_fn( batch_inputs[const.SOURCE], batch_inputs[const.TARGET], pad_id=self.vocabs[const.TARGET].pad_id, ) # hidden_states also includes the embedding layer # hidden_states[-1] is the last layer last_hidden_state, pooler_output, hidden_states = self.bert( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, ) # TODO: select one of these strategies via cli # TODO: get a BETTER strategy features = self.scalar_mix(hidden_states, attention_mask) # features = sum(hidden_states[-5:-1]) # features = hidden_states[-2] if self.config.use_mlp: features = self.mlp(features) # Build the feature dictionary to be returned to the system output_features = self.split_outputs( features, batch_inputs, interleaved=self.config.interleave_input ) # Convert pieces to tokens target_features = pieces_to_tokens( output_features[const.TARGET], batch_inputs[const.TARGET] ) source_features = pieces_to_tokens( output_features[const.SOURCE], batch_inputs[const.SOURCE] ) # sentence_features = pooler_output sentence_features = last_hidden_state.mean(dim=1) # Substitute CLS on target side # target_features[:, 0] = 0 output_features[const.TARGET] = target_features output_features[const.SOURCE] = source_features output_features[const.TARGET_SENTENCE] = sentence_features # Logits for multi-task fine-tuning if include_target_logits: output_features[const.TARGET_LOGITS] = torch.einsum( 'vh,bsh->bsv', self.output_embeddings.weight, output_features[const.TARGET], ) if include_source_logits: output_features[const.SOURCE_LOGITS] = torch.einsum( 'vh,bsh->bsv', self.output_embeddings.weight, output_features[const.SOURCE], ) # Additional features if self.config.use_mismatch_features: raise NotImplementedError return output_features @staticmethod def concat_input(source_batch, target_batch, pad_id): """Concatenate the target + source embeddings into one tensor. Return: concatenation of embeddings, mask of target (as ones) and source (as zeroes) and concatenation of attention_mask """ source_ids = source_batch.tensor target_ids = target_batch.tensor source_attention_mask = retrieve_tokens_mask(source_batch) target_attention_mask = retrieve_tokens_mask(target_batch) target_types = torch.zeros_like(target_ids) # zero denotes first sequence source_types = torch.ones_like(source_ids) input_ids = torch.cat((target_ids, source_ids), dim=1) token_type_ids = torch.cat((target_types, source_types), dim=1) attention_mask = torch.cat( (target_attention_mask, source_attention_mask), dim=1 ) return input_ids, token_type_ids, attention_mask @staticmethod def split_outputs( features: Tensor, batch_inputs: MultiFieldBatch, interleaved: bool = False ) -> Dict[str, Tensor]: """Split features back into sentences A and B. Args: features: BERT's output: ``[CLS] target [SEP] source [SEP]``. Shape of (bs, 1 + target_len + 1 + source_len + 1, 2) batch_inputs: the regular batch object, containing ``source`` and ``target`` batches interleaved: whether the concat strategy was interleaved Return: dict of tensors for ``source`` and ``target``. """ outputs = OrderedDict() target_lengths = batch_inputs[const.TARGET].lengths if interleaved: raise NotImplementedError('interleaving not supported.') # TODO: fix code below to use the lengths information and not bounds # if interleaved, shift each source sample by its correspondent length shift = target_lengths.unsqueeze(-1) range_vector = torch.arange( features.size(0), device=features.device ).unsqueeze(1) target_bounds = batch_inputs[const.TARGET].bounds target_features = features[range_vector, target_bounds] # Shift bounds by target length and preserve padding source_bounds = batch_inputs[const.SOURCE].bounds m = (source_bounds != -1).long() # for masking out padding (which is -1) shifted_bounds = (source_bounds + shift) * m + source_bounds * (1 - m) source_features = features[range_vector, shifted_bounds] else: # otherwise, shift all by max_length # if we'd like to maintain the word pieces we merely select all target_features = features[:, : target_lengths.max()] # ignore the target and get the rest source_features = features[:, target_lengths.max() :] outputs[const.TARGET] = target_features # Source doesn't have an init_token (like CLS) and we keep SEP outputs[const.SOURCE] = source_features return outputs # TODO this strategy is not being used, should we keep it? @staticmethod def interleave_input(source_batch, target_batch, pad_id): """Interleave the source + target embeddings into one tensor. This means making the input as [batch, target [SEP] source]. Return: interleave of embds, mask of target (as zeroes) and source (as ones) and concatenation of attention_mask. """ source_ids = source_batch.tensor target_ids = target_batch.tensor batch_size = source_ids.size(0) source_lengths = source_batch.lengths target_lengths = target_batch.lengths max_pair_length = source_ids.size(1) + target_ids.size(1) input_ids = torch.full( (batch_size, max_pair_length), pad_id, dtype=torch.long, device=source_ids.device, ) token_type_ids = torch.zeros_like(input_ids) attention_mask = torch.zeros_like(input_ids) for i in range(batch_size): # [CLS] and [SEP] are included in the mask (=1) # note: source does not have CLS t_len = target_lengths[i].item() s_len = source_lengths[i].item() input_ids[i, :t_len] = target_ids[i, :t_len] token_type_ids[i, :t_len] = 0 attention_mask[i, :t_len] = 1 input_ids[i, t_len : t_len + s_len] = source_ids[i, :s_len] token_type_ids[i, t_len : t_len + s_len] = 1 attention_mask[i, t_len : t_len + s_len] = 1 # TODO, why is attention mask 1 for all positions? return input_ids, token_type_ids, attention_mask @staticmethod def get_mismatch_features(logits, target, pred): # calculate mismatch features and concat them t_max = torch.gather(logits, -1, target.unsqueeze(-1)) p_max = torch.gather(logits, -1, pred.unsqueeze(-1)) diff_max = t_max - p_max diff_arg = (target != pred).float().unsqueeze(-1) mismatch = torch.cat((t_max, p_max, diff_max, diff_arg), dim=-1) return mismatch
class CSER(BertPreTrainedModel): """ Span-based model to extract entities """ def __init__(self, config: BertConfig, cls_token: int, entity_types: int, size_embedding: int, prop_drop: float, freeze_transformer: bool): # noqa super(CSER, self).__init__(config) # BERT model self.bert = BertModel(config) # layers self.entity_classifier = nn.Linear( config.hidden_size * 2 + size_embedding, entity_types) self.size_embeddings = nn.Embedding(100, size_embedding) self.dropout = nn.Dropout(prop_drop) self._cls_token = cls_token self._entity_types = entity_types # weight initialization self.init_weights() if freeze_transformer: print("Freeze transformer weights") # freeze all transformer weights for param in self.bert.parameters(): param.requires_grad = False def _forward_eval(self, encodings: torch.tensor, context_masks: torch.tensor, entity_masks: torch.tensor, entity_sizes: torch.tensor, entity_spans: torch.tensor, entity_sample_masks: torch.tensor): # noqa # get contextualized token embeddings from last transformer layer context_masks = context_masks.float() h = self.bert(input_ids=encodings, attention_mask=context_masks)[0] # classify entities size_embeddings = self.size_embeddings( entity_sizes) # embed entity candidate sizes entity_clf, entity_spans_pool = self._classify_entities( encodings, h, entity_masks, size_embeddings) # apply softmax entity_clf = torch.softmax(entity_clf, dim=2) return entity_clf def _classify_entities(self, encodings, h, entity_masks, size_embeddings): # max pool entity candidate spans m = (entity_masks.unsqueeze(-1) == 0).float() * (-1e30) entity_spans_pool = m + h.unsqueeze(1).repeat(1, entity_masks.shape[1], 1, 1) entity_spans_pool = entity_spans_pool.max(dim=2)[0] # get cls token as candidate context representation entity_ctx = get_token(h, encodings, self._cls_token) # create candidate representations including context, max pooled span and size embedding entity_repr = torch.cat([ entity_ctx.unsqueeze(1).repeat(1, entity_spans_pool.shape[1], 1), entity_spans_pool, size_embeddings ], dim=2) entity_repr = self.dropout(entity_repr) # classify entity candidates entity_clf = self.entity_classifier(entity_repr) return entity_clf, entity_spans_pool def forward(self, *args, **kwargs): return self._forward_eval(*args, **kwargs)
class SpEER(BertPreTrainedModel): """ Span-based model to jointly extract entities and relations """ def __init__(self, config: BertConfig, cls_token: int, relation_types: int, entity_types: int, size_embedding: int, prop_drop: float, freeze_transformer: bool, max_pairs: int = 100, encoding_size: int = 200, feature_enhancer: str = "pass"): super(SpEER, self).__init__(config) # BERT model self.bert = BertModel(config) # layers self.encoding_size = encoding_size self.feature_enhancer = fe.get_feature_enhancer(feature_enhancer)( config.hidden_size, config.hidden_size) self.rel_encoder = nn.Linear( config.hidden_size * 3 + size_embedding * 2, encoding_size) self.entity_encoder = nn.Linear( config.hidden_size * 2 + size_embedding, encoding_size) self.size_embeddings = nn.Embedding(100, size_embedding) self.dropout = nn.Dropout(prop_drop) self._cls_token = cls_token self._relation_types = relation_types self._entity_types = entity_types self._max_pairs = max_pairs # weight initialization self.init_weights() if freeze_transformer or feature_enhancer not in { "pass", "transformer" }: print("Freeze transformer weights") # freeze all transformer weights for param in self.bert.parameters(): param.requires_grad = False def _forward_train(self, encodings: torch.tensor, context_mask: torch.tensor, entity_masks: torch.tensor, entity_sizes: torch.tensor, relations: torch.tensor, rel_masks: torch.tensor): # get contextualized token embeddings from last transformer layer context_mask = context_mask.float() h = self.bert(input_ids=encodings, attention_mask=context_mask)[0] # enhance hidden features orig_shape = h.shape h = self.feature_enhancer.prepare_input(h, context_mask) h = self.feature_enhancer(h) h = self.feature_enhancer.prepare_output(h, orig_shape) entity_masks = entity_masks.float() batch_size = encodings.shape[0] device = self.entity_encoder.weight.device # encode and classify entities size_embeddings = self.size_embeddings( entity_sizes) # embed entity candidate sizes entity_encoding, entity_spans_pool = self._encode_entities( encodings, h, entity_masks, size_embeddings) entity_clf = self._classify_entities(entity_encoding) # prepare relation encoding rel_masks = rel_masks.float().unsqueeze(-1) h_large = h.unsqueeze(1).repeat( 1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1) rel_encoding = torch.zeros( [batch_size, relations.shape[1], self.encoding_size]).to(device) # obtain relation encodings # chunk processing to reduce memory usage for i in range(0, relations.shape[1], self._max_pairs): # classify relation candidates rel_encoding_chunk = self._encode_relations( entity_spans_pool, size_embeddings, relations, rel_masks, h_large, i) rel_encoding[:, i:i + self._max_pairs, :] = rel_encoding_chunk rel_clf = self._classify_relations(rel_encoding) return entity_clf, rel_clf def _forward_eval(self, entity_knn_module, rel_knn_module, entity_entries: List[List[Dict]], type_key: str, encodings: torch.tensor, context_mask: torch.tensor, entity_masks: torch.tensor, entity_sizes: torch.tensor, entity_spans: torch.tensor = None, entity_sample_mask: torch.tensor = None, verbose=False): # get contextualized token embeddings from last transformer layer context_mask = context_mask.float() h = self.bert(input_ids=encodings, attention_mask=context_mask)[0] # enhance hidden features orig_shape = h.shape h = self.feature_enhancer.prepare_input(h, context_mask) h = self.feature_enhancer(h) h = self.feature_enhancer.prepare_output(h, orig_shape) entity_masks = entity_masks.float() batch_size = encodings.shape[0] ctx_size = context_mask.shape[-1] device = self.entity_encoder.weight.device # encode and classify entities size_embeddings = self.size_embeddings( entity_sizes) # embed entity candidate sizes entity_encoding, entity_spans_pool = self._encode_entities( encodings, h, entity_masks, size_embeddings) entity_encoding_reshaped = entity_encoding.view( entity_encoding.shape[0] * entity_encoding.shape[1], -1).cpu() entity_types, entity_neighbors = entity_knn_module.infer_( entity_encoding_reshaped, int, type_key) # for i, neighbors in enumerate(entity_neighbors): # print(entity_types[i], neighbors) # print neighbor entities if verbose: print('*' * 50) print("entity neighbors:") entity_entries_flat = [] for entry in entity_entries: entity_entries_flat += entry for i, neighbors in enumerate(entity_neighbors): if entity_types[i] == 0: continue print("[ENT] {} >> {}".format( entity_entries_flat[i]["phrase"].encode('utf-8'), entity_types[i])) for j in range(min(len(neighbors), 5)): n = neighbors[j] print("\t", n["phrase"].encode('utf-8'), n["type_string"], n[type_key]) print() entity_types = torch.tensor(entity_types).view( entity_encoding.shape[0], entity_encoding.shape[1]).to(device) entity_clf = torch.zeros([ entity_encoding.shape[0], entity_encoding.shape[1], self._entity_types ], dtype=torch.long).to(device) entity_clf.scatter_(2, entity_types.unsqueeze(2), 1) # ignore entity candidates that do not constitute an actual entity for relations (based on classifier) relations, rel_masks, rel_sample_masks, rel_entries = self._filter_spans( entity_clf, entity_spans, entity_sample_mask, entity_entries, ctx_size, type_key) rel_masks = rel_masks.float() rel_sample_masks = rel_sample_masks.float() h_large = h.unsqueeze(1).repeat( 1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1) rel_encoding = torch.zeros( [batch_size, relations.shape[1], self.encoding_size]).to(device) # obtain relation encodings # chunk processing to reduce memory usage for i in range(0, relations.shape[1], self._max_pairs): # classify relation candidates rel_encoding_chunk = self._encode_relations( entity_spans_pool, size_embeddings, relations, rel_masks, h_large, i) rel_encoding[:, i:i + self._max_pairs, :] = rel_encoding_chunk rel_encoding_reshaped = rel_encoding.view( rel_encoding.shape[0] * rel_encoding.shape[1], -1).cpu() # encode and classify relations rel_types, rel_neighbors = rel_knn_module.infer_( rel_encoding_reshaped, int, type_key) # print neighbor relations if verbose: print('*' * 50) rel_entries_flat = [] for entry in rel_entries: rel_entries_flat += entry for i, neighbors in enumerate(rel_neighbors): if rel_types[i] == 0: continue print("[REL] {} >> {}".format( rel_entries_flat[i]["phrase"].encode('utf-8'), rel_types[i])) for j in range(min(len(neighbors), 5)): n = neighbors[j] print("\t", n["phrase"].encode('utf-8'), n["type_string"], n[type_key]) print() rel_types = torch.LongTensor(rel_types).view( rel_encoding.shape[0], rel_encoding.shape[1]).to(device) rel_clf = torch.zeros([ rel_encoding.shape[0], rel_encoding.shape[1], self._relation_types ], dtype=torch.float32).to(device) rel_clf.scatter_(2, rel_types.unsqueeze(2), 1) rel_clf = rel_clf[:, :, 1:] # exclude 'none' prediction for multi-label prediction rel_clf = rel_clf * rel_sample_masks # mask return entity_clf, rel_clf, relations def _forward_encode(self, encodings: torch.tensor, context_mask: torch.tensor, entity_masks: torch.tensor, entity_sizes: torch.tensor, relations: torch.tensor, rel_masks: torch.tensor): # get contextualized token embeddings from last transformer layer context_mask = context_mask.float() h = self.bert(input_ids=encodings, attention_mask=context_mask)[0] entity_masks = entity_masks.float() batch_size = encodings.shape[0] device = self.entity_encoder.weight.device # encode and classify entities size_embeddings = self.size_embeddings( entity_sizes) # embed entity candidate sizes entity_encoding, entity_spans_pool = self._encode_entities( encodings, h, entity_masks, size_embeddings) # prepare relation encoding rel_masks = rel_masks.float().unsqueeze(-1) h_large = h.unsqueeze(1).repeat( 1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1) rel_encoding = torch.zeros( [batch_size, relations.shape[1], self.encoding_size]).to(device) # obtain relation encodings # chunk processing to reduce memory usage for i in range(0, relations.shape[1], self._max_pairs): # classify relation candidates rel_encoding_chunk = self._encode_relations( entity_spans_pool, size_embeddings, relations, rel_masks, h_large, i) rel_encoding[:, i:i + self._max_pairs, :] = rel_encoding_chunk return entity_encoding, rel_encoding def _classify_entities(self, entity_encoding, verification=False): # cosine similarities of every possible entity encoding pair in the batch cosine_similarities = torch.einsum('abc, ijc -> abij', entity_encoding, entity_encoding) # einsum verification (at least each element has similarity 1 with itself) if verification: with torch.no_grad(): is_close_bools = cosine_similarities.isclose( torch.tensor([1.00], device=self.entity_encoder.weight.device)) is_close_sum = is_close_bools.int().sum().item() assert (is_close_sum >= entity_encoding.shape[0] * entity_encoding.shape[1]) # normalize cosine similarity from [-1, 1] to [0, 1], and clip float precision errors normalized_similarities = (cosine_similarities + 1) / 2 normalized_similarities = normalized_similarities.clamp(0, 1) return normalized_similarities def _encode_entities(self, encodings, h, entity_masks, size_embeddings): # max pool entity candidate spans entity_spans_pool = entity_masks.unsqueeze(-1) * h.unsqueeze(1) entity_spans_pool = entity_spans_pool.max(dim=2)[0] # get cls token as candidate context representation entity_ctx = get_token(h, encodings, self._cls_token) # create candidate representations including context, max pooled span and size embedding entity_repr = torch.cat([ entity_ctx.unsqueeze(1).repeat(1, entity_spans_pool.shape[1], 1), entity_spans_pool, size_embeddings ], dim=2) entity_repr = self.dropout(entity_repr) # encode entity candidates entity_encoding = self.entity_encoder(entity_repr) # normalize encoding to unit length for cosine similarity entity_encoding = f.normalize(entity_encoding, dim=2, p=2) return entity_encoding, entity_spans_pool def _classify_relations(self, rel_encoding, verification=False): # cosine similarity of every possible relation encoding pair in the batch cosine_similarities = torch.einsum('abc, ijc -> abij', rel_encoding, rel_encoding) # einsum verification (at least each element has similarity 1 with itself) if verification: with torch.no_grad(): is_close_bools = cosine_similarities.isclose( torch.tensor([1.00], device=self.rel_encoder.weight.device)) is_close_sum = is_close_bools.int().sum().item() assert (is_close_sum >= rel_encoding.shape[0] * rel_encoding.shape[1]) # normalize cosine similarity from [-1, 1] to [0, 1], and clip float precision errors normalized_similarities = (cosine_similarities + 1) / 2 normalized_similarities = normalized_similarities.clamp(0, 1) return normalized_similarities def _encode_relations(self, entity_spans, size_embeddings, relations, rel_masks, h, chunk_start): batch_size = relations.shape[0] # create chunks if necessary if relations.shape[1] > self._max_pairs: relations = relations[:, chunk_start:chunk_start + self._max_pairs] rel_masks = rel_masks[:, chunk_start:chunk_start + self._max_pairs] h = h[:, :relations.shape[1], :] # get pairs of entity candidate representations entity_pairs = util.batch_index(entity_spans, relations) entity_pairs = entity_pairs.view(batch_size, entity_pairs.shape[1], -1) # get corresponding size embeddings size_pair_embeddings = util.batch_index(size_embeddings, relations) size_pair_embeddings = size_pair_embeddings.view( batch_size, size_pair_embeddings.shape[1], -1) # relation context (context between entity candidate pair) rel_ctx = rel_masks * h rel_ctx = rel_ctx.max(dim=2)[0] # create relation candidate representations including context, max pooled entity candidate pairs # and corresponding size embeddings rel_repr = torch.cat([rel_ctx, entity_pairs, size_pair_embeddings], dim=2) rel_repr = self.dropout(rel_repr) # encode relation candidates rel_encoding = self.rel_encoder(rel_repr) # normalize encoding to unit length for cosine similarity rel_encoding = f.normalize(rel_encoding, dim=2, p=2) return rel_encoding #TODO: Needs checking of relation entries def _filter_spans(self, entity_clf, entity_spans, entity_sample_mask, entity_entries, ctx_size, type_key): batch_size = entity_clf.shape[0] entity_logits_max = entity_clf.argmax( dim=-1) * entity_sample_mask.long( ) # get entity type (including none) batch_relations = [] batch_rel_masks = [] batch_rel_sample_masks = [] batch_rel_entries = [] for i in range(batch_size): rels = [] rel_masks = [] sample_masks = [] rel_entries = [] # get spans classified as entities non_zero_indices = (entity_logits_max[i] != 0).nonzero().view(-1) non_zero_spans = entity_spans[i][non_zero_indices].tolist() non_zero_entries = [entity_entries[i][j] for j in non_zero_indices] non_zero_indices = non_zero_indices.tolist() # create relations and masks for n, (i1, s1) in enumerate(zip(non_zero_indices, non_zero_spans)): for m, (i2, s2) in enumerate(zip(non_zero_indices, non_zero_spans)): if i1 != i2: rels.append((i1, i2)) phrase = "|{}| <TBD> |{}|".format( non_zero_entries[n]["phrase"], non_zero_entries[m]["phrase"]) rel_entries.append({ "phrase": phrase, "type_string": "<TBD>", type_key: -1 }) rel_masks.append( sampling.create_rel_mask(s1, s2, ctx_size)) sample_masks.append(1) if not rels: # case: no more than two spans classified as entities batch_relations.append(torch.tensor([[0, 0]], dtype=torch.long)) batch_rel_masks.append( torch.tensor([[0] * ctx_size], dtype=torch.bool)) batch_rel_sample_masks.append( torch.tensor([0], dtype=torch.bool)) phrase = "" batch_rel_entries.append([{ "phrase": phrase, "type_string": "<TBD>", type_key: -1 }]) else: # case: more than two spans classified as entities batch_relations.append(torch.tensor(rels, dtype=torch.long)) batch_rel_masks.append(torch.stack(rel_masks)) batch_rel_sample_masks.append( torch.tensor(sample_masks, dtype=torch.bool)) batch_rel_entries.append(rel_entries) # stack device = self.rel_encoder.weight.device batch_relations = util.padded_stack(batch_relations).to(device) batch_rel_masks = util.padded_stack(batch_rel_masks).to( device).unsqueeze(-1) batch_rel_sample_masks = util.padded_stack(batch_rel_sample_masks).to( device).unsqueeze(-1) batch_rel_entries = util.padded_entries(batch_rel_entries) return batch_relations, batch_rel_masks, batch_rel_sample_masks, batch_rel_entries def forward(self, *args, mode="train", **kwargs): f_forward = { "train": self._forward_train, "eval": self._forward_eval, "encode": self._forward_encode }.get(mode) return f_forward(*args, **kwargs)
class TableF(BertPreTrainedModel): """ table filling model to jointly extract entities and relations """ def __init__(self, config: BertConfig, tokenizer: BertTokenizer, relation_labels: int, entity_labels: int, entity_label_embedding: int, att_hidden: int, prop_drop: float, freeze_transformer: bool, device): super(TableF, self).__init__(config) # BERT model self.bert = BertModel(config) self._tokenizer = tokenizer self._device = device # layers self.entity_label_embedding = nn.Embedding(entity_labels , entity_label_embedding) self.entity_classifier = nn.Linear(config.hidden_size * 2 + entity_label_embedding, entity_labels) self.rel_classifier = MultiHeadAttention(relation_labels, config.hidden_size + entity_label_embedding, att_hidden , device) self.dropout = nn.Dropout(prop_drop) self._relation_labels = relation_labels self._entity_labels = entity_labels # weight initialization self.init_weights() if freeze_transformer: print("Freeze transformer weights") # freeze all transformer weights for param in self.bert.parameters(): param.requires_grad = False def _forward_token(self, h: torch.tensor, token_mask: torch.tensor, gold_seq: torch.tensor, entity_mask: torch.tensor): num_steps = gold_seq.shape[-1] word_h = h.repeat(token_mask.shape[0], 1, 1) * token_mask.unsqueeze(-1) word_h_pooled = word_h.max(dim=1)[0] word_h_pooled = word_h_pooled[:num_steps+2].contiguous() word_h_pooled[0,:] = 0 # curr word repr. curr_word_repr = word_h_pooled[1:-1].contiguous() # prev entity repr. prev_entity = torch.tril(entity_mask, diagonal=0) prev_entity_h = word_h_pooled.repeat(prev_entity.shape[0], 1, 1) * prev_entity.unsqueeze(-1) prev_entity_pooled = prev_entity_h.max(dim=1)[0] prev_entity_pooled = prev_entity_pooled[:num_steps].contiguous() # prev_label_embedding. prev_seq = torch.cat([torch.tensor([0]).to(self._device), gold_seq]) prev_label = self.entity_label_embedding(prev_seq[:-1]) entity_repr = torch.cat([curr_word_repr - 1, prev_entity_pooled - 1, prev_label], dim=1).unsqueeze(0) entity_repr = self.dropout(entity_repr) curr_entity_logits = self.entity_classifier(entity_repr) return curr_word_repr, curr_entity_logits def _forward_relation(self, h: torch.tensor, entity_preds: torch.tensor, entity_mask: torch.tensor, is_eval: bool = False): entity_labels = entity_preds.unsqueeze(0) # entity repr. masks_no_cls_rep = entity_mask[1:-1, 1:-1] entity_repr = h.repeat(masks_no_cls_rep.shape[-1], 1, 1) * masks_no_cls_rep.unsqueeze(-1) entity_repr_pool = entity_repr.max(dim=1)[0] #entity_label repr. entity_label_embeddings = self.entity_label_embedding(entity_labels) # entity_label_embeddings = torch.matmul(entity_preds, self.entity_label_embedding.weight) entity_label_repr = entity_label_embeddings.repeat(masks_no_cls_rep.shape[-1], 1, 1) * masks_no_cls_rep.unsqueeze(-1) entity_label_pool = entity_label_repr.max(dim=1)[0] rel_embedding = torch.cat([entity_repr_pool.unsqueeze(0) - 1, entity_label_pool.unsqueeze(0)], dim=2) rel_embedding = self.dropout(rel_embedding) rel_logits = self.rel_classifier(rel_embedding, rel_embedding, rel_embedding) return rel_logits def _forward_train(self, encodings: torch.tensor, context_mask: torch.tensor, token_mask: torch.tensor, gold_entity: torch.tensor, entity_masks: List[torch.tensor], allow_rel: bool): # get contextualized token embeddings from last transformer layer context_mask = context_mask.float() h = self.bert(input_ids=encodings, attention_mask=context_mask)[0] + 1 batch_size = encodings.shape[0] all_entity_logits = [] all_rel_logits = [] for batch in range(batch_size): # every batch entity_mask = entity_masks[batch] word_h, curr_entity_logits = self._forward_token(h[batch], token_mask[batch], gold_entity[batch], entity_mask) entity_preds = torch.argmax(curr_entity_logits, dim=2) # entity_preds_soft = torch.softmax(curr_entity_logits, dim=2) diag_entity_mask = torch.zeros_like(entity_mask, dtype=torch.bool).to(self._device).fill_diagonal_(1) all_entity_logits.append(curr_entity_logits) # Relation classification. num_steps = gold_entity[batch].shape[-1] word_h = h[batch].repeat(token_mask[batch].shape[0], 1, 1) * token_mask[batch].unsqueeze(-1) word_h_pooled = word_h.max(dim=1)[0] word_h_pooled = word_h_pooled[:num_steps+2].contiguous() # curr word repr. curr_word_repr = word_h_pooled[1:-1].contiguous() # curr_rel_logits = self._forward_relation(curr_word_repr, entity_preds_soft , diag_entity_mask) # curr_rel_logits = self._forward_relation(curr_word_repr, entity_preds.squeeze(0) , diag_entity_mask) curr_rel_logits = self._forward_relation(curr_word_repr, gold_entity[batch] , entity_masks[batch]) all_rel_logits.append(curr_rel_logits) if allow_rel: return all_entity_logits, all_rel_logits else: return all_entity_logits, [] def _forward_eval(self, encodings: torch.tensor, context_mask: torch.tensor, token_mask: torch.tensor, gold_entity: List[torch.tensor], gold_entity_mask: List[torch.tensor]): context_mask = context_mask.float() h = self.bert(input_ids=encodings, attention_mask=context_mask)[0] + 1 batch_size = encodings.shape[0] all_entity_logits = [] all_entity_scores = [] all_entity_preds = [] all_rel_logits = [] for batch in range(batch_size): # every batch num_steps = token_mask[batch].sum(axis=1).nonzero().shape[0] - 2 word_h = h[batch].repeat(token_mask[batch].shape[0], 1, 1) * token_mask[batch].unsqueeze(-1) word_h_pooled = word_h.max(dim=1)[0] word_h_pooled = word_h_pooled[:num_steps+2].contiguous() word_h_pooled[0,:] = 0 # curr word repr. curr_word_reprs = word_h_pooled[1:-1].contiguous() entity_masks = torch.zeros((num_steps + 2, num_steps + 2), dtype = torch.bool).fill_diagonal_(1).to(self._device) # diag_entity_mask = torch.zeros((num_steps + 2, num_steps + 2), dtype = torch.bool).fill_diagonal_(1).to(self._device) entity_preds = torch.zeros((num_steps + 1, 1), dtype=torch.long).to(self._device) entity_logits = [] entity_scores = torch.zeros((num_steps, 1), dtype=torch.float).to(self._device) # Entity classification. for i in range(num_steps): # no [CLS], no [SEP] # curr word repr. curr_word_repr = curr_word_reprs[i].unsqueeze(0) # mask from previous entity token until current position. prev_mask = entity_masks[i, :] prev_label_repr = self.entity_label_embedding(entity_preds[i]) prev_entity = word_h_pooled.unsqueeze(0) * prev_mask.unsqueeze(-1) prev_entity_pooled = prev_entity.max(dim=1)[0] curr_entity_repr = torch.cat([curr_word_repr - 1, prev_entity_pooled - 1, prev_label_repr], dim=1).unsqueeze(0) curr_entity_logits = self.entity_classifier(curr_entity_repr) entity_logits.append(curr_entity_logits.squeeze(1)) curr_label = curr_entity_logits.argmax(dim=2).squeeze(0) # print(i, curr_entity_logits, torch.softmax(curr_entity_logits, dim=2)) entity_scores[i] += torch.softmax(curr_entity_logits, dim=2).max(dim=2)[0].squeeze(0) entity_preds[i+1] = curr_label istart = (curr_label % 4 == 1) | (curr_label % 4 == 2) | (curr_label == 0) # update entity mask for the next time step entity_masks[i+1] += (~istart) * prev_mask # update entity span info for all time-steps entity_masks[prev_mask.nonzero()[0].item():i+1, i+1] += (~istart).squeeze(0) all_entity_logits.append(torch.stack(entity_logits, dim=1)) all_entity_scores.append(torch.t(entity_scores.squeeze(-1))) all_entity_preds.append(torch.t(entity_preds[1:].squeeze(-1))) # print(entity_preds.shape) # print(gold_entity[batch]) # exit(0) # Relation classification. # curr_rel_logits = self._forward_relation(curr_word_reprs, torch.stack(entity_logits, dim=1), entity_masks , True) curr_rel_logits = self._forward_relation(curr_word_reprs, entity_preds[1:].squeeze(-1), entity_masks, True) # curr_rel_logits = self._forward_relation(curr_word_reprs, gold_entity[batch], gold_entity_mask[batch], True) all_rel_logits.append(curr_rel_logits) return all_entity_logits, all_entity_scores, all_entity_preds, all_rel_logits def forward(self, *args, evaluate=False, **kwargs): if not evaluate: return self._forward_train(*args, **kwargs) else: return self._forward_eval(*args, **kwargs)
class SynFueBERT(BertPreTrainedModel): """ Span-based model to jointly extract terms and relations """ VERSION = '1.1' def __init__(self, config: BertConfig, cls_token: int, relation_types: int, term_types: int, size_embedding: int, prop_drop: float, freeze_transformer: bool, args, max_pairs: int = 100, beta: float = 0.3, alpha: float = 1.0, sigma: float = 1.0): super(SynFueBERT, self).__init__(config) # BERT model self.bert = BertModel(config) self.SynFue = Encoder.SynFueEncoder(self.bert, opt=args) self.cc = cross_attn.CA_module(config.hidden_size, config.hidden_size, 1, dropout=1.0) # layers self.rel_classifier = nn.Linear( config.hidden_size * 6 + size_embedding * 2, relation_types) self.rel_classifier3 = nn.Linear( config.hidden_size * 6 + size_embedding * 3, relation_types) self.term_classifier = nn.Linear( config.hidden_size * 8 + size_embedding, term_types) self.dep_linear = nn.Linear(config.hidden_size, relation_types) self.size_embeddings = nn.Embedding(100, size_embedding) self.dropout = nn.Dropout(prop_drop) self._cls_token = cls_token self._relation_types = relation_types self._term_types = term_types self._max_pairs = max_pairs self._beta = beta self._alpha = alpha self._sigma = sigma # weight initialization self.init_weights() if freeze_transformer: print("Freeze transformer weights") # freeze all transformer weights for param in self.bert.parameters(): param.requires_grad = False def _forward_train(self, encodings: torch.tensor, context_masks: torch.tensor, term_masks: torch.tensor, term_sizes: torch.tensor, term_spans: torch.tensor, term_types: torch.tensor, relations: torch.tensor, rel_masks: torch.tensor, simple_graph: torch.tensor, graph: torch.tensor, relations3: torch.tensor, rel_masks3: torch.tensor, pair_mask: torch.tensor, pos: torch.tensor = None): # get contextualized token embeddings from last transformer layer context_masks = context_masks.float() h, dep_output = self.SynFue(input_ids=encodings, input_masks=context_masks, simple_graph=simple_graph, graph=graph, pos=pos) batch_size = encodings.shape[0] # classify terms size_embeddings = self.size_embeddings( term_sizes) # embed term candidate sizes term_clf, term_spans_pool = self._classify_terms( encodings, h, term_masks, size_embeddings) # classify relations h_large = h.unsqueeze(1).repeat( 1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1) rel_clf = torch.zeros( [batch_size, relations.shape[1], self._relation_types]).to(self.rel_classifier.weight.device) # get span representation # dep_output = [batch size, seq_len, seq_len, feat_dim] -> [batch size, span num, span num, feat_dim] span_repr, mapping_list = self.get_span_repr(term_spans, term_types, dep_output) cross_attn_span = self.cc( span_repr) # batch size, seq_len, seq_len, feat_dim # obtain relation logits # chunk processing to reduce memory usage for i in range(0, relations.shape[1], self._max_pairs): # classify relation candidates chunk_rel_logits, chunk_rel_clf3, chunk_dep_score = self._classify_relations( cross_attn_span, term_spans_pool, size_embeddings, relations, rel_masks, h_large, i, relations3, rel_masks3, pair_mask, mapping_list) # apply sigmoid chunk_rel_clf = torch.sigmoid(chunk_rel_logits) chunk_rel_clf = self._alpha * chunk_rel_clf + self._beta * chunk_rel_clf3 + self._sigma * chunk_dep_score rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf max_clf = torch.full_like(rel_clf, torch.max(rel_clf).item()) min_clf = torch.full_like(rel_clf, torch.min(rel_clf).item()) inifite = torch.full_like(rel_clf, 1e-18) rel_clf = torch.div(rel_clf - min_clf + inifite, max_clf - min_clf + inifite) return term_clf, rel_clf def _forward_eval(self, encodings: torch.tensor, context_masks: torch.tensor, term_masks: torch.tensor, term_sizes: torch.tensor, term_spans: torch.tensor, term_sample_masks: torch.tensor, simple_graph: torch.tensor, graph: torch.tensor, pos: torch.tensor = None): # get contextualized token embeddings from last transformer layer context_masks = context_masks.float() h, dep_output = self.SynFue(input_ids=encodings, input_masks=context_masks, simple_graph=simple_graph, graph=graph, pos=pos) batch_size = encodings.shape[0] ctx_size = context_masks.shape[-1] # classify terms size_embeddings = self.size_embeddings( term_sizes) # embed term candidate sizes term_clf, term_spans_pool = self._classify_terms( encodings, h, term_masks, size_embeddings) # ignore term candidates that do not constitute an actual term for relations (based on classifier) relations, rel_masks, rel_sample_masks, relations3, rel_masks3, \ rel_sample_masks3, pair_mask, span_repr, mapping_list = self._filter_spans(term_clf, term_spans, term_sample_masks, ctx_size, dep_output) rel_sample_masks = rel_sample_masks.float().unsqueeze(-1) # h = self.rel_bert(input_ids=encodings, attention_mask=context_masks)[0] h_large = h.unsqueeze(1).repeat( 1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1) rel_clf = torch.zeros( [batch_size, relations.shape[1], self._relation_types]).to(self.rel_classifier.weight.device) # get span representation cross_attn_span = self.cc( span_repr) # batch size, seq_len, seq_len, feat_dim # obtain relation logits # chunk processing to reduce memory usage for i in range(0, relations.shape[1], self._max_pairs): # classify relation candidates chunk_rel_logits, chunk_rel_clf3, chunk_dep_score = self._classify_relations( cross_attn_span, term_spans_pool, size_embeddings, relations, rel_masks, h_large, i, relations3, rel_masks3, pair_mask, mapping_list) # apply sigmoid chunk_rel_clf = torch.sigmoid(chunk_rel_logits) chunk_rel_clf = self._alpha * chunk_rel_clf + self._beta * chunk_rel_clf3 + self._sigma * chunk_dep_score rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf max_clf = torch.full_like(rel_clf, torch.max(rel_clf).item()) min_clf = torch.full_like(rel_clf, torch.min(rel_clf).item()) inifite = torch.full_like(rel_clf, 1e-18) rel_clf = torch.div(rel_clf - min_clf + inifite, max_clf - min_clf + inifite) rel_clf = rel_clf * rel_sample_masks # mask # apply softmax term_clf = torch.softmax(term_clf, dim=2) return term_clf, rel_clf, relations def _classify_terms(self, encodings, h, term_masks, size_embeddings): # max pool term candidate spans m = (term_masks.unsqueeze(-1) == 0).float() * (-1e30) term_spans_pool = m + h.unsqueeze(1).repeat(1, term_masks.shape[1], 1, 1) term_spans_pool = term_spans_pool.max(dim=2)[0] # get cls token as candidate context representation term_ctx = get_token(h, encodings, self._cls_token) # get head and tail token representation m = term_masks.to(dtype=torch.long) k = torch.tensor(np.arange(0, term_masks.size(-1)), dtype=torch.long) k = k.unsqueeze(0).unsqueeze(0).repeat(term_masks.size(0), term_masks.size(1), 1).to(m.device) mk = torch.mul(m, k) # element-wise multiply mk_max = torch.argmax(mk, dim=-1, keepdim=True) mk_min = torch.argmin(mk, dim=-1, keepdim=True) mk = torch.cat([mk_min, mk_max], dim=-1) head_tail_rep = get_head_tail_rep( h, mk) # [batch size, term_num, bert_dim*2) # create candidate representations including context, max pooled span and size embedding term_repr = torch.cat([ term_ctx.unsqueeze(1).repeat(1, term_spans_pool.shape[1], 1), term_spans_pool, size_embeddings, head_tail_rep ], dim=2) term_repr = self.dropout(term_repr) # classify term candidates term_clf = self.term_classifier(term_repr) return term_clf, term_spans_pool def _classify_relations(self, spans_matrix, term_spans_repr, size_embeddings, relations, rel_masks, h, chunk_start, relations3, rel_masks3, pair_mask, rel_to_span): batch_size = relations.shape[0] feat_dim = spans_matrix.size(-1) # create chunks if necessary if relations.shape[1] > self._max_pairs: relations = relations[:, chunk_start:chunk_start + self._max_pairs] rel_masks = rel_masks[:, chunk_start:chunk_start + self._max_pairs] h = h[:, :relations.shape[1], :] def get_span_idx(mapping_list, idx1, idx2): for x in mapping_list: if idx1 == x[0][0] and idx2 == x[0][1]: return x[1][0], x[1][1] batch_dep_score = [] for i in range(batch_size): rela = relations[i] dep_score_list = [] r_2_s = rel_to_span[i] for r in rela: i1, i2 = r[0].item(), r[1].item() idx1, idx2 = get_span_idx(r_2_s, i1, i2) try: feat = spans_matrix[i][idx1][idx2] except: print('Out of bundary', spans_matrix.size(), i, i1, i2) feat = torch.zeros(feat_dim) dep_socre = self.dep_linear(feat).item() dep_score_list.append([dep_socre]) batch_dep_score.append(dep_score_list) batch_dep_score = torch.sigmoid( torch.tensor(batch_dep_score).to(device=torch.device( 'cuda' if torch.cuda.is_available() else 'cpu'))) # get pairs of term candidate representations term_pairs = util.batch_index(term_spans_repr, relations) term_pairs = term_pairs.view(batch_size, term_pairs.shape[1], -1) # get corresponding size embeddings size_pair_embeddings = util.batch_index(size_embeddings, relations) size_pair_embeddings = size_pair_embeddings.view( batch_size, size_pair_embeddings.shape[1], -1) # relation context (context between term candidate pair) # mask non term candidate tokens m = ((rel_masks == 0).float() * (-1e30)).unsqueeze(-1) rel_ctx = m + h # max pooling rel_ctx = rel_ctx.max(dim=2)[0] # set the context vector of neighboring or adjacent term candidates to zero rel_ctx[rel_masks.to(torch.uint8).any(-1) == 0] = 0 # create relation candidate representations including context, max pooled term candidate pairs # and corresponding size embeddings rel_repr = torch.cat([rel_ctx, term_pairs, size_pair_embeddings], dim=2) rel_repr = self.dropout(rel_repr) # classify relation candidates chunk_rel_logits = self.rel_classifier(rel_repr) if relations3.shape[1] > self._max_pairs: relations3 = relations3[:, chunk_start:chunk_start + self._max_pairs] # rel_masks3 = rel_masks3[:, chunk_start:chunk_start + self._max_pairs] p_num = relations3.size(1) p_tris = relations3.size(2) relations3 = relations3.view(batch_size, -1, 3) # get three pairs candidata representations term_pairs3 = util.batch_index(term_spans_repr, relations3) term_pairs3 = term_pairs3.view(batch_size, term_pairs3.shape[1], -1) size_pair_embeddings3 = util.batch_index(size_embeddings, relations3) size_pair_embeddings3 = size_pair_embeddings3.view( batch_size, size_pair_embeddings3.shape[1], -1) rel_repr = torch.cat([term_pairs3, size_pair_embeddings3], dim=2) rel_repr = self.dropout(rel_repr) # classify relation candidates chunk_rel_logits3 = self.rel_classifier3(rel_repr) chunk_rel_clf3 = chunk_rel_logits3.view(batch_size, p_num, p_tris, -1) chunk_rel_clf3 = torch.sigmoid(chunk_rel_clf3) chunk_rel_clf3 = torch.sum(chunk_rel_clf3, dim=2) chunk_rel_clf3 = torch.sigmoid(chunk_rel_clf3) return chunk_rel_logits, chunk_rel_clf3, batch_dep_score def _filter_spans(self, term_clf, term_spans, term_sample_masks, ctx_size, token_repr): batch_size = term_clf.shape[0] feat_dim = token_repr.size(-1) term_logits_max = term_clf.argmax(dim=-1) * term_sample_masks.long( ) # get term type (including none) batch_relations = [] batch_rel_masks = [] batch_rel_sample_masks = [] batch_relations3 = [] batch_rel_masks3 = [] batch_rel_sample_masks3 = [] batch_pair_mask = [] batch_span_repr = [] batch_rel_to_span = [] for i in range(batch_size): rels = [] rel_masks = [] sample_masks = [] rels3 = [] rel_masks3 = [] sample_masks3 = [] span_repr = [] rel_to_span = [] # get spans classified as terms non_zero_indices = (term_logits_max[i] != 0).nonzero().view(-1) non_zero_spans = term_spans[i][non_zero_indices].tolist() non_zero_indices = non_zero_indices.tolist() # create relations and masks pair_mask = [] for idx1, (i1, s1) in enumerate(zip(non_zero_indices, non_zero_spans)): temp = [] for idx2, (i2, s2) in enumerate( zip(non_zero_indices, non_zero_spans)): if i1 != i2: rels.append((i1, i2)) rel_masks.append( sampling.create_rel_mask(s1, s2, ctx_size)) sample_masks.append(1) p_rels3 = [] p_masks3 = [] for i3, s3 in zip(non_zero_indices, non_zero_spans): if i1 != i2 and i1 != i3 and i2 != i3: p_rels3.append((i1, i2, i3)) p_masks3.append( sampling.create_rel_mask3( s1, s2, s3, ctx_size)) sample_masks3.append(1) if len(p_rels3) > 0: rels3.append(p_rels3) rel_masks3.append(p_masks3) pair_mask.append(1) else: rels3.append([(i1, i2, 0)]) rel_masks3.append([ sampling.create_rel_mask3( s1, s2, (0, 0), ctx_size) ]) pair_mask.append(0) rel_to_span.append([[i1, i2], [idx1, idx2]]) feat = torch.max( token_repr[i, s1[0]:s1[-1] + 1, s2[0]:s2[-1] + 1, :].contiguous().view( -1, feat_dim), dim=0)[0] temp.append(feat) span_repr.append(temp) if not rels: # case: no more than two spans classified as terms batch_relations.append(torch.tensor([[0, 0]], dtype=torch.long)) batch_rel_masks.append( torch.tensor([[0] * ctx_size], dtype=torch.bool)) batch_rel_sample_masks.append( torch.tensor([0], dtype=torch.bool)) batch_span_repr.append( torch.tensor([[[0] * feat_dim]], dtype=torch.float)) batch_rel_to_span.append([[[0, 0], [0, 0]]]) else: # case: more than two spans classified as terms batch_relations.append(torch.tensor(rels, dtype=torch.long)) batch_rel_masks.append(torch.stack(rel_masks)) batch_rel_sample_masks.append( torch.tensor(sample_masks, dtype=torch.bool)) batch_span_repr.append( torch.stack([torch.stack(x) for x in span_repr])) batch_rel_to_span.append(rel_to_span) if not rels3: batch_relations3.append( torch.tensor([[[0, 0, 0]]], dtype=torch.long)) batch_rel_masks3.append( torch.tensor([[0] * ctx_size], dtype=torch.bool)) batch_rel_sample_masks3.append( torch.tensor([0], dtype=torch.bool)) batch_pair_mask.append(torch.tensor([0], dtype=torch.bool)) else: max_tri = max([len(x) for x in rels3]) # print(max_tri) for idx, r in enumerate(rels3): r_len = len(r) if r_len < max_tri: rels3[idx].extend([rels3[idx][0]] * (max_tri - r_len)) rel_masks3[idx].extend([rel_masks3[idx][0]] * (max_tri - r_len)) batch_relations3.append(torch.tensor(rels3, dtype=torch.long)) batch_rel_masks3.append( torch.stack([torch.stack(x) for x in rel_masks3])) batch_rel_sample_masks3.append( torch.tensor(sample_masks3, dtype=torch.bool)) batch_pair_mask.append( torch.tensor(pair_mask, dtype=torch.bool)) # stack device = self.rel_classifier.weight.device batch_relations = util.padded_stack(batch_relations).to(device) batch_rel_masks = util.padded_stack(batch_rel_masks).to(device) batch_rel_sample_masks = util.padded_stack(batch_rel_sample_masks).to( device) batch_span_repr = util.padded_stack(batch_span_repr).to(device) batch_relations3 = util.padded_stack(batch_relations3).to(device) batch_rel_masks3 = util.padded_stack(batch_rel_masks3).to(device) batch_rel_sample_masks3 = util.padded_stack( batch_rel_sample_masks3).to(device) batch_pair_mask = util.padded_stack(batch_pair_mask).to(device) return batch_relations, batch_rel_masks, batch_rel_sample_masks, \ batch_relations3, batch_rel_masks3, batch_rel_sample_masks3, batch_pair_mask, batch_span_repr, batch_rel_to_span def get_span_repr(self, term_spans, term_types, token_repr): """ :param term_spans: [batch size, span_num, 2] :param term_types: [batch size, span_num] :param token_repr: [batch size, seq_len, seq_len, feat_dim] :return: [batch size, span_num, span_num, feat_dim] """ batch_size = term_spans.size(0) feat_dim = token_repr.size(-1) batch_span_repr = [] batch_mapping_list = [] for i in range(batch_size): span_repr = [] mapping_list = [] # get target spans as aspect term or opinion term non_zero_indices = (term_types[i] != 0).nonzero().view(-1) non_zero_spans = term_spans[i][non_zero_indices].tolist() non_zero_indices = non_zero_indices.tolist() for x1, (i1, s1) in enumerate(zip(non_zero_indices, non_zero_spans)): temp = [] for x2, (i2, s2) in enumerate(zip(non_zero_indices, non_zero_spans)): feat = torch.max( token_repr[i, s1[0]:s1[-1] + 1, s2[0]:s2[-1] + 1, :].contiguous().view( -1, feat_dim), dim=0)[0] temp.append(feat) mapping_list.append([[i1, i2], [x1, x2]]) span_repr.append(torch.stack(temp)) batch_span_repr.append(torch.stack(span_repr)) batch_mapping_list.append(mapping_list) device = self.rel_classifier.weight.device batch_span_repr = util.padded_stack(batch_span_repr).to(device) return batch_span_repr, batch_mapping_list def forward(self, *args, evaluate=False, **kwargs): if not evaluate: return self._forward_train(*args, **kwargs) else: return self._forward_eval(*args, **kwargs)
class DocumentBert(BertPreTrainedModel): def __init__(self, bert_model_config: BertConfig): super(DocumentBert, self).__init__(bert_model_config) self.bert_patent = BertModel(bert_model_config) self.bert_tsd = BertModel(bert_model_config) for param in self.bert_patent.parameters(): param.requires_grad = False for param in self.bert_tsd.parameters(): param.requires_grad = False self.bert_batch_size = self.bert_patent.config.bert_batch_size self.dropout_patent = torch.nn.Dropout( p=bert_model_config.hidden_dropout_prob) self.dropout_tsd = torch.nn.Dropout( p=bert_model_config.hidden_dropout_prob) self.lstm_patent = torch.nn.LSTM(bert_model_config.hidden_size, bert_model_config.hidden_size) self.lstm_tsd = torch.nn.LSTM(bert_model_config.hidden_size, bert_model_config.hidden_size) self.output = torch.nn.Linear(bert_model_config.hidden_size * 2, out_features=1) def forward(self, patent_batch: torch.Tensor, tsd_batch: torch.Tensor, device='cuda'): #patent bert_output_patent = torch.zeros( size=(patent_batch.shape[0], min(patent_batch.shape[1], self.bert_batch_size), self.bert_patent.config.hidden_size), dtype=torch.float, device=device) for doc_id in range(patent_batch.shape[0]): bert_output_patent[ doc_id][:self.bert_batch_size] = self.dropout_patent( self.bert_patent( patent_batch[doc_id][:self.bert_batch_size, 0], token_type_ids=patent_batch[doc_id] [:self.bert_batch_size, 1], attention_mask=patent_batch[doc_id] [:self.bert_batch_size, 2])[1]) output_patent, (_, _) = self.lstm_patent( bert_output_patent.permute(1, 0, 2)) last_layer_patent = output_patent[-1] #tsd bert_output_tsd = torch.zeros(size=(tsd_batch.shape[0], min(tsd_batch.shape[1], self.bert_batch_size), self.bert_tsd.config.hidden_size), dtype=torch.float, device=device) for doc_id in range(tsd_batch.shape[0]): bert_output_tsd[doc_id][:self.bert_batch_size] = self.dropout_tsd( self.bert_tsd( tsd_batch[doc_id][:self.bert_batch_size, 0], token_type_ids=tsd_batch[doc_id][:self.bert_batch_size, 1], attention_mask=tsd_batch[doc_id][:self.bert_batch_size, 2])[1]) output_tsd, (_, _) = self.lstm_tsd(bert_output_tsd.permute(1, 0, 2)) last_layer_tsd = output_tsd[-1] x = torch.cat([last_layer_patent, last_layer_tsd], dim=1) prediction = torch.nn.functional.sigmoid(self.output(x)) assert prediction.shape[0] == patent_batch.shape[0] return prediction def freeze_bert_encoder(self): for param in self.bert_patent.parameters(): param.requires_grad = False for param in self.bert_tsd.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert_patent.parameters(): param.requires_grad = True for param in self.bert_tsd.parameters(): param.requires_grad = True def unfreeze_bert_encoder_last_layers(self): for name, param in self.bert_patent.named_parameters(): if "encoder.layer.11" in name or "pooler" in name: param.requires_grad = True for name, param in self.bert_tsd.named_parameters(): if "encoder.layer.11" in name or "pooler" in name: param.requires_grad = True def unfreeze_bert_encoder_pooler_layer(self): for name, param in self.bert_patent.named_parameters(): if "pooler" in name: param.requires_grad = True for name, param in self.bert_tsd.named_parameters(): if "pooler" in name: param.requires_grad = True
class Bert(nn.Module): def __init__(self, config, num=0): super(Bert, self).__init__() model_config = BertConfig() model_config.vocab_size = config.vocab_size # 计算loss的方法 self.loss_method = config.loss_method self.multi_drop = config.multi_drop self.bert = BertModel(model_config) if config.requires_grad: for param in self.bert.parameters(): param.requires_grad = True self.dropout = nn.Dropout(config.hidden_dropout_prob) self.hidden_size = config.hidden_size[num] if self.loss_method in ['binary', 'focal_loss', 'ghmc']: self.classifier = nn.Linear(self.hidden_size, 1) else: self.classifier = nn.Linear(self.hidden_size, self.num_labels) self.classifier.apply(self._init_weights) self.bert.apply(self._init_weights) def _init_weights(self, module): """ Initialize the weights """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=0.02) def forward(self, input_ids=None, token_type_ids=None, attention_mask=None, labels=None): outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) pooled_output = outputs[1] out = None loss = 0 for i in range(self.multi_drop): output = self.dropout(pooled_output) if labels is not None: if i == 0: out = self.classifier(output) loss = compute_loss(out, labels, loss_method=self.loss_method) else: temp_out = self.classifier(output) temp_loss = compute_loss(temp_out, labels, loss_method=self.loss_method) out = out + temp_out loss = loss + temp_loss loss = loss / self.multi_drop out = out / self.multi_drop if self.loss_method in ['binary']: out = torch.sigmoid(out).flatten() return out, loss
class MyBertForSequenceClassification(BertPreTrainedModel): num_labels = 4 num_tasks = 20 def __init__(self, config): super(MyBertForSequenceClassification, self).__init__(config) self.num_labels = MyBertForSequenceClassification.num_labels self.num_tasks = MyBertForSequenceClassification.num_tasks self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) # 创建20个分类任务,每个任务共享输入: BertModel 的输出最后一层的 [CLS] 的 pooler_output # 但是源程序也说了,使用 [cls] 的 pooler_output is usually *not* a good summary # of the semantic content of the input, you're often better with averaging or pooling # the sequence of hidden-states for the whole input sequence. # module_list = [] # for _ in range(self.num_tasks): # module_list.append(nn.Linear(config.hidden_size, self.num_labels)) # self.classifier = nn.ModuleList(module_list) self.classifier = nn.ModuleList([ nn.Linear(config.hidden_size, self.num_labels) for _ in range(self.num_tasks) ]) self.init_weights() def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): """forward :param input_ids: :param labels: 给定的形式是 [batch, num_tasks] """ outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) # logits = [] # for i in range(self.num_tasks): # logits.append(self.classifier[i](pooled_output)) logits = [ self.classifier[i](pooled_output) for i in range(self.num_tasks) ] if labels is not None: loss_fct = nn.CrossEntropyLoss() # 这个要放在gpu 上,很容易遗忘,从而 loss.backward()的时候出错 loss = torch.tensor([0.]).to(device) for i in range(self.num_tasks): loss += loss_fct(logits[i], labels[:, i]) return loss else: # 用于 验证集和测试集 标签的预测, 维度是[num_tasks, batch, num_labels] logits = [logit.cpu().numpy() for logit in logits] return torch.tensor(logits) # 可以选择 冻结 BertModel 中的参数,也可以不冻结,在 multiLabels classification 中不冻结,不调用该函数即可。这里给出了一个冻结的示范 def freeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): for param in self.bert.parameters(): param.requires_grad = True
class SpET(BertPreTrainedModel): """ Span-based model to extract entities """ def __init__(self, config: BertConfig, cls_token: int, relation_types: int, entity_types: int, size_embedding: int, prop_drop: float, freeze_transformer: bool, max_pairs: int = 100, feature_enhancer: str = "pass"): super(SpET, self).__init__(config) # BERT model self.bert = BertModel(config) # layers self.feature_enhancer = fe.get_feature_enhancer(feature_enhancer)( config.hidden_size, config.hidden_size) self.entity_classifier = nn.Linear( config.hidden_size * 2 + size_embedding, entity_types) self.size_embeddings = nn.Embedding(100, size_embedding) self.dropout = nn.Dropout(prop_drop) self._cls_token = cls_token self._entity_types = entity_types self._max_pairs = max_pairs # weight initialization self.init_weights() if freeze_transformer or feature_enhancer not in { "pass", "transformer" }: print("Freeze transformer weights") # freeze all transformer weights for param in self.bert.parameters(): param.requires_grad = False def _forward_train(self, encodings: torch.tensor, context_mask: torch.tensor, entity_masks: torch.tensor, entity_sizes: torch.tensor): # get contextualized token embeddings from last transformer layer context_mask = context_mask.float() h_bert = self.bert(input_ids=encodings, attention_mask=context_mask)[0] # enhance hidden features orig_shape = h_bert.shape h = self.feature_enhancer.prepare_input(h_bert, context_mask) h = self.feature_enhancer(h) h = self.feature_enhancer.prepare_output(h, orig_shape) entity_masks = entity_masks.float() batch_size = encodings.shape[0] # classify entities size_embeddings = self.size_embeddings( entity_sizes) # embed entity candidate sizes entity_clf, entity_spans_pool = self._classify_entities( encodings, h, entity_masks, size_embeddings) return entity_clf def _forward_eval(self, encodings: torch.tensor, context_mask: torch.tensor, entity_masks: torch.tensor, entity_sizes: torch.tensor, entity_spans: torch.tensor = None, entity_sample_mask: torch.tensor = None): # get contextualized token embeddings from last transformer layer context_mask = context_mask.float() h = self.bert(input_ids=encodings, attention_mask=context_mask)[0] # enhance hidden features orig_shape = h.shape h = self.feature_enhancer.prepare_input(h, context_mask) h = self.feature_enhancer(h) h = self.feature_enhancer.prepare_output(h, orig_shape) entity_masks = entity_masks.float() batch_size = encodings.shape[0] ctx_size = context_mask.shape[-1] # classify entities size_embeddings = self.size_embeddings( entity_sizes) # embed entity candidate sizes entity_clf, entity_spans_pool = self._classify_entities( encodings, h, entity_masks, size_embeddings) # apply softmax entity_clf = torch.softmax(entity_clf, dim=2) return entity_clf def _classify_entities(self, encodings, h, entity_masks, size_embeddings): # max pool entity candidate spans entity_spans_pool = entity_masks.unsqueeze(-1) * h.unsqueeze(1) entity_spans_pool = entity_spans_pool.max(dim=2)[0] # get cls token as candidate context representation entity_ctx = get_token(h, encodings, self._cls_token) # create candidate representations including context, max pooled span and size embedding entity_repr = torch.cat([ entity_ctx.unsqueeze(1).repeat(1, entity_spans_pool.shape[1], 1), entity_spans_pool, size_embeddings ], dim=2) entity_repr = self.dropout(entity_repr) # classify entity candidates entity_clf = self.entity_classifier(entity_repr) return entity_clf, entity_spans_pool def forward(self, *args, evaluate=False, **kwargs): if not evaluate: return self._forward_train(*args, **kwargs) else: return self._forward_eval(*args, **kwargs)
class ExampleIntentBertModel(torch.nn.Module): def __init__(self, model_name_or_path: str, dropout: float, num_intent_labels: int, use_observers: bool = False): super(ExampleIntentBertModel, self).__init__() #self.bert_model = BertModel.from_pretrained(model_name_or_path) self.bert_model = BertModel( BertConfig.from_pretrained(model_name_or_path, output_attentions=True)) self.dropout = Dropout(dropout) self.num_intent_labels = num_intent_labels self.use_observers = use_observers self.all_outputs = [] def encode(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids: torch.tensor): extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze( 2).repeat(1, 1, input_ids.size(1), 1) extended_attention_mask = extended_attention_mask.to( dtype=next(self.bert_model.parameters()).dtype) # Combine attention maps padding = (input_ids.unsqueeze(1) == 0).unsqueeze(-1) padding = padding.repeat(1, 1, 1, padding.size(-2)) extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 embedding_output = self.bert_model.embeddings( input_ids, position_ids=None, token_type_ids=token_type_ids) encoder_outputs = self.bert_model.encoder( embedding_output, extended_attention_mask, head_mask=[None] * self.bert_model.config.num_hidden_layers) if encoder_outputs[0].size(0) == 1: pass #self.all_outputs.append(torch.cat(encoder_outputs[1], dim=0).cpu()) #self.all_outputs.append(encoder_outputs[0][:, -20:].cpu()) sequence_output = encoder_outputs[0] if self.use_observers: pooled_output = sequence_output[:, -20:].mean(dim=1) else: pooled_output = self.bert_model.pooler(sequence_output) return pooled_output def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids: torch.tensor, intent_label: torch.tensor, example_input: torch.tensor, example_mask: torch.tensor, example_token_types: torch.tensor, example_intents: torch.tensor): example_pooled_output = self.encode(input_ids=example_input, attention_mask=example_mask, token_type_ids=example_token_types) pooled_output = self.encode(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) pooled_output = self.dropout(pooled_output) probs = torch.softmax(pooled_output.mm(example_pooled_output.t()), dim=-1) intent_probs = 1e-6 + torch.zeros( probs.size(0), self.num_intent_labels).cuda().scatter_add( -1, example_intents.unsqueeze(0).repeat(probs.size(0), 1), probs) # Compute losses if labels provided if intent_label is not None: loss_fct = NLLLoss() intent_lp = torch.log(intent_probs) intent_loss = loss_fct(intent_lp.view(-1, self.num_intent_labels), intent_label.type(torch.long)) else: intent_loss = torch.tensor(0) return intent_probs, intent_loss
class BertClassification(BertPreTrainedModel): def __init__(self, config, freeze_bert = False): super(BertClassification, self).__init__(config) self.hidden_size = config.hidden_size #self.lstm_hidden_size = 256 self.hidden_dropout_prob = config.hidden_dropout_prob self.num_labels = config.num_labels self.bert = BertModel(config) if freeze_bert: for p in self.bert.parameters(): p.requires_grad = False self.dropout = nn.Dropout(self.hidden_dropout_prob) #self.bilstm = nn.LSTM(self.hidden_size, self.lstm_hidden_size, bidirectional=True, batch_first=True) self.classifier = nn.Linear(self.hidden_size, self.num_labels) #self.classifier = nn.Linear(self.hidden_size, self.num_labels) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) #pooled_output = outputs[1] hidden_states = outputs[0] pooled_output = hidden_states.mean(-2) pooled_output = self.dropout(pooled_output) #hidden_states = self.dropout(hidden_states) #lstm_hidden_states, _ = self.bilstm(hidden_states) #lstm_hidden_states = self.dropout(lstm_hidden_states) #pooled_output = lstm_hidden_states.mean(-2) logits = self.classifier(pooled_output) loss = None if labels is not None: if self.num_labels == 1: # We are doing regression loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output
class BertSum(pl.LightningModule): def __init__(self, conf=None): super().__init__() # save conf, accessible in self.hparams.conf self.save_hyperparameters() # MODEL self.bert = BertModel.from_pretrained('bert-base-uncased') # change hidden layers from 12 to 10 (memory limit) bert_config = BertConfig(self.bert.config.vocab_size, num_hidden_layers=10) self.bert = BertModel(bert_config) # change embeddings to enable longer input sequences pos_embeddings = nn.Embedding(self.hparams.conf.dataset.setup.max_length_input_context, self.bert.config.hidden_size) pos_embeddings.weight.data[:512] = self.bert.embeddings.position_embeddings.weight.data pos_embeddings.weight.data[512:] = self.bert.embeddings.position_embeddings.weight.data[-1][None, :].repeat(self.hparams.conf.dataset.setup.max_length_input_context - 512, 1) self.bert.embeddings.position_embeddings = pos_embeddings # classification layers self.linear1 = nn.Linear(self.bert.config.hidden_size, 1) self.sigmoid = nn.Sigmoid() # TODO: model to encode answer (and fix) # metrics self.evaluation_metrics = torch.nn.ModuleDict({ 'train_metrics': self._init_metrics(mode='train'), 'val_metrics': self._init_metrics(mode='val'), 'test_metrics': self._init_metrics(mode='test'), }) # loss self.loss = nn.BCEWithLogitsLoss(reduction='sum') def forward(self, x, **kwargs): return self.bert(x, **kwargs) # optimizer def configure_optimizers(self): params = list(self.bert.parameters()) + list(self.linear1.parameters()) return torch.optim.Adam(params, lr=self.hparams.conf.training.lr) # TRAIN def training_step(self, batch, batch_idx): loss, metrics = self._get_loss(batch) self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) metrics = self._log_metrics(metrics, mode='train') return {'loss': loss, **metrics} def validation_step(self, batch, batch_idx): loss, metrics = self._get_loss(batch, mode='val') self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) metrics = self._log_metrics(metrics, mode='val') return {'val_loss': loss, **metrics} def _get_loss(self, batch, mode='train'): src_ids, src_mask, seg_ids, seg_idx, seg_idx_mask, tgt_labels = batch['input_ids'], batch['input_attention_mask'], batch['segment_ids'], batch['segment_idx'], batch['segment_idx_mask'], batch['target_labels'] position_ids = torch.tensor(list(range(self.hparams.conf.dataset.setup.max_length_input_context))).to('cuda') # bert last_hidden_state, pooler_output = self.bert(src_ids, attention_mask=src_mask, token_type_ids=seg_ids, position_ids=position_ids) # select sentence representation embeddings seg_idx = seg_idx.unsqueeze(dim=2).repeat(1, 1, last_hidden_state.shape[-1]) sent_embeddings = last_hidden_state.gather(dim=1, index=seg_idx) # filter mask mask_idx = torch.nonzero(seg_idx_mask, as_tuple=True) sent_embeddings = sent_embeddings[mask_idx[0], mask_idx[1], :] tgt_labels = tgt_labels[mask_idx[0], mask_idx[1]] # classifier logits = self.linear1(sent_embeddings) logits = logits.squeeze().float() # loss tgt_labels = tgt_labels.float() loss = self.loss(logits, tgt_labels) loss = loss / torch.sum(seg_idx_mask) # metrics preds = self.sigmoid(logits) preds = torch.stack([1 - preds, preds], dim=1) metrics = self._get_metrics(preds, tgt_labels, mode) return loss, metrics def validation_epoch_end(self, outputs): metrics = self._compute_metrics(mode='val') self._log_precision_recall_curve(metrics) self._log_confusion_matrix(metrics) # TEST def test_step(self, batch, batch_idx): loss, metrics = self._get_loss(batch, mode='test') self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) metrics = self._log_metrics(metrics, mode='test') return {'test_loss': loss, **metrics} def test_epoch_end(self, outputs): metrics = self._compute_metrics(mode='test') self._log_confusion_matrix(metrics) # progress bar def get_progress_bar_dict(self): tqdm_dict = super().get_progress_bar_dict() if 'v_num' in tqdm_dict: del tqdm_dict['v_num'] return tqdm_dict # METRICS def _log_metrics(self, metrics, mode): metrics = {key + (f'_{mode}' if mode != 'train' else ''): value for key, value in metrics.items() if key != 'confusion_matrix' and key != 'precision_recall_curve'} for k, m in metrics.items(): self.log(k, m, on_step=False, on_epoch=True, prog_bar=True, logger=True) return metrics def _log_precision_recall_curve(self, metrics): precision, recall, thresholds = metrics['precision_recall_curve'] plt.plot(recall.cpu().numpy(), precision.cpu().numpy(), 'ro') plt.xlabel('recall') plt.ylabel('precision') self.logger.experiment.log({f'precision_recall_curve_{self.current_epoch}': wandb.Image(plt)}) plt.clf() f1 = 2 * (precision * recall) / (precision + recall) data = { 'precision': precision.cpu().numpy().tolist(), 'recall': recall.cpu().numpy().tolist(), 'f1': f1.cpu().numpy().tolist(), 'thresholds': thresholds.cpu().numpy().tolist(), 'argmax': torch.argmax(f1).cpu().numpy().tolist() } with open(f'precision_recall_{self.current_epoch}', 'wb') as file: pickle.dump(data, file) def _log_confusion_matrix(self, metrics): confusion_matrix = metrics['confusion_matrix'] heatmap = sns.heatmap(confusion_matrix.cpu().numpy(), annot=True, fmt='g') figure = heatmap.get_figure() self.logger.experiment.log({f'confusion_matrix_{self.current_epoch}': wandb.Image(figure)}) plt.clf() def _get_metrics(self, prediction, target, mode='train'): metrics = {} for name, metric in self.evaluation_metrics[mode + '_metrics'].items(): metrics[name] = metric(prediction, target) return metrics def _compute_metrics(self, mode='train'): metrics = {} for name, metric in self.evaluation_metrics[mode + '_metrics'].items(): metrics[name] = metric.compute() return metrics @staticmethod def _init_metrics(mode): metrics = torch.nn.ModuleDict({ 'accuracy': pl.metrics.Accuracy(), 'f1': F1(), 'precision': Precision(), 'recall': Recall(), }) if mode != 'train': metrics['confusion_matrix'] = pl.metrics.ConfusionMatrix(num_classes=2) metrics['precision_recall_curve'] = pl.metrics.PrecisionRecallCurve(pos_label=1) return metrics
class HIBERT(BertPreTrainedModel): def __init__(self, config, n_classes, add_linear=None, attn_bias=False, freeze_layer_count=-1): super(HIBERT, self).__init__(config) self.n_classes = n_classes self.add_linear = add_linear self.attn_bias = attn_bias self.freeze_layer_count = freeze_layer_count self.attn_weights = None # Define model objects self.bert = BertModel(config, add_pooling_layer=False) self.fc_in_size = self.bert.config.hidden_size # Control layer freezing if freeze_layer_count == -1: # freeze all bert layers for param in self.bert.parameters(): param.requires_grad = False if freeze_layer_count == -2: # unfreeze all bert layers for param in self.bert.parameters(): param.requires_grad = True if freeze_layer_count > 0: # freeze embedding layer for param in self.bert.embeddings.parameters(): param.requires_grad = False # freeze the top `freeze_layer_count` encoder layers for layer in self.bert.encoder.layer[:freeze_layer_count]: for param in layer.parameters(): param.requires_grad = False # Attention pooling layer self.attention = Attention(dim=self.bert.config.hidden_size, attn_bias=self.attn_bias) # fully connected layers if self.add_linear is None: self.fc = nn.ModuleList([nn.Linear(self.fc_in_size, self.n_classes)]) else: self.fc_layers = [self.fc_in_size] + self.add_linear self.fc = nn.ModuleList([ LinearBlock(self.fc_layers[i], self.fc_layers[i+1]) for i in range(len(self.fc_layers) - 1) ]) # no relu after last dense (cannot use LinearBlock) self.fc.append(nn.Linear(self.fc_layers[-1], self.n_classes)) def forward(self, input_ids, attention_mask, n_chunks): # Bert transformer (take sequential output) output, _ = self.bert( input_ids = input_ids, attention_mask = attention_mask, return_dict=False ) # group chunks together chunks = output.split_with_sizes(n_chunks.tolist()) # loop through attention layer (need a loop as there are different sized chunks) # collect attention output and attention weights for each call of attention after_attn_list = [] self.attn_weights = [] for chunk in chunks: after_attn_list.append(self.attention(chunk.view(1, -1, self.bert.config.hidden_size))) self.attn_weights.append(self.attention.attn_weights) output = torch.cat(after_attn_list) # fully connected layers for fc in self.fc: output = fc(output) return output
class UnStructuredModel: def __init__(self, model_name, max_length, stride): self.model_name = model_name self.tokenizer = None self.model = None self.max_length = max_length self.stride = stride if model_name == 'bert-base-uncased': configuration = BertConfig() self.tokenizer = BertTokenizer.from_pretrained(self.model_name) self.model = BertModel(configuration).from_pretrained(self.model_name) self.model.to(device) self.model.eval() for param in self.model.parameters(): param.requires_grad = False #self.model.bert.embeddings.requires_grad = False def padTokens(self, tokens): if len(tokens)<self.max_length: tokens = tokens + ["[PAD]" for i in range(self.max_length - len(tokens))] return tokens def getEmbedding(self, text, if_pool=True, pooling_type="mean", batchsize = 1): tokens = self.tokenizer.tokenize(text) tokenized_array = self.tokenizeText(tokens) embeddingTensorsList = [] print(len(tokenized_array)) tensor = torch.zeros([1, 768], device=device) count = 0 if len(tokenized_array)>batchsize: for i in range(0, len(tokenized_array), batchsize): current_tokens = tokenized_array[i:min(i+batchsize,len(tokenized_array))] token_ids = torch.tensor(current_tokens).to(device) seg_ids=[[0 for _ in range(len(tokenized_array[0]))] for _ in range(len(current_tokens))] seg_ids = torch.tensor(seg_ids).to(device) hidden_reps, cls_head = self.model(token_ids, token_type_ids = seg_ids) cls_head.to(device) clas_head = cls_head.detach if if_pool and pooling_type=="mean": tensor = tensor.add(torch.sum(cls_head, dim=0)) count +=cls_head.shape[0] else: embeddingTensorsList.append(cls_head) del cls_head, hidden_reps if if_pool and pooling_type=="mean" and count>0: embedding = torch.div(tensor, count) elif not if_pool: embedding = torch.cat(embeddingTensorsList, dim=0) else: raise NotImplementedError() else: token_ids = torch.tensor(tokenized_array).to(device) seg_ids=[[0 for _ in range(len(tokenized_array[0]))] for _ in range(len(tokenized_array))] seg_ids = torch.tensor(seg_ids).to(device) hidden_reps, cls_head = self.model(token_ids, token_type_ids = seg_ids) cls_head.to(device) cls_head.requires_grad = False if if_pool and pooling_type=="mean": embedding = torch.div(torch.sum(cls_head, dim=0), cls_head.shape[0]) elif not if_pool: embedding = cls_head else: raise NotImplementedError() del cls_head, hidden_reps return embedding def tokenizeText(self, tokens): tokens_array = [] #window_movement_tokens = max_length - stride for i in range(0, len(tokens), self.stride): if i+self.max_length<len(tokens): curr_tokens = ["[CLS]"] + tokens[i:i+self.max_length] + ["[SEP]"] else: padded_tokens = self.padTokens(tokens[i:i+self.max_length]) curr_tokens = ["[CLS]"] + padded_tokens + ["[SEP]"] curr_tokens = self.tokenizer.convert_tokens_to_ids(curr_tokens) tokens_array.append(curr_tokens) return tokens_array
def train(config, bert_config, train_path, dev_path, rel2id, id2rel, tokenizer): if os.path.exists(config.output_dir) is False: os.makedirs(config.output_dir, exist_ok=True) if os.path.exists('./data/train_file.pkl'): train_data = pickle.load(open("./data/train_file.pkl", mode='rb')) else: train_data = data.load_data(train_path, tokenizer, rel2id, num_rels) pickle.dump(train_data, open("./data/train_file.pkl", mode='wb')) dev_data = json.load(open(dev_path)) for sent in dev_data: data.to_tuple(sent) data_manager = data.SPO(train_data) train_sampler = RandomSampler(data_manager) train_data_loader = DataLoader(data_manager, sampler=train_sampler, batch_size=config.batch_size, drop_last=True) num_train_steps = int( len(data_manager) / config.batch_size) * config.max_epoch if config.bert_pretrained_model is not None: logger.info('load bert weight') Bert_model = BertModel.from_pretrained(config.bert_pretrained_model, config=bert_config) else: logger.info('random initialize bert model') Bert_model = BertModel(config=bert_config).init_weights() Bert_model.to(device) submodel = sub_model(config).to(device) objmodel = obj_model(config).to(device) loss_fuc = nn.BCELoss(reduction='none') params = list(Bert_model.parameters()) + list( submodel.parameters()) + list(objmodel.parameters()) optimizer = AdamW(params, lr=config.lr) # Train! logger.info("***** Running training *****") logger.info(" Num examples = %d", len(data_manager)) logger.info(" Num Epochs = %d", config.max_epoch) logger.info(" Total train batch size = %d", config.batch_size) logger.info(" Total optimization steps = %d", num_train_steps) logger.info(" Logging steps = %d", config.print_freq) logger.info(" Save steps = %d", config.save_freq) global_step = 0 Bert_model.train() submodel.train() objmodel.train() for _ in range(config.max_epoch): optimizer.zero_grad() epoch_itorator = tqdm(train_data_loader, disable=None) for step, batch in enumerate(epoch_itorator): batch = tuple(t.to(device) for t in batch) input_ids, segment_ids, input_masks, sub_positions, sub_heads, sub_tails, obj_heads, obj_tails = batch bert_output = Bert_model(input_ids, input_masks, segment_ids)[0] pred_sub_heads, pred_sub_tails = submodel( bert_output) # [batch_size, seq_len, 1] pred_obj_heads, pred_obj_tails = objmodel(bert_output, sub_positions) # 计算loss mask = input_masks.view(-1) # loss1 sub_heads = sub_heads.unsqueeze(-1) # [batch_szie, seq_len, 1] sub_tails = sub_tails.unsqueeze(-1) loss1_head = loss_fuc(pred_sub_heads, sub_heads).view(-1) loss1_head = torch.sum(loss1_head * mask) / torch.sum(mask) loss1_tail = loss_fuc(pred_sub_tails, sub_tails).view(-1) loss1_tail = torch.sum(loss1_tail * mask) / torch.sum(mask) loss1 = loss1_head + loss1_tail # loss2 loss2_head = loss_fuc(pred_obj_heads, obj_heads).view(-1, obj_heads.shape[-1]) loss2_head = torch.sum( loss2_head * mask.unsqueeze(-1)) / torch.sum(mask) loss2_tail = loss_fuc(pred_obj_tails, obj_tails).view(-1, obj_tails.shape[-1]) loss2_tail = torch.sum( loss2_tail * mask.unsqueeze(-1)) / torch.sum(mask) loss2 = loss2_head + loss2_tail # optimize loss = loss1 + loss2 loss.backward() optimizer.step() optimizer.zero_grad() global_step += 1 if (global_step + 1) % config.print_freq == 0: logger.info( "epoch : {} step: {} #### loss1: {} loss2: {}".format( _, global_step + 1, loss1.cpu().item(), loss2.cpu().item())) if (global_step + 1) % config.eval_freq == 0: logger.info("***** Running evaluating *****") with torch.no_grad(): Bert_model.eval() submodel.eval() objmodel.eval() P, R, F1 = utils.metric(Bert_model, submodel, objmodel, dev_data, id2rel, tokenizer) logger.info(f'precision:{P}\nrecall:{R}\nF1:{F1}') Bert_model.train() submodel.train() objmodel.train() if (global_step + 1) % config.save_freq == 0: # Save a trained model model_name = "pytorch_model_%d" % (global_step + 1) output_model_file = os.path.join(config.output_dir, model_name) state = { 'bert_state_dict': Bert_model.state_dict(), 'subject_state_dict': submodel.state_dict(), 'object_state_dict': objmodel.state_dict(), } torch.save(state, output_model_file) model_name = "pytorch_model_last" output_model_file = os.path.join(config.output_dir, model_name) state = { 'bert_state_dict': Bert_model.state_dict(), 'subject_state_dict': submodel.state_dict(), 'object_state_dict': objmodel.state_dict(), } torch.save(state, output_model_file)
class BertForMultiLabelSequenceClassification(BertPreTrainedModel): """ Bert model adapted for multi-label sequence classification. Note that for imbalance problems will also provide an extra parameter to add inside the loss function to integrate the classes distribution. """ def __init__(self, config): super(BertForMultiLabelSequenceClassification, self).__init__(config) self.num_labels = config.num_labels self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.pos_weight = torch.Tensor(config.pos_weight).to(device) if config.use_pos_weight else None self.init_weights() def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None): """ :param input_ids: sentence or sentences represented as tokens :param attention_mask: tells the model which tokens in the input_ids are words and which are padding. 1 indicates a token and 0 indicates padding. :param token_type_ids: used when there are two sentences that need to be part of the input. It indicate which tokens are part of sentence1 and which are part of sentence2. :param position_ids: indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1] :param head_mask: mask to nullify selected heads of the self-attention modules :param labels: target for each input :return: """ outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, ) pooled_output = outputs[1] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) outputs = (logits,) + outputs[2:] if labels is not None: loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) labels = labels.float() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) outputs = (loss,) + outputs return outputs def freeze_bert_encoder(self): """Freeze BERT layers""" for param in self.bert.parameters(): param.requires_grad = False def unfreeze_bert_encoder(self): """Unfreeze BERT layers""" for param in self.bert.parameters(): param.requires_grad = True
class SpERT(BertPreTrainedModel): """ Span-based model to jointly extract entities and relations """ def __init__(self, config: BertConfig, cls_token: int, relation_types: int, entity_types: int, size_embedding: int, prop_drop: float, freeze_transformer: bool, max_pairs: int = 100): super(SpERT, self).__init__(config) # BERT model self.bert = BertModel(config) # layers self.rel_classifier = nn.Linear( config.hidden_size * 3 + size_embedding * 2, relation_types) self.entity_classifier = nn.Linear( config.hidden_size * 2 + size_embedding, entity_types) self.size_embeddings = nn.Embedding(100, size_embedding) self.dropout = nn.Dropout(prop_drop) self._cls_token = cls_token self._relation_types = relation_types self._entity_types = entity_types self._max_pairs = max_pairs # weight initialization self.init_weights() if freeze_transformer: print("Freeze transformer weights") # freeze all transformer weights for param in self.bert.parameters(): param.requires_grad = False def _forward_train(self, encodings: torch.tensor, context_masks: torch.tensor, entity_masks: torch.tensor, entity_sizes: torch.tensor, relations: torch.tensor, rel_masks: torch.tensor): # get contextualized token embeddings from last transformer layer context_masks = context_masks.float() h = self.bert(input_ids=encodings, attention_mask=context_masks)[0] entity_masks = entity_masks.float() batch_size = encodings.shape[0] # classify entities size_embeddings = self.size_embeddings( entity_sizes) # embed entity candidate sizes entity_clf, entity_spans_pool = self._classify_entities( encodings, h, entity_masks, size_embeddings) # classify relations rel_masks = rel_masks.float().unsqueeze(-1) h_large = h.unsqueeze(1).repeat( 1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1) rel_clf = torch.zeros( [batch_size, relations.shape[1], self._relation_types]).to(self.rel_classifier.weight.device) # obtain relation logits # chunk processing to reduce memory usage for i in range(0, relations.shape[1], self._max_pairs): # classify relation candidates chunk_rel_logits = self._classify_relations( entity_spans_pool, size_embeddings, relations, rel_masks, h_large, i) rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_logits return entity_clf, rel_clf def _forward_eval(self, encodings: torch.tensor, context_masks: torch.tensor, entity_masks: torch.tensor, entity_sizes: torch.tensor, entity_spans: torch.tensor, entity_sample_masks: torch.tensor): # get contextualized token embeddings from last transformer layer context_masks = context_masks.float() h = self.bert(input_ids=encodings, attention_mask=context_masks)[0] entity_masks = entity_masks.float() batch_size = encodings.shape[0] ctx_size = context_masks.shape[-1] # classify entities size_embeddings = self.size_embeddings( entity_sizes) # embed entity candidate sizes entity_clf, entity_spans_pool = self._classify_entities( encodings, h, entity_masks, size_embeddings) # ignore entity candidates that do not constitute an actual entity for relations (based on classifier) relations, rel_masks, rel_sample_masks = self._filter_spans( entity_clf, entity_spans, entity_sample_masks, ctx_size) rel_masks = rel_masks.float() rel_sample_masks = rel_sample_masks.float() h_large = h.unsqueeze(1).repeat( 1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1) rel_clf = torch.zeros( [batch_size, relations.shape[1], self._relation_types]).to(self.rel_classifier.weight.device) # obtain relation logits # chunk processing to reduce memory usage for i in range(0, relations.shape[1], self._max_pairs): # classify relation candidates chunk_rel_logits = self._classify_relations( entity_spans_pool, size_embeddings, relations, rel_masks, h_large, i) # apply sigmoid chunk_rel_clf = torch.sigmoid(chunk_rel_logits) rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf rel_clf = rel_clf * rel_sample_masks # mask # apply softmax entity_clf = torch.softmax(entity_clf, dim=2) return entity_clf, rel_clf, relations def _classify_entities(self, encodings, h, entity_masks, size_embeddings): # max pool entity candidate spans entity_spans_pool = entity_masks.unsqueeze(-1) * h.unsqueeze(1).repeat( 1, entity_masks.shape[1], 1, 1) entity_spans_pool = entity_spans_pool.max(dim=2)[0] # get cls token as candidate context representation entity_ctx = get_token(h, encodings, self._cls_token) # create candidate representations including context, max pooled span and size embedding entity_repr = torch.cat([ entity_ctx.unsqueeze(1).repeat(1, entity_spans_pool.shape[1], 1), entity_spans_pool, size_embeddings ], dim=2) entity_repr = self.dropout(entity_repr) # classify entity candidates entity_clf = self.entity_classifier(entity_repr) return entity_clf, entity_spans_pool def _classify_relations(self, entity_spans, size_embeddings, relations, rel_masks, h, chunk_start): batch_size = relations.shape[0] # create chunks if necessary if relations.shape[1] > self._max_pairs: relations = relations[:, chunk_start:chunk_start + self._max_pairs] rel_masks = rel_masks[:, chunk_start:chunk_start + self._max_pairs] h = h[:, :relations.shape[1], :] # get pairs of entity candidate representations entity_pairs = util.batch_index(entity_spans, relations) entity_pairs = entity_pairs.view(batch_size, entity_pairs.shape[1], -1) # get corresponding size embeddings size_pair_embeddings = util.batch_index(size_embeddings, relations) size_pair_embeddings = size_pair_embeddings.view( batch_size, size_pair_embeddings.shape[1], -1) # relation context (context between entity candidate pair) rel_ctx = rel_masks * h rel_ctx = rel_ctx.max(dim=2)[0] # create relation candidate representations including context, max pooled entity candidate pairs # and corresponding size embeddings rel_repr = torch.cat([rel_ctx, entity_pairs, size_pair_embeddings], dim=2) rel_repr = self.dropout(rel_repr) # classify relation candidates chunk_rel_logits = self.rel_classifier(rel_repr) return chunk_rel_logits def _filter_spans(self, entity_clf, entity_spans, entity_sample_masks, ctx_size): batch_size = entity_clf.shape[0] entity_logits_max = entity_clf.argmax( dim=-1) * entity_sample_masks.long( ) # get entity type (including none) batch_relations = [] batch_rel_masks = [] batch_rel_sample_masks = [] for i in range(batch_size): rels = [] rel_masks = [] sample_masks = [] # get spans classified as entities non_zero_indices = (entity_logits_max[i] != 0).nonzero().view(-1) non_zero_spans = entity_spans[i][non_zero_indices].tolist() non_zero_indices = non_zero_indices.tolist() # create relations and masks for i1, s1 in zip(non_zero_indices, non_zero_spans): for i2, s2 in zip(non_zero_indices, non_zero_spans): if i1 != i2: rels.append((i1, i2)) rel_masks.append( sampling.create_rel_mask(s1, s2, ctx_size)) sample_masks.append(1) if not rels: # case: no more than two spans classified as entities batch_relations.append(torch.tensor([[0, 0]], dtype=torch.long)) batch_rel_masks.append( torch.tensor([[0] * ctx_size], dtype=torch.bool)) batch_rel_sample_masks.append( torch.tensor([0], dtype=torch.bool)) else: # case: more than two spans classified as entities batch_relations.append(torch.tensor(rels, dtype=torch.long)) batch_rel_masks.append(torch.stack(rel_masks)) batch_rel_sample_masks.append( torch.tensor(sample_masks, dtype=torch.bool)) # stack device = self.rel_classifier.weight.device batch_relations = util.padded_stack(batch_relations).to(device) batch_rel_masks = util.padded_stack(batch_rel_masks).to( device).unsqueeze(-1) batch_rel_sample_masks = util.padded_stack(batch_rel_sample_masks).to( device).unsqueeze(-1) return batch_relations, batch_rel_masks, batch_rel_sample_masks def forward(self, *args, evaluate=False, **kwargs): if not evaluate: return self._forward_train(*args, **kwargs) else: return self._forward_eval(*args, **kwargs)
class BertLstmCrf(BertPreTrainedModel): def __init__(self, config, extra_config, ignore_ids): """ num_labels : int, required Number of tags. idx2tag : ``Dict[int, str]``, required A mapping {label_id -> label}. Example: {0:"B-LOC", 1:"I-LOC", 2:"O"} label_encoding : ``str``, required Indicates which constraint to apply. Current choices are "BIO", "IOB1", "BIOUL", "BMES" and "BIOES",. B = Beginning I/M = Inside / Middle L/E = Last / End O = Outside U/W/S = Unit / Whole / Single """ super(BertLstmCrf, self).__init__(config) self.pretraind = BertModel(config) self.dropout = nn.Dropout(extra_config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.bilstm = nn.LSTM(input_size=config.hidden_size, hidden_size=config.hidden_size // 2, batch_first=True, num_layers=extra_config.num_layers, dropout=extra_config.lstm_dropout, bidirectional=True) self.crf = crf(config.num_labels, extra_config.label_encoding, extra_config.idx2tag) self.init_weights() if extra_config.freez_prrtrained: for param in self.pretraind.parameters(): param.requires_grad = False self.ignore_ids = ignore_ids def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, labels=None): # outputs的组成: # last_hidden_state: Sequence of hidden-states at the output of the last layer of the model. # (batch_size, sequence_length, hidden_size) # pooler_output: Last layer hidden-state of the first token of the sequence (classification token) # processed by a Linear layer and a Tanh activation function. # hidden_states: one for the output of the embeddings + one for the output of each layer. # each is (batch_size, sequence_length, hidden_size) # attentions: Attentions weights after the attention softmax of each layer. # each is (batch_size, num_heads, sequence_length, sequence_length) outputs = self.pretraind(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask) last_hidden_state = outputs[0] seq_output = self.dropout(last_hidden_state) seq_output, _ = self.bilstm(seq_output) seq_output = nn.LayerNorm(seq_output.size()[-1])(seq_output) logits = self.classifier(seq_output) outputs = (logits, ) + outputs[2:] masked_labels, masked_logits = self._get_masked_inputs( input_ids, labels, logits, attention_mask) if labels is not None: loss = self.crf(masked_logits, masked_labels, mask=None) # mask=None: 已经处理了所有的无用的位置 outputs = (loss, ) + outputs # (loss), logits, (hidden_states), (attentions) return outputs def _get_masked_inputs(self, input_ids, label_ids, logits, attention_mask): ignore_ids = self.ignore_ids # Remove unuseful positions masked_ids = input_ids[(1 == attention_mask)] masked_labels = label_ids[(1 == attention_mask)] masked_logits = logits[(1 == attention_mask)] for id in ignore_ids: masked_labels = masked_labels[(id != masked_ids)] masked_logits = masked_logits[(id != masked_ids)] masked_ids = masked_ids[(id != masked_ids)] return masked_labels, masked_logits