def forward(self, observations: ContinualRLSetting.Observations, representations: Tensor) -> PolicyHeadOutput: """ Forward pass of a Policy head. TODO: Do we actually need the observations here? It is here so we have access to the 'done' from the env, but do we really need it here? or would there be another (cleaner) way to do this? """ if len(representations.shape) < 2: # Flatten the representations. representations = representations.reshape( [-1, flatdim(self.input_space)]) # Setup the buffers, which will hold the most recent observations, # actions and rewards within the current episode for each environment. if not self.batch_size: self.batch_size = representations.shape[0] self.create_buffers() representations = representations.float() logits = self.dense(representations) # The policy is the distribution over actions given the current state. action_dist = Categorical(logits=logits) sample = action_dist.sample() actions = PolicyHeadOutput( y_pred=sample, logits=logits, action_dist=action_dist, ) return actions
def _stack_distributions( first_item: Categorical, *others: Categorical, **kwargs ) -> Categorical: return Categorical( logits=torch.stack( [first_item.logits, *(other.logits for other in others)], **kwargs ) )
def _set_slice_categorical(target: Categorical, indices: Sequence[int], values: Sequence[Any]) -> None: target.logits[indices] = values.logits
task_labels=None, ), Observations( x=torch.tensor([5, 6, 7, 8, 9], dtype=int), task_labels=None, ) ], Observations( x=torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], dtype=int), task_labels=np.array([None, None]), )), ( [ RLActions( y_pred=torch.tensor([0, 1, 2, 3, 4], dtype=int), action_dist=Categorical( logits=torch.ones([5, 5], dtype=float) / 5), ), RLActions( y_pred=torch.tensor([0, 1, 2, 3, 4], dtype=int), action_dist=Categorical( logits=torch.ones([5, 5], dtype=float) / 5), ), ], RLActions( y_pred=torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], dtype=int), action_dist=Categorical(logits=torch.ones([2, 5, 5], dtype=float) / 5), ), ), ]) def test_stack(items: List[Batch], expected: Batch):