def __init__(self, ms_config_p, dl_config_p, log_dir_root, log_config: LogConfig, num_workers, saver: Saver, restorer: TrainRestorer=None, sw_cls=vis.safe_summary_writer.SafeSummaryWriter): """ :param ms_config_p: Path to the multiscale config file, see README :param dl_config_p: Path to the dataloader config file, see README :param log_dir_root: All outputs (checkpoints, tensorboard) will be saved here. :param log_config: Instance of train.trainer.LogConfig, contains intervals. :param num_workers: Number of workers to use for DataLoading, see train.py :param saver: Saver instance to use. :param restorer: Instance of TrainRestorer, if we need to restore """ # Read configs # config_ms = config for the network (ms = multiscale) # config_dl = config for data loading (self.config_ms, self.config_dl), rel_paths = ft.unzip(map(config_parser.parse, [ms_config_p, dl_config_p])) # Update config_ms depending on global_config global_config.update_config(self.config_ms) # Create data loaders dl_train, dl_val = self._get_dataloaders(num_workers) # Create blueprint. A blueprint collects the network as well as the losses in one class, for easy reuse # during testing. self.blueprint = MultiscaleBlueprint(self.config_ms) print('Network:', self.blueprint.net) # Setup optimizer optim_cls = {'RMSprop': optim.RMSprop, 'Adam': optim.Adam, 'SGD': optim.SGD, }[self.config_ms.optim] net = self.blueprint.net self.optim = optim_cls(net.parameters(), self.config_ms.lr.initial, weight_decay=self.config_ms.weight_decay) # Calculate a rough estimate for time per batch (does not take into account that CUDA is async, # but good enought to get a feeling during training). self.time_accumulator = timer.TimeAccumulator() # Restore network if requested skip_to_itr = self.maybe_restore(restorer) if skip_to_itr is not None: # i.e., we have a restorer print('Skipping to {}...'.format(skip_to_itr)) # Create LR schedule to update parameters self.lr_schedule = lr_schedule.from_spec( self.config_ms.lr.schedule, self.config_ms.lr.initial, [self.optim], epoch_len=len(dl_train)) # --- All nn.Modules are setup --- print('-' * 80) # create log dir and summary writer self.log_dir = Trainer.get_log_dir(log_dir_root, rel_paths, restorer) self.log_date = logdir_helpers.log_date_from_log_dir(self.log_dir) self.ckpt_dir = os.path.join(self.log_dir, CKPTS_DIR_NAME) print(f'Checkpoints will be saved to {self.ckpt_dir}') saver.set_out_dir(self.ckpt_dir) # Create summary writer sw = sw_cls(self.log_dir) self.summarizer = vis.summarizable_module.Summarizer(sw) net.register_summarizer(self.summarizer) self.blueprint.register_summarizer(self.summarizer) # superclass setup super(MultiscaleTrainer, self).__init__(dl_train, dl_val, [self.optim], net, sw, max_epochs=self.config_dl.max_epochs, log_config=log_config, saver=saver, skip_to_itr=skip_to_itr)
def __init__(self, config_p, dl_config_p, log_dir_root, log_config: LogConfig, num_workers, saver: Saver, restorer: TrainRestorer = None, sw_cls=vis.safe_summary_writer.SafeSummaryWriter): """ :param config_p: Path to the network config file, see README :param dl_config_p: Path to the dataloader config file, see README :param log_dir_root: All outputs (checkpoints, tensorboard) will be saved here. :param log_config: Instance of train.trainer.LogConfig, contains intervals. :param num_workers: Number of workers to use for DataLoading, see train.py :param saver: Saver instance to use. :param restorer: Instance of TrainRestorer, if we need to restore """ self.style = MultiscaleTrainer.get_style_from_config(config_p) self.blueprint_cls = { 'enhancement': EnhancementBlueprint, 'classifier': ClassifierBlueprint }[self.style] global_config.declare_used('filter_imgs') # Read configs # config = config for the network # config_dl = config for data loading (self.config, self.config_dl), rel_paths = ft.unzip( map(config_parser.parse, [config_p, dl_config_p])) # TODO only read by enhancement classes self.config.is_residual = self.config_dl.is_residual_dataset # Update global_config given config.global_config global_config_config_keys = global_config.add_from_str_without_overwriting( self.config.global_config) # Update config_ms depending on global_config global_config.update_config(self.config) if self.style == 'enhancement': EnhancementBlueprint.read_evenly_spaced_bins(self.config_dl) self._custom_init() # Create data loaders dl_train, self.ds_val, self.fixed_first_val = self._get_dataloaders( num_workers) # Create blueprint. A blueprint collects the network as well as the losses in one class, for easy reuse # during testing. self.blueprint = self.blueprint_cls(self.config) print('Network:', self.blueprint.net) # Setup optimizer optim_cls = { 'RMSprop': optim.RMSprop, 'Adam': optim.Adam, 'SGD': optim.SGD, }[self.config.optim] net = self.blueprint.net self.optim = optim_cls(net.parameters(), self.config.lr.initial, weight_decay=self.config.weight_decay) # Calculate a rough estimate for time per batch (does not take into account that CUDA is async, # but good enought to get a feeling during training). self.time_accumulator = timer.TimeAccumulator() # Restore network if requested skip_to_itr = self.maybe_restore(restorer) if skip_to_itr is not None: # i.e., we have a restorer print('Skipping to {}...'.format(skip_to_itr)) # Create LR schedule to update parameters self.lr_schedule = lr_schedule.from_spec(self.config.lr.schedule, self.config.lr.initial, [self.optim], epoch_len=len(dl_train)) # --- All nn.Modules are setup --- print('-' * 80) # create log dir and summary writer self.log_dir_root = log_dir_root global_config_values = global_config.values( ignore=global_config_config_keys) self.log_dir = Trainer.get_log_dir( log_dir_root, rel_paths, restorer, global_config_values=global_config_values) self.log_date = logdir_helpers.log_date_from_log_dir(self.log_dir) self.ckpt_dir = os.path.join(self.log_dir, CKPTS_DIR_NAME) print(f'Checkpoints will be saved to {self.ckpt_dir}') saver.set_out_dir(self.ckpt_dir) if global_config.get('ds_syn', None): underlying = dl_train.dataset while not isinstance(underlying, _CheckerboardDataset): underlying = underlying.ds underlying.save_all(self.log_dir) # Create summary writer sw = sw_cls(self.log_dir) self.summarizer = vis.summarizable_module.Summarizer(sw) net.register_summarizer(self.summarizer) self.blueprint.register_summarizer(self.summarizer) # Try to write filenames somewhere try: dl_train.dataset.write_file_names_to_txt(self.log_dir) except AttributeError: raise AttributeError( f'dl_train.dataset of type {type(dl_train.dataset)} does not support ' f'write_file_names_to_txt(log_dir)!') # superclass setup super(MultiscaleTrainer, self).__init__(dl_train, [self.optim], net, sw, max_epochs=self.config_dl.max_epochs, log_config=log_config, saver=saver, skip_to_itr=skip_to_itr)