コード例 #1
0
ファイル: svmMet.py プロジェクト: niuye8911/rapidlib-linux
    def getQoS(self):
        X_train, y_train, X_test, y_test = load_CIFAR10(self.data_path)

        X_test = np.reshape(X_test, (X_test.shape[0], -1))
        X_train = np.reshape(X_train, (X_train.shape[0], -1))

        mean_image = np.mean(X_train, axis=0)

        X_test -= mean_image
        X_test = np.hstack([X_test, np.ones((X_test.shape[0], 1))])
        svm = LinearSVM()
        try:
            svm.W = pickle.load(open(self.run_dir + "model_svm.p", "rb"),
                                encoding='latin1')
            y_test_pred = svm.predict(X_test)
            test_accuracy = np.mean(y_test == y_test_pred)
        except:
            test_accuracy = 0.0
        return test_accuracy * 100.0