def _Run(args):
  """This is what gets called when the user runs this command.

  Args:
    args: an argparse namespace. All the arguments that were provided to this
      command invocation.

  Returns:
    A json object that contains predictions.
  """
  instances = predict_utilities.ReadInstancesFromArgs(
      args.json_request,
      args.json_instances,
      args.text_instances,
      limit=INPUT_INSTANCES_LIMIT)

  with endpoint_util.MlEndpointOverrides(region=args.region):
    model_or_version_ref = predict_utilities.ParseModelOrVersionRef(
        args.model, args.version)
    if (args.signature_name is None and
        predict_utilities.CheckRuntimeVersion(args.model, args.version)):
      log.status.Print(
          'You are running on a runtime version >= 1.8. '
          'If the signature defined in the model is '
          'not serving_default then you must specify it via '
          '--signature-name flag, otherwise the command may fail.')
    results = predict.Predict(
        model_or_version_ref, instances, signature_name=args.signature_name)

  if not args.IsSpecified('format'):
    # default format is based on the response.
    args.format = predict_utilities.GetDefaultFormat(
        results.get('predictions'))

  return results
Example #2
0
    def Run(self, args):
        """This is what gets called when the user runs this command.

    Args:
      args: an argparse namespace. All the arguments that were provided to this
        command invocation.

    Returns:
      Some value that we want to have printed later.
    """
        instances = predict_utilities.ReadInstancesFromArgs(
            args.json_instances,
            args.text_instances,
            limit=INPUT_INSTANCES_LIMIT)

        model_or_version_ref = predict_utilities.ParseModelOrVersionRef(
            args.model, args.version)

        results = predict.Predict(model_or_version_ref, instances)

        if not args.IsSpecified('format'):
            # default format is based on the response.
            args.format = predict_utilities.GetDefaultFormat(
                results.get('predictions'))

        return results
 def testParseModelOrVersionRef_Version(self):
     self.assertEqual(
         predict_utilities.ParseModelOrVersionRef('m', 'v'),
         resources.REGISTRY.Create('ml.projects.models.versions',
                                   projectsId=self.Project(),
                                   modelsId='m',
                                   versionsId='v'))
 def testParseModelOrVersionRef_MissingModelAndVersion(self):
     with self.assertRaises(resources.RequiredFieldOmittedException):
         predict_utilities.ParseModelOrVersionRef(None, None)