def test_custom_events(self): class CustomEvents(EventEnum): STEP_START = 'step_start' STEP_COMPLETE = 'step_complete' def step_function(batch, state): state.step_output = batch state.engine.raise_event(CustomEvents.STEP_START) state.step_output += 1 state.engine.raise_event(CustomEvents.STEP_COMPLETE) class CustomCallback(Callback): def __init__(self): self.start_storage = [] self.end_storage = [] def step_start(self, state): self.start_storage.append(state.step_output) def step_complete(self, state): self.end_storage.append(state.step_output) data_loader = [4, 8, 15, 16, 23, 42] callback = CustomCallback() engine = Engine(step_function) _attach_callbacks(engine, [callback]) engine.run(data_loader) assert callback.start_storage == data_loader assert callback.end_storage == [d + 1 for d in data_loader]
def test_custom_metric(self, engine): metric = CustomMetric() data_loader = [4, 8, 15, 16, 23, 42] _attach_metrics(engine, [metric]) with pytest.raises(TypeError): _attach_metrics(engine, [None]) state = engine.run(data_loader) assert metric.data == data_loader assert metric.compute() == len(data_loader) assert state.metrics == {"custom_metric": len(data_loader)} metric.reset() assert metric.data == [] assert metric.compute() == 0 engine = Engine(lambda batch, state: batch, phase='train') _attach_metrics(engine, [metric]) state = engine.run(data_loader) assert metric.compute() == len(data_loader) assert state.metrics == {"train_custom_metric": len(data_loader)} @argus.callbacks.on_iteration_start def stop_on_first_iteration(state): state.stopped = True stop_on_first_iteration.attach(engine) engine.run(data_loader) assert metric.compute() == 1
def fit(self, train_loader, val_loader=None, max_epochs=1, metrics=None, metrics_on_train=False, callbacks=None, val_callbacks=None): metrics = [] if metrics is None else metrics assert self.train_ready() setup_logging() train_engine = Engine(self.train_step, model=self, logger=self.logger) train_metrics = [Loss()] + metrics if metrics_on_train else [Loss()] _attach_metrics(train_engine, train_metrics, name_prefix='train_') metrics_logging.attach(train_engine, train=True) if val_loader is not None: self.validate(val_loader, metrics, val_callbacks) val_engine = Engine(self.val_step, model=self, logger=self.logger) _attach_metrics(val_engine, [Loss()] + metrics, name_prefix='val_') _attach_callbacks(val_engine, val_callbacks) @on_epoch_complete def validation_epoch(train_state, val_engine, val_loader): epoch = train_state.epoch val_state = val_engine.run(val_loader, epoch, epoch + 1) train_state.metrics.update(val_state.metrics) validation_epoch.attach(train_engine, val_engine, val_loader) metrics_logging.attach(train_engine, train=False) _attach_callbacks(train_engine, callbacks) train_engine.run(train_loader, 0, max_epochs)
def fit(self, train_loader: Iterable, val_loader: Optional[Iterable] = None, num_epochs: int = 1, metrics: Optional[List[Union[Metric, str]]] = None, metrics_on_train: bool = False, callbacks: Optional[List[Callback]] = None, val_callbacks: Optional[List[Callback]] = None): """Train the argus model. The method attaches metrics and callbacks to the train and validation, and runs the training process. Args: train_loader (Iterable): The train data loader. val_loader (Iterable, optional): The validation data loader. Defaults to `None`. num_epochs (int, optional): Number of training epochs to run. Defaults to 1. metrics (list of :class:`argus.metrics.Metric`, optional): List of metrics to evaluate. By default, the metrics are evaluated on the validation data (if any) only. Defaults to `None`. metrics_on_train (bool, optional): Evaluate the metrics on train data as well. Defaults to False. callbacks (list of :class:`argus.callbacks.Callback`, optional): List of callbacks to be attached to the training process. Defaults to `None`. val_callbacks (list of :class:`argus.callbacks.Callback`, optional): List of callbacks to be attached to the validation process. Defaults to `None`. """ self._check_train_ready() metrics = [] if metrics is None else metrics train_engine = Engine(self.train_step, model=self, logger=self.logger, phase='train') train_metrics = [Loss()] + metrics if metrics_on_train else [Loss()] _attach_metrics(train_engine, train_metrics) metrics_logging.attach(train_engine, train=True) if val_loader is not None: self.validate(val_loader, metrics, val_callbacks) val_engine = Engine(self.val_step, model=self, logger=self.logger, phase='val') _attach_metrics(val_engine, [Loss()] + metrics) _attach_callbacks(val_engine, val_callbacks) @on_epoch_complete def validation_epoch(train_state, val_engine, val_loader): epoch = train_state.epoch val_state = val_engine.run(val_loader, epoch, epoch + 1) train_state.metrics.update(val_state.metrics) validation_epoch.attach(train_engine, val_engine, val_loader) metrics_logging.attach(train_engine, train=False) _attach_callbacks(train_engine, callbacks) train_engine.run(train_loader, 0, num_epochs)
def test_run(self): class StepStorage: def __init__(self): self.batch_lst = [] self.state = None def reset(self): self.batch_lst = [] self.state = None def step_method(self, batch, state): self.batch_lst.append(batch) self.state = state step_storage = StepStorage() data_loader = [4, 8, 15, 16, 23, 42] engine = Engine( step_storage.step_method, logger=logging.getLogger('TestEngineMethods::test_run')) state = engine.run(data_loader, start_epoch=0, end_epoch=3) assert step_storage.batch_lst == data_loader * 3 assert state.epoch == 3 assert state.iteration == len(data_loader) def stop_function(state): state.stopped = True step_storage.reset() engine.add_event_handler(Events.EPOCH_COMPLETE, stop_function) state = engine.run(data_loader, start_epoch=0, end_epoch=3) assert step_storage.batch_lst == data_loader assert state.epoch == 1 assert state.iteration == len(data_loader) step_storage.reset() engine.add_event_handler(Events.ITERATION_COMPLETE, stop_function) state = engine.run(data_loader, start_epoch=0, end_epoch=3) assert step_storage.batch_lst == [data_loader[0]] assert state.iteration == 1 class CustomException(Exception): pass def exception_function(state): raise CustomException step_storage.reset() engine.add_event_handler(Events.START, exception_function) with pytest.raises(CustomException): engine.run(data_loader, start_epoch=0, end_epoch=3) assert step_storage.batch_lst == [] assert engine.state.iteration == 0 assert engine.state.epoch == 0
def validate( self, val_loader: Iterable, metrics: Optional[List[Union[Metric, str]]] = None, callbacks: Optional[List[Callback]] = None) -> Dict[str, float]: """Perform a validation. Args: val_loader (Iterable): The validation data loader. metrics (list of :class:`argus.metrics.Metric` or str, optional): List of metrics to evaluate with the data. Defaults to `None`. callbacks (list of :class:`argus.callbacks.Callback`, optional): List of callbacks to be attached to the validation process. Defaults to `None`. Returns: dict: The metrics dictionary. """ self._check_train_ready() metrics = [] if metrics is None else metrics phase_states = dict() val_engine = Engine(self.val_step, phase_states=phase_states) attach_metrics(val_engine, [Loss()] + metrics) default_logging.attach(val_engine) attach_callbacks(val_engine, callbacks) state = val_engine.run(val_loader, -1, 0) return state.metrics
def validate(self, val_loader: Optional[Iterable], metrics: Optional[List[Metric]] = None, callbacks: Optional[List[Callback]] = None) -> Dict[str, float]: """Perform a validation. Args: val_loader (Iterable): The validation data loader. metrics (list of :class:`argus.metrics.Metric`, optional): List of metrics to evaluate with the data. Defaults to `None`. callbacks (list of :class:`argus.callbacks.Callback`, optional): List of callbacks to be attached to the validation process. Defaults to `None`. Returns: dict: The metrics dictionary. """ self._check_train_ready() metrics = [] if metrics is None else metrics val_engine = Engine(self.val_step, model=self, logger=self.logger, phase='val') _attach_metrics(val_engine, [Loss()] + metrics) _attach_callbacks(val_engine, callbacks) metrics_logging.attach(val_engine, train=False, print_epoch=False) return val_engine.run(val_loader).metrics
def validate(self, val_loader, metrics=None, callbacks=None): metrics = [] if metrics is None else metrics assert self.train_ready() val_engine = Engine(self.val_step, model=self, logger=self.logger) _attach_metrics(val_engine, [Loss()] + metrics, name_prefix='val_') _attach_callbacks(val_engine, callbacks) metrics_logging.attach(val_engine, train=False, print_epoch=False) return val_engine.run(val_loader).metrics
def test_custom_callback_by_name(self): data_loader = [4, 8, 15, 16, 23, 42] engine = Engine(lambda batch, state: batch, phase='val') _attach_metrics(engine, ["custom_metric"]) state = engine.run(data_loader) assert state.metrics == {"val_custom_metric": len(data_loader)} with pytest.raises(ValueError): _attach_metrics(engine, ["qwerty"])
def test_attach_callback(self, n_epochs, custom_test_callback, step_storage): callback = custom_test_callback engine = Engine(step_storage.step_method) callback.attach(engine) data_loader = [4, 8, 15, 16, 23, 42] engine.run(data_loader, start_epoch=0, end_epoch=n_epochs) assert callback.start_count == 1 assert callback.complete_count == 1 assert callback.epoch_start_count == n_epochs assert callback.epoch_complete_count == n_epochs assert callback.iteration_start_count == n_epochs * len(data_loader) assert callback.iteration_complete_count == n_epochs * len(data_loader) assert callback.catch_exception_count == 0 assert callback.step_start_count == n_epochs * len(data_loader) assert callback.step_complete_count == n_epochs * len(data_loader) assert step_storage.batch_lst == n_epochs * data_loader assert step_storage.state is engine.state if n_epochs \ else step_storage.state is None
def test_on_event(self, step_storage): @argus.callbacks.on_event(Events.START) def some_function(state): state.special_secret = 42 engine = Engine(step_storage.step_method) some_function.attach(engine) data_loader = [4, 8, 15, 16, 23, 42] state = engine.run(data_loader) assert state.special_secret == 42
def test_custom_events(self, linear_net_class): class CustomEvents(EventEnum): STEP_START = 'step_start' STEP_COMPLETE = 'step_complete' class CustomEventsModel(argus.Model): nn_module = linear_net_class def count_step(self, batch, state): state.step_output = batch state.engine.raise_event(CustomEvents.STEP_START) state.step_output += 1 state.engine.raise_event(CustomEvents.STEP_COMPLETE) model = CustomEventsModel({ 'nn_module': { 'in_features': 10, 'out_features': 1, }, 'optimizer': None, 'loss': None }) class CustomCallback(Callback): def __init__(self): self.start_storage = [] self.end_storage = [] def step_start(self, state): self.start_storage.append(state.step_output) def step_complete(self, state): self.end_storage.append(state.step_output) data_loader = [4, 8, 15, 16, 23, 42] callback = CustomCallback() engine = Engine(model.count_step) attach_callbacks(engine, [callback]) engine.run(data_loader) assert callback.start_storage == data_loader assert callback.end_storage == [d + 1 for d in data_loader]
def fit(self, train_loader, val_loader=None, max_epochs=1, metrics=None, val_event_handlers=None, train_event_handlers=None): assert self.train_ready() setup_logging() if metrics is None: metrics = dict() train_engine = Engine(self._train_step) train_loss = TrainLoss() train_loss.attach(train_engine, 'train_loss') train_engine.add_event_handler(Events.EPOCH_COMPLETE, train_loss_logging) if val_loader is not None: val_engine = Engine(self._val_step) if 'val_loss' not in metrics: metrics['val_loss'] = Loss(self.loss) for name, metric in metrics.items(): metric.attach(val_engine, name) validation_logging(train_engine, val_engine, val_loader) train_engine.add_event_handler(Events.EPOCH_COMPLETE, validation_logging, val_engine, val_loader) train_engine.run(train_loader, max_epochs)
def test_on_decorators(self, step_storage, linear_argus_model_instance): @argus.callbacks.on_start def on_start_function(state): state.call_count = 1 state.on_start_flag = True @argus.callbacks.on_complete def on_complete_function(state): state.call_count += 1 state.on_complete_flag = True @argus.callbacks.on_epoch_start def on_epoch_start_function(state): state.call_count += 1 state.on_epoch_start_flag = True @argus.callbacks.on_epoch_complete def on_epoch_complete_function(state): state.call_count += 1 state.on_epoch_complete_flag = True @argus.callbacks.on_iteration_start def on_iteration_start_function(state): state.call_count += 1 state.on_iteration_start_flag = True @argus.callbacks.on_iteration_complete def on_iteration_complete_function(state): state.call_count += 1 state.on_iteration_complete_flag = True @argus.callbacks.on_catch_exception def on_catch_exception_function(state): state.call_count += 1 state.on_catch_exception_flag = True engine = Engine(step_storage.step_method, model=linear_argus_model_instance) attach_callbacks(engine, [ on_start_function, on_complete_function, on_epoch_start_function, on_epoch_complete_function, on_iteration_start_function, on_iteration_complete_function, on_catch_exception_function ]) data_loader = [4, 8, 15, 16, 23, 42] state = engine.run(data_loader, start_epoch=0, end_epoch=3) assert state.call_count == len(data_loader) * 3 * 2 + 3 * 2 + 2 assert state.on_start_flag assert state.on_complete_flag assert state.on_epoch_start_flag assert state.on_epoch_complete_flag assert state.on_iteration_start_flag assert state.on_iteration_complete_flag assert not hasattr(state, 'on_catch_exception_flag') class CustomException(Exception): pass @argus.callbacks.on_start def on_start_raise_exception(state): raise CustomException on_start_raise_exception.attach(engine) with pytest.raises(CustomException): engine.run(data_loader, start_epoch=0, end_epoch=3) assert engine.state.on_catch_exception_flag