def get_metrics(
        self,
        action: Union[Actions, Any],
        reward: Union[Rewards, Any],
        done: Union[bool, Sequence[bool]],
    ) -> Optional[EpisodeMetrics]:
        metrics = []

        rewards = reward.y if isinstance(reward, Rewards) else reward
        actions = action.y_pred if isinstance(action, Actions) else action
        dones: Sequence[bool]
        if not self.is_batched_env:
            rewards = [rewards]
            actions = [actions]
            assert isinstance(done, bool)
            dones = [done]
        else:
            assert isinstance(done, (np.ndarray, Tensor))
            dones = done

        for env_index, (env_is_done, reward) in enumerate(zip(dones, rewards)):
            if env_is_done:
                metrics.append(
                    EpisodeMetrics(
                        n_samples=1,
                        # The average reward per episode.
                        mean_episode_reward=self.
                        _current_episode_reward[env_index],
                        # The average length of each episode.
                        mean_episode_length=self.
                        _current_episode_steps[env_index],
                    ))
                self._current_episode_reward[env_index] = 0
                self._current_episode_steps[env_index] = 0

        if not metrics:
            return None

        metric = sum(metrics, Metrics())
        if wandb.run:
            log_dict = metric.to_log_dict()
            if self.wandb_prefix:
                log_dict = add_prefix(log_dict,
                                      prefix=self.wandb_prefix,
                                      sep="/")
            log_dict["steps"] = self._steps
            log_dict["episode"] = self._episodes
            wandb.log(log_dict)

        return metric
    def get_metrics(self, action: Actions, reward: Rewards) -> Metrics:
        assert action.y_pred.shape == reward.y.shape, (action.shapes,
                                                       reward.shapes)
        metric = ClassificationMetrics(y_pred=action.y_pred,
                                       y=reward.y,
                                       num_classes=self.n_classes)

        if wandb.run:
            log_dict = metric.to_log_dict()
            if self.wandb_prefix:
                log_dict = add_prefix(log_dict,
                                      prefix=self.wand_probs,
                                      sep="/")
            log_dict["steps"] = self._steps
            wandb.log(log_dict)
        return metric
Ejemplo n.º 3
0
    def all_metrics(self) -> Dict[str, Metrics]:
        """ Returns a 'cleaned up' dictionary of all the Metrics objects. """
        assert self.name
        result: Dict[str, Metrics] = {}
        result.update(self.metrics)

        for name, loss in self.losses.items():
            # TODO: Aren't we potentially colliding with 'self.metrics' here?
            subloss_metrics = loss.all_metrics()
            for key, metric in subloss_metrics.items():
                assert key not in result, (
                    f"Collision in metric keys of subloss {name}: key={key}, "
                    f"result={result}"
                )
                result[key] = metric
        result = add_prefix(result, prefix=self.name, sep="/")
        return result
Ejemplo n.º 4
0
    def to_pbar_message(self) -> Dict[str, float]:
        """ Smaller, less-detailed version of `to_log_dict()` for progress bars.
        """
        # NOTE: PL actually doesn't seem to accept strings as values 
        message: Dict[str, Union[str, float]] = {}
        message["Loss"] = float(self.loss)

        for name, metric in self.metrics.items():
            if isinstance(metric, Metrics):
                message[name] = metric.to_pbar_message()
            else:
                message[name] = metric

        for name, loss_info in self.losses.items():
            message[name] = loss_info.to_pbar_message()

        message = add_prefix(message, prefix=self.name, sep=" ")

        return cleanup(message, sep=" ")
Ejemplo n.º 5
0
    def to_log_dict(self, verbose: bool = False) -> Dict[str, Union[str, float, Dict]]:
        """Creates a dictionary to be logged (e.g. by `wandb.log`).

        Args:
            verbose (bool, optional): Wether to include a lot of information, or
            to only log the 'essential' stuff. See the `cleanup` function for
            more info. Defaults to False.

        Returns:
            Dict: A dict containing the things to be logged.
        """
        # TODO: Could also produce some wandb plots and stuff here when verbose?
        log_dict: Dict[str, Union[str, float, Dict, Tensor]] = {}
        log_dict["loss"] = round(float(self.loss), 6)

        for name, metric in self.metrics.items():
            if isinstance(metric, Serializable):
                log_dict[name] = metric.to_log_dict(verbose=verbose)
            else:
                log_dict[name] = metric

        for name, loss in self.losses.items():
            if isinstance(loss, Serializable):
                log_dict[name] = loss.to_log_dict(verbose=verbose)
            else:
                log_dict[name] = loss

        log_dict = add_prefix(log_dict, prefix=self.name, sep="/")
        keys_to_remove: List[str] = []
        if not verbose:
            # when NOT verbose, remove any entries with this matching key.
            # TODO: add/remove keys here if you want to customize what doesn't get logged to wandb.
            # TODO: Could maybe make this a class variable so that it could be
            # extended/overwritten, but that sounds like a bit too much rn.
            keys_to_remove = [
                "n_samples",
                "name",
                "confusion_matrix",
                "class_accuracy",
                "_coefficient",
            ]
        result = cleanup(log_dict, keys_to_remove=keys_to_remove, sep="/") 
        return result
Ejemplo n.º 6
0
    def main_loop(self, method: Method) -> IncrementalResults:
        """ Runs an incremental training loop, wether in RL or CL. """
        # TODO: Add ways of restoring state to continue a given run?
        # For each training task, for each test task, a list of the Metrics obtained
        # during testing on that task.
        # NOTE: We could also just store a single metric for each test task, but then
        # we'd lose the ability to create a plots to show the performance within a test
        # task.
        # IDEA: We could use a list of IIDResults! (but that might cause some circular
        # import issues)
        results = self.Results()
        if self.monitor_training_performance:
            results._online_training_performance = []

        # TODO: Fix this up, need to set the '_objective_scaling_factor' to a different
        # value depending on the 'dataset' / environment.
        results._objective_scaling_factor = self._get_objective_scaling_factor(
        )

        if self.wandb:
            # Init wandb, and then log the setting's options.
            self.wandb_run = self.setup_wandb(method)
            method.setup_wandb(self.wandb_run)

        method.set_training()

        self._start_time = time.process_time()

        for task_id in range(self.phases):
            logger.info(f"Starting training" +
                        (f" on task {task_id}." if self.nb_tasks > 1 else "."))
            self.current_task_id = task_id
            self.task_boundary_reached(method, task_id=task_id, training=True)

            # Creating the dataloaders ourselves (rather than passing 'self' as
            # the datamodule):
            task_train_env = self.train_dataloader()
            task_valid_env = self.val_dataloader()

            method.fit(
                train_env=task_train_env,
                valid_env=task_valid_env,
            )
            task_train_env.close()
            task_valid_env.close()

            if self.monitor_training_performance:
                results._online_training_performance.append(
                    task_train_env.get_online_performance())

            logger.info(f"Finished Training on task {task_id}.")
            test_metrics: TaskSequenceResults = self.test_loop(method)

            # Add a row to the transfer matrix.
            results.append(test_metrics)
            logger.info(
                f"Resulting objective of Test Loop: {test_metrics.objective}")

            if wandb.run:
                d = add_prefix(test_metrics.to_log_dict(),
                               prefix="Test",
                               sep="/")
                # d = add_prefix(test_metrics.to_log_dict(), prefix="Test", sep="/")
                d["current_task"] = task_id
                wandb.log(d)

        self._end_time = time.process_time()
        runtime = self._end_time - self._start_time
        results._runtime = runtime
        logger.info(f"Finished main loop in {runtime} seconds.")
        self.log_results(method, results)
        return results