def test_confusion_matrix(): target = (torch.arange(120) % 3).view(-1, 1) pred = target.clone() cm = confusion_matrix(pred, target, normalize=True) assert torch.allclose( cm, torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])) pred = torch.zeros_like(pred) cm = confusion_matrix(pred, target, normalize=True) assert torch.allclose( cm, torch.tensor([[1., 0., 0.], [1., 0., 0.], [1., 0., 0.]]))
def validation_epoch_end(self, validation_step_outputs): if self.logger and self.current_epoch > 0 and self.current_epoch % 5 == 0: epoch_preds = torch.cat([x["preds"] for x in validation_step_outputs]) epoch_targets = torch.cat([x["labels"] for x in validation_step_outputs]) cm = confusion_matrix(epoch_preds, epoch_targets, num_classes=self.hparams.num_classes).cpu().numpy() class_names = getattr(self.train_dataloader().dataset, "classes", None) fig = generate_confusion_matrix(cm, class_names=class_names) self.logger.experiment.log({"confusion_matrix": fig})
def on_test_epoch_end(self): # confusion matrix cm = confusion_matrix(self.test_pred[:,0], self.test_pred[:,1],\ normalize=True, num_classes=10) plot_confusion_matrix(cm.cpu().numpy(), \ target_names=[str(i) for i in range(10)], \ title='Confusion Matrix',normalize=False, \ cmap=plt.get_cmap('bwr'), \ args=self.args) return None
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Actual metric computation Args: pred: predicted labels target: ground truth labels Return: A Tensor with the confusion matrix. """ return confusion_matrix(pred=pred, target=target, normalize=self.normalize)
def validation_epoch_end(self, outputs): avg_loss = torch.stack([b['loss'] for b in outputs]).mean() conf_mtx = confusion_matrix(torch.cat([b['preds'] for b in outputs]), torch.cat([b['labels'] for b in outputs]), normalize=False, num_classes=5) avg_acc = torch.diag(conf_mtx).sum() / conf_mtx.sum() self.logger.experiment.add_scalar('val_avg_loss', avg_loss, self.current_epoch) self.logger.experiment.add_scalar('val_avg_acc', avg_acc, self.current_epoch)
def training_epoch_end(self, outputs): ##### Print Progress ############### if self.trainer.current_epoch < self.hparams.burn_in_epochs: print('[PROGRESS] Burning in classifier epoch %i' % self.current_epoch) ##### Calculate Metircs ############ avg_loss = torch.stack([b['loss'] for b in outputs]).mean() conf_mtx = confusion_matrix(torch.cat([b['preds'] for b in outputs]), torch.cat([b['labels'] for b in outputs]), normalize=False, num_classes=5) avg_acc = torch.diag(conf_mtx).sum() / conf_mtx.sum() self.logger.experiment.add_scalar('train_avg_loss', avg_loss, self.current_epoch) self.logger.experiment.add_scalar('train_avg_acc', avg_acc, self.current_epoch)
def test_step(self, batch: tuple, batch_nb: int, *args, **kwargs) -> dict: """ Runs one training step. This usually consists in the forward function followed by the loss function. :param batch: The output of your dataloader. :param batch_nb: Integer displaying which batch this is Returns: - dictionary containing the loss and the metrics to be added to the lightning logger. """ inputs, targets = batch model_out = self.forward(**inputs) loss_val = self.loss(model_out, targets) # in DP mode (default) make sure if result is scalar, there's another dim in the beginning if self.trainer.use_dp or self.trainer.use_ddp2: loss_val = loss_val.unsqueeze(0) self.log('test_loss', loss_val) y_hat = model_out['logits'] labels_hat = torch.argmax(y_hat, dim=1) y = targets['labels'] f1 = metrics.f1_score(labels_hat, y, class_reduction='weighted') prec = metrics.precision(labels_hat, y, class_reduction='weighted') recall = metrics.recall(labels_hat, y, class_reduction='weighted') acc = metrics.accuracy(labels_hat, y, class_reduction='weighted') self.log('test_batch_prec', prec) self.log('test_batch_f1', f1) self.log('test_batch_recall', recall) self.log('test_batch_weighted_acc', acc) cm = metrics.confusion_matrix(pred=labels_hat, target=y, normalize=False) self.test_conf_matrices.append(cm)
def test(args, dm, net): from pytorch_lightning.metrics.functional.classification import confusion_matrix model = net.load_from_checkpoint(checkpoint_path=args.ckpt_path, num_channel=dm.n_channels, num_class=dm.n_classes) model.eval() dm.config["batch_size"] = 1 dm.prepare_data() dm.setup() trainer = pl.Trainer() trainer.test(model, datamodule=dm) cmat = confusion_matrix(torch.stack(model.predictions), torch.stack(model.targets), num_classes=dm.n_classes) utils.plot_confusion_matrix(matrix=cmat.numpy(), classes=dm.raw_Y.columns, figure_name="./figures/cmat.jpg")
def test_epoch_end(self, outputs): if self.incorrect_type != 'boundary': ##### Confusion Matrix ##### conf_mtx = confusion_matrix( torch.cat([b['preds'] for b in outputs]), torch.cat([b['labels'] for b in outputs]), normalize=False, num_classes=5) ##### Normalized Confusion Matrix ##### conf_mtx_normalized = confusion_matrix( torch.cat([b['preds'] for b in outputs]), torch.cat([b['labels'] for b in outputs]), normalize=True, num_classes=5) ##### Weighted Confusion Matrix ##### conf_mtx_weighted = conf_mtx.clone() for c, w in enumerate(self.weights): conf_mtx_weighted[c, :] *= w ##### ACCURACY ##### accuracy = torch.diag(conf_mtx).sum() / conf_mtx.sum() accuracy_weighted = torch.diag( conf_mtx_weighted).sum() / conf_mtx_weighted.sum() ##### AUC_SCORE ##### roc_results = multiclass_roc( torch.cat([b['logits'] for b in outputs]), torch.cat([b['labels'] for b in outputs]), num_classes=5) AUROC_str = '' AUROC_list = {} for cls, roc_cls in enumerate(roc_results): fpr, tpr, threshold = roc_cls self.logger.experiment.add_scalar(f'val_AUC[{cls}]', auc(fpr, tpr), self.current_epoch) AUROC_str += '\tAUC_SCORE[CLS %d]: \t%.4f\n' % (cls, auc(fpr, tpr)) AUROC_list['AUC_SCORE[CLS %d]' % cls] = auc(fpr, tpr) ##### F1 ##### f1_score = f1(torch.cat([b['preds'] for b in outputs]), torch.cat([b['labels'] for b in outputs]), num_classes=5) ##### Average Precision ##### # TO DO ##### PRINT RESULTS ##### print('=' * 100) print( f'[MODEL NAME]: {self.model_name} \t [INCORRECT TYPE]: {self.incorrect_type}' ) print('RESULTS:') print('\tAccuracy: \t\t%.4f' % accuracy) print('\tWeighted Accuracy: \t%.4f' % accuracy_weighted) print('\tF1 Score: \t\t%.4f' % f1_score) print(AUROC_str) self.metrics_result[self.incorrect_type][self.model_name] = { 'Accuracy': round(float(accuracy), 4), 'Weighted Accuracy': round(float(accuracy_weighted), 4), 'F1_score': round(float(f1_score), 4) } for key, val in AUROC_list.items(): self.metrics_result[self.incorrect_type][ self.model_name].update({key: round(float(val), 4)}) print('Confusion Matrix') fig, ax = plt.subplots(figsize=(4, 4)) sn.heatmap(conf_mtx.cpu(), annot=True, cbar=False, annot_kws={"size": 15}, fmt='g', cmap='mako') plt.show() fig, ax = plt.subplots(figsize=(4, 4)) sn.heatmap(conf_mtx_normalized.cpu(), annot=True, cbar=False, annot_kws={"size": 12}, fmt='.2f', cmap='mako') plt.show() print('=' * 100) else: tol_correct = 0 tol_samples = 0 tol_drop = 0 for batch in outputs: preds = batch['preds'] labels = batch['labels'] slope_id = batch['doc_ids'] ##### Change lizhong's code #### for idx, slop_idx in enumerate(slope_id): agree_by_user = bool( slope_df[slope_df['slope_id'] == slop_idx.item()] ['sentiment_correct'].values[0]) possible_classes = slope_df[ slope_df['slope_id'] == slop_idx.item()]['label_from_score'].values[0] pred_class = preds[idx] # difference between pred and true label diff = torch.abs(pred_class - possible_classes) # if correct label if agree_by_user: # True if diff == 0: # correct prediction tol_correct += 1 tol_samples += 1 elif diff == 1: # discard tol_drop += 1 else: # wrong prediction tol_samples += 1 # if incorrect label else: # False if diff == 0: # wrong tol_samples += 1 elif diff == 1: # discard tol_drop += 1 else: # Correct tol_correct += 1 tol_samples += 1 boundary_accuracy = round(tol_correct / tol_samples, 4) self.metrics_result[self.incorrect_type][self.model_name] = {} self.metrics_result[self.incorrect_type][ self.model_name]['boundary_acc'] = boundary_accuracy self.metrics_result[self.incorrect_type][ self.model_name]['total_drop_sample'] = tol_drop print('=' * 100) print( f'[MODEL NAME]: {self.model_name} \t [INCORRECT TYPE]: {self.incorrect_type}' ) print('\tBoundary Accuracy: \t\t%.4f' % boundary_accuracy) print('\tDrop Total Sample: \t\t%.4f' % tol_drop)
def run_epoch(model, dataloader, criterion, optimizer=None, epoch=0, scheduler=None, device='cpu'): import pytorch_lightning.metrics.functional.classification as clmetrics from pytorch_lightning.metrics import Precision, Accuracy, Recall from sklearn.metrics import roc_auc_score, average_precision_score metrics = Accumulator() cnt = 0 total_steps = len(dataloader) steps = 0 running_corrects = 0 accuracy = Accuracy() precision = Precision(num_classes=2) recall = Recall(num_classes=2) preds_epoch = [] labels_epoch = [] for inputs, labels in dataloader: steps += 1 inputs = inputs.to(device) # torch.Size([2, 1, 224, 224]) labels = labels.to(device).unsqueeze(1).float() ## torch.Size([2, 1]) outputs = model(inputs) # [batch_size, nb_classes] loss = criterion(outputs, labels) if optimizer: loss.backward() optimizer.step() optimizer.zero_grad() preds_epoch.extend(torch.sigmoid(outputs).tolist()) labels_epoch.extend(labels.tolist()) threshold = 0.5 prob = (torch.sigmoid(outputs) > threshold).long() conf = torch.flatten( clmetrics.confusion_matrix(prob, labels, num_classes=2)) tn, fp, fn, tp = conf metrics.add_dict({ 'data_count': len(inputs), 'loss': loss.item() * len(inputs), 'tp': tp.item(), 'tn': tn.item(), 'fp': fp.item(), 'fn': fn.item(), }) cnt += len(inputs) if scheduler: scheduler.step() del outputs, loss, inputs, labels, prob logger.info(f'cnt = {cnt}') metrics['loss'] /= cnt def safe_div(x, y): if y == 0: return 0 return x / y _TP, _TN, _FP, _FN = metrics['tp'], metrics['tn'], metrics['fp'], metrics[ 'fn'] acc = (_TP + _TN) / cnt sen = safe_div(_TP, (_TP + _FN)) spe = safe_div(_TN, (_FP + _TN)) prec = safe_div(_TP, (_TP + _FP)) metrics.add('accuracy', acc) metrics.add('sensitivity', sen) metrics.add('specificity', spe) metrics.add('precision', prec) auc = roc_auc_score(labels_epoch, preds_epoch) aupr = average_precision_score(labels_epoch, preds_epoch) metrics.add('auroc', auc) metrics.add('aupr', aupr) logger.info(metrics) return metrics, preds_epoch, labels_epoch