def test_get_metadata_double_output(self):
     md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
         self.model_path,
         signature_name='double',
         outputs_to_explain=['lin'])
     self.assertLen(md_builder.get_metadata()['outputs'], 1)
     self.assertIn('lin', md_builder.get_metadata()['outputs'])
 def test_save_metadata(self):
     md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
         self.model_path, tags=[tf.saved_model.tag_constants.SERVING])
     filepath = self.create_tempdir().full_path
     md_builder.save_metadata(filepath)
     self.assertTrue(
         os.path.exists(os.path.join(filepath,
                                     'explanation_metadata.json')))
 def test_constructor_incorrect_signature_name(self):
     with self.assertRaisesRegex(
             ValueError, 'Serving sigdef key .* not in '
             'the signature def.'):
         _ = saved_model_metadata_builder.SavedModelMetadataBuilder(
             self.model_path,
             tags=[tf.saved_model.tag_constants.SERVING],
             signature_name='incorrect_signature')
 def test_save_model_with_metadata_successfully(self):
     model_path = self.create_tempdir().full_path
     md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
         self.model_path, tags=[tf.saved_model.tag_constants.SERVING])
     md_builder.save_model_with_metadata(model_path)
     md = explain_metadata.ExplainMetadata.from_file(
         os.path.join(model_path, 'explanation_metadata.json'))
     self.assertDictEqual(md.to_dict()['inputs'],
                          md_builder.get_metadata()['inputs'])
     self.assertDictEqual(md.to_dict()['outputs'],
                          md_builder.get_metadata()['outputs'])
 def test_get_metadata_correct_inputs(self):
     md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
         self.model_path, tags=[tf.saved_model.tag_constants.SERVING])
     self.assertLen(md_builder.get_metadata()['inputs'], 1)
     self.assertLen(md_builder.get_metadata()['outputs'], 1)
 def test_constructor_no_outputs_explain(self):
     with self.assertRaisesRegex(
             ValueError, 'The signature contains multiple '
             'outputs'):
         _ = saved_model_metadata_builder.SavedModelMetadataBuilder(
             self.model_path, signature_name='double')
 def test_constructor_multiple_outputs_to_explain(self):
     with self.assertRaisesRegex(ValueError,
                                 'Only one output is supported'):
         _ = saved_model_metadata_builder.SavedModelMetadataBuilder(
             self.model_path, outputs_to_explain=['out1', 'out2'])
 def test_constructor_empty_tags(self):
     md_builder = saved_model_metadata_builder.SavedModelMetadataBuilder(
         self.model_path, tags=[])
     self.assertLen(md_builder.get_metadata()['inputs'], 1)
     self.assertLen(md_builder.get_metadata()['outputs'], 1)