Пример #1
0
def evaluate(args, device, model):
    '''
	Create a logger for logging model evaluation results
	'''
    logger = logging.getLogger(__name__)
    '''
	Create an instance of evaluation data loader
	'''
    xray_transform = CenterCrop(2048)
    _, _, eval_labels, eval_dicom_ids, _ = split_tr_eval(
        args.data_split_path, args.training_folds, args.evaluation_folds)
    cxr_dataset = CXRImageDataset(eval_dicom_ids,
                                  eval_labels,
                                  args.image_dir,
                                  transform=xray_transform,
                                  image_format=args.image_format)
    data_loader = DataLoader(cxr_dataset,
                             batch_size=args.batch_size,
                             num_workers=8,
                             pin_memory=True)
    print('Total number of evaluation images: ', len(cxr_dataset))
    '''
	Log evaluation info
	'''
    logger.info("***** Evaluation info *****")
    logger.info("  Model architecture: %s", args.model_architecture)
    logger.info("  Data split file: %s", args.data_split_path)
    logger.info("  Training folds: %s\t Evaluation folds: %s" %
                (args.training_folds, args.evaluation_folds))
    logger.info("  Number of evaluation examples: %d", len(cxr_dataset))
    logger.info("  Number of epochs: %d", args.num_train_epochs)
    logger.info("  Batch size: %d", args.batch_size)
    logger.info("  Model checkpoint {}:".format(args.checkpoint_path))
    '''
	Evaluate the model
	'''

    logger.info("***** Evaluating the model *****")

    # For storing labels and model predictions
    preds = []
    labels = []
    embeddings = []

    model.eval()
    epoch_iterator = tqdm(data_loader, desc="Iteration")
    for i, batch in enumerate(epoch_iterator, 0):
        # Get the batch; each batch is a list of [image, label]
        batch = tuple(t.to(device, non_blocking=True) for t in batch)
        image, label, _ = batch
        with torch.no_grad():
            output, embedding, _ = model(image)
            pred = output.detach().cpu().numpy()
            embedding = embedding.detach().cpu().numpy()
            label = label.detach().cpu().numpy()
            for j in range(len(pred)):
                preds.append(pred[j])
                labels.append(label[j])
                embeddings.append(embedding[j])

    labels_raw = np.argmax(labels, axis=1)
    eval_results = {}

    ordinal_aucs = eval_metrics.compute_ordinal_auc(labels, preds)
    eval_results['ordinal_aucs'] = ordinal_aucs

    pairwise_aucs = eval_metrics.compute_pairwise_auc(labels, preds)
    eval_results['pairwise_auc'] = pairwise_aucs

    multiclass_aucs = eval_metrics.compute_multiclass_auc(labels, preds)
    eval_results['multiclass_aucs'] = multiclass_aucs

    eval_results['mse'] = eval_metrics.compute_mse(labels_raw, preds)

    results_acc_f1, _, _ = eval_metrics.compute_acc_f1_metrics(
        labels_raw, preds)
    eval_results.update(results_acc_f1)

    logger.info("  AUC(0v123) = %4f", eval_results['ordinal_aucs'][0])
    logger.info("  AUC(01v23) = %4f", eval_results['ordinal_aucs'][1])
    logger.info("  AUC(012v3) = %4f", eval_results['ordinal_aucs'][2])

    logger.info("  AUC(0v1) = %4f", eval_results['pairwise_auc']['0v1'])
    logger.info("  AUC(0v2) = %4f", eval_results['pairwise_auc']['0v2'])
    logger.info("  AUC(0v3) = %4f", eval_results['pairwise_auc']['0v3'])
    logger.info("  AUC(1v2) = %4f", eval_results['pairwise_auc']['1v2'])
    logger.info("  AUC(1v3) = %4f", eval_results['pairwise_auc']['1v3'])
    logger.info("  AUC(2v3) = %4f", eval_results['pairwise_auc']['2v3'])

    logger.info("  AUC(0v123) = %4f", eval_results['multiclass_aucs'][0])
    logger.info("  AUC(1v023) = %4f", eval_results['multiclass_aucs'][1])
    logger.info("  AUC(2v013) = %4f", eval_results['multiclass_aucs'][2])
    logger.info("  AUC(3v012) = %4f", eval_results['multiclass_aucs'][3])

    logger.info("  MSE = %4f", eval_results['mse'])

    logger.info("  Macro_F1 = %4f", eval_results['macro_f1'])
    logger.info("  Accuracy = %4f", eval_results['accuracy'])

    return eval_results, embeddings, labels_raw
Пример #2
0
    def eval(self, device, args, checkpoint_path):
        '''
		Load the checkpoint (essentially create a "different" model)
		'''
        self.model = build_model(model_name=self.model_name,
                                 output_channels=self.output_channels,
                                 checkpoint_path=checkpoint_path)
        '''
		Create an instance of evaluation data loader
		'''
        print('***** Instantiate a data loader *****')
        dataset = build_evaluation_dataset(
            data_dir=args.data_dir,
            img_size=self.img_size,
            dataset_metadata=args.dataset_metadata,
            label_key=args.label_key)
        data_loader = DataLoader(dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=8,
                                 pin_memory=True)
        print(f'Total number of evaluation images: {len(dataset)}')
        '''
		Evaluate the model
		'''
        print('***** Evaluate the model *****')
        self.model = self.model.to(device)
        self.model.eval()

        # For storing labels and model predictions
        all_preds_prob = []
        all_preds_logit = []
        all_labels = []

        epoch_iterator = tqdm(data_loader, desc="Iteration")
        for i, batch in enumerate(epoch_iterator, 0):
            # Parse the batch
            images, labels, image_ids = batch
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            with torch.no_grad():
                outputs = self.model(images)

                preds_prob = outputs[0]
                preds_logit = outputs[-1]

                #if not args.label_key == 'edema_severity':
                #labels = torch.reshape(labels, preds_logit.size())

                preds_prob = preds_prob.detach().cpu().numpy()
                preds_logit = preds_logit.detach().cpu().numpy()
                labels = labels.detach().cpu().numpy()

                all_preds_prob += \
                 [preds_prob[j] for j in range(len(labels))]
                all_preds_logit += \
                 [preds_logit[j] for j in range(len(labels))]
                all_labels += \
                 [labels[j] for j in range(len(labels))]

        all_preds_class = np.argmax(all_preds_prob, axis=1)
        inference_results = {
            'all_preds_prob': all_preds_prob,
            'all_preds_class': all_preds_class,
            'all_preds_logit': all_preds_logit,
            'all_labels': all_labels
        }
        eval_results = {}

        if args.label_key == '14d_hf':
            aucs = eval_metrics.compute_binary_auc(all_labels, all_preds_prob)
            eval_results['aucs'] = aucs
            # 			all_onehot_labels = [convert_to_onehot(label) for label in all_labels]

            # 			ordinal_aucs = eval_metrics.compute_ordinal_auc(all_onehot_labels, all_preds_prob)
            # 			eval_results['ordinal_aucs'] = ordinal_aucs

            # 			ordinal_acc_f1 = eval_metrics.compute_ordinal_acc_f1_metrics(all_onehot_labels,
            # 																	     all_preds_prob)
            # 			eval_results.update(ordinal_acc_f1)

            # 			eval_results['mse'] = eval_metrics.compute_mse(all_labels, all_preds_prob)

            results_acc_f1, _, _ = eval_metrics.compute_acc_f1_metrics(
                all_labels, all_preds_prob)
            eval_results.update(results_acc_f1)
        else:
            all_preds_prob = [
                1 / (1 + np.exp(-logit)) for logit in all_preds_logit
            ]
            all_preds_class = np.argmax(all_preds_prob, axis=1)
            aucs = eval_metrics.compute_multiclass_auc(all_labels,
                                                       all_preds_prob)
            eval_results['aucs'] = aucs

        return inference_results, eval_results