def run(self, data_loaders, workflow, max_epochs=None, **kwargs): """Start running. Args: data_loaders (list[:obj:`DataLoader`]): Dataloaders for training and validation. workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs. E.g, [('train', 2), ('val', 1)] means running 2 epochs for training and 1 epoch for validation, iteratively. """ assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) assert len(data_loaders) == len(workflow) if max_epochs is not None: warnings.warn( "setting max_epochs in run is deprecated, " "please set max_epochs in runner_config", DeprecationWarning, ) self._max_epochs = max_epochs assert ( self._max_epochs is not None), "max_epochs must be specified during instantiation" for i, flow in enumerate(workflow): mode, epochs = flow if mode == "train": self._max_iters = self._max_epochs * len(data_loaders[i]) break work_dir = self.work_dir if self.work_dir is not None else "NONE" self.logger.info("Start running, host: %s, work_dir: %s", get_host_info(), work_dir) self.logger.info("workflow: %s, max: %d epochs", workflow, self._max_epochs) self.call_hook("before_run") while self.epoch < self._max_epochs: for i, flow in enumerate(workflow): mode, epochs = flow if isinstance(mode, str): # self.train() if not hasattr(self, mode): raise ValueError( f'runner has no method named "{mode}" to run an ' "epoch") epoch_runner = getattr(self, mode) else: raise TypeError( "mode in workflow must be a str, but got {}".format( type(mode))) for _ in range(epochs): if mode == "train" and self.epoch >= self._max_epochs: break epoch_runner(data_loaders[i], **kwargs) time.sleep(1) # wait for some hooks like loggers to finish self.call_hook("after_run")
def run_step_alter(self, data_loaders, workflow, max_epochs, arch_update_epoch, **kwargs): """Start running. Arch and weight optimization alternates by step. Args: data_loaders (list[:obj:`DataLoader`]): Dataloaders for training and validation. workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs. E.g, [('train', 2), ('val', 1)] means running 2 epochs for training and 1 epoch for validation, iteratively. max_epochs (int): Total training epochs. """ assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) self._max_epochs = max_epochs work_dir = self.work_dir if self.work_dir is not None else 'NONE' self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir) self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs) self.call_hook('before_run', 'train') while self.epoch < max_epochs: self.search_stage = 0 if self.epoch < self.cfg.arch_update_epoch else 1 self.train_step_alter(data_loaders) time.sleep(1) # wait for some hooks like loggers to finish self.call_hook('after_run', 'train')
def run(self, data_loaders, workflow, max_epochs=None, **kwargs): """Start running. Args: data_loaders (list[:obj:`DataLoader`]): Dataloaders for training. `data_loaders[0]` is the main data_loader, which contains target datasets and determines the epoch length. `data_loaders[1:]` are auxiliary data loaders, which contain auxiliary web datasets. workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs. E.g, [('train', 2)] means running 2 epochs for training iteratively. Note that val epoch is not supported for this runner for simplicity. max_epochs (int | None): The max epochs that training lasts, deprecated now. Default: None. """ assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) assert len(workflow) == 1 and workflow[0][0] == 'train' if max_epochs is not None: warnings.warn( 'setting max_epochs in run is deprecated, ' 'please set max_epochs in runner_config', DeprecationWarning) self._max_epochs = max_epochs assert self._max_epochs is not None, ( 'max_epochs must be specified during instantiation') mode, epochs = workflow[0] self._max_iters = self._max_epochs * len(data_loaders[0]) work_dir = self.work_dir if self.work_dir is not None else 'NONE' self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir) self.logger.info('workflow: %s, max: %d epochs', workflow, self._max_epochs) self.call_hook('before_run') while self.epoch < self._max_epochs: if isinstance(mode, str): # self.train() if not hasattr(self, mode): raise ValueError( f'runner has no method named "{mode}" to run an ' 'epoch') epoch_runner = getattr(self, mode) else: raise TypeError( f'mode in workflow must be a str, but got {mode}') for _ in range(epochs): if mode == 'train' and self.epoch >= self._max_epochs: break epoch_runner(data_loaders, **kwargs) time.sleep(1) # wait for some hooks like loggers to finish self.call_hook('after_run')
def run_epoch_alter(self, data_loaders, workflow, max_epochs, arch_update_epoch, **kwargs): """Start running. Arch and weight optimization alternates by epoch. Args: data_loaders (list[:obj:`DataLoader`]): Dataloaders for training and validation. workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs. E.g, [('train', 2), ('val', 1)] means running 2 epochs for training and 1 epoch for validation, iteratively. max_epochs (int): Total training epochs. """ assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) self._max_epochs = max_epochs work_dir = self.work_dir if self.work_dir is not None else 'NONE' self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir) self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs) self.call_hook('before_run', 'train') while self.epoch < max_epochs: for i, flow in enumerate(workflow): mode, epochs = flow if isinstance(mode, str): # self.train() assert mode in ['train', 'arch', 'val'] if mode in ['train', 'arch']: epoch_runner = getattr(self, 'train' + '_epoch_alter') else: epoch_runner = getattr(self, 'val') elif callable(mode): # custom train() epoch_runner = mode else: raise TypeError('mode in workflow must be a str or ' 'callable function, not {}'.format( type(mode))) for _ in range(epochs): if mode == 'train' and self.epoch >= max_epochs: return elif mode in ['arch', 'val' ] and self.epoch < arch_update_epoch: break data_loader = data_loaders[ 0] if mode == 'train' else data_loaders[1] epoch_runner(data_loader, mode=mode) time.sleep(1) # wait for some hooks like loggers to finish self.call_hook('after_run', 'train')
def run(self, data_loaders, workflow, max_epochs, **kwargs): """Start running. Args: data_loaders (list[:obj:`DataLoader`]): Dataloaders for training and validation. workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs. E.g, [('train', 2), ('val', 1)] means running 2 epochs for training and 1 epoch for validation, iteratively. max_epochs (int): Total training epochs. """ assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) assert len(data_loaders) == len(workflow) self._max_epochs = max_epochs work_dir = self.work_dir if self.work_dir is not None else 'NONE' self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir) self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs) self.call_hook('before_run') while self.epoch < max_epochs: for i, flow in enumerate(workflow): mode, epochs = flow if isinstance(mode, str): # self.train() if not hasattr(self, mode): raise ValueError( 'runner has no method named "{}" to run an epoch'. format(mode)) epoch_runner = getattr(self, mode) elif callable(mode): # custom train() epoch_runner = mode else: raise TypeError('mode in workflow must be a str or ' 'callable function, not {}'.format( type(mode))) for _ in range(epochs): if mode == 'train' and self.epoch >= max_epochs: return epoch_runner(data_loaders[i], **kwargs) time.sleep(1) # wait for some hooks like loggers to finish self.call_hook('after_run')
def run(self, data_loaders, workflow, max_epochs, **kwargs): """Start running. Args: data_loaders (list[:obj:`DataLoader`]): Dataloaders for training and validation. workflow (list[tuple]): A list of (phase, epochs) to specify the running order and epochs. E.g, [('train', 2), ('val', 1)] means running 2 epochs for training and 1 epoch for validation, iteratively. max_epochs (int): Total training epochs. """ assert isinstance(data_loaders, list) assert mmcv.is_list_of(workflow, tuple) assert len(data_loaders) == len(workflow) try: if self.resume_from: self.resume(self.resume_from) elif self.load_from: self.load_checkpoint(self.load_from) elif self.auto_resume_bool: self.auto_resume() resume_optimizer = True except ValueError as e: # can not load optimizer state_dict beacuse of param_group dismatching self.logger.warn(str(e)) raise e self._max_epochs = max_epochs work_dir = self.work_dir if self.work_dir is not None else 'NONE' self.logger.info('Start running, host: %s, work_dir: %s', get_host_info(), work_dir) self.logger.info('workflow: %s, max: %d epochs', workflow, max_epochs) self.call_hook('before_run') while self.epoch < max_epochs: for i, flow in enumerate(workflow): if i < self.current_stage: continue mode, epochs = flow if isinstance(mode, str): # self.train() or custom functions if not hasattr(self, mode): raise ValueError( 'runner has no method named "{}" to run an epoch'. format(mode)) epoch_runner = getattr(self, mode) elif callable(mode): # custom train() epoch_runner = mode else: raise TypeError('mode in workflow must be a str or ' 'callable function, not {}'.format( type(mode))) self._stage_max_epochs = epochs for epoch in range(epochs): if 'train' in mode and (self.epoch >= max_epochs or self.stage_epoch >= epochs): break if 'stage' in mode: # TODO: fix multiple data_loaders epoch_runner(self, data_loader=data_loaders[0], stage_epoch=epoch, resume_optimizer=resume_optimizer, **kwargs) else: epoch_runner(data_loaders[i], **kwargs) resume_optimizer = False self._stage += 1 self._stage_epoch = 0 self._stage_iter = 0 time.sleep(1) # wait for some hooks like loggers to finish self.call_hook('after_run')