def parse_ps_args(ps_args=None): parser = argparse.ArgumentParser(description="ElasticDL PS") parser.add_argument("--ps_id", help="ID unique to the PS", type=int, required=True) parser.add_argument("--port", help="Port used by the PS pod", type=int, required=True) parser.add_argument("--master_addr", help="Master ip:port") add_common_params(parser) add_train_params(parser) # TODO: add PS replica address for RPC stub creation args, unknown_args = parser.parse_known_args(args=ps_args) print_args(args, groups=ALL_ARGS_GROUPS) if unknown_args: logger.warning("Unknown arguments: %s", unknown_args) if args.use_async and args.grads_to_wait > 1: args.grads_to_wait = 1 logger.warning( "grads_to_wait is set to 1 when using asynchronous SGD.") return args
def parse_master_args(master_args=None): parser = argparse.ArgumentParser(description="ElasticDL Master") parser.add_argument( "--port", default=50001, type=pos_int, help="The listening port of master", ) parser.add_argument( "--worker_image", help="Docker image for workers", default=None ) parser.add_argument( "--prediction_data", help="Either the data directory that contains RecordIO files " "or an ODPS table name used for prediction.", default="", ) add_common_params(parser) add_train_params(parser) args, unknown_args = parser.parse_known_args(args=master_args) print_args(args, groups=ALL_ARGS_GROUPS) if unknown_args: logger.warning("Unknown arguments: %s", unknown_args) if all( v == "" or v is None for v in [ args.training_data, args.validation_data, args.prediction_data, ] ): raise ValueError( "At least one of the data directories needs to be provided" ) if args.prediction_data and (args.training_data or args.validation_data): raise ValueError( "Running prediction together with training or evaluation " "is not supported" ) if args.prediction_data and not args.checkpoint_dir_for_init: raise ValueError( "checkpoint_dir_for_init is required for running " "prediction job" ) if not args.use_async and args.get_model_steps > 1: args.get_model_steps = 1 logger.warning( "get_model_steps is set to 1 when using synchronous SGD." ) if args.use_async and args.grads_to_wait > 1: args.grads_to_wait = 1 logger.warning( "grads_to_wait is set to 1 when using asynchronous SGD." ) return args
def build_argument_parser(): parser = argparse.ArgumentParser() subparsers = parser.add_subparsers() subparsers.required = True # Initialize the parser for the `elasticdl zoo` commands zoo_parser = subparsers.add_parser( "zoo", help="Initialize | Build | Push a docker image for the model zoo.", ) zoo_subparsers = zoo_parser.add_subparsers() zoo_subparsers.required = True # elasticdl zoo init zoo_init_parser = zoo_subparsers.add_parser( "init", help="Initialize the model zoo.") zoo_init_parser.set_defaults(func=init_zoo) args.add_zoo_init_arguments(zoo_init_parser) # elasticdl zoo build zoo_build_parser = zoo_subparsers.add_parser( "build", help="Build a docker image for the model zoo.") zoo_build_parser.set_defaults(func=build_zoo) args.add_zoo_build_arguments(zoo_build_parser) # elasticdl zoo push zoo_push_parser = zoo_subparsers.add_parser( "push", help="Push the docker image to a remote registry for the distributed" "ElasticDL job.", ) zoo_push_parser.set_defaults(func=push_zoo) args.add_zoo_push_arguments(zoo_push_parser) # elasticdl train train_parser = subparsers.add_parser( "train", help="Submit a ElasticDL distributed training job") train_parser.set_defaults(func=train) args.add_common_params(train_parser) args.add_train_params(train_parser) # elasticdl evaluate evaluate_parser = subparsers.add_parser( "evaluate", help="Submit a ElasticDL distributed evaluation job") evaluate_parser.set_defaults(func=evaluate) args.add_common_params(evaluate_parser) args.add_evaluate_params(evaluate_parser) # elasticdl predict predict_parser = subparsers.add_parser( "predict", help="Submit a ElasticDL distributed prediction job") predict_parser.set_defaults(func=predict) args.add_common_params(predict_parser) args.add_predict_params(predict_parser) return parser