class Classifier(pl.LightningModule): """ Sample model to show how to use a Transformer model to classify sentences. :param hparams: ArgumentParser containing the hyperparameters. """ # ************************* class DataModule(pl.LightningDataModule): def __init__(self, classifier_instance): super().__init__() self.hparams = classifier_instance.hparams # if self.hparams.transformer_type == 'longformer': # self.hparams.batch_size = 1 # raise Exception self.classifier = classifier_instance self.transformer_type = self.hparams.transformer_type self.hparams.n_labels = 50 df = pd.read_csv(self.hparams.train_csv, converters={'CODES': eval}) a = pd.Series([item for sublist in df.CODES for item in sublist]) self.hparams.top_codes = a.value_counts()[:self.hparams. n_labels].index.tolist() # self.top_codes = ['I48', 'CIR007'] # self.hparams.n_labels = len(self.top_codes) logger.warning( f'Classifying against the top {self.hparams.n_labels} most frequent ICD codes: {self.hparams.top_codes}' ) self.mlb = MultiLabelBinarizer() self.mlb.fit([self.hparams.top_codes]) # Label Encoder # if self.hparams.single_label_encoding == 'default': # self.label_encoder = LabelEncoder( # np.unique(self.top_codes).tolist(), # reserved_labels=[] # ) # self.label_encoder.unknown_index = None # New code added for mlb def top_labeler(self, codes): # logger.info("codes are: {}".format(codes)) out = [label for label in codes if label in self.hparams.top_codes] # logger.info("out is: {}".format(out)) # if out == []: # out = [''] # return self.mlb.transform([out]) return out def get_mimic_data(self, path: str) -> list: #REWRITE """ Reads a comma separated value file. :param path: path to a csv file. :return: List of records as dictionaries """ if self.hparams.fast_dev_run: df = pd.read_csv(path, converters={'CODES': eval}, skiprows=range(100, 100000)) elif self.hparams.mid_dev_run: df = pd.read_csv(path, converters={'CODES': eval}, skiprows=range(1000, 100000)) else: df = pd.read_csv(path, converters={'CODES': eval}) df = df[["note_text", "CODES"]] df.rename(columns={ 'note_text': 'text', 'CODES': 'labels' }, inplace=True) # df = df[df['labels'].isin(self.top_codes)] # logger.info(df['labels']) df['labels'] = df['labels'].map(self.top_labeler) # logger.info(df['labels']) # df["text"] = df["text"].astype(str) # df["label"] = df["label"].astype(tr) df.to_csv(f'{path}_top_codes_filtered.csv') logger.warning(f'{path} dataframe has {len(df)} examples.') return df.to_dict("records") def train_dataloader(self) -> DataLoader: """ Function that loads the train set. """ logger.warning('Loading training data...') self._train_dataset = self.get_mimic_data(self.hparams.train_csv) return DataLoader( dataset=self._train_dataset, sampler=RandomSampler(self._train_dataset), batch_size=self.hparams.batch_size, collate_fn=self.classifier.prepare_sample, num_workers=self.hparams.loader_workers, ) def val_dataloader(self) -> DataLoader: logger.warning('Loading validation data...') """ Function that loads the validation set. """ self._dev_dataset = self.get_mimic_data(self.hparams.dev_csv) return DataLoader( dataset=self._dev_dataset, batch_size=self.hparams.batch_size, collate_fn=self.classifier.prepare_sample, num_workers=self.hparams.loader_workers, ) def test_dataloader(self) -> DataLoader: logger.warning('Loading testing data...') """ Function that loads the test set. """ self._test_dataset = self.get_mimic_data(self.hparams.test_csv) return DataLoader( dataset=self._test_dataset, batch_size=self.hparams.batch_size, collate_fn=self.classifier.prepare_sample, num_workers=self.hparams.loader_workers, ) # **************** def __init__(self, hparams: Namespace) -> None: super(Classifier, self).__init__() self.hparams = hparams self.batch_size = hparams.batch_size # Build Data module self.data = self.DataModule(self) # build model self.__build_model() # Loss criterion initialization. self.__build_loss() if hparams.nr_frozen_epochs > 0: self.freeze_encoder() else: self._frozen = False self.nr_frozen_epochs = hparams.nr_frozen_epochs self.test_conf_matrices = [] # Set up multi label binarizer: self.mlb = MultiLabelBinarizer() self.mlb.fit([self.hparams.top_codes]) self.acc = torchmetrics.Accuracy() self.f1 = torchmetrics.F1(num_classes=self.hparams.n_labels, average='micro') self.auroc = torchmetrics.AUROC(num_classes=self.hparams.n_labels, average='weighted') # NOTE could try 'global' instead of samplewise for mdmc reduce self.prec = torchmetrics.Precision(num_classes=self.hparams.n_labels, is_multiclass=False) self.recall = torchmetrics.Recall(num_classes=self.hparams.n_labels, is_multiclass=False) self.confusion_matrix = torchmetrics.ConfusionMatrix( num_classes=self.hparams.n_labels) self.test_predictions = None self.test_labels = None def __build_model(self) -> None: """ Init transformer model + tokenizer + classification head.""" if self.hparams.transformer_type == 'roberta-long': self.transformer = RobertaLongForMaskedLM.from_pretrained( self.hparams.encoder_model, output_hidden_states=True, gradient_checkpointing=True) elif self.hparams.transformer_type == 'longformer': self.transformer = AutoModel.from_pretrained( self.hparams.encoder_model, output_hidden_states=True, gradient_checkpointing=True, #critical for training speed. ) else: #BERT self.transformer = AutoModel.from_pretrained( self.hparams.encoder_model, output_hidden_states=True, ) logger.warning(f'model is {self.hparams.encoder_model}') if self.hparams.transformer_type == 'longformer': logger.warning('Turning ON gradient checkpointing...') self.transformer = AutoModel.from_pretrained( self.hparams.encoder_model, output_hidden_states=True, gradient_checkpointing=True, #critical for training speed. ) else: self.transformer = AutoModel.from_pretrained( self.hparams.encoder_model, output_hidden_states=True, ) # set the number of features our encoder model will return... self.encoder_features = 768 # Tokenizer if self.hparams.transformer_type == 'longformer' or self.hparams.transformer_type == 'roberta-long': self.tokenizer = Tokenizer( pretrained_model=self.hparams.encoder_model, max_tokens=self.hparams.max_tokens_longformer) self.tokenizer.max_len = 4096 else: self.tokenizer = Tokenizer( pretrained_model=self.hparams.encoder_model, max_tokens=512) #others: #'emilyalsentzer/Bio_ClinicalBERT' 'simonlevine/biomed_roberta_base-4096-speedfix' # Ben's new architecture if self.hparams.nn_arch == 'ben1': self.classification_head = nn.Sequential( nn.Linear(self.encoder_features, self.encoder_features * 3), nn.Dropout(0.1), nn.Linear(self.encoder_features * 3, self.encoder_features), nn.Linear(self.encoder_features, self.hparams.n_labels), ) elif self.hparams.nn_arch == 'ben2': self.classification_head = nn.Sequential( nn.Dropout(0.1), nn.Linear(self.encoder_features, self.encoder_features * 2), nn.Tanh(), nn.Linear(self.encoder_features * 2, self.encoder_features * 3), nn.ReLU(), nn.Linear(self.encoder_features * 3, self.encoder_features), nn.Sigmoid(), nn.Linear(self.encoder_features, self.hparams.n_labels), ) elif self.hparams.nn_arch == 'CNN': logger.critical('CNN not yet implemented') elif self.hparams.nn_arch == 'default': self.classification_head = nn.Sequential( nn.Linear(self.encoder_features, self.encoder_features * 2), nn.Tanh(), nn.Linear(self.encoder_features * 2, self.encoder_features), nn.Tanh(), nn.Linear(self.encoder_features, self.hparams.n_labels), ) # Classification head elif self.hparams.single_label_encoding == 'default': self.classification_head = nn.Sequential( nn.Linear(self.encoder_features, self.encoder_features * 2), nn.Tanh(), nn.Linear(self.encoder_features * 2, self.encoder_features), nn.Tanh(), nn.Linear(self.encoder_features, self.hparams.n_labels), ) elif self.hparams.single_label_encoding == 'graphical': logger.critical('Graphical embedding not yet implemented!') # self.classification_head = nn.Sequential( #TODO # ) def __build_loss(self): """ Initializes the loss function/s. """ #FOR SINGLE LABELS --> MSE (linear regression) LOSS (like a regression problem) # For multiple POSSIBLE discrete single labels, CELoss # for many possible categoricla labels, binary cross-entropy (logistic regression for all labels.) self._loss = nn.BCELoss() self._loss = nn.BCEWithLogitsLoss() # self._loss = nn.CrossEntropyLoss() # self._loss = nn.MSELoss() def unfreeze_encoder(self) -> None: """ un-freezes the encoder layer. """ if self._frozen: log.info(f"\n-- Encoder model fine-tuning") for param in self.transformer.parameters(): param.requires_grad = True self._frozen = False def freeze_encoder(self) -> None: """ freezes the encoder layer. """ for param in self.transformer.parameters(): param.requires_grad = False self._frozen = True def predict(self, sample: dict) -> dict: """ Predict function. :param sample: dictionary with the text we want to classify. Returns: Dictionary with the input text and the predicted label. """ if self.training: self.eval() with torch.no_grad(): model_input, _ = self.prepare_sample([sample], prepare_target=False) model_out = self.forward(**model_input) logits = model_out["logits"].numpy() predicted_labels = [ #TODO change this for no label encoder self.mlb.inverse_transform[prediction] for prediction in np.argmax(logits, axis=1) ] sample["predicted_label"] = predicted_labels return sample def forward(self, tokens, lengths): """ Usual pytorch forward function. :param tokens: text sequences [batch_size x src_seq_len] :param lengths: source lengths [batch_size] Returns: Dictionary with model outputs (e.g: logits) """ tokens = tokens[:, :lengths.max()] # When using just one GPU this should not change behavior # but when splitting batches across GPU the tokens have padding # from the entire original batch mask = lengths_to_mask(lengths, device=tokens.device) # Run BERT model. word_embeddings = self.transformer(tokens, mask)[0] # Average Pooling word_embeddings = mask_fill(0.0, tokens, word_embeddings, self.tokenizer.padding_index) sentemb = torch.sum(word_embeddings, 1) sum_mask = mask.unsqueeze(-1).expand( word_embeddings.size()).float().sum(1) sentemb = sentemb / sum_mask return {"logits": self.classification_head(sentemb)} def loss(self, predictions: dict, targets: dict) -> torch.tensor: """ Computes Loss value according to a loss function. :param predictions: model specific output. Must contain a key 'logits' with a tensor [batch_size x 1] with model predictions :param labels: Label values [batch_size] Returns: torch.tensor with loss value. """ return self._loss(predictions["logits"], targets["labels"].float()) def prepare_sample(self, sample: list, prepare_target: bool = True) -> (dict, dict): """ Function that prepares a sample to input the model. :param sample: list of dictionaries. Returns: - dictionary with the expected model inputs. - dictionary with the expe cted target labels. """ # logger.info("Sample label:{}".format([sample[i]['labels'] for i in range(6)])) # logger.info("Sample label:{}".format(sample)) # sample['text'] = collate_tensors(sample[i]['text']for i in range(6)) # sample = collate_tensors(sample) texts = [s['text'] for s in sample] labels = [s['labels'] for s in sample] # logger.info("sample text len is:{}".format(len(sample['text']))) # logger.info("text 1is:{}".format(sample['text'][0])) # logger.info("text2 is:{}".format(sample['text'][1])) # logger.info("text3 is:{}".format(sample['text'][2])) tokens, lengths = self.tokenizer.batch_encode(texts) # logger.info("sample text len is:{}".format(len(sample['text']))) # logger.info("Sample label after collate:{}".format(sample['labels'])) # logger.info("labels are lists: {}".format(type(sample['labels'][0]) == list )) # logger.info("lengths:{}".format(lengths)) inputs = {"tokens": tokens, "lengths": lengths} if not prepare_target: return inputs, {} # return inputs, self.data[self.hparams.targets] # Prepare target: # try: #NOTE WARNING torch.tensor is kinda bad maybe switch for a copier # logger.info(labels) #NOTE WARNING double check that mlb is working correct I think it is sample_labels = torch.tensor(self.mlb.transform(labels)) # logger.info(sample_labels) targets = {'labels': sample_labels} return inputs, targets # if not sample['labels']: # targets = {'labels': self.mlb.transform('')} # else: # logger.info('sample: {}'.format(sample['labels'])) # logger.info('type: {}'.format(type(sample['labels']))) # targets = {"labels": self.mlb.transform(sample["labels"])} # a = [self.mlb.transform([x]) if x else self.mlb.transform(['']) for x in sample["labels"]] # if sample['labels']==[]: # a = self.mlb.transform(' ') # else: # a = self.mlb.transform([sample['labels']]) # b = torch.tensor(a) # c = b.squeeze() targets = {"labels": c} # logger.info(targets['labels']) # targets = {"labels": torch.tensor([self.mlb.transform([x]) if x else self.mlb.transform(['']) for x in sample["labels"]])} # logger.info('targets: {}'.format(targets['labels'])) # logger.info("input len{} is {}".format(len(inputs['lengths']),inputs['lengths'])) # logger.info(targets['labels'].size()) return inputs, targets # except RuntimeError: # raise Exception("Label encoder found an unknown label.") def training_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict: """ Runs one training step. This usually consists in the forward function followed by the loss function. :param batch: The output of your dataloader. :param batch_nb: Integer displaying which batch this is Returns: - dictionary containing the loss and the metrics to be added to the lightning logger. """ inputs, targets = batch model_out = self.forward(**inputs) loss_val = self.loss(model_out, targets) # in DP mode (default) make sure if result is scalar, there's another dim in the beginning if self.trainer.use_dp or self.trainer.use_ddp2: loss_val = loss_val.unsqueeze(0) self.log('loss', loss_val) # can also return just a scalar instead of a dict (return loss_val) return loss_val def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict: """ Runs one training step. This usually consists in the forward function followed by the loss function. :param batch: The output of your dataloader. :param batch_nb: Integer displaying which batch this is Returns: - dictionary containing the loss and the metrics to be added to the lightning logger. """ inputs, targets = batch # logger.info(batch) model_out = self.forward(**inputs) loss_val = self.loss(model_out, targets) # in DP mode (default) make sure if result is scalar, there's another dim in the beginning # if self.trainer.use_dp or self.trainer.use_ddp2: # loss_val = loss_val.unsqueeze(0) self.log('test_loss', loss_val) y_hat = model_out['logits'] # labels_hat = torch.argmax(y_hat, dim=1) preds = torch.tensor([[1 if x > 0 else 0 for x in item] for item in y_hat], device=y_hat.device) y = targets['labels'] return {'loss': loss_val, 'preds': preds, 'target': y} def test_step_end(self, outputs): preds = outputs['preds'] target = outputs['target'] # logger.info("Pred shape is {}".format(preds.size())) # logger.info("Target shape is {}".format(target.size())) # f1 = metrics.f1_score(labels_hat,y, class_reduction='weighted') # prec =metrics.precision(labels_hat,y, class_reduction='weighted') # recall = metrics.recall(labels_hat,y, class_reduction='weighted') # acc = metrics.accuracy(labels_hat,y, class_reduction='weighted') # # auroc = metrics.multiclass_auroc(labels_hat, y) acc = self.acc(preds, target) f1 = self.f1(preds, target) # f1 = fnmetrics.f1(preds, target, num_classes=self.hparams.n_labels) prec = self.prec(preds, target) recall = self.recall(preds, target) try: auroc = self.auroc(preds.float(), target) except ValueError as v: print(v) auroc = 0 confusion_matrix = self.confusion_matrix(preds, target) self.log('test_batch_prec', prec) self.log('test_batch_f1', f1) self.log('test_batch_recall', recall) self.log('test_batch_weighted_acc', acc) self.log('test_batch auroc', auroc) # self.log('test_batch_auc_roc', auroc) from pytorch_lightning.metrics.functional import confusion_matrix # TODO CHANGE THIS # return (labels_hat, y) # cm = confusion_matrix(preds = preds,target=target,normalize=None, num_classes=50) # cm = confusion_matrix(preds = labels_hat,target=y,normalize=False, num_classes=len(y.unique())) # self.test_conf_matrices.append(cm) # logger.info(labels_hat) # logger.info(y) # logger.info(classification_report(preds.detach().cpu(), target.detach().cpu())) # logger.info(confusion_matrix) #update and log if self.test_predictions is None: self.test_predictions = preds self.test_labels = target else: self.test_predictions = torch.cat((self.test_predictions, preds), 0) self.test_labels = torch.cat((self.test_labels, target), 0) def validation_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict: """ Similar to the training step but with the model in eval mode. Returns: - dictionary passed to the validation_end function. """ inputs, targets = batch model_out = self.forward(**inputs) loss_val = self.loss(model_out, targets) y = targets["labels"] y_hat = model_out["logits"] # acc # logger.info(y_hat) # logger.info(y) # labels_hat = torch.argmax(y_hat, dim=0) # labels_hat labels_hat = torch.tensor([[1 if x > 0 else 0 for x in item] for item in y_hat], device=y_hat.device) # logger.info(labels_hat.device) # logger.info(torch.argmax(y_hat, dim=1)) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) val_acc = torch.tensor(val_acc) if self.on_gpu: val_acc = val_acc.cuda(loss_val.device.index) # in DP mode (default) make sure if result is scalar, there's another dim in the beginning if self.trainer.use_dp or self.trainer.use_ddp2: loss_val = loss_val.unsqueeze(0) val_acc = val_acc.unsqueeze(0) self.log('val_loss', loss_val) # f1 = metrics.f1_score(labels_hat, y,class_reduction='weighted') # prec =metrics.precision(labels_hat, y,class_reduction='weighted') # recall = metrics.recall(labels_hat, y,class_reduction='weighted') # acc = metrics.accuracy(labels_hat, y,class_reduction='weighted') # auroc = skm.AUROC # # auroc = metrics.multiclass_auroc(y_hat,y) # self.log('val_prec',prec) # self.log('val_f1',f1) # self.log('val_recall',recall) # self.log('val_acc_weighted', acc) # logger.info(classification_report(labels_hat.detach().cpu(), y.detach().cpu())) # self.log('val_cm',cm) def configure_optimizers(self): """ Sets different Learning rates for different parameter groups. """ parameters = [ { "params": self.classification_head.parameters() }, { "params": self.transformer.parameters(), "lr": self.hparams.encoder_learning_rate, }, ] optimizer = optim.Adam(parameters, lr=self.hparams.learning_rate) return [optimizer], [] def on_epoch_end(self): """ Pytorch lightning hook """ if self.current_epoch + 1 >= self.nr_frozen_epochs: self.unfreeze_encoder() @classmethod def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: """ Parser for Estimator specific arguments/hyperparameters. :param parser: argparse.ArgumentParser Returns: - updated parser """ parser.add_argument( "--encoder_model", default= 'emilyalsentzer/Bio_ClinicalBERT', # 'allenai/biomed_roberta_base',#'simonlevine/biomed_roberta_base-4096-speedfix', # 'bert-base-uncased', type=str, help="Encoder model to be used.", ) parser.add_argument( "--transformer_type", default='bert', #'longformer', roberta-long type=str, help= "Encoder model /tokenizer to be used (has consequences for tokenization and encoding; default = longformer).", ) parser.add_argument( "--single_label_encoding", default='none', type=str, help= "How should labels be encoded? Default for torch-nlp label-encoder...", ) parser.add_argument( "--max_tokens_longformer", default=4096, type=int, help="Max tokens to be considered per instance..", ) parser.add_argument( "--max_tokens", default=512, type=int, help="Max tokens to be considered per instance..", ) parser.add_argument( "--encoder_learning_rate", default=1e-05, type=float, help="Encoder specific learning rate.", ) parser.add_argument( "--learning_rate", default=3e-05, type=float, help="Classification head learning rate.", ) parser.add_argument( "--nr_frozen_epochs", default=0, type=int, help="Number of epochs we want to keep the encoder model frozen.", ) parser.add_argument( "--train_csv", default="data/intermediary-data/notes2diagnosis-icd-train.csv", type=str, help="Path to the file containing the train data.", ) parser.add_argument( "--dev_csv", default="data/intermediary-data/notes2diagnosis-icd-validate.csv", type=str, help="Path to the file containing the dev data.", ) parser.add_argument( "--test_csv", default="data/intermediary-data/notes2diagnosis-icd-test.csv", type=str, help="Path to the file containing the dev data.", ) parser.add_argument( "--loader_workers", default=8, type=int, help="How many subprocesses to use for data loading. 0 means that \ the data will be loaded in the main process.", ) return parser
class MedNLIClassifier(pl.LightningModule): """ Sample model to show how to use a Transformer model to classify sentences. :param hparams: ArgumentParser containing the hyperparameters. """ # ************************* class DataModule(pl.LightningDataModule): def __init__(self, classifier_instance): super().__init__() self.hparams = classifier_instance.hparams if self.hparams.transformer_type == 'longformer': self.hparams.batch_size = 1 self.classifier = classifier_instance self.transformer_type = self.hparams.transformer_type # Label Encoder self.label_encoder = LabelEncoder( ['contradiction', 'entailment', 'neutral'], reserved_labels=[]) self.label_encoder.unknown_index = None def get_mednli_data(self, raw_data: list) -> list: #REWRITE """ Reads a comma separated value file. :param path: path to a csv file. :return: List of records as dictionaries """ df = pd.DataFrame(raw_data) df = df.rename(columns={0: 'premise', 1: 'hypothesis', 2: 'label'}) df["text"] = df["premise"].astype(str) + df["hypothesis"].astype( str) df["label"] = df["label"].astype(str) df = df.drop(['premise', 'hypothesis'], axis=1) return df.to_dict("records") def setup(self, stage=None): self.mednli_train, self.mednli_dev, self.mednli_test = load_mednli( ) def train_dataloader(self) -> DataLoader: """ Function that loads the train set. """ logger.warning('Loading training data...') self._train_dataset = self.get_mednli_data(self.mednli_train) return DataLoader( dataset=self._train_dataset, sampler=RandomSampler(self._train_dataset), batch_size=self.hparams.batch_size, collate_fn=self.classifier.prepare_sample, num_workers=self.hparams.loader_workers, ) def val_dataloader(self) -> DataLoader: logger.warning('Loading validation data...') """ Function that loads the validation set. """ self._dev_dataset = self.get_mednli_data(self.mednli_dev) return DataLoader( dataset=self._dev_dataset, batch_size=self.hparams.batch_size, collate_fn=self.classifier.prepare_sample, num_workers=self.hparams.loader_workers, ) def test_dataloader(self) -> DataLoader: logger.warning('Loading testing data...') """ Function that loads the validation set. """ self._test_dataset = self.get_mednli_data(self.mednli_test) return DataLoader( dataset=self._test_dataset, batch_size=self.hparams.batch_size, collate_fn=self.classifier.prepare_sample, num_workers=self.hparams.loader_workers, ) # **************** def __init__(self, hparams: Namespace) -> None: super(MedNLIClassifier, self).__init__() self.hparams = hparams self.batch_size = hparams.batch_size # Build Data module self.data = self.DataModule(self) # build model self.__build_model() # Loss criterion initialization. self.__build_loss() if hparams.nr_frozen_epochs > 0: self.freeze_encoder() else: self._frozen = False self.nr_frozen_epochs = hparams.nr_frozen_epochs self.confusion_matrix = ConfusionMatrix(num_classes=3) def __build_model(self) -> None: """ Init transformer model + tokenizer + classification head.""" if self.hparams.transformer_type == 'roberta-long': self.transformer = RobertaLongForMaskedLM.from_pretrained( self.hparams.encoder_model, output_hidden_states=True, gradient_checkpointing=True) elif self.hparams.transformer_type == 'longformer': self.transformer = AutoModel.from_pretrained( self.hparams.encoder_model, output_hidden_states=True, gradient_checkpointing=True, #critical for training speed. ) else: #BERT self.transformer = AutoModel.from_pretrained( self.hparams.encoder_model, output_hidden_states=True, ) logger.warning(f'model is {self.hparams.encoder_model}') if self.hparams.transformer_type == 'longformer': logger.warning('Turnin ON gradient checkpointing...') self.transformer = AutoModel.from_pretrained( self.hparams.encoder_model, output_hidden_states=True, gradient_checkpointing=True, #critical for training speed. ) else: self.transformer = AutoModel.from_pretrained( self.hparams.encoder_model, output_hidden_states=True, ) # set the number of features our encoder model will return... self.encoder_features = 768 # Tokenizer if self.hparams.transformer_type == 'longformer' or self.hparams.transformer_type == 'roberta-long': self.tokenizer = Tokenizer( pretrained_model=self.hparams.encoder_model, max_tokens=self.hparams.max_tokens_longformer) self.tokenizer.max_len = 4096 else: self.tokenizer = Tokenizer( pretrained_model=self.hparams.encoder_model, max_tokens=512) #others: #'emilyalsentzer/Bio_ClinicalBERT' 'simonlevine/biomed_roberta_base-4096-speedfix' # Classification head if self.hparams.single_label_encoding == 'default': self.classification_head = nn.Sequential( nn.Linear(self.encoder_features, self.encoder_features * 2), nn.Tanh(), nn.Linear(self.encoder_features * 2, self.encoder_features), nn.Tanh(), nn.Linear(self.encoder_features, self.data.label_encoder.vocab_size), ) elif self.hparams.single_label_encoding == 'graphical': logger.critical('Graphical embedding not yet implemented!') # self.classification_head = nn.Sequential( #TODO # ) def __build_loss(self): """ Initializes the loss function/s. """ #FOR SINGLE LABELS --> MSE (linear regression) LOSS (like a regression problem) # For multiple POSSIBLE discrete single labels, CELoss # for many possible categoricla labels, binary cross-entropy (logistic regression for all labels.) self._loss = nn.CrossEntropyLoss() # self._loss = nn.MSELoss() def unfreeze_encoder(self) -> None: """ un-freezes the encoder layer. """ if self._frozen: log.info(f"\n-- Encoder model fine-tuning") for param in self.transformer.parameters(): param.requires_grad = True self._frozen = False def freeze_encoder(self) -> None: """ freezes the encoder layer. """ for param in self.transformer.parameters(): param.requires_grad = False self._frozen = True def predict(self, sample: dict) -> dict: """ Predict function. :param sample: dictionary with the text we want to classify. Returns: Dictionary with the input text and the predicted label. """ if self.training: self.eval() with torch.no_grad(): model_input, _ = self.prepare_sample([sample], prepare_target=False) model_out = self.forward(**model_input) logits = model_out["logits"].numpy() predicted_labels = [ self.data.label_encoder.index_to_token[prediction] for prediction in np.argmax(logits, axis=1) ] sample["predicted_label"] = predicted_labels[0] return sample def forward(self, tokens, lengths): """ Usual pytorch forward function. :param tokens: text sequences [batch_size x src_seq_len] :param lengths: source lengths [batch_size] Returns: Dictionary with model outputs (e.g: logits) """ tokens = tokens[:, :lengths.max()] # When using just one GPU this should not change behavior # but when splitting batches across GPU the tokens have padding # from the entire original batch mask = lengths_to_mask(lengths, device=tokens.device) # Run BERT model. word_embeddings = self.transformer(tokens, mask)[0] # Average Pooling word_embeddings = mask_fill(0.0, tokens, word_embeddings, self.tokenizer.padding_index) sentemb = torch.sum(word_embeddings, 1) sum_mask = mask.unsqueeze(-1).expand( word_embeddings.size()).float().sum(1) sentemb = sentemb / sum_mask return {"logits": self.classification_head(sentemb)} def loss(self, predictions: dict, targets: dict) -> torch.tensor: """ Computes Loss value according to a loss function. :param predictions: model specific output. Must contain a key 'logits' with a tensor [batch_size x 1] with model predictions :param labels: Label values [batch_size] Returns: torch.tensor with loss value. """ return self._loss(predictions["logits"], targets["labels"]) def prepare_sample(self, sample: list, prepare_target: bool = True) -> (dict, dict): """ Function that prepares a sample to input the model. :param sample: list of dictionaries. Returns: - dictionary with the expected model inputs. - dictionary with the expected target labels. """ sample = collate_tensors(sample) tokens, lengths = self.tokenizer.batch_encode(sample["text"]) inputs = {"tokens": tokens, "lengths": lengths} if not prepare_target: return inputs, {} # Prepare target: try: targets = { "labels": self.data.label_encoder.batch_encode(sample["label"]) } return inputs, targets except RuntimeError: raise Exception("Label encoder found an unknown label.") def training_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict: """ Runs one training step. This usually consists in the forward function followed by the loss function. :param batch: The output of your dataloader. :param batch_nb: Integer displaying which batch this is Returns: - dictionary containing the loss and the metrics to be added to the lightning logger. """ inputs, targets = batch model_out = self.forward(**inputs) loss_val = self.loss(model_out, targets) # in DP mode (default) make sure if result is scalar, there's another dim in the beginning if self.trainer.use_dp or self.trainer.use_ddp2: loss_val = loss_val.unsqueeze(0) self.log('loss', loss_val) # can also return just a scalar instead of a dict (return loss_val) return loss_val def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict: """ Runs one training step. This usually consists in the forward function followed by the loss function. :param batch: The output of your dataloader. :param batch_nb: Integer displaying which batch this is Returns: - dictionary containing the loss and the metrics to be added to the lightning logger. """ inputs, targets = batch model_out = self.forward(**inputs) loss_val = self.loss(model_out, targets) # in DP mode (default) make sure if result is scalar, there's another dim in the beginning if self.trainer.use_dp or self.trainer.use_ddp2: loss_val = loss_val.unsqueeze(0) self.log('test_loss', loss_val) y_hat = model_out['logits'] labels_hat = torch.argmax(y_hat, dim=1) y = targets['labels'] f1 = metrics.f1(labels_hat, y, average='weighted', num_classes=3) prec = metrics.precision(labels_hat, y, class_reduction='weighted', num_classes=3) recall = metrics.recall(labels_hat, y, class_reduction='weighted', num_classes=3) acc = metrics.accuracy(labels_hat, y, class_reduction='weighted', num_classes=3) self.confusion_matrix.update(labels_hat, y) self.log('test_batch_prec', prec) self.log('test_batch_f1', f1) self.log('test_batch_recall', recall) self.log('test_batch_weighted_acc', acc) def validation_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict: """ Similar to the training step but with the model in eval mode. Returns: - dictionary passed to the validation_end function. """ inputs, targets = batch model_out = self.forward(**inputs) loss_val = self.loss(model_out, targets) y = targets["labels"] y_hat = model_out["logits"] # acc labels_hat = torch.argmax(y_hat, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) val_acc = torch.tensor(val_acc) if self.on_gpu: val_acc = val_acc.cuda(loss_val.device.index) # in DP mode (default) make sure if result is scalar, there's another dim in the beginning if self.trainer.use_dp or self.trainer.use_ddp2: loss_val = loss_val.unsqueeze(0) val_acc = val_acc.unsqueeze(0) self.log('val_loss', loss_val) f1 = metrics.f1(labels_hat, y, average='weighted', num_classes=3) prec = metrics.precision(labels_hat, y, class_reduction='weighted', num_classes=3) recall = metrics.recall(labels_hat, y, class_reduction='weighted', num_classes=3) acc = metrics.accuracy(labels_hat, y, class_reduction='weighted', num_classes=3) self.log('val_prec', prec) self.log('val_f1', f1) self.log('val_recall', recall) self.log('val_acc_weighted', acc) def configure_optimizers(self): """ Sets different Learning rates for different parameter groups. """ parameters = [ { "params": self.classification_head.parameters() }, { "params": self.transformer.parameters(), "lr": self.hparams.encoder_learning_rate, }, ] optimizer = optim.Adam(parameters, lr=self.hparams.learning_rate) return [optimizer], [] def on_epoch_end(self): """ Pytorch lightning hook """ if self.current_epoch + 1 >= self.nr_frozen_epochs: self.unfreeze_encoder() @classmethod def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: """ Parser for Estimator specific arguments/hyperparameters. :param parser: argparse.ArgumentParser Returns: - updated parser """ parser.add_argument( "--encoder_model", default= 'simonlevine/bioclinical-roberta-long', #'emilyalsentzer/Bio_ClinicalBERT',# 'allenai/biomed_roberta_base',', # 'bert-base-uncased', type=str, help="Encoder model to be used.", ) parser.add_argument( "--transformer_type", default='roberta-long', #'bert', roberta, or roberta-long type=str, help= "Encoder model /tokenizer to be used (has consequences for tokenization and encoding; default = longformer).", ) parser.add_argument( "--single_label_encoding", default='default', type=str, help= "How should labels be encoded? Default for torch-nlp label-encoder...", ) parser.add_argument( "--max_tokens_longformer", default=4096, type=int, help="Max tokens to be considered per instance..", ) parser.add_argument( "--max_tokens", default=512, type=int, help="Max tokens to be considered per instance..", ) parser.add_argument( "--encoder_learning_rate", default=1e-05, type=float, help="Encoder specific learning rate.", ) parser.add_argument( "--learning_rate", default=3e-05, type=float, help="Classification head learning rate.", ) parser.add_argument( "--nr_frozen_epochs", default=0, type=int, help="Number of epochs we want to keep the encoder model frozen.", ) parser.add_argument( "--loader_workers", default=8, type=int, help="How many subprocesses to use for data loading. 0 means that \ the data will be loaded in the main process.", ) return parser
class Classifier(pl.LightningModule): """ Sample model to show how to use a Transformer model to classify sentences. :param hparams: ArgumentParser containing the hyperparameters. """ # ************************* class DataModule(pl.LightningDataModule): def __init__(self, classifier_instance): super().__init__() self.hparams = classifier_instance.hparams if self.hparams.transformer_type == 'longformer': self.hparams.batch_size = 1 self.classifier = classifier_instance self.transformer_type = self.hparams.transformer_type self.raw_data = self.get_mimic_data() msk = np.random.rand(len(self.raw_data)) < 0.8 self.train = self.raw_data[msk] self.test = self.raw_data[~msk] # self.label_encoder.unknown_index = None def get_mimic_data(self, path: str) -> list: """ Reads a comma separated value file. :param path: path to a csv file. :return: List of records as dictionaries """ df = pd.read_csv(path) df = df.drop('ROW_ID',axis=1) return df # df["text"] = df["text"].astype(str) # df["label"] = df["label"].astype(str) # return df.to_dict("records") def train_dataloader(self) -> DataLoader: """ Function that loads the train set. """ self._train_dataset = MimicAnnotDataset(self.train) return DataLoader( dataset=self._train_dataset, sampler=RandomSampler(self._train_dataset), batch_size=self.hparams.batch_size, collate_fn=self.classifier.prepare_sample, num_workers=self.hparams.loader_workers, ) # def val_dataloader(self) -> DataLoader: # """ Function that loads the validation set. """ # self._dev_dataset = self.get_mimic_data(self.hparams.dev_csv) # return DataLoader( # dataset=self._dev_dataset, # batch_size=self.hparams.batch_size, # collate_fn=self.classifier.prepare_sample, # num_workers=self.hparams.loader_workers, # ) def test_dataloader(self) -> DataLoader: """ Function that loads the validation set. """ self._test_dataset = self.MimicAnnotDataset(self.test) return DataLoader( dataset=self._test_dataset, batch_size=self.hparams.batch_size, collate_fn=self.classifier.prepare_sample, num_workers=self.hparams.loader_workers, ) # **************** def __init__(self, hparams: Namespace) -> None: super(Classifier,self).__init__() self.hparams = hparams self.batch_size = hparams.batch_size # Build Data module self.data = self.DataModule(self) # build model self.__build_model() # Loss criterion initialization. self.__build_loss() if hparams.nr_frozen_epochs > 0: self.freeze_encoder() else: self._frozen = False self.nr_frozen_epochs = hparams.nr_frozen_epochs def __build_model(self) -> None: """ Init transformer model + tokenizer + classification head.""" #simonlevine/biomed_roberta_base-4096-speedfix' self.transformer = AutoModel.from_pretrained( self.hparams.encoder_model, output_hidden_states=True, # gradient_checkpointing=True, #critical for training speed. ) if self.hparams.transformer_type == 'longformer': logger.warning('Turnin ON gradient checkpointing...') self.transformer = AutoModelForSequenceClassification.from_pretrained( self.hparams.encoder_model, output_hidden_states=True, gradient_checkpointing=True, #critical for training speed. ) else: self.transformer = AutoModelForSequenceClassification.from_pretrained( self.hparams.encoder_model, output_hidden_states=True, ) #others to try: # bert-base-uncased #'emilyalsentzer/Bio_ClinicalBERT' # allenai/biomed_roberta_base # simonlevine/biomed_roberta_base-4096-speedfix' # set the number of features our encoder model will return... self.encoder_features = 768 # Tokenizer if self.hparams.transformer_type == 'longformer': self.tokenizer = Tokenizer( pretrained_model=self.hparams.encoder_model, max_tokens = self.hparams.max_tokens_longformer) else: self.hparams.tokenizer = Tokenizer( pretrained_model=self.hparams.encoder_model, max_tokens = self.hparams.max_tokens) #others: #'emilyalsentzer/Bio_ClinicalBERT' 'simonlevine/biomed_roberta_base-4096-speedfix' def __build_loss(self): """ Initializes the loss function/s. """ #FOR SINGLE LABELS --> MSE (linear regression) LOSS (like a regression problem) # For multiple POSSIBLE discrete single labels, CELoss # for many possible categoricla labels, binary cross-entropy (logistic regression for all labels.) self._loss = nn.CrossEntropyLoss() # self._loss = nn.MSELoss() def unfreeze_encoder(self) -> None: """ un-freezes the encoder layer. """ if self._frozen: log.info(f"\n-- Encoder model fine-tuning") for param in self.transformer.parameters(): param.requires_grad = True self._frozen = False def freeze_encoder(self) -> None: """ freezes the encoder layer. """ for param in self.transformer.parameters(): param.requires_grad = False self._frozen = True def predict(self, sample: dict) -> dict: """ Predict function. :param sample: dictionary with the text we want to classify. Returns: Dictionary with the input text and the predicted label. """ if self.training: self.eval() with torch.no_grad(): model_input, _ = self.prepare_sample([sample], prepare_target=False) model_out = self.forward(**model_input) logits = model_out["logits"].numpy() predicted_labels = [ self.data.label_encoder.index_to_token[prediction] for prediction in np.argmax(logits, axis=1) ] sample["predicted_label"] = predicted_labels[0] return sample def forward(self, tokens, lengths): """ Usual pytorch forward function. :param tokens: text sequences [batch_size x src_seq_len] :param lengths: source lengths [batch_size] Returns: Dictionary with model outputs (e.g: logits) """ tokens = tokens[:, : lengths.max()] # When using just one GPU this should not change behavior # but when splitting batches across GPU the tokens have padding # from the entire original batch mask = lengths_to_mask(lengths, device=tokens.device) # Run transformer model. word_embeddings = self.transformer(tokens, mask)[0] # Average Pooling word_embeddings = mask_fill( 0.0, tokens, word_embeddings, self.tokenizer.padding_index ) sentemb = torch.sum(word_embeddings, 1) sum_mask = mask.unsqueeze(-1).expand(word_embeddings.size()).float().sum(1) sentemb = sentemb / sum_mask return {"logits": self.classification_head(sentemb)} def loss(self, predictions: dict, targets: dict) -> torch.tensor: """ Computes Loss value according to a loss function. :param predictions: model specific output. Must contain a key 'logits' with a tensor [batch_size x 1] with model predictions :param labels: Label values [batch_size] Returns: torch.tensor with loss value. """ return self._loss(predictions["logits"], targets["labels"]) def prepare_sample(self, sample: list, prepare_target: bool = True) -> (dict, dict): """ Function that prepares a sample to input the model. :param sample: list of dictionaries. Returns: - dictionary with the expected model inputs. - dictionary with the expected target labels. """ sample = collate_tensors(sample) tokens, lengths = self.tokenizer.batch_encode(sample["text"]) inputs = {"tokens": tokens, "lengths": lengths} if not prepare_target: return inputs, {} # Prepare target: try: targets = {"labels": self.data.batch_encode(sample["label"])} return inputs, targets except RuntimeError: raise Exception("Label encoder found an unknown label.") def training_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict: """ Runs one training step. This usually consists in the forward function followed by the loss function. :param batch: The output of your dataloader. :param batch_nb: Integer displaying which batch this is Returns: - dictionary containing the loss and the metrics to be added to the lightning logger. """ inputs, targets = batch model_out = self.forward(**inputs) loss_val = self.loss(model_out, targets) # in DP mode (default) make sure if result is scalar, there's another dim in the beginning if self.trainer.use_dp or self.trainer.use_ddp2: loss_val = loss_val.unsqueeze(0) self.log('loss',loss_val) # can also return just a scalar instead of a dict (return loss_val) return loss_val def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict: """ Runs one training step. This usually consists in the forward function followed by the loss function. :param batch: The output of your dataloader. :param batch_nb: Integer displaying which batch this is Returns: - dictionary containing the loss and the metrics to be added to the lightning logger. """ inputs, targets = batch model_out = self.forward(**inputs) loss_val = self.loss(model_out, targets) # in DP mode (default) make sure if result is scalar, there's another dim in the beginning if self.trainer.use_dp or self.trainer.use_ddp2: loss_val = loss_val.unsqueeze(0) self.log('test_loss',loss_val) # can also return just a scalar instead of a dict (return loss_val) return loss_val def configure_optimizers(self): """ Sets different Learning rates for different parameter groups. """ parameters = [ {"params": self.classification_head.parameters()}, { "params": self.transformer.parameters(), "lr": self.hparams.encoder_learning_rate, }, ] optimizer = optim.Adam(parameters, lr=self.hparams.learning_rate) return [optimizer], [] def on_epoch_end(self): """ Pytorch lightning hook """ if self.current_epoch + 1 >= self.nr_frozen_epochs: self.unfreeze_encoder() @classmethod def add_model_specific_args( cls, parser: ArgumentParser ) -> ArgumentParser: """ Parser for Estimator specific arguments/hyperparameters. :param parser: argparse.ArgumentParser Returns: - updated parser """ parser.add_argument( "--encoder_model", default='simonlevine/biomed_roberta_base-4096-speedfix', # 'bert-base-uncased', type=str, help="Encoder model to be used.", ) parser.add_argument( "--transformer_type", default='longformer', type=str, help="Encoder model /tokenizer to be used (has consequences for tokenization and encoding; default = longformer).", ) parser.add_argument( "--single_label_encoding", default='default', type=str, help="How should labels be encoded? Default for torch-nlp label-encoder...", ) parser.add_argument( "--max_tokens_longformer", default=4096, type=int, help="Max tokens to be considered per instance..", ) parser.add_argument( "--max_tokens", default=512, type=int, help="Max tokens to be considered per instance..", ) parser.add_argument( "--encoder_learning_rate", default=1e-05, type=float, help="Encoder specific learning rate.", ) parser.add_argument( "--learning_rate", default=3e-05, type=float, help="Classification head learning rate.", ) parser.add_argument( "--nr_frozen_epochs", default=1, type=int, help="Number of epochs we want to keep the encoder model frozen.", ) parser.add_argument( "--train_csv", default="data/intermediary-data/notes2diagnosis-icd-train.csv", type=str, help="Path to the file containing the train data.", ) parser.add_argument( "--dev_csv", default="data/intermediary-data/notes2diagnosis-icd-validate.csv", type=str, help="Path to the file containing the dev data.", ) parser.add_argument( "--test_csv", default="data/intermediary-data/notes2diagnosis-icd-test.csv", type=str, help="Path to the file containing the dev data.", ) parser.add_argument( "--loader_workers", default=8, type=int, help="How many subprocesses to use for data loading. 0 means that \ the data will be loaded in the main process.", ) return parser
class Classifier(pl.LightningModule): """ Sample model to show how to use a Transformer model to classify sentences. :param hparams: ArgumentParser containing the hyperparameters. """ class DataModule(pl.LightningDataModule): def __init__(self, classifier_instance): super().__init__() self.hparams = classifier_instance.hparams self.classifier = classifier_instance # Label Encoder self.label_encoder = LabelEncoder(pd.read_csv( self.hparams.train_csv).label.unique().tolist(), reserved_labels=[]) self.label_encoder.unknown_index = None def read_csv(self, path: str) -> list: """ Reads a comma separated value file. :param path: path to a csv file. :return: List of records as dictionaries """ df = pd.read_csv(path) df = df[["text", "label"]] df["text"] = df["text"].astype(str) df["label"] = df["label"].astype(str) return df.to_dict("records") def train_dataloader(self) -> DataLoader: """ Function that loads the train set. """ self._train_dataset = self.read_csv(self.hparams.train_csv) return DataLoader( dataset=self._train_dataset, sampler=RandomSampler(self._train_dataset), batch_size=self.hparams.batch_size, collate_fn=self.classifier.prepare_sample, num_workers=self.hparams.loader_workers, ) def val_dataloader(self) -> DataLoader: """ Function that loads the validation set. """ self._dev_dataset = self.read_csv(self.hparams.dev_csv) return DataLoader( dataset=self._dev_dataset, batch_size=self.hparams.batch_size, collate_fn=self.classifier.prepare_sample, num_workers=self.hparams.loader_workers, ) def test_dataloader(self) -> DataLoader: """ Function that loads the validation set. """ self._test_dataset = self.read_csv(self.hparams.test_csv) return DataLoader( dataset=self._test_dataset, batch_size=self.hparams.batch_size, collate_fn=self.classifier.prepare_sample, num_workers=self.hparams.loader_workers, ) def __init__(self, hparams: Namespace) -> None: super(Classifier, self).__init__() self.hparams = hparams self.batch_size = hparams.batch_size # Build Data module self.data = self.DataModule(self) # build model self.__build_model() # Loss criterion initialization. self.__build_loss() if hparams.nr_frozen_epochs > 0: self.freeze_encoder() else: self._frozen = False self.nr_frozen_epochs = hparams.nr_frozen_epochs def __build_model(self) -> None: """ Init BERT model + tokenizer + classification head.""" self.bert = AutoModel.from_pretrained(self.hparams.encoder_model, output_hidden_states=True) # set the number of features our encoder model will return... if self.hparams.encoder_model == "google/bert_uncased_L-2_H-128_A-2": self.encoder_features = 128 else: self.encoder_features = 768 # Tokenizer self.tokenizer = Tokenizer("bert-base-uncased") # Classification head self.classification_head = nn.Sequential( nn.Linear(self.encoder_features, self.encoder_features * 2), nn.Tanh(), nn.Linear(self.encoder_features * 2, self.encoder_features), nn.Tanh(), nn.Linear(self.encoder_features, self.data.label_encoder.vocab_size), ) def __build_loss(self): """ Initializes the loss function/s. """ self._loss = nn.CrossEntropyLoss() def unfreeze_encoder(self) -> None: """ un-freezes the encoder layer. """ if self._frozen: log.info(f"\n-- Encoder model fine-tuning") for param in self.bert.parameters(): param.requires_grad = True self._frozen = False def freeze_encoder(self) -> None: """ freezes the encoder layer. """ for param in self.bert.parameters(): param.requires_grad = False self._frozen = True def predict(self, sample: dict) -> dict: """ Predict function. :param sample: dictionary with the text we want to classify. Returns: Dictionary with the input text and the predicted label. """ if self.training: self.eval() with torch.no_grad(): model_input, _ = self.prepare_sample([sample], prepare_target=False) model_out = self.forward(**model_input) logits = model_out["logits"].numpy() predicted_labels = [ self.data.label_encoder.index_to_token[prediction] for prediction in np.argmax(logits, axis=1) ] sample["predicted_label"] = predicted_labels[0] return sample def forward(self, tokens, lengths): """ Usual pytorch forward function. :param tokens: text sequences [batch_size x src_seq_len] :param lengths: source lengths [batch_size] Returns: Dictionary with model outputs (e.g: logits) """ tokens = tokens[:, :lengths.max()] # When using just one GPU this should not change behavior # but when splitting batches across GPU the tokens have padding # from the entire original batch mask = lengths_to_mask(lengths, device=tokens.device) # Run BERT model. word_embeddings = self.bert(tokens, mask)[0] # Average Pooling word_embeddings = mask_fill(0.0, tokens, word_embeddings, self.tokenizer.padding_index) sentemb = torch.sum(word_embeddings, 1) sum_mask = mask.unsqueeze(-1).expand( word_embeddings.size()).float().sum(1) sentemb = sentemb / sum_mask return {"logits": self.classification_head(sentemb)} def loss(self, predictions: dict, targets: dict) -> torch.tensor: """ Computes Loss value according to a loss function. :param predictions: model specific output. Must contain a key 'logits' with a tensor [batch_size x 1] with model predictions :param labels: Label values [batch_size] Returns: torch.tensor with loss value. """ return self._loss(predictions["logits"], targets["labels"]) def prepare_sample(self, sample: list, prepare_target: bool = True) -> (dict, dict): """ Function that prepares a sample to input the model. :param sample: list of dictionaries. Returns: - dictionary with the expected model inputs. - dictionary with the expected target labels. """ sample = collate_tensors(sample) tokens, lengths = self.tokenizer.batch_encode(sample["text"]) inputs = {"tokens": tokens, "lengths": lengths} if not prepare_target: return inputs, {} # Prepare target: try: targets = { "labels": self.data.label_encoder.batch_encode(sample["label"]) } return inputs, targets except RuntimeError: raise Exception("Label encoder found an unknown label.") def training_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict: """ Runs one training step. This usually consists in the forward function followed by the loss function. :param batch: The output of your dataloader. :param batch_nb: Integer displaying which batch this is Returns: - dictionary containing the loss and the metrics to be added to the lightning logger. """ inputs, targets = batch model_out = self.forward(**inputs) loss_val = self.loss(model_out, targets) # in DP mode (default) make sure if result is scalar, there's another dim in the beginning if self.trainer.use_dp or self.trainer.use_ddp2: loss_val = loss_val.unsqueeze(0) tqdm_dict = {"train_loss": loss_val} output = OrderedDict({ "loss": loss_val, "progress_bar": tqdm_dict, "log": tqdm_dict }) # can also return just a scalar instead of a dict (return loss_val) return output def validation_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict: """ Similar to the training step but with the model in eval mode. Returns: - dictionary passed to the validation_end function. """ inputs, targets = batch model_out = self.forward(**inputs) loss_val = self.loss(model_out, targets) y = targets["labels"] y_hat = model_out["logits"] # acc labels_hat = torch.argmax(y_hat, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) val_acc = torch.tensor(val_acc) if self.on_gpu: val_acc = val_acc.cuda(loss_val.device.index) # in DP mode (default) make sure if result is scalar, there's another dim in the beginning if self.trainer.use_dp or self.trainer.use_ddp2: loss_val = loss_val.unsqueeze(0) val_acc = val_acc.unsqueeze(0) output = OrderedDict({ "val_loss": loss_val, "val_acc": val_acc, }) # can also return just a scalar instead of a dict (return loss_val) return output def validation_end(self, outputs: list) -> dict: """ Function that takes as input a list of dictionaries returned by the validation_step function and measures the model performance accross the entire validation set. Returns: - Dictionary with metrics to be added to the lightning logger. """ val_loss_mean = 0 val_acc_mean = 0 for output in outputs: val_loss = output["val_loss"] # reduce manually when using dp if self.trainer.use_dp or self.trainer.use_ddp2: val_loss = torch.mean(val_loss) val_loss_mean += val_loss # reduce manually when using dp val_acc = output["val_acc"] if self.trainer.use_dp or self.trainer.use_ddp2: val_acc = torch.mean(val_acc) val_acc_mean += val_acc val_loss_mean /= len(outputs) val_acc_mean /= len(outputs) tqdm_dict = {"val_loss": val_loss_mean, "val_acc": val_acc_mean} result = { "progress_bar": tqdm_dict, "log": tqdm_dict, "val_loss": val_loss_mean, } return result def configure_optimizers(self): """ Sets different Learning rates for different parameter groups. """ parameters = [ { "params": self.classification_head.parameters() }, { "params": self.bert.parameters(), "lr": self.hparams.encoder_learning_rate, }, ] optimizer = optim.Adam(parameters, lr=self.hparams.learning_rate) return [optimizer], [] def on_epoch_end(self): """ Pytorch lightning hook """ if self.current_epoch + 1 >= self.nr_frozen_epochs: self.unfreeze_encoder() @classmethod def add_model_specific_args(cls, parser: ArgumentParser) -> ArgumentParser: """ Parser for Estimator specific arguments/hyperparameters. :param parser: argparse.ArgumentParser Returns: - updated parser """ parser.add_argument( "--encoder_model", default="bert-base-uncased", type=str, help="Encoder model to be used.", ) parser.add_argument( "--encoder_learning_rate", default=1e-05, type=float, help="Encoder specific learning rate.", ) parser.add_argument( "--learning_rate", default=3e-05, type=float, help="Classification head learning rate.", ) parser.add_argument( "--nr_frozen_epochs", default=1, type=int, help="Number of epochs we want to keep the encoder model frozen.", ) parser.add_argument( "--train_csv", default="data/imdb_reviews_train.csv", type=str, help="Path to the file containing the train data.", ) parser.add_argument( "--dev_csv", default="data/imdb_reviews_test.csv", type=str, help="Path to the file containing the dev data.", ) parser.add_argument( "--test_csv", default="data/imdb_reviews_test.csv", type=str, help="Path to the file containing the dev data.", ) parser.add_argument( "--loader_workers", default=8, type=int, help="How many subprocesses to use for data loading. 0 means that \ the data will be loaded in the main process.", ) return parser