예제 #1
0
파일: runner.py 프로젝트: unhoang/catalyst
    def _run_stage(self, stage: str):
        self._prepare_for_stage(stage)
        loaders = self.experiment.get_loaders(stage)
        callbacks = self.experiment.get_callbacks(stage)
        loggers = utils.process_callbacks(
            OrderedDict([(k, v) for k, v in callbacks.items()
                         if isinstance(v, LoggerCallback)]))
        callbacks = utils.process_callbacks(
            OrderedDict([(k, v) for k, v in callbacks.items()
                         if not isinstance(v, LoggerCallback)]))
        self.state.loggers = loggers
        self.loggers = loggers
        self.callbacks = callbacks

        self._run_event("stage", moment="start")
        for epoch in range(self.state.num_epochs):
            self.state.stage_epoch = epoch

            self._run_event("epoch", moment="start")
            self._run_epoch(loaders)
            self._run_event("epoch", moment="end")

            if self._check_run and self.state.epoch >= 3:
                break
            if self.state.early_stop:
                self.state.early_stop = False
                break

            self.state.epoch += 1
        self._run_event("stage", moment="end")
예제 #2
0
파일: base.py 프로젝트: wmmxk/catalyst
    def __init__(
        self,
        model: _Model,
        loaders: "OrderedDict[str, DataLoader]",
        callbacks: "Union[OrderedDict[str, Callback], List[Callback]]" = None,
        logdir: str = None,
        stage: str = "train",
        criterion: _Criterion = None,
        optimizer: _Optimizer = None,
        scheduler: _Scheduler = None,
        num_epochs: int = 1,
        valid_loader: str = "valid",
        main_metric: str = "loss",
        minimize_metric: bool = True,
        verbose: bool = False,
        state_kwargs: Dict = None,
        checkpoint_data: Dict = None,
        distributed_params: Dict = None,
        monitoring_params: Dict = None,
        initial_seed: int = 42,
    ):
        self._model = model
        self._loaders = loaders
        self._callbacks = process_callbacks(callbacks)

        self._criterion = criterion
        self._optimizer = optimizer
        self._scheduler = scheduler

        self._initial_seed = initial_seed
        self._logdir = logdir
        self._stage = stage
        self._num_epochs = num_epochs
        self._valid_loader = valid_loader
        self._main_metric = main_metric
        self._minimize_metric = minimize_metric
        self._verbose = verbose
        self._additional_state_kwargs = state_kwargs or {}
        self.checkpoint_data = checkpoint_data or {}
        self._distributed_params = distributed_params or {}
        self._monitoring_params = monitoring_params or {}
예제 #3
0
    def train(
        self,
        model: Model,
        criterion: Criterion,
        optimizer: Optimizer,
        loaders: "OrderedDict[str, DataLoader]",
        logdir: str,
        callbacks: "Union[List[Callback], OrderedDict[str, Callback]]" = None,
        scheduler: Scheduler = None,
        resume: str = None,
        num_epochs: int = 1,
        valid_loader: str = "valid",
        main_metric: str = "loss",
        minimize_metric: bool = True,
        verbose: bool = False,
        state_kwargs: Dict = None,
        checkpoint_data: Dict = None,
        fp16: Union[Dict, bool] = None,
        monitoring_params: Dict = None,
        check: bool = False,
    ) -> None:
        """
        Starts the training process of the model.

        Args:
            model (Model): model to train
            criterion (Criterion): criterion function for training
            optimizer (Optimizer): optimizer for training
            loaders (dict): dictionary containing one or several
                ``torch.utils.data.DataLoader`` for training and validation
            logdir (str): path to output directory
            callbacks (List[catalyst.dl.Callback]): list of callbacks
            scheduler (Scheduler): scheduler for training
            resume (str): path to checkpoint for model
            num_epochs (int): number of training epochs
            valid_loader (str): loader name used to calculate
                the metrics and save the checkpoints. For example,
                you can pass `train` and then
                the metrics will be taken from `train` loader.
            main_metric (str): the key to the name of the metric
                by which the checkpoints will be selected.
            minimize_metric (bool): flag to indicate whether
                the ``main_metric`` should be minimized.
            verbose (bool): ff true, it displays the status of the training
                to the console.
            state_kwargs (dict): additional state params to ``State``
            checkpoint_data (dict): additional data to save in checkpoint,
                for example: ``class_names``, ``date_of_training``, etc
            fp16 (Union[Dict, bool]): If not None, then sets training to FP16.
                See https://nvidia.github.io/apex/amp.html#properties
                if fp16=True, params by default will be ``{"opt_level": "O1"}``
            monitoring_params (dict): If not None, then create monitoring
                through Alchemy or Weights&Biases.
                For example,
                ``{"token": "api_token", "experiment": "experiment_name"}``
            check (bool): if True, then only checks that pipeline is working
                (3 epochs only)
        """
        if len(loaders) == 1:
            valid_loader = list(loaders.keys())[0]
            logger.warning(
                "Attention, there is only one data loader - "
                + str(valid_loader)
            )
        if isinstance(fp16, bool) and fp16:
            fp16 = {"opt_level": "O1"}

        if model is not None:
            self.model = model

        if resume is not None:
            callbacks = utils.process_callbacks(callbacks)
            checkpoint_callback_flag = any(
                isinstance(x, CheckpointCallback) for x in callbacks.values()
            )
            if not checkpoint_callback_flag:
                callbacks["loader"] = CheckpointCallback(resume=resume)
            else:
                raise NotImplementedError("CheckpointCallback already exist")

        experiment = self._experiment_fn(
            stage="train",
            model=model,
            loaders=loaders,
            callbacks=callbacks,
            logdir=logdir,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            num_epochs=num_epochs,
            valid_loader=valid_loader,
            main_metric=main_metric,
            minimize_metric=minimize_metric,
            verbose=verbose,
            check_run=check,
            state_kwargs=state_kwargs,
            checkpoint_data=checkpoint_data,
            distributed_params=fp16,
            monitoring_params=monitoring_params,
        )
        self.run_experiment(experiment)
예제 #4
0
    def __init__(
        self,
        model: Model,
        loaders: "OrderedDict[str, DataLoader]",
        callbacks: "Union[OrderedDict[str, Callback], List[Callback]]" = None,
        logdir: str = None,
        stage: str = "train",
        criterion: Criterion = None,
        optimizer: Optimizer = None,
        scheduler: Scheduler = None,
        num_epochs: int = 1,
        valid_loader: str = "valid",
        main_metric: str = "loss",
        minimize_metric: bool = True,
        verbose: bool = False,
        check_run: bool = False,
        state_kwargs: Dict = None,
        checkpoint_data: Dict = None,
        distributed_params: Dict = None,
        monitoring_params: Dict = None,
        initial_seed: int = 42,
    ):
        """
        Args:
            model (Model): model
            loaders (dict): dictionary containing one or several
                ``torch.utils.data.DataLoader`` for training and validation
            callbacks (List[catalyst.dl.Callback]): list of callbacks
            logdir (str): path to output directory
            stage (str): current stage
            criterion (Criterion): criterion function
            optimizer (Optimizer): optimizer
            scheduler (Scheduler): scheduler
            num_epochs (int): number of experiment's epochs
            valid_loader (str): loader name used to calculate
                the metrics and save the checkpoints. For example,
                you can pass `train` and then
                the metrics will be taken from `train` loader.
            main_metric (str): the key to the name of the metric
                by which the checkpoints will be selected.
            minimize_metric (bool): flag to indicate whether
                the ``main_metric`` should be minimized.
            verbose (bool): ff true, it displays the status of the training
                to the console.
            state_kwargs (dict): additional state params to ``State``
            checkpoint_data (dict): additional data to save in checkpoint,
                for example: ``class_names``, ``date_of_training``, etc
            distributed_params (dict): dictionary with the parameters
                for distributed and FP16 method
            monitoring_params (dict): dict with the parameters
                for monitoring services
            initial_seed (int): experiment's initial seed value
        """
        self._model = model
        self._loaders = loaders
        self._callbacks = utils.process_callbacks(callbacks)

        self._criterion = criterion
        self._optimizer = optimizer
        self._scheduler = scheduler

        self._initial_seed = initial_seed
        self._logdir = logdir
        self._stage = stage
        self._num_epochs = num_epochs
        self._valid_loader = valid_loader
        self._main_metric = main_metric
        self._minimize_metric = minimize_metric
        self._verbose = verbose
        self._check_run = check_run
        self._additional_state_kwargs = state_kwargs or {}
        self._checkpoint_data = checkpoint_data or {}
        self._distributed_params = distributed_params or {}
        self._monitoring_params = monitoring_params or {}