def eval_ner_step(self, x, model):
        label = x[self.tag_id_field]
        _, pred_y = model(x[self.word_id_field], x[self.segment_id_field], x[self.mask_field])

        loss = masked_sparse_categorical_crossentropy(label, pred_y)
        self.val_loss(loss)
        pred_labels = tf.argmax(pred_y, axis=-1)
        self.val_accuracy(label, pred_labels, sample_weight=x[self.mask_field])
        self.val_accuracy_no_other(label, pred_labels, sample_weight=get_accuracy_no_other_mask(label))
    def train_step(self, x, model, ner_optimizer):
        label = x[self.tag_id_field]
        with tf.GradientTape() as tape:
            _, pred_y = model(x[self.word_id_field], x[self.segment_id_field], x[self.mask_field])

            loss = masked_sparse_categorical_crossentropy(label, pred_y)

        gradients = tape.gradient(loss, model.trainable_variables)
        ner_optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        self.train_loss(loss)

        pred_labels = tf.argmax(pred_y, axis=-1)

        self.train_accuracy(label, pred_labels, sample_weight=x[self.mask_field])
        self.train_accuracy_no_other(label, pred_labels, sample_weight=get_accuracy_no_other_mask(label))