Esempio n. 1
0
    def test_loop(self, method: Method) -> "IncrementalSetting.Results":
        """ (WIP): Runs an incremental test loop and returns the Results.

        The idea is that this loop should be exactly the same, regardless of if
        you're on the RL or the CL side of the tree.

        NOTE: If `self.known_task_boundaries_at_test_time` is `True` and the
        method has the `on_task_switch` callback defined, then a callback
        wrapper is added that will invoke the method's `on_task_switch` and pass
        it the task id (or `None` if `not self.task_labels_available_at_test_time`)
        when a task boundary is encountered.

        This `on_task_switch` 'callback' wrapper gets added the same way for
        Supervised or Reinforcement learning settings.
        """
        test_env = self.test_dataloader()

        test_env: TestEnvironment

        was_training = method.training
        method.set_testing()

        if self.known_task_boundaries_at_test_time and self.nb_tasks > 1:

            def _on_task_switch(step: int, *arg) -> None:
                # TODO: This attribute isn't on IncrementalSetting itself, it's defined
                # on ContinualRLSetting.
                if step not in test_env.boundary_steps:
                    return
                if not hasattr(method, "on_task_switch"):
                    logger.warning(
                        UserWarning(
                            f"On a task boundary, but since your method doesn't "
                            f"have an `on_task_switch` method, it won't know about "
                            f"it! "))
                    return

                if self.task_labels_at_test_time:
                    # TODO: Should this 'test boundary' step depend on the batch size?
                    task_steps = sorted(test_env.boundary_steps)
                    # TODO: If the ordering of tasks were different (shuffled
                    # tasks for example), then this wouldn't work, we'd need a
                    # list of the task ids or something like that.
                    task_id = task_steps.index(step)
                    logger.debug(
                        f"Calling `method.on_task_switch({task_id})` "
                        f"since task labels are available at test-time.")
                    method.on_task_switch(task_id)
                else:
                    logger.debug(f"Calling `method.on_task_switch(None)` "
                                 f"since task labels aren't available at "
                                 f"test-time, but task boundaries are known.")
                    method.on_task_switch(None)

            test_env = StepCallbackWrapper(test_env,
                                           callbacks=[_on_task_switch])

        try:
            # If the Method has `test` defined, use it.
            method.test(test_env)
            test_env.close()
            test_env: TestEnvironment
            # Get the metrics from the test environment
            test_results: Results = test_env.get_results()

        except NotImplementedError:
            logger.debug(f"Will query the method for actions at each step, "
                         f"since it doesn't implement a `test` method.")
            obs = test_env.reset()

            # TODO: Do we always have a maximum number of steps? or of episodes?
            # Will it work the same for Supervised and Reinforcement learning?
            max_steps: int = getattr(test_env, "step_limit", None)

            # Reset on the last step is causing trouble, since the env is closed.
            pbar = tqdm.tqdm(itertools.count(), total=max_steps, desc="Test")
            episode = 0

            for step in pbar:
                if obs is None:
                    break
                # NOTE: The env might not be closed, while `obs` is actually still there.
                # if test_env.is_closed():
                #     logger.debug(f"Env is closed")
                #     break
                # logger.debug(f"At step {step}")

                # BUG: Need to pass an action space that actually reflects the batch
                # size, even for the last batch!

                # BUG: This doesn't work if the env isn't batched.
                action_space = test_env.action_space
                batch_size = getattr(test_env, "num_envs",
                                     getattr(test_env, "batch_size", 0))
                env_is_batched = batch_size is not None and batch_size >= 1
                if env_is_batched:
                    # NOTE: Need to pass an action space that actually reflects the batch
                    # size, even for the last batch!
                    obs_batch_size = obs.x.shape[0] if obs.x.shape else None
                    action_space_batch_size = (test_env.action_space.shape[0]
                                               if test_env.action_space.shape
                                               else None)
                    if (obs_batch_size is not None
                            and obs_batch_size != action_space_batch_size):
                        action_space = batch_space(
                            test_env.single_action_space, obs_batch_size)

                action = method.get_actions(obs, action_space)

                # logger.debug(f"action: {action}")
                # TODO: Remove this:
                if isinstance(action, Actions):
                    action = action.y_pred
                if isinstance(action, Tensor):
                    action = action.cpu().numpy()

                if test_env.is_closed():
                    break

                obs, reward, done, info = test_env.step(action)

                if done and not test_env.is_closed():
                    # logger.debug(f"end of test episode {episode}")
                    obs = test_env.reset()
                    episode += 1

            test_env.close()
            test_results: TaskSequenceResults = test_env.get_results()

        # Restore 'training' mode, if it was set at the start.
        if was_training:
            method.set_training()

        return test_results
Esempio n. 2
0
    def test_loop(self, method: Method) -> "IncrementalAssumption.Results":
        """ WIP: Continual test loop.
        """
        test_env = self.test_dataloader()

        test_env: TestEnvironment

        was_training = method.training
        method.set_testing()

        try:
            # If the Method has `test` defined, use it.
            method.test(test_env)
            test_env.close()
            test_env: TestEnvironment
            # Get the metrics from the test environment
            test_results: Results = test_env.get_results()

        except NotImplementedError:
            logger.debug(f"Will query the method for actions at each step, "
                         f"since it doesn't implement a `test` method.")
            obs = test_env.reset()

            # TODO: Do we always have a maximum number of steps? or of episodes?
            # Will it work the same for Supervised and Reinforcement learning?
            max_steps: int = getattr(test_env, "step_limit", None)

            # Reset on the last step is causing trouble, since the env is closed.
            pbar = tqdm.tqdm(itertools.count(), total=max_steps, desc="Test")
            episode = 0

            for step in pbar:
                if obs is None:
                    break
                # NOTE: The env might not be closed, while `obs` is actually still there.
                # if test_env.is_closed():
                #     logger.debug(f"Env is closed")
                #     break
                # logger.debug(f"At step {step}")

                # BUG: Need to pass an action space that actually reflects the batch
                # size, even for the last batch!

                # BUG: This doesn't work if the env isn't batched.
                action_space = test_env.action_space
                batch_size = getattr(test_env, "num_envs",
                                     getattr(test_env, "batch_size", 0))
                env_is_batched = batch_size is not None and batch_size >= 1
                if env_is_batched:
                    # NOTE: Need to pass an action space that actually reflects the batch
                    # size, even for the last batch!
                    obs_batch_size = obs.x.shape[0] if obs.x.shape else None
                    action_space_batch_size = (test_env.action_space.shape[0]
                                               if test_env.action_space.shape
                                               else None)
                    if (obs_batch_size is not None
                            and obs_batch_size != action_space_batch_size):
                        action_space = batch_space(
                            test_env.single_action_space, obs_batch_size)

                action = method.get_actions(obs, action_space)

                # logger.debug(f"action: {action}")
                # TODO: Remove this:
                if isinstance(action, Actions):
                    action = action.y_pred
                if isinstance(action, Tensor):
                    action = action.detach().cpu().numpy()

                if test_env.is_closed():
                    break

                obs, reward, done, info = test_env.step(action)

                if done and not test_env.is_closed():
                    # logger.debug(f"end of test episode {episode}")
                    obs = test_env.reset()
                    episode += 1

            test_env.close()
            test_results: Results = test_env.get_results()

        if wandb.run:
            d = add_prefix(test_results.to_log_dict(), prefix="Test", sep="/")
            # d = add_prefix(test_metrics.to_log_dict(), prefix="Test", sep="/")
            # d["current_task"] = task_id
            wandb.log(d)

        # Restore 'training' mode, if it was set at the start.
        if was_training:
            method.set_training()

        return test_results