def stack_buffers(self, env_index: int): """ Stack the observations/actions/rewards for this env and return them. """ # episode_observations = tuple(self.observations[env_index]) episode_representations = tuple(self.representations[env_index]) episode_actions = tuple(self.actions[env_index]) episode_rewards = tuple(self.rewards[env_index]) assert len(episode_representations) assert len(episode_actions) assert len(episode_rewards) # BUG: Need to make sure that all tensors are on the same device: # assert self.device is not None # episode_representations = [ # move(item, device=self.device) for item in episode_representations # ] # episode_actions = [ # move(item, device=self.device) for item in episode_actions # ] # episode_rewards = [ # move(item, device=self.device) for item in episode_rewards # ] stacked_inputs = stack(episode_representations) stacked_actions = stack(episode_actions) stacked_rewards = stack(episode_rewards) return stacked_inputs, stacked_actions, stacked_rewards
def stack_buffers(self, env_index: int): """ Stack the observations/actions/rewards for this env and return them. """ # episode_observations = tuple(self.observations[env_index]) episode_representations = tuple(self.representations[env_index]) episode_actions = tuple(self.actions[env_index]) episode_rewards = tuple(self.rewards[env_index]) assert len(episode_representations) assert len(episode_actions) assert len(episode_rewards) stacked_inputs = stack(episode_representations) stacked_actions = stack(episode_actions) stacked_rewards = stack(episode_rewards) return stacked_inputs, stacked_actions, stacked_rewards
def stack(cls: Type[B], items: List[B]) -> B: items = list(items) from sequoia.utils.generic_functions import stack # Just to make sure that the returned item will be of the type `cls`. assert isinstance(items[0], cls) return stack(items)
def task_inference_forward_pass(self, observations: Observations) -> Tensor: """ Forward pass with a simple form of task inference. """ # We don't have access to task labels (`task_labels` is None). # --> Perform a simple kind of task inference: # 1. Perform a forward pass with each task's output head; # 2. Merge these predictions into a single prediction somehow. assert observations.task_labels is None # NOTE: This assumes that the observations are batched. # These are used below to indicate the shape of the different tensors. B = observations.x.shape[0] T = n_known_tasks = len(self.output_heads) N = self.action_space.n # Tasks encountered previously and for which we have an output head. known_task_ids: list[int] = list(range(n_known_tasks)) assert known_task_ids # Placeholder for the predictions from each output head for each item in the # batch task_outputs = [None for _ in known_task_ids] # [T, B, N] # Get the forward pass for each task. for task_id in known_task_ids: # Create 'fake' Observations for this forward pass, with 'fake' task labels. # NOTE: We do this so we can call `self.forward` and not get an infinite # recursion. task_labels = torch.full([B], task_id, device=self.device, dtype=int) task_observations = replace(observations, task_labels=task_labels) # Setup the model for task `task_id`, and then do a forward pass. task_forward_pass = self.forward(task_observations) task_outputs[task_id] = task_forward_pass # 'Merge' the predictions from each output head using some kind of task # inference. assert all(item is not None for item in task_outputs) # Stack the predictions (logits) from each output head. stacked_forward_pass: ForwardPass = stack(task_outputs, dim=1) logits_from_each_head = stacked_forward_pass.actions.logits assert logits_from_each_head.shape == (B, T, N), ( logits_from_each_head.shape, (B, T, N)) # Normalize the logits from each output head with softmax. # Example with batch size of 1, output heads = 2, and classes = 4: # logits from each head: [[[123, 456, 123, 123], [1, 1, 2, 1]]] # 'probs' from each head: [[[0.1, 0.6, 0.1, 0.1], [0.2, 0.2, 0.4, 0.2]]] probs_from_each_head = torch.softmax(logits_from_each_head, dim=-1) assert probs_from_each_head.shape == (B, T, N) # Simple kind of task inference: # For each item in the batch, use the class that has the highest probability # accross all output heads. max_probs_across_heads, chosen_head_per_class = probs_from_each_head.max( dim=1) assert max_probs_across_heads.shape == (B, N) assert chosen_head_per_class.shape == (B, N) # Example (continued): # max probs across heads: [[0.2, 0.6, 0.4, 0.2]] # chosen output heads per class: [[1, 0, 1, 1]] # Determine which output head has highest "confidence": max_prob_value, most_probable_class = max_probs_across_heads.max(dim=1) assert max_prob_value.shape == (B, ) assert most_probable_class.shape == (B, ) # Example (continued): # max_prob_value: [0.6] # max_prob_class: [1] # A bit of boolean trickery to get what we need, which is, for each item, the # index of the output head that gave the most confident prediction. mask = F.one_hot(most_probable_class, N).to(dtype=bool, device=self.device) chosen_output_head_per_item = chosen_head_per_class[mask] assert mask.shape == (B, N) assert chosen_output_head_per_item.shape == (B, ) # Example (continued): # mask: [[False, True, False, True]] # chosen_output_head_per_item: [0] # Create a bool tensor to select items associated with the chosen output head. selected_mask = F.one_hot(chosen_output_head_per_item, T).to(dtype=bool, device=self.device) assert selected_mask.shape == (B, T) # Select the logits using the mask: selected_forward_pass = stacked_forward_pass[selected_mask] assert selected_forward_pass.actions.logits.shape == (B, N) return selected_forward_pass