Exemplo n.º 1
0
 def get_val(self, ob, action, tgt=False, first=True, *args, **kwargs):
     self.eval_mode()
     ob = torch_float(ob, device=cfg.alg.device)
     action = torch_float(action, device=cfg.alg.device)
     idx = 1 if first else 2
     tgt_suffix = '_tgt' if tgt else ''
     q_func = getattr(self, f'q{idx}{tgt_suffix}')
     val = q_func((ob, action))[0]
     val = val.squeeze(-1)
     return val
Exemplo n.º 2
0
    def get_val(self, ob, *args, **kwargs):
        self.eval_mode()

        if type(ob) is dict:
            ob = {
                key: torch_float(ob[key], device=cfg.alg.device)
                for key in ob
            }
        else:
            ob = torch_float(ob, device=cfg.alg.device)

        val, body_out = self.critic(x=ob)
        val = val.squeeze(-1)
        return val
Exemplo n.º 3
0
 def get_act_val(self, ob, *args, **kwargs):
     if type(ob) is dict:
         ob = {
             key: torch_float(ob[key], device=cfg.alg.device)
             for key in ob
         }
     else:
         ob = torch_float(ob, device=cfg.alg.device)
     act_dist_cont, act_dist_disc, body_out = self.actor(ob)
     if self.same_body:
         val, body_out = self.critic(body_x=body_out)
     else:
         val, body_out = self.critic(x=ob)
     val = val.squeeze(-1)
     return act_dist_cont, act_dist_disc, val
Exemplo n.º 4
0
    def optim_preprocess(self, data):
        self.train_mode()
        for key, val in data.items():
            data[key] = torch_float(val, device=cfg.alg.device)
        ob = data['ob']
        state = data['state']
        action = data['action']
        ret = data['ret']
        adv = data['adv']
        old_log_prob = data['log_prob']
        old_val = data['val']

        act_dist, val = self.get_act_val({"ob": ob, "state": state})
        log_prob = action_log_prob(action, act_dist)
        entropy = action_entropy(act_dist, log_prob)
        if not all([x.ndim == 1 for x in [val, entropy, log_prob]]):
            raise ValueError('val, entropy, log_prob should be 1-dim!')
        processed_data = dict(val=val,
                              old_val=old_val,
                              ret=ret,
                              log_prob=log_prob,
                              old_log_prob=old_log_prob,
                              adv=adv,
                              entropy=entropy)
        return processed_data
Exemplo n.º 5
0
    def get_action(self, ob, sample=True, *args, **kwargs):
        self.eval_mode()
        t_ob = {key: torch_float(ob[key], device=cfg.alg.device) for key in ob}
        act_dist_cont, act_dist_disc, val = self.get_act_val(t_ob)
        action_cont = action_from_dist(act_dist_cont, sample=sample)
        action_discrete = action_from_dist(act_dist_disc, sample=sample)
        #print('456', action_discrete.shape, act_dist_disc)
        #print('123', action_cont.shape, act_dist_cont)
        log_prob_disc = action_log_prob(action_discrete, act_dist_disc)
        log_prob_cont = action_log_prob(action_cont, act_dist_cont)
        entropy_disc = action_entropy(act_dist_disc, log_prob_disc)
        entropy_cont = action_entropy(act_dist_cont, log_prob_cont)
        #print("cont:", torch_to_np(log_prob_cont).reshape(-1, 1))
        log_prob = log_prob_cont + torch.sum(log_prob_disc, axis=1)
        #print(log_prob_cont.shape, log_prob_disc.shape)
        entropy = entropy_cont + torch.sum(entropy_disc, axis=1)

        action_info = dict(log_prob=torch_to_np(log_prob),
                           entropy=torch_to_np(entropy),
                           val=torch_to_np(val))
        #print("cd", action_cont.shape, action_discrete.shape)
        action = np.concatenate(
            (torch_to_np(action_cont), torch_to_np(action_discrete)), axis=1)
        #print("action:", action)

        return action, action_info
Exemplo n.º 6
0
    def get_val(self, ob, hidden_state=None, *args, **kwargs):
        self.eval_mode()

        if type(ob) is dict:
            ob = {
                key: torch_float(ob[key], device=cfg.alg.device)
                for key in ob
            }
        else:
            ob = torch_float(ob, device=cfg.alg.device)

        #ob = torch_float(ob, device=cfg.alg.device).unsqueeze(dim=1)
        val, body_out, out_hidden_state = self.critic(
            x=ob, hidden_state=hidden_state)
        val = val.squeeze(-1)
        return val, out_hidden_state
Exemplo n.º 7
0
 def get_action(self, ob, sample=True, *args, **kwargs):
     self.eval_mode()
     ob = torch_float(ob, device=cfg.alg.device)
     act_dist = self.actor(ob)[0]
     action = action_from_dist(act_dist, sample=sample)
     action_info = dict()
     return torch_to_np(action), action_info
Exemplo n.º 8
0
    def optim_preprocess(self, data):
        self.train_mode()
        for key, val in data.items():
            if val is not None:
                data[key] = torch_float(val, device=cfg.alg.device)
        ob = data['ob']
        action = data['action']
        ret = data['ret']
        adv = data['adv']
        old_log_prob = data['log_prob']
        old_val = data['val']
        done = data['done']
        hidden_state = data['hidden_state']
        hidden_state = hidden_state.permute(1, 0, 2)

        act_dist, val, out_hidden_state = self.get_act_val(
            ob, hidden_state=hidden_state, done=done)
        log_prob = action_log_prob(action, act_dist)
        entropy = action_entropy(act_dist, log_prob)
        processed_data = dict(val=val,
                              old_val=old_val,
                              ret=ret,
                              log_prob=log_prob,
                              old_log_prob=old_log_prob,
                              adv=adv,
                              entropy=entropy)
        return processed_data
Exemplo n.º 9
0
 def get_val(self, ob, hidden_state=None, *args, **kwargs):
     self.eval_mode()
     ob = torch_float(ob, device=cfg.alg.device).unsqueeze(dim=1)
     val, body_out, out_hidden_state = self.critic(
         x=ob, hidden_state=hidden_state)
     val = val.squeeze(-1)
     return val, out_hidden_state
Exemplo n.º 10
0
    def get_act_val(self, ob, hidden_state=None, done=None, *args, **kwargs):
        if type(ob) is dict:
            ob = {
                key: torch_float(ob[key], device=cfg.alg.device)
                for key in ob
            }
        else:
            ob = torch_float(ob, device=cfg.alg.device)

        act_dist, body_out, out_hidden_state = self.actor(
            ob, hidden_state=hidden_state, done=done)

        val, body_out, _ = self.critic(body_x=body_out,
                                       hidden_state=hidden_state,
                                       done=done)
        val = val.squeeze(-1)
        return act_dist, val, out_hidden_state
Exemplo n.º 11
0
    def get_action(self, ob, sample=True, *args, **kwargs):
        self.eval_mode()
        if type(ob) is dict:
            t_ob = {
                key: torch_float(ob[key], device=cfg.alg.device)
                for key in ob
            }
        else:
            t_ob = torch_float(ob, device=cfg.alg.device)

        act_dist, val = self.get_act_val(t_ob)
        action = action_from_dist(act_dist, sample=sample)
        log_prob = action_log_prob(action, act_dist)
        entropy = action_entropy(act_dist, log_prob)
        action_info = dict(log_prob=torch_to_np(log_prob),
                           entropy=torch_to_np(entropy),
                           val=torch_to_np(val))
        return torch_to_np(action), action_info
Exemplo n.º 12
0
 def get_act_val(self, ob, *args, **kwargs):
     ob = torch_float(ob, device=cfg.alg.device)
     act_dist, body_out = self.actor(ob)
     if self.same_body:
         val, body_out = self.critic(body_x=body_out)
     else:
         val, body_out = self.critic(x=ob)
     val = val.squeeze(-1)
     return act_dist, val
Exemplo n.º 13
0
    def get_act_val(self, ob, hidden_state=None, done=None, *args, **kwargs):
        ob = torch_float(ob, device=cfg.alg.device)
        act_dist, body_out, out_hidden_state = self.actor(
            ob, hidden_state=hidden_state, done=done)

        val, body_out, _ = self.critic(body_x=body_out,
                                       hidden_state=hidden_state,
                                       done=done)
        val = val.squeeze(-1)
        return act_dist, val, out_hidden_state
Exemplo n.º 14
0
def cal_gae_torch(gamma, lam, rewards, value_estimates, last_value, dones):
    device = value_estimates.device
    rewards = torch_float(rewards, device)
    value_estimates = torch_float(value_estimates, device)
    last_value = torch_float(last_value, device)
    if len(value_estimates.shape) > 1:
        last_value = last_value.view(1, -1)
    dones = torch_float(dones, device)
    advs = torch.zeros_like(rewards).to(device)
    last_gae_lam = 0
    value_estimates = torch.cat((value_estimates,
                                 last_value),
                                dim=0)
    for t in reversed(range(rewards.shape[0])):
        non_terminal = 1.0 - dones[t]
        delta = rewards[t] + gamma * value_estimates[t + 1].flatten() * non_terminal - value_estimates[t].flatten()
        last_gae_lam = delta + gamma * lam * non_terminal * last_gae_lam
        advs[t] = last_gae_lam.clone()
    return advs
Exemplo n.º 15
0
    def get_act_val(self, ob, hidden_state=None, done=None, *args, **kwargs):
        if type(ob) is dict:
            ob = {key: torch_float(ob[key], device=cfg.alg.device) for key in ob}
        else:
            ob = torch_float(ob, device=cfg.alg.device)

        #print(ob["state"].shape)
        act_dist_cont, act_dist_disc, body_out, out_hidden_state = self.actor(ob,
                                                                              hidden_state=hidden_state,
                                                                              done=done)
        #print(act_dist_cont)
        if self.same_body:
            val, body_out, _ = self.critic(body_x=body_out,
                                        hidden_state=hidden_state,
                                        done=done)
        else:
            val, body_out, _ = self.critic(x=ob,
                                        hidden_state=hidden_state,
                                        done=done)
        val = val.squeeze(-1)
        return act_dist_cont, act_dist_disc, val, out_hidden_state
Exemplo n.º 16
0
    def optim_preprocess(self, data):
        self.train_mode()
        for key, val in data.items():
            data[key] = torch_float(val, device=cfg.alg.device)
        ob = data['ob']
        #print(ob.shape)
        #from IPython import embed
        #embed()

        state = data['state']
        action = data['action']
        ret = data['ret']
        adv = data['adv']
        old_log_prob = data['log_prob']
        old_val = data['val']
        done = data['done']
        hidden_state = data['hidden_state']
        hidden_state = hidden_state.permute(1, 0, 2)

        act_dist_cont, act_dist_disc, val, out_hidden_state = self.get_act_val({"ob": ob, "state": state},
                                                             hidden_state=hidden_state,
                                                             done=done)
        action_cont = action[:, :, :self.dim_cont]
        action_discrete = action[:, :, self.dim_cont:]
        #print('456', action_discrete.shape, act_dist_disc)
        #print('123', action_cont.shape, act_dist_cont)
        log_prob_disc = action_log_prob(action_discrete, act_dist_disc)
        log_prob_cont = action_log_prob(action_cont, act_dist_cont)
        entropy_disc = action_entropy(act_dist_disc, log_prob_disc)
        entropy_cont = action_entropy(act_dist_cont, log_prob_cont)
        #print("cont:", torch_to_np(log_prob_cont).reshape(-1, 1))
        if len(log_prob_disc.shape) == 2:
            log_prob = log_prob_cont + torch.sum(log_prob_disc, axis=1)
            #print(log_prob_cont.shape, log_prob_disc.shape)
            entropy = entropy_cont + torch.sum(entropy_disc, axis=1)
        else:
            log_prob = log_prob_cont + torch.sum(log_prob_disc, axis=2)
            #print(log_prob_cont.shape, log_prob_disc.shape)
            entropy = entropy_cont + torch.sum(entropy_disc, axis=2)

        #print(val.shape, entropy.shape, log_prob.shape)
        #if not all([x.ndim == 1 for x in [val, entropy, log_prob]]):
        #    raise ValueError('val, entropy, log_prob should be 1-dim!')
        processed_data = dict(
            val=val,
            old_val=old_val,
            ret=ret,
            log_prob=log_prob,
            old_log_prob=old_log_prob,
            adv=adv,
            entropy=entropy
        )
        return processed_data
Exemplo n.º 17
0
    def optimize(self, trajs: TrajectorySet,
                 critic_only: bool) -> Dict[str, float]:
        obs, next_obs, actions, desired_goal, achieved_goal = trajs.data
        state = torch_float(
            self.normalizer(np.concatenate([obs, desired_goal], axis=-1)))
        next_state = torch_float(
            self.normalizer(np.concatenate([next_obs, desired_goal], axis=-1)))
        u = torch_float(actions)
        r = torch_float(self.reward_fn(achieved_goal, desired_goal, info=None))

        with torch.no_grad():
            q_next = self.critic_targ(next_state, self.actor_targ(next_state))
            y = r.view(-1, 1) + self.cfg.discount * q_next
            y = torch.clamp(y, -1 / (1 - self.cfg.discount), 0)
        q = self.critic(state, u)
        critic_loss = F.mse_loss(q, y)

        if not critic_only:
            u_pred = self.actor(state)
            actor_loss = -self.critic(state, u_pred).mean()
            actor_reg = torch.square(u_pred / self.cfg.action_range).mean()
            actor_loss += self.cfg.action_reg * actor_reg

            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()
        else:
            actor_loss = torch.tensor(0)

        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        return {
            'train/actor_loss': actor_loss.item(),
            'train/critic_loss': critic_loss.item()
        }
Exemplo n.º 18
0
 def get_action(self, obs: Dict[str, np.ndarray],
                sample: bool) -> np.ndarray:
     state = np.concatenate([obs['observation'], obs['desired_goal']],
                            axis=-1)
     state_tensor = torch_float(self.normalizer(state))
     u = self.actor(state_tensor).numpy()
     if sample:
         noise_scale = self.cfg.noise_eps * self.cfg.action_range
         u += noise_scale * np.random.randn(*u.shape)
         u = np.clip(u, -self.cfg.action_range, self.cfg.action_range)
         u_rand = np.random.uniform(low=-self.cfg.action_range,
                                    high=self.cfg.action_range,
                                    size=u.shape)
         use_rand = np.random.binomial(1, self.cfg.epsilon, size=u.shape[0])
         u += use_rand.reshape(-1, 1) * (u_rand - u)
     if self.pretrain is not None:
         u += self.pretrain(state_tensor).numpy()
     return u
Exemplo n.º 19
0
    def optimize(self, data, *args, **kwargs):
        self.train_mode()
        for key, val in data.items():
            data[key] = torch_float(val, device=cfg.alg.device)
        obs = data['obs']
        actions = data['actions']
        next_obs = data['next_obs']
        rewards = data['rewards'].unsqueeze(-1)
        dones = data['dones'].unsqueeze(-1)
        q_info = self.update_q(obs=obs,
                               actions=actions,
                               next_obs=next_obs,
                               rewards=rewards,
                               dones=dones)
        pi_info = self.update_pi(obs=obs)
        alpha_info = self.update_alpha(pi_info['pi_neg_log_prob'])
        optim_info = {**q_info, **pi_info, **alpha_info}
        optim_info['alpha'] = self.alpha
        if hasattr(self, 'log_alpha'):
            optim_info['log_alpha'] = self.log_alpha.item()

        soft_update(self.q1_tgt, self.q1, cfg.alg.polyak)
        soft_update(self.q2_tgt, self.q2, cfg.alg.polyak)
        return optim_info
Exemplo n.º 20
0
    def get_action(self, ob, sample=True, hidden_state=None, *args, **kwargs):
        self.eval_mode()

        if type(ob) is dict:
            t_ob = {
                key: torch_float(ob[key], device=cfg.alg.device)
                for key in ob
            }
        else:
            t_ob = torch.from_numpy(ob).float().to(
                cfg.alg.device).unsqueeze(dim=1)

        act_dist, val, out_hidden_state = self.get_act_val(
            t_ob, hidden_state=hidden_state)
        action = action_from_dist(act_dist, sample=sample)
        log_prob = action_log_prob(action, act_dist)
        entropy = action_entropy(act_dist, log_prob)
        in_hidden_state = torch_to_np(
            hidden_state) if hidden_state is not None else hidden_state
        action_info = dict(log_prob=torch_to_np(log_prob.squeeze(1)),
                           entropy=torch_to_np(entropy.squeeze(1)),
                           val=torch_to_np(val.squeeze(1)),
                           in_hidden_state=in_hidden_state)
        return torch_to_np(action.squeeze(1)), action_info, out_hidden_state
Exemplo n.º 21
0
    def optim_preprocess(self, data):
        self.train_mode()
        for key, val in data.items():
            data[key] = torch_float(val, device=cfg.alg.device)
        ob = data['ob']
        state = data['state']
        action = data['action']
        ret = data['ret']
        adv = data['adv']
        old_log_prob = data['log_prob']
        old_val = data['val']

        act_dist_cont, act_dist_disc, val = self.get_act_val({
            "ob": ob,
            "state": state
        })
        action_cont = action[:, :self.dim_cont]
        action_discrete = action[:, self.dim_cont:]
        log_prob_disc = action_log_prob(action_discrete, act_dist_disc)
        log_prob_cont = action_log_prob(action_cont, act_dist_cont)
        entropy_disc = action_entropy(act_dist_disc, log_prob_disc)
        entropy_cont = action_entropy(act_dist_cont, log_prob_cont)
        #print("cont:", torch_to_np(log_prob_cont).reshape(-1, 1))
        log_prob = log_prob_cont + torch.sum(log_prob_disc, axis=1)
        entropy = entropy_cont + torch.sum(entropy_disc, axis=1)

        if not all([x.ndim == 1 for x in [val, entropy, log_prob]]):
            raise ValueError('val, entropy, log_prob should be 1-dim!')
        processed_data = dict(val=val,
                              old_val=old_val,
                              ret=ret,
                              log_prob=log_prob,
                              old_log_prob=old_log_prob,
                              adv=adv,
                              entropy=entropy)
        return processed_data
Exemplo n.º 22
0
 def get_val(self, ob, *args, **kwargs):
     self.eval_mode()
     ob = torch_float(ob, device=cfg.alg.device)
     val, body_out = self.critic(x=ob)
     val = val.squeeze(-1)
     return val