def main_loop(self, method: Method) -> ContinualResults: """ Runs a continual learning training loop, wether in RL or CL. """ # TODO: Add ways of restoring state to continue a given run. if self.wandb and self.wandb.project: # Init wandb, and then log the setting's options. self.wandb_run = self.setup_wandb(method) method.setup_wandb(self.wandb_run) train_env = self.train_dataloader() valid_env = self.val_dataloader() logger.info(f"Starting training") method.set_training() self._start_time = time.process_time() method.fit( train_env=train_env, valid_env=valid_env, ) train_env.close() valid_env.close() logger.info(f"Finished Training.") results: ContinualResults = self.test_loop(method) if self.monitor_training_performance: results._online_training_performance = train_env.get_online_performance( ) logger.info(f"Resulting objective of Test Loop: {results.objective}") self._end_time = time.process_time() runtime = self._end_time - self._start_time results._runtime = runtime logger.info(f"Finished main loop in {runtime} seconds.") self.log_results(method, results) return results
def main_loop(self, method: Method) -> IncrementalResults: """ Runs an incremental training loop, wether in RL or CL. """ # TODO: Add ways of restoring state to continue a given run? # For each training task, for each test task, a list of the Metrics obtained # during testing on that task. # NOTE: We could also just store a single metric for each test task, but then # we'd lose the ability to create a plots to show the performance within a test # task. # IDEA: We could use a list of IIDResults! (but that might cause some circular # import issues) results = self.Results() if self.monitor_training_performance: results._online_training_performance = [] # TODO: Fix this up, need to set the '_objective_scaling_factor' to a different # value depending on the 'dataset' / environment. results._objective_scaling_factor = self._get_objective_scaling_factor( ) if self.wandb: # Init wandb, and then log the setting's options. self.wandb_run = self.setup_wandb(method) method.setup_wandb(self.wandb_run) method.set_training() self._start_time = time.process_time() for task_id in range(self.phases): logger.info(f"Starting training" + (f" on task {task_id}." if self.nb_tasks > 1 else ".")) self.current_task_id = task_id self.task_boundary_reached(method, task_id=task_id, training=True) # Creating the dataloaders ourselves (rather than passing 'self' as # the datamodule): task_train_env = self.train_dataloader() task_valid_env = self.val_dataloader() method.fit( train_env=task_train_env, valid_env=task_valid_env, ) task_train_env.close() task_valid_env.close() if self.monitor_training_performance: results._online_training_performance.append( task_train_env.get_online_performance()) logger.info(f"Finished Training on task {task_id}.") test_metrics: TaskSequenceResults = self.test_loop(method) # Add a row to the transfer matrix. results.append(test_metrics) logger.info( f"Resulting objective of Test Loop: {test_metrics.objective}") if wandb.run: d = add_prefix(test_metrics.to_log_dict(), prefix="Test", sep="/") # d = add_prefix(test_metrics.to_log_dict(), prefix="Test", sep="/") d["current_task"] = task_id wandb.log(d) self._end_time = time.process_time() runtime = self._end_time - self._start_time results._runtime = runtime logger.info(f"Finished main loop in {runtime} seconds.") self.log_results(method, results) return results