def predict(self, X_Test, patClassIdTest): """ Perform classification result = predict(Xl_Test, Xu_Test, patClassIdTest) INPUT: X_Test Test data (rows = objects, columns = features) patClassIdTest Test data class labels (crisp) OUTPUT: result A object with Bunch datatype containing all results as follows: + summis Number of misclassified objects + misclass Binary error map + sumamb Number of objects with maximum membership in more than one class + out Soft class memberships + mem Hyperbox memberships """ #X_Test = delete_const_dims(X_Test) # Normalize testing dataset if training datasets were normalized if len(self.mins) > 0: noSamples = X_Test.shape[0] X_Test = self.loLim + (self.hiLim - self.loLim) * ( X_Test - np.ones((noSamples, 1)) * self.mins) / (np.ones( (noSamples, 1)) * (self.maxs - self.mins)) if X_Test.min() < self.loLim or X_Test.max() > self.hiLim: print('Test sample falls outside', self.loLim, '-', self.hiLim, 'interval') print('Number of original samples = ', noSamples) # only keep samples within the interval loLim-hiLim indX_Keep = np.where((X_Test >= self.loLim).all(axis=1) & (X_Test <= self.hiLim).all(axis=1))[0] X_Test = X_Test[indX_Keep, :] print('Number of kept samples =', X_Test.shape[0]) # do classification result = None if X_Test.shape[0] > 0: result = predict(self.V, self.W, self.classId, X_Test, patClassIdTest, self.gamma) return result
def pruning_val(self, XTest, patClassIdTest, accuracy_threshold=0.5): """ pruning handling based on validation (validation routine) with hyperboxes stored in self. V, W, classId result = pruning_val(XlT,XuT,patClassIdTest) INPUT XTest Test data (rows = objects, columns = features) patClassIdTest Test data class labels (crisp) accuracy_threshold The minimum accuracy for each hyperbox """ #initialization yX = XTest.shape[0] mem = np.zeros((yX, self.V.shape[0])) no_predicted_samples_hyperboxes = np.zeros((len(self.classId), 2)) # classifications for i in range(yX): mem[i, :] = simpsonMembership( XTest[i, :], self.V, self.W, self.gamma) # calculate memberships for all hyperboxes bmax = mem[i, :].max() # get max membership value maxVind = np.nonzero(mem[i, :] == bmax)[ 0] # get indexes of all hyperboxes with max membership if len(maxVind) == 1: # Only one hyperbox with the highest membership function if self.classId[maxVind[0]] == patClassIdTest[i]: no_predicted_samples_hyperboxes[ maxVind[0], 0] = no_predicted_samples_hyperboxes[maxVind[0], 0] + 1 else: no_predicted_samples_hyperboxes[ maxVind[0], 1] = no_predicted_samples_hyperboxes[maxVind[0], 1] + 1 else: # More than one hyperbox with highest membership => random choosing id_min = maxVind[np.random.randint(len(maxVind))] if self.classId[id_min] != patClassIdTest[ i] and patClassIdTest[i] != 0: no_predicted_samples_hyperboxes[ id_min, 1] = no_predicted_samples_hyperboxes[id_min, 1] + 1 else: no_predicted_samples_hyperboxes[ id_min, 0] = no_predicted_samples_hyperboxes[id_min, 0] + 1 # pruning handling based on the validation results tmp_no_box = no_predicted_samples_hyperboxes.shape[0] accuracy_larger_half = np.zeros(tmp_no_box).astype(np.bool) accuracy_larger_half_keep_nojoin = np.zeros(tmp_no_box).astype(np.bool) for i in range(tmp_no_box): if (no_predicted_samples_hyperboxes[i, 0] + no_predicted_samples_hyperboxes[i, 1] != 0) and no_predicted_samples_hyperboxes[i, 0] / ( no_predicted_samples_hyperboxes[i, 0] + no_predicted_samples_hyperboxes[i, 1] ) >= accuracy_threshold: accuracy_larger_half[i] = True accuracy_larger_half_keep_nojoin[i] = True if (no_predicted_samples_hyperboxes[i, 0] + no_predicted_samples_hyperboxes[i, 1] == 0): accuracy_larger_half_keep_nojoin[i] = True # keep one hyperbox for class prunned all current_classes = np.unique(self.classId) class_tmp = self.classId[accuracy_larger_half] for c in current_classes: if c not in class_tmp: pos = np.nonzero(self.classId == c) id_kept = np.random.randint(len(pos)) # keep pos[id_kept] accuracy_larger_half[pos[id_kept]] = True # Pruning V_prun_remove = self.V[accuracy_larger_half] W_prun_remove = self.W[accuracy_larger_half] classId_prun_remove = self.classId[accuracy_larger_half] W_prun_keep = self.W[accuracy_larger_half_keep_nojoin] V_prun_keep = self.V[accuracy_larger_half_keep_nojoin] classId_prun_keep = self.classId[accuracy_larger_half_keep_nojoin] result_prun_remove = predict(V_prun_remove, W_prun_remove, classId_prun_remove, XTest, patClassIdTest, self.gamma) result_prun_keep_nojoin = predict(V_prun_keep, W_prun_keep, classId_prun_keep, XTest, patClassIdTest, self.gamma) if (result_prun_remove.summis <= result_prun_keep_nojoin.summis): self.V = V_prun_remove self.W = W_prun_remove self.classId = classId_prun_remove else: self.V = V_prun_keep self.W = W_prun_keep self.classId = classId_prun_keep