예제 #1
0
def main():
    parser = argparse.ArgumentParser(
        description="AI for Earth Land Cover Worker")

    parser.add_argument("-v",
                        "--verbose",
                        action="store_true",
                        help="Enable verbose debugging",
                        default=False)
    parser.add_argument("--port",
                        action="store",
                        type=int,
                        help="Port we are listenning on",
                        default=0)
    parser.add_argument("--gpu_id",
                        action="store",
                        dest="gpu_id",
                        type=int,
                        help="GPU to use",
                        required=False)
    parser.add_argument("--model_key",
                        action="store",
                        dest="model_key",
                        type=str,
                        help="Model key from models.json to use")
    args = parser.parse_args(sys.argv[1:])

    # Setup logging
    log_path = os.path.join(os.getcwd(), "tmp/logs/")
    setup_logging(log_path, "worker")

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "" if args.gpu_id is None else str(
        args.gpu_id)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    model_configs = load_models()
    if not args.model_key in model_configs:
        LOGGER.error("'%s' is not recognized as a valid model, exiting..." %
                     (args.model_key))
        return
    model_type = model_configs[args.model_key]["type"]

    if model_type == "keras_example":
        model = KerasDenseFineTune(args.gpu_id,
                                   **model_configs[args.model_key])
    elif model_type == "pytorch_example":
        model = TorchFineTuning(args.model_fn, args.gpu_id,
                                args.fine_tune_layer)
    elif model_type == "pytorch_smoothing_multiple":
        model = TorchSmoothingCycleFineTune(args.model_fn, args.gpu_id,
                                            args.fine_tune_layer,
                                            args.num_models)
    else:
        raise NotImplementedError(
            "The given model type is not implemented yet.")

    t = OneShotServer(MyService(model), port=args.port)
    t.start()
예제 #2
0
파일: worker.py 프로젝트: zgle-me/landcover
def main():
    global MODEL
    parser = argparse.ArgumentParser(description="AI for Earth Land Cover Worker")

    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose debugging", default=False)

    parser.add_argument("--port", action="store", type=int, help="Port we are listenning on", default=0)
    parser.add_argument("--model", action="store", dest="model",
        choices=[
            "keras_dense",
            "pytorch"
        ],
        help="Model to use", required=True
    )
    parser.add_argument("--model_fn", action="store", dest="model_fn", type=str, help="Model fn to use", default=None)
    parser.add_argument("--fine_tune_layer", action="store", dest="fine_tune_layer", type=int, help="Layer of model to fine tune", default=-2)
    
    parser.add_argument("--gpu", action="store", dest="gpuid", type=int, help="GPU to use", required=False)

    args = parser.parse_args(sys.argv[1:])

    # Setup logging
    log_path = os.path.join(os.getcwd(), "tmp/logs/")
    setup_logging(log_path, "worker")


    # Setup model
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "" if args.gpuid is None else str(args.gpuid)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

    if args.model == "keras_dense":
        model = KerasDenseFineTune(args.model_fn, args.gpuid, args.fine_tune_layer)
    elif args.model == "pytorch":
        model = TorchFineTuning(args.model_fn, args.gpuid, args.fine_tune_layer)
    else:
        raise NotImplementedError("The given model type is not implemented yet.")

    t = OneShotServer(MyService(model), port=args.port)
    t.start()
예제 #3
0
def main():
    global SESSION_HANDLER
    parser = argparse.ArgumentParser(description="AI for Earth Land Cover")

    parser.add_argument("-v",
                        "--verbose",
                        action="store_true",
                        help="Enable verbose debugging",
                        default=False)
    parser.add_argument("--host",
                        action="store",
                        dest="host",
                        type=str,
                        help="Host to bind to",
                        default="0.0.0.0")
    parser.add_argument("--port",
                        action="store",
                        dest="port",
                        type=int,
                        help="Port to listen on",
                        default=8080)

    parser.add_argument(
        "--disable_checkpoints",
        action="store_true",
        help="Disables the ability to save checkpoints on the server")

    args = parser.parse_args(sys.argv[1:])

    # Create session factory to handle incoming requests
    SESSION_HANDLER = SessionHandler(args)
    SESSION_HANDLER.start_monitor(SESSION_TIMEOUT_SECONDS)

    # Setup logging
    log_path = os.path.join(os.getcwd(), "tmp/logs/")
    setup_logging(log_path, "server")

    # Make sure some directories exist
    os.makedirs("tmp/checkpoints/", exist_ok=True)
    os.makedirs("tmp/downloads/", exist_ok=True)
    os.makedirs("tmp/logs/", exist_ok=True)
    os.makedirs("tmp/output/",
                exist_ok=True)  # TODO: Remove this after we rework
    os.makedirs("tmp/session/", exist_ok=True)

    # Setup the bottle server
    app = bottle.Bottle()

    app.add_hook("after_request", enable_cors)
    app.add_hook(
        "before_request", manage_sessions
    )  # before every request we want to check to make sure there are no session issues

    # API paths
    app.route(
        "/predPatch", method="OPTIONS", callback=do_options
    )  # TODO: all of our web requests from index.html fire an OPTIONS call because of https://stackoverflow.com/questions/1256593/why-am-i-getting-an-options-request-instead-of-a-get-request, we should fix this
    app.route('/predPatch', method="POST", callback=pred_patch)

    app.route("/predTile", method="OPTIONS", callback=do_options)
    app.route('/predTile', method="POST", callback=pred_tile)

    app.route("/downloadAll", method="OPTIONS", callback=do_options)
    app.route('/downloadAll', method="POST", callback=download_all)

    app.route("/getInput", method="OPTIONS", callback=do_options)
    app.route('/getInput', method="POST", callback=get_input)

    app.route("/recordCorrection", method="OPTIONS", callback=do_options)
    app.route('/recordCorrection', method="POST", callback=record_correction)

    app.route("/retrainModel", method="OPTIONS", callback=do_options)
    app.route('/retrainModel', method="POST", callback=retrain_model)

    app.route("/resetModel", method="OPTIONS", callback=do_options)
    app.route('/resetModel', method="POST", callback=reset_model)

    app.route("/doUndo", method="OPTIONS", callback=do_options)
    app.route("/doUndo", method="POST", callback=do_undo)

    app.route("/createSession", method="OPTIONS", callback=do_options)
    app.route("/createSession", method="POST", callback=create_session)

    app.route("/killSession", method="OPTIONS", callback=do_options)
    app.route("/killSession", method="POST", callback=kill_session)

    app.route("/getSessionStatus", method="OPTIONS", callback=do_options)
    app.route("/getSessionStatus", method="POST", callback=get_session_status)

    # Checkpoints
    app.route("/createCheckpoint", method="OPTIONS", callback=do_options)
    app.route("/createCheckpoint",
              method="POST",
              callback=checkpoint_wrapper(args.disable_checkpoints))
    app.route("/getCheckpoints", method="GET", callback=get_checkpoints)

    # Sessions
    app.route("/whoami", method="GET", callback=whoami)

    # Content paths
    app.route("/", method="GET", callback=get_landing_page)
    app.route("/data/basemaps/<filepath:re:.*>",
              method="GET",
              callback=get_basemap_data)
    app.route("/data/zones/<filepath:re:.*>",
              method="GET",
              callback=get_zone_data)
    app.route("/tmp/downloads/<filepath:re:.*>",
              method="GET",
              callback=get_downloads)
    app.route("/favicon.ico", method="GET", callback=get_favicon)
    app.route("/<filepath:re:.*>", method="GET", callback=get_everything_else)

    manage_session_folders()
    session_opts = {
        'session.type': 'file',
        #'session.cookie_expires': 3000, # session cookie
        'session.data_dir': SESSION_FOLDER,
        'session.auto': True
    }
    app = beaker.middleware.SessionMiddleware(app, session_opts)

    server = cheroot.wsgi.Server((args.host, args.port), app)
    server.max_request_header_size = 2**13
    server.max_request_body_size = 2**27

    LOGGER.info("Server initialized")
    try:
        server.start()
    finally:
        server.stop()