示例#1
0
    def get_action(self, ob, sample=True, *args, **kwargs):
        self.eval_mode()
        t_ob = {key: torch_float(ob[key], device=cfg.alg.device) for key in ob}
        act_dist_cont, act_dist_disc, val = self.get_act_val(t_ob)
        action_cont = action_from_dist(act_dist_cont, sample=sample)
        action_discrete = action_from_dist(act_dist_disc, sample=sample)
        #print('456', action_discrete.shape, act_dist_disc)
        #print('123', action_cont.shape, act_dist_cont)
        log_prob_disc = action_log_prob(action_discrete, act_dist_disc)
        log_prob_cont = action_log_prob(action_cont, act_dist_cont)
        entropy_disc = action_entropy(act_dist_disc, log_prob_disc)
        entropy_cont = action_entropy(act_dist_cont, log_prob_cont)
        #print("cont:", torch_to_np(log_prob_cont).reshape(-1, 1))
        log_prob = log_prob_cont + torch.sum(log_prob_disc, axis=1)
        #print(log_prob_cont.shape, log_prob_disc.shape)
        entropy = entropy_cont + torch.sum(entropy_disc, axis=1)

        action_info = dict(log_prob=torch_to_np(log_prob),
                           entropy=torch_to_np(entropy),
                           val=torch_to_np(val))
        #print("cd", action_cont.shape, action_discrete.shape)
        action = np.concatenate(
            (torch_to_np(action_cont), torch_to_np(action_discrete)), axis=1)
        #print("action:", action)

        return action, action_info
示例#2
0
 def get_action(self, ob, sample=True, *args, **kwargs):
     self.eval_mode()
     ob = torch_float(ob, device=cfg.alg.device)
     act_dist = self.actor(ob)[0]
     action = action_from_dist(act_dist, sample=sample)
     action_info = dict()
     return torch_to_np(action), action_info
示例#3
0
 def update_q(self, obs, actions, next_obs, rewards, dones):
     q1 = self.q1((obs, actions))[0]
     q2 = self.q2((obs, actions))[0]
     with torch.no_grad():
         next_act_dist = self.actor(next_obs)[0]
         next_actions = action_from_dist(next_act_dist, sample=True)
         nlog_prob = action_log_prob(next_actions,
                                     next_act_dist).unsqueeze(-1)
         nq1_tgt_val = self.q1_tgt((next_obs, next_actions))[0]
         nq2_tgt_val = self.q2_tgt((next_obs, next_actions))[0]
         nq_tgt_val = torch.min(nq1_tgt_val,
                                nq2_tgt_val) - self.alpha * nlog_prob
         q_tgt_val = rewards + cfg.alg.rew_discount * (1 -
                                                       dones) * nq_tgt_val
     loss_q1 = F.mse_loss(q1, q_tgt_val)
     loss_q2 = F.mse_loss(q2, q_tgt_val)
     loss_q = loss_q1 + loss_q2
     self.q_optimizer.zero_grad()
     loss_q.backward()
     grad_norm = clip_grad(self.q_params, cfg.alg.max_grad_norm)
     self.q_optimizer.step()
     q_info = dict(
         q1_loss=loss_q1.item(),
         q2_loss=loss_q2.item(),
         vec_q1_val=torch_to_np(q1),
         vec_q2_val=torch_to_np(q2),
         vec_q_tgt_val=torch_to_np(q_tgt_val),
     )
     q_info['q_grad_norm'] = grad_norm
     return q_info
示例#4
0
 def get_action(self, ob, sample=True, *args, **kwargs):
     self.eval_mode()
     t_ob = torch_float(ob, device=cfg.alg.device)
     act_dist, val = self.get_act_val(t_ob)
     action = action_from_dist(act_dist, sample=sample)
     log_prob = action_log_prob(action, act_dist)
     entropy = action_entropy(act_dist, log_prob)
     action_info = dict(log_prob=torch_to_np(log_prob),
                        entropy=torch_to_np(entropy),
                        val=torch_to_np(val))
     return torch_to_np(action), action_info
示例#5
0
 def get_action(self, ob, sample=True, hidden_state=None, *args, **kwargs):
     self.eval_mode()
     t_ob = torch.from_numpy(ob).float().to(cfg.alg.device).unsqueeze(dim=1)
     act_dist, val, out_hidden_state = self.get_act_val(
         t_ob, hidden_state=hidden_state)
     action = action_from_dist(act_dist, sample=sample)
     log_prob = action_log_prob(action, act_dist)
     entropy = action_entropy(act_dist, log_prob)
     action_info = dict(
         log_prob=torch_to_np(log_prob.squeeze(1)),
         entropy=torch_to_np(entropy.squeeze(1)),
         val=torch_to_np(val.squeeze(1)),
     )
     return torch_to_np(action.squeeze(1)), action_info, out_hidden_state
示例#6
0
    def update_pi(self, obs):
        freeze_model([self.q1, self.q2])
        act_dist = self.actor(obs)[0]
        new_actions = action_from_dist(act_dist, sample=True)
        new_log_prob = action_log_prob(new_actions, act_dist).unsqueeze(-1)
        new_q1 = self.q1((obs, new_actions))[0]
        new_q2 = self.q2((obs, new_actions))[0]
        new_q = torch.min(new_q1, new_q2)

        loss_pi = (self.alpha * new_log_prob - new_q).mean()
        self.q_optimizer.zero_grad()
        self.pi_optimizer.zero_grad()
        loss_pi.backward()
        grad_norm = clip_grad(self.actor.parameters(), cfg.alg.max_grad_norm)
        self.pi_optimizer.step()
        pi_info = dict(pi_loss=loss_pi.item(),
                       pi_neg_log_prob=-new_log_prob.mean().item())
        pi_info['pi_grad_norm'] = grad_norm
        unfreeze_model([self.q1, self.q2])
        return pi_info
示例#7
0
    def get_action(self, ob, sample=True, hidden_state=None, *args, **kwargs):
        self.eval_mode()

        if type(ob) is dict:
            t_ob = {
                key: torch_float(ob[key], device=cfg.alg.device)
                for key in ob
            }
        else:
            t_ob = torch.from_numpy(ob).float().to(
                cfg.alg.device).unsqueeze(dim=1)

        act_dist, val, out_hidden_state = self.get_act_val(
            t_ob, hidden_state=hidden_state)
        action = action_from_dist(act_dist, sample=sample)
        log_prob = action_log_prob(action, act_dist)
        entropy = action_entropy(act_dist, log_prob)
        in_hidden_state = torch_to_np(
            hidden_state) if hidden_state is not None else hidden_state
        action_info = dict(log_prob=torch_to_np(log_prob.squeeze(1)),
                           entropy=torch_to_np(entropy.squeeze(1)),
                           val=torch_to_np(val.squeeze(1)),
                           in_hidden_state=in_hidden_state)
        return torch_to_np(action.squeeze(1)), action_info, out_hidden_state