def train_refined_SRClassifier(wcd): """A training function that takes in a WeightedCDataset and produces a weak BinaryClassifier. This training function uses train_SRClassifier() to train but with 6 rectangles and 2 seconds. Also, the threshold of the weak classifier is refined after training. The function to minimize is classification error with lambda = 1.0 :Parameters: wcd : WeightedCDataset a weighted dataset of vector-integrated patches :Returns: fc: THaarClassifier the best feature classifier obtained with fc.b being the best threshold after refined fc.err is the estimated 'error' corresponding to fc.b fc.b2 is the best threshold estimated by train_SRClassifier() fc.err2 is its corresponding 'estimated error' fc.stats2e is its corresponding 1D statistics """ stats2 = compute_Stats2_integral(wcd, True) fc = train_SRClassifier(stats2, 0, 1, 6, 2)[0] # project to this newly found direction feature_value_array = [project_Haar(x, fc.w, fc.p) for x in wcd.input_data] wcd1 = WeightedCDataset(feature_value_array, wcd.weights) # save results fc.b2 = fc.b fc.err2 = fc.err fc.stats2e = project_SR_stats(stats2, fc.w, fc.p) # refine the fc classifier z = thresh_1d(0,1,wcd1,sort_1d(wcd1)) fc.b = z[0] fc.err = z[1] # store the histograms # z = histogram_1d(wcd1, 32) # fc.hist = z[0] # fc.minValue = z[1][0,0] # fc.maxValue = z[1][31,1] # tprint("Hist=") # tprint(fc.hist) # tprint(fc.minValue) # tprint(fc.maxValue) # tprint(fc.stats2e.A) return fc
def get_local_stats(cd,haars): """ Take a 2-class ClassificationDataset, obtain the global statistics, then project using different Haar directions. Input: cd: a 2-class ClassificationDataset haars: the generated Haar feature set of J features Output: output of _project_Haar_stats() actually B: a list of J*2-class Stats2e, each for each type of Haar features. invert: a list of numpy.arrays, each array tells whether which feature needs negating its projection direction """ tprint("Accumulating data...") stats2 = compute_Stats2_integral(cd,True) tprint("Projecting all the Haar features...") return _project_Haar_stats(stats2,haars)