def _Run(args): """This is what gets called when the user runs this command. Args: args: an argparse namespace. All the arguments that were provided to this command invocation. Returns: A json object that contains predictions. """ instances = predict_utilities.ReadInstancesFromArgs( args.json_request, args.json_instances, args.text_instances, limit=INPUT_INSTANCES_LIMIT) with endpoint_util.MlEndpointOverrides(region=args.region): model_or_version_ref = predict_utilities.ParseModelOrVersionRef( args.model, args.version) if (args.signature_name is None and predict_utilities.CheckRuntimeVersion(args.model, args.version)): log.status.Print( 'You are running on a runtime version >= 1.8. ' 'If the signature defined in the model is ' 'not serving_default then you must specify it via ' '--signature-name flag, otherwise the command may fail.') results = predict.Predict( model_or_version_ref, instances, signature_name=args.signature_name) if not args.IsSpecified('format'): # default format is based on the response. args.format = predict_utilities.GetDefaultFormat( results.get('predictions')) return results
def Run(self, args): """This is what gets called when the user runs this command. Args: args: an argparse namespace. All the arguments that were provided to this command invocation. Returns: Some value that we want to have printed later. """ instances = predict_utilities.ReadInstancesFromArgs( args.json_instances, args.text_instances, limit=INPUT_INSTANCES_LIMIT) model_or_version_ref = predict_utilities.ParseModelOrVersionRef( args.model, args.version) results = predict.Predict(model_or_version_ref, instances) if not args.IsSpecified('format'): # default format is based on the response. args.format = predict_utilities.GetDefaultFormat( results.get('predictions')) return results
def testParseModelOrVersionRef_Version(self): self.assertEqual( predict_utilities.ParseModelOrVersionRef('m', 'v'), resources.REGISTRY.Create('ml.projects.models.versions', projectsId=self.Project(), modelsId='m', versionsId='v'))
def testParseModelOrVersionRef_MissingModelAndVersion(self): with self.assertRaises(resources.RequiredFieldOmittedException): predict_utilities.ParseModelOrVersionRef(None, None)