Exemplo n.º 1
0
    def main_loop(self, method: Method) -> ContinualResults:
        """ Runs a continual learning training loop, wether in RL or CL. """
        # TODO: Add ways of restoring state to continue a given run.
        if self.wandb and self.wandb.project:
            # Init wandb, and then log the setting's options.
            self.wandb_run = self.setup_wandb(method)
            method.setup_wandb(self.wandb_run)

        train_env = self.train_dataloader()
        valid_env = self.val_dataloader()

        logger.info(f"Starting training")
        method.set_training()
        self._start_time = time.process_time()

        method.fit(
            train_env=train_env,
            valid_env=valid_env,
        )
        train_env.close()
        valid_env.close()

        logger.info(f"Finished Training.")
        results: ContinualResults = self.test_loop(method)

        if self.monitor_training_performance:
            results._online_training_performance = train_env.get_online_performance(
            )

        logger.info(f"Resulting objective of Test Loop: {results.objective}")

        self._end_time = time.process_time()
        runtime = self._end_time - self._start_time
        results._runtime = runtime
        logger.info(f"Finished main loop in {runtime} seconds.")
        self.log_results(method, results)
        return results
Exemplo n.º 2
0
    def main_loop(self, method: Method) -> IncrementalResults:
        """ Runs an incremental training loop, wether in RL or CL. """
        # TODO: Add ways of restoring state to continue a given run?
        # For each training task, for each test task, a list of the Metrics obtained
        # during testing on that task.
        # NOTE: We could also just store a single metric for each test task, but then
        # we'd lose the ability to create a plots to show the performance within a test
        # task.
        # IDEA: We could use a list of IIDResults! (but that might cause some circular
        # import issues)
        results = self.Results()
        if self.monitor_training_performance:
            results._online_training_performance = []

        # TODO: Fix this up, need to set the '_objective_scaling_factor' to a different
        # value depending on the 'dataset' / environment.
        results._objective_scaling_factor = self._get_objective_scaling_factor(
        )

        if self.wandb:
            # Init wandb, and then log the setting's options.
            self.wandb_run = self.setup_wandb(method)
            method.setup_wandb(self.wandb_run)

        method.set_training()

        self._start_time = time.process_time()

        for task_id in range(self.phases):
            logger.info(f"Starting training" +
                        (f" on task {task_id}." if self.nb_tasks > 1 else "."))
            self.current_task_id = task_id
            self.task_boundary_reached(method, task_id=task_id, training=True)

            # Creating the dataloaders ourselves (rather than passing 'self' as
            # the datamodule):
            task_train_env = self.train_dataloader()
            task_valid_env = self.val_dataloader()

            method.fit(
                train_env=task_train_env,
                valid_env=task_valid_env,
            )
            task_train_env.close()
            task_valid_env.close()

            if self.monitor_training_performance:
                results._online_training_performance.append(
                    task_train_env.get_online_performance())

            logger.info(f"Finished Training on task {task_id}.")
            test_metrics: TaskSequenceResults = self.test_loop(method)

            # Add a row to the transfer matrix.
            results.append(test_metrics)
            logger.info(
                f"Resulting objective of Test Loop: {test_metrics.objective}")

            if wandb.run:
                d = add_prefix(test_metrics.to_log_dict(),
                               prefix="Test",
                               sep="/")
                # d = add_prefix(test_metrics.to_log_dict(), prefix="Test", sep="/")
                d["current_task"] = task_id
                wandb.log(d)

        self._end_time = time.process_time()
        runtime = self._end_time - self._start_time
        results._runtime = runtime
        logger.info(f"Finished main loop in {runtime} seconds.")
        self.log_results(method, results)
        return results