def __init__(self, args, scale_factor=1.0, div_flow=0.05, num_scales=5, num_highres_scales=2, coarsest_resolution_loss_weight=0.32): super().__init__() self.args = args self.div_flow = div_flow self.num_scales = num_scales self.scale_factor = scale_factor # --------------------------------------------------------------------- # start with initial scale # for "low-resolution" scales we apply a scale factor of 4 # for "high-resolution" scales we apply a scale factor of 2 # # e.g. [0.005, 0.01, 0.02, 0.08, 0.32] # --------------------------------------------------------------------- self.weights = [coarsest_resolution_loss_weight] num_lowres_scales = num_scales - num_highres_scales for k in range(num_lowres_scales - 1): self.weights += [self.weights[-1] / 4] for k in range(num_highres_scales): self.weights += [self.weights[-1] / 2] self.weights.reverse() logging.value('MultiScaleEPE Weights: ', str(self.weights)) assert (len(self.weights) == num_scales) # sanity check
def load_state_dict_into_module(state_dict, module, strict=True): own_state = module.state_dict() for name, param in state_dict.items(): if name in own_state: if isinstance(param, nn.Parameter): # backwards compatibility for serialized parameters param = param.data try: own_state[name].resize_as_(param) own_state[name].copy_(param) except Exception: raise RuntimeError( 'While copying the parameter named {}, ' 'whose dimensions in the model are {} and ' 'whose dimensions in the checkpoint are {}.'.format( name, own_state[name].size(), param.size())) elif strict: logging.info('Unexpected key "{}" in state_dict'.format(name)) raise KeyError('') if strict: missing = set(own_state.keys()) - set(state_dict.keys()) if len(missing) > 0: logging.info('Missing keys in state_dict: ') logging.value("{}".format(missing)) raise KeyError('')
def configure_tensorboard_summaries(save): logdir = os.path.join(save, 'tb') writer = summary.SummaryWriter(logdir, flush_secs=constants.TENSORBOARD_FLUSH_SECS) setattr(summary, "_summary_writer", writer) with LoggingBlock("Tensorboard", emph=True): logging.value(' flush_secs:', constants.TENSORBOARD_FLUSH_SECS) logging.value(' logdir: ', logdir)
def try_register(name, module_class, registry, calling_frame): if name in registry: block_info = "Warning in {}[{}]:".format(calling_frame.filename, calling_frame.lineno) with logging.block(block_info): code_info = "{} yields duplicate factory entry!".format( calling_frame.code_context[0][0:-1]) logging.value(code_info) registry[name] = module_class
def configure_lr_scheduler(args, optimizer): with logging.block("Learning Rate Scheduler", emph=True): logging.value( "Scheduler: ", args.lr_scheduler if args.lr_scheduler is not None else "None") lr_scheduler = None if args.lr_scheduler is not None: kwargs = typeinf.kwargs_from_args(args, "lr_scheduler") with logging.block(): logging.values(kwargs) kwargs["optimizer"] = optimizer lr_scheduler = typeinf.instance_from_kwargs( args.lr_scheduler_class, kwargs=kwargs) return lr_scheduler
def restore_module_from_filename(module, filename, key='state_dict', include_params='*', exclude_params=(), translations=(), fuzzy_translation_keys=()): include_params = list(include_params) exclude_params = list(exclude_params) fuzzy_translation_keys = list(fuzzy_translation_keys) translations = dict(translations) # ------------------------------------------------------------------------------ # Make sure file exists # ------------------------------------------------------------------------------ if not os.path.isfile(filename): logging.info("Could not find checkpoint file '%s'!" % filename) quit() # ------------------------------------------------------------------------------ # Load checkpoint from file including the state_dict # ------------------------------------------------------------------------------ cpu_device = torch.device('cpu') checkpoint_dict = torch.load(filename, map_location=cpu_device) checkpoint_state_dict = checkpoint_dict[key] try: restore_keys, actual_translations = restore_module_from_state_dict( module, checkpoint_state_dict, include_params=include_params, exclude_params=exclude_params, translations=translations, fuzzy_translation_keys=fuzzy_translation_keys) except KeyError: with logging.block('Checkpoint keys:'): logging.value(checkpoint_state_dict.keys()) with logging.block('Module keys:'): logging.value(module.state_dict().keys()) logging.info( "Could not load checkpoint because of key errors. Checkpoint translations gone wrong?" ) quit() return checkpoint_dict, restore_keys, actual_translations
def msra_(mod_or_modules, mode='fan_out', nonlinearity='relu'): logging.value("Initializing MSRA") modules = _2modules(mod_or_modules) uninitialized_modules = {} for layer in modules: if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.ConvTranspose2d): nn.init.kaiming_normal_(layer.weight, mode=mode, nonlinearity=nonlinearity) if layer.bias is not None: nn.init.constant_(layer.bias, 0.0) elif isinstance(layer, nn.BatchNorm2d): identity_(layer) else: uninitialized_modules = _check_uninitialized(uninitialized_modules, layer) for name, num_params in uninitialized_modules.items(): logging.value("Found uninitialized layer of type '{}' [{} params]".format(name, num_params))
def configure_visualizers(args, model_and_loss, optimizer, param_scheduler, lr_scheduler, train_loader, validation_loader): with logging.block("Runtime Visualizers", emph=True): logging.value( "Visualizer: ", args.visualizer if args.visualizer is not None else "None") visualizer = None if args.visualizer is not None: kwargs = typeinf.kwargs_from_args(args, "visualizer") logging.values(kwargs) kwargs["args"] = args kwargs["model_and_loss"] = model_and_loss kwargs["optimizer"] = optimizer kwargs["param_scheduler"] = param_scheduler kwargs["lr_scheduler"] = lr_scheduler kwargs["train_loader"] = train_loader kwargs["validation_loader"] = validation_loader visualizer = typeinf.instance_from_kwargs(args.visualizer_class, kwargs=kwargs) return visualizer
def _log_statistics(loader, dataset): example_dict = loader.first_item( ) # get sizes from first dataset example for key, value in sorted(example_dict.items()): if key == "index" or "name" in key: # no need to display these continue if isinstance(value, str): logging.value("%s: " % key, value) elif isinstance(value, list) or isinstance(value, tuple): logging.value("%s: " % key, _sizes_to_str(value[0])) else: logging.value("%s: " % key, _sizes_to_str(value)) logging.value("num_examples: ", len(dataset))
def configure_runtime_augmentations(args): with logging.block("Runtime Augmentations", emph=True): training_augmentation = None validation_augmentation = None # ---------------------------------------------------- # Training Augmentation # ---------------------------------------------------- if args.training_augmentation is not None: kwargs = typeinf.kwargs_from_args(args, "training_augmentation") logging.value("training_augmentation: ", args.training_augmentation) with logging.block(): logging.values(kwargs) kwargs["args"] = args training_augmentation = typeinf.instance_from_kwargs( args.training_augmentation_class, kwargs=kwargs) training_augmentation = training_augmentation.to(args.device) else: logging.info("training_augmentation: None") # ---------------------------------------------------- # Training Augmentation # ---------------------------------------------------- if args.validation_augmentation is not None: kwargs = typeinf.kwargs_from_args(args, "validation_augmentation") logging.value("validation_augmentation: ", args.training_augmentation) with logging.block(): logging.values(kwargs) kwargs["args"] = args validation_augmentation = typeinf.instance_from_kwargs( args.validation_augmentation_class, kwargs=kwargs) validation_augmentation = validation_augmentation.to(args.device) else: logging.info("validation_augmentation: None") return training_augmentation, validation_augmentation
def fanmax_(mod_or_modules): logging.value("Initializing FAN_MAX") modules = _2modules(mod_or_modules) uninitialized_modules = {} for layer in modules: if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.ConvTranspose2d): m = layer.in_channels n = layer.out_channels k = layer.kernel_size stddev = np.sqrt(2.0 / (np.maximum(m, n)) * np.prod(k)) nn.init.normal_(layer.weight, mean=0.0, std=stddev) if layer.bias is not None: nn.init.constant_(layer.bias, 0.0) elif isinstance(layer, nn.BatchNorm2d): identity_(layer) else: uninitialized_modules = _check_uninitialized(uninitialized_modules, layer) for name, num_params in uninitialized_modules.items(): logging.value("Found uninitialized layer of type '{}' [{} params]".format(name, num_params))
def initialize(self, filename): if not os.path.isfile(filename): logging.info("Could not find {}".format(filename)) else: logging.info("Loading Telegram tokens from {}".format(filename)) bots = None with open(filename, "r") as f: lines = f.readlines() try: bots = json.loads(''.join(lines), encoding='utf-8') except Exception: raise ValueError('Could not read %s. %s' % (filename, sys.exc_info()[1])) if bots is not None: hostname = system.hostname() logging.value("Found Host: ", hostname) if "chat_id" in bots.keys(): self.chat_id = bots["chat_id"] if "machines" in bots.keys(): machines = bots["machines"] if hostname in machines.keys(): self.token = machines[hostname] if self.token is not None: try: # try out once from telegrambotapiwrapper import Api Api(token=self.token) except Exception: bots = None self.chat_id = None self.token = None logging.info( "Token seems to be invalid (or internet access is restricted)!" ) if bots is None or self.chat_id is None or self.token is None: logging.info("Could not set up telegram bot for some reason !")
def configure_parameter_scheduler(args, model_and_loss): param_groups = args.param_scheduler_group with logging.block("Parameter Scheduler", emph=True): if param_groups is None: logging.info("None") else: logging.value("Info: ", "Please set lr=0 for scheduled parameters!") scheduled_parameter_groups = [] with logging.block("parameter_groups:"): for group_kwargs in param_groups: group_match = group_kwargs["params"] group_args = { key: value for key, value in group_kwargs.items() if key != "params" } with logging.block("%s: %s" % (group_match, group_args)): gnames, gparams = _param_names_and_trainable_generator( model_and_loss, match=group_match) for n in sorted(gnames): logging.info(n) group_args['params'] = gparams scheduled_parameter_groups.append(group_args) # create schedulers for every parameter group schedulers = [ _configure_parameter_scheduler_group(kwargs) for kwargs in scheduled_parameter_groups ] # create container of parameter schedulers scheduler = facade.ParameterSchedulerContainer(schedulers) return scheduler return None
def configure_random_seed(args): with logging.block("Random Seeds", emph=True): seed = args.seed if seed is not None: # python random.seed(seed) logging.value("Python seed: ", seed) # numpy seed += 1 np.random.seed(seed) logging.value("Numpy seed: ", seed) # torch seed += 1 torch.manual_seed(seed) logging.value("Torch CPU seed: ", seed) # torch cuda seed += 1 torch.cuda.manual_seed(seed) logging.value("Torch CUDA seed: ", seed) else: logging.info("None")
def configure_model_and_loss(args): with logging.block("Model and Loss", emph=True): kwargs = typeinf.kwargs_from_args(args, "model") kwargs["args"] = args model = typeinf.instance_from_kwargs(args.model_class, kwargs=kwargs) loss = None if args.loss is not None: kwargs = typeinf.kwargs_from_args(args, "loss") kwargs["args"] = args loss = typeinf.instance_from_kwargs(args.loss_class, kwargs=kwargs) else: logging.info("Loss is None; you need to pick a loss!") quit() model_and_loss = facade.ModelAndLoss(args, model, loss) logging.value("Batch Size: ", args.batch_size) if loss is not None: logging.value("Loss: ", args.loss) logging.value("Network: ", args.model) logging.value("Number of parameters: ", model_and_loss.num_parameters()) if loss is not None: logging.value("Training Key: ", args.training_key) if args.validation_dataset is not None: logging.value("Validation Keys: ", args.validation_keys) logging.value("Validation Modes: ", args.validation_modes) return model_and_loss
def configure_data_loaders(args): with logging.block("Datasets", emph=True): def _sizes_to_str(value): if np.isscalar(value): return '1L' else: sizes = str([d for d in value.size()]) return ' '.join([strings.replace_index(sizes, 1, '#')]) def _log_statistics(loader, dataset): example_dict = loader.first_item( ) # get sizes from first dataset example for key, value in sorted(example_dict.items()): if key == "index" or "name" in key: # no need to display these continue if isinstance(value, str): logging.value("%s: " % key, value) elif isinstance(value, list) or isinstance(value, tuple): logging.value("%s: " % key, _sizes_to_str(value[0])) else: logging.value("%s: " % key, _sizes_to_str(value)) logging.value("num_examples: ", len(dataset)) # ----------------------------------------------------------------------------------------- # GPU parameters # ----------------------------------------------------------------------------------------- gpuargs = { "pin_memory": constants.DATALOADER_PIN_MEMORY } if args.cuda else {} train_loader_and_collation = None validation_loader_and_collation = None # ----------------------------------------------------------------- # This figures out from the args alone, whether we need batch collcation # ----------------------------------------------------------------- train_collation, validation_collation = configure_collation(args) # ----------------------------------------------------------------------------------------- # Training dataset # ----------------------------------------------------------------------------------------- if args.training_dataset is not None: # ---------------------------------------------- # Figure out training_dataset arguments # ---------------------------------------------- kwargs = typeinf.kwargs_from_args(args, "training_dataset") kwargs["args"] = args # ---------------------------------------------- # Create training dataset and loader # ---------------------------------------------- logging.value("Training Dataset: ", args.training_dataset) with logging.block(): train_dataset = typeinf.instance_from_kwargs( args.training_dataset_class, kwargs=kwargs) if args.batch_size > len(train_dataset): logging.info( "Problem: batch_size bigger than number of training dataset examples!" ) quit() train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=constants.TRAINING_DATALOADER_SHUFFLE, drop_last=constants.TRAINING_DATALOADER_DROP_LAST, num_workers=args.training_dataset_num_workers, **gpuargs) train_loader_and_collation = facade.LoaderAndCollation( args, loader=train_loader, collation=train_collation) _log_statistics(train_loader_and_collation, train_dataset) # ----------------------------------------------------------------------------------------- # Validation dataset # ----------------------------------------------------------------------------------------- if args.validation_dataset is not None: # ---------------------------------------------- # Figure out validation_dataset arguments # ---------------------------------------------- kwargs = typeinf.kwargs_from_args(args, "validation_dataset") kwargs["args"] = args # ------------------------------------------------------ # per default batch_size is the same as for training, # unless a validation_batch_size is specified. # ----------------------------------------------------- validation_batch_size = args.batch_size if args.validation_batch_size > 0: validation_batch_size = args.validation_batch_size # ---------------------------------------------- # Create validation dataset and loader # ---------------------------------------------- logging.value("Validation Dataset: ", args.validation_dataset) with logging.block(): validation_dataset = typeinf.instance_from_kwargs( args.validation_dataset_class, kwargs=kwargs) if validation_batch_size > len(validation_dataset): logging.info( "Problem: validation_batch_size bigger than number of validation dataset examples!" ) quit() validation_loader = DataLoader( validation_dataset, batch_size=validation_batch_size, shuffle=constants.VALIDATION_DATALOADER_SHUFFLE, drop_last=constants.VALIDATION_DATALOADER_DROP_LAST, num_workers=args.validation_dataset_num_workers, **gpuargs) validation_loader_and_collation = facade.LoaderAndCollation( args, loader=validation_loader, collation=validation_collation) _log_statistics(validation_loader_and_collation, validation_dataset) return train_loader_and_collation, validation_loader_and_collation
def configure_optimizer(args, model_and_loss): optimizer = None with logging.block("Optimizer", emph=True): logging.value("Algorithm: ", args.optimizer if args.optimizer is not None else "None") if args.optimizer is not None: if model_and_loss.num_parameters() == 0: logging.info("No trainable parameters detected.") logging.info("Setting optimizer to None.") else: with logging.block(): # ------------------------------------------- # Figure out all optimizer arguments # ------------------------------------------- all_kwargs = typeinf.kwargs_from_args(args, "optimizer") # ------------------------------------------- # Get the split of param groups # ------------------------------------------- kwargs_without_groups = { key: value for key, value in all_kwargs.items() if key != "group" } param_groups = all_kwargs["group"] # ---------------------------------------------------------------------- # Print arguments (without groups) # ---------------------------------------------------------------------- logging.values(kwargs_without_groups) # ---------------------------------------------------------------------- # Construct actual optimizer params # ---------------------------------------------------------------------- kwargs = dict(kwargs_without_groups) if param_groups is None: # --------------------------------------------------------- # Add all trainable parameters if there is no param groups # --------------------------------------------------------- all_trainable_parameters = _generate_trainable_params( model_and_loss) kwargs["params"] = all_trainable_parameters else: # ------------------------------------------- # Add list of parameter groups instead # ------------------------------------------- trainable_parameter_groups = [] dnames, dparams = _param_names_and_trainable_generator( model_and_loss) dnames = set(dnames) dparams = set(list(dparams)) with logging.block("parameter_groups:"): for group in param_groups: # log group settings group_match = group["params"] group_args = { key: value for key, value in group.items() if key != "params" } with logging.block("%s: %s" % (group_match, group_args)): # retrieve parameters by matching name gnames, gparams = _param_names_and_trainable_generator( model_and_loss, match=group_match) # log all names affected for n in sorted(gnames): logging.info(n) # set generator for group group_args["params"] = gparams # append parameter group trainable_parameter_groups.append( group_args) # update remaining trainable parameters dnames -= set(gnames) dparams -= set(list(gparams)) # append default parameter group trainable_parameter_groups.append( {"params": list(dparams)}) # and log its parameter names with logging.block("default:"): for dname in sorted(dnames): logging.info(dname) # set params in optimizer kwargs kwargs["params"] = trainable_parameter_groups # ------------------------------------------- # Create optimizer instance # ------------------------------------------- optimizer = typeinf.instance_from_kwargs( args.optimizer_class, kwargs=kwargs) return optimizer
def exec_runtime(args, checkpoint_saver, model_and_loss, optimizer, lr_scheduler, param_scheduler, train_loader, validation_loader, training_augmentation, validation_augmentation, visualizer): # -------------------------------------------------------------------------------- # Validation schedulers are a bit special: # They need special treatment as they want to be called with a validation loss.. # -------------------------------------------------------------------------------- validation_scheduler = (lr_scheduler is not None and args.lr_scheduler == "ReduceLROnPlateau") # -------------------------------------------------------- # Log some runtime info # -------------------------------------------------------- with logging.block("Runtime", emph=True): logging.value("start_epoch: ", args.start_epoch) logging.value("total_epochs: ", args.total_epochs) # --------------------------------------- # Total progress bar arguments # --------------------------------------- progressbar_args = { "desc": "Total", "initial": args.start_epoch - 1, "invert_iterations": True, "iterable": range(1, args.total_epochs + 1), "logging_on_close": True, "logging_on_update": True, "unit": "ep", "track_eta": True } # -------------------------------------------------------- # Total progress bar # -------------------------------------------------------- print(''), logging.logbook('') total_progress = create_progressbar(**progressbar_args) total_progress_stats = {} print("\n") # -------------------------------------------------k------- # Remember validation losses # -------------------------------------------------------- best_validation_losses = None store_as_best = None if validation_loader is not None: num_validation_losses = len(args.validation_keys) best_validation_losses = [ float("inf") if args.validation_modes[i] == 'min' else -float("inf") for i in range(num_validation_losses) ] store_as_best = [False for _ in range(num_validation_losses)] # ---------------------------------------------------------------- # Send Telegram message # ---------------------------------------------------------------- logging.telegram(format_telegram_status_update(args, epoch=0)) avg_loss_dict = {} for epoch in range(args.start_epoch, args.total_epochs + 1): # -------------------------------- # Make Epoch %i/%i header message # -------------------------------- epoch_header = "Epoch {}/{}{}{}".format( epoch, args.total_epochs, " " * 24, format_epoch_header_machine_stats(args)) with logger.LoggingBlock(epoch_header, emph=True): # ------------------------------------------------------------------------------- # Let TensorBoard know where we are.. # ------------------------------------------------------------------------------- summary.set_global_step(epoch) # ----------------------------------------------------------------- # Update standard learning scheduler and get current learning rate # ----------------------------------------------------------------- # Starting with PyTorch 1.1 the expected validation order is: # optimize(...) # validate(...) # scheduler.step().. # --------------------------------------------------------------------- # Update parameter schedule before the epoch # Note: Parameter schedulers are tuples of (optimizer, schedule) # --------------------------------------------------------------------- if param_scheduler is not None: param_scheduler.step(epoch=epoch) # ----------------------------------------------------------------- # Get current learning rate from either optimizer or scheduler # ----------------------------------------------------------------- lr = args.optimizer_lr if args.optimizer is not None else "None" if lr_scheduler is not None: lr = [group['lr'] for group in optimizer.param_groups] \ if args.optimizer is not None else "None" # -------------------------------------------------------- # Current Epoch header stats # -------------------------------------------------------- logging.info(format_epoch_header_stats(args, lr)) # ------------------------------------------- # Create and run a training epoch # ------------------------------------------- if train_loader is not None: if visualizer is not None: visualizer.on_epoch_init(lr, train=True, epoch=epoch, total_epochs=args.total_epochs) ema_loss_dict = RuntimeEpoch( args, desc="Train", augmentation=training_augmentation, loader=train_loader, model_and_loss=model_and_loss, optimizer=optimizer, visualizer=visualizer).run(train=True) if visualizer is not None: visualizer.on_epoch_finished( ema_loss_dict, train=True, epoch=epoch, total_epochs=args.total_epochs) # ------------------------------------------- # Create and run a validation epoch # ------------------------------------------- if validation_loader is not None: if visualizer is not None: visualizer.on_epoch_init(lr, train=False, epoch=epoch, total_epochs=args.total_epochs) # --------------------------------------------------- # Construct holistic recorder for epoch # --------------------------------------------------- epoch_recorder = configure_holistic_epoch_recorder( args, epoch=epoch, loader=validation_loader) with torch.no_grad(): avg_loss_dict = RuntimeEpoch( args, desc="Valid", augmentation=validation_augmentation, loader=validation_loader, model_and_loss=model_and_loss, recorder=epoch_recorder, visualizer=visualizer).run(train=False) try: epoch_recorder.add_scalars("evaluation_losses", avg_loss_dict) except Exception: pass if visualizer is not None: visualizer.on_epoch_finished( avg_loss_dict, train=False, epoch=epoch, total_epochs=args.total_epochs) # ---------------------------------------------------------------- # Evaluate valdiation losses # ---------------------------------------------------------------- validation_losses = [ avg_loss_dict[vkey] for vkey in args.validation_keys ] for i, (vkey, vmode) in enumerate( zip(args.validation_keys, args.validation_modes)): if vmode == 'min': store_as_best[i] = validation_losses[ i] < best_validation_losses[i] else: store_as_best[i] = validation_losses[ i] > best_validation_losses[i] if store_as_best[i]: best_validation_losses[i] = validation_losses[i] # ---------------------------------------------------------------- # Update validation scheduler, if one is in place # We use the first key in validation keys as the relevant one # ---------------------------------------------------------------- if lr_scheduler is not None: if validation_scheduler: lr_scheduler.step(validation_losses[0], epoch=epoch) else: lr_scheduler.step(epoch=epoch) # ---------------------------------------------------------------- # Also show best loss on total_progress # ---------------------------------------------------------------- total_progress_stats = { "best_" + vkey + "_avg": "%1.4f" % best_validation_losses[i] for i, vkey in enumerate(args.validation_keys) } total_progress.set_postfix(total_progress_stats) # ---------------------------------------------------------------- # Bump total progress # ---------------------------------------------------------------- total_progress.update() print('') # ---------------------------------------------------------------- # Get ETA string for display in loggers # ---------------------------------------------------------------- eta_str = total_progress.eta_str() # ---------------------------------------------------------------- # Send Telegram status udpate # ---------------------------------------------------------------- total_progress_stats['lr'] = format_learning_rate(lr) logging.telegram( format_telegram_status_update( args, eta_str=eta_str, epoch=epoch, total_progress_stats=total_progress_stats)) # ---------------------------------------------------------------- # Update ETA in progress title # ---------------------------------------------------------------- eta_proctitle = "{} finishes in {}".format(args.proctitle, eta_str) proctitles.setproctitle(eta_proctitle) # ---------------------------------------------------------------- # Store checkpoint # ---------------------------------------------------------------- if checkpoint_saver is not None and validation_loader is not None: checkpoint_saver.save_latest( directory=args.save, model_and_loss=model_and_loss, stats_dict=dict(avg_loss_dict, epoch=epoch), store_as_best=store_as_best, store_prefixes=args.validation_keys) # ---------------------------------------------------------------- # Vertical space between epochs # ---------------------------------------------------------------- print(''), logging.logbook('') # ---------------------------------------------------------------- # Finish up # ---------------------------------------------------------------- logging.telegram_flush() total_progress.close() logging.info("Finished.")
def main(): # --------------------------------------------------- # Set working directory to folder containing main.py # --------------------------------------------------- os.chdir(os.path.dirname(os.path.realpath(__file__))) # ---------------------------------------------------------------- # Activate syntax highlighting in tracebacks for better debugging # ---------------------------------------------------------------- colored_traceback.add_hook() # ----------------------------------------------------------- # Configure logging # ----------------------------------------------------------- logging_filename = os.path.join(commandline.parse_save_dir(), constants.LOGGING_LOGBOOK_FILENAME) logger.configure_logging(logging_filename) # ---------------------------------------------------------------- # Register type factories before parsing the commandline. # NOTE: We decided to explicitly call these init() functions, to # have more precise control over the timeline # ---------------------------------------------------------------- with logging.block("Registering factories", emph=True): augmentations.init() datasets.init() losses.init() models.init() optim.init() visualizers.init() logging.info('Done!') # ----------------------------------------------------------- # Parse commandline after factories have been filled # ----------------------------------------------------------- args = commandline.parse_arguments(blocktitle="Commandline Arguments") # ----------------------- # Telegram configuration # ----------------------- with logging.block("Telegram", emph=True): logger.configure_telegram(constants.LOGGING_TELEGRAM_MACHINES_FILENAME) # ---------------------------------------------------------------------- # Log git repository hash and make a compressed copy of the source code # ---------------------------------------------------------------------- with logging.block("Source Code", emph=True): logging.value("Git Hash: ", system.git_hash()) # Zip source code and copy to save folder filename = os.path.join(args.save, constants.LOGGING_ZIPSOURCE_FILENAME) zipsource.create_zip(filename=filename, directory=os.getcwd()) logging.value("Archieved code: ", filename) # ---------------------------------------------------- # Change process title for `top` and `pkill` commands # This is more "informative" in `nvidia-smi` ;-) # ---------------------------------------------------- args = config.configure_proctitle(args) # ------------------------------------------------- # Set random seed for python, numpy, torch, cuda.. # ------------------------------------------------- config.configure_random_seed(args) # ----------------------------------------------------------- # Machine stats # ----------------------------------------------------------- with logging.block("Machine Statistics", emph=True): if args.cuda: args.device = torch.device("cuda:0") logging.value("Cuda: ", torch.version.cuda) logging.value("Cuda device count: ", torch.cuda.device_count()) logging.value("Cuda device name: ", torch.cuda.get_device_name(0)) logging.value("CuDNN: ", torch.backends.cudnn.version()) device_no = 0 if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): device_no = os.environ['CUDA_VISIBLE_DEVICES'] args.actual_device = "gpu:%s" % device_no else: args.device = torch.device("cpu") args.actual_device = "cpu" logging.value("Hostname: ", system.hostname()) logging.value("PyTorch: ", torch.__version__) logging.value("PyTorch device: ", args.actual_device) # ------------------------------------------------------ # Fetch data loaders. Quit if no data loader is present # ------------------------------------------------------ train_loader, validation_loader = config.configure_data_loaders(args) # ------------------------------------------------------------------------- # Check whether any dataset could be found # ------------------------------------------------------------------------- success = any(loader is not None for loader in [train_loader, validation_loader]) if not success: logging.info( "No dataset could be loaded successfully. Please check dataset paths!" ) quit() # ------------------------------------------------------------------------- # Configure runtime augmentations # ------------------------------------------------------------------------- training_augmentation, validation_augmentation = config.configure_runtime_augmentations( args) # ---------------------------------------------------------- # Configure model and loss. # ---------------------------------------------------------- model_and_loss = config.configure_model_and_loss(args) # -------------------------------------------------------- # Print model visualization # -------------------------------------------------------- if args.logging_model_graph: with logging.block("Model Graph", emph=True): logger.log_module_info(model_and_loss.model) if args.logging_loss_graph: with logging.block("Loss Graph", emph=True): logger.log_module_info(model_and_loss.loss) # ------------------------------------------------------------------------- # Possibly resume from checkpoint # ------------------------------------------------------------------------- checkpoint_saver, checkpoint_stats = config.configure_checkpoint_saver( args, model_and_loss) if checkpoint_stats is not None: with logging.block(): logging.info("Checkpoint Statistics:") with logging.block(): logging.values(checkpoint_stats) # --------------------------------------------------------------------- # Set checkpoint stats # --------------------------------------------------------------------- if args.checkpoint_mode in ["resume_from_best", "resume_from_latest"]: args.start_epoch = checkpoint_stats["epoch"] # --------------------------------------------------------------------- # Checkpoint and save directory # --------------------------------------------------------------------- with logging.block("Save Directory", emph=True): if args.save is None: logging.info("No 'save' directory specified!") quit() logging.value("Save directory: ", args.save) if not os.path.exists(args.save): os.makedirs(args.save) # ------------------------------------------------------------ # If this is just an evaluation: overwrite savers and epochs # ------------------------------------------------------------ if args.training_dataset is None and args.validation_dataset is not None: args.start_epoch = 1 args.total_epochs = 1 train_loader = None checkpoint_saver = None args.optimizer = None args.lr_scheduler = None # ---------------------------------------------------- # Tensorboard summaries # ---------------------------------------------------- logger.configure_tensorboard_summaries(args.save) # ------------------------------------------------------------------- # From PyTorch API: # If you need to move a model to GPU via .cuda(), please do so before # constructing optimizers for it. Parameters of a model after .cuda() # will be different objects with those before the call. # In general, you should make sure that optimized parameters live in # consistent locations when optimizers are constructed and used. # ------------------------------------------------------------------- model_and_loss = model_and_loss.to(args.device) # ---------------------------------------------------------- # Configure optimizer # ---------------------------------------------------------- optimizer = config.configure_optimizer(args, model_and_loss) # ---------------------------------------------------------- # Configure learning rate # ---------------------------------------------------------- lr_scheduler = config.configure_lr_scheduler(args, optimizer) # -------------------------------------------------------------------------- # Configure parameter scheduling # -------------------------------------------------------------------------- param_scheduler = config.configure_parameter_scheduler( args, model_and_loss) # quit() # ---------------------------------------------------------- # Cuda optimization # ---------------------------------------------------------- if args.cuda: torch.backends.cudnn.benchmark = constants.CUDNN_BENCHMARK # ---------------------------------------------------------- # Configurate runtime visualization # ---------------------------------------------------------- visualizer = config.configure_visualizers( args, model_and_loss=model_and_loss, optimizer=optimizer, param_scheduler=param_scheduler, lr_scheduler=lr_scheduler, train_loader=train_loader, validation_loader=validation_loader) if visualizer is not None: visualizer = visualizer.to(args.device) # ---------------------------------------------------------- # Kickoff training, validation and/or testing # ---------------------------------------------------------- return runtime.exec_runtime( args, checkpoint_saver=checkpoint_saver, lr_scheduler=lr_scheduler, param_scheduler=param_scheduler, model_and_loss=model_and_loss, optimizer=optimizer, train_loader=train_loader, training_augmentation=training_augmentation, validation_augmentation=validation_augmentation, validation_loader=validation_loader, visualizer=visualizer)