def _test_save_load_repurposed_model(self, save_source):
        file_path = 'test_serialisation'
        RepurposerTestUtils._remove_files_with_prefix(file_path)
        source_model = mx.module.Module.load(
            'tests/data/testnetv1',
            0,
            label_names=['softmaxoutput1_label'],
            data_names=('data', ))
        repurposer = self._get_repurposer(source_model)
        repurposer.repurpose(self.train_iter)
        results = repurposer.predict_label(test_iterator=self.test_iter)
        assert not os.path.isfile(file_path + '.json')
        repurposer.save_repurposer(model_name=file_path,
                                   save_source_model=save_source)
        assert os.path.isfile(file_path + '.json')
        if save_source:
            loaded_repurposer = load(file_path)
        else:
            loaded_repurposer = load(file_path,
                                     source_model=repurposer.source_model)
        results_loaded = loaded_repurposer.predict_label(
            test_iterator=self.test_iter)

        assert type(repurposer) == type(loaded_repurposer)
        accuracy1 = np.mean(results == self.test_labels)
        accuracy2 = np.mean(results_loaded == self.test_labels)
        assert abs(accuracy1 - accuracy2) < 0.05

        assert repurposer.get_params() == loaded_repurposer.get_params()
        self._assert_attributes_equal(repurposer, loaded_repurposer)
        RepurposerTestUtils._remove_files_with_prefix(file_path)
Пример #2
0
    def _test_save_load_repurposed_model(self, mock_model_handler, save_source_model):
        # To speed-up unit test running time. Accuracy is validated in integration tests.
        num_train_points = 2
        self.train_features = self.train_features[:num_train_points]
        self.train_labels = self.train_labels[:num_train_points]

        mock_model_handler.return_value = RepurposerTestUtils.get_mock_model_handler_object()
        file_path = 'test_serialisation'
        RepurposerTestUtils._remove_files_with_prefix(file_path)
        source_model = mx.module.Module.load('tests/data/testnetv1', 0, label_names=['softmaxoutput1_label'],
                                             data_names=('data',))
        repurposer = self.repurposer_class(source_model, self.source_model_layers)
        if self.repurposer_class == BnnRepurposer:
            repurposer = BnnRepurposer(source_model, self.source_model_layers, num_epochs=1,
                                       num_samples_mc_prediction=15)
        repurposer.target_model = repurposer._train_model_from_features(self.train_features, self.train_labels)
        # Manually setting provide_data and provide_label because repurpose() is not called
        repurposer.provide_data = [('data', (2, 3, 224, 224))]
        repurposer.provide_label = [('softmaxoutput1_label', (2,))]
        # Mocking iterator because get_layer_output is patched
        mock_model_handler.return_value.get_layer_output.return_value = self.test_feature_dict, self.test_labels
        results = repurposer.predict_label(test_iterator=self.mock_object)
        assert not os.path.isfile(file_path + '.json')

        if save_source_model:
            assert not os.path.isfile(file_path + '_source-symbol.json')
            assert not os.path.isfile(file_path + '_source-0000.params')
            repurposer.save_repurposer(model_name=file_path, save_source_model=save_source_model)
            assert os.path.isfile(file_path + '_source-symbol.json')
            assert os.path.isfile(file_path + '_source-0000.params')
            loaded_repurposer = load(file_path)
        else:
            repurposer.save_repurposer(model_name=file_path, save_source_model=save_source_model)
            loaded_repurposer = load(file_path, source_model=repurposer.source_model)

        assert os.path.isfile(file_path + '.json')
        RepurposerTestUtils._remove_files_with_prefix(file_path)
        results_loaded = loaded_repurposer.predict_label(test_iterator=self.mock_object)
        assert type(repurposer) == type(loaded_repurposer)
        self._assert_target_model_equal(repurposer.target_model, loaded_repurposer.target_model)
        accuracy1 = np.mean(results == self.test_labels)
        accuracy2 = np.mean(results_loaded == self.test_labels)

        if self.repurposer_class == BnnRepurposer:
            assert np.isclose(accuracy1, accuracy2, atol=0.1), 'Inconsistent accuracies: {}, {}.'.format(accuracy1,
                                                                                                         accuracy2)
        else:
            assert accuracy1 == accuracy2, 'Inconsistent accuracies: {}, {}.'.format(accuracy1, accuracy2)

        self._assert_attributes_equal(repurposer, loaded_repurposer)
Пример #3
0
 def _save_and_load_repurposer(self, gp_repurposer):
     file_path = 'test_serialisation'
     RepurposerTestUtils._remove_files_with_prefix(file_path)
     assert not os.path.isfile(file_path + '.json')
     gp_repurposer.save_repurposer(model_name=file_path, save_source_model=False)
     assert os.path.isfile(file_path + '.json')
     loaded_repurposer = load(file_path, source_model=gp_repurposer.source_model)
     RepurposerTestUtils._remove_files_with_prefix(file_path)
     return loaded_repurposer
Пример #4
0
 def test_load_pre_saved_repurposer(self):
     """ Test case to check for backward compatibility of deserialization """
     if self.__class__ == WorkflowTestCase:  # base class
         return
     # Load pre-saved repurposer from file
     repurposer_file_prefix = self.pre_saved_prefix + self.__class__.__name__
     repurposer = xfer.load(repurposer_file_prefix, source_model=self.source_model)
     # Validate accuracy of predictions
     predicted_labels = repurposer.predict_label(self.test_iter)
     accuracy = np.mean(predicted_labels == self.test_labels)
     self.assert_accuracy(accuracy)
Пример #5
0
    def test_workflow(self):
        """
        Test workflow

        Instantiate repurposer(1), repurpose(2), predict(3), save & load with source model(4),
        predict(5), repurpose(6), predict(7), save & load without model(8), predict(9)
        """
        if self.__class__ == WorkflowTestCase:  # base class
            return
        # remove any old saved repurposer files
        RepurposerTestUtils._remove_files_with_prefix(self.save_name)

        # instantiate repurposer (1)
        rep = self.get_repurposer(self.source_model)

        for save_source_model in [True, False]:
            # (2/6) repurpose
            # random seeds are set before repurposing to ensure training is the same
            np.random.seed(1)
            random.seed(1)
            mx.random.seed(1)
            rep.repurpose(self.train_iter)
            # (3/7) predict
            results = rep.predict_label(self.test_iter)
            accuracy = np.mean(results == self.test_labels)
            self.assert_accuracy(accuracy)
            # (4/8) serialise
            rep.save_repurposer(self.save_name,
                                save_source_model=save_source_model)
            del rep
            if save_source_model:
                rep = xfer.load(self.save_name)
            else:
                rep = xfer.load(self.save_name, source_model=self.source_model)
            RepurposerTestUtils._remove_files_with_prefix(self.save_name)
            # (5/9) predict
            results = rep.predict_label(self.test_iter)
            accuracy = np.mean(results == self.test_labels)
            self.assert_accuracy(accuracy)