def test_client_server_eventloop(): host = '127.0.0.1' port = 7001 N = 20 data = [None, None] server = ZmqServer(host=host, port=port, serializer='pickle', deserializer='pickle') all_requests = [] def handler(x): if x == N - 1: server.stop() all_requests.append(x) return x + 1 server_thread = server.start_loop(handler, blocking=False) client = ZmqClient(host=host, port=port, serializer='pickle', deserializer='pickle') all_responses = [] for i in range(N): all_responses.append(client.request(i)) data[1] = all_responses server_thread.join() assert all_requests == [n for n in range(N)] assert all_responses == [n + 1 for n in range(N)]
def run(self): """ Run relative threads and wait until they finish (due to error) """ self._subscriber = ZmqSub( host=self.publisher_host, port=self.publisher_port, # handler=self._set_storage, topic='ps', deserializer=U.deserialize, ) self._server = ZmqServer( host=self.serving_host, port=self.serving_port, # handler=self._handle_agent_request, serializer=U.serialize, deserializer=U.deserialize, bind=not self.load_balanced, ) self._subscriber_thread = self._subscriber.start_loop( handler=self._set_storage, blocking=False) self._server_thread = self._server.start_loop( handler=self._handle_agent_request, blocking=False) print('Parameter server started') self._subscriber_thread.join() self._server_thread.join()
def __init__(self, seed, evict_interval, compress_before_send, load_balanced=True, index=0, **kwargs): self.config = ConfigDict(kwargs) self.index = index if load_balanced: collector_port = os.environ['SYMPH_COLLECTOR_BACKEND_PORT'] sampler_port = os.environ['SYMPH_SAMPLER_BACKEND_PORT'] else: collector_port = os.environ['SYMPH_COLLECTOR_FRONTEND_PORT'] sampler_port = os.environ['SYMPH_SAMPLER_FRONTEND_PORT'] self._collector_server = ExperienceCollectorServer( host='localhost' if load_balanced else '*', port=collector_port, exp_handler=self._insert_wrapper, load_balanced=load_balanced, compress_before_send=compress_before_send) self._sampler_server = ZmqServer( host='localhost' if load_balanced else '*', port=sampler_port, bind=not load_balanced, serializer=get_serializer(compress_before_send), deserializer=get_deserializer(compress_before_send)) self._sampler_server_thread = None self._evict_interval = evict_interval self._evict_thread = None self._setup_logging()
def run(self): """ Run relative threads and wait until they finish (due to error) """ if self._supress_output: sys.stdout = open('/tmp/' + 'latest' + ".out", "w") sys.stderr = open('/tmp/' + 'latest' + ".err", "w") self._param_reciever = ZmqServer( host='*', port=self.publish_port, serializer=U.serialize, deserializer=U.deserialize, ) self._server = ZmqServer( host='*', port=self.serving_port, # handler=self._handle_agent_request, serializer=U.serialize, deserializer=U.deserialize, ) self._subscriber_thread = self._param_reciever.start_loop( handler=self._set_storage, blocking=False) self._server_thread = self._server.start_loop( handler=self._handle_agent_request, blocking=False) logging.info('Parameter server started') self._subscriber_thread.join() self._server_thread.join()
def __init__(self, learner_config, env_config, session_config, index=0): """ """ # Note that there're 2 replay configs: # one in learner_config that controls algorithmic # part of the replay logic # one in session_config that controls system settings self.learner_config = learner_config self.env_config = env_config self.session_config = session_config self.index = index collector_port = os.environ['SYMPH_COLLECTOR_BACKEND_PORT'] sampler_port = os.environ['SYMPH_SAMPLER_BACKEND_PORT'] self._collector_server = ExperienceCollectorServer( host='localhost', port=collector_port, exp_handler=self._insert_wrapper, load_balanced=True, ) self._sampler_server = ZmqServer(host='localhost', port=sampler_port, bind=False) self._sampler_server_thread = None self._evict_interval = self.session_config.replay.evict_interval self._evict_thread = None self._setup_logging()
def run(self): self._server = ZmqServer(host=self.serving_host, port=self.serving_port, serializer=U.serialize, deserializer=U.deserialize, bind=True) self._server.start_loop(handler=self._handle_request, blocking=True)
def server(N): server = ZmqServer(host=host, port=port, serializer='pickle', deserializer='pickle') all_requests = [] for i in range(N): n = server.recv() all_requests.append(n) server.send(n + 1) data[0] = all_requests
def server_fn(): server = ZmqServer(host='*', port=args.port, serializer='pyarrow', deserializer='pyarrow') def f(msg): assert msg == MSG print('message received succesfully') server.start_loop(f)
def run(self): self._server = ZmqServer( host='*', port=self.port, serializer=U.pickle_serialize, deserializer=U.pickle_deserialize, bind=True, ) self._server_thread = self._server.start_loop(handler=self._handle_request, blocking=False) logging.info('Spec server started') self._server_thread.join()
def _start_remote_server(self): server = ZmqServer(host='*', port=REMOTE_PORT, serializer=get_serializer(), deserializer=get_deserializer(), auth=False) graph = self._get_graph_features() def handler(_): return [graph] thr = Thread(target=lambda: server.start_loop(handler, blocking=True)) thr.daemon = True thr.start() return thr
class SpecServer(Thread): def __init__(self, port, traj_spec, action_spec): self._traj_spec = traj_spec self._action_spec = action_spec self.port = port super(SpecServer, self).__init__() def run(self): self._server = ZmqServer( host='*', port=self.port, serializer=U.pickle_serialize, deserializer=U.pickle_deserialize, bind=True, ) self._server_thread = self._server.start_loop(handler=self._handle_request, blocking=False) logging.info('Spec server started') self._server_thread.join() def _handle_request(self, req): """req -> (batch_size, traj_length)""" batch_size, _ = req traj_spec = Trajectory.format_traj_spec(self._traj_spec, *req) self._action_spec.set_shape((batch_size, ) + self._action_spec.shape[1:]) return traj_spec, self._action_spec
def __init__(self): self.param_info = { 'time': time.time(), 'iteration': 0, 'variable_list': [], 'hash': U.pyobj_hash({}), } self._server = ZmqServer( host=_LOCALHOST, port=PS_FRONTEND_PORT, serializer=U.serialize, deserializer=U.deserialize, bind=True, ) self._server_thread = self._server.start_loop( handler=self._handle_agent_request, blocking=False)
class Proxy: def __init__(self, serving_host, serving_port): self.serving_host = serving_host self.serving_port = serving_port self._irs_client = IRSClient(auto_detect_proxy=False) def run(self): self._server = ZmqServer(host='*', port=self.serving_port, serializer=U.serialize, deserializer=U.deserialize, bind=True) self._server.start_loop(handler=self._handle_request, blocking=True) def _handle_request(self, req): req_fn, args, kwargs = req return getattr(self._irs_client, req_fn)(*args, **kwargs)
def main(_): server = ZmqServer(host='localhost', port=PORT, serializer=U.serialize, deserializer=U.deserialize, bind=True) server_thread = server.start_loop(handler=server_f, blocking=False) # client = ZmqClient(host='localhost', # port=PORT, # timeout=2, # serializer=U.serialize, # deserializer=U.deserialize) client = get_ps_client() for _ in range(10): # client.request(['info', ['x']]) client.fetch_parameter_with_info(['x']) print('Done!')
def server(): server = ZmqServer(host=host, port=port, serializer='pickle', deserializer='pickle') while True: req = server.recv() if req == 'request-1': server.send('received-request-1') elif req == 'request-2': time.sleep(0.5) server.send('received-request-2') elif req == 'request-3': server.send('received-request-3') break
def main(_): proxy = ZmqProxyThread("tcp://*:%d" % FRONTEND_PORT, "tcp://*:%d" % BACKEND_PORT) proxy.start() server = ZmqServer(host='localhost', port=BACKEND_PORT, serializer=U.serialize, deserializer=U.deserialize, bind=False) server_thread = server.start_loop(handler=server_f, blocking=False) # client = ZmqClient(host='localhost', # port=PORT, # timeout=2, # serializer=U.serialize, # deserializer=U.deserialize) client = get_ps_client() for _ in range(10): # client.request(['info', ['x']]) client.fetch_parameter_with_info(['x']) print('Done!')
class DummyPS: def __init__(self): self.param_info = { 'time': time.time(), 'iteration': 0, 'variable_list': [], 'hash': U.pyobj_hash({}), } self._server = ZmqServer( host=_LOCALHOST, port=PS_FRONTEND_PORT, serializer=U.serialize, deserializer=U.deserialize, bind=True, ) self._server_thread = self._server.start_loop( handler=self._handle_agent_request, blocking=False) def _handle_agent_request(self, request): """Reply to agents' request for parameters.""" request = PSRequest(**request) if request.type == 'info': return PSResponse(type='info')._asdict() elif request.type == 'parameter': if request.hash is not None: if request.hash == self.param_info[ 'hash']: # param not changed return PSResponse(type='no_change', info=self.param_info)._asdict() return PSResponse(type='parameters', info=self.param_info, parameters={})._asdict() else: raise ValueError('invalid request type received: %s' % (request.type))
class Worker(Thread): def __init__(self, serving_host, serving_port, checkpoint_folder, profile_folder, kvstream_folder, **kwargs): Thread.__init__(self) self.config = ConfigDict(**kwargs) self.checkpoint_folder = checkpoint_folder self.profile_folder = profile_folder self.kvstream_folder = kvstream_folder self.serving_host = serving_host self.serving_port = serving_port # Attributes self._server = None def run(self): self._server = ZmqServer(host=self.serving_host, port=self.serving_port, serializer=U.serialize, deserializer=U.deserialize, bind=True) self._server.start_loop(handler=self._handle_request, blocking=True) def _handle_request(self, req): req_fn, args, kwargs = req assert isinstance(req_fn, str) try: fn = getattr(self, req_fn) return fn(*args, **kwargs) except AttributeError: logging.error('Unknown request func name received: %s', req_fn) def _read_checkpoint_info(self): with open(os.path.join(self.checkpoint_folder, 'info.txt'), 'r') as f: ckpts = [] for line in f.readlines(): j = json.loads(line) ckpts.append([j['time'], j['dst_dir_name']]) ckpts.sort() return ckpts def _write_checkpoint_info(self, ckpts): with open(os.path.join(self.checkpoint_folder, 'info.txt'), 'w') as f: for t, d in ckpts: print('{"dst_dir_name": "%s", "time":%d}' % (d, t), file=f) def _enforce_checkpoint_policy(self): """ max_to_keep should include latest checkpoint keep_ckpt_every_n_hrs < 0 => disables this option """ max_to_keep = self.config.max_to_keep assert max_to_keep >= 0 keep_ckpt_every_n_hrs = self.config.keep_ckpt_every_n_hrs ckpts = self._read_checkpoint_info() if keep_ckpt_every_n_hrs >= 0: prev_keep_t = ckpts[0][0] to_delete = [] for i, (t, d) in enumerate(ckpts[1:]): if t - prev_keep_t < 3600 * keep_ckpt_every_n_hrs: to_delete.append([t, d]) else: prev_keep_t = t else: to_delete = list(ckpts) if max_to_keep: for ckpt in ckpts[-max_to_keep:]: if ckpt in to_delete: to_delete.remove(ckpt) self._write_checkpoint_info( [ckpt for ckpt in ckpts if ckpt not in to_delete]) for _, d in to_delete: U.f_remove(os.path.join(self.checkpoint_folder, d)) def _stream_to_file(self, offset, data, fname, done): """fname should be with full path.""" fname = fname.rstrip('/') Path(fname + '.part').touch() with open(fname + '.part', 'r+b') as f: f.seek(offset, os.SEEK_SET) f.write(data) if done: shutil.move(fname + '.part', fname) # ================== PUBLIC REMOTE API ================== def register_commands(self, **cmds): U.f_mkdir(self.config.cmd_folder) U.pretty_dump(cmds, os.path.join(self.config.cmd_folder, 'cmds.txt')) def register_metagraph(self, offset, data, _, fname, done): """TODO: If multi threaded client, add filelock support here.""" U.f_mkdir(self.checkpoint_folder) self._stream_to_file(offset, data, os.path.join(self.checkpoint_folder, fname), done) def register_checkpoint(self, offset, data, dst_dir_name, fname, done): """TODO: If multi threaded client, add filelock support here.""" U.f_mkdir(os.path.join(self.checkpoint_folder, dst_dir_name)) self._stream_to_file( offset, data, os.path.join(self.checkpoint_folder, dst_dir_name, fname), done) if done: with open(os.path.join(self.checkpoint_folder, 'info.txt'), 'a') as f: print('{"dst_dir_name": "%s", "time":%d}' % (dst_dir_name, int(time.time())), file=f) logging.info("Received new checkpoint which is saved at %s/%s/%s", self.checkpoint_folder, dst_dir_name, fname) self.enforce_checkpoint_policy() def register_profile(self, offset, data, dst_dir_name, fname, done): """TODO: If multi threaded client, add filelock support here.""" del dst_dir_name # unused U.f_mkdir(self.profile_folder) self._stream_to_file(offset, data, os.path.join(self.profile_folder, fname), done) if done: logging.info("Received new profile which is saved at %s/%s", self.checkpoint_folder, fname) def enforce_checkpoint_policy(self): # Remove duplicates from checkpoint info.txt first ckpts = self._read_checkpoint_info() ckpts = sorted([[sec, first] for first, sec in ckpts]) to_remove = [] for i, (d, t) in enumerate(ckpts): if i > 0: if d == ckpts[i - 1][0]: to_remove.append(i - 1) self._write_checkpoint_info([[t, d] for i, (d, t) in enumerate(ckpts) if i not in to_remove]) self._enforce_checkpoint_policy() def record_kv_data(self, stream, kv_data, **kwargs): """Add key-value data to the stream.""" logging.info(f'Received new kvdata on stream {stream}') U.f_mkdir(self.kvstream_folder) with open(os.path.join(self.kvstream_folder, stream) + '.pkl', 'wb') as f: d = dict(kv=kv_data, stream=stream, **kwargs) pickle.dump(d, f) def save_file(self, fname, data, **kwargs): fname = f'{self.config.vis_files_folder}/{fname}' U.f_mkdir(os.path.dirname(fname)) with open(fname, 'wb') as f: f.write(data)
class Replay: """ Important: When extending this class, make sure to follow the init method signature so that orchestrating functions can properly initialize the replay server. """ def __init__(self, seed, evict_interval, compress_before_send, load_balanced=True, index=0, **kwargs): self.config = ConfigDict(kwargs) self.index = index if load_balanced: collector_port = os.environ['SYMPH_COLLECTOR_BACKEND_PORT'] sampler_port = os.environ['SYMPH_SAMPLER_BACKEND_PORT'] else: collector_port = os.environ['SYMPH_COLLECTOR_FRONTEND_PORT'] sampler_port = os.environ['SYMPH_SAMPLER_FRONTEND_PORT'] self._collector_server = ExperienceCollectorServer( host='localhost' if load_balanced else '*', port=collector_port, exp_handler=self._insert_wrapper, load_balanced=load_balanced, compress_before_send=compress_before_send) self._sampler_server = ZmqServer( host='localhost' if load_balanced else '*', port=sampler_port, bind=not load_balanced, serializer=get_serializer(compress_before_send), deserializer=get_deserializer(compress_before_send)) self._sampler_server_thread = None self._evict_interval = evict_interval self._evict_thread = None self._setup_logging() def start_threads(self): if self._has_tensorplex: self.start_tensorplex_thread() self._collector_server.start() if self._evict_interval: self.start_evict_thread() self._sampler_server_thread = self._sampler_server.start_loop( handler=self._sample_request_handler) def join(self): self._collector_server.join() self._sampler_server_thread.join() if self._has_tensorplex: self._tensorplex_thread.join() if self._evict_interval: self._evict_thread.join() def insert(self, exp_dict): """ Add a new experience to the replay. Includes passive evict logic if memory capacity is exceeded. Args: exp_dict: {[obs], action, reward, done, info} """ raise NotImplementedError def sample(self, batch_size): """ This function is called in _sample_handler for learner side Zmq request Args: batch_size Returns: a list of exp_tuples """ raise NotImplementedError def evict(self): """ Actively evict old experiences. """ pass def start_sample_condition(self): """ Tells the thread to start sampling only when this condition is met. For example, only when the replay memory has > 10K experiences. Returns: bool: whether to start sampling or not """ raise NotImplementedError def __len__(self): raise NotImplementedError # ======================== internal methods ======================== def _sample_request_handler(self, req): """ Handle requests to the learner https://stackoverflow.com/questions/29082268/python-time-sleep-vs-event-wait Since we don't have external notify, we'd better just use sleep """ # batch_size = U.deserialize(req) batch_size = req U.assert_type(batch_size, int) while not self.start_sample_condition(): time.sleep(0.01) self.cumulative_sampled_count += batch_size self.cumulative_request_count += 1 with self.sample_time.time(): while True: try: sample = self.sample(batch_size) break except ReplayUnderFlowException: time.sleep(1e-3) with self.serialize_time.time(): return sample # return U.serialize(sample) def _insert_wrapper(self, exp): """ Allows us to do some book keeping in the base class """ self.cumulative_collected_count += 1 with self.insert_time.time(): self.insert(exp) def _get_tensorplex_client(self, client_id): host = os.environ['SYMPH_TENSORPLEX_SYSTEM_HOST'] port = os.environ['SYMPH_TENSORPLEX_SYSTEM_PORT'] return TensorplexClient( client_id, host=host, port=port, serializer=self.config.tensorplex_config.serializer, deserializer=self.config.tensorplex_config.deserializer) def _setup_logging(self): # self.log = get_loggerplex_client('{}/{}'.format('replay', self.index), # self.config) self.tensorplex = self._get_tensorplex_client('{}/{}'.format( 'replay', self.index)) self._tensorplex_thread = None self._has_tensorplex = self.config.tensorboard_display # Origin of all global steps self.init_time = time.time() # Number of experience collected by agents self.cumulative_collected_count = 0 # Number of experience sampled by learner self.cumulative_sampled_count = 0 # Number of sampling requests from the learner self.cumulative_request_count = 0 # Timer for tensorplex reporting self.last_tensorplex_iter_time = time.time() # Last reported values used for speed computation self.last_experience_count = 0 self.last_sample_count = 0 self.last_request_count = 0 self.insert_time = U.TimeRecorder(decay=0.99998) self.sample_time = U.TimeRecorder() self.serialize_time = U.TimeRecorder() # moving avrage of about 100s self.exp_in_speed = U.MovingAverageRecorder(decay=0.99) self.exp_out_speed = U.MovingAverageRecorder(decay=0.99) self.handle_sample_request_speed = U.MovingAverageRecorder(decay=0.99) def start_evict_thread(self): if self._evict_thread is not None: raise RuntimeError('evict thread already running') self._evict_thread = U.start_thread(self._evict_loop) return self._evict_thread def _evict_loop(self): assert self._evict_interval while True: time.sleep(self._evict_interval) self.evict() def start_tensorplex_thread(self): if self._tensorplex_thread is not None: raise RuntimeError('tensorplex thread already running') self._tensorplex_thread = U.PeriodicWakeUpWorker( target=self.generate_tensorplex_report) self._tensorplex_thread.start() return self._tensorplex_thread def generate_tensorplex_report(self): """ Generates stats to be reported to tensorplex """ global_step = int(time.time() - self.init_time) time_elapsed = time.time() - self.last_tensorplex_iter_time + 1e-6 cum_count_collected = self.cumulative_collected_count new_exp_count = cum_count_collected - self.last_experience_count self.last_experience_count = cum_count_collected cum_count_sampled = self.cumulative_sampled_count new_sample_count = cum_count_sampled - self.last_sample_count self.last_sample_count = cum_count_sampled cum_count_requests = self.cumulative_request_count new_request_count = cum_count_requests - self.last_request_count self.last_request_count = cum_count_requests exp_in_speed = self.exp_in_speed.add_value(new_exp_count / time_elapsed) exp_out_speed = self.exp_out_speed.add_value(new_sample_count / time_elapsed) handle_sample_request_speed = self.handle_sample_request_speed.add_value( new_request_count / time_elapsed) insert_time = self.insert_time.avg sample_time = self.sample_time.avg serialize_time = self.serialize_time.avg core_metrics = { 'num_exps': len(self), 'total_collected_exps': cum_count_collected, 'total_sampled_exps': cum_count_sampled, 'total_sample_requests': self.cumulative_request_count, 'exp_in_per_s': exp_in_speed, 'exp_out_per_s': exp_out_speed, 'requests_per_s': handle_sample_request_speed, 'insert_time_s': insert_time, 'sample_time_s': sample_time, 'serialize_time_s': serialize_time, } if hasattr(self, 'per_sample_size'): core_metrics['per_sample_size_MB'] = self.per_sample_size / 1e6 serialize_load = serialize_time * handle_sample_request_speed / time_elapsed collect_exp_load = insert_time * exp_in_speed / time_elapsed sample_exp_load = sample_time * handle_sample_request_speed / time_elapsed system_metrics = { 'lifetime_experience_utilization_percent': cum_count_sampled / (cum_count_collected + 1) * 100, 'current_experience_utilization_percent': exp_out_speed / (exp_in_speed + 1) * 100, 'serialization_load_percent': serialize_load * 100, 'collect_exp_load_percent': collect_exp_load * 100, 'sample_exp_load_percent': sample_exp_load * 100, # 'exp_queue_occupancy_percent': self._exp_queue.occupancy() * 100, } all_metrics = {} for k in core_metrics: all_metrics['.core/' + k] = core_metrics[k] for k in system_metrics: all_metrics['.system/' + k] = system_metrics[k] self.tensorplex.add_scalars(all_metrics, global_step=global_step) self.last_tensorplex_iter_time = time.time()
class ParameterServer(Thread): """ Standalone script for PS node that runs in an infinite loop. The ParameterServer subscribes to learner to get the latest model parameters and serves these parameters to agents It implements a simple hash based caching mechanism to avoid serving duplicate parameters to agent """ def __init__( self, publish_port, serving_port, supress_output=False, ): """ Args: publish_port: where learner should send parameters to. load_balanced: whether multiple parameter servers are sharing the same address """ Thread.__init__(self) self.publish_port = publish_port self.serving_port = serving_port self._supress_output = supress_output # storage self.parameters = None self.param_info = None # threads self._subscriber = None self._server = None self._subscriber_thread = None self._server_thread = None def run(self): """ Run relative threads and wait until they finish (due to error) """ if self._supress_output: sys.stdout = open('/tmp/' + 'latest' + ".out", "w") sys.stderr = open('/tmp/' + 'latest' + ".err", "w") self._param_reciever = ZmqServer( host='*', port=self.publish_port, serializer=U.serialize, deserializer=U.deserialize, ) self._server = ZmqServer( host='*', port=self.serving_port, # handler=self._handle_agent_request, serializer=U.serialize, deserializer=U.deserialize, ) self._subscriber_thread = self._param_reciever.start_loop( handler=self._set_storage, blocking=False) self._server_thread = self._server.start_loop( handler=self._handle_agent_request, blocking=False) logging.info('Parameter server started') self._subscriber_thread.join() self._server_thread.join() def _set_storage(self, data): self.parameters, self.param_info = data logging.info('_set_storage received info: {}'.format(self.param_info)) def _handle_agent_request(self, request): """Reply to agents' request for parameters.""" request = PSRequest(**request) logging.info('Request received of type: %s', request.type) if self.param_info is None: return PSResponse(type='not_ready')._asdict() if request.type == 'info': return PSResponse(type='info', info=self.param_info)._asdict() elif request.type == 'parameter': if request.hash is not None: if request.hash == self.param_info[ 'hash']: # param not changed return PSResponse(type='no_change', info=self.param_info)._asdict() params_asked_for = { var_name: self.parameters[var_name.replace( request.agent_scope + '/', self.param_info['agent_scope'] + '/', 1)] for var_name in request.var_list } return PSResponse(type='parameters', info=self.param_info, parameters=params_asked_for)._asdict() else: raise ValueError('invalid request type received: %s' % (request.type))
class ParameterServer(Process): """ Standalone script for PS node that runs in an infinite loop. The ParameterServer subscribes to learner to get the latest model parameters and serves these parameters to agents It implements a simple hash based caching mechanism to avoid serving duplicate parameters to agent """ def __init__( self, publisher_host, publisher_port, serving_host, serving_port, load_balanced=False, ): """ Args: publisher_host, publisher_port: where learner publish parameters serving_host, serving_port: where to serve parameters to agents load_balanced: whether multiple parameter servers are sharing the same address """ Process.__init__(self) self.publisher_host = publisher_host self.publisher_port = publisher_port self.serving_host = serving_host self.serving_port = serving_port self.load_balanced = load_balanced # storage self.parameters = None self.param_info = None # threads self._subscriber = None self._server = None self._subscriber_thread = None self._server_thread = None def run(self): """ Run relative threads and wait until they finish (due to error) """ self._subscriber = ZmqSub( host=self.publisher_host, port=self.publisher_port, # handler=self._set_storage, topic='ps', deserializer=U.deserialize, ) self._server = ZmqServer( host=self.serving_host, port=self.serving_port, # handler=self._handle_agent_request, serializer=U.serialize, deserializer=U.deserialize, bind=not self.load_balanced, ) self._subscriber_thread = self._subscriber.start_loop( handler=self._set_storage, blocking=False) self._server_thread = self._server.start_loop( handler=self._handle_agent_request, blocking=False) print('Parameter server started') self._subscriber_thread.join() self._server_thread.join() def _set_storage(self, data): self.parameters, self.param_info = data def _handle_agent_request(self, request): """ Reply to agents' request for parameters Args: request: 3 types - "info": (None, info) - "parameter": (param, info) - "parameter:<agent-hash>": returns (None, None) if no parameter has been published returns (None, info) if the hash of server side parameters is the same as the agent's otherwise returns (param, info) """ if request == 'info': return None, self.param_info elif request.startswith('parameter'): if self.parameters is None: return None, None if ':' in request: _, last_hash = request.split(':', 1) current_hash = self.param_info['hash'] if last_hash == current_hash: # param not changed return None, self.param_info else: return self.parameters, self.param_info else: return self.parameters, self.param_info else: raise ValueError('invalid request: ' + str(request))