def validate( self, val_loader: Iterable, metrics: Optional[List[Union[Metric, str]]] = None, callbacks: Optional[List[Callback]] = None) -> Dict[str, float]: """Perform a validation. Args: val_loader (Iterable): The validation data loader. metrics (list of :class:`argus.metrics.Metric` or str, optional): List of metrics to evaluate with the data. Defaults to `None`. callbacks (list of :class:`argus.callbacks.Callback`, optional): List of callbacks to be attached to the validation process. Defaults to `None`. Returns: dict: The metrics dictionary. """ self._check_train_ready() metrics = [] if metrics is None else metrics phase_states = dict() val_engine = Engine(self.val_step, phase_states=phase_states) attach_metrics(val_engine, [Loss()] + metrics) default_logging.attach(val_engine) attach_callbacks(val_engine, callbacks) state = val_engine.run(val_loader, -1, 0) return state.metrics
def fit(self, train_loader: Iterable, val_loader: Optional[Iterable] = None, num_epochs: int = 1, metrics: List[Union[Metric, str]] = None, metrics_on_train: bool = False, callbacks: Optional[List[Callback]] = None, val_callbacks: Optional[List[Callback]] = None): """Train the argus model. The method attaches metrics and callbacks to the train and validation process, and performs training itself. Args: train_loader (Iterable): The train data loader. val_loader (Iterable, optional): The validation data loader. Defaults to `None`. num_epochs (int, optional): Number of training epochs to run. Defaults to 1. metrics (list of :class:`argus.metrics.Metric` or str, optional): List of metrics to evaluate. By default, the metrics are evaluated on the validation data (if any) only. Defaults to `None`. metrics_on_train (bool, optional): Evaluate the metrics on train data as well. Defaults to False. callbacks (list of :class:`argus.callbacks.Callback`, optional): List of callbacks to be attached to the training process. Defaults to `None`. val_callbacks (list of :class:`argus.callbacks.Callback`, optional): List of callbacks to be attached to the validation process. Defaults to `None`. """ self._check_train_ready() metrics = [] if metrics is None else metrics phase_states = dict() train_engine = Engine(self.train_step, phase_states=phase_states) train_metrics = [Loss()] + metrics if metrics_on_train else [Loss()] attach_metrics(train_engine, train_metrics) default_logging.attach(train_engine) if val_loader is not None: val_engine = Engine(self.val_step, phase_states=phase_states) attach_metrics(val_engine, [Loss()] + metrics) default_logging.attach(val_engine) attach_callbacks(val_engine, val_callbacks) @argus.callbacks.on_epoch_complete def validation_epoch(train_state, val_engine, val_loader): epoch = train_state.epoch val_state = val_engine.run(val_loader, epoch, epoch + 1) train_state.metrics.update(val_state.metrics) validation_epoch.attach(train_engine, val_engine, val_loader) val_engine.run(val_loader, -1, 0) attach_callbacks(train_engine, callbacks) train_engine.run(train_loader, 0, num_epochs)
def test_on_decorators(self, step_storage, linear_argus_model_instance): @argus.callbacks.on_start def on_start_function(state): state.call_count = 1 state.on_start_flag = True @argus.callbacks.on_complete def on_complete_function(state): state.call_count += 1 state.on_complete_flag = True @argus.callbacks.on_epoch_start def on_epoch_start_function(state): state.call_count += 1 state.on_epoch_start_flag = True @argus.callbacks.on_epoch_complete def on_epoch_complete_function(state): state.call_count += 1 state.on_epoch_complete_flag = True @argus.callbacks.on_iteration_start def on_iteration_start_function(state): state.call_count += 1 state.on_iteration_start_flag = True @argus.callbacks.on_iteration_complete def on_iteration_complete_function(state): state.call_count += 1 state.on_iteration_complete_flag = True @argus.callbacks.on_catch_exception def on_catch_exception_function(state): state.call_count += 1 state.on_catch_exception_flag = True engine = Engine(step_storage.step_method, model=linear_argus_model_instance) attach_callbacks(engine, [ on_start_function, on_complete_function, on_epoch_start_function, on_epoch_complete_function, on_iteration_start_function, on_iteration_complete_function, on_catch_exception_function ]) data_loader = [4, 8, 15, 16, 23, 42] state = engine.run(data_loader, start_epoch=0, end_epoch=3) assert state.call_count == len(data_loader) * 3 * 2 + 3 * 2 + 2 assert state.on_start_flag assert state.on_complete_flag assert state.on_epoch_start_flag assert state.on_epoch_complete_flag assert state.on_iteration_start_flag assert state.on_iteration_complete_flag assert not hasattr(state, 'on_catch_exception_flag') class CustomException(Exception): pass @argus.callbacks.on_start def on_start_raise_exception(state): raise CustomException on_start_raise_exception.attach(engine) with pytest.raises(CustomException): engine.run(data_loader, start_epoch=0, end_epoch=3) assert engine.state.on_catch_exception_flag
def test_attach_not_a_callback(self, test_engine): with pytest.raises(TypeError): attach_callbacks(test_engine, [None]) with pytest.raises(TypeError): attach_callbacks(test_engine, [test_engine])