Example #1
0
    def forward(self, inputs: TensorType) -> TensorType:
        L = list(inputs.size())[1]  # length of segment
        H = self._num_heads  # number of attention heads
        D = self._head_dim  # attention head dimension

        qkv = self._qkv_layer(inputs)

        queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1)
        queries = queries[:, -L:]  # only query based on the segment

        queries = torch.reshape(queries, [-1, L, H, D])
        keys = torch.reshape(keys, [-1, L, H, D])
        values = torch.reshape(values, [-1, L, H, D])

        score = torch.einsum("bihd,bjhd->bijh", queries, keys)
        score = score / D**0.5

        # causal mask of the same length as the sequence
        mask = sequence_mask(torch.arange(1, L + 1), dtype=score.dtype)
        mask = mask[None, :, :, None]
        mask = mask.float()

        masked_score = score * mask + 1e30 * (mask - 1.)
        wmat = nn.functional.softmax(masked_score, dim=2)

        out = torch.einsum("bijh,bjhd->bihd", wmat, values)
        shape = list(out.size())[:2] + [H * D]
        #        temp = torch.cat(temp2, [H * D], dim=0)
        out = torch.reshape(out, shape)
        return self._linear_layer(out)
Example #2
0
    def imagine(self,
                action: TensorType,
                state: List[TensorType] = None) -> List[TensorType]:
        """Imagines the trajectory starting from state through a list of actions.
        Similar to observe(), requires rolling out the RNN for each timestep.

        Args:
            action (TensorType): Actions
            state (List[TensorType]): Starting state before rollout

        Returns:
            Prior states
        """
        if state is None:
            state = self.get_initial_state(action.size()[0])

        action = action.permute(1, 0, 2)

        indices = range(len(action))
        priors = [[] for _ in range(len(state))]
        last = state
        for index in indices:
            last = self.img_step(last, action[index])
            [o.append(s) for s, o in zip(last, priors)]

        prior = [torch.stack(x, dim=0) for x in priors]
        prior = [e.permute(1, 0, 2) for e in prior]
        return prior
Example #3
0
    def observe(
        self,
        embed: TensorType,
        action: TensorType,
        state: List[TensorType] = None
    ) -> Tuple[List[TensorType], List[TensorType]]:
        """Returns the corresponding states from the embedding from ConvEncoder
        and actions. This is accomplished by rolling out the RNN from the
        starting state through each index of embed and action, saving all
        intermediate states between.

        Args:
            embed (TensorType): ConvEncoder embedding
            action (TensorType): Actions
            state (List[TensorType]): Initial state before rollout

        Returns:
            Posterior states and prior states (both List[TensorType])
        """
        if state is None:
            state = self.get_initial_state(action.size()[0])

        if embed.dim() <= 2:
            embed = torch.unsqueeze(embed, 1)

        if action.dim() <= 2:
            action = torch.unsqueeze(action, 1)

        embed = embed.permute(1, 0, 2)
        action = action.permute(1, 0, 2)

        priors = [[] for i in range(len(state))]
        posts = [[] for i in range(len(state))]
        last = (state, state)
        for index in range(len(action)):
            # Tuple of post and prior
            last = self.obs_step(last[0], action[index], embed[index])
            [o.append(s) for s, o in zip(last[0], posts)]
            [o.append(s) for s, o in zip(last[1], priors)]

        prior = [torch.stack(x, dim=0) for x in priors]
        post = [torch.stack(x, dim=0) for x in posts]

        prior = [e.permute(1, 0, 2) for e in prior]
        post = [e.permute(1, 0, 2) for e in post]

        return post, prior