def _Run(args): """This is what gets called when the user runs this command. Args: args: an argparse namespace. All the arguments that were provided to this command invocation. Returns: A json object that contains predictions. """ instances = predict_utilities.ReadInstancesFromArgs( args.json_request, args.json_instances, args.text_instances, limit=INPUT_INSTANCES_LIMIT) with endpoint_util.MlEndpointOverrides(region=args.region): model_or_version_ref = predict_utilities.ParseModelOrVersionRef( args.model, args.version) if (args.signature_name is None and predict_utilities.CheckRuntimeVersion(args.model, args.version)): log.status.Print( 'You are running on a runtime version >= 1.8. ' 'If the signature defined in the model is ' 'not serving_default then you must specify it via ' '--signature-name flag, otherwise the command may fail.') results = predict.Predict( model_or_version_ref, instances, signature_name=args.signature_name) if not args.IsSpecified('format'): # default format is based on the response. args.format = predict_utilities.GetDefaultFormat( results.get('predictions')) return results
def Run(self, args): """This is what gets called when the user runs this command. Args: args: an argparse namespace. All the arguments that were provided to this command invocation. Returns: Some value that we want to have printed later. """ instances = predict_utilities.ReadInstancesFromArgs( args.json_instances, args.text_instances, limit=INPUT_INSTANCES_LIMIT) model_or_version_ref = predict_utilities.ParseModelOrVersionRef( args.model, args.version) results = predict.Predict(model_or_version_ref, instances) if not args.IsSpecified('format'): # default format is based on the response. args.format = predict_utilities.GetDefaultFormat( results.get('predictions')) return results
def Run(self, args): results = local_utils.RunPredict(args.model_dir, args.json_instances, args.text_instances) if not args.IsSpecified('format'): # default format is based on the response. if isinstance(results, list): predictions = results else: predictions = results.get('predictions') args.format = predict_utilities.GetDefaultFormat(predictions) return results
def Run(self, args): framework = flags.FRAMEWORK_MAPPER.GetEnumForChoice(args.framework) framework_flag = framework.name.lower() if framework else 'tensorflow' results = local_utils.RunPredict(args.model_dir, args.json_instances, args.text_instances, framework=framework_flag) if not args.IsSpecified('format'): # default format is based on the response. if isinstance(results, list): predictions = results else: predictions = results.get('predictions') args.format = predict_utilities.GetDefaultFormat(predictions) return results
def Run(self, args): framework = flags.FRAMEWORK_MAPPER.GetEnumForChoice(args.framework) framework_flag = framework.name.lower() if framework else 'tensorflow' if args.signature_name is None: log.status.Print('If the signature defined in the model is ' 'not serving_default then you must specify it via ' '--signature-name flag, otherwise the command may fail.') results = local_utils.RunPredict( args.model_dir, json_instances=args.json_instances, text_instances=args.text_instances, framework=framework_flag, signature_name=args.signature_name) if not args.IsSpecified('format'): # default format is based on the response. if isinstance(results, list): predictions = results else: predictions = results.get('predictions') args.format = predict_utilities.GetDefaultFormat(predictions) return results