예제 #1
0
def test_client_server_eventloop():
    host = '127.0.0.1'
    port = 7001
    N = 20
    data = [None, None]

    server = ZmqServer(host=host,
                       port=port,
                       serializer='pickle',
                       deserializer='pickle')
    all_requests = []

    def handler(x):
        if x == N - 1:
            server.stop()
        all_requests.append(x)
        return x + 1

    server_thread = server.start_loop(handler, blocking=False)

    client = ZmqClient(host=host,
                       port=port,
                       serializer='pickle',
                       deserializer='pickle')
    all_responses = []
    for i in range(N):
        all_responses.append(client.request(i))
    data[1] = all_responses

    server_thread.join()

    assert all_requests == [n for n in range(N)]
    assert all_responses == [n + 1 for n in range(N)]
예제 #2
0
    def run(self):
        """
            Run relative threads and wait until they finish (due to error)
        """
        self._subscriber = ZmqSub(
            host=self.publisher_host,
            port=self.publisher_port,
            # handler=self._set_storage,
            topic='ps',
            deserializer=U.deserialize,
        )
        self._server = ZmqServer(
            host=self.serving_host,
            port=self.serving_port,
            # handler=self._handle_agent_request,
            serializer=U.serialize,
            deserializer=U.deserialize,
            bind=not self.load_balanced,
        )
        self._subscriber_thread = self._subscriber.start_loop(
            handler=self._set_storage, blocking=False)
        self._server_thread = self._server.start_loop(
            handler=self._handle_agent_request, blocking=False)
        print('Parameter server started')

        self._subscriber_thread.join()
        self._server_thread.join()
예제 #3
0
파일: base.py 프로젝트: aravic/liaison
    def __init__(self,
                 seed,
                 evict_interval,
                 compress_before_send,
                 load_balanced=True,
                 index=0,
                 **kwargs):
        self.config = ConfigDict(kwargs)
        self.index = index

        if load_balanced:
            collector_port = os.environ['SYMPH_COLLECTOR_BACKEND_PORT']
            sampler_port = os.environ['SYMPH_SAMPLER_BACKEND_PORT']
        else:
            collector_port = os.environ['SYMPH_COLLECTOR_FRONTEND_PORT']
            sampler_port = os.environ['SYMPH_SAMPLER_FRONTEND_PORT']
        self._collector_server = ExperienceCollectorServer(
            host='localhost' if load_balanced else '*',
            port=collector_port,
            exp_handler=self._insert_wrapper,
            load_balanced=load_balanced,
            compress_before_send=compress_before_send)
        self._sampler_server = ZmqServer(
            host='localhost' if load_balanced else '*',
            port=sampler_port,
            bind=not load_balanced,
            serializer=get_serializer(compress_before_send),
            deserializer=get_deserializer(compress_before_send))
        self._sampler_server_thread = None

        self._evict_interval = evict_interval
        self._evict_thread = None

        self._setup_logging()
예제 #4
0
    def run(self):
        """
      Run relative threads and wait until they finish (due to error)
    """

        if self._supress_output:
            sys.stdout = open('/tmp/' + 'latest' + ".out", "w")
            sys.stderr = open('/tmp/' + 'latest' + ".err", "w")

        self._param_reciever = ZmqServer(
            host='*',
            port=self.publish_port,
            serializer=U.serialize,
            deserializer=U.deserialize,
        )
        self._server = ZmqServer(
            host='*',
            port=self.serving_port,
            # handler=self._handle_agent_request,
            serializer=U.serialize,
            deserializer=U.deserialize,
        )
        self._subscriber_thread = self._param_reciever.start_loop(
            handler=self._set_storage, blocking=False)
        self._server_thread = self._server.start_loop(
            handler=self._handle_agent_request, blocking=False)
        logging.info('Parameter server started')

        self._subscriber_thread.join()
        self._server_thread.join()
예제 #5
0
파일: base.py 프로젝트: wwxFromTju/surreal
    def __init__(self, learner_config, env_config, session_config, index=0):
        """
        """
        # Note that there're 2 replay configs:
        # one in learner_config that controls algorithmic
        # part of the replay logic
        # one in session_config that controls system settings
        self.learner_config = learner_config
        self.env_config = env_config
        self.session_config = session_config
        self.index = index

        collector_port = os.environ['SYMPH_COLLECTOR_BACKEND_PORT']
        sampler_port = os.environ['SYMPH_SAMPLER_BACKEND_PORT']
        self._collector_server = ExperienceCollectorServer(
            host='localhost',
            port=collector_port,
            exp_handler=self._insert_wrapper,
            load_balanced=True,
        )
        self._sampler_server = ZmqServer(host='localhost',
                                         port=sampler_port,
                                         bind=False)
        self._sampler_server_thread = None

        self._evict_interval = self.session_config.replay.evict_interval
        self._evict_thread = None

        self._setup_logging()
예제 #6
0
파일: worker.py 프로젝트: aravic/liaison
 def run(self):
     self._server = ZmqServer(host=self.serving_host,
                              port=self.serving_port,
                              serializer=U.serialize,
                              deserializer=U.deserialize,
                              bind=True)
     self._server.start_loop(handler=self._handle_request, blocking=True)
예제 #7
0
 def server(N):
     server = ZmqServer(host=host,
                        port=port,
                        serializer='pickle',
                        deserializer='pickle')
     all_requests = []
     for i in range(N):
         n = server.recv()
         all_requests.append(n)
         server.send(n + 1)
     data[0] = all_requests
예제 #8
0
def server_fn():
    server = ZmqServer(host='*',
                       port=args.port,
                       serializer='pyarrow',
                       deserializer='pyarrow')

    def f(msg):
        assert msg == MSG
        print('message received succesfully')

    server.start_loop(f)
예제 #9
0
  def run(self):
    self._server = ZmqServer(
        host='*',
        port=self.port,
        serializer=U.pickle_serialize,
        deserializer=U.pickle_deserialize,
        bind=True,
    )

    self._server_thread = self._server.start_loop(handler=self._handle_request,
                                                  blocking=False)
    logging.info('Spec server started')

    self._server_thread.join()
예제 #10
0
    def _start_remote_server(self):
        server = ZmqServer(host='*',
                           port=REMOTE_PORT,
                           serializer=get_serializer(),
                           deserializer=get_deserializer(),
                           auth=False)
        graph = self._get_graph_features()

        def handler(_):
            return [graph]

        thr = Thread(target=lambda: server.start_loop(handler, blocking=True))
        thr.daemon = True
        thr.start()
        return thr
예제 #11
0
class SpecServer(Thread):

  def __init__(self, port, traj_spec, action_spec):
    self._traj_spec = traj_spec
    self._action_spec = action_spec
    self.port = port
    super(SpecServer, self).__init__()

  def run(self):
    self._server = ZmqServer(
        host='*',
        port=self.port,
        serializer=U.pickle_serialize,
        deserializer=U.pickle_deserialize,
        bind=True,
    )

    self._server_thread = self._server.start_loop(handler=self._handle_request,
                                                  blocking=False)
    logging.info('Spec server started')

    self._server_thread.join()

  def _handle_request(self, req):
    """req -> (batch_size, traj_length)"""
    batch_size, _ = req
    traj_spec = Trajectory.format_traj_spec(self._traj_spec, *req)
    self._action_spec.set_shape((batch_size, ) + self._action_spec.shape[1:])
    return traj_spec, self._action_spec
예제 #12
0
    def __init__(self):

        self.param_info = {
            'time': time.time(),
            'iteration': 0,
            'variable_list': [],
            'hash': U.pyobj_hash({}),
        }

        self._server = ZmqServer(
            host=_LOCALHOST,
            port=PS_FRONTEND_PORT,
            serializer=U.serialize,
            deserializer=U.deserialize,
            bind=True,
        )
        self._server_thread = self._server.start_loop(
            handler=self._handle_agent_request, blocking=False)
예제 #13
0
class Proxy:

  def __init__(self, serving_host, serving_port):
    self.serving_host = serving_host
    self.serving_port = serving_port
    self._irs_client = IRSClient(auto_detect_proxy=False)

  def run(self):
    self._server = ZmqServer(host='*',
                             port=self.serving_port,
                             serializer=U.serialize,
                             deserializer=U.deserialize,
                             bind=True)
    self._server.start_loop(handler=self._handle_request, blocking=True)

  def _handle_request(self, req):
    req_fn, args, kwargs = req
    return getattr(self._irs_client, req_fn)(*args, **kwargs)
예제 #14
0
def main(_):
    server = ZmqServer(host='localhost',
                       port=PORT,
                       serializer=U.serialize,
                       deserializer=U.deserialize,
                       bind=True)
    server_thread = server.start_loop(handler=server_f, blocking=False)

    # client = ZmqClient(host='localhost',
    #                    port=PORT,
    #                    timeout=2,
    #                    serializer=U.serialize,
    #                    deserializer=U.deserialize)
    client = get_ps_client()
    for _ in range(10):
        # client.request(['info', ['x']])
        client.fetch_parameter_with_info(['x'])
    print('Done!')
예제 #15
0
 def server():
     server = ZmqServer(host=host,
                        port=port,
                        serializer='pickle',
                        deserializer='pickle')
     while True:
         req = server.recv()
         if req == 'request-1':
             server.send('received-request-1')
         elif req == 'request-2':
             time.sleep(0.5)
             server.send('received-request-2')
         elif req == 'request-3':
             server.send('received-request-3')
             break
예제 #16
0
def main(_):

  proxy = ZmqProxyThread("tcp://*:%d" % FRONTEND_PORT,
                         "tcp://*:%d" % BACKEND_PORT)
  proxy.start()
  server = ZmqServer(host='localhost',
                     port=BACKEND_PORT,
                     serializer=U.serialize,
                     deserializer=U.deserialize,
                     bind=False)
  server_thread = server.start_loop(handler=server_f, blocking=False)

  # client = ZmqClient(host='localhost',
  #                    port=PORT,
  #                    timeout=2,
  #                    serializer=U.serialize,
  #                    deserializer=U.deserialize)
  client = get_ps_client()
  for _ in range(10):
    # client.request(['info', ['x']])
    client.fetch_parameter_with_info(['x'])
  print('Done!')
예제 #17
0
class DummyPS:
    def __init__(self):

        self.param_info = {
            'time': time.time(),
            'iteration': 0,
            'variable_list': [],
            'hash': U.pyobj_hash({}),
        }

        self._server = ZmqServer(
            host=_LOCALHOST,
            port=PS_FRONTEND_PORT,
            serializer=U.serialize,
            deserializer=U.deserialize,
            bind=True,
        )
        self._server_thread = self._server.start_loop(
            handler=self._handle_agent_request, blocking=False)

    def _handle_agent_request(self, request):
        """Reply to agents' request for parameters."""

        request = PSRequest(**request)

        if request.type == 'info':
            return PSResponse(type='info')._asdict()

        elif request.type == 'parameter':
            if request.hash is not None:
                if request.hash == self.param_info[
                        'hash']:  # param not changed
                    return PSResponse(type='no_change',
                                      info=self.param_info)._asdict()

            return PSResponse(type='parameters',
                              info=self.param_info,
                              parameters={})._asdict()
        else:
            raise ValueError('invalid request type received: %s' %
                             (request.type))
예제 #18
0
파일: worker.py 프로젝트: aravic/liaison
class Worker(Thread):
    def __init__(self, serving_host, serving_port, checkpoint_folder,
                 profile_folder, kvstream_folder, **kwargs):
        Thread.__init__(self)
        self.config = ConfigDict(**kwargs)
        self.checkpoint_folder = checkpoint_folder
        self.profile_folder = profile_folder
        self.kvstream_folder = kvstream_folder
        self.serving_host = serving_host
        self.serving_port = serving_port

        # Attributes
        self._server = None

    def run(self):
        self._server = ZmqServer(host=self.serving_host,
                                 port=self.serving_port,
                                 serializer=U.serialize,
                                 deserializer=U.deserialize,
                                 bind=True)
        self._server.start_loop(handler=self._handle_request, blocking=True)

    def _handle_request(self, req):
        req_fn, args, kwargs = req
        assert isinstance(req_fn, str)
        try:
            fn = getattr(self, req_fn)
            return fn(*args, **kwargs)
        except AttributeError:
            logging.error('Unknown request func name received: %s', req_fn)

    def _read_checkpoint_info(self):
        with open(os.path.join(self.checkpoint_folder, 'info.txt'), 'r') as f:
            ckpts = []
            for line in f.readlines():
                j = json.loads(line)
                ckpts.append([j['time'], j['dst_dir_name']])

        ckpts.sort()
        return ckpts

    def _write_checkpoint_info(self, ckpts):
        with open(os.path.join(self.checkpoint_folder, 'info.txt'), 'w') as f:
            for t, d in ckpts:
                print('{"dst_dir_name": "%s", "time":%d}' % (d, t), file=f)

    def _enforce_checkpoint_policy(self):
        """
      max_to_keep should include latest checkpoint
      keep_ckpt_every_n_hrs < 0 => disables this option
    """
        max_to_keep = self.config.max_to_keep
        assert max_to_keep >= 0
        keep_ckpt_every_n_hrs = self.config.keep_ckpt_every_n_hrs
        ckpts = self._read_checkpoint_info()

        if keep_ckpt_every_n_hrs >= 0:
            prev_keep_t = ckpts[0][0]
            to_delete = []
            for i, (t, d) in enumerate(ckpts[1:]):
                if t - prev_keep_t < 3600 * keep_ckpt_every_n_hrs:
                    to_delete.append([t, d])
                else:
                    prev_keep_t = t
        else:
            to_delete = list(ckpts)

        if max_to_keep:
            for ckpt in ckpts[-max_to_keep:]:
                if ckpt in to_delete:
                    to_delete.remove(ckpt)

        self._write_checkpoint_info(
            [ckpt for ckpt in ckpts if ckpt not in to_delete])

        for _, d in to_delete:
            U.f_remove(os.path.join(self.checkpoint_folder, d))

    def _stream_to_file(self, offset, data, fname, done):
        """fname should be with full path."""
        fname = fname.rstrip('/')

        Path(fname + '.part').touch()
        with open(fname + '.part', 'r+b') as f:
            f.seek(offset, os.SEEK_SET)
            f.write(data)

        if done:
            shutil.move(fname + '.part', fname)

    # ================== PUBLIC REMOTE API ==================
    def register_commands(self, **cmds):
        U.f_mkdir(self.config.cmd_folder)
        U.pretty_dump(cmds, os.path.join(self.config.cmd_folder, 'cmds.txt'))

    def register_metagraph(self, offset, data, _, fname, done):
        """TODO: If multi threaded client, add filelock support here."""
        U.f_mkdir(self.checkpoint_folder)
        self._stream_to_file(offset, data,
                             os.path.join(self.checkpoint_folder, fname), done)

    def register_checkpoint(self, offset, data, dst_dir_name, fname, done):
        """TODO: If multi threaded client, add filelock support here."""
        U.f_mkdir(os.path.join(self.checkpoint_folder, dst_dir_name))
        self._stream_to_file(
            offset, data,
            os.path.join(self.checkpoint_folder, dst_dir_name, fname), done)

        if done:
            with open(os.path.join(self.checkpoint_folder, 'info.txt'),
                      'a') as f:
                print('{"dst_dir_name": "%s", "time":%d}' %
                      (dst_dir_name, int(time.time())),
                      file=f)
            logging.info("Received new checkpoint which is saved at %s/%s/%s",
                         self.checkpoint_folder, dst_dir_name, fname)
            self.enforce_checkpoint_policy()

    def register_profile(self, offset, data, dst_dir_name, fname, done):
        """TODO: If multi threaded client, add filelock support here."""
        del dst_dir_name  # unused
        U.f_mkdir(self.profile_folder)
        self._stream_to_file(offset, data,
                             os.path.join(self.profile_folder, fname), done)

        if done:
            logging.info("Received new profile which is saved at %s/%s",
                         self.checkpoint_folder, fname)

    def enforce_checkpoint_policy(self):
        # Remove duplicates from checkpoint info.txt first
        ckpts = self._read_checkpoint_info()
        ckpts = sorted([[sec, first] for first, sec in ckpts])
        to_remove = []
        for i, (d, t) in enumerate(ckpts):
            if i > 0:
                if d == ckpts[i - 1][0]:
                    to_remove.append(i - 1)

        self._write_checkpoint_info([[t, d] for i, (d, t) in enumerate(ckpts)
                                     if i not in to_remove])

        self._enforce_checkpoint_policy()

    def record_kv_data(self, stream, kv_data, **kwargs):
        """Add key-value data to the stream."""
        logging.info(f'Received new kvdata on stream {stream}')
        U.f_mkdir(self.kvstream_folder)
        with open(os.path.join(self.kvstream_folder, stream) + '.pkl',
                  'wb') as f:
            d = dict(kv=kv_data, stream=stream, **kwargs)
            pickle.dump(d, f)

    def save_file(self, fname, data, **kwargs):
        fname = f'{self.config.vis_files_folder}/{fname}'
        U.f_mkdir(os.path.dirname(fname))
        with open(fname, 'wb') as f:
            f.write(data)
예제 #19
0
파일: base.py 프로젝트: aravic/liaison
class Replay:
    """
        Important: When extending this class, make sure to follow the init
        method signature so that orchestrating functions can properly
        initialize the replay server.
    """
    def __init__(self,
                 seed,
                 evict_interval,
                 compress_before_send,
                 load_balanced=True,
                 index=0,
                 **kwargs):
        self.config = ConfigDict(kwargs)
        self.index = index

        if load_balanced:
            collector_port = os.environ['SYMPH_COLLECTOR_BACKEND_PORT']
            sampler_port = os.environ['SYMPH_SAMPLER_BACKEND_PORT']
        else:
            collector_port = os.environ['SYMPH_COLLECTOR_FRONTEND_PORT']
            sampler_port = os.environ['SYMPH_SAMPLER_FRONTEND_PORT']
        self._collector_server = ExperienceCollectorServer(
            host='localhost' if load_balanced else '*',
            port=collector_port,
            exp_handler=self._insert_wrapper,
            load_balanced=load_balanced,
            compress_before_send=compress_before_send)
        self._sampler_server = ZmqServer(
            host='localhost' if load_balanced else '*',
            port=sampler_port,
            bind=not load_balanced,
            serializer=get_serializer(compress_before_send),
            deserializer=get_deserializer(compress_before_send))
        self._sampler_server_thread = None

        self._evict_interval = evict_interval
        self._evict_thread = None

        self._setup_logging()

    def start_threads(self):
        if self._has_tensorplex:
            self.start_tensorplex_thread()

        self._collector_server.start()

        if self._evict_interval:
            self.start_evict_thread()

        self._sampler_server_thread = self._sampler_server.start_loop(
            handler=self._sample_request_handler)

    def join(self):
        self._collector_server.join()
        self._sampler_server_thread.join()
        if self._has_tensorplex:
            self._tensorplex_thread.join()
        if self._evict_interval:
            self._evict_thread.join()

    def insert(self, exp_dict):
        """
        Add a new experience to the replay.
        Includes passive evict logic if memory capacity is exceeded.

        Args:
            exp_dict: {[obs], action, reward, done, info}
        """
        raise NotImplementedError

    def sample(self, batch_size):
        """
        This function is called in _sample_handler for learner side Zmq request

        Args:
            batch_size

        Returns:
            a list of exp_tuples
        """
        raise NotImplementedError

    def evict(self):
        """
        Actively evict old experiences.
        """
        pass

    def start_sample_condition(self):
        """
        Tells the thread to start sampling only when this condition is met.
        For example, only when the replay memory has > 10K experiences.

        Returns:
            bool: whether to start sampling or not
        """
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    # ======================== internal methods ========================
    def _sample_request_handler(self, req):
        """
    Handle requests to the learner
    https://stackoverflow.com/questions/29082268/python-time-sleep-vs-event-wait
    Since we don't have external notify, we'd better just use sleep
    """
        # batch_size = U.deserialize(req)
        batch_size = req
        U.assert_type(batch_size, int)
        while not self.start_sample_condition():
            time.sleep(0.01)
        self.cumulative_sampled_count += batch_size
        self.cumulative_request_count += 1

        with self.sample_time.time():
            while True:
                try:
                    sample = self.sample(batch_size)
                    break
                except ReplayUnderFlowException:
                    time.sleep(1e-3)

        with self.serialize_time.time():
            return sample
        # return U.serialize(sample)

    def _insert_wrapper(self, exp):
        """
            Allows us to do some book keeping in the base class
        """
        self.cumulative_collected_count += 1
        with self.insert_time.time():
            self.insert(exp)

    def _get_tensorplex_client(self, client_id):
        host = os.environ['SYMPH_TENSORPLEX_SYSTEM_HOST']
        port = os.environ['SYMPH_TENSORPLEX_SYSTEM_PORT']
        return TensorplexClient(
            client_id,
            host=host,
            port=port,
            serializer=self.config.tensorplex_config.serializer,
            deserializer=self.config.tensorplex_config.deserializer)

    def _setup_logging(self):
        # self.log = get_loggerplex_client('{}/{}'.format('replay', self.index),
        #                                  self.config)
        self.tensorplex = self._get_tensorplex_client('{}/{}'.format(
            'replay', self.index))
        self._tensorplex_thread = None
        self._has_tensorplex = self.config.tensorboard_display

        # Origin of all global steps
        self.init_time = time.time()
        # Number of experience collected by agents
        self.cumulative_collected_count = 0
        # Number of experience sampled by learner
        self.cumulative_sampled_count = 0
        # Number of sampling requests from the learner
        self.cumulative_request_count = 0
        # Timer for tensorplex reporting
        self.last_tensorplex_iter_time = time.time()
        # Last reported values used for speed computation
        self.last_experience_count = 0
        self.last_sample_count = 0
        self.last_request_count = 0

        self.insert_time = U.TimeRecorder(decay=0.99998)
        self.sample_time = U.TimeRecorder()
        self.serialize_time = U.TimeRecorder()

        # moving avrage of about 100s
        self.exp_in_speed = U.MovingAverageRecorder(decay=0.99)
        self.exp_out_speed = U.MovingAverageRecorder(decay=0.99)
        self.handle_sample_request_speed = U.MovingAverageRecorder(decay=0.99)

    def start_evict_thread(self):
        if self._evict_thread is not None:
            raise RuntimeError('evict thread already running')
        self._evict_thread = U.start_thread(self._evict_loop)
        return self._evict_thread

    def _evict_loop(self):
        assert self._evict_interval
        while True:
            time.sleep(self._evict_interval)
            self.evict()

    def start_tensorplex_thread(self):
        if self._tensorplex_thread is not None:
            raise RuntimeError('tensorplex thread already running')
        self._tensorplex_thread = U.PeriodicWakeUpWorker(
            target=self.generate_tensorplex_report)
        self._tensorplex_thread.start()
        return self._tensorplex_thread

    def generate_tensorplex_report(self):
        """
            Generates stats to be reported to tensorplex
        """
        global_step = int(time.time() - self.init_time)

        time_elapsed = time.time() - self.last_tensorplex_iter_time + 1e-6

        cum_count_collected = self.cumulative_collected_count
        new_exp_count = cum_count_collected - self.last_experience_count
        self.last_experience_count = cum_count_collected

        cum_count_sampled = self.cumulative_sampled_count
        new_sample_count = cum_count_sampled - self.last_sample_count
        self.last_sample_count = cum_count_sampled

        cum_count_requests = self.cumulative_request_count
        new_request_count = cum_count_requests - self.last_request_count
        self.last_request_count = cum_count_requests

        exp_in_speed = self.exp_in_speed.add_value(new_exp_count /
                                                   time_elapsed)
        exp_out_speed = self.exp_out_speed.add_value(new_sample_count /
                                                     time_elapsed)
        handle_sample_request_speed = self.handle_sample_request_speed.add_value(
            new_request_count / time_elapsed)

        insert_time = self.insert_time.avg
        sample_time = self.sample_time.avg
        serialize_time = self.serialize_time.avg

        core_metrics = {
            'num_exps': len(self),
            'total_collected_exps': cum_count_collected,
            'total_sampled_exps': cum_count_sampled,
            'total_sample_requests': self.cumulative_request_count,
            'exp_in_per_s': exp_in_speed,
            'exp_out_per_s': exp_out_speed,
            'requests_per_s': handle_sample_request_speed,
            'insert_time_s': insert_time,
            'sample_time_s': sample_time,
            'serialize_time_s': serialize_time,
        }

        if hasattr(self, 'per_sample_size'):
            core_metrics['per_sample_size_MB'] = self.per_sample_size / 1e6
        serialize_load = serialize_time * handle_sample_request_speed / time_elapsed
        collect_exp_load = insert_time * exp_in_speed / time_elapsed
        sample_exp_load = sample_time * handle_sample_request_speed / time_elapsed

        system_metrics = {
            'lifetime_experience_utilization_percent':
            cum_count_sampled / (cum_count_collected + 1) * 100,
            'current_experience_utilization_percent':
            exp_out_speed / (exp_in_speed + 1) * 100,
            'serialization_load_percent':
            serialize_load * 100,
            'collect_exp_load_percent':
            collect_exp_load * 100,
            'sample_exp_load_percent':
            sample_exp_load * 100,
            # 'exp_queue_occupancy_percent': self._exp_queue.occupancy() * 100,
        }

        all_metrics = {}
        for k in core_metrics:
            all_metrics['.core/' + k] = core_metrics[k]
        for k in system_metrics:
            all_metrics['.system/' + k] = system_metrics[k]
        self.tensorplex.add_scalars(all_metrics, global_step=global_step)

        self.last_tensorplex_iter_time = time.time()
예제 #20
0
class ParameterServer(Thread):
    """
      Standalone script for PS node that runs in an infinite loop.
      The ParameterServer subscribes to learner to get the latest
          model parameters and serves these parameters to agents
      It implements a simple hash based caching mechanism to avoid
          serving duplicate parameters to agent
  """
    def __init__(
        self,
        publish_port,
        serving_port,
        supress_output=False,
    ):
        """
        Args:
            publish_port: where learner should send parameters to.
            load_balanced: whether multiple parameter servers are sharing the
                same address
        """
        Thread.__init__(self)
        self.publish_port = publish_port
        self.serving_port = serving_port
        self._supress_output = supress_output
        # storage
        self.parameters = None
        self.param_info = None
        # threads
        self._subscriber = None
        self._server = None
        self._subscriber_thread = None
        self._server_thread = None

    def run(self):
        """
      Run relative threads and wait until they finish (due to error)
    """

        if self._supress_output:
            sys.stdout = open('/tmp/' + 'latest' + ".out", "w")
            sys.stderr = open('/tmp/' + 'latest' + ".err", "w")

        self._param_reciever = ZmqServer(
            host='*',
            port=self.publish_port,
            serializer=U.serialize,
            deserializer=U.deserialize,
        )
        self._server = ZmqServer(
            host='*',
            port=self.serving_port,
            # handler=self._handle_agent_request,
            serializer=U.serialize,
            deserializer=U.deserialize,
        )
        self._subscriber_thread = self._param_reciever.start_loop(
            handler=self._set_storage, blocking=False)
        self._server_thread = self._server.start_loop(
            handler=self._handle_agent_request, blocking=False)
        logging.info('Parameter server started')

        self._subscriber_thread.join()
        self._server_thread.join()

    def _set_storage(self, data):
        self.parameters, self.param_info = data
        logging.info('_set_storage received info: {}'.format(self.param_info))

    def _handle_agent_request(self, request):
        """Reply to agents' request for parameters."""

        request = PSRequest(**request)
        logging.info('Request received of type: %s', request.type)

        if self.param_info is None:
            return PSResponse(type='not_ready')._asdict()

        if request.type == 'info':
            return PSResponse(type='info', info=self.param_info)._asdict()

        elif request.type == 'parameter':
            if request.hash is not None:
                if request.hash == self.param_info[
                        'hash']:  # param not changed
                    return PSResponse(type='no_change',
                                      info=self.param_info)._asdict()

            params_asked_for = {
                var_name: self.parameters[var_name.replace(
                    request.agent_scope + '/',
                    self.param_info['agent_scope'] + '/', 1)]
                for var_name in request.var_list
            }
            return PSResponse(type='parameters',
                              info=self.param_info,
                              parameters=params_asked_for)._asdict()

        else:
            raise ValueError('invalid request type received: %s' %
                             (request.type))
예제 #21
0
class ParameterServer(Process):
    """
        Standalone script for PS node that runs in an infinite loop.
        The ParameterServer subscribes to learner to get the latest
            model parameters and serves these parameters to agents
        It implements a simple hash based caching mechanism to avoid
            serving duplicate parameters to agent
    """
    def __init__(
        self,
        publisher_host,
        publisher_port,
        serving_host,
        serving_port,
        load_balanced=False,
    ):
        """
        Args:
            publisher_host, publisher_port: where learner publish parameters
            serving_host, serving_port: where to serve parameters to agents
            load_balanced: whether multiple parameter servers are sharing the
                same address
        """
        Process.__init__(self)
        self.publisher_host = publisher_host
        self.publisher_port = publisher_port
        self.serving_host = serving_host
        self.serving_port = serving_port
        self.load_balanced = load_balanced
        # storage
        self.parameters = None
        self.param_info = None
        # threads
        self._subscriber = None
        self._server = None
        self._subscriber_thread = None
        self._server_thread = None

    def run(self):
        """
            Run relative threads and wait until they finish (due to error)
        """
        self._subscriber = ZmqSub(
            host=self.publisher_host,
            port=self.publisher_port,
            # handler=self._set_storage,
            topic='ps',
            deserializer=U.deserialize,
        )
        self._server = ZmqServer(
            host=self.serving_host,
            port=self.serving_port,
            # handler=self._handle_agent_request,
            serializer=U.serialize,
            deserializer=U.deserialize,
            bind=not self.load_balanced,
        )
        self._subscriber_thread = self._subscriber.start_loop(
            handler=self._set_storage, blocking=False)
        self._server_thread = self._server.start_loop(
            handler=self._handle_agent_request, blocking=False)
        print('Parameter server started')

        self._subscriber_thread.join()
        self._server_thread.join()

    def _set_storage(self, data):
        self.parameters, self.param_info = data

    def _handle_agent_request(self, request):
        """
            Reply to agents' request for parameters

        Args:
            request: 3 types
             - "info": (None, info)
             - "parameter": (param, info)
             - "parameter:<agent-hash>":
                returns (None, None) if no parameter has been published
                returns (None, info) if the hash
                    of server side parameters is the same as the agent's
                otherwise returns (param, info)
        """
        if request == 'info':
            return None, self.param_info
        elif request.startswith('parameter'):
            if self.parameters is None:
                return None, None
            if ':' in request:
                _, last_hash = request.split(':', 1)
                current_hash = self.param_info['hash']
                if last_hash == current_hash:  # param not changed
                    return None, self.param_info
                else:
                    return self.parameters, self.param_info
            else:
                return self.parameters, self.param_info
        else:
            raise ValueError('invalid request: ' + str(request))