예제 #1
0
def parse_args():
    parser = argparse.ArgumentParser(description="ElasticDL Master")
    parser.add_argument(
        "--port",
        default=50001,
        type=pos_int,
        help="The listening port of master",
    )
    parser.add_argument("--worker_image",
                        help="Docker image for workers",
                        default=None)
    parser.add_argument("--worker_pod_priority",
                        help="Priority requested by workers")
    parser.add_argument(
        "--prediction_data_dir",
        help="Prediction data directory. Files should be in RecordIO format",
        default="",
    )
    parser.add_argument(
        "--log_level",
        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
        type=str.upper,
        default="INFO",
        help="The logging level. Default to WARNING",
    )
    add_common_params(parser)
    add_train_params(parser)

    args = parser.parse_args()
    print_args(args, groups=ALL_ARGS_GROUPS)

    if all(v == "" or v is None for v in [
            args.training_data_dir,
            args.evaluation_data_dir,
            args.prediction_data_dir,
    ]):
        raise ValueError(
            "At least one of the data directories needs to be provided")

    if args.prediction_data_dir and (args.training_data_dir
                                     or args.evaluation_data_dir):
        raise ValueError(
            "Running prediction together with training or evaluation "
            "is not supported")
    if args.prediction_data_dir and not args.checkpoint_filename_for_init:
        raise ValueError(
            "checkpoint_filename_for_init is required for running "
            "prediction job")

    return args
예제 #2
0
def main():
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(dest="cmd")
    subparsers.required = True

    train_parser = subparsers.add_parser(
        "train", help="Submit a ElasticDL distributed training job"
    )
    train_parser.set_defaults(func=train)
    add_common_params(train_parser)
    add_train_params(train_parser)

    evaluate_parser = subparsers.add_parser(
        "evaluate", help="Submit a ElasticDL distributed evaluation job"
    )
    evaluate_parser.set_defaults(func=evaluate)
    add_common_params(evaluate_parser)
    add_evaluate_params(evaluate_parser)

    predict_parser = subparsers.add_parser(
        "predict", help="Submit a ElasticDL distributed prediction job"
    )
    predict_parser.set_defaults(func=predict)
    add_common_params(predict_parser)
    add_predict_params(predict_parser)

    args, _ = parser.parse_known_args()
    args.func(args)
예제 #3
0
def main():
    parser = argparse.ArgumentParser(
        usage="""elasticdl <command> [<args>]

Below is the list of supported commands:
train         Submit a ElasticDL distributed training job.
evaluate      Submit a ElasticDL distributed evaluation job.
predict       Submit a ElasticDL distributed prediction job.
"""
    )
    subparsers = parser.add_subparsers()

    train_parser = subparsers.add_parser("train", help="elasticdl train -h")
    train_parser.set_defaults(func=train)
    add_common_params(train_parser)
    add_train_params(train_parser)

    evaluate_parser = subparsers.add_parser(
        "evaluate", help="elasticdl evaluate -h"
    )
    evaluate_parser.set_defaults(func=evaluate)
    add_common_params(evaluate_parser)
    add_evaluate_params(evaluate_parser)

    predict_parser = subparsers.add_parser(
        "predict", help="elasticdl predict -h"
    )
    predict_parser.set_defaults(func=predict)
    add_common_params(predict_parser)
    add_predict_params(predict_parser)

    args, _ = parser.parse_known_args()
    args.func(args)