Exemple #1
0
    def __init__(self, state_size: int, action_size: int, hidden_layers=(300, 200), config=None, device=None, **kwargs):
        config = config if config is not None else {}
        self.device = device if device is not None else DEVICE

        self.state_size = state_size
        self.action_size = action_size
        self.iteration = 0

        self.actor_lr = float(config.get('actor_lr', 3e-4))
        self.critic_lr = float(config.get('critic_lr', 1e-3))
        self.gamma: float = float(config.get("gamma", 0.99))
        self.ppo_ratio_clip: float = float(config.get("ppo_ratio_clip", 0.2))

        self.rollout_length: int = int(config.get("rollout_length", 48))  # "Much less than the episode length"
        self.batch_size: int = int(config.get("batch_size", self.rollout_length // 2))
        self.number_updates: int = int(config.get("number_updates", 5))
        self.entropy_weight: float = float(config.get("entropy_weight", 0.0005))
        self.value_loss_weight: float = float(config.get("value_loss_weight", 1.0))

        self.local_memory_buffer = {}
        self.memory = ReplayBuffer(batch_size=self.batch_size, buffer_size=self.rollout_length)

        self.action_scale: float = float(config.get("action_scale", 1))
        self.action_min: float = float(config.get("action_min", -2))
        self.action_max: float = float(config.get("action_max", 2))
        self.max_grad_norm_actor: float = float(config.get("max_grad_norm_actor", 100.0))
        self.max_grad_norm_critic: float = float(config.get("max_grad_norm_critic", 100.0))

        self.hidden_layers = config.get('hidden_layers', hidden_layers)
        self.actor = ActorBody(state_size, action_size, self.hidden_layers).to(self.device)
        self.critic = CriticBody(state_size, action_size, self.hidden_layers).to(self.device)
        self.policy = GaussianPolicy(action_size).to(self.device)

        self.actor_params = list(self.actor.parameters()) + [self.policy.std]
        self.critic_params = self.critic.parameters()
        self.actor_opt = torch.optim.SGD(self.actor_params, lr=self.actor_lr)
        self.critic_opt = torch.optim.SGD(self.critic_params, lr=self.critic_lr)
Exemple #2
0
    def __init__(self, in_features: FeatureType, action_size: int, **kwargs):
        """
        Parameters:
            hidden_layers: (default: (128, 128)) Shape of the hidden layers that are fully connected networks.
            gamma: (default: 0.99) Discount value.
            tau: (default: 0.02) Soft copy fraction.
            batch_size: (default 64) Number of samples in a batch.
            buffer_size: (default: 1e6) Size of the prioritized experience replay buffer.
            warm_up: (default: 0) Number of samples that needs to be observed before starting to learn.
            update_freq: (default: 1) Number of samples between policy updates.
            number_updates: (default: 1) Number of times of batch sampling/training per `update_freq`.
            alpha: (default: 0.2) Weight of log probs in value function.
            alpha_lr: (default: None) If provided, it will add alpha as a training parameters and `alpha_lr` is its learning rate.
            action_scale: (default: 1.) Scale for returned action values.
            max_grad_norm_alpha: (default: 1.) Gradient clipping for the alpha.
            max_grad_norm_actor: (default 10.) Gradient clipping for the actor.
            max_grad_norm_critic: (default: 10.) Gradient clipping for the critic.
            device: Defaults to CUDA if available.

        """
        super().__init__(**kwargs)
        self.device = kwargs.get("device", DEVICE)
        self.in_features: Tuple[int] = (in_features, ) if isinstance(
            in_features, int) else tuple(in_features)
        self.state_size: int = in_features if isinstance(
            in_features, int) else reduce(operator.mul, in_features)
        self.action_size = action_size

        self.gamma: float = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau: float = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size: int = int(
            self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size: int = int(
            self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.memory = PERBuffer(self.batch_size, self.buffer_size)

        self.action_min = self._register_param(kwargs, 'action_min', -1)
        self.action_max = self._register_param(kwargs, 'action_max', 1)
        self.action_scale = self._register_param(kwargs, 'action_scale', 1)

        self.warm_up = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq = int(self._register_param(kwargs, 'update_freq', 1))
        self.number_updates = int(
            self._register_param(kwargs, 'number_updates', 1))
        self.actor_number_updates = int(
            self._register_param(kwargs, 'actor_number_updates', 1))
        self.critic_number_updates = int(
            self._register_param(kwargs, 'critic_number_updates', 1))

        # Reason sequence initiation.
        hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'hidden_layers', (128, 128)))
        actor_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'actor_hidden_layers', hidden_layers))
        critic_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'critic_hidden_layers',
                                 hidden_layers))

        self.simple_policy = bool(
            self._register_param(kwargs, "simple_policy", False))
        if self.simple_policy:
            self.policy = MultivariateGaussianPolicySimple(
                self.action_size, **kwargs)
            self.actor = ActorBody(self.state_size,
                                   self.policy.param_dim * self.action_size,
                                   hidden_layers=actor_hidden_layers,
                                   device=self.device)
        else:
            self.policy = GaussianPolicy(actor_hidden_layers[-1],
                                         self.action_size,
                                         out_scale=self.action_scale,
                                         device=self.device)
            self.actor = ActorBody(self.state_size,
                                   actor_hidden_layers[-1],
                                   hidden_layers=actor_hidden_layers[:-1],
                                   device=self.device)

        self.double_critic = DoubleCritic(self.in_features,
                                          self.action_size,
                                          CriticBody,
                                          hidden_layers=critic_hidden_layers,
                                          device=self.device)
        self.target_double_critic = DoubleCritic(
            self.in_features,
            self.action_size,
            CriticBody,
            hidden_layers=critic_hidden_layers,
            device=self.device)

        # Target sequence initiation
        hard_update(self.target_double_critic, self.double_critic)

        # Optimization sequence initiation.
        self.target_entropy = -self.action_size
        alpha_lr = self._register_param(kwargs, "alpha_lr")
        self.alpha_lr = float(alpha_lr) if alpha_lr else None
        alpha_init = float(self._register_param(kwargs, "alpha", 0.2))
        self.log_alpha = torch.tensor(np.log(alpha_init),
                                      device=self.device,
                                      requires_grad=True)

        actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4))

        self.actor_params = list(self.actor.parameters()) + list(
            self.policy.parameters())
        self.critic_params = list(self.double_critic.parameters())
        self.actor_optimizer = optim.Adam(self.actor_params, lr=actor_lr)
        self.critic_optimizer = optim.Adam(list(self.critic_params),
                                           lr=critic_lr)
        if self.alpha_lr is not None:
            self.alpha_optimizer = optim.Adam([self.log_alpha],
                                              lr=self.alpha_lr)
        self.max_grad_norm_alpha = float(
            self._register_param(kwargs, "max_grad_norm_alpha", 1.0))
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 10.0))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 10.0))

        # Breath, my child.
        self.iteration = 0

        self._loss_actor = float('inf')
        self._loss_critic = float('inf')
        self._metrics: Dict[str, Union[float, Dict[str, float]]] = {}
Exemple #3
0
class SACAgent(AgentBase):
    """
    Soft Actor-Critic.

    Uses stochastic policy and dual value network (two critics).

    Based on
    "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor"
    by Haarnoja et al. (2018) (http://arxiv.org/abs/1801.01290).
    """

    name = "SAC"

    def __init__(self, in_features: FeatureType, action_size: int, **kwargs):
        """
        Parameters:
            hidden_layers: (default: (128, 128)) Shape of the hidden layers that are fully connected networks.
            gamma: (default: 0.99) Discount value.
            tau: (default: 0.02) Soft copy fraction.
            batch_size: (default 64) Number of samples in a batch.
            buffer_size: (default: 1e6) Size of the prioritized experience replay buffer.
            warm_up: (default: 0) Number of samples that needs to be observed before starting to learn.
            update_freq: (default: 1) Number of samples between policy updates.
            number_updates: (default: 1) Number of times of batch sampling/training per `update_freq`.
            alpha: (default: 0.2) Weight of log probs in value function.
            alpha_lr: (default: None) If provided, it will add alpha as a training parameters and `alpha_lr` is its learning rate.
            action_scale: (default: 1.) Scale for returned action values.
            max_grad_norm_alpha: (default: 1.) Gradient clipping for the alpha.
            max_grad_norm_actor: (default 10.) Gradient clipping for the actor.
            max_grad_norm_critic: (default: 10.) Gradient clipping for the critic.
            device: Defaults to CUDA if available.

        """
        super().__init__(**kwargs)
        self.device = kwargs.get("device", DEVICE)
        self.in_features: Tuple[int] = (in_features, ) if isinstance(
            in_features, int) else tuple(in_features)
        self.state_size: int = in_features if isinstance(
            in_features, int) else reduce(operator.mul, in_features)
        self.action_size = action_size

        self.gamma: float = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau: float = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size: int = int(
            self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size: int = int(
            self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.memory = PERBuffer(self.batch_size, self.buffer_size)

        self.action_min = self._register_param(kwargs, 'action_min', -1)
        self.action_max = self._register_param(kwargs, 'action_max', 1)
        self.action_scale = self._register_param(kwargs, 'action_scale', 1)

        self.warm_up = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq = int(self._register_param(kwargs, 'update_freq', 1))
        self.number_updates = int(
            self._register_param(kwargs, 'number_updates', 1))
        self.actor_number_updates = int(
            self._register_param(kwargs, 'actor_number_updates', 1))
        self.critic_number_updates = int(
            self._register_param(kwargs, 'critic_number_updates', 1))

        # Reason sequence initiation.
        hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'hidden_layers', (128, 128)))
        actor_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'actor_hidden_layers', hidden_layers))
        critic_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'critic_hidden_layers',
                                 hidden_layers))

        self.simple_policy = bool(
            self._register_param(kwargs, "simple_policy", False))
        if self.simple_policy:
            self.policy = MultivariateGaussianPolicySimple(
                self.action_size, **kwargs)
            self.actor = ActorBody(self.state_size,
                                   self.policy.param_dim * self.action_size,
                                   hidden_layers=actor_hidden_layers,
                                   device=self.device)
        else:
            self.policy = GaussianPolicy(actor_hidden_layers[-1],
                                         self.action_size,
                                         out_scale=self.action_scale,
                                         device=self.device)
            self.actor = ActorBody(self.state_size,
                                   actor_hidden_layers[-1],
                                   hidden_layers=actor_hidden_layers[:-1],
                                   device=self.device)

        self.double_critic = DoubleCritic(self.in_features,
                                          self.action_size,
                                          CriticBody,
                                          hidden_layers=critic_hidden_layers,
                                          device=self.device)
        self.target_double_critic = DoubleCritic(
            self.in_features,
            self.action_size,
            CriticBody,
            hidden_layers=critic_hidden_layers,
            device=self.device)

        # Target sequence initiation
        hard_update(self.target_double_critic, self.double_critic)

        # Optimization sequence initiation.
        self.target_entropy = -self.action_size
        alpha_lr = self._register_param(kwargs, "alpha_lr")
        self.alpha_lr = float(alpha_lr) if alpha_lr else None
        alpha_init = float(self._register_param(kwargs, "alpha", 0.2))
        self.log_alpha = torch.tensor(np.log(alpha_init),
                                      device=self.device,
                                      requires_grad=True)

        actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4))

        self.actor_params = list(self.actor.parameters()) + list(
            self.policy.parameters())
        self.critic_params = list(self.double_critic.parameters())
        self.actor_optimizer = optim.Adam(self.actor_params, lr=actor_lr)
        self.critic_optimizer = optim.Adam(list(self.critic_params),
                                           lr=critic_lr)
        if self.alpha_lr is not None:
            self.alpha_optimizer = optim.Adam([self.log_alpha],
                                              lr=self.alpha_lr)
        self.max_grad_norm_alpha = float(
            self._register_param(kwargs, "max_grad_norm_alpha", 1.0))
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 10.0))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 10.0))

        # Breath, my child.
        self.iteration = 0

        self._loss_actor = float('inf')
        self._loss_critic = float('inf')
        self._metrics: Dict[str, Union[float, Dict[str, float]]] = {}

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @property
    def loss(self):
        return {'actor': self._loss_actor, 'critic': self._loss_critic}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            self._loss_actor = value['actor']
            self._loss_critic = value['critic']
        else:
            self._loss_actor = value
            self._loss_critic = value

    def reset_agent(self) -> None:
        self.actor.reset_parameters()
        self.policy.reset_parameters()
        self.double_critic.reset_parameters()
        hard_update(self.target_double_critic, self.double_critic)

    def state_dict(self) -> Dict[str, dict]:
        """
        Returns network's weights in order:
        Actor, TargetActor, Critic, TargetCritic
        """
        return {
            "actor": self.actor.state_dict(),
            "policy": self.policy.state_dict(),
            "double_critic": self.double_critic.state_dict(),
            "target_double_critic": self.target_double_critic.state_dict(),
        }

    @torch.no_grad()
    def act(self,
            state,
            epsilon: float = 0.0,
            deterministic=False) -> List[float]:
        if self.iteration < self.warm_up or self._rng.random() < epsilon:
            random_action = torch.rand(self.action_size) * (
                self.action_max + self.action_min) + self.action_min
            return random_action.cpu().tolist()

        state = to_tensor(state).view(1,
                                      self.state_size).float().to(self.device)
        proto_action = self.actor(state)
        action = self.policy(proto_action, deterministic)

        return action.flatten().tolist()

    def step(self, state, action, reward, next_state, done):
        self.iteration += 1
        self.memory.add(
            state=state,
            action=action,
            reward=reward,
            next_state=next_state,
            done=done,
        )

        if self.iteration < self.warm_up:
            return

        if len(self.memory) > self.batch_size and (self.iteration %
                                                   self.update_freq) == 0:
            for _ in range(self.number_updates):
                self.learn(self.memory.sample())

    def compute_value_loss(self, states, actions, rewards, next_states,
                           dones) -> Tuple[Tensor, Tensor]:
        Q1_expected, Q2_expected = self.double_critic(states, actions)

        with torch.no_grad():
            proto_next_action = self.actor(states)
            next_actions = self.policy(proto_next_action)
            log_prob = self.policy.logprob
            assert next_actions.shape == (self.batch_size, self.action_size)
            assert log_prob.shape == (self.batch_size, 1)

            Q1_target_next, Q2_target_next = self.target_double_critic.act(
                next_states, next_actions)
            assert Q1_target_next.shape == Q2_target_next.shape == (
                self.batch_size, 1)

            Q_min = torch.min(Q1_target_next, Q2_target_next)
            QH_target = Q_min - self.alpha * log_prob
            assert QH_target.shape == (self.batch_size, 1)

            Q_target = rewards + self.gamma * QH_target * (1 - dones)
            assert Q_target.shape == (self.batch_size, 1)

        Q1_diff = Q1_expected - Q_target
        error_1 = Q1_diff.pow(2)
        mse_loss_1 = error_1.mean()
        self._metrics['value/critic1'] = {
            'mean': float(Q1_expected.mean()),
            'std': float(Q1_expected.std())
        }
        self._metrics['value/critic1_lse'] = float(mse_loss_1.item())

        Q2_diff = Q2_expected - Q_target
        error_2 = Q2_diff.pow(2)
        mse_loss_2 = error_2.mean()
        self._metrics['value/critic2'] = {
            'mean': float(Q2_expected.mean()),
            'std': float(Q2_expected.std())
        }
        self._metrics['value/critic2_lse'] = float(mse_loss_2.item())

        Q_diff = Q1_expected - Q2_expected
        self._metrics['value/Q_diff'] = {
            'mean': float(Q_diff.mean()),
            'std': float(Q_diff.std())
        }

        error = torch.min(error_1, error_2)
        loss = mse_loss_1 + mse_loss_2
        return loss, error

    def compute_policy_loss(self, states):
        proto_actions = self.actor(states)
        pred_actions = self.policy(proto_actions)
        log_prob = self.policy.logprob
        assert pred_actions.shape == (self.batch_size, self.action_size)

        Q_estimate = torch.min(*self.double_critic(states, pred_actions))
        assert Q_estimate.shape == (self.batch_size, 1)

        self._metrics['policy/entropy'] = -float(log_prob.detach().mean())
        loss = (self.alpha * log_prob - Q_estimate).mean()

        # Update alpha
        if self.alpha_lr is not None:
            self.alpha_optimizer.zero_grad()
            loss_alpha = -(self.alpha *
                           (log_prob + self.target_entropy).detach()).mean()
            loss_alpha.backward()
            nn.utils.clip_grad_norm_(self.log_alpha, self.max_grad_norm_alpha)
            self.alpha_optimizer.step()

        return loss

    def learn(self, samples):
        """update the critics and actors of all the agents """

        rewards = to_tensor(samples['reward']).float().to(self.device).view(
            self.batch_size, 1)
        dones = to_tensor(samples['done']).int().to(self.device).view(
            self.batch_size, 1)
        states = to_tensor(samples['state']).float().to(self.device).view(
            self.batch_size, self.state_size)
        next_states = to_tensor(samples['next_state']).float().to(
            self.device).view(self.batch_size, self.state_size)
        actions = to_tensor(samples['action']).to(self.device).view(
            self.batch_size, self.action_size)

        # Critic (value) update
        for _ in range(self.critic_number_updates):
            value_loss, error = self.compute_value_loss(
                states, actions, rewards, next_states, dones)
            self.critic_optimizer.zero_grad()
            value_loss.backward()
            nn.utils.clip_grad_norm_(self.critic_params,
                                     self.max_grad_norm_critic)
            self.critic_optimizer.step()
            self._loss_critic = value_loss.item()

        # Actor (policy) update
        for _ in range(self.actor_number_updates):
            policy_loss = self.compute_policy_loss(states)
            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            nn.utils.clip_grad_norm_(self.actor_params,
                                     self.max_grad_norm_actor)
            self.actor_optimizer.step()
            self._loss_actor = policy_loss.item()

        if hasattr(self.memory, 'priority_update'):
            assert any(~torch.isnan(error))
            self.memory.priority_update(samples['index'], error.abs())

        soft_update(self.target_double_critic, self.double_critic, self.tau)

    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)
        data_logger.log_value("loss/alpha", self.alpha, step)

        if self.simple_policy:
            policy_params = {
                str(i): v
                for i, v in enumerate(
                    itertools.chain.from_iterable(self.policy.parameters()))
            }
            data_logger.log_values_dict("policy/param", policy_params, step)

        for name, value in self._metrics.items():
            if isinstance(value, dict):
                data_logger.log_values_dict(name, value, step)
            else:
                data_logger.log_value(name, value, step)

        if full_log:
            # TODO: Add Policy layers
            for idx, layer in enumerate(self.actor.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"policy/layer_weights_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"policy/layer_bias_{idx}",
                                                 layer.bias, step)

            for idx, layer in enumerate(self.double_critic.critic_1.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"critic_1/layer_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"critic_1/layer_bias_{idx}",
                                                 layer.bias, step)

            for idx, layer in enumerate(self.double_critic.critic_2.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"critic_2/layer_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"critic_2/layer_bias_{idx}",
                                                 layer.bias, step)

    def get_state(self):
        return dict(
            actor=self.actor.state_dict(),
            policy=self.policy.state_dict(),
            double_critic=self.double_critic.state_dict(),
            target_double_critic=self.target_double_critic.state_dict(),
            config=self._config,
        )

    def save_state(self, path: str):
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self, path: str):
        agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.actor.load_state_dict(agent_state['actor'])
        self.policy.load_state_dict(agent_state['policy'])
        self.double_critic.load_state_dict(agent_state['double_critic'])
        self.target_double_critic.load_state_dict(
            agent_state['target_double_critic'])
Exemple #4
0
    def __init__(self,
                 state_size: int,
                 action_size: int,
                 hidden_layers: Sequence[int] = (128, 128),
                 actor_lr: float = 2e-3,
                 critic_lr: float = 2e-3,
                 clip: Tuple[int, int] = (-1, 1),
                 alpha: float = 0.2,
                 device=None,
                 **kwargs):
        self.device = device if device is not None else DEVICE
        self.action_size = action_size

        # Reason sequence initiation.
        self.hidden_layers = kwargs.get('hidden_layers', hidden_layers)
        self.policy = GaussianPolicy(action_size).to(self.device)
        self.actor = ActorBody(state_size,
                               action_size,
                               hidden_layers=hidden_layers).to(self.device)

        self.double_critic = DoubleCritic(state_size, action_size,
                                          hidden_layers).to(self.device)
        self.target_double_critic = DoubleCritic(state_size, action_size,
                                                 hidden_layers).to(self.device)

        # Target sequence initiation
        hard_update(self.target_double_critic, self.double_critic)

        # Optimization sequence initiation.
        self.target_entropy = -action_size
        self.alpha_lr = kwargs.get("alpha_lr")
        alpha_init = kwargs.get("alpha", alpha)
        self.log_alpha = torch.tensor(np.log(alpha_init),
                                      device=self.device,
                                      requires_grad=True)

        self.actor_params = list(self.actor.parameters()) + [self.policy.std]
        self.critic_params = list(self.double_critic.parameters())
        self.actor_optimizer = optim.Adam(self.actor_params, lr=actor_lr)
        self.critic_optimizer = optim.Adam(list(self.critic_params),
                                           lr=critic_lr)
        if self.alpha_lr is not None:
            self.alpha_optimizer = optim.Adam([self.log_alpha],
                                              lr=self.alpha_lr)
        self.action_min = clip[0]
        self.action_max = clip[1]
        self.action_scale = kwargs.get('action_scale', 1)
        self.max_grad_norm_alpha: float = float(
            kwargs.get("max_grad_norm_alpha", 1.0))
        self.max_grad_norm_actor: float = float(
            kwargs.get("max_grad_norm_actor", 20.0))
        self.max_grad_norm_critic: float = float(
            kwargs.get("max_grad_norm_critic", 20.0))

        self.gamma: float = float(kwargs.get('gamma', 0.99))
        self.tau: float = float(kwargs.get('tau', 0.02))
        self.batch_size: int = int(kwargs.get('batch_size', 64))
        self.buffer_size: int = int(kwargs.get('buffer_size', int(1e6)))
        self.memory = Buffer(self.batch_size, self.buffer_size)

        self.warm_up: int = int(kwargs.get('warm_up', 0))
        self.update_freq: int = int(kwargs.get('update_freq', 1))
        self.number_updates: int = int(kwargs.get('number_updates', 1))

        # Breath, my child.
        self.reset_agent()
        self.iteration = 0

        self.actor_loss = np.nan
        self.critic_loss = np.nan
Exemple #5
0
class PPOAgent(AgentType):

    name = "PPO"

    def __init__(self, state_size: int, action_size: int, hidden_layers=(300, 200), config=None, device=None, **kwargs):
        config = config if config is not None else {}
        self.device = device if device is not None else DEVICE

        self.state_size = state_size
        self.action_size = action_size
        self.iteration = 0

        self.actor_lr = float(config.get('actor_lr', 3e-4))
        self.critic_lr = float(config.get('critic_lr', 1e-3))
        self.gamma: float = float(config.get("gamma", 0.99))
        self.ppo_ratio_clip: float = float(config.get("ppo_ratio_clip", 0.2))

        self.rollout_length: int = int(config.get("rollout_length", 48))  # "Much less than the episode length"
        self.batch_size: int = int(config.get("batch_size", self.rollout_length // 2))
        self.number_updates: int = int(config.get("number_updates", 5))
        self.entropy_weight: float = float(config.get("entropy_weight", 0.0005))
        self.value_loss_weight: float = float(config.get("value_loss_weight", 1.0))

        self.local_memory_buffer = {}
        self.memory = ReplayBuffer(batch_size=self.batch_size, buffer_size=self.rollout_length)

        self.action_scale: float = float(config.get("action_scale", 1))
        self.action_min: float = float(config.get("action_min", -2))
        self.action_max: float = float(config.get("action_max", 2))
        self.max_grad_norm_actor: float = float(config.get("max_grad_norm_actor", 100.0))
        self.max_grad_norm_critic: float = float(config.get("max_grad_norm_critic", 100.0))

        self.hidden_layers = config.get('hidden_layers', hidden_layers)
        self.actor = ActorBody(state_size, action_size, self.hidden_layers).to(self.device)
        self.critic = CriticBody(state_size, action_size, self.hidden_layers).to(self.device)
        self.policy = GaussianPolicy(action_size).to(self.device)

        self.actor_params = list(self.actor.parameters()) + [self.policy.std]
        self.critic_params = self.critic.parameters()
        self.actor_opt = torch.optim.SGD(self.actor_params, lr=self.actor_lr)
        self.critic_opt = torch.optim.SGD(self.critic_params, lr=self.critic_lr)

    def __clear_memory(self):
        self.memory = ReplayBuffer(batch_size=self.batch_size, buffer_size=self.rollout_length)

    def act(self, state, noise=0):
        with torch.no_grad():
            state = torch.tensor(state.reshape(1, -1).astype(np.float32)).to(self.device)
            action_mu = self.actor(state)
            value = self.critic(state, action_mu)

            dist = self.policy(action_mu)
            action = dist.sample()
            logprob = dist.log_prob(action)

            self.local_memory_buffer['value'] = value
            self.local_memory_buffer['logprob'] = logprob

            action = action.cpu().numpy().flatten()
            return np.clip(action*self.action_scale, self.action_min, self.action_max)

    def step(self, states, actions, rewards, next_state, done, **kwargs):
        self.iteration += 1

        self.memory.add(
            state=states, action=actions, reward=rewards, done=done,
            logprob=self.local_memory_buffer['logprob'], value=self.local_memory_buffer['value']
        )

        if self.iteration % self.rollout_length == 0:
            self.update()
            self.__clear_memory()

    def ppo_iter(self, mini_batch_size, states, actions, log_probs, returns, advantage):
        all_indices = np.arange(self.batch_size)
        for _ in range(self.batch_size // mini_batch_size):
            rand_ids = np.random.choice(all_indices, mini_batch_size, replace=False)
            yield states[rand_ids], actions[rand_ids], log_probs[rand_ids], returns[rand_ids], advantage[rand_ids]

    def _unpack_experiences(self, experiences):
        unpacked_experiences = defaultdict(lambda: [])
        for experience in experiences:
            unpacked_experiences['rewards'].append(experience.reward)
            unpacked_experiences['dones'].append(experience.done)
            unpacked_experiences['values'].append(experience.value)
            unpacked_experiences['states'].append(experience.state)
            unpacked_experiences['actions'].append(experience.action)
            unpacked_experiences['logprobs'].append(experience.logprob)

        return unpacked_experiences

    def update(self):
        experiences = self.memory.sample()
        rewards = torch.tensor(experiences['reward']).to(self.device)
        dones = torch.tensor(experiences['done']).type(torch.int).to(self.device)
        states = torch.tensor(experiences['state']).to(self.device)
        actions = torch.tensor(experiences['action']).to(self.device)
        values = torch.cat(experiences['value'])
        log_probs = torch.cat(experiences['logprob'])

        returns = revert_norm_returns(rewards, dones, self.gamma, device=self.device).unsqueeze(1)
        advantages = returns - values

        for _ in range(self.number_updates):
            for samples in self.ppo_iter(self.batch_size, states, actions, log_probs, returns, advantages):
                self.learn(samples)

    def learn(self, samples):
        state, action, old_log_probs, return_, advantage = samples

        action_mu = self.actor(state.detach())
        dist = self.policy(action_mu)
        value = self.critic(state.detach(), action_mu.detach())

        entropy = dist.entropy()
        new_log_probs = dist.log_prob(action.detach())

        r_theta = (new_log_probs - old_log_probs).exp()
        r_theta_clip = torch.clamp(r_theta, 1.0 - self.ppo_ratio_clip, 1.0 + self.ppo_ratio_clip)

        policy_loss = -torch.min(r_theta * advantage, r_theta_clip * advantage).mean()
        entropy_loss = -self.entropy_weight * entropy.mean()
        actor_loss = policy_loss + entropy_loss

        self.actor_opt.zero_grad()
        actor_loss.backward()
        nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_actor)
        self.actor_opt.step()
        self.actor_loss = actor_loss.item()
        # loss = policy_loss + value_loss + entropy_loss

        value_loss = self.value_loss_weight * 0.5 * (return_ - value).pow(2).mean()

        self.critic_opt.zero_grad()
        value_loss.backward()
        nn.utils.clip_grad_norm_(self.critic_params, self.max_grad_norm_critic)
        self.critic_opt.step()
        self.critic_loss = value_loss.mean().item()

    def log_writer(self, writer, episode):
        writer.add_scalar("loss/actor", self.actor_loss, episode)
        writer.add_scalar("loss/critic", self.critic_loss, episode)

    def save_state(self, path: str):
        agent_state = dict(policy=self.policy.state_dict())
        torch.save(agent_state, path)

    def load_state(self, path: str):
        agent_state = torch.load(path)
        self.policy.load_state_dict(agent_state['policy'])