Beispiel #1
0
def test_get_latest_from_buffer_not_full():
    mem = ReplayMemory(2150, "cpu")

    latest = []
    for i in range(1300):
        transition = Transition(i, 'New', None, None, None)
        latest.append(transition)
        mem.add(*transition)

    returned_latest = mem.get_latest(1300)
    assert (len(returned_latest) == 1300)
    for i, trans in enumerate(latest):
        assert (trans.state == returned_latest[i].state)
        assert (trans.action == returned_latest[i].action)
    def __init__(self, state_dim, action_dim, gamma, tau, buffer_size,
                 is_mem_cuda, out_act):

        self.actor = Actor(state_dim,
                           action_dim,
                           is_evo=False,
                           out_act=out_act)
        self.actor_target = Actor(state_dim,
                                  action_dim,
                                  is_evo=False,
                                  out_act=out_act)
        self.actor_optim = Adam(self.actor.parameters(), lr=1e-4)

        self.critic = Critic(state_dim, action_dim)
        self.critic_target = Critic(state_dim, action_dim)
        self.critic_optim = Adam(self.critic.parameters(), lr=1e-3)

        self.gamma = gamma
        self.tau = tau
        self.loss = nn.MSELoss()
        self.replay_buffer = ReplayMemory(buffer_size, is_mem_cuda)
        self.exploration_noise = OUNoise(action_dim)

        hard_update(self.actor_target,
                    self.actor)  # Make sure target is with the same weight
        hard_update(self.critic_target, self.critic)
Beispiel #3
0
def test_add_content_of():
    mem = ReplayMemory(2150, "cpu")
    other = ReplayMemory(1560, "cpu")

    latest = []
    for i in range(2000):
        transition = Transition(i, 'New', None, None, None)
        latest.append(transition)
        other.add(*transition)

    mem.add_content_of(other)

    for i, trans in enumerate(latest[-1560:]):
        assert (trans.state == mem.memory[i].state)
Beispiel #4
0
def test_add_latest_from():
    mem = ReplayMemory(2150, "cpu")
    mem2 = ReplayMemory(3000, "cpu")

    latest = []
    for i in range(1300):
        transition = Transition(i, 'New', None, None, None)
        latest.append(transition)
        mem.add(*transition)

    mem2.add_latest_from(mem, 1000)
    assert (len(mem2) == 1000)
    for i, trans in enumerate(latest[-1000:]):
        assert (trans.state == mem2.memory[i].state)
        assert (trans.action == mem2.memory[i].action)
Beispiel #5
0
def test_get_latest():
    mem = ReplayMemory(4250, "cpu")
    for i in range(5000):
        transition = Transition(i, 'Old', None, None, None)
        mem.add(*transition)

    assert (len(mem) == 4250)

    # other.position >= latest
    latest = []
    for i in range(2000):
        transition = Transition(i, 'New', None, None, None)
        latest.append(transition)
        mem.add(*transition)

    assert (len(mem) == 4250)

    returned_latest = mem.get_latest(2000)
    assert (len(returned_latest) == 2000)
    for i, trans in enumerate(latest):
        assert (trans.state == returned_latest[i].state)
        assert (trans.action == returned_latest[i].action)

    # other.position < latest and buffer is full
    latest = []
    for i in range(2000):
        transition = Transition(i, 'New2', None, None, None)
        latest.append(transition)
        mem.add(*transition)

    assert (len(mem) == 4250)
    assert (mem.position < 2000)

    returned_latest = mem.get_latest(2000)
    assert (len(returned_latest) == 2000)
    for i, trans in enumerate(latest):
        assert (trans.state == returned_latest[i].state)
        assert (trans.action == returned_latest[i].action)
Beispiel #6
0
def test_get_latest_from_small_capacity():
    mem = ReplayMemory(1150, "cpu")
    for i in range(3000):
        transition = Transition(i, 'Old', None, None, None)
        mem.add(*transition)

    assert (len(mem) == 1150)

    latest = []
    for i in range(2000):
        transition = Transition(i, 'New', None, None, None)
        latest.append(transition)
        mem.add(*transition)

    returned_latest = mem.get_latest(2000)
    assert (len(returned_latest) == 1150)
    for i, trans in enumerate(latest[-1150:]):
        assert (trans.state == returned_latest[i].state)
        assert (trans.action == returned_latest[i].action)
Beispiel #7
0
def test_shuffle():
    mem = ReplayMemory(2150, "cpu")

    for i in range(1300):
        transition = Transition(i, 'New', None, None, None)
        mem.add(*transition)

    mem.shuffle()
    returned_latest = mem.get_latest(1300)

    shuffled_states = []
    for i, trans in enumerate(returned_latest):
        shuffled_states.append(returned_latest[i].state[0][0])

    ordered_states = np.sort(shuffled_states)

    # Check the shuffled list is different
    assert (list(np.arange(1300)) != shuffled_states)

    # Check the ordered shuffled list is the same
    assert (list(np.arange(1300)) == list(ordered_states))