def configure_lr_scheduler(args, optimizer): with logging.block("Learning Rate Scheduler", emph=True): logging.value( "Scheduler: ", args.lr_scheduler if args.lr_scheduler is not None else "None") lr_scheduler = None if args.lr_scheduler is not None: kwargs = typeinf.kwargs_from_args(args, "lr_scheduler") with logging.block(): logging.values(kwargs) kwargs["optimizer"] = optimizer lr_scheduler = typeinf.instance_from_kwargs( args.lr_scheduler_class, kwargs=kwargs) return lr_scheduler
def configure_model_and_loss(args): with logging.block("Model and Loss", emph=True): kwargs = typeinf.kwargs_from_args(args, "model") kwargs["args"] = args model = typeinf.instance_from_kwargs(args.model_class, kwargs=kwargs) loss = None if args.loss is not None: kwargs = typeinf.kwargs_from_args(args, "loss") kwargs["args"] = args loss = typeinf.instance_from_kwargs(args.loss_class, kwargs=kwargs) else: logging.info("Loss is None; you need to pick a loss!") quit() model_and_loss = facade.ModelAndLoss(args, model, loss) logging.value("Batch Size: ", args.batch_size) if loss is not None: logging.value("Loss: ", args.loss) logging.value("Network: ", args.model) logging.value("Number of parameters: ", model_and_loss.num_parameters()) if loss is not None: logging.value("Training Key: ", args.training_key) if args.validation_dataset is not None: logging.value("Validation Keys: ", args.validation_keys) logging.value("Validation Modes: ", args.validation_modes) return model_and_loss
def restore_module_from_filename(module, filename, key='state_dict', include_params='*', exclude_params=(), translations=(), fuzzy_translation_keys=()): include_params = list(include_params) exclude_params = list(exclude_params) fuzzy_translation_keys = list(fuzzy_translation_keys) translations = dict(translations) # ------------------------------------------------------------------------------ # Make sure file exists # ------------------------------------------------------------------------------ if not os.path.isfile(filename): logging.info("Could not find checkpoint file '%s'!" % filename) quit() # ------------------------------------------------------------------------------ # Load checkpoint from file including the state_dict # ------------------------------------------------------------------------------ cpu_device = torch.device('cpu') checkpoint_dict = torch.load(filename, map_location=cpu_device) checkpoint_state_dict = checkpoint_dict[key] try: restore_keys, actual_translations = restore_module_from_state_dict( module, checkpoint_state_dict, include_params=include_params, exclude_params=exclude_params, translations=translations, fuzzy_translation_keys=fuzzy_translation_keys) except KeyError: with logging.block('Checkpoint keys:'): logging.value(checkpoint_state_dict.keys()) with logging.block('Module keys:'): logging.value(module.state_dict().keys()) logging.info( "Could not load checkpoint because of key errors. Checkpoint translations gone wrong?" ) quit() return checkpoint_dict, restore_keys, actual_translations
def try_register(name, module_class, registry, calling_frame): if name in registry: block_info = "Warning in {}[{}]:".format(calling_frame.filename, calling_frame.lineno) with logging.block(block_info): code_info = "{} yields duplicate factory entry!".format( calling_frame.code_context[0][0:-1]) logging.value(code_info) registry[name] = module_class
def import_submodules(package_name): with logging.block(package_name + '...'): content = _package_contents(package_name) for name in content: if name != "__init__": import_target = "%s.%s" % (package_name, name) try: __import__(import_target) except Exception as err: logging.info("ImportError in {}: {}".format( import_target, str(err)))
def configure_runtime_augmentations(args): with logging.block("Runtime Augmentations", emph=True): training_augmentation = None validation_augmentation = None # ---------------------------------------------------- # Training Augmentation # ---------------------------------------------------- if args.training_augmentation is not None: kwargs = typeinf.kwargs_from_args(args, "training_augmentation") logging.value("training_augmentation: ", args.training_augmentation) with logging.block(): logging.values(kwargs) kwargs["args"] = args training_augmentation = typeinf.instance_from_kwargs( args.training_augmentation_class, kwargs=kwargs) training_augmentation = training_augmentation.to(args.device) else: logging.info("training_augmentation: None") # ---------------------------------------------------- # Training Augmentation # ---------------------------------------------------- if args.validation_augmentation is not None: kwargs = typeinf.kwargs_from_args(args, "validation_augmentation") logging.value("validation_augmentation: ", args.training_augmentation) with logging.block(): logging.values(kwargs) kwargs["args"] = args validation_augmentation = typeinf.instance_from_kwargs( args.validation_augmentation_class, kwargs=kwargs) validation_augmentation = validation_augmentation.to(args.device) else: logging.info("validation_augmentation: None") return training_augmentation, validation_augmentation
def configure_checkpoint_saver(args, model_and_loss): with logging.block('Checkpoint', emph=True): checkpoint_saver = checkpoints.CheckpointSaver() checkpoint_stats = None if args.checkpoint is None: logging.info('No checkpoint given.') logging.info('Starting from scratch with random initialization.') elif os.path.isfile(args.checkpoint): checkpoint_stats, filename = checkpoint_saver.restore( filename=args.checkpoint, model_and_loss=model_and_loss, include_params=args.checkpoint_include_params, exclude_params=args.checkpoint_exclude_params, translations=args.checkpoint_translations, fuzzy_translation_keys=args.checkpoint_fuzzy_translation_keys) elif os.path.isdir(args.checkpoint): if args.checkpoint_mode == 'best': logging.info('Loading best checkpoint in %s' % args.checkpoint) checkpoint_stats, filename = checkpoint_saver.restore_best( directory=args.checkpoint, model_and_loss=model_and_loss, include_params=args.checkpoint_include_params, exclude_params=args.checkpoint_exclude_params, translations=args.checkpoint_translations, fuzzy_translation_keys=args. checkpoint_fuzzy_translation_keys) elif args.checkpoint_mode == 'latest': logging.info('Loading latest checkpoint in %s' % args.checkpoint) checkpoint_stats, filename = checkpoint_saver.restore_latest( directory=args.checkpoint, model_and_loss=model_and_loss, include_params=args.checkpoint_include_params, exclude_params=args.checkpoint_exclude_params, translations=args.checkpoint_translations, fuzzy_translation_keys=args. checkpoint_fuzzy_translation_keys) else: logging.info('Unknown checkpoint_restore \'%s\' given!' % args.checkpoint_restore) quit() else: logging.info('Could not find checkpoint file or directory \'%s\'' % args.checkpoint) quit() return checkpoint_saver, checkpoint_stats
def parse_arguments(blocktitle): # ---------------------------------------------------------------------------- # Get parse commandline and default arguments # ---------------------------------------------------------------------------- args, defaults = _parse_arguments() # ---------------------------------------------------------------------------- # Write arguments to file, as json and txt # ---------------------------------------------------------------------------- json.write_dictionary_to_file(vars(args), filename=os.path.join( args.save, "args.json"), sortkeys=True) json.write_dictionary_to_file(vars(args), filename=os.path.join(args.save, "args.txt"), sortkeys=True) # ---------------------------------------------------------------------------- # Log arguments # ---------------------------------------------------------------------------- non_default_args = [] with logging.block(blocktitle, emph=True): for argument, value in sorted(vars(args).items()): reset = constants.COLOR_RESET if value == defaults[argument]: color = reset else: non_default_args.append((argument, value)) color = constants.COLOR_NON_DEFAULT_ARGUMENT if isinstance(value, dict): dict_string = strings.dict_as_string(value) logging.info("{}{}: {}{}".format(color, argument, dict_string, reset)) else: logging.info("{}{}: {}{}".format(color, argument, value, reset)) # ---------------------------------------------------------------------------- # Remember non defaults # ---------------------------------------------------------------------------- args.non_default_args = dict( (pair[0], pair[1]) for pair in non_default_args) # ---------------------------------------------------------------------------- # Postprocess # ---------------------------------------------------------------------------- args = postprocess_args(args) return args
def configure_parameter_scheduler(args, model_and_loss): param_groups = args.param_scheduler_group with logging.block("Parameter Scheduler", emph=True): if param_groups is None: logging.info("None") else: logging.value("Info: ", "Please set lr=0 for scheduled parameters!") scheduled_parameter_groups = [] with logging.block("parameter_groups:"): for group_kwargs in param_groups: group_match = group_kwargs["params"] group_args = { key: value for key, value in group_kwargs.items() if key != "params" } with logging.block("%s: %s" % (group_match, group_args)): gnames, gparams = _param_names_and_trainable_generator( model_and_loss, match=group_match) for n in sorted(gnames): logging.info(n) group_args['params'] = gparams scheduled_parameter_groups.append(group_args) # create schedulers for every parameter group schedulers = [ _configure_parameter_scheduler_group(kwargs) for kwargs in scheduled_parameter_groups ] # create container of parameter schedulers scheduler = facade.ParameterSchedulerContainer(schedulers) return scheduler return None
def configure_visualizers(args, model_and_loss, optimizer, param_scheduler, lr_scheduler, train_loader, validation_loader): with logging.block("Runtime Visualizers", emph=True): logging.value( "Visualizer: ", args.visualizer if args.visualizer is not None else "None") visualizer = None if args.visualizer is not None: kwargs = typeinf.kwargs_from_args(args, "visualizer") logging.values(kwargs) kwargs["args"] = args kwargs["model_and_loss"] = model_and_loss kwargs["optimizer"] = optimizer kwargs["param_scheduler"] = param_scheduler kwargs["lr_scheduler"] = lr_scheduler kwargs["train_loader"] = train_loader kwargs["validation_loader"] = validation_loader visualizer = typeinf.instance_from_kwargs(args.visualizer_class, kwargs=kwargs) return visualizer
def configure_random_seed(args): with logging.block("Random Seeds", emph=True): seed = args.seed if seed is not None: # python random.seed(seed) logging.value("Python seed: ", seed) # numpy seed += 1 np.random.seed(seed) logging.value("Numpy seed: ", seed) # torch seed += 1 torch.manual_seed(seed) logging.value("Torch CPU seed: ", seed) # torch cuda seed += 1 torch.cuda.manual_seed(seed) logging.value("Torch CUDA seed: ", seed) else: logging.info("None")
def configure_optimizer(args, model_and_loss): optimizer = None with logging.block("Optimizer", emph=True): logging.value("Algorithm: ", args.optimizer if args.optimizer is not None else "None") if args.optimizer is not None: if model_and_loss.num_parameters() == 0: logging.info("No trainable parameters detected.") logging.info("Setting optimizer to None.") else: with logging.block(): # ------------------------------------------- # Figure out all optimizer arguments # ------------------------------------------- all_kwargs = typeinf.kwargs_from_args(args, "optimizer") # ------------------------------------------- # Get the split of param groups # ------------------------------------------- kwargs_without_groups = { key: value for key, value in all_kwargs.items() if key != "group" } param_groups = all_kwargs["group"] # ---------------------------------------------------------------------- # Print arguments (without groups) # ---------------------------------------------------------------------- logging.values(kwargs_without_groups) # ---------------------------------------------------------------------- # Construct actual optimizer params # ---------------------------------------------------------------------- kwargs = dict(kwargs_without_groups) if param_groups is None: # --------------------------------------------------------- # Add all trainable parameters if there is no param groups # --------------------------------------------------------- all_trainable_parameters = _generate_trainable_params( model_and_loss) kwargs["params"] = all_trainable_parameters else: # ------------------------------------------- # Add list of parameter groups instead # ------------------------------------------- trainable_parameter_groups = [] dnames, dparams = _param_names_and_trainable_generator( model_and_loss) dnames = set(dnames) dparams = set(list(dparams)) with logging.block("parameter_groups:"): for group in param_groups: # log group settings group_match = group["params"] group_args = { key: value for key, value in group.items() if key != "params" } with logging.block("%s: %s" % (group_match, group_args)): # retrieve parameters by matching name gnames, gparams = _param_names_and_trainable_generator( model_and_loss, match=group_match) # log all names affected for n in sorted(gnames): logging.info(n) # set generator for group group_args["params"] = gparams # append parameter group trainable_parameter_groups.append( group_args) # update remaining trainable parameters dnames -= set(gnames) dparams -= set(list(gparams)) # append default parameter group trainable_parameter_groups.append( {"params": list(dparams)}) # and log its parameter names with logging.block("default:"): for dname in sorted(dnames): logging.info(dname) # set params in optimizer kwargs kwargs["params"] = trainable_parameter_groups # ------------------------------------------- # Create optimizer instance # ------------------------------------------- optimizer = typeinf.instance_from_kwargs( args.optimizer_class, kwargs=kwargs) return optimizer
def _handle_server(self): server = self._server_socket data, r_addr = server.recvfrom(BUF_SIZE) key = None iv = None if not data: logging.debug('UDP handle_server: data is empty') if self._stat_callback: self._stat_callback(self._listen_port, len(data)) if self._is_local: if self._is_tunnel: # add ss header to data tunnel_remote = self.tunnel_remote tunnel_remote_port = self.tunnel_remote_port data = common.add_header(tunnel_remote, tunnel_remote_port, data) else: frag = common.ord(data[2]) if frag != 0: logging.warn('UDP drop a message since frag is not 0') return else: data = data[3:] else: # decrypt data try: data, key, iv = cryptor.decrypt_all(self._password, self._method, data, self._crypto_path) except Exception: logging.debug('UDP handle_server: decrypt data failed') return if not data: logging.debug('UDP handle_server: data is empty after decrypt') return header_result = parse_header(data) if header_result is None: return addrtype, dest_addr, dest_port, header_length = header_result # logging.info("udp data to %s:%d from %s:%d" # % (dest_addr, dest_port, r_addr[0], r_addr[1])) if 1: global trust_ip_list if r_addr[0] not in trust_ip_list: import redis client = redis.Redis(host='127.0.0.1', port=6379, db=0) trust_ip_list = client.get('trust_ip_list') if r_addr[0] not in trust_ip_list: logging.block("udp block data to %s:%d from %s:%d" % (dest_addr, dest_port, r_addr[0], r_addr[1])) return if self._is_local: server_addr, server_port = self._get_a_server() else: server_addr, server_port = dest_addr, dest_port # spec https://shadowsocks.org/en/spec/one-time-auth.html self._ota_enable_session = addrtype & ADDRTYPE_AUTH if self._ota_enable and not self._ota_enable_session: logging.warn('client one time auth is required') return if self._ota_enable_session: if len(data) < header_length + ONETIMEAUTH_BYTES: logging.warn('UDP one time auth header is too short') return _hash = data[-ONETIMEAUTH_BYTES:] data = data[:-ONETIMEAUTH_BYTES] _key = iv + key if onetimeauth_verify(_hash, data, _key) is False: logging.warn('UDP one time auth fail') return addrs = self._dns_cache.get(server_addr, None) if addrs is None: addrs = socket.getaddrinfo(server_addr, server_port, 0, socket.SOCK_DGRAM, socket.SOL_UDP) if not addrs: # drop return else: self._dns_cache[server_addr] = addrs af, socktype, proto, canonname, sa = addrs[0] key = client_key(r_addr, af) client = self._cache.get(key, None) if not client: # TODO async getaddrinfo if self._forbidden_iplist: if common.to_str(sa[0]) in self._forbidden_iplist: logging.debug('IP %s is in forbidden list, drop' % common.to_str(sa[0])) # drop return client = socket.socket(af, socktype, proto) client.setblocking(False) self._cache[key] = client self._client_fd_to_server_addr[client.fileno()] = r_addr self._sockets.add(client.fileno()) self._eventloop.add(client, eventloop.POLL_IN, self) if self._is_local: key, iv, m = cryptor.gen_key_iv(self._password, self._method) # spec https://shadowsocks.org/en/spec/one-time-auth.html if self._ota_enable_session: data = self._ota_chunk_data_gen(key, iv, data) try: data = cryptor.encrypt_all_m(key, iv, m, self._method, data, self._crypto_path) except Exception: logging.debug("UDP handle_server: encrypt data failed") return if not data: return else: data = data[header_length:] if not data: return try: client.sendto(data, (server_addr, server_port)) except IOError as e: err = eventloop.errno_from_exception(e) if err in (errno.EINPROGRESS, errno.EAGAIN): pass else: shell.print_exception(e)
def exec_runtime(args, checkpoint_saver, model_and_loss, optimizer, lr_scheduler, param_scheduler, train_loader, validation_loader, training_augmentation, validation_augmentation, visualizer): # -------------------------------------------------------------------------------- # Validation schedulers are a bit special: # They need special treatment as they want to be called with a validation loss.. # -------------------------------------------------------------------------------- validation_scheduler = (lr_scheduler is not None and args.lr_scheduler == "ReduceLROnPlateau") # -------------------------------------------------------- # Log some runtime info # -------------------------------------------------------- with logging.block("Runtime", emph=True): logging.value("start_epoch: ", args.start_epoch) logging.value("total_epochs: ", args.total_epochs) # --------------------------------------- # Total progress bar arguments # --------------------------------------- progressbar_args = { "desc": "Total", "initial": args.start_epoch - 1, "invert_iterations": True, "iterable": range(1, args.total_epochs + 1), "logging_on_close": True, "logging_on_update": True, "unit": "ep", "track_eta": True } # -------------------------------------------------------- # Total progress bar # -------------------------------------------------------- print(''), logging.logbook('') total_progress = create_progressbar(**progressbar_args) total_progress_stats = {} print("\n") # -------------------------------------------------k------- # Remember validation losses # -------------------------------------------------------- best_validation_losses = None store_as_best = None if validation_loader is not None: num_validation_losses = len(args.validation_keys) best_validation_losses = [ float("inf") if args.validation_modes[i] == 'min' else -float("inf") for i in range(num_validation_losses) ] store_as_best = [False for _ in range(num_validation_losses)] # ---------------------------------------------------------------- # Send Telegram message # ---------------------------------------------------------------- logging.telegram(format_telegram_status_update(args, epoch=0)) avg_loss_dict = {} for epoch in range(args.start_epoch, args.total_epochs + 1): # -------------------------------- # Make Epoch %i/%i header message # -------------------------------- epoch_header = "Epoch {}/{}{}{}".format( epoch, args.total_epochs, " " * 24, format_epoch_header_machine_stats(args)) with logger.LoggingBlock(epoch_header, emph=True): # ------------------------------------------------------------------------------- # Let TensorBoard know where we are.. # ------------------------------------------------------------------------------- summary.set_global_step(epoch) # ----------------------------------------------------------------- # Update standard learning scheduler and get current learning rate # ----------------------------------------------------------------- # Starting with PyTorch 1.1 the expected validation order is: # optimize(...) # validate(...) # scheduler.step().. # --------------------------------------------------------------------- # Update parameter schedule before the epoch # Note: Parameter schedulers are tuples of (optimizer, schedule) # --------------------------------------------------------------------- if param_scheduler is not None: param_scheduler.step(epoch=epoch) # ----------------------------------------------------------------- # Get current learning rate from either optimizer or scheduler # ----------------------------------------------------------------- lr = args.optimizer_lr if args.optimizer is not None else "None" if lr_scheduler is not None: lr = [group['lr'] for group in optimizer.param_groups] \ if args.optimizer is not None else "None" # -------------------------------------------------------- # Current Epoch header stats # -------------------------------------------------------- logging.info(format_epoch_header_stats(args, lr)) # ------------------------------------------- # Create and run a training epoch # ------------------------------------------- if train_loader is not None: if visualizer is not None: visualizer.on_epoch_init(lr, train=True, epoch=epoch, total_epochs=args.total_epochs) ema_loss_dict = RuntimeEpoch( args, desc="Train", augmentation=training_augmentation, loader=train_loader, model_and_loss=model_and_loss, optimizer=optimizer, visualizer=visualizer).run(train=True) if visualizer is not None: visualizer.on_epoch_finished( ema_loss_dict, train=True, epoch=epoch, total_epochs=args.total_epochs) # ------------------------------------------- # Create and run a validation epoch # ------------------------------------------- if validation_loader is not None: if visualizer is not None: visualizer.on_epoch_init(lr, train=False, epoch=epoch, total_epochs=args.total_epochs) # --------------------------------------------------- # Construct holistic recorder for epoch # --------------------------------------------------- epoch_recorder = configure_holistic_epoch_recorder( args, epoch=epoch, loader=validation_loader) with torch.no_grad(): avg_loss_dict = RuntimeEpoch( args, desc="Valid", augmentation=validation_augmentation, loader=validation_loader, model_and_loss=model_and_loss, recorder=epoch_recorder, visualizer=visualizer).run(train=False) try: epoch_recorder.add_scalars("evaluation_losses", avg_loss_dict) except Exception: pass if visualizer is not None: visualizer.on_epoch_finished( avg_loss_dict, train=False, epoch=epoch, total_epochs=args.total_epochs) # ---------------------------------------------------------------- # Evaluate valdiation losses # ---------------------------------------------------------------- validation_losses = [ avg_loss_dict[vkey] for vkey in args.validation_keys ] for i, (vkey, vmode) in enumerate( zip(args.validation_keys, args.validation_modes)): if vmode == 'min': store_as_best[i] = validation_losses[ i] < best_validation_losses[i] else: store_as_best[i] = validation_losses[ i] > best_validation_losses[i] if store_as_best[i]: best_validation_losses[i] = validation_losses[i] # ---------------------------------------------------------------- # Update validation scheduler, if one is in place # We use the first key in validation keys as the relevant one # ---------------------------------------------------------------- if lr_scheduler is not None: if validation_scheduler: lr_scheduler.step(validation_losses[0], epoch=epoch) else: lr_scheduler.step(epoch=epoch) # ---------------------------------------------------------------- # Also show best loss on total_progress # ---------------------------------------------------------------- total_progress_stats = { "best_" + vkey + "_avg": "%1.4f" % best_validation_losses[i] for i, vkey in enumerate(args.validation_keys) } total_progress.set_postfix(total_progress_stats) # ---------------------------------------------------------------- # Bump total progress # ---------------------------------------------------------------- total_progress.update() print('') # ---------------------------------------------------------------- # Get ETA string for display in loggers # ---------------------------------------------------------------- eta_str = total_progress.eta_str() # ---------------------------------------------------------------- # Send Telegram status udpate # ---------------------------------------------------------------- total_progress_stats['lr'] = format_learning_rate(lr) logging.telegram( format_telegram_status_update( args, eta_str=eta_str, epoch=epoch, total_progress_stats=total_progress_stats)) # ---------------------------------------------------------------- # Update ETA in progress title # ---------------------------------------------------------------- eta_proctitle = "{} finishes in {}".format(args.proctitle, eta_str) proctitles.setproctitle(eta_proctitle) # ---------------------------------------------------------------- # Store checkpoint # ---------------------------------------------------------------- if checkpoint_saver is not None and validation_loader is not None: checkpoint_saver.save_latest( directory=args.save, model_and_loss=model_and_loss, stats_dict=dict(avg_loss_dict, epoch=epoch), store_as_best=store_as_best, store_prefixes=args.validation_keys) # ---------------------------------------------------------------- # Vertical space between epochs # ---------------------------------------------------------------- print(''), logging.logbook('') # ---------------------------------------------------------------- # Finish up # ---------------------------------------------------------------- logging.telegram_flush() total_progress.close() logging.info("Finished.")
def main(): # --------------------------------------------------- # Set working directory to folder containing main.py # --------------------------------------------------- os.chdir(os.path.dirname(os.path.realpath(__file__))) # ---------------------------------------------------------------- # Activate syntax highlighting in tracebacks for better debugging # ---------------------------------------------------------------- colored_traceback.add_hook() # ----------------------------------------------------------- # Configure logging # ----------------------------------------------------------- logging_filename = os.path.join(commandline.parse_save_dir(), constants.LOGGING_LOGBOOK_FILENAME) logger.configure_logging(logging_filename) # ---------------------------------------------------------------- # Register type factories before parsing the commandline. # NOTE: We decided to explicitly call these init() functions, to # have more precise control over the timeline # ---------------------------------------------------------------- with logging.block("Registering factories", emph=True): augmentations.init() datasets.init() losses.init() models.init() optim.init() visualizers.init() logging.info('Done!') # ----------------------------------------------------------- # Parse commandline after factories have been filled # ----------------------------------------------------------- args = commandline.parse_arguments(blocktitle="Commandline Arguments") # ----------------------- # Telegram configuration # ----------------------- with logging.block("Telegram", emph=True): logger.configure_telegram(constants.LOGGING_TELEGRAM_MACHINES_FILENAME) # ---------------------------------------------------------------------- # Log git repository hash and make a compressed copy of the source code # ---------------------------------------------------------------------- with logging.block("Source Code", emph=True): logging.value("Git Hash: ", system.git_hash()) # Zip source code and copy to save folder filename = os.path.join(args.save, constants.LOGGING_ZIPSOURCE_FILENAME) zipsource.create_zip(filename=filename, directory=os.getcwd()) logging.value("Archieved code: ", filename) # ---------------------------------------------------- # Change process title for `top` and `pkill` commands # This is more "informative" in `nvidia-smi` ;-) # ---------------------------------------------------- args = config.configure_proctitle(args) # ------------------------------------------------- # Set random seed for python, numpy, torch, cuda.. # ------------------------------------------------- config.configure_random_seed(args) # ----------------------------------------------------------- # Machine stats # ----------------------------------------------------------- with logging.block("Machine Statistics", emph=True): if args.cuda: args.device = torch.device("cuda:0") logging.value("Cuda: ", torch.version.cuda) logging.value("Cuda device count: ", torch.cuda.device_count()) logging.value("Cuda device name: ", torch.cuda.get_device_name(0)) logging.value("CuDNN: ", torch.backends.cudnn.version()) device_no = 0 if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): device_no = os.environ['CUDA_VISIBLE_DEVICES'] args.actual_device = "gpu:%s" % device_no else: args.device = torch.device("cpu") args.actual_device = "cpu" logging.value("Hostname: ", system.hostname()) logging.value("PyTorch: ", torch.__version__) logging.value("PyTorch device: ", args.actual_device) # ------------------------------------------------------ # Fetch data loaders. Quit if no data loader is present # ------------------------------------------------------ train_loader, validation_loader = config.configure_data_loaders(args) # ------------------------------------------------------------------------- # Check whether any dataset could be found # ------------------------------------------------------------------------- success = any(loader is not None for loader in [train_loader, validation_loader]) if not success: logging.info( "No dataset could be loaded successfully. Please check dataset paths!" ) quit() # ------------------------------------------------------------------------- # Configure runtime augmentations # ------------------------------------------------------------------------- training_augmentation, validation_augmentation = config.configure_runtime_augmentations( args) # ---------------------------------------------------------- # Configure model and loss. # ---------------------------------------------------------- model_and_loss = config.configure_model_and_loss(args) # -------------------------------------------------------- # Print model visualization # -------------------------------------------------------- if args.logging_model_graph: with logging.block("Model Graph", emph=True): logger.log_module_info(model_and_loss.model) if args.logging_loss_graph: with logging.block("Loss Graph", emph=True): logger.log_module_info(model_and_loss.loss) # ------------------------------------------------------------------------- # Possibly resume from checkpoint # ------------------------------------------------------------------------- checkpoint_saver, checkpoint_stats = config.configure_checkpoint_saver( args, model_and_loss) if checkpoint_stats is not None: with logging.block(): logging.info("Checkpoint Statistics:") with logging.block(): logging.values(checkpoint_stats) # --------------------------------------------------------------------- # Set checkpoint stats # --------------------------------------------------------------------- if args.checkpoint_mode in ["resume_from_best", "resume_from_latest"]: args.start_epoch = checkpoint_stats["epoch"] # --------------------------------------------------------------------- # Checkpoint and save directory # --------------------------------------------------------------------- with logging.block("Save Directory", emph=True): if args.save is None: logging.info("No 'save' directory specified!") quit() logging.value("Save directory: ", args.save) if not os.path.exists(args.save): os.makedirs(args.save) # ------------------------------------------------------------ # If this is just an evaluation: overwrite savers and epochs # ------------------------------------------------------------ if args.training_dataset is None and args.validation_dataset is not None: args.start_epoch = 1 args.total_epochs = 1 train_loader = None checkpoint_saver = None args.optimizer = None args.lr_scheduler = None # ---------------------------------------------------- # Tensorboard summaries # ---------------------------------------------------- logger.configure_tensorboard_summaries(args.save) # ------------------------------------------------------------------- # From PyTorch API: # If you need to move a model to GPU via .cuda(), please do so before # constructing optimizers for it. Parameters of a model after .cuda() # will be different objects with those before the call. # In general, you should make sure that optimized parameters live in # consistent locations when optimizers are constructed and used. # ------------------------------------------------------------------- model_and_loss = model_and_loss.to(args.device) # ---------------------------------------------------------- # Configure optimizer # ---------------------------------------------------------- optimizer = config.configure_optimizer(args, model_and_loss) # ---------------------------------------------------------- # Configure learning rate # ---------------------------------------------------------- lr_scheduler = config.configure_lr_scheduler(args, optimizer) # -------------------------------------------------------------------------- # Configure parameter scheduling # -------------------------------------------------------------------------- param_scheduler = config.configure_parameter_scheduler( args, model_and_loss) # quit() # ---------------------------------------------------------- # Cuda optimization # ---------------------------------------------------------- if args.cuda: torch.backends.cudnn.benchmark = constants.CUDNN_BENCHMARK # ---------------------------------------------------------- # Configurate runtime visualization # ---------------------------------------------------------- visualizer = config.configure_visualizers( args, model_and_loss=model_and_loss, optimizer=optimizer, param_scheduler=param_scheduler, lr_scheduler=lr_scheduler, train_loader=train_loader, validation_loader=validation_loader) if visualizer is not None: visualizer = visualizer.to(args.device) # ---------------------------------------------------------- # Kickoff training, validation and/or testing # ---------------------------------------------------------- return runtime.exec_runtime( args, checkpoint_saver=checkpoint_saver, lr_scheduler=lr_scheduler, param_scheduler=param_scheduler, model_and_loss=model_and_loss, optimizer=optimizer, train_loader=train_loader, training_augmentation=training_augmentation, validation_augmentation=validation_augmentation, validation_loader=validation_loader, visualizer=visualizer)
def configure_logging(filename): # set global indent level sys.modules[__name__].global_indent = 0 # add custom tqdm logger add_logging_level("LOGBOOK", 1000) # create logger root_logger = logging.getLogger("") root_logger.setLevel(logging.INFO) # create console handler console = logging.StreamHandler() console.setLevel(logging.INFO) fmt = get_default_logging_format(colorize=True, brackets=False) datefmt = constants.LOGGING_TIMESTAMP_FORMAT formatter = ConsoleFormatter(fmt=fmt, datefmt=datefmt) console.setFormatter(formatter) # Skip logging.tqdm requests for console outputs skip_logbook_filter = SkipLogbookFilter() console.addFilter(skip_logbook_filter) # add console to root_logger root_logger.addHandler(console) # Show warnings in logger logging.captureWarnings(True) def _log_key_value_pair(key, value=None): if value is None: logging.info("{}{}".format(constants.COLOR_KEY_VALUE, str(key))) else: logging.info("{}{}{}".format(key, constants.COLOR_KEY_VALUE, str(value))) def _log_dict(indict): for key, value in sorted(indict.items()): logging.info("{}: {}{}".format(key, constants.COLOR_KEY_VALUE, str(value))) # this is for logging key value pairs or dictionaries setattr(logging, "value", _log_key_value_pair) setattr(logging, "values", _log_dict) # this is for logging blocks setattr(logging, "block", LoggingBlock) # add logbook if filename is not None: # ensure dir d = os.path.dirname(filename) if not os.path.exists(d): os.makedirs(d) with logging.block("Creating Logbook", emph=True): logging.info(filename) # -------------------------------------------------------------------------- # Configure handler that removes color codes from logbook # -------------------------------------------------------------------------- logbook = logging.FileHandler(filename=filename, mode="a", encoding="utf-8") logbook.setLevel(logging.INFO) fmt = get_default_logging_format(colorize=False, brackets=True) logbook_formatter = LogbookFormatter(fmt=fmt, datefmt=datefmt) logbook.setFormatter(logbook_formatter) root_logger.addHandler(logbook)
def configure_data_loaders(args): with logging.block("Datasets", emph=True): def _sizes_to_str(value): if np.isscalar(value): return '1L' else: sizes = str([d for d in value.size()]) return ' '.join([strings.replace_index(sizes, 1, '#')]) def _log_statistics(loader, dataset): example_dict = loader.first_item( ) # get sizes from first dataset example for key, value in sorted(example_dict.items()): if key == "index" or "name" in key: # no need to display these continue if isinstance(value, str): logging.value("%s: " % key, value) elif isinstance(value, list) or isinstance(value, tuple): logging.value("%s: " % key, _sizes_to_str(value[0])) else: logging.value("%s: " % key, _sizes_to_str(value)) logging.value("num_examples: ", len(dataset)) # ----------------------------------------------------------------------------------------- # GPU parameters # ----------------------------------------------------------------------------------------- gpuargs = { "pin_memory": constants.DATALOADER_PIN_MEMORY } if args.cuda else {} train_loader_and_collation = None validation_loader_and_collation = None # ----------------------------------------------------------------- # This figures out from the args alone, whether we need batch collcation # ----------------------------------------------------------------- train_collation, validation_collation = configure_collation(args) # ----------------------------------------------------------------------------------------- # Training dataset # ----------------------------------------------------------------------------------------- if args.training_dataset is not None: # ---------------------------------------------- # Figure out training_dataset arguments # ---------------------------------------------- kwargs = typeinf.kwargs_from_args(args, "training_dataset") kwargs["args"] = args # ---------------------------------------------- # Create training dataset and loader # ---------------------------------------------- logging.value("Training Dataset: ", args.training_dataset) with logging.block(): train_dataset = typeinf.instance_from_kwargs( args.training_dataset_class, kwargs=kwargs) if args.batch_size > len(train_dataset): logging.info( "Problem: batch_size bigger than number of training dataset examples!" ) quit() train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=constants.TRAINING_DATALOADER_SHUFFLE, drop_last=constants.TRAINING_DATALOADER_DROP_LAST, num_workers=args.training_dataset_num_workers, **gpuargs) train_loader_and_collation = facade.LoaderAndCollation( args, loader=train_loader, collation=train_collation) _log_statistics(train_loader_and_collation, train_dataset) # ----------------------------------------------------------------------------------------- # Validation dataset # ----------------------------------------------------------------------------------------- if args.validation_dataset is not None: # ---------------------------------------------- # Figure out validation_dataset arguments # ---------------------------------------------- kwargs = typeinf.kwargs_from_args(args, "validation_dataset") kwargs["args"] = args # ------------------------------------------------------ # per default batch_size is the same as for training, # unless a validation_batch_size is specified. # ----------------------------------------------------- validation_batch_size = args.batch_size if args.validation_batch_size > 0: validation_batch_size = args.validation_batch_size # ---------------------------------------------- # Create validation dataset and loader # ---------------------------------------------- logging.value("Validation Dataset: ", args.validation_dataset) with logging.block(): validation_dataset = typeinf.instance_from_kwargs( args.validation_dataset_class, kwargs=kwargs) if validation_batch_size > len(validation_dataset): logging.info( "Problem: validation_batch_size bigger than number of validation dataset examples!" ) quit() validation_loader = DataLoader( validation_dataset, batch_size=validation_batch_size, shuffle=constants.VALIDATION_DATALOADER_SHUFFLE, drop_last=constants.VALIDATION_DATALOADER_DROP_LAST, num_workers=args.validation_dataset_num_workers, **gpuargs) validation_loader_and_collation = facade.LoaderAndCollation( args, loader=validation_loader, collation=validation_collation) _log_statistics(validation_loader_and_collation, validation_dataset) return train_loader_and_collation, validation_loader_and_collation