def test(
     self,
     test_metrics_classification: Tuple[nn.Module, ...] = (
         validation_metric.ClassificationAccuracy(), ),
     test_metrics_bounding_box: Tuple[nn.Module, ...] = (
         nn.L1Loss(), nn.MSELoss(), validation_metric.BoundingBoxIoU(),
         validation_metric.BoundingBoxGIoU()),
     test_metrics_segmentation: Tuple[nn.Module, ...] = (
         validation_metric.Accuracy(), validation_metric.Precision(),
         validation_metric.Recall(), validation_metric.F1(),
         validation_metric.IoU(), validation_metric.MIoU(),
         validation_metric.Dice(), validation_metric.CellIoU(),
         validation_metric.MeanAveragePrecision(),
         validation_metric.InstancesAccuracy())
 ) -> None:
     """
     Test method
     :param test_metrics_classification: (Tuple[nn.Module, ...]) Test modules for classification
     :param test_metrics_bounding_box: (Tuple[nn.Module, ...]) Test modules for bounding boxes
     :param test_metrics_segmentation: (Tuple[nn.Module, ...]) Test modules for segmentation
     """
     # DETR to device
     self.detr.to(self.device)
     # DETR into eval mode
     self.detr.eval()
     # Init dicts to store metrics
     metrics_classification = dict()
     metrics_bounding_box = dict()
     metrics_segmentation = dict()
     # Main loop over the test set
     for index, batch in enumerate(self.test_dataset):
         # Get data from batch
         input, instance_labels, bounding_box_labels, class_labels = batch
         # Data to device
         input = input.to(self.device)
         instance_labels = misc.iterable_to_device(instance_labels,
                                                   device=self.device)
         bounding_box_labels = misc.iterable_to_device(bounding_box_labels,
                                                       device=self.device)
         class_labels = misc.iterable_to_device(class_labels,
                                                device=self.device)
         # Get prediction
         class_predictions, bounding_box_predictions, instance_predictions = self.detr(
             input)
         # Perform matching
         matching_indexes = self.loss_function.matcher(
             class_predictions, bounding_box_predictions, class_labels,
             bounding_box_labels)
         # Apply permutation to labels and predictions
         class_predictions, class_labels = self.loss_function.apply_permutation(
             prediction=class_predictions,
             label=class_labels,
             indexes=matching_indexes)
         bounding_box_predictions, bounding_box_labels = self.loss_function.apply_permutation(
             prediction=bounding_box_predictions,
             label=bounding_box_labels,
             indexes=matching_indexes)
         instance_predictions, instance_labels = self.loss_function.apply_permutation(
             prediction=instance_predictions,
             label=instance_labels,
             indexes=matching_indexes)
         for batch_index in range(len(class_labels)):
             # Calc test metrics for classification
             for test_metric_classification in test_metrics_classification:
                 # Calc metric
                 metric = test_metric_classification(
                     class_predictions[
                         batch_index, :class_labels[batch_index].shape[0]].
                     argmax(dim=-1),
                     class_labels[batch_index].argmax(dim=-1)).item()
                 # Save metric and name of metric
                 if test_metric_classification.__class__.__name__ in metrics_classification.keys(
                 ):
                     metrics_classification[test_metric_classification.
                                            __class__.__name__].append(
                                                metric)
                 else:
                     metrics_classification[
                         test_metric_classification.__class__.__name__] = [
                             metric
                         ]
             # Calc test metrics for bounding boxes
             for test_metric_bounding_box in test_metrics_bounding_box:
                 # Calc metric
                 metric = test_metric_bounding_box(
                     misc.bounding_box_xcycwh_to_x0y0x1y1(
                         bounding_box_predictions[
                             batch_index, :bounding_box_labels[batch_index].
                             shape[0]]),
                     misc.bounding_box_xcycwh_to_x0y0x1y1(
                         bounding_box_labels[batch_index])).item()
                 # Save metric and name of metric
                 if test_metric_bounding_box.__class__.__name__ in metrics_bounding_box.keys(
                 ):
                     metrics_bounding_box[test_metric_bounding_box.
                                          __class__.__name__].append(metric)
                 else:
                     metrics_bounding_box[
                         test_metric_bounding_box.__class__.__name__] = [
                             metric
                         ]
             # Calc test metrics for bounding boxes
             for test_metric_segmentation in test_metrics_segmentation:
                 # Calc metric
                 metric = test_metric_segmentation(
                     instance_predictions[
                         batch_index, :instance_labels[batch_index].
                         shape[0]],
                     instance_labels[batch_index],
                     class_label=class_labels[batch_index].argmax(
                         dim=-1)).item()
                 # Save metric and name of metric
                 if test_metric_segmentation.__class__.__name__ in metrics_segmentation.keys(
                 ):
                     metrics_segmentation[test_metric_segmentation.
                                          __class__.__name__].append(metric)
                 else:
                     metrics_segmentation[
                         test_metric_segmentation.__class__.__name__] = [
                             metric
                         ]
         # Plot
         object_classes = class_predictions[0].argmax(dim=-1).cpu().detach()
         # Case the no objects are detected
         if object_classes.shape[0] > 0:
             object_indexes = torch.from_numpy(
                 np.argwhere(object_classes.numpy() > 0)[:, 0])
             bounding_box_predictions = misc.relative_bounding_box_to_absolute(
                 misc.bounding_box_xcycwh_to_x0y0x1y1(
                     bounding_box_predictions[
                         0, object_indexes].cpu().clone().detach()),
                 height=input.shape[-2],
                 width=input.shape[-1])
             misc.plot_instance_segmentation_overlay_instances_bb_classes(
                 image=input[0],
                 instances=(instance_predictions[0][object_indexes] >
                            0.5).float(),
                 bounding_boxes=bounding_box_predictions,
                 class_labels=object_classes[object_indexes],
                 show=False,
                 save=True,
                 file_path=os.path.join(
                     self.path_save_plots,
                     "test_plot_{}_is_bb_c.png".format(index)))
             misc.plot_instance_segmentation_overlay_instances_bb_classes(
                 image=input[0],
                 instances=(instance_predictions[0][object_indexes] >
                            0.5).float(),
                 bounding_boxes=bounding_box_predictions,
                 class_labels=object_classes[object_indexes],
                 show=False,
                 save=True,
                 show_class_label=False,
                 file_path=os.path.join(
                     self.path_save_plots,
                     "test_plot_{}_is_bb.png".format(index)))
             misc.plot_instance_segmentation_overlay_instances(
                 image=input[0],
                 instances=(instance_predictions[0][object_indexes] >
                            0.5).float(),
                 class_labels=object_classes[object_indexes],
                 show=False,
                 save=True,
                 file_path=os.path.join(
                     self.path_save_plots,
                     "test_plot_{}_is.png".format(index)))
             misc.plot_instance_segmentation_overlay_bb_classes(
                 image=input[0],
                 bounding_boxes=bounding_box_predictions,
                 class_labels=object_classes[object_indexes],
                 show=False,
                 save=True,
                 file_path=os.path.join(
                     self.path_save_plots,
                     "test_plot_{}_bb_c.png".format(index)))
             misc.plot_instance_segmentation_labels(
                 instances=(instance_predictions[0][object_indexes] >
                            0.5).float(),
                 bounding_boxes=bounding_box_predictions,
                 class_labels=object_classes[object_indexes],
                 show=False,
                 save=True,
                 file_path=os.path.join(
                     self.path_save_plots,
                     "test_plot_{}_bb_no_overlay_.png".format(index)),
                 show_class_label=False,
                 white_background=True)
             misc.plot_instance_segmentation_map_label(
                 instances=(instance_predictions[0][object_indexes] >
                            0.5).float(),
                 class_labels=object_classes[object_indexes],
                 show=False,
                 save=True,
                 file_path=os.path.join(
                     self.path_save_plots,
                     "test_plot_{}_no_overlay.png".format(index)),
                 white_background=True)
     # Average metrics and save them in logs
     for metric_name in metrics_classification:
         print(metric_name + "_classification_test=",
               float(np.mean(metrics_classification[metric_name])))
         self.logger.log(metric_name=metric_name + "_classification_test",
                         value=float(
                             np.mean(metrics_classification[metric_name])))
     for metric_name in metrics_bounding_box:
         print(metric_name + "_bounding_box_test=",
               float(np.mean(metrics_bounding_box[metric_name])))
         self.logger.log(metric_name=metric_name + "_bounding_box_test",
                         value=float(
                             np.mean(metrics_bounding_box[metric_name])))
     for metric_name in metrics_segmentation:
         metric_values = np.array(metrics_segmentation[metric_name])
         print(metric_name + "_segmentation_test=",
               float(np.mean(metric_values[~np.isnan(metric_values)])))
         self.logger.log(
             metric_name=metric_name + "_segmentation_test",
             value=float(np.mean(metric_values[~np.isnan(metric_values)])))
     # Save metrics
     self.logger.save_metrics(path=self.path_save_metrics)
 def validate(self,
              validation_metrics_classification: Tuple[nn.Module, ...] = (
                  validation_metric.ClassificationAccuracy(), ),
              validation_metrics_bounding_box: Tuple[nn.Module, ...] = (
                  nn.L1Loss(), nn.MSELoss(),
                  validation_metric.BoundingBoxIoU(),
                  validation_metric.BoundingBoxGIoU()),
              validation_metrics_segmentation: Tuple[nn.Module, ...] = (
                  validation_metric.Accuracy(),
                  validation_metric.Precision(), validation_metric.Recall(),
                  validation_metric.F1(), validation_metric.IoU(),
                  validation_metric.MIoU(), validation_metric.Dice(),
                  validation_metric.CellIoU(),
                  validation_metric.MeanAveragePrecision(),
                  validation_metric.InstancesAccuracy()),
              epoch: int = -1,
              number_of_plots: int = 5,
              train: bool = False) -> None:
     """
     Validation method
     :param validation_metrics_classification: (Tuple[nn.Module, ...]) Validation modules for classification
     :param validation_metrics_bounding_box: (Tuple[nn.Module, ...]) Validation modules for bounding boxes
     :param validation_metrics_segmentation: (Tuple[nn.Module, ...]) Validation modules for segmentation
     :param epoch: (int) Current epoch
     :param number_of_plots: (int) Number of validation plot to be produced
     :param train: (bool) Train flag if set best model is saved based on val iou
     """
     # DETR to device
     self.detr.to(self.device)
     # DETR into eval mode
     self.detr.eval()
     # Init dicts to store metrics
     metrics_classification = dict()
     metrics_bounding_box = dict()
     metrics_segmentation = dict()
     # Init indexes of elements to be plotted
     plot_indexes = np.random.choice(np.arange(
         0, len(self.validation_dataset)),
                                     number_of_plots,
                                     replace=False)
     # Main loop over the validation set
     for index, batch in enumerate(self.validation_dataset):
         # Get data from batch
         input, instance_labels, bounding_box_labels, class_labels = batch
         # Data to device
         input = input.to(self.device)
         instance_labels = misc.iterable_to_device(instance_labels,
                                                   device=self.device)
         bounding_box_labels = misc.iterable_to_device(bounding_box_labels,
                                                       device=self.device)
         class_labels = misc.iterable_to_device(class_labels,
                                                device=self.device)
         # Get prediction
         class_predictions, bounding_box_predictions, instance_predictions = self.detr(
             input)
         # Perform matching
         matching_indexes = self.loss_function.matcher(
             class_predictions, bounding_box_predictions, class_labels,
             bounding_box_labels)
         # Apply permutation to labels and predictions
         class_predictions, class_labels = self.loss_function.apply_permutation(
             prediction=class_predictions,
             label=class_labels,
             indexes=matching_indexes)
         bounding_box_predictions, bounding_box_labels = self.loss_function.apply_permutation(
             prediction=bounding_box_predictions,
             label=bounding_box_labels,
             indexes=matching_indexes)
         instance_predictions, instance_labels = self.loss_function.apply_permutation(
             prediction=instance_predictions,
             label=instance_labels,
             indexes=matching_indexes)
         for batch_index in range(len(class_labels)):
             # Calc validation metrics for classification
             for validation_metric_classification in validation_metrics_classification:
                 # Calc metric
                 metric = validation_metric_classification(
                     class_predictions[
                         batch_index, :class_labels[batch_index].shape[0]].
                     argmax(dim=-1),
                     class_labels[batch_index].argmax(dim=-1)).item()
                 # Save metric and name of metric
                 if validation_metric_classification.__class__.__name__ in metrics_classification.keys(
                 ):
                     metrics_classification[
                         validation_metric_classification.__class__.
                         __name__].append(metric)
                 else:
                     metrics_classification[validation_metric_classification
                                            .__class__.__name__] = [metric]
             # Calc validation metrics for bounding boxes
             for validation_metric_bounding_box in validation_metrics_bounding_box:
                 # Calc metric
                 metric = validation_metric_bounding_box(
                     misc.bounding_box_xcycwh_to_x0y0x1y1(
                         bounding_box_predictions[
                             batch_index, :bounding_box_labels[batch_index].
                             shape[0]]),
                     misc.bounding_box_xcycwh_to_x0y0x1y1(
                         bounding_box_labels[batch_index])).item()
                 # Save metric and name of metric
                 if validation_metric_bounding_box.__class__.__name__ in metrics_bounding_box.keys(
                 ):
                     metrics_bounding_box[validation_metric_bounding_box.
                                          __class__.__name__].append(metric)
                 else:
                     metrics_bounding_box[validation_metric_bounding_box.
                                          __class__.__name__] = [metric]
             # Calc validation metrics for bounding boxes
             for validation_metric_segmentation in validation_metrics_segmentation:
                 # Calc metric
                 metric = validation_metric_segmentation(
                     instance_predictions[
                         batch_index, :instance_labels[batch_index].
                         shape[0]],
                     instance_labels[batch_index],
                     class_label=class_labels[batch_index].argmax(
                         dim=-1)).item()
                 # Save metric and name of metric
                 if validation_metric_segmentation.__class__.__name__ in metrics_segmentation.keys(
                 ):
                     metrics_segmentation[validation_metric_segmentation.
                                          __class__.__name__].append(metric)
                 else:
                     metrics_segmentation[validation_metric_segmentation.
                                          __class__.__name__] = [metric]
         if index in plot_indexes:
             # Plot
             object_classes = class_predictions[0].argmax(
                 dim=-1).cpu().detach()
             # Case the no objects are detected
             if object_classes.shape[0] > 0:
                 object_indexes = torch.from_numpy(
                     np.argwhere(object_classes.numpy() > 0)[:, 0])
                 bounding_box_predictions = misc.relative_bounding_box_to_absolute(
                     misc.bounding_box_xcycwh_to_x0y0x1y1(
                         bounding_box_predictions[
                             0, object_indexes].cpu().clone().detach()),
                     height=input.shape[-2],
                     width=input.shape[-1])
                 misc.plot_instance_segmentation_overlay_instances_bb_classes(
                     image=input[0],
                     instances=(instance_predictions[0][object_indexes] >
                                0.5).float(),
                     bounding_boxes=bounding_box_predictions,
                     class_labels=object_classes[object_indexes],
                     show=False,
                     save=True,
                     file_path=os.path.join(
                         self.path_save_plots,
                         "validation_plot_is_bb_c_{}_{}.png".format(
                             epoch + 1, index)))
     # Average metrics and save them in logs
     for metric_name in metrics_classification:
         self.logger.log(metric_name=metric_name + "_classification_val",
                         value=float(
                             np.mean(metrics_classification[metric_name])))
     for metric_name in metrics_bounding_box:
         self.logger.log(metric_name=metric_name + "_bounding_box_val",
                         value=float(
                             np.mean(metrics_bounding_box[metric_name])))
     for metric_name in metrics_segmentation:
         metric_values = np.array(metrics_segmentation[metric_name])
         # Save best mIoU model if training is utilized
         if train and "MIoU" in metric_name and float(
                 np.mean(
                     metrics_segmentation[metric_name])) > self.best_miou:
             # Save current mIoU
             self.best_miou = float(
                 np.mean(metric_values[~np.isnan(metric_values)]))
             # Show best MIoU as process name
             setproctitle.setproctitle("Cell-DETR best MIoU={:.4f}".format(
                 self.best_miou))
             # Save model
             torch.save(
                 self.detr.module.state_dict() if isinstance(
                     self.detr, nn.DataParallel) else
                 self.detr.state_dict(),
                 os.path.join(self.path_save_models, "detr_best_model.pt"))
         self.logger.log(
             metric_name=metric_name + "_segmentation_val",
             value=float(np.mean(metric_values[~np.isnan(metric_values)])))
     # Save metrics
     self.logger.save_metrics(path=self.path_save_metrics)