def configure_lr_scheduler(args, optimizer): lr_scheduler = None with logger.LoggingBlock("Learning Rate Scheduler", emph=True): logging.info("class: %s" % args.lr_scheduler) if args.lr_scheduler is not None: # ---------------------------------------------- # Figure out lr_scheduler arguments # ---------------------------------------------- kwargs = tools.kwargs_from_args(args, "lr_scheduler") # ------------------------------------------- # Print arguments # ------------------------------------------- for param, default in sorted(kwargs.items()): logging.info("%s: %s" % (param, default)) # ------------------------------------------- # Add optimizer # ------------------------------------------- kwargs["optimizer"] = optimizer # ------------------------------------------- # Create lr_scheduler instance # ------------------------------------------- lr_scheduler = tools.instance_from_kwargs(args.lr_scheduler_class, kwargs) return lr_scheduler
def setup_logging_and_parse_arguments(blocktitle, yaml_conf=None): # ---------------------------------------------------------------------------- # Get parse commandline and default arguments # ---------------------------------------------------------------------------- args, defaults = _parse_arguments(yaml_conf=yaml_conf) # ---------------------------------------------------------------------------- # Setup logbook before everything else # ---------------------------------------------------------------------------- logger.configure_logging(os.path.join(args.save, 'logbook.txt')) # ---------------------------------------------------------------------------- # Write arguments to file, as txt # ---------------------------------------------------------------------------- tools.write_dictionary_to_file( sorted(vars(args).items()), filename=os.path.join(args.save, 'args.txt')) # ---------------------------------------------------------------------------- # Log arguments # ---------------------------------------------------------------------------- with logger.LoggingBlock(blocktitle, emph=True): for argument, value in sorted(vars(args).items()): reset = colorama.Style.RESET_ALL color = reset if (argument in defaults and value == defaults[argument]) else colorama.Fore.CYAN logging.info('{}{}: {}{}'.format(color, argument, value, reset)) # ---------------------------------------------------------------------------- # Postprocess # ---------------------------------------------------------------------------- args = postprocess_args(args) return args
def _log_statistics(dataset, prefix, name): with logger.LoggingBlock("%s Dataset: %s" % (prefix, name)): example_dict = dataset[0] # get sizes from first dataset example for key, value in sorted(example_dict.items()): if key in ["index", "basename"]: # no need to display these continue if isinstance(value, str): logging.info("{}: {}".format(key, value)) else: logging.info("%s: %s" % (key, _sizes_to_str(value))) logging.info("num_examples: %i" % len(dataset))
def configure_random_seed(args): with logger.LoggingBlock("Random Seeds", emph=True): # python seed = args.seed random.seed(seed) logging.info("Python seed: %i" % seed) # numpy seed += 1 np.random.seed(seed) logging.info("Numpy seed: %i" % seed) # torch seed += 1 torch.manual_seed(seed) logging.info("Torch CPU seed: %i" % seed) # torch cuda seed += 1 torch.cuda.manual_seed(seed) logging.info("Torch CUDA seed: %i" % seed)
def load_checkpoint_saver(args, model_and_loss): with logger.LoggingBlock("Checkpoint", emph=True): checkpoint_saver = CheckpointSaver() checkpoint_stats = None if args.checkpoint is None: logging.info("No checkpoint given.") logging.info("Starting from scratch with random initialization.") elif os.path.isfile(args.checkpoint): logging.info("Loading checkpoint in %s" % args.checkpoint) checkpoint_stats, filename = checkpoint_saver.restore( filename=args.checkpoint, model_and_loss=model_and_loss, include_params=args.checkpoint_include_params, exclude_params=args.checkpoint_exclude_params) elif os.path.isdir(args.checkpoint): if args.checkpoint_mode in ["resume_from_best"]: logging.info("Loading best checkpoint in %s" % args.checkpoint) checkpoint_stats, filename = checkpoint_saver.restore_best( directory=args.checkpoint, model_and_loss=model_and_loss, include_params=args.checkpoint_include_params, exclude_params=args.checkpoint_exclude_params) elif args.checkpoint_mode in ["resume_from_latest"]: logging.info("Loading latest checkpoint in %s" % args.checkpoint) checkpoint_stats, filename = checkpoint_saver.restore_latest( directory=args.checkpoint, model_and_loss=model_and_loss, include_params=args.checkpoint_include_params, exclude_params=args.checkpoint_exclude_params) else: logging.info("Unknown checkpoint_restore '%s' given!" % args.checkpoint_restore) quit() else: logging.info("Could not find checkpoint file or directory '%s'" % args.checkpoint) quit() return checkpoint_saver, checkpoint_stats
def configure_runtime_augmentations(args): with logger.LoggingBlock("Runtime Augmentations", emph=True): training_augmentation = None validation_augmentation = None # ---------------------------------------------------- # Training Augmentation # ---------------------------------------------------- if args.training_augmentation is not None: kwargs = tools.kwargs_from_args(args, "training_augmentation") logging.info("training_augmentation: %s" % args.training_augmentation) for param, default in sorted(kwargs.items()): logging.info(" %s: %s" % (param, default)) kwargs["args"] = args training_augmentation = tools.instance_from_kwargs( args.training_augmentation_class, kwargs) if args.cuda: training_augmentation = training_augmentation.cuda() else: logging.info("training_augmentation: None") # ---------------------------------------------------- # Training Augmentation # ---------------------------------------------------- if args.validation_augmentation is not None: kwargs = tools.kwargs_from_args(args, "validation_augmentation") logging.info("validation_augmentation: %s" % args.training_augmentation) for param, default in sorted(kwargs.items()): logging.info(" %s: %s" % (param, default)) kwargs["args"] = args validation_augmentation = tools.instance_from_kwargs( args.validation_augmentation_class, kwargs) if args.cuda: validation_augmentation = validation_augmentation.cuda() else: logging.info("validation_augmentation: None") return training_augmentation, validation_augmentation
def configure_optimizer(args, model_and_loss): optimizer = None with logger.LoggingBlock("Optimizer", emph=True): if args.optimizer is not None: if model_and_loss.num_parameters() == 0: logging.info("No trainable parameters detected.") logging.info("Setting optimizer to None.") else: logging.info(args.optimizer) # ------------------------------------------- # Figure out all optimizer arguments # ------------------------------------------- all_kwargs = tools.kwargs_from_args(args, "optimizer") # ------------------------------------------- # Get the split of param groups # ------------------------------------------- kwargs_without_groups = { key: value for key, value in all_kwargs.items() if key != "group" } param_groups = all_kwargs["group"] # ---------------------------------------------------------------------- # Print arguments (without groups) # ---------------------------------------------------------------------- for param, default in sorted(kwargs_without_groups.items()): logging.info("%s: %s" % (param, default)) # ---------------------------------------------------------------------- # Construct actual optimizer params # ---------------------------------------------------------------------- kwargs = dict(kwargs_without_groups) if param_groups is None: # --------------------------------------------------------- # Add all trainable parameters if there is no param groups # --------------------------------------------------------- all_trainable_parameters = _generate_trainable_params( model_and_loss) kwargs["params"] = all_trainable_parameters else: # ------------------------------------------- # Add list of parameter groups instead # ------------------------------------------- trainable_parameter_groups = [] dnames, dparams = _param_names_and_trainable_generator( model_and_loss) dnames = set(dnames) dparams = set(list(dparams)) with logger.LoggingBlock("parameter_groups:"): for group in param_groups: # log group settings group_match = group["params"] group_args = { key: value for key, value in group.items() if key != "params" } with logger.LoggingBlock( "%s: %s" % (group_match, group_args)): # retrieve parameters by matching name gnames, gparams = _param_names_and_trainable_generator( model_and_loss, match=group_match) # log all names affected for n in sorted(gnames): logging.info(n) # set generator for group group_args["params"] = gparams # append parameter group trainable_parameter_groups.append(group_args) # update remaining trainable parameters dnames -= set(gnames) dparams -= set(list(gparams)) # append default parameter group trainable_parameter_groups.append( {"params": list(dparams)}) # and log its parameter names with logger.LoggingBlock("default:"): for dname in sorted(dnames): logging.info(dname) # set params in optimizer kwargs kwargs["params"] = trainable_parameter_groups # ------------------------------------------- # Create optimizer instance # ------------------------------------------- optimizer = tools.instance_from_kwargs(args.optimizer_class, kwargs) return optimizer
def configure_data_loaders(args): with logger.LoggingBlock("Datasets", emph=True): def _sizes_to_str(value): if np.isscalar(value): return '[1L]' else: return ' '.join([str([d for d in value.size()])]) def _log_statistics(dataset, prefix, name): with logger.LoggingBlock("%s Dataset: %s" % (prefix, name)): example_dict = dataset[ 0] # get sizes from first dataset example for key, value in sorted(example_dict.items()): if key in ["index", "basename"]: # no need to display these continue if isinstance(value, str): logging.info("{}: {}".format(key, value)) else: logging.info("%s: %s" % (key, _sizes_to_str(value))) logging.info("num_examples: %i" % len(dataset)) # ----------------------------------------------------------------------------------------- # GPU parameters -- turning off pin_memory? for resolving the deadlock? # ----------------------------------------------------------------------------------------- gpuargs = { "num_workers": args.num_workers, "pin_memory": False } if args.cuda else {} train_loader = None validation_loader = None inference_loader = None # ----------------------------------------------------------------------------------------- # Training dataset # ----------------------------------------------------------------------------------------- if args.training_dataset is not None: # ---------------------------------------------- # Figure out training_dataset arguments # ---------------------------------------------- kwargs = tools.kwargs_from_args(args, "training_dataset") kwargs["is_cropped"] = True kwargs["args"] = args # ---------------------------------------------- # Create training dataset # ---------------------------------------------- train_dataset = tools.instance_from_kwargs( args.training_dataset_class, kwargs) # ---------------------------------------------- # Create training loader # ---------------------------------------------- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, **gpuargs) _log_statistics(train_dataset, prefix="Training", name=args.training_dataset) # ----------------------------------------------------------------------------------------- # Validation dataset # ----------------------------------------------------------------------------------------- if args.validation_dataset is not None: # ---------------------------------------------- # Figure out validation_dataset arguments # ---------------------------------------------- kwargs = tools.kwargs_from_args(args, "validation_dataset") kwargs["is_cropped"] = True kwargs["args"] = args # ---------------------------------------------- # Create validation dataset # ---------------------------------------------- validation_dataset = tools.instance_from_kwargs( args.validation_dataset_class, kwargs) # ---------------------------------------------- # Create validation loader # ---------------------------------------------- validation_loader = DataLoader(validation_dataset, batch_size=args.batch_size_val, shuffle=False, drop_last=False, **gpuargs) _log_statistics(validation_dataset, prefix="Validation", name=args.validation_dataset) return train_loader, validation_loader, inference_loader
def configure_model_and_loss(args): # ---------------------------------------------------- # Dynamically load model and loss class with parameters # passed in via "--model_[param]=[value]" or "--loss_[param]=[value]" arguments # ---------------------------------------------------- with logger.LoggingBlock("Model and Loss", emph=True): # ---------------------------------------------------- # Model # ---------------------------------------------------- kwargs = tools.kwargs_from_args(args, "model") kwargs["args"] = args if type(args.checkpoint) == list and len(args.checkpoint) > 1: models = nn.ModuleList([ tools.instance_from_kwargs(args.model_class, kwargs) for _ in args.checkpoint ]) else: models = tools.instance_from_kwargs(args.model_class, kwargs) if hasattr(args, 'avoid_list') and args.avoid_list: for model in models: model.avoid_list = args.avoid_list.split(',') # ---------------------------------------------------- # Training loss # ---------------------------------------------------- training_loss = None if args.training_loss is not None: kwargs = tools.kwargs_from_args(args, "training_loss") kwargs["args"] = args training_loss = tools.instance_from_kwargs( args.training_loss_class, kwargs) # ---------------------------------------------------- # Validation loss # ---------------------------------------------------- validation_loss = None if args.validation_loss is not None: kwargs = tools.kwargs_from_args(args, "validation_loss") kwargs["args"] = args validation_loss = tools.instance_from_kwargs( args.validation_loss_class, kwargs) # ---------------------------------------------------- # Model and loss # ---------------------------------------------------- model_and_loss = ModelAndLoss(args, models, training_loss, validation_loss) # ----------------------------------------------------------- # If Cuda, transfer model to Cuda and wrap with DataParallel. # ----------------------------------------------------------- if args.cuda: model_and_loss = model_and_loss.cuda() # --------------------------------------------------------------- # Report some network statistics # --------------------------------------------------------------- logging.info("Batch Size: %i" % args.batch_size) logging.info("GPGPU: Cuda") if args.cuda else logging.info( "GPGPU: off") logging.info("Network: %s" % args.model) logging.info("Number of parameters: %i" % tools.x2module(model_and_loss).num_parameters()) if training_loss is not None: logging.info("Training Key: %s" % args.training_key) logging.info("Training Loss: %s" % args.training_loss) if validation_loss is not None: logging.info("Validation Key: %s" % args.validation_key) logging.info("Validation Loss: %s" % args.validation_loss) return model_and_loss
def train(opt, netG): # Re-generate dataset frames fps, td, fps_index = utils.get_fps_td_by_index(opt.scale_idx, opt) opt.fps = fps opt.td = td opt.fps_index = fps_index with logger.LoggingBlock("Updating dataset", emph=True): logging.info("{}FPS :{} {}{}".format(green, clear, opt.fps, clear)) logging.info("{}Time-Depth :{} {}{}".format(green, clear, opt.td, clear)) logging.info("{}Sampling-Ratio :{} {}{}".format( green, clear, opt.sampling_rates[opt.fps_index], clear)) opt.dataset.generate_frames(opt.scale_idx) # Initialize noise if not hasattr(opt, 'Z_init_size'): initial_size = utils.get_scales_by_index(0, opt.scale_factor, opt.stop_scale, opt.img_size) initial_size = [int(initial_size * opt.ar), initial_size] opt.Z_init_size = [ opt.batch_size, opt.latent_dim, opt.td, *initial_size ] if opt.vae_levels < opt.scale_idx + 1: D_curr = getattr(networks_3d, opt.discriminator)(opt).to(opt.device) if (opt.netG != '') and (opt.resumed_idx == opt.scale_idx): D_curr.load_state_dict( torch.load('{}/netD_{}.pth'.format( opt.resume_dir, opt.scale_idx - 1))['state_dict']) elif opt.vae_levels < opt.scale_idx: D_curr.load_state_dict( torch.load( '{}/netD_{}.pth'.format(opt.saver.experiment_dir, opt.scale_idx - 1))['state_dict']) # Current optimizers optimizerD = optim.Adam(D_curr.parameters(), lr=opt.lr_d, betas=(opt.beta1, 0.999)) parameter_list = [] # Generator Adversary if not opt.train_all: if opt.vae_levels < opt.scale_idx + 1: train_depth = min(opt.train_depth, len(netG.body) - opt.vae_levels + 1) parameter_list += [{ "params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-train_depth:]) - 1 - idx)) } for idx, block in enumerate(netG.body[-train_depth:])] else: # VAE parameter_list += [{ "params": netG.encode.parameters(), "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx) }, { "params": netG.decoder.parameters(), "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx) }] parameter_list += [{ "params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx)) } for idx, block in enumerate(netG.body[-opt.train_depth:])] else: if len(netG.body) < opt.train_depth: parameter_list += [{ "params": netG.encode.parameters(), "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx) }, { "params": netG.decoder.parameters(), "lr": opt.lr_g * (opt.lr_scale**opt.scale_idx) }] parameter_list += [{ "params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body) - 1 - idx)) } for idx, block in enumerate(netG.body)] else: parameter_list += [{ "params": block.parameters(), "lr": opt.lr_g * (opt.lr_scale**(len(netG.body[-opt.train_depth:]) - 1 - idx)) } for idx, block in enumerate(netG.body[-opt.train_depth:])] optimizerG = optim.Adam(parameter_list, lr=opt.lr_g, betas=(opt.beta1, 0.999)) # Parallel if opt.device == 'cuda': G_curr = torch.nn.DataParallel(netG) if opt.vae_levels < opt.scale_idx + 1: D_curr = torch.nn.DataParallel(D_curr) else: G_curr = netG progressbar_args = { "iterable": range(opt.niter), "desc": "Training scale [{}/{}]".format(opt.scale_idx + 1, opt.stop_scale + 1), "train": True, "offset": 0, "logging_on_update": False, "logging_on_close": True, "postfix": True } epoch_iterator = tools.create_progressbar(**progressbar_args) iterator = iter(data_loader) for iteration in epoch_iterator: try: data = next(iterator) except StopIteration: iterator = iter(opt.data_loader) data = next(iterator) if opt.scale_idx > 0: real, real_zero = data real = real.to(opt.device) real_zero = real_zero.to(opt.device) else: real = data.to(opt.device) real_zero = real noise_init = utils.generate_noise(size=opt.Z_init_size, device=opt.device) ############################ # calculate noise_amp ########################### if iteration == 0: if opt.const_amp: opt.Noise_Amps.append(1) else: with torch.no_grad(): if opt.scale_idx == 0: opt.noise_amp = 1 opt.Noise_Amps.append(opt.noise_amp) else: opt.Noise_Amps.append(0) z_reconstruction, _, _ = G_curr(real_zero, opt.Noise_Amps, mode="rec") RMSE = torch.sqrt(F.mse_loss(real, z_reconstruction)) opt.noise_amp = opt.noise_amp_init * RMSE.item( ) / opt.batch_size opt.Noise_Amps[-1] = opt.noise_amp ############################ # (1) Update VAE network ########################### total_loss = 0 generated, generated_vae, (mu, logvar) = G_curr(real_zero, opt.Noise_Amps, mode="rec") if opt.vae_levels >= opt.scale_idx + 1: rec_vae_loss = opt.rec_loss(generated, real) + opt.rec_loss( generated_vae, real_zero) kl_loss = kl_criterion(mu, logvar) vae_loss = opt.rec_weight * rec_vae_loss + opt.kl_weight * kl_loss total_loss += vae_loss else: ############################ # (2) Update D network: maximize D(x) + D(G(z)) ########################### # train with real ################# # Train 3D Discriminator D_curr.zero_grad() output = D_curr(real) errD_real = -output.mean() # train with fake ################# fake, _ = G_curr(noise_init, opt.Noise_Amps, noise_init=noise_init, mode="rand") # Train 3D Discriminator output = D_curr(fake.detach()) errD_fake = output.mean() gradient_penalty = calc_gradient_penalty(D_curr, real, fake, opt.lambda_grad, opt.device) errD_total = errD_real + errD_fake + gradient_penalty errD_total.backward() optimizerD.step() ############################ # (3) Update G network: maximize D(G(z)) ########################### errG_total = 0 rec_loss = opt.rec_loss(generated, real) errG_total += opt.rec_weight * rec_loss # Train with 3D Discriminator output = D_curr(fake) errG = -output.mean() * opt.disc_loss_weight errG_total += errG total_loss += errG_total G_curr.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(G_curr.parameters(), opt.grad_clip) optimizerG.step() # Update progress bar epoch_iterator.set_description( 'Scale [{}/{}], Iteration [{}/{}]'.format( opt.scale_idx + 1, opt.stop_scale + 1, iteration + 1, opt.niter, )) if opt.visualize: # Tensorboard opt.summary.add_scalar( 'Video/Scale {}/noise_amp'.format(opt.scale_idx), opt.noise_amp, iteration) if opt.vae_levels >= opt.scale_idx + 1: opt.summary.add_scalar( 'Video/Scale {}/KLD'.format(opt.scale_idx), kl_loss.item(), iteration) else: opt.summary.add_scalar( 'Video/Scale {}/rec loss'.format(opt.scale_idx), rec_loss.item(), iteration) opt.summary.add_scalar( 'Video/Scale {}/noise_amp'.format(opt.scale_idx), opt.noise_amp, iteration) if opt.vae_levels < opt.scale_idx + 1: opt.summary.add_scalar( 'Video/Scale {}/errG'.format(opt.scale_idx), errG.item(), iteration) opt.summary.add_scalar( 'Video/Scale {}/errD_fake'.format(opt.scale_idx), errD_fake.item(), iteration) opt.summary.add_scalar( 'Video/Scale {}/errD_real'.format(opt.scale_idx), errD_real.item(), iteration) else: opt.summary.add_scalar( 'Video/Scale {}/Rec VAE'.format(opt.scale_idx), rec_vae_loss.item(), iteration) if iteration % opt.print_interval == 0: with torch.no_grad(): fake_var = [] fake_vae_var = [] for _ in range(3): noise_init = utils.generate_noise(ref=noise_init) fake, fake_vae = G_curr(noise_init, opt.Noise_Amps, noise_init=noise_init, mode="rand") fake_var.append(fake) fake_vae_var.append(fake_vae) fake_var = torch.cat(fake_var, dim=0) fake_vae_var = torch.cat(fake_vae_var, dim=0) opt.summary.visualize_video(opt, iteration, real, 'Real') opt.summary.visualize_video(opt, iteration, generated, 'Generated') opt.summary.visualize_video(opt, iteration, generated_vae, 'Generated VAE') opt.summary.visualize_video(opt, iteration, fake_var, 'Fake var') opt.summary.visualize_video(opt, iteration, fake_vae_var, 'Fake VAE var') epoch_iterator.close() # Save data opt.saver.save_checkpoint({'data': opt.Noise_Amps}, 'Noise_Amps.pth') opt.saver.save_checkpoint( { 'scale': opt.scale_idx, 'state_dict': netG.state_dict(), 'optimizer': optimizerG.state_dict(), 'noise_amps': opt.Noise_Amps, }, 'netG.pth') if opt.vae_levels < opt.scale_idx + 1: opt.saver.save_checkpoint( { 'scale': opt.scale_idx, 'state_dict': D_curr.module.state_dict() if opt.device == 'cuda' else D_curr.state_dict(), 'optimizer': optimizerD.state_dict(), }, 'netD_{}.pth'.format(opt.scale_idx))
batch_size=opt.batch_size, num_workers=4) if opt.stop_scale_time == -1: opt.stop_scale_time = opt.stop_scale opt.dataset = dataset opt.data_loader = data_loader with open(os.path.join(opt.saver.experiment_dir, 'args.txt'), 'w') as args_file: for argument, value in sorted(vars(opt).items()): if type(value) in (str, int, float, tuple, list, bool): args_file.write('{}: {}\n'.format(argument, value)) with logger.LoggingBlock("Commandline Arguments", emph=True): for argument, value in sorted(vars(opt).items()): if type(value) in (str, int, float, tuple, list): logging.info('{}: {}'.format(argument, value)) with logger.LoggingBlock("Experiment Summary", emph=True): video_file_name, checkname, experiment = opt.saver.experiment_dir.split( '/')[-3:] logging.info("{}Video file :{} {}{}".format(magenta, clear, video_file_name, clear)) logging.info("{}Checkname :{} {}{}".format(magenta, clear, checkname, clear)) logging.info("{}Experiment :{} {}{}".format(magenta, clear, experiment, clear)) with logger.LoggingBlock("Commandline Summary", emph=True):
def exec_runtime(args, checkpoint_saver, model_and_loss, optimizer, lr_scheduler, train_loader, validation_loader, inference_loader, training_augmentation, validation_augmentation): # ---------------------------------------------------------------------------------------------- # Validation schedulers are a bit special: # They want to be called with a validation loss.. # ---------------------------------------------------------------------------------------------- validation_scheduler = (lr_scheduler is not None and args.lr_scheduler == "ReduceLROnPlateau") # -------------------------------------------------------- # Log some runtime info # -------------------------------------------------------- with logger.LoggingBlock("Runtime", emph=True): logging.info("start_epoch: %i" % args.start_epoch) logging.info("total_epochs: %i" % args.total_epochs) # --------------------------------------- # Total progress bar arguments # --------------------------------------- progressbar_args = { "desc": "Progress", "initial": args.start_epoch - 1, "invert_iterations": True, "iterable": range(1, args.total_epochs + 1), "logging_on_close": True, "logging_on_update": True, "postfix": False, "unit": "ep" } # -------------------------------------------------------- # Total progress bar # -------------------------------------------------------- print(''), logging.logbook('') total_progress = create_progressbar(**progressbar_args) print("\n") # -------------------------------------------------------- # Remember validation loss # -------------------------------------------------------- best_validation_loss = float( "inf") if args.validation_key_minimize else -float("inf") store_as_best = False for epoch in range(args.start_epoch, args.total_epochs + 1): with logger.LoggingBlock("Epoch %i/%i" % (epoch, args.total_epochs), emph=True): # -------------------------------------------------------- # Update standard learning scheduler # -------------------------------------------------------- if lr_scheduler is not None and not validation_scheduler: lr_scheduler.step(epoch) # -------------------------------------------------------- # Always report learning rate # -------------------------------------------------------- if lr_scheduler is None: logging.info("lr: %s" % format_learning_rate(args.optimizer_lr)) else: logging.info("lr: %s" % format_learning_rate(lr_scheduler.get_lr())) # ------------------------------------------- # Create and run a training epoch # ------------------------------------------- if train_loader is not None: avg_loss_dict = TrainingEpoch( args, desc=" Train", model_and_loss=model_and_loss, optimizer=optimizer, loader=train_loader, augmentation=training_augmentation).run() # ------------------------------------------- # Create and run a validation epoch # ------------------------------------------- if validation_loader is not None: # --------------------------------------------------- # Construct holistic recorder for epoch # --------------------------------------------------- avg_loss_dict = EvaluationEpoch( args, desc="Validate", model_and_loss=model_and_loss, loader=validation_loader, augmentation=validation_augmentation).run() # ---------------------------------------------------------------- # Evaluate whether this is the best validation_loss # ---------------------------------------------------------------- validation_loss = avg_loss_dict[args.validation_key] if args.validation_key_minimize: store_as_best = validation_loss < best_validation_loss else: store_as_best = validation_loss > best_validation_loss if store_as_best: best_validation_loss = validation_loss # ---------------------------------------------------------------- # Update validation scheduler, if one is in place # ---------------------------------------------------------------- if lr_scheduler is not None and validation_scheduler: lr_scheduler.step(validation_loss, epoch=epoch) # ---------------------------------------------------------------- # Also show best loss on total_progress # ---------------------------------------------------------------- total_progress_stats = { "best_" + args.validation_key + "_avg": "%1.4f" % best_validation_loss } total_progress.set_postfix(total_progress_stats) # ---------------------------------------------------------------- # Bump total progress # ---------------------------------------------------------------- total_progress.update() print('') # ---------------------------------------------------------------- # Store checkpoint # ---------------------------------------------------------------- if checkpoint_saver is not None: if args.max_save and (args.min_save <= epoch <= args.max_save): checkpoint_saver.save_latest( directory=args.save, model_and_loss=model_and_loss, stats_dict=dict(avg_loss_dict, epoch=epoch), suffix="_epoch_{}".format(epoch)) checkpoint_saver.save_latest(directory=args.save, model_and_loss=model_and_loss, stats_dict=dict(avg_loss_dict, epoch=epoch), store_as_best=store_as_best) # ---------------------------------------------------------------- # Vertical space between epochs # ---------------------------------------------------------------- print(''), logging.logbook('') # ---------------------------------------------------------------- # Finish # ---------------------------------------------------------------- total_progress.close() logging.info("Finished.") return avg_loss_dict