예제 #1
0
def test_client_server_eventloop():
    host = '127.0.0.1'
    port = 7001
    N = 20
    data = [None, None]

    server = ZmqServer(host=host,
                       port=port,
                       serializer='pickle',
                       deserializer='pickle')
    all_requests = []

    def handler(x):
        if x == N - 1:
            server.stop()
        all_requests.append(x)
        return x + 1

    server_thread = server.start_loop(handler, blocking=False)

    client = ZmqClient(host=host,
                       port=port,
                       serializer='pickle',
                       deserializer='pickle')
    all_responses = []
    for i in range(N):
        all_responses.append(client.request(i))
    data[1] = all_responses

    server_thread.join()

    assert all_requests == [n for n in range(N)]
    assert all_responses == [n + 1 for n in range(N)]
예제 #2
0
  def __init__(
      self,
      host,
      port,
      agent_scope,
      timeout=2,
      not_ready_sleep=2,
  ):
    """
        Args:
            host: parameter server host
            port: parameter server port
            timeout: how long should the the client wait
                if the parameter server is not available
        """
    self.host = host
    self.port = port
    self.timeout = timeout
    self._current_info = {}
    self._last_hash = ''
    self.alive = False
    self._agent_scope = agent_scope
    self._not_ready_sleep = not_ready_sleep

    self._client = ZmqClient(host=self.host,
                             port=self.port,
                             timeout=self.timeout,
                             serializer=U.serialize,
                             deserializer=U.deserialize)
예제 #3
0
def client_fn():
    cli = ZmqClient(host=args.ip,
                    port=args.port,
                    timeout=args.timeout,
                    serializer='pyarrow',
                    deserializer='pyarrow')
    cli.request(MSG)
    print('Done')
예제 #4
0
 def client(N):
     client = ZmqClient(host=host,
                        port=port,
                        serializer='pickle',
                        deserializer='pickle')
     all_responses = []
     for i in range(N):
         all_responses.append(client.request(i))
     data[1] = all_responses
예제 #5
0
    def __init__(self, host, port, agent_scope):
        """
        Args:
            host: IP of the ps
            port: the port connected to the pub socket
        """
        self._agent_scope = agent_scope
        self.alive = False

        self._publisher = ZmqClient(
            host=host,
            port=port,
            timeout=2,
            serializer=U.serialize,
            deserializer=U.deserialize,
        )
예제 #6
0
def get_irs_client(timeout, auto_detect_proxy=True):

    if 'SYMPH_IRS_PROXY_HOST' in os.environ and auto_detect_proxy:
        host = os.environ['SYMPH_IRS_PROXY_HOST']
        port = os.environ['SYMPH_IRS_PROXY_PORT']
    else:
        host = os.environ['SYMPH_IRS_HOST']
        port = os.environ['SYMPH_IRS_PORT']

    return ZmqClient(host=host,
                     port=port,
                     serializer='pyarrow',
                     deserializer='pyarrow',
                     timeout=timeout)
예제 #7
0
 def client():
     client = ZmqClient(host=host,
                        port=port,
                        timeout=0.3,
                        serializer='pickle',
                        deserializer='pickle')
     assert client.request('request-1') == 'received-request-1'
     with pytest.raises(ZmqTimeoutError):
         client.request('request-2')
     assert client.request('request-3') == 'received-request-3'
예제 #8
0
파일: learner.py 프로젝트: aravic/liaison
 def _setup_spec_client(self):
     self.spec_client = ZmqClient(host=os.environ['SYMPH_SPEC_HOST'],
                                  port=os.environ['SYMPH_SPEC_PORT'],
                                  serializer=U.pickle_serialize,
                                  deserializer=U.pickle_deserialize,
                                  timeout=4)
예제 #9
0
class ParameterPublisher(object):
    """
      Publishes parameters from the learner side
      Using ZmqPub socket
  """
    def __init__(self, host, port, agent_scope):
        """
        Args:
            host: IP of the ps
            port: the port connected to the pub socket
        """
        self._agent_scope = agent_scope
        self.alive = False

        self._publisher = ZmqClient(
            host=host,
            port=port,
            timeout=2,
            serializer=U.serialize,
            deserializer=U.deserialize,
        )

    def publish(self, iteration, var_dict):
        """
        Called by learner. Publishes model parameters with additional info

        Args:
            iteration: current learning iteration
            var_dict: Dict of available variables.
    """
        info = {
            'agent_scope': self._agent_scope,
            'time': time.time(),
            'iteration': iteration,
            'variable_list': list(var_dict.keys()),
            'hash': U.pyobj_hash(var_dict),
        }
        while True:
            try:
                self._publisher.request((var_dict, info))
            except ZmqTimeoutError as e:
                self.on_fetch_parameter_failed()
                continue
            break
        self.on_fetch_parameter_success()

    def on_fetch_parameter_failed(self):
        """
            Called when connection with parameter server fails
            to be established
        """
        if self.alive:
            self.alive = False
            logging.info('Parameter client request timed out')

    def on_fetch_parameter_success(self):
        """
            Called when connection with parameter server
            is succesfully established
        """
        if not self.alive:
            self.alive = True
            logging.info('Parameter client came back alive')
예제 #10
0
 def _register_commands_with_irs(self, commands, host, port):
     self._cli = ZmqClient(host=host,
                           port=port,
                           serializer='pyarrow',
                           deserializer='pyarrow')
     self._cli.request(['register_commands', [], commands])
예제 #11
0
class TurrealParser(SymphonyParser):
    def create_cluster(self, server_name):
        self.cluster = Cluster.new('tmux', server_name=server_name)

    def setup(self):
        super().setup()
        self._setup_create()

    def _setup_create(self):
        parser = self.add_subparser('create', aliases=['c'])
        self._add_experiment_name(parser, positional=False)
        parser.add_argument('--results_folder',
                            '-r',
                            required=True,
                            type=str,
                            help='Results folder.')
        parser.add_argument('--n_work_units', type=int, default=1)
        self._add_dry_run(parser)

    # ==================== helpers ====================
    def _add_dry_run(self, parser):
        parser.add_argument(
            '-dr',
            '--dry-run',
            action='store_true',
            help='print the kubectl command without actually executing it.')

    def _process_experiment_name(self, experiment_name):
        """
        experiment_name will be used as DNS, so must not have underscore or dot
        """
        new_name = experiment_name.lower().replace('.', '-').replace('_', '-')
        if new_name != experiment_name:
            print('experiment name string has been fixed: {} -> {}'.format(
                experiment_name, new_name))
        return new_name

    def _setup_xmanager_client(self, args):

        self._xm_client = get_xmanager_client(host=args.xmanager_server_host,
                                              port=int(
                                                  args.xmanager_server_port),
                                              timeout=4)

    def _register_exp(self, exp_name):
        return self._xm_client.register(name=exp_name)

    def _record_launch_command(self, exp_id):
        """Records the launch command with the irs."""
        self._xm_client.record_metadata(exp_id=exp_id,
                                        command=' '.join(sys.argv))

    def _record_hyper_config(self, exp_id, hyper_configs):
        """Records the launch command with the irs."""
        self._xm_client.record_metadata(exp_id=exp_id,
                                        hyper_config=hyper_configs)

    def _register_commands_with_irs(self, commands, host, port):
        self._cli = ZmqClient(host=host,
                              port=port,
                              serializer='pyarrow',
                              deserializer='pyarrow')
        self._cli.request(['register_commands', [], commands])

    def get_cluster(self):
        return self.cluster

    def action_create(self, args):
        """
        Spin up a multi-node distributed Surreal experiment.
        Put any command line args that pass to the config script after "--"
    """
        self.experiment_name = self._process_experiment_name(
            args.experiment_name)

        self._setup_xmanager_client(args)
        exp_id = self._register_exp(self.experiment_name)
        print(f'Registered Experiment ID: {exp_id}')
        if args.tmux_server_name is None:
            self.create_cluster(f'{exp_id}')
        self._record_launch_command(exp_id)

        results_folder = args.results_folder.format(
            experiment_name=self.experiment_name, exp_id=exp_id)

        self.results_folder = results_folder
        self.remainder_args = args.remainder
        self.exp_id = exp_id

    def launch(self, experiments, exp_configs, hyper_configs):
        """
    Tasks:
      1. Adds all the commands needed for processes, shells and experiments.

      2. Register with XManager client and IRS

      3. Launch experiments

    Details:
      1. Add PREAMBLE_CMDS to the experiment
    """

        exp_id = self.exp_id
        print('Experiment ID: %d' % exp_id)
        print('Results folder: %s' % (self.results_folder))
        algorithm_args = self.remainder_args
        algorithm_args += ["--experiment_id", str(exp_id)]
        algorithm_args += ["--experiment_name", self.experiment_name]
        algorithm_args += ["--results_folder", self.results_folder]
        self._record_hyper_config(exp_id, hyper_configs)

        commands = []
        for exp, exp_config in zip(experiments, exp_configs):
            cmd_gen = CommandGenerator(executable='liaison/launch/main.py',
                                       config_commands=algorithm_args +
                                       exp_config)

            exp.set_preamble_cmds(PREAMBLE_CMDS)
            all_procs = [
                proc for pg in exp.list_process_groups()
                for proc in pg.list_processes()
            ] + [proc for proc in exp.list_processes()]

            for proc in all_procs:
                proc.append_cmds([cmd_gen.get_command(proc.name)])

            for proc in all_procs:
                commands.append(cmd_gen.get_command(proc.name))

            self.cluster.launch(exp)

    def main(self, argv):
        assert argv.count('--') <= 1, \
            'command line can only have at most one "--"'

        argv = list(argv)
        if '--' in argv:
            idx = argv.index('--')
            remainder = argv[idx + 1:]
            argv = argv[:idx]
            has_remainder = True  # even if remainder itself is empty
        else:
            remainder = []
            has_remainder = False
        master_args = argv

        args_l = []
        for parser in self._external_parsers:
            args, unknown = parser.parse_known_args(master_args)
            master_args = unknown
            args_l.append(args)

        assert '--' not in master_args
        args = self.master_parser.parse_args(master_args)
        args.remainder = remainder
        args.has_remainder = has_remainder

        self.create_cluster(args.tmux_server_name or 'default')
        args.func(args)
        return args.func, args_l
예제 #12
0
class ParameterClient(object):
    """
        On agent side, sends requests to parameter servers to fetch the
        latest parameters.
    """
    def __init__(self, host, port, timeout=2):
        """
        Args:
            host: parameter server host
            port: parameter server port
            timeout: how long should the the client wait
                if the parameter server is not available
        """
        self.host = host
        self.port = port
        self.timeout = timeout
        self._current_info = {}
        self._last_hash = ''
        self.alive = False

        self._client = ZmqClient(host=self.host,
                                 port=self.port,
                                 timeout=self.timeout,
                                 serializer=U.serialize,
                                 deserializer=U.deserialize)

    def fetch_parameter_with_info(self, force_update=False):
        """
            Called by agent to retrieve parameters
            By default, pulls from PS ONLY WHEN the parameter hash changes
                to prevent duplicate fetching. No-op when duplicate.
            Caching can be overriden by force_update

        Args:
            force_update: forces download of parameter, regardless of
                currently cached hash

        Returns:
            (param or None, info or None)
        """
        try:
            if force_update:
                response = self._client.request('parameter')
            else:
                response = self._client.request('parameter:' + self._last_hash)
        except ZmqTimeoutError:
            self.on_fetch_parameter_failed()
            return None, None
        self.on_fetch_parameter_success()
        param, info = response
        if info is None:
            return None, None

        self._last_hash = info['hash']
        return param, info

    def fetch_info(self):
        """
            Fetch the metadata of parameters on parameter server

        Returns:
            dictionary of metadata
        """
        try:
            response = self._client.request('info')
        except ZmqTimeoutError:
            self.on_fetch_parameter_failed()
            return None
        self.on_fetch_parameter_success()
        _, info = response
        return info

    def on_fetch_parameter_failed(self):
        """
            Called when connection with parameter server fails
            to be established
        """
        if self.alive:
            self.alive = False
            print('Parameter client request timed out')

    def on_fetch_parameter_success(self):
        """
            Called when connection with parameter server
            is succesfully established
        """
        if not self.alive:
            self.alive = True
            print('Parameter client came back alive')
예제 #13
0
class ParameterClient(object):
  """
        On agent side, sends requests to parameter servers to fetch the
        latest parameters.
    """

  def __init__(
      self,
      host,
      port,
      agent_scope,
      timeout=2,
      not_ready_sleep=2,
  ):
    """
        Args:
            host: parameter server host
            port: parameter server port
            timeout: how long should the the client wait
                if the parameter server is not available
        """
    self.host = host
    self.port = port
    self.timeout = timeout
    self._current_info = {}
    self._last_hash = ''
    self.alive = False
    self._agent_scope = agent_scope
    self._not_ready_sleep = not_ready_sleep

    self._client = ZmqClient(host=self.host,
                             port=self.port,
                             timeout=self.timeout,
                             serializer=U.serialize,
                             deserializer=U.deserialize)

  def fetch_parameter_with_info(self, var_names, force_update=False):
    """Keeps trying on time out errors and not ready responses until
      fetch is successful."""

    if force_update:
      use_hash = None
    else:
      use_hash = self._last_hash

    while True:
      try:
        response = self._client.request(
            PSRequest(type='parameter',
                      hash=use_hash,
                      var_list=var_names,
                      agent_scope=self._agent_scope)._asdict())
      except ZmqTimeoutError:
        logging.info('ZmQ timed out.')
        self.on_fetch_parameter_failed()
        continue

      self.on_fetch_parameter_success()
      response = PSResponse(**response)

      if use_hash is None:
        assert response.type != 'no_change'

      if response.type == 'not_ready':
        logging.info('PS not ready.')
        time.sleep(self._not_ready_sleep)

      elif response.type == 'no_change':
        assert self._last_hash == response.info['hash']
        return None, response.info

      else:
        self._last_hash = response.info['hash']
        return response.parameters, response.info

  def fetch_info(self):
    """
        Fetch the metadata of parameters on parameter server.
        Keeps trying on time outs. Returns None if response received with
        status `not_ready`.

    Returns:
        dictionary of metadata
    """
    while True:
      try:
        response = self._client.request(
            PSRequest(type='info',
                      hash=self._last_hash,
                      var_list=None,
                      agent_scope=self._agent_scope)._asdict())
      except ZmqTimeoutError:
        logging.info('ZmQ timed out.')
        self.on_fetch_parameter_failed()
        continue
      break

    self.on_fetch_parameter_success()
    response = PSResponse(**response)
    assert response.type == 'info' or response.type == 'not_ready'
    return response.info

  def fetch_info_no_retry(self):
    """
        Fetch the metadata of parameters on parameter server.
        Keeps trying on time outs. Returns None if response received with
        status `not_ready`.

    Returns:
        dictionary of metadata
    """
    try:
      response = self._client.request(
          PSRequest(type='info',
                    hash=self._last_hash,
                    var_list=None,
                    agent_scope=self._agent_scope)._asdict())
    except ZmqTimeoutError:
      logging.info('ZmQ timed out.')
      self.on_fetch_parameter_failed()
      return None

    self.on_fetch_parameter_success()
    response = PSResponse(**response)
    assert response.type == 'info' or response.type == 'not_ready'
    return response.info

  def on_fetch_parameter_failed(self):
    """
            Called when connection with parameter server fails
            to be established
        """
    if self.alive:
      self.alive = False
      logging.info('Parameter client request timed out')

  def on_fetch_parameter_success(self):
    """
            Called when connection with parameter server
            is succesfully established
        """
    if not self.alive:
      self.alive = True
      logging.info('Parameter client came back alive')