示例#1
0
 def _act(self, obs):
     assert len(obs.shape) == 2 and obs.shape[
         0] == 1  # assume single-observation batches with leading 1-dim
     if not self.action_plan:
         # generate action plan if the current one is empty
         split_obs = self._split_obs(obs)
         with no_batchnorm_update(
                 self._policy
         ) if obs.shape[0] == 1 else contextlib.suppress():
             actions = self._policy.decode(
                 map2torch(split_obs.z, self._hp.device),
                 map2torch(split_obs.cond_input, self._hp.device),
                 self._policy.n_rollout_steps)
         self.action_plan = deque(split_along_axis(map2np(actions), axis=1))
     return AttrDict(action=self.action_plan.popleft())
示例#2
0
    def run(self, inputs, use_learned_prior=True):
        """Policy interface for model. Runs decoder if action plan is empty, otherwise returns next action from action plan.
        :arg inputs: dict with 'states', 'actions', 'images' keys from environment
        :arg use_learned_prior: if True, uses learned prior otherwise samples latent from uniform prior
        """
        if not self._action_plan:
            inputs = map2torch(inputs, device=self.device)

            # sample latent variable from prior
            z = self.compute_learned_prior(self._learned_prior_input(inputs), first_only=True).sample() \
                if use_learned_prior else Gaussian(torch.zeros((1, self._hp.nz_vae*2), device=self.device)).sample()

            # decode into action plan
            z = z.repeat(
                self._hp.batch_size, 1
            )  # this is a HACK flat LSTM decoder can only take batch_size inputs
            input_obs = self._learned_prior_input(inputs).repeat(
                self._hp.batch_size, 1)
            actions = self.decode(z,
                                  cond_inputs=input_obs,
                                  steps=self._hp.n_rollout_steps)[0]
            self._action_plan = deque(split_along_axis(map2np(actions),
                                                       axis=0))

        return AttrDict(action=self._action_plan.popleft()[None])
示例#3
0
文件: ac_agent.py 项目: clvrai/spirl
 def _act(self, obs):
     # TODO implement non-sampling validation mode
     obs = map2torch(self._obs_normalizer(obs), self._hp.device)
     if len(obs.shape) == 1:  # we need batched inputs for policy
         policy_output = self._remove_batch(self.policy(obs[None]))
         if 'dist' in policy_output:
             del policy_output['dist']
         return map2np(policy_output)
     return map2np(self.policy(obs))
示例#4
0
 def _split_obs(self, obs):
     unflattened_obs = map2np(
         self._policy.unflatten_obs(
             map2torch(obs[:, :-self._policy.latent_dim],
                       device=self.device)))
     return AttrDict(
         cond_input=unflattened_obs.prior_obs,
         z=obs[:, -self._policy.latent_dim:],
     )
示例#5
0
 def update(self, experience_batch):
     if 'delay' in self._hp.omega_schedule_params and self._update_steps < self._hp.omega_schedule_params.delay:
         # if schedule has warmup phase in which *only* prior is sampled, train policy to minimize divergence
         self.replay_buffer.append(experience_batch)
         experience_batch = self.replay_buffer.sample(n_samples=self._hp.batch_size)
         experience_batch = map2torch(experience_batch, self._hp.device)
         policy_output = self._run_policy(experience_batch.observation)
         policy_loss = policy_output.prior_divergence.mean()
         self._perform_update(policy_loss, self.policy_opt, self.policy)
         self._update_steps += 1
         info = AttrDict(prior_divergence=policy_output.prior_divergence.mean())
     else:
         info = super().update(experience_batch)
     info.omega = self._omega(self._update_steps)
     return info
示例#6
0
文件: ac_agent.py 项目: clvrai/spirl
 def _act_rand(self, obs):
     policy_output = self.policy.sample_rand(
         map2torch(obs, self.policy.device))
     if 'dist' in policy_output:
         del policy_output['dist']
     return map2np(policy_output)
示例#7
0
文件: ac_agent.py 项目: clvrai/spirl
    def update(self, experience_batch):
        """Updates actor and critics."""
        # push experience batch into replay buffer
        self.add_experience(experience_batch)

        for _ in range(self._hp.update_iterations):
            # sample batch and normalize
            experience_batch = self._sample_experience()
            experience_batch = self._normalize_batch(experience_batch)
            experience_batch = map2torch(experience_batch, self._hp.device)
            experience_batch = self._preprocess_experience(experience_batch)

            policy_output = self._run_policy(experience_batch.observation)

            # update alpha
            alpha_loss = self._update_alpha(experience_batch, policy_output)

            # compute policy loss
            policy_loss = self._compute_policy_loss(experience_batch,
                                                    policy_output)

            # compute target Q value
            with torch.no_grad():
                policy_output_next = self._run_policy(
                    experience_batch.observation_next)
                value_next = self._compute_next_value(experience_batch,
                                                      policy_output_next)
                q_target = experience_batch.reward * self._hp.reward_scale + \
                                (1 - experience_batch.done) * self._hp.discount_factor * value_next
                if self._hp.clip_q_target:
                    q_target = self._clip_q_target(q_target)
                q_target = q_target.detach()
                check_shape(q_target, [self._hp.batch_size])

            # compute critic loss
            critic_losses, qs = self._compute_critic_loss(
                experience_batch, q_target)

            # update critic networks
            [
                self._perform_update(critic_loss, critic_opt, critic)
                for critic_loss, critic_opt, critic in zip(
                    critic_losses, self.critic_opts, self.critics)
            ]

            # update target networks
            [
                self._soft_update_target_network(critic_target, critic) for
                critic_target, critic in zip(self.critic_targets, self.critics)
            ]

            # update policy network on policy loss
            self._perform_update(policy_loss, self.policy_opt, self.policy)

            # logging
            info = AttrDict(  # losses
                policy_loss=policy_loss,
                alpha_loss=alpha_loss,
                critic_loss_1=critic_losses[0],
                critic_loss_2=critic_losses[1],
            )
            if self._update_steps % 100 == 0:
                info.update(
                    AttrDict(  # gradient norms
                        policy_grad_norm=avg_grad_norm(self.policy),
                        critic_1_grad_norm=avg_grad_norm(self.critics[0]),
                        critic_2_grad_norm=avg_grad_norm(self.critics[1]),
                    ))
            info.update(
                AttrDict(  # misc
                    alpha=self.alpha,
                    pi_log_prob=policy_output.log_prob.mean(),
                    policy_entropy=policy_output.dist.entropy().mean(),
                    q_target=q_target.mean(),
                    q_1=qs[0].mean(),
                    q_2=qs[1].mean(),
                ))
            info.update(self._aux_info(experience_batch, policy_output))
            info = map_dict(ten2ar, info)

            self._update_steps += 1

        return info