def main(unused_argv):
    base_config = config_util.get_config(urdf_root=FLAGS.urdf_data_path)
    config = config_util.generate_config(
        base_config, current_time_string=FLAGS.current_time_string)
    servers = []
    server_creds = loas2.loas2_server_credentials()
    port = FLAGS.port
    if not config.run_on_borg:
        port = 20000 + FLAGS.server_id
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=100),
                         ports=(port, ))

    blackbox_object = config.blackbox_object_fn()

    np.random.seed(FLAGS.server_id)

    if config.algorithm == "zero_order":
        if FLAGS.server_id < config.test_workers:
            worker_mode = "Test"
            task_ids = range(config.train_set_size,
                             config.train_set_size + config.test_set_size)

        else:
            worker_mode = "Train"
            task_ids = range(config.train_set_size)

        servicer = blackbox_maml_objects.GeneralMAMLBlackboxWorker(
            worker_id=FLAGS.server_id,
            blackbox_object=blackbox_object,
            task_ids=task_ids,
            task_batch_size=config.task_batch_size,
            worker_mode=worker_mode)
        zero_order_pb2_grpc.add_EvaluationServicer_to_server(servicer, server)

    elif config.algorithm == "first_order":
        tasks = [
            config.make_task_fn(s)
            for s in range(config.train_set_size + config.test_set_size)
        ]
        servicer = blackbox_maml_objects.GradientMAMLWorker(
            FLAGS.server_id, blackbox_object=blackbox_object, tasks=tasks)
        first_order_pb2_grpc.add_EvaluationServicer_to_server(servicer, server)

    server.add_secure_port("[::]:{}".format(port), server_creds)
    servers.append(server)
    server.start()
    print("Start server {}".format(FLAGS.server_id))

    # prevent the main thread from exiting
    try:
        while True:
            time.sleep(_ONE_DAY_IN_SECONDS)
    except KeyboardInterrupt:
        for server in servers:
            server.stop(0)
Beispiel #2
0
def main(unused_argv):
  servers = []
  server_creds = loas2.loas2_server_credentials()
  port = FLAGS.port
  if not FLAGS.run_on_borg:
    port = 20000 + FLAGS.server_id
  server = grpc.server(
      futures.ThreadPoolExecutor(max_workers=10), ports=(port,))
  servicer = ars_evaluation_service.ParameterEvaluationServicer(
      FLAGS.config_name, worker_id=FLAGS.server_id)
  ars_evaluation_service_pb2_grpc.add_EvaluationServicer_to_server(
      servicer, server)
  server.add_secure_port("[::]:{}".format(port), server_creds)
  servers.append(server)
  server.start()
  print("Start server {}".format(FLAGS.server_id))

  # prevent the main thread from exiting
  try:
    while True:
      time.sleep(_ONE_DAY_IN_SECONDS)
  except KeyboardInterrupt:
    for server in servers:
      server.stop(0)
Beispiel #3
0
def main(unused_argv):
    servers = []
    server_creds = loas2.loas2_server_credentials()
    port = FLAGS.port
    if not FLAGS.run_on_borg:
        port = 20000 + FLAGS.server_id
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10),
                         ports=(port, ))
    servicer = ars_evaluation_service.ParameterEvaluationServicer(
        FLAGS.config_name, worker_id=FLAGS.server_id)
    ars_evaluation_service_pb2_grpc.add_EvaluationServicer_to_server(
        servicer, server)
    server.add_secure_port("[::]:{}".format(port), server_creds)
    servers.append(server)
    server.start()
    print("Start server {}".format(FLAGS.server_id))

    # prevent the main thread from exiting
    try:
        while True:
            time.sleep(_ONE_DAY_IN_SECONDS)
    except KeyboardInterrupt:
        for server in servers:
            server.stop(0)
Beispiel #4
0
def add_port(server):
    server_credentials = loas2.loas2_server_credentials()
    return server.add_secure_port(
        _ADDRESS_FORMAT.format(FLAGS.env_service_port), server_credentials)
Beispiel #5
0
def add_port(server, port):
  server_credentials = loas2.loas2_server_credentials()
  return server.add_secure_port(
      _ADDRESS_FORMAT.format(port), server_credentials)