def imagine(self, action: TensorType, state: List[TensorType] = None) -> List[TensorType]: """Imagines the trajectory starting from state through a list of actions. Similar to observe(), requires rolling out the RNN for each timestep. Args: action (TensorType): Actions state (List[TensorType]): Starting state before rollout Returns: Prior states """ if state is None: state = self.get_initial_state(action.size()[0]) action = action.permute(1, 0, 2) indices = range(len(action)) priors = [[] for _ in range(len(state))] last = state for index in indices: last = self.img_step(last, action[index]) [o.append(s) for s, o in zip(last, priors)] prior = [torch.stack(x, dim=0) for x in priors] prior = [e.permute(1, 0, 2) for e in prior] return prior
def forward(self, inputs: TensorType) -> TensorType: L = list(inputs.size())[1] # length of segment H = self._num_heads # number of attention heads D = self._head_dim # attention head dimension qkv = self._qkv_layer(inputs) queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1) queries = queries[:, -L:] # only query based on the segment queries = torch.reshape(queries, [-1, L, H, D]) keys = torch.reshape(keys, [-1, L, H, D]) values = torch.reshape(values, [-1, L, H, D]) score = torch.einsum("bihd,bjhd->bijh", queries, keys) score = score / D**0.5 # causal mask of the same length as the sequence mask = sequence_mask(torch.arange(1, L + 1), dtype=score.dtype) mask = mask[None, :, :, None] mask = mask.float() masked_score = score * mask + 1e30 * (mask - 1.) wmat = nn.functional.softmax(masked_score, dim=2) out = torch.einsum("bijh,bjhd->bihd", wmat, values) shape = list(out.size())[:2] + [H * D] # temp = torch.cat(temp2, [H * D], dim=0) out = torch.reshape(out, shape) return self._linear_layer(out)
def observe( self, embed: TensorType, action: TensorType, state: List[TensorType] = None ) -> Tuple[List[TensorType], List[TensorType]]: """Returns the corresponding states from the embedding from ConvEncoder and actions. This is accomplished by rolling out the RNN from the starting state through each index of embed and action, saving all intermediate states between. Args: embed (TensorType): ConvEncoder embedding action (TensorType): Actions state (List[TensorType]): Initial state before rollout Returns: Posterior states and prior states (both List[TensorType]) """ if state is None: state = self.get_initial_state(action.size()[0]) if embed.dim() <= 2: embed = torch.unsqueeze(embed, 1) if action.dim() <= 2: action = torch.unsqueeze(action, 1) embed = embed.permute(1, 0, 2) action = action.permute(1, 0, 2) priors = [[] for i in range(len(state))] posts = [[] for i in range(len(state))] last = (state, state) for index in range(len(action)): # Tuple of post and prior last = self.obs_step(last[0], action[index], embed[index]) [o.append(s) for s, o in zip(last[0], posts)] [o.append(s) for s, o in zip(last[1], priors)] prior = [torch.stack(x, dim=0) for x in priors] post = [torch.stack(x, dim=0) for x in posts] prior = [e.permute(1, 0, 2) for e in prior] post = [e.permute(1, 0, 2) for e in post] return post, prior
def _predict_next_obs(self, obs: TensorType, action: TensorType): """ Returns the predicted next state, given an action and state. obs (TensorType): Observed state at time t. action (TensorType): Action taken at time t """ return self.forward_model( torch.cat((self._get_latent_vector(obs), action.unsqueeze(1)), dim=-1))
def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType) -> (TensorType, List[TensorType]): nbr_agents = self.nbr_agents cell_size = self.gru_cell_size obs = input_dict[SampleBatch.OBS]['obs'] B = obs.shape[0] h = state[0] R = h.shape[0] max_T = seq_lens.max().item() obs = add_time_dimension(obs, max_seq_len=max_T, framework=self.framework, time_major=self.is_time_major()) agent_indexes = torch.eye(n=nbr_agents, dtype=h.dtype, device=h.device).unsqueeze(0).unsqueeze(0) agent_indexes = agent_indexes.expand(max_T, R, -1, -1) x = torch.cat([obs, agent_indexes], dim=-1) x = self.stage1(x) x = x.view(max_T, R * self.nbr_agents, -1) h = h.view(1, R * self.nbr_agents, cell_size) mems, h = self.gru(x, h) h = h.view(R, nbr_agents, cell_size) mems = mems.view(max_T, R, nbr_agents, cell_size) output = self.stage2(mems) if self.has_avail_actions: avail_actions = add_time_dimension( input_dict['obs']['avail_actions'], max_seq_len=max_T, framework=self.framework, time_major=self.is_time_major()) avail_actions = avail_actions.view(max_T, R, nbr_agents, self.nbr_actions) inf_mask = torch.clamp(torch.log(avail_actions), FLOAT_MIN, FLOAT_MAX) output = output + inf_mask output = output.view(B, self.num_outputs) return output, [ h, ]