def test( self, test_metrics_classification: Tuple[nn.Module, ...] = ( validation_metric.ClassificationAccuracy(), ), test_metrics_bounding_box: Tuple[nn.Module, ...] = ( nn.L1Loss(), nn.MSELoss(), validation_metric.BoundingBoxIoU(), validation_metric.BoundingBoxGIoU()), test_metrics_segmentation: Tuple[nn.Module, ...] = ( validation_metric.Accuracy(), validation_metric.Precision(), validation_metric.Recall(), validation_metric.F1(), validation_metric.IoU(), validation_metric.MIoU(), validation_metric.Dice(), validation_metric.CellIoU(), validation_metric.MeanAveragePrecision(), validation_metric.InstancesAccuracy()) ) -> None: """ Test method :param test_metrics_classification: (Tuple[nn.Module, ...]) Test modules for classification :param test_metrics_bounding_box: (Tuple[nn.Module, ...]) Test modules for bounding boxes :param test_metrics_segmentation: (Tuple[nn.Module, ...]) Test modules for segmentation """ # DETR to device self.detr.to(self.device) # DETR into eval mode self.detr.eval() # Init dicts to store metrics metrics_classification = dict() metrics_bounding_box = dict() metrics_segmentation = dict() # Main loop over the test set for index, batch in enumerate(self.test_dataset): # Get data from batch input, instance_labels, bounding_box_labels, class_labels = batch # Data to device input = input.to(self.device) instance_labels = misc.iterable_to_device(instance_labels, device=self.device) bounding_box_labels = misc.iterable_to_device(bounding_box_labels, device=self.device) class_labels = misc.iterable_to_device(class_labels, device=self.device) # Get prediction class_predictions, bounding_box_predictions, instance_predictions = self.detr( input) # Perform matching matching_indexes = self.loss_function.matcher( class_predictions, bounding_box_predictions, class_labels, bounding_box_labels) # Apply permutation to labels and predictions class_predictions, class_labels = self.loss_function.apply_permutation( prediction=class_predictions, label=class_labels, indexes=matching_indexes) bounding_box_predictions, bounding_box_labels = self.loss_function.apply_permutation( prediction=bounding_box_predictions, label=bounding_box_labels, indexes=matching_indexes) instance_predictions, instance_labels = self.loss_function.apply_permutation( prediction=instance_predictions, label=instance_labels, indexes=matching_indexes) for batch_index in range(len(class_labels)): # Calc test metrics for classification for test_metric_classification in test_metrics_classification: # Calc metric metric = test_metric_classification( class_predictions[ batch_index, :class_labels[batch_index].shape[0]]. argmax(dim=-1), class_labels[batch_index].argmax(dim=-1)).item() # Save metric and name of metric if test_metric_classification.__class__.__name__ in metrics_classification.keys( ): metrics_classification[test_metric_classification. __class__.__name__].append( metric) else: metrics_classification[ test_metric_classification.__class__.__name__] = [ metric ] # Calc test metrics for bounding boxes for test_metric_bounding_box in test_metrics_bounding_box: # Calc metric metric = test_metric_bounding_box( misc.bounding_box_xcycwh_to_x0y0x1y1( bounding_box_predictions[ batch_index, :bounding_box_labels[batch_index]. shape[0]]), misc.bounding_box_xcycwh_to_x0y0x1y1( bounding_box_labels[batch_index])).item() # Save metric and name of metric if test_metric_bounding_box.__class__.__name__ in metrics_bounding_box.keys( ): metrics_bounding_box[test_metric_bounding_box. __class__.__name__].append(metric) else: metrics_bounding_box[ test_metric_bounding_box.__class__.__name__] = [ metric ] # Calc test metrics for bounding boxes for test_metric_segmentation in test_metrics_segmentation: # Calc metric metric = test_metric_segmentation( instance_predictions[ batch_index, :instance_labels[batch_index]. shape[0]], instance_labels[batch_index], class_label=class_labels[batch_index].argmax( dim=-1)).item() # Save metric and name of metric if test_metric_segmentation.__class__.__name__ in metrics_segmentation.keys( ): metrics_segmentation[test_metric_segmentation. __class__.__name__].append(metric) else: metrics_segmentation[ test_metric_segmentation.__class__.__name__] = [ metric ] # Plot object_classes = class_predictions[0].argmax(dim=-1).cpu().detach() # Case the no objects are detected if object_classes.shape[0] > 0: object_indexes = torch.from_numpy( np.argwhere(object_classes.numpy() > 0)[:, 0]) bounding_box_predictions = misc.relative_bounding_box_to_absolute( misc.bounding_box_xcycwh_to_x0y0x1y1( bounding_box_predictions[ 0, object_indexes].cpu().clone().detach()), height=input.shape[-2], width=input.shape[-1]) misc.plot_instance_segmentation_overlay_instances_bb_classes( image=input[0], instances=(instance_predictions[0][object_indexes] > 0.5).float(), bounding_boxes=bounding_box_predictions, class_labels=object_classes[object_indexes], show=False, save=True, file_path=os.path.join( self.path_save_plots, "test_plot_{}_is_bb_c.png".format(index))) misc.plot_instance_segmentation_overlay_instances_bb_classes( image=input[0], instances=(instance_predictions[0][object_indexes] > 0.5).float(), bounding_boxes=bounding_box_predictions, class_labels=object_classes[object_indexes], show=False, save=True, show_class_label=False, file_path=os.path.join( self.path_save_plots, "test_plot_{}_is_bb.png".format(index))) misc.plot_instance_segmentation_overlay_instances( image=input[0], instances=(instance_predictions[0][object_indexes] > 0.5).float(), class_labels=object_classes[object_indexes], show=False, save=True, file_path=os.path.join( self.path_save_plots, "test_plot_{}_is.png".format(index))) misc.plot_instance_segmentation_overlay_bb_classes( image=input[0], bounding_boxes=bounding_box_predictions, class_labels=object_classes[object_indexes], show=False, save=True, file_path=os.path.join( self.path_save_plots, "test_plot_{}_bb_c.png".format(index))) misc.plot_instance_segmentation_labels( instances=(instance_predictions[0][object_indexes] > 0.5).float(), bounding_boxes=bounding_box_predictions, class_labels=object_classes[object_indexes], show=False, save=True, file_path=os.path.join( self.path_save_plots, "test_plot_{}_bb_no_overlay_.png".format(index)), show_class_label=False, white_background=True) misc.plot_instance_segmentation_map_label( instances=(instance_predictions[0][object_indexes] > 0.5).float(), class_labels=object_classes[object_indexes], show=False, save=True, file_path=os.path.join( self.path_save_plots, "test_plot_{}_no_overlay.png".format(index)), white_background=True) # Average metrics and save them in logs for metric_name in metrics_classification: print(metric_name + "_classification_test=", float(np.mean(metrics_classification[metric_name]))) self.logger.log(metric_name=metric_name + "_classification_test", value=float( np.mean(metrics_classification[metric_name]))) for metric_name in metrics_bounding_box: print(metric_name + "_bounding_box_test=", float(np.mean(metrics_bounding_box[metric_name]))) self.logger.log(metric_name=metric_name + "_bounding_box_test", value=float( np.mean(metrics_bounding_box[metric_name]))) for metric_name in metrics_segmentation: metric_values = np.array(metrics_segmentation[metric_name]) print(metric_name + "_segmentation_test=", float(np.mean(metric_values[~np.isnan(metric_values)]))) self.logger.log( metric_name=metric_name + "_segmentation_test", value=float(np.mean(metric_values[~np.isnan(metric_values)]))) # Save metrics self.logger.save_metrics(path=self.path_save_metrics)
def validate(self, validation_metrics_classification: Tuple[nn.Module, ...] = ( validation_metric.ClassificationAccuracy(), ), validation_metrics_bounding_box: Tuple[nn.Module, ...] = ( nn.L1Loss(), nn.MSELoss(), validation_metric.BoundingBoxIoU(), validation_metric.BoundingBoxGIoU()), validation_metrics_segmentation: Tuple[nn.Module, ...] = ( validation_metric.Accuracy(), validation_metric.Precision(), validation_metric.Recall(), validation_metric.F1(), validation_metric.IoU(), validation_metric.MIoU(), validation_metric.Dice(), validation_metric.CellIoU(), validation_metric.MeanAveragePrecision(), validation_metric.InstancesAccuracy()), epoch: int = -1, number_of_plots: int = 5, train: bool = False) -> None: """ Validation method :param validation_metrics_classification: (Tuple[nn.Module, ...]) Validation modules for classification :param validation_metrics_bounding_box: (Tuple[nn.Module, ...]) Validation modules for bounding boxes :param validation_metrics_segmentation: (Tuple[nn.Module, ...]) Validation modules for segmentation :param epoch: (int) Current epoch :param number_of_plots: (int) Number of validation plot to be produced :param train: (bool) Train flag if set best model is saved based on val iou """ # DETR to device self.detr.to(self.device) # DETR into eval mode self.detr.eval() # Init dicts to store metrics metrics_classification = dict() metrics_bounding_box = dict() metrics_segmentation = dict() # Init indexes of elements to be plotted plot_indexes = np.random.choice(np.arange( 0, len(self.validation_dataset)), number_of_plots, replace=False) # Main loop over the validation set for index, batch in enumerate(self.validation_dataset): # Get data from batch input, instance_labels, bounding_box_labels, class_labels = batch # Data to device input = input.to(self.device) instance_labels = misc.iterable_to_device(instance_labels, device=self.device) bounding_box_labels = misc.iterable_to_device(bounding_box_labels, device=self.device) class_labels = misc.iterable_to_device(class_labels, device=self.device) # Get prediction class_predictions, bounding_box_predictions, instance_predictions = self.detr( input) # Perform matching matching_indexes = self.loss_function.matcher( class_predictions, bounding_box_predictions, class_labels, bounding_box_labels) # Apply permutation to labels and predictions class_predictions, class_labels = self.loss_function.apply_permutation( prediction=class_predictions, label=class_labels, indexes=matching_indexes) bounding_box_predictions, bounding_box_labels = self.loss_function.apply_permutation( prediction=bounding_box_predictions, label=bounding_box_labels, indexes=matching_indexes) instance_predictions, instance_labels = self.loss_function.apply_permutation( prediction=instance_predictions, label=instance_labels, indexes=matching_indexes) for batch_index in range(len(class_labels)): # Calc validation metrics for classification for validation_metric_classification in validation_metrics_classification: # Calc metric metric = validation_metric_classification( class_predictions[ batch_index, :class_labels[batch_index].shape[0]]. argmax(dim=-1), class_labels[batch_index].argmax(dim=-1)).item() # Save metric and name of metric if validation_metric_classification.__class__.__name__ in metrics_classification.keys( ): metrics_classification[ validation_metric_classification.__class__. __name__].append(metric) else: metrics_classification[validation_metric_classification .__class__.__name__] = [metric] # Calc validation metrics for bounding boxes for validation_metric_bounding_box in validation_metrics_bounding_box: # Calc metric metric = validation_metric_bounding_box( misc.bounding_box_xcycwh_to_x0y0x1y1( bounding_box_predictions[ batch_index, :bounding_box_labels[batch_index]. shape[0]]), misc.bounding_box_xcycwh_to_x0y0x1y1( bounding_box_labels[batch_index])).item() # Save metric and name of metric if validation_metric_bounding_box.__class__.__name__ in metrics_bounding_box.keys( ): metrics_bounding_box[validation_metric_bounding_box. __class__.__name__].append(metric) else: metrics_bounding_box[validation_metric_bounding_box. __class__.__name__] = [metric] # Calc validation metrics for bounding boxes for validation_metric_segmentation in validation_metrics_segmentation: # Calc metric metric = validation_metric_segmentation( instance_predictions[ batch_index, :instance_labels[batch_index]. shape[0]], instance_labels[batch_index], class_label=class_labels[batch_index].argmax( dim=-1)).item() # Save metric and name of metric if validation_metric_segmentation.__class__.__name__ in metrics_segmentation.keys( ): metrics_segmentation[validation_metric_segmentation. __class__.__name__].append(metric) else: metrics_segmentation[validation_metric_segmentation. __class__.__name__] = [metric] if index in plot_indexes: # Plot object_classes = class_predictions[0].argmax( dim=-1).cpu().detach() # Case the no objects are detected if object_classes.shape[0] > 0: object_indexes = torch.from_numpy( np.argwhere(object_classes.numpy() > 0)[:, 0]) bounding_box_predictions = misc.relative_bounding_box_to_absolute( misc.bounding_box_xcycwh_to_x0y0x1y1( bounding_box_predictions[ 0, object_indexes].cpu().clone().detach()), height=input.shape[-2], width=input.shape[-1]) misc.plot_instance_segmentation_overlay_instances_bb_classes( image=input[0], instances=(instance_predictions[0][object_indexes] > 0.5).float(), bounding_boxes=bounding_box_predictions, class_labels=object_classes[object_indexes], show=False, save=True, file_path=os.path.join( self.path_save_plots, "validation_plot_is_bb_c_{}_{}.png".format( epoch + 1, index))) # Average metrics and save them in logs for metric_name in metrics_classification: self.logger.log(metric_name=metric_name + "_classification_val", value=float( np.mean(metrics_classification[metric_name]))) for metric_name in metrics_bounding_box: self.logger.log(metric_name=metric_name + "_bounding_box_val", value=float( np.mean(metrics_bounding_box[metric_name]))) for metric_name in metrics_segmentation: metric_values = np.array(metrics_segmentation[metric_name]) # Save best mIoU model if training is utilized if train and "MIoU" in metric_name and float( np.mean( metrics_segmentation[metric_name])) > self.best_miou: # Save current mIoU self.best_miou = float( np.mean(metric_values[~np.isnan(metric_values)])) # Show best MIoU as process name setproctitle.setproctitle("Cell-DETR best MIoU={:.4f}".format( self.best_miou)) # Save model torch.save( self.detr.module.state_dict() if isinstance( self.detr, nn.DataParallel) else self.detr.state_dict(), os.path.join(self.path_save_models, "detr_best_model.pt")) self.logger.log( metric_name=metric_name + "_segmentation_val", value=float(np.mean(metric_values[~np.isnan(metric_values)]))) # Save metrics self.logger.save_metrics(path=self.path_save_metrics)