Esempio n. 1
0
    def imagine(self,
                action: TensorType,
                state: List[TensorType] = None) -> List[TensorType]:
        """Imagines the trajectory starting from state through a list of actions.
        Similar to observe(), requires rolling out the RNN for each timestep.

        Args:
            action (TensorType): Actions
            state (List[TensorType]): Starting state before rollout

        Returns:
            Prior states
        """
        if state is None:
            state = self.get_initial_state(action.size()[0])

        action = action.permute(1, 0, 2)

        indices = range(len(action))
        priors = [[] for _ in range(len(state))]
        last = state
        for index in indices:
            last = self.img_step(last, action[index])
            [o.append(s) for s, o in zip(last, priors)]

        prior = [torch.stack(x, dim=0) for x in priors]
        prior = [e.permute(1, 0, 2) for e in prior]
        return prior
Esempio n. 2
0
    def forward(self, inputs: TensorType) -> TensorType:
        L = list(inputs.size())[1]  # length of segment
        H = self._num_heads  # number of attention heads
        D = self._head_dim  # attention head dimension

        qkv = self._qkv_layer(inputs)

        queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1)
        queries = queries[:, -L:]  # only query based on the segment

        queries = torch.reshape(queries, [-1, L, H, D])
        keys = torch.reshape(keys, [-1, L, H, D])
        values = torch.reshape(values, [-1, L, H, D])

        score = torch.einsum("bihd,bjhd->bijh", queries, keys)
        score = score / D**0.5

        # causal mask of the same length as the sequence
        mask = sequence_mask(torch.arange(1, L + 1), dtype=score.dtype)
        mask = mask[None, :, :, None]
        mask = mask.float()

        masked_score = score * mask + 1e30 * (mask - 1.)
        wmat = nn.functional.softmax(masked_score, dim=2)

        out = torch.einsum("bijh,bjhd->bihd", wmat, values)
        shape = list(out.size())[:2] + [H * D]
        #        temp = torch.cat(temp2, [H * D], dim=0)
        out = torch.reshape(out, shape)
        return self._linear_layer(out)
Esempio n. 3
0
    def observe(
        self,
        embed: TensorType,
        action: TensorType,
        state: List[TensorType] = None
    ) -> Tuple[List[TensorType], List[TensorType]]:
        """Returns the corresponding states from the embedding from ConvEncoder
        and actions. This is accomplished by rolling out the RNN from the
        starting state through each index of embed and action, saving all
        intermediate states between.

        Args:
            embed (TensorType): ConvEncoder embedding
            action (TensorType): Actions
            state (List[TensorType]): Initial state before rollout

        Returns:
            Posterior states and prior states (both List[TensorType])
        """
        if state is None:
            state = self.get_initial_state(action.size()[0])

        if embed.dim() <= 2:
            embed = torch.unsqueeze(embed, 1)

        if action.dim() <= 2:
            action = torch.unsqueeze(action, 1)

        embed = embed.permute(1, 0, 2)
        action = action.permute(1, 0, 2)

        priors = [[] for i in range(len(state))]
        posts = [[] for i in range(len(state))]
        last = (state, state)
        for index in range(len(action)):
            # Tuple of post and prior
            last = self.obs_step(last[0], action[index], embed[index])
            [o.append(s) for s, o in zip(last[0], posts)]
            [o.append(s) for s, o in zip(last[1], priors)]

        prior = [torch.stack(x, dim=0) for x in priors]
        post = [torch.stack(x, dim=0) for x in posts]

        prior = [e.permute(1, 0, 2) for e in prior]
        post = [e.permute(1, 0, 2) for e in post]

        return post, prior
Esempio n. 4
0
    def _predict_next_obs(self, obs: TensorType, action: TensorType):
        """
        Returns the predicted next state, given an action and state.

        obs (TensorType): Observed state at time t.
        action (TensorType): Action taken at time t
        """
        return self.forward_model(
            torch.cat((self._get_latent_vector(obs), action.unsqueeze(1)),
                      dim=-1))
Esempio n. 5
0
    def forward(self, input_dict: Dict[str,
                                       TensorType], state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        nbr_agents = self.nbr_agents
        cell_size = self.gru_cell_size
        obs = input_dict[SampleBatch.OBS]['obs']
        B = obs.shape[0]
        h = state[0]
        R = h.shape[0]
        max_T = seq_lens.max().item()
        obs = add_time_dimension(obs,
                                 max_seq_len=max_T,
                                 framework=self.framework,
                                 time_major=self.is_time_major())

        agent_indexes = torch.eye(n=nbr_agents, dtype=h.dtype,
                                  device=h.device).unsqueeze(0).unsqueeze(0)
        agent_indexes = agent_indexes.expand(max_T, R, -1, -1)
        x = torch.cat([obs, agent_indexes], dim=-1)
        x = self.stage1(x)
        x = x.view(max_T, R * self.nbr_agents, -1)
        h = h.view(1, R * self.nbr_agents, cell_size)
        mems, h = self.gru(x, h)
        h = h.view(R, nbr_agents, cell_size)
        mems = mems.view(max_T, R, nbr_agents, cell_size)

        output = self.stage2(mems)

        if self.has_avail_actions:
            avail_actions = add_time_dimension(
                input_dict['obs']['avail_actions'],
                max_seq_len=max_T,
                framework=self.framework,
                time_major=self.is_time_major())
            avail_actions = avail_actions.view(max_T, R, nbr_agents,
                                               self.nbr_actions)
            inf_mask = torch.clamp(torch.log(avail_actions), FLOAT_MIN,
                                   FLOAT_MAX)
            output = output + inf_mask
        output = output.view(B, self.num_outputs)

        return output, [
            h,
        ]