Beispiel #1
0
 def path(name):
     if self._restore_path is not None and not self._use_new_dir:
         # - we don't want to overwrite the previous args
         dir = os.path.join(savedir,
                            'args_restore_%s' % self._timestring)
         safe_mkdirs(dir, exist_ok=True)
         return os.path.join(dir, name)
     else:
         return os.path.join(savedir, name)
Beispiel #2
0
    def save_checkpoint(self, path=None, filename=None, is_best=False):
        if filename is None:
            filename = self._epoch

        if isinstance(filename, int):
            filename = self.checkpoints_format.format(filename)

        if path is None:
            path = self.checkpoints_dir

        if is_best:
            path = self.best_checkpoints_dir

        safe_mkdirs(path, exist_ok=True)

        try:
            filename = os.path.join(path, filename)

            current_state_dict = {
                'global_step': self._global_step,
                'epoch': self._epoch,
                'best_epoch': self.best_epoch,
                'beaten_epochs': self.beaten_epochs,
                'best_epoch_score': self.best_epoch_score,
                'best_stats': self.best_stats,
                'model_state_dict': self._model.state_dict(),
            }

            # -- there might be more than one optimizer
            if is_dict(self._model.optimizer):
                optimizer_state_dict = {}
                for key, opt in self._model.optimizer.items():
                    optimizer_state_dict.update({key: opt.state_dict()})
            else:
                optimizer_state_dict = self._model.optimizer.state_dict()

            current_state_dict.update(
                {'optimizer_state_dict': optimizer_state_dict})

            torch.save(current_state_dict, filename)

            # -- track filename to delete after,
            # if keep_only_last_checkpoint is set true
            if not is_best and 'init' not in filename:
                self.last_checkpoint = filename

        except Exception as e:
            self.console_log.error(
                "Error occurred while saving the checkpoint: %s", e)
        return filename
Beispiel #3
0
    def save_results(self, outputs, name, idx=''):
        try:
            path = os.path.join(self.results_dir, name)
            filename = os.path.join(path, "%s%s.pt" % (name, idx))
            safe_mkdirs(path)

            results = {
                'global_step': self.global_step,
                'epoch': self.epoch,
                'best_epoch': self.best_epoch,
                'beaten_epochs': self.beaten_epochs,
                'best_epoch_score': self.best_epoch_score,
                'best_stats': self.best_stats,
                'outputs': outputs
            }
            torch.save(results, filename)

        except Exception as e:
            self.console_log.error(
                "Error occurred while saving results into : %s", e)
Beispiel #4
0
    def configure_loggers(self, external_logdir=None):
        """
        YAPT supports logging experiments with multiple loggers at the same time.
        By default, an experiment is logged by TensorBoardLogger.

        external_logdir: if you want to override the logdir.
            It could be useful to use the same directory used by Ray Tune.

        """
        if external_logdir is not None:
            self.args.loggers.logdir = self._logdir = external_logdir
            self.console_log.warning(
                "external logdir {}".format(external_logdir))

            if self.args.loggers.debug:
                self.console_log.warning(
                    "Debug flag is disable when external logdir is used")

        # -- No loggers in test mode
        if self.mode == 'test':
            return None

        args_logger = self.args.loggers
        loggers = dict()

        safe_mkdirs(self._logdir, exist_ok=True)

        # -- Tensorboard is not defualt anymore
        if args_logger.tensorboard:
            loggers['tb'] = TensorBoardLogger(self._logdir)

        # -- Neptune
        if (get_maybe_missing_args(args_logger, 'neptune') is not None
                and len(args_logger.neptune.keys()) > 0):
            # TODO: because of api key and sesitive data,
            # neptune project should be per_project in a separate file

            # TODO: THIS THIS SHOULD BE DONE FOR EACH LEAF
            args_neptune = dict()
            for key, val in args_logger.neptune.items():
                if isinstance(val, ListConfig):
                    val = list(val)
                elif isinstance(val, DictConfig):
                    val = dict(val)
                args_neptune[key] = val

            # -- Recursively search for files or extensions
            if 'upload_source_files' in args_neptune.keys():
                source_files = [
                    str(path) for ext in args_neptune['upload_source_files']
                    for path in Path('./').rglob(ext)
                ]
                del args_neptune['upload_source_files']
            else:
                source_files = None

            loggers['neptune'] = NeptuneLogger(
                api_key=os.environ['NEPTUNE_API_TOKEN'],
                experiment_name=self.args.exp_name,
                params=flatten_dict(self.args),
                logger=self.console_log,
                upload_source_files=source_files,
                **(args_neptune))

        # Wrap loggers
        loggers = LoggerDict(loggers)
        return loggers
Beispiel #5
0
 def videos_dir(self):
     _results_dir = os.path.join(self._logdir, 'videos')
     safe_mkdirs(_results_dir, True)
     return _results_dir
Beispiel #6
0
 def images_dir(self):
     _results_dir = os.path.join(self._logdir, 'images')
     safe_mkdirs(_results_dir, True)
     return _results_dir
Beispiel #7
0
 def plots_dir(self):
     _results_dir = os.path.join(self._logdir, 'plots')
     safe_mkdirs(_results_dir, True)
     return _results_dir
Beispiel #8
0
 def best_checkpoints_dir(self):
     _checkpoints_dir = os.path.join(self.checkpoints_dir, 'best')
     safe_mkdirs(_checkpoints_dir, True)
     return _checkpoints_dir
Beispiel #9
0
 def checkpoints_dir(self):
     _checkpoints_dir = os.path.join(self._logdir, 'checkpoints')
     safe_mkdirs(_checkpoints_dir, True)
     return _checkpoints_dir
Beispiel #10
0
 def inputs_dir(self):
     _inputs_dir = os.path.join(self.debug_dir, 'inputs')
     safe_mkdirs(_inputs_dir, True)
     return _inputs_dir
Beispiel #11
0
 def debug_dir(self):
     _debug_dir = os.path.join(self._logdir, 'debug')
     safe_mkdirs(_debug_dir, True)
     return _debug_dir
Beispiel #12
0
    def train_epoch(self, dataloader):

        # -- Call Optimizer schedulers and track lr
        if self._epoch > 1:
            # -- Pytorch 1.1.0 requires to call first optimizer.step()
            self.call_schedulers_optimizers()
        self.log_each_epoch()

        # -- Initialization and training mode
        # Enters train mode
        self._model.train()
        # Zero the parameter gradients
        self._model.zero_grad()
        self._model.zero_grad.calls -= 1
        # Track training statistics
        self._model.init_train_stats()

        pbar_descr_prefix = "ep %d (best: %d beaten: %d) - " % (
            self._epoch, self.best_epoch, self.beaten_epochs)

        self._train_pbar = tqdm(
            dataloader, total=self.num_batches_train,
            desc='', **self.args.loggers.tqdm)
        try:
            # -- Start epoch
            outputs = None
            accum_stats = []
            self._model.on_epoch_start()
            for batch_idx, batch in enumerate(self._train_pbar):
                if batch_idx >= self.num_batches_train:
                    break

                if self.args.debug.save_inputs:
                    # epoch
                    path = os.path.join(self.inputs_dir, 'epoch_{}'.format(self.epoch))
                    safe_mkdirs(path, exist_ok=True)
                    filename = os.path.join(path, 'batch_{}.pt'.format(batch_idx))
                    torch.save(batch, filename)

                device_batch = self.to_device(batch)

                # -- Model specific schedulers
                self._model.custom_schedulers()

                # -- Execute a training step
                outputs = self._model.training_step(
                    device_batch)

                # -- Accumulate stats from grads accum steps
                accum_stats.append(outputs.get('stats', dict()))

                # -- Save output for each training step
                # self.collect_outputs(outputs)

                if self._model._train_step % self.args.accum_batches == 0:
                    # -- Increment the global step only
                    # every accum_batches batches
                    self._global_step += 1

                    # -- Aggregate stats from grads accum steps
                    stats = self._model.aggregate_accum_stats(accum_stats)
                    accum_stats = []

                    # -- Eventually log statistics
                    if self._global_step % self.log_every == 0:
                        self._model.log_train(stats)
                        self._model.log_grads()

                running_tqdm = outputs.get('running_tqdm', dict())
                # self._train_pbar.set_postfix(ordered_dict=running_tqdm)
                self._train_pbar.set_description("ep %d - %s" % (
                    self.epoch, stats_to_str(running_tqdm)))
                self._train_pbar.update()

            self.console_log.info("Processed {} batches.".format(batch_idx))
            self._train_pbar.clear()
            self._train_pbar.close()

        except KeyboardInterrupt:
            self.console_log.info('Detected KeyboardInterrupt, attempting graceful shutdown...')
            self.shutdown()

        # -- End Epoch
        self._model.on_epoch_end()
        self._model.reset_train_stats()

        if outputs is not None:
            final_tqdm = outputs.get('final_tqdm', dict())
            print(pbar_descr_prefix + stats_to_str(final_tqdm))
        return outputs