Пример #1
0
    def forward(self, observations: ContinualRLSetting.Observations,
                representations: Tensor) -> PolicyHeadOutput:
        """ Forward pass of a Policy head.

        TODO: Do we actually need the observations here? It is here so we have
        access to the 'done' from the env, but do we really need it here? or
        would there be another (cleaner) way to do this?
        """
        if len(representations.shape) < 2:
            # Flatten the representations.
            representations = representations.reshape(
                [-1, flatdim(self.input_space)])

        # Setup the buffers, which will hold the most recent observations,
        # actions and rewards within the current episode for each environment.
        if not self.batch_size:
            self.batch_size = representations.shape[0]
            self.create_buffers()

        representations = representations.float()

        logits = self.dense(representations)

        # The policy is the distribution over actions given the current state.
        action_dist = Categorical(logits=logits)
        sample = action_dist.sample()
        actions = PolicyHeadOutput(
            y_pred=sample,
            logits=logits,
            action_dist=action_dist,
        )
        return actions
Пример #2
0
def _stack_distributions(
    first_item: Categorical, *others: Categorical, **kwargs
) -> Categorical:
    return Categorical(
        logits=torch.stack(
            [first_item.logits, *(other.logits for other in others)], **kwargs
        )
    )
Пример #3
0
def _set_slice_categorical(target: Categorical, indices: Sequence[int],
                           values: Sequence[Any]) -> None:
    target.logits[indices] = values.logits
Пример #4
0
            task_labels=None,
        ),
        Observations(
            x=torch.tensor([5, 6, 7, 8, 9], dtype=int),
            task_labels=None,
        )
    ],
     Observations(
         x=torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], dtype=int),
         task_labels=np.array([None, None]),
     )),
    (
        [
            RLActions(
                y_pred=torch.tensor([0, 1, 2, 3, 4], dtype=int),
                action_dist=Categorical(
                    logits=torch.ones([5, 5], dtype=float) / 5),
            ),
            RLActions(
                y_pred=torch.tensor([0, 1, 2, 3, 4], dtype=int),
                action_dist=Categorical(
                    logits=torch.ones([5, 5], dtype=float) / 5),
            ),
        ],
        RLActions(
            y_pred=torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]], dtype=int),
            action_dist=Categorical(logits=torch.ones([2, 5, 5], dtype=float) /
                                    5),
        ),
    ),
])
def test_stack(items: List[Batch], expected: Batch):