示例#1
0
    def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)
        if max_epochs is not None:
            warnings.warn(
                "setting max_epochs in run is deprecated, "
                "please set max_epochs in runner_config",
                DeprecationWarning,
            )
            self._max_epochs = max_epochs

        assert (
            self._max_epochs
            is not None), "max_epochs must be specified during instantiation"

        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if mode == "train":
                self._max_iters = self._max_epochs * len(data_loaders[i])
                break

        work_dir = self.work_dir if self.work_dir is not None else "NONE"
        self.logger.info("Start running, host: %s, work_dir: %s",
                         get_host_info(), work_dir)
        self.logger.info("workflow: %s, max: %d epochs", workflow,
                         self._max_epochs)
        self.call_hook("before_run")

        while self.epoch < self._max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            f'runner has no method named "{mode}" to run an '
                            "epoch")
                    epoch_runner = getattr(self, mode)
                else:
                    raise TypeError(
                        "mode in workflow must be a str, but got {}".format(
                            type(mode)))

                for _ in range(epochs):
                    if mode == "train" and self.epoch >= self._max_epochs:
                        break
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook("after_run")
示例#2
0
    def run_step_alter(self, data_loaders, workflow, max_epochs,
                       arch_update_epoch, **kwargs):
        """Start running. Arch and weight optimization alternates by step.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
            max_epochs (int): Total training epochs.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)

        self._max_epochs = max_epochs
        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
        self.call_hook('before_run', 'train')

        while self.epoch < max_epochs:
            self.search_stage = 0 if self.epoch < self.cfg.arch_update_epoch else 1
            self.train_step_alter(data_loaders)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run', 'train')
    def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training.
                `data_loaders[0]` is the main data_loader, which contains
                target datasets and determines the epoch length.
                `data_loaders[1:]` are auxiliary data loaders, which contain
                auxiliary web datasets.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2)] means running 2
                epochs for training iteratively. Note that val epoch is not
                supported for this runner for simplicity.
            max_epochs (int | None): The max epochs that training lasts,
                deprecated now. Default: None.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(workflow) == 1 and workflow[0][0] == 'train'
        if max_epochs is not None:
            warnings.warn(
                'setting max_epochs in run is deprecated, '
                'please set max_epochs in runner_config', DeprecationWarning)
            self._max_epochs = max_epochs

        assert self._max_epochs is not None, (
            'max_epochs must be specified during instantiation')

        mode, epochs = workflow[0]
        self._max_iters = self._max_epochs * len(data_loaders[0])

        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow,
                         self._max_epochs)
        self.call_hook('before_run')

        while self.epoch < self._max_epochs:
            if isinstance(mode, str):  # self.train()
                if not hasattr(self, mode):
                    raise ValueError(
                        f'runner has no method named "{mode}" to run an '
                        'epoch')
                epoch_runner = getattr(self, mode)
            else:
                raise TypeError(
                    f'mode in workflow must be a str, but got {mode}')

            for _ in range(epochs):
                if mode == 'train' and self.epoch >= self._max_epochs:
                    break
                epoch_runner(data_loaders, **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')
示例#4
0
    def run_epoch_alter(self, data_loaders, workflow, max_epochs,
                        arch_update_epoch, **kwargs):
        """Start running. Arch and weight optimization alternates by epoch.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
            max_epochs (int): Total training epochs.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)

        self._max_epochs = max_epochs
        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
        self.call_hook('before_run', 'train')

        while self.epoch < max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    assert mode in ['train', 'arch', 'val']
                    if mode in ['train', 'arch']:
                        epoch_runner = getattr(self, 'train' + '_epoch_alter')
                    else:
                        epoch_runner = getattr(self, 'val')
                elif callable(mode):  # custom train()
                    epoch_runner = mode
                else:
                    raise TypeError('mode in workflow must be a str or '
                                    'callable function, not {}'.format(
                                        type(mode)))
                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= max_epochs:
                        return
                    elif mode in ['arch', 'val'
                                  ] and self.epoch < arch_update_epoch:
                        break
                    data_loader = data_loaders[
                        0] if mode == 'train' else data_loaders[1]
                    epoch_runner(data_loader, mode=mode)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run', 'train')
示例#5
0
    def run(self, data_loaders, workflow, max_epochs, **kwargs):
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
            max_epochs (int): Total training epochs.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)

        self._max_epochs = max_epochs
        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
        self.call_hook('before_run')

        while self.epoch < max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            'runner has no method named "{}" to run an epoch'.
                            format(mode))
                    epoch_runner = getattr(self, mode)
                elif callable(mode):  # custom train()
                    epoch_runner = mode
                else:
                    raise TypeError('mode in workflow must be a str or '
                                    'callable function, not {}'.format(
                                        type(mode)))
                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= max_epochs:
                        return
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')
示例#6
0
    def run(self, data_loaders, workflow, max_epochs, **kwargs):
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
            max_epochs (int): Total training epochs.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)

        try:
            if self.resume_from:
                self.resume(self.resume_from)
            elif self.load_from:
                self.load_checkpoint(self.load_from)
            elif self.auto_resume_bool:
                self.auto_resume()
            resume_optimizer = True
        except ValueError as e:  # can not load optimizer state_dict beacuse of param_group dismatching
            self.logger.warn(str(e))
            raise e

        self._max_epochs = max_epochs
        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs)
        self.call_hook('before_run')

        while self.epoch < max_epochs:
            for i, flow in enumerate(workflow):
                if i < self.current_stage:
                    continue
                mode, epochs = flow
                if isinstance(mode, str):  # self.train() or custom functions
                    if not hasattr(self, mode):
                        raise ValueError(
                            'runner has no method named "{}" to run an epoch'.
                            format(mode))
                    epoch_runner = getattr(self, mode)
                elif callable(mode):  # custom train()
                    epoch_runner = mode
                else:
                    raise TypeError('mode in workflow must be a str or '
                                    'callable function, not {}'.format(
                                        type(mode)))

                self._stage_max_epochs = epochs
                for epoch in range(epochs):
                    if 'train' in mode and (self.epoch >= max_epochs
                                            or self.stage_epoch >= epochs):
                        break
                    if 'stage' in mode:
                        # TODO: fix multiple data_loaders
                        epoch_runner(self,
                                     data_loader=data_loaders[0],
                                     stage_epoch=epoch,
                                     resume_optimizer=resume_optimizer,
                                     **kwargs)
                    else:
                        epoch_runner(data_loaders[i], **kwargs)
                resume_optimizer = False
                self._stage += 1
                self._stage_epoch = 0
                self._stage_iter = 0

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')