def train(self, batch, **kwargs): states, actions, returns, action_logprobs = \ batch["state"], batch["action"], batch["return"],\ batch["action_logprob"] states = utils.any2device(states, device=self._device) actions = utils.any2device(actions, device=self._device) returns = utils.any2device(returns, device=self._device) old_logprobs = utils.any2device(action_logprobs, device=self._device) # actor loss _, logprobs = self.actor(states, logprob=actions) # REINFORCE objective function policy_loss = -torch.mean(logprobs * returns) entropy = -(torch.exp(logprobs) * logprobs).mean() entropy_loss = self.entropy_reg_coefficient * entropy policy_loss = policy_loss + entropy_loss # actor update actor_update_metrics = self.actor_update(policy_loss) or {} # metrics kl = 0.5 * (logprobs - old_logprobs).pow(2).mean() metrics = { "loss_actor": policy_loss.item(), "kl": kl.item(), } metrics = {**metrics, **actor_update_metrics} return metrics
def train(self, batch, actor_update=True, critic_update=True): states_t, actions_t, rewards_t, states_tp1, done_t = \ batch["state"], batch["action"], batch["reward"], \ batch["next_state"], batch["done"] states_t = utils.any2device(states_t, device=self._device) actions_t = utils.any2device(actions_t, device=self._device) rewards_t = utils.any2device( rewards_t, device=self._device ).unsqueeze(1) states_tp1 = utils.any2device(states_tp1, device=self._device) done_t = utils.any2device(done_t, device=self._device).unsqueeze(1) """ states_t: [bs; history_len; observation_len] actions_t: [bs; action_len] rewards_t: [bs; 1] states_tp1: [bs; history_len; observation_len] done_t: [bs; 1] """ policy_loss, value_loss = self._loss_fn( states_t, actions_t, rewards_t, states_tp1, done_t ) metrics = self.update_step( policy_loss=policy_loss, value_loss=value_loss, actor_update=actor_update, critic_update=critic_update ) return metrics
def _init(self, critics: List[CriticSpec], reward_scale: float = 1.0): self.reward_scale = reward_scale # @TODO: policy regularization critics = [x.to(self._device) for x in critics] target_critics = [copy.deepcopy(x).to(self._device) for x in critics] critics_optimizer = [] critics_scheduler = [] for critic in critics: critic_components = utils.get_trainer_components( agent=critic, loss_params=self._critic_loss_params, optimizer_params=self._critic_optimizer_params, scheduler_params=self._critic_scheduler_params, grad_clip_params=self._critic_grad_clip_params) critics_optimizer.append(critic_components["optimizer"]) critics_scheduler.append(critic_components["scheduler"]) self.critics = [self.critic] + critics self.critics_optimizer = [self.critic_optimizer] + critics_optimizer self.critics_scheduler = [self.critic_scheduler] + critics_scheduler self.target_critics = [self.target_critic] + target_critics # value distribution approximation critic_distribution = self.critic.distribution self._loss_fn = self._base_loss self._num_heads = self.critic.num_heads self._num_critics = len(self.critics) self._hyperbolic_constant = self.critic.hyperbolic_constant self._gammas = \ utils.hyperbolic_gammas( self._gamma, self._hyperbolic_constant, self._num_heads ) self._gammas = utils.any2device(self._gammas, device=self._device) assert critic_distribution in [None, "categorical", "quantile"] if critic_distribution == "categorical": self.num_atoms = self.critic.num_atoms values_range = self.critic.values_range self.v_min, self.v_max = values_range self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1) z = torch.linspace(start=self.v_min, end=self.v_max, steps=self.num_atoms) self.z = utils.any2device(z, device=self._device) self._loss_fn = self._categorical_loss elif critic_distribution == "quantile": self.num_atoms = self.critic.num_atoms tau_min = 1 / (2 * self.num_atoms) tau_max = 1 - tau_min tau = torch.linspace(start=tau_min, end=tau_max, steps=self.num_atoms) self.tau = utils.any2device(tau, device=self._device) self._loss_fn = self._quantile_loss else: assert self.critic_criterion is not None
def get_rollout(self, states, actions, rewards, dones): assert len(states) == len(actions) == len(rewards) == len(dones) trajectory_len = \ rewards.shape[0] if dones[-1] else rewards.shape[0] - 1 states_len = states.shape[0] states = utils.any2device(states, device=self._device) actions = utils.any2device(actions, device=self._device) rewards = np.array(rewards)[:trajectory_len] values = torch.zeros( (states_len + 1, self._num_heads, self._num_atoms)).\ to(self._device) values[:states_len, ...] = self.critic(states).squeeze_(dim=2) # Each column corresponds to a different gamma values = values.cpu().numpy()[:trajectory_len + 1, ...] _, logprobs = self.actor(states, logprob=actions) logprobs = logprobs.cpu().numpy().reshape(-1)[:trajectory_len] # len x num_heads deltas = rewards[:, None, None] \ + self._gammas[:, None] * values[1:] - values[:-1] # For each gamma in the list of gammas compute the # advantage and returns # len x num_heads x num_atoms advantages = np.stack([ utils.geometric_cumsum(gamma * self.gae_lambda, deltas[:, i]) for i, gamma in enumerate(self._gammas) ], axis=1) # len x num_heads returns = np.stack([ utils.geometric_cumsum(gamma, rewards[:, None])[:, 0] for gamma in self._gammas ], axis=1) # final rollout dones = dones[:trajectory_len] values = values[:trajectory_len] assert len(logprobs) == len(advantages) \ == len(dones) == len(returns) == len(values) rollout = { "action_logprob": logprobs, "advantage": advantages, "done": dones, "return": returns, "value": values, } return rollout
def _init(self, use_value_clipping: bool = True, gae_lambda: float = 0.95, clip_eps: float = 0.2, entropy_regularization: float = None): self.use_value_clipping = use_value_clipping self.gae_lambda = gae_lambda self.clip_eps = clip_eps self.entropy_regularization = entropy_regularization critic_distribution = self.critic.distribution self._value_loss_fn = self._base_value_loss self._num_atoms = self.critic.num_atoms self._num_heads = self.critic.num_heads self._hyperbolic_constant = self.critic.hyperbolic_constant self._gammas = \ utils.hyperbolic_gammas( self._gamma, self._hyperbolic_constant, self._num_heads ) # 1 x num_heads x 1 self._gammas_torch = utils.any2device(self._gammas, device=self._device)[None, :, None] if critic_distribution == "categorical": self.num_atoms = self.critic.num_atoms values_range = self.critic.values_range self.v_min, self.v_max = values_range self.delta_z = (self.v_max - self.v_min) / (self._num_atoms - 1) z = torch.linspace(start=self.v_min, end=self.v_max, steps=self._num_atoms) self.z = utils.any2device(z, device=self._device) self._value_loss_fn = self._categorical_value_loss elif critic_distribution == "quantile": assert self.critic_criterion is not None self.num_atoms = self.critic.num_atoms tau_min = 1 / (2 * self._num_atoms) tau_max = 1 - tau_min tau = torch.linspace(start=tau_min, end=tau_max, steps=self._num_atoms) self.tau = utils.any2device(tau, device=self._device) self._value_loss_fn = self._quantile_value_loss if not self.use_value_clipping: assert self.critic_criterion is not None
def train(self, batch, **kwargs): (states_t, actions_t, returns_t, states_tp1, done_t, values_t, advantages_t, action_logprobs_t) = (batch["state"], batch["action"], batch["return"], batch["state_tp1"], batch["done"], batch["value"], batch["advantage"], batch["action_logprob"]) states_t = utils.any2device(states_t, device=self._device) actions_t = utils.any2device(actions_t, device=self._device) returns_t = utils.any2device(returns_t, device=self._device).unsqueeze_(-1) states_tp1 = utils.any2device(states_tp1, device=self._device) done_t = utils.any2device(done_t, device=self._device)[:, None, None] values_t = utils.any2device(values_t, device=self._device) advantages_t = utils.any2device(advantages_t, device=self._device) action_logprobs_t = utils.any2device(action_logprobs_t, device=self._device) # critic loss value_loss = self._value_loss_fn(states_t, values_t, returns_t, states_tp1, done_t) # actor loss _, action_logprobs_tp0 = self.actor(states_t, logprob=actions_t) ratio = torch.exp(action_logprobs_tp0 - action_logprobs_t) ratio = ratio[:, None, None] # The same ratio for each head of the critic policy_loss_unclipped = advantages_t * ratio policy_loss_clipped = advantages_t * torch.clamp( ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) policy_loss = -torch.min(policy_loss_unclipped, policy_loss_clipped).mean() if self.entropy_regularization is not None: entropy = -(torch.exp(action_logprobs_tp0) * action_logprobs_tp0).mean() entropy_loss = self.entropy_regularization * entropy policy_loss = policy_loss + entropy_loss # actor update actor_update_metrics = self.actor_update(policy_loss) or {} # critic update critic_update_metrics = self.critic_update(value_loss) or {} # metrics kl = 0.5 * (action_logprobs_tp0 - action_logprobs_t).pow(2).mean() clipped_fraction = \ (torch.abs(ratio - 1.0) > self.clip_eps).float().mean() metrics = { "loss_actor": policy_loss.item(), "loss_critic": value_loss.item(), "kl": kl.item(), "clipped_fraction": clipped_fraction.item() } metrics = {**metrics, **actor_update_metrics, **critic_update_metrics} return metrics
def get_rollout(self, states, actions, rewards, dones): trajectory_len = \ rewards.shape[0] if dones[-1] else rewards.shape[0] - 1 states = utils.any2device(states, device=self._device) actions = utils.any2device(actions, device=self._device) rewards = np.array(rewards)[:trajectory_len] _, logprobs = self.actor(states, logprob=actions) logprobs = logprobs.cpu().numpy().reshape(-1)[:trajectory_len] returns = utils.geometric_cumsum(self.gamma, rewards)[0] rollout = {"return": returns, "action_logprob": logprobs} return rollout
def _state2device(array: np.ndarray, device): array = utils.any2device(array, device) if isinstance(array, dict): array = { key: value.to(device).unsqueeze(0) for key, value in array.items() } else: array = array.to(device).unsqueeze(0) return array
def _init(self, entropy_regularization: float = None): self.entropy_regularization = entropy_regularization # value distribution approximation critic_distribution = self.critic.distribution self._loss_fn = self._base_loss self._num_heads = self.critic.num_heads self._hyperbolic_constant = self.critic.hyperbolic_constant self._gammas = \ utils.hyperbolic_gammas( self._gamma, self._hyperbolic_constant, self._num_heads ) self._gammas = utils.any2device(self._gammas, device=self._device) assert critic_distribution in [None, "categorical", "quantile"] if critic_distribution == "categorical": assert self.critic_criterion is None self.num_atoms = self.critic.num_atoms values_range = self.critic.values_range self.v_min, self.v_max = values_range self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1) z = torch.linspace(start=self.v_min, end=self.v_max, steps=self.num_atoms) self.z = utils.any2device(z, device=self._device) self._loss_fn = self._categorical_loss elif critic_distribution == "quantile": assert self.critic_criterion is not None self.num_atoms = self.critic.num_atoms tau_min = 1 / (2 * self.num_atoms) tau_max = 1 - tau_min tau = torch.linspace(start=tau_min, end=tau_max, steps=self.num_atoms) self.tau = utils.any2device(tau, device=self._device) self._loss_fn = self._quantile_loss else: assert self.critic_criterion is not None
def reset(self, exploration_strategy=None): from catalyst.rl.exploration import \ ParameterSpaceNoise, OrnsteinUhlenbeckProcess if isinstance(exploration_strategy, OrnsteinUhlenbeckProcess): exploration_strategy.reset_state(self.env.action_space.shape[0]) if isinstance(exploration_strategy, ParameterSpaceNoise) \ and len(self.observations) > 1: states = self._get_states_history() states = utils.any2device(states, device=self._device) exploration_strategy.update_actor(self.agent, states) self._init_buffers() self._init_with_observation(self.env.reset())
def train(self, batch, **kwargs): states, actions, returns, values, advantages, action_logprobs = \ batch["state"], batch["action"], batch["return"], \ batch["value"], batch["advantage"], batch["action_logprob"] states = utils.any2device(states, device=self._device) actions = utils.any2device(actions, device=self._device) returns = utils.any2device(returns, device=self._device) old_values = utils.any2device(values, device=self._device) advantages = utils.any2device(advantages, device=self._device) old_logprobs = utils.any2device(action_logprobs, device=self._device) # critic loss values = self.critic(states).squeeze(-1) values_clip = old_values + torch.clamp(values - old_values, -self.clip_eps, self.clip_eps) value_loss_unclipped = (values - returns).pow(2) value_loss_clipped = (values_clip - returns).pow(2) value_loss = 0.5 * torch.max(value_loss_unclipped, value_loss_clipped).mean() # actor loss _, logprobs = self.actor(states, logprob=actions) ratio = torch.exp(logprobs - old_logprobs) # The same ratio for each head of the critic policy_loss_unclipped = advantages * ratio[:, None] policy_loss_clipped = advantages * torch.clamp( ratio[:, None], 1.0 - self.clip_eps, 1.0 + self.clip_eps) policy_loss = -torch.min(policy_loss_unclipped, policy_loss_clipped).mean() entropy = -(torch.exp(logprobs) * logprobs).mean() entropy_loss = self.entropy_reg_coefficient * entropy policy_loss = policy_loss + entropy_loss # actor update actor_update_metrics = self.actor_update(policy_loss) or {} # critic update critic_update_metrics = self.critic_update(value_loss) or {} # metrics kl = 0.5 * (logprobs - old_logprobs).pow(2).mean() clipped_fraction = \ (torch.abs(ratio - 1.0) > self.clip_eps).float().mean() metrics = { "loss_actor": policy_loss.item(), "loss_critic": value_loss.item(), "kl": kl.item(), "clipped_fraction": clipped_fraction.item() } metrics = {**metrics, **actor_update_metrics, **critic_update_metrics} return metrics