Example #1
0
 def test_request_body_with_binary_data(self):
     example = text_format.Parse(
         """
   features {
     feature { key: "x_bytes" value { bytes_list { value: ["ASa8asdf"] }}}
     feature { key: "x" value { bytes_list { value: "JLK7ljk3" }}}
     feature { key: "y" value { int64_list { value: [1, 2] }}}
     feature { key: "z" value { float_list { value: [4.5, 5, 5.5] }}}
   }
   """, tf.train.Example())
     inference_spec_type = model_spec_pb2.InferenceSpecType(
         ai_platform_prediction_model_spec=model_spec_pb2.
         AIPlatformPredictionModelSpec(project_id='test_project',
                                       model_name='test_model',
                                       version_name='test_version'))
     remote_predict = run_inference._RemotePredictDoFn(
         inference_spec_type, None)
     result = list(remote_predict._prepare_instances([example]))
     self.assertEqual(result, [
         {
             'x_bytes': {
                 'b64': 'QVNhOGFzZGY='
             },
             'x': 'JLK7ljk3',
             'y': [1, 2],
             'z': [4.5, 5, 5.5]
         },
     ])
Example #2
0
 def test_request_serialized_example(self):
   example = text_format.Parse(
       """
     features {
       feature { key: "x_bytes" value { bytes_list { value: ["ASa8asdf"] }}}
       feature { key: "x" value { bytes_list { value: "JLK7ljk3" }}}
       feature { key: "y" value { int64_list { value: [1, 2] }}}
     }
     """, tf.train.Example())
   inference_spec_type = model_spec_pb2.InferenceSpecType(
       ai_platform_prediction_model_spec=model_spec_pb2
       .AIPlatformPredictionModelSpec(
           project_id='test_project',
           model_name='test_model',
           version_name='test_version',
           use_serialization_config=True))
   remote_predict = run_inference._RemotePredictDoFn(inference_spec_type, None)
   result = list(remote_predict._prepare_instances([example]))
   self.assertEqual(result, [{
       'b64': base64.b64encode(example.SerializeToString()).decode()
   }])