def get_action(self, observation): """Get a single action given an observation.""" with torch.no_grad(): if type(observation) != torch.Tensor: observation = tu.from_numpy(observation) observation = observation.unsqueeze(0) dist = self.forward(observation) return (dist.rsample().squeeze(0).cpu().numpy(), dict(mean=dist.mean.squeeze(0).cpu().numpy(), log_std=(dist.variance**.5).log().squeeze(0).cpu().numpy()))
def update_context(self, inputs): """Append single transition to the current context. Args: inputs (dict): Dictionary of transition information in np arrays . """ o, a, r, no, _, _ = inputs o = tu.from_numpy(o[None, None, ...]) a = tu.from_numpy(a[None, None, ...]) r = tu.from_numpy(np.array([r])[None, None, ...]) no = tu.from_numpy(no[None, None, ...]) if self._use_next_obs: data = torch.cat([o, a, r, no], dim=2) else: data = torch.cat([o, a, r], dim=2) if self._context is None: self._context = data else: self._context = torch.cat([self._context, data], dim=1)
def sample_data(self, indices): """Sample batch of training data from a list of tasks. Args: indices (list): List of tasks. Returns: torch.Tensor: Data. """ # transitions sampled randomly from replay buffer initialized = False for idx in indices: batch = self._replay_buffers[idx].sample_transitions( self._batch_size) if not initialized: o = batch['observations'][np.newaxis] a = batch['actions'][np.newaxis] r = batch['rewards'][np.newaxis] no = batch['next_observations'][np.newaxis] t = batch['terminals'][np.newaxis] initialized = True else: o = np.vstack((o, batch['observations'][np.newaxis])) a = np.vstack((a, batch['actions'][np.newaxis])) r = np.vstack((r, batch['rewards'][np.newaxis])) no = np.vstack((no, batch['next_observations'][np.newaxis])) t = np.vstack((t, batch['terminals'][np.newaxis])) o = tu.from_numpy(o) a = tu.from_numpy(a) r = tu.from_numpy(r) no = tu.from_numpy(no) t = tu.from_numpy(t) return o, a, r, no, t
def get_action(self, obs): """Sample action from the policy, conditioned on the task embedding. Args: obs (torch.Tensor): Observation values. Returns: torch.Tensor: Output action values. """ z = self.z obs = tu.from_numpy(obs[None]) obs_in = torch.cat([obs, z], dim=1) action, info = self._policy.get_action(obs_in) action = np.squeeze(action, axis=0) info['mean'] = np.squeeze(info['mean'], axis=0) return action, info
def sample_context(self, indices): """Sample batch of context from a list of tasks. Args: indices (list): List of tasks. Returns: torch.Tensor: Context data. """ # make method work given a single task index if not hasattr(indices, '__iter__'): indices = [indices] initialized = False for idx in indices: batch = self._context_replay_buffers[idx].sample_transitions( self._embedding_batch_size) o = batch['observations'] a = batch['actions'] r = batch['rewards'] context = np.hstack((np.hstack((o, a)), r)) if self._use_next_obs_in_context: context = np.hstack((context, batch['next_observations'])) if not initialized: final_context = context[np.newaxis] initialized = True else: final_context = np.vstack((final_context, context[np.newaxis])) final_context = tu.from_numpy(final_context) if len(indices) == 1: final_context = final_context.unsqueeze(0) return final_context