Beispiel #1
0
def test_per_buffer_priority_update():
    """Update all priorities to the same value makes them all to be 1."""
    # Assign
    batch_size = 5
    buffer_size = 10
    per_buffer = PERBuffer(batch_size, buffer_size)
    for _ in range(2 * buffer_size):  # Make sure we fill the whole buffer
        per_buffer.add(priority=np.random.randint(10),
                       state=np.random.random(10))
    per_buffer.add(priority=100,
                   state=np.random.random(10))  # Make sure there's one highest

    # Act & Assert
    experiences = per_buffer.sample(beta=0.5)
    assert experiences is not None
    assert sum(experiences['weight']) < batch_size
    # assert sum([weight for exp in experiences]) < batch_size

    per_buffer.priority_update(indices=range(buffer_size),
                               priorities=np.ones(buffer_size))
    experiences = per_buffer.sample(beta=0.9)
    assert experiences is not None
    # weights = [exp.weight for exp in experiences]

    assert sum(experiences['weight']) == batch_size
    assert all([w == 1 for w in experiences['weight']])
Beispiel #2
0
    def load_state(self,
                   *,
                   path: Optional[str] = None,
                   state: Optional[AgentState] = None) -> None:
        """Loads state from a file under provided path.

        Parameters:
            path: String path indicating where the state is stored.

        """
        if path is None and state is None:
            raise ValueError(
                "Either `path` or `state` must be provided to load agent's state."
            )
        if path is not None:
            state = torch.load(path)

        # Populate agent
        agent_state = state.agent
        self._config = agent_state.config
        self.__dict__.update(**self._config)

        # Populate network
        network_state = state.network
        self.net.load_state_dict(network_state.net['net'])
        self.target_net.load_state_dict(network_state.net['target_net'])
        self.buffer = PERBuffer(**self._config)
Beispiel #3
0
def test_per_from_state_wrong_type():
    # Assign
    buffer = PERBuffer(batch_size=5, buffer_size=20)
    state = buffer.get_state()
    state.type = "WrongType"

    # Act
    with pytest.raises(ValueError):
        PERBuffer.from_state(state=state)
Beispiel #4
0
def test_per_buffer_len():
    # Assign
    buffer_size = 10
    per_buffer = PERBuffer(5, buffer_size)

    # Act & Assert
    for sample_num in range(buffer_size + 2):
        assert len(per_buffer) == min(sample_num, buffer_size)
        per_buffer.add(priority=1, state=1)
Beispiel #5
0
    def __init__(self,
                 state_size: Union[Sequence[int], int],
                 action_size: int,
                 lr: float = 0.001,
                 gamma: float = 0.99,
                 tau: float = 0.002,
                 network_fn: Callable[[], NetworkType] = None,
                 hidden_layers: Sequence[int] = (64, 64),
                 state_transform: Optional[Callable] = None,
                 reward_transform: Optional[Callable] = None,
                 device=None,
                 **kwargs):
        """
        Accepted parameters:
        :param float lr: learning rate (default: 1e-3)
        :param float gamma: discount factor (default: 0.99)
        :param float tau: soft-copy factor (default: 0.002) 

        """

        self.device = device if device is not None else DEVICE
        self.state_size = state_size if not isinstance(state_size, int) else (
            state_size, )
        self.action_size = action_size

        self.lr = float(kwargs.get('lr', lr))
        self.gamma = float(kwargs.get('gamma', gamma))
        self.tau = float(kwargs.get('tau', tau))

        self.update_freq = int(kwargs.get('update_freq', 1))
        self.batch_size = int(kwargs.get('batch_size', 32))
        self.warm_up = int(kwargs.get('warm_up', 0))
        self.number_updates = int(kwargs.get('number_updates', 1))
        self.max_grad_norm = float(kwargs.get('max_grad_norm', 10))

        self.iteration: int = 0
        self.buffer = PERBuffer(self.batch_size)
        self.using_double_q = bool(kwargs.get("using_double_q", False))

        self.n_steps = kwargs.get("n_steps", 1)
        self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma)

        self.state_transform = state_transform if state_transform is not None else lambda x: x
        self.reward_transform = reward_transform if reward_transform is not None else lambda x: x
        if network_fn:
            self.net = network_fn().to(self.device)
            self.target_net = network_fn().to(self.device)
        else:
            hidden_layers = kwargs.get('hidden_layers', hidden_layers)
            self.net = DuelingNet(self.state_size[0],
                                  self.action_size,
                                  hidden_layers=hidden_layers).to(self.device)
            self.target_net = DuelingNet(self.state_size[0],
                                         self.action_size,
                                         hidden_layers=hidden_layers).to(
                                             self.device)
        self.optimizer = optim.SGD(self.net.parameters(), lr=self.lr)
Beispiel #6
0
def test_per_get_state_without_data():
    # Assign
    buffer = PERBuffer(batch_size=5, buffer_size=20)

    # Act
    state = buffer.get_state()

    # Assert
    assert state.type == PERBuffer.type
    assert state.buffer_size == 20
    assert state.batch_size == 5
    assert state.data is None
Beispiel #7
0
def test_per_from_state_without_data():
    # Assign
    buffer = PERBuffer(batch_size=5, buffer_size=20)
    state = buffer.get_state()

    # Act
    new_buffer = PERBuffer.from_state(state=state)

    # Assert
    assert new_buffer == buffer
    assert new_buffer.buffer_size == state.buffer_size
    assert new_buffer.batch_size == state.batch_size
    assert new_buffer.data == []
Beispiel #8
0
def test_per_buffer_add_one_sample_one():
    # Assign
    per_buffer = PERBuffer(1, 20)

    # Act
    per_buffer.add(priority=0.5, state=range(5))

    # Assert
    samples = per_buffer.sample()
    assert samples is not None
    assert samples['state'] == [range(5)]
    assert samples['weight'] == [1.]  # max scale
    assert samples['index'] == [0]
Beispiel #9
0
def test_per_buffer_add_one_sample_one():
    # Assign
    per_buffer = PERBuffer(1, 20)

    # Act
    per_buffer.add(priority=0.5, state=range(5))

    # Assert
    raw_samples = per_buffer.sample_list()
    assert raw_samples is not None
    experience = raw_samples[0]
    assert experience.state == range(5)
    assert experience.weight == 1.  # max scale
    assert experience.index == 0
Beispiel #10
0
def test_per_buffer_reset_alpha():
    # Assign
    per_buffer = PERBuffer(10, 10, alpha=0.1)
    for _ in range(30):
        per_buffer.add(reward=np.random.randint(0, 1e5),
                       priority=np.random.random())

    # Act
    old_experiences = per_buffer.sample()
    per_buffer.reset_alpha(0.5)
    new_experiences = per_buffer.sample()

    # Assert
    assert old_experiences is not None and new_experiences is not None
    old_index, new_index = np.array(old_experiences['index']), np.array(
        new_experiences['index'])
    old_weight, new_weight = np.array(old_experiences['weight']), np.array(
        new_experiences['weight'])
    old_reward, new_reward = np.array(old_experiences['reward']), np.array(
        new_experiences['reward'])
    old_sort, new_sort = np.argsort(old_index), np.argsort(new_index)
    assert all([
        i1 == i2 for (i1, i2) in zip(old_index[old_sort], new_index[new_sort])
    ])
    assert all([
        w1 != w2
        for (w1, w2) in zip(old_weight[old_sort], new_weight[new_sort])
    ])
    assert all([
        r1 == r2
        for (r1, r2) in zip(old_reward[old_sort], new_reward[new_sort])
    ])
Beispiel #11
0
def test_per_from_state_with_data():
    # Assign
    buffer = PERBuffer(batch_size=5, buffer_size=20)
    buffer = populate_buffer(buffer, 30)
    state = buffer.get_state()

    # Act
    new_buffer = PERBuffer.from_state(state=state)

    # Assert
    assert new_buffer == buffer
    assert new_buffer.buffer_size == state.buffer_size
    assert new_buffer.batch_size == state.batch_size
    assert new_buffer.data == state.data
    assert len(buffer.data) == state.buffer_size
Beispiel #12
0
def test_per_buffer_sample():
    # Assign
    buffer_size = 5
    per_buffer = PERBuffer(buffer_size)

    # Act
    for priority in range(buffer_size):
        state = np.arange(priority, priority + 10)
        per_buffer.add(priority=priority + 0.01, state=state)

    # Assert
    experiences = per_buffer.sample()
    assert experiences is not None
    state = experiences['state']
    weight = experiences['weight']
    index = experiences['index']
    assert len(state) == len(weight) == len(index) == buffer_size
    assert all([s is not None for s in state])
Beispiel #13
0
def test_priority_buffer_load_json_dump():
    # Assign
    prop_keys = ["state", "action", "reward", "next_state", "done"]
    buffer = PERBuffer(batch_size=10, buffer_size=20)
    ser_buffer = []
    for sars in generate_sample_SARS(10, dict_type=True):
        ser_buffer.append(Experience(**sars))

    # Act
    buffer.load_buffer(ser_buffer)

    # Assert
    samples = buffer._sample_list()
    assert len(buffer) == 10
    assert len(samples) == 10
    for sample in samples:
        assert all([hasattr(sample, key) for key in prop_keys])
        assert all(
            [isinstance(getattr(sample, key), list) for key in prop_keys])
Beispiel #14
0
def test_per_buffer_too_few_samples():
    # Assign
    batch_size = 5
    per_buffer = PERBuffer(batch_size, 10)

    # Act & Assert
    for _ in range(batch_size - 1):
        per_buffer.add(priority=0.1, reward=0.1)
        assert per_buffer.sample() is None

    per_buffer.add(priority=0.1, reward=0.1)
    assert len(per_buffer.sample()['reward']) == 5
Beispiel #15
0
def test_per_get_state_with_data():
    # Assign
    buffer = PERBuffer(batch_size=5, buffer_size=20)
    sample_experience = Experience(state=[0, 1], action=[0], reward=0)
    for _ in range(25):
        buffer.add(**sample_experience.data)

    # Act
    state = buffer.get_state()

    # Assert
    assert state.type == PERBuffer.type
    assert state.buffer_size == 20
    assert state.batch_size == 5
    assert state.data is not None
    assert len(state.data) == 20

    for data in state.data:
        assert data == sample_experience
 def from_state(state: BufferState) -> BufferBase:
     if state.type == ReplayBuffer.type:
         return ReplayBuffer.from_state(state)
     elif state.type == PERBuffer.type:
         return PERBuffer.from_state(state)
     elif state.type == NStepBuffer.type:
         return NStepBuffer.from_state(state)
     elif state.type == RolloutBuffer.type:
         return RolloutBuffer.from_state(state)
     else:
         raise ValueError(f"Buffer state contains unsupported buffer type: '{state.type}'")
Beispiel #17
0
def test_priority_buffer_dump_serializable():
    import json
    import torch
    # Assign
    filled_buffer = 8
    buffer = PERBuffer(batch_size=5, buffer_size=10)
    for sars in generate_sample_SARS(filled_buffer):
        buffer.add(state=torch.tensor(sars[0]),
                   reward=sars[1],
                   action=[sars[2]],
                   next_state=torch.tensor(sars[3]),
                   dones=sars[4])

    # Act
    dump = list(buffer.dump_buffer(serialize=True))

    # Assert
    ser_dump = json.dumps(dump)
    assert isinstance(ser_dump, str)
    assert json.loads(ser_dump) == dump
Beispiel #18
0
def test_per_buffer_sample():
    # Assign
    buffer_size = 5
    per_buffer = PERBuffer(buffer_size)

    # Act
    for priority in range(buffer_size):
        state = np.arange(priority, priority + 10)
        per_buffer.add(priority=priority + 0.01, state=state)

    # Assert
    experiences = per_buffer.sample_list()
    assert experiences is not None
    assert len(experiences) == buffer_size
    zipped_exp = [(exp.state, exp.reward, exp.weight, exp.index)
                  for exp in experiences]
    states, rewards, weights, indices = zip(*zipped_exp)
    assert len(weights) == len(indices) == buffer_size
    assert all([s is not None for s in states])
    assert all([r is None for r in rewards])
Beispiel #19
0
def test_per_buffer_too_few_samples():
    # Assign
    batch_size = 5
    per_buffer = PERBuffer(batch_size, 10)

    # Act & Assert
    for _ in range(batch_size):
        assert per_buffer.sample_list() is None
        per_buffer.add(priority=0.1, reward=0.1)

    assert per_buffer.sample_list() is not None
Beispiel #20
0
def test_per_buffer_add_two_sample_two_beta():
    # Assign
    per_buffer = PERBuffer(2, 20, 0.4)

    # Act
    per_buffer.add(state=range(5), priority=0.9)
    per_buffer.add(state=range(3, 8), priority=0.1)

    # Assert
    experiences = per_buffer.sample(beta=0.6)
    assert experiences is not None
    for (state, weight) in zip(experiences['state'], experiences['weight']):
        if weight == 1:
            assert state == range(3, 8)
        else:
            assert 0.6421 < weight < 0.6422
            assert state == range(5)
Beispiel #21
0
def test_per_buffer_reset_alpha():
    # Assign
    per_buffer = PERBuffer(10, 10, alpha=0.1)
    for _ in range(30):
        per_buffer.add(reward=np.random.randint(0, 1e5),
                       priority=np.random.random())

    # Act
    old_experiences = per_buffer.sample_list()
    per_buffer.reset_alpha(0.5)
    new_experiences = per_buffer.sample_list()

    # Assert
    assert old_experiences is not None and new_experiences is not None
    sorted_new_experiences = sorted(new_experiences, key=lambda k: k.index)
    sorted_old_experiences = sorted(old_experiences, key=lambda k: k.index)
    for (new_sample, old_sample) in zip(sorted_new_experiences,
                                        sorted_old_experiences):
        assert new_sample.index == old_sample.index
        assert new_sample.weight != old_sample.weight
        assert new_sample.reward == old_sample.reward
Beispiel #22
0
def test_per_buffer_add_two_sample_two_beta():
    # Assign
    per_buffer = PERBuffer(2, 20)

    # Act
    per_buffer.add(state=range(5), priority=0.9)
    per_buffer.add(state=range(3, 8), priority=0.1)

    # Assert
    experiences = per_buffer.sample_list(beta=0.6)
    assert experiences is not None
    for experience in experiences:
        if experience.index == 0:
            assert experience.state == range(5)
            # assert 0.936 < experience.weight < 0.937
            assert 0.946 < experience.weight < 0.947
        else:
            assert experience.state == range(3, 8)
            assert experience.weight == 1.
Beispiel #23
0
    def __init__(self, in_features: FeatureType, action_size: int, **kwargs):
        """
        Parameters:
            hidden_layers: (default: (128, 128)) Shape of the hidden layers that are fully connected networks.
            gamma: (default: 0.99) Discount value.
            tau: (default: 0.02) Soft copy fraction.
            batch_size: (default 64) Number of samples in a batch.
            buffer_size: (default: 1e6) Size of the prioritized experience replay buffer.
            warm_up: (default: 0) Number of samples that needs to be observed before starting to learn.
            update_freq: (default: 1) Number of samples between policy updates.
            number_updates: (default: 1) Number of times of batch sampling/training per `update_freq`.
            alpha: (default: 0.2) Weight of log probs in value function.
            alpha_lr: (default: None) If provided, it will add alpha as a training parameters and `alpha_lr` is its learning rate.
            action_scale: (default: 1.) Scale for returned action values.
            max_grad_norm_alpha: (default: 1.) Gradient clipping for the alpha.
            max_grad_norm_actor: (default 10.) Gradient clipping for the actor.
            max_grad_norm_critic: (default: 10.) Gradient clipping for the critic.
            device: Defaults to CUDA if available.

        """
        super().__init__(**kwargs)
        self.device = kwargs.get("device", DEVICE)
        self.in_features: Tuple[int] = (in_features, ) if isinstance(
            in_features, int) else tuple(in_features)
        self.state_size: int = in_features if isinstance(
            in_features, int) else reduce(operator.mul, in_features)
        self.action_size = action_size

        self.gamma: float = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau: float = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size: int = int(
            self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size: int = int(
            self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.memory = PERBuffer(self.batch_size, self.buffer_size)

        self.action_min = self._register_param(kwargs, 'action_min', -1)
        self.action_max = self._register_param(kwargs, 'action_max', 1)
        self.action_scale = self._register_param(kwargs, 'action_scale', 1)

        self.warm_up = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq = int(self._register_param(kwargs, 'update_freq', 1))
        self.number_updates = int(
            self._register_param(kwargs, 'number_updates', 1))
        self.actor_number_updates = int(
            self._register_param(kwargs, 'actor_number_updates', 1))
        self.critic_number_updates = int(
            self._register_param(kwargs, 'critic_number_updates', 1))

        # Reason sequence initiation.
        hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'hidden_layers', (128, 128)))
        actor_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'actor_hidden_layers', hidden_layers))
        critic_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'critic_hidden_layers',
                                 hidden_layers))

        self.simple_policy = bool(
            self._register_param(kwargs, "simple_policy", False))
        if self.simple_policy:
            self.policy = MultivariateGaussianPolicySimple(
                self.action_size, **kwargs)
            self.actor = ActorBody(self.state_size,
                                   self.policy.param_dim * self.action_size,
                                   hidden_layers=actor_hidden_layers,
                                   device=self.device)
        else:
            self.policy = GaussianPolicy(actor_hidden_layers[-1],
                                         self.action_size,
                                         out_scale=self.action_scale,
                                         device=self.device)
            self.actor = ActorBody(self.state_size,
                                   actor_hidden_layers[-1],
                                   hidden_layers=actor_hidden_layers[:-1],
                                   device=self.device)

        self.double_critic = DoubleCritic(self.in_features,
                                          self.action_size,
                                          CriticBody,
                                          hidden_layers=critic_hidden_layers,
                                          device=self.device)
        self.target_double_critic = DoubleCritic(
            self.in_features,
            self.action_size,
            CriticBody,
            hidden_layers=critic_hidden_layers,
            device=self.device)

        # Target sequence initiation
        hard_update(self.target_double_critic, self.double_critic)

        # Optimization sequence initiation.
        self.target_entropy = -self.action_size
        alpha_lr = self._register_param(kwargs, "alpha_lr")
        self.alpha_lr = float(alpha_lr) if alpha_lr else None
        alpha_init = float(self._register_param(kwargs, "alpha", 0.2))
        self.log_alpha = torch.tensor(np.log(alpha_init),
                                      device=self.device,
                                      requires_grad=True)

        actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4))

        self.actor_params = list(self.actor.parameters()) + list(
            self.policy.parameters())
        self.critic_params = list(self.double_critic.parameters())
        self.actor_optimizer = optim.Adam(self.actor_params, lr=actor_lr)
        self.critic_optimizer = optim.Adam(list(self.critic_params),
                                           lr=critic_lr)
        if self.alpha_lr is not None:
            self.alpha_optimizer = optim.Adam([self.log_alpha],
                                              lr=self.alpha_lr)
        self.max_grad_norm_alpha = float(
            self._register_param(kwargs, "max_grad_norm_alpha", 1.0))
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 10.0))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 10.0))

        # Breath, my child.
        self.iteration = 0

        self._loss_actor = float('inf')
        self._loss_critic = float('inf')
        self._metrics: Dict[str, Union[float, Dict[str, float]]] = {}
Beispiel #24
0
class SACAgent(AgentBase):
    """
    Soft Actor-Critic.

    Uses stochastic policy and dual value network (two critics).

    Based on
    "Soft Actor-Critic: Off-Policy Maximum Entropy Deep Reinforcement Learning with a Stochastic Actor"
    by Haarnoja et al. (2018) (http://arxiv.org/abs/1801.01290).
    """

    name = "SAC"

    def __init__(self, in_features: FeatureType, action_size: int, **kwargs):
        """
        Parameters:
            hidden_layers: (default: (128, 128)) Shape of the hidden layers that are fully connected networks.
            gamma: (default: 0.99) Discount value.
            tau: (default: 0.02) Soft copy fraction.
            batch_size: (default 64) Number of samples in a batch.
            buffer_size: (default: 1e6) Size of the prioritized experience replay buffer.
            warm_up: (default: 0) Number of samples that needs to be observed before starting to learn.
            update_freq: (default: 1) Number of samples between policy updates.
            number_updates: (default: 1) Number of times of batch sampling/training per `update_freq`.
            alpha: (default: 0.2) Weight of log probs in value function.
            alpha_lr: (default: None) If provided, it will add alpha as a training parameters and `alpha_lr` is its learning rate.
            action_scale: (default: 1.) Scale for returned action values.
            max_grad_norm_alpha: (default: 1.) Gradient clipping for the alpha.
            max_grad_norm_actor: (default 10.) Gradient clipping for the actor.
            max_grad_norm_critic: (default: 10.) Gradient clipping for the critic.
            device: Defaults to CUDA if available.

        """
        super().__init__(**kwargs)
        self.device = kwargs.get("device", DEVICE)
        self.in_features: Tuple[int] = (in_features, ) if isinstance(
            in_features, int) else tuple(in_features)
        self.state_size: int = in_features if isinstance(
            in_features, int) else reduce(operator.mul, in_features)
        self.action_size = action_size

        self.gamma: float = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau: float = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size: int = int(
            self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size: int = int(
            self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.memory = PERBuffer(self.batch_size, self.buffer_size)

        self.action_min = self._register_param(kwargs, 'action_min', -1)
        self.action_max = self._register_param(kwargs, 'action_max', 1)
        self.action_scale = self._register_param(kwargs, 'action_scale', 1)

        self.warm_up = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq = int(self._register_param(kwargs, 'update_freq', 1))
        self.number_updates = int(
            self._register_param(kwargs, 'number_updates', 1))
        self.actor_number_updates = int(
            self._register_param(kwargs, 'actor_number_updates', 1))
        self.critic_number_updates = int(
            self._register_param(kwargs, 'critic_number_updates', 1))

        # Reason sequence initiation.
        hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'hidden_layers', (128, 128)))
        actor_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'actor_hidden_layers', hidden_layers))
        critic_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'critic_hidden_layers',
                                 hidden_layers))

        self.simple_policy = bool(
            self._register_param(kwargs, "simple_policy", False))
        if self.simple_policy:
            self.policy = MultivariateGaussianPolicySimple(
                self.action_size, **kwargs)
            self.actor = ActorBody(self.state_size,
                                   self.policy.param_dim * self.action_size,
                                   hidden_layers=actor_hidden_layers,
                                   device=self.device)
        else:
            self.policy = GaussianPolicy(actor_hidden_layers[-1],
                                         self.action_size,
                                         out_scale=self.action_scale,
                                         device=self.device)
            self.actor = ActorBody(self.state_size,
                                   actor_hidden_layers[-1],
                                   hidden_layers=actor_hidden_layers[:-1],
                                   device=self.device)

        self.double_critic = DoubleCritic(self.in_features,
                                          self.action_size,
                                          CriticBody,
                                          hidden_layers=critic_hidden_layers,
                                          device=self.device)
        self.target_double_critic = DoubleCritic(
            self.in_features,
            self.action_size,
            CriticBody,
            hidden_layers=critic_hidden_layers,
            device=self.device)

        # Target sequence initiation
        hard_update(self.target_double_critic, self.double_critic)

        # Optimization sequence initiation.
        self.target_entropy = -self.action_size
        alpha_lr = self._register_param(kwargs, "alpha_lr")
        self.alpha_lr = float(alpha_lr) if alpha_lr else None
        alpha_init = float(self._register_param(kwargs, "alpha", 0.2))
        self.log_alpha = torch.tensor(np.log(alpha_init),
                                      device=self.device,
                                      requires_grad=True)

        actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4))

        self.actor_params = list(self.actor.parameters()) + list(
            self.policy.parameters())
        self.critic_params = list(self.double_critic.parameters())
        self.actor_optimizer = optim.Adam(self.actor_params, lr=actor_lr)
        self.critic_optimizer = optim.Adam(list(self.critic_params),
                                           lr=critic_lr)
        if self.alpha_lr is not None:
            self.alpha_optimizer = optim.Adam([self.log_alpha],
                                              lr=self.alpha_lr)
        self.max_grad_norm_alpha = float(
            self._register_param(kwargs, "max_grad_norm_alpha", 1.0))
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 10.0))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 10.0))

        # Breath, my child.
        self.iteration = 0

        self._loss_actor = float('inf')
        self._loss_critic = float('inf')
        self._metrics: Dict[str, Union[float, Dict[str, float]]] = {}

    @property
    def alpha(self):
        return self.log_alpha.exp()

    @property
    def loss(self):
        return {'actor': self._loss_actor, 'critic': self._loss_critic}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            self._loss_actor = value['actor']
            self._loss_critic = value['critic']
        else:
            self._loss_actor = value
            self._loss_critic = value

    def reset_agent(self) -> None:
        self.actor.reset_parameters()
        self.policy.reset_parameters()
        self.double_critic.reset_parameters()
        hard_update(self.target_double_critic, self.double_critic)

    def state_dict(self) -> Dict[str, dict]:
        """
        Returns network's weights in order:
        Actor, TargetActor, Critic, TargetCritic
        """
        return {
            "actor": self.actor.state_dict(),
            "policy": self.policy.state_dict(),
            "double_critic": self.double_critic.state_dict(),
            "target_double_critic": self.target_double_critic.state_dict(),
        }

    @torch.no_grad()
    def act(self,
            state,
            epsilon: float = 0.0,
            deterministic=False) -> List[float]:
        if self.iteration < self.warm_up or self._rng.random() < epsilon:
            random_action = torch.rand(self.action_size) * (
                self.action_max + self.action_min) + self.action_min
            return random_action.cpu().tolist()

        state = to_tensor(state).view(1,
                                      self.state_size).float().to(self.device)
        proto_action = self.actor(state)
        action = self.policy(proto_action, deterministic)

        return action.flatten().tolist()

    def step(self, state, action, reward, next_state, done):
        self.iteration += 1
        self.memory.add(
            state=state,
            action=action,
            reward=reward,
            next_state=next_state,
            done=done,
        )

        if self.iteration < self.warm_up:
            return

        if len(self.memory) > self.batch_size and (self.iteration %
                                                   self.update_freq) == 0:
            for _ in range(self.number_updates):
                self.learn(self.memory.sample())

    def compute_value_loss(self, states, actions, rewards, next_states,
                           dones) -> Tuple[Tensor, Tensor]:
        Q1_expected, Q2_expected = self.double_critic(states, actions)

        with torch.no_grad():
            proto_next_action = self.actor(states)
            next_actions = self.policy(proto_next_action)
            log_prob = self.policy.logprob
            assert next_actions.shape == (self.batch_size, self.action_size)
            assert log_prob.shape == (self.batch_size, 1)

            Q1_target_next, Q2_target_next = self.target_double_critic.act(
                next_states, next_actions)
            assert Q1_target_next.shape == Q2_target_next.shape == (
                self.batch_size, 1)

            Q_min = torch.min(Q1_target_next, Q2_target_next)
            QH_target = Q_min - self.alpha * log_prob
            assert QH_target.shape == (self.batch_size, 1)

            Q_target = rewards + self.gamma * QH_target * (1 - dones)
            assert Q_target.shape == (self.batch_size, 1)

        Q1_diff = Q1_expected - Q_target
        error_1 = Q1_diff.pow(2)
        mse_loss_1 = error_1.mean()
        self._metrics['value/critic1'] = {
            'mean': float(Q1_expected.mean()),
            'std': float(Q1_expected.std())
        }
        self._metrics['value/critic1_lse'] = float(mse_loss_1.item())

        Q2_diff = Q2_expected - Q_target
        error_2 = Q2_diff.pow(2)
        mse_loss_2 = error_2.mean()
        self._metrics['value/critic2'] = {
            'mean': float(Q2_expected.mean()),
            'std': float(Q2_expected.std())
        }
        self._metrics['value/critic2_lse'] = float(mse_loss_2.item())

        Q_diff = Q1_expected - Q2_expected
        self._metrics['value/Q_diff'] = {
            'mean': float(Q_diff.mean()),
            'std': float(Q_diff.std())
        }

        error = torch.min(error_1, error_2)
        loss = mse_loss_1 + mse_loss_2
        return loss, error

    def compute_policy_loss(self, states):
        proto_actions = self.actor(states)
        pred_actions = self.policy(proto_actions)
        log_prob = self.policy.logprob
        assert pred_actions.shape == (self.batch_size, self.action_size)

        Q_estimate = torch.min(*self.double_critic(states, pred_actions))
        assert Q_estimate.shape == (self.batch_size, 1)

        self._metrics['policy/entropy'] = -float(log_prob.detach().mean())
        loss = (self.alpha * log_prob - Q_estimate).mean()

        # Update alpha
        if self.alpha_lr is not None:
            self.alpha_optimizer.zero_grad()
            loss_alpha = -(self.alpha *
                           (log_prob + self.target_entropy).detach()).mean()
            loss_alpha.backward()
            nn.utils.clip_grad_norm_(self.log_alpha, self.max_grad_norm_alpha)
            self.alpha_optimizer.step()

        return loss

    def learn(self, samples):
        """update the critics and actors of all the agents """

        rewards = to_tensor(samples['reward']).float().to(self.device).view(
            self.batch_size, 1)
        dones = to_tensor(samples['done']).int().to(self.device).view(
            self.batch_size, 1)
        states = to_tensor(samples['state']).float().to(self.device).view(
            self.batch_size, self.state_size)
        next_states = to_tensor(samples['next_state']).float().to(
            self.device).view(self.batch_size, self.state_size)
        actions = to_tensor(samples['action']).to(self.device).view(
            self.batch_size, self.action_size)

        # Critic (value) update
        for _ in range(self.critic_number_updates):
            value_loss, error = self.compute_value_loss(
                states, actions, rewards, next_states, dones)
            self.critic_optimizer.zero_grad()
            value_loss.backward()
            nn.utils.clip_grad_norm_(self.critic_params,
                                     self.max_grad_norm_critic)
            self.critic_optimizer.step()
            self._loss_critic = value_loss.item()

        # Actor (policy) update
        for _ in range(self.actor_number_updates):
            policy_loss = self.compute_policy_loss(states)
            self.actor_optimizer.zero_grad()
            policy_loss.backward()
            nn.utils.clip_grad_norm_(self.actor_params,
                                     self.max_grad_norm_actor)
            self.actor_optimizer.step()
            self._loss_actor = policy_loss.item()

        if hasattr(self.memory, 'priority_update'):
            assert any(~torch.isnan(error))
            self.memory.priority_update(samples['index'], error.abs())

        soft_update(self.target_double_critic, self.double_critic, self.tau)

    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)
        data_logger.log_value("loss/alpha", self.alpha, step)

        if self.simple_policy:
            policy_params = {
                str(i): v
                for i, v in enumerate(
                    itertools.chain.from_iterable(self.policy.parameters()))
            }
            data_logger.log_values_dict("policy/param", policy_params, step)

        for name, value in self._metrics.items():
            if isinstance(value, dict):
                data_logger.log_values_dict(name, value, step)
            else:
                data_logger.log_value(name, value, step)

        if full_log:
            # TODO: Add Policy layers
            for idx, layer in enumerate(self.actor.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"policy/layer_weights_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"policy/layer_bias_{idx}",
                                                 layer.bias, step)

            for idx, layer in enumerate(self.double_critic.critic_1.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"critic_1/layer_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"critic_1/layer_bias_{idx}",
                                                 layer.bias, step)

            for idx, layer in enumerate(self.double_critic.critic_2.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"critic_2/layer_{idx}",
                                                 layer.weight, step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"critic_2/layer_bias_{idx}",
                                                 layer.bias, step)

    def get_state(self):
        return dict(
            actor=self.actor.state_dict(),
            policy=self.policy.state_dict(),
            double_critic=self.double_critic.state_dict(),
            target_double_critic=self.target_double_critic.state_dict(),
            config=self._config,
        )

    def save_state(self, path: str):
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self, path: str):
        agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.actor.load_state_dict(agent_state['actor'])
        self.policy.load_state_dict(agent_state['policy'])
        self.double_critic.load_state_dict(agent_state['double_critic'])
        self.target_double_critic.load_state_dict(
            agent_state['target_double_critic'])
Beispiel #25
0
    def __init__(self, state_size: int, action_size: int, hidden_layers: Sequence[int]=(128, 128), **kwargs):
        """
        Parameters:
            state_size (int): Number of input dimensions.
            action_size (int): Number of output dimensions
            hidden_layers (tuple of ints): Tuple defining hidden dimensions in fully connected nets. Default: (128, 128).

        Keyword parameters:
            gamma (float): Discount value. Default: 0.99.
            tau (float): Soft-copy factor. Default: 0.02.
            actor_lr (float): Learning rate for the actor (policy). Default: 0.0003.
            critic_lr (float): Learning rate for the critic (value function). Default: 0.0003.
            actor_hidden_layers (tuple of ints): Shape of network for actor. Default: `hideen_layers`.
            critic_hidden_layers (tuple of ints): Shape of network for critic. Default: `hideen_layers`.
            max_grad_norm_actor (float) Maximum norm value for actor gradient. Default: 100.
            max_grad_norm_critic (float): Maximum norm value for critic gradient. Default: 100.
            num_atoms (int): Number of discrete values for the value distribution. Default: 51.
            v_min (float): Value distribution minimum (left most) value. Default: -10.
            v_max (float): Value distribution maximum (right most) value. Default: 10.
            n_steps (int): Number of steps (N-steps) for the TD. Defualt: 3.
            batch_size (int): Number of samples used in learning. Default: 64.
            buffer_size (int): Maximum number of samples to store. Default: 1e6.
            warm_up (int): Number of samples to observe before starting any learning step. Default: 0.
            update_freq (int): Number of steps between each learning step. Default 1.
            action_min (float): Minimum returned action value. Default: -1.
            action_max (float): Maximum returned action value. Default: 1.
            action_scale (float): Multipler value for action. Default: 1.

        """
        super().__init__(**kwargs)
        self.device = self._register_param(kwargs, "device", DEVICE)
        self.state_size = state_size
        self.action_size = action_size

        self.num_atoms = int(self._register_param(kwargs, 'num_atoms', 51))
        v_min = float(self._register_param(kwargs, 'v_min', -10))
        v_max = float(self._register_param(kwargs, 'v_max', 10))

        # Reason sequence initiation.
        self.action_min = float(self._register_param(kwargs, 'action_min', -1))
        self.action_max = float(self._register_param(kwargs, 'action_max', 1))
        self.action_scale = int(self._register_param(kwargs, 'action_scale', 1))

        self.gamma = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size: int = int(self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size: int = int(self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.buffer = PERBuffer(self.batch_size, self.buffer_size)

        self.n_steps = int(self._register_param(kwargs, "n_steps", 3))
        self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma)

        self.warm_up: int = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq: int = int(self._register_param(kwargs, 'update_freq', 1))

        if kwargs.get("simple_policy", False):
            std_init = kwargs.get("std_init", 1.0)
            std_max = kwargs.get("std_max", 1.5)
            std_min = kwargs.get("std_min", 0.25)
            self.policy = MultivariateGaussianPolicySimple(self.action_size, std_init=std_init, std_min=std_min, std_max=std_max, device=self.device)
        else:
            self.policy = MultivariateGaussianPolicy(self.action_size, device=self.device)

        self.actor_hidden_layers = to_numbers_seq(self._register_param(kwargs, 'actor_hidden_layers', hidden_layers))
        self.critic_hidden_layers = to_numbers_seq(self._register_param(kwargs, 'critic_hidden_layers', hidden_layers))

        # This looks messy but it's not that bad. Actor, critic_net and Critic(critic_net). Then the same for `target_`.
        self.actor = ActorBody(
            state_size, self.policy.param_dim*action_size, hidden_layers=self.actor_hidden_layers,
            gate_out=torch.tanh, device=self.device
        )
        critic_net = CriticBody(
            state_size, action_size, out_features=self.num_atoms, hidden_layers=self.critic_hidden_layers, device=self.device
        )
        self.critic = CategoricalNet(
            num_atoms=self.num_atoms, v_min=v_min, v_max=v_max, net=critic_net, device=self.device
        )

        self.target_actor = ActorBody(
            state_size, self.policy.param_dim*action_size, hidden_layers=self.actor_hidden_layers,
            gate_out=torch.tanh, device=self.device
        )
        target_critic_net = CriticBody(
            state_size, action_size, out_features=self.num_atoms, hidden_layers=self.critic_hidden_layers, device=self.device
        )
        self.target_critic = CategoricalNet(
            num_atoms=self.num_atoms, v_min=v_min, v_max=v_max, net=target_critic_net, device=self.device
        )

        # Target sequence initiation
        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)

        # Optimization sequence initiation.
        self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4))
        self.value_loss_func = nn.BCELoss(reduction='none')

        # self.actor_params = list(self.actor.parameters()) #+ list(self.policy.parameters())
        self.actor_params = list(self.actor.parameters()) + list(self.policy.parameters())
        self.actor_optimizer = Adam(self.actor_params, lr=self.actor_lr)
        self.critic_optimizer = Adam(self.critic.parameters(), lr=self.critic_lr)
        self.max_grad_norm_actor = float(self._register_param(kwargs, "max_grad_norm_actor", 100))
        self.max_grad_norm_critic = float(self._register_param(kwargs, "max_grad_norm_critic", 100))

        # Breath, my child.
        self.iteration = 0
        self._loss_actor = float('nan')
        self._loss_critic = float('nan')
        self._display_dist = torch.zeros(self.critic.z_atoms.shape)
        self._metric_batch_error = torch.zeros(self.batch_size)
        self._metric_batch_value_dist = torch.zeros(self.batch_size)
Beispiel #26
0
class RainbowAgent(AgentBase):
    """Rainbow agent as described in [1].

    Rainbow is a DQN agent with some improvments that were suggested before 2017.
    As mentioned by the authors it's not exhaustive improvment but all changes are in
    relatively separate areas so their connection makes sense. These improvements are:
    * Priority Experience Replay
    * Multi-step
    * Double Q net
    * Dueling nets
    * NoisyNet
    * CategoricalNet for Q estimate

    Consider this class as a particular version of the DQN agent.

    [1] "Rainbow: Combining Improvements in Deep Reinforcement Learning" by Hessel et al. (DeepMind team)
    https://arxiv.org/abs/1710.02298
    """

    name = "Rainbow"

    def __init__(self,
                 input_shape: Union[Sequence[int], int],
                 output_shape: Union[Sequence[int], int],
                 state_transform: Optional[Callable] = None,
                 reward_transform: Optional[Callable] = None,
                 **kwargs):
        """
        A wrapper over the DQN thus majority of the logic is in the DQNAgent.
        Special treatment is required because the Rainbow agent uses categorical nets
        which operate on probability distributions. Each action is taken as the estimate
        from such distributions.

        Parameters:
            input_shape (tuple of ints): Most likely that's your *state* shape.
            output_shape (tuple of ints): Most likely that's you *action* shape.
            pre_network_fn (function that takes input_shape and returns network):
                Used to preprocess state before it is used in the value- and advantage-function in the dueling nets.
            hidden_layers (tuple of ints): Shape and sizes of fully connected networks used. Default: (100, 100).
            lr (default: 1e-3): Learning rate value.
            gamma (float): Discount factor. Default: 0.99.
            tau (float): Soft-copy factor. Default: 0.002.
            update_freq (int): Number of steps between each learning step. Default 1.
            batch_size (int): Number of samples to use at each learning step. Default: 80.
            buffer_size (int): Number of most recent samples to keep in memory for learning. Default: 1e5.
            warm_up (int): Number of samples to observe before starting any learning step. Default: 0.
            number_updates (int): How many times to use learning step in the learning phase. Default: 1.
            max_grad_norm (float): Maximum norm of the gradient used in learning. Default: 10.
            using_double_q (bool): Whether to use Double Q Learning network. Default: True.
            n_steps (int): Number of lookahead steps when estimating reward. See :ref:`NStepBuffer`. Default: 3.
            v_min (float): Lower bound for distributional value V. Default: -10.
            v_max (float): Upper bound for distributional value V. Default: 10.
            num_atoms (int): Number of atoms (discrete states) in the value V distribution. Default: 21.

        """
        super().__init__(**kwargs)
        self.device = self._register_param(kwargs,
                                           "device",
                                           DEVICE,
                                           update=True)
        self.input_shape: Sequence[int] = input_shape if not isinstance(
            input_shape, int) else (input_shape, )

        self.state_size: int = self.input_shape[0]
        self.output_shape: Sequence[int] = output_shape if not isinstance(
            output_shape, int) else (output_shape, )
        self.action_size: int = self.output_shape[0]

        self.lr = float(self._register_param(kwargs, 'lr', 3e-4))
        self.gamma = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau = float(self._register_param(kwargs, 'tau', 0.002))
        self.update_freq = int(self._register_param(kwargs, 'update_freq', 1))
        self.batch_size = int(
            self._register_param(kwargs, 'batch_size', 80, update=True))
        self.buffer_size = int(
            self._register_param(kwargs, 'buffer_size', int(1e5), update=True))
        self.warm_up = int(self._register_param(kwargs, 'warm_up', 0))
        self.number_updates = int(
            self._register_param(kwargs, 'number_updates', 1))
        self.max_grad_norm = float(
            self._register_param(kwargs, 'max_grad_norm', 10))

        self.iteration: int = 0
        self.using_double_q = bool(
            self._register_param(kwargs, "using_double_q", True))

        self.state_transform = state_transform if state_transform is not None else lambda x: x
        self.reward_transform = reward_transform if reward_transform is not None else lambda x: x

        v_min = float(self._register_param(kwargs, "v_min", -10))
        v_max = float(self._register_param(kwargs, "v_max", 10))
        self.num_atoms = int(
            self._register_param(kwargs, "num_atoms", 21, drop=True))
        self.z_atoms = torch.linspace(v_min,
                                      v_max,
                                      self.num_atoms,
                                      device=self.device)
        self.z_delta = self.z_atoms[1] - self.z_atoms[0]

        self.buffer = PERBuffer(**kwargs)
        self.__batch_indices = torch.arange(self.batch_size,
                                            device=self.device)

        self.n_steps = int(self._register_param(kwargs, "n_steps", 3))
        self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma)

        # Note that in case a pre_network is provided, e.g. a shared net that extracts pixels values,
        # it should be explicitly passed in kwargs
        kwargs["hidden_layers"] = to_numbers_seq(
            self._register_param(kwargs, "hidden_layers", (100, 100)))
        self.net = RainbowNet(self.input_shape,
                              self.output_shape,
                              num_atoms=self.num_atoms,
                              **kwargs)
        self.target_net = RainbowNet(self.input_shape,
                                     self.output_shape,
                                     num_atoms=self.num_atoms,
                                     **kwargs)

        self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
        self.dist_probs = None
        self._loss = float('inf')

    @property
    def loss(self):
        return {'loss': self._loss}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            value = value['loss']
        self._loss = value

    def step(self, state, action, reward, next_state, done) -> None:
        """Letting the agent to take a step.

        On some steps the agent will initiate learning step. This is dependent on
        the `update_freq` value.

        Parameters:
            state: S(t)
            action: A(t)
            reward: R(t)
            nexxt_state: S(t+1)
            done: (bool) Whether the state is terminal.

        """
        self.iteration += 1
        state = to_tensor(self.state_transform(state)).float().to("cpu")
        next_state = to_tensor(
            self.state_transform(next_state)).float().to("cpu")
        reward = self.reward_transform(reward)

        # Delay adding to buffer to account for n_steps (particularly the reward)
        self.n_buffer.add(state=state.numpy(),
                          action=[int(action)],
                          reward=[reward],
                          done=[done],
                          next_state=next_state.numpy())
        if not self.n_buffer.available:
            return

        self.buffer.add(**self.n_buffer.get().get_dict())

        if self.iteration < self.warm_up:
            return

        if len(self.buffer) >= self.batch_size and (self.iteration %
                                                    self.update_freq) == 0:
            for _ in range(self.number_updates):
                self.learn(self.buffer.sample())

            # Update networks only once - sync local & target
            soft_update(self.target_net, self.net, self.tau)

    def act(self, state, eps: float = 0.) -> int:
        """
        Returns actions for given state as per current policy.

        Parameters:
            state: Current available state from the environment.
            epislon: Epsilon value in the epislon-greedy policy.

        """
        # Epsilon-greedy action selection
        if self._rng.random() < eps:
            return self._rng.randint(0, self.action_size - 1)

        state = to_tensor(self.state_transform(state)).float().unsqueeze(0).to(
            self.device)
        # state = to_tensor(self.state_transform(state)).float().to(self.device)
        self.dist_probs = self.net.act(state)
        q_values = (self.dist_probs * self.z_atoms).sum(-1)
        return int(
            q_values.argmax(-1))  # Action maximizes state-action value Q(s, a)

    def learn(self, experiences: Dict[str, List]) -> None:
        """
        Parameters:
            experiences: Contains all experiences for the agent. Typically sampled from the memory buffer.
                Five keys are expected, i.e. `state`, `action`, `reward`, `next_state`, `done`.
                Each key contains a array and all arrays have to have the same length.

        """
        rewards = to_tensor(experiences['reward']).float().to(self.device)
        dones = to_tensor(experiences['done']).type(torch.int).to(self.device)
        states = to_tensor(experiences['state']).float().to(self.device)
        next_states = to_tensor(experiences['next_state']).float().to(
            self.device)
        actions = to_tensor(experiences['action']).type(torch.long).to(
            self.device)
        assert rewards.shape == dones.shape == (self.batch_size, 1)
        assert states.shape == next_states.shape == (self.batch_size,
                                                     self.state_size)
        assert actions.shape == (self.batch_size, 1)  # Discrete domain

        with torch.no_grad():
            prob_next = self.target_net.act(next_states)
            q_next = (prob_next * self.z_atoms).sum(-1) * self.z_delta
            if self.using_double_q:
                duel_prob_next = self.net.act(next_states)
                a_next = torch.argmax((duel_prob_next * self.z_atoms).sum(-1),
                                      dim=-1)
            else:
                a_next = torch.argmax(q_next, dim=-1)

            prob_next = prob_next[self.__batch_indices, a_next, :]

        m = self.net.dist_projection(rewards, 1 - dones,
                                     self.gamma**self.n_steps, prob_next)
        assert m.shape == (self.batch_size, self.num_atoms)

        log_prob = self.net(states, log_prob=True)
        assert log_prob.shape == (self.batch_size, self.action_size,
                                  self.num_atoms)
        log_prob = log_prob[self.__batch_indices, actions.squeeze(), :]
        assert log_prob.shape == m.shape == (self.batch_size, self.num_atoms)

        # Cross-entropy loss error and the loss is batch mean
        error = -torch.sum(m * log_prob, 1)
        assert error.shape == (self.batch_size, )
        loss = error.mean()
        assert loss >= 0

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.net.parameters(), self.max_grad_norm)
        self.optimizer.step()
        self._loss = float(loss.item())

        if hasattr(self.buffer, 'priority_update'):
            assert (~torch.isnan(error)).any()
            self.buffer.priority_update(experiences['index'],
                                        error.detach().cpu().numpy())

        # Update networks - sync local & target
        soft_update(self.target_net, self.net, self.tau)

    def state_dict(self) -> Dict[str, dict]:
        """Returns agent's state dictionary.

        Returns:
            State dicrionary for internal networks.

        """
        return {
            "net": self.net.state_dict(),
            "target_net": self.target_net.state_dict()
        }

    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/agent", self._loss, step)

        if full_log and self.dist_probs is not None:
            for action_idx in range(self.action_size):
                dist = self.dist_probs[0, action_idx]
                data_logger.log_value(f'dist/expected_{action_idx}',
                                      (dist * self.z_atoms).sum().item(), step)
                data_logger.add_histogram(f'dist/Q_{action_idx}',
                                          min=self.z_atoms[0],
                                          max=self.z_atoms[-1],
                                          num=len(self.z_atoms),
                                          sum=dist.sum(),
                                          sum_squares=dist.pow(2).sum(),
                                          bucket_limits=self.z_atoms +
                                          self.z_delta,
                                          bucket_counts=dist,
                                          global_step=step)

        # This method, `log_metrics`, isn't executed on every iteration but just in case we delay plotting weights.
        # It simply might be quite costly. Thread wisely.
        if full_log:
            for idx, layer in enumerate(self.net.value_net.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(
                        f"value_net/layer_weights_{idx}", layer.weight.cpu(),
                        step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(f"value_net/layer_bias_{idx}",
                                                 layer.bias.cpu(), step)
            for idx, layer in enumerate(self.net.advantage_net.layers):
                if hasattr(layer, "weight"):
                    data_logger.create_histogram(f"advantage_net/layer_{idx}",
                                                 layer.weight.cpu(), step)
                if hasattr(layer, "bias") and layer.bias is not None:
                    data_logger.create_histogram(
                        f"advantage_net/layer_bias_{idx}", layer.bias.cpu(),
                        step)

    def get_state(self) -> AgentState:
        """Provides agent's internal state."""
        return AgentState(
            model=self.name,
            state_space=self.state_size,
            action_space=self.action_size,
            config=self._config,
            buffer=copy.deepcopy(self.buffer.get_state()),
            network=copy.deepcopy(self.get_network_state()),
        )

    def get_network_state(self) -> NetworkState:
        return NetworkState(net=dict(net=self.net.state_dict(),
                                     target_net=self.target_net.state_dict()))

    @staticmethod
    def from_state(state: AgentState) -> AgentBase:
        config = copy.copy(state.config)
        config.update({
            'input_shape': state.state_space,
            'output_shape': state.action_space
        })
        agent = RainbowAgent(**config)
        if state.network is not None:
            agent.set_network(state.network)
        if state.buffer is not None:
            agent.set_buffer(state.buffer)
        return agent

    def set_network(self, network_state: NetworkState) -> None:
        self.net.load_state_dict(network_state.net['net'])
        self.target_net.load_state_dict(network_state.net['target_net'])

    def set_buffer(self, buffer_state: BufferState) -> None:
        self.buffer = BufferFactory.from_state(buffer_state)

    def save_state(self, path: str) -> None:
        """Saves agent's state into a file.

        Parameters:
            path: String path where to write the state.

        """
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self, path: str) -> None:
        """Loads state from a file under provided path.

        Parameters:
            path: String path indicating where the state is stored.

        """
        agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.net.load_state_dict(agent_state['net'])
        self.target_net.load_state_dict(agent_state['target_net'])

    def save_buffer(self, path: str) -> None:
        """Saves data from the buffer into a file under provided path.

        Parameters:
            path: String path where to write the buffer.

        """
        import json
        dump = self.buffer.dump_buffer(serialize=True)
        with open(path, 'w') as f:
            json.dump(dump, f)

    def load_buffer(self, path: str) -> None:
        """Loads data into the buffer from provided file path.

        Parameters:
            path: String path indicating where the buffer is stored.

        """
        import json
        with open(path, 'r') as f:
            buffer_dump = json.load(f)
        self.buffer.load_buffer(buffer_dump)

    def __eq__(self, o: object) -> bool:
        return super().__eq__(o) \
               and self._config == o._config \
               and self.buffer == o.buffer \
               and self.get_network_state() == o.get_network_state()
Beispiel #27
0
    def __init__(self,
                 input_shape: Union[Sequence[int], int],
                 output_shape: Union[Sequence[int], int],
                 state_transform: Optional[Callable] = None,
                 reward_transform: Optional[Callable] = None,
                 **kwargs):
        """
        A wrapper over the DQN thus majority of the logic is in the DQNAgent.
        Special treatment is required because the Rainbow agent uses categorical nets
        which operate on probability distributions. Each action is taken as the estimate
        from such distributions.

        Parameters:
            input_shape (tuple of ints): Most likely that's your *state* shape.
            output_shape (tuple of ints): Most likely that's you *action* shape.
            pre_network_fn (function that takes input_shape and returns network):
                Used to preprocess state before it is used in the value- and advantage-function in the dueling nets.
            hidden_layers (tuple of ints): Shape and sizes of fully connected networks used. Default: (100, 100).
            lr (default: 1e-3): Learning rate value.
            gamma (float): Discount factor. Default: 0.99.
            tau (float): Soft-copy factor. Default: 0.002.
            update_freq (int): Number of steps between each learning step. Default 1.
            batch_size (int): Number of samples to use at each learning step. Default: 80.
            buffer_size (int): Number of most recent samples to keep in memory for learning. Default: 1e5.
            warm_up (int): Number of samples to observe before starting any learning step. Default: 0.
            number_updates (int): How many times to use learning step in the learning phase. Default: 1.
            max_grad_norm (float): Maximum norm of the gradient used in learning. Default: 10.
            using_double_q (bool): Whether to use Double Q Learning network. Default: True.
            n_steps (int): Number of lookahead steps when estimating reward. See :ref:`NStepBuffer`. Default: 3.
            v_min (float): Lower bound for distributional value V. Default: -10.
            v_max (float): Upper bound for distributional value V. Default: 10.
            num_atoms (int): Number of atoms (discrete states) in the value V distribution. Default: 21.

        """
        super().__init__(**kwargs)
        self.device = self._register_param(kwargs,
                                           "device",
                                           DEVICE,
                                           update=True)
        self.input_shape: Sequence[int] = input_shape if not isinstance(
            input_shape, int) else (input_shape, )

        self.state_size: int = self.input_shape[0]
        self.output_shape: Sequence[int] = output_shape if not isinstance(
            output_shape, int) else (output_shape, )
        self.action_size: int = self.output_shape[0]

        self.lr = float(self._register_param(kwargs, 'lr', 3e-4))
        self.gamma = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau = float(self._register_param(kwargs, 'tau', 0.002))
        self.update_freq = int(self._register_param(kwargs, 'update_freq', 1))
        self.batch_size = int(
            self._register_param(kwargs, 'batch_size', 80, update=True))
        self.buffer_size = int(
            self._register_param(kwargs, 'buffer_size', int(1e5), update=True))
        self.warm_up = int(self._register_param(kwargs, 'warm_up', 0))
        self.number_updates = int(
            self._register_param(kwargs, 'number_updates', 1))
        self.max_grad_norm = float(
            self._register_param(kwargs, 'max_grad_norm', 10))

        self.iteration: int = 0
        self.using_double_q = bool(
            self._register_param(kwargs, "using_double_q", True))

        self.state_transform = state_transform if state_transform is not None else lambda x: x
        self.reward_transform = reward_transform if reward_transform is not None else lambda x: x

        v_min = float(self._register_param(kwargs, "v_min", -10))
        v_max = float(self._register_param(kwargs, "v_max", 10))
        self.num_atoms = int(
            self._register_param(kwargs, "num_atoms", 21, drop=True))
        self.z_atoms = torch.linspace(v_min,
                                      v_max,
                                      self.num_atoms,
                                      device=self.device)
        self.z_delta = self.z_atoms[1] - self.z_atoms[0]

        self.buffer = PERBuffer(**kwargs)
        self.__batch_indices = torch.arange(self.batch_size,
                                            device=self.device)

        self.n_steps = int(self._register_param(kwargs, "n_steps", 3))
        self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma)

        # Note that in case a pre_network is provided, e.g. a shared net that extracts pixels values,
        # it should be explicitly passed in kwargs
        kwargs["hidden_layers"] = to_numbers_seq(
            self._register_param(kwargs, "hidden_layers", (100, 100)))
        self.net = RainbowNet(self.input_shape,
                              self.output_shape,
                              num_atoms=self.num_atoms,
                              **kwargs)
        self.target_net = RainbowNet(self.input_shape,
                                     self.output_shape,
                                     num_atoms=self.num_atoms,
                                     **kwargs)

        self.optimizer = optim.Adam(self.net.parameters(), lr=self.lr)
        self.dist_probs = None
        self._loss = float('inf')
Beispiel #28
0
class D3PGAgent(AgentBase):
    """Distributional DDPG (D3PG) [1].

    It's closely related to, and sits in-between, D4PG and DDPG. Compared to D4PG it lacks
    the multi actors support. It extends the DDPG agent with:
    1. Distributional critic update.
    2. N-step returns.
    3. Prioritization of the experience replay (PER).

    [1] "Distributed Distributional Deterministic Policy Gradients"
        (2018, ICLR) by G. Barth-Maron & M. Hoffman et al.

    """

    name = "D3PG"

    def __init__(self,
                 state_size: int,
                 action_size: int,
                 hidden_layers: Sequence[int] = (128, 128),
                 **kwargs):
        super().__init__(**kwargs)
        self.device = self._register_param(kwargs, "device", DEVICE)
        self.state_size = state_size
        self.action_size = action_size

        self.num_atoms = int(self._register_param(kwargs, 'num_atoms', 51))
        v_min = float(self._register_param(kwargs, 'v_min', -10))
        v_max = float(self._register_param(kwargs, 'v_max', 10))

        # Reason sequence initiation.
        self.action_min = float(self._register_param(kwargs, 'action_min', -1))
        self.action_max = float(self._register_param(kwargs, 'action_max', 1))
        self.action_scale = int(self._register_param(kwargs, 'action_scale',
                                                     1))

        self.gamma = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size: int = int(
            self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size: int = int(
            self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.buffer = PERBuffer(self.batch_size, self.buffer_size)

        self.n_steps = int(self._register_param(kwargs, "n_steps", 3))
        self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma)

        self.warm_up: int = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq: int = int(
            self._register_param(kwargs, 'update_freq', 1))

        if kwargs.get("simple_policy", False):
            std_init = kwargs.get("std_init", 1.0)
            std_max = kwargs.get("std_max", 1.5)
            std_min = kwargs.get("std_min", 0.25)
            self.policy = MultivariateGaussianPolicySimple(self.action_size,
                                                           std_init=std_init,
                                                           std_min=std_min,
                                                           std_max=std_max,
                                                           device=self.device)
        else:
            self.policy = MultivariateGaussianPolicy(self.action_size,
                                                     device=self.device)

        self.actor_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'actor_hidden_layers', hidden_layers))
        self.critic_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'critic_hidden_layers',
                                 hidden_layers))

        # This looks messy but it's not that bad. Actor, critic_net and Critic(critic_net). Then the same for `target_`.
        self.actor = ActorBody(state_size,
                               self.policy.param_dim * action_size,
                               hidden_layers=self.actor_hidden_layers,
                               gate_out=torch.tanh,
                               device=self.device)
        critic_net = CriticBody(state_size,
                                action_size,
                                out_features=self.num_atoms,
                                hidden_layers=self.critic_hidden_layers,
                                device=self.device)
        self.critic = CategoricalNet(num_atoms=self.num_atoms,
                                     v_min=v_min,
                                     v_max=v_max,
                                     net=critic_net,
                                     device=self.device)

        self.target_actor = ActorBody(state_size,
                                      self.policy.param_dim * action_size,
                                      hidden_layers=self.actor_hidden_layers,
                                      gate_out=torch.tanh,
                                      device=self.device)
        target_critic_net = CriticBody(state_size,
                                       action_size,
                                       out_features=self.num_atoms,
                                       hidden_layers=self.critic_hidden_layers,
                                       device=self.device)
        self.target_critic = CategoricalNet(num_atoms=self.num_atoms,
                                            v_min=v_min,
                                            v_max=v_max,
                                            net=target_critic_net,
                                            device=self.device)

        # Target sequence initiation
        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)

        # Optimization sequence initiation.
        self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4))
        self.value_loss_func = nn.BCELoss(reduction='none')

        # self.actor_params = list(self.actor.parameters()) #+ list(self.policy.parameters())
        self.actor_params = list(self.actor.parameters()) + list(
            self.policy.parameters())
        self.actor_optimizer = Adam(self.actor_params, lr=self.actor_lr)
        self.critic_optimizer = Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 50.0))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 50.0))

        # Breath, my child.
        self.iteration = 0
        self._loss_actor = float('nan')
        self._loss_critic = float('nan')
        self._display_dist = torch.zeros(self.critic.z_atoms.shape)
        self._metric_batch_error = torch.zeros(self.batch_size)
        self._metric_batch_value_dist = torch.zeros(self.batch_size)

    @property
    def loss(self) -> Dict[str, float]:
        return {'actor': self._loss_actor, 'critic': self._loss_critic}

    @loss.setter
    def loss(self, value):
        if isinstance(value, dict):
            self._loss_actor = value['actor']
            self._loss_critic = value['critic']
        else:
            self._loss_actor = value
            self._loss_critic = value

    @torch.no_grad()
    def act(self, state, epsilon: float = 0.0) -> List[float]:
        """
        Returns actions for given state as per current policy.

        Parameters:
            state: Current available state from the environment.
            epislon: Epsilon value in the epislon-greedy policy.

        """
        state = to_tensor(state).float().to(self.device)
        if self._rng.random() < epsilon:
            action = self.action_scale * (torch.rand(self.action_size) - 0.5)

        else:
            action_seed = self.actor.act(state).view(1, -1)
            action_dist = self.policy(action_seed)
            action = action_dist.sample()
            action *= self.action_scale
            action = action.squeeze()

        # Purely for logging
        self._display_dist = self.target_critic.act(
            state, action.to(self.device)).squeeze().cpu()
        self._display_dist = F.softmax(self._display_dist, dim=0)

        return torch.clamp(action, self.action_min,
                           self.action_max).cpu().tolist()

    def step(self, state, action, reward, next_state, done):
        self.iteration += 1

        # Delay adding to buffer to account for n_steps (particularly the reward)
        self.n_buffer.add(state=state,
                          action=action,
                          reward=[reward],
                          done=[done],
                          next_state=next_state)
        if not self.n_buffer.available:
            return

        self.buffer.add(**self.n_buffer.get().get_dict())

        if self.iteration < self.warm_up:
            return

        if len(self.buffer) > self.batch_size and (self.iteration %
                                                   self.update_freq) == 0:
            self.learn(self.buffer.sample())

    def compute_value_loss(self,
                           states,
                           actions,
                           next_states,
                           rewards,
                           dones,
                           indices=None):
        # Q_w estimate
        value_dist_estimate = self.critic(states, actions)
        assert value_dist_estimate.shape == (self.batch_size, 1,
                                             self.num_atoms)
        value_dist = F.softmax(value_dist_estimate.squeeze(), dim=1)
        assert value_dist.shape == (self.batch_size, self.num_atoms)

        # Q_w' estimate via Bellman's dist operator
        next_action_seeds = self.target_actor.act(next_states)
        next_actions = self.policy(next_action_seeds).sample()
        assert next_actions.shape == (self.batch_size, self.action_size)

        target_value_dist_estimate = self.target_critic.act(
            states, next_actions)
        assert target_value_dist_estimate.shape == (self.batch_size, 1,
                                                    self.num_atoms)
        target_value_dist_estimate = target_value_dist_estimate.squeeze()
        assert target_value_dist_estimate.shape == (self.batch_size,
                                                    self.num_atoms)

        discount = self.gamma**self.n_steps
        target_value_projected = self.target_critic.dist_projection(
            rewards, 1 - dones, discount, target_value_dist_estimate)
        assert target_value_projected.shape == (self.batch_size,
                                                self.num_atoms)

        target_value_dist = F.softmax(target_value_dist_estimate,
                                      dim=-1).detach()
        assert target_value_dist.shape == (self.batch_size, self.num_atoms)

        # Comparing Q_w with Q_w'
        loss = self.value_loss_func(value_dist, target_value_projected)
        self._metric_batch_error = loss.detach().sum(dim=-1)
        samples_error = loss.sum(dim=-1).pow(2)
        loss_critic = samples_error.mean()

        if hasattr(self.buffer, 'priority_update') and indices is not None:
            assert (~torch.isnan(samples_error)).any()
            self.buffer.priority_update(indices,
                                        samples_error.detach().cpu().numpy())

        return loss_critic

    def compute_policy_loss(self, states):
        # Compute actor loss
        pred_action_seeds = self.actor(states)
        pred_actions = self.policy(pred_action_seeds).rsample()
        # Negative because the optimizer minimizes, but we want to maximize the value
        value_dist = self.critic(states, pred_actions)
        self._metric_batch_value_dist = value_dist.detach()
        # Estimate on Z support
        return -torch.mean(value_dist * self.critic.z_atoms)

    def learn(self, experiences):
        """Update critics and actors"""
        rewards = to_tensor(experiences['reward']).float().to(self.device)
        dones = to_tensor(experiences['done']).type(torch.int).to(self.device)
        states = to_tensor(experiences['state']).float().to(self.device)
        actions = to_tensor(experiences['action']).to(self.device)
        next_states = to_tensor(experiences['next_state']).float().to(
            self.device)
        assert rewards.shape == dones.shape == (self.batch_size, 1)
        assert states.shape == next_states.shape == (self.batch_size,
                                                     self.state_size)
        assert actions.shape == (self.batch_size, self.action_size)

        indices = None
        if hasattr(self.buffer, 'priority_update'):  # When using PER buffer
            indices = experiences['index']
        loss_critic = self.compute_value_loss(states, actions, next_states,
                                              rewards, dones, indices)

        # Value (critic) optimization
        self.critic_optimizer.zero_grad()
        loss_critic.backward()
        nn.utils.clip_grad_norm_(self.actor_params, self.max_grad_norm_critic)
        self.critic_optimizer.step()
        self._loss_critic = float(loss_critic.item())

        # Policy (actor) optimization
        loss_actor = self.compute_policy_loss(states)
        self.actor_optimizer.zero_grad()
        loss_actor.backward()
        nn.utils.clip_grad_norm_(self.actor.parameters(),
                                 self.max_grad_norm_actor)
        self.actor_optimizer.step()
        self._loss_actor = float(loss_actor.item())

        # Networks gradual sync
        soft_update(self.target_actor, self.actor, self.tau)
        soft_update(self.target_critic, self.critic, self.tau)

    def state_dict(self) -> Dict[str, dict]:
        """Describes agent's networks.

        Returns:
            state: (dict) Provides actors and critics states.

        """
        return {
            "actor": self.actor.state_dict(),
            "target_actor": self.target_actor.state_dict(),
            "critic": self.critic.state_dict(),
            "target_critic": self.target_critic()
        }

    def log_metrics(self,
                    data_logger: DataLogger,
                    step: int,
                    full_log: bool = False):
        data_logger.log_value("loss/actor", self._loss_actor, step)
        data_logger.log_value("loss/critic", self._loss_critic, step)
        policy_params = {
            str(i): v
            for i, v in enumerate(
                itertools.chain.from_iterable(self.policy.parameters()))
        }
        data_logger.log_values_dict("policy/param", policy_params, step)

        data_logger.create_histogram('metric/batch_errors',
                                     self._metric_batch_error, step)
        data_logger.create_histogram('metric/batch_value_dist',
                                     self._metric_batch_value_dist, step)

        if full_log:
            dist = self._display_dist
            z_atoms = self.critic.z_atoms
            z_delta = self.critic.z_delta
            data_logger.add_histogram('dist/dist_value',
                                      min=z_atoms[0],
                                      max=z_atoms[-1],
                                      num=self.num_atoms,
                                      sum=dist.sum(),
                                      sum_squares=dist.pow(2).sum(),
                                      bucket_limits=z_atoms + z_delta,
                                      bucket_counts=dist,
                                      global_step=step)

    def get_state(self):
        return dict(
            actor=self.actor.state_dict(),
            target_actor=self.target_actor.state_dict(),
            critic=self.critic.state_dict(),
            target_critic=self.target_critic.state_dict(),
            config=self._config,
        )

    def save_state(self, path: str):
        agent_state = self.get_state()
        torch.save(agent_state, path)

    def load_state(self, path: str):
        agent_state = torch.load(path)
        self._config = agent_state.get('config', {})
        self.__dict__.update(**self._config)

        self.actor.load_state_dict(agent_state['actor'])
        self.critic.load_state_dict(agent_state['critic'])
        self.target_actor.load_state_dict(agent_state['target_actor'])
        self.target_critic.load_state_dict(agent_state['target_critic'])
Beispiel #29
0
    def __init__(self,
                 state_size: int,
                 action_size: int,
                 hidden_layers: Sequence[int] = (128, 128),
                 **kwargs):
        super().__init__(**kwargs)
        self.device = self._register_param(kwargs, "device", DEVICE)
        self.state_size = state_size
        self.action_size = action_size

        self.num_atoms = int(self._register_param(kwargs, 'num_atoms', 51))
        v_min = float(self._register_param(kwargs, 'v_min', -10))
        v_max = float(self._register_param(kwargs, 'v_max', 10))

        # Reason sequence initiation.
        self.action_min = float(self._register_param(kwargs, 'action_min', -1))
        self.action_max = float(self._register_param(kwargs, 'action_max', 1))
        self.action_scale = int(self._register_param(kwargs, 'action_scale',
                                                     1))

        self.gamma = float(self._register_param(kwargs, 'gamma', 0.99))
        self.tau = float(self._register_param(kwargs, 'tau', 0.02))
        self.batch_size: int = int(
            self._register_param(kwargs, 'batch_size', 64))
        self.buffer_size: int = int(
            self._register_param(kwargs, 'buffer_size', int(1e6)))
        self.buffer = PERBuffer(self.batch_size, self.buffer_size)

        self.n_steps = int(self._register_param(kwargs, "n_steps", 3))
        self.n_buffer = NStepBuffer(n_steps=self.n_steps, gamma=self.gamma)

        self.warm_up: int = int(self._register_param(kwargs, 'warm_up', 0))
        self.update_freq: int = int(
            self._register_param(kwargs, 'update_freq', 1))

        if kwargs.get("simple_policy", False):
            std_init = kwargs.get("std_init", 1.0)
            std_max = kwargs.get("std_max", 1.5)
            std_min = kwargs.get("std_min", 0.25)
            self.policy = MultivariateGaussianPolicySimple(self.action_size,
                                                           std_init=std_init,
                                                           std_min=std_min,
                                                           std_max=std_max,
                                                           device=self.device)
        else:
            self.policy = MultivariateGaussianPolicy(self.action_size,
                                                     device=self.device)

        self.actor_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'actor_hidden_layers', hidden_layers))
        self.critic_hidden_layers = to_numbers_seq(
            self._register_param(kwargs, 'critic_hidden_layers',
                                 hidden_layers))

        # This looks messy but it's not that bad. Actor, critic_net and Critic(critic_net). Then the same for `target_`.
        self.actor = ActorBody(state_size,
                               self.policy.param_dim * action_size,
                               hidden_layers=self.actor_hidden_layers,
                               gate_out=torch.tanh,
                               device=self.device)
        critic_net = CriticBody(state_size,
                                action_size,
                                out_features=self.num_atoms,
                                hidden_layers=self.critic_hidden_layers,
                                device=self.device)
        self.critic = CategoricalNet(num_atoms=self.num_atoms,
                                     v_min=v_min,
                                     v_max=v_max,
                                     net=critic_net,
                                     device=self.device)

        self.target_actor = ActorBody(state_size,
                                      self.policy.param_dim * action_size,
                                      hidden_layers=self.actor_hidden_layers,
                                      gate_out=torch.tanh,
                                      device=self.device)
        target_critic_net = CriticBody(state_size,
                                       action_size,
                                       out_features=self.num_atoms,
                                       hidden_layers=self.critic_hidden_layers,
                                       device=self.device)
        self.target_critic = CategoricalNet(num_atoms=self.num_atoms,
                                            v_min=v_min,
                                            v_max=v_max,
                                            net=target_critic_net,
                                            device=self.device)

        # Target sequence initiation
        hard_update(self.target_actor, self.actor)
        hard_update(self.target_critic, self.critic)

        # Optimization sequence initiation.
        self.actor_lr = float(self._register_param(kwargs, 'actor_lr', 3e-4))
        self.critic_lr = float(self._register_param(kwargs, 'critic_lr', 3e-4))
        self.value_loss_func = nn.BCELoss(reduction='none')

        # self.actor_params = list(self.actor.parameters()) #+ list(self.policy.parameters())
        self.actor_params = list(self.actor.parameters()) + list(
            self.policy.parameters())
        self.actor_optimizer = Adam(self.actor_params, lr=self.actor_lr)
        self.critic_optimizer = Adam(self.critic.parameters(),
                                     lr=self.critic_lr)
        self.max_grad_norm_actor = float(
            self._register_param(kwargs, "max_grad_norm_actor", 50.0))
        self.max_grad_norm_critic = float(
            self._register_param(kwargs, "max_grad_norm_critic", 50.0))

        # Breath, my child.
        self.iteration = 0
        self._loss_actor = float('nan')
        self._loss_critic = float('nan')
        self._display_dist = torch.zeros(self.critic.z_atoms.shape)
        self._metric_batch_error = torch.zeros(self.batch_size)
        self._metric_batch_value_dist = torch.zeros(self.batch_size)
Beispiel #30
0
def test_per_buffer_seed():
    # Assign
    batch_size = 4
    buffer_0 = PERBuffer(batch_size)
    buffer_1 = PERBuffer(batch_size, seed=32167)
    buffer_2 = PERBuffer(batch_size, seed=32167)

    # Act
    for sars in generate_sample_SARS(400, dict_type=True):
        buffer_0.add(**copy.deepcopy(sars))
        buffer_1.add(**copy.deepcopy(sars))
        buffer_2.add(**copy.deepcopy(sars))

    # Assert
    for _ in range(10):
        samples_0 = buffer_0.sample()
        samples_1 = buffer_1.sample()
        samples_2 = buffer_2.sample()

        assert samples_0 != samples_1
        assert samples_0 != samples_2
        assert samples_1 == samples_2