def batch_metrics_step(self, dto: UnetDto, epoch): batch_metrics = MetricMeasuresDtoInit.init_dto() batch_metrics.core = metrics.binary_measures_torch(dto.outputs.core, dto.given_variables.core, self.is_cuda) batch_metrics.penu = metrics.binary_measures_torch(dto.outputs.penu, dto.given_variables.penu, self.is_cuda) return batch_metrics
def batch_metrics_step(self, dto: CaeDto, epoch): batch_metrics = MetricMeasuresDtoInit.init_dto() batch_metrics.lesion = metrics.binary_measures_torch(dto.reconstructions.gtruth.interpolation, dto.given_variables.gtruth.lesion, self.is_cuda) batch_metrics.core = metrics.binary_measures_torch(dto.reconstructions.gtruth.core, dto.given_variables.gtruth.core, self.is_cuda) batch_metrics.penu = metrics.binary_measures_torch(dto.reconstructions.gtruth.penu, dto.given_variables.gtruth.penu, self.is_cuda) return batch_metrics
def batch_metrics_step(self, dto: Dto): return MetricMeasuresDtoInit.init_dto()
def run_training(self): min_loss = self.get_start_min_loss() for epoch in range(self.get_start_epoch(), self._n_epochs): self.adapt_lr(epoch) self.adapt_betas(epoch) # ---------------------------- (1) TRAINING ---------------------------- # self._model.train() epoch_metrics = MetricMeasuresDtoInit.init_dto() for batch in self._dataloader_training: epoch_metrics.add(self.train_batch(batch, epoch)) epoch_metrics.div(len(self._dataloader_training)) del batch self.print_epoch(epoch, 'training', epoch_metrics) self._metric_dtos['training'].append(epoch_metrics) del epoch_metrics # ---------------------------- (2) VALIDATE ---------------------------- # self._model.eval() if self._dataloader_validation is None: epoch_metrics = MetricMeasuresDtoInit.init_dto( 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) else: epoch_metrics = MetricMeasuresDtoInit.init_dto() for batch in self._dataloader_validation: epoch_metrics.add(self.validate_batch(batch, epoch)) epoch_metrics.div(len(self._dataloader_validation)) del batch self.print_epoch(epoch, 'validate', epoch_metrics) self._metric_dtos['validate'].append(epoch_metrics) del epoch_metrics # ------------ (3) SAVE MODEL / VISUALIZE (if new optimum) ------------ # if self._metric_dtos['validate'] and self._metric_dtos['validate'][ -1].loss < min_loss: min_loss = self._metric_dtos['validate'][-1].loss self.save_model() self.save_training( ) # allows to continue if training has been interrupted print('(New optimum: Training saved)', end=' ') self.visualize_epoch(epoch) if epoch % 50 == 0: self.visualize_epoch(epoch) # ----------------- (4) PLOT / SAVE EVALUATION METRICS ---------------- # if epoch > 0: fig, plot = plt.subplots() self.plot_epoch(plot, range(1, epoch + 2)) fig.savefig(self._path_outputs_base + self.FN_VIS_BASE + 'plots.png', bbox_inches='tight', dpi=300) del plot del fig # ------------ (5) SAVE FINAL MODEL / VISUALIZE ------------- # self.save_model('_final') self.visualize_epoch(epoch)