def test_create_warmup_requests_numpy(self):
    mock_t2r_model = mocks.MockT2RModel(
        preprocessor_cls=noop_preprocessor.NoOpPreprocessor)
    exporter = mocks.MockExportGenerator()
    exporter.set_specification_from_model(mock_t2r_model)

    export_dir = self.create_tempdir()
    batch_sizes = [2, 4]
    request_filename = exporter.create_warmup_requests_numpy(
        batch_sizes=batch_sizes, export_dir=export_dir.full_path)

    for expected_batch_size, record in zip(
        batch_sizes, tf.compat.v1.io.tf_record_iterator(request_filename)):
      record_proto = prediction_log_pb2.PredictionLog()
      record_proto.ParseFromString(record)
      request = record_proto.predict_log.request
      self.assertEqual(request.model_spec.name, 'MockT2RModel')
      for _, in_tensor in request.inputs.items():
        self.assertEqual(in_tensor.tensor_shape.dim[0].size,
                         expected_batch_size)
示例#2
0
  def test_hooks(self, mock_create_warmup_requests_numpy,
                 mock_create_serving_input_receiver_numpy_fn,
                 mock_checkpoint_init, mock_export_saved_model):

    def _checkpoint_init(export_fn, export_dir, **kwargs):
      del kwargs
      export_fn(export_dir, global_step=1)
      return None

    mock_checkpoint_init.side_effect = _checkpoint_init

    export_generator = mocks.MockExportGenerator()

    hook_builder = td3.TD3Hooks(
        export_dir=_EXPORT_DIR,
        lagged_export_dir=_LAGGED_EXPORT_DIR,
        batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT,
        export_generator=export_generator)

    model = mocks.MockT2RModel()
    estimator = MockEstimator()

    mock_create_warmup_requests_numpy.return_value = _NUMPY_WARMUP_REQUESTS

    hooks = hook_builder.create_hooks(t2r_model=model, estimator=estimator)
    self.assertLen(hooks, 1)

    mock_create_warmup_requests_numpy.assert_called_with(
        batch_sizes=_BATCH_SIZES_FOR_EXPORT,
        export_dir=_MODEL_DIR)

    mock_export_saved_model.assert_called_with(
        serving_input_receiver_fn=mock.ANY,
        export_dir_base=_EXPORT_DIR,
        assets_extra={
            "tf_serving_warmup_requests": _NUMPY_WARMUP_REQUESTS,
            tensorspec_utils.T2R_ASSETS_FILENAME: mock.ANY
        })

    mock_create_serving_input_receiver_numpy_fn.assert_called()