예제 #1
0
파일: args.py 프로젝트: weblfe/elasticdl
def parse_ps_args(ps_args=None):
    parser = argparse.ArgumentParser(description="ElasticDL PS")
    parser.add_argument("--ps_id",
                        help="ID unique to the PS",
                        type=int,
                        required=True)
    parser.add_argument("--port",
                        help="Port used by the PS pod",
                        type=int,
                        required=True)
    parser.add_argument("--master_addr", help="Master ip:port")

    add_common_params(parser)
    add_train_params(parser)
    # TODO: add PS replica address for RPC stub creation

    args, unknown_args = parser.parse_known_args(args=ps_args)
    print_args(args, groups=ALL_ARGS_GROUPS)
    if unknown_args:
        logger.warning("Unknown arguments: %s", unknown_args)
    if args.use_async and args.grads_to_wait > 1:
        args.grads_to_wait = 1
        logger.warning(
            "grads_to_wait is set to 1 when using asynchronous SGD.")
    return args
예제 #2
0
def parse_master_args(master_args=None):
    parser = argparse.ArgumentParser(description="ElasticDL Master")
    parser.add_argument(
        "--port",
        default=50001,
        type=pos_int,
        help="The listening port of master",
    )
    parser.add_argument(
        "--worker_image", help="Docker image for workers", default=None
    )
    parser.add_argument(
        "--prediction_data",
        help="Either the data directory that contains RecordIO files "
        "or an ODPS table name used for prediction.",
        default="",
    )
    add_common_params(parser)
    add_train_params(parser)

    args, unknown_args = parser.parse_known_args(args=master_args)
    print_args(args, groups=ALL_ARGS_GROUPS)
    if unknown_args:
        logger.warning("Unknown arguments: %s", unknown_args)

    if all(
        v == "" or v is None
        for v in [
            args.training_data,
            args.validation_data,
            args.prediction_data,
        ]
    ):
        raise ValueError(
            "At least one of the data directories needs to be provided"
        )

    if args.prediction_data and (args.training_data or args.validation_data):
        raise ValueError(
            "Running prediction together with training or evaluation "
            "is not supported"
        )
    if args.prediction_data and not args.checkpoint_dir_for_init:
        raise ValueError(
            "checkpoint_dir_for_init is required for running " "prediction job"
        )
    if not args.use_async and args.get_model_steps > 1:
        args.get_model_steps = 1
        logger.warning(
            "get_model_steps is set to 1 when using synchronous SGD."
        )
    if args.use_async and args.grads_to_wait > 1:
        args.grads_to_wait = 1
        logger.warning(
            "grads_to_wait is set to 1 when using asynchronous SGD."
        )

    return args
예제 #3
0
def build_argument_parser():
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers()
    subparsers.required = True

    # Initialize the parser for the `elasticdl zoo` commands
    zoo_parser = subparsers.add_parser(
        "zoo",
        help="Initialize | Build | Push a docker image for the model zoo.",
    )
    zoo_subparsers = zoo_parser.add_subparsers()
    zoo_subparsers.required = True

    # elasticdl zoo init
    zoo_init_parser = zoo_subparsers.add_parser(
        "init", help="Initialize the model zoo.")
    zoo_init_parser.set_defaults(func=init_zoo)
    args.add_zoo_init_arguments(zoo_init_parser)

    # elasticdl zoo build
    zoo_build_parser = zoo_subparsers.add_parser(
        "build", help="Build a docker image for the model zoo.")
    zoo_build_parser.set_defaults(func=build_zoo)
    args.add_zoo_build_arguments(zoo_build_parser)

    # elasticdl zoo push
    zoo_push_parser = zoo_subparsers.add_parser(
        "push",
        help="Push the docker image to a remote registry for the distributed"
        "ElasticDL job.",
    )
    zoo_push_parser.set_defaults(func=push_zoo)
    args.add_zoo_push_arguments(zoo_push_parser)

    # elasticdl train
    train_parser = subparsers.add_parser(
        "train", help="Submit a ElasticDL distributed training job")
    train_parser.set_defaults(func=train)
    args.add_common_params(train_parser)
    args.add_train_params(train_parser)

    # elasticdl evaluate
    evaluate_parser = subparsers.add_parser(
        "evaluate", help="Submit a ElasticDL distributed evaluation job")
    evaluate_parser.set_defaults(func=evaluate)
    args.add_common_params(evaluate_parser)
    args.add_evaluate_params(evaluate_parser)

    # elasticdl predict
    predict_parser = subparsers.add_parser(
        "predict", help="Submit a ElasticDL distributed prediction job")
    predict_parser.set_defaults(func=predict)
    args.add_common_params(predict_parser)
    args.add_predict_params(predict_parser)

    return parser