def test_validate(self): history = validate_model(self.orga, self.model) # input to model is ones # --> after concatenate: 18 ones # Output of each output layer = 18 # labels: 0 and 1 # --> loss (18-0)^2 and (18-1)^2 target = { 'loss': 18**2 + 17**2, 'mc_A_loss': 18**2, 'mc_B_loss': 17**2, } self.model.summary() self.assertDictEqual(history, target)
def validate(self): """ Validate the most recent saved model on all validation files. Will also log the progress, as well as update the summary plot and plot weights and activations of the model. Returns ------- history : dict The history of the validation on all files. A record of validation loss values and metrics values. """ latest_epoch = self.io.get_latest_epoch() if latest_epoch is None: raise ValueError("Can not validate: No saved model found") if self.history.get_state()[-1]["is_validated"] is True: raise ValueError( "Can not validate in epoch {} file {}: " "Has already been validated".format(*latest_epoch)) if self._stored_model is None: model = self.load_saved_model(*latest_epoch) else: model = self._stored_model self._set_up(model, logging=True) epoch_float = self.io.get_epoch_float(*latest_epoch) smry_logger = logging.SummaryLogger(self, model) logging.log_start_validation(self) start_time = time.time() history = backend.validate_model(self, model) elapsed_s = int(time.time() - start_time) self.io.print_log('Validation results:') for metric_name, loss in history.items(): self.io.print_log(f" {metric_name}: \t{loss}") self.io.print_log(f"Elapsed time: {timedelta(seconds=elapsed_s)}\n") smry_logger.write_line(epoch_float, "n/a", history_val=history) update_summary_plot(self) if self.cfg.cleanup_models: self.cleanup_models() return history