def train(self, epochs: int = 20, validate_after_n_epochs: int = 5, save_model_after_n_epochs: int = 10, optimize_only_segmentation_head_after_epoch: int = 150) -> None: """ Training method :param epochs: (int) Number of epochs to perform :param validate_after_n_epochs: (int) Number of epochs after the validation is performed :param save_model_after_n_epochs: (int) Number epochs after the current models is saved :param optimize_only_segmentation_head_after_epoch: (int) Number of epochs after only the seg. head is trained """ # Model into train mode self.detr.train() # Model to device self.detr.to(self.device) # Init progress bar self.progress_bar = tqdm(total=epochs * len(self.training_dataset.dataset)) # Main trainings loop for epoch in range(epochs): for input, instance_labels, bounding_box_labels, class_labels in self.training_dataset: # Update progress bar self.progress_bar.update(n=input.shape[0]) # Data to device input = input.to(self.device) instance_labels = misc.iterable_to_device(instance_labels, device=self.device) bounding_box_labels = misc.iterable_to_device( bounding_box_labels, device=self.device) class_labels = misc.iterable_to_device(class_labels, device=self.device) # Reset gradients self.detr.zero_grad() # Get prediction class_predictions, bounding_box_predictions, instance_predictions = self.detr( input) # Calc loss loss_classification, loss_bounding_box, loss_segmentation = self.loss_function( class_predictions, bounding_box_predictions, instance_predictions, class_labels, bounding_box_labels, instance_labels) # Case if the whole network is optimized if epoch < optimize_only_segmentation_head_after_epoch: # Perform backward pass to compute the gradients (loss_classification + loss_bounding_box + loss_segmentation).backward() # Optimize detr self.detr_optimizer.step() else: # Perform backward pass to compute the gradients loss_segmentation.backward() # Optimize detr self.detr_segmentation_optimizer.step() # Show losses in progress bar self.progress_bar.set_description( "Epoch {}/{} Best val. mIoU={:.4f} Loss C.={:.4f} Loss BB.={:.4f} Loss Seg.={:.4f}" .format(epoch + 1, epochs, self.best_miou, loss_classification.item(), loss_bounding_box.item(), loss_segmentation.item())) # Log losses self.logger.log(metric_name="loss_classification", value=loss_classification.item()) self.logger.log(metric_name="loss_bounding_box", value=loss_bounding_box.item()) self.logger.log(metric_name="loss_segmentation", value=loss_segmentation.item()) # Learning rate schedule step if self.learning_rate_schedule is not None: self.learning_rate_schedule.step() # Validate if (epoch + 1) % validate_after_n_epochs == 0: self.validate(epoch=epoch, train=True) # Save model if (epoch + 1) % save_model_after_n_epochs == 0: torch.save( self.detr.module.state_dict() if isinstance( self.detr, nn.DataParallel) else self.detr.state_dict(), os.path.join(self.path_save_models, "detr_{}.pt".format(epoch))) # Final validation self.validate(epoch=epoch, number_of_plots=30) # Close progress bar self.progress_bar.close() # Load best model self.detr.state_dict( torch.load( os.path.join(self.path_save_models, "detr_best_model.pt")))
def validate(self, validation_metrics_classification: Tuple[nn.Module, ...] = ( validation_metric.ClassificationAccuracy(), ), validation_metrics_bounding_box: Tuple[nn.Module, ...] = ( nn.L1Loss(), nn.MSELoss(), validation_metric.BoundingBoxIoU(), validation_metric.BoundingBoxGIoU()), validation_metrics_segmentation: Tuple[nn.Module, ...] = ( validation_metric.Accuracy(), validation_metric.Precision(), validation_metric.Recall(), validation_metric.F1(), validation_metric.IoU(), validation_metric.MIoU(), validation_metric.Dice(), validation_metric.CellIoU(), validation_metric.MeanAveragePrecision(), validation_metric.InstancesAccuracy()), epoch: int = -1, number_of_plots: int = 5, train: bool = False) -> None: """ Validation method :param validation_metrics_classification: (Tuple[nn.Module, ...]) Validation modules for classification :param validation_metrics_bounding_box: (Tuple[nn.Module, ...]) Validation modules for bounding boxes :param validation_metrics_segmentation: (Tuple[nn.Module, ...]) Validation modules for segmentation :param epoch: (int) Current epoch :param number_of_plots: (int) Number of validation plot to be produced :param train: (bool) Train flag if set best model is saved based on val iou """ # DETR to device self.detr.to(self.device) # DETR into eval mode self.detr.eval() # Init dicts to store metrics metrics_classification = dict() metrics_bounding_box = dict() metrics_segmentation = dict() # Init indexes of elements to be plotted plot_indexes = np.random.choice(np.arange( 0, len(self.validation_dataset)), number_of_plots, replace=False) # Main loop over the validation set for index, batch in enumerate(self.validation_dataset): # Get data from batch input, instance_labels, bounding_box_labels, class_labels = batch # Data to device input = input.to(self.device) instance_labels = misc.iterable_to_device(instance_labels, device=self.device) bounding_box_labels = misc.iterable_to_device(bounding_box_labels, device=self.device) class_labels = misc.iterable_to_device(class_labels, device=self.device) # Get prediction class_predictions, bounding_box_predictions, instance_predictions = self.detr( input) # Perform matching matching_indexes = self.loss_function.matcher( class_predictions, bounding_box_predictions, class_labels, bounding_box_labels) # Apply permutation to labels and predictions class_predictions, class_labels = self.loss_function.apply_permutation( prediction=class_predictions, label=class_labels, indexes=matching_indexes) bounding_box_predictions, bounding_box_labels = self.loss_function.apply_permutation( prediction=bounding_box_predictions, label=bounding_box_labels, indexes=matching_indexes) instance_predictions, instance_labels = self.loss_function.apply_permutation( prediction=instance_predictions, label=instance_labels, indexes=matching_indexes) for batch_index in range(len(class_labels)): # Calc validation metrics for classification for validation_metric_classification in validation_metrics_classification: # Calc metric metric = validation_metric_classification( class_predictions[ batch_index, :class_labels[batch_index].shape[0]]. argmax(dim=-1), class_labels[batch_index].argmax(dim=-1)).item() # Save metric and name of metric if validation_metric_classification.__class__.__name__ in metrics_classification.keys( ): metrics_classification[ validation_metric_classification.__class__. __name__].append(metric) else: metrics_classification[validation_metric_classification .__class__.__name__] = [metric] # Calc validation metrics for bounding boxes for validation_metric_bounding_box in validation_metrics_bounding_box: # Calc metric metric = validation_metric_bounding_box( misc.bounding_box_xcycwh_to_x0y0x1y1( bounding_box_predictions[ batch_index, :bounding_box_labels[batch_index]. shape[0]]), misc.bounding_box_xcycwh_to_x0y0x1y1( bounding_box_labels[batch_index])).item() # Save metric and name of metric if validation_metric_bounding_box.__class__.__name__ in metrics_bounding_box.keys( ): metrics_bounding_box[validation_metric_bounding_box. __class__.__name__].append(metric) else: metrics_bounding_box[validation_metric_bounding_box. __class__.__name__] = [metric] # Calc validation metrics for bounding boxes for validation_metric_segmentation in validation_metrics_segmentation: # Calc metric metric = validation_metric_segmentation( instance_predictions[ batch_index, :instance_labels[batch_index]. shape[0]], instance_labels[batch_index], class_label=class_labels[batch_index].argmax( dim=-1)).item() # Save metric and name of metric if validation_metric_segmentation.__class__.__name__ in metrics_segmentation.keys( ): metrics_segmentation[validation_metric_segmentation. __class__.__name__].append(metric) else: metrics_segmentation[validation_metric_segmentation. __class__.__name__] = [metric] if index in plot_indexes: # Plot object_classes = class_predictions[0].argmax( dim=-1).cpu().detach() # Case the no objects are detected if object_classes.shape[0] > 0: object_indexes = torch.from_numpy( np.argwhere(object_classes.numpy() > 0)[:, 0]) bounding_box_predictions = misc.relative_bounding_box_to_absolute( misc.bounding_box_xcycwh_to_x0y0x1y1( bounding_box_predictions[ 0, object_indexes].cpu().clone().detach()), height=input.shape[-2], width=input.shape[-1]) misc.plot_instance_segmentation_overlay_instances_bb_classes( image=input[0], instances=(instance_predictions[0][object_indexes] > 0.5).float(), bounding_boxes=bounding_box_predictions, class_labels=object_classes[object_indexes], show=False, save=True, file_path=os.path.join( self.path_save_plots, "validation_plot_is_bb_c_{}_{}.png".format( epoch + 1, index))) # Average metrics and save them in logs for metric_name in metrics_classification: self.logger.log(metric_name=metric_name + "_classification_val", value=float( np.mean(metrics_classification[metric_name]))) for metric_name in metrics_bounding_box: self.logger.log(metric_name=metric_name + "_bounding_box_val", value=float( np.mean(metrics_bounding_box[metric_name]))) for metric_name in metrics_segmentation: metric_values = np.array(metrics_segmentation[metric_name]) # Save best mIoU model if training is utilized if train and "MIoU" in metric_name and float( np.mean( metrics_segmentation[metric_name])) > self.best_miou: # Save current mIoU self.best_miou = float( np.mean(metric_values[~np.isnan(metric_values)])) # Show best MIoU as process name setproctitle.setproctitle("Cell-DETR best MIoU={:.4f}".format( self.best_miou)) # Save model torch.save( self.detr.module.state_dict() if isinstance( self.detr, nn.DataParallel) else self.detr.state_dict(), os.path.join(self.path_save_models, "detr_best_model.pt")) self.logger.log( metric_name=metric_name + "_segmentation_val", value=float(np.mean(metric_values[~np.isnan(metric_values)]))) # Save metrics self.logger.save_metrics(path=self.path_save_metrics)
def test( self, test_metrics_classification: Tuple[nn.Module, ...] = ( validation_metric.ClassificationAccuracy(), ), test_metrics_bounding_box: Tuple[nn.Module, ...] = ( nn.L1Loss(), nn.MSELoss(), validation_metric.BoundingBoxIoU(), validation_metric.BoundingBoxGIoU()), test_metrics_segmentation: Tuple[nn.Module, ...] = ( validation_metric.Accuracy(), validation_metric.Precision(), validation_metric.Recall(), validation_metric.F1(), validation_metric.IoU(), validation_metric.MIoU(), validation_metric.Dice(), validation_metric.CellIoU(), validation_metric.MeanAveragePrecision(), validation_metric.InstancesAccuracy()) ) -> None: """ Test method :param test_metrics_classification: (Tuple[nn.Module, ...]) Test modules for classification :param test_metrics_bounding_box: (Tuple[nn.Module, ...]) Test modules for bounding boxes :param test_metrics_segmentation: (Tuple[nn.Module, ...]) Test modules for segmentation """ # DETR to device self.detr.to(self.device) # DETR into eval mode self.detr.eval() # Init dicts to store metrics metrics_classification = dict() metrics_bounding_box = dict() metrics_segmentation = dict() # Main loop over the test set for index, batch in enumerate(self.test_dataset): # Get data from batch input, instance_labels, bounding_box_labels, class_labels = batch # Data to device input = input.to(self.device) instance_labels = misc.iterable_to_device(instance_labels, device=self.device) bounding_box_labels = misc.iterable_to_device(bounding_box_labels, device=self.device) class_labels = misc.iterable_to_device(class_labels, device=self.device) # Get prediction class_predictions, bounding_box_predictions, instance_predictions = self.detr( input) # Perform matching matching_indexes = self.loss_function.matcher( class_predictions, bounding_box_predictions, class_labels, bounding_box_labels) # Apply permutation to labels and predictions class_predictions, class_labels = self.loss_function.apply_permutation( prediction=class_predictions, label=class_labels, indexes=matching_indexes) bounding_box_predictions, bounding_box_labels = self.loss_function.apply_permutation( prediction=bounding_box_predictions, label=bounding_box_labels, indexes=matching_indexes) instance_predictions, instance_labels = self.loss_function.apply_permutation( prediction=instance_predictions, label=instance_labels, indexes=matching_indexes) for batch_index in range(len(class_labels)): # Calc test metrics for classification for test_metric_classification in test_metrics_classification: # Calc metric metric = test_metric_classification( class_predictions[ batch_index, :class_labels[batch_index].shape[0]]. argmax(dim=-1), class_labels[batch_index].argmax(dim=-1)).item() # Save metric and name of metric if test_metric_classification.__class__.__name__ in metrics_classification.keys( ): metrics_classification[test_metric_classification. __class__.__name__].append( metric) else: metrics_classification[ test_metric_classification.__class__.__name__] = [ metric ] # Calc test metrics for bounding boxes for test_metric_bounding_box in test_metrics_bounding_box: # Calc metric metric = test_metric_bounding_box( misc.bounding_box_xcycwh_to_x0y0x1y1( bounding_box_predictions[ batch_index, :bounding_box_labels[batch_index]. shape[0]]), misc.bounding_box_xcycwh_to_x0y0x1y1( bounding_box_labels[batch_index])).item() # Save metric and name of metric if test_metric_bounding_box.__class__.__name__ in metrics_bounding_box.keys( ): metrics_bounding_box[test_metric_bounding_box. __class__.__name__].append(metric) else: metrics_bounding_box[ test_metric_bounding_box.__class__.__name__] = [ metric ] # Calc test metrics for bounding boxes for test_metric_segmentation in test_metrics_segmentation: # Calc metric metric = test_metric_segmentation( instance_predictions[ batch_index, :instance_labels[batch_index]. shape[0]], instance_labels[batch_index], class_label=class_labels[batch_index].argmax( dim=-1)).item() # Save metric and name of metric if test_metric_segmentation.__class__.__name__ in metrics_segmentation.keys( ): metrics_segmentation[test_metric_segmentation. __class__.__name__].append(metric) else: metrics_segmentation[ test_metric_segmentation.__class__.__name__] = [ metric ] # Plot object_classes = class_predictions[0].argmax(dim=-1).cpu().detach() # Case the no objects are detected if object_classes.shape[0] > 0: object_indexes = torch.from_numpy( np.argwhere(object_classes.numpy() > 0)[:, 0]) bounding_box_predictions = misc.relative_bounding_box_to_absolute( misc.bounding_box_xcycwh_to_x0y0x1y1( bounding_box_predictions[ 0, object_indexes].cpu().clone().detach()), height=input.shape[-2], width=input.shape[-1]) misc.plot_instance_segmentation_overlay_instances_bb_classes( image=input[0], instances=(instance_predictions[0][object_indexes] > 0.5).float(), bounding_boxes=bounding_box_predictions, class_labels=object_classes[object_indexes], show=False, save=True, file_path=os.path.join( self.path_save_plots, "test_plot_{}_is_bb_c.png".format(index))) misc.plot_instance_segmentation_overlay_instances_bb_classes( image=input[0], instances=(instance_predictions[0][object_indexes] > 0.5).float(), bounding_boxes=bounding_box_predictions, class_labels=object_classes[object_indexes], show=False, save=True, show_class_label=False, file_path=os.path.join( self.path_save_plots, "test_plot_{}_is_bb.png".format(index))) misc.plot_instance_segmentation_overlay_instances( image=input[0], instances=(instance_predictions[0][object_indexes] > 0.5).float(), class_labels=object_classes[object_indexes], show=False, save=True, file_path=os.path.join( self.path_save_plots, "test_plot_{}_is.png".format(index))) misc.plot_instance_segmentation_overlay_bb_classes( image=input[0], bounding_boxes=bounding_box_predictions, class_labels=object_classes[object_indexes], show=False, save=True, file_path=os.path.join( self.path_save_plots, "test_plot_{}_bb_c.png".format(index))) misc.plot_instance_segmentation_labels( instances=(instance_predictions[0][object_indexes] > 0.5).float(), bounding_boxes=bounding_box_predictions, class_labels=object_classes[object_indexes], show=False, save=True, file_path=os.path.join( self.path_save_plots, "test_plot_{}_bb_no_overlay_.png".format(index)), show_class_label=False, white_background=True) misc.plot_instance_segmentation_map_label( instances=(instance_predictions[0][object_indexes] > 0.5).float(), class_labels=object_classes[object_indexes], show=False, save=True, file_path=os.path.join( self.path_save_plots, "test_plot_{}_no_overlay.png".format(index)), white_background=True) # Average metrics and save them in logs for metric_name in metrics_classification: print(metric_name + "_classification_test=", float(np.mean(metrics_classification[metric_name]))) self.logger.log(metric_name=metric_name + "_classification_test", value=float( np.mean(metrics_classification[metric_name]))) for metric_name in metrics_bounding_box: print(metric_name + "_bounding_box_test=", float(np.mean(metrics_bounding_box[metric_name]))) self.logger.log(metric_name=metric_name + "_bounding_box_test", value=float( np.mean(metrics_bounding_box[metric_name]))) for metric_name in metrics_segmentation: metric_values = np.array(metrics_segmentation[metric_name]) print(metric_name + "_segmentation_test=", float(np.mean(metric_values[~np.isnan(metric_values)]))) self.logger.log( metric_name=metric_name + "_segmentation_test", value=float(np.mean(metric_values[~np.isnan(metric_values)]))) # Save metrics self.logger.save_metrics(path=self.path_save_metrics)
def inference(self) -> None: """ Inference method: plots segmentation mask and bounding boxes for visual inspections. """ # DETR to device self.detr.to(self.device) # DETR into eval mode self.detr.eval() # Init dicts to store metrics metrics_classification = dict() metrics_bounding_box = dict() metrics_segmentation = dict() # Main loop over the test set for index, batch in enumerate(self.test_dataset): # Get data from batch input, instance_labels, bounding_box_labels, class_labels = batch # Data to device input = input.to(self.device) instance_labels = misc.iterable_to_device(instance_labels, device=self.device) bounding_box_labels = misc.iterable_to_device(bounding_box_labels, device=self.device) class_labels = misc.iterable_to_device(class_labels, device=self.device) # Get prediction class_predictions, bounding_box_predictions, instance_predictions = self.detr( input) # Perform matching matching_indexes = self.loss_function.matcher( class_predictions, bounding_box_predictions, class_labels, bounding_box_labels) # Apply permutation to labels and predictions class_predictions, class_labels = self.loss_function.apply_permutation( prediction=class_predictions, label=class_labels, indexes=matching_indexes) bounding_box_predictions, bounding_box_labels = self.loss_function.apply_permutation( prediction=bounding_box_predictions, label=bounding_box_labels, indexes=matching_indexes) instance_predictions, instance_labels = self.loss_function.apply_permutation( prediction=instance_predictions, label=instance_labels, indexes=matching_indexes) # Plot object_classes = class_predictions[0].argmax(dim=-1).cpu().detach() # Case the no objects are detected if object_classes.shape[0] > 0: object_indexes = torch.from_numpy( np.argwhere(object_classes.numpy() > 0)[:, 0]) bounding_box_predictions = misc.relative_bounding_box_to_absolute( misc.bounding_box_xcycwh_to_x0y0x1y1( bounding_box_predictions[ 0, object_indexes].cpu().clone().detach()), height=input.shape[-2], width=input.shape[-1]) misc.plot_instance_segmentation_overlay_instances( image=input[0], instances=(instance_predictions[0][object_indexes] > 0.5).float(), class_labels=object_classes[object_indexes], colors=self.colors, show=True, save=False, file_path=os.path.join( self.path_save_plots, "test_plot_{}_is.png".format(index))) misc.plot_instance_segmentation_overlay_bb_classes( image=input[0], bounding_boxes=bounding_box_predictions, class_labels=object_classes[object_indexes], colors=self.colors, class_list=self.class_labels, show=True, save=False, file_path=os.path.join( self.path_save_plots, "test_plot_{}_bb_c.png".format(index)))