def Run(self, args): data_format = jobs_util.DataFormatFlagMap().GetEnumForChoice( args.data_format) return jobs_util.SubmitPrediction( jobs.JobsClient(), args.job, model_dir=args.model_dir, model=args.model, version=args.version, input_paths=args.input_paths, data_format=data_format.name, output_path=args.output_path, region=args.region, runtime_version=args.runtime_version, max_worker_count=args.max_worker_count, batch_size=args.batch_size)
def Run(self, args): data_format = jobs_util.DataFormatFlagMap().GetEnumForChoice( args.data_format) jobs_client = jobs.JobsClient() labels = jobs_util.ParseCreateLabels(jobs_client, args) return jobs_util.SubmitPrediction( jobs_client, args.job, model_dir=args.model_dir, model=args.model, version=args.version, input_paths=args.input_paths, data_format=data_format.name, output_path=args.output_path, region=args.region, runtime_version=args.runtime_version, max_worker_count=args.max_worker_count, batch_size=args.batch_size, signature_name=args.signature_name, labels=labels, accelerator_type=args.accelerator_type, accelerator_count=args.accelerator_count)
def _AddSubmitPredictionArgs(parser): """Add arguments for `jobs submit prediction` command.""" parser.add_argument('job', help='Name of the batch prediction job.') model_group = parser.add_mutually_exclusive_group(required=True) model_group.add_argument('--model-dir', help=('Google Cloud Storage location where ' 'the model files are located.')) model_group.add_argument('--model', help='Name of the model to use for prediction.') parser.add_argument('--version', help="""\ Model version to be used. This flag may only be given if --model is specified. If unspecified, the default version of the model will be used. To list versions for a model, run $ gcloud ml-engine versions list """) # input location is a repeated field. parser.add_argument('--input-paths', type=arg_parsers.ArgList(min_length=1), required=True, metavar='INPUT_PATH', help="""\ Google Cloud Storage paths to the instances to run prediction on. Wildcards (```*```) accepted at the *end* of a path. More than one path can be specified if multiple file patterns are needed. For example, gs://my-bucket/instances*,gs://my-bucket/other-instances1 will match any objects whose names start with `instances` in `my-bucket` as well as the `other-instances1` bucket, while gs://my-bucket/instance-dir/* will match any objects in the `instance-dir` "directory" (since directories aren't a first-class Cloud Storage concept) of `my-bucket`. """) jobs_util.DataFormatFlagMap().choice_arg.AddToParser(parser) parser.add_argument( '--output-path', required=True, help='Google Cloud Storage path to which to save the output. ' 'Example: gs://my-bucket/output.') parser.add_argument( '--region', required=True, help='The Google Compute Engine region to run the job in.') parser.add_argument( '--max-worker-count', required=False, type=int, help=( 'The maximum number of workers to be used for parallel processing. ' 'Defaults to 10 if not specified.')) parser.add_argument( '--batch-size', required=False, type=int, help=( 'The number of records per batch. The service will buffer ' 'batch_size number of records in memory before invoking TensorFlow.' ' Defaults to 64 if not specified.')) flags.RUNTIME_VERSION.AddToParser(parser)