def test_create_warmup_requests_numpy(self): mock_t2r_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor) exporter = mocks.MockExportGenerator() exporter.set_specification_from_model(mock_t2r_model) export_dir = self.create_tempdir() batch_sizes = [2, 4] request_filename = exporter.create_warmup_requests_numpy( batch_sizes=batch_sizes, export_dir=export_dir.full_path) for expected_batch_size, record in zip( batch_sizes, tf.compat.v1.io.tf_record_iterator(request_filename)): record_proto = prediction_log_pb2.PredictionLog() record_proto.ParseFromString(record) request = record_proto.predict_log.request self.assertEqual(request.model_spec.name, 'MockT2RModel') for _, in_tensor in request.inputs.items(): self.assertEqual(in_tensor.tensor_shape.dim[0].size, expected_batch_size)
def test_hooks(self, mock_create_warmup_requests_numpy, mock_create_serving_input_receiver_numpy_fn, mock_checkpoint_init, mock_export_saved_model): def _checkpoint_init(export_fn, export_dir, **kwargs): del kwargs export_fn(export_dir, global_step=1) return None mock_checkpoint_init.side_effect = _checkpoint_init export_generator = mocks.MockExportGenerator() hook_builder = td3.TD3Hooks( export_dir=_EXPORT_DIR, lagged_export_dir=_LAGGED_EXPORT_DIR, batch_sizes_for_export=_BATCH_SIZES_FOR_EXPORT, export_generator=export_generator) model = mocks.MockT2RModel() estimator = MockEstimator() mock_create_warmup_requests_numpy.return_value = _NUMPY_WARMUP_REQUESTS hooks = hook_builder.create_hooks(t2r_model=model, estimator=estimator) self.assertLen(hooks, 1) mock_create_warmup_requests_numpy.assert_called_with( batch_sizes=_BATCH_SIZES_FOR_EXPORT, export_dir=_MODEL_DIR) mock_export_saved_model.assert_called_with( serving_input_receiver_fn=mock.ANY, export_dir_base=_EXPORT_DIR, assets_extra={ "tf_serving_warmup_requests": _NUMPY_WARMUP_REQUESTS, tensorspec_utils.T2R_ASSETS_FILENAME: mock.ANY }) mock_create_serving_input_receiver_numpy_fn.assert_called()