def test_proxyed_send_repceive(): host = '127.0.0.1' port_frontend = 7000 port_backend = 7001 N = 6 msg_send = list(range(N)) msg_receive = [None] * len(msg_send) def send(N): sender = ZmqSender(host=host, port=port_frontend, serializer='pickle') for i in range(N): sender.send(msg_send[i]) def receive(N): receiver = ZmqReceiver(host=host, port=port_backend, bind=False, deserializer='pickle') for i in range(N): n = receiver.recv() msg_receive[i] = n server_thread = Thread(target=send, args=[N]) client_thread = Thread(target=receive, args=[N]) in_add = 'tcp://{}:{}'.format(host, port_frontend) out_add = 'tcp://{}:{}'.format(host, port_backend) proxy_thread = ZmqProxyThread(in_add, out_add, pattern='router-dealer') client_thread.start() server_thread.start() proxy_thread.start() client_thread.join() server_thread.join() assert msg_send == msg_receive
def test_proxyed_pull_push(): host = '127.0.0.1' port_frontend = 7002 port_backend = 7003 N = 6 msg_push = list(range(N)) msg_pull = [None] * len(msg_push) def push(N): pusher = ZmqPusher(host=host, port=port_frontend, serializer='pickle') for i in range(N): pusher.push(msg_push[i]) def pull(N): puller = ZmqPuller(host=host, port=port_backend, bind=False, deserializer='pickle') for i in range(N): n = puller.pull() msg_pull[i] = n server_thread = Thread(target=push, args=[N]) client_thread = Thread(target=pull, args=[N]) in_add = 'tcp://{}:{}'.format(host, port_frontend) out_add = 'tcp://{}:{}'.format(host, port_backend) proxy_thread = ZmqProxyThread(in_add, out_add, pattern='pull-push') client_thread.start() server_thread.start() proxy_thread.start() client_thread.join() server_thread.join() assert msg_push == msg_pull
class ReplayLoadBalancer(object): def __init__(self): self.sampler_proxy = None self.collector_proxy = None self.collector_frontend_port = os.environ['SYMPH_COLLECTOR_FRONTEND_PORT'] self.collector_backend_port = os.environ['SYMPH_COLLECTOR_BACKEND_PORT'] self.sampler_frontend_port = os.environ['SYMPH_SAMPLER_FRONTEND_PORT'] self.sampler_backend_port = os.environ['SYMPH_SAMPLER_BACKEND_PORT'] self.collector_frontend_add = "tcp://*:{}".format( self.collector_frontend_port) self.collector_backend_add = "tcp://*:{}".format( self.collector_backend_port) self.sampler_frontend_add = "tcp://*:{}".format(self.sampler_frontend_port) self.sampler_backend_add = "tcp://*:{}".format(self.sampler_backend_port) def launch(self): self.collector_proxy = ZmqProxyThread(in_add=self.collector_frontend_add, out_add=self.collector_backend_add, pattern='router-dealer') self.sampler_proxy = ZmqProxyThread(in_add=self.sampler_frontend_add, out_add=self.sampler_backend_add, pattern='router-dealer') self.collector_proxy.setDaemon(False) self.collector_proxy.start() self.sampler_proxy.setDaemon(False) self.sampler_proxy.start() def join(self): self.collector_proxy.join() self.sampler_proxy.join()
class ShardedParameterServer(object): """ Runs multiple parameter servers in parallel processes. """ def __init__(self, shards, supress_output=False): self.shards = shards # Serving parameter to agents self.frontend_port = os.environ['SYMPH_PS_FRONTEND_PORT'] self.backend_port = os.environ['SYMPH_PS_BACKEND_PORT'] self.serving_frontend_add = "tcp://*:{}".format(self.frontend_port) self.serving_backend_add = "tcp://*:{}".format(self.backend_port) # Subscribing to learner published parameters self.publisher_host = os.environ['SYMPH_PARAMETER_PUBLISH_HOST'] self.publisher_port = os.environ['SYMPH_PARAMETER_PUBLISH_PORT'] self._supress_output = supress_output self.proxy = None self.workers = [] def launch(self): """ Runs load balancing proxy thread and self.shards ParameterServer processes Returns after all threads and processes are running """ self.proxy = ZmqProxyThread(in_add=self.serving_frontend_add, out_add=self.serving_backend_add, pattern='router-dealer') self.proxy.start() self.workers = [] for i in range(self.shards): worker = ParameterServer(publisher_host=self.publisher_host, publisher_port=self.publisher_port, serving_host='localhost', serving_port=self.backend_port, load_balanced=True, supress_output=self._supress_output) worker.start() self.workers.append(worker) def join(self): """ Wait for all parameter server workers to exit (Currently this means they crashed) Note that proxy is a daemon thread and doesn't need waiting """ for i, worker in enumerate(self.workers): worker.join() U.report_exitcode(worker.exitcode, 'ps-{}'.format(i)) def quit(self): for worker in self.workers: worker.terminate()
def main(_): proxy = ZmqProxyThread("tcp://*:%d" % FRONTEND_PORT, "tcp://*:%d" % BACKEND_PORT) proxy.start() server = ZmqServer(host='localhost', port=BACKEND_PORT, serializer=U.serialize, deserializer=U.deserialize, bind=False) server_thread = server.start_loop(handler=server_f, blocking=False) # client = ZmqClient(host='localhost', # port=PORT, # timeout=2, # serializer=U.serialize, # deserializer=U.deserialize) client = get_ps_client() for _ in range(10): # client.request(['info', ['x']]) client.fetch_parameter_with_info(['x']) print('Done!')
class ShardedReplay(object): def __init__(self, replay_class, learner_config, env_config, session_config,): """ Args: *_config: passed on to replay """ self.sampler_proxy = None self.collector_proxy = None self.processes = [] self.learner_config = learner_config self.env_config = env_config self.session_config = session_config self.replay_class = replay_class self.shards = self.learner_config.replay.replay_shards self.collector_frontend_port = os.environ['SYMPH_COLLECTOR_FRONTEND_PORT'] self.collector_backend_port = os.environ['SYMPH_COLLECTOR_BACKEND_PORT'] self.sampler_frontend_port = os.environ['SYMPH_SAMPLER_FRONTEND_PORT'] self.sampler_backend_port = os.environ['SYMPH_SAMPLER_BACKEND_PORT'] self.collector_frontend_add = "tcp://*:{}".format(self.collector_frontend_port) self.collector_backend_add = "tcp://*:{}".format(self.collector_backend_port) self.sampler_frontend_add = "tcp://*:{}".format(self.sampler_frontend_port) self.sampler_backend_add = "tcp://*:{}".format(self.sampler_backend_port) def launch(self): self.processes = [] print('Starting {} replay shards'.format(self.shards)) for i in range(self.shards): print('Replay {} starting'.format(i)) p = Process(target=self.start_replay, args=[i]) p.start() self.processes.append(p) self.collector_proxy = ZmqProxyThread( in_add=self.collector_frontend_add, out_add=self.collector_backend_add, pattern='router-dealer') self.sampler_proxy = ZmqProxyThread( in_add=self.sampler_frontend_add, out_add=self.sampler_backend_add, pattern='router-dealer') self.collector_proxy.start() self.sampler_proxy.start() def start_replay(self, index): replay = self.replay_class(self.learner_config, self.env_config, self.session_config, index=index) replay.start_threads() replay.join() def join(self): for i, p in enumerate(self.processes): p.join() U.report_exitcode(p.exitcode, 'replay-{}'.format(i)) self.collector_proxy.join() self.sampler_proxy.join()