예제 #1
0
    def __init__(self,
                 ms_config_p, dl_config_p,
                 log_dir_root, log_config: LogConfig,
                 num_workers,
                 saver: Saver, restorer: TrainRestorer=None,
                 sw_cls=vis.safe_summary_writer.SafeSummaryWriter):
        """
        :param ms_config_p: Path to the multiscale config file, see README
        :param dl_config_p: Path to the dataloader config file, see README
        :param log_dir_root: All outputs (checkpoints, tensorboard) will be saved here.
        :param log_config: Instance of train.trainer.LogConfig, contains intervals.
        :param num_workers: Number of workers to use for DataLoading, see train.py
        :param saver: Saver instance to use.
        :param restorer: Instance of TrainRestorer, if we need to restore
        """

        # Read configs
        # config_ms = config for the network (ms = multiscale)
        # config_dl = config for data loading
        (self.config_ms, self.config_dl), rel_paths = ft.unzip(map(config_parser.parse, [ms_config_p, dl_config_p]))
        # Update config_ms depending on global_config
        global_config.update_config(self.config_ms)
        # Create data loaders
        dl_train, dl_val = self._get_dataloaders(num_workers)
        # Create blueprint. A blueprint collects the network as well as the losses in one class, for easy reuse
        # during testing.
        self.blueprint = MultiscaleBlueprint(self.config_ms)
        print('Network:', self.blueprint.net)
        # Setup optimizer
        optim_cls = {'RMSprop': optim.RMSprop,
                     'Adam': optim.Adam,
                     'SGD': optim.SGD,
                     }[self.config_ms.optim]
        net = self.blueprint.net
        self.optim = optim_cls(net.parameters(), self.config_ms.lr.initial,
                               weight_decay=self.config_ms.weight_decay)
        # Calculate a rough estimate for time per batch (does not take into account that CUDA is async,
        # but good enought to get a feeling during training).
        self.time_accumulator = timer.TimeAccumulator()
        # Restore network if requested
        skip_to_itr = self.maybe_restore(restorer)
        if skip_to_itr is not None:  # i.e., we have a restorer
            print('Skipping to {}...'.format(skip_to_itr))
        # Create LR schedule to update parameters
        self.lr_schedule = lr_schedule.from_spec(
                self.config_ms.lr.schedule, self.config_ms.lr.initial, [self.optim], epoch_len=len(dl_train))

        # --- All nn.Modules are setup ---
        print('-' * 80)

        # create log dir and summary writer
        self.log_dir = Trainer.get_log_dir(log_dir_root, rel_paths, restorer)
        self.log_date = logdir_helpers.log_date_from_log_dir(self.log_dir)
        self.ckpt_dir = os.path.join(self.log_dir, CKPTS_DIR_NAME)
        print(f'Checkpoints will be saved to {self.ckpt_dir}')
        saver.set_out_dir(self.ckpt_dir)


        # Create summary writer
        sw = sw_cls(self.log_dir)
        self.summarizer = vis.summarizable_module.Summarizer(sw)
        net.register_summarizer(self.summarizer)
        self.blueprint.register_summarizer(self.summarizer)
        # superclass setup
        super(MultiscaleTrainer, self).__init__(dl_train, dl_val, [self.optim], net, sw,
                                                max_epochs=self.config_dl.max_epochs,
                                                log_config=log_config, saver=saver, skip_to_itr=skip_to_itr)
    def __init__(self,
                 config_p,
                 dl_config_p,
                 log_dir_root,
                 log_config: LogConfig,
                 num_workers,
                 saver: Saver,
                 restorer: TrainRestorer = None,
                 sw_cls=vis.safe_summary_writer.SafeSummaryWriter):
        """
        :param config_p: Path to the network config file, see README
        :param dl_config_p: Path to the dataloader config file, see README
        :param log_dir_root: All outputs (checkpoints, tensorboard) will be saved here.
        :param log_config: Instance of train.trainer.LogConfig, contains intervals.
        :param num_workers: Number of workers to use for DataLoading, see train.py
        :param saver: Saver instance to use.
        :param restorer: Instance of TrainRestorer, if we need to restore
        """
        self.style = MultiscaleTrainer.get_style_from_config(config_p)
        self.blueprint_cls = {
            'enhancement': EnhancementBlueprint,
            'classifier': ClassifierBlueprint
        }[self.style]

        global_config.declare_used('filter_imgs')

        # Read configs
        # config = config for the network
        # config_dl = config for data loading
        (self.config, self.config_dl), rel_paths = ft.unzip(
            map(config_parser.parse, [config_p, dl_config_p]))
        # TODO only read by enhancement classes
        self.config.is_residual = self.config_dl.is_residual_dataset

        # Update global_config given config.global_config
        global_config_config_keys = global_config.add_from_str_without_overwriting(
            self.config.global_config)
        # Update config_ms depending on global_config
        global_config.update_config(self.config)

        if self.style == 'enhancement':
            EnhancementBlueprint.read_evenly_spaced_bins(self.config_dl)

        self._custom_init()

        # Create data loaders
        dl_train, self.ds_val, self.fixed_first_val = self._get_dataloaders(
            num_workers)
        # Create blueprint. A blueprint collects the network as well as the losses in one class, for easy reuse
        # during testing.
        self.blueprint = self.blueprint_cls(self.config)
        print('Network:', self.blueprint.net)
        # Setup optimizer
        optim_cls = {
            'RMSprop': optim.RMSprop,
            'Adam': optim.Adam,
            'SGD': optim.SGD,
        }[self.config.optim]
        net = self.blueprint.net
        self.optim = optim_cls(net.parameters(),
                               self.config.lr.initial,
                               weight_decay=self.config.weight_decay)
        # Calculate a rough estimate for time per batch (does not take into account that CUDA is async,
        # but good enought to get a feeling during training).
        self.time_accumulator = timer.TimeAccumulator()
        # Restore network if requested
        skip_to_itr = self.maybe_restore(restorer)
        if skip_to_itr is not None:  # i.e., we have a restorer
            print('Skipping to {}...'.format(skip_to_itr))
        # Create LR schedule to update parameters
        self.lr_schedule = lr_schedule.from_spec(self.config.lr.schedule,
                                                 self.config.lr.initial,
                                                 [self.optim],
                                                 epoch_len=len(dl_train))

        # --- All nn.Modules are setup ---
        print('-' * 80)

        # create log dir and summary writer
        self.log_dir_root = log_dir_root
        global_config_values = global_config.values(
            ignore=global_config_config_keys)
        self.log_dir = Trainer.get_log_dir(
            log_dir_root,
            rel_paths,
            restorer,
            global_config_values=global_config_values)
        self.log_date = logdir_helpers.log_date_from_log_dir(self.log_dir)
        self.ckpt_dir = os.path.join(self.log_dir, CKPTS_DIR_NAME)
        print(f'Checkpoints will be saved to {self.ckpt_dir}')
        saver.set_out_dir(self.ckpt_dir)

        if global_config.get('ds_syn', None):
            underlying = dl_train.dataset
            while not isinstance(underlying, _CheckerboardDataset):
                underlying = underlying.ds
            underlying.save_all(self.log_dir)

        # Create summary writer
        sw = sw_cls(self.log_dir)
        self.summarizer = vis.summarizable_module.Summarizer(sw)
        net.register_summarizer(self.summarizer)
        self.blueprint.register_summarizer(self.summarizer)

        # Try to write filenames somewhere
        try:
            dl_train.dataset.write_file_names_to_txt(self.log_dir)
        except AttributeError:
            raise AttributeError(
                f'dl_train.dataset of type {type(dl_train.dataset)} does not support '
                f'write_file_names_to_txt(log_dir)!')

        # superclass setup
        super(MultiscaleTrainer,
              self).__init__(dl_train, [self.optim],
                             net,
                             sw,
                             max_epochs=self.config_dl.max_epochs,
                             log_config=log_config,
                             saver=saver,
                             skip_to_itr=skip_to_itr)