示例#1
0
文件: model.py 项目: lRomul/argus
    def validate(
            self,
            val_loader: Iterable,
            metrics: Optional[List[Union[Metric, str]]] = None,
            callbacks: Optional[List[Callback]] = None) -> Dict[str, float]:
        """Perform a validation.

        Args:
            val_loader (Iterable): The validation data loader.
            metrics (list of :class:`argus.metrics.Metric` or str, optional):
                List of metrics to evaluate with the data. Defaults to `None`.
            callbacks (list of :class:`argus.callbacks.Callback`, optional):
                List of callbacks to be attached to the validation process.
                Defaults to `None`.

        Returns:
            dict: The metrics dictionary.

        """
        self._check_train_ready()
        metrics = [] if metrics is None else metrics
        phase_states = dict()
        val_engine = Engine(self.val_step, phase_states=phase_states)
        attach_metrics(val_engine, [Loss()] + metrics)
        default_logging.attach(val_engine)
        attach_callbacks(val_engine, callbacks)
        state = val_engine.run(val_loader, -1, 0)
        return state.metrics
示例#2
0
文件: model.py 项目: lRomul/argus
    def fit(self,
            train_loader: Iterable,
            val_loader: Optional[Iterable] = None,
            num_epochs: int = 1,
            metrics: List[Union[Metric, str]] = None,
            metrics_on_train: bool = False,
            callbacks: Optional[List[Callback]] = None,
            val_callbacks: Optional[List[Callback]] = None):
        """Train the argus model.

        The method attaches metrics and callbacks to the train and validation
        process, and performs training itself.

        Args:
            train_loader (Iterable): The train data loader.
            val_loader (Iterable, optional): The validation data loader.
                Defaults to `None`.
            num_epochs (int, optional): Number of training epochs to
                run. Defaults to 1.
            metrics (list of :class:`argus.metrics.Metric` or str, optional):
                List of metrics to evaluate. By default, the metrics are
                evaluated on the validation data (if any) only.
                Defaults to `None`.
            metrics_on_train (bool, optional): Evaluate the metrics on train
                data as well. Defaults to False.
            callbacks (list of :class:`argus.callbacks.Callback`, optional):
                List of callbacks to be attached to the training process.
                Defaults to `None`.
            val_callbacks (list of :class:`argus.callbacks.Callback`, optional):
                List of callbacks to be attached to the validation process.
                Defaults to `None`.

        """
        self._check_train_ready()
        metrics = [] if metrics is None else metrics
        phase_states = dict()

        train_engine = Engine(self.train_step, phase_states=phase_states)
        train_metrics = [Loss()] + metrics if metrics_on_train else [Loss()]
        attach_metrics(train_engine, train_metrics)
        default_logging.attach(train_engine)

        if val_loader is not None:
            val_engine = Engine(self.val_step, phase_states=phase_states)
            attach_metrics(val_engine, [Loss()] + metrics)
            default_logging.attach(val_engine)
            attach_callbacks(val_engine, val_callbacks)

            @argus.callbacks.on_epoch_complete
            def validation_epoch(train_state, val_engine, val_loader):
                epoch = train_state.epoch
                val_state = val_engine.run(val_loader, epoch, epoch + 1)
                train_state.metrics.update(val_state.metrics)

            validation_epoch.attach(train_engine, val_engine, val_loader)
            val_engine.run(val_loader, -1, 0)

        attach_callbacks(train_engine, callbacks)
        train_engine.run(train_loader, 0, num_epochs)
示例#3
0
    def test_on_decorators(self, step_storage, linear_argus_model_instance):
        @argus.callbacks.on_start
        def on_start_function(state):
            state.call_count = 1
            state.on_start_flag = True

        @argus.callbacks.on_complete
        def on_complete_function(state):
            state.call_count += 1
            state.on_complete_flag = True

        @argus.callbacks.on_epoch_start
        def on_epoch_start_function(state):
            state.call_count += 1
            state.on_epoch_start_flag = True

        @argus.callbacks.on_epoch_complete
        def on_epoch_complete_function(state):
            state.call_count += 1
            state.on_epoch_complete_flag = True

        @argus.callbacks.on_iteration_start
        def on_iteration_start_function(state):
            state.call_count += 1
            state.on_iteration_start_flag = True

        @argus.callbacks.on_iteration_complete
        def on_iteration_complete_function(state):
            state.call_count += 1
            state.on_iteration_complete_flag = True

        @argus.callbacks.on_catch_exception
        def on_catch_exception_function(state):
            state.call_count += 1
            state.on_catch_exception_flag = True

        engine = Engine(step_storage.step_method,
                        model=linear_argus_model_instance)
        attach_callbacks(engine, [
            on_start_function, on_complete_function, on_epoch_start_function,
            on_epoch_complete_function, on_iteration_start_function,
            on_iteration_complete_function, on_catch_exception_function
        ])
        data_loader = [4, 8, 15, 16, 23, 42]
        state = engine.run(data_loader, start_epoch=0, end_epoch=3)
        assert state.call_count == len(data_loader) * 3 * 2 + 3 * 2 + 2
        assert state.on_start_flag
        assert state.on_complete_flag
        assert state.on_epoch_start_flag
        assert state.on_epoch_complete_flag
        assert state.on_iteration_start_flag
        assert state.on_iteration_complete_flag
        assert not hasattr(state, 'on_catch_exception_flag')

        class CustomException(Exception):
            pass

        @argus.callbacks.on_start
        def on_start_raise_exception(state):
            raise CustomException

        on_start_raise_exception.attach(engine)
        with pytest.raises(CustomException):
            engine.run(data_loader, start_epoch=0, end_epoch=3)
        assert engine.state.on_catch_exception_flag
示例#4
0
 def test_attach_not_a_callback(self, test_engine):
     with pytest.raises(TypeError):
         attach_callbacks(test_engine, [None])
     with pytest.raises(TypeError):
         attach_callbacks(test_engine, [test_engine])