Ejemplo n.º 1
0
    def train(self, batch, **kwargs):
        states, actions, returns, action_logprobs = \
            batch["state"], batch["action"], batch["return"],\
            batch["action_logprob"]

        states = utils.any2device(states, device=self._device)
        actions = utils.any2device(actions, device=self._device)
        returns = utils.any2device(returns, device=self._device)
        old_logprobs = utils.any2device(action_logprobs, device=self._device)

        # actor loss
        _, logprobs = self.actor(states, logprob=actions)

        # REINFORCE objective function
        policy_loss = -torch.mean(logprobs * returns)

        entropy = -(torch.exp(logprobs) * logprobs).mean()
        entropy_loss = self.entropy_reg_coefficient * entropy
        policy_loss = policy_loss + entropy_loss

        # actor update
        actor_update_metrics = self.actor_update(policy_loss) or {}

        # metrics
        kl = 0.5 * (logprobs - old_logprobs).pow(2).mean()
        metrics = {
            "loss_actor": policy_loss.item(),
            "kl": kl.item(),
        }
        metrics = {**metrics, **actor_update_metrics}
        return metrics
Ejemplo n.º 2
0
    def train(self, batch, actor_update=True, critic_update=True):
        states_t, actions_t, rewards_t, states_tp1, done_t = \
            batch["state"], batch["action"], batch["reward"], \
            batch["next_state"], batch["done"]

        states_t = utils.any2device(states_t, device=self._device)
        actions_t = utils.any2device(actions_t, device=self._device)
        rewards_t = utils.any2device(
            rewards_t, device=self._device
        ).unsqueeze(1)
        states_tp1 = utils.any2device(states_tp1, device=self._device)
        done_t = utils.any2device(done_t, device=self._device).unsqueeze(1)
        """
        states_t: [bs; history_len; observation_len]
        actions_t: [bs; action_len]
        rewards_t: [bs; 1]
        states_tp1: [bs; history_len; observation_len]
        done_t: [bs; 1]
        """

        policy_loss, value_loss = self._loss_fn(
            states_t, actions_t, rewards_t, states_tp1, done_t
        )

        metrics = self.update_step(
            policy_loss=policy_loss,
            value_loss=value_loss,
            actor_update=actor_update,
            critic_update=critic_update
        )

        return metrics
Ejemplo n.º 3
0
    def _init(self, critics: List[CriticSpec], reward_scale: float = 1.0):
        self.reward_scale = reward_scale
        # @TODO: policy regularization

        critics = [x.to(self._device) for x in critics]
        target_critics = [copy.deepcopy(x).to(self._device) for x in critics]
        critics_optimizer = []
        critics_scheduler = []

        for critic in critics:
            critic_components = utils.get_trainer_components(
                agent=critic,
                loss_params=self._critic_loss_params,
                optimizer_params=self._critic_optimizer_params,
                scheduler_params=self._critic_scheduler_params,
                grad_clip_params=self._critic_grad_clip_params)
            critics_optimizer.append(critic_components["optimizer"])
            critics_scheduler.append(critic_components["scheduler"])

        self.critics = [self.critic] + critics
        self.critics_optimizer = [self.critic_optimizer] + critics_optimizer
        self.critics_scheduler = [self.critic_scheduler] + critics_scheduler
        self.target_critics = [self.target_critic] + target_critics

        # value distribution approximation
        critic_distribution = self.critic.distribution
        self._loss_fn = self._base_loss
        self._num_heads = self.critic.num_heads
        self._num_critics = len(self.critics)
        self._hyperbolic_constant = self.critic.hyperbolic_constant
        self._gammas = \
            utils.hyperbolic_gammas(
                self._gamma,
                self._hyperbolic_constant,
                self._num_heads
            )
        self._gammas = utils.any2device(self._gammas, device=self._device)
        assert critic_distribution in [None, "categorical", "quantile"]

        if critic_distribution == "categorical":
            self.num_atoms = self.critic.num_atoms
            values_range = self.critic.values_range
            self.v_min, self.v_max = values_range
            self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
            z = torch.linspace(start=self.v_min,
                               end=self.v_max,
                               steps=self.num_atoms)
            self.z = utils.any2device(z, device=self._device)
            self._loss_fn = self._categorical_loss
        elif critic_distribution == "quantile":
            self.num_atoms = self.critic.num_atoms
            tau_min = 1 / (2 * self.num_atoms)
            tau_max = 1 - tau_min
            tau = torch.linspace(start=tau_min,
                                 end=tau_max,
                                 steps=self.num_atoms)
            self.tau = utils.any2device(tau, device=self._device)
            self._loss_fn = self._quantile_loss
        else:
            assert self.critic_criterion is not None
Ejemplo n.º 4
0
    def get_rollout(self, states, actions, rewards, dones):
        assert len(states) == len(actions) == len(rewards) == len(dones)

        trajectory_len = \
            rewards.shape[0] if dones[-1] else rewards.shape[0] - 1
        states_len = states.shape[0]

        states = utils.any2device(states, device=self._device)
        actions = utils.any2device(actions, device=self._device)
        rewards = np.array(rewards)[:trajectory_len]
        values = torch.zeros(
            (states_len + 1, self._num_heads, self._num_atoms)).\
            to(self._device)
        values[:states_len, ...] = self.critic(states).squeeze_(dim=2)
        # Each column corresponds to a different gamma
        values = values.cpu().numpy()[:trajectory_len + 1, ...]
        _, logprobs = self.actor(states, logprob=actions)
        logprobs = logprobs.cpu().numpy().reshape(-1)[:trajectory_len]
        # len x num_heads
        deltas = rewards[:, None, None] \
            + self._gammas[:, None] * values[1:] - values[:-1]

        # For each gamma in the list of gammas compute the
        # advantage and returns
        # len x num_heads x num_atoms
        advantages = np.stack([
            utils.geometric_cumsum(gamma * self.gae_lambda, deltas[:, i])
            for i, gamma in enumerate(self._gammas)
        ],
                              axis=1)

        # len x num_heads
        returns = np.stack([
            utils.geometric_cumsum(gamma, rewards[:, None])[:, 0]
            for gamma in self._gammas
        ],
                           axis=1)

        # final rollout
        dones = dones[:trajectory_len]
        values = values[:trajectory_len]
        assert len(logprobs) == len(advantages) \
            == len(dones) == len(returns) == len(values)
        rollout = {
            "action_logprob": logprobs,
            "advantage": advantages,
            "done": dones,
            "return": returns,
            "value": values,
        }

        return rollout
Ejemplo n.º 5
0
    def _init(self,
              use_value_clipping: bool = True,
              gae_lambda: float = 0.95,
              clip_eps: float = 0.2,
              entropy_regularization: float = None):
        self.use_value_clipping = use_value_clipping
        self.gae_lambda = gae_lambda
        self.clip_eps = clip_eps
        self.entropy_regularization = entropy_regularization

        critic_distribution = self.critic.distribution
        self._value_loss_fn = self._base_value_loss
        self._num_atoms = self.critic.num_atoms
        self._num_heads = self.critic.num_heads
        self._hyperbolic_constant = self.critic.hyperbolic_constant
        self._gammas = \
            utils.hyperbolic_gammas(
                self._gamma,
                self._hyperbolic_constant,
                self._num_heads
            )
        # 1 x num_heads x 1
        self._gammas_torch = utils.any2device(self._gammas,
                                              device=self._device)[None, :,
                                                                   None]

        if critic_distribution == "categorical":
            self.num_atoms = self.critic.num_atoms
            values_range = self.critic.values_range
            self.v_min, self.v_max = values_range
            self.delta_z = (self.v_max - self.v_min) / (self._num_atoms - 1)
            z = torch.linspace(start=self.v_min,
                               end=self.v_max,
                               steps=self._num_atoms)
            self.z = utils.any2device(z, device=self._device)
            self._value_loss_fn = self._categorical_value_loss
        elif critic_distribution == "quantile":
            assert self.critic_criterion is not None

            self.num_atoms = self.critic.num_atoms
            tau_min = 1 / (2 * self._num_atoms)
            tau_max = 1 - tau_min
            tau = torch.linspace(start=tau_min,
                                 end=tau_max,
                                 steps=self._num_atoms)
            self.tau = utils.any2device(tau, device=self._device)
            self._value_loss_fn = self._quantile_value_loss

        if not self.use_value_clipping:
            assert self.critic_criterion is not None
Ejemplo n.º 6
0
    def train(self, batch, **kwargs):
        (states_t, actions_t, returns_t, states_tp1, done_t, values_t,
         advantages_t,
         action_logprobs_t) = (batch["state"], batch["action"],
                               batch["return"], batch["state_tp1"],
                               batch["done"], batch["value"],
                               batch["advantage"], batch["action_logprob"])

        states_t = utils.any2device(states_t, device=self._device)
        actions_t = utils.any2device(actions_t, device=self._device)
        returns_t = utils.any2device(returns_t,
                                     device=self._device).unsqueeze_(-1)
        states_tp1 = utils.any2device(states_tp1, device=self._device)
        done_t = utils.any2device(done_t, device=self._device)[:, None, None]

        values_t = utils.any2device(values_t, device=self._device)
        advantages_t = utils.any2device(advantages_t, device=self._device)
        action_logprobs_t = utils.any2device(action_logprobs_t,
                                             device=self._device)

        # critic loss
        value_loss = self._value_loss_fn(states_t, values_t, returns_t,
                                         states_tp1, done_t)

        # actor loss
        _, action_logprobs_tp0 = self.actor(states_t, logprob=actions_t)

        ratio = torch.exp(action_logprobs_tp0 - action_logprobs_t)
        ratio = ratio[:, None, None]
        # The same ratio for each head of the critic
        policy_loss_unclipped = advantages_t * ratio
        policy_loss_clipped = advantages_t * torch.clamp(
            ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps)
        policy_loss = -torch.min(policy_loss_unclipped,
                                 policy_loss_clipped).mean()

        if self.entropy_regularization is not None:
            entropy = -(torch.exp(action_logprobs_tp0) *
                        action_logprobs_tp0).mean()
            entropy_loss = self.entropy_regularization * entropy
            policy_loss = policy_loss + entropy_loss

        # actor update
        actor_update_metrics = self.actor_update(policy_loss) or {}

        # critic update
        critic_update_metrics = self.critic_update(value_loss) or {}

        # metrics
        kl = 0.5 * (action_logprobs_tp0 - action_logprobs_t).pow(2).mean()
        clipped_fraction = \
            (torch.abs(ratio - 1.0) > self.clip_eps).float().mean()
        metrics = {
            "loss_actor": policy_loss.item(),
            "loss_critic": value_loss.item(),
            "kl": kl.item(),
            "clipped_fraction": clipped_fraction.item()
        }
        metrics = {**metrics, **actor_update_metrics, **critic_update_metrics}
        return metrics
Ejemplo n.º 7
0
    def get_rollout(self, states, actions, rewards, dones):
        trajectory_len = \
            rewards.shape[0] if dones[-1] else rewards.shape[0] - 1

        states = utils.any2device(states, device=self._device)
        actions = utils.any2device(actions, device=self._device)
        rewards = np.array(rewards)[:trajectory_len]

        _, logprobs = self.actor(states, logprob=actions)
        logprobs = logprobs.cpu().numpy().reshape(-1)[:trajectory_len]

        returns = utils.geometric_cumsum(self.gamma, rewards)[0]

        rollout = {"return": returns, "action_logprob": logprobs}
        return rollout
Ejemplo n.º 8
0
def _state2device(array: np.ndarray, device):
    array = utils.any2device(array, device)

    if isinstance(array, dict):
        array = {
            key: value.to(device).unsqueeze(0)
            for key, value in array.items()
        }
    else:
        array = array.to(device).unsqueeze(0)

    return array
Ejemplo n.º 9
0
    def _init(self, entropy_regularization: float = None):
        self.entropy_regularization = entropy_regularization

        # value distribution approximation
        critic_distribution = self.critic.distribution
        self._loss_fn = self._base_loss
        self._num_heads = self.critic.num_heads
        self._hyperbolic_constant = self.critic.hyperbolic_constant
        self._gammas = \
            utils.hyperbolic_gammas(
                self._gamma,
                self._hyperbolic_constant,
                self._num_heads
            )
        self._gammas = utils.any2device(self._gammas, device=self._device)
        assert critic_distribution in [None, "categorical", "quantile"]

        if critic_distribution == "categorical":
            assert self.critic_criterion is None
            self.num_atoms = self.critic.num_atoms
            values_range = self.critic.values_range
            self.v_min, self.v_max = values_range
            self.delta_z = (self.v_max - self.v_min) / (self.num_atoms - 1)
            z = torch.linspace(start=self.v_min,
                               end=self.v_max,
                               steps=self.num_atoms)
            self.z = utils.any2device(z, device=self._device)
            self._loss_fn = self._categorical_loss
        elif critic_distribution == "quantile":
            assert self.critic_criterion is not None
            self.num_atoms = self.critic.num_atoms
            tau_min = 1 / (2 * self.num_atoms)
            tau_max = 1 - tau_min
            tau = torch.linspace(start=tau_min,
                                 end=tau_max,
                                 steps=self.num_atoms)
            self.tau = utils.any2device(tau, device=self._device)
            self._loss_fn = self._quantile_loss
        else:
            assert self.critic_criterion is not None
Ejemplo n.º 10
0
    def reset(self, exploration_strategy=None):

        from catalyst.rl.exploration import \
            ParameterSpaceNoise, OrnsteinUhlenbeckProcess

        if isinstance(exploration_strategy, OrnsteinUhlenbeckProcess):
            exploration_strategy.reset_state(self.env.action_space.shape[0])

        if isinstance(exploration_strategy, ParameterSpaceNoise) \
                and len(self.observations) > 1:
            states = self._get_states_history()
            states = utils.any2device(states, device=self._device)
            exploration_strategy.update_actor(self.agent, states)

        self._init_buffers()
        self._init_with_observation(self.env.reset())
Ejemplo n.º 11
0
    def train(self, batch, **kwargs):
        states, actions, returns, values, advantages, action_logprobs = \
            batch["state"], batch["action"], batch["return"], \
            batch["value"], batch["advantage"], batch["action_logprob"]

        states = utils.any2device(states, device=self._device)
        actions = utils.any2device(actions, device=self._device)
        returns = utils.any2device(returns, device=self._device)
        old_values = utils.any2device(values, device=self._device)
        advantages = utils.any2device(advantages, device=self._device)
        old_logprobs = utils.any2device(action_logprobs, device=self._device)

        # critic loss
        values = self.critic(states).squeeze(-1)

        values_clip = old_values + torch.clamp(values - old_values,
                                               -self.clip_eps, self.clip_eps)
        value_loss_unclipped = (values - returns).pow(2)
        value_loss_clipped = (values_clip - returns).pow(2)
        value_loss = 0.5 * torch.max(value_loss_unclipped,
                                     value_loss_clipped).mean()

        # actor loss
        _, logprobs = self.actor(states, logprob=actions)

        ratio = torch.exp(logprobs - old_logprobs)
        # The same ratio for each head of the critic
        policy_loss_unclipped = advantages * ratio[:, None]
        policy_loss_clipped = advantages * torch.clamp(
            ratio[:, None], 1.0 - self.clip_eps, 1.0 + self.clip_eps)
        policy_loss = -torch.min(policy_loss_unclipped,
                                 policy_loss_clipped).mean()

        entropy = -(torch.exp(logprobs) * logprobs).mean()
        entropy_loss = self.entropy_reg_coefficient * entropy
        policy_loss = policy_loss + entropy_loss

        # actor update
        actor_update_metrics = self.actor_update(policy_loss) or {}

        # critic update
        critic_update_metrics = self.critic_update(value_loss) or {}

        # metrics
        kl = 0.5 * (logprobs - old_logprobs).pow(2).mean()
        clipped_fraction = \
            (torch.abs(ratio - 1.0) > self.clip_eps).float().mean()
        metrics = {
            "loss_actor": policy_loss.item(),
            "loss_critic": value_loss.item(),
            "kl": kl.item(),
            "clipped_fraction": clipped_fraction.item()
        }
        metrics = {**metrics, **actor_update_metrics, **critic_update_metrics}
        return metrics