def fit(self, train_loader, val_loader=None, max_epochs=1, metrics=None, metrics_on_train=False, callbacks=None, val_callbacks=None): metrics = [] if metrics is None else metrics assert self.train_ready() setup_logging() train_engine = Engine(self.train_step, model=self, logger=self.logger) train_metrics = [Loss()] + metrics if metrics_on_train else [Loss()] _attach_metrics(train_engine, train_metrics, name_prefix='train_') metrics_logging.attach(train_engine, train=True) if val_loader is not None: self.validate(val_loader, metrics, val_callbacks) val_engine = Engine(self.val_step, model=self, logger=self.logger) _attach_metrics(val_engine, [Loss()] + metrics, name_prefix='val_') _attach_callbacks(val_engine, val_callbacks) @on_epoch_complete def validation_epoch(train_state, val_engine, val_loader): epoch = train_state.epoch val_state = val_engine.run(val_loader, epoch, epoch + 1) train_state.metrics.update(val_state.metrics) validation_epoch.attach(train_engine, val_engine, val_loader) metrics_logging.attach(train_engine, train=False) _attach_callbacks(train_engine, callbacks) train_engine.run(train_loader, 0, max_epochs)
def validate(self, val_loader: Optional[Iterable], metrics: Optional[List[Metric]] = None, callbacks: Optional[List[Callback]] = None) -> Dict[str, float]: """Perform a validation. Args: val_loader (Iterable): The validation data loader. metrics (list of :class:`argus.metrics.Metric`, optional): List of metrics to evaluate with the data. Defaults to `None`. callbacks (list of :class:`argus.callbacks.Callback`, optional): List of callbacks to be attached to the validation process. Defaults to `None`. Returns: dict: The metrics dictionary. """ self._check_train_ready() metrics = [] if metrics is None else metrics val_engine = Engine(self.val_step, model=self, logger=self.logger, phase='val') _attach_metrics(val_engine, [Loss()] + metrics) _attach_callbacks(val_engine, callbacks) metrics_logging.attach(val_engine, train=False, print_epoch=False) return val_engine.run(val_loader).metrics
def fit(self, train_loader: Iterable, val_loader: Optional[Iterable] = None, num_epochs: int = 1, metrics: Optional[List[Union[Metric, str]]] = None, metrics_on_train: bool = False, callbacks: Optional[List[Callback]] = None, val_callbacks: Optional[List[Callback]] = None): """Train the argus model. The method attaches metrics and callbacks to the train and validation, and runs the training process. Args: train_loader (Iterable): The train data loader. val_loader (Iterable, optional): The validation data loader. Defaults to `None`. num_epochs (int, optional): Number of training epochs to run. Defaults to 1. metrics (list of :class:`argus.metrics.Metric`, optional): List of metrics to evaluate. By default, the metrics are evaluated on the validation data (if any) only. Defaults to `None`. metrics_on_train (bool, optional): Evaluate the metrics on train data as well. Defaults to False. callbacks (list of :class:`argus.callbacks.Callback`, optional): List of callbacks to be attached to the training process. Defaults to `None`. val_callbacks (list of :class:`argus.callbacks.Callback`, optional): List of callbacks to be attached to the validation process. Defaults to `None`. """ self._check_train_ready() metrics = [] if metrics is None else metrics train_engine = Engine(self.train_step, model=self, logger=self.logger, phase='train') train_metrics = [Loss()] + metrics if metrics_on_train else [Loss()] _attach_metrics(train_engine, train_metrics) metrics_logging.attach(train_engine, train=True) if val_loader is not None: self.validate(val_loader, metrics, val_callbacks) val_engine = Engine(self.val_step, model=self, logger=self.logger, phase='val') _attach_metrics(val_engine, [Loss()] + metrics) _attach_callbacks(val_engine, val_callbacks) @on_epoch_complete def validation_epoch(train_state, val_engine, val_loader): epoch = train_state.epoch val_state = val_engine.run(val_loader, epoch, epoch + 1) train_state.metrics.update(val_state.metrics) validation_epoch.attach(train_engine, val_engine, val_loader) metrics_logging.attach(train_engine, train=False) _attach_callbacks(train_engine, callbacks) train_engine.run(train_loader, 0, num_epochs)
def validate(self, val_loader, metrics=None, callbacks=None): metrics = [] if metrics is None else metrics assert self.train_ready() val_engine = Engine(self.val_step, model=self, logger=self.logger) _attach_metrics(val_engine, [Loss()] + metrics, name_prefix='val_') _attach_callbacks(val_engine, callbacks) metrics_logging.attach(val_engine, train=False, print_epoch=False) return val_engine.run(val_loader).metrics