class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
    """BERT model for classification.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output.
    """

    def __init__(self, config, num_labels=4):
        super(BertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
        _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
                                     position_ids=None, head_mask=None, output_all_encoded_layers=False)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        return logits

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True
Beispiel #2
0
class BertClassifier(BertPreTrainedModel):
    """
    BERT multi-label classifier
    """
    def __init__(self, config):
        super(BertClassifier, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.init_weights()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        """
        Forward pass of the BERT classifier
        :param input_ids: the input IDs (bs, seq len)
        :param token_type_ids: (not used) a tensor of zeros indicating which sequence in sequence pairs (bs, seq len)
        :param attention_mask: tensor of one if not pad token, zero otherwise (bs, seq len)
        :return: logits corresponding to each output class (bs, )
        """
        _, pooled_output = self.bert(input_ids=input_ids,
                                     token_type_ids=token_type_ids,
                                     attention_mask=attention_mask)
        pooled_output = self.dropout(pooled_output)
        return self.classifier(pooled_output)

    def freeze_bert_encoder(self):
        """
        Prevents further backpropagation (used when testing)
        """
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        """
        Re-enables backpropagation (used when training)
        """
        for param in self.bert.parameters():
            param.requires_grad = True

    def save(self, path: str):
        print('save model parameters to [%s]' % path, file=sys.stderr)

        # Only save the model and not the entire pretrained Bert
        model_to_save = self.module if hasattr(self, 'module') else self
        torch.save(model_to_save.state_dict(), path)

    @staticmethod
    def load(model_path: str, bert_pretrained_path: str, num_labels: int):
        """ Load a fine-tuned model from a file.
        @param model_path (str): path to model
        """

        state_dict = torch.load(model_path)
        model = BertClassifier.from_pretrained(bert_pretrained_path,
                                               state_dict=state_dict,
                                               num_labels=num_labels)
        return model
Beispiel #3
0
class Teacher(nn.Module):
    def __init__(self, pretrained_model, freeze_bert=True, lstm_dim=-1):
        super(Teacher, self).__init__()
        self.output_dim = len(punctuation_dict)
        self.config = BertConfig.from_pretrained(pretrained_model, )
        self.bert_layer = BertModel(self.config)
        # Freeze bert layers
        # if freeze_bert:
        for p in self.bert_layer.parameters():
            p.requires_grad = False
        bert_dim = self.config.hidden_size
        if lstm_dim == -1:
            hidden_size = bert_dim
        else:
            hidden_size = lstm_dim
        self.lstm = nn.LSTM(input_size=bert_dim,
                            hidden_size=hidden_size,
                            num_layers=1,
                            bidirectional=True)

    def forward(self, input_ids, attention_mask):
        # if len(x.shape) == 1:
        #     x = x.view(1, x.shape[0])  # add dummy batch for single sample
        # (B, N, E) -> (B, N, E)
        out = self.bert_layer(input_ids, attention_mask=attention_mask)
        x = out.last_hidden_state
        # (B, N, E) -> (N, B, E)
        x = torch.transpose(x, 0, 1)
        x, (_, _) = self.lstm(x)
        # (N, B, E) -> (B, N, E)
        x = torch.transpose(x, 0, 1)
        x = self.linear(x)

        return x, hs[0], hs[6], hs[12]
Beispiel #4
0
class BertForLinearSequenceToSequenceProbing(ProteinBertAbstractModel):
    """Bert head for token-level prediction tasks (secondary structure, binding sites)"""

    def __init__(self, config):
        super().__init__(config)

        self.bert = BertModel(config)

        self.classify = LinearSequenceToSequenceClassificationHead(
            config.hidden_size,
            config.num_labels,
            ignore_index=-1,
            dropout=0.5)

        for param in self.bert.parameters():
            param.requires_grad = False

        self.init_weights()

    def forward(self, input_ids, input_mask=None, targets=None):
        outputs = self.bert(input_ids, input_mask=input_mask)

        sequence_output, pooled_output = outputs[:2]
        outputs = self.classify(sequence_output, targets) + outputs[2:]
        return outputs
Beispiel #5
0
class RCNNModel(BertPreTrainedModel):
    def __init__(self, config, num_class):
        super(RCNNModel, self).__init__(config)
        self.bert = BertModel(config)
        for param in self.bert.parameters():
            param.requires_grad = True
        self.lstm = nn.LSTM(768,
                            256,
                            2,
                            bidirectional=True,
                            batch_first=True,
                            dropout=0.1)
        self.maxpool = nn.MaxPool1d(512)
        self.fc = nn.Linear(512 + 768, num_class)

    def forward(self, x, masks):
        encoder_out, text_cls = self.bert(input_ids=x, attention_mask=masks)
        out, _ = self.lstm(encoder_out)
        out = torch.cat((encoder_out, out), 2)
        out = F.relu(out)
        out = out.permute(0, 2, 1)
        # print(out.size())
        out = self.maxpool(out)
        out = out.squeeze()
        out = self.fc(out)
        return out
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
    def __init__(self, config, label_emb):
        super(BertForMultiLabelSequenceClassification,
              self).__init__(config, label_emb=None)
        self.num_labels = config.num_labels
        self.hidden_size = config.hidden_size
        self.label_emb = label_emb

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(0.1)

        self.self_attn = SelfAttention(self.hidden_size, self.num_labels)
        self.label_attn = LabelAttention(self.hidden_size, self.num_labels,
                                         self.label_emb)
        self.linear = MLinear(self.hidden_size, self.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
    ):
        sequence, _ = self.bert(input_ids, attention_mask)
        sequence = self.dropout(sequence)  # [batch, sequence, hidden_size]

        masks = attention_mask != 0  # [batch, sequence]
        masks = torch.unsqueeze(masks, 1)  # [batch, 1, sequence]

        self_attn = self.self_attn(sequence, masks)
        label_attn = self.label_attn(sequence, masks)

        return self.linear(self_attn, label_attn)

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True
Beispiel #7
0
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
    """BERT model for classification.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output.
    """
    def __init__(self, config):
        super(BertForMultiLabelSequenceClassification, self).__init__(config)

        self.bert = BertModel(config)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
        self.classifier = torch.nn.Linear(config.hidden_size + 128,
                                          config.num_labels)

        self.init_weights()

    def forward(self,
                input_ids,
                node_vec,
                tfidf,
                token_type_ids=None,
                attention_mask=None,
                position_ids=None,
                head_mask=None,
                labels=None):

        outputs = self.bert(input_ids,
                            token_type_ids=token_type_ids,
                            attention_mask=attention_mask,
                            position_ids=None,
                            head_mask=None)[0]
        outputs = torch.sum(outputs * tfidf.unsqueeze(2), 1)
        outputs = torch.cat((outputs, node_vec), 1)

        outputs = self.dropout(outputs)
        logits = self.classifier(outputs)

        return logits

    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True
Beispiel #8
0
class BaseModel(BertPreTrainedModel):
    def __init__(self, config, num_class):
        super(BaseModel, self).__init__(config)
        self.bert = BertModel(config)
        for param in self.bert.parameters():
            param.requires_grad = True
        self.dropout = nn.Dropout(0.5)
        self.fc1 = nn.Linear(768, num_class)
        self.init_weights()

    def forward(self, x, masks):
        encoder_out, text_cls = self.bert(input_ids=x, attention_mask=masks)
        x = self.dropout(text_cls)
        x = self.fc1(x)
        return x
Beispiel #9
0
class DPCNNModel(BertPreTrainedModel):
    def __init__(self, config, num_class):
        super(DPCNNModel, self).__init__(config)
        self.bert = BertModel(config)
        for param in self.bert.parameters():
            param.requires_grad = True
        # self.fc = nn.Linear(config.hidden_size, config.num_classes)
        self.conv_region = nn.Conv2d(1, 250, (3, 768), stride=1)
        self.conv = nn.Conv2d(250, 250, (3, 1), stride=1)
        self.max_pool = nn.MaxPool2d(kernel_size=(3, 1), stride=2)
        self.padding1 = nn.ZeroPad2d((0, 0, 1, 1))  # top bottom
        self.padding2 = nn.ZeroPad2d((0, 0, 0, 1))  # bottom
        self.relu = nn.ReLU()
        self.fc = nn.Linear(250, num_class)

    def forward(self, x, masks):
        encoder_out, text_cls = self.bert(input_ids=x, attention_mask=masks)
        x = encoder_out.unsqueeze(1)  # [batch_size, 1, seq_len, embed]
        x = self.conv_region(x)  # [batch_size, 250, seq_len-3+1, 1]

        x = self.padding1(x)  # [batch_size, 250, seq_len, 1]
        x = self.relu(x)
        x = self.conv(x)  # [batch_size, 250, seq_len-3+1, 1]
        x = self.padding1(x)  # [batch_size, 250, seq_len, 1]
        x = self.relu(x)
        x = self.conv(x)  # [batch_size, 250, seq_len-3+1, 1]
        while x.size()[2] > 2:
            x = self._block(x)
        x = x.squeeze()  # [batch_size, num_filters(250)]
        x = self.fc(x)
        return x

    def _block(self, x):
        x = self.padding2(x)
        px = self.max_pool(x)
        x = self.padding1(px)
        x = F.relu(x)
        x = self.conv(x)
        x = self.padding1(x)
        x = F.relu(x)
        x = self.conv(x)
        x = x + px  # short cut
        return x
Beispiel #10
0
def add_enc_adapters(bert_model: BertModel,
                     config: AdapterConfig) -> BertModel:

    # Replace specific layer with adapter-added layer
    bert_encoder = bert_model.encoder
    for i in range(len(bert_model.encoder.layer)):
        bert_encoder.layer[i].attention.output = adapt_bert_self_output(
            config)(bert_encoder.layer[i].attention.output)
        bert_encoder.layer[i].output = adapt_bert_output(config)(
            bert_encoder.layer[i].output)

    # Freeze all parameters
    for param in bert_model.parameters():
        param.requires_grad = False
    # Unfreeze trainable parts — layer norms and adapters
    for name, sub_module in bert_model.named_modules():
        if isinstance(sub_module, (Adapter_func, BertLayerNorm)):
            for param_name, param in sub_module.named_parameters():
                param.requires_grad = True
    return bert_model
Beispiel #11
0
class BertABSATagger(BertPreTrainedModel):
    def __init__(self, bert_config):
        """

        :param bert_config: configuration for bert model
        """
        super(BertABSATagger, self).__init__(bert_config)
        self.num_labels = bert_config.num_labels
        self.tagger_config = TaggerConfig()
        self.tagger_config.absa_type = bert_config.absa_type.lower()
        if bert_config.tfm_mode == 'finetune':
            # initialized with pre-trained BERT and perform finetuning
            print("Fine-tuning the pre-trained BERT...")
            self.bert = BertModel(bert_config)
        else:
            raise Exception("Invalid transformer mode %s!!!" %
                            bert_config.tfm_mode)
        self.bert_dropout = nn.Dropout(bert_config.hidden_dropout_prob)
        # fix the parameters in BERT and regard it as feature extractor
        if bert_config.fix_tfm:
            # fix the parameters of the (pre-trained or randomly initialized) transformers during fine-tuning
            for p in self.bert.parameters():
                p.requires_grad = False

        self.tagger = None
        if self.tagger_config.absa_type == 'linear':
            # hidden size at the penultimate layer
            penultimate_hidden_size = bert_config.hidden_size
        else:
            self.tagger_dropout = nn.Dropout(
                self.tagger_config.hidden_dropout_prob)
            if self.tagger_config.absa_type == 'lstm':
                self.tagger = LSTM(
                    input_size=bert_config.hidden_size,
                    hidden_size=self.tagger_config.hidden_size,
                    bidirectional=self.tagger_config.bidirectional)
            elif self.tagger_config.absa_type == 'gru':
                self.tagger = GRU(
                    input_size=bert_config.hidden_size,
                    hidden_size=self.tagger_config.hidden_size,
                    bidirectional=self.tagger_config.bidirectional)
            elif self.tagger_config.absa_type == 'tfm':
                # transformer encoder layer
                self.tagger = nn.TransformerEncoderLayer(
                    d_model=bert_config.hidden_size,
                    nhead=12,
                    dim_feedforward=4 * bert_config.hidden_size,
                    dropout=0.1)
            elif self.tagger_config.absa_type == 'san':
                # vanilla self attention networks
                self.tagger = SAN(d_model=bert_config.hidden_size,
                                  nhead=12,
                                  dropout=0.1)
            elif self.tagger_config.absa_type == 'crf':
                self.tagger = CRF(num_tags=self.num_labels)
            else:
                raise Exception('Unimplemented downstream tagger %s...' %
                                self.tagger_config.absa_type)
            penultimate_hidden_size = self.tagger_config.hidden_size
        self.classifier = nn.Linear(penultimate_hidden_size,
                                    bert_config.num_labels)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None,
                position_ids=None,
                head_mask=None):
        outputs = self.bert(input_ids,
                            position_ids=position_ids,
                            token_type_ids=token_type_ids,
                            attention_mask=attention_mask,
                            head_mask=head_mask)
        # the hidden states of the last Bert Layer, shape: (bsz, seq_len, hsz)
        tagger_input = outputs[0]
        tagger_input = self.bert_dropout(tagger_input)
        #print("tagger_input.shape:", tagger_input.shape)
        if self.tagger is None or self.tagger_config.absa_type == 'crf':
            # regard classifier as the tagger
            logits = self.classifier(tagger_input)
        else:
            if self.tagger_config.absa_type == 'lstm':
                # customized LSTM
                classifier_input, _ = self.tagger(tagger_input)
            elif self.tagger_config.absa_type == 'gru':
                # customized GRU
                classifier_input, _ = self.tagger(tagger_input)
            elif self.tagger_config.absa_type == 'san' or self.tagger_config.absa_type == 'tfm':
                # vanilla self-attention networks or transformer
                # adapt the input format for the transformer or self attention networks
                tagger_input = tagger_input.transpose(0, 1)
                classifier_input = self.tagger(tagger_input)
                classifier_input = classifier_input.transpose(0, 1)
            else:
                raise Exception("Unimplemented downstream tagger %s..." %
                                self.tagger_config.absa_type)
            classifier_input = self.tagger_dropout(classifier_input)
            logits = self.classifier(classifier_input)
        outputs = (logits, ) + outputs[2:]

        if labels is not None:
            if self.tagger_config.absa_type != 'crf':
                loss_fct = CrossEntropyLoss()
                if attention_mask is not None:
                    active_loss = attention_mask.view(-1) == 1
                    active_logits = logits.view(-1,
                                                self.num_labels)[active_loss]
                    active_labels = labels.view(-1)[active_loss]
                    loss = loss_fct(active_logits, active_labels)
                else:
                    loss = loss_fct(logits.view(-1, self.num_labels),
                                    labels.view(-1))
                outputs = (loss, ) + outputs
            else:
                log_likelihood = self.tagger(inputs=logits,
                                             tags=labels,
                                             mask=attention_mask)
                loss = -log_likelihood
                outputs = (loss, ) + outputs
        return outputs
Beispiel #12
0
def main():
    parser = argparse.ArgumentParser(
        description='Train the individual Transformer model')
    parser.add_argument('--dataset_folder', type=str, default='datasets')
    parser.add_argument('--dataset_name', type=str, default='zara1')
    parser.add_argument('--obs', type=int, default=8)
    parser.add_argument('--preds', type=int, default=12)
    parser.add_argument('--emb_size', type=int, default=1024)
    parser.add_argument('--heads', type=int, default=8)
    parser.add_argument('--layers', type=int, default=6)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--cpu', action='store_true')
    parser.add_argument('--output_folder', type=str, default='Output')
    parser.add_argument('--val_size', type=int, default=50)
    parser.add_argument('--gpu_device', type=str, default="0")
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument('--max_epoch', type=int, default=100)
    parser.add_argument('--batch_size', type=int, default=256)
    parser.add_argument('--validation_epoch_start', type=int, default=30)
    parser.add_argument('--resume_train', action='store_true')
    parser.add_argument('--delim', type=str, default='\t')
    parser.add_argument('--name', type=str, default="zara1")

    args = parser.parse_args()
    model_name = args.name

    try:
        os.mkdir('models')
    except:
        pass
    try:
        os.mkdir('output')
    except:
        pass
    try:
        os.mkdir('output/BERT')
    except:
        pass
    try:
        os.mkdir(f'models/BERT')
    except:
        pass

    try:
        os.mkdir(f'output/BERT/{args.name}')
    except:
        pass

    try:
        os.mkdir(f'models/BERT/{args.name}')
    except:
        pass

    log = SummaryWriter('logs/BERT_%s' % model_name)

    log.add_scalar('eval/mad', 0, 0)
    log.add_scalar('eval/fad', 0, 0)

    try:
        os.mkdir(args.name)
    except:
        pass

    device = torch.device("cuda")
    if args.cpu or not torch.cuda.is_available():
        device = torch.device("cpu")

    args.verbose = True

    ## creation of the dataloaders for train and validation
    train_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                    args.dataset_name,
                                                    0,
                                                    args.obs,
                                                    args.preds,
                                                    delim=args.delim,
                                                    train=True,
                                                    verbose=args.verbose)
    val_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                  args.dataset_name,
                                                  0,
                                                  args.obs,
                                                  args.preds,
                                                  delim=args.delim,
                                                  train=False,
                                                  verbose=args.verbose)
    test_dataset, _ = baselineUtils.create_dataset(args.dataset_folder,
                                                   args.dataset_name,
                                                   0,
                                                   args.obs,
                                                   args.preds,
                                                   delim=args.delim,
                                                   train=False,
                                                   eval=True,
                                                   verbose=args.verbose)

    from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, AdamW

    config = BertConfig(vocab_size=30522,
                        hidden_size=768,
                        num_hidden_layers=12,
                        num_attention_heads=12,
                        intermediate_size=3072,
                        hidden_act='relu',
                        hidden_dropout_prob=0.1,
                        attention_probs_dropout_prob=0.1,
                        max_position_embeddings=512,
                        type_vocab_size=2,
                        initializer_range=0.02,
                        layer_norm_eps=1e-12)
    model = BertModel(config).to(device)

    from individual_TF import LinearEmbedding as NewEmbed, Generator as GeneratorTS
    a = NewEmbed(3, 768).to(device)
    model.set_input_embeddings(a)
    generator = GeneratorTS(768, 2).to(device)
    #model.set_output_embeddings(GeneratorTS(1024,2))

    tr_dl = torch.utils.data.DataLoader(train_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=True,
                                        num_workers=0)
    val_dl = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=0)
    test_dl = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=False,
                                          num_workers=0)

    #optim = SGD(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01)
    #sched=torch.optim.lr_scheduler.StepLR(optim,0.0005)
    optim = NoamOpt(
        768, 0.1, len(tr_dl),
        torch.optim.Adam(list(a.parameters()) + list(model.parameters()) +
                         list(generator.parameters()),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
    #optim=Adagrad(list(a.parameters())+list(model.parameters())+list(generator.parameters()),lr=0.01,lr_decay=0.001)
    epoch = 0

    mean = train_dataset[:]['src'][:, :, 2:4].mean((0, 1)) * 0
    std = train_dataset[:]['src'][:, :, 2:4].std((0, 1)) * 0 + 1

    while epoch < args.max_epoch:
        epoch_loss = 0
        model.train()

        for id_b, batch in enumerate(tr_dl):

            optim.optimizer.zero_grad()
            r = 0
            rot_mat = np.array([[np.cos(r), np.sin(r)],
                                [-np.sin(r), np.cos(r)]])

            inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device)
            inp = torch.matmul(inp,
                               torch.from_numpy(rot_mat).float().to(device))
            trg_masked = torch.zeros((inp.shape[0], args.preds, 2)).to(device)
            inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device)
            trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1],
                                  1).to(device)
            inp_cat = torch.cat((inp, trg_masked), 1)
            cls_cat = torch.cat((inp_cls, trg_cls), 1)
            net_input = torch.cat((inp_cat, cls_cat), 2)

            position = torch.arange(0, net_input.shape[1]).repeat(
                inp.shape[0], 1).long().to(device)
            token = torch.zeros(
                (inp.shape[0], net_input.shape[1])).long().to(device)
            attention_mask = torch.ones(
                (inp.shape[0], net_input.shape[1])).long().to(device)

            out = model(input_ids=net_input,
                        position_ids=position,
                        token_type_ids=token,
                        attention_mask=attention_mask)

            pred = generator(out[0])

            loss = F.pairwise_distance(
                pred[:, :].contiguous().view(-1, 2),
                torch.matmul(
                    torch.cat(
                        (batch['src'][:, :, 2:4], batch['trg'][:, :, 2:4]),
                        1).contiguous().view(-1, 2).to(device),
                    torch.from_numpy(rot_mat).float().to(device))).mean()
            loss.backward()
            optim.step()
            print("epoch %03i/%03i  frame %04i / %04i loss: %7.4f" %
                  (epoch, args.max_epoch, id_b, len(tr_dl), loss.item()))
            epoch_loss += loss.item()
        #sched.step()
        log.add_scalar('Loss/train', epoch_loss / len(tr_dl), epoch)
        with torch.no_grad():
            model.eval()

            gt = []
            pr = []
            val_loss = 0
            for batch in val_dl:
                inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device)
                trg_masked = torch.zeros(
                    (inp.shape[0], args.preds, 2)).to(device)
                inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device)
                trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1],
                                      1).to(device)
                inp_cat = torch.cat((inp, trg_masked), 1)
                cls_cat = torch.cat((inp_cls, trg_cls), 1)
                net_input = torch.cat((inp_cat, cls_cat), 2)

                position = torch.arange(0, net_input.shape[1]).repeat(
                    inp.shape[0], 1).long().to(device)
                token = torch.zeros(
                    (inp.shape[0], net_input.shape[1])).long().to(device)
                attention_mask = torch.zeros(
                    (inp.shape[0], net_input.shape[1])).long().to(device)

                out = model(input_ids=net_input,
                            position_ids=position,
                            token_type_ids=token,
                            attention_mask=attention_mask)

                pred = generator(out[0])

                loss = F.pairwise_distance(
                    pred[:, :].contiguous().view(-1, 2),
                    torch.cat(
                        (batch['src'][:, :, 2:4], batch['trg'][:, :, 2:4]),
                        1).contiguous().view(-1, 2).to(device)).mean()
                val_loss += loss.item()

                gt_b = batch['trg'][:, :, 0:2]
                preds_tr_b = pred[:, args.obs:].cumsum(1).to(
                    'cpu').detach() + batch['src'][:, -1:, 0:2]
                gt.append(gt_b)
                pr.append(preds_tr_b)

            gt = np.concatenate(gt, 0)
            pr = np.concatenate(pr, 0)
            mad, fad, errs = baselineUtils.distance_metrics(gt, pr)
            log.add_scalar('validation/loss', val_loss / len(val_dl), epoch)
            log.add_scalar('validation/mad', mad, epoch)
            log.add_scalar('validation/fad', fad, epoch)

            model.eval()

            gt = []
            pr = []
            for batch in test_dl:
                inp = ((batch['src'][:, :, 2:4] - mean) / std).to(device)
                trg_masked = torch.zeros(
                    (inp.shape[0], args.preds, 2)).to(device)
                inp_cls = torch.ones(inp.shape[0], inp.shape[1], 1).to(device)
                trg_cls = torch.zeros(trg_masked.shape[0], trg_masked.shape[1],
                                      1).to(device)
                inp_cat = torch.cat((inp, trg_masked), 1)
                cls_cat = torch.cat((inp_cls, trg_cls), 1)
                net_input = torch.cat((inp_cat, cls_cat), 2)

                position = torch.arange(0, net_input.shape[1]).repeat(
                    inp.shape[0], 1).long().to(device)
                token = torch.zeros(
                    (inp.shape[0], net_input.shape[1])).long().to(device)
                attention_mask = torch.zeros(
                    (inp.shape[0], net_input.shape[1])).long().to(device)

                out = model(input_ids=net_input,
                            position_ids=position,
                            token_type_ids=token,
                            attention_mask=attention_mask)

                pred = generator(out[0])

                gt_b = batch['trg'][:, :, 0:2]
                preds_tr_b = pred[:, args.obs:].cumsum(1).to(
                    'cpu').detach() + batch['src'][:, -1:, 0:2]
                gt.append(gt_b)
                pr.append(preds_tr_b)

            gt = np.concatenate(gt, 0)
            pr = np.concatenate(pr, 0)
            mad, fad, errs = baselineUtils.distance_metrics(gt, pr)

            torch.save(model.state_dict(),
                       "models/BERT/%s/ep_%03i.pth" % (args.name, epoch))
            torch.save(generator.state_dict(),
                       "models/BERT/%s/gen_%03i.pth" % (args.name, epoch))
            torch.save(a.state_dict(),
                       "models/BERT/%s/emb_%03i.pth" % (args.name, epoch))

            log.add_scalar('eval/mad', mad, epoch)
            log.add_scalar('eval/fad', fad, epoch)

        epoch += 1

    ab = 1
Beispiel #13
0
class BertEncoder(MetaModule):
    """BERT model as presented in Google's paper and using Hugging Face's code

    References:
        https://arxiv.org/abs/1810.04805
    """

    class Config(BaseConfig):
        model_name: Union[str, Path] = 'bert-base-multilingual-cased'
        """Pre-trained BERT model to use."""

        use_mismatch_features: bool = False
        """Use Alibaba's mismatch features."""

        use_predictor_features: bool = False
        """Use features originally proposed in the Predictor model."""

        interleave_input: bool = False
        """Concatenate SOURCE and TARGET without internal padding
        (111222000 instead of 111002220)"""

        freeze: bool = False
        """Freeze BERT during training."""

        use_mlp: bool = True
        """Apply a linear layer on top of BERT."""

        hidden_size: int = 100
        """Size of the linear layer on top of BERT."""

        scalar_mix_dropout: confloat(ge=0.0, le=1.0) = 0.1
        scalar_mix_layer_norm: bool = True

        @validator('model_name', pre=True)
        def fix_relative_path(cls, v):
            if (
                v not in BERT_PRETRAINED_MODEL_ARCHIVE_LIST
                and v not in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST
            ):
                v = Path(v)
                if not v.is_absolute():
                    v = Path.cwd().joinpath(v)
            return v

        @validator('use_mismatch_features', 'use_predictor_features', pre=True)
        def no_implementation(cls, v):
            if v:
                raise NotImplementedError('Not yet implemented')
            return False

    def __init__(
        self, vocabs: Dict[str, Vocabulary], config: Config, pre_load_model: bool = True
    ):
        super().__init__(config=config)

        if pre_load_model:
            self.bert = BertModel.from_pretrained(
                self.config.model_name, output_hidden_states=True
            )
        else:
            bert_config = BertConfig.from_pretrained(
                self.config.model_name, output_hidden_states=True
            )
            self.bert = BertModel(bert_config)

        self.vocabs = {
            const.TARGET: vocabs[const.TARGET],
            const.SOURCE: vocabs[const.SOURCE],
        }

        self.mlp = None
        if self.config.use_mlp:
            self.mlp = nn.Sequential(
                nn.Linear(self.bert.config.hidden_size, self.config.hidden_size),
                nn.Tanh(),
            )
            output_size = self.config.hidden_size
        else:
            output_size = self.bert.config.hidden_size

        self.scalar_mix = ScalarMixWithDropout(
            mixture_size=self.bert.config.num_hidden_layers + 1,  # +1 for embeddings
            do_layer_norm=self.config.scalar_mix_layer_norm,
            dropout=self.config.scalar_mix_dropout,
        )

        self._sizes = {
            const.TARGET: output_size,
            const.TARGET_LOGITS: output_size,
            const.TARGET_SENTENCE: self.bert.config.hidden_size,
            const.SOURCE: output_size,
        }

        self.output_embeddings = self.bert.embeddings.word_embeddings

        if self.config.freeze:
            for param in self.bert.parameters():
                param.requires_grad = False

    def load_state_dict(
        self,
        state_dict: Union[Dict[str, Tensor], Dict[str, Tensor]],
        strict: bool = True,
    ):
        try:
            keys = super().load_state_dict(state_dict, strict)
        except RuntimeError as e:
            if "position_ids" in str(e):
                # FIXME: hack to get around Transformers 3.1 breaking changes
                # https://github.com/huggingface/transformers/issues/6882
                self.bert.embeddings._non_persistent_buffers_set.add('position_ids')
                keys = super().load_state_dict(state_dict, strict)
                self.bert.embeddings._non_persistent_buffers_set.discard('position_ids')
            else:
                raise e
        return keys

    @classmethod
    def input_data_encoders(cls, config: Config):
        return {
            const.SOURCE: TransformersTextEncoder(
                tokenizer_name=config.model_name, is_source=True
            ),
            const.TARGET: TransformersTextEncoder(tokenizer_name=config.model_name),
        }

    def size(self, field=None):
        if field:
            return self._sizes[field]
        return self._sizes

    def forward(
        self,
        batch_inputs,
        *args,
        include_target_logits=False,
        include_source_logits=False
    ):
        # BERT gets it's input as a concatenation of both embeddings
        # or as an interleave of inputs
        if self.config.interleave_input:
            merge_input_fn = self.interleave_input
        else:
            merge_input_fn = self.concat_input

        input_ids, token_type_ids, attention_mask = merge_input_fn(
            batch_inputs[const.SOURCE],
            batch_inputs[const.TARGET],
            pad_id=self.vocabs[const.TARGET].pad_id,
        )

        # hidden_states also includes the embedding layer
        # hidden_states[-1] is the last layer
        last_hidden_state, pooler_output, hidden_states = self.bert(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
        )

        # TODO: select one of these strategies via cli
        # TODO: get a BETTER strategy
        features = self.scalar_mix(hidden_states, attention_mask)

        # features = sum(hidden_states[-5:-1])
        # features = hidden_states[-2]

        if self.config.use_mlp:
            features = self.mlp(features)

        # Build the feature dictionary to be returned to the system
        output_features = self.split_outputs(
            features, batch_inputs, interleaved=self.config.interleave_input
        )

        # Convert pieces to tokens
        target_features = pieces_to_tokens(
            output_features[const.TARGET], batch_inputs[const.TARGET]
        )
        source_features = pieces_to_tokens(
            output_features[const.SOURCE], batch_inputs[const.SOURCE]
        )

        # sentence_features = pooler_output
        sentence_features = last_hidden_state.mean(dim=1)

        # Substitute CLS on target side
        # target_features[:, 0] = 0

        output_features[const.TARGET] = target_features
        output_features[const.SOURCE] = source_features
        output_features[const.TARGET_SENTENCE] = sentence_features

        # Logits for multi-task fine-tuning
        if include_target_logits:
            output_features[const.TARGET_LOGITS] = torch.einsum(
                'vh,bsh->bsv',
                self.output_embeddings.weight,
                output_features[const.TARGET],
            )
        if include_source_logits:
            output_features[const.SOURCE_LOGITS] = torch.einsum(
                'vh,bsh->bsv',
                self.output_embeddings.weight,
                output_features[const.SOURCE],
            )

        # Additional features
        if self.config.use_mismatch_features:
            raise NotImplementedError

        return output_features

    @staticmethod
    def concat_input(source_batch, target_batch, pad_id):
        """Concatenate the target + source embeddings into one tensor.

        Return:
             concatenation of embeddings, mask of target (as ones) and source
                 (as zeroes) and concatenation of attention_mask
        """
        source_ids = source_batch.tensor
        target_ids = target_batch.tensor

        source_attention_mask = retrieve_tokens_mask(source_batch)
        target_attention_mask = retrieve_tokens_mask(target_batch)

        target_types = torch.zeros_like(target_ids)
        # zero denotes first sequence
        source_types = torch.ones_like(source_ids)
        input_ids = torch.cat((target_ids, source_ids), dim=1)
        token_type_ids = torch.cat((target_types, source_types), dim=1)
        attention_mask = torch.cat(
            (target_attention_mask, source_attention_mask), dim=1
        )
        return input_ids, token_type_ids, attention_mask

    @staticmethod
    def split_outputs(
        features: Tensor, batch_inputs: MultiFieldBatch, interleaved: bool = False
    ) -> Dict[str, Tensor]:
        """Split features back into sentences A and B.

        Args:
            features: BERT's output: ``[CLS] target [SEP] source [SEP]``.
                Shape of (bs, 1 + target_len + 1 + source_len + 1, 2)
            batch_inputs: the regular batch object, containing ``source`` and ``target``
                batches
            interleaved: whether the concat strategy was interleaved

        Return:
            dict of tensors for ``source`` and ``target``.
        """
        outputs = OrderedDict()

        target_lengths = batch_inputs[const.TARGET].lengths

        if interleaved:
            raise NotImplementedError('interleaving not supported.')
            # TODO: fix code below to use the lengths information and not bounds
            # if interleaved, shift each source sample by its correspondent length
            shift = target_lengths.unsqueeze(-1)

            range_vector = torch.arange(
                features.size(0), device=features.device
            ).unsqueeze(1)

            target_bounds = batch_inputs[const.TARGET].bounds
            target_features = features[range_vector, target_bounds]
            # Shift bounds by target length and preserve padding
            source_bounds = batch_inputs[const.SOURCE].bounds
            m = (source_bounds != -1).long()  # for masking out padding (which is -1)
            shifted_bounds = (source_bounds + shift) * m + source_bounds * (1 - m)
            source_features = features[range_vector, shifted_bounds]
        else:
            # otherwise, shift all by max_length
            # if we'd like to maintain the word pieces we merely select all
            target_features = features[:, : target_lengths.max()]
            # ignore the target and get the rest
            source_features = features[:, target_lengths.max() :]

        outputs[const.TARGET] = target_features

        # Source doesn't have an init_token (like CLS) and we keep SEP
        outputs[const.SOURCE] = source_features

        return outputs

    # TODO this strategy is not being used, should we keep it?
    @staticmethod
    def interleave_input(source_batch, target_batch, pad_id):
        """Interleave the source + target embeddings into one tensor.

        This means making the input as [batch, target [SEP] source].

        Return:
            interleave of embds, mask of target (as zeroes) and source (as ones)
                and concatenation of attention_mask.
        """
        source_ids = source_batch.tensor
        target_ids = target_batch.tensor

        batch_size = source_ids.size(0)

        source_lengths = source_batch.lengths
        target_lengths = target_batch.lengths

        max_pair_length = source_ids.size(1) + target_ids.size(1)

        input_ids = torch.full(
            (batch_size, max_pair_length),
            pad_id,
            dtype=torch.long,
            device=source_ids.device,
        )
        token_type_ids = torch.zeros_like(input_ids)
        attention_mask = torch.zeros_like(input_ids)

        for i in range(batch_size):
            # [CLS] and [SEP] are included in the mask (=1)
            # note: source does not have CLS
            t_len = target_lengths[i].item()
            s_len = source_lengths[i].item()

            input_ids[i, :t_len] = target_ids[i, :t_len]
            token_type_ids[i, :t_len] = 0
            attention_mask[i, :t_len] = 1

            input_ids[i, t_len : t_len + s_len] = source_ids[i, :s_len]
            token_type_ids[i, t_len : t_len + s_len] = 1
            attention_mask[i, t_len : t_len + s_len] = 1

        # TODO, why is attention mask 1 for all positions?
        return input_ids, token_type_ids, attention_mask

    @staticmethod
    def get_mismatch_features(logits, target, pred):
        # calculate mismatch features and concat them
        t_max = torch.gather(logits, -1, target.unsqueeze(-1))
        p_max = torch.gather(logits, -1, pred.unsqueeze(-1))
        diff_max = t_max - p_max
        diff_arg = (target != pred).float().unsqueeze(-1)
        mismatch = torch.cat((t_max, p_max, diff_max, diff_arg), dim=-1)
        return mismatch
Beispiel #14
0
class CSER(BertPreTrainedModel):
    """ Span-based model to extract entities """
    def __init__(self, config: BertConfig, cls_token: int, entity_types: int,
                 size_embedding: int, prop_drop: float,
                 freeze_transformer: bool):  # noqa

        super(CSER, self).__init__(config)

        # BERT model
        self.bert = BertModel(config)

        # layers
        self.entity_classifier = nn.Linear(
            config.hidden_size * 2 + size_embedding, entity_types)
        self.size_embeddings = nn.Embedding(100, size_embedding)
        self.dropout = nn.Dropout(prop_drop)

        self._cls_token = cls_token
        self._entity_types = entity_types

        # weight initialization
        self.init_weights()

        if freeze_transformer:
            print("Freeze transformer weights")

            # freeze all transformer weights
            for param in self.bert.parameters():
                param.requires_grad = False

    def _forward_eval(self, encodings: torch.tensor,
                      context_masks: torch.tensor, entity_masks: torch.tensor,
                      entity_sizes: torch.tensor, entity_spans: torch.tensor,
                      entity_sample_masks: torch.tensor):  # noqa
        # get contextualized token embeddings from last transformer layer
        context_masks = context_masks.float()
        h = self.bert(input_ids=encodings, attention_mask=context_masks)[0]

        # classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_clf, entity_spans_pool = self._classify_entities(
            encodings, h, entity_masks, size_embeddings)

        # apply softmax
        entity_clf = torch.softmax(entity_clf, dim=2)

        return entity_clf

    def _classify_entities(self, encodings, h, entity_masks, size_embeddings):
        # max pool entity candidate spans
        m = (entity_masks.unsqueeze(-1) == 0).float() * (-1e30)
        entity_spans_pool = m + h.unsqueeze(1).repeat(1, entity_masks.shape[1],
                                                      1, 1)
        entity_spans_pool = entity_spans_pool.max(dim=2)[0]

        # get cls token as candidate context representation
        entity_ctx = get_token(h, encodings, self._cls_token)

        # create candidate representations including context, max pooled span and size embedding
        entity_repr = torch.cat([
            entity_ctx.unsqueeze(1).repeat(1, entity_spans_pool.shape[1], 1),
            entity_spans_pool, size_embeddings
        ],
                                dim=2)
        entity_repr = self.dropout(entity_repr)

        # classify entity candidates
        entity_clf = self.entity_classifier(entity_repr)

        return entity_clf, entity_spans_pool

    def forward(self, *args, **kwargs):
        return self._forward_eval(*args, **kwargs)
class SpEER(BertPreTrainedModel):
    """ Span-based model to jointly extract entities and relations """
    def __init__(self,
                 config: BertConfig,
                 cls_token: int,
                 relation_types: int,
                 entity_types: int,
                 size_embedding: int,
                 prop_drop: float,
                 freeze_transformer: bool,
                 max_pairs: int = 100,
                 encoding_size: int = 200,
                 feature_enhancer: str = "pass"):
        super(SpEER, self).__init__(config)

        # BERT model
        self.bert = BertModel(config)

        # layers
        self.encoding_size = encoding_size
        self.feature_enhancer = fe.get_feature_enhancer(feature_enhancer)(
            config.hidden_size, config.hidden_size)
        self.rel_encoder = nn.Linear(
            config.hidden_size * 3 + size_embedding * 2, encoding_size)
        self.entity_encoder = nn.Linear(
            config.hidden_size * 2 + size_embedding, encoding_size)
        self.size_embeddings = nn.Embedding(100, size_embedding)
        self.dropout = nn.Dropout(prop_drop)

        self._cls_token = cls_token
        self._relation_types = relation_types
        self._entity_types = entity_types
        self._max_pairs = max_pairs

        # weight initialization
        self.init_weights()

        if freeze_transformer or feature_enhancer not in {
                "pass", "transformer"
        }:
            print("Freeze transformer weights")

            # freeze all transformer weights
            for param in self.bert.parameters():
                param.requires_grad = False

    def _forward_train(self, encodings: torch.tensor,
                       context_mask: torch.tensor, entity_masks: torch.tensor,
                       entity_sizes: torch.tensor, relations: torch.tensor,
                       rel_masks: torch.tensor):
        # get contextualized token embeddings from last transformer layer
        context_mask = context_mask.float()
        h = self.bert(input_ids=encodings, attention_mask=context_mask)[0]

        # enhance hidden features
        orig_shape = h.shape
        h = self.feature_enhancer.prepare_input(h, context_mask)
        h = self.feature_enhancer(h)
        h = self.feature_enhancer.prepare_output(h, orig_shape)

        entity_masks = entity_masks.float()
        batch_size = encodings.shape[0]
        device = self.entity_encoder.weight.device

        # encode and classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_encoding, entity_spans_pool = self._encode_entities(
            encodings, h, entity_masks, size_embeddings)
        entity_clf = self._classify_entities(entity_encoding)

        # prepare relation encoding
        rel_masks = rel_masks.float().unsqueeze(-1)
        h_large = h.unsqueeze(1).repeat(
            1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
        rel_encoding = torch.zeros(
            [batch_size, relations.shape[1], self.encoding_size]).to(device)

        # obtain relation encodings
        # chunk processing to reduce memory usage
        for i in range(0, relations.shape[1], self._max_pairs):
            # classify relation candidates
            rel_encoding_chunk = self._encode_relations(
                entity_spans_pool, size_embeddings, relations, rel_masks,
                h_large, i)
            rel_encoding[:, i:i + self._max_pairs, :] = rel_encoding_chunk

        rel_clf = self._classify_relations(rel_encoding)

        return entity_clf, rel_clf

    def _forward_eval(self,
                      entity_knn_module,
                      rel_knn_module,
                      entity_entries: List[List[Dict]],
                      type_key: str,
                      encodings: torch.tensor,
                      context_mask: torch.tensor,
                      entity_masks: torch.tensor,
                      entity_sizes: torch.tensor,
                      entity_spans: torch.tensor = None,
                      entity_sample_mask: torch.tensor = None,
                      verbose=False):
        # get contextualized token embeddings from last transformer layer
        context_mask = context_mask.float()
        h = self.bert(input_ids=encodings, attention_mask=context_mask)[0]

        # enhance hidden features
        orig_shape = h.shape
        h = self.feature_enhancer.prepare_input(h, context_mask)
        h = self.feature_enhancer(h)
        h = self.feature_enhancer.prepare_output(h, orig_shape)

        entity_masks = entity_masks.float()
        batch_size = encodings.shape[0]
        ctx_size = context_mask.shape[-1]
        device = self.entity_encoder.weight.device

        # encode and classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_encoding, entity_spans_pool = self._encode_entities(
            encodings, h, entity_masks, size_embeddings)
        entity_encoding_reshaped = entity_encoding.view(
            entity_encoding.shape[0] * entity_encoding.shape[1], -1).cpu()
        entity_types, entity_neighbors = entity_knn_module.infer_(
            entity_encoding_reshaped, int, type_key)

        # for i, neighbors in enumerate(entity_neighbors):
        #     print(entity_types[i], neighbors)

        # print neighbor entities
        if verbose:
            print('*' * 50)
            print("entity neighbors:")
            entity_entries_flat = []
            for entry in entity_entries:
                entity_entries_flat += entry
            for i, neighbors in enumerate(entity_neighbors):
                if entity_types[i] == 0:
                    continue
                print("[ENT] {} >> {}".format(
                    entity_entries_flat[i]["phrase"].encode('utf-8'),
                    entity_types[i]))
                for j in range(min(len(neighbors), 5)):
                    n = neighbors[j]
                    print("\t", n["phrase"].encode('utf-8'), n["type_string"],
                          n[type_key])
                print()

        entity_types = torch.tensor(entity_types).view(
            entity_encoding.shape[0], entity_encoding.shape[1]).to(device)
        entity_clf = torch.zeros([
            entity_encoding.shape[0], entity_encoding.shape[1],
            self._entity_types
        ],
                                 dtype=torch.long).to(device)
        entity_clf.scatter_(2, entity_types.unsqueeze(2), 1)

        # ignore entity candidates that do not constitute an actual entity for relations (based on classifier)
        relations, rel_masks, rel_sample_masks, rel_entries = self._filter_spans(
            entity_clf, entity_spans, entity_sample_mask, entity_entries,
            ctx_size, type_key)
        rel_masks = rel_masks.float()
        rel_sample_masks = rel_sample_masks.float()
        h_large = h.unsqueeze(1).repeat(
            1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
        rel_encoding = torch.zeros(
            [batch_size, relations.shape[1], self.encoding_size]).to(device)

        # obtain relation encodings
        # chunk processing to reduce memory usage
        for i in range(0, relations.shape[1], self._max_pairs):
            # classify relation candidates
            rel_encoding_chunk = self._encode_relations(
                entity_spans_pool, size_embeddings, relations, rel_masks,
                h_large, i)
            rel_encoding[:, i:i + self._max_pairs, :] = rel_encoding_chunk
        rel_encoding_reshaped = rel_encoding.view(
            rel_encoding.shape[0] * rel_encoding.shape[1], -1).cpu()

        # encode and classify relations
        rel_types, rel_neighbors = rel_knn_module.infer_(
            rel_encoding_reshaped, int, type_key)

        # print neighbor relations
        if verbose:
            print('*' * 50)
            rel_entries_flat = []
            for entry in rel_entries:
                rel_entries_flat += entry
            for i, neighbors in enumerate(rel_neighbors):
                if rel_types[i] == 0:
                    continue
                print("[REL] {} >> {}".format(
                    rel_entries_flat[i]["phrase"].encode('utf-8'),
                    rel_types[i]))
                for j in range(min(len(neighbors), 5)):
                    n = neighbors[j]
                    print("\t", n["phrase"].encode('utf-8'), n["type_string"],
                          n[type_key])
                print()

        rel_types = torch.LongTensor(rel_types).view(
            rel_encoding.shape[0], rel_encoding.shape[1]).to(device)
        rel_clf = torch.zeros([
            rel_encoding.shape[0], rel_encoding.shape[1], self._relation_types
        ],
                              dtype=torch.float32).to(device)

        rel_clf.scatter_(2, rel_types.unsqueeze(2), 1)
        rel_clf = rel_clf[:, :,
                          1:]  # exclude 'none' prediction for multi-label prediction

        rel_clf = rel_clf * rel_sample_masks  # mask

        return entity_clf, rel_clf, relations

    def _forward_encode(self, encodings: torch.tensor,
                        context_mask: torch.tensor, entity_masks: torch.tensor,
                        entity_sizes: torch.tensor, relations: torch.tensor,
                        rel_masks: torch.tensor):
        # get contextualized token embeddings from last transformer layer
        context_mask = context_mask.float()
        h = self.bert(input_ids=encodings, attention_mask=context_mask)[0]

        entity_masks = entity_masks.float()
        batch_size = encodings.shape[0]
        device = self.entity_encoder.weight.device

        # encode and classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_encoding, entity_spans_pool = self._encode_entities(
            encodings, h, entity_masks, size_embeddings)

        # prepare relation encoding
        rel_masks = rel_masks.float().unsqueeze(-1)
        h_large = h.unsqueeze(1).repeat(
            1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
        rel_encoding = torch.zeros(
            [batch_size, relations.shape[1], self.encoding_size]).to(device)

        # obtain relation encodings
        # chunk processing to reduce memory usage
        for i in range(0, relations.shape[1], self._max_pairs):
            # classify relation candidates
            rel_encoding_chunk = self._encode_relations(
                entity_spans_pool, size_embeddings, relations, rel_masks,
                h_large, i)
            rel_encoding[:, i:i + self._max_pairs, :] = rel_encoding_chunk

        return entity_encoding, rel_encoding

    def _classify_entities(self, entity_encoding, verification=False):
        # cosine similarities of every possible entity encoding pair in the batch
        cosine_similarities = torch.einsum('abc, ijc -> abij', entity_encoding,
                                           entity_encoding)

        # einsum verification (at least each element has similarity 1 with itself)
        if verification:
            with torch.no_grad():
                is_close_bools = cosine_similarities.isclose(
                    torch.tensor([1.00],
                                 device=self.entity_encoder.weight.device))
                is_close_sum = is_close_bools.int().sum().item()
                assert (is_close_sum >=
                        entity_encoding.shape[0] * entity_encoding.shape[1])

        # normalize cosine similarity from [-1, 1] to [0, 1], and clip float precision errors
        normalized_similarities = (cosine_similarities + 1) / 2
        normalized_similarities = normalized_similarities.clamp(0, 1)

        return normalized_similarities

    def _encode_entities(self, encodings, h, entity_masks, size_embeddings):
        # max pool entity candidate spans
        entity_spans_pool = entity_masks.unsqueeze(-1) * h.unsqueeze(1)
        entity_spans_pool = entity_spans_pool.max(dim=2)[0]

        # get cls token as candidate context representation
        entity_ctx = get_token(h, encodings, self._cls_token)

        # create candidate representations including context, max pooled span and size embedding
        entity_repr = torch.cat([
            entity_ctx.unsqueeze(1).repeat(1, entity_spans_pool.shape[1], 1),
            entity_spans_pool, size_embeddings
        ],
                                dim=2)
        entity_repr = self.dropout(entity_repr)

        # encode entity candidates
        entity_encoding = self.entity_encoder(entity_repr)

        # normalize encoding to unit length for cosine similarity
        entity_encoding = f.normalize(entity_encoding, dim=2, p=2)

        return entity_encoding, entity_spans_pool

    def _classify_relations(self, rel_encoding, verification=False):
        # cosine similarity of every possible relation encoding pair in the batch
        cosine_similarities = torch.einsum('abc, ijc -> abij', rel_encoding,
                                           rel_encoding)

        # einsum verification (at least each element has similarity 1 with itself)
        if verification:
            with torch.no_grad():
                is_close_bools = cosine_similarities.isclose(
                    torch.tensor([1.00],
                                 device=self.rel_encoder.weight.device))
                is_close_sum = is_close_bools.int().sum().item()
                assert (is_close_sum >=
                        rel_encoding.shape[0] * rel_encoding.shape[1])

        # normalize cosine similarity from [-1, 1] to [0, 1], and clip float precision errors
        normalized_similarities = (cosine_similarities + 1) / 2
        normalized_similarities = normalized_similarities.clamp(0, 1)

        return normalized_similarities

    def _encode_relations(self, entity_spans, size_embeddings, relations,
                          rel_masks, h, chunk_start):
        batch_size = relations.shape[0]

        # create chunks if necessary
        if relations.shape[1] > self._max_pairs:
            relations = relations[:, chunk_start:chunk_start + self._max_pairs]
            rel_masks = rel_masks[:, chunk_start:chunk_start + self._max_pairs]
            h = h[:, :relations.shape[1], :]

        # get pairs of entity candidate representations
        entity_pairs = util.batch_index(entity_spans, relations)
        entity_pairs = entity_pairs.view(batch_size, entity_pairs.shape[1], -1)

        # get corresponding size embeddings
        size_pair_embeddings = util.batch_index(size_embeddings, relations)
        size_pair_embeddings = size_pair_embeddings.view(
            batch_size, size_pair_embeddings.shape[1], -1)

        # relation context (context between entity candidate pair)
        rel_ctx = rel_masks * h
        rel_ctx = rel_ctx.max(dim=2)[0]

        # create relation candidate representations including context, max pooled entity candidate pairs
        # and corresponding size embeddings
        rel_repr = torch.cat([rel_ctx, entity_pairs, size_pair_embeddings],
                             dim=2)
        rel_repr = self.dropout(rel_repr)

        # encode relation candidates
        rel_encoding = self.rel_encoder(rel_repr)

        # normalize encoding to unit length for cosine similarity
        rel_encoding = f.normalize(rel_encoding, dim=2, p=2)

        return rel_encoding

    #TODO: Needs checking of relation entries
    def _filter_spans(self, entity_clf, entity_spans, entity_sample_mask,
                      entity_entries, ctx_size, type_key):
        batch_size = entity_clf.shape[0]
        entity_logits_max = entity_clf.argmax(
            dim=-1) * entity_sample_mask.long(
            )  # get entity type (including none)
        batch_relations = []
        batch_rel_masks = []
        batch_rel_sample_masks = []
        batch_rel_entries = []

        for i in range(batch_size):
            rels = []
            rel_masks = []
            sample_masks = []
            rel_entries = []

            # get spans classified as entities
            non_zero_indices = (entity_logits_max[i] != 0).nonzero().view(-1)
            non_zero_spans = entity_spans[i][non_zero_indices].tolist()
            non_zero_entries = [entity_entries[i][j] for j in non_zero_indices]
            non_zero_indices = non_zero_indices.tolist()

            # create relations and masks
            for n, (i1, s1) in enumerate(zip(non_zero_indices,
                                             non_zero_spans)):
                for m, (i2,
                        s2) in enumerate(zip(non_zero_indices,
                                             non_zero_spans)):
                    if i1 != i2:
                        rels.append((i1, i2))
                        phrase = "|{}| <TBD> |{}|".format(
                            non_zero_entries[n]["phrase"],
                            non_zero_entries[m]["phrase"])
                        rel_entries.append({
                            "phrase": phrase,
                            "type_string": "<TBD>",
                            type_key: -1
                        })
                        rel_masks.append(
                            sampling.create_rel_mask(s1, s2, ctx_size))
                        sample_masks.append(1)

            if not rels:
                # case: no more than two spans classified as entities
                batch_relations.append(torch.tensor([[0, 0]],
                                                    dtype=torch.long))
                batch_rel_masks.append(
                    torch.tensor([[0] * ctx_size], dtype=torch.bool))
                batch_rel_sample_masks.append(
                    torch.tensor([0], dtype=torch.bool))
                phrase = ""
                batch_rel_entries.append([{
                    "phrase": phrase,
                    "type_string": "<TBD>",
                    type_key: -1
                }])
            else:
                # case: more than two spans classified as entities
                batch_relations.append(torch.tensor(rels, dtype=torch.long))
                batch_rel_masks.append(torch.stack(rel_masks))
                batch_rel_sample_masks.append(
                    torch.tensor(sample_masks, dtype=torch.bool))
                batch_rel_entries.append(rel_entries)

        # stack
        device = self.rel_encoder.weight.device
        batch_relations = util.padded_stack(batch_relations).to(device)
        batch_rel_masks = util.padded_stack(batch_rel_masks).to(
            device).unsqueeze(-1)
        batch_rel_sample_masks = util.padded_stack(batch_rel_sample_masks).to(
            device).unsqueeze(-1)
        batch_rel_entries = util.padded_entries(batch_rel_entries)

        return batch_relations, batch_rel_masks, batch_rel_sample_masks, batch_rel_entries

    def forward(self, *args, mode="train", **kwargs):
        f_forward = {
            "train": self._forward_train,
            "eval": self._forward_eval,
            "encode": self._forward_encode
        }.get(mode)
        return f_forward(*args, **kwargs)
Beispiel #16
0
class TableF(BertPreTrainedModel):
    """ table filling model to jointly extract entities and relations """

    def __init__(self, config: BertConfig, tokenizer: BertTokenizer,
                 relation_labels: int, entity_labels: int,
                 entity_label_embedding: int,  att_hidden: int,
                 prop_drop: float, freeze_transformer: bool, device):
        super(TableF, self).__init__(config)

        # BERT model
        self.bert = BertModel(config)
        self._tokenizer = tokenizer
        self._device = device
        # layers
        self.entity_label_embedding = nn.Embedding(entity_labels , entity_label_embedding)
        self.entity_classifier = nn.Linear(config.hidden_size * 2 + entity_label_embedding, entity_labels)
        self.rel_classifier = MultiHeadAttention(relation_labels, config.hidden_size + entity_label_embedding, att_hidden , device)
        self.dropout = nn.Dropout(prop_drop)

        self._relation_labels = relation_labels
        self._entity_labels = entity_labels

        # weight initialization
        self.init_weights()

        if freeze_transformer:
            print("Freeze transformer weights")

            # freeze all transformer weights
            for param in self.bert.parameters():
                param.requires_grad = False


    def _forward_token(self, h: torch.tensor, token_mask: torch.tensor, 
                           gold_seq: torch.tensor, entity_mask: torch.tensor):

        num_steps = gold_seq.shape[-1]
        word_h = h.repeat(token_mask.shape[0], 1, 1) * token_mask.unsqueeze(-1)
        word_h_pooled = word_h.max(dim=1)[0]
        word_h_pooled = word_h_pooled[:num_steps+2].contiguous()
        word_h_pooled[0,:] = 0

        # curr word repr.
        curr_word_repr = word_h_pooled[1:-1].contiguous()

        # prev entity repr.
        prev_entity = torch.tril(entity_mask, diagonal=0)

        prev_entity_h = word_h_pooled.repeat(prev_entity.shape[0], 1, 1) * prev_entity.unsqueeze(-1)
        prev_entity_pooled = prev_entity_h.max(dim=1)[0]
        prev_entity_pooled = prev_entity_pooled[:num_steps].contiguous()

        # prev_label_embedding.
        prev_seq = torch.cat([torch.tensor([0]).to(self._device), gold_seq])
        prev_label = self.entity_label_embedding(prev_seq[:-1])

        entity_repr = torch.cat([curr_word_repr - 1, prev_entity_pooled - 1, prev_label], dim=1).unsqueeze(0)

        entity_repr = self.dropout(entity_repr)
        curr_entity_logits = self.entity_classifier(entity_repr)

        return curr_word_repr, curr_entity_logits


    def _forward_relation(self, h: torch.tensor,  entity_preds: torch.tensor, 
                          entity_mask: torch.tensor, is_eval: bool = False):


        entity_labels = entity_preds.unsqueeze(0)
        
        # entity repr.
        masks_no_cls_rep = entity_mask[1:-1, 1:-1]
        entity_repr = h.repeat(masks_no_cls_rep.shape[-1], 1, 1) * masks_no_cls_rep.unsqueeze(-1)
        entity_repr_pool = entity_repr.max(dim=1)[0]

        #entity_label repr.
        entity_label_embeddings = self.entity_label_embedding(entity_labels)        
#         entity_label_embeddings = torch.matmul(entity_preds, self.entity_label_embedding.weight)
       
        entity_label_repr = entity_label_embeddings.repeat(masks_no_cls_rep.shape[-1], 1, 1) * masks_no_cls_rep.unsqueeze(-1)
        entity_label_pool = entity_label_repr.max(dim=1)[0]

        
        
        rel_embedding = torch.cat([entity_repr_pool.unsqueeze(0) - 1, entity_label_pool.unsqueeze(0)], dim=2)
        rel_embedding = self.dropout(rel_embedding)
        rel_logits = self.rel_classifier(rel_embedding, rel_embedding, rel_embedding)

        return rel_logits

    def _forward_train(self, encodings: torch.tensor, context_mask: torch.tensor, 
                        token_mask: torch.tensor, gold_entity: torch.tensor, entity_masks: List[torch.tensor],
                      allow_rel: bool):  
        # get contextualized token embeddings from last transformer layer
        context_mask = context_mask.float()
        h = self.bert(input_ids=encodings, attention_mask=context_mask)[0] + 1

        batch_size = encodings.shape[0]
        all_entity_logits = []
        all_rel_logits = []

        for batch in range(batch_size): # every batch
            

            entity_mask = entity_masks[batch]

            word_h, curr_entity_logits = self._forward_token(h[batch], token_mask[batch], gold_entity[batch], entity_mask)

            entity_preds = torch.argmax(curr_entity_logits, dim=2)
#             entity_preds_soft = torch.softmax(curr_entity_logits, dim=2)
            
            diag_entity_mask = torch.zeros_like(entity_mask, dtype=torch.bool).to(self._device).fill_diagonal_(1)

            all_entity_logits.append(curr_entity_logits)
            # Relation classification.

            num_steps = gold_entity[batch].shape[-1]
            word_h = h[batch].repeat(token_mask[batch].shape[0], 1, 1) * token_mask[batch].unsqueeze(-1)
            word_h_pooled = word_h.max(dim=1)[0]
            word_h_pooled = word_h_pooled[:num_steps+2].contiguous()

            # curr word repr.
            curr_word_repr = word_h_pooled[1:-1].contiguous()
            
#             curr_rel_logits = self._forward_relation(curr_word_repr, entity_preds_soft , diag_entity_mask)
#             curr_rel_logits = self._forward_relation(curr_word_repr, entity_preds.squeeze(0) , diag_entity_mask)
            curr_rel_logits = self._forward_relation(curr_word_repr, gold_entity[batch] , entity_masks[batch])
            all_rel_logits.append(curr_rel_logits)

        if allow_rel:
            return all_entity_logits, all_rel_logits
        else:
            return all_entity_logits, []

    
    def _forward_eval(self, encodings: torch.tensor, context_mask: torch.tensor, token_mask: torch.tensor,
                     gold_entity: List[torch.tensor], gold_entity_mask: List[torch.tensor]):
                
        context_mask = context_mask.float()
        h = self.bert(input_ids=encodings, attention_mask=context_mask)[0] + 1

        batch_size = encodings.shape[0]
        all_entity_logits = []
        all_entity_scores = []
        all_entity_preds = []
        all_rel_logits = []

        for batch in range(batch_size): # every batch


            num_steps = token_mask[batch].sum(axis=1).nonzero().shape[0] - 2


            word_h = h[batch].repeat(token_mask[batch].shape[0], 1, 1) * token_mask[batch].unsqueeze(-1)
            word_h_pooled = word_h.max(dim=1)[0]
            word_h_pooled = word_h_pooled[:num_steps+2].contiguous()
            word_h_pooled[0,:] = 0

            # curr word repr.
            curr_word_reprs = word_h_pooled[1:-1].contiguous()

            entity_masks = torch.zeros((num_steps + 2, num_steps + 2), dtype = torch.bool).fill_diagonal_(1).to(self._device)
#             diag_entity_mask = torch.zeros((num_steps + 2, num_steps + 2), dtype = torch.bool).fill_diagonal_(1).to(self._device)

            entity_preds = torch.zeros((num_steps + 1, 1), dtype=torch.long).to(self._device)
            entity_logits = []
            entity_scores = torch.zeros((num_steps, 1), dtype=torch.float).to(self._device)

           # Entity classification.
            for i in range(num_steps): # no [CLS], no [SEP] 

                # curr word repr.
                curr_word_repr = curr_word_reprs[i].unsqueeze(0)
                # mask from previous entity token until current position.

                prev_mask = entity_masks[i, :]
     
                prev_label_repr = self.entity_label_embedding(entity_preds[i])
                
                prev_entity = word_h_pooled.unsqueeze(0) * prev_mask.unsqueeze(-1)
                prev_entity_pooled = prev_entity.max(dim=1)[0]

                curr_entity_repr = torch.cat([curr_word_repr - 1, prev_entity_pooled - 1, prev_label_repr], dim=1).unsqueeze(0)
                curr_entity_logits = self.entity_classifier(curr_entity_repr)
                entity_logits.append(curr_entity_logits.squeeze(1))

                curr_label = curr_entity_logits.argmax(dim=2).squeeze(0)
#                 print(i, curr_entity_logits, torch.softmax(curr_entity_logits, dim=2))
                entity_scores[i] += torch.softmax(curr_entity_logits, dim=2).max(dim=2)[0].squeeze(0)
                entity_preds[i+1] = curr_label

                istart =  (curr_label % 4 == 1) | (curr_label % 4 == 2) | (curr_label == 0)
                
                # update entity mask for the next time step            
                entity_masks[i+1] +=  (~istart) * prev_mask

                # update entity span info for all time-steps
                entity_masks[prev_mask.nonzero()[0].item():i+1, i+1] += (~istart).squeeze(0)

            all_entity_logits.append(torch.stack(entity_logits, dim=1))
            all_entity_scores.append(torch.t(entity_scores.squeeze(-1)))
            all_entity_preds.append(torch.t(entity_preds[1:].squeeze(-1)))
#             print(entity_preds.shape)
#             print(gold_entity[batch])
#             exit(0)
            # Relation classification.
        
#             curr_rel_logits = self._forward_relation(curr_word_reprs, torch.stack(entity_logits, dim=1), entity_masks , True)
            curr_rel_logits = self._forward_relation(curr_word_reprs, entity_preds[1:].squeeze(-1), entity_masks, True)
#             curr_rel_logits = self._forward_relation(curr_word_reprs, gold_entity[batch], gold_entity_mask[batch], True)
            all_rel_logits.append(curr_rel_logits)


        return all_entity_logits, all_entity_scores, all_entity_preds, all_rel_logits


    def forward(self, *args, evaluate=False, **kwargs):
        if not evaluate:
            return self._forward_train(*args, **kwargs)
        else:
            return self._forward_eval(*args, **kwargs)
Beispiel #17
0
class SynFueBERT(BertPreTrainedModel):
    """ Span-based model to jointly extract terms and relations """

    VERSION = '1.1'

    def __init__(self,
                 config: BertConfig,
                 cls_token: int,
                 relation_types: int,
                 term_types: int,
                 size_embedding: int,
                 prop_drop: float,
                 freeze_transformer: bool,
                 args,
                 max_pairs: int = 100,
                 beta: float = 0.3,
                 alpha: float = 1.0,
                 sigma: float = 1.0):
        super(SynFueBERT, self).__init__(config)

        # BERT model
        self.bert = BertModel(config)
        self.SynFue = Encoder.SynFueEncoder(self.bert, opt=args)
        self.cc = cross_attn.CA_module(config.hidden_size,
                                       config.hidden_size,
                                       1,
                                       dropout=1.0)

        # layers
        self.rel_classifier = nn.Linear(
            config.hidden_size * 6 + size_embedding * 2, relation_types)
        self.rel_classifier3 = nn.Linear(
            config.hidden_size * 6 + size_embedding * 3, relation_types)
        self.term_classifier = nn.Linear(
            config.hidden_size * 8 + size_embedding, term_types)
        self.dep_linear = nn.Linear(config.hidden_size, relation_types)
        self.size_embeddings = nn.Embedding(100, size_embedding)
        self.dropout = nn.Dropout(prop_drop)

        self._cls_token = cls_token
        self._relation_types = relation_types
        self._term_types = term_types
        self._max_pairs = max_pairs
        self._beta = beta
        self._alpha = alpha
        self._sigma = sigma

        # weight initialization
        self.init_weights()

        if freeze_transformer:
            print("Freeze transformer weights")

            # freeze all transformer weights
            for param in self.bert.parameters():
                param.requires_grad = False

    def _forward_train(self,
                       encodings: torch.tensor,
                       context_masks: torch.tensor,
                       term_masks: torch.tensor,
                       term_sizes: torch.tensor,
                       term_spans: torch.tensor,
                       term_types: torch.tensor,
                       relations: torch.tensor,
                       rel_masks: torch.tensor,
                       simple_graph: torch.tensor,
                       graph: torch.tensor,
                       relations3: torch.tensor,
                       rel_masks3: torch.tensor,
                       pair_mask: torch.tensor,
                       pos: torch.tensor = None):
        # get contextualized token embeddings from last transformer layer
        context_masks = context_masks.float()

        h, dep_output = self.SynFue(input_ids=encodings,
                                    input_masks=context_masks,
                                    simple_graph=simple_graph,
                                    graph=graph,
                                    pos=pos)

        batch_size = encodings.shape[0]

        # classify terms
        size_embeddings = self.size_embeddings(
            term_sizes)  # embed term candidate sizes
        term_clf, term_spans_pool = self._classify_terms(
            encodings, h, term_masks, size_embeddings)

        # classify relations
        h_large = h.unsqueeze(1).repeat(
            1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
        rel_clf = torch.zeros(
            [batch_size, relations.shape[1],
             self._relation_types]).to(self.rel_classifier.weight.device)

        # get span representation
        # dep_output = [batch size, seq_len, seq_len, feat_dim] -> [batch size, span num, span num, feat_dim]
        span_repr, mapping_list = self.get_span_repr(term_spans, term_types,
                                                     dep_output)
        cross_attn_span = self.cc(
            span_repr)  # batch size, seq_len, seq_len, feat_dim

        # obtain relation logits
        # chunk processing to reduce memory usage
        for i in range(0, relations.shape[1], self._max_pairs):
            # classify relation candidates
            chunk_rel_logits, chunk_rel_clf3, chunk_dep_score = self._classify_relations(
                cross_attn_span, term_spans_pool, size_embeddings, relations,
                rel_masks, h_large, i, relations3, rel_masks3, pair_mask,
                mapping_list)
            # apply sigmoid
            chunk_rel_clf = torch.sigmoid(chunk_rel_logits)
            chunk_rel_clf = self._alpha * chunk_rel_clf + self._beta * chunk_rel_clf3 + self._sigma * chunk_dep_score
            rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf

        max_clf = torch.full_like(rel_clf, torch.max(rel_clf).item())
        min_clf = torch.full_like(rel_clf, torch.min(rel_clf).item())
        inifite = torch.full_like(rel_clf, 1e-18)
        rel_clf = torch.div(rel_clf - min_clf + inifite,
                            max_clf - min_clf + inifite)

        return term_clf, rel_clf

    def _forward_eval(self,
                      encodings: torch.tensor,
                      context_masks: torch.tensor,
                      term_masks: torch.tensor,
                      term_sizes: torch.tensor,
                      term_spans: torch.tensor,
                      term_sample_masks: torch.tensor,
                      simple_graph: torch.tensor,
                      graph: torch.tensor,
                      pos: torch.tensor = None):
        # get contextualized token embeddings from last transformer layer
        context_masks = context_masks.float()
        h, dep_output = self.SynFue(input_ids=encodings,
                                    input_masks=context_masks,
                                    simple_graph=simple_graph,
                                    graph=graph,
                                    pos=pos)

        batch_size = encodings.shape[0]
        ctx_size = context_masks.shape[-1]

        # classify terms
        size_embeddings = self.size_embeddings(
            term_sizes)  # embed term candidate sizes
        term_clf, term_spans_pool = self._classify_terms(
            encodings, h, term_masks, size_embeddings)

        # ignore term candidates that do not constitute an actual term for relations (based on classifier)
        relations, rel_masks, rel_sample_masks, relations3, rel_masks3, \
        rel_sample_masks3, pair_mask, span_repr, mapping_list = self._filter_spans(term_clf, term_spans,
                                                                                   term_sample_masks,
                                                                                   ctx_size, dep_output)

        rel_sample_masks = rel_sample_masks.float().unsqueeze(-1)
        # h = self.rel_bert(input_ids=encodings, attention_mask=context_masks)[0]
        h_large = h.unsqueeze(1).repeat(
            1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
        rel_clf = torch.zeros(
            [batch_size, relations.shape[1],
             self._relation_types]).to(self.rel_classifier.weight.device)

        # get span representation
        cross_attn_span = self.cc(
            span_repr)  # batch size, seq_len, seq_len, feat_dim

        # obtain relation logits
        # chunk processing to reduce memory usage
        for i in range(0, relations.shape[1], self._max_pairs):
            # classify relation candidates
            chunk_rel_logits, chunk_rel_clf3, chunk_dep_score = self._classify_relations(
                cross_attn_span, term_spans_pool, size_embeddings, relations,
                rel_masks, h_large, i, relations3, rel_masks3, pair_mask,
                mapping_list)
            # apply sigmoid
            chunk_rel_clf = torch.sigmoid(chunk_rel_logits)
            chunk_rel_clf = self._alpha * chunk_rel_clf + self._beta * chunk_rel_clf3 + self._sigma * chunk_dep_score
            rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf

        max_clf = torch.full_like(rel_clf, torch.max(rel_clf).item())
        min_clf = torch.full_like(rel_clf, torch.min(rel_clf).item())
        inifite = torch.full_like(rel_clf, 1e-18)
        rel_clf = torch.div(rel_clf - min_clf + inifite,
                            max_clf - min_clf + inifite)

        rel_clf = rel_clf * rel_sample_masks  # mask

        # apply softmax
        term_clf = torch.softmax(term_clf, dim=2)

        return term_clf, rel_clf, relations

    def _classify_terms(self, encodings, h, term_masks, size_embeddings):
        # max pool term candidate spans
        m = (term_masks.unsqueeze(-1) == 0).float() * (-1e30)
        term_spans_pool = m + h.unsqueeze(1).repeat(1, term_masks.shape[1], 1,
                                                    1)
        term_spans_pool = term_spans_pool.max(dim=2)[0]

        # get cls token as candidate context representation
        term_ctx = get_token(h, encodings, self._cls_token)

        # get head and tail token representation
        m = term_masks.to(dtype=torch.long)
        k = torch.tensor(np.arange(0, term_masks.size(-1)), dtype=torch.long)
        k = k.unsqueeze(0).unsqueeze(0).repeat(term_masks.size(0),
                                               term_masks.size(1),
                                               1).to(m.device)
        mk = torch.mul(m, k)  # element-wise multiply
        mk_max = torch.argmax(mk, dim=-1, keepdim=True)
        mk_min = torch.argmin(mk, dim=-1, keepdim=True)
        mk = torch.cat([mk_min, mk_max], dim=-1)
        head_tail_rep = get_head_tail_rep(
            h, mk)  # [batch size, term_num, bert_dim*2)

        # create candidate representations including context, max pooled span and size embedding
        term_repr = torch.cat([
            term_ctx.unsqueeze(1).repeat(1, term_spans_pool.shape[1], 1),
            term_spans_pool, size_embeddings, head_tail_rep
        ],
                              dim=2)
        term_repr = self.dropout(term_repr)

        # classify term candidates
        term_clf = self.term_classifier(term_repr)

        return term_clf, term_spans_pool

    def _classify_relations(self, spans_matrix, term_spans_repr,
                            size_embeddings, relations, rel_masks, h,
                            chunk_start, relations3, rel_masks3, pair_mask,
                            rel_to_span):
        batch_size = relations.shape[0]
        feat_dim = spans_matrix.size(-1)

        # create chunks if necessary
        if relations.shape[1] > self._max_pairs:
            relations = relations[:, chunk_start:chunk_start + self._max_pairs]
            rel_masks = rel_masks[:, chunk_start:chunk_start + self._max_pairs]
            h = h[:, :relations.shape[1], :]

        def get_span_idx(mapping_list, idx1, idx2):
            for x in mapping_list:
                if idx1 == x[0][0] and idx2 == x[0][1]:
                    return x[1][0], x[1][1]

        batch_dep_score = []
        for i in range(batch_size):
            rela = relations[i]
            dep_score_list = []
            r_2_s = rel_to_span[i]
            for r in rela:
                i1, i2 = r[0].item(), r[1].item()
                idx1, idx2 = get_span_idx(r_2_s, i1, i2)
                try:
                    feat = spans_matrix[i][idx1][idx2]
                except:
                    print('Out of bundary', spans_matrix.size(), i, i1, i2)
                    feat = torch.zeros(feat_dim)
                dep_socre = self.dep_linear(feat).item()
                dep_score_list.append([dep_socre])
            batch_dep_score.append(dep_score_list)

        batch_dep_score = torch.sigmoid(
            torch.tensor(batch_dep_score).to(device=torch.device(
                'cuda' if torch.cuda.is_available() else 'cpu')))

        # get pairs of term candidate representations
        term_pairs = util.batch_index(term_spans_repr, relations)
        term_pairs = term_pairs.view(batch_size, term_pairs.shape[1], -1)

        # get corresponding size embeddings
        size_pair_embeddings = util.batch_index(size_embeddings, relations)
        size_pair_embeddings = size_pair_embeddings.view(
            batch_size, size_pair_embeddings.shape[1], -1)

        # relation context (context between term candidate pair)
        # mask non term candidate tokens
        m = ((rel_masks == 0).float() * (-1e30)).unsqueeze(-1)
        rel_ctx = m + h
        # max pooling
        rel_ctx = rel_ctx.max(dim=2)[0]
        # set the context vector of neighboring or adjacent term candidates to zero
        rel_ctx[rel_masks.to(torch.uint8).any(-1) == 0] = 0

        # create relation candidate representations including context, max pooled term candidate pairs
        # and corresponding size embeddings
        rel_repr = torch.cat([rel_ctx, term_pairs, size_pair_embeddings],
                             dim=2)
        rel_repr = self.dropout(rel_repr)
        # classify relation candidates
        chunk_rel_logits = self.rel_classifier(rel_repr)

        if relations3.shape[1] > self._max_pairs:
            relations3 = relations3[:,
                                    chunk_start:chunk_start + self._max_pairs]
            # rel_masks3 = rel_masks3[:, chunk_start:chunk_start + self._max_pairs]

        p_num = relations3.size(1)
        p_tris = relations3.size(2)

        relations3 = relations3.view(batch_size, -1, 3)

        # get three pairs candidata representations
        term_pairs3 = util.batch_index(term_spans_repr, relations3)
        term_pairs3 = term_pairs3.view(batch_size, term_pairs3.shape[1], -1)

        size_pair_embeddings3 = util.batch_index(size_embeddings, relations3)
        size_pair_embeddings3 = size_pair_embeddings3.view(
            batch_size, size_pair_embeddings3.shape[1], -1)

        rel_repr = torch.cat([term_pairs3, size_pair_embeddings3], dim=2)
        rel_repr = self.dropout(rel_repr)
        # classify relation candidates
        chunk_rel_logits3 = self.rel_classifier3(rel_repr)

        chunk_rel_clf3 = chunk_rel_logits3.view(batch_size, p_num, p_tris, -1)
        chunk_rel_clf3 = torch.sigmoid(chunk_rel_clf3)

        chunk_rel_clf3 = torch.sum(chunk_rel_clf3, dim=2)
        chunk_rel_clf3 = torch.sigmoid(chunk_rel_clf3)

        return chunk_rel_logits, chunk_rel_clf3, batch_dep_score

    def _filter_spans(self, term_clf, term_spans, term_sample_masks, ctx_size,
                      token_repr):
        batch_size = term_clf.shape[0]
        feat_dim = token_repr.size(-1)
        term_logits_max = term_clf.argmax(dim=-1) * term_sample_masks.long(
        )  # get term type (including none)
        batch_relations = []
        batch_rel_masks = []
        batch_rel_sample_masks = []

        batch_relations3 = []
        batch_rel_masks3 = []
        batch_rel_sample_masks3 = []
        batch_pair_mask = []

        batch_span_repr = []
        batch_rel_to_span = []

        for i in range(batch_size):
            rels = []
            rel_masks = []
            sample_masks = []
            rels3 = []
            rel_masks3 = []
            sample_masks3 = []

            span_repr = []
            rel_to_span = []

            # get spans classified as terms
            non_zero_indices = (term_logits_max[i] != 0).nonzero().view(-1)
            non_zero_spans = term_spans[i][non_zero_indices].tolist()
            non_zero_indices = non_zero_indices.tolist()

            # create relations and masks
            pair_mask = []
            for idx1, (i1,
                       s1) in enumerate(zip(non_zero_indices, non_zero_spans)):
                temp = []
                for idx2, (i2, s2) in enumerate(
                        zip(non_zero_indices, non_zero_spans)):
                    if i1 != i2:
                        rels.append((i1, i2))
                        rel_masks.append(
                            sampling.create_rel_mask(s1, s2, ctx_size))
                        sample_masks.append(1)
                        p_rels3 = []
                        p_masks3 = []
                        for i3, s3 in zip(non_zero_indices, non_zero_spans):
                            if i1 != i2 and i1 != i3 and i2 != i3:
                                p_rels3.append((i1, i2, i3))
                                p_masks3.append(
                                    sampling.create_rel_mask3(
                                        s1, s2, s3, ctx_size))
                                sample_masks3.append(1)
                        if len(p_rels3) > 0:
                            rels3.append(p_rels3)
                            rel_masks3.append(p_masks3)
                            pair_mask.append(1)
                        else:
                            rels3.append([(i1, i2, 0)])
                            rel_masks3.append([
                                sampling.create_rel_mask3(
                                    s1, s2, (0, 0), ctx_size)
                            ])
                            pair_mask.append(0)
                        rel_to_span.append([[i1, i2], [idx1, idx2]])
                    feat = torch.max(
                        token_repr[i, s1[0]:s1[-1] + 1,
                                   s2[0]:s2[-1] + 1, :].contiguous().view(
                                       -1, feat_dim),
                        dim=0)[0]
                    temp.append(feat)
                span_repr.append(temp)

            if not rels:
                # case: no more than two spans classified as terms
                batch_relations.append(torch.tensor([[0, 0]],
                                                    dtype=torch.long))
                batch_rel_masks.append(
                    torch.tensor([[0] * ctx_size], dtype=torch.bool))
                batch_rel_sample_masks.append(
                    torch.tensor([0], dtype=torch.bool))
                batch_span_repr.append(
                    torch.tensor([[[0] * feat_dim]], dtype=torch.float))
                batch_rel_to_span.append([[[0, 0], [0, 0]]])
            else:
                # case: more than two spans classified as terms
                batch_relations.append(torch.tensor(rels, dtype=torch.long))
                batch_rel_masks.append(torch.stack(rel_masks))
                batch_rel_sample_masks.append(
                    torch.tensor(sample_masks, dtype=torch.bool))
                batch_span_repr.append(
                    torch.stack([torch.stack(x) for x in span_repr]))
                batch_rel_to_span.append(rel_to_span)

            if not rels3:
                batch_relations3.append(
                    torch.tensor([[[0, 0, 0]]], dtype=torch.long))
                batch_rel_masks3.append(
                    torch.tensor([[0] * ctx_size], dtype=torch.bool))
                batch_rel_sample_masks3.append(
                    torch.tensor([0], dtype=torch.bool))
                batch_pair_mask.append(torch.tensor([0], dtype=torch.bool))

            else:
                max_tri = max([len(x) for x in rels3])
                # print(max_tri)
                for idx, r in enumerate(rels3):
                    r_len = len(r)
                    if r_len < max_tri:
                        rels3[idx].extend([rels3[idx][0]] * (max_tri - r_len))
                        rel_masks3[idx].extend([rel_masks3[idx][0]] *
                                               (max_tri - r_len))
                batch_relations3.append(torch.tensor(rels3, dtype=torch.long))
                batch_rel_masks3.append(
                    torch.stack([torch.stack(x) for x in rel_masks3]))
                batch_rel_sample_masks3.append(
                    torch.tensor(sample_masks3, dtype=torch.bool))
                batch_pair_mask.append(
                    torch.tensor(pair_mask, dtype=torch.bool))

        # stack
        device = self.rel_classifier.weight.device
        batch_relations = util.padded_stack(batch_relations).to(device)
        batch_rel_masks = util.padded_stack(batch_rel_masks).to(device)
        batch_rel_sample_masks = util.padded_stack(batch_rel_sample_masks).to(
            device)
        batch_span_repr = util.padded_stack(batch_span_repr).to(device)

        batch_relations3 = util.padded_stack(batch_relations3).to(device)
        batch_rel_masks3 = util.padded_stack(batch_rel_masks3).to(device)
        batch_rel_sample_masks3 = util.padded_stack(
            batch_rel_sample_masks3).to(device)
        batch_pair_mask = util.padded_stack(batch_pair_mask).to(device)

        return batch_relations, batch_rel_masks, batch_rel_sample_masks, \
               batch_relations3, batch_rel_masks3, batch_rel_sample_masks3, batch_pair_mask, batch_span_repr, batch_rel_to_span

    def get_span_repr(self, term_spans, term_types, token_repr):
        """

        :param term_spans: [batch size, span_num, 2]
        :param term_types: [batch size, span_num]
        :param token_repr: [batch size, seq_len, seq_len, feat_dim]
        :return: [batch size, span_num, span_num, feat_dim]
        """
        batch_size = term_spans.size(0)
        feat_dim = token_repr.size(-1)
        batch_span_repr = []
        batch_mapping_list = []
        for i in range(batch_size):
            span_repr = []
            mapping_list = []
            # get target spans  as aspect term or opinion term
            non_zero_indices = (term_types[i] != 0).nonzero().view(-1)
            non_zero_spans = term_spans[i][non_zero_indices].tolist()
            non_zero_indices = non_zero_indices.tolist()
            for x1, (i1, s1) in enumerate(zip(non_zero_indices,
                                              non_zero_spans)):
                temp = []
                for x2, (i2,
                         s2) in enumerate(zip(non_zero_indices,
                                              non_zero_spans)):
                    feat = torch.max(
                        token_repr[i, s1[0]:s1[-1] + 1,
                                   s2[0]:s2[-1] + 1, :].contiguous().view(
                                       -1, feat_dim),
                        dim=0)[0]
                    temp.append(feat)
                    mapping_list.append([[i1, i2], [x1, x2]])

                span_repr.append(torch.stack(temp))
            batch_span_repr.append(torch.stack(span_repr))
            batch_mapping_list.append(mapping_list)

        device = self.rel_classifier.weight.device
        batch_span_repr = util.padded_stack(batch_span_repr).to(device)

        return batch_span_repr, batch_mapping_list

    def forward(self, *args, evaluate=False, **kwargs):
        if not evaluate:
            return self._forward_train(*args, **kwargs)
        else:
            return self._forward_eval(*args, **kwargs)
Beispiel #18
0
class DocumentBert(BertPreTrainedModel):
    def __init__(self, bert_model_config: BertConfig):
        super(DocumentBert, self).__init__(bert_model_config)
        self.bert_patent = BertModel(bert_model_config)
        self.bert_tsd = BertModel(bert_model_config)

        for param in self.bert_patent.parameters():
            param.requires_grad = False

        for param in self.bert_tsd.parameters():
            param.requires_grad = False

        self.bert_batch_size = self.bert_patent.config.bert_batch_size
        self.dropout_patent = torch.nn.Dropout(
            p=bert_model_config.hidden_dropout_prob)
        self.dropout_tsd = torch.nn.Dropout(
            p=bert_model_config.hidden_dropout_prob)

        self.lstm_patent = torch.nn.LSTM(bert_model_config.hidden_size,
                                         bert_model_config.hidden_size)
        self.lstm_tsd = torch.nn.LSTM(bert_model_config.hidden_size,
                                      bert_model_config.hidden_size)

        self.output = torch.nn.Linear(bert_model_config.hidden_size * 2,
                                      out_features=1)

    def forward(self,
                patent_batch: torch.Tensor,
                tsd_batch: torch.Tensor,
                device='cuda'):

        #patent
        bert_output_patent = torch.zeros(
            size=(patent_batch.shape[0],
                  min(patent_batch.shape[1], self.bert_batch_size),
                  self.bert_patent.config.hidden_size),
            dtype=torch.float,
            device=device)
        for doc_id in range(patent_batch.shape[0]):
            bert_output_patent[
                doc_id][:self.bert_batch_size] = self.dropout_patent(
                    self.bert_patent(
                        patent_batch[doc_id][:self.bert_batch_size, 0],
                        token_type_ids=patent_batch[doc_id]
                        [:self.bert_batch_size, 1],
                        attention_mask=patent_batch[doc_id]
                        [:self.bert_batch_size, 2])[1])
        output_patent, (_, _) = self.lstm_patent(
            bert_output_patent.permute(1, 0, 2))
        last_layer_patent = output_patent[-1]

        #tsd

        bert_output_tsd = torch.zeros(size=(tsd_batch.shape[0],
                                            min(tsd_batch.shape[1],
                                                self.bert_batch_size),
                                            self.bert_tsd.config.hidden_size),
                                      dtype=torch.float,
                                      device=device)
        for doc_id in range(tsd_batch.shape[0]):
            bert_output_tsd[doc_id][:self.bert_batch_size] = self.dropout_tsd(
                self.bert_tsd(
                    tsd_batch[doc_id][:self.bert_batch_size, 0],
                    token_type_ids=tsd_batch[doc_id][:self.bert_batch_size, 1],
                    attention_mask=tsd_batch[doc_id][:self.bert_batch_size,
                                                     2])[1])
        output_tsd, (_, _) = self.lstm_tsd(bert_output_tsd.permute(1, 0, 2))
        last_layer_tsd = output_tsd[-1]

        x = torch.cat([last_layer_patent, last_layer_tsd], dim=1)
        prediction = torch.nn.functional.sigmoid(self.output(x))

        assert prediction.shape[0] == patent_batch.shape[0]
        return prediction

    def freeze_bert_encoder(self):
        for param in self.bert_patent.parameters():
            param.requires_grad = False
        for param in self.bert_tsd.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert_patent.parameters():
            param.requires_grad = True
        for param in self.bert_tsd.parameters():
            param.requires_grad = True

    def unfreeze_bert_encoder_last_layers(self):
        for name, param in self.bert_patent.named_parameters():
            if "encoder.layer.11" in name or "pooler" in name:
                param.requires_grad = True
        for name, param in self.bert_tsd.named_parameters():
            if "encoder.layer.11" in name or "pooler" in name:
                param.requires_grad = True

    def unfreeze_bert_encoder_pooler_layer(self):
        for name, param in self.bert_patent.named_parameters():
            if "pooler" in name:
                param.requires_grad = True
        for name, param in self.bert_tsd.named_parameters():
            if "pooler" in name:
                param.requires_grad = True
Beispiel #19
0
class Bert(nn.Module):
    def __init__(self, config, num=0):
        super(Bert, self).__init__()
        model_config = BertConfig()
        model_config.vocab_size = config.vocab_size
        # 计算loss的方法
        self.loss_method = config.loss_method
        self.multi_drop = config.multi_drop

        self.bert = BertModel(model_config)
        if config.requires_grad:
            for param in self.bert.parameters():
                param.requires_grad = True
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.hidden_size = config.hidden_size[num]
        if self.loss_method in ['binary', 'focal_loss', 'ghmc']:
            self.classifier = nn.Linear(self.hidden_size, 1)
        else:
            self.classifier = nn.Linear(self.hidden_size, self.num_labels)

        self.classifier.apply(self._init_weights)
        self.bert.apply(self._init_weights)

    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=0.02)

    def forward(self,
                input_ids=None,
                token_type_ids=None,
                attention_mask=None,
                labels=None):
        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)
        pooled_output = outputs[1]
        out = None
        loss = 0
        for i in range(self.multi_drop):
            output = self.dropout(pooled_output)
            if labels is not None:
                if i == 0:
                    out = self.classifier(output)
                    loss = compute_loss(out,
                                        labels,
                                        loss_method=self.loss_method)
                else:
                    temp_out = self.classifier(output)
                    temp_loss = compute_loss(temp_out,
                                             labels,
                                             loss_method=self.loss_method)
                    out = out + temp_out
                    loss = loss + temp_loss

        loss = loss / self.multi_drop
        out = out / self.multi_drop

        if self.loss_method in ['binary']:
            out = torch.sigmoid(out).flatten()

        return out, loss
class MyBertForSequenceClassification(BertPreTrainedModel):

    num_labels = 4
    num_tasks = 20

    def __init__(self, config):
        super(MyBertForSequenceClassification, self).__init__(config)
        self.num_labels = MyBertForSequenceClassification.num_labels
        self.num_tasks = MyBertForSequenceClassification.num_tasks

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # 创建20个分类任务,每个任务共享输入: BertModel 的输出最后一层的 [CLS] 的 pooler_output
        # 但是源程序也说了,使用 [cls] 的 pooler_output is usually *not* a good summary
        # of the semantic content of the input, you're often better with averaging or pooling
        # the sequence of hidden-states for the whole input sequence.
        # module_list = []
        # for _ in range(self.num_tasks):
        # module_list.append(nn.Linear(config.hidden_size, self.num_labels))
        # self.classifier = nn.ModuleList(module_list)
        self.classifier = nn.ModuleList([
            nn.Linear(config.hidden_size, self.num_labels)
            for _ in range(self.num_tasks)
        ])

        self.init_weights()

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None):
        """forward

        :param input_ids:
        :param labels: 给定的形式是 [batch, num_tasks]
        """

        outputs = self.bert(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)

        #         logits = []
        # for i in range(self.num_tasks):
        #             logits.append(self.classifier[i](pooled_output))

        logits = [
            self.classifier[i](pooled_output) for i in range(self.num_tasks)
        ]

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # 这个要放在gpu 上,很容易遗忘,从而 loss.backward()的时候出错
            loss = torch.tensor([0.]).to(device)
            for i in range(self.num_tasks):
                loss += loss_fct(logits[i], labels[:, i])
            return loss
        else:
            # 用于 验证集和测试集 标签的预测, 维度是[num_tasks, batch, num_labels]
            logits = [logit.cpu().numpy() for logit in logits]
            return torch.tensor(logits)

    # 可以选择 冻结 BertModel 中的参数,也可以不冻结,在 multiLabels classification 中不冻结,不调用该函数即可。这里给出了一个冻结的示范
    def freeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        for param in self.bert.parameters():
            param.requires_grad = True
class SpET(BertPreTrainedModel):
    """ Span-based model to extract entities """
    def __init__(self,
                 config: BertConfig,
                 cls_token: int,
                 relation_types: int,
                 entity_types: int,
                 size_embedding: int,
                 prop_drop: float,
                 freeze_transformer: bool,
                 max_pairs: int = 100,
                 feature_enhancer: str = "pass"):
        super(SpET, self).__init__(config)

        # BERT model
        self.bert = BertModel(config)

        # layers
        self.feature_enhancer = fe.get_feature_enhancer(feature_enhancer)(
            config.hidden_size, config.hidden_size)
        self.entity_classifier = nn.Linear(
            config.hidden_size * 2 + size_embedding, entity_types)
        self.size_embeddings = nn.Embedding(100, size_embedding)
        self.dropout = nn.Dropout(prop_drop)

        self._cls_token = cls_token
        self._entity_types = entity_types
        self._max_pairs = max_pairs

        # weight initialization
        self.init_weights()

        if freeze_transformer or feature_enhancer not in {
                "pass", "transformer"
        }:
            print("Freeze transformer weights")

            # freeze all transformer weights
            for param in self.bert.parameters():
                param.requires_grad = False

    def _forward_train(self, encodings: torch.tensor,
                       context_mask: torch.tensor, entity_masks: torch.tensor,
                       entity_sizes: torch.tensor):
        # get contextualized token embeddings from last transformer layer
        context_mask = context_mask.float()
        h_bert = self.bert(input_ids=encodings, attention_mask=context_mask)[0]

        # enhance hidden features
        orig_shape = h_bert.shape
        h = self.feature_enhancer.prepare_input(h_bert, context_mask)
        h = self.feature_enhancer(h)
        h = self.feature_enhancer.prepare_output(h, orig_shape)

        entity_masks = entity_masks.float()
        batch_size = encodings.shape[0]

        # classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_clf, entity_spans_pool = self._classify_entities(
            encodings, h, entity_masks, size_embeddings)

        return entity_clf

    def _forward_eval(self,
                      encodings: torch.tensor,
                      context_mask: torch.tensor,
                      entity_masks: torch.tensor,
                      entity_sizes: torch.tensor,
                      entity_spans: torch.tensor = None,
                      entity_sample_mask: torch.tensor = None):
        # get contextualized token embeddings from last transformer layer
        context_mask = context_mask.float()
        h = self.bert(input_ids=encodings, attention_mask=context_mask)[0]

        # enhance hidden features
        orig_shape = h.shape
        h = self.feature_enhancer.prepare_input(h, context_mask)
        h = self.feature_enhancer(h)
        h = self.feature_enhancer.prepare_output(h, orig_shape)

        entity_masks = entity_masks.float()
        batch_size = encodings.shape[0]
        ctx_size = context_mask.shape[-1]

        # classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_clf, entity_spans_pool = self._classify_entities(
            encodings, h, entity_masks, size_embeddings)

        # apply softmax
        entity_clf = torch.softmax(entity_clf, dim=2)

        return entity_clf

    def _classify_entities(self, encodings, h, entity_masks, size_embeddings):
        # max pool entity candidate spans
        entity_spans_pool = entity_masks.unsqueeze(-1) * h.unsqueeze(1)
        entity_spans_pool = entity_spans_pool.max(dim=2)[0]

        # get cls token as candidate context representation
        entity_ctx = get_token(h, encodings, self._cls_token)

        # create candidate representations including context, max pooled span and size embedding
        entity_repr = torch.cat([
            entity_ctx.unsqueeze(1).repeat(1, entity_spans_pool.shape[1], 1),
            entity_spans_pool, size_embeddings
        ],
                                dim=2)
        entity_repr = self.dropout(entity_repr)

        # classify entity candidates
        entity_clf = self.entity_classifier(entity_repr)

        return entity_clf, entity_spans_pool

    def forward(self, *args, evaluate=False, **kwargs):
        if not evaluate:
            return self._forward_train(*args, **kwargs)
        else:
            return self._forward_eval(*args, **kwargs)
Beispiel #22
0
class ExampleIntentBertModel(torch.nn.Module):
    def __init__(self,
                 model_name_or_path: str,
                 dropout: float,
                 num_intent_labels: int,
                 use_observers: bool = False):
        super(ExampleIntentBertModel, self).__init__()
        #self.bert_model = BertModel.from_pretrained(model_name_or_path)
        self.bert_model = BertModel(
            BertConfig.from_pretrained(model_name_or_path,
                                       output_attentions=True))

        self.dropout = Dropout(dropout)
        self.num_intent_labels = num_intent_labels
        self.use_observers = use_observers
        self.all_outputs = []

    def encode(self, input_ids: torch.tensor, attention_mask: torch.tensor,
               token_type_ids: torch.tensor):
        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(
            2).repeat(1, 1, input_ids.size(1), 1)
        extended_attention_mask = extended_attention_mask.to(
            dtype=next(self.bert_model.parameters()).dtype)

        # Combine attention maps
        padding = (input_ids.unsqueeze(1) == 0).unsqueeze(-1)
        padding = padding.repeat(1, 1, 1, padding.size(-2))

        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.bert_model.embeddings(
            input_ids, position_ids=None, token_type_ids=token_type_ids)
        encoder_outputs = self.bert_model.encoder(
            embedding_output,
            extended_attention_mask,
            head_mask=[None] * self.bert_model.config.num_hidden_layers)

        if encoder_outputs[0].size(0) == 1:
            pass
            #self.all_outputs.append(torch.cat(encoder_outputs[1], dim=0).cpu())
            #self.all_outputs.append(encoder_outputs[0][:, -20:].cpu())
        sequence_output = encoder_outputs[0]

        if self.use_observers:
            pooled_output = sequence_output[:, -20:].mean(dim=1)
        else:
            pooled_output = self.bert_model.pooler(sequence_output)

        return pooled_output

    def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor,
                token_type_ids: torch.tensor, intent_label: torch.tensor,
                example_input: torch.tensor, example_mask: torch.tensor,
                example_token_types: torch.tensor,
                example_intents: torch.tensor):
        example_pooled_output = self.encode(input_ids=example_input,
                                            attention_mask=example_mask,
                                            token_type_ids=example_token_types)

        pooled_output = self.encode(input_ids=input_ids,
                                    attention_mask=attention_mask,
                                    token_type_ids=token_type_ids)

        pooled_output = self.dropout(pooled_output)
        probs = torch.softmax(pooled_output.mm(example_pooled_output.t()),
                              dim=-1)

        intent_probs = 1e-6 + torch.zeros(
            probs.size(0), self.num_intent_labels).cuda().scatter_add(
                -1,
                example_intents.unsqueeze(0).repeat(probs.size(0), 1), probs)

        # Compute losses if labels provided
        if intent_label is not None:
            loss_fct = NLLLoss()
            intent_lp = torch.log(intent_probs)
            intent_loss = loss_fct(intent_lp.view(-1, self.num_intent_labels),
                                   intent_label.type(torch.long))
        else:
            intent_loss = torch.tensor(0)

        return intent_probs, intent_loss
Beispiel #23
0
class BertClassification(BertPreTrainedModel):
    def __init__(self, config, freeze_bert = False):
        super(BertClassification, self).__init__(config)
        self.hidden_size = config.hidden_size
        #self.lstm_hidden_size = 256
        self.hidden_dropout_prob = config.hidden_dropout_prob
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        if freeze_bert:
            for p in self.bert.parameters():
                p.requires_grad = False
        self.dropout = nn.Dropout(self.hidden_dropout_prob)
        #self.bilstm = nn.LSTM(self.hidden_size, self.lstm_hidden_size, bidirectional=True, batch_first=True)
        self.classifier = nn.Linear(self.hidden_size, self.num_labels)
        #self.classifier = nn.Linear(self.hidden_size, self.num_labels)

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        #pooled_output = outputs[1]
        hidden_states = outputs[0]
        pooled_output = hidden_states.mean(-2)
        pooled_output = self.dropout(pooled_output)
        #hidden_states = self.dropout(hidden_states)
        #lstm_hidden_states, _ = self.bilstm(hidden_states)
        #lstm_hidden_states = self.dropout(lstm_hidden_states)
        #pooled_output = lstm_hidden_states.mean(-2)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        output = (logits,) + outputs[2:]
        return ((loss,) + output) if loss is not None else output
Beispiel #24
0
class BertSum(pl.LightningModule):
    def __init__(self, conf=None):
        super().__init__()
        # save conf, accessible in self.hparams.conf
        self.save_hyperparameters()

        # MODEL
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        # change hidden layers from 12 to 10 (memory limit)
        bert_config = BertConfig(self.bert.config.vocab_size, num_hidden_layers=10)
        self.bert = BertModel(bert_config)
        # change embeddings to enable longer input sequences
        pos_embeddings = nn.Embedding(self.hparams.conf.dataset.setup.max_length_input_context, self.bert.config.hidden_size)
        pos_embeddings.weight.data[:512] = self.bert.embeddings.position_embeddings.weight.data
        pos_embeddings.weight.data[512:] = self.bert.embeddings.position_embeddings.weight.data[-1][None, :].repeat(self.hparams.conf.dataset.setup.max_length_input_context - 512, 1)
        self.bert.embeddings.position_embeddings = pos_embeddings
        # classification layers
        self.linear1 = nn.Linear(self.bert.config.hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

        # TODO: model to encode answer (and fix)

        # metrics
        self.evaluation_metrics = torch.nn.ModuleDict({
            'train_metrics': self._init_metrics(mode='train'),
            'val_metrics': self._init_metrics(mode='val'),
            'test_metrics': self._init_metrics(mode='test'), })
        # loss
        self.loss = nn.BCEWithLogitsLoss(reduction='sum')

    def forward(self, x, **kwargs):
        return self.bert(x, **kwargs)

    # optimizer
    def configure_optimizers(self):
        params = list(self.bert.parameters()) + list(self.linear1.parameters())
        return torch.optim.Adam(params, lr=self.hparams.conf.training.lr)

    # TRAIN
    def training_step(self, batch, batch_idx):
        loss, metrics = self._get_loss(batch)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        metrics = self._log_metrics(metrics, mode='train')
        return {'loss': loss, **metrics}

    def validation_step(self, batch, batch_idx):
        loss, metrics = self._get_loss(batch, mode='val')
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        metrics = self._log_metrics(metrics, mode='val')
        return {'val_loss': loss, **metrics}

    def _get_loss(self, batch, mode='train'):
        src_ids, src_mask, seg_ids, seg_idx, seg_idx_mask, tgt_labels = batch['input_ids'], batch['input_attention_mask'], batch['segment_ids'], batch['segment_idx'], batch['segment_idx_mask'], batch['target_labels']
        position_ids = torch.tensor(list(range(self.hparams.conf.dataset.setup.max_length_input_context))).to('cuda')

        # bert
        last_hidden_state, pooler_output = self.bert(src_ids, attention_mask=src_mask, token_type_ids=seg_ids, position_ids=position_ids)

        # select sentence representation embeddings
        seg_idx = seg_idx.unsqueeze(dim=2).repeat(1, 1, last_hidden_state.shape[-1])
        sent_embeddings = last_hidden_state.gather(dim=1, index=seg_idx)
        # filter mask
        mask_idx = torch.nonzero(seg_idx_mask, as_tuple=True)
        sent_embeddings = sent_embeddings[mask_idx[0], mask_idx[1], :]
        tgt_labels = tgt_labels[mask_idx[0], mask_idx[1]]

        # classifier
        logits = self.linear1(sent_embeddings)
        logits = logits.squeeze().float()

        # loss
        tgt_labels = tgt_labels.float()
        loss = self.loss(logits, tgt_labels)
        loss = loss / torch.sum(seg_idx_mask)

        # metrics
        preds = self.sigmoid(logits)
        preds = torch.stack([1 - preds, preds], dim=1)
        metrics = self._get_metrics(preds, tgt_labels, mode)

        return loss, metrics

    def validation_epoch_end(self, outputs):
        metrics = self._compute_metrics(mode='val')
        self._log_precision_recall_curve(metrics)
        self._log_confusion_matrix(metrics)

    # TEST
    def test_step(self, batch, batch_idx):
        loss, metrics = self._get_loss(batch, mode='test')
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        metrics = self._log_metrics(metrics, mode='test')
        return {'test_loss': loss, **metrics}

    def test_epoch_end(self, outputs):
        metrics = self._compute_metrics(mode='test')
        self._log_confusion_matrix(metrics)

    # progress bar
    def get_progress_bar_dict(self):
        tqdm_dict = super().get_progress_bar_dict()
        if 'v_num' in tqdm_dict:
            del tqdm_dict['v_num']
        return tqdm_dict

    # METRICS
    def _log_metrics(self, metrics, mode):
        metrics = {key + (f'_{mode}' if mode != 'train' else ''): value for key, value in metrics.items() if key != 'confusion_matrix' and key != 'precision_recall_curve'}
        for k, m in metrics.items():
            self.log(k, m, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return metrics

    def _log_precision_recall_curve(self, metrics):
        precision, recall, thresholds = metrics['precision_recall_curve']
        plt.plot(recall.cpu().numpy(), precision.cpu().numpy(), 'ro')
        plt.xlabel('recall')
        plt.ylabel('precision')
        self.logger.experiment.log({f'precision_recall_curve_{self.current_epoch}': wandb.Image(plt)})
        plt.clf()
        f1 = 2 * (precision * recall) / (precision + recall)
        data = {
            'precision': precision.cpu().numpy().tolist(),
            'recall': recall.cpu().numpy().tolist(),
            'f1': f1.cpu().numpy().tolist(),
            'thresholds': thresholds.cpu().numpy().tolist(),
            'argmax': torch.argmax(f1).cpu().numpy().tolist()
        }

        with open(f'precision_recall_{self.current_epoch}', 'wb') as file:
            pickle.dump(data, file)

    def _log_confusion_matrix(self, metrics):
        confusion_matrix = metrics['confusion_matrix']
        heatmap = sns.heatmap(confusion_matrix.cpu().numpy(), annot=True, fmt='g')
        figure = heatmap.get_figure()
        self.logger.experiment.log({f'confusion_matrix_{self.current_epoch}': wandb.Image(figure)})
        plt.clf()

    def _get_metrics(self, prediction, target, mode='train'):
        metrics = {}
        for name, metric in self.evaluation_metrics[mode + '_metrics'].items():
            metrics[name] = metric(prediction, target)
        return metrics

    def _compute_metrics(self, mode='train'):
        metrics = {}
        for name, metric in self.evaluation_metrics[mode + '_metrics'].items():
            metrics[name] = metric.compute()
        return metrics

    @staticmethod
    def _init_metrics(mode):
        metrics = torch.nn.ModuleDict({
            'accuracy': pl.metrics.Accuracy(),
            'f1': F1(),
            'precision': Precision(),
            'recall': Recall(),
        })
        if mode != 'train':
            metrics['confusion_matrix'] = pl.metrics.ConfusionMatrix(num_classes=2)
            metrics['precision_recall_curve'] = pl.metrics.PrecisionRecallCurve(pos_label=1)

        return metrics
Beispiel #25
0
class HIBERT(BertPreTrainedModel):

    def __init__(self, 
                    config, 
                    n_classes, 
                    add_linear=None, 
                    attn_bias=False, 
                    freeze_layer_count=-1):

        super(HIBERT, self).__init__(config)

        self.n_classes = n_classes
        self.add_linear = add_linear
        self.attn_bias = attn_bias
        self.freeze_layer_count = freeze_layer_count
        self.attn_weights = None

        # Define model objects
        self.bert = BertModel(config, add_pooling_layer=False)
        self.fc_in_size = self.bert.config.hidden_size

        # Control layer freezing
        if freeze_layer_count == -1:
            # freeze all bert layers
            for param in self.bert.parameters():
                param.requires_grad = False

        if freeze_layer_count == -2:
            # unfreeze all bert layers
            for param in self.bert.parameters():
                param.requires_grad = True

        if freeze_layer_count > 0:
            # freeze embedding layer
            for param in self.bert.embeddings.parameters():
                param.requires_grad = False

            # freeze the top `freeze_layer_count` encoder layers
            for layer in self.bert.encoder.layer[:freeze_layer_count]:
                for param in layer.parameters():
                    param.requires_grad = False

        # Attention pooling layer
        self.attention = Attention(dim=self.bert.config.hidden_size, attn_bias=self.attn_bias)

        # fully connected layers
        if self.add_linear is None:
            self.fc = nn.ModuleList([nn.Linear(self.fc_in_size, self.n_classes)])
        
        else: 
            self.fc_layers = [self.fc_in_size] + self.add_linear
            self.fc = nn.ModuleList([
                LinearBlock(self.fc_layers[i], self.fc_layers[i+1])
                for i in range(len(self.fc_layers) - 1)
                ])
            
            # no relu after last dense (cannot use LinearBlock)
            self.fc.append(nn.Linear(self.fc_layers[-1], self.n_classes))

    def forward(self, input_ids, attention_mask, n_chunks):

        # Bert transformer (take sequential output)
        output, _ = self.bert(
            input_ids = input_ids, 
            attention_mask = attention_mask, 
            return_dict=False
            )

        # group chunks together
        chunks = output.split_with_sizes(n_chunks.tolist())

        # loop through attention layer (need a loop as there are different sized chunks)
        # collect attention output and attention weights for each call of attention
        after_attn_list = []
        self.attn_weights = []

        for chunk in chunks:
            after_attn_list.append(self.attention(chunk.view(1, -1, self.bert.config.hidden_size)))
            self.attn_weights.append(self.attention.attn_weights)

        output = torch.cat(after_attn_list)

        # fully connected layers
        for fc in self.fc:
            output = fc(output)

        return output
class UnStructuredModel:

    def __init__(self, model_name, max_length, stride):
        self.model_name = model_name
        self.tokenizer = None
        self.model = None
        self.max_length = max_length
        self.stride = stride
        if model_name == 'bert-base-uncased':
            configuration = BertConfig()
            self.tokenizer = BertTokenizer.from_pretrained(self.model_name)
            self.model = BertModel(configuration).from_pretrained(self.model_name)
            self.model.to(device)
            self.model.eval()
            for param in self.model.parameters():
                param.requires_grad = False
            #self.model.bert.embeddings.requires_grad = False


    def padTokens(self, tokens):
        if len(tokens)<self.max_length:
            tokens = tokens + ["[PAD]" for i in range(self.max_length - len(tokens))]
        return tokens

    def getEmbedding(self, text, if_pool=True, pooling_type="mean", batchsize = 1):
        tokens = self.tokenizer.tokenize(text)
        tokenized_array = self.tokenizeText(tokens)
        embeddingTensorsList = []
        print(len(tokenized_array))
        tensor = torch.zeros([1, 768], device=device)
        count = 0
        if len(tokenized_array)>batchsize:
            for i in range(0, len(tokenized_array), batchsize):
                current_tokens = tokenized_array[i:min(i+batchsize,len(tokenized_array))]
                token_ids = torch.tensor(current_tokens).to(device)
                seg_ids=[[0 for _ in range(len(tokenized_array[0]))] for _ in range(len(current_tokens))]
                seg_ids   = torch.tensor(seg_ids).to(device)
                hidden_reps, cls_head = self.model(token_ids, token_type_ids = seg_ids)
                cls_head.to(device)
                clas_head = cls_head.detach
                if if_pool and pooling_type=="mean":
                    tensor = tensor.add(torch.sum(cls_head, dim=0))
                    count +=cls_head.shape[0]
                else:
                    embeddingTensorsList.append(cls_head)
                del cls_head, hidden_reps
            if if_pool and pooling_type=="mean" and count>0:
                embedding = torch.div(tensor, count)
            elif not if_pool:
                embedding = torch.cat(embeddingTensorsList, dim=0)
            else:
                raise NotImplementedError()

        else:
            token_ids = torch.tensor(tokenized_array).to(device)
            seg_ids=[[0 for _ in range(len(tokenized_array[0]))] for _ in range(len(tokenized_array))]
            seg_ids   = torch.tensor(seg_ids).to(device)
            hidden_reps, cls_head = self.model(token_ids, token_type_ids = seg_ids)
            cls_head.to(device)
            cls_head.requires_grad = False
            if if_pool and pooling_type=="mean":
                embedding = torch.div(torch.sum(cls_head, dim=0), cls_head.shape[0])
            elif not if_pool:
                embedding = cls_head
            else:
                raise NotImplementedError()
            del cls_head, hidden_reps
        return embedding

    def tokenizeText(self, tokens):
        tokens_array = []
        #window_movement_tokens =  max_length - stride
        for i in range(0, len(tokens), self.stride):
            if i+self.max_length<len(tokens):
                curr_tokens = ["[CLS]"] + tokens[i:i+self.max_length] + ["[SEP]"]
            else:
                padded_tokens = self.padTokens(tokens[i:i+self.max_length])
                curr_tokens = ["[CLS]"] + padded_tokens + ["[SEP]"]
            curr_tokens = self.tokenizer.convert_tokens_to_ids(curr_tokens)
            tokens_array.append(curr_tokens)
        return tokens_array
def train(config, bert_config, train_path, dev_path, rel2id, id2rel,
          tokenizer):
    if os.path.exists(config.output_dir) is False:
        os.makedirs(config.output_dir, exist_ok=True)
    if os.path.exists('./data/train_file.pkl'):
        train_data = pickle.load(open("./data/train_file.pkl", mode='rb'))
    else:
        train_data = data.load_data(train_path, tokenizer, rel2id, num_rels)
        pickle.dump(train_data, open("./data/train_file.pkl", mode='wb'))
    dev_data = json.load(open(dev_path))
    for sent in dev_data:
        data.to_tuple(sent)
    data_manager = data.SPO(train_data)
    train_sampler = RandomSampler(data_manager)
    train_data_loader = DataLoader(data_manager,
                                   sampler=train_sampler,
                                   batch_size=config.batch_size,
                                   drop_last=True)
    num_train_steps = int(
        len(data_manager) / config.batch_size) * config.max_epoch

    if config.bert_pretrained_model is not None:
        logger.info('load bert weight')
        Bert_model = BertModel.from_pretrained(config.bert_pretrained_model,
                                               config=bert_config)
    else:
        logger.info('random initialize bert model')
        Bert_model = BertModel(config=bert_config).init_weights()
    Bert_model.to(device)
    submodel = sub_model(config).to(device)
    objmodel = obj_model(config).to(device)

    loss_fuc = nn.BCELoss(reduction='none')
    params = list(Bert_model.parameters()) + list(
        submodel.parameters()) + list(objmodel.parameters())
    optimizer = AdamW(params, lr=config.lr)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(data_manager))
    logger.info("  Num Epochs = %d", config.max_epoch)
    logger.info("  Total train batch size = %d", config.batch_size)
    logger.info("  Total optimization steps = %d", num_train_steps)
    logger.info("  Logging steps = %d", config.print_freq)
    logger.info("  Save steps = %d", config.save_freq)

    global_step = 0
    Bert_model.train()
    submodel.train()
    objmodel.train()

    for _ in range(config.max_epoch):
        optimizer.zero_grad()
        epoch_itorator = tqdm(train_data_loader, disable=None)
        for step, batch in enumerate(epoch_itorator):
            batch = tuple(t.to(device) for t in batch)
            input_ids, segment_ids, input_masks, sub_positions, sub_heads, sub_tails, obj_heads, obj_tails = batch

            bert_output = Bert_model(input_ids, input_masks, segment_ids)[0]
            pred_sub_heads, pred_sub_tails = submodel(
                bert_output)  # [batch_size, seq_len, 1]
            pred_obj_heads, pred_obj_tails = objmodel(bert_output,
                                                      sub_positions)

            # 计算loss
            mask = input_masks.view(-1)

            # loss1
            sub_heads = sub_heads.unsqueeze(-1)  # [batch_szie, seq_len, 1]
            sub_tails = sub_tails.unsqueeze(-1)

            loss1_head = loss_fuc(pred_sub_heads, sub_heads).view(-1)
            loss1_head = torch.sum(loss1_head * mask) / torch.sum(mask)

            loss1_tail = loss_fuc(pred_sub_tails, sub_tails).view(-1)
            loss1_tail = torch.sum(loss1_tail * mask) / torch.sum(mask)

            loss1 = loss1_head + loss1_tail

            # loss2
            loss2_head = loss_fuc(pred_obj_heads,
                                  obj_heads).view(-1, obj_heads.shape[-1])
            loss2_head = torch.sum(
                loss2_head * mask.unsqueeze(-1)) / torch.sum(mask)

            loss2_tail = loss_fuc(pred_obj_tails,
                                  obj_tails).view(-1, obj_tails.shape[-1])
            loss2_tail = torch.sum(
                loss2_tail * mask.unsqueeze(-1)) / torch.sum(mask)

            loss2 = loss2_head + loss2_tail

            # optimize
            loss = loss1 + loss2
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            if (global_step + 1) % config.print_freq == 0:
                logger.info(
                    "epoch : {} step: {} #### loss1: {}  loss2: {}".format(
                        _, global_step + 1,
                        loss1.cpu().item(),
                        loss2.cpu().item()))

            if (global_step + 1) % config.eval_freq == 0:
                logger.info("***** Running evaluating *****")
                with torch.no_grad():
                    Bert_model.eval()
                    submodel.eval()
                    objmodel.eval()
                    P, R, F1 = utils.metric(Bert_model, submodel, objmodel,
                                            dev_data, id2rel, tokenizer)
                    logger.info(f'precision:{P}\nrecall:{R}\nF1:{F1}')
                Bert_model.train()
                submodel.train()
                objmodel.train()

            if (global_step + 1) % config.save_freq == 0:
                # Save a trained model
                model_name = "pytorch_model_%d" % (global_step + 1)
                output_model_file = os.path.join(config.output_dir, model_name)
                state = {
                    'bert_state_dict': Bert_model.state_dict(),
                    'subject_state_dict': submodel.state_dict(),
                    'object_state_dict': objmodel.state_dict(),
                }
                torch.save(state, output_model_file)

    model_name = "pytorch_model_last"
    output_model_file = os.path.join(config.output_dir, model_name)
    state = {
        'bert_state_dict': Bert_model.state_dict(),
        'subject_state_dict': submodel.state_dict(),
        'object_state_dict': objmodel.state_dict(),
    }
    torch.save(state, output_model_file)
class BertForMultiLabelSequenceClassification(BertPreTrainedModel):
    """
    Bert model adapted for multi-label sequence classification.
    Note that for imbalance problems will also provide an extra parameter to add inside
    the loss function to integrate the classes distribution.
    """

    def __init__(self, config):
        super(BertForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = config.num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.pos_weight = torch.Tensor(config.pos_weight).to(device) if config.use_pos_weight else None

        self.init_weights()

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
                head_mask=None, labels=None):
        """
        :param input_ids: sentence or sentences represented as tokens
        :param attention_mask: tells the model which tokens in the input_ids are words and which are padding.
                               1 indicates a token and 0 indicates padding.
        :param token_type_ids: used when there are two sentences that need to be part of the input. It indicate which
                               tokens are part of sentence1 and which are part of sentence2.
        :param position_ids: indices of positions of each input sequence tokens in the position embeddings. Selected
                             in the range ``[0, config.max_position_embeddings - 1]
        :param head_mask: mask to nullify selected heads of the self-attention modules
        :param labels: target for each input
        :return:
        """
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
        )
        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        outputs = (logits,) + outputs[2:]

        if labels is not None:
            loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
            labels = labels.float()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
            outputs = (loss,) + outputs

        return outputs

    def freeze_bert_encoder(self):
        """Freeze BERT layers"""
        for param in self.bert.parameters():
            param.requires_grad = False

    def unfreeze_bert_encoder(self):
        """Unfreeze BERT layers"""
        for param in self.bert.parameters():
            param.requires_grad = True
Beispiel #29
0
class SpERT(BertPreTrainedModel):
    """ Span-based model to jointly extract entities and relations """
    def __init__(self,
                 config: BertConfig,
                 cls_token: int,
                 relation_types: int,
                 entity_types: int,
                 size_embedding: int,
                 prop_drop: float,
                 freeze_transformer: bool,
                 max_pairs: int = 100):
        super(SpERT, self).__init__(config)

        # BERT model
        self.bert = BertModel(config)

        # layers
        self.rel_classifier = nn.Linear(
            config.hidden_size * 3 + size_embedding * 2, relation_types)
        self.entity_classifier = nn.Linear(
            config.hidden_size * 2 + size_embedding, entity_types)
        self.size_embeddings = nn.Embedding(100, size_embedding)
        self.dropout = nn.Dropout(prop_drop)

        self._cls_token = cls_token
        self._relation_types = relation_types
        self._entity_types = entity_types
        self._max_pairs = max_pairs

        # weight initialization
        self.init_weights()

        if freeze_transformer:
            print("Freeze transformer weights")

            # freeze all transformer weights
            for param in self.bert.parameters():
                param.requires_grad = False

    def _forward_train(self, encodings: torch.tensor,
                       context_masks: torch.tensor, entity_masks: torch.tensor,
                       entity_sizes: torch.tensor, relations: torch.tensor,
                       rel_masks: torch.tensor):
        # get contextualized token embeddings from last transformer layer
        context_masks = context_masks.float()
        h = self.bert(input_ids=encodings, attention_mask=context_masks)[0]

        entity_masks = entity_masks.float()
        batch_size = encodings.shape[0]

        # classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_clf, entity_spans_pool = self._classify_entities(
            encodings, h, entity_masks, size_embeddings)

        # classify relations
        rel_masks = rel_masks.float().unsqueeze(-1)
        h_large = h.unsqueeze(1).repeat(
            1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
        rel_clf = torch.zeros(
            [batch_size, relations.shape[1],
             self._relation_types]).to(self.rel_classifier.weight.device)

        # obtain relation logits
        # chunk processing to reduce memory usage
        for i in range(0, relations.shape[1], self._max_pairs):
            # classify relation candidates
            chunk_rel_logits = self._classify_relations(
                entity_spans_pool, size_embeddings, relations, rel_masks,
                h_large, i)
            rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_logits

        return entity_clf, rel_clf

    def _forward_eval(self, encodings: torch.tensor,
                      context_masks: torch.tensor, entity_masks: torch.tensor,
                      entity_sizes: torch.tensor, entity_spans: torch.tensor,
                      entity_sample_masks: torch.tensor):
        # get contextualized token embeddings from last transformer layer
        context_masks = context_masks.float()
        h = self.bert(input_ids=encodings, attention_mask=context_masks)[0]

        entity_masks = entity_masks.float()
        batch_size = encodings.shape[0]
        ctx_size = context_masks.shape[-1]

        # classify entities
        size_embeddings = self.size_embeddings(
            entity_sizes)  # embed entity candidate sizes
        entity_clf, entity_spans_pool = self._classify_entities(
            encodings, h, entity_masks, size_embeddings)

        # ignore entity candidates that do not constitute an actual entity for relations (based on classifier)
        relations, rel_masks, rel_sample_masks = self._filter_spans(
            entity_clf, entity_spans, entity_sample_masks, ctx_size)
        rel_masks = rel_masks.float()
        rel_sample_masks = rel_sample_masks.float()
        h_large = h.unsqueeze(1).repeat(
            1, max(min(relations.shape[1], self._max_pairs), 1), 1, 1)
        rel_clf = torch.zeros(
            [batch_size, relations.shape[1],
             self._relation_types]).to(self.rel_classifier.weight.device)

        # obtain relation logits
        # chunk processing to reduce memory usage
        for i in range(0, relations.shape[1], self._max_pairs):
            # classify relation candidates
            chunk_rel_logits = self._classify_relations(
                entity_spans_pool, size_embeddings, relations, rel_masks,
                h_large, i)
            # apply sigmoid
            chunk_rel_clf = torch.sigmoid(chunk_rel_logits)
            rel_clf[:, i:i + self._max_pairs, :] = chunk_rel_clf

        rel_clf = rel_clf * rel_sample_masks  # mask

        # apply softmax
        entity_clf = torch.softmax(entity_clf, dim=2)

        return entity_clf, rel_clf, relations

    def _classify_entities(self, encodings, h, entity_masks, size_embeddings):
        # max pool entity candidate spans
        entity_spans_pool = entity_masks.unsqueeze(-1) * h.unsqueeze(1).repeat(
            1, entity_masks.shape[1], 1, 1)
        entity_spans_pool = entity_spans_pool.max(dim=2)[0]

        # get cls token as candidate context representation
        entity_ctx = get_token(h, encodings, self._cls_token)

        # create candidate representations including context, max pooled span and size embedding
        entity_repr = torch.cat([
            entity_ctx.unsqueeze(1).repeat(1, entity_spans_pool.shape[1], 1),
            entity_spans_pool, size_embeddings
        ],
                                dim=2)
        entity_repr = self.dropout(entity_repr)

        # classify entity candidates
        entity_clf = self.entity_classifier(entity_repr)

        return entity_clf, entity_spans_pool

    def _classify_relations(self, entity_spans, size_embeddings, relations,
                            rel_masks, h, chunk_start):
        batch_size = relations.shape[0]

        # create chunks if necessary
        if relations.shape[1] > self._max_pairs:
            relations = relations[:, chunk_start:chunk_start + self._max_pairs]
            rel_masks = rel_masks[:, chunk_start:chunk_start + self._max_pairs]
            h = h[:, :relations.shape[1], :]

        # get pairs of entity candidate representations
        entity_pairs = util.batch_index(entity_spans, relations)
        entity_pairs = entity_pairs.view(batch_size, entity_pairs.shape[1], -1)

        # get corresponding size embeddings
        size_pair_embeddings = util.batch_index(size_embeddings, relations)
        size_pair_embeddings = size_pair_embeddings.view(
            batch_size, size_pair_embeddings.shape[1], -1)

        # relation context (context between entity candidate pair)
        rel_ctx = rel_masks * h
        rel_ctx = rel_ctx.max(dim=2)[0]

        # create relation candidate representations including context, max pooled entity candidate pairs
        # and corresponding size embeddings
        rel_repr = torch.cat([rel_ctx, entity_pairs, size_pair_embeddings],
                             dim=2)
        rel_repr = self.dropout(rel_repr)

        # classify relation candidates
        chunk_rel_logits = self.rel_classifier(rel_repr)
        return chunk_rel_logits

    def _filter_spans(self, entity_clf, entity_spans, entity_sample_masks,
                      ctx_size):
        batch_size = entity_clf.shape[0]
        entity_logits_max = entity_clf.argmax(
            dim=-1) * entity_sample_masks.long(
            )  # get entity type (including none)
        batch_relations = []
        batch_rel_masks = []
        batch_rel_sample_masks = []

        for i in range(batch_size):
            rels = []
            rel_masks = []
            sample_masks = []

            # get spans classified as entities
            non_zero_indices = (entity_logits_max[i] != 0).nonzero().view(-1)
            non_zero_spans = entity_spans[i][non_zero_indices].tolist()
            non_zero_indices = non_zero_indices.tolist()

            # create relations and masks
            for i1, s1 in zip(non_zero_indices, non_zero_spans):
                for i2, s2 in zip(non_zero_indices, non_zero_spans):
                    if i1 != i2:
                        rels.append((i1, i2))
                        rel_masks.append(
                            sampling.create_rel_mask(s1, s2, ctx_size))
                        sample_masks.append(1)

            if not rels:
                # case: no more than two spans classified as entities
                batch_relations.append(torch.tensor([[0, 0]],
                                                    dtype=torch.long))
                batch_rel_masks.append(
                    torch.tensor([[0] * ctx_size], dtype=torch.bool))
                batch_rel_sample_masks.append(
                    torch.tensor([0], dtype=torch.bool))
            else:
                # case: more than two spans classified as entities
                batch_relations.append(torch.tensor(rels, dtype=torch.long))
                batch_rel_masks.append(torch.stack(rel_masks))
                batch_rel_sample_masks.append(
                    torch.tensor(sample_masks, dtype=torch.bool))

        # stack
        device = self.rel_classifier.weight.device
        batch_relations = util.padded_stack(batch_relations).to(device)
        batch_rel_masks = util.padded_stack(batch_rel_masks).to(
            device).unsqueeze(-1)
        batch_rel_sample_masks = util.padded_stack(batch_rel_sample_masks).to(
            device).unsqueeze(-1)

        return batch_relations, batch_rel_masks, batch_rel_sample_masks

    def forward(self, *args, evaluate=False, **kwargs):
        if not evaluate:
            return self._forward_train(*args, **kwargs)
        else:
            return self._forward_eval(*args, **kwargs)
Beispiel #30
0
class BertLstmCrf(BertPreTrainedModel):
    def __init__(self, config, extra_config, ignore_ids):
        """
        num_labels : int, required
            Number of tags.
        idx2tag : ``Dict[int, str]``, required
            A mapping {label_id -> label}. Example: {0:"B-LOC", 1:"I-LOC", 2:"O"}
        label_encoding : ``str``, required
            Indicates which constraint to apply. Current choices are
            "BIO", "IOB1", "BIOUL", "BMES" and "BIOES",.
                B = Beginning
                I/M = Inside / Middle
                L/E = Last / End
                O = Outside
                U/W/S = Unit / Whole / Single
        """
        super(BertLstmCrf, self).__init__(config)
        self.pretraind = BertModel(config)
        self.dropout = nn.Dropout(extra_config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.bilstm = nn.LSTM(input_size=config.hidden_size,
                              hidden_size=config.hidden_size // 2,
                              batch_first=True,
                              num_layers=extra_config.num_layers,
                              dropout=extra_config.lstm_dropout,
                              bidirectional=True)
        self.crf = crf(config.num_labels, extra_config.label_encoding,
                       extra_config.idx2tag)
        self.init_weights()
        if extra_config.freez_prrtrained:
            for param in self.pretraind.parameters():
                param.requires_grad = False

        self.ignore_ids = ignore_ids

    def forward(self,
                input_ids,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                labels=None):
        # outputs的组成:
        # last_hidden_state: Sequence of hidden-states at the output of the last layer of the model.
        #                     (batch_size, sequence_length, hidden_size)
        # pooler_output:      Last layer hidden-state of the first token of the sequence (classification token)
        #                     processed by a Linear layer and a Tanh activation function.
        # hidden_states:     one for the output of the embeddings + one for the output of each layer.
        #                     each is (batch_size, sequence_length, hidden_size)
        # attentions:         Attentions weights after the attention softmax of each layer.
        #                     each is (batch_size, num_heads, sequence_length, sequence_length)
        outputs = self.pretraind(input_ids,
                                 attention_mask=attention_mask,
                                 token_type_ids=token_type_ids,
                                 position_ids=position_ids,
                                 head_mask=head_mask)
        last_hidden_state = outputs[0]

        seq_output = self.dropout(last_hidden_state)
        seq_output, _ = self.bilstm(seq_output)
        seq_output = nn.LayerNorm(seq_output.size()[-1])(seq_output)
        logits = self.classifier(seq_output)

        outputs = (logits, ) + outputs[2:]

        masked_labels, masked_logits = self._get_masked_inputs(
            input_ids, labels, logits, attention_mask)
        if labels is not None:
            loss = self.crf(masked_logits, masked_labels,
                            mask=None)  # mask=None: 已经处理了所有的无用的位置
            outputs = (loss, ) + outputs

        # (loss), logits, (hidden_states), (attentions)
        return outputs

    def _get_masked_inputs(self, input_ids, label_ids, logits, attention_mask):
        ignore_ids = self.ignore_ids

        # Remove unuseful positions
        masked_ids = input_ids[(1 == attention_mask)]
        masked_labels = label_ids[(1 == attention_mask)]
        masked_logits = logits[(1 == attention_mask)]
        for id in ignore_ids:
            masked_labels = masked_labels[(id != masked_ids)]
            masked_logits = masked_logits[(id != masked_ids)]
            masked_ids = masked_ids[(id != masked_ids)]

        return masked_labels, masked_logits