Ejemplo n.º 1
0
    def fit(self,
            train_loader,
            val_loader=None,
            max_epochs=1,
            metrics=None,
            metrics_on_train=False,
            callbacks=None,
            val_callbacks=None):
        metrics = [] if metrics is None else metrics
        assert self.train_ready()
        setup_logging()

        train_engine = Engine(self.train_step, model=self, logger=self.logger)
        train_metrics = [Loss()] + metrics if metrics_on_train else [Loss()]
        _attach_metrics(train_engine, train_metrics, name_prefix='train_')
        metrics_logging.attach(train_engine, train=True)

        if val_loader is not None:
            self.validate(val_loader, metrics, val_callbacks)
            val_engine = Engine(self.val_step, model=self, logger=self.logger)
            _attach_metrics(val_engine, [Loss()] + metrics, name_prefix='val_')
            _attach_callbacks(val_engine, val_callbacks)

            @on_epoch_complete
            def validation_epoch(train_state, val_engine, val_loader):
                epoch = train_state.epoch
                val_state = val_engine.run(val_loader, epoch, epoch + 1)
                train_state.metrics.update(val_state.metrics)

            validation_epoch.attach(train_engine, val_engine, val_loader)
            metrics_logging.attach(train_engine, train=False)

        _attach_callbacks(train_engine, callbacks)
        train_engine.run(train_loader, 0, max_epochs)
Ejemplo n.º 2
0
    def validate(self,
                 val_loader: Optional[Iterable],
                 metrics: Optional[List[Metric]] = None,
                 callbacks: Optional[List[Callback]] = None) -> Dict[str, float]:
        """Perform a validation.

        Args:
            val_loader (Iterable): The validation data loader.
            metrics (list of :class:`argus.metrics.Metric`, optional):
                List of metrics to evaluate with the data. Defaults to `None`.
            callbacks (list of :class:`argus.callbacks.Callback`, optional):
                List of callbacks to be attached to the validation process.
                Defaults to `None`.

        Returns:
            dict: The metrics dictionary.

        """
        self._check_train_ready()
        metrics = [] if metrics is None else metrics
        val_engine = Engine(self.val_step, model=self,
                            logger=self.logger, phase='val')
        _attach_metrics(val_engine, [Loss()] + metrics)
        _attach_callbacks(val_engine, callbacks)
        metrics_logging.attach(val_engine, train=False, print_epoch=False)
        return val_engine.run(val_loader).metrics
Ejemplo n.º 3
0
    def fit(self,
            train_loader: Iterable,
            val_loader: Optional[Iterable] = None,
            num_epochs: int = 1,
            metrics: Optional[List[Union[Metric, str]]] = None,
            metrics_on_train: bool = False,
            callbacks: Optional[List[Callback]] = None,
            val_callbacks: Optional[List[Callback]] = None):
        """Train the argus model.

        The method attaches metrics and callbacks to the train and validation,
        and runs the training process.

        Args:
            train_loader (Iterable): The train data loader.
            val_loader (Iterable, optional):
                The validation data loader. Defaults to `None`.
            num_epochs (int, optional): Number of training epochs to
                run. Defaults to 1.
            metrics (list of :class:`argus.metrics.Metric`, optional):
                List of metrics to evaluate. By default, the metrics are
                evaluated on the validation data (if any) only.
                Defaults to `None`.
            metrics_on_train (bool, optional): Evaluate the metrics on train
                data as well. Defaults to False.
            callbacks (list of :class:`argus.callbacks.Callback`, optional):
                List of callbacks to be attached to the training process.
                Defaults to `None`.
            val_callbacks (list of :class:`argus.callbacks.Callback`, optional):
                List of callbacks to be attached to the validation process.
                Defaults to `None`.

        """
        self._check_train_ready()
        metrics = [] if metrics is None else metrics

        train_engine = Engine(self.train_step, model=self,
                              logger=self.logger, phase='train')
        train_metrics = [Loss()] + metrics if metrics_on_train else [Loss()]
        _attach_metrics(train_engine, train_metrics)
        metrics_logging.attach(train_engine, train=True)

        if val_loader is not None:
            self.validate(val_loader, metrics, val_callbacks)
            val_engine = Engine(self.val_step, model=self,
                                logger=self.logger, phase='val')
            _attach_metrics(val_engine, [Loss()] + metrics)
            _attach_callbacks(val_engine, val_callbacks)

            @on_epoch_complete
            def validation_epoch(train_state, val_engine, val_loader):
                epoch = train_state.epoch
                val_state = val_engine.run(val_loader, epoch, epoch + 1)
                train_state.metrics.update(val_state.metrics)

            validation_epoch.attach(train_engine, val_engine, val_loader)
            metrics_logging.attach(train_engine, train=False)

        _attach_callbacks(train_engine, callbacks)
        train_engine.run(train_loader, 0, num_epochs)
Ejemplo n.º 4
0
 def validate(self, val_loader, metrics=None, callbacks=None):
     metrics = [] if metrics is None else metrics
     assert self.train_ready()
     val_engine = Engine(self.val_step, model=self, logger=self.logger)
     _attach_metrics(val_engine, [Loss()] + metrics, name_prefix='val_')
     _attach_callbacks(val_engine, callbacks)
     metrics_logging.attach(val_engine, train=False, print_epoch=False)
     return val_engine.run(val_loader).metrics