예제 #1
0
def test_proxyed_send_repceive():
    host = '127.0.0.1'
    port_frontend = 7000
    port_backend = 7001
    N = 6
    msg_send = list(range(N))
    msg_receive = [None] * len(msg_send)

    def send(N):
        sender = ZmqSender(host=host, port=port_frontend, serializer='pickle')
        for i in range(N):
            sender.send(msg_send[i])

    def receive(N):
        receiver = ZmqReceiver(host=host,
                               port=port_backend,
                               bind=False,
                               deserializer='pickle')
        for i in range(N):
            n = receiver.recv()
            msg_receive[i] = n

    server_thread = Thread(target=send, args=[N])
    client_thread = Thread(target=receive, args=[N])
    in_add = 'tcp://{}:{}'.format(host, port_frontend)
    out_add = 'tcp://{}:{}'.format(host, port_backend)
    proxy_thread = ZmqProxyThread(in_add, out_add, pattern='router-dealer')

    client_thread.start()
    server_thread.start()
    proxy_thread.start()
    client_thread.join()
    server_thread.join()

    assert msg_send == msg_receive
예제 #2
0
def test_proxyed_pull_push():
    host = '127.0.0.1'
    port_frontend = 7002
    port_backend = 7003
    N = 6
    msg_push = list(range(N))
    msg_pull = [None] * len(msg_push)

    def push(N):
        pusher = ZmqPusher(host=host, port=port_frontend, serializer='pickle')
        for i in range(N):
            pusher.push(msg_push[i])

    def pull(N):
        puller = ZmqPuller(host=host,
                           port=port_backend,
                           bind=False,
                           deserializer='pickle')
        for i in range(N):
            n = puller.pull()
            msg_pull[i] = n

    server_thread = Thread(target=push, args=[N])
    client_thread = Thread(target=pull, args=[N])
    in_add = 'tcp://{}:{}'.format(host, port_frontend)
    out_add = 'tcp://{}:{}'.format(host, port_backend)
    proxy_thread = ZmqProxyThread(in_add, out_add, pattern='pull-push')

    client_thread.start()
    server_thread.start()
    proxy_thread.start()
    client_thread.join()
    server_thread.join()

    assert msg_push == msg_pull
예제 #3
0
class ShardedParameterServer(object):
  """
        Runs multiple parameter servers in parallel processes.
    """

  def __init__(self, shards, supress_output=False):
    self.shards = shards

    # Serving parameter to agents
    self.frontend_port = os.environ['SYMPH_PS_FRONTEND_PORT']
    self.backend_port = os.environ['SYMPH_PS_BACKEND_PORT']
    self.serving_frontend_add = "tcp://*:{}".format(self.frontend_port)
    self.serving_backend_add = "tcp://*:{}".format(self.backend_port)

    # Subscribing to learner published parameters
    self.publisher_host = os.environ['SYMPH_PARAMETER_PUBLISH_HOST']
    self.publisher_port = os.environ['SYMPH_PARAMETER_PUBLISH_PORT']

    self._supress_output = supress_output
    self.proxy = None
    self.workers = []

  def launch(self):
    """
        Runs load balancing proxy thread
            and self.shards ParameterServer processes
        Returns after all threads and processes are running
    """
    self.proxy = ZmqProxyThread(in_add=self.serving_frontend_add,
                                out_add=self.serving_backend_add,
                                pattern='router-dealer')
    self.proxy.start()

    self.workers = []
    for i in range(self.shards):
      worker = ParameterServer(publisher_host=self.publisher_host,
                               publisher_port=self.publisher_port,
                               serving_host='localhost',
                               serving_port=self.backend_port,
                               load_balanced=True,
                               supress_output=self._supress_output)
      worker.start()
      self.workers.append(worker)

  def join(self):
    """
            Wait for all parameter server workers to exit
                (Currently this means they crashed)
            Note that proxy is a daemon thread and doesn't need waiting
        """
    for i, worker in enumerate(self.workers):
      worker.join()
      U.report_exitcode(worker.exitcode, 'ps-{}'.format(i))

  def quit(self):
    for worker in self.workers:
      worker.terminate()
예제 #4
0
  def launch(self):
    self.collector_proxy = ZmqProxyThread(in_add=self.collector_frontend_add,
                                          out_add=self.collector_backend_add,
                                          pattern='router-dealer')
    self.sampler_proxy = ZmqProxyThread(in_add=self.sampler_frontend_add,
                                        out_add=self.sampler_backend_add,
                                        pattern='router-dealer')

    self.collector_proxy.setDaemon(False)
    self.collector_proxy.start()
    self.sampler_proxy.setDaemon(False)
    self.sampler_proxy.start()
예제 #5
0
  def launch(self):
    """
        Runs load balancing proxy thread
            and self.shards ParameterServer processes
        Returns after all threads and processes are running
    """
    self.proxy = ZmqProxyThread(in_add=self.serving_frontend_add,
                                out_add=self.serving_backend_add,
                                pattern='router-dealer')
    self.proxy.start()

    self.workers = []
    for i in range(self.shards):
      worker = ParameterServer(publisher_host=self.publisher_host,
                               publisher_port=self.publisher_port,
                               serving_host='localhost',
                               serving_port=self.backend_port,
                               load_balanced=True,
                               supress_output=self._supress_output)
      worker.start()
      self.workers.append(worker)
예제 #6
0
    def launch(self):

        self.processes = []

        print('Starting {} replay shards'.format(self.shards))
        for i in range(self.shards):
            print('Replay {} starting'.format(i))
            p = Process(target=self.start_replay, args=[i])
            p.start()
            self.processes.append(p)

        self.collector_proxy = ZmqProxyThread(
            in_add=self.collector_frontend_add,
            out_add=self.collector_backend_add,
            pattern='router-dealer')
        self.sampler_proxy = ZmqProxyThread(in_add=self.sampler_frontend_add,
                                            out_add=self.sampler_backend_add,
                                            pattern='router-dealer')

        self.collector_proxy.start()
        self.sampler_proxy.start()
예제 #7
0
def main(_):

  proxy = ZmqProxyThread("tcp://*:%d" % FRONTEND_PORT,
                         "tcp://*:%d" % BACKEND_PORT)
  proxy.start()
  server = ZmqServer(host='localhost',
                     port=BACKEND_PORT,
                     serializer=U.serialize,
                     deserializer=U.deserialize,
                     bind=False)
  server_thread = server.start_loop(handler=server_f, blocking=False)

  # client = ZmqClient(host='localhost',
  #                    port=PORT,
  #                    timeout=2,
  #                    serializer=U.serialize,
  #                    deserializer=U.deserialize)
  client = get_ps_client()
  for _ in range(10):
    # client.request(['info', ['x']])
    client.fetch_parameter_with_info(['x'])
  print('Done!')
예제 #8
0
class ShardedReplay(object):
    def __init__(self,
                 replay_class,
                 learner_config,
                 env_config,
                 session_config,):
        """
        Args:
            *_config: passed on to replay
        """
        self.sampler_proxy = None
        self.collector_proxy = None
        self.processes = []

        self.learner_config = learner_config
        self.env_config = env_config
        self.session_config = session_config

        self.replay_class = replay_class

        self.shards = self.learner_config.replay.replay_shards

        self.collector_frontend_port = os.environ['SYMPH_COLLECTOR_FRONTEND_PORT']
        self.collector_backend_port = os.environ['SYMPH_COLLECTOR_BACKEND_PORT']
        self.sampler_frontend_port = os.environ['SYMPH_SAMPLER_FRONTEND_PORT']
        self.sampler_backend_port = os.environ['SYMPH_SAMPLER_BACKEND_PORT']

        self.collector_frontend_add = "tcp://*:{}".format(self.collector_frontend_port)
        self.collector_backend_add = "tcp://*:{}".format(self.collector_backend_port)
        self.sampler_frontend_add = "tcp://*:{}".format(self.sampler_frontend_port)
        self.sampler_backend_add = "tcp://*:{}".format(self.sampler_backend_port)

    def launch(self):

        self.processes = []

        print('Starting {} replay shards'.format(self.shards))
        for i in range(self.shards):
            print('Replay {} starting'.format(i))
            p = Process(target=self.start_replay, args=[i])
            p.start()
            self.processes.append(p)

        self.collector_proxy = ZmqProxyThread(
            in_add=self.collector_frontend_add,
            out_add=self.collector_backend_add,
            pattern='router-dealer')
        self.sampler_proxy = ZmqProxyThread(
            in_add=self.sampler_frontend_add,
            out_add=self.sampler_backend_add,
            pattern='router-dealer')

        self.collector_proxy.start()
        self.sampler_proxy.start()

    def start_replay(self, index):
        replay = self.replay_class(self.learner_config,
                                   self.env_config,
                                   self.session_config,
                                   index=index)
        replay.start_threads()
        replay.join()

    def join(self):
        for i, p in enumerate(self.processes):
            p.join()
            U.report_exitcode(p.exitcode, 'replay-{}'.format(i))
        self.collector_proxy.join()
        self.sampler_proxy.join()
예제 #9
0
class ReplayLoadBalancer(object):

  def __init__(self):
    self.sampler_proxy = None
    self.collector_proxy = None

    self.collector_frontend_port = os.environ['SYMPH_COLLECTOR_FRONTEND_PORT']
    self.collector_backend_port = os.environ['SYMPH_COLLECTOR_BACKEND_PORT']
    self.sampler_frontend_port = os.environ['SYMPH_SAMPLER_FRONTEND_PORT']
    self.sampler_backend_port = os.environ['SYMPH_SAMPLER_BACKEND_PORT']

    self.collector_frontend_add = "tcp://*:{}".format(
        self.collector_frontend_port)
    self.collector_backend_add = "tcp://*:{}".format(
        self.collector_backend_port)
    self.sampler_frontend_add = "tcp://*:{}".format(self.sampler_frontend_port)
    self.sampler_backend_add = "tcp://*:{}".format(self.sampler_backend_port)

  def launch(self):
    self.collector_proxy = ZmqProxyThread(in_add=self.collector_frontend_add,
                                          out_add=self.collector_backend_add,
                                          pattern='router-dealer')
    self.sampler_proxy = ZmqProxyThread(in_add=self.sampler_frontend_add,
                                        out_add=self.sampler_backend_add,
                                        pattern='router-dealer')

    self.collector_proxy.setDaemon(False)
    self.collector_proxy.start()
    self.sampler_proxy.setDaemon(False)
    self.sampler_proxy.start()

  def join(self):
    self.collector_proxy.join()
    self.sampler_proxy.join()