Example #1
0
    def fit(self,
            train_loader,
            val_loader=None,
            max_epochs=1,
            metrics=None,
            metrics_on_train=False,
            callbacks=None,
            val_callbacks=None):
        metrics = [] if metrics is None else metrics
        assert self.train_ready()
        setup_logging()

        train_engine = Engine(self.train_step, model=self, logger=self.logger)
        train_metrics = [Loss()] + metrics if metrics_on_train else [Loss()]
        _attach_metrics(train_engine, train_metrics, name_prefix='train_')
        metrics_logging.attach(train_engine, train=True)

        if val_loader is not None:
            self.validate(val_loader, metrics, val_callbacks)
            val_engine = Engine(self.val_step, model=self, logger=self.logger)
            _attach_metrics(val_engine, [Loss()] + metrics, name_prefix='val_')
            _attach_callbacks(val_engine, val_callbacks)

            @on_epoch_complete
            def validation_epoch(train_state, val_engine, val_loader):
                epoch = train_state.epoch
                val_state = val_engine.run(val_loader, epoch, epoch + 1)
                train_state.metrics.update(val_state.metrics)

            validation_epoch.attach(train_engine, val_engine, val_loader)
            metrics_logging.attach(train_engine, train=False)

        _attach_callbacks(train_engine, callbacks)
        train_engine.run(train_loader, 0, max_epochs)
Example #2
0
    def validate(
            self,
            val_loader: Iterable,
            metrics: Optional[List[Union[Metric, str]]] = None,
            callbacks: Optional[List[Callback]] = None) -> Dict[str, float]:
        """Perform a validation.

        Args:
            val_loader (Iterable): The validation data loader.
            metrics (list of :class:`argus.metrics.Metric` or str, optional):
                List of metrics to evaluate with the data. Defaults to `None`.
            callbacks (list of :class:`argus.callbacks.Callback`, optional):
                List of callbacks to be attached to the validation process.
                Defaults to `None`.

        Returns:
            dict: The metrics dictionary.

        """
        self._check_train_ready()
        metrics = [] if metrics is None else metrics
        phase_states = dict()
        val_engine = Engine(self.val_step, phase_states=phase_states)
        attach_metrics(val_engine, [Loss()] + metrics)
        default_logging.attach(val_engine)
        attach_callbacks(val_engine, callbacks)
        state = val_engine.run(val_loader, -1, 0)
        return state.metrics
Example #3
0
    def test_custom_events(self):
        class CustomEvents(EventEnum):
            STEP_START = 'step_start'
            STEP_COMPLETE = 'step_complete'

        def step_function(batch, state):
            state.step_output = batch
            state.engine.raise_event(CustomEvents.STEP_START)
            state.step_output += 1
            state.engine.raise_event(CustomEvents.STEP_COMPLETE)

        class CustomCallback(Callback):
            def __init__(self):
                self.start_storage = []
                self.end_storage = []

            def step_start(self, state):
                self.start_storage.append(state.step_output)

            def step_complete(self, state):
                self.end_storage.append(state.step_output)

        data_loader = [4, 8, 15, 16, 23, 42]
        callback = CustomCallback()
        engine = Engine(step_function)
        _attach_callbacks(engine, [callback])
        engine.run(data_loader)
        assert callback.start_storage == data_loader
        assert callback.end_storage == [d + 1 for d in data_loader]
Example #4
0
    def validate(self,
                 val_loader: Optional[Iterable],
                 metrics: Optional[List[Metric]] = None,
                 callbacks: Optional[List[Callback]] = None) -> Dict[str, float]:
        """Perform a validation.

        Args:
            val_loader (Iterable): The validation data loader.
            metrics (list of :class:`argus.metrics.Metric`, optional):
                List of metrics to evaluate with the data. Defaults to `None`.
            callbacks (list of :class:`argus.callbacks.Callback`, optional):
                List of callbacks to be attached to the validation process.
                Defaults to `None`.

        Returns:
            dict: The metrics dictionary.

        """
        self._check_train_ready()
        metrics = [] if metrics is None else metrics
        val_engine = Engine(self.val_step, model=self,
                            logger=self.logger, phase='val')
        _attach_metrics(val_engine, [Loss()] + metrics)
        _attach_callbacks(val_engine, callbacks)
        metrics_logging.attach(val_engine, train=False, print_epoch=False)
        return val_engine.run(val_loader).metrics
Example #5
0
    def fit(self,
            train_loader: Iterable,
            val_loader: Optional[Iterable] = None,
            num_epochs: int = 1,
            metrics: Optional[List[Union[Metric, str]]] = None,
            metrics_on_train: bool = False,
            callbacks: Optional[List[Callback]] = None,
            val_callbacks: Optional[List[Callback]] = None):
        """Train the argus model.

        The method attaches metrics and callbacks to the train and validation,
        and runs the training process.

        Args:
            train_loader (Iterable): The train data loader.
            val_loader (Iterable, optional):
                The validation data loader. Defaults to `None`.
            num_epochs (int, optional): Number of training epochs to
                run. Defaults to 1.
            metrics (list of :class:`argus.metrics.Metric`, optional):
                List of metrics to evaluate. By default, the metrics are
                evaluated on the validation data (if any) only.
                Defaults to `None`.
            metrics_on_train (bool, optional): Evaluate the metrics on train
                data as well. Defaults to False.
            callbacks (list of :class:`argus.callbacks.Callback`, optional):
                List of callbacks to be attached to the training process.
                Defaults to `None`.
            val_callbacks (list of :class:`argus.callbacks.Callback`, optional):
                List of callbacks to be attached to the validation process.
                Defaults to `None`.

        """
        self._check_train_ready()
        metrics = [] if metrics is None else metrics

        train_engine = Engine(self.train_step, model=self,
                              logger=self.logger, phase='train')
        train_metrics = [Loss()] + metrics if metrics_on_train else [Loss()]
        _attach_metrics(train_engine, train_metrics)
        metrics_logging.attach(train_engine, train=True)

        if val_loader is not None:
            self.validate(val_loader, metrics, val_callbacks)
            val_engine = Engine(self.val_step, model=self,
                                logger=self.logger, phase='val')
            _attach_metrics(val_engine, [Loss()] + metrics)
            _attach_callbacks(val_engine, val_callbacks)

            @on_epoch_complete
            def validation_epoch(train_state, val_engine, val_loader):
                epoch = train_state.epoch
                val_state = val_engine.run(val_loader, epoch, epoch + 1)
                train_state.metrics.update(val_state.metrics)

            validation_epoch.attach(train_engine, val_engine, val_loader)
            metrics_logging.attach(train_engine, train=False)

        _attach_callbacks(train_engine, callbacks)
        train_engine.run(train_loader, 0, num_epochs)
Example #6
0
    def test_custom_metric(self, engine):
        metric = CustomMetric()
        data_loader = [4, 8, 15, 16, 23, 42]
        _attach_metrics(engine, [metric])
        with pytest.raises(TypeError):
            _attach_metrics(engine, [None])
        state = engine.run(data_loader)
        assert metric.data == data_loader
        assert metric.compute() == len(data_loader)
        assert state.metrics == {"custom_metric": len(data_loader)}
        metric.reset()
        assert metric.data == []
        assert metric.compute() == 0

        engine = Engine(lambda batch, state: batch, phase='train')
        _attach_metrics(engine, [metric])
        state = engine.run(data_loader)
        assert metric.compute() == len(data_loader)
        assert state.metrics == {"train_custom_metric": len(data_loader)}

        @argus.callbacks.on_iteration_start
        def stop_on_first_iteration(state):
            state.stopped = True

        stop_on_first_iteration.attach(engine)
        engine.run(data_loader)
        assert metric.compute() == 1
Example #7
0
 def validate(self, val_loader, metrics=None, callbacks=None):
     metrics = [] if metrics is None else metrics
     assert self.train_ready()
     val_engine = Engine(self.val_step, model=self, logger=self.logger)
     _attach_metrics(val_engine, [Loss()] + metrics, name_prefix='val_')
     _attach_callbacks(val_engine, callbacks)
     metrics_logging.attach(val_engine, train=False, print_epoch=False)
     return val_engine.run(val_loader).metrics
Example #8
0
    def test_custom_callback_by_name(self):
        data_loader = [4, 8, 15, 16, 23, 42]
        engine = Engine(lambda batch, state: batch, phase='val')
        _attach_metrics(engine, ["custom_metric"])
        state = engine.run(data_loader)
        assert state.metrics == {"val_custom_metric": len(data_loader)}

        with pytest.raises(ValueError):
            _attach_metrics(engine, ["qwerty"])
Example #9
0
 def test_phase_states(self, linear_argus_model_instance):
     phase_states = dict()
     train_engine = Engine(linear_argus_model_instance.train_step,
                           phase_states=phase_states)
     val_engine = Engine(linear_argus_model_instance.val_step,
                         phase_states=phase_states)
     assert train_engine.state.phase_states['train'] is train_engine.state
     assert train_engine.state.phase_states['val'] is val_engine.state
     assert val_engine.state.phase_states['train'] is train_engine.state
     assert val_engine.state.phase_states['val'] is val_engine.state
Example #10
0
    def test_on_event(self, step_storage):
        @argus.callbacks.on_event(Events.START)
        def some_function(state):
            state.special_secret = 42

        engine = Engine(step_storage.step_method)
        some_function.attach(engine)
        data_loader = [4, 8, 15, 16, 23, 42]
        state = engine.run(data_loader)
        assert state.special_secret == 42
Example #11
0
    def attach(self, engine: Engine, *args, **kwargs):
        """Attach callback to the :class:`argus.engine.Engine`.

        Args:
            engine (Engine): The engine to which the callback will be attached.
            *args: optional args arguments to be passed to the handler.
            **kwargs: optional kwargs arguments to be passed to the handler.

        """
        engine.add_event_handler(self.event, self.handler, *args, **kwargs)
Example #12
0
    def attach(self, engine: Engine):
        """Attach callback to the :class:`argus.engine.Engine`.

        Args:
            engine (Engine): The engine to which the callback will be attached.

        """
        for event_enum in inheritors(EventEnum):
            for key, event in event_enum.__members__.items():
                if hasattr(self, event.value):
                    handler = getattr(self, event.value)
                    if isinstance(handler, (FunctionType, MethodType)):
                        engine.add_event_handler(event, handler)
                    else:
                        raise TypeError(
                            f"Attribute {event.value} is not callable.")
Example #13
0
 def test_attach_callback(self, n_epochs, custom_test_callback,
                          step_storage):
     callback = custom_test_callback
     engine = Engine(step_storage.step_method)
     callback.attach(engine)
     data_loader = [4, 8, 15, 16, 23, 42]
     engine.run(data_loader, start_epoch=0, end_epoch=n_epochs)
     assert callback.start_count == 1
     assert callback.complete_count == 1
     assert callback.epoch_start_count == n_epochs
     assert callback.epoch_complete_count == n_epochs
     assert callback.iteration_start_count == n_epochs * len(data_loader)
     assert callback.iteration_complete_count == n_epochs * len(data_loader)
     assert callback.catch_exception_count == 0
     assert callback.step_start_count == n_epochs * len(data_loader)
     assert callback.step_complete_count == n_epochs * len(data_loader)
     assert step_storage.batch_lst == n_epochs * data_loader
     assert step_storage.state is engine.state if n_epochs \
         else step_storage.state is None
Example #14
0
    def test_custom_events(self, linear_net_class):
        class CustomEvents(EventEnum):
            STEP_START = 'step_start'
            STEP_COMPLETE = 'step_complete'

        class CustomEventsModel(argus.Model):
            nn_module = linear_net_class

            def count_step(self, batch, state):
                state.step_output = batch
                state.engine.raise_event(CustomEvents.STEP_START)
                state.step_output += 1
                state.engine.raise_event(CustomEvents.STEP_COMPLETE)

        model = CustomEventsModel({
            'nn_module': {
                'in_features': 10,
                'out_features': 1,
            },
            'optimizer': None,
            'loss': None
        })

        class CustomCallback(Callback):
            def __init__(self):
                self.start_storage = []
                self.end_storage = []

            def step_start(self, state):
                self.start_storage.append(state.step_output)

            def step_complete(self, state):
                self.end_storage.append(state.step_output)

        data_loader = [4, 8, 15, 16, 23, 42]
        callback = CustomCallback()
        engine = Engine(model.count_step)
        attach_callbacks(engine, [callback])
        engine.run(data_loader)
        assert callback.start_storage == data_loader
        assert callback.end_storage == [d + 1 for d in data_loader]
Example #15
0
    def test_add_event_handler(self):
        def some_function():
            pass

        engine = Engine(some_function)
        assert len(engine.event_handlers[Events.START]) == 0
        engine.add_event_handler(Events.START, some_function)
        assert len(engine.event_handlers[Events.START]) == 1
        assert engine.event_handlers[Events.START][0][0] is some_function

        with pytest.raises(TypeError):
            engine.add_event_handler(42, some_function)
Example #16
0
    def fit(self,
            train_loader,
            val_loader=None,
            max_epochs=1,
            metrics=None,
            val_event_handlers=None,
            train_event_handlers=None):

        assert self.train_ready()

        setup_logging()
        if metrics is None:
            metrics = dict()

        train_engine = Engine(self._train_step)

        train_loss = TrainLoss()
        train_loss.attach(train_engine, 'train_loss')
        train_engine.add_event_handler(Events.EPOCH_COMPLETE,
                                       train_loss_logging)

        if val_loader is not None:
            val_engine = Engine(self._val_step)

            if 'val_loss' not in metrics:
                metrics['val_loss'] = Loss(self.loss)

            for name, metric in metrics.items():
                metric.attach(val_engine, name)

            validation_logging(train_engine, val_engine, val_loader)
            train_engine.add_event_handler(Events.EPOCH_COMPLETE,
                                           validation_logging, val_engine,
                                           val_loader)

        train_engine.run(train_loader, max_epochs)
Example #17
0
def engine(linear_argus_model_instance):
    return Engine(lambda batch, state: batch,
                  model=linear_argus_model_instance,
                  logger=linear_argus_model_instance.logger)
Example #18
0
    def test_on_decorators(self, step_storage, linear_argus_model_instance):
        @argus.callbacks.on_start
        def on_start_function(state):
            state.call_count = 1
            state.on_start_flag = True

        @argus.callbacks.on_complete
        def on_complete_function(state):
            state.call_count += 1
            state.on_complete_flag = True

        @argus.callbacks.on_epoch_start
        def on_epoch_start_function(state):
            state.call_count += 1
            state.on_epoch_start_flag = True

        @argus.callbacks.on_epoch_complete
        def on_epoch_complete_function(state):
            state.call_count += 1
            state.on_epoch_complete_flag = True

        @argus.callbacks.on_iteration_start
        def on_iteration_start_function(state):
            state.call_count += 1
            state.on_iteration_start_flag = True

        @argus.callbacks.on_iteration_complete
        def on_iteration_complete_function(state):
            state.call_count += 1
            state.on_iteration_complete_flag = True

        @argus.callbacks.on_catch_exception
        def on_catch_exception_function(state):
            state.call_count += 1
            state.on_catch_exception_flag = True

        engine = Engine(step_storage.step_method,
                        model=linear_argus_model_instance)
        attach_callbacks(engine, [
            on_start_function, on_complete_function, on_epoch_start_function,
            on_epoch_complete_function, on_iteration_start_function,
            on_iteration_complete_function, on_catch_exception_function
        ])
        data_loader = [4, 8, 15, 16, 23, 42]
        state = engine.run(data_loader, start_epoch=0, end_epoch=3)
        assert state.call_count == len(data_loader) * 3 * 2 + 3 * 2 + 2
        assert state.on_start_flag
        assert state.on_complete_flag
        assert state.on_epoch_start_flag
        assert state.on_epoch_complete_flag
        assert state.on_iteration_start_flag
        assert state.on_iteration_complete_flag
        assert not hasattr(state, 'on_catch_exception_flag')

        class CustomException(Exception):
            pass

        @argus.callbacks.on_start
        def on_start_raise_exception(state):
            raise CustomException

        on_start_raise_exception.attach(engine)
        with pytest.raises(CustomException):
            engine.run(data_loader, start_epoch=0, end_epoch=3)
        assert engine.state.on_catch_exception_flag
Example #19
0
    def test_run(self):
        class StepStorage:
            def __init__(self):
                self.batch_lst = []
                self.state = None

            def reset(self):
                self.batch_lst = []
                self.state = None

            def step_method(self, batch, state):
                self.batch_lst.append(batch)
                self.state = state

        step_storage = StepStorage()

        data_loader = [4, 8, 15, 16, 23, 42]
        engine = Engine(
            step_storage.step_method,
            logger=logging.getLogger('TestEngineMethods::test_run'))
        state = engine.run(data_loader, start_epoch=0, end_epoch=3)

        assert step_storage.batch_lst == data_loader * 3
        assert state.epoch == 3
        assert state.iteration == len(data_loader)

        def stop_function(state):
            state.stopped = True

        step_storage.reset()
        engine.add_event_handler(Events.EPOCH_COMPLETE, stop_function)
        state = engine.run(data_loader, start_epoch=0, end_epoch=3)
        assert step_storage.batch_lst == data_loader
        assert state.epoch == 1
        assert state.iteration == len(data_loader)

        step_storage.reset()
        engine.add_event_handler(Events.ITERATION_COMPLETE, stop_function)
        state = engine.run(data_loader, start_epoch=0, end_epoch=3)
        assert step_storage.batch_lst == [data_loader[0]]
        assert state.iteration == 1

        class CustomException(Exception):
            pass

        def exception_function(state):
            raise CustomException

        step_storage.reset()
        engine.add_event_handler(Events.START, exception_function)
        with pytest.raises(CustomException):
            engine.run(data_loader, start_epoch=0, end_epoch=3)
        assert step_storage.batch_lst == []
        assert engine.state.iteration == 0
        assert engine.state.epoch == 0
Example #20
0
def test_engine(linear_argus_model_instance):
    return Engine(linear_argus_model_instance.test_step)
Example #21
0
def val_engine(linear_argus_model_instance):
    return Engine(linear_argus_model_instance.val_step)
Example #22
0
def train_engine(linear_argus_model_instance):
    return Engine(linear_argus_model_instance.train_step)