def _test_predict(self, mock_validate_method, test_predict_probability):
        neural_network_repurposer = NeuralNetworkRepurposer(source_model=None)
        neural_network_repurposer.target_model = mx.module.Module.load(
            prefix=RepurposerTestUtils.MNIST_MODEL_PATH_PREFIX,
            epoch=10,
            label_names=None)
        test_iterator = RepurposerTestUtils.create_mnist_test_iterator()
        mock_validate_method.reset_mock()

        if test_predict_probability:
            labels = np.argmax(
                neural_network_repurposer.predict_probability(test_iterator),
                axis=1)
        else:
            labels = neural_network_repurposer.predict_label(test_iterator)

        # Check if predict called validate
        self.assertTrue(
            mock_validate_method.call_count == 1,
            "Predict expected to called {} once. Found {} calls".format(
                RepurposerTestUtils.VALIDATE_PREDICT_METHOD_NAME,
                mock_validate_method.call_count))

        expected_accuracy = 0.96985
        accuracy = np.mean(
            labels == RepurposerTestUtils.get_labels(test_iterator))
        self.assertTrue(
            np.isclose(accuracy, expected_accuracy, rtol=1e-3),
            "Prediction accuracy is incorrect. Expected:{}. Got:{}".format(
                expected_accuracy, accuracy))
    def test_repurpose_calls_validate(self, mock_create_target_module,
                                      mock_validate_method):
        neural_network_repurposer = NeuralNetworkRepurposer(
            source_model=self.mxnet_model)
        neural_network_repurposer.target_model = mx.module.Module.load(
            prefix=self.dropout_model_path_prefix, epoch=0)

        mock_validate_method.reset_mock()
        neural_network_repurposer.repurpose(Mock())
        self.assertTrue(
            mock_validate_method.call_count == 1,
            "Repurpose expected to called {} once. Found {} calls".format(
                RepurposerTestUtils.VALIDATE_REPURPOSE_METHOD_NAME,
                mock_validate_method.call_count))
    def test_validate_before_predict(self):
        # Test invalid inputs
        neural_network_repurposer = NeuralNetworkRepurposer(
            source_model=self.mxnet_model)

        # Target model is neither created through repurpose nor explicitly assigned
        self.assertRaisesRegex(
            TypeError,
            "Cannot predict because target_model is not an `mxnet.mod.Module` object",
            neural_network_repurposer._validate_before_predict)

        neural_network_repurposer.target_model = {}
        self.assertRaisesRegex(
            TypeError,
            "Cannot predict because target_model is not an `mxnet.mod.Module` object",
            neural_network_repurposer._validate_before_predict)

        # Assert validate raises error for mxnet module that is not trained yet
        neural_network_repurposer.target_model = self.mxnet_model
        self.assertRaisesRegex(
            ValueError,
            "target_model params aren't initialized. Ensure model is trained before calling predict",
            neural_network_repurposer._validate_before_predict)

        # Test valid input
        neural_network_repurposer.target_model = mx.module.Module.load(
            prefix=self.dropout_model_path_prefix, epoch=0)
        neural_network_repurposer._validate_before_predict()
    def test_prediction_consistency(self):
        """ Test if predict method returns consistent predictions using the same model and test data """
        if self.repurposer_class != NeuralNetworkRepurposer:
            return
        # Create test data iterator to run predictions on
        test_iterator = RepurposerTestUtils.get_image_iterator()
        # Load a pre-trained model to predict. The model has a dropout layer used for training.
        # This test is to ensure that dropout doesn't happen during prediction.
        target_model = mx.module.Module.load(
            prefix=self.dropout_model_path_prefix, epoch=0, label_names=None)

        # Create repurposer and set the target model loaded from file
        repurposer = NeuralNetworkRepurposer(source_model=None)
        repurposer.target_model = target_model

        # Ensure prediction results are consistent
        self._predict_and_compare_results(repurposer,
                                          test_iterator,
                                          test_predict_probability=True)
        self._predict_and_compare_results(repurposer,
                                          test_iterator,
                                          test_predict_probability=False)
    def test_validate_before_repurpose(self):
        # Test invalid inputs
        neural_network_repurposer = NeuralNetworkRepurposer(source_model=None)
        self.assertRaisesRegex(
            TypeError,
            "Cannot repurpose because source_model is not an `mxnet.mod.Module` object",
            neural_network_repurposer._validate_before_repurpose)

        neural_network_repurposer = NeuralNetworkRepurposer(source_model='')
        self.assertRaisesRegex(
            TypeError,
            "Cannot repurpose because source_model is not an `mxnet.mod.Module` object",
            neural_network_repurposer._validate_before_repurpose)

        # Test valid input
        neural_network_repurposer = NeuralNetworkRepurposer(
            source_model=self.mxnet_model)
        neural_network_repurposer._validate_before_repurpose()