def imagine(self, action: TensorType, state: List[TensorType] = None) -> List[TensorType]: """Imagines the trajectory starting from state through a list of actions. Similar to observe(), requires rolling out the RNN for each timestep. Args: action (TensorType): Actions state (List[TensorType]): Starting state before rollout Returns: Prior states """ if state is None: state = self.get_initial_state(action.size()[0]) action = action.permute(1, 0, 2) indices = range(len(action)) priors = [[] for _ in range(len(state))] last = state for index in indices: last = self.img_step(last, action[index]) [o.append(s) for s, o in zip(last, priors)] prior = [torch.stack(x, dim=0) for x in priors] prior = [e.permute(1, 0, 2) for e in prior] return prior
def observe( self, embed: TensorType, action: TensorType, state: List[TensorType] = None ) -> Tuple[List[TensorType], List[TensorType]]: """Returns the corresponding states from the embedding from ConvEncoder and actions. This is accomplished by rolling out the RNN from the starting state through each index of embed and action, saving all intermediate states between. Args: embed (TensorType): ConvEncoder embedding action (TensorType): Actions state (List[TensorType]): Initial state before rollout Returns: Posterior states and prior states (both List[TensorType]) """ if state is None: state = self.get_initial_state(action.size()[0]) if embed.dim() <= 2: embed = torch.unsqueeze(embed, 1) if action.dim() <= 2: action = torch.unsqueeze(action, 1) embed = embed.permute(1, 0, 2) action = action.permute(1, 0, 2) priors = [[] for i in range(len(state))] posts = [[] for i in range(len(state))] last = (state, state) for index in range(len(action)): # Tuple of post and prior last = self.obs_step(last[0], action[index], embed[index]) [o.append(s) for s, o in zip(last[0], posts)] [o.append(s) for s, o in zip(last[1], priors)] prior = [torch.stack(x, dim=0) for x in priors] post = [torch.stack(x, dim=0) for x in posts] prior = [e.permute(1, 0, 2) for e in prior] post = [e.permute(1, 0, 2) for e in post] return post, prior