Ejemplo n.º 1
0
class SACAgent(AgentBase):
    """
    Soft Actor-Critic.

    Uses stochastic policy and dual value network (two critics).

    Based on
    "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor"
    by Haarnoja et al. (2018) (http://arxiv.org/abs/1801.01290).
    """

    name = "SAC"

    def __init__(self, in_features: FeatureType, action_size: int, **kwargs):
        """
        Parameters:
            hidden_layers: (default: (128, 128)) Shape of the hidden layers that are fully connected networks.
            gamma: (default: 0.99) Discount value.
            tau: (default: 0.02) Soft copy fraction.
            batch_size: (default 64) Number of samples in a batch.
            buffer_size: (default: 1e6) Size of the prioritized experience replay buffer.
            warm_up: (default: 0) Number of samples that needs to be observed before starting to learn.
            update_freq: (default: 1) Number of samples between policy updates.
            number_updates: (default: 1) Number of times of batch sampling/training per `update_freq`.
            alpha: (default: 0.2) Weight of log probs in value function.
            alpha_lr: (default: None) If provided, it will add alpha as a training parameters and `alpha_lr` is its learning rate.
            action_scale: (default: 1.) Scale for returned action values.
            max_grad_norm_alpha: (default: 1.) Gradient clipping for the alpha.
            max_grad_norm_actor: (default 10.) Gradient clipping for the actor.
            max_grad_norm_critic: (default: 10.) Gradient clipping for the critic.
            device: Defaults to CUDA if available.

        """
        super().__init__(**kwargs)
        self.device = kwargs.get("device", DEVICE)
        self.in_features: Tuple[int] = (in_features, ) if isinstance(
            in_features, int) else tuple(in_features)
        self.state_size: int = in_features if isinstance(
            in_features, int) else reduce(operator.mul, in_features)
        self.action_size = action_size

        self.gamma: float = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau: float = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size: int = int(
            self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size: int = int(
            self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.memory = PERBuffer(self.batch_size, self.buffer_size)

        self.action_min = self._register_param(kwargs, 'action_min', -1)
        self.action_max = self._register_param(kwargs, 'action_max', 1)
        self.action_scale = self._register_param(kwargs, 'action_scale', 1)

        self.warm_up = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq = int(self._register_param(kwargs, 'update_freq', 1))
        self.number_updates = int(
            self._register_param(kwargs, 'number_updates', 1))
        self.actor_number_updates = int(
            self._register_param(kwargs, 'actor_number_updates', 1))
        self.critic_number_updates = int(
            self._register_param(kwargs, 'critic_number_updates', 1))

        # Reason sequence initiation.
        hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'hidden_layers', (128, 128)))
        actor_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'actor_hidden_layers', hidden_layers))
        critic_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'critic_hidden_layers',
                                 hidden_layers))

        self.simple_policy = bool(
            self._register_param(kwargs, "simple_policy", False))
        if self.simple_policy:
            self.policy = MultivariateGaussianPolicySimple(
                self.action_size, **kwargs)
            self.actor = ActorBody(self.state_size,
                                   self.policy.param_dim * self.action_size,
                                   hidden_layers=actor_hidden_layers,
                                   device=self.device)
        else:
            self.policy = GaussianPolicy(actor_hidden_layers[-1],
                                         self.action_size,
                                         out_scale=self.action_scale,
                                         device=self.device)
            self.actor = ActorBody(self.state_size,
                                   actor_hidden_layers[-1],
                                   hidden_layers=actor_hidden_layers[:-1],
                                   device=self.device)

        self.double_critic = DoubleCritic(self.in_features,
                                          self.action_size,
                                          CriticBody,
                                          hidden_layers=critic_hidden_layers,
                                          device=self.device)
        self.target_double_critic = DoubleCritic(
            self.in_features,
            self.action_size,
            CriticBody,
            hidden_layers=critic_hidden_layers,
            device=self.device)

        # Target sequence initiation
        hard_update(self.target_double_critic, self.double_critic)

        # Optimization sequence initiation.
        self.target_entropy = -self.action_size
        alpha_lr = self._register_param(kwargs, "alpha_lr")
        self.alpha_lr = float(alpha_lr) if alpha_lr else None
        alpha_init = float(self._register_param(kwargs, "alpha", 0.2))
        self.log_alpha = torch.tensor(np.log(alpha_init),
                                      device=self.device,
                                      requires_grad=True)

        actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4))

        self.actor_params = list(self.actor.parameters()) + list(
            self.policy.parameters())
        self.critic_params = list(self.double_critic.parameters())
        self.actor_optimizer = optim.Adam(self.actor_params, lr=actor_lr)
        self.critic_optimizer = optim.Adam(list(self.critic_params),
                                           lr=critic_lr)
        if self.alpha_lr is not None:
            self.alpha_optimizer = optim.Adam([self.log_alpha],
                                              lr=self.alpha_lr)
        self.max_grad_norm_alpha = float(
            self._register_param(kwargs, "max_grad_norm_alpha", 1.0))
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 10.0))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 10.0))

        # Breath, my child.
        self.iteration = 0

        self._loss_actor = float('inf')
        self._loss_critic = float('inf')
        self._metrics: Dict[str, Union[float, Dict[str, float]]] = {}

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @property
    def loss(self):
        return {'actor': self._loss_actor, 'critic': self._loss_critic}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            self._loss_actor = value['actor']
            self._loss_critic = value['critic']
        else:
            self._loss_actor = value
            self._loss_critic = value

    def reset_agent(self) -> None:
        self.actor.reset_parameters()
        self.policy.reset_parameters()
        self.double_critic.reset_parameters()
        hard_update(self.target_double_critic, self.double_critic)

    def state_dict(self) -> Dict[str, dict]:
        """
        Returns network's weights in order:
        Actor, TargetActor, Critic, TargetCritic
        """
        return {
            "actor": self.actor.state_dict(),
            "policy": self.policy.state_dict(),
            "double_critic": self.double_critic.state_dict(),
            "target_double_critic": self.target_double_critic.state_dict(),
        }

    @torch.no_grad()
    def act(self,
            state,
            epsilon: float = 0.0,
            deterministic=False) -> List[float]:
        if self.iteration < self.warm_up or self._rng.random() < epsilon:
            random_action = torch.rand(self.action_size) * (
                self.action_max + self.action_min) + self.action_min
            return random_action.cpu().tolist()

        state = to_tensor(state).view(1,
                                      self.state_size).float().to(self.device)
        proto_action = self.actor(state)
        action = self.policy(proto_action, deterministic)

        return action.flatten().tolist()

    def step(self, state, action, reward, next_state, done):
        self.iteration += 1
        self.memory.add(
            state=state,
            action=action,
            reward=reward,
            next_state=next_state,
            done=done,
        )

        if self.iteration < self.warm_up:
            return

        if len(self.memory) > self.batch_size and (self.iteration %
                                                   self.update_freq) == 0:
            for _ in range(self.number_updates):
                self.learn(self.memory.sample())

    def compute_value_loss(self, states, actions, rewards, next_states,
                           dones) -> Tuple[Tensor, Tensor]:
        Q1_expected, Q2_expected = self.double_critic(states, actions)

        with torch.no_grad():
            proto_next_action = self.actor(states)
            next_actions = self.policy(proto_next_action)
            log_prob = self.policy.logprob
            assert next_actions.shape == (self.batch_size, self.action_size)
            assert log_prob.shape == (self.batch_size, 1)

            Q1_target_next, Q2_target_next = self.target_double_critic.act(
                next_states, next_actions)
            assert Q1_target_next.shape == Q2_target_next.shape == (
                self.batch_size, 1)

            Q_min = torch.min(Q1_target_next, Q2_target_next)
            QH_target = Q_min - self.alpha * log_prob
            assert QH_target.shape == (self.batch_size, 1)

            Q_target = rewards + self.gamma * QH_target * (1 - dones)
            assert Q_target.shape == (self.batch_size, 1)

        Q1_diff = Q1_expected - Q_target
        error_1 = Q1_diff.pow(2)
        mse_loss_1 = error_1.mean()
        self._metrics['value/critic1'] = {
            'mean': float(Q1_expected.mean()),
            'std': float(Q1_expected.std())
        }
        self._metrics['value/critic1_lse'] = float(mse_loss_1.item())

        Q2_diff = Q2_expected - Q_target
        error_2 = Q2_diff.pow(2)
        mse_loss_2 = error_2.mean()
        self._metrics['value/critic2'] = {
            'mean': float(Q2_expected.mean()),
            'std': float(Q2_expected.std())
        }
        self._metrics['value/critic2_lse'] = float(mse_loss_2.item())

        Q_diff = Q1_expected - Q2_expected
        self._metrics['value/Q_diff'] = {
            'mean': float(Q_diff.mean()),
            'std': float(Q_diff.std())
        }

        error = torch.min(error_1, error_2)
        loss = mse_loss_1 + mse_loss_2
        return loss, error

    def compute_policy_loss(self, states):
        proto_actions = self.actor(states)
        pred_actions = self.policy(proto_actions)
        log_prob = self.policy.logprob
        assert pred_actions.shape == (self.batch_size, self.action_size)

        Q_estimate = torch.min(*self.double_critic(states, pred_actions))
        assert Q_estimate.shape == (self.batch_size, 1)

        self._metrics['policy/entropy'] = -float(log_prob.detach().mean())
        loss = (self.alpha * log_prob - Q_estimate).mean()

        # Update alpha
        if self.alpha_lr is not None:
            self.alpha_optimizer.zero_grad()
            loss_alpha = -(self.alpha *
                           (log_prob + self.target_entropy).detach()).mean()
            loss_alpha.backward()
            nn.utils.clip_grad_norm_(self.log_alpha, self.max_grad_norm_alpha)
            self.alpha_optimizer.step()

        return loss

    def learn(self, samples):
        """update the critics and actors of all the agents """

        rewards = to_tensor(samples['reward']).float().to(self.device).view(
            self.batch_size, 1)
        dones = to_tensor(samples['done']).int().to(self.device).view(
            self.batch_size, 1)
        states = to_tensor(samples['state']).float().to(self.device).view(
            self.batch_size, self.state_size)
        next_states = to_tensor(samples['next_state']).float().to(
            self.device).view(self.batch_size, self.state_size)
        actions = to_tensor(samples['action']).to(self.device).view(
            self.batch_size, self.action_size)

        # Critic (value) update
        for _ in range(self.critic_number_updates):
            value_loss, error = self.compute_value_loss(
                states, actions, rewards, next_states, dones)
            self.critic_optimizer.zero_grad()
            value_loss.backward()
            nn.utils.clip_grad_norm_(self.critic_params,
                                     self.max_grad_norm_critic)
            self.critic_optimizer.step()
            self._loss_critic = value_loss.item()

        # Actor (policy) update
        for _ in range(self.actor_number_updates):
            policy_loss = self.compute_policy_loss(states)
            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            nn.utils.clip_grad_norm_(self.actor_params,
                                     self.max_grad_norm_actor)
            self.actor_optimizer.step()
            self._loss_actor = policy_loss.item()

        if hasattr(self.memory, 'priority_update'):
            assert any(~torch.isnan(error))
            self.memory.priority_update(samples['index'], error.abs())

        soft_update(self.target_double_critic, self.double_critic, self.tau)

    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)
        data_logger.log_value("loss/alpha", self.alpha, step)

        if self.simple_policy:
            policy_params = {
                str(i): v
                for i, v in enumerate(
                    itertools.chain.from_iterable(self.policy.parameters()))
            }
            data_logger.log_values_dict("policy/param", policy_params, step)

        for name, value in self._metrics.items():
            if isinstance(value, dict):
                data_logger.log_values_dict(name, value, step)
            else:
                data_logger.log_value(name, value, step)

        if full_log:
            # TODO: Add Policy layers
            for idx, layer in enumerate(self.actor.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"policy/layer_weights_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"policy/layer_bias_{idx}",
                                                 layer.bias, step)

            for idx, layer in enumerate(self.double_critic.critic_1.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"critic_1/layer_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"critic_1/layer_bias_{idx}",
                                                 layer.bias, step)

            for idx, layer in enumerate(self.double_critic.critic_2.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"critic_2/layer_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"critic_2/layer_bias_{idx}",
                                                 layer.bias, step)

    def get_state(self):
        return dict(
            actor=self.actor.state_dict(),
            policy=self.policy.state_dict(),
            double_critic=self.double_critic.state_dict(),
            target_double_critic=self.target_double_critic.state_dict(),
            config=self._config,
        )

    def save_state(self, path: str):
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self, path: str):
        agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.actor.load_state_dict(agent_state['actor'])
        self.policy.load_state_dict(agent_state['policy'])
        self.double_critic.load_state_dict(agent_state['double_critic'])
        self.target_double_critic.load_state_dict(
            agent_state['target_double_critic'])
Ejemplo n.º 2
0
class D3PGAgent(AgentBase):
    """Distributional DDPG (D3PG) [1].

    It's closely related to, and sits in-between, D4PG and DDPG. Compared to D4PG it lacks
    the multi actors support. It extends the DDPG agent with:
    1. Distributional critic update.
    2. N-step returns.
    3. Prioritization of the experience replay (PER).

    [1] "Distributed Distributional Deterministic Policy Gradients"
        (2018, ICLR) by G. Barth-Maron & M. Hoffman et al.

    """

    name = "D3PG"

    def __init__(self,
                 state_size: int,
                 action_size: int,
                 hidden_layers: Sequence[int] = (128, 128),
                 **kwargs):
        super().__init__(**kwargs)
        self.device = self._register_param(kwargs, "device", DEVICE)
        self.state_size = state_size
        self.action_size = action_size

        self.num_atoms = int(self._register_param(kwargs, 'num_atoms', 51))
        v_min = float(self._register_param(kwargs, 'v_min', -10))
        v_max = float(self._register_param(kwargs, 'v_max', 10))

        # Reason sequence initiation.
        self.action_min = float(self._register_param(kwargs, 'action_min', -1))
        self.action_max = float(self._register_param(kwargs, 'action_max', 1))
        self.action_scale = int(self._register_param(kwargs, 'action_scale',
                                                     1))

        self.gamma = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size: int = int(
            self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size: int = int(
            self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.buffer = PERBuffer(self.batch_size, self.buffer_size)

        self.n_steps = int(self._register_param(kwargs, "n_steps", 3))
        self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma)

        self.warm_up: int = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq: int = int(
            self._register_param(kwargs, 'update_freq', 1))

        if kwargs.get("simple_policy", False):
            std_init = kwargs.get("std_init", 1.0)
            std_max = kwargs.get("std_max", 1.5)
            std_min = kwargs.get("std_min", 0.25)
            self.policy = MultivariateGaussianPolicySimple(self.action_size,
                                                           std_init=std_init,
                                                           std_min=std_min,
                                                           std_max=std_max,
                                                           device=self.device)
        else:
            self.policy = MultivariateGaussianPolicy(self.action_size,
                                                     device=self.device)

        self.actor_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'actor_hidden_layers', hidden_layers))
        self.critic_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'critic_hidden_layers',
                                 hidden_layers))

        # This looks messy but it's not that bad. Actor, critic_net and Critic(critic_net). Then the same for `target_`.
        self.actor = ActorBody(state_size,
                               self.policy.param_dim * action_size,
                               hidden_layers=self.actor_hidden_layers,
                               gate_out=torch.tanh,
                               device=self.device)
        critic_net = CriticBody(state_size,
                                action_size,
                                out_features=self.num_atoms,
                                hidden_layers=self.critic_hidden_layers,
                                device=self.device)
        self.critic = CategoricalNet(num_atoms=self.num_atoms,
                                     v_min=v_min,
                                     v_max=v_max,
                                     net=critic_net,
                                     device=self.device)

        self.target_actor = ActorBody(state_size,
                                      self.policy.param_dim * action_size,
                                      hidden_layers=self.actor_hidden_layers,
                                      gate_out=torch.tanh,
                                      device=self.device)
        target_critic_net = CriticBody(state_size,
                                       action_size,
                                       out_features=self.num_atoms,
                                       hidden_layers=self.critic_hidden_layers,
                                       device=self.device)
        self.target_critic = CategoricalNet(num_atoms=self.num_atoms,
                                            v_min=v_min,
                                            v_max=v_max,
                                            net=target_critic_net,
                                            device=self.device)

        # Target sequence initiation
        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)

        # Optimization sequence initiation.
        self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4))
        self.value_loss_func = nn.BCELoss(reduction='none')

        # self.actor_params = list(self.actor.parameters()) #+ list(self.policy.parameters())
        self.actor_params = list(self.actor.parameters()) + list(
            self.policy.parameters())
        self.actor_optimizer = Adam(self.actor_params, lr=self.actor_lr)
        self.critic_optimizer = Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 50.0))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 50.0))

        # Breath, my child.
        self.iteration = 0
        self._loss_actor = float('nan')
        self._loss_critic = float('nan')
        self._display_dist = torch.zeros(self.critic.z_atoms.shape)
        self._metric_batch_error = torch.zeros(self.batch_size)
        self._metric_batch_value_dist = torch.zeros(self.batch_size)

    @property
    def loss(self) -> Dict[str, float]:
        return {'actor': self._loss_actor, 'critic': self._loss_critic}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            self._loss_actor = value['actor']
            self._loss_critic = value['critic']
        else:
            self._loss_actor = value
            self._loss_critic = value

    @torch.no_grad()
    def act(self, state, epsilon: float = 0.0) -> List[float]:
        """
        Returns actions for given state as per current policy.

        Parameters:
            state: Current available state from the environment.
            epislon: Epsilon value in the epislon-greedy policy.

        """
        state = to_tensor(state).float().to(self.device)
        if self._rng.random() < epsilon:
            action = self.action_scale * (torch.rand(self.action_size) - 0.5)

        else:
            action_seed = self.actor.act(state).view(1, -1)
            action_dist = self.policy(action_seed)
            action = action_dist.sample()
            action *= self.action_scale
            action = action.squeeze()

        # Purely for logging
        self._display_dist = self.target_critic.act(
            state, action.to(self.device)).squeeze().cpu()
        self._display_dist = F.softmax(self._display_dist, dim=0)

        return torch.clamp(action, self.action_min,
                           self.action_max).cpu().tolist()

    def step(self, state, action, reward, next_state, done):
        self.iteration += 1

        # Delay adding to buffer to account for n_steps (particularly the reward)
        self.n_buffer.add(state=state,
                          action=action,
                          reward=[reward],
                          done=[done],
                          next_state=next_state)
        if not self.n_buffer.available:
            return

        self.buffer.add(**self.n_buffer.get().get_dict())

        if self.iteration < self.warm_up:
            return

        if len(self.buffer) > self.batch_size and (self.iteration %
                                                   self.update_freq) == 0:
            self.learn(self.buffer.sample())

    def compute_value_loss(self,
                           states,
                           actions,
                           next_states,
                           rewards,
                           dones,
                           indices=None):
        # Q_w estimate
        value_dist_estimate = self.critic(states, actions)
        assert value_dist_estimate.shape == (self.batch_size, 1,
                                             self.num_atoms)
        value_dist = F.softmax(value_dist_estimate.squeeze(), dim=1)
        assert value_dist.shape == (self.batch_size, self.num_atoms)

        # Q_w' estimate via Bellman's dist operator
        next_action_seeds = self.target_actor.act(next_states)
        next_actions = self.policy(next_action_seeds).sample()
        assert next_actions.shape == (self.batch_size, self.action_size)

        target_value_dist_estimate = self.target_critic.act(
            states, next_actions)
        assert target_value_dist_estimate.shape == (self.batch_size, 1,
                                                    self.num_atoms)
        target_value_dist_estimate = target_value_dist_estimate.squeeze()
        assert target_value_dist_estimate.shape == (self.batch_size,
                                                    self.num_atoms)

        discount = self.gamma**self.n_steps
        target_value_projected = self.target_critic.dist_projection(
            rewards, 1 - dones, discount, target_value_dist_estimate)
        assert target_value_projected.shape == (self.batch_size,
                                                self.num_atoms)

        target_value_dist = F.softmax(target_value_dist_estimate,
                                      dim=-1).detach()
        assert target_value_dist.shape == (self.batch_size, self.num_atoms)

        # Comparing Q_w with Q_w'
        loss = self.value_loss_func(value_dist, target_value_projected)
        self._metric_batch_error = loss.detach().sum(dim=-1)
        samples_error = loss.sum(dim=-1).pow(2)
        loss_critic = samples_error.mean()

        if hasattr(self.buffer, 'priority_update') and indices is not None:
            assert (~torch.isnan(samples_error)).any()
            self.buffer.priority_update(indices,
                                        samples_error.detach().cpu().numpy())

        return loss_critic

    def compute_policy_loss(self, states):
        # Compute actor loss
        pred_action_seeds = self.actor(states)
        pred_actions = self.policy(pred_action_seeds).rsample()
        # Negative because the optimizer minimizes, but we want to maximize the value
        value_dist = self.critic(states, pred_actions)
        self._metric_batch_value_dist = value_dist.detach()
        # Estimate on Z support
        return -torch.mean(value_dist * self.critic.z_atoms)

    def learn(self, experiences):
        """Update critics and actors"""
        rewards = to_tensor(experiences['reward']).float().to(self.device)
        dones = to_tensor(experiences['done']).type(torch.int).to(self.device)
        states = to_tensor(experiences['state']).float().to(self.device)
        actions = to_tensor(experiences['action']).to(self.device)
        next_states = to_tensor(experiences['next_state']).float().to(
            self.device)
        assert rewards.shape == dones.shape == (self.batch_size, 1)
        assert states.shape == next_states.shape == (self.batch_size,
                                                     self.state_size)
        assert actions.shape == (self.batch_size, self.action_size)

        indices = None
        if hasattr(self.buffer, 'priority_update'):  # When using PER buffer
            indices = experiences['index']
        loss_critic = self.compute_value_loss(states, actions, next_states,
                                              rewards, dones, indices)

        # Value (critic) optimization
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_critic)
        self.critic_optimizer.step()
        self._loss_critic = float(loss_critic.item())

        # Policy (actor) optimization
        loss_actor = self.compute_policy_loss(states)
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(),
                                 self.max_grad_norm_actor)
        self.actor_optimizer.step()
        self._loss_actor = float(loss_actor.item())

        # Networks gradual sync
        soft_update(self.target_actor, self.actor, self.tau)
        soft_update(self.target_critic, self.critic, self.tau)

    def state_dict(self) -> Dict[str, dict]:
        """Describes agent's networks.

        Returns:
            state: (dict) Provides actors and critics states.

        """
        return {
            "actor": self.actor.state_dict(),
            "target_actor": self.target_actor.state_dict(),
            "critic": self.critic.state_dict(),
            "target_critic": self.target_critic()
        }

    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)
        policy_params = {
            str(i): v
            for i, v in enumerate(
                itertools.chain.from_iterable(self.policy.parameters()))
        }
        data_logger.log_values_dict("policy/param", policy_params, step)

        data_logger.create_histogram('metric/batch_errors',
                                     self._metric_batch_error, step)
        data_logger.create_histogram('metric/batch_value_dist',
                                     self._metric_batch_value_dist, step)

        if full_log:
            dist = self._display_dist
            z_atoms = self.critic.z_atoms
            z_delta = self.critic.z_delta
            data_logger.add_histogram('dist/dist_value',
                                      min=z_atoms[0],
                                      max=z_atoms[-1],
                                      num=self.num_atoms,
                                      sum=dist.sum(),
                                      sum_squares=dist.pow(2).sum(),
                                      bucket_limits=z_atoms + z_delta,
                                      bucket_counts=dist,
                                      global_step=step)

    def get_state(self):
        return dict(
            actor=self.actor.state_dict(),
            target_actor=self.target_actor.state_dict(),
            critic=self.critic.state_dict(),
            target_critic=self.target_critic.state_dict(),
            config=self._config,
        )

    def save_state(self, path: str):
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self, path: str):
        agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.actor.load_state_dict(agent_state['actor'])
        self.critic.load_state_dict(agent_state['critic'])
        self.target_actor.load_state_dict(agent_state['target_actor'])
        self.target_critic.load_state_dict(agent_state['target_critic'])
Ejemplo n.º 3
0
class DDPGAgent(AgentBase):
    """
    Deep Deterministic Policy Gradients (DDPG).

    Instead of popular Ornstein-Uhlenbeck (OU) process for noise this agent uses Gaussian noise.
    """

    name = "DDPG"

    def __init__(self,
                 state_size: int,
                 action_size: int,
                 actor_lr: float = 2e-3,
                 critic_lr: float = 2e-3,
                 noise_scale: float = 0.2,
                 noise_sigma: float = 0.1,
                 **kwargs):
        super().__init__(**kwargs)
        self.device = self._register_param(kwargs, "device", DEVICE)
        self.state_size = state_size
        self.action_size = action_size

        # Reason sequence initiation.
        hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'hidden_layers', (128, 128)))
        self.actor = ActorBody(state_size,
                               action_size,
                               hidden_layers=hidden_layers,
                               gate_out=torch.tanh).to(self.device)
        self.critic = CriticBody(state_size,
                                 action_size,
                                 hidden_layers=hidden_layers).to(self.device)
        self.target_actor = ActorBody(state_size,
                                      action_size,
                                      hidden_layers=hidden_layers,
                                      gate_out=torch.tanh).to(self.device)
        self.target_critic = CriticBody(state_size,
                                        action_size,
                                        hidden_layers=hidden_layers).to(
                                            self.device)

        # Noise sequence initiation
        self.noise = GaussianNoise(shape=(action_size, ),
                                   mu=1e-8,
                                   sigma=noise_sigma,
                                   scale=noise_scale,
                                   device=self.device)

        # Target sequence initiation
        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)

        # Optimization sequence initiation.
        self.actor_lr = float(
            self._register_param(kwargs, 'actor_lr', actor_lr))
        self.critic_lr = float(
            self._register_param(kwargs, 'critic_lr', critic_lr))
        self.actor_optimizer = Adam(self.actor.parameters(), lr=self.actor_lr)
        self.critic_optimizer = Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 10.0))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 10.0))
        self.action_min = float(self._register_param(kwargs, 'action_min', -1))
        self.action_max = float(self._register_param(kwargs, 'action_max', 1))
        self.action_scale = float(
            self._register_param(kwargs, 'action_scale', 1))

        self.gamma = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size = int(self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size = int(
            self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.buffer = ReplayBuffer(self.batch_size, self.buffer_size)

        self.warm_up = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq = int(self._register_param(kwargs, 'update_freq', 1))
        self.number_updates = int(
            self._register_param(kwargs, 'number_updates', 1))

        # Breath, my child.
        self.reset_agent()
        self.iteration = 0
        self._loss_actor = 0.
        self._loss_critic = 0.

    def reset_agent(self) -> None:
        self.actor.reset_parameters()
        self.critic.reset_parameters()
        self.target_actor.reset_parameters()
        self.target_critic.reset_parameters()

    @property
    def loss(self) -> Dict[str, float]:
        return {'actor': self._loss_actor, 'critic': self._loss_critic}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            self._loss_actor = value['actor']
            self._loss_critic = value['critic']
        else:
            self._loss_actor = value
            self._loss_critic = value

    @torch.no_grad()
    def act(self, obs, noise: float = 0.0) -> List[float]:
        """Acting on the observations. Returns action.

        Returns:
            action: (list float) Action values.
        """
        obs = to_tensor(obs).float().to(self.device)
        action = self.actor(obs)
        action += noise * self.noise.sample()
        action = torch.clamp(action * self.action_scale, self.action_min,
                             self.action_max)
        return action.cpu().numpy().tolist()

    def step(self, state, action, reward, next_state, done) -> None:
        self.iteration += 1
        self.buffer.add(state=state,
                        action=action,
                        reward=reward,
                        next_state=next_state,
                        done=done)

        if self.iteration < self.warm_up:
            return

        if len(self.buffer) > self.batch_size and (self.iteration %
                                                   self.update_freq) == 0:
            for _ in range(self.number_updates):
                self.learn(self.buffer.sample())

    def compute_value_loss(self, states, actions, next_states, rewards, dones):
        next_actions = self.target_actor.act(next_states)
        assert next_actions.shape == actions.shape
        Q_target_next = self.target_critic.act(next_states, next_actions)
        Q_target = rewards + self.gamma * Q_target_next * (1 - dones)
        Q_expected = self.critic(states, actions)
        assert Q_expected.shape == Q_target.shape == Q_target_next.shape
        return mse_loss(Q_expected, Q_target)

    def compute_policy_loss(self, states) -> None:
        """Compute Policy loss based on provided states.

        Loss = Mean(-Q(s, _a) ),
        where _a is actor's estimate based on state, _a = Actor(s).
        """
        pred_actions = self.actor(states)
        return -self.critic(states, pred_actions).mean()

    def learn(self, experiences) -> None:
        """Update critics and actors"""
        rewards = to_tensor(experiences['reward']).float().to(
            self.device).unsqueeze(1)
        dones = to_tensor(experiences['done']).type(torch.int).to(
            self.device).unsqueeze(1)
        states = to_tensor(experiences['state']).float().to(self.device)
        actions = to_tensor(experiences['action']).to(self.device)
        next_states = to_tensor(experiences['next_state']).float().to(
            self.device)
        assert rewards.shape == dones.shape == (self.batch_size, 1)
        assert states.shape == next_states.shape == (self.batch_size,
                                                     self.state_size)
        assert actions.shape == (self.batch_size, self.action_size)

        # Value (critic) optimization
        loss_critic = self.compute_value_loss(states, actions, next_states,
                                              rewards, dones)
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(),
                                 self.max_grad_norm_critic)
        self.critic_optimizer.step()
        self._loss_critic = float(loss_critic.item())

        # Policy (actor) optimization
        loss_actor = self.compute_policy_loss(states)
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(),
                                 self.max_grad_norm_actor)
        self.actor_optimizer.step()
        self._loss_actor = loss_actor.item()

        # Soft update target weights
        soft_update(self.target_actor, self.actor, self.tau)
        soft_update(self.target_critic, self.critic, self.tau)

    def state_dict(self) -> Dict[str, dict]:
        """Describes agent's networks.

        Returns:
            state: (dict) Provides actors and critics states.

        """
        return {
            "actor": self.actor.state_dict(),
            "target_actor": self.target_actor.state_dict(),
            "critic": self.critic.state_dict(),
            "target_critic": self.target_critic.state_dict()
        }

    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)

        if full_log:
            for idx, layer in enumerate(self.actor.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"actor/layer_weights_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"actor/layer_bias_{idx}",
                                                 layer.bias, step)

            for idx, layer in enumerate(self.critic.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"critic/layer_weights_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"critic/layer_bias_{idx}",
                                                 layer.bias, step)

    def get_state(self) -> AgentState:
        net = dict(
            actor=self.actor.state_dict(),
            target_actor=self.target_actor.state_dict(),
            critic=self.critic.state_dict(),
            target_critic=self.target_critic.state_dict(),
        )
        network_state: NetworkState = NetworkState(net=net)
        return AgentState(model=self.name,
                          state_space=self.state_size,
                          action_space=self.action_size,
                          config=self._config,
                          buffer=self.buffer.get_state(),
                          network=network_state)

    def save_state(self, path: str) -> None:
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self,
                   *,
                   path: Optional[str] = None,
                   agent_state: Optional[dict] = None):
        if path is None and agent_state:
            raise ValueError(
                "Either `path` or `agent_state` must be provided to load agent's state."
            )
        if path is not None and agent_state is None:
            agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.actor.load_state_dict(agent_state['actor'])
        self.critic.load_state_dict(agent_state['critic'])
        self.target_actor.load_state_dict(agent_state['target_actor'])
        self.target_critic.load_state_dict(agent_state['target_critic'])
Ejemplo n.º 4
0
class TD3Agent(AgentBase):
    """
    Twin Delayed Deep Deterministic (TD3) Policy Gradient.

    In short, it's a slightly modified/improved version of the DDPG. Compared to the DDPG in this package,
    which uses Guassian noise, this TD3 uses Ornstein–Uhlenbeck process as the noise.
    """

    name = "TD3"

    def __init__(self,
                 state_size: int,
                 action_size: int,
                 noise_scale: float = 0.2,
                 noise_sigma: float = 0.1,
                 **kwargs):
        """
        Parameters:
            state_size (int): Number of input dimensions.
            action_size (int): Number of output dimensions
            noise_scale (float): Added noise amplitude. Default: 0.2.
            noise_sigma (float): Added noise variance. Default: 0.1.

        Keyword parameters:
            hidden_layers (tuple of ints): Tuple defining hidden dimensions in fully connected nets. Default: (128, 128).
            actor_lr (float): Learning rate for the actor (policy). Default: 0.003.
            critic_lr (float): Learning rate for the critic (value function). Default: 0.003.
            gamma (float): Discount value. Default: 0.99.
            tau (float): Soft-copy factor. Default: 0.02.
            actor_hidden_layers (tuple of ints): Shape of network for actor. Default: `hideen_layers`.
            critic_hidden_layers (tuple of ints): Shape of network for critic. Default: `hideen_layers`.
            max_grad_norm_actor (float) Maximum norm value for actor gradient. Default: 100.
            max_grad_norm_critic (float): Maximum norm value for critic gradient. Default: 100.
            batch_size (int): Number of samples used in learning. Default: 64.
            buffer_size (int): Maximum number of samples to store. Default: 1e6.
            warm_up (int): Number of samples to observe before starting any learning step. Default: 0.
            update_freq (int): Number of steps between each learning step. Default 1.
            number_updates (int): How many times to use learning step in the learning phase. Default: 1.
            action_min (float): Minimum returned action value. Default: -1.
            action_max (float): Maximum returned action value. Default: 1.
            action_scale (float): Multipler value for action. Default: 1.

        """
        super().__init__(**kwargs)
        self.device = self._register_param(
            kwargs, "device", DEVICE)  # Default device is CUDA if available

        # Reason sequence initiation.
        self.state_size = state_size
        self.action_size = action_size

        hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'hidden_layers', (128, 128)))
        self.actor = ActorBody(state_size,
                               action_size,
                               hidden_layers=hidden_layers).to(self.device)
        self.critic = DoubleCritic(state_size,
                                   action_size,
                                   CriticBody,
                                   hidden_layers=hidden_layers).to(self.device)
        self.target_actor = ActorBody(state_size,
                                      action_size,
                                      hidden_layers=hidden_layers).to(
                                          self.device)
        self.target_critic = DoubleCritic(state_size,
                                          action_size,
                                          CriticBody,
                                          hidden_layers=hidden_layers).to(
                                              self.device)

        # Noise sequence initiation
        # self.noise = GaussianNoise(shape=(action_size,), mu=1e-8, sigma=noise_sigma, scale=noise_scale, device=device)
        self.noise = OUProcess(shape=action_size,
                               scale=noise_scale,
                               sigma=noise_sigma,
                               device=self.device)

        # Target sequence initiation
        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)

        # Optimization sequence initiation.
        actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-3))
        critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-3))
        self.actor_optimizer = AdamW(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = AdamW(self.critic.parameters(), lr=critic_lr)
        self.max_grad_norm_actor: float = float(
            kwargs.get("max_grad_norm_actor", 100))
        self.max_grad_norm_critic: float = float(
            kwargs.get("max_grad_norm_critic", 100))
        self.action_min = float(self._register_param(kwargs, 'action_min',
                                                     -1.))
        self.action_max = float(self._register_param(kwargs, 'action_max', 1.))
        self.action_scale = float(
            self._register_param(kwargs, 'action_scale', 1.))

        self.gamma = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size = int(self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size = int(
            self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.buffer = ReplayBuffer(self.batch_size, self.buffer_size)

        self.warm_up = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq = int(self._register_param(kwargs, 'update_freq', 1))
        self.update_policy_freq = int(
            self._register_param(kwargs, 'update_policy_freq', 1))
        self.number_updates = int(
            self._register_param(kwargs, 'number_updates', 1))
        self.noise_reset_freq = int(
            self._register_param(kwargs, 'noise_reset_freq', 10000))

        # Breath, my child.
        self.reset_agent()
        self.iteration = 0
        self._loss_actor = 0.
        self._loss_critic = 0.

    @property
    def loss(self) -> Dict[str, float]:
        return {'actor': self._loss_actor, 'critic': self._loss_critic}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            self._loss_actor = value['actor']
            self._loss_critic = value['critic']
        else:
            self._loss_actor = value
            self._loss_critic = value

    def reset_agent(self) -> None:
        self.actor.reset_parameters()
        self.critic.reset_parameters()
        self.target_actor.reset_parameters()
        self.target_critic.reset_parameters()

    def act(self,
            state,
            epsilon: float = 0.0,
            training_mode=True) -> List[float]:
        """
        Agent acting on observations.

        When the training_mode is True (default) a noise is added to each action.
        """
        # Epsilon greedy
        if self._rng.random() < epsilon:
            rnd_actions = torch.rand(self.action_size) * (
                self.action_max - self.action_min) - self.action_min
            return rnd_actions.tolist()

        with torch.no_grad():
            state = to_tensor(state).float().to(self.device)
            action = self.actor(state)
            if training_mode:
                action += self.noise.sample()
            return (self.action_scale * torch.clamp(action, self.action_min,
                                                    self.action_max)).tolist()

    def target_act(self, staten, noise: float = 0.0):
        with torch.no_grad():
            staten = to_tensor(staten).float().to(self.device)
            action = self.target_actor(staten) + noise * self.noise.sample()
            return torch.clamp(action, self.action_min,
                               self.action_max).cpu().numpy().astype(
                                   np.float32)

    def step(self, state, action, reward, next_state, done):
        self.iteration += 1
        self.buffer.add(state=state,
                        action=action,
                        reward=reward,
                        next_state=next_state,
                        done=done)

        if (self.iteration % self.noise_reset_freq) == 0:
            self.noise.reset_states()

        if self.iteration < self.warm_up:
            return

        if len(self.buffer) <= self.batch_size:
            return

        if not (self.iteration % self.update_freq) or not (
                self.iteration % self.update_policy_freq):
            for _ in range(self.number_updates):
                # Note: Inside this there's a delayed policy update.
                #       Every `update_policy_freq` it will learn `number_updates` times.
                self.learn(self.buffer.sample())

    def learn(self, experiences):
        """Update critics and actors"""
        rewards = to_tensor(experiences['reward']).float().to(
            self.device).unsqueeze(1)
        dones = to_tensor(experiences['done']).type(torch.int).to(
            self.device).unsqueeze(1)
        states = to_tensor(experiences['state']).float().to(self.device)
        actions = to_tensor(experiences['action']).to(self.device)
        next_states = to_tensor(experiences['next_state']).float().to(
            self.device)

        if (self.iteration % self.update_freq) == 0:
            self._update_value_function(states, actions, rewards, next_states,
                                        dones)

        if (self.iteration % self.update_policy_freq) == 0:
            self._update_policy(states)

            soft_update(self.target_actor, self.actor, self.tau)
            soft_update(self.target_critic, self.critic, self.tau)

    def _update_value_function(self, states, actions, rewards, next_states,
                               dones):
        # critic loss
        next_actions = self.target_actor.act(next_states)
        Q_target_next = torch.min(
            *self.target_critic.act(next_states, next_actions))
        Q_target = rewards + (self.gamma * Q_target_next * (1 - dones))
        Q1_expected, Q2_expected = self.critic(states, actions)
        loss_critic = mse_loss(Q1_expected, Q_target) + mse_loss(
            Q2_expected, Q_target)

        # Minimize the loss
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(),
                                 self.max_grad_norm_critic)
        self.critic_optimizer.step()
        self._loss_critic = float(loss_critic.item())

    def _update_policy(self, states):
        # Compute actor loss
        pred_actions = self.actor(states)
        loss_actor = -self.critic(states, pred_actions)[0].mean()
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(),
                                 self.max_grad_norm_actor)
        self.actor_optimizer.step()
        self._loss_actor = loss_actor.item()

    def state_dict(self) -> Dict[str, dict]:
        """Describes agent's networks.

        Returns:
            state: (dict) Provides actors and critics states.

        """
        return {
            "actor": self.actor.state_dict(),
            "target_actor": self.target_actor.state_dict(),
            "critic": self.critic.state_dict(),
            "target_critic": self.target_critic()
        }

    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)

    def get_state(self):
        return dict(
            actor=self.actor.state_dict(),
            target_actor=self.target_actor.state_dict(),
            critic=self.critic.state_dict(),
            target_critic=self.target_critic.state_dict(),
            config=self._config,
        )

    def save_state(self, path: str):
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self, path: str):
        agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.actor.load_state_dict(agent_state['actor'])
        self.critic.load_state_dict(agent_state['critic'])
        self.target_actor.load_state_dict(agent_state['target_actor'])
        self.target_critic.load_state_dict(agent_state['target_critic'])
Ejemplo n.º 5
0
class PPOAgent(AgentBase):
    """
    Proximal Policy Optimization (PPO) [1] is an online policy gradient method
    that could be considered as an implementation-wise simplified version of
    the Trust Region Policy Optimization (TRPO).


    [1] "Proximal Policy Optimization Algorithms" (2017) by J. Schulman, F. Wolski,
        P. Dhariwal, A. Radford, O. Klimov. https://arxiv.org/abs/1707.06347
    """

    name = "PPO"

    def __init__(self, state_size: int, action_size: int, **kwargs):
        """
        Parameters:
            state_size: Number of input dimensions.
            action_size: Number of output dimensions
            hidden_layers: (default: (100, 100) ) Tuple defining hidden dimensions in fully connected nets.
            is_discrete: (default: False) Whether return discrete action.
            kl_div: (default: False) Whether to use KL divergence in loss.
            using_gae: (default: True) Whether to use General Advantage Estimator.
            gae_lambda: (default: 0.96) Value of \lambda in GAE.
            actor_lr: (default: 0.0003) Learning rate for the actor (policy).
            critic_lr: (default: 0.001) Learning rate for the critic (value function).
            actor_betas: (default: (0.9, 0.999) Adam's betas for actor optimizer.
            critic_betas: (default: (0.9, 0.999) Adam's betas for critic optimizer.
            gamma: (default: 0.99) Discount value.
            ppo_ratio_clip: (default: 0.25) Policy ratio clipping value.
            num_epochs: (default: 1) Number of time to learn from samples.
            rollout_length: (default: 48) Number of actions to take before update.
            batch_size: (default: rollout_length) Number of samples used in learning.
            actor_number_updates: (default: 10) Number of times policy losses are propagated.
            critic_number_updates: (default: 10) Number of times value losses are propagated.
            entropy_weight: (default: 0.005) Weight of the entropy term in the loss.
            value_loss_weight: (default: 0.005) Weight of the entropy term in the loss.

        """
        super().__init__(**kwargs)

        self.device = self._register_param(
            kwargs, "device", DEVICE)  # Default device is CUDA if available

        self.state_size = state_size
        self.action_size = action_size
        self.hidden_layers = to_numbers_seq(
            self._register_param(kwargs, "hidden_layers", (100, 100)))
        self.iteration = 0

        self.is_discrete = bool(
            self._register_param(kwargs, "is_discrete", False))
        self.using_gae = bool(self._register_param(kwargs, "using_gae", True))
        self.gae_lambda = float(
            self._register_param(kwargs, "gae_lambda", 0.96))

        self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        self.actor_betas: Tuple[float, float] = to_numbers_seq(
            self._register_param(kwargs, 'actor_betas', (0.9, 0.999)))
        self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 1e-3))
        self.critic_betas: Tuple[float, float] = to_numbers_seq(
            self._register_param(kwargs, 'critic_betas', (0.9, 0.999)))
        self.gamma = float(self._register_param(kwargs, "gamma", 0.99))
        self.ppo_ratio_clip = float(
            self._register_param(kwargs, "ppo_ratio_clip", 0.25))

        self.using_kl_div = bool(
            self._register_param(kwargs, "using_kl_div", False))
        self.kl_beta = float(self._register_param(kwargs, 'kl_beta', 0.1))
        self.target_kl = float(self._register_param(kwargs, "target_kl", 0.01))
        self.kl_div = float('inf')

        self.num_workers = int(self._register_param(kwargs, "num_workers", 1))
        self.num_epochs = int(self._register_param(kwargs, "num_epochs", 1))
        self.rollout_length = int(
            self._register_param(kwargs, "rollout_length",
                                 48))  # "Much less than the episode length"
        self.batch_size = int(
            self._register_param(kwargs, "batch_size", self.rollout_length))
        self.actor_number_updates = int(
            self._register_param(kwargs, "actor_number_updates", 10))
        self.critic_number_updates = int(
            self._register_param(kwargs, "critic_number_updates", 10))
        self.entropy_weight = float(
            self._register_param(kwargs, "entropy_weight", 0.5))
        self.value_loss_weight = float(
            self._register_param(kwargs, "value_loss_weight", 1.0))

        self.local_memory_buffer = {}

        self.action_scale = float(
            self._register_param(kwargs, "action_scale", 1))
        self.action_min = float(self._register_param(kwargs, "action_min", -1))
        self.action_max = float(self._register_param(kwargs, "action_max", 1))
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 100.0))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 100.0))

        if kwargs.get("simple_policy", False):
            self.policy = MultivariateGaussianPolicySimple(
                self.action_size, **kwargs)
        else:
            self.policy = MultivariateGaussianPolicy(self.action_size,
                                                     device=self.device)

        self.buffer = RolloutBuffer(batch_size=self.batch_size,
                                    buffer_size=self.rollout_length)
        self.actor = ActorBody(state_size,
                               self.policy.param_dim * action_size,
                               gate_out=torch.tanh,
                               hidden_layers=self.hidden_layers,
                               device=self.device)
        self.critic = ActorBody(state_size,
                                1,
                                gate_out=None,
                                hidden_layers=self.hidden_layers,
                                device=self.device)
        self.actor_params = list(self.actor.parameters()) + list(
            self.policy.parameters())
        self.critic_params = list(self.critic.parameters())

        self.actor_opt = optim.Adam(self.actor_params,
                                    lr=self.actor_lr,
                                    betas=self.actor_betas)
        self.critic_opt = optim.Adam(self.critic_params,
                                     lr=self.critic_lr,
                                     betas=self.critic_betas)
        self._loss_actor = float('nan')
        self._loss_critic = float('nan')
        self._metrics: Dict[str, float] = {}

    @property
    def loss(self) -> Dict[str, float]:
        return {'actor': self._loss_actor, 'critic': self._loss_critic}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            self._loss_actor = value['actor']
            self._loss_critic = value['critic']
        else:
            self._loss_actor = value
            self._loss_critic = value

    def __eq__(self, o: object) -> bool:
        return super().__eq__(o) \
            and self._config == o._config \
            and self.buffer == o.buffer \
            and self.get_network_state() == o.get_network_state()  # TODO @dawid: Currently net isn't compared properly

    def __clear_memory(self):
        self.buffer.clear()

    @torch.no_grad()
    def act(self, state, epsilon: float = 0.):
        actions = []
        logprobs = []
        values = []
        state = to_tensor(state).view(self.num_workers,
                                      self.state_size).float().to(self.device)
        for worker in range(self.num_workers):
            actor_est = self.actor.act(state[worker].unsqueeze(0))
            assert not torch.any(torch.isnan(actor_est))

            dist = self.policy(actor_est)
            action = dist.sample()
            value = self.critic.act(
                state[worker].unsqueeze(0))  # Shape: (1, 1)
            logprob = self.policy.log_prob(dist, action)  # Shape: (1,)
            values.append(value)
            logprobs.append(logprob)

            if self.is_discrete:  # *Technically* it's the max of Softmax but that's monotonic.
                action = int(torch.argmax(action))
            else:
                action = torch.clamp(action * self.action_scale,
                                     self.action_min, self.action_max)
                action = action.cpu().numpy().flatten().tolist()
            actions.append(action)

        self.local_memory_buffer['value'] = torch.cat(values)
        self.local_memory_buffer['logprob'] = torch.stack(logprobs)
        assert len(actions) == self.num_workers
        return actions if self.num_workers > 1 else actions[0]

    def step(self, state, action, reward, next_state, done, **kwargs):
        self.iteration += 1

        self.buffer.add(
            state=torch.tensor(state).reshape(self.num_workers,
                                              self.state_size).float(),
            action=torch.tensor(action).reshape(self.num_workers,
                                                self.action_size).float(),
            reward=torch.tensor(reward).reshape(self.num_workers, 1),
            done=torch.tensor(done).reshape(self.num_workers, 1),
            logprob=self.local_memory_buffer['logprob'].reshape(
                self.num_workers, 1),
            value=self.local_memory_buffer['value'].reshape(
                self.num_workers, 1),
        )

        if self.iteration % self.rollout_length == 0:
            self.train()
            self.__clear_memory()

    def train(self):
        """
        Main loop that initiates the training.
        """
        experiences = self.buffer.all_samples()
        rewards = to_tensor(experiences['reward']).to(self.device)
        dones = to_tensor(experiences['done']).type(torch.int).to(self.device)
        states = to_tensor(experiences['state']).to(self.device)
        actions = to_tensor(experiences['action']).to(self.device)
        values = to_tensor(experiences['value']).to(self.device)
        logprobs = to_tensor(experiences['logprob']).to(self.device)
        assert rewards.shape == dones.shape == values.shape == logprobs.shape
        assert states.shape == (
            self.rollout_length, self.num_workers,
            self.state_size), f"Wrong states shape: {states.shape}"
        assert actions.shape == (
            self.rollout_length, self.num_workers,
            self.action_size), f"Wrong action shape: {actions.shape}"

        with torch.no_grad():
            if self.using_gae:
                next_value = self.critic.act(states[-1])
                advantages = compute_gae(rewards, dones, values, next_value,
                                         self.gamma, self.gae_lambda)
                advantages = normalize(advantages)
                returns = advantages + values
                # returns = normalize(advantages + values)
                assert advantages.shape == returns.shape == values.shape
            else:
                returns = revert_norm_returns(rewards, dones, self.gamma)
                returns = returns.float()
                advantages = normalize(returns - values)
                assert advantages.shape == returns.shape == values.shape

        for _ in range(self.num_epochs):
            idx = 0
            self.kl_div = 0
            while idx < self.rollout_length:
                _states = states[idx:idx + self.batch_size].view(
                    -1, self.state_size).detach()
                _actions = actions[idx:idx + self.batch_size].view(
                    -1, self.action_size).detach()
                _logprobs = logprobs[idx:idx + self.batch_size].view(
                    -1, 1).detach()
                _returns = returns[idx:idx + self.batch_size].view(-1,
                                                                   1).detach()
                _advantages = advantages[idx:idx + self.batch_size].view(
                    -1, 1).detach()
                idx += self.batch_size
                self.learn(
                    (_states, _actions, _logprobs, _returns, _advantages))

            self.kl_div = abs(
                self.kl_div) / (self.actor_number_updates *
                                self.rollout_length / self.batch_size)
            if self.kl_div > self.target_kl * 1.75:
                self.kl_beta = min(2 * self.kl_beta, 1e2)  # Max 100
            if self.kl_div < self.target_kl / 1.75:
                self.kl_beta = max(0.5 * self.kl_beta, 1e-6)  # Min 0.000001
            self._metrics['policy/kl_beta'] = self.kl_beta

    def compute_policy_loss(self, samples):
        states, actions, old_log_probs, _, advantages = samples

        actor_est = self.actor(states)
        dist = self.policy(actor_est)

        entropy = dist.entropy()
        new_log_probs = self.policy.log_prob(dist, actions).view(-1, 1)
        assert new_log_probs.shape == old_log_probs.shape

        r_theta = (new_log_probs - old_log_probs).exp()
        r_theta_clip = torch.clamp(r_theta, 1.0 - self.ppo_ratio_clip,
                                   1.0 + self.ppo_ratio_clip)
        assert r_theta.shape == r_theta_clip.shape

        # KL = E[log(P/Q)] = sum_{P}( P * log(P/Q) ) -- \approx --> avg_{P}( log(P) - log(Q) )
        approx_kl_div = (old_log_probs - new_log_probs).mean().item()
        if self.using_kl_div:
            # Ratio threshold for updates is 1.75 (although it should be configurable)
            policy_loss = -torch.mean(
                r_theta * advantages) + self.kl_beta * approx_kl_div
        else:
            joint_theta_adv = torch.stack(
                (r_theta * advantages, r_theta_clip * advantages))
            assert joint_theta_adv.shape[0] == 2
            policy_loss = -torch.amin(joint_theta_adv, dim=0).mean()
        entropy_loss = -self.entropy_weight * entropy.mean()

        loss = policy_loss + entropy_loss
        self._metrics['policy/kl_div'] = approx_kl_div
        self._metrics['policy/policy_ratio'] = float(r_theta.mean())
        self._metrics['policy/policy_ratio_clip_mean'] = float(
            r_theta_clip.mean())
        return loss, approx_kl_div

    def compute_value_loss(self, samples):
        states, _, _, returns, _ = samples
        values = self.critic(states)
        self._metrics['value/value_mean'] = values.mean()
        self._metrics['value/value_std'] = values.std()
        return F.mse_loss(values, returns)

    def learn(self, samples):
        self._loss_actor = 0.

        for _ in range(self.actor_number_updates):
            self.actor_opt.zero_grad()
            loss_actor, kl_div = self.compute_policy_loss(samples)
            self.kl_div += kl_div
            if kl_div > 1.5 * self.target_kl:
                # Early break
                # print(f"Iter: {i:02} Early break")
                break
            loss_actor.backward()
            nn.utils.clip_grad_norm_(self.actor_params,
                                     self.max_grad_norm_actor)
            self.actor_opt.step()
            self._loss_actor = loss_actor.item()

        for _ in range(self.critic_number_updates):
            self.critic_opt.zero_grad()
            loss_critic = self.compute_value_loss(samples)
            loss_critic.backward()
            nn.utils.clip_grad_norm_(self.critic_params,
                                     self.max_grad_norm_critic)
            self.critic_opt.step()
            self._loss_critic = float(loss_critic.item())

    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)
        for metric_name, metric_value in self._metrics.items():
            data_logger.log_value(metric_name, metric_value, step)

        policy_params = {
            str(i): v
            for i, v in enumerate(
                itertools.chain.from_iterable(self.policy.parameters()))
        }
        data_logger.log_values_dict("policy/param", policy_params, step)

        if full_log:
            for idx, layer in enumerate(self.actor.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"actor/layer_weights_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"actor/layer_bias_{idx}",
                                                 layer.bias, step)

            for idx, layer in enumerate(self.critic.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"critic/layer_weights_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"critic/layer_bias_{idx}",
                                                 layer.bias, step)

    def get_state(self) -> AgentState:
        return AgentState(model=self.name,
                          state_space=self.state_size,
                          action_space=self.action_size,
                          config=self._config,
                          buffer=copy.deepcopy(self.buffer.get_state()),
                          network=copy.deepcopy(self.get_network_state()))

    def get_network_state(self) -> NetworkState:
        return NetworkState(net=dict(
            policy=self.policy.state_dict(),
            actor=self.actor.state_dict(),
            critic=self.critic.state_dict(),
        ))

    def set_buffer(self, buffer_state: BufferState) -> None:
        self.buffer = BufferFactory.from_state(buffer_state)

    def set_network(self, network_state: NetworkState) -> None:
        self.policy.load_state_dict(network_state.net['policy'])
        self.actor.load_state_dict(network_state.net['actor'])
        self.critic.load_state_dict(network_state.net['critic'])

    @staticmethod
    def from_state(state: AgentState) -> AgentBase:
        agent = PPOAgent(state.state_space, state.action_space, **state.config)
        if state.network is not None:
            agent.set_network(state.network)
        if state.buffer is not None:
            agent.set_buffer(state.buffer)
        return agent

    def save_state(self, path: str):
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self, path: str):
        agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.policy.load_state_dict(agent_state['policy'])
        self.actor.load_state_dict(agent_state['actor'])
        self.critic.load_state_dict(agent_state['critic'])
Ejemplo n.º 6
0
class DDPGAgent(AgentBase):
    """
    Deep Deterministic Policy Gradients (DDPG).

    Instead of popular Ornstein-Uhlenbeck (OU) process for noise this agent uses Gaussian noise.
    """

    name = "DDPG"

    def __init__(self,
                 state_size: int,
                 action_size: int,
                 noise_scale: float = 0.2,
                 noise_sigma: float = 0.1,
                 **kwargs):
        """
        Parameters:
            state_size: Number of input dimensions.
            action_size: Number of output dimensions
            noise_scale (float): Added noise amplitude. Default: 0.2.
            noise_sigma (float): Added noise variance. Default: 0.1.

        Keyword parameters:
            hidden_layers (tuple of ints): Tuple defining hidden dimensions in fully connected nets. Default: (64, 64).
            gamma (float): Discount value. Default: 0.99.
            tau (float): Soft-copy factor. Default: 0.002.
            actor_lr (float): Learning rate for the actor (policy). Default: 0.0003.
            critic_lr (float): Learning rate for the critic (value function). Default: 0.0003.
            max_grad_norm_actor (float) Maximum norm value for actor gradient. Default: 10.
            max_grad_norm_critic (float): Maximum norm value for critic gradient. Default: 10.
            batch_size (int): Number of samples used in learning. Default: 64.
            buffer_size (int): Maximum number of samples to store. Default: 1e6.
            warm_up (int): Number of samples to observe before starting any learning step. Default: 0.
            update_freq (int): Number of steps between each learning step. Default 1.
            number_updates (int): How many times to use learning step in the learning phase. Default: 1.
            action_min (float): Minimum returned action value. Default: -1.
            action_max (float): Maximum returned action value. Default: 1.
            action_scale (float): Multipler value for action. Default: 1.

        """
        super().__init__(**kwargs)
        self.device = self._register_param(kwargs, "device", DEVICE)
        self.state_size = state_size
        self.action_size = action_size

        # Reason sequence initiation.
        hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'hidden_layers', (64, 64)))
        self.actor = ActorBody(state_size,
                               action_size,
                               hidden_layers=hidden_layers,
                               gate_out=torch.tanh).to(self.device)
        self.critic = CriticBody(state_size,
                                 action_size,
                                 hidden_layers=hidden_layers).to(self.device)
        self.target_actor = ActorBody(state_size,
                                      action_size,
                                      hidden_layers=hidden_layers,
                                      gate_out=torch.tanh).to(self.device)
        self.target_critic = CriticBody(state_size,
                                        action_size,
                                        hidden_layers=hidden_layers).to(
                                            self.device)

        # Noise sequence initiation
        self.noise = GaussianNoise(shape=(action_size, ),
                                   mu=1e-8,
                                   sigma=noise_sigma,
                                   scale=noise_scale,
                                   device=self.device)

        # Target sequence initiation
        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)

        # Optimization sequence initiation.
        self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4))
        self.actor_optimizer = Adam(self.actor.parameters(), lr=self.actor_lr)
        self.critic_optimizer = Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 10.0))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 10.0))
        self.action_min = float(self._register_param(kwargs, 'action_min', -1))
        self.action_max = float(self._register_param(kwargs, 'action_max', 1))
        self.action_scale = float(
            self._register_param(kwargs, 'action_scale', 1))

        self.gamma = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size = int(self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size = int(
            self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.buffer = ReplayBuffer(self.batch_size, self.buffer_size)

        self.warm_up = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq = int(self._register_param(kwargs, 'update_freq', 1))
        self.number_updates = int(
            self._register_param(kwargs, 'number_updates', 1))

        # Breath, my child.
        self.reset_agent()
        self.iteration = 0
        self._loss_actor = 0.
        self._loss_critic = 0.

    def reset_agent(self) -> None:
        self.actor.reset_parameters()
        self.critic.reset_parameters()
        self.target_actor.reset_parameters()
        self.target_critic.reset_parameters()

    @property
    def loss(self) -> Dict[str, float]:
        return {'actor': self._loss_actor, 'critic': self._loss_critic}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            self._loss_actor = value['actor']
            self._loss_critic = value['critic']
        else:
            self._loss_actor = value
            self._loss_critic = value

    def __eq__(self, o: object) -> bool:
        return super().__eq__(o) \
            and self._config == o._config \
            and self.buffer == o.buffer \
            and self.get_network_state() == o.get_network_state()

    @torch.no_grad()
    def act(self, obs, noise: float = 0.0) -> List[float]:
        """Acting on the observations. Returns action.

        Returns:
            action: (list float) Action values.
        """
        obs = to_tensor(obs).float().to(self.device)
        action = self.actor(obs)
        action += noise * self.noise.sample()
        action = torch.clamp(action * self.action_scale, self.action_min,
                             self.action_max)
        return action.cpu().numpy().tolist()

    def step(self, state, action, reward, next_state, done) -> None:
        self.iteration += 1
        self.buffer.add(state=state,
                        action=action,
                        reward=reward,
                        next_state=next_state,
                        done=done)

        if self.iteration < self.warm_up:
            return

        if len(self.buffer) > self.batch_size and (self.iteration %
                                                   self.update_freq) == 0:
            for _ in range(self.number_updates):
                self.learn(self.buffer.sample())

    def compute_value_loss(self, states, actions, next_states, rewards, dones):
        next_actions = self.target_actor.act(next_states)
        assert next_actions.shape == actions.shape
        Q_target_next = self.target_critic.act(next_states, next_actions)
        Q_target = rewards + self.gamma * Q_target_next * (1 - dones)
        Q_expected = self.critic(states, actions)
        assert Q_expected.shape == Q_target.shape == Q_target_next.shape
        return mse_loss(Q_expected, Q_target)

    def compute_policy_loss(self, states) -> None:
        """Compute Policy loss based on provided states.

        Loss = Mean(-Q(s, _a) ),
        where _a is actor's estimate based on state, _a = Actor(s).
        """
        pred_actions = self.actor(states)
        return -self.critic(states, pred_actions).mean()

    def learn(self, experiences) -> None:
        """Update critics and actors"""
        rewards = to_tensor(experiences['reward']).float().to(
            self.device).unsqueeze(1)
        dones = to_tensor(experiences['done']).type(torch.int).to(
            self.device).unsqueeze(1)
        states = to_tensor(experiences['state']).float().to(self.device)
        actions = to_tensor(experiences['action']).to(self.device)
        next_states = to_tensor(experiences['next_state']).float().to(
            self.device)
        assert rewards.shape == dones.shape == (self.batch_size, 1)
        assert states.shape == next_states.shape == (self.batch_size,
                                                     self.state_size)
        assert actions.shape == (self.batch_size, self.action_size)

        # Value (critic) optimization
        loss_critic = self.compute_value_loss(states, actions, next_states,
                                              rewards, dones)
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        nn.utils.clip_grad_norm_(self.critic.parameters(),
                                 self.max_grad_norm_critic)
        self.critic_optimizer.step()
        self._loss_critic = float(loss_critic.item())

        # Policy (actor) optimization
        loss_actor = self.compute_policy_loss(states)
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(),
                                 self.max_grad_norm_actor)
        self.actor_optimizer.step()
        self._loss_actor = loss_actor.item()

        # Soft update target weights
        soft_update(self.target_actor, self.actor, self.tau)
        soft_update(self.target_critic, self.critic, self.tau)

    def state_dict(self) -> Dict[str, dict]:
        """Describes agent's networks.

        Returns:
            state: (dict) Provides actors and critics states.

        """
        return {
            "actor": self.actor.state_dict(),
            "target_actor": self.target_actor.state_dict(),
            "critic": self.critic.state_dict(),
            "target_critic": self.target_critic.state_dict()
        }

    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)

        if full_log:
            for idx, layer in enumerate(self.actor.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"actor/layer_weights_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"actor/layer_bias_{idx}",
                                                 layer.bias, step)

            for idx, layer in enumerate(self.critic.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"critic/layer_weights_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"critic/layer_bias_{idx}",
                                                 layer.bias, step)

    def get_state(self) -> AgentState:
        return AgentState(
            model=self.name,
            state_space=self.state_size,
            action_space=self.action_size,
            config=self._config,
            buffer=copy.deepcopy(self.buffer.get_state()),
            network=copy.deepcopy(self.get_network_state()),
        )

    def get_network_state(self) -> NetworkState:
        net = dict(
            actor=self.actor.state_dict(),
            target_actor=self.target_actor.state_dict(),
            critic=self.critic.state_dict(),
            target_critic=self.target_critic.state_dict(),
        )
        return NetworkState(net=net)

    @staticmethod
    def from_state(state: AgentState) -> AgentBase:
        config = copy.copy(state.config)
        config.update({
            'state_size': state.state_space,
            'action_size': state.action_space
        })
        agent = DDPGAgent(**config)
        if state.network is not None:
            agent.set_network(state.network)
        if state.buffer is not None:
            agent.set_buffer(state.buffer)
        return agent

    def set_buffer(self, buffer_state: BufferState) -> None:
        self.buffer = BufferFactory.from_state(buffer_state)

    def set_network(self, network_state: NetworkState) -> None:
        self.actor.load_state_dict(copy.deepcopy(network_state.net['actor']))
        self.target_actor.load_state_dict(network_state.net['target_actor'])
        self.critic.load_state_dict(network_state.net['critic'])
        self.target_critic.load_state_dict(network_state.net['target_critic'])

    def save_state(self, path: str) -> None:
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self,
                   *,
                   path: Optional[str] = None,
                   agent_state: Optional[dict] = None):
        if path is None and agent_state:
            raise ValueError(
                "Either `path` or `agent_state` must be provided to load agent's state."
            )
        if path is not None and agent_state is None:
            agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.actor.load_state_dict(agent_state['actor'])
        self.critic.load_state_dict(agent_state['critic'])
        self.target_actor.load_state_dict(agent_state['target_actor'])
        self.target_critic.load_state_dict(agent_state['target_critic'])
Ejemplo n.º 7
0
class D4PGAgent(AgentBase):
    """
    Distributed Distributional DDPG (D4PG) [1].

    Extends the DDPG agent with:
    1. Distributional critic update.
    2. The use of distributed parallel actors.
    3. N-step returns.
    4. Prioritization of the experience replay (PER).

    [1] "Distributed Distributional Deterministic Policy Gradients"
        (2018, ICLR) by G. Barth-Maron & M. Hoffman et al.

    """

    name = "D4PG"

    def __init__(self,
                 state_size: int,
                 action_size: int,
                 hidden_layers: Sequence[int] = (128, 128),
                 **kwargs):
        """
        Parameters:
            state_size (int): Number of input dimensions.
            action_size (int): Number of output dimensions
            hidden_layers (tuple of ints): Tuple defining hidden dimensions in fully connected nets. Default: (128, 128).

        Keyword parameters:
            gamma (float): Discount value. Default: 0.99.
            tau (float): Soft-copy factor. Default: 0.02.
            actor_lr (float): Learning rate for the actor (policy). Default: 0.0003.
            critic_lr (float): Learning rate for the critic (value function). Default: 0.0003.
            actor_hidden_layers (tuple of ints): Shape of network for actor. Default: `hideen_layers`.
            critic_hidden_layers (tuple of ints): Shape of network for critic. Default: `hideen_layers`.
            max_grad_norm_actor (float) Maximum norm value for actor gradient. Default: 100.
            max_grad_norm_critic (float): Maximum norm value for critic gradient. Default: 100.
            num_atoms (int): Number of discrete values for the value distribution. Default: 51.
            v_min (float): Value distribution minimum (left most) value. Default: -10.
            v_max (float): Value distribution maximum (right most) value. Default: 10.
            n_steps (int): Number of steps (N-steps) for the TD. Defualt: 3.
            batch_size (int): Number of samples used in learning. Default: 64.
            buffer_size (int): Maximum number of samples to store. Default: 1e6.
            warm_up (int): Number of samples to observe before starting any learning step. Default: 0.
            update_freq (int): Number of steps between each learning step. Default 1.
            number_updates (int): How many times to use learning step in the learning phase. Default: 1.
            action_min (float): Minimum returned action value. Default: -1.
            action_max (float): Maximum returned action value. Default: 1.
            action_scale (float): Multipler value for action. Default: 1.
            num_workers (int): Number of workers that will assume this agent. Default: 1.

        """
        super().__init__(**kwargs)
        self.device = self._register_param(kwargs, "device", DEVICE)
        self.state_size = state_size
        self.action_size = action_size

        self.num_atoms = int(self._register_param(kwargs, 'num_atoms', 51))
        v_min = float(self._register_param(kwargs, 'v_min', -10))
        v_max = float(self._register_param(kwargs, 'v_max', 10))

        # Reason sequence initiation.
        self.action_min = float(self._register_param(kwargs, 'action_min', -1))
        self.action_max = float(self._register_param(kwargs, 'action_max', 1))
        self.action_scale = float(
            self._register_param(kwargs, 'action_scale', 1))

        self.gamma = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size = int(self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size = int(
            self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.buffer = PERBuffer(self.batch_size, self.buffer_size)

        self.n_steps = int(self._register_param(kwargs, "n_steps", 3))
        self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma)

        self.warm_up = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq = int(self._register_param(kwargs, 'update_freq', 1))

        self.actor_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'actor_hidden_layers', hidden_layers))
        self.critic_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'critic_hidden_layers',
                                 hidden_layers))

        if kwargs.get("simple_policy", False):
            std_init = float(self._register_param(kwargs, "std_init", 1.0))
            std_max = float(self._register_param(kwargs, "std_max", 2.0))
            std_min = float(self._register_param(kwargs, "std_min", 0.05))
            self.policy = MultivariateGaussianPolicySimple(self.action_size,
                                                           std_init=std_init,
                                                           std_min=std_min,
                                                           std_max=std_max,
                                                           device=self.device)
        else:
            self.policy = MultivariateGaussianPolicy(self.action_size,
                                                     device=self.device)

        # This looks messy but it's not that bad. Actor, critic_net and Critic(critic_net). Then the same for `target_`.
        self.actor = ActorBody(state_size,
                               self.policy.param_dim * action_size,
                               hidden_layers=self.actor_hidden_layers,
                               gate_out=torch.tanh,
                               device=self.device)
        critic_net = CriticBody(state_size,
                                action_size,
                                out_features=self.num_atoms,
                                hidden_layers=self.critic_hidden_layers,
                                device=self.device)
        self.critic = CategoricalNet(num_atoms=self.num_atoms,
                                     v_min=v_min,
                                     v_max=v_max,
                                     net=critic_net,
                                     device=self.device)

        self.target_actor = ActorBody(state_size,
                                      self.policy.param_dim * action_size,
                                      hidden_layers=self.actor_hidden_layers,
                                      gate_out=torch.tanh,
                                      device=self.device)
        target_critic_net = CriticBody(state_size,
                                       action_size,
                                       out_features=self.num_atoms,
                                       hidden_layers=self.critic_hidden_layers,
                                       device=self.device)
        self.target_critic = CategoricalNet(num_atoms=self.num_atoms,
                                            v_min=v_min,
                                            v_max=v_max,
                                            net=target_critic_net,
                                            device=self.device)

        # Target sequence initiation
        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)

        # Optimization sequence initiation.
        self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4))
        self.value_loss_func = nn.BCELoss(reduction='none')

        self.actor_params = list(self.actor.parameters()) + list(
            self.policy.parameters())
        self.actor_optimizer = Adam(self.actor_params, lr=self.actor_lr)
        self.critic_optimizer = Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 100))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 100))

        self.num_workers = int(self._register_param(kwargs, "num_workers", 1))

        # Breath, my child.
        self.iteration = 0
        self._loss_actor = float('nan')
        self._loss_critic = float('nan')

    @property
    def loss(self) -> Dict[str, float]:
        return {'actor': self._loss_actor, 'critic': self._loss_critic}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            self._loss_actor = value['actor']
            self._loss_critic = value['critic']
        else:
            self._loss_actor = value
            self._loss_critic = value

    @torch.no_grad()
    def act(self, state, epsilon: float = 0.) -> List[float]:
        """
        Returns actions for given state as per current policy.

        Parameters:
            state: Current available state from the environment.
            epislon: Epsilon value in the epislon-greedy policy.

        """
        actions = []
        state = to_tensor(state).view(self.num_workers,
                                      self.state_size).float().to(self.device)
        for worker in range(self.num_workers):
            if self._rng.random() < epsilon:
                action = self.action_scale * (torch.rand(self.action_size) -
                                              0.5)
            else:
                action_seed = self.actor.act(state[worker].view(1, -1))
                action_dist = self.policy(action_seed)
                action = action_dist.sample()
                action *= self.action_scale
                action = torch.clamp(action.squeeze(), self.action_min,
                                     self.action_max).cpu()
            actions.append(action.tolist())

        assert len(actions) == self.num_workers
        return actions

    def step(self, states, actions, rewards, next_states, dones):
        self.iteration += 1

        # Delay adding to buffer to account for n_steps (particularly the reward)
        self.n_buffer.add(
            state=torch.tensor(states).reshape(self.num_workers,
                                               self.state_size).float(),
            next_state=torch.tensor(next_states).reshape(
                self.num_workers, self.state_size).float(),
            action=torch.tensor(actions).reshape(self.num_workers,
                                                 self.action_size).float(),
            reward=torch.tensor(rewards).reshape(self.num_workers, 1),
            done=torch.tensor(dones).reshape(self.num_workers, 1),
        )
        if not self.n_buffer.available:
            return

        samples = self.n_buffer.get().get_dict()
        for worker_idx in range(self.num_workers):
            self.buffer.add(
                state=samples['state'][worker_idx],
                next_state=samples['next_state'][worker_idx],
                action=samples['action'][worker_idx],
                done=samples['done'][worker_idx],
                reward=samples['reward'][worker_idx],
            )

        if self.iteration < self.warm_up:
            return

        if len(self.buffer) > self.batch_size and (self.iteration %
                                                   self.update_freq) == 0:
            self.learn(self.buffer.sample())

    def compute_value_loss(self,
                           states,
                           actions,
                           next_states,
                           rewards,
                           dones,
                           indices=None):
        # Q_w estimate
        value_dist_estimate = self.critic(states, actions)
        assert value_dist_estimate.shape == (self.batch_size, 1,
                                             self.num_atoms)
        value_dist = F.softmax(value_dist_estimate.squeeze(), dim=1)
        assert value_dist.shape == (self.batch_size, self.num_atoms)

        # Q_w' estimate via Bellman's dist operator
        next_action_seeds = self.target_actor.act(next_states)
        next_actions = self.policy(next_action_seeds).sample()
        assert next_actions.shape == (self.batch_size, self.action_size)

        target_value_dist_estimate = self.target_critic.act(
            states, next_actions)
        assert target_value_dist_estimate.shape == (self.batch_size, 1,
                                                    self.num_atoms)
        target_value_dist_estimate = target_value_dist_estimate.squeeze()
        assert target_value_dist_estimate.shape == (self.batch_size,
                                                    self.num_atoms)

        discount = self.gamma**self.n_steps
        target_value_projected = self.target_critic.dist_projection(
            rewards, 1 - dones, discount, target_value_dist_estimate)
        assert target_value_projected.shape == (self.batch_size,
                                                self.num_atoms)

        target_value_dist = F.softmax(target_value_dist_estimate,
                                      dim=-1).detach()
        assert target_value_dist.shape == (self.batch_size, self.num_atoms)

        # Comparing Q_w with Q_w'
        loss = self.value_loss_func(value_dist, target_value_projected)
        self._metric_batch_error = loss.detach().sum(dim=-1)
        samples_error = loss.sum(dim=-1).pow(2)
        loss_critic = samples_error.mean()

        if hasattr(self.buffer, 'priority_update') and indices is not None:
            assert (~torch.isnan(samples_error)).any()
            self.buffer.priority_update(indices,
                                        samples_error.detach().cpu().numpy())

        return loss_critic

    def compute_policy_loss(self, states):
        # Compute actor loss
        pred_action_seeds = self.actor(states)
        pred_actions = self.policy.act(pred_action_seeds)
        pred_actions = self.policy(pred_action_seeds).rsample()
        # Negative because the optimizer minimizes, but we want to maximize the value
        value_dist = self.critic(states, pred_actions)
        self._batch_value_dist_metric = value_dist.detach()
        # Estimate on Z support
        return -torch.mean(value_dist * self.critic.z_atoms)

    def learn(self, experiences):
        """Update critics and actors"""
        # No need for size assertion since .view() has explicit sizes
        rewards = to_tensor(experiences['reward']).view(
            self.batch_size, 1).float().to(self.device)
        dones = to_tensor(experiences['done']).view(self.batch_size, 1).type(
            torch.int).to(self.device)
        states = to_tensor(experiences['state']).view(
            self.batch_size, self.state_size).float().to(self.device)
        actions = to_tensor(experiences['action']).view(
            self.batch_size, self.action_size).to(self.device)
        next_states = to_tensor(experiences['next_state']).view(
            self.batch_size, self.state_size).float().to(self.device)

        indices = None
        if hasattr(self.buffer, 'priority_update'):  # When using PER buffer
            indices = experiences['index']

        # Value (critic) optimization
        loss_critic = self.compute_value_loss(states, actions, next_states,
                                              rewards, dones, indices)
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_critic)
        self.critic_optimizer.step()
        self._loss_critic = float(loss_critic.item())

        # Policy (actor) optimization
        loss_actor = self.compute_policy_loss(states)
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(),
                                 self.max_grad_norm_actor)
        self.actor_optimizer.step()
        self._loss_actor = float(loss_actor.item())

        # Networks gradual sync
        soft_update(self.target_actor, self.actor, self.tau)
        soft_update(self.target_critic, self.critic, self.tau)

    def state_dict(self) -> Dict[str, dict]:
        """Describes agent's networks.

        Returns:
            state: (dict) Provides actors and critics states.

        """
        return {
            "actor": self.actor.state_dict(),
            "target_actor": self.target_actor.state_dict(),
            "critic": self.critic.state_dict(),
            "target_critic": self.target_critic()
        }

    def log_metrics(self, data_logger: DataLogger, step, full_log=False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)
        policy_params = {
            str(i): v
            for i, v in enumerate(
                itertools.chain.from_iterable(self.policy.parameters()))
        }
        data_logger.log_values_dict("policy/param", policy_params, step)

        data_logger.create_histogram('metric/batch_errors',
                                     self._metric_batch_error.sum(-1), step)
        data_logger.create_histogram('metric/batch_value_dist',
                                     self._batch_value_dist_metric, step)

        # This method, `log_metrics`, isn't executed on every iteration but just in case we delay plotting weights.
        # It simply might be quite costly. Thread wisely.
        if full_log:
            for idx, layer in enumerate(self.actor.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"actor/layer_weights_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"actor/layer_bias_{idx}",
                                                 layer.bias, step)

            for idx, layer in enumerate(self.critic.net.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"critic/layer_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"critic/layer_bias_{idx}",
                                                 layer.bias, step)

    def get_state(self):
        return dict(
            actor=self.actor.state_dict(),
            target_actor=self.target_actor.state_dict(),
            critic=self.critic.state_dict(),
            target_critic=self.target_critic.state_dict(),
            config=self._config,
        )

    def save_state(self, path: str):
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self, path: str):
        agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.actor.load_state_dict(agent_state['actor'])
        self.critic.load_state_dict(agent_state['critic'])
        self.target_actor.load_state_dict(agent_state['target_actor'])
        self.target_critic.load_state_dict(agent_state['target_critic'])