def configure_lr_scheduler(args, optimizer): lr_scheduler = None with logger.LoggingBlock("Learning Rate Scheduler", emph=True): logging.info("class: %s" % args.lr_scheduler) if args.lr_scheduler is not None: # ---------------------------------------------- # Figure out lr_scheduler arguments # ---------------------------------------------- kwargs = tools.kwargs_from_args(args, "lr_scheduler") # ------------------------------------------- # Print arguments # ------------------------------------------- for param, default in sorted(kwargs.items()): logging.info("%s: %s" % (param, default)) # ------------------------------------------- # Add optimizer # ------------------------------------------- kwargs["optimizer"] = optimizer # ------------------------------------------- # Create lr_scheduler instance # ------------------------------------------- lr_scheduler = tools.instance_from_kwargs(args.lr_scheduler_class, kwargs) return lr_scheduler
def configure_attack(args): attack = None with logger.LoggingBlock("Adversarial Attack:", emph=True): if args.attack is not None: # ---------------------------------------------- # Figure out keyword arguments # ---------------------------------------------- kwargs = typeinf.kwargs_from_args(args, "attack") # ------------------------------------------- # Log arguments # ------------------------------------------- logging.info("%s" % args.attack) for param, default in sorted(kwargs.items()): logging.info("%s: %s" % (param, default)) # ------------------------------------------- # Create instance # ------------------------------------------- kwargs["args"] = args attack = typeinf.instance_from_kwargs(args.attack_class, kwargs) else: logging.info("None") return attack
def setup_logging_and_parse_arguments(blocktitle): # ---------------------------------------------------------------------------- # Get parse commandline and default arguments # ---------------------------------------------------------------------------- args, defaults = _parse_arguments() # ---------------------------------------------------------------------------- # Setup logbook before everything else # ---------------------------------------------------------------------------- logger.configure_logging(os.path.join(args.save, 'logbook.txt')) # ---------------------------------------------------------------------------- # Write arguments to file, as txt # ---------------------------------------------------------------------------- tools.write_dictionary_to_file(sorted(vars(args).items()), filename=os.path.join( args.save, 'args.txt')) # ---------------------------------------------------------------------------- # Log arguments # ---------------------------------------------------------------------------- with logger.LoggingBlock(blocktitle, emph=True): for argument, value in sorted(vars(args).items()): reset = colorama.Style.RESET_ALL color = reset if value == defaults[argument] else colorama.Fore.CYAN logging.info('{}{}: {}{}'.format(color, argument, value, reset)) # ---------------------------------------------------------------------------- # Postprocess # ---------------------------------------------------------------------------- args = postprocess_args(args) return args
def configure_model_and_loss(args): # ---------------------------------------------------- # Dynamically load model and loss class with parameters # passed in via "--model_[param]=[value]" or "--loss_[param]=[value]" arguments # ---------------------------------------------------- with logger.LoggingBlock("Model and Loss", emph=True): # ---------------------------------------------------- # Model # ---------------------------------------------------- kwargs = tools.kwargs_from_args(args, "model") kwargs["args"] = args model = tools.instance_from_kwargs(args.model_class, kwargs) # ---------------------------------------------------- # Training loss # ---------------------------------------------------- training_loss = None if args.training_loss is not None: kwargs = tools.kwargs_from_args(args, "training_loss") kwargs["args"] = args training_loss = tools.instance_from_kwargs(args.training_loss_class, kwargs) # ---------------------------------------------------- # Validation loss # ---------------------------------------------------- validation_loss = None if args.validation_loss is not None: kwargs = tools.kwargs_from_args(args, "validation_loss") kwargs["args"] = args validation_loss = tools.instance_from_kwargs(args.validation_loss_class, kwargs) # ---------------------------------------------------- # Model and loss # ---------------------------------------------------- model_and_loss = ModelAndLoss(args, model, training_loss, validation_loss) # ----------------------------------------------------------- # If Cuda, transfer model to Cuda and wrap with DataParallel. # ----------------------------------------------------------- if args.cuda: model_and_loss = model_and_loss.cuda() # --------------------------------------------------------------- # Report some network statistics # --------------------------------------------------------------- logging.info("Batch Size: %i" % args.batch_size) logging.info("GPGPU: Cuda") if args.cuda else logging.info("GPGPU: off") logging.info("Network: %s" % args.model) logging.info("Number of parameters: %i" % tools.x2module(model_and_loss).num_parameters()) if training_loss is not None: logging.info("Training Key: %s" % args.training_key) logging.info("Training Loss: %s" % args.training_loss) if validation_loss is not None: logging.info("Validation Key: %s" % args.validation_key) logging.info("Validation Loss: %s" % args.validation_loss) return model_and_loss
def _log_statistics(dataset, prefix, name): with logger.LoggingBlock("%s Dataset: %s" % (prefix, name)): example_dict = dataset[0] # get sizes from first dataset example for key, value in sorted(example_dict.items()): if key in ["index", "basename"]: # no need to display these continue if isinstance(value, str): logging.info("{}: {}".format(key, value)) else: logging.info("%s: %s" % (key, _sizes_to_str(value))) logging.info("num_examples: %i" % len(dataset))
def configure_checkpoint_saver(args, model_and_loss): with logger.LoggingBlock("Checkpoint", emph=True): checkpoint_saver = CheckpointSaver() checkpoint_stats = None if args.checkpoint is None: logging.info("No checkpoint given.") logging.info("Starting from scratch with random initialization.") elif os.path.isfile(args.checkpoint): logging.info("Loading checkpoint file %s" % args.checkpoint) checkpoint_stats, filename = checkpoint_saver.restore( filename=args.checkpoint, model_and_loss=model_and_loss, include_params=args.checkpoint_include_params, exclude_params=args.checkpoint_exclude_params) # load epoch number no_extension = filename.split('.')[0] statistics_filename = no_extension + ".json" statistics = tools.read_json(statistics_filename) args.start_epoch = statistics['epoch'] + 1 elif os.path.isdir(args.checkpoint): if args.checkpoint_mode in ["resume_from_best"]: logging.info("Loading best checkpoint in %s" % args.checkpoint) checkpoint_stats, filename = checkpoint_saver.restore_best( directory=args.checkpoint, model_and_loss=model_and_loss, include_params=args.checkpoint_include_params, exclude_params=args.checkpoint_exclude_params) elif args.checkpoint_mode in ["resume_from_latest"]: logging.info("Loading latest checkpoint in %s" % args.checkpoint) checkpoint_stats, filename = checkpoint_saver.restore_latest( directory=args.checkpoint, model_and_loss=model_and_loss, include_params=args.checkpoint_include_params, exclude_params=args.checkpoint_exclude_params) else: logging.info("Unknown checkpoint_restore '%s' given!" % args.checkpoint_restore) quit() else: logging.info("Could not find checkpoint file or directory '%s'. Starting with random initialization." % args.checkpoint) return checkpoint_saver, checkpoint_stats
def configure_random_seed(args): with logger.LoggingBlock("Random Seeds", emph=True): # python seed = args.seed random.seed(seed) logging.info("Python seed: %i" % seed) # numpy seed += 1 np.random.seed(seed) logging.info("Numpy seed: %i" % seed) # torch seed += 1 torch.manual_seed(seed) logging.info("Torch CPU seed: %i" % seed) # torch cuda seed += 1 torch.cuda.manual_seed(seed) logging.info("Torch CUDA seed: %i" % seed)
def setup_logging_and_parse_arguments(blocktitle): # ---------------------------------------------------------------------------- # Get parse commandline and default arguments # ---------------------------------------------------------------------------- args, defaults = _parse_arguments() # ---------------------------------------------------------------------------- # Setup logbook before everything else # ---------------------------------------------------------------------------- logger.configure_logging(os.path.join(args.save, "logbook.txt")) # ---------------------------------------------------------------------------- # Write arguments to file, as json and txt # ---------------------------------------------------------------------------- json.write_dictionary_to_file(vars(args), filename=os.path.join( args.save, "args.json"), sortkeys=True) json.write_dictionary_to_file(vars(args), filename=os.path.join(args.save, "args.txt"), sortkeys=True) # ---------------------------------------------------------------------------- # Log arguments # ---------------------------------------------------------------------------- with logger.LoggingBlock(blocktitle, emph=True): for argument, value in sorted(vars(args).items()): reset = colorama.Style.RESET_ALL color = reset if value == defaults[argument] else colorama.Fore.CYAN if isinstance(value, dict): for sub_argument, sub_value in collections.OrderedDict( value).items(): logging.info("{}{}_{}: {}{}".format( color, argument, sub_argument, sub_value, reset)) else: logging.info("{}{}: {}{}".format(color, argument, value, reset)) # ---------------------------------------------------------------------------- # Postprocess # ---------------------------------------------------------------------------- args = postprocess_args(args) return args
def configure_runtime_augmentations(args): with logger.LoggingBlock("Runtime Augmentations", emph=True): training_augmentation = None validation_augmentation = None # ---------------------------------------------------- # Training Augmentation # ---------------------------------------------------- if args.training_augmentation is not None: kwargs = tools.kwargs_from_args(args, "training_augmentation") logging.info("training_augmentation: %s" % args.training_augmentation) for param, default in sorted(kwargs.items()): logging.info(" %s: %s" % (param, default)) kwargs["args"] = args training_augmentation = tools.instance_from_kwargs( args.training_augmentation_class, kwargs) if args.cuda: training_augmentation = training_augmentation.cuda() else: logging.info("training_augmentation: None") # ---------------------------------------------------- # Validation Augmentation # ---------------------------------------------------- if args.validation_augmentation is not None: kwargs = tools.kwargs_from_args(args, "validation_augmentation") logging.info("validation_augmentation: %s" % args.validation_augmentation) for param, default in sorted(kwargs.items()): logging.info(" %s: %s" % (param, default)) kwargs["args"] = args validation_augmentation = tools.instance_from_kwargs( args.validation_augmentation_class, kwargs) if args.cuda: validation_augmentation = validation_augmentation.cuda() else: logging.info("validation_augmentation: None") return training_augmentation, validation_augmentation
def configure_checkpoint_saver(args, model_and_loss): with logger.LoggingBlock("Checkpoint", emph=True): checkpoint_saver = CheckpointSaver() checkpoint_stats = None if args.checkpoint is None: logging.info("No checkpoint given.") logging.info("Starting from scratch with random initialization.") elif os.path.isfile(args.checkpoint): checkpoint_stats, filename = checkpoint_saver.restore( filename=args.checkpoint, model_and_loss=model_and_loss, include_params=args.checkpoint_include_params, exclude_params=args.checkpoint_exclude_params) elif os.path.isdir(args.checkpoint): if args.checkpoint_mode in ["resume_from_best"]: logging.info("Loading best checkpoint in %s" % args.checkpoint) checkpoint_stats, filename = checkpoint_saver.restore_best( directory=args.checkpoint, model_and_loss=model_and_loss, include_params=args.checkpoint_include_params, exclude_params=args.checkpoint_exclude_params) elif args.checkpoint_mode in ["resume_from_latest"]: logging.info("Loading latest checkpoint in %s" % args.checkpoint) checkpoint_stats, filename = checkpoint_saver.restore_latest( directory=args.checkpoint, model_and_loss=model_and_loss, include_params=args.checkpoint_include_params, exclude_params=args.checkpoint_exclude_params) else: logging.info("Unknown checkpoint_restore '%s' given!" % args.checkpoint_restore) quit() else: logging.info("Could not find checkpoint file or directory '%s'" % args.checkpoint) quit() return checkpoint_saver, checkpoint_stats
def configure_model_and_loss(args): # ---------------------------------------------------- # Dynamically load model and loss class with parameters # passed in via "--model_[param]=[value]" or "--loss_[param]=[value]" arguments # ---------------------------------------------------- with logger.LoggingBlock("Model and Loss", emph=True): # ---------------------------------------------------- # Model # ---------------------------------------------------- kwargs = typeinf.kwargs_from_args(args, "model") kwargs["args"] = args model = typeinf.instance_from_kwargs(args.model_class, kwargs) # ---------------------------------------------------- # Training loss # ---------------------------------------------------- loss = None if args.loss is not None: kwargs = typeinf.kwargs_from_args(args, "loss") kwargs["args"] = args loss = typeinf.instance_from_kwargs(args.loss_class, kwargs) # ---------------------------------------------------- # Model and loss # ---------------------------------------------------- model_and_loss = ModelAndLoss(args, model, loss) # --------------------------------------------------------------- # Report some network statistics # --------------------------------------------------------------- logging.info("Batch Size: %i" % args.batch_size) logging.info("Network: %s" % args.model) logging.info("Number of parameters: %i" % model_and_loss.num_parameters()) if loss is not None: logging.info("Training Key: %s" % args.training_key) logging.info("Training Loss: %s" % args.loss) logging.info("Validation Keys: %s" % args.validation_keys) logging.info("Validation Keys Minimize: %s" % args.validation_keys_minimize) return model_and_loss
def exec_runtime(args, device, checkpoint_saver, model_and_loss, optimizer, attack, lr_scheduler, train_loader, validation_loader, inference_loader, training_augmentation, validation_augmentation): # ---------------------------------------------------------------------------------------------- # Validation schedulers are a bit special: # They want to be called with a validation loss.. # ---------------------------------------------------------------------------------------------- validation_scheduler = (lr_scheduler is not None and args.lr_scheduler == "ReduceLROnPlateau") # -------------------------------------------------------- # Log some runtime info # -------------------------------------------------------- with logger.LoggingBlock("Runtime", emph=True): logging.info("start_epoch: %i" % args.start_epoch) logging.info("total_epochs: %i" % args.total_epochs) # --------------------------------------- # Total progress bar arguments # --------------------------------------- progressbar_args = { "desc": "Progress", "initial": args.start_epoch - 1, "invert_iterations": True, "iterable": range(1, args.total_epochs + 1), "logging_on_close": True, "logging_on_update": True, "postfix": False, "unit": "ep" } # -------------------------------------------------------- # Total progress bar # -------------------------------------------------------- print(''), logging.logbook('') total_progress = create_progressbar(**progressbar_args) print("\n") # -------------------------------------------------------- # Remember validation losses # -------------------------------------------------------- num_validation_losses = len(args.validation_keys) best_validation_losses = [ float("inf") if args.validation_keys_minimize[i] else -float("inf") for i in range(num_validation_losses) ] store_as_best = [False for i in range(num_validation_losses)] # -------------------------------------------------------- # Transfer model to device once before training/evaluation # -------------------------------------------------------- model_and_loss = model_and_loss.to(device) avg_loss_dict = {} for epoch in range(args.start_epoch, args.total_epochs + 1): with logger.LoggingBlock("Epoch %i/%i" % (epoch, args.total_epochs), emph=True): # -------------------------------------------------------- # Update standard learning scheduler # -------------------------------------------------------- if lr_scheduler is not None and not validation_scheduler: lr_scheduler.step(epoch) # -------------------------------------------------------- # Always report learning rate and model # -------------------------------------------------------- if lr_scheduler is None: logging.info( "model: %s lr: %s" % (args.model, format_learning_rate(args.optimizer_lr))) else: logging.info( "model: %s lr: %s" % (args.model, format_learning_rate(lr_scheduler.get_lr()))) # ------------------------------------------- # Create and run a training epoch # ------------------------------------------- if train_loader is not None: avg_loss_dict, _ = TrainingEpoch( args, desc=" Train", device=device, model_and_loss=model_and_loss, optimizer=optimizer, loader=train_loader, augmentation=training_augmentation).run() # ------------------------------------------- # Create and run a validation epoch # ------------------------------------------- if validation_loader is not None: # --------------------------------------------------- # Construct holistic recorder for epoch # --------------------------------------------------- epoch_recorder = configure_holistic_epoch_recorder( args, epoch=epoch, loader=validation_loader) with torch.no_grad(): avg_loss_dict, output_dict = EvaluationEpoch( args, desc="Validate", device=device, model_and_loss=model_and_loss, attack=attack, loader=validation_loader, recorder=epoch_recorder, augmentation=validation_augmentation).run() # ---------------------------------------------------------------- # Evaluate valdiation losses # ---------------------------------------------------------------- validation_losses = [ avg_loss_dict[vkey] for vkey in args.validation_keys ] for i, (vkey, vminimize) in enumerate( zip(args.validation_keys, args.validation_keys_minimize)): if vminimize: store_as_best[i] = validation_losses[ i] < best_validation_losses[i] else: store_as_best[i] = validation_losses[ i] > best_validation_losses[i] if store_as_best[i]: best_validation_losses[i] = validation_losses[i] # ---------------------------------------------------------------- # Update validation scheduler, if one is in place # We use the first key in validation keys as the relevant one # ---------------------------------------------------------------- if lr_scheduler is not None and validation_scheduler: lr_scheduler.step(validation_losses[0], epoch=epoch) # ---------------------------------------------------------------- # Also show best loss on total_progress # ---------------------------------------------------------------- total_progress_stats = { "best_" + vkey + "_avg": "%1.4f" % best_validation_losses[i] for i, vkey in enumerate(args.validation_keys) } total_progress.set_postfix(total_progress_stats) # ---------------------------------------------------------------- # Bump total progress # ---------------------------------------------------------------- total_progress.update() print('') # ---------------------------------------------------------------- # Store checkpoint # ---------------------------------------------------------------- if checkpoint_saver is not None: checkpoint_saver.save_latest( directory=args.save, model_and_loss=model_and_loss, stats_dict=dict(avg_loss_dict, epoch=epoch), store_as_best=store_as_best, store_prefixes=args.validation_keys) # ---------------------------------------------------------------- # Vertical space between epochs # ---------------------------------------------------------------- print(''), logging.logbook('') # ---------------------------------------------------------------- # Finish # ---------------------------------------------------------------- total_progress.close() logging.info("Finished.")
def configure_optimizer(args, model_and_loss): optimizer = None with logger.LoggingBlock("Optimizer", emph=True): if args.optimizer is not None: if model_and_loss.num_parameters() == 0: logging.info("No trainable parameters detected.") logging.info("Setting optimizer to None.") else: logging.info(args.optimizer) # ------------------------------------------- # Figure out all optimizer arguments # ------------------------------------------- all_kwargs = tools.kwargs_from_args(args, "optimizer") # ------------------------------------------- # Get the split of param groups # ------------------------------------------- kwargs_without_groups = { key: value for key, value in all_kwargs.items() if key != "group" } param_groups = all_kwargs["group"] # ---------------------------------------------------------------------- # Print arguments (without groups) # ---------------------------------------------------------------------- for param, default in sorted(kwargs_without_groups.items()): logging.info("%s: %s" % (param, default)) # ---------------------------------------------------------------------- # Construct actual optimizer params # ---------------------------------------------------------------------- kwargs = dict(kwargs_without_groups) if param_groups is None: # --------------------------------------------------------- # Add all trainable parameters if there is no param groups # --------------------------------------------------------- all_trainable_parameters = _generate_trainable_params( model_and_loss) kwargs["params"] = all_trainable_parameters else: # ------------------------------------------- # Add list of parameter groups instead # ------------------------------------------- trainable_parameter_groups = [] dnames, dparams = _param_names_and_trainable_generator( model_and_loss) dnames = set(dnames) dparams = set(list(dparams)) with logger.LoggingBlock("parameter_groups:"): for group in param_groups: # log group settings group_match = group["params"] group_args = { key: value for key, value in group.items() if key != "params" } with logger.LoggingBlock( "%s: %s" % (group_match, group_args)): # retrieve parameters by matching name gnames, gparams = _param_names_and_trainable_generator( model_and_loss, match=group_match) # log all names affected for n in sorted(gnames): logging.info(n) # set generator for group group_args["params"] = gparams # append parameter group trainable_parameter_groups.append(group_args) # update remaining trainable parameters dnames -= set(gnames) dparams -= set(list(gparams)) # append default parameter group trainable_parameter_groups.append( {"params": list(dparams)}) # and log its parameter names with logger.LoggingBlock("default:"): for dname in sorted(dnames): logging.info(dname) # set params in optimizer kwargs kwargs["params"] = trainable_parameter_groups # ------------------------------------------- # Create optimizer instance # ------------------------------------------- optimizer = tools.instance_from_kwargs(args.optimizer_class, kwargs) return optimizer
def configure_data_loaders(args): with logger.LoggingBlock("Datasets", emph=True): def _sizes_to_str(value): if np.isscalar(value): return '[1L]' else: return ' '.join([str([d for d in value.size()])]) def _log_statistics(dataset, prefix, name): with logger.LoggingBlock("%s Dataset: %s" % (prefix, name)): example_dict = dataset[ 0] # get sizes from first dataset example for key, value in sorted(example_dict.items()): if key in ["index", "basename"]: # no need to display these continue if isinstance(value, str): logging.info("{}: {}".format(key, value)) else: logging.info("%s: %s" % (key, _sizes_to_str(value))) logging.info("num_examples: %i" % len(dataset)) # ----------------------------------------------------------------------------------------- # GPU parameters -- turning off pin_memory? for resolving the deadlock? # ----------------------------------------------------------------------------------------- gpuargs = { "num_workers": args.num_workers, "pin_memory": False } if args.cuda else {} train_loader = None validation_loader = None inference_loader = None # ----------------------------------------------------------------------------------------- # Training dataset # ----------------------------------------------------------------------------------------- if args.training_dataset is not None: # ---------------------------------------------- # Figure out training_dataset arguments # ---------------------------------------------- kwargs = tools.kwargs_from_args(args, "training_dataset") kwargs["is_cropped"] = True kwargs["args"] = args # ---------------------------------------------- # Create training dataset # ---------------------------------------------- train_dataset = tools.instance_from_kwargs( args.training_dataset_class, kwargs) # ---------------------------------------------- # Create training loader # ---------------------------------------------- train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, **gpuargs) _log_statistics(train_dataset, prefix="Training", name=args.training_dataset) # ----------------------------------------------------------------------------------------- # Validation dataset # ----------------------------------------------------------------------------------------- if args.validation_dataset is not None: # ---------------------------------------------- # Figure out validation_dataset arguments # ---------------------------------------------- kwargs = tools.kwargs_from_args(args, "validation_dataset") kwargs["is_cropped"] = True kwargs["args"] = args # ---------------------------------------------- # Create validation dataset # ---------------------------------------------- validation_dataset = tools.instance_from_kwargs( args.validation_dataset_class, kwargs) # ---------------------------------------------- # Create validation loader # ---------------------------------------------- validation_loader = DataLoader(validation_dataset, batch_size=args.batch_size_val, shuffle=False, drop_last=False, **gpuargs) _log_statistics(validation_dataset, prefix="Validation", name=args.validation_dataset) return train_loader, validation_loader, inference_loader
def main(): # ---------------------------------------------------- # Change working directory # ---------------------------------------------------- os.chdir(os.path.dirname(os.path.realpath(__file__))) # ---------------------------------------------------- # Parse commandline arguments # ---------------------------------------------------- args = commandline.setup_logging_and_parse_arguments( blocktitle="Commandline Arguments") with logger.LoggingBlock("Source Code", emph=True): # ---------------------------------------------------- # Also archieve source code # ---------------------------------------------------- dst = os.path.join(args.save, "src.zip") zipsource.create_zip(filename=os.path.join(args.save, "src.zip"), directory=os.getcwd()) logging.info("Archieved code: %s" % dst) # ---------------------------------------------------- # Set random seed, possibly on Cuda # ---------------------------------------------------- config.configure_random_seed(args) # ---------------------------------------------------- # Change process title for `top` and `pkill` commands # This is more informative in `nvidia-smi` # ---------------------------------------------------- setproctitle.setproctitle(args.proctitle) # ------------------------------------------------------ # Fetch data loaders. Quit if no data loader is present # ------------------------------------------------------ train_loader, validation_loader, inference_loader = config.configure_data_loaders( args) # ------------------------------------------------------------------------- # Check whether any dataset could be found # ------------------------------------------------------------------------- success = any( loader is not None for loader in [train_loader, validation_loader, inference_loader]) if not success: logging.info( "No dataset could be loaded successfully. Please check dataset paths!" ) quit() # ------------------------------------------------------------------------- # Configure runtime augmentations # ------------------------------------------------------------------------- training_augmentation, validation_augmentation = config.configure_runtime_augmentations( args) # ---------------------------------------------------------- # Configure model and loss. # ---------------------------------------------------------- model_and_loss = config.configure_model_and_loss(args) # ----------------------------------------------------------- # Cuda # ----------------------------------------------------------- with logger.LoggingBlock("Device", emph=True): if args.cuda: device = torch.device("cuda") logging.info("GPU") else: device = torch.device("cpu") logging.info("CPU") # ---------------------------------------------------------- # Configure adversarial attack # ---------------------------------------------------------- attack = config.configure_attack(args) # -------------------------------------------------------- # Print model visualization # -------------------------------------------------------- if args.logging_model_graph: with logger.LoggingBlock("Model Graph", emph=True): logger.log_module_info(model_and_loss.model) if args.logging_loss_graph: with logger.LoggingBlock("Loss Graph", emph=True): logger.log_module_info(model_and_loss.loss) # ------------------------------------------------------------------------- # Possibly resume from checkpoint # ------------------------------------------------------------------------- checkpoint_saver, checkpoint_stats = config.configure_checkpoint_saver( args, model_and_loss) if checkpoint_stats is not None: logging.info(" Checkpoint Statistics:") for key, value in checkpoint_stats.items(): logging.info(" {}: {}".format(key, value)) # --------------------------------------------------------------------- # Set checkpoint stats # --------------------------------------------------------------------- if args.checkpoint_mode in ["resume_from_best", "resume_from_latest"]: args.start_epoch = checkpoint_stats["epoch"] # --------------------------------------------------------------------- # Checkpoint and save directory # --------------------------------------------------------------------- with logger.LoggingBlock("Save Directory", emph=True): logging.info("Save directory: %s" % args.save) if not os.path.exists(args.save): os.makedirs(args.save) # ---------------------------------------------------------- # Configure optimizer # ---------------------------------------------------------- optimizer = config.configure_optimizer(args, model_and_loss) # ---------------------------------------------------------- # Configure learning rate # ---------------------------------------------------------- lr_scheduler = config.configure_lr_scheduler(args, optimizer) # ------------------------------------------------------------ # If this is just an evaluation: overwrite savers and epochs # ------------------------------------------------------------ if args.evaluation: args.start_epoch = 1 args.total_epochs = 1 train_loader = None checkpoint_saver = None optimizer = None lr_scheduler = None # ---------------------------------------------------------- # Cuda optimization # ---------------------------------------------------------- if args.cuda: torch.backends.cudnn.benchmark = True # ---------------------------------------------------------- # Kickoff training, validation and/or testing # ---------------------------------------------------------- return runtime.exec_runtime( args, device=device, checkpoint_saver=checkpoint_saver, model_and_loss=model_and_loss, optimizer=optimizer, attack=attack, lr_scheduler=lr_scheduler, train_loader=train_loader, validation_loader=validation_loader, inference_loader=inference_loader, training_augmentation=training_augmentation, validation_augmentation=validation_augmentation)
def exec_runtime(args, checkpoint_saver, model_and_loss, optimizer, lr_scheduler, param_scheduler, train_loader, validation_loader, training_augmentation, validation_augmentation, visualizer): # -------------------------------------------------------------------------------- # Validation schedulers are a bit special: # They need special treatment as they want to be called with a validation loss.. # -------------------------------------------------------------------------------- validation_scheduler = (lr_scheduler is not None and args.lr_scheduler == "ReduceLROnPlateau") # -------------------------------------------------------- # Log some runtime info # -------------------------------------------------------- with logging.block("Runtime", emph=True): logging.value("start_epoch: ", args.start_epoch) logging.value("total_epochs: ", args.total_epochs) # --------------------------------------- # Total progress bar arguments # --------------------------------------- progressbar_args = { "desc": "Total", "initial": args.start_epoch - 1, "invert_iterations": True, "iterable": range(1, args.total_epochs + 1), "logging_on_close": True, "logging_on_update": True, "unit": "ep", "track_eta": True } # -------------------------------------------------------- # Total progress bar # -------------------------------------------------------- print(''), logging.logbook('') total_progress = create_progressbar(**progressbar_args) total_progress_stats = {} print("\n") # -------------------------------------------------k------- # Remember validation losses # -------------------------------------------------------- best_validation_losses = None store_as_best = None if validation_loader is not None: num_validation_losses = len(args.validation_keys) best_validation_losses = [ float("inf") if args.validation_modes[i] == 'min' else -float("inf") for i in range(num_validation_losses) ] store_as_best = [False for _ in range(num_validation_losses)] # ---------------------------------------------------------------- # Send Telegram message # ---------------------------------------------------------------- logging.telegram(format_telegram_status_update(args, epoch=0)) avg_loss_dict = {} for epoch in range(args.start_epoch, args.total_epochs + 1): # -------------------------------- # Make Epoch %i/%i header message # -------------------------------- epoch_header = "Epoch {}/{}{}{}".format( epoch, args.total_epochs, " " * 24, format_epoch_header_machine_stats(args)) with logger.LoggingBlock(epoch_header, emph=True): # ------------------------------------------------------------------------------- # Let TensorBoard know where we are.. # ------------------------------------------------------------------------------- summary.set_global_step(epoch) # ----------------------------------------------------------------- # Update standard learning scheduler and get current learning rate # ----------------------------------------------------------------- # Starting with PyTorch 1.1 the expected validation order is: # optimize(...) # validate(...) # scheduler.step().. # --------------------------------------------------------------------- # Update parameter schedule before the epoch # Note: Parameter schedulers are tuples of (optimizer, schedule) # --------------------------------------------------------------------- if param_scheduler is not None: param_scheduler.step(epoch=epoch) # ----------------------------------------------------------------- # Get current learning rate from either optimizer or scheduler # ----------------------------------------------------------------- lr = args.optimizer_lr if args.optimizer is not None else "None" if lr_scheduler is not None: lr = [group['lr'] for group in optimizer.param_groups] \ if args.optimizer is not None else "None" # -------------------------------------------------------- # Current Epoch header stats # -------------------------------------------------------- logging.info(format_epoch_header_stats(args, lr)) # ------------------------------------------- # Create and run a training epoch # ------------------------------------------- if train_loader is not None: if visualizer is not None: visualizer.on_epoch_init(lr, train=True, epoch=epoch, total_epochs=args.total_epochs) ema_loss_dict = RuntimeEpoch( args, desc="Train", augmentation=training_augmentation, loader=train_loader, model_and_loss=model_and_loss, optimizer=optimizer, visualizer=visualizer).run(train=True) if visualizer is not None: visualizer.on_epoch_finished( ema_loss_dict, train=True, epoch=epoch, total_epochs=args.total_epochs) # ------------------------------------------- # Create and run a validation epoch # ------------------------------------------- if validation_loader is not None: if visualizer is not None: visualizer.on_epoch_init(lr, train=False, epoch=epoch, total_epochs=args.total_epochs) # --------------------------------------------------- # Construct holistic recorder for epoch # --------------------------------------------------- epoch_recorder = configure_holistic_epoch_recorder( args, epoch=epoch, loader=validation_loader) with torch.no_grad(): avg_loss_dict = RuntimeEpoch( args, desc="Valid", augmentation=validation_augmentation, loader=validation_loader, model_and_loss=model_and_loss, recorder=epoch_recorder, visualizer=visualizer).run(train=False) try: epoch_recorder.add_scalars("evaluation_losses", avg_loss_dict) except Exception: pass if visualizer is not None: visualizer.on_epoch_finished( avg_loss_dict, train=False, epoch=epoch, total_epochs=args.total_epochs) # ---------------------------------------------------------------- # Evaluate valdiation losses # ---------------------------------------------------------------- validation_losses = [ avg_loss_dict[vkey] for vkey in args.validation_keys ] for i, (vkey, vmode) in enumerate( zip(args.validation_keys, args.validation_modes)): if vmode == 'min': store_as_best[i] = validation_losses[ i] < best_validation_losses[i] else: store_as_best[i] = validation_losses[ i] > best_validation_losses[i] if store_as_best[i]: best_validation_losses[i] = validation_losses[i] # ---------------------------------------------------------------- # Update validation scheduler, if one is in place # We use the first key in validation keys as the relevant one # ---------------------------------------------------------------- if lr_scheduler is not None: if validation_scheduler: lr_scheduler.step(validation_losses[0], epoch=epoch) else: lr_scheduler.step(epoch=epoch) # ---------------------------------------------------------------- # Also show best loss on total_progress # ---------------------------------------------------------------- total_progress_stats = { "best_" + vkey + "_avg": "%1.4f" % best_validation_losses[i] for i, vkey in enumerate(args.validation_keys) } total_progress.set_postfix(total_progress_stats) # ---------------------------------------------------------------- # Bump total progress # ---------------------------------------------------------------- total_progress.update() print('') # ---------------------------------------------------------------- # Get ETA string for display in loggers # ---------------------------------------------------------------- eta_str = total_progress.eta_str() # ---------------------------------------------------------------- # Send Telegram status udpate # ---------------------------------------------------------------- total_progress_stats['lr'] = format_learning_rate(lr) logging.telegram( format_telegram_status_update( args, eta_str=eta_str, epoch=epoch, total_progress_stats=total_progress_stats)) # ---------------------------------------------------------------- # Update ETA in progress title # ---------------------------------------------------------------- eta_proctitle = "{} finishes in {}".format(args.proctitle, eta_str) proctitles.setproctitle(eta_proctitle) # ---------------------------------------------------------------- # Store checkpoint # ---------------------------------------------------------------- if checkpoint_saver is not None and validation_loader is not None: checkpoint_saver.save_latest( directory=args.save, model_and_loss=model_and_loss, stats_dict=dict(avg_loss_dict, epoch=epoch), store_as_best=store_as_best, store_prefixes=args.validation_keys) # ---------------------------------------------------------------- # Vertical space between epochs # ---------------------------------------------------------------- print(''), logging.logbook('') # ---------------------------------------------------------------- # Finish up # ---------------------------------------------------------------- logging.telegram_flush() total_progress.close() logging.info("Finished.")
def main(): # Change working directory os.chdir(os.path.dirname(os.path.realpath(__file__))) # Parse commandline arguments args = commandline.setup_logging_and_parse_arguments( blocktitle="Commandline Arguments") # Set random seed, possibly on Cuda config.configure_random_seed(args) # DataLoader train_loader, validation_loader, inference_loader = config.configure_data_loaders( args) success = any( loader is not None for loader in [train_loader, validation_loader, inference_loader]) if not success: logging.info( "No dataset could be loaded successfully. Please check dataset paths!" ) quit() # Configure data augmentation training_augmentation, validation_augmentation = config.configure_runtime_augmentations( args) # Configure model and loss model_and_loss = config.configure_model_and_loss(args) # Multi-GPU automation with logger.LoggingBlock("Multi GPU", emph=True): logging.info("Let's use %d GPUs!" % torch.cuda.device_count()) model_and_loss._model = torch.nn.DataParallel(model_and_loss._model) # Resume from checkpoint if available checkpoint_saver, checkpoint_stats = config.configure_checkpoint_saver( args, model_and_loss) # Checkpoint and save directory with logger.LoggingBlock("Save Directory", emph=True): logging.info("Save directory: %s" % args.save) if not os.path.exists(args.save): os.makedirs(args.save) # Configure optimizer optimizer = config.configure_optimizer(args, model_and_loss) # Configure learning rate lr_scheduler = config.configure_lr_scheduler(args, optimizer) # If this is just an evaluation: overwrite savers and epochs if args.evaluation: args.start_epoch = 1 args.total_epochs = 1 train_loader = None checkpoint_saver = None optimizer = None lr_scheduler = None # Cuda optimization if args.cuda: torch.backends.cudnn.benchmark = True # Kickoff training, validation and/or testing return runtime.exec_runtime( args, checkpoint_saver=checkpoint_saver, model_and_loss=model_and_loss, optimizer=optimizer, lr_scheduler=lr_scheduler, train_loader=train_loader, validation_loader=validation_loader, inference_loader=inference_loader, training_augmentation=training_augmentation, validation_augmentation=validation_augmentation)
def exec_runtime(args, checkpoint_saver, model_and_loss, optimizer, lr_scheduler, train_loader, validation_loader, inference_loader, training_augmentation, validation_augmentation): # ---------------------------------------------------------------------------------------------- # Validation schedulers are a bit special: # They want to be called with a validation loss.. # ---------------------------------------------------------------------------------------------- validation_scheduler = (lr_scheduler is not None and args.lr_scheduler == "ReduceLROnPlateau") # -------------------------------------------------------- # Log some runtime info # -------------------------------------------------------- with logger.LoggingBlock("Runtime", emph=True): logging.info("start_epoch: %i" % args.start_epoch) logging.info("total_epochs: %i" % args.total_epochs) # --------------------------------------- # Total progress bar arguments # --------------------------------------- progressbar_args = { "desc": "Progress", "initial": args.start_epoch - 1, "invert_iterations": True, "iterable": range(1, args.total_epochs + 1), "logging_on_close": True, "logging_on_update": True, "postfix": False, "unit": "ep" } # -------------------------------------------------------- # Total progress bar # -------------------------------------------------------- print(''), logging.logbook('') total_progress = create_progressbar(**progressbar_args) print("\n") # -------------------------------------------------------- # Remember validation loss # -------------------------------------------------------- best_validation_loss = float("inf") if args.validation_key_minimize else -float("inf") store_as_best = False for epoch in range(args.start_epoch, args.total_epochs + 1): with logger.LoggingBlock("Epoch %i/%i" % (epoch, args.total_epochs), emph=True): # -------------------------------------------------------- # Update standard learning scheduler # -------------------------------------------------------- if lr_scheduler is not None and not validation_scheduler: lr_scheduler.step(epoch) # -------------------------------------------------------- # Always report learning rate # -------------------------------------------------------- if lr_scheduler is None: logging.info("lr: %s" % format_learning_rate(args.optimizer_lr)) else: logging.info("lr: %s" % format_learning_rate(lr_scheduler.get_lr())) # ------------------------------------------- # Create and run a training epoch # ------------------------------------------- if train_loader is not None: avg_loss_dict = TrainingEpoch( args, desc=" Train", model_and_loss=model_and_loss, optimizer=optimizer, loader=train_loader, augmentation=training_augmentation, checkpoint_saver=checkpoint_saver, checkpoint_args={ 'directory': args.save, 'model_and_loss': model_and_loss, 'stats_dict': dict({'epe': 0, 'F1': 0}, epoch=epoch), 'store_as_best': False } ).run() # ------------------------------------------- # Create and run a validation epoch # ------------------------------------------- if validation_loader is not None: # --------------------------------------------------- # Construct holistic recorder for epoch # --------------------------------------------------- avg_loss_dict = EvaluationEpoch( args, desc="Validate", model_and_loss=model_and_loss, loader=validation_loader, augmentation=validation_augmentation).run() # ---------------------------------------------------------------- # Evaluate whether this is the best validation_loss # ---------------------------------------------------------------- validation_loss = avg_loss_dict[args.validation_key] if args.validation_key_minimize: store_as_best = validation_loss < best_validation_loss else: store_as_best = validation_loss > best_validation_loss if store_as_best: best_validation_loss = validation_loss # ---------------------------------------------------------------- # Update validation scheduler, if one is in place # ---------------------------------------------------------------- if lr_scheduler is not None and validation_scheduler: lr_scheduler.step(validation_loss, epoch=epoch) # ---------------------------------------------------------------- # Also show best loss on total_progress # ---------------------------------------------------------------- total_progress_stats = { "best_" + args.validation_key + "_avg": "%1.4f" % best_validation_loss } total_progress.set_postfix(total_progress_stats) # ---------------------------------------------------------------- # Bump total progress # ---------------------------------------------------------------- total_progress.update() print('') # ---------------------------------------------------------------- # Store checkpoint # ---------------------------------------------------------------- if checkpoint_saver is not None: checkpoint_saver.save_latest( directory=args.save, model_and_loss=model_and_loss, stats_dict=dict(avg_loss_dict, epoch=epoch), store_as_best=store_as_best) # ---------------------------------------------------------------- # Vertical space between epochs # ---------------------------------------------------------------- print(''), logging.logbook('') # quit after completing epoch quit() # ---------------------------------------------------------------- # Finish # ---------------------------------------------------------------- total_progress.close() logging.info("Finished.")