def _Run(args):
  """This is what gets called when the user runs this command.

  Args:
    args: an argparse namespace. All the arguments that were provided to this
      command invocation.

  Returns:
    A json object that contains predictions.
  """
  instances = predict_utilities.ReadInstancesFromArgs(
      args.json_request,
      args.json_instances,
      args.text_instances,
      limit=INPUT_INSTANCES_LIMIT)

  with endpoint_util.MlEndpointOverrides(region=args.region):
    model_or_version_ref = predict_utilities.ParseModelOrVersionRef(
        args.model, args.version)
    if (args.signature_name is None and
        predict_utilities.CheckRuntimeVersion(args.model, args.version)):
      log.status.Print(
          'You are running on a runtime version >= 1.8. '
          'If the signature defined in the model is '
          'not serving_default then you must specify it via '
          '--signature-name flag, otherwise the command may fail.')
    results = predict.Predict(
        model_or_version_ref, instances, signature_name=args.signature_name)

  if not args.IsSpecified('format'):
    # default format is based on the response.
    args.format = predict_utilities.GetDefaultFormat(
        results.get('predictions'))

  return results
Exemple #2
0
    def Run(self, args):
        """This is what gets called when the user runs this command.

    Args:
      args: an argparse namespace. All the arguments that were provided to this
        command invocation.

    Returns:
      Some value that we want to have printed later.
    """
        instances = predict_utilities.ReadInstancesFromArgs(
            args.json_instances,
            args.text_instances,
            limit=INPUT_INSTANCES_LIMIT)

        model_or_version_ref = predict_utilities.ParseModelOrVersionRef(
            args.model, args.version)

        results = predict.Predict(model_or_version_ref, instances)

        if not args.IsSpecified('format'):
            # default format is based on the response.
            args.format = predict_utilities.GetDefaultFormat(
                results.get('predictions'))

        return results
Exemple #3
0
    def Run(self, args):
        results = local_utils.RunPredict(args.model_dir, args.json_instances,
                                         args.text_instances)
        if not args.IsSpecified('format'):
            # default format is based on the response.
            if isinstance(results, list):
                predictions = results
            else:
                predictions = results.get('predictions')

            args.format = predict_utilities.GetDefaultFormat(predictions)

        return results
Exemple #4
0
    def Run(self, args):
        framework = flags.FRAMEWORK_MAPPER.GetEnumForChoice(args.framework)
        framework_flag = framework.name.lower() if framework else 'tensorflow'

        results = local_utils.RunPredict(args.model_dir,
                                         args.json_instances,
                                         args.text_instances,
                                         framework=framework_flag)
        if not args.IsSpecified('format'):
            # default format is based on the response.
            if isinstance(results, list):
                predictions = results
            else:
                predictions = results.get('predictions')

            args.format = predict_utilities.GetDefaultFormat(predictions)

        return results
  def Run(self, args):
    framework = flags.FRAMEWORK_MAPPER.GetEnumForChoice(args.framework)
    framework_flag = framework.name.lower() if framework else 'tensorflow'
    if args.signature_name is None:
      log.status.Print('If the signature defined in the model is '
                       'not serving_default then you must specify it via '
                       '--signature-name flag, otherwise the command may fail.')
    results = local_utils.RunPredict(
        args.model_dir,
        json_instances=args.json_instances,
        text_instances=args.text_instances,
        framework=framework_flag,
        signature_name=args.signature_name)
    if not args.IsSpecified('format'):
      # default format is based on the response.
      if isinstance(results, list):
        predictions = results
      else:
        predictions = results.get('predictions')

      args.format = predict_utilities.GetDefaultFormat(predictions)

    return results