def compute_det_pr_and_hard_neg(dets, gt, min_overlap=0.5): """ Compute the Precision-Recall and find hard negatives of the given detections for the ground truth. Args: dets (skpyutils.Table): detections. !NOTE: the first four columns must be the bounding box coordinates! gt (skpyutils.Table): detectin ground truth Can be for a single image or a whole dataset, and can contain either all classes or a single class. The 'cls_ind' column must be present in either case. Note that depending on these choices, the meaning of the PR evaluation is different. In particular, if gt is for a single class but detections are for multiple classes, there will be a lot of false positives! min_overlap (float): minimum required area of union of area of intersection overlap for a true positive. Returns: (ap, recall, precision, hard_negatives, sorted_dets): tuple of (float, list, list, list, ndarray), where the lists are 0/1 masks onto the sorted dets. """ tt = TicToc().tic() # if dets or gt are empty, return 0's nd = dets.arr.shape[0] if nd < 1 or gt.shape[0] < 1: ap = 0 rec = np.array([0]) prec = np.array([0]) hard_negs = np.array([0]) return (ap, rec, prec, hard_negs) # augment gt with a column keeping track of matches cols = list(gt.cols) + ["matched"] arr = np.zeros((gt.arr.shape[0], gt.arr.shape[1] + 1)) arr[:, :-1] = gt.arr.copy() gt = Table(arr, cols) # sort detections by confidence dets = dets.copy() dets.sort_by_column("score", descending=True) # match detections to ground truth objects npos = gt.filter_on_column("diff", 0).shape[0] tp = np.zeros(nd) fp = np.zeros(nd) hard_neg = np.zeros(nd) for d in range(nd): if tt.qtoc() > 15: print("... on %d/%d dets" % (d, nd)) tt.tic() det = dets.arr[d, :] # find ground truth for this image if "img_ind" in gt.cols: img_ind = det[dets.ind("img_ind")] inds = gt.arr[:, gt.ind("img_ind")] == img_ind gt_for_image = gt.arr[inds, :] else: gt_for_image = gt.arr if gt_for_image.shape[0] < 1: # false positive due to a det in image that does not contain the class # NOTE: this can happen if we're passing ground truth for a class fp[d] = 1 hard_neg[d] = 1 continue # find the maximally overlapping ground truth element for this # detection overlaps = BoundingBox.get_overlap(gt_for_image[:, :4], det[:4]) jmax = overlaps.argmax() ovmax = overlaps[jmax] # assign detection as true positive/don't care/false positive if ovmax >= min_overlap: if gt_for_image[jmax, gt.ind("diff")]: # not a false positive because object is difficult None else: if gt_for_image[jmax, gt.ind("matched")] == 0: if gt_for_image[jmax, gt.ind("cls_ind")] == det[dets.ind("cls_ind")]: # true positive tp[d] = 1 gt_for_image[jmax, gt.ind("matched")] = 1 else: # false positive due to wrong class fp[d] = 1 hard_neg[d] = 1 else: # false positive due to multiple detection # this is still a correct answer, so not a hard negative fp[d] = 1 else: # false positive due to not matching any ground truth object fp[d] = 1 hard_neg[d] = 1 # NOTE: must do this for gt.arr to get the changes we made to # gt_for_image if "img_ind" in gt.cols: gt.arr[inds, :] = gt_for_image ap, rec, prec = compute_rec_prec_ap(tp, fp, npos) return (ap, rec, prec, hard_neg, dets)
def plot_coocurrence(self, cmap=plt.cm.Reds, color_anchor=[0, 1], x_tick_rot=90, size=None, title=None, plot_vals=True, second_order=False): """ Plot a heat map of conditional occurence, where cell (i,j) means P(C_j|C_i). The last column in the K x (K+2) heat map corresponds to the prior P(C_i). If second_order, plots (K choose 2) x (K+2) heat map corresponding to P(C_i|C_j,C_k): second-order correlations. Return the figure. """ table = self.get_cls_ground_truth(with_diff=False, with_trun=True) # This takes care of most of the difference between normal and second_order # In the former case, a "combination" is just one class to condition # on. combinations = combination_strs = table.cols if second_order: combinations = [x for x in itertools.combinations(table.cols, 2)] combination_strs = ['%s, %s' % (x[0], x[1]) for x in combinations] total = table.shape[0] N = len(table.cols) K = len(combinations) # extra columns are for P("nothing"|C) and P(C) data = np.zeros((K, N + 2)) for i, combination in enumerate(combinations): if second_order: cls1 = combination[0] cls2 = combination[1] conditioned = table.filter_on_column( cls1).filter_on_column(cls2) else: conditioned = table.filter_on_column(combination) # count all the classes data[i, :-2] = conditioned.sum() # count the number of times that cls was the only one present to get # P("nothing"|C) if second_order: data[i, -2] = ((conditioned.sum(1) - 2) == 0).sum() else: data[i, -2] = ((conditioned.sum(1) - 1) == 0).sum() # normalize max_val = np.max(data[i, :]) data[i, :] /= max_val data[i, :][data[i, :] == 1] = np.nan # use the max count to compute the prior data[i, -1] = max_val / total m = Table( data, table.cols + ['nothing', 'prior'], index=combination_strs) # If second_order, sort by prior and remove rows with 0 prior if second_order: m = m.filter_on_column('prior', 0.001, operator.gt).\ sort_by_column('prior', descending=True) # TODO: just take the top K actually, for a side-by-side figure m.arr = m.arr[:len(self.classes), :] if size: fig = plt.figure(figsize=size) else: w = max(12, m.shape[1]) h = max(12, m.shape[0]) fig = plt.figure(figsize=(w, h)) ax_im = fig.add_subplot(111) # make axes for colorbar divider = make_axes_locatable(ax_im) ax_cb = divider.new_vertical(size="5%", pad=0.1, pack_start=True) fig.add_axes(ax_cb) # The call to imshow produces the matrix plot: im = ax_im.imshow(m.arr, origin='upper', interpolation='nearest', vmin=color_anchor[0], vmax=color_anchor[1], cmap=cmap) # Formatting: ax = ax_im ax.set_xticks(np.arange(m.shape[1])) ax.set_xticklabels(m.cols) for tick in ax.xaxis.iter_ticks(): tick[0].label2On = True tick[0].label1On = False tick[0].label2.set_rotation(x_tick_rot) tick[0].label2.set_fontsize('x-large') ax.set_yticks(np.arange(m.shape[0])) ax.set_yticklabels(m.index, size='x-large') ax.yaxis.set_minor_locator( mpl.ticker.FixedLocator(np.arange(-.5, m.shape[0] + 0.5))) ax.xaxis.set_minor_locator( mpl.ticker.FixedLocator(np.arange(-.5, m.shape[1] - 0.5))) ax.grid(False, which='major') ax.grid(True, which='minor', ls='-', lw=7, c='w') # Make the major and minor tick marks invisible for line in ax.xaxis.get_ticklines() + ax.yaxis.get_ticklines(): line.set_markeredgewidth(0) for line in ax.xaxis.get_minorticklines() + ax.yaxis.get_minorticklines(): line.set_markeredgewidth(0) # Limit the area of the plot ax.set_ybound([-0.5, m.shape[0] - 0.5]) ax.set_xbound([-0.5, m.shape[1] - 0.5]) # The following produces the colorbar and sets the ticks # Set the ticks - if 0 is in the interval of values, set that, as well # as the maximal and minimal values: # Extract the minimum and maximum values for scaling max_val = np.nanmax(m.arr) min_val = np.nanmin(m.arr) if min_val < 0: ticks = [color_anchor[0], min_val, 0, max_val, color_anchor[1]] # Otherwise - only set the maximal value: else: ticks = [color_anchor[0], max_val, color_anchor[1]] # Plot line separating 'nothing' and 'prior' from rest of plot l = ax.add_line(mpl.lines.Line2D( [m.shape[1] - 2.5, m.shape[1] - 2.5], [-.5, m.shape[0] - 0.5], ls='--', c='gray', lw=2)) l.set_zorder(3) # Display the actual values in the cells if plot_vals: for i in xrange(0, m.shape[0]): for j in xrange(0, m.shape[1]): val = m.arr[i, j] if np.isnan(val): continue if val > 0.5: ax.text(j - 0.2, i + 0.1, '%.2f' % val, color='w') else: ax.text(j - 0.2, i + 0.1, '%.2f' % val, color='k') # Hide the black frame around the plot # Doing ax.set_frame_on(False) results in weird thin lines # from imshow() at the edges. Instead, we set the frame to white. for spine in ax.spines.values(): spine.set_edgecolor('w') # Set title if title is not None: ax.set_title(title) # Plot the colorbar and remove its frame as well. cb = fig.colorbar(im, cax=ax_cb, orientation='horizontal', cmap=cmap, ticks=ticks, format='%.2f') cb.ax.artists.remove(cb.outline) # Save figure dirname = self.config.get_dataset_stats_dir(self) suffix = '_second_order' if second_order else '' filename = os.path.join(dirname, 'cooccur%s.png' % suffix) fig.savefig(filename) return fig