def _search_for_prev_state(path, extensions=None): """ Helper function to search in a given path for previous epoch states (indicated by extensions) Parameters ---------- path : str the path to search in extensions : list list of strings containing valid file extensions for checkpoint files Returns ------- str the file containing the latest checkpoint (if available) None if no latst checkpoint was found int the latest epoch (1 if no checkpoint was found) """ if extensions is None: extensions = [".pkl"] return BaseNetworkTrainer._search_for_prev_state(path, extensions)
def calc_metrics(batch, metrics: dict = None, metric_keys=None): if metrics is None: metrics = {} if metric_keys is None: metric_keys = {k: ("pred", "y") for k in metrics.keys()} return BaseNetworkTrainer.calc_metrics(batch, metrics, metric_keys)