Пример #1
0
class Classifier(pl.LightningModule):
    """
    Sample model to show how to use a Transformer model to classify sentences.
    
    :param hparams: ArgumentParser containing the hyperparameters.
    """

    # *************************
    class DataModule(pl.LightningDataModule):
        def __init__(self, classifier_instance):
            super().__init__()
            self.hparams = classifier_instance.hparams

            # if self.hparams.transformer_type == 'longformer':
            # self.hparams.batch_size = 1
            # raise Exception
            self.classifier = classifier_instance

            self.transformer_type = self.hparams.transformer_type

            self.hparams.n_labels = 50
            df = pd.read_csv(self.hparams.train_csv,
                             converters={'CODES': eval})
            a = pd.Series([item for sublist in df.CODES for item in sublist])
            self.hparams.top_codes = a.value_counts()[:self.hparams.
                                                      n_labels].index.tolist()
            # self.top_codes = ['I48', 'CIR007']
            # self.hparams.n_labels = len(self.top_codes)
            logger.warning(
                f'Classifying against the top {self.hparams.n_labels} most frequent ICD codes: {self.hparams.top_codes}'
            )

            self.mlb = MultiLabelBinarizer()
            self.mlb.fit([self.hparams.top_codes])

            # Label Encoder
            # if self.hparams.single_label_encoding == 'default':
            #     self.label_encoder = LabelEncoder(
            #         np.unique(self.top_codes).tolist(),
            #         reserved_labels=[]
            #     )
            #     self.label_encoder.unknown_index = None

            # New code added for mlb

        def top_labeler(self, codes):
            # logger.info("codes are: {}".format(codes))
            out = [label for label in codes if label in self.hparams.top_codes]
            # logger.info("out is: {}".format(out))
            # if out == []:
            #    out = ['']
            # return self.mlb.transform([out])
            return out

        def get_mimic_data(self, path: str) -> list:  #REWRITE
            """ Reads a comma separated value file.

            :param path: path to a csv file.
            
            :return: List of records as dictionaries
            """

            if self.hparams.fast_dev_run:
                df = pd.read_csv(path,
                                 converters={'CODES': eval},
                                 skiprows=range(100, 100000))
            elif self.hparams.mid_dev_run:
                df = pd.read_csv(path,
                                 converters={'CODES': eval},
                                 skiprows=range(1000, 100000))
            else:
                df = pd.read_csv(path, converters={'CODES': eval})
            df = df[["note_text", "CODES"]]
            df.rename(columns={
                'note_text': 'text',
                'CODES': 'labels'
            },
                      inplace=True)

            # df = df[df['labels'].isin(self.top_codes)]
            # logger.info(df['labels'])
            df['labels'] = df['labels'].map(self.top_labeler)
            # logger.info(df['labels'])
            # df["text"] = df["text"].astype(str)
            # df["label"] = df["label"].astype(tr)

            df.to_csv(f'{path}_top_codes_filtered.csv')

            logger.warning(f'{path} dataframe has {len(df)} examples.')
            return df.to_dict("records")

        def train_dataloader(self) -> DataLoader:
            """ Function that loads the train set. """
            logger.warning('Loading training data...')
            self._train_dataset = self.get_mimic_data(self.hparams.train_csv)
            return DataLoader(
                dataset=self._train_dataset,
                sampler=RandomSampler(self._train_dataset),
                batch_size=self.hparams.batch_size,
                collate_fn=self.classifier.prepare_sample,
                num_workers=self.hparams.loader_workers,
            )

        def val_dataloader(self) -> DataLoader:
            logger.warning('Loading validation data...')
            """ Function that loads the validation set. """
            self._dev_dataset = self.get_mimic_data(self.hparams.dev_csv)
            return DataLoader(
                dataset=self._dev_dataset,
                batch_size=self.hparams.batch_size,
                collate_fn=self.classifier.prepare_sample,
                num_workers=self.hparams.loader_workers,
            )

        def test_dataloader(self) -> DataLoader:
            logger.warning('Loading testing data...')
            """ Function that loads the test set. """
            self._test_dataset = self.get_mimic_data(self.hparams.test_csv)

            return DataLoader(
                dataset=self._test_dataset,
                batch_size=self.hparams.batch_size,
                collate_fn=self.classifier.prepare_sample,
                num_workers=self.hparams.loader_workers,
            )

#    ****************

    def __init__(self, hparams: Namespace) -> None:
        super(Classifier, self).__init__()

        self.hparams = hparams
        self.batch_size = hparams.batch_size

        # Build Data module
        self.data = self.DataModule(self)

        # build model
        self.__build_model()

        # Loss criterion initialization.
        self.__build_loss()

        if hparams.nr_frozen_epochs > 0:
            self.freeze_encoder()
        else:
            self._frozen = False
        self.nr_frozen_epochs = hparams.nr_frozen_epochs

        self.test_conf_matrices = []

        # Set up multi label binarizer:
        self.mlb = MultiLabelBinarizer()
        self.mlb.fit([self.hparams.top_codes])

        self.acc = torchmetrics.Accuracy()
        self.f1 = torchmetrics.F1(num_classes=self.hparams.n_labels,
                                  average='micro')
        self.auroc = torchmetrics.AUROC(num_classes=self.hparams.n_labels,
                                        average='weighted')
        # NOTE could try 'global' instead of samplewise for mdmc reduce
        self.prec = torchmetrics.Precision(num_classes=self.hparams.n_labels,
                                           is_multiclass=False)
        self.recall = torchmetrics.Recall(num_classes=self.hparams.n_labels,
                                          is_multiclass=False)
        self.confusion_matrix = torchmetrics.ConfusionMatrix(
            num_classes=self.hparams.n_labels)

        self.test_predictions = None
        self.test_labels = None

    def __build_model(self) -> None:
        """ Init transformer model + tokenizer + classification head."""

        if self.hparams.transformer_type == 'roberta-long':
            self.transformer = RobertaLongForMaskedLM.from_pretrained(
                self.hparams.encoder_model,
                output_hidden_states=True,
                gradient_checkpointing=True)

        elif self.hparams.transformer_type == 'longformer':
            self.transformer = AutoModel.from_pretrained(
                self.hparams.encoder_model,
                output_hidden_states=True,
                gradient_checkpointing=True,  #critical for training speed.
            )

        else:  #BERT
            self.transformer = AutoModel.from_pretrained(
                self.hparams.encoder_model,
                output_hidden_states=True,
            )

        logger.warning(f'model is {self.hparams.encoder_model}')

        if self.hparams.transformer_type == 'longformer':
            logger.warning('Turning ON gradient checkpointing...')

            self.transformer = AutoModel.from_pretrained(
                self.hparams.encoder_model,
                output_hidden_states=True,
                gradient_checkpointing=True,  #critical for training speed.
            )

        else:
            self.transformer = AutoModel.from_pretrained(
                self.hparams.encoder_model,
                output_hidden_states=True,
            )

        # set the number of features our encoder model will return...
        self.encoder_features = 768

        # Tokenizer
        if self.hparams.transformer_type == 'longformer' or self.hparams.transformer_type == 'roberta-long':
            self.tokenizer = Tokenizer(
                pretrained_model=self.hparams.encoder_model,
                max_tokens=self.hparams.max_tokens_longformer)
            self.tokenizer.max_len = 4096

        else:
            self.tokenizer = Tokenizer(
                pretrained_model=self.hparams.encoder_model, max_tokens=512)

        #others:
        #'emilyalsentzer/Bio_ClinicalBERT' 'simonlevine/biomed_roberta_base-4096-speedfix'
        # Ben's new architecture
        if self.hparams.nn_arch == 'ben1':
            self.classification_head = nn.Sequential(
                nn.Linear(self.encoder_features, self.encoder_features * 3),
                nn.Dropout(0.1),
                nn.Linear(self.encoder_features * 3, self.encoder_features),
                nn.Linear(self.encoder_features, self.hparams.n_labels),
            )

        elif self.hparams.nn_arch == 'ben2':
            self.classification_head = nn.Sequential(
                nn.Dropout(0.1),
                nn.Linear(self.encoder_features, self.encoder_features * 2),
                nn.Tanh(),
                nn.Linear(self.encoder_features * 2,
                          self.encoder_features * 3),
                nn.ReLU(),
                nn.Linear(self.encoder_features * 3, self.encoder_features),
                nn.Sigmoid(),
                nn.Linear(self.encoder_features, self.hparams.n_labels),
            )

        elif self.hparams.nn_arch == 'CNN':
            logger.critical('CNN not yet implemented')

        elif self.hparams.nn_arch == 'default':
            self.classification_head = nn.Sequential(
                nn.Linear(self.encoder_features, self.encoder_features * 2),
                nn.Tanh(),
                nn.Linear(self.encoder_features * 2, self.encoder_features),
                nn.Tanh(),
                nn.Linear(self.encoder_features, self.hparams.n_labels),
            )

        # Classification head
        elif self.hparams.single_label_encoding == 'default':
            self.classification_head = nn.Sequential(
                nn.Linear(self.encoder_features, self.encoder_features * 2),
                nn.Tanh(),
                nn.Linear(self.encoder_features * 2, self.encoder_features),
                nn.Tanh(),
                nn.Linear(self.encoder_features, self.hparams.n_labels),
            )

        elif self.hparams.single_label_encoding == 'graphical':
            logger.critical('Graphical embedding not yet implemented!')
            # self.classification_head = nn.Sequential(
            #TODO
            # )

    def __build_loss(self):
        """ Initializes the loss function/s. """
        #FOR SINGLE LABELS --> MSE (linear regression) LOSS (like a regression problem)
        # For multiple POSSIBLE discrete single labels, CELoss
        # for many possible categoricla labels, binary cross-entropy (logistic regression for all labels.)
        self._loss = nn.BCELoss()
        self._loss = nn.BCEWithLogitsLoss()
        # self._loss = nn.CrossEntropyLoss()

        # self._loss = nn.MSELoss()

    def unfreeze_encoder(self) -> None:
        """ un-freezes the encoder layer. """
        if self._frozen:
            log.info(f"\n-- Encoder model fine-tuning")
            for param in self.transformer.parameters():
                param.requires_grad = True
            self._frozen = False

    def freeze_encoder(self) -> None:
        """ freezes the encoder layer. """
        for param in self.transformer.parameters():
            param.requires_grad = False
        self._frozen = True

    def predict(self, sample: dict) -> dict:
        """ Predict function.
        :param sample: dictionary with the text we want to classify.

        Returns:
            Dictionary with the input text and the predicted label.
        """
        if self.training:
            self.eval()

        with torch.no_grad():
            model_input, _ = self.prepare_sample([sample],
                                                 prepare_target=False)
            model_out = self.forward(**model_input)
            logits = model_out["logits"].numpy()
            predicted_labels = [
                #TODO change this for no label encoder
                self.mlb.inverse_transform[prediction]
                for prediction in np.argmax(logits, axis=1)
            ]
            sample["predicted_label"] = predicted_labels

        return sample

    def forward(self, tokens, lengths):
        """ Usual pytorch forward function.
        :param tokens: text sequences [batch_size x src_seq_len]
        :param lengths: source lengths [batch_size]

        Returns:
            Dictionary with model outputs (e.g: logits)
        """
        tokens = tokens[:, :lengths.max()]
        # When using just one GPU this should not change behavior
        # but when splitting batches across GPU the tokens have padding
        # from the entire original batch
        mask = lengths_to_mask(lengths, device=tokens.device)

        # Run BERT model.
        word_embeddings = self.transformer(tokens, mask)[0]

        # Average Pooling
        word_embeddings = mask_fill(0.0, tokens, word_embeddings,
                                    self.tokenizer.padding_index)
        sentemb = torch.sum(word_embeddings, 1)
        sum_mask = mask.unsqueeze(-1).expand(
            word_embeddings.size()).float().sum(1)
        sentemb = sentemb / sum_mask

        return {"logits": self.classification_head(sentemb)}

    def loss(self, predictions: dict, targets: dict) -> torch.tensor:
        """
        Computes Loss value according to a loss function.
        :param predictions: model specific output. Must contain a key 'logits' with
            a tensor [batch_size x 1] with model predictions
        :param labels: Label values [batch_size]

        Returns:
            torch.tensor with loss value.
        """
        return self._loss(predictions["logits"], targets["labels"].float())

    def prepare_sample(self,
                       sample: list,
                       prepare_target: bool = True) -> (dict, dict):
        """
        Function that prepares a sample to input the model.
        :param sample: list of dictionaries.
        
        Returns:
            - dictionary with the expected model inputs.
            - dictionary with the expe  cted target labels.
        """
        # logger.info("Sample label:{}".format([sample[i]['labels'] for i in range(6)]))
        # logger.info("Sample label:{}".format(sample))
        # sample['text'] = collate_tensors(sample[i]['text']for i in range(6))
        # sample = collate_tensors(sample)

        texts = [s['text'] for s in sample]
        labels = [s['labels'] for s in sample]

        # logger.info("sample text len is:{}".format(len(sample['text'])))
        # logger.info("text 1is:{}".format(sample['text'][0]))
        # logger.info("text2 is:{}".format(sample['text'][1]))
        # logger.info("text3 is:{}".format(sample['text'][2]))
        tokens, lengths = self.tokenizer.batch_encode(texts)
        # logger.info("sample text len is:{}".format(len(sample['text'])))
        # logger.info("Sample label after collate:{}".format(sample['labels']))
        # logger.info("labels are lists: {}".format(type(sample['labels'][0]) == list ))
        # logger.info("lengths:{}".format(lengths))

        inputs = {"tokens": tokens, "lengths": lengths}

        if not prepare_target:
            return inputs, {}

        # return inputs, self.data[self.hparams.targets]
        # Prepare target:
        # try:
        #NOTE WARNING torch.tensor is kinda bad maybe switch for a copier
        # logger.info(labels)
        #NOTE WARNING double check that mlb is working correct I think it is
        sample_labels = torch.tensor(self.mlb.transform(labels))
        # logger.info(sample_labels)
        targets = {'labels': sample_labels}
        return inputs, targets

        # if not sample['labels']:
        #     targets = {'labels': self.mlb.transform('')}
        # else:
        # logger.info('sample: {}'.format(sample['labels']))
        # logger.info('type: {}'.format(type(sample['labels'])))
        # targets = {"labels": self.mlb.transform(sample["labels"])}
        # a = [self.mlb.transform([x]) if x else self.mlb.transform(['']) for x in sample["labels"]]
        # if sample['labels']==[]:
        #     a = self.mlb.transform(' ')
        # else:
        #     a = self.mlb.transform([sample['labels']])
        # b = torch.tensor(a)
        # c = b.squeeze()

        targets = {"labels": c}
        # logger.info(targets['labels'])
        # targets = {"labels": torch.tensor([self.mlb.transform([x]) if x else self.mlb.transform(['']) for x in sample["labels"]])}
        # logger.info('targets: {}'.format(targets['labels']))
        # logger.info("input len{} is {}".format(len(inputs['lengths']),inputs['lengths']))
        # logger.info(targets['labels'].size())
        return inputs, targets
        # except RuntimeError:
        # raise Exception("Label encoder found an unknown label.")

    def training_step(self, batch: tuple, batch_nb: int, *args,
                      **kwargs) -> dict:
        """ 
        Runs one training step. This usually consists in the forward function followed
            by the loss function.
        
        :param batch: The output of your dataloader. 
        :param batch_nb: Integer displaying which batch this is

        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)

        self.log('loss', loss_val)

        # can also return just a scalar instead of a dict (return loss_val)
        return loss_val

    def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ 
        Runs one training step. This usually consists in the forward function followed
            by the loss function.
        
        :param batch: The output of your dataloader. 
        :param batch_nb: Integer displaying which batch this is

        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        # logger.info(batch)
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        # if self.trainer.use_dp or self.trainer.use_ddp2:
        #     loss_val = loss_val.unsqueeze(0)

        self.log('test_loss', loss_val)

        y_hat = model_out['logits']
        # labels_hat = torch.argmax(y_hat, dim=1)
        preds = torch.tensor([[1 if x > 0 else 0 for x in item]
                              for item in y_hat],
                             device=y_hat.device)
        y = targets['labels']
        return {'loss': loss_val, 'preds': preds, 'target': y}

    def test_step_end(self, outputs):

        preds = outputs['preds']
        target = outputs['target']

        # logger.info("Pred shape is {}".format(preds.size()))
        # logger.info("Target shape is {}".format(target.size()))
        # f1 = metrics.f1_score(labels_hat,y,    class_reduction='weighted')
        # prec =metrics.precision(labels_hat,y,  class_reduction='weighted')
        # recall = metrics.recall(labels_hat,y,  class_reduction='weighted')
        # acc = metrics.accuracy(labels_hat,y,   class_reduction='weighted')
        # # auroc = metrics.multiclass_auroc(labels_hat, y)

        acc = self.acc(preds, target)
        f1 = self.f1(preds, target)
        # f1 = fnmetrics.f1(preds, target, num_classes=self.hparams.n_labels)
        prec = self.prec(preds, target)
        recall = self.recall(preds, target)
        try:
            auroc = self.auroc(preds.float(), target)
        except ValueError as v:
            print(v)
            auroc = 0
        confusion_matrix = self.confusion_matrix(preds, target)

        self.log('test_batch_prec', prec)
        self.log('test_batch_f1', f1)
        self.log('test_batch_recall', recall)
        self.log('test_batch_weighted_acc', acc)
        self.log('test_batch auroc', auroc)
        # self.log('test_batch_auc_roc', auroc)

        from pytorch_lightning.metrics.functional import confusion_matrix
        # TODO CHANGE THIS
        # return (labels_hat, y)
        # cm = confusion_matrix(preds = preds,target=target,normalize=None, num_classes=50)
        # cm = confusion_matrix(preds = labels_hat,target=y,normalize=False, num_classes=len(y.unique()))
        # self.test_conf_matrices.append(cm)
        # logger.info(labels_hat)
        # logger.info(y)
        # logger.info(classification_report(preds.detach().cpu(), target.detach().cpu()))
        # logger.info(confusion_matrix)
        #update and log
        if self.test_predictions is None:
            self.test_predictions = preds
            self.test_labels = target
        else:

            self.test_predictions = torch.cat((self.test_predictions, preds),
                                              0)
            self.test_labels = torch.cat((self.test_labels, target), 0)

    def validation_step(self, batch: tuple, batch_nb: int, *args,
                        **kwargs) -> dict:
        """ Similar to the training step but with the model in eval mode.

        Returns:
            - dictionary passed to the validation_end function.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        y = targets["labels"]
        y_hat = model_out["logits"]

        # acc
        # logger.info(y_hat)
        # logger.info(y)
        # labels_hat = torch.argmax(y_hat, dim=0)
        # labels_hat
        labels_hat = torch.tensor([[1 if x > 0 else 0 for x in item]
                                   for item in y_hat],
                                  device=y_hat.device)

        # logger.info(labels_hat.device)
        # logger.info(torch.argmax(y_hat, dim=1))

        val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
        val_acc = torch.tensor(val_acc)

        if self.on_gpu:
            val_acc = val_acc.cuda(loss_val.device.index)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)
            val_acc = val_acc.unsqueeze(0)

        self.log('val_loss', loss_val)

        # f1 = metrics.f1_score(labels_hat, y,class_reduction='weighted')
        # prec =metrics.precision(labels_hat, y,class_reduction='weighted')
        # recall = metrics.recall(labels_hat, y,class_reduction='weighted')
        # acc = metrics.accuracy(labels_hat, y,class_reduction='weighted')
        # auroc = skm.AUROC
        # # auroc = metrics.multiclass_auroc(y_hat,y)

        # self.log('val_prec',prec)
        # self.log('val_f1',f1)
        # self.log('val_recall',recall)
        # self.log('val_acc_weighted', acc)
        # logger.info(classification_report(labels_hat.detach().cpu(), y.detach().cpu()))
        # self.log('val_cm',cm)

    def configure_optimizers(self):
        """ Sets different Learning rates for different parameter groups. """
        parameters = [
            {
                "params": self.classification_head.parameters()
            },
            {
                "params": self.transformer.parameters(),
                "lr": self.hparams.encoder_learning_rate,
            },
        ]
        optimizer = optim.Adam(parameters, lr=self.hparams.learning_rate)
        return [optimizer], []

    def on_epoch_end(self):
        """ Pytorch lightning hook """
        if self.current_epoch + 1 >= self.nr_frozen_epochs:
            self.unfreeze_encoder()

    @classmethod
    def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser:
        """ Parser for Estimator specific arguments/hyperparameters. 
        :param parser: argparse.ArgumentParser

        Returns:
            - updated parser
        """
        parser.add_argument(
            "--encoder_model",
            default=
            'emilyalsentzer/Bio_ClinicalBERT',  # 'allenai/biomed_roberta_base',#'simonlevine/biomed_roberta_base-4096-speedfix', # 'bert-base-uncased',
            type=str,
            help="Encoder model to be used.",
        )

        parser.add_argument(
            "--transformer_type",
            default='bert',  #'longformer', roberta-long
            type=str,
            help=
            "Encoder model /tokenizer to be used (has consequences for tokenization and encoding; default = longformer).",
        )

        parser.add_argument(
            "--single_label_encoding",
            default='none',
            type=str,
            help=
            "How should labels be encoded? Default for torch-nlp label-encoder...",
        )

        parser.add_argument(
            "--max_tokens_longformer",
            default=4096,
            type=int,
            help="Max tokens to be considered per instance..",
        )
        parser.add_argument(
            "--max_tokens",
            default=512,
            type=int,
            help="Max tokens to be considered per instance..",
        )
        parser.add_argument(
            "--encoder_learning_rate",
            default=1e-05,
            type=float,
            help="Encoder specific learning rate.",
        )
        parser.add_argument(
            "--learning_rate",
            default=3e-05,
            type=float,
            help="Classification head learning rate.",
        )
        parser.add_argument(
            "--nr_frozen_epochs",
            default=0,
            type=int,
            help="Number of epochs we want to keep the encoder model frozen.",
        )
        parser.add_argument(
            "--train_csv",
            default="data/intermediary-data/notes2diagnosis-icd-train.csv",
            type=str,
            help="Path to the file containing the train data.",
        )
        parser.add_argument(
            "--dev_csv",
            default="data/intermediary-data/notes2diagnosis-icd-validate.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--test_csv",
            default="data/intermediary-data/notes2diagnosis-icd-test.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--loader_workers",
            default=8,
            type=int,
            help="How many subprocesses to use for data loading. 0 means that \
                the data will be loaded in the main process.",
        )
        return parser
Пример #2
0
class MedNLIClassifier(pl.LightningModule):
    """
    Sample model to show how to use a Transformer model to classify sentences.
    
    :param hparams: ArgumentParser containing the hyperparameters.
    """

    # *************************
    class DataModule(pl.LightningDataModule):
        def __init__(self, classifier_instance):
            super().__init__()
            self.hparams = classifier_instance.hparams

            if self.hparams.transformer_type == 'longformer':
                self.hparams.batch_size = 1
            self.classifier = classifier_instance

            self.transformer_type = self.hparams.transformer_type

            # Label Encoder
            self.label_encoder = LabelEncoder(
                ['contradiction', 'entailment', 'neutral'], reserved_labels=[])

            self.label_encoder.unknown_index = None

        def get_mednli_data(self, raw_data: list) -> list:  #REWRITE
            """ Reads a comma separated value file.

            :param path: path to a csv file.
            
            :return: List of records as dictionaries
            """

            df = pd.DataFrame(raw_data)
            df = df.rename(columns={0: 'premise', 1: 'hypothesis', 2: 'label'})

            df["text"] = df["premise"].astype(str) + df["hypothesis"].astype(
                str)
            df["label"] = df["label"].astype(str)
            df = df.drop(['premise', 'hypothesis'], axis=1)
            return df.to_dict("records")

        def setup(self, stage=None):
            self.mednli_train, self.mednli_dev, self.mednli_test = load_mednli(
            )

        def train_dataloader(self) -> DataLoader:
            """ Function that loads the train set. """
            logger.warning('Loading training data...')
            self._train_dataset = self.get_mednli_data(self.mednli_train)
            return DataLoader(
                dataset=self._train_dataset,
                sampler=RandomSampler(self._train_dataset),
                batch_size=self.hparams.batch_size,
                collate_fn=self.classifier.prepare_sample,
                num_workers=self.hparams.loader_workers,
            )

        def val_dataloader(self) -> DataLoader:
            logger.warning('Loading validation data...')
            """ Function that loads the validation set. """
            self._dev_dataset = self.get_mednli_data(self.mednli_dev)
            return DataLoader(
                dataset=self._dev_dataset,
                batch_size=self.hparams.batch_size,
                collate_fn=self.classifier.prepare_sample,
                num_workers=self.hparams.loader_workers,
            )

        def test_dataloader(self) -> DataLoader:
            logger.warning('Loading testing data...')
            """ Function that loads the validation set. """
            self._test_dataset = self.get_mednli_data(self.mednli_test)

            return DataLoader(
                dataset=self._test_dataset,
                batch_size=self.hparams.batch_size,
                collate_fn=self.classifier.prepare_sample,
                num_workers=self.hparams.loader_workers,
            )

#    ****************

    def __init__(self, hparams: Namespace) -> None:
        super(MedNLIClassifier, self).__init__()

        self.hparams = hparams
        self.batch_size = hparams.batch_size

        # Build Data module
        self.data = self.DataModule(self)

        # build model
        self.__build_model()

        # Loss criterion initialization.
        self.__build_loss()

        if hparams.nr_frozen_epochs > 0:
            self.freeze_encoder()
        else:
            self._frozen = False
        self.nr_frozen_epochs = hparams.nr_frozen_epochs

        self.confusion_matrix = ConfusionMatrix(num_classes=3)

    def __build_model(self) -> None:
        """ Init transformer model + tokenizer + classification head."""

        if self.hparams.transformer_type == 'roberta-long':
            self.transformer = RobertaLongForMaskedLM.from_pretrained(
                self.hparams.encoder_model,
                output_hidden_states=True,
                gradient_checkpointing=True)

        elif self.hparams.transformer_type == 'longformer':
            self.transformer = AutoModel.from_pretrained(
                self.hparams.encoder_model,
                output_hidden_states=True,
                gradient_checkpointing=True,  #critical for training speed.
            )

        else:  #BERT
            self.transformer = AutoModel.from_pretrained(
                self.hparams.encoder_model,
                output_hidden_states=True,
            )

        logger.warning(f'model is {self.hparams.encoder_model}')

        if self.hparams.transformer_type == 'longformer':
            logger.warning('Turnin ON gradient checkpointing...')

            self.transformer = AutoModel.from_pretrained(
                self.hparams.encoder_model,
                output_hidden_states=True,
                gradient_checkpointing=True,  #critical for training speed.
            )

        else:
            self.transformer = AutoModel.from_pretrained(
                self.hparams.encoder_model,
                output_hidden_states=True,
            )

        # set the number of features our encoder model will return...
        self.encoder_features = 768

        # Tokenizer
        if self.hparams.transformer_type == 'longformer' or self.hparams.transformer_type == 'roberta-long':
            self.tokenizer = Tokenizer(
                pretrained_model=self.hparams.encoder_model,
                max_tokens=self.hparams.max_tokens_longformer)
            self.tokenizer.max_len = 4096

        else:
            self.tokenizer = Tokenizer(
                pretrained_model=self.hparams.encoder_model, max_tokens=512)

        #others:
        #'emilyalsentzer/Bio_ClinicalBERT' 'simonlevine/biomed_roberta_base-4096-speedfix'

        # Classification head
        if self.hparams.single_label_encoding == 'default':
            self.classification_head = nn.Sequential(
                nn.Linear(self.encoder_features, self.encoder_features * 2),
                nn.Tanh(),
                nn.Linear(self.encoder_features * 2, self.encoder_features),
                nn.Tanh(),
                nn.Linear(self.encoder_features,
                          self.data.label_encoder.vocab_size),
            )

        elif self.hparams.single_label_encoding == 'graphical':
            logger.critical('Graphical embedding not yet implemented!')
            # self.classification_head = nn.Sequential(
            #TODO
            # )

    def __build_loss(self):
        """ Initializes the loss function/s. """
        #FOR SINGLE LABELS --> MSE (linear regression) LOSS (like a regression problem)
        # For multiple POSSIBLE discrete single labels, CELoss
        # for many possible categoricla labels, binary cross-entropy (logistic regression for all labels.)
        self._loss = nn.CrossEntropyLoss()

        # self._loss = nn.MSELoss()

    def unfreeze_encoder(self) -> None:
        """ un-freezes the encoder layer. """
        if self._frozen:
            log.info(f"\n-- Encoder model fine-tuning")
            for param in self.transformer.parameters():
                param.requires_grad = True
            self._frozen = False

    def freeze_encoder(self) -> None:
        """ freezes the encoder layer. """
        for param in self.transformer.parameters():
            param.requires_grad = False
        self._frozen = True

    def predict(self, sample: dict) -> dict:
        """ Predict function.
        :param sample: dictionary with the text we want to classify.

        Returns:
            Dictionary with the input text and the predicted label.
        """
        if self.training:
            self.eval()

        with torch.no_grad():
            model_input, _ = self.prepare_sample([sample],
                                                 prepare_target=False)
            model_out = self.forward(**model_input)
            logits = model_out["logits"].numpy()
            predicted_labels = [
                self.data.label_encoder.index_to_token[prediction]
                for prediction in np.argmax(logits, axis=1)
            ]
            sample["predicted_label"] = predicted_labels[0]

        return sample

    def forward(self, tokens, lengths):
        """ Usual pytorch forward function.
        :param tokens: text sequences [batch_size x src_seq_len]
        :param lengths: source lengths [batch_size]

        Returns:
            Dictionary with model outputs (e.g: logits)
        """
        tokens = tokens[:, :lengths.max()]
        # When using just one GPU this should not change behavior
        # but when splitting batches across GPU the tokens have padding
        # from the entire original batch
        mask = lengths_to_mask(lengths, device=tokens.device)

        # Run BERT model.
        word_embeddings = self.transformer(tokens, mask)[0]

        # Average Pooling
        word_embeddings = mask_fill(0.0, tokens, word_embeddings,
                                    self.tokenizer.padding_index)
        sentemb = torch.sum(word_embeddings, 1)
        sum_mask = mask.unsqueeze(-1).expand(
            word_embeddings.size()).float().sum(1)
        sentemb = sentemb / sum_mask

        return {"logits": self.classification_head(sentemb)}

    def loss(self, predictions: dict, targets: dict) -> torch.tensor:
        """
        Computes Loss value according to a loss function.
        :param predictions: model specific output. Must contain a key 'logits' with
            a tensor [batch_size x 1] with model predictions
        :param labels: Label values [batch_size]

        Returns:
            torch.tensor with loss value.
        """
        return self._loss(predictions["logits"], targets["labels"])

    def prepare_sample(self,
                       sample: list,
                       prepare_target: bool = True) -> (dict, dict):
        """
        Function that prepares a sample to input the model.
        :param sample: list of dictionaries.
        
        Returns:
            - dictionary with the expected model inputs.
            - dictionary with the expected target labels.
        """
        sample = collate_tensors(sample)

        tokens, lengths = self.tokenizer.batch_encode(sample["text"])

        inputs = {"tokens": tokens, "lengths": lengths}

        if not prepare_target:
            return inputs, {}

        # Prepare target:
        try:
            targets = {
                "labels": self.data.label_encoder.batch_encode(sample["label"])
            }
            return inputs, targets
        except RuntimeError:
            raise Exception("Label encoder found an unknown label.")

    def training_step(self, batch: tuple, batch_nb: int, *args,
                      **kwargs) -> dict:
        """ 
        Runs one training step. This usually consists in the forward function followed
            by the loss function.
        
        :param batch: The output of your dataloader. 
        :param batch_nb: Integer displaying which batch this is

        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)

        self.log('loss', loss_val)

        # can also return just a scalar instead of a dict (return loss_val)
        return loss_val

    def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ 
        Runs one training step. This usually consists in the forward function followed
            by the loss function.
        
        :param batch: The output of your dataloader. 
        :param batch_nb: Integer displaying which batch this is

        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)

        self.log('test_loss', loss_val)

        y_hat = model_out['logits']
        labels_hat = torch.argmax(y_hat, dim=1)
        y = targets['labels']

        f1 = metrics.f1(labels_hat, y, average='weighted', num_classes=3)
        prec = metrics.precision(labels_hat,
                                 y,
                                 class_reduction='weighted',
                                 num_classes=3)
        recall = metrics.recall(labels_hat,
                                y,
                                class_reduction='weighted',
                                num_classes=3)
        acc = metrics.accuracy(labels_hat,
                               y,
                               class_reduction='weighted',
                               num_classes=3)

        self.confusion_matrix.update(labels_hat, y)
        self.log('test_batch_prec', prec)
        self.log('test_batch_f1', f1)
        self.log('test_batch_recall', recall)
        self.log('test_batch_weighted_acc', acc)

    def validation_step(self, batch: tuple, batch_nb: int, *args,
                        **kwargs) -> dict:
        """ Similar to the training step but with the model in eval mode.

        Returns:
            - dictionary passed to the validation_end function.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        y = targets["labels"]
        y_hat = model_out["logits"]

        # acc
        labels_hat = torch.argmax(y_hat, dim=1)
        val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
        val_acc = torch.tensor(val_acc)

        if self.on_gpu:
            val_acc = val_acc.cuda(loss_val.device.index)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)
            val_acc = val_acc.unsqueeze(0)

        self.log('val_loss', loss_val)

        f1 = metrics.f1(labels_hat, y, average='weighted', num_classes=3)
        prec = metrics.precision(labels_hat,
                                 y,
                                 class_reduction='weighted',
                                 num_classes=3)
        recall = metrics.recall(labels_hat,
                                y,
                                class_reduction='weighted',
                                num_classes=3)
        acc = metrics.accuracy(labels_hat,
                               y,
                               class_reduction='weighted',
                               num_classes=3)

        self.log('val_prec', prec)
        self.log('val_f1', f1)
        self.log('val_recall', recall)
        self.log('val_acc_weighted', acc)

    def configure_optimizers(self):
        """ Sets different Learning rates for different parameter groups. """
        parameters = [
            {
                "params": self.classification_head.parameters()
            },
            {
                "params": self.transformer.parameters(),
                "lr": self.hparams.encoder_learning_rate,
            },
        ]
        optimizer = optim.Adam(parameters, lr=self.hparams.learning_rate)
        return [optimizer], []

    def on_epoch_end(self):
        """ Pytorch lightning hook """
        if self.current_epoch + 1 >= self.nr_frozen_epochs:
            self.unfreeze_encoder()

    @classmethod
    def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser:
        """ Parser for Estimator specific arguments/hyperparameters. 
        :param parser: argparse.ArgumentParser

        Returns:
            - updated parser
        """

        parser.add_argument(
            "--encoder_model",
            default=
            'simonlevine/bioclinical-roberta-long',  #'emilyalsentzer/Bio_ClinicalBERT',# 'allenai/biomed_roberta_base',', # 'bert-base-uncased',
            type=str,
            help="Encoder model to be used.",
        )

        parser.add_argument(
            "--transformer_type",
            default='roberta-long',  #'bert', roberta, or roberta-long
            type=str,
            help=
            "Encoder model /tokenizer to be used (has consequences for tokenization and encoding; default = longformer).",
        )

        parser.add_argument(
            "--single_label_encoding",
            default='default',
            type=str,
            help=
            "How should labels be encoded? Default for torch-nlp label-encoder...",
        )

        parser.add_argument(
            "--max_tokens_longformer",
            default=4096,
            type=int,
            help="Max tokens to be considered per instance..",
        )
        parser.add_argument(
            "--max_tokens",
            default=512,
            type=int,
            help="Max tokens to be considered per instance..",
        )
        parser.add_argument(
            "--encoder_learning_rate",
            default=1e-05,
            type=float,
            help="Encoder specific learning rate.",
        )
        parser.add_argument(
            "--learning_rate",
            default=3e-05,
            type=float,
            help="Classification head learning rate.",
        )
        parser.add_argument(
            "--nr_frozen_epochs",
            default=0,
            type=int,
            help="Number of epochs we want to keep the encoder model frozen.",
        )

        parser.add_argument(
            "--loader_workers",
            default=8,
            type=int,
            help="How many subprocesses to use for data loading. 0 means that \
                the data will be loaded in the main process.",
        )
        return parser
Пример #3
0
class Classifier(pl.LightningModule):
    """
    Sample model to show how to use a Transformer model to classify sentences.
    
    :param hparams: ArgumentParser containing the hyperparameters.
    """

    # *************************
    class DataModule(pl.LightningDataModule):
        def __init__(self, classifier_instance):
            super().__init__()
            self.hparams = classifier_instance.hparams
            if self.hparams.transformer_type == 'longformer':
                self.hparams.batch_size = 1
            self.classifier = classifier_instance

            self.transformer_type = self.hparams.transformer_type
            self.raw_data = self.get_mimic_data()

            msk = np.random.rand(len(self.raw_data)) < 0.8
            self.train = self.raw_data[msk]
            self.test = self.raw_data[~msk]

            # self.label_encoder.unknown_index = None

        def get_mimic_data(self, path: str) -> list:
            """ Reads a comma separated value file.

            :param path: path to a csv file.
            
            :return: List of records as dictionaries
            """
            df = pd.read_csv(path)
            df = df.drop('ROW_ID',axis=1)
            return df
            # df["text"] = df["text"].astype(str)
            # df["label"] = df["label"].astype(str)
            # return df.to_dict("records")

        def train_dataloader(self) -> DataLoader:
            """ Function that loads the train set. """
            self._train_dataset = MimicAnnotDataset(self.train)
            return DataLoader(
                dataset=self._train_dataset,
                sampler=RandomSampler(self._train_dataset),
                batch_size=self.hparams.batch_size,
                collate_fn=self.classifier.prepare_sample,
                num_workers=self.hparams.loader_workers,
            )

        # def val_dataloader(self) -> DataLoader:
        #     """ Function that loads the validation set. """
        #     self._dev_dataset = self.get_mimic_data(self.hparams.dev_csv)
        #     return DataLoader(
        #         dataset=self._dev_dataset,
        #         batch_size=self.hparams.batch_size,
        #         collate_fn=self.classifier.prepare_sample,
        #         num_workers=self.hparams.loader_workers,
        #     )

        def test_dataloader(self) -> DataLoader:
            """ Function that loads the validation set. """
            self._test_dataset = self.MimicAnnotDataset(self.test)

            return DataLoader(
                dataset=self._test_dataset,
                batch_size=self.hparams.batch_size,
                collate_fn=self.classifier.prepare_sample,
                num_workers=self.hparams.loader_workers,
            )

#    ****************

    def __init__(self, hparams: Namespace) -> None:
        super(Classifier,self).__init__()

        self.hparams = hparams
        self.batch_size = hparams.batch_size

        # Build Data module
        self.data = self.DataModule(self)
        
        # build model
        self.__build_model()

        # Loss criterion initialization.
        self.__build_loss()

        if hparams.nr_frozen_epochs > 0:
            self.freeze_encoder()
        else:
            self._frozen = False
        self.nr_frozen_epochs = hparams.nr_frozen_epochs


    def __build_model(self) -> None:
        """ Init transformer model + tokenizer + classification head."""

        #simonlevine/biomed_roberta_base-4096-speedfix'
        
        self.transformer = AutoModel.from_pretrained(
            self.hparams.encoder_model,
            output_hidden_states=True,
            # gradient_checkpointing=True, #critical for training speed.
        )

        if self.hparams.transformer_type == 'longformer':
            logger.warning('Turnin ON gradient checkpointing...')

            self.transformer = AutoModelForSequenceClassification.from_pretrained(
            self.hparams.encoder_model,
            output_hidden_states=True,
            gradient_checkpointing=True, #critical for training speed.
                )

        else: self.transformer = AutoModelForSequenceClassification.from_pretrained(
            self.hparams.encoder_model,
            output_hidden_states=True,
                )
            
           #others to try:
            # bert-base-uncased
            #'emilyalsentzer/Bio_ClinicalBERT'
            # allenai/biomed_roberta_base
            # simonlevine/biomed_roberta_base-4096-speedfix'
        
        # set the number of features our encoder model will return...
        self.encoder_features = 768

        # Tokenizer
        if self.hparams.transformer_type  == 'longformer':
            self.tokenizer = Tokenizer(
                pretrained_model=self.hparams.encoder_model,
                max_tokens = self.hparams.max_tokens_longformer)

        else: self.hparams.tokenizer = Tokenizer(
            pretrained_model=self.hparams.encoder_model,
            max_tokens = self.hparams.max_tokens)

           #others:
             #'emilyalsentzer/Bio_ClinicalBERT' 'simonlevine/biomed_roberta_base-4096-speedfix'

 
    def __build_loss(self):
        """ Initializes the loss function/s. """
        #FOR SINGLE LABELS --> MSE (linear regression) LOSS (like a regression problem)
        # For multiple POSSIBLE discrete single labels, CELoss
        # for many possible categoricla labels, binary cross-entropy (logistic regression for all labels.)
        self._loss = nn.CrossEntropyLoss()

        # self._loss = nn.MSELoss()

    def unfreeze_encoder(self) -> None:
        """ un-freezes the encoder layer. """
        if self._frozen:
            log.info(f"\n-- Encoder model fine-tuning")
            for param in self.transformer.parameters():
                param.requires_grad = True
            self._frozen = False

    def freeze_encoder(self) -> None:
        """ freezes the encoder layer. """
        for param in self.transformer.parameters():
            param.requires_grad = False
        self._frozen = True

    def predict(self, sample: dict) -> dict:
        """ Predict function.
        :param sample: dictionary with the text we want to classify.

        Returns:
            Dictionary with the input text and the predicted label.
        """
        if self.training:
            self.eval()

        with torch.no_grad():
            model_input, _ = self.prepare_sample([sample], prepare_target=False)
            model_out = self.forward(**model_input)
            logits = model_out["logits"].numpy()
            predicted_labels = [
                self.data.label_encoder.index_to_token[prediction]
                for prediction in np.argmax(logits, axis=1)
            ]
            sample["predicted_label"] = predicted_labels[0]

        return sample

    def forward(self, tokens, lengths):
        """ Usual pytorch forward function.
        :param tokens: text sequences [batch_size x src_seq_len]
        :param lengths: source lengths [batch_size]

        Returns:
            Dictionary with model outputs (e.g: logits)
        """
        tokens = tokens[:, : lengths.max()]
        # When using just one GPU this should not change behavior
        # but when splitting batches across GPU the tokens have padding
        # from the entire original batch
        mask = lengths_to_mask(lengths, device=tokens.device)

        # Run transformer model.
        word_embeddings = self.transformer(tokens, mask)[0]

        # Average Pooling
        word_embeddings = mask_fill(
            0.0, tokens, word_embeddings, self.tokenizer.padding_index
        )
        sentemb = torch.sum(word_embeddings, 1)
        sum_mask = mask.unsqueeze(-1).expand(word_embeddings.size()).float().sum(1)
        sentemb = sentemb / sum_mask

        return {"logits": self.classification_head(sentemb)}

    def loss(self, predictions: dict, targets: dict) -> torch.tensor:
        """
        Computes Loss value according to a loss function.
        :param predictions: model specific output. Must contain a key 'logits' with
            a tensor [batch_size x 1] with model predictions
        :param labels: Label values [batch_size]

        Returns:
            torch.tensor with loss value.
        """
        return self._loss(predictions["logits"], targets["labels"])

    def prepare_sample(self, sample: list, prepare_target: bool = True) -> (dict, dict):
        """
        Function that prepares a sample to input the model.
        :param sample: list of dictionaries.
        
        Returns:
            - dictionary with the expected model inputs.
            - dictionary with the expected target labels.
        """
        sample = collate_tensors(sample)
        tokens, lengths = self.tokenizer.batch_encode(sample["text"])

        inputs = {"tokens": tokens, "lengths": lengths}

        if not prepare_target:
            return inputs, {}

        # Prepare target:
        try:
            targets = {"labels": self.data.batch_encode(sample["label"])}
            return inputs, targets
        except RuntimeError:
            raise Exception("Label encoder found an unknown label.")

    def training_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ 
        Runs one training step. This usually consists in the forward function followed
            by the loss function.
        
        :param batch: The output of your dataloader. 
        :param batch_nb: Integer displaying which batch this is

        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)

        self.log('loss',loss_val)

        # can also return just a scalar instead of a dict (return loss_val)
        return loss_val


    
    def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict:
        """ 
        Runs one training step. This usually consists in the forward function followed
            by the loss function.
        
        :param batch: The output of your dataloader. 
        :param batch_nb: Integer displaying which batch this is

        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)
            
        self.log('test_loss',loss_val)

        # can also return just a scalar instead of a dict (return loss_val)
        return loss_val

    

    def configure_optimizers(self):
        """ Sets different Learning rates for different parameter groups. """
        parameters = [
            {"params": self.classification_head.parameters()},
            {
                "params": self.transformer.parameters(),
                "lr": self.hparams.encoder_learning_rate,
            },
        ]
        optimizer = optim.Adam(parameters, lr=self.hparams.learning_rate)
        return [optimizer], []

    def on_epoch_end(self):
        """ Pytorch lightning hook """
        if self.current_epoch + 1 >= self.nr_frozen_epochs:
            self.unfreeze_encoder()
    
    @classmethod
    def add_model_specific_args(
        cls, parser: ArgumentParser
    ) -> ArgumentParser:
        """ Parser for Estimator specific arguments/hyperparameters. 
        :param parser: argparse.ArgumentParser

        Returns:
            - updated parser
        """
        parser.add_argument(
            "--encoder_model",
            default='simonlevine/biomed_roberta_base-4096-speedfix', # 'bert-base-uncased',
            type=str,
            help="Encoder model to be used.",
        )

        parser.add_argument(
            "--transformer_type",
            default='longformer',
            type=str,
            help="Encoder model /tokenizer to be used (has consequences for tokenization and encoding; default = longformer).",
        )

        parser.add_argument(
            "--single_label_encoding",
            default='default',
            type=str,
            help="How should labels be encoded? Default for torch-nlp label-encoder...",
        )
        

        parser.add_argument(
            "--max_tokens_longformer",
            default=4096,
            type=int,
            help="Max tokens to be considered per instance..",
        )
        parser.add_argument(
            "--max_tokens",
            default=512,
            type=int,
            help="Max tokens to be considered per instance..",
        )
        parser.add_argument(
            "--encoder_learning_rate",
            default=1e-05,
            type=float,
            help="Encoder specific learning rate.",
        )
        parser.add_argument(
            "--learning_rate",
            default=3e-05,
            type=float,
            help="Classification head learning rate.",
        )
        parser.add_argument(
            "--nr_frozen_epochs",
            default=1,
            type=int,
            help="Number of epochs we want to keep the encoder model frozen.",
        )
        parser.add_argument(
            "--train_csv",
            default="data/intermediary-data/notes2diagnosis-icd-train.csv",
            type=str,
            help="Path to the file containing the train data.",
        )
        parser.add_argument(
            "--dev_csv",
            default="data/intermediary-data/notes2diagnosis-icd-validate.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--test_csv",
            default="data/intermediary-data/notes2diagnosis-icd-test.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--loader_workers",
            default=8,
            type=int,
            help="How many subprocesses to use for data loading. 0 means that \
                the data will be loaded in the main process.",
        )
        return parser
Пример #4
0
class Classifier(pl.LightningModule):
    """
    Sample model to show how to use a Transformer model to classify sentences.
    
    :param hparams: ArgumentParser containing the hyperparameters.
    """
    class DataModule(pl.LightningDataModule):
        def __init__(self, classifier_instance):
            super().__init__()
            self.hparams = classifier_instance.hparams
            self.classifier = classifier_instance
            # Label Encoder
            self.label_encoder = LabelEncoder(pd.read_csv(
                self.hparams.train_csv).label.unique().tolist(),
                                              reserved_labels=[])
            self.label_encoder.unknown_index = None

        def read_csv(self, path: str) -> list:
            """ Reads a comma separated value file.

            :param path: path to a csv file.
            
            :return: List of records as dictionaries
            """
            df = pd.read_csv(path)
            df = df[["text", "label"]]
            df["text"] = df["text"].astype(str)
            df["label"] = df["label"].astype(str)
            return df.to_dict("records")

        def train_dataloader(self) -> DataLoader:
            """ Function that loads the train set. """
            self._train_dataset = self.read_csv(self.hparams.train_csv)
            return DataLoader(
                dataset=self._train_dataset,
                sampler=RandomSampler(self._train_dataset),
                batch_size=self.hparams.batch_size,
                collate_fn=self.classifier.prepare_sample,
                num_workers=self.hparams.loader_workers,
            )

        def val_dataloader(self) -> DataLoader:
            """ Function that loads the validation set. """
            self._dev_dataset = self.read_csv(self.hparams.dev_csv)
            return DataLoader(
                dataset=self._dev_dataset,
                batch_size=self.hparams.batch_size,
                collate_fn=self.classifier.prepare_sample,
                num_workers=self.hparams.loader_workers,
            )

        def test_dataloader(self) -> DataLoader:
            """ Function that loads the validation set. """
            self._test_dataset = self.read_csv(self.hparams.test_csv)
            return DataLoader(
                dataset=self._test_dataset,
                batch_size=self.hparams.batch_size,
                collate_fn=self.classifier.prepare_sample,
                num_workers=self.hparams.loader_workers,
            )

    def __init__(self, hparams: Namespace) -> None:
        super(Classifier, self).__init__()
        self.hparams = hparams
        self.batch_size = hparams.batch_size

        # Build Data module
        self.data = self.DataModule(self)

        # build model
        self.__build_model()

        # Loss criterion initialization.
        self.__build_loss()

        if hparams.nr_frozen_epochs > 0:
            self.freeze_encoder()
        else:
            self._frozen = False
        self.nr_frozen_epochs = hparams.nr_frozen_epochs

    def __build_model(self) -> None:
        """ Init BERT model + tokenizer + classification head."""
        self.bert = AutoModel.from_pretrained(self.hparams.encoder_model,
                                              output_hidden_states=True)

        # set the number of features our encoder model will return...
        if self.hparams.encoder_model == "google/bert_uncased_L-2_H-128_A-2":
            self.encoder_features = 128
        else:
            self.encoder_features = 768

        # Tokenizer
        self.tokenizer = Tokenizer("bert-base-uncased")

        # Classification head
        self.classification_head = nn.Sequential(
            nn.Linear(self.encoder_features, self.encoder_features * 2),
            nn.Tanh(),
            nn.Linear(self.encoder_features * 2, self.encoder_features),
            nn.Tanh(),
            nn.Linear(self.encoder_features,
                      self.data.label_encoder.vocab_size),
        )

    def __build_loss(self):
        """ Initializes the loss function/s. """
        self._loss = nn.CrossEntropyLoss()

    def unfreeze_encoder(self) -> None:
        """ un-freezes the encoder layer. """
        if self._frozen:
            log.info(f"\n-- Encoder model fine-tuning")
            for param in self.bert.parameters():
                param.requires_grad = True
            self._frozen = False

    def freeze_encoder(self) -> None:
        """ freezes the encoder layer. """
        for param in self.bert.parameters():
            param.requires_grad = False
        self._frozen = True

    def predict(self, sample: dict) -> dict:
        """ Predict function.
        :param sample: dictionary with the text we want to classify.

        Returns:
            Dictionary with the input text and the predicted label.
        """
        if self.training:
            self.eval()

        with torch.no_grad():
            model_input, _ = self.prepare_sample([sample],
                                                 prepare_target=False)
            model_out = self.forward(**model_input)
            logits = model_out["logits"].numpy()
            predicted_labels = [
                self.data.label_encoder.index_to_token[prediction]
                for prediction in np.argmax(logits, axis=1)
            ]
            sample["predicted_label"] = predicted_labels[0]

        return sample

    def forward(self, tokens, lengths):
        """ Usual pytorch forward function.
        :param tokens: text sequences [batch_size x src_seq_len]
        :param lengths: source lengths [batch_size]

        Returns:
            Dictionary with model outputs (e.g: logits)
        """
        tokens = tokens[:, :lengths.max()]
        # When using just one GPU this should not change behavior
        # but when splitting batches across GPU the tokens have padding
        # from the entire original batch
        mask = lengths_to_mask(lengths, device=tokens.device)

        # Run BERT model.
        word_embeddings = self.bert(tokens, mask)[0]

        # Average Pooling
        word_embeddings = mask_fill(0.0, tokens, word_embeddings,
                                    self.tokenizer.padding_index)
        sentemb = torch.sum(word_embeddings, 1)
        sum_mask = mask.unsqueeze(-1).expand(
            word_embeddings.size()).float().sum(1)
        sentemb = sentemb / sum_mask

        return {"logits": self.classification_head(sentemb)}

    def loss(self, predictions: dict, targets: dict) -> torch.tensor:
        """
        Computes Loss value according to a loss function.
        :param predictions: model specific output. Must contain a key 'logits' with
            a tensor [batch_size x 1] with model predictions
        :param labels: Label values [batch_size]

        Returns:
            torch.tensor with loss value.
        """
        return self._loss(predictions["logits"], targets["labels"])

    def prepare_sample(self,
                       sample: list,
                       prepare_target: bool = True) -> (dict, dict):
        """
        Function that prepares a sample to input the model.
        :param sample: list of dictionaries.
        
        Returns:
            - dictionary with the expected model inputs.
            - dictionary with the expected target labels.
        """
        sample = collate_tensors(sample)
        tokens, lengths = self.tokenizer.batch_encode(sample["text"])

        inputs = {"tokens": tokens, "lengths": lengths}

        if not prepare_target:
            return inputs, {}

        # Prepare target:
        try:
            targets = {
                "labels": self.data.label_encoder.batch_encode(sample["label"])
            }
            return inputs, targets
        except RuntimeError:
            raise Exception("Label encoder found an unknown label.")

    def training_step(self, batch: tuple, batch_nb: int, *args,
                      **kwargs) -> dict:
        """ 
        Runs one training step. This usually consists in the forward function followed
            by the loss function.
        
        :param batch: The output of your dataloader. 
        :param batch_nb: Integer displaying which batch this is

        Returns:
            - dictionary containing the loss and the metrics to be added to the lightning logger.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)

        tqdm_dict = {"train_loss": loss_val}
        output = OrderedDict({
            "loss": loss_val,
            "progress_bar": tqdm_dict,
            "log": tqdm_dict
        })

        # can also return just a scalar instead of a dict (return loss_val)
        return output

    def validation_step(self, batch: tuple, batch_nb: int, *args,
                        **kwargs) -> dict:
        """ Similar to the training step but with the model in eval mode.

        Returns:
            - dictionary passed to the validation_end function.
        """
        inputs, targets = batch
        model_out = self.forward(**inputs)
        loss_val = self.loss(model_out, targets)

        y = targets["labels"]
        y_hat = model_out["logits"]

        # acc
        labels_hat = torch.argmax(y_hat, dim=1)
        val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
        val_acc = torch.tensor(val_acc)

        if self.on_gpu:
            val_acc = val_acc.cuda(loss_val.device.index)

        # in DP mode (default) make sure if result is scalar, there's another dim in the beginning
        if self.trainer.use_dp or self.trainer.use_ddp2:
            loss_val = loss_val.unsqueeze(0)
            val_acc = val_acc.unsqueeze(0)

        output = OrderedDict({
            "val_loss": loss_val,
            "val_acc": val_acc,
        })

        # can also return just a scalar instead of a dict (return loss_val)
        return output

    def validation_end(self, outputs: list) -> dict:
        """ Function that takes as input a list of dictionaries returned by the validation_step
        function and measures the model performance accross the entire validation set.
        
        Returns:
            - Dictionary with metrics to be added to the lightning logger.  
        """
        val_loss_mean = 0
        val_acc_mean = 0
        for output in outputs:
            val_loss = output["val_loss"]

            # reduce manually when using dp
            if self.trainer.use_dp or self.trainer.use_ddp2:
                val_loss = torch.mean(val_loss)
            val_loss_mean += val_loss

            # reduce manually when using dp
            val_acc = output["val_acc"]
            if self.trainer.use_dp or self.trainer.use_ddp2:
                val_acc = torch.mean(val_acc)

            val_acc_mean += val_acc

        val_loss_mean /= len(outputs)
        val_acc_mean /= len(outputs)
        tqdm_dict = {"val_loss": val_loss_mean, "val_acc": val_acc_mean}
        result = {
            "progress_bar": tqdm_dict,
            "log": tqdm_dict,
            "val_loss": val_loss_mean,
        }
        return result

    def configure_optimizers(self):
        """ Sets different Learning rates for different parameter groups. """
        parameters = [
            {
                "params": self.classification_head.parameters()
            },
            {
                "params": self.bert.parameters(),
                "lr": self.hparams.encoder_learning_rate,
            },
        ]
        optimizer = optim.Adam(parameters, lr=self.hparams.learning_rate)
        return [optimizer], []

    def on_epoch_end(self):
        """ Pytorch lightning hook """
        if self.current_epoch + 1 >= self.nr_frozen_epochs:
            self.unfreeze_encoder()

    @classmethod
    def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser:
        """ Parser for Estimator specific arguments/hyperparameters. 
        :param parser: argparse.ArgumentParser

        Returns:
            - updated parser
        """
        parser.add_argument(
            "--encoder_model",
            default="bert-base-uncased",
            type=str,
            help="Encoder model to be used.",
        )
        parser.add_argument(
            "--encoder_learning_rate",
            default=1e-05,
            type=float,
            help="Encoder specific learning rate.",
        )
        parser.add_argument(
            "--learning_rate",
            default=3e-05,
            type=float,
            help="Classification head learning rate.",
        )
        parser.add_argument(
            "--nr_frozen_epochs",
            default=1,
            type=int,
            help="Number of epochs we want to keep the encoder model frozen.",
        )
        parser.add_argument(
            "--train_csv",
            default="data/imdb_reviews_train.csv",
            type=str,
            help="Path to the file containing the train data.",
        )
        parser.add_argument(
            "--dev_csv",
            default="data/imdb_reviews_test.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--test_csv",
            default="data/imdb_reviews_test.csv",
            type=str,
            help="Path to the file containing the dev data.",
        )
        parser.add_argument(
            "--loader_workers",
            default=8,
            type=int,
            help="How many subprocesses to use for data loading. 0 means that \
                the data will be loaded in the main process.",
        )
        return parser