def __init__(self, 
              num_classes: int,
              bert_weights: str,
              dropout: float=.10):
     super(BertCRFClassifier, self).__init__()
     
     self.bert = BertModel.from_pretrained(bert_weights)
     
     for param in list(self.bert.parameters())[:-5]:
         param.requires_grad = False
     
     hidden_size = self.bert.config.hidden_size
     
     self.span_clf_head = nn.Linear(hidden_size, num_classes)
     self.binary_clf_head = nn.Linear(hidden_size, 2)
     
     self.attention = SelfAttention(hidden_size, batch_first=True)
     
     self.dropout = nn.Dropout(p=dropout)        
     self.crf = CRF(num_tags=num_classes)
예제 #2
0
def get_convo_nn2(no_word=200, n_gram=21, no_char=178):
    input1 = Input(shape=(n_gram, ))
    input2 = Input(shape=(n_gram, ))

    a = Embedding(no_char, 32, input_length=n_gram)(input1)
    a = SpatialDropout1D(0.15)(a)
    a = BatchNormalization()(a)

    a_concat = []
    for i in range(1, 9):
        a_concat.append(conv_unit(a, n_gram, no_word, window=i))
    for i in range(9, 12):
        a_concat.append(conv_unit(a, n_gram, no_word - 50, window=i))
    a_concat.append(conv_unit(a, n_gram, no_word - 100, window=12))
    a_sum = Maximum()(a_concat)

    b = Embedding(12, 12, input_length=n_gram)(input2)
    b = SpatialDropout1D(0.15)(b)

    x = Concatenate(axis=-1)([a, a_sum, b])
    #x = Concatenate(axis=-1)([a_sum, b])
    x = BatchNormalization()(x)

    x = Flatten()(x)
    x = Dense(100, activation='relu')(x)

    ########################
    #out = Dense(1, activation='sigmoid')(x)

    # crf = CRF(n_gram)  ## ???????
    # out = crf(x)

    # out = x.add(CRF(100))

    crf = CRF(2, sparse_target=False)  # num_label
    loss = crf.loss_function
    out = crf(x)
    ##########################

    model = Model(inputs=[input1, input2], outputs=out)

    ####################
    # model.compile(optimizer=Adam(),
    #   loss='binary_crossentropy', metrics=['acc'])
    model.compile(optimizer="adam", loss=loss)
    #####################
    return model
예제 #3
0
    def build(self):
        # build word embedding
        word_ids = Input(batch_shape=(None, None), dtype='int32', name='word_input')
        inputs = [word_ids]
        if self._embeddings is None:
            word_embeddings = Embedding(input_dim=self._word_vocab_size,
                                        output_dim=self._word_embedding_dim,
                                        mask_zero=True,
                                        name='word_embedding')(word_ids)
        else:
            word_embeddings = Embedding(input_dim=self._embeddings.shape[0],
                                        output_dim=self._embeddings.shape[1],
                                        mask_zero=True,
                                        weights=[self._embeddings],
                                        name='word_embedding')(word_ids)

        # build character based word embedding
        if self._use_char:
            char_ids = Input(batch_shape=(None, None, None), dtype='int32', name='char_input')
            inputs.append(char_ids)
            char_embeddings = Embedding(input_dim=self._char_vocab_size,
                                        output_dim=self._char_embedding_dim,
                                        mask_zero=True,
                                        name='char_embedding')(char_ids)
            char_embeddings = TimeDistributed(Bidirectional(rnn.get_rnn_layer(
                self._cell_type, self._char_lstm_size, return_sequences=False)))(char_embeddings)
            word_embeddings = Concatenate()([word_embeddings, char_embeddings])

        word_embeddings = Dropout(self._dropout)(word_embeddings)
        z = Bidirectional(rnn.get_rnn_layer(
            self._cell_type, self._word_lstm_size, return_sequences=True))(word_embeddings)
        z = Dense(self._fc_dim, activation='tanh')(z)

        if self._use_crf:
            crf = CRF(self._num_labels, sparse_target=False)
            loss = crf.loss_function
            pred = crf(z)
        else:
            loss = 'categorical_crossentropy'
            pred = Dense(self._num_labels, activation='softmax')(z)

        model = Model(inputs=inputs, outputs=pred)

        return model, loss
예제 #4
0
    def build(self):
        # build word embedding
        word_ids = Input(batch_shape=(None, None), dtype='int32', name='word_input')
        print(self._embeddings)
        if self._embeddings is None:
            word_embeddings = Embedding(input_dim=self._word_vocab_size,
                                        output_dim=self._word_embedding_dim,
                                        mask_zero=True,
                                        name='word_embedding', trainable=False)(word_ids)
        else:
            word_embeddings = Embedding(input_dim=self._embeddings.shape[0],
                                        output_dim=self._embeddings.shape[1],
                                        mask_zero=True,
                                        weights=[self._embeddings],
                                        name='word_embedding', trainable=False)(word_ids)

        # build character based word embedding
        char_ids = Input(batch_shape=(None, None, None), dtype='int32', name='char_input')
        char_embeddings = Embedding(input_dim=self._char_vocab_size,
                                    output_dim=self._char_embedding_dim,
                                    mask_zero=True,
                                    name='char_embedding')(char_ids)
        char_embeddings = TimeDistributed(Bidirectional(rnn.get_rnn_layer(
            self._cell_type, self._char_lstm_size, return_sequences=False)))(char_embeddings)

        elmo_embeddings = Input(shape=(None, 1024), dtype='float32')

        word_embeddings = Concatenate()([word_embeddings, char_embeddings, elmo_embeddings])

        word_embeddings = Dropout(self._dropout)(word_embeddings)
        z = Bidirectional(rnn.get_rnn_layer(
            self._cell_type, self._word_lstm_size, return_sequences=True))(word_embeddings)
        z = Dense(self._fc_dim, activation='tanh')(z)

        crf = CRF(self._num_labels, sparse_target=False)
        loss = crf.loss_function
        pred = crf(z)

        model = Model(inputs=[word_ids, char_ids, elmo_embeddings], outputs=pred)

        return model, loss
예제 #5
0
def build_model(config_path,
                checkpoint_path,
                max_seq_length,
                label_num,
                bert_trainable=False):
    in_id = Input(shape=(max_seq_length, ), name="input_ids", dtype="int32")
    in_segment = Input(shape=(max_seq_length, ),
                       name="segment_ids",
                       dtype="int32")

    bert_model = load_pretrained_model(config_path,
                                       checkpoint_path)  # 建立模型,加载权重

    for l in bert_model.layers:
        if bert_trainable:
            l.trainable = True
        else:
            l.trainable = False

    sequence_output = bert_model([in_id, in_segment])
    bilstm_output = Bidirectional(CuDNNLSTM(
        128, return_sequences=True))(sequence_output)

    layer_dense = Dense(64, activation='tanh', name='layer_dense')
    layer_crf_dense = Dense(label_num, name='layer_crf_dense')
    layer_crf = CRF(label_num, name='layer_crf')

    dense = layer_dense(bilstm_output)
    dense = layer_crf_dense(dense)
    pred = layer_crf(dense)

    model = Model(inputs=[in_id, in_segment], outputs=pred)
    model.compile(loss=layer_crf.loss,
                  optimizer=Adam(lr=1e-5),
                  metrics=[layer_crf.viterbi_accuracy])

    model.summary(line_length=150)

    return model
    def __init__(self, labels, n_words, n_chars):
        self.n_labels = len(labels)
        self.n_words = n_words
        self.n_chars = n_chars

        #Word embedding
        word_in = Input(shape=(None,))
        word_emb = Embedding(input_dim=self.n_words+1, output_dim=128)(word_in)
        
        #Character embedding
        char_in = Input(shape=(None, None,))
        char_emb = TimeDistributed(Embedding(input_dim=self.n_chars + 2, output_dim=16,
                         mask_zero=True))(char_in)

        # character LSTM to get word encodings by characters
        char_enc = TimeDistributed(LSTM(units=28, return_sequences=False,
                                        recurrent_dropout=0.5))(char_emb)

        concat = concatenate([word_emb, char_enc])
        bi_lstm = SpatialDropout1D(0.3)(concat)

        for i in range(2):
            bi_lstm = Bidirectional(
                LSTM(
                    units=256, 
                    return_sequences=True,
                    recurrent_dropout=0.3
                )
            )(bi_lstm)

        linear = TimeDistributed(Dense(self.n_labels, activation='relu'))(bi_lstm)  # softmax output layer

        crf = CRF(self.n_labels, sparse_target=False)
        pred = crf(linear)

        self.model = Model(inputs=[word_in, char_in], outputs=pred)
        self.loss = crf.loss_function
        self.accuracy = crf.accuracy
        self.model.compile(loss=self.loss, optimizer='adam', metrics=[self.accuracy])
예제 #7
0
randomInit = False
if doCRF:
    randomInit = True
outputLayer = LogisticRegression(n_in=hiddenUnits,
                                 n_out=numClasses,
                                 rng=rng,
                                 randomInit=randomInit)
layers.append(outputLayer)
outputLayerET = LogisticRegression(n_in=hiddenUnitsET,
                                   n_out=numClassesET,
                                   rng=rng,
                                   randomInit=randomInit)
layers.append(outputLayerET)
if doCRF:
    crfLayer = CRF(numClasses=numClasses + numClassesET,
                   rng=rng,
                   batchsizeVar=batchsizeVar,
                   sequenceLength=3)
    layers.append(crfLayer)

x1_resh = x1.reshape((batchsizeVar * numPerBag, contextsize))
x1_emb = embeddings[:, x1_resh].dimshuffle(1, 0, 2)
x1_emb = x1_emb.reshape((x1_emb.shape[0], 1, x1_emb.shape[1], x1_emb.shape[2]))
x2_resh = x2.reshape((batchsizeVar * numPerBag, contextsize))
x2_emb = embeddings[:, x2_resh].dimshuffle(1, 0, 2)
x2_emb = x2_emb.reshape((x2_emb.shape[0], 1, x2_emb.shape[1], x2_emb.shape[2]))
x3_resh = x3.reshape((batchsizeVar * numPerBag, contextsize))
x3_emb = embeddings[:, x3_resh].dimshuffle(1, 0, 2)
x3_emb = x3_emb.reshape((x3_emb.shape[0], 1, x3_emb.shape[1], x3_emb.shape[2]))
x4_resh = x4.reshape((batchsizeVar * numPerBag, contextsize))
x4_emb = embeddings[:, x4_resh].dimshuffle(1, 0, 2)
x4_emb = x4_emb.reshape((x4_emb.shape[0], 1, x4_emb.shape[1], x4_emb.shape[2]))
    def __init__(self, num_classes: int, bert_weights: str):
        super(BertCRFClassifier, self).__init__()

        self.bert = BertForTokenClassification.from_pretrained(
            bert_weights, num_labels=4, output_hidden_states=True)
        self.crf = CRF(num_tags=num_classes)
class BertCRFClassifier(pl.LightningModule):
    def __init__(self, num_classes: int, bert_weights: str):
        super(BertCRFClassifier, self).__init__()

        self.bert = BertForTokenClassification.from_pretrained(
            bert_weights, num_labels=4, output_hidden_states=True)
        self.crf = CRF(num_tags=num_classes)

    def forward(self,
                input_ids: torch.tensor,
                attention_mask: torch.tensor = None):
        bert_out, _ = self.bert(input_ids, attention_mask=attention_mask)
        crf_out = self.crf(bert_out)
        # pooled_logits = torch.mean(torch.stack(bert_logits), dim=0)
        return crf_out, bert_out

    def crf_loss(self, pred_logits: torch.tensor,
                 labels: torch.tensor) -> torch.tensor:
        return self.crf.loss(pred_logits, labels)

    def training_step(self, batch, batch_idx):

        input_ids, mask, labels = batch
        preds, logits = self.forward(input_ids, attention_mask=mask)
        loss = self.crf_loss(logits, labels)

        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        input_ids, mask, labels = batch

        preds, logits = self.forward(input_ids, attention_mask=mask)
        loss = self.crf_loss(logits, labels)

        labels = labels.detach().cpu().numpy().flatten()
        preds = preds.detach().cpu().numpy().flatten()

        recall = recall_score(labels, preds, average="macro")
        recall = torch.tensor(recall)

        precision = precision_score(labels, preds, average="macro")
        precision = torch.tensor(precision)
        return {'val_loss': loss, "recall": recall, "precision": precision}

    def validation_end(self, outputs):
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        avg_recall = torch.stack([x["recall"] for x in outputs]).mean()
        avg_precision = torch.stack([x["precision"] for x in outputs]).mean()

        tensorboard_logs = {
            "val_loss": avg_loss,
            'avg_val_recall': avg_recall,
            'avg_val_precision': avg_precision
        }
        return {
            'avg_val_loss': avg_loss,
            'avg_val_recall': avg_recall,
            'avg_val_precision': avg_precision,
            'progress_bar': tensorboard_logs
        }

    def configure_optimizers(self):
        param_optimizer = list(self.parameters())
        no_decay = ['bias', 'gamma', 'beta']
        optimizer_grouped_parameters = [{
            "params": [
                p for n, p in self.bert.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.01,
        }, {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        }]
        optimizer = Adam(optimizer_grouped_parameters, lr=2e-5)
        # optimizer = Adam(optimizer_grouped_parameters, lr=5e-5)

        return optimizer

    @pl.data_loader
    def train_dataloader(self):
        return train_dataloader_

    @pl.data_loader
    def val_dataloader(self):
        return val_dataloader_
    def __init__(self, configfile, train=False):

        self.slotList = [
            "N", "per:age", "per:alternate_names", "per:children",
            "per:cause_of_death", "per:date_of_birth", "per:date_of_death",
            "per:employee_or_member_of", "per:location_of_birth",
            "per:location_of_death", "per:locations_of_residence",
            "per:origin", "per:schools_attended", "per:siblings", "per:spouse",
            "per:title", "org:alternate_names", "org:date_founded",
            "org:founded_by", "org:location_of_headquarters", "org:members",
            "org:parents", "org:top_members_employees"
        ]

        typeList = [
            "O", "PERSON", "LOCATION", "ORGANIZATION", "DATE", "NUMBER"
        ]

        self.config = readConfig(configfile)

        self.addInputSize = 1
        logger.info("additional mlp input")

        wordvectorfile = self.config["wordvectors"]
        logger.info("wordvectorfile " + wordvectorfile)
        networkfile = self.config["net"]
        logger.info("networkfile " + networkfile)
        hiddenunits = int(self.config["hidden"])
        logger.info("hidden units " + str(hiddenunits))
        hiddenunitsNer = hiddenunits
        if "hiddenunitsNER" in self.config:
            hiddenunitsNer = int(self.config["hiddenunitsNER"])
        representationsizeNER = 50
        if "representationsizeNER" in self.config:
            representationsizeNER = int(self.config["representationsizeNER"])
        learning_rate = float(self.config["lrate"])
        logger.info("learning rate " + str(learning_rate))
        if train:
            self.batch_size = int(self.config["batchsize"])
        else:
            self.batch_size = 1
        logger.info("batch size " + str(self.batch_size))
        self.filtersize = [1, int(self.config["filtersize"])]
        nkerns = [int(self.config["nkerns"])]
        logger.info("nkerns " + str(nkerns))
        pool = [1, int(self.config["kmax"])]

        self.contextsize = int(self.config["contextsize"])
        logger.info("contextsize " + str(self.contextsize))

        if self.contextsize < self.filtersize[1]:
            logger.info("setting filtersize to " + str(self.contextsize))
            self.filtersize[1] = self.contextsize
        logger.info("filtersize " + str(self.filtersize))

        sizeAfterConv = self.contextsize - self.filtersize[1] + 1

        sizeAfterPooling = -1
        if sizeAfterConv < pool[1]:
            logger.info("setting poolsize to " + str(sizeAfterConv))
            pool[1] = sizeAfterConv
        sizeAfterPooling = pool[1]
        logger.info("kmax pooling: k = " + str(pool[1]))

        # reading word vectors
        self.wordvectors, self.vectorsize = readWordvectors(wordvectorfile)

        self.representationsize = self.vectorsize + 1

        rng = numpy.random.RandomState(
            23455
        )  # not relevant, parameters will be overwritten by stored model anyways
        if train:
            seed = rng.get_state()[1][0]
            logger.info("seed: " + str(seed))

        numSFclasses = 23
        numNERclasses = 6

        # allocate symbolic variables for the data
        self.index = T.lscalar()  # index to a [mini]batch
        self.xa = T.matrix('xa')  # left context
        self.xb = T.matrix('xb')  # middle context
        self.xc = T.matrix('xc')  # right context
        self.y = T.imatrix('y')  # label (only present in training)
        self.yNER1 = T.imatrix(
            'yNER1')  # label for first entity (only present in training)
        self.yNER2 = T.imatrix(
            'yNER2')  # label for second entity (only present in training)
        ishape = [self.representationsize,
                  self.contextsize]  # this is the size of context matrizes

        ######################
        # BUILD ACTUAL MODEL #
        ######################
        logger.info('... building the model')

        # Reshape input matrix to be compatible with LeNetConvPoolLayer
        layer0a_input = self.xa.reshape(
            (self.batch_size, 1, ishape[0], ishape[1]))
        layer0b_input = self.xb.reshape(
            (self.batch_size, 1, ishape[0], ishape[1]))
        layer0c_input = self.xc.reshape(
            (self.batch_size, 1, ishape[0], ishape[1]))

        y_reshaped = self.y.reshape((self.batch_size, 1))
        yNER1reshaped = self.yNER1.reshape((self.batch_size, 1))
        yNER2reshaped = self.yNER2.reshape((self.batch_size, 1))

        # Construct convolutional pooling layer:
        filter_shape = (nkerns[0], 1, self.representationsize,
                        self.filtersize[1])
        poolsize = (pool[0], pool[1])
        fan_in = numpy.prod(filter_shape[1:])
        fan_out = (filter_shape[0] * numpy.prod(filter_shape[2:]) /
                   numpy.prod(poolsize))
        W_bound = numpy.sqrt(6. / (fan_in + fan_out))
        # the convolution weight matrix
        convW = theano.shared(numpy.asarray(rng.uniform(low=-W_bound,
                                                        high=W_bound,
                                                        size=filter_shape),
                                            dtype=theano.config.floatX),
                              borrow=True)
        # the bias is a 1D tensor -- one bias per output feature map
        b_values = numpy.zeros((filter_shape[0], ), dtype=theano.config.floatX)
        convB = theano.shared(value=b_values, borrow=True)

        self.layer0a = LeNetConvPoolLayer(rng,
                                          W=convW,
                                          b=convB,
                                          input=layer0a_input,
                                          image_shape=(self.batch_size, 1,
                                                       ishape[0], ishape[1]),
                                          filter_shape=filter_shape,
                                          poolsize=poolsize)
        self.layer0b = LeNetConvPoolLayer(rng,
                                          W=convW,
                                          b=convB,
                                          input=layer0b_input,
                                          image_shape=(self.batch_size, 1,
                                                       ishape[0], ishape[1]),
                                          filter_shape=filter_shape,
                                          poolsize=poolsize)
        self.layer0c = LeNetConvPoolLayer(rng,
                                          W=convW,
                                          b=convB,
                                          input=layer0c_input,
                                          image_shape=(self.batch_size, 1,
                                                       ishape[0], ishape[1]),
                                          filter_shape=filter_shape,
                                          poolsize=poolsize)

        layer0aflattened = self.layer0a.output.flatten(2).reshape(
            (self.batch_size, nkerns[0] * sizeAfterPooling))
        layer0bflattened = self.layer0b.output.flatten(2).reshape(
            (self.batch_size, nkerns[0] * sizeAfterPooling))
        layer0cflattened = self.layer0c.output.flatten(2).reshape(
            (self.batch_size, nkerns[0] * sizeAfterPooling))
        layer0outputSF = T.concatenate(
            [layer0aflattened, layer0bflattened, layer0cflattened], axis=1)
        layer0outputSFsize = 3 * (nkerns[0] * sizeAfterPooling)

        layer0outputNER1 = T.concatenate([layer0aflattened, layer0bflattened],
                                         axis=1)
        layer0outputNER2 = T.concatenate([layer0bflattened, layer0cflattened],
                                         axis=1)
        layer0outputNERsize = 2 * (nkerns[0] * sizeAfterPooling)

        layer2ner1 = HiddenLayer(rng,
                                 input=layer0outputNER1,
                                 n_in=layer0outputNERsize,
                                 n_out=hiddenunitsNer,
                                 activation=T.tanh)
        layer2ner2 = HiddenLayer(rng,
                                 input=layer0outputNER2,
                                 n_in=layer0outputNERsize,
                                 n_out=hiddenunitsNer,
                                 activation=T.tanh,
                                 W=layer2ner1.W,
                                 b=layer2ner1.b)

        # concatenate additional features to sentence representation
        self.additionalFeatures = T.matrix('additionalFeatures')
        self.additionalFeatsShaped = self.additionalFeatures.reshape(
            (self.batch_size, 1))

        layer2SFinput = T.concatenate(
            [layer0outputSF, self.additionalFeatsShaped], axis=1)
        layer2SFinputSize = layer0outputSFsize + self.addInputSize

        layer2SF = HiddenLayer(rng,
                               input=layer2SFinput,
                               n_in=layer2SFinputSize,
                               n_out=hiddenunits,
                               activation=T.tanh)

        # classify the values of the fully-connected sigmoidal layer
        layer3rel = LogisticRegression(input=layer2SF.output,
                                       n_in=hiddenunits,
                                       n_out=numSFclasses)
        layer3et = LogisticRegression(input=layer2ner1.output,
                                      n_in=hiddenunitsNer,
                                      n_out=numNERclasses)

        scoresForR1 = layer3rel.getScores(layer2SF.output)
        scoresForE1 = layer3et.getScores(layer2ner1.output)
        scoresForE2 = layer3et.getScores(layer2ner2.output)

        self.crfLayer = CRF(numClasses=numSFclasses + numNERclasses,
                            rng=rng,
                            batchsizeVar=self.batch_size,
                            sequenceLength=3)

        scores = T.zeros((self.batch_size, 3, numSFclasses + numNERclasses))
        scores = T.set_subtensor(scores[:, 0, numSFclasses:], scoresForE1)
        scores = T.set_subtensor(scores[:, 1, :numSFclasses], scoresForR1)
        scores = T.set_subtensor(scores[:, 2, numSFclasses:], scoresForE2)
        self.scores = scores

        self.y_conc = T.concatenate([
            yNER1reshaped + numSFclasses, y_reshaped,
            yNER2reshaped + numSFclasses
        ],
                                    axis=1)

        # create a list of all model parameters
        self.paramList = [
            self.crfLayer.params, layer3rel.params, layer3et.params,
            layer2SF.params, layer2ner1.params, self.layer0a.params
        ]
        self.params = []
        for p in self.paramList:
            self.params += p
            logger.info(p)

        if not train:
            self.gotNetwork = 1
            # load parameters
            if not os.path.isfile(networkfile):
                logger.error("network file does not exist")
                self.gotNetwork = 0
            else:
                save_file = open(networkfile, 'rb')
                for p in self.params:
                    p.set_value(cPickle.load(save_file), borrow=False)
                save_file.close()

        self.relation_scores_global = self.crfLayer.getProbForClass(
            self.scores, numSFclasses)
        self.predictions_global = self.crfLayer.getPrediction(self.scores)
class BertCRFClassifier(pl.LightningModule):
    
    def __init__(self, 
                 num_classes: int,
                 bert_weights: str,
                 dropout: float=.10):
        super(BertCRFClassifier, self).__init__()
        
        self.bert = BertModel.from_pretrained(bert_weights)
        
        for param in list(self.bert.parameters())[:-5]:
            param.requires_grad = False
        
        hidden_size = self.bert.config.hidden_size
        
        self.span_clf_head = nn.Linear(hidden_size, num_classes)
        self.binary_clf_head = nn.Linear(hidden_size, 2)
        
        self.attention = SelfAttention(hidden_size, batch_first=True)
        
        self.dropout = nn.Dropout(p=dropout)        
        self.crf = CRF(num_tags=num_classes)
        
    def forward(self, 
                input_ids: torch.tensor,
                attention_mask: torch.tensor,
                sent_lens: torch.tensor):        
        bert_last, bert_hidden = self.bert(input_ids, attention_mask=attention_mask)
        
        span_attention = self.attention(bert_last, sent_lens)
        bin_attention  = self.attention(bert_hidden, sent_lens)
        
        span_clf = self.dropout(self.span_clf_head(span_attention))
        bin_clf = self.dropout(self.binary_clf_head(bin_attention))
        crf_out = self.crf(span_clf)        
        
        return crf_out, span_clf, bin_clf
    
    def crf_loss(self, 
                 pred_logits: torch.tensor, 
                 labels: torch.tensor) -> torch.tensor:
        return self.crf.loss(pred_logits, labels)
        
    def training_step(self, batch, batch_idx):
        
        input_ids, mask, labels = batch
        
        sent_lengths = torch.sum(labels, dim=1).long().to("cuda:0")
        
        bin_labels = (torch.sum(labels, dim=1) > 0).long()
        bin_labels = bin_labels.to("cuda:0")
        preds, span_logits, bin_logits = self.forward(input_ids, 
                                                      attention_mask=mask,
                                                      sent_lens=sent_lengths)
        
        span_loss = self.crf_loss(span_logits, labels)        
        bin_loss = F.cross_entropy(bin_logits, bin_labels)
        combined_loss = span_loss + bin_loss
        
        tensorboard_logs = {'train_loss': combined_loss}
        return {'loss': combined_loss, 'log': tensorboard_logs}
    
    def validation_step(self, batch, batch_idx):
        input_ids, mask, labels = batch
        
        sent_lengths = torch.sum(labels, dim=1).long().to("cuda:0")

        
        bin_labels = (torch.sum(labels, dim=1) > 0).long()
        bin_labels = bin_labels.to("cuda:0")
        preds, span_logits, bin_logits = self.forward(input_ids, 
                                                      attention_mask=mask,
                                                      sent_lens=sent_lengths)
        
        span_loss = self.crf_loss(span_logits, labels)        
        bin_loss = F.cross_entropy(bin_logits, bin_labels)
        combined_loss = span_loss + bin_loss
        
        labels = labels.detach().cpu().numpy().flatten()
        bin_labels = bin_labels.detach().cpu().numpy().flatten()
        span_preds  = preds.detach().cpu().numpy().flatten()
        
        bin_preds  = torch.argmax(bin_logits,dim=1).detach().cpu().numpy().flatten()
        
        span_recall = torch.tensor(recall_score(labels, span_preds, 
                                                average="macro"))
        bin_recall  = torch.tensor(recall_score(bin_labels, bin_preds, 
                                                average="macro"))
        
        span_precision = torch.tensor(precision_score(labels, span_preds, 
                                                      average="macro"))
        bin_precision = torch.tensor(precision_score(bin_labels, bin_preds, 
                                                      average="macro"))
        
        return {'val_loss': combined_loss, 
                "span_recall": span_recall, 
                "bin_recall": bin_recall, 
                "span_precision": span_precision,
                "bin_precision": bin_precision}
    
    def validation_end(self, outputs):
        
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
        
        avg_span_recall = torch.stack([x["span_recall"] for x in outputs]).mean()
        avg_bin_recall = torch.stack([x["bin_recall"] for x in outputs]).mean()
        
        avg_span_precision = torch.stack([x["span_precision"] for x in outputs]).mean()
        avg_bin_precision = torch.stack([x["bin_precision"] for x in outputs]).mean()
        
        tensorboard_logs = {"val_loss": avg_loss}
        return {'avg_val_loss': avg_loss, 
                'avg_val_span_recall': avg_span_recall,
                'avg_val_bin_recall': avg_bin_recall,
                'avg_val_span_precision': avg_span_precision,
                'avg_val_bin_precision': avg_bin_precision,
                'progress_bar': tensorboard_logs}

    def configure_optimizers(self):
        param_optimizer = list(self.named_parameters())
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.bert.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": 0.01,
            },
            {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 
             "weight_decay": 0.0}
        ]
        optimizer = AdamW(optimizer_grouped_parameters, 
                          lr=2e-5, 
                          eps=1e-8)
        return optimizer

    @pl.data_loader
    def train_dataloader(self):
        return train_dataloader_
    
    @pl.data_loader
    def val_dataloader(self):
        return val_dataloader_