def test_keras_layer_fails_if_setting_both_output_key_and_as_dict(self): path = test_utils.get_test_data_path("hub_module_v1_mini") with self.assertRaisesRegex( ValueError, "When using a signature, either output_key or " "signature_outputs_as_dict=True should be set."): hub.KerasLayer(path, signature="default", signature_outputs_as_dict=True, output_key="output")
def test_load_v1(self): if (not hasattr(tf.compat.v1.saved_model, "load_v2") or not tf.compat.v1.executing_eagerly()): return # The test only applies when running V2 mode. full_module_path = test_utils.get_test_data_path("half_plus_two_v1.tar.gz") os.chdir(os.path.dirname(full_module_path)) server_port = test_utils.start_http_server() handle = "http://localhost:%d/half_plus_two_v1.tar.gz" % server_port hub.load(handle)
def test_keras_layer_get_config( self, module_name, signature, output_key, as_dict): inputs = 10. # Test modules perform increment op. path = test_utils.get_test_data_path(module_name) layer = hub.KerasLayer(path, signature=signature, output_key=output_key, signature_outputs_as_dict=as_dict) outputs = layer(inputs) config = layer.get_config() new_layer = hub.KerasLayer.from_config(_json_cycle(config)) new_outputs = new_layer(inputs) self.assertEqual(outputs, new_outputs)
def test_load_legacy_hub_module_v1_with_signature( self, module_name, signature, output_key, as_dict): inputs, expected_outputs = 10., 11. # Test modules perform increment op. path = test_utils.get_test_data_path(module_name) layer = hub.KerasLayer(path, signature=signature, output_key=output_key, signature_outputs_as_dict=as_dict) output = layer(inputs) if as_dict: self.assertEqual(output, {"default": expected_outputs}) else: self.assertEqual(output, expected_outputs)
def test_load_callable_saved_model_v2_with_signature( self, module_name, signature, output_key, as_dict): inputs, expected_outputs = 10., 11. # Test modules perform increment op. path = test_utils.get_test_data_path(module_name) layer = hub.KerasLayer(path, signature=signature, output_key=output_key, signature_outputs_as_dict=as_dict) output = layer(inputs) if as_dict: self.assertIsInstance(output, dict) self.assertEqual(output["output_0"], expected_outputs) else: self.assertEqual(output, expected_outputs)
def test_load(self, module_name, tags, is_hub_module_v1): path = test_utils.get_test_data_path(module_name) m = module_v2.load(path, tags) self.assertEqual(m._is_hub_module_v1, is_hub_module_v1)
def test_keras_layer_fails_if_output_key_not_in_layer_outputs(self): path = test_utils.get_test_data_path("hub_module_v1_mini") layer = hub.KerasLayer(path, output_key="unknown") with self.assertRaisesRegex( ValueError, "KerasLayer output does not contain the output key*"): layer(10.)
def test_keras_layer_fails_if_output_is_not_dict(self): path = test_utils.get_test_data_path("saved_model_v2_mini") layer = hub.KerasLayer(path, output_key="output_0") with self.assertRaisesRegex( ValueError, "Specifying `output_key` is forbidden if output type *"): layer(10.)
def test_keras_layer_fails_if_saved_model_v2_with_tags(self): path = test_utils.get_test_data_path("saved_model_v2_mini") with self.assertRaises(ValueError): hub.KerasLayer(path, signature=None, tags=["train"])
def test_keras_layer_fails_if_with_outputs_as_dict_but_no_signature(self): path = test_utils.get_test_data_path("saved_model_v2_mini") with self.assertRaisesRegex( ValueError, "signature_outputs_as_dict is only valid if specifying a signature *"): hub.KerasLayer(path, signature_outputs_as_dict=True)
def test_keras_layer_fails_if_signature_output_not_specified(self): path = test_utils.get_test_data_path("saved_model_v2_mini") with self.assertRaisesRegex( ValueError, "When using a signature, either output_key or " "signature_outputs_as_dict=True should be set."): hub.KerasLayer(path, signature="serving_default")
def test_load_with_defaults(self, module_name): inputs, expected_outputs = 10., 11. # Test modules perform increment op. path = test_utils.get_test_data_path(module_name) layer = hub.KerasLayer(path) output = layer(inputs) self.assertEqual(output, expected_outputs)