def _run_stage(self, stage: str): self._prepare_for_stage(stage) loaders = self.experiment.get_loaders(stage) callbacks = self.experiment.get_callbacks(stage) loggers = utils.process_callbacks( OrderedDict([(k, v) for k, v in callbacks.items() if isinstance(v, LoggerCallback)])) callbacks = utils.process_callbacks( OrderedDict([(k, v) for k, v in callbacks.items() if not isinstance(v, LoggerCallback)])) self.state.loggers = loggers self.loggers = loggers self.callbacks = callbacks self._run_event("stage", moment="start") for epoch in range(self.state.num_epochs): self.state.stage_epoch = epoch self._run_event("epoch", moment="start") self._run_epoch(loaders) self._run_event("epoch", moment="end") if self._check_run and self.state.epoch >= 3: break if self.state.early_stop: self.state.early_stop = False break self.state.epoch += 1 self._run_event("stage", moment="end")
def __init__( self, model: _Model, loaders: "OrderedDict[str, DataLoader]", callbacks: "Union[OrderedDict[str, Callback], List[Callback]]" = None, logdir: str = None, stage: str = "train", criterion: _Criterion = None, optimizer: _Optimizer = None, scheduler: _Scheduler = None, num_epochs: int = 1, valid_loader: str = "valid", main_metric: str = "loss", minimize_metric: bool = True, verbose: bool = False, state_kwargs: Dict = None, checkpoint_data: Dict = None, distributed_params: Dict = None, monitoring_params: Dict = None, initial_seed: int = 42, ): self._model = model self._loaders = loaders self._callbacks = process_callbacks(callbacks) self._criterion = criterion self._optimizer = optimizer self._scheduler = scheduler self._initial_seed = initial_seed self._logdir = logdir self._stage = stage self._num_epochs = num_epochs self._valid_loader = valid_loader self._main_metric = main_metric self._minimize_metric = minimize_metric self._verbose = verbose self._additional_state_kwargs = state_kwargs or {} self.checkpoint_data = checkpoint_data or {} self._distributed_params = distributed_params or {} self._monitoring_params = monitoring_params or {}
def train( self, model: Model, criterion: Criterion, optimizer: Optimizer, loaders: "OrderedDict[str, DataLoader]", logdir: str, callbacks: "Union[List[Callback], OrderedDict[str, Callback]]" = None, scheduler: Scheduler = None, resume: str = None, num_epochs: int = 1, valid_loader: str = "valid", main_metric: str = "loss", minimize_metric: bool = True, verbose: bool = False, state_kwargs: Dict = None, checkpoint_data: Dict = None, fp16: Union[Dict, bool] = None, monitoring_params: Dict = None, check: bool = False, ) -> None: """ Starts the training process of the model. Args: model (Model): model to train criterion (Criterion): criterion function for training optimizer (Optimizer): optimizer for training loaders (dict): dictionary containing one or several ``torch.utils.data.DataLoader`` for training and validation logdir (str): path to output directory callbacks (List[catalyst.dl.Callback]): list of callbacks scheduler (Scheduler): scheduler for training resume (str): path to checkpoint for model num_epochs (int): number of training epochs valid_loader (str): loader name used to calculate the metrics and save the checkpoints. For example, you can pass `train` and then the metrics will be taken from `train` loader. main_metric (str): the key to the name of the metric by which the checkpoints will be selected. minimize_metric (bool): flag to indicate whether the ``main_metric`` should be minimized. verbose (bool): ff true, it displays the status of the training to the console. state_kwargs (dict): additional state params to ``State`` checkpoint_data (dict): additional data to save in checkpoint, for example: ``class_names``, ``date_of_training``, etc fp16 (Union[Dict, bool]): If not None, then sets training to FP16. See https://nvidia.github.io/apex/amp.html#properties if fp16=True, params by default will be ``{"opt_level": "O1"}`` monitoring_params (dict): If not None, then create monitoring through Alchemy or Weights&Biases. For example, ``{"token": "api_token", "experiment": "experiment_name"}`` check (bool): if True, then only checks that pipeline is working (3 epochs only) """ if len(loaders) == 1: valid_loader = list(loaders.keys())[0] logger.warning( "Attention, there is only one data loader - " + str(valid_loader) ) if isinstance(fp16, bool) and fp16: fp16 = {"opt_level": "O1"} if model is not None: self.model = model if resume is not None: callbacks = utils.process_callbacks(callbacks) checkpoint_callback_flag = any( isinstance(x, CheckpointCallback) for x in callbacks.values() ) if not checkpoint_callback_flag: callbacks["loader"] = CheckpointCallback(resume=resume) else: raise NotImplementedError("CheckpointCallback already exist") experiment = self._experiment_fn( stage="train", model=model, loaders=loaders, callbacks=callbacks, logdir=logdir, criterion=criterion, optimizer=optimizer, scheduler=scheduler, num_epochs=num_epochs, valid_loader=valid_loader, main_metric=main_metric, minimize_metric=minimize_metric, verbose=verbose, check_run=check, state_kwargs=state_kwargs, checkpoint_data=checkpoint_data, distributed_params=fp16, monitoring_params=monitoring_params, ) self.run_experiment(experiment)
def __init__( self, model: Model, loaders: "OrderedDict[str, DataLoader]", callbacks: "Union[OrderedDict[str, Callback], List[Callback]]" = None, logdir: str = None, stage: str = "train", criterion: Criterion = None, optimizer: Optimizer = None, scheduler: Scheduler = None, num_epochs: int = 1, valid_loader: str = "valid", main_metric: str = "loss", minimize_metric: bool = True, verbose: bool = False, check_run: bool = False, state_kwargs: Dict = None, checkpoint_data: Dict = None, distributed_params: Dict = None, monitoring_params: Dict = None, initial_seed: int = 42, ): """ Args: model (Model): model loaders (dict): dictionary containing one or several ``torch.utils.data.DataLoader`` for training and validation callbacks (List[catalyst.dl.Callback]): list of callbacks logdir (str): path to output directory stage (str): current stage criterion (Criterion): criterion function optimizer (Optimizer): optimizer scheduler (Scheduler): scheduler num_epochs (int): number of experiment's epochs valid_loader (str): loader name used to calculate the metrics and save the checkpoints. For example, you can pass `train` and then the metrics will be taken from `train` loader. main_metric (str): the key to the name of the metric by which the checkpoints will be selected. minimize_metric (bool): flag to indicate whether the ``main_metric`` should be minimized. verbose (bool): ff true, it displays the status of the training to the console. state_kwargs (dict): additional state params to ``State`` checkpoint_data (dict): additional data to save in checkpoint, for example: ``class_names``, ``date_of_training``, etc distributed_params (dict): dictionary with the parameters for distributed and FP16 method monitoring_params (dict): dict with the parameters for monitoring services initial_seed (int): experiment's initial seed value """ self._model = model self._loaders = loaders self._callbacks = utils.process_callbacks(callbacks) self._criterion = criterion self._optimizer = optimizer self._scheduler = scheduler self._initial_seed = initial_seed self._logdir = logdir self._stage = stage self._num_epochs = num_epochs self._valid_loader = valid_loader self._main_metric = main_metric self._minimize_metric = minimize_metric self._verbose = verbose self._check_run = check_run self._additional_state_kwargs = state_kwargs or {} self._checkpoint_data = checkpoint_data or {} self._distributed_params = distributed_params or {} self._monitoring_params = monitoring_params or {}