def get_loss(self, *args, **kwargs) -> Loss:
        """Gets the loss. 

        NOTE: This is a simplified version of EWC where the loss is the P-norm
        between the current weights and the weights as they were on the begining
        of the task.

        This doesn't actually use any of the provided arguments.
        """
        if self.previous_task is None:
            # We're in the first task: do nothing.
            return Loss(name=self.name)

        old_weights: Dict[str, Tensor] = self.previous_model_weights
        new_weights: Dict[str, Tensor] = dict(self.model.named_parameters())

        loss = 0.
        for weight_name, (new_w,
                          old_w) in dict_intersection(new_weights,
                                                      old_weights):
            loss += torch.dist(new_w,
                               old_w.type_as(new_w),
                               p=self.options.distance_norm)

        self._i += 1
        ewc_loss = Loss(
            name=self.name,
            loss=loss,
        )
        return ewc_loss
Example #2
0
    def get_episode_loss(self, env_index: int, done: bool) -> Optional[Loss]:
        """Calculate a loss to train with, given the last (up to
        max_episode_window_length) observations/actions/rewards of the current
        episode in the environment at the given index in the batch.

        If `done` is True, then this is for the end of an episode. If `done` is
        False, the episode is still underway.

        NOTE: While the Batch Observations/Actions/Rewards objects usually
        contain the "batches" of data coming from the N different environments,
        now they are actually a sequence of items coming from this single
        environment. For more info on how this is done, see the  
        """
        inputs: Tensor
        actions: PolicyHeadOutput
        rewards: ContinualRLSetting.Rewards
        if not done:
            # This particular algorithm (REINFORCE) can't give a loss until the
            # end of the episode is reached.
            return None

        if len(self.actions[env_index]) == 0:
            logger.error(f"Weird, asked to get episode loss, but there is "
                         f"nothing in the buffer?")
            return None

        inputs, actions, rewards = self.stack_buffers(env_index)

        episode_length = actions.batch_size
        assert len(inputs) == len(actions.y_pred) == len(rewards.y)

        if episode_length <= 1:
            # TODO: If the episode has len of 1, we can't really get a loss!
            logger.error("Episode is too short!")
            return None

        log_probabilities = actions.y_pred_log_prob
        rewards = rewards.y

        loss_tensor = self.policy_gradient(
            rewards=rewards,
            log_probs=log_probabilities,
            gamma=self.hparams.gamma,
        )
        loss = Loss(self.name, loss_tensor)
        loss.metric = EpisodeMetrics(
            n_samples=1,
            mean_episode_reward=float(rewards.sum()),
            mean_episode_length=len(rewards),
        )
        # TODO: add something like `add_metric(self, metric: Metrics, name: str=None)`
        # to `Loss`.
        loss.metrics["gradient_usage"] = self.get_gradient_usage_metrics(
            env_index)
        return loss
Example #3
0
    def get_loss(self, forward_pass: ForwardPass,
                 actions: ClassificationOutput, rewards: Rewards) -> Loss:
        logits: Tensor = actions.logits
        y_pred: Tensor = actions.y_pred
        rewards = rewards.to(logits.device)

        y: Tensor = rewards.y

        n_classes = logits.shape[-1]
        # Could remove these: just used for debugging.
        assert len(y.shape) == 1, y.shape
        assert not torch.is_floating_point(y), y.dtype
        assert 0 <= y.min(), y
        assert y.max() < n_classes, y

        loss = self.loss_fn(logits, y)

        assert loss.shape == ()
        metrics = ClassificationMetrics(y_pred=logits, y=y)

        assert self.name, "Output Heads should have a name!"
        loss_object = Loss(
            name=self.name,
            loss=loss,
            # NOTE: we're passing the tensors to the Loss object because we let
            # it create the Metrics for us automatically.
            metrics={self.name: metrics},
        )
        return loss_object
Example #4
0
    def __init__(self,
                 input_space: spaces.Space,
                 action_space: spaces.Discrete,
                 reward_space: spaces.Box,
                 hparams: "PolicyHead.HParams" = None,
                 name: str = "policy"):
        assert isinstance(
            input_space, spaces.Box
        ), f"Only support Tensor (box) input space. (got {input_space})."
        assert isinstance(
            action_space, spaces.Discrete
        ), f"Only support discrete action space (got {action_space})."
        assert isinstance(
            reward_space, spaces.Box
        ), f"Reward space should be a Box (scalar rewards) (got {reward_space})."
        super().__init__(
            input_space=input_space,
            action_space=action_space,
            reward_space=reward_space,
            hparams=hparams,
            name=name,
        )
        logger.debug("New Output head with hparams: " +
                     self.hparams.dumps_json(indent='\t'))
        self.hparams: PolicyHead.HParams
        # Type hints for the spaces;
        self.input_space: spaces.Box
        self.action_space: spaces.Discrete
        self.reward_space: spaces.Box

        # List of buffers for each environment that will hold some items.
        # TODO: Won't use the 'observations' anymore, will only use the
        # representations from the encoder, so renaming 'representations' to
        # 'observations' in this case.
        # (Should probably come up with another name so this isn't ambiguous).
        # TODO: Perhaps we should register these as buffers so they get
        # persisted correclty? But then we also need to make sure that the grad
        # stuff would work the same way..
        self.representations: List[Deque[Tensor]] = []
        # self.representations: List[deque] = []
        self.actions: List[Deque[PolicyHeadOutput]] = []
        self.rewards: List[Deque[ContinualRLSetting.Rewards]] = []

        # The actual "internal" loss we use for training.
        self.loss: Loss = Loss(self.name)
        self.batch_size: int = 0

        self.num_episodes_since_update: np.ndarray = np.zeros(1)
        self.num_steps_in_episode: np.ndarray = np.zeros(1)

        self._training: bool = True

        self.device: Optional[Union[str, torch.device]] = None
Example #5
0
    def get_loss(self, forward_pass: ForwardPass, actions: Actions, rewards: Rewards) -> Loss:
        actions: Actions = forward_pass.actions
        y_pred: Tensor = actions.y_pred
        y: Tensor = rewards.y

        loss = self.loss_fn(y_pred, y)
        metrics = RegressionMetrics(y_pred=y_pred, y=y)

        assert self.name, "Output Heads should have a name!"
        loss = Loss(
            name=self.name,
            loss=loss,
            # NOTE: we're passing the tensors to the Loss object because we let
            # it create the Metrics for us automatically.
            metrics={self.name: metrics},
        )
        return loss
Example #6
0
    def output_head_loss(self, forward_pass: ForwardPass, actions: Actions,
                         rewards: Rewards) -> Loss:
        # Asks each output head for its contribution to the loss.
        observations: IncrementalSetting.Observations = forward_pass.observations
        task_labels = observations.task_labels
        if isinstance(task_labels, Tensor):
            task_labels = task_labels.cpu().numpy()

        batch_size = forward_pass.batch_size
        assert batch_size is not None

        if task_labels is None:
            if self.task_inference_module:
                # TODO: Predict the task ids using some kind of task
                # inference mechanism.
                task_labels = self.task_inference_module(forward_pass)
            else:
                raise NotImplementedError(
                    f"Multihead model doesn't have access to task labels and "
                    f"doesn't have a task inference module!")
                # TODO: Maybe use the last trained output head, by default?
        # BUG: We get no loss from the output head for the first episode after a task
        # switch.
        # NOTE: The problem is that the `done` in the observation isn't necessarily
        # associated with the task designed by the `task_id` in that observation!
        # That is because of how vectorized environments work, they reset the env and
        # give the new initial observation when `done` is True, rather than the last
        # observation in that env.
        if self.previous_task_labels is None:
            self.previous_task_labels = task_labels

        # Default behaviour: use the (only) output head.
        if not self.hp.multihead:
            return self.output_head.get_loss(
                forward_pass,
                actions=actions,
                rewards=rewards,
            )

        # The sum of all the losses from all the output heads.
        total_loss = Loss(self.output_head.name)

        task_switched_in_env = task_labels != self.previous_task_labels
        # TODO: This `done` attribute isn't added in supervised settings.
        episode_ended = getattr(observations, "done",
                                np.zeros(batch_size, dtype=bool))
        # TODO: Remove all this useless conversion from Tensors to ndarrays, by making
        # Sequoia more numpy-centric.
        if isinstance(episode_ended, Tensor):
            episode_ended = episode_ended.cpu().numpy()

        # logger.debug(f"Task labels: {task_labels}, task switched in env: {task_switched_in_env}, episode ended: {episode_ended}")
        done_set_to_false_temporarily_indices = []

        if any(episode_ended & task_switched_in_env):
            # In the environments where there was a task switch to a different task and
            # where some episodes ended, we need to first get the corresponding output
            # head losses from these environments first.
            if self.batch_size in {None, 1}:
                # If the batch size is 1, this is a little bit simpler to deal with.
                previous_task: int = self.previous_task_labels[0].item()
                # IDEA:
                from sequoia.methods.models.output_heads.rl import PolicyHead

                previous_output_head = self.output_heads[str(previous_task)]
                assert isinstance(
                    previous_output_head, PolicyHead
                ), "todo: assuming that this only happends in RL currently."
                # We want the loss from that output head, but we don't want to
                # re-compute it below!
                env_index_in_previous_batch = 0
                # breakpoint()
                logger.debug(
                    f"Getting a loss from the output head for task {previous_task}, that was used for the last task."
                )
                env_episode_loss = previous_output_head.get_episode_loss(
                    env_index_in_previous_batch, done=True)
                # logger.debug(f"Loss from that output head: {env_episode_loss}")
                # Add this end-of-episode loss to the total loss.
                # breakpoint()
                # BUG: This can sometimes (rarely) be None! Need to better understand
                # why this is happening.
                if env_episode_loss is None:
                    logger.warning(
                        RuntimeWarning(
                            f"BUG: Env {env_index_in_previous_batch} gave back a loss "
                            f"of `None`, when we expected a loss from that output head "
                            f"for task id {previous_task}."))
                else:
                    total_loss += env_episode_loss
                # We call on_episode_end so the output head can clear the relevant
                # buffers. Note that get_episode_loss(env_index, done=True) doesn't
                # clear the buffers, it just calculates a loss.
                previous_output_head.on_episode_end(
                    env_index_in_previous_batch)

                # Set `done` to `False` for that env, to prevent the output head for the
                # new task from seeing the first observation in the episode as the last.
                observations.done[env_index_in_previous_batch] = False
                # FIXME: If we modify that entry in-place, then even after this method
                # returns, the change will persist.. Therefore we just save the indices
                # that we altered, and reset them before returning.
                done_set_to_false_temporarily_indices.append(
                    env_index_in_previous_batch)
            else:
                raise NotImplementedError(
                    "TODO: The BaselineModel doesn't yet support having multiple "
                    "different tasks within the same batch in RL. ")
                # IDEA: Need to somehow pass the indices of which env to take care of to
                # each output head, so they can create / clear buffers only when needed.

        assert task_labels is not None
        all_task_indices: Dict[int, Tensor] = get_task_indices(task_labels)

        # Get the loss from each output head:
        if len(all_task_indices) == 1:
            # If everything is in the same task (only one key), no need to split/merge
            # stuff, so it's a bit easier:
            task_id: int = list(all_task_indices.keys())[0]

            with self.switch_output_head(task_id):
                # task_output_head = self.output_heads[str(task_id)]
                total_loss += self.output_head.get_loss(
                    forward_pass,
                    actions=actions,
                    rewards=rewards,
                )
        else:
            # Split off the input batch, do a forward pass for each sub-task.
            # (could be done in parallel but whatever.)
            # TODO: Also, not sure if this will play well with DP, DDP, etc.
            for task_id, task_indices in all_task_indices.items():
                # # Make a partial observation without the task labels, so that
                # # super().forward will use the current output head.
                forward_pass_slice = get_slice(forward_pass, task_indices)
                actions_slice = get_slice(actions, task_indices)
                rewards_slice = get_slice(rewards, task_indices)

                logger.debug(
                    f"Getting output head loss"
                    f"{len(task_indices)/batch_size:.0%} of the batch which "
                    f"has task_id of '{task_id}'.")
                task_output_head = self.output_heads[str(task_id)]
                task_loss = task_output_head.get_loss(
                    forward_pass_slice,
                    actions=actions_slice,
                    rewards=rewards_slice,
                )
                # FIXME: debugging
                # task_output_head_loss.name += f"(task {task_id})"
                logger.debug(f"Task {task_id} loss: {task_loss}")
                total_loss += task_loss

        self.previous_task_labels = task_labels
        # FIXME: Reset the 'done' to True, if we manually set it to False.
        for index in done_set_to_false_temporarily_indices:
            observations.done[index] = True

        return total_loss
Example #7
0
 def get_loss(self, forward_pass, actions, rewards):
     return Loss(self.name, 0.)
Example #8
0
    def get_loss(self, forward_pass: ForwardPass, actions: PolicyHeadOutput,
                 rewards: ContinualRLSetting.Rewards) -> Loss:
        """ Given the forward pass, the actions produced by this output head and
        the corresponding rewards for the current step, get a Loss to use for
        training.
        
        TODO: Replace the `forward_pass` argument with just `observations` and
        `representations` and provide the right (augmented) observations to the
        aux tasks. (Need to design that part later).
        
        NOTE: If an end of episode was reached in a given environment, we always
        calculate the losses and clear the buffers before adding in the new observation.
        """
        observations: ContinualRLSetting.Observations = forward_pass.observations
        representations: Tensor = forward_pass.representations
        assert self.batch_size, "forward() should have been called before this."

        if not self.hparams.accumulate_losses_before_backward:
            # Reset the loss for the current step, if we're not accumulating it.
            self.loss = Loss(self.name)

        observations = forward_pass.observations
        representations = forward_pass.representations
        assert observations.done is not None, "need the end-of-episode signal"

        # Calculate the loss for each environment.
        for env_index, done in enumerate(observations.done):

            env_loss = self.get_episode_loss(env_index, done=done)

            if env_loss is not None:
                self.loss += env_loss

            if done:
                # End of episode reached in that env!
                if self.training:
                    # BUG: This seems to be failing, during testing:
                    # assert env_loss is not None, (self.name)
                    pass

                self.on_episode_end(env_index)

        if self.batch_size != forward_pass.batch_size:
            raise NotImplementedError(
                "TODO: The batch size changed, because the batch contains different "
                "tasks. The BaselineModel isn't yet applicable in the setup where "
                "there are multiple different tasks in the same batch in RL. ")
            # IDEA: Need to get access to the 'original' env indices (before slicing),
            # so that even when one more environment is in this task, the other
            # environment's buffers remain at the same index.. Something like a
            # remapping of env indices?
            assert len(representations.shape) == 2, (
                f"Need batched representations, with a shape [16, 128] or similar, but "
                f"representations have shape {representations.shape}.")
            self.batch_size = representations.shape[0]
            self.create_buffers()

        for env_index in range(self.batch_size):
            # Take a slice across the first dimension
            # env_observations = get_slice(observations, env_index)
            env_representations = representations[env_index]
            env_actions = actions.slice(env_index)
            # env_actions = actions[env_index, ...] # TODO: Is this nicer?
            env_rewards = rewards.slice(env_index)
            self.representations[env_index].append(env_representations)
            self.actions[env_index].append(env_actions)
            self.rewards[env_index].append(env_rewards)

        self.num_steps_in_episode += 1
        # TODO:
        # If we want to accumulate the losses before backward, then we just return self.loss
        # If we DONT want to accumulate the losses before backward, then we do the
        # 'small' backward pass, and return a detached loss.
        if self.hparams.accumulate_losses_before_backward:
            if all(self.num_episodes_since_update >=
                   self.hparams.min_episodes_before_update):
                # Every environment has seen the required number of episodes.
                # We return the accumulated loss, so that the model can do the backward
                # pass and update the weights.
                returned_loss = self.loss
                self.loss = Loss(self.name)
                self.detach_all_buffers()
                self.num_episodes_since_update[:] = 0
                return returned_loss
            else:
                return Loss(self.name)
        else:
            # Perform the backward pass as soon as a loss is available (with
            # retain_graph=True).
            if all(self.num_episodes_since_update >=
                   self.hparams.min_episodes_before_update):
                # Every environment has seen the required number of episodes.
                # We return the loss for this step, with gradients, to indicate to the
                # Model that it can perform the backward pass and update the weights.
                returned_loss = self.loss
                self.loss = Loss(self.name)
                self.detach_all_buffers()
                self.num_episodes_since_update[:] = 0
                return returned_loss

            elif self.loss.requires_grad:
                # Not all environments are done, but we have a Loss from one of them.
                self.loss.backward(retain_graph=True)
                # self.loss will be reset at each step in the `forward` method above.
                return self.loss.detach()

            else:
                # TODO: Why is self.loss non-zero here?
                if self.loss.loss != 0.:
                    # BUG: This is a weird edge-case, where at least one env produced
                    # a loss, but that loss doesn't require grad.
                    # This should only happen if the model isn't in training mode, for
                    # instance.
                    # assert not self.training, self.loss
                    # return self.loss
                    pass
                return self.loss
        assert False, f"huh? {self.loss}"
        return self.loss
Example #9
0
class PolicyHead(ClassificationHead):
    """ [WIP] Output head for RL settings.
    
    Uses the REINFORCE algorithm to calculate its loss. 
    
    TODOs/issues:
    - Only currently works with batch_size == 1
    - The buffers are common to training/validation/testing atm..
    
    """
    name: ClassVar[str] = "policy"

    @dataclass
    class HParams(ClassificationHead.HParams):
        hidden_layers: int = 0
        hidden_neurons: List[int] = list_field()
        # The discount factor for the Return term.
        gamma: float = 0.99

        # The maximum length of the buffer that will hold the most recent
        # states/actions/rewards of the current episode.
        max_episode_window_length: int = 1000

        # Minumum number of epidodes that need to be completed in each env
        # before we update the parameters of the output head.
        min_episodes_before_update: int = 1

        # TODO: Add this mechanism, so that this method could work even when
        # episodes are very long.
        max_steps_between_updates: Optional[int] = None

        # NOTE: Here we have two options:
        # 1- `True`: sum up all the losses and do one larger backward pass,
        # and have `retrain_graph=False`, or
        # 2- `False`: Perform multiple little backward passes, one for each
        # end-of-episode in a single env, w/ `retain_graph=True`.
        # Option 1 is maybe more performant, as it might only require
        # unrolling the graph once, but would use more memory to store all the
        # intermediate graphs.
        accumulate_losses_before_backward: bool = flag(True)

    def __init__(self,
                 input_space: spaces.Space,
                 action_space: spaces.Discrete,
                 reward_space: spaces.Box,
                 hparams: "PolicyHead.HParams" = None,
                 name: str = "policy"):
        assert isinstance(
            input_space, spaces.Box
        ), f"Only support Tensor (box) input space. (got {input_space})."
        assert isinstance(
            action_space, spaces.Discrete
        ), f"Only support discrete action space (got {action_space})."
        assert isinstance(
            reward_space, spaces.Box
        ), f"Reward space should be a Box (scalar rewards) (got {reward_space})."
        super().__init__(
            input_space=input_space,
            action_space=action_space,
            reward_space=reward_space,
            hparams=hparams,
            name=name,
        )
        logger.debug("New Output head with hparams: " +
                     self.hparams.dumps_json(indent='\t'))
        self.hparams: PolicyHead.HParams
        # Type hints for the spaces;
        self.input_space: spaces.Box
        self.action_space: spaces.Discrete
        self.reward_space: spaces.Box

        # List of buffers for each environment that will hold some items.
        # TODO: Won't use the 'observations' anymore, will only use the
        # representations from the encoder, so renaming 'representations' to
        # 'observations' in this case.
        # (Should probably come up with another name so this isn't ambiguous).
        # TODO: Perhaps we should register these as buffers so they get
        # persisted correclty? But then we also need to make sure that the grad
        # stuff would work the same way..
        self.representations: List[Deque[Tensor]] = []
        # self.representations: List[deque] = []
        self.actions: List[Deque[PolicyHeadOutput]] = []
        self.rewards: List[Deque[ContinualRLSetting.Rewards]] = []

        # The actual "internal" loss we use for training.
        self.loss: Loss = Loss(self.name)
        self.batch_size: int = 0

        self.num_episodes_since_update: np.ndarray = np.zeros(1)
        self.num_steps_in_episode: np.ndarray = np.zeros(1)

        self._training: bool = True

    def create_buffers(self):
        """ Creates the buffers to hold the items from each env. """
        logger.debug(f"Creating buffers (batch size={self.batch_size})")
        logger.debug(
            f"Maximum buffer length: {self.hparams.max_episode_window_length}")

        self.representations = self._make_buffers()
        self.actions = self._make_buffers()
        self.rewards = self._make_buffers()

        self.num_steps_in_episode = np.zeros(self.batch_size, dtype=int)
        self.num_episodes_since_update = np.zeros(self.batch_size, dtype=int)

    def forward(self, observations: ContinualRLSetting.Observations,
                representations: Tensor) -> PolicyHeadOutput:
        """ Forward pass of a Policy head.

        TODO: Do we actually need the observations here? It is here so we have
        access to the 'done' from the env, but do we really need it here? or
        would there be another (cleaner) way to do this?
        """
        if len(representations.shape) < 2:
            # Flatten the representations.
            representations = representations.reshape(
                [-1, flatdim(self.input_space)])

        # Setup the buffers, which will hold the most recent observations,
        # actions and rewards within the current episode for each environment.
        if not self.batch_size:
            self.batch_size = representations.shape[0]
            self.create_buffers()

        representations = representations.float()

        logits = self.dense(representations)

        # The policy is the distribution over actions given the current state.
        action_dist = Categorical(logits=logits)
        sample = action_dist.sample()
        actions = PolicyHeadOutput(
            y_pred=sample,
            logits=logits,
            action_dist=action_dist,
        )
        return actions

    def get_loss(self, forward_pass: ForwardPass, actions: PolicyHeadOutput,
                 rewards: ContinualRLSetting.Rewards) -> Loss:
        """ Given the forward pass, the actions produced by this output head and
        the corresponding rewards for the current step, get a Loss to use for
        training.
        
        TODO: Replace the `forward_pass` argument with just `observations` and
        `representations` and provide the right (augmented) observations to the
        aux tasks. (Need to design that part later).
        
        NOTE: If an end of episode was reached in a given environment, we always
        calculate the losses and clear the buffers before adding in the new observation.
        """
        observations: ContinualRLSetting.Observations = forward_pass.observations
        representations: Tensor = forward_pass.representations
        assert self.batch_size, "forward() should have been called before this."

        if not self.hparams.accumulate_losses_before_backward:
            # Reset the loss for the current step, if we're not accumulating it.
            self.loss = Loss(self.name)

        observations = forward_pass.observations
        representations = forward_pass.representations
        assert observations.done is not None, "need the end-of-episode signal"

        # Calculate the loss for each environment.
        for env_index, done in enumerate(observations.done):

            env_loss = self.get_episode_loss(env_index, done=done)

            if env_loss is not None:
                self.loss += env_loss

            if done:
                # End of episode reached in that env!
                if self.training:
                    # BUG: This seems to be failing, during testing:
                    # assert env_loss is not None, (self.name)
                    pass

                self.on_episode_end(env_index)

        if self.batch_size != forward_pass.batch_size:
            raise NotImplementedError(
                "TODO: The batch size changed, because the batch contains different "
                "tasks. The BaselineModel isn't yet applicable in the setup where "
                "there are multiple different tasks in the same batch in RL. ")
            # IDEA: Need to get access to the 'original' env indices (before slicing),
            # so that even when one more environment is in this task, the other
            # environment's buffers remain at the same index.. Something like a
            # remapping of env indices?
            assert len(representations.shape) == 2, (
                f"Need batched representations, with a shape [16, 128] or similar, but "
                f"representations have shape {representations.shape}.")
            self.batch_size = representations.shape[0]
            self.create_buffers()

        for env_index in range(self.batch_size):
            # Take a slice across the first dimension
            # env_observations = get_slice(observations, env_index)
            env_representations = representations[env_index]
            env_actions = actions.slice(env_index)
            # env_actions = actions[env_index, ...] # TODO: Is this nicer?
            env_rewards = rewards.slice(env_index)
            self.representations[env_index].append(env_representations)
            self.actions[env_index].append(env_actions)
            self.rewards[env_index].append(env_rewards)

        self.num_steps_in_episode += 1
        # TODO:
        # If we want to accumulate the losses before backward, then we just return self.loss
        # If we DONT want to accumulate the losses before backward, then we do the
        # 'small' backward pass, and return a detached loss.
        if self.hparams.accumulate_losses_before_backward:
            if all(self.num_episodes_since_update >=
                   self.hparams.min_episodes_before_update):
                # Every environment has seen the required number of episodes.
                # We return the accumulated loss, so that the model can do the backward
                # pass and update the weights.
                returned_loss = self.loss
                self.loss = Loss(self.name)
                self.detach_all_buffers()
                self.num_episodes_since_update[:] = 0
                return returned_loss
            else:
                return Loss(self.name)
        else:
            # Perform the backward pass as soon as a loss is available (with
            # retain_graph=True).
            if all(self.num_episodes_since_update >=
                   self.hparams.min_episodes_before_update):
                # Every environment has seen the required number of episodes.
                # We return the loss for this step, with gradients, to indicate to the
                # Model that it can perform the backward pass and update the weights.
                returned_loss = self.loss
                self.loss = Loss(self.name)
                self.detach_all_buffers()
                self.num_episodes_since_update[:] = 0
                return returned_loss

            elif self.loss.requires_grad:
                # Not all environments are done, but we have a Loss from one of them.
                self.loss.backward(retain_graph=True)
                # self.loss will be reset at each step in the `forward` method above.
                return self.loss.detach()

            else:
                # TODO: Why is self.loss non-zero here?
                if self.loss.loss != 0.:
                    # BUG: This is a weird edge-case, where at least one env produced
                    # a loss, but that loss doesn't require grad.
                    # This should only happen if the model isn't in training mode, for
                    # instance.
                    # assert not self.training, self.loss
                    # return self.loss
                    pass
                return self.loss
        assert False, f"huh? {self.loss}"
        return self.loss

    def on_episode_end(self, env_index: int) -> None:
        self.num_episodes_since_update[env_index] += 1
        self.num_steps_in_episode[env_index] = 0
        self.clear_buffers(env_index)

    def get_episode_loss(self, env_index: int, done: bool) -> Optional[Loss]:
        """Calculate a loss to train with, given the last (up to
        max_episode_window_length) observations/actions/rewards of the current
        episode in the environment at the given index in the batch.

        If `done` is True, then this is for the end of an episode. If `done` is
        False, the episode is still underway.

        NOTE: While the Batch Observations/Actions/Rewards objects usually
        contain the "batches" of data coming from the N different environments,
        now they are actually a sequence of items coming from this single
        environment. For more info on how this is done, see the  
        """
        inputs: Tensor
        actions: PolicyHeadOutput
        rewards: ContinualRLSetting.Rewards
        if not done:
            # This particular algorithm (REINFORCE) can't give a loss until the
            # end of the episode is reached.
            return None

        if len(self.actions[env_index]) == 0:
            logger.error(f"Weird, asked to get episode loss, but there is "
                         f"nothing in the buffer?")
            return None

        inputs, actions, rewards = self.stack_buffers(env_index)

        episode_length = actions.batch_size
        assert len(inputs) == len(actions.y_pred) == len(rewards.y)

        if episode_length <= 1:
            # TODO: If the episode has len of 1, we can't really get a loss!
            logger.error("Episode is too short!")
            return None

        log_probabilities = actions.y_pred_log_prob
        rewards = rewards.y

        loss_tensor = self.policy_gradient(
            rewards=rewards,
            log_probs=log_probabilities,
            gamma=self.hparams.gamma,
        )
        loss = Loss(self.name, loss_tensor)
        loss.metric = EpisodeMetrics(
            n_samples=1,
            mean_episode_reward=float(rewards.sum()),
            mean_episode_length=len(rewards),
        )
        # TODO: add something like `add_metric(self, metric: Metrics, name: str=None)`
        # to `Loss`.
        loss.metrics["gradient_usage"] = self.get_gradient_usage_metrics(
            env_index)
        return loss

    def get_gradient_usage_metrics(self,
                                   env_index: int) -> GradientUsageMetric:
        """ Returns a Metrics object that describes how many of the actions
        from an episode that are used to calculate a loss still have their
        graphs, versus ones that don't have them (due to being created before
        the last model update, and therefore having been detached.)

        Does this by inspecting the contents of `self.actions[env_index]`. 
        """
        episode_actions = self.actions[env_index]
        n_stored_items = len(self.actions[env_index])
        n_items_with_grad = sum(v.logits.requires_grad
                                for v in episode_actions)
        n_items_without_grad = n_stored_items - n_items_with_grad
        return GradientUsageMetric(
            used_gradients=n_items_with_grad,
            wasted_gradients=n_items_without_grad,
        )

    @staticmethod
    def get_returns(rewards: Union[Tensor, List[Tensor]],
                    gamma: float) -> Tensor:
        """ Calculates the returns, as the sum of discounted future rewards at
        each step.
        """
        return discounted_sum_of_future_rewards(rewards, gamma=gamma)

    @staticmethod
    def policy_gradient(rewards: List[float],
                        log_probs: Union[Tensor, List[Tensor]],
                        gamma: float = 0.95):
        """Implementation of the REINFORCE algorithm.

        Adapted from https://medium.com/@thechrisyoon/deriving-policy-gradients-and-implementing-reinforce-f887949bd63

        Parameters
        ----------
        - episode_rewards : List[Tensor]

            The rewards at each step in an episode

        - episode_log_probs : List[Tensor]

            The log probabilities associated with the actions that were taken at
            each step.

        Returns
        -------
        Tensor
            The "vanilla policy gradient" / REINFORCE gradient resulting from
            that episode.
        """
        return vanilla_policy_gradient(rewards, log_probs, gamma=gamma)

    @property
    def training(self) -> bool:
        return self._training

    @training.setter
    def training(self, value: bool) -> None:
        # logger.debug(f"setting training to {value} on the Policy output head")
        if hasattr(self, "_training") and value != self._training:
            before = "train" if self._training else "test"
            after = "train" if value else "test"
            logger.debug(
                f"Clearing buffers, since we're transitioning between from {before}->{after}"
            )
            self.clear_all_buffers()
            self.batch_size = None
            self.num_episodes_since_update[:] = 0
        self._training = value

    def clear_all_buffers(self) -> None:
        if self.batch_size is None:
            assert not self.rewards
            assert not self.representations
            assert not self.actions
            return
        for env_id in range(self.batch_size):
            self.clear_buffers(env_id)
        self.rewards.clear()
        self.representations.clear()
        self.actions.clear()
        self.batch_size = None

    def clear_buffers(self, env_index: int) -> None:
        """ Clear the buffers associated with the environment at env_index.
        """
        self.representations[env_index].clear()
        self.actions[env_index].clear()
        self.rewards[env_index].clear()

    def detach_all_buffers(self):
        if not self.batch_size:
            assert not self.actions
            # No buffers to detach!
            return
        for env_index in range(self.batch_size):
            self.detach_buffers(env_index)

    def detach_buffers(self, env_index: int) -> None:
        """ Detach all the tensors in the buffers for a given environment.
        
        We have to do this when we update the model while an episode in one of
        the enviroment isn't done.
        """
        # detached_representations = map(detach, )
        # detached_actions = map(detach, self.actions[env_index])
        # detached_rewards = map(detach, self.rewards[env_index])
        self.representations[env_index] = self._detach_buffer(
            self.representations[env_index])
        self.actions[env_index] = self._detach_buffer(self.actions[env_index])
        self.rewards[env_index] = self._detach_buffer(self.rewards[env_index])
        # assert False, (self.representations[0], self.representations[-1])

    def _detach_buffer(self, old_buffer: Sequence[Tensor]) -> deque:
        new_items = self._make_buffer()
        for item in old_buffer:
            detached = item.detach()
            new_items.append(detached)
        return new_items

    def _make_buffer(self, elements: Sequence[T] = None) -> Deque[T]:
        buffer: Deque[T] = deque(maxlen=self.hparams.max_episode_window_length)
        if elements:
            buffer.extend(elements)
        return buffer

    def _make_buffers(self) -> List[deque]:
        return [self._make_buffer() for _ in range(self.batch_size)]

    def stack_buffers(self, env_index: int):
        """ Stack the observations/actions/rewards for this env and return them.
        """
        # episode_observations = tuple(self.observations[env_index])
        episode_representations = tuple(self.representations[env_index])
        episode_actions = tuple(self.actions[env_index])
        episode_rewards = tuple(self.rewards[env_index])
        assert len(episode_representations)
        assert len(episode_actions)
        assert len(episode_rewards)
        stacked_inputs = stack(episode_representations)
        stacked_actions = stack(episode_actions)
        stacked_rewards = stack(episode_rewards)
        return stacked_inputs, stacked_actions, stacked_rewards
Example #10
0
    def get_episode_loss(self, env_index: int, done: bool) -> Optional[Loss]:
        # IDEA: Actually, now that I think about it, instead of detaching the
        # tensors, we could instead use the critic's 'value' estimate and get a
        # loss for that incomplete episode using the tensors in the buffer,
        # rather than detaching them!

        if not done:
            return None

        # TODO: Add something like a 'num_steps_since_update' for each env? (it
        # would actually be a num_steps_since_backward)
        # if self.num_steps_since_update?
        n_stored_steps = self.num_stored_steps(env_index)
        if n_stored_steps < 5:
            # For now, we only give back a loss at the end of the episode.
            # TODO: Test if giving back a loss at each step or every few steps
            # would work better!
            logger.warning(
                RuntimeWarning(
                    f"Returning None as the episode loss, because only have "
                    f"{n_stored_steps} steps stored for that environment."))
            return None

        inputs: Tensor
        actions: A2CHeadOutput
        rewards: Rewards
        inputs, actions, rewards = self.stack_buffers(env_index)
        logits: Tensor = actions.logits
        action_log_probs: Tensor = actions.action_log_prob
        values: Tensor = actions.value
        assert rewards.y is not None
        episode_rewards: Tensor = rewards.y

        # target values are calculated backward
        # it's super important to handle correctly done states,
        # for those cases we want our to target to be equal to the reward only
        episode_length = len(episode_rewards)
        dones = torch.zeros(episode_length, dtype=torch.bool)
        dones[-1] = bool(done)

        returns = self.get_returns(episode_rewards,
                                   gamma=self.hparams.gamma).type_as(values)
        advantages = returns - values

        # Normalize advantage (not present in the original implementation)
        if self.hparams.normalize_advantages:
            advantages = normalize(advantages)

        # Create the Loss to be returned.
        loss = Loss(self.name)

        # Policy gradient loss (actor loss)
        policy_gradient_loss = -(advantages.detach() * action_log_probs).mean()
        actor_loss = Loss("actor", policy_gradient_loss)
        loss += self.hparams.actor_loss_coef * actor_loss

        # Value loss: Try to get the critic's values close to the actual return,
        # which means the advantages should be close to zero.
        value_loss_tensor = F.mse_loss(values, returns.reshape(values.shape))
        critic_loss = Loss("critic", value_loss_tensor)
        loss += self.hparams.critic_loss_coef * critic_loss

        # Entropy loss, to "favor exploration".
        entropy_loss_tensor = -actions.action_dist.entropy().mean()
        entropy_loss = Loss("entropy", entropy_loss_tensor)
        loss += self.hparams.entropy_loss_coef * entropy_loss
        if done:
            episode_rewards_array = episode_rewards.reshape([-1])
            loss.metric = EpisodeMetrics(
                n_samples=1,
                mean_episode_reward=float(episode_rewards_array.sum()),
                mean_episode_length=len(episode_rewards_array),
            )
        loss.metrics["gradient_usage"] = self.get_gradient_usage_metrics(
            env_index)
        return loss