def test_explain(self, mock_post_request_func, mock_get_metadata, mock_get_modality_map): mock_post_request_func.return_value = { 'explanations': [{ 'attributions_by_label': [{ 'attributions': { 'data': [0.01, 0.02] }, 'baseline_score': 0.0001, 'example_score': 0.80, 'label_index': 3, 'output_name': 'probability' }] }] } m = ai_platform_model.AIPlatformModel('fake_end_point') instances = [{'input': [0.05]}] explanations = m.explain(instances) self.assertTrue(mock_post_request_func.called) tensor_dict = explanations[0].as_tensors() self.assertTrue( np.array_equal(tensor_dict['data'], np.asarray([0.01, 0.02])))
def test_predict(self): self._mock_post_request_func.return_value = {'predictions': [0.5]} m = ai_platform_model.AIPlatformModel('fake_end_point') instances = [{'images_str': {'b64': u'fake_b64_str'}}] predictions = m.predict(instances) self.assertEqual(predictions['predictions'][0], 0.5)
def test_explain_attribution_key_error(self): self._mock_post_request_func.return_value = { 'explanations': [{ 'unknown_attribution_key': {} }] } m = ai_platform_model.AIPlatformModel('fake_end_point') with self.assertRaisesRegex(KeyError, ('Attribution keys are not present')): m.explain([{'input': [0.05]}])
def test_explain_with_error(self, mock_post_request_func, mock_get_metadata, mock_get_modality_map): mock_post_request_func.return_value = {'error': 'This is an error.'} m = ai_platform_model.AIPlatformModel('fake_end_point') instances = [{'input': [0.05]}] with self.assertRaisesRegex( ValueError, ('Explanation call failed. .*\n' 'Original error message: "This is an error."')): m.explain(instances)
def test_explain(self, mock_response): self._mock_post_request_func.return_value = mock_response m = ai_platform_model.AIPlatformModel( 'fake_end_point', modality_to_inputs_map={constants.ALL_MODALITY: ['data']}) instances = [{'input': [0.05]}] explanations = m.explain(instances) self.assertTrue(self._mock_post_request_func.called) tensor_dict = explanations[0].as_tensors() self.assertTrue( np.array_equal(tensor_dict['data'], np.asarray([0.01, 0.02])))