def train(self) -> bool: """ main training function """ # setup data output directories: setup_directories() save_codebase_of_run(self.arguments) # data gathering progress = [] epoch = 0 try: print( f"{PRINTCOLOR_BOLD}Started training with the following config:{PRINTCOLOR_END}\n{self.arguments}\n\n" ) print(self._log_header) best_metrics = (math.inf, 0) patience = self._patience # run for epoch in range(self.arguments.epochs): # do epoch epoch_progress, best_metrics, patience = self._epoch_iteration( epoch, best_metrics, patience) # add progress-list to global progress-list progress += epoch_progress # write progress to pickle file (overwrite because there is no point keeping seperate versions) DATA_MANAGER.save_python_obj( progress, os.path.join(RESULTS_DIR, DATA_MANAGER.stamp, PROGRESS_DIR, "progress_list"), print_success=False) # flush prints sys.stdout.flush() if patience == 0: break except KeyboardInterrupt as e: print(f"Killed by user: {e}") save_models([self.model], f"KILLED_at_epoch_{epoch}") return False except Exception as e: print(e) save_models([self.model], f"CRASH_at_epoch_{epoch}") raise e # flush prints sys.stdout.flush() # example last save save_models([self.model], "finished") return True
def report_error(error, model, episode, metrics): print(error) from utils.constants import RESULTS_DIR with open( os.path.join(DATA_MANAGER.directory, RESULTS_DIR, DATA_MANAGER.stamp, OUTPUT_DIR, "error_report.txt"), "w") as f: f.write(str(error)) f.write("\n") summary = traceback.extract_tb(error.__traceback__) for x in summary: f.write(str(x.__repr__())) f.write("\n\n") variables = {**locals(), **globals()} for key, variable in variables.items(): try: shape = "absent" try: shape = variable.shape f.write(str(key) + " shape " + str(shape)) except: f.write( str(key) + " " + str(variable) + " shape " + str(shape)) except Exception as error2: f.write(str(error2)) DATA_MANAGER.write_to_file( os.path.join(RESULTS_DIR, DATA_MANAGER.stamp, OUTPUT_DIR, "log.txt"), metrics.log) save_models([model, metrics], f"CRASH_at_epoch_{episode}") raise error
def _epoch_iteration(self, epoch_num: int, best_metrics: Tuple[float, float], patience: int) -> Tuple[List, Tuple, int]: """ one epoch implementation """ if not self.arguments.train_classifier: self.loss_function.reset() progress = [] train_accuracy = 0 train_loss = 0 data_loader_length = len(self.data_loader_train) for i, (batch, targets, lengths) in enumerate(self.data_loader_train): print(f'Train: {i}/{data_loader_length} \r', end='') # do forward pass and whatnot on batch loss_batch, accuracy_batch = self._batch_iteration( batch, targets, lengths, i) train_loss += loss_batch train_accuracy += accuracy_batch # add to list somehow: progress.append({"loss": loss_batch, "acc": accuracy_batch}) # calculate amount of batches and walltime passed batches_passed = i + (epoch_num * len(self.data_loader_train)) time_passed = datetime.now() - DATA_MANAGER.actual_date # run on validation set and print progress to terminal # if we have eval_frequency or if we have finished the epoch if (batches_passed % self.arguments.eval_freq) == 0 or (i + 1 == data_loader_length): loss_validation, acc_validation = self._evaluate() new_best = False if self.model.compare_metric(best_metrics, loss_validation, acc_validation): save_models([self.model], 'model_best') best_metrics = (loss_validation, acc_validation) new_best = True patience = self._patience else: patience -= 1 self._log(loss_validation, acc_validation, (train_loss / (i + 1)), (train_accuracy / (i + 1)), batches_passed, float(time_passed.microseconds), epoch_num, i, data_loader_length, new_best) # check if runtime is expired if (time_passed.total_seconds() > (self.arguments.max_training_minutes * 60)) \ and self.arguments.max_training_minutes > 0: raise KeyboardInterrupt( f"Process killed because {self.arguments.max_training_minutes} minutes passed " f"since {DATA_MANAGER.actual_date}. Time now is {datetime.now()}" ) if patience == 0: break return progress, best_metrics, patience
def train(self) -> bool: """ main training function :return: """ # setup data output directories: setup_directories() save_codebase_of_run(self.arguments) # data gathering progress = [] try: print( f"{PRINTCOLOR_BOLD}Started training with the following config:{PRINTCOLOR_END}\n{self.arguments}" ) # run for epoch in range(self.arguments.epochs): print( f"\n\n{PRINTCOLOR_BOLD}Starting epoch{PRINTCOLOR_END} {epoch}/{self.arguments.epochs} at {str(datetime.now())}" ) # do epoch epoch_progress = self.epoch_iteration(epoch) # add progress progress += epoch_progress # write progress to pickle file (overwrite because there is no point keeping seperate versions) DATA_MANAGER.save_python_obj( progress, f"{DATA_MANAGER.stamp}/{PROGRESS_DIR}/progress_list", print_success=False) # write models if needed (don't save the first one if (((epoch + 1) % self.arguments.saving_freq) == 0): save_models(self.discriminator, self.generator, self.embedder, f"Models_at_epoch_{epoch}") # flush prints sys.stdout.flush() except KeyboardInterrupt as e: print(f"Killed by user: {e}") save_models(self.discriminator, self.generator, self.embedder, f"KILLED_at_epoch_{epoch}") return False except Exception as e: print(e) save_models(self.discriminator, self.generator, self.embedder, f"CRASH_at_epoch_{epoch}") raise e # flush prints sys.stdout.flush() # example last save save_models(self.discriminator, self.generator, self.embedder, "finished") return True
def train(self): # setup data output directories: if self.args.debug: print("\n \033[1;32m Note: DEBUG mode active!!!! \033[0m \n") else: setup_directories() # data gathering progress = [] try: print( f"{PRINTCOLOR_BOLD}Started training with the following config:{PRINTCOLOR_END}\n{self.args}" ) time_per_epoch = [] avg_time_per_batch = [] # run for epoch in range(self.args.epochs): if self.args.timing: epoch_start = time.process_time() print( f"\n\n{PRINTCOLOR_BOLD}Starting epoch{PRINTCOLOR_END} {epoch}/{self.args.epochs} at {str(datetime.now())}" ) # do epoch epoch_progress, time_per_batch = self.epoch_iteration(epoch) # update learning rate self.gen_lr_sched.step() self.dis_lr_sched.step() self.calnet_lr_sched.step() # add progress progress += epoch_progress if self.args.debug == False: # write models if needed (don't save the first one if (((epoch + 1) % self.args.saving_freq) == 0): save_models(self.discriminator, self.generator, self.calibration_net, f"Models_at_epoch_{epoch}") # flush prints sys.stdout.flush() if self.args.timing: epoch_end = time.process_time() - epoch_start time_per_epoch.append(epoch_end) avg_time_per_batch.append(np.mean(time_per_batch)) except KeyboardInterrupt as e: print(f"Killed by user: {e}") if self.args.debug == False: save_models(self.discriminator, self.generator, self.calibration_net, f"KILLED_at_epoch_{epoch}") return False except Exception as e: print(e) if self.args.debug == False: save_models(self.discriminator, self.generator, self.calibration_net, f"CRASH_at_epoch_{epoch}") raise e # flush prints sys.stdout.flush() if self.args.debug == False: save_models(self.discriminator, self.generator, self.calibration_net, "finished") return True
def validation_plots(batch, generator, calibration_net, discriminator, args, batch_idx=0): assert instance_checker(generator, GeneralGenerator) assert instance_checker(calibration_net, GeneralGenerator) assert instance_checker(discriminator, GeneralDiscriminator) if args.dataset == "LIDC": images, labels, gt_dist = unpack_batch(batch) gt_labels = None else: images, labels = unpack_batch(batch) gt_dist = None gt_labels = None if args.dataset == "CITYSCAPES19": bb_preds = batch["bb_preds"].to(DEVICE).float() bb_preds = torch.eye(LABELS_CHANNELS)[ bb_preds[:, 1, :, :].long()].permute(0, 3, 1, 2).to(DEVICE) one_hot_labels = torch.eye(LABELS_CHANNELS)[ labels[:, 1, :, :].long()].permute(0, 3, 1, 2).to(DEVICE) overlapped_mask = get_cs_ignore_mask(bb_preds, one_hot_labels) else: bb_preds = None overlapped_mask = None if (args.dataset == "CITYSCAPES19" and args.class_flip): gt_labels = labels.clone() labels = torch.eye(LABELS_CHANNELS)[labels[:, 1, :, :].long()].permute( 0, 3, 1, 2) calnet_preds, calnet_labelled_imgs, fake_labels, pred_dist, al_maps, gan_al_maps = test_forward_pass( images, labels, bb_preds, generator, calibration_net, discriminator, args) # save best calibration net if args.dataset == "LIDC": lab_dist = torch.eye(LABELS_CHANNELS)[(gt_dist).long()].permute( 1, 0, 4, 2, 3).to(DEVICE).mean(0) eps = 1e-7 kl = lambda p, q: (-p.clamp(min=eps, max=1 - eps) * torch.log( q.clamp(min=eps, max=1 - eps)) + p.clamp(min=eps, max=1 - eps) * torch.log(p.clamp(min=eps, max=1 - eps))).sum(1) calnet_score = kl(calnet_preds.detach(), lab_dist).mean() if args.generator == "EmptyGenerator" and (not args.debug): wandb.log({"Calnet score": calnet_score}) global BEST_CALNET_SCORE if args.mode == "train" and args.generator == "EmptyGenerator" and ( not args.debug ) and calnet_score is not None and calnet_score < BEST_CALNET_SCORE: BEST_CALNET_SCORE = calnet_score print( f"{PRINTCOLOR_GREEN} Saved New Best Calibration Net! {PRINTCOLOR_END}" ) save_models(discriminator, generator, calibration_net, f"Best_Model") # # log stats ged = compute_stats(args, generator, images, calnet_preds, calnet_labelled_imgs, fake_labels, pred_dist, gan_al_maps, labels, gt_dist, gt_labels, overlapped_mask, b_index=batch_idx) global BEST_GED if args.mode == "train" and ( not args.debug) and ged is not None and ged < BEST_GED: BEST_GED = ged print(f"{PRINTCOLOR_GREEN} Saved New Best Model! {PRINTCOLOR_END}") save_models(discriminator, generator, calibration_net, f"Best_Model") # Plots comparison_figure = plot_comparison_figure(batch, calnet_preds, fake_labels, al_maps, gan_al_maps, generator, calibration_net, discriminator, args) if args.dataset == "CAMVID" or args.dataset == "CITYSCAPES19": calibration_figure = plot_calibration_figure(labels, calnet_preds, pred_dist, overlapped_mask, args) if instance_checker(generator, GeneralVAE): plotted_samples = generator.plot_sample_preds( images, labels, calnet_preds, pred_dist, gt_dist, n_preds=args.n_generator_samples_test, dataset=args.dataset) if not args.debug: # save and log comparison_figure = torch.from_numpy( np.moveaxis(comparison_figure, -1, 0)).float() save_example_images(comparison_figure, batch_idx, "comparison", "png") wandb.log({ "Results": wandb.Image(vutils.make_grid(comparison_figure, normalize=True)) }) if args.dataset == "CITYSCAPES19": calibration_figure = torch.from_numpy( np.moveaxis(calibration_figure, -1, 0)).float() wandb.log({ "Calibration": wandb.Image( vutils.make_grid(calibration_figure, normalize=True)) }) if instance_checker(generator, GeneralVAE): plotted_samples = torch.from_numpy( np.moveaxis(plotted_samples, -1, 0)).float() wandb.log({ "Plotted samples": wandb.Image(vutils.make_grid(plotted_samples, normalize=True)) })