def get_action(self, observation):
     """Get a single action given an observation."""
     with torch.no_grad():
         if type(observation) != torch.Tensor:
             observation = tu.from_numpy(observation)
         observation = observation.unsqueeze(0)
         dist = self.forward(observation)
         return (dist.rsample().squeeze(0).cpu().numpy(),
                 dict(mean=dist.mean.squeeze(0).cpu().numpy(),
                      log_std=(dist.variance**.5).log().squeeze(0).cpu().numpy()))
    def update_context(self, inputs):
        """Append single transition to the current context.

        Args:
            inputs (dict): Dictionary of transition information in np arrays .

        """
        o, a, r, no, _, _ = inputs
        o = tu.from_numpy(o[None, None, ...])
        a = tu.from_numpy(a[None, None, ...])
        r = tu.from_numpy(np.array([r])[None, None, ...])
        no = tu.from_numpy(no[None, None, ...])

        if self._use_next_obs:
            data = torch.cat([o, a, r, no], dim=2)
        else:
            data = torch.cat([o, a, r], dim=2)

        if self._context is None:
            self._context = data
        else:
            self._context = torch.cat([self._context, data], dim=1)
Пример #3
0
    def sample_data(self, indices):
        """Sample batch of training data from a list of tasks.

        Args:
            indices (list): List of tasks.

        Returns:
            torch.Tensor: Data.

        """
        # transitions sampled randomly from replay buffer
        initialized = False
        for idx in indices:
            batch = self._replay_buffers[idx].sample_transitions(
                self._batch_size)
            if not initialized:
                o = batch['observations'][np.newaxis]
                a = batch['actions'][np.newaxis]
                r = batch['rewards'][np.newaxis]
                no = batch['next_observations'][np.newaxis]
                t = batch['terminals'][np.newaxis]
                initialized = True
            else:
                o = np.vstack((o, batch['observations'][np.newaxis]))
                a = np.vstack((a, batch['actions'][np.newaxis]))
                r = np.vstack((r, batch['rewards'][np.newaxis]))
                no = np.vstack((no, batch['next_observations'][np.newaxis]))
                t = np.vstack((t, batch['terminals'][np.newaxis]))

        o = tu.from_numpy(o)
        a = tu.from_numpy(a)
        r = tu.from_numpy(r)
        no = tu.from_numpy(no)
        t = tu.from_numpy(t)

        return o, a, r, no, t
    def get_action(self, obs):
        """Sample action from the policy, conditioned on the task embedding.

        Args:
            obs (torch.Tensor): Observation values.

        Returns:
            torch.Tensor: Output action values.

        """
        z = self.z
        obs = tu.from_numpy(obs[None])
        obs_in = torch.cat([obs, z], dim=1)
        action, info = self._policy.get_action(obs_in)
        action = np.squeeze(action, axis=0)
        info['mean'] = np.squeeze(info['mean'], axis=0)
        return action, info
Пример #5
0
    def sample_context(self, indices):
        """Sample batch of context from a list of tasks.

        Args:
            indices (list): List of tasks.

        Returns:
            torch.Tensor: Context data.

        """
        # make method work given a single task index
        if not hasattr(indices, '__iter__'):
            indices = [indices]

        initialized = False
        for idx in indices:
            batch = self._context_replay_buffers[idx].sample_transitions(
                self._embedding_batch_size)
            o = batch['observations']
            a = batch['actions']
            r = batch['rewards']
            context = np.hstack((np.hstack((o, a)), r))
            if self._use_next_obs_in_context:
                context = np.hstack((context, batch['next_observations']))

            if not initialized:
                final_context = context[np.newaxis]
                initialized = True
            else:
                final_context = np.vstack((final_context, context[np.newaxis]))

        final_context = tu.from_numpy(final_context)
        if len(indices) == 1:
            final_context = final_context.unsqueeze(0)

        return final_context