def get_val(self, ob, action, tgt=False, first=True, *args, **kwargs): self.eval_mode() ob = torch_float(ob, device=cfg.alg.device) action = torch_float(action, device=cfg.alg.device) idx = 1 if first else 2 tgt_suffix = '_tgt' if tgt else '' q_func = getattr(self, f'q{idx}{tgt_suffix}') val = q_func((ob, action))[0] val = val.squeeze(-1) return val
def get_val(self, ob, *args, **kwargs): self.eval_mode() if type(ob) is dict: ob = { key: torch_float(ob[key], device=cfg.alg.device) for key in ob } else: ob = torch_float(ob, device=cfg.alg.device) val, body_out = self.critic(x=ob) val = val.squeeze(-1) return val
def get_act_val(self, ob, *args, **kwargs): if type(ob) is dict: ob = { key: torch_float(ob[key], device=cfg.alg.device) for key in ob } else: ob = torch_float(ob, device=cfg.alg.device) act_dist_cont, act_dist_disc, body_out = self.actor(ob) if self.same_body: val, body_out = self.critic(body_x=body_out) else: val, body_out = self.critic(x=ob) val = val.squeeze(-1) return act_dist_cont, act_dist_disc, val
def optim_preprocess(self, data): self.train_mode() for key, val in data.items(): data[key] = torch_float(val, device=cfg.alg.device) ob = data['ob'] state = data['state'] action = data['action'] ret = data['ret'] adv = data['adv'] old_log_prob = data['log_prob'] old_val = data['val'] act_dist, val = self.get_act_val({"ob": ob, "state": state}) log_prob = action_log_prob(action, act_dist) entropy = action_entropy(act_dist, log_prob) if not all([x.ndim == 1 for x in [val, entropy, log_prob]]): raise ValueError('val, entropy, log_prob should be 1-dim!') processed_data = dict(val=val, old_val=old_val, ret=ret, log_prob=log_prob, old_log_prob=old_log_prob, adv=adv, entropy=entropy) return processed_data
def get_action(self, ob, sample=True, *args, **kwargs): self.eval_mode() t_ob = {key: torch_float(ob[key], device=cfg.alg.device) for key in ob} act_dist_cont, act_dist_disc, val = self.get_act_val(t_ob) action_cont = action_from_dist(act_dist_cont, sample=sample) action_discrete = action_from_dist(act_dist_disc, sample=sample) #print('456', action_discrete.shape, act_dist_disc) #print('123', action_cont.shape, act_dist_cont) log_prob_disc = action_log_prob(action_discrete, act_dist_disc) log_prob_cont = action_log_prob(action_cont, act_dist_cont) entropy_disc = action_entropy(act_dist_disc, log_prob_disc) entropy_cont = action_entropy(act_dist_cont, log_prob_cont) #print("cont:", torch_to_np(log_prob_cont).reshape(-1, 1)) log_prob = log_prob_cont + torch.sum(log_prob_disc, axis=1) #print(log_prob_cont.shape, log_prob_disc.shape) entropy = entropy_cont + torch.sum(entropy_disc, axis=1) action_info = dict(log_prob=torch_to_np(log_prob), entropy=torch_to_np(entropy), val=torch_to_np(val)) #print("cd", action_cont.shape, action_discrete.shape) action = np.concatenate( (torch_to_np(action_cont), torch_to_np(action_discrete)), axis=1) #print("action:", action) return action, action_info
def get_val(self, ob, hidden_state=None, *args, **kwargs): self.eval_mode() if type(ob) is dict: ob = { key: torch_float(ob[key], device=cfg.alg.device) for key in ob } else: ob = torch_float(ob, device=cfg.alg.device) #ob = torch_float(ob, device=cfg.alg.device).unsqueeze(dim=1) val, body_out, out_hidden_state = self.critic( x=ob, hidden_state=hidden_state) val = val.squeeze(-1) return val, out_hidden_state
def get_action(self, ob, sample=True, *args, **kwargs): self.eval_mode() ob = torch_float(ob, device=cfg.alg.device) act_dist = self.actor(ob)[0] action = action_from_dist(act_dist, sample=sample) action_info = dict() return torch_to_np(action), action_info
def optim_preprocess(self, data): self.train_mode() for key, val in data.items(): if val is not None: data[key] = torch_float(val, device=cfg.alg.device) ob = data['ob'] action = data['action'] ret = data['ret'] adv = data['adv'] old_log_prob = data['log_prob'] old_val = data['val'] done = data['done'] hidden_state = data['hidden_state'] hidden_state = hidden_state.permute(1, 0, 2) act_dist, val, out_hidden_state = self.get_act_val( ob, hidden_state=hidden_state, done=done) log_prob = action_log_prob(action, act_dist) entropy = action_entropy(act_dist, log_prob) processed_data = dict(val=val, old_val=old_val, ret=ret, log_prob=log_prob, old_log_prob=old_log_prob, adv=adv, entropy=entropy) return processed_data
def get_val(self, ob, hidden_state=None, *args, **kwargs): self.eval_mode() ob = torch_float(ob, device=cfg.alg.device).unsqueeze(dim=1) val, body_out, out_hidden_state = self.critic( x=ob, hidden_state=hidden_state) val = val.squeeze(-1) return val, out_hidden_state
def get_act_val(self, ob, hidden_state=None, done=None, *args, **kwargs): if type(ob) is dict: ob = { key: torch_float(ob[key], device=cfg.alg.device) for key in ob } else: ob = torch_float(ob, device=cfg.alg.device) act_dist, body_out, out_hidden_state = self.actor( ob, hidden_state=hidden_state, done=done) val, body_out, _ = self.critic(body_x=body_out, hidden_state=hidden_state, done=done) val = val.squeeze(-1) return act_dist, val, out_hidden_state
def get_action(self, ob, sample=True, *args, **kwargs): self.eval_mode() if type(ob) is dict: t_ob = { key: torch_float(ob[key], device=cfg.alg.device) for key in ob } else: t_ob = torch_float(ob, device=cfg.alg.device) act_dist, val = self.get_act_val(t_ob) action = action_from_dist(act_dist, sample=sample) log_prob = action_log_prob(action, act_dist) entropy = action_entropy(act_dist, log_prob) action_info = dict(log_prob=torch_to_np(log_prob), entropy=torch_to_np(entropy), val=torch_to_np(val)) return torch_to_np(action), action_info
def get_act_val(self, ob, *args, **kwargs): ob = torch_float(ob, device=cfg.alg.device) act_dist, body_out = self.actor(ob) if self.same_body: val, body_out = self.critic(body_x=body_out) else: val, body_out = self.critic(x=ob) val = val.squeeze(-1) return act_dist, val
def get_act_val(self, ob, hidden_state=None, done=None, *args, **kwargs): ob = torch_float(ob, device=cfg.alg.device) act_dist, body_out, out_hidden_state = self.actor( ob, hidden_state=hidden_state, done=done) val, body_out, _ = self.critic(body_x=body_out, hidden_state=hidden_state, done=done) val = val.squeeze(-1) return act_dist, val, out_hidden_state
def cal_gae_torch(gamma, lam, rewards, value_estimates, last_value, dones): device = value_estimates.device rewards = torch_float(rewards, device) value_estimates = torch_float(value_estimates, device) last_value = torch_float(last_value, device) if len(value_estimates.shape) > 1: last_value = last_value.view(1, -1) dones = torch_float(dones, device) advs = torch.zeros_like(rewards).to(device) last_gae_lam = 0 value_estimates = torch.cat((value_estimates, last_value), dim=0) for t in reversed(range(rewards.shape[0])): non_terminal = 1.0 - dones[t] delta = rewards[t] + gamma * value_estimates[t + 1].flatten() * non_terminal - value_estimates[t].flatten() last_gae_lam = delta + gamma * lam * non_terminal * last_gae_lam advs[t] = last_gae_lam.clone() return advs
def get_act_val(self, ob, hidden_state=None, done=None, *args, **kwargs): if type(ob) is dict: ob = {key: torch_float(ob[key], device=cfg.alg.device) for key in ob} else: ob = torch_float(ob, device=cfg.alg.device) #print(ob["state"].shape) act_dist_cont, act_dist_disc, body_out, out_hidden_state = self.actor(ob, hidden_state=hidden_state, done=done) #print(act_dist_cont) if self.same_body: val, body_out, _ = self.critic(body_x=body_out, hidden_state=hidden_state, done=done) else: val, body_out, _ = self.critic(x=ob, hidden_state=hidden_state, done=done) val = val.squeeze(-1) return act_dist_cont, act_dist_disc, val, out_hidden_state
def optim_preprocess(self, data): self.train_mode() for key, val in data.items(): data[key] = torch_float(val, device=cfg.alg.device) ob = data['ob'] #print(ob.shape) #from IPython import embed #embed() state = data['state'] action = data['action'] ret = data['ret'] adv = data['adv'] old_log_prob = data['log_prob'] old_val = data['val'] done = data['done'] hidden_state = data['hidden_state'] hidden_state = hidden_state.permute(1, 0, 2) act_dist_cont, act_dist_disc, val, out_hidden_state = self.get_act_val({"ob": ob, "state": state}, hidden_state=hidden_state, done=done) action_cont = action[:, :, :self.dim_cont] action_discrete = action[:, :, self.dim_cont:] #print('456', action_discrete.shape, act_dist_disc) #print('123', action_cont.shape, act_dist_cont) log_prob_disc = action_log_prob(action_discrete, act_dist_disc) log_prob_cont = action_log_prob(action_cont, act_dist_cont) entropy_disc = action_entropy(act_dist_disc, log_prob_disc) entropy_cont = action_entropy(act_dist_cont, log_prob_cont) #print("cont:", torch_to_np(log_prob_cont).reshape(-1, 1)) if len(log_prob_disc.shape) == 2: log_prob = log_prob_cont + torch.sum(log_prob_disc, axis=1) #print(log_prob_cont.shape, log_prob_disc.shape) entropy = entropy_cont + torch.sum(entropy_disc, axis=1) else: log_prob = log_prob_cont + torch.sum(log_prob_disc, axis=2) #print(log_prob_cont.shape, log_prob_disc.shape) entropy = entropy_cont + torch.sum(entropy_disc, axis=2) #print(val.shape, entropy.shape, log_prob.shape) #if not all([x.ndim == 1 for x in [val, entropy, log_prob]]): # raise ValueError('val, entropy, log_prob should be 1-dim!') processed_data = dict( val=val, old_val=old_val, ret=ret, log_prob=log_prob, old_log_prob=old_log_prob, adv=adv, entropy=entropy ) return processed_data
def optimize(self, trajs: TrajectorySet, critic_only: bool) -> Dict[str, float]: obs, next_obs, actions, desired_goal, achieved_goal = trajs.data state = torch_float( self.normalizer(np.concatenate([obs, desired_goal], axis=-1))) next_state = torch_float( self.normalizer(np.concatenate([next_obs, desired_goal], axis=-1))) u = torch_float(actions) r = torch_float(self.reward_fn(achieved_goal, desired_goal, info=None)) with torch.no_grad(): q_next = self.critic_targ(next_state, self.actor_targ(next_state)) y = r.view(-1, 1) + self.cfg.discount * q_next y = torch.clamp(y, -1 / (1 - self.cfg.discount), 0) q = self.critic(state, u) critic_loss = F.mse_loss(q, y) if not critic_only: u_pred = self.actor(state) actor_loss = -self.critic(state, u_pred).mean() actor_reg = torch.square(u_pred / self.cfg.action_range).mean() actor_loss += self.cfg.action_reg * actor_reg self.actor_optim.zero_grad() actor_loss.backward() self.actor_optim.step() else: actor_loss = torch.tensor(0) self.critic_optim.zero_grad() critic_loss.backward() self.critic_optim.step() return { 'train/actor_loss': actor_loss.item(), 'train/critic_loss': critic_loss.item() }
def get_action(self, obs: Dict[str, np.ndarray], sample: bool) -> np.ndarray: state = np.concatenate([obs['observation'], obs['desired_goal']], axis=-1) state_tensor = torch_float(self.normalizer(state)) u = self.actor(state_tensor).numpy() if sample: noise_scale = self.cfg.noise_eps * self.cfg.action_range u += noise_scale * np.random.randn(*u.shape) u = np.clip(u, -self.cfg.action_range, self.cfg.action_range) u_rand = np.random.uniform(low=-self.cfg.action_range, high=self.cfg.action_range, size=u.shape) use_rand = np.random.binomial(1, self.cfg.epsilon, size=u.shape[0]) u += use_rand.reshape(-1, 1) * (u_rand - u) if self.pretrain is not None: u += self.pretrain(state_tensor).numpy() return u
def optimize(self, data, *args, **kwargs): self.train_mode() for key, val in data.items(): data[key] = torch_float(val, device=cfg.alg.device) obs = data['obs'] actions = data['actions'] next_obs = data['next_obs'] rewards = data['rewards'].unsqueeze(-1) dones = data['dones'].unsqueeze(-1) q_info = self.update_q(obs=obs, actions=actions, next_obs=next_obs, rewards=rewards, dones=dones) pi_info = self.update_pi(obs=obs) alpha_info = self.update_alpha(pi_info['pi_neg_log_prob']) optim_info = {**q_info, **pi_info, **alpha_info} optim_info['alpha'] = self.alpha if hasattr(self, 'log_alpha'): optim_info['log_alpha'] = self.log_alpha.item() soft_update(self.q1_tgt, self.q1, cfg.alg.polyak) soft_update(self.q2_tgt, self.q2, cfg.alg.polyak) return optim_info
def get_action(self, ob, sample=True, hidden_state=None, *args, **kwargs): self.eval_mode() if type(ob) is dict: t_ob = { key: torch_float(ob[key], device=cfg.alg.device) for key in ob } else: t_ob = torch.from_numpy(ob).float().to( cfg.alg.device).unsqueeze(dim=1) act_dist, val, out_hidden_state = self.get_act_val( t_ob, hidden_state=hidden_state) action = action_from_dist(act_dist, sample=sample) log_prob = action_log_prob(action, act_dist) entropy = action_entropy(act_dist, log_prob) in_hidden_state = torch_to_np( hidden_state) if hidden_state is not None else hidden_state action_info = dict(log_prob=torch_to_np(log_prob.squeeze(1)), entropy=torch_to_np(entropy.squeeze(1)), val=torch_to_np(val.squeeze(1)), in_hidden_state=in_hidden_state) return torch_to_np(action.squeeze(1)), action_info, out_hidden_state
def optim_preprocess(self, data): self.train_mode() for key, val in data.items(): data[key] = torch_float(val, device=cfg.alg.device) ob = data['ob'] state = data['state'] action = data['action'] ret = data['ret'] adv = data['adv'] old_log_prob = data['log_prob'] old_val = data['val'] act_dist_cont, act_dist_disc, val = self.get_act_val({ "ob": ob, "state": state }) action_cont = action[:, :self.dim_cont] action_discrete = action[:, self.dim_cont:] log_prob_disc = action_log_prob(action_discrete, act_dist_disc) log_prob_cont = action_log_prob(action_cont, act_dist_cont) entropy_disc = action_entropy(act_dist_disc, log_prob_disc) entropy_cont = action_entropy(act_dist_cont, log_prob_cont) #print("cont:", torch_to_np(log_prob_cont).reshape(-1, 1)) log_prob = log_prob_cont + torch.sum(log_prob_disc, axis=1) entropy = entropy_cont + torch.sum(entropy_disc, axis=1) if not all([x.ndim == 1 for x in [val, entropy, log_prob]]): raise ValueError('val, entropy, log_prob should be 1-dim!') processed_data = dict(val=val, old_val=old_val, ret=ret, log_prob=log_prob, old_log_prob=old_log_prob, adv=adv, entropy=entropy) return processed_data
def get_val(self, ob, *args, **kwargs): self.eval_mode() ob = torch_float(ob, device=cfg.alg.device) val, body_out = self.critic(x=ob) val = val.squeeze(-1) return val