예제 #1
0
 def test_calculate_saliency_score_invalid(self):
     mock_session = mock.MagicMock()
     mock_session.run.return_value = {'softmax': (np.ones(2), np.ones(2))}
     with mock.patch.object(utils,
                            'restore_model',
                            return_value=mock_session):
         output = utils.calculate_saliency_score(
             run_params=mock.MagicMock(),
             image=np.ones((2, 2, 2)),
             saliency_map=np.zeros((2, 2)),
             area_threshold=0.05)
     self.assertIsNone(output)
예제 #2
0
 def test_calculate_saliency_score_valid(self):
     mock_session = mock.MagicMock()
     mock_session.run.return_value = {'softmax': (np.ones(2), np.ones(2))}
     with mock.patch.object(utils,
                            'restore_model',
                            return_value=mock_session):
         output = utils.calculate_saliency_score(
             run_params=mock.MagicMock(),
             image=np.ones((2, 2, 2)).astype(np.float),
             saliency_map=np.random.rand(2, 2),
             area_threshold=0.05)
     self.assertCountEqual(output.keys(), [
         'true_label', 'true_confidence', 'cropped_label',
         'cropped_confidence', 'crop_mask', 'saliency_map', 'image',
         'saliency_score'
     ])
예제 #3
0
    def test_calculate_saliency_score_text(self):
        mock_session = mock.MagicMock()
        mock_session.run.return_value = {'softmax': (np.ones(2), np.ones(2))}
        mock_run_params = mock.MagicMock()
        mock_run_params.model_type = 'text_cnn'

        with mock.patch.object(utils,
                               'restore_model',
                               return_value=mock_session):
            output = utils.calculate_saliency_score(
                run_params=mock_run_params,
                image=np.asarray([0, 1, 1, 1]).astype(np.float),
                saliency_map=np.asarray([0, 0, 1, 1]).astype(np.float),
                area_threshold=0)
        self.assertCountEqual(output.keys(), [
            'true_label', 'true_confidence', 'cropped_label',
            'cropped_confidence', 'crop_mask', 'saliency_map', 'image',
            'saliency_score'
        ])