示例#1
0
    def Run(self, args):
        """This is what gets called when the user runs this command.

    Args:
      args: an argparse namespace. All the arguments that were provided to this
        command invocation.

    Returns:
      Some value that we want to have printed later.
    """
        instances = predict_utilities.ReadInstancesFromArgs(
            args.json_instances,
            args.text_instances,
            limit=INPUT_INSTANCES_LIMIT)

        model_or_version_ref = predict_utilities.ParseModelOrVersionRef(
            args.model, args.version)

        results = predict.Predict(model_or_version_ref, instances)

        if not args.IsSpecified('format'):
            # default format is based on the response.
            args.format = predict_utilities.GetDefaultFormat(
                results.get('predictions'))

        return results
def _Run(args):
  """This is what gets called when the user runs this command.

  Args:
    args: an argparse namespace. All the arguments that were provided to this
      command invocation.

  Returns:
    A json object that contains predictions.
  """
  instances = predict_utilities.ReadInstancesFromArgs(
      args.json_request,
      args.json_instances,
      args.text_instances,
      limit=INPUT_INSTANCES_LIMIT)

  with endpoint_util.MlEndpointOverrides(region=args.region):
    model_or_version_ref = predict_utilities.ParseModelOrVersionRef(
        args.model, args.version)
    if (args.signature_name is None and
        predict_utilities.CheckRuntimeVersion(args.model, args.version)):
      log.status.Print(
          'You are running on a runtime version >= 1.8. '
          'If the signature defined in the model is '
          'not serving_default then you must specify it via '
          '--signature-name flag, otherwise the command may fail.')
    results = predict.Predict(
        model_or_version_ref, instances, signature_name=args.signature_name)

  if not args.IsSpecified('format'):
    # default format is based on the response.
    args.format = predict_utilities.GetDefaultFormat(
        results.get('predictions'))

  return results
示例#3
0
    def testPredictTextInstancesWithJSON(self):
        self.mock_http.request.return_value = [
            self.http_response, self.http_body
        ]

        result = predict.Predict(self.version_ref,
                                 instances=[
                                     '{"images": [0, 1], "key": 3}',
                                     '{"images": [0.3, 0.2], "key": 2}',
                                     '{"images": [0.2, 0.1], "key": 1}'
                                 ])

        url = (self._BASE_URL + 'projects/fake-project/'
               'models/my_model/versions/v1:predict')
        method = 'POST'
        headers = {'Content-Type': 'application/json'}

        self.mock_http.request.assert_called_once_with(
            uri=url,
            method=method,
            body=('{"instances": ["{\\"images\\": [0, 1], \\"key\\": 3}", '
                  '"{\\"images\\": [0.3, 0.2], \\"key\\": 2}", '
                  '"{\\"images\\": [0.2, 0.1], \\"key\\": 1}"]}'),
            headers=headers)
        self.assertEqual(json.loads(self.http_body), result)
示例#4
0
    def testPredictMultipleJsonInstances(self):
        self.mock_http.request.return_value = [
            self.http_response, self.http_body
        ]

        test_instances = [{
            'images': [0, 1],
            'key': 3
        }, {
            'images': [2, 3],
            'key': 2
        }, {
            'images': [3, 1],
            'key': 1
        }]
        result = predict.Predict(self.version_ref, test_instances)

        url = (self._BASE_URL + 'projects/fake-project/'
               'models/my_model/versions/v1:predict')
        method = 'POST'
        headers = {'Content-Type': 'application/json'}

        self.mock_http.request.assert_called_once_with(
            uri=url,
            method=method,
            body=('{"instances": [{"images": [0, 1], "key": 3}, '
                  '{"images": [2, 3], "key": 2}, '
                  '{"images": [3, 1], "key": 1}]}'),
            headers=headers)
        self.assertEqual(json.loads(self.http_body), result)
示例#5
0
    def testPredictFailedRequest(self):
        failed_response = {'status': '502'}
        failed_response_body = 'Error 502'
        self.mock_http.request.return_value = [
            failed_response, failed_response_body
        ]

        with self.assertRaisesRegex(
                core_exceptions.Error,
                'HTTP request failed. Response: Error 502'):
            predict.Predict(self.version_ref, self.test_instances)
示例#6
0
    def testPredictInvalidResponse(self):
        invalid_http_body = 'abcd'  # invalid json dump
        self.mock_http.request.return_value = [
            self.http_response, invalid_http_body
        ]

        with self.assertRaisesRegex(
                core_exceptions.Error,
                'No JSON object could be decoded from the '
                'HTTP response body: abcd'):
            predict.Predict(self.version_ref, self.test_instances)
示例#7
0
    def testPredictJsonInstances(self):
        self.mock_http.request.return_value = [
            self.http_response, self.http_body
        ]

        result = predict.Predict(self.version_ref, self.test_instances)

        url = (self._BASE_URL + 'projects/fake-project/'
               'models/my_model/versions/v1:predict')
        method = 'POST'
        headers = {'Content-Type': 'application/json'}

        self.mock_http.request.assert_called_once_with(uri=url,
                                                       method=method,
                                                       body=self.expected_body,
                                                       headers=headers)
        self.assertEqual(json.loads(self.http_body), result)
示例#8
0
    def testPredictionSignatureName(self):
        self.mock_http.request.return_value = [
            self.http_response, self.http_body
        ]

        result = predict.Predict(self.version_ref, ['2, 3'],
                                 signature_name='my-custom-signature')

        url = (self._BASE_URL + 'projects/fake-project/'
               'models/my_model/versions/v1:predict')
        method = 'POST'
        headers = {'Content-Type': 'application/json'}

        self.mock_http.request.assert_called_once_with(
            uri=url,
            method=method,
            body=
            '{"instances": ["2, 3"], "signature_name": "my-custom-signature"}',
            headers=headers)
        self.assertEqual(json.loads(self.http_body), result)
示例#9
0
 def testPredictNonUtf8Instances(self):
     with self.assertRaisesRegex(core_exceptions.Error,
                                 'Instances cannot be JSON encoded'):
         predict.Predict(self.version_ref, [b'\x89PNG'])