示例#1
0
 def test_step(self, *args, **kwargs) -> dict:
     self.eval()
     outputs = self._test_step(*args, **kwargs)
     assert is_dict(outputs), "Output of _test_step should be a dict"
     # Hopefully avoid any memory leak on gpu
     outputs = detach_dict(outputs)
     outputs = to_device(outputs, 'cpu')
     return outputs
示例#2
0
 def training_step(self, batch, *args, **kwargs) -> dict:
     self.train()
     outputs = self._training_step(batch, *args, **kwargs)
     assert is_dict(outputs), "Output of _training_step should be a dict"
     # Hopefully avoid any memory leak on gpu
     outputs = detach_dict(outputs)
     outputs = to_device(outputs, 'cpu')
     self._train_step += 1
     return outputs
示例#3
0
    def _check_optimizer_accum(self):
        """
        Check if helper functions for `loss.backward()`, `optim.step()` and
        `optim.zero_grad()` are used instead of the pytorch originals when
        gradient accumulation or clipping is used.
        """
        # If not already, convert into a dict for convenience
        optim_dict = (self.optimizer if is_dict(self.optimizer) else {
            'optimizer': self.optimizer
        })

        clip_val = self.args.grads_norm_clip.max_norm
        accum_batches = self.args.accum_batches
        if clip_val > 0 or accum_batches > 1:
            # -- Warning messages stuff
            warn_pattern = ("{1} has been called directly while using {0}, "
                            "which will prevent it from working. Please use "
                            "{2} instead.")
            used_methods = {
                'grads clipping': clip_val > 0,
                'grads accumulation': accum_batches > 1
            }
            methods = "".join([
                "%s and " % name for name, active in used_methods.items()
                if active
            ])[:-len(" and ")]

            # -- Check if optimizer.step() has been directly called
            steps_calls = [o.step.calls for o in optim_dict.values()]
            if max(steps_calls) > 0:
                self.console_log.warning(
                    warn_pattern.format(methods, "optimizer.step()",
                                        "self.optimize_step(optimizer)"))

            # -- Check if optimizer.zero_grad() has been directly called
            zero_calls = [o.zero_grad.calls for o in optim_dict.values()]
            if max(zero_calls) > 0:
                self.console_log.warning(
                    warn_pattern.format(methods, "optimizer.zero_grad()",
                                        "self.zero_grad_step(optimizer)"))

            # -- Check if model.zero_grad() has been directly called
            if self.zero_grad.calls > 0:
                self.console_log.warning(
                    warn_pattern.format(methods, "self.zero_grad()",
                                        "self.zero_grad_step(optimizer)"))

            # -- Check if self.compute_grad_step has been called at least once
            if self.compute_grad_step.calls == 0:
                self.console_log.warning(
                    warn_pattern.format(methods, "loss.backward()",
                                        "self.compute_grad_step(loss)"))
示例#4
0
文件: base.py 项目: marcociccone/yapt
    def call_schedulers_optimizers(self):

        schedulers = self._model.scheduler_optimizer

        if isinstance(schedulers, torch.optim.lr_scheduler._LRScheduler):
            schedulers.step()

        elif is_dict(schedulers):
            for _, scheduler in schedulers.items():
                scheduler.step()
        else:
            raise ValueError("optimizers_schedulers should be a \
                dict or a torch.optim.lr_scheduler._LRScheduler object")
示例#5
0
    def save_checkpoint(self, path=None, filename=None, is_best=False):
        if filename is None:
            filename = self._epoch

        if isinstance(filename, int):
            filename = self.checkpoints_format.format(filename)

        if path is None:
            path = self.checkpoints_dir

        if is_best:
            path = self.best_checkpoints_dir

        safe_mkdirs(path, exist_ok=True)

        try:
            filename = os.path.join(path, filename)

            current_state_dict = {
                'global_step': self._global_step,
                'epoch': self._epoch,
                'best_epoch': self.best_epoch,
                'beaten_epochs': self.beaten_epochs,
                'best_epoch_score': self.best_epoch_score,
                'best_stats': self.best_stats,
                'model_state_dict': self._model.state_dict(),
            }

            # -- there might be more than one optimizer
            if is_dict(self._model.optimizer):
                optimizer_state_dict = {}
                for key, opt in self._model.optimizer.items():
                    optimizer_state_dict.update({key: opt.state_dict()})
            else:
                optimizer_state_dict = self._model.optimizer.state_dict()

            current_state_dict.update(
                {'optimizer_state_dict': optimizer_state_dict})

            torch.save(current_state_dict, filename)

            # -- track filename to delete after,
            # if keep_only_last_checkpoint is set true
            if not is_best and 'init' not in filename:
                self.last_checkpoint = filename

        except Exception as e:
            self.console_log.error(
                "Error occurred while saving the checkpoint: %s", e)
        return filename
示例#6
0
    def _train(self):
        args = self.args
        if args.dry_run:
            print(args.extra_args.pretty())
            self.stop()
            return {}

        # -- Training epoch
        # TODO: now labelled is hardcoded, make it general
        self._runner.train_epoch(self._runner.train_loader['labelled'])
        self.epoch = self._runner.epoch

        # -- Validate over all datasets
        val_outputs_flat = OrderedDict()
        for key_loader, val_loader in self._runner.val_loader.items():
            num_batches = self._runner.num_batches_val[key_loader]
            # -- Validation on val_loader
            outputs = self._runner.validate(val_loader,
                                            num_batches=num_batches,
                                            set_name=key_loader,
                                            logger=self._runner.logger)

            # -- TODO: flatten_Dict should not be necessary,
            # -- prefix key is already concatenated in validate method
            # -- collect and return flatten metrics
            for key_stats, val_stats in outputs['stats'].items():
                if is_dict(val_stats):
                    if 'scalar' in val_stats.keys():
                        _flat = flatten_dict(val_stats['scalar'], False)
                    elif 'scalars' in val_stats.keys():
                        _flat = flatten_dict(val_stats['scalars'], False)
                    else:
                        _flat = flatten_dict(val_stats, False)
                else:
                    _flat = {key_stats: val_stats}
                val_outputs_flat.update(_flat)

        # -- Be sure that values are scalar and not tensor
        remove_keys = []
        for key, val in val_outputs_flat.items():
            if val.dim() == 0:
                val_outputs_flat[key] = self._get_scalar(val)
            else:
                remove_keys.append(key)

        for key in remove_keys:
            del val_outputs_flat[key]

        return val_outputs_flat
示例#7
0
    def log_each_epoch(self):
        # -- TODO: add anything else to be logged here

        # -- Log learning rates
        optimizers = self._model.optimizer
        if is_optimizer(optimizers):
            current_lr = optimizers.param_groups[0]['lr']
            self._logger.log_metric('optim/lr', current_lr, self._epoch)

        elif is_dict(optimizers):
            for key, opt in optimizers.items():
                current_lr = opt.param_groups[0]['lr']
                self._logger.log_metric(
                    'optim/lr_{}'.format(key), current_lr, self._epoch)
        else:
            raise ValueError(
                "optimizer should be a dict or a torch.optim.Optimizer object")
示例#8
0
    def load_checkpoint(self, filename=None, is_best=False):
        """
            This function actually restores a checkpoint with:

            - model state_dict
            - optimizers state_dict
            - global_step
            - epoch
            - best_epoch
            - beaten_epochs
            - best_epoch_score
            - best_stats
        """

        path = self.checkpoints_dir
        ckp_format = self.checkpoints_format
        if is_best:
            path = self.best_checkpoints_dir
        if filename is None:
            filename = os.path.join(path, ckp_format.format(self._epoch))

        elif isinstance(filename, int):
            filename = os.path.join(path, ckp_format.format(filename))

        assert isinstance(filename, str), \
            'filename should be the epoch (int) or the checkpoint path (str)'

        checkpoint = torch.load(filename)
        self._global_step = checkpoint.get('global_step', checkpoint.get('seen', 0))
        self._epoch = checkpoint['epoch']

        self._model.best_epoch = checkpoint.get('best_epoch', -1)
        self._model.beaten_epochs = checkpoint.get('beaten_epochs', 0)
        self._model.best_epoch_score = checkpoint.get('best_epoch_score', 0)
        self._model.best_stats = checkpoint.get('best_stats', [])

        self._model.load_state_dict(checkpoint['model_state_dict'])

        if is_dict(self._model.optimizer):
            for key in self._model.optimizer.keys():
                self._model.optimizer[key].load_state_dict(
                    checkpoint['optimizer_state_dict'][key])
        else:
            self._model.optimizer.load_state_dict(
                checkpoint['optimizer_state_dict'])
示例#9
0
    def _init_optimizer_accum(self):
        """
        Initialize a dictionary of counters to keep track, for each optimizer,
        how many times the `optimizer.step()` function has been called. The
        optimizer itself is used as key of the dictionary.
        """
        # If not already, convert into a dict for convenience
        optim_dict = (self.optimizer if is_dict(self.optimizer) else {
            'optimizer': self.optimizer
        })
        # -- Initialize the dict of steps per optimizer
        self._optimizers_accum = {key: 0 for key in optim_dict.values()}

        # -- Decorate optimizer's step() and zero_grad() with counters
        for optim in optim_dict.values():
            optim.step = call_counter(optim.step)
            optim.zero_grad = call_counter(optim.zero_grad)

        # -- Decorate model zero_grad
        self.zero_grad = call_counter(self.zero_grad)
示例#10
0
    def restore_exp(self):
        # TODO: fix this
        checkpoint = self.reload_checkpoint(
            self.args.restore,
            os.path.join(self._restore_path, 'checkpoints'),
            self.args)

        self._global_step = checkpoint.get('global_step', checkpoint.get('seen', 0))
        self._epoch = checkpoint['epoch']

        self._model.best_epoch = checkpoint.get('best_epoch', -1)
        self._model.beaten_epochs = checkpoint.get('beaten_epochs', 0)
        self._model.best_epoch_score = checkpoint.get('best_epoch_score', 0)
        self._model.best_stats = checkpoint.get('best_stats', [])

        self._model.load_state_dict(checkpoint['model_state_dict'])

        if is_dict(self._model.optimizer):
            for key in self._model.optimizer.keys():
                self._model.optimizer.load_state_dict(
                    checkpoint['optimizer_state_dict'][key])
        else:
            self._model.optimizer.load_state_dict(
                checkpoint['optimizer_state_dict'])
示例#11
0
文件: base.py 项目: marcociccone/yapt
    def override_with_custom_args(self, extra_args=None):
        """
        Specific arguments can be overridden by:

        - `custom_config` file, defined via cli.
        - `extra_args` dict passed to the constructor of the Trainer object.
        - via command line using the dotted notation.

        The arguments should already defined in the default_config, otherwise
        an exception is raised since you are trying to modify an argument that
        does not exist.
        """

        # -- From command line
        self._cli_args = OmegaConf.from_cli()

        # -- From experiment custom config file (passed from cli)
        self._custom_config_args = OmegaConf.create(dict())
        if self._cli_args.custom_config is not None:
            self._custom_config_args = OmegaConf.load(
                self._cli_args.custom_config)

        # -- Extra config from Tune or any script
        if is_dict(extra_args):
            matching = [s for s in extra_args.keys() if "." in s]
            if len(matching) > 0:
                self.console_log.warning(
                    "It seems you are using dotted notation \
                      in a dictionary! Please use a list instead, \
                      to modify the correct values! %s", matching)
            self._extra_args = OmegaConf.create(extra_args)

        elif is_list(extra_args):
            self._extra_args = OmegaConf.from_dotlist(extra_args)

        elif extra_args is None:
            self._extra_args = OmegaConf.create(dict())

        else:
            raise ValueError("extra_args should be a list of \
                             dotted strings or a dict")

        # -- Save optimizer args for later
        dict_opt_custom = deepcopy(self._custom_config_args.optimizer)
        dict_opt_extra = deepcopy(self._extra_args.optimizer)
        if dict_opt_custom is not None:
            del self._custom_config_args['optimizer']
        if dict_opt_extra is not None:
            del self._extra_args['optimizer']

        # -- override custom args, ONLY IF THEY EXISTS
        self._args = OmegaConf.merge(self._args, self._custom_config_args,
                                     self._extra_args)

        # !!NOTE!! Optimizer could drastically change
        OmegaConf.set_struct(self._args, False)
        if dict_opt_custom is not None:
            self._args = OmegaConf.merge(
                self._args, OmegaConf.create({'optimizer': dict_opt_custom}))

        if dict_opt_extra is not None:
            self._args = OmegaConf.merge(
                self._args, OmegaConf.create({'optimizer': dict_opt_extra}))
        OmegaConf.set_struct(self._args, True)

        # !!NOTE!! WORKAROUND because of Tune comman line args
        OmegaConf.set_struct(self._args, False)
        self._args = OmegaConf.merge(self._args, self._cli_args)
        OmegaConf.set_struct(self._args, True)

        # -- Resolve interpolations to be sure all nodes are explicit
        self._args = OmegaConf.to_container(self._args, resolve=True)
        self._args = OmegaConf.create(self._args)