def _test_predict(self, mock_validate_method, test_predict_probability): neural_network_repurposer = NeuralNetworkRepurposer(source_model=None) neural_network_repurposer.target_model = mx.module.Module.load( prefix=RepurposerTestUtils.MNIST_MODEL_PATH_PREFIX, epoch=10, label_names=None) test_iterator = RepurposerTestUtils.create_mnist_test_iterator() mock_validate_method.reset_mock() if test_predict_probability: labels = np.argmax( neural_network_repurposer.predict_probability(test_iterator), axis=1) else: labels = neural_network_repurposer.predict_label(test_iterator) # Check if predict called validate self.assertTrue( mock_validate_method.call_count == 1, "Predict expected to called {} once. Found {} calls".format( RepurposerTestUtils.VALIDATE_PREDICT_METHOD_NAME, mock_validate_method.call_count)) expected_accuracy = 0.96985 accuracy = np.mean( labels == RepurposerTestUtils.get_labels(test_iterator)) self.assertTrue( np.isclose(accuracy, expected_accuracy, rtol=1e-3), "Prediction accuracy is incorrect. Expected:{}. Got:{}".format( expected_accuracy, accuracy))
def test_repurpose_calls_validate(self, mock_create_target_module, mock_validate_method): neural_network_repurposer = NeuralNetworkRepurposer( source_model=self.mxnet_model) neural_network_repurposer.target_model = mx.module.Module.load( prefix=self.dropout_model_path_prefix, epoch=0) mock_validate_method.reset_mock() neural_network_repurposer.repurpose(Mock()) self.assertTrue( mock_validate_method.call_count == 1, "Repurpose expected to called {} once. Found {} calls".format( RepurposerTestUtils.VALIDATE_REPURPOSE_METHOD_NAME, mock_validate_method.call_count))
def test_validate_before_predict(self): # Test invalid inputs neural_network_repurposer = NeuralNetworkRepurposer( source_model=self.mxnet_model) # Target model is neither created through repurpose nor explicitly assigned self.assertRaisesRegex( TypeError, "Cannot predict because target_model is not an `mxnet.mod.Module` object", neural_network_repurposer._validate_before_predict) neural_network_repurposer.target_model = {} self.assertRaisesRegex( TypeError, "Cannot predict because target_model is not an `mxnet.mod.Module` object", neural_network_repurposer._validate_before_predict) # Assert validate raises error for mxnet module that is not trained yet neural_network_repurposer.target_model = self.mxnet_model self.assertRaisesRegex( ValueError, "target_model params aren't initialized. Ensure model is trained before calling predict", neural_network_repurposer._validate_before_predict) # Test valid input neural_network_repurposer.target_model = mx.module.Module.load( prefix=self.dropout_model_path_prefix, epoch=0) neural_network_repurposer._validate_before_predict()
def test_prediction_consistency(self): """ Test if predict method returns consistent predictions using the same model and test data """ if self.repurposer_class != NeuralNetworkRepurposer: return # Create test data iterator to run predictions on test_iterator = RepurposerTestUtils.get_image_iterator() # Load a pre-trained model to predict. The model has a dropout layer used for training. # This test is to ensure that dropout doesn't happen during prediction. target_model = mx.module.Module.load( prefix=self.dropout_model_path_prefix, epoch=0, label_names=None) # Create repurposer and set the target model loaded from file repurposer = NeuralNetworkRepurposer(source_model=None) repurposer.target_model = target_model # Ensure prediction results are consistent self._predict_and_compare_results(repurposer, test_iterator, test_predict_probability=True) self._predict_and_compare_results(repurposer, test_iterator, test_predict_probability=False)
def test_validate_before_repurpose(self): # Test invalid inputs neural_network_repurposer = NeuralNetworkRepurposer(source_model=None) self.assertRaisesRegex( TypeError, "Cannot repurpose because source_model is not an `mxnet.mod.Module` object", neural_network_repurposer._validate_before_repurpose) neural_network_repurposer = NeuralNetworkRepurposer(source_model='') self.assertRaisesRegex( TypeError, "Cannot repurpose because source_model is not an `mxnet.mod.Module` object", neural_network_repurposer._validate_before_repurpose) # Test valid input neural_network_repurposer = NeuralNetworkRepurposer( source_model=self.mxnet_model) neural_network_repurposer._validate_before_repurpose()