示例#1
0
def test_02():
    server = ReplayMemoryServer(entry, 4, 64)
    server.rem.rollout_len = 4
    server.rem.max_episode = 0
    batch_size = 1
    #
    threads = []
    threads.append(
        Thread(target=server.rep_worker_main, args=("tcp://*:10101", Bind)))
    threads.append(
        Thread(target=server.pull_worker_main, args=("tcp://*:10102", Bind)))
    threads.append(Thread(target=client_main, args=(1, 6, 0)))

    for th in threads:
        th.start()

    data, weight = server.get_data(batch_size)
    assert len(data) == len(entry) + 1
    assert len(weight) == batch_size
    s = data[0]
    assert len(s) == batch_size
    assert len(s[0]) == server.rem.rollout_len
    traj = s[0]
    assert traj[0] == 2
    assert traj[1] == 3
    assert traj[2] == 4
    assert traj[3] == 5
    time.sleep(1)
    server.close()  # Prevent auto-deletion
    time.sleep(1)
示例#2
0
def test_04():
    # Test not_first
    server = ReplayMemoryServer(entry, 4, 64)
    server.rem.rollout_len = 4
    server.rem.max_episode = 0
    server.rem.do_padding = False
    batch_size = 1
    #
    threads = []
    threads.append(
        Thread(target=server.rep_worker_main, args=("tcp://*:10101", Bind)))
    threads.append(
        Thread(target=server.pull_worker_main, args=("tcp://*:10102", Bind)))
    threads.append(Thread(target=client_main, args=(1, 3, 0)))

    for th in threads:
        th.start()

    def term():
        time.sleep(3)
        server.close()

    Thread(target=term).start()

    data, weight = server.get_data(batch_size)
    assert False
示例#3
0
#!/usr/bin/env python

from memoire import ReplayMemoryServer, ReplayMemoryClient, Bind, Conn
import numpy as np
import time, os
from threading import Thread

s = np.ndarray([2, 2], dtype=np.float32)
r = np.ndarray([], dtype=np.float32)
p = np.ndarray([], dtype=np.float32)
v = np.ndarray([], dtype=np.float32)

entry = (s, r, p, v)

server = ReplayMemoryServer(entry, 4, 64)
server.rem.rollout_len = 4
server.rem.max_episode = 0
server.print_info()
server.set_logfile("test.log", "w")

try:
    batch_size = 1
    #
    threads = []
    threads.append(
        Thread(target=server.rep_worker_main, args=("tcp://*:10101", Bind)))
    threads.append(
        Thread(target=server.pull_worker_main, args=("tcp://*:10102", Bind)))
    for th in threads:
        th.start()
    while True:
示例#4
0
文件: pub.py 项目: whao1160/memoire
#!/usr/bin/env python

from memoire import ReplayMemoryServer, ReplayMemoryClient, Bind, Conn
import numpy as np
import time, os
from threading import Thread

r = np.ndarray([], dtype=np.float32)
p = np.ndarray([], dtype=np.float32)
v = np.ndarray([], dtype=np.float32)

entry = (r, p, v)

server = ReplayMemoryServer(entry, 0, 0)
server.pub_endpoint = "tcp://*:10100"

try:
    index = 0
    while True:
        server.pub_bytes("topic", "message:%d" % index)
        index += 1
        time.sleep(1)
except KeyboardInterrupt:
    pass
os.kill(os.getpid(), 9)
示例#5
0
#!/usr/bin/env python

from memoire import ReplayMemoryServer, ReplayMemoryClient, Bind, Conn
import numpy as np
import time, os
from threading import Thread

s = np.ndarray([1], dtype=np.float32)
a = np.ndarray([1], dtype=np.float32)
r = np.ndarray([], dtype=np.float32)
p = np.ndarray([], dtype=np.float32)
v = np.ndarray([], dtype=np.float32)

entry = (s, a, r, p, v)

server = ReplayMemoryServer(entry, 4, 64)
server.rem.rollout_len = 4
server.rem.max_episode = 0
server.print_info()

try:
    batch_size = 2
    #
    threads = []
    threads.append(
        Thread(target=server.rep_worker_main, args=("tcp://*:10101", Bind)))
    threads.append(
        Thread(target=server.pull_worker_main, args=("tcp://*:10102", Bind)))
    #threads.append(Thread(target=server.rep_worker_main,  args=("ipc:///tmp/memoire_reqrep_test", Bind)))
    #threads.append(Thread(target=server.pull_worker_main, args=("ipc:///tmp/memoire_pushpull_test", Bind)))
    for th in threads: