def make_interactive_plot(x, y, im_arr, x_type=None, y_type=None): # create figure and plot scatter fig = plt.figure() ax = fig.add_subplot(111) line, = ax.plot(x, y, ls="", marker="o") plt.xlabel(x_type) plt.ylabel(y_type) # create the annotations box im = OffsetImage(im_arr[0], zoom=5) xybox = (50., 50.) ab = AnnotationBbox(im, (0, 0), xybox=xybox, xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) # add it to the axes and make it invisible ax.add_artist(ab) ab.set_visible(False) # add callback for mouse moves fig.canvas.mpl_connect( 'motion_notify_event', lambda event: my_hover(event, im_arr, line, fig, xybox, ab, im, x, y)) plt.show()
def __fill_retracement_text_annotations__(self, index: int, ret: float, value: float): text = '{:=1.3f} - {:.2f}'.format(ret, value) text_area = TextArea(text, minimumdescent=True, textprops=dict(size=7)) annotation_box = AnnotationBbox(text_area, (index, value)) annotation_box.set_visible(False) self.__fib_retracement_rectangle_text_dic[ret] = annotation_box
def interactive_plot(xdata,ydata,images,supp_values=None,fig=None): """ Displays a plot of xdata, ydata and displays the image corresponding to this point when hovering over data. Parameters: xdata: array, sequence of scalar ydata: array, metric values image: 3D array, of dimensions (n_images,width,height), containing the images from which the values ydata are calculated. supp_values: list of arrays, values of a different metric """ if fig is None: fig = plt.figure() ax = fig.add_subplot(111) line, = ax.plot(xdata,ydata, ls="-", marker="o") if supp_values is not None: for ys in supp_values: ax.plot(xdata,ys) # create the annotations box if type(images)==list: images=np.asarray(images) im = OffsetImage(images[0,:,:], zoom=200/(images.shape[1]+images.shape[0])) xybox=(50., 50.) ab = AnnotationBbox(im, (0,0), xybox=xybox, xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) # add it to the axes and make it invisible ax.add_artist(ab) ab.set_visible(False) def hover(event): try: # if the mouse is over the scatter points if line.contains(event)[0]: # find out the index within the array from the event ind, = line.contains(event)[1]["ind"] # get the figure size w,h = fig.get_size_inches()*fig.dpi ws = (event.x > w/2.)*-1 + (event.x <= w/2.) hs = (event.y > h/2.)*-1 + (event.y <= h/2.) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. ab.xybox = (xybox[0]*ws, xybox[1]*hs) # make annotation box visible ab.set_visible(True) # place it at the position of the hovered scatter point ab.xy =(xdata[ind], ydata[ind]) # set the image corresponding to that point im.set_data(images[ind,:,:]) else: #if the mouse is not over a scatter point ab.set_visible(False) except Exception as e: print("Hovering error") print(e) fig.canvas.draw_idle() # add callback for mouse moves fig.canvas.mpl_connect('motion_notify_event', hover) return fig,ax
def plot_embedding_images(embedd, labels, paths, info, savefile): fig = plt.figure(figsize=(15,10)) embedding = reduce_dimensionality(embedd) ax = plt.gca() ax.set_xlim(-15, 15) ax.set_ylim(-10, 10) colors = [indexcolors[int(i) % len(indexcolors)] for i in labels.squeeze()] sc = ax.scatter(embedding[:,0],embedding[:,1], s=12, c= colors ) legend_texts = [ x[0] for x in sorted(info.items(), key=lambda kv: kv[1])] patches=[] for i,label in enumerate(legend_texts): patches.append(mpatches.Patch(color=indexcolors[i], label=label)) plt.legend(handles=patches) #plt.xlabel('Dim 1', fontsize=12) #plt.ylabel('Dim 2', fontsize=12) #plt.grid(True) for i,thumb in enumerate(paths): #print(thumb) img = PILImage.open(thumb) # img.thumbnail((16, 12), PILImage.ANTIALIAS) img.thumbnail((24, 18), PILImage.ANTIALIAS) img = OffsetImage(img, zoom=1) ab = AnnotationBbox(img, (embedding[i,0]+0.3, embedding[i,1]+0.3), xycoords='data', frameon=False) ax.add_artist(ab) ab.set_visible(True) # plt.show() plt.savefig(savefile)
def __fill_retracement_spikes_text_annotations__(self, ret: str, position_in_wave: int): if position_in_wave >= self.xy.shape[ 0]: # we don't have this component for unfinished waves return index = self.xy[position_in_wave, 0] value = self.xy[position_in_wave, 1] position = self.fib_wave.comp_position_list[position_in_wave] is_position_retracement = position_in_wave % 2 == 0 prefix = 'Retr.' if is_position_retracement else 'Reg.' value_adjusted = value + (value * 0.01 if is_position_retracement else value * -0.01) reg_ret_value = self.fib_wave.comp_reg_ret_percent_list[ position_in_wave - 1] if position_in_wave == 1: reg_ret_value = round( self.xy[position_in_wave, 1] - self.xy[position_in_wave - 1, 1], 2) text = 'P_{}={:.2f}\n{}: {:.2f}'.format(position, value, prefix, reg_ret_value) else: text = 'P_{}={:.2f}\n{}: {:=3.1f}%'.format(position, value, prefix, reg_ret_value) text_props = dict(color='crimson', backgroundcolor=self.color_bg, size=7) text_area = TextArea(text, minimumdescent=True, textprops=text_props) annotation_box = AnnotationBbox(text_area, (index, value_adjusted)) annotation_box.set_visible(False) self.__fib_retracement_spikes_text_dic[ret] = annotation_box
def annot_and_hover(x_pos, y_pos, fig, ax, bars, slider_val, df): ''' This is a function to set up the annotation and hover so as to show the text on the top of the bar x_pos : x coordinate that the pointer is pointed to y_pos : y coordinate that the pointer is pointed to fig : figure of the plot ax : axis of the plot bars : bars of the plot ''' img = df['image'][0] img_show = OffsetImage(img, zoom=0.05) annot = AnnotationBbox(img_show, xy=(0, 0), xybox=(x_pos, y_pos), xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) ax.add_artist(annot) annot.set_visible(False) # this functino work with hover func to update the annotation value and pos def update_annot(bar): x = bar.get_x() + bar.get_width( ) / 2 # to ensure the pointer is pointed to the center print('{}'.format(x)) y = bar.get_y() + bar.get_height() annot.xy = (x, y) num = int(slider_val.val + x) img = df['image'][num] img_show.set_data(img) # set up the hover event def hover(event): vis = annot.get_visible() print('slider value is{}'.format(slider_val.val)) print('x is {}'.format(event.x)) if event.inaxes == ax: for bar in bars: cont, ind = bar.contains( event ) # it will return cont when mouse is over the container if cont: update_annot(bar) annot.set_visible(True) fig.canvas.draw_idle() return # stop the function when mouse it's over any bar if vis: annot.set_visible(False) fig.canvas.draw_idle() # connect the event and call back hover functino with the background fig.convas fig.canvas.mpl_connect("button_press_event", hover)
def generate_figures(arr, x, y): # create figure and plot scatter fig = plt.figure(2) fig.suptitle("Samples association for each neuron") ax = fig.add_subplot(111) line, = ax.plot(x, y, ls="", marker="o") cell = 0 for i in range(NN_SIZE): for j in range(NN_SIZE): ax.plot(i, j, marker="D", ls="", color="red") plt.text(j, NN_SIZE - i - 1, cell) cell += 1 # create the annotations box im = OffsetImage(arr[0, :, :], zoom=5) xybox = (50., 50.) ab = AnnotationBbox(im, (0, 0), xybox=xybox, xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) # add it to the axes and make it invisible ax.add_artist(ab) ab.set_visible(False) def hover(event): # if the mouse is over the scatter points if line.contains(event)[0]: # find out the index within the array from the event ind = line.contains(event)[1]["ind"] ind = np.array(ind[0]) # get the figure size w, h = fig.get_size_inches() * fig.dpi ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.) hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. ab.xybox = (xybox[0] * ws, xybox[1] * hs) # make annotation box visible ab.set_visible(True) # place it at the position of the hovered scatter point ab.xy = (x[ind], y[ind]) # set the image corresponding to that point im.set_data(arr[ind, :, :]) else: #if the mouse is not over a scatter point ab.set_visible(False) fig.canvas.draw_idle() # add callback for mouse moves fig.canvas.mpl_connect('motion_notify_event', hover) plt.grid(True) plt.show()
def createLabelSquare(self): ta = TextArea("Test 1", minimumdescent=False) ab = AnnotationBbox(ta, xy=(0, 0), xybox=(1.02, 1), xycoords=("data", "axes fraction"), boxcoords=("data", "axes fraction"), box_alignment=(0, 0), frameon=True) ab.set_visible(False) return ta, ab
def onpick3(event): if event.button==1: cont, ind = sc.contains(event) for thumb in ind['ind']: print(paths[thumb]) img = PILImage.open(paths[thumb]) img.thumbnail((128, 96), PILImage.ANTIALIAS) img = OffsetImage(img, zoom=1) ab = AnnotationBbox(img, (embedding[thumb,0]+0.2, embedding[thumb,1]+0.2), xycoords='data', frameon=False) ax.add_artist(ab) ab.set_visible(True) t_list.append(ab) else: for annot in t_list: annot.set_visible(False) event.canvas.draw()
def add_pie(treeplot, node, values, colors=None, size=16, norm=True, xoff=0, yoff=0, halign=0.5, valign=0.5, xycoords='data', boxcoords=('offset points'), vis=True): """ Draw a pie chart Args: node (Node): A single Node object or node label values (list): A list of floats. colors (list): A list of strings to pull colors from. Optional. size (float): Diameter of the pie chart norm (bool): Whether or not to normalize the values so they add up to 360 xoff, yoff (float): X and Y offset. Optional, defaults to 0 halign, valign (float): Horizontal and vertical alignment within box. Optional, defaults to 0.5 """ x, y = xy(treeplot, node) da = DrawingArea(size, size); r = size*0.5; center = (r,r) x0 = 0 S = 360.0 if norm: S = 360.0/sum(values) if not colors: c = _tango colors = [ next(c) for v in values ] for i, v in enumerate(values): theta = v*S if v: da.add_artist(Wedge(center, r, x0, x0+theta, fc=colors[i], ec='none')) x0 += theta box = AnnotationBbox(da, (x,y), pad=0, frameon=False, xybox=(xoff, yoff), xycoords=xycoords, box_alignment=(halign,valign), boxcoords=boxcoords) treeplot.add_artist(box) box.set_visible(vis) treeplot.figure.canvas.draw_idle() return box
def plot_chemical_space(self, embedding, legend_elements, colors=None): if colors is None: colors = [0 for _ in range(len(self._pdbs))] elif colors is not None: colors = [sns.color_palette()[x] for x in colors] annotations = self._pdbs fig, ax = plt.subplots() images_path = self._im_path images = np.array([ files for files in os.listdir(images_path) if files.endswith("png") ]) image_path = sorted(np.asarray(images)) annot1 = ax.annotate("", xy=(5, 5), xytext=(1000, 1000), xycoords="figure pixels", bbox=dict(boxstyle="round", fc="w"), arrowprops=dict(arrowstyle="->")) image = plt.imread(os.path.join(images_path, image_path[0])) im = OffsetImage(image, zoom=0.6) annot = AnnotationBbox(im, xy=(0, 0), xybox=(1.2, 0.5), boxcoords="offset points") ax.add_artist(annot) annot.set_visible(False) annot1.set_visible(False) sc = plt.scatter(embedding[:, 0], embedding[:, 1], c=colors, s=200, alpha=1, edgecolors="k") if legend_elements is not None: ax.legend(handles=legend_elements) ax.set_xlabel("PCA 1", fontsize=26, labelpad=25) ax.set_ylabel("PCA 2", fontsize=26, labelpad=30) def update_annot(ind): pos = sc.get_offsets()[ind["ind"][0]] annot.xy = pos annot1.xy = pos annot1.set_text(annotations[int(ind["ind"][0])]) ax.set_title(annotations[int(ind["ind"][0])]) im.set_data( plt.imread(os.path.join(images_path, image_path[int(ind["ind"][0])]), format="png")) def hover(event): vis = annot.get_visible() if event.inaxes == ax: cont, ind = sc.contains(event) if cont: update_annot(ind) annot.set_visible(True) annot1.set_visible(False) fig.canvas.draw_idle() else: if vis: annot.set_visible(False) annot1.set_visible(False) fig.canvas.draw_idle() fig.canvas.mpl_connect("motion_notify_event", hover) plt.show()
def main(): parser = argparse.ArgumentParser( description= "Classify zones using complexity criteria (unsupervised learning)") parser.add_argument('--data', type=str, help='required data file', required=True) parser.add_argument('--clusters', type=int, help='number of expected clusters', default=2) parser.add_argument('--output', type=str, help='output folder name', required=True) args = parser.parse_args() p_data = args.data p_clusters = args.clusters p_output = args.output x_values = [] images_path = [] images = [] zones = [] scenes = [] with open(p_data, 'r') as f: for line in f.readlines(): data = line.split(';') del data[-1] scene = data[0] if scene not in scenes: scenes.append(scene) images_path.append(data[1]) zones.append(int(data[2])) img_arr = segmentation.divide_in_blocks(Image.open(data[1]), (200, 200))[int(data[2])] images.append(np.array(img_arr)) x = [] for v in data[3:]: x.append(float(v)) x_values.append(x) print(scenes) # plt.show() # TODO : save kmean model kmeans = KMeans(init='k-means++', n_clusters=p_clusters, n_init=10) labels = kmeans.fit(x_values).labels_ unique, counts = np.unique(labels, return_counts=True) print(dict(zip(unique, counts))) pca = PCA(n_components=2) x_data = pca.fit_transform(x_values) # Need to create as global variable so our callback(on_plot_hover) can access fig, ax = plt.subplots() fig.set_figheight(20) fig.set_figwidth(40) ax.tick_params(axis='both', which='major', labelsize=20) sc = plt.scatter(x_data[:, 0], x_data[:, 1], c=labels, linewidths=10) # annot = ax.annotate("", xy=(0,0), xytext=(20,20), textcoords="offset points", # bbox=dict(boxstyle="round", fc="w"), # arrowprops=dict(arrowstyle="->")) imagebox = OffsetImage(images[0], zoom=1.2) imagebox.image.axes = ax annot = AnnotationBbox(imagebox, xy=(0, 0), xybox=(-150., 150.), xycoords='data', boxcoords="offset points", pad=0.8, arrowprops=dict(arrowstyle="->")) annot.set_visible(False) ax.add_artist(annot) def update_annot(ind): imagebox = OffsetImage([images[n] for n in ind["ind"]][0], zoom=1.2) imagebox.image.axes = ax pos = sc.get_offsets()[ind["ind"][0]] setattr(annot, 'offsetbox', imagebox) annot.xy = pos # text = "{}, {}".format(" ".join(list(map(str,ind["ind"]))), # " ".join([images_path[n] for n in ind["ind"]])) # annot.text(text) # #annot.get_bbox_patch().set_facecolor(cmap(norm(c[ind["ind"][0]]))) # annot.get_bbox_patch().set_alpha(0.4) def hover(event): vis = annot.get_visible() if event.inaxes == ax: cont, ind = sc.contains(event) if cont: update_annot(ind) annot.set_visible(True) fig.canvas.draw_idle() else: if vis: annot.set_visible(False) fig.canvas.draw_idle() fig.canvas.mpl_connect("motion_notify_event", hover) #plt.show() if not os.path.exists(model_output_folder): os.makedirs(model_output_folder) model_path = os.path.join(model_output_folder, p_output + '.joblib') print('Model saved into {0}'.format(model_path)) joblib.dump(kmeans, model_path)
class ZoomPlot(): def __init__(self, pnts): self.fig = plt.figure(figsize=(15,9)) self.ax = self.fig.add_subplot(111) self.days = pnts['days'] self.lons = pnts['lons'] self.lats = pnts['lats'] self.dirs = pnts['dirs'] self.coms = pnts['coms'] self.bnds = self.bnds_strt = [-58, 80, -180, 180] self.resolution = 'c' # add callback for mouse clicks self.fig.canvas.mpl_connect('button_press_event', self.onclick) self.plot_map() def plot_map(self): self.map = Basemap(projection='merc',llcrnrlat=self.bnds[0],urcrnrlat=self.bnds[1], llcrnrlon=self.bnds[2],urcrnrlon=self.bnds[3],resolution=self.resolution) self.map.drawcoastlines() self.map.drawmapboundary(fill_color='cornflowerblue') self.map.fillcontinents(color='lightgreen', lake_color='aqua') self.map.drawcountries() self.map.drawstates() self.plot_points() self.fig.canvas.draw() self.zoomcall = self.ax.callbacks.connect('ylim_changed', self.onzoom) def onzoom(self, axes): #print('zoom triggered') self.ax.patches.clear() self.ax.collections.clear() self.ax.callbacks.disconnect(self.zoomcall) x1, y1 = self.map(self.ax.get_xlim()[0], self.ax.get_ylim()[0], inverse = True) x2, y2 = self.map(self.ax.get_xlim()[1], self.ax.get_ylim()[1], inverse = True) self.bnds = [y1, y2, x1, x2] # reset zoom to home (workaround for unidentified error when you press the home button) if any([a/b > 1 for a,b in zip(self.bnds,self.bnds_strt)]): self.bnds = self.bnds_strt # reset map boundaryies self.ax.lines.clear() # reset points self.ab.set_visible(False) # hide picture if visible # change map resolution based on zoom level zoom_set = max(abs(self.bnds[0]-self.bnds[1]),abs(self.bnds[2]-self.bnds[3])) if zoom_set < 30 and zoom_set >= 3: self.resolution = 'l' #print(' --- low resolution') elif zoom_set < 3: self.resolution = 'i' #print(' --- intermeditate resolution') else: self.resolution = 'c' #print(' --- coarse resolution') self.plot_map() def plot_points(self): self.x, self.y = self.map(self.lons, self.lats) self.line, = self.map.plot(self.x, self.y, color='darkmagenta', linestyle='none', marker='o', markeredgecolor='gold') # create the annotations box self.pic = mpimg.imread('pics\\profpic.png') # just to set up variables, will change later self.im = OffsetImage(self.pic) self.xybox = (50., 50.) self.ab = AnnotationBbox(self.im, (0,0), xybox=self.xybox, xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) # add it to the axes and make it invisible self.ax.add_artist(self.ab) self.ab.set_visible(False) def onclick(self, event): # if you click on a data point if self.line.contains(event)[0]: # find out the index within the array from the event try: ind, = self.line.contains(event)[1]["ind"] except ValueError: self.ax.text(0.5, 0.5, 'Please zoom in!', fontsize=24, ha='center', va='center', transform=self.ax.transAxes, weight='bold') self.fig.canvas.draw() time.sleep(0.7) self.ax.texts.clear() else: # get the figure size w,h = self.fig.get_size_inches()*self.fig.dpi ws = (event.x > w/2.)*-1 + (event.x <= w/2.) hs = (event.y > h/2.)*-1 + (event.y <= h/2.) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. self.ab.xybox = (self.xybox[0]*ws, self.xybox[1]*hs) # make annotation box visible self.ab.set_visible(True) # place it at the position of the hovered scatter point self.ab.xy = (self.x[ind], self.y[ind]) # set the image corresponding to that point dir = self.dirs[ind] self.im.set_data(mpimg.imread(dir)) # change zoom of the image to deal with different file sizes picsize = max(mpimg.imread(dir).shape) self.im.set_zoom(0.4*(1000/picsize)) # optimum: zoom = 0.4, picsize = 1000 else: #if you didn't click on a data point self.ab.set_visible(False) self.fig.canvas.draw_idle()
def plot_interpolation(sess, ae, data, first_image_id, second_image_id, file_list, graph_output_name=""): y, z = sess.run([ae['y'], ae['z']], feed_dict={ae['x']: np.asarray(data)}) z = np.squeeze(z) y = np.squeeze(y) if (len(z.shape) == 1): z_size = 1 z = np.expand_dims(z, axis=1) data_temp = np.hstack((z, z)) else: z_size = z.shape[1] img_size = 64 num_examples = 100 decimalPlaces = 3 currInd = 0 # first z z_first = z[first_image_id, :] #second z z_second = z[second_image_id, :] z_list = z_first + np.expand_dims(np.linspace(0, 1, num_examples), axis=1) * (z_second - z_first) x_list_first = sess.run(ae['y'], feed_dict={ ae['z']: np.reshape(z_list, (z_list.shape[0], 1, 1, z_size)) }) n_iters = 10 z_list_iter = z_list x_list = x_list_first for ind_i in range(0, n_iters): if (z_size == 1): x_list_temp = sess.run( ae['y'], feed_dict={ ae['z']: np.reshape(z_list_iter[:, 0], (z_list_iter.shape[0], 1, 1, z_size)) }) else: x_list_temp = sess.run( ae['y'], feed_dict={ ae['z']: np.reshape(z_list_iter, (z_list_iter.shape[0], 1, 1, z_size)) }) x_list = np.zeros((x_list_temp.shape[0], img_size * img_size)) for ind_i in range(0, x_list.shape[0]): x_list[ind_i, :] = np.reshape(x_list_temp[ind_i, :, :], (1, img_size * img_size)) for ind_i in range(0, x_list.shape[0]): x_list[ind_i, :] = x_list[ind_i, :] - np.min( x_list[ind_i, :].flatten()) x_list[ind_i, :] = x_list[ind_i, :] / (np.max( x_list[ind_i, :].flatten())) z_list_iter = sess.run(ae['z'], feed_dict={ae['x']: np.asarray(x_list)}) if (len(z_list_iter.shape) == 1): z_list_iter = np.expand_dims(z_list_iter, axis=1) else: z_list_iter = np.squeeze(z_list_iter) # plot result fig = plt.figure() ax = fig.add_subplot(111) #set up colours first_colour = 1 second_colour = 5 third_colour = 10 colour_list = np.vstack( (first_colour*np.ones((z.shape[0],1)) , \ second_colour*np.ones((z_list.shape[0],1)), third_colour*np.ones((z_list_iter.shape[0],1)))) plot_data = np.vstack( (z[:, 0:z_size], z_list[:, 0:z_size], z_list_iter[:, 0:z_size])) img_data = np.vstack( (y, np.reshape(x_list_first, (x_list.shape[0], img_size, img_size)), np.reshape(x_list, (x_list.shape[0], img_size, img_size)))) if (z_size == 1): plot_data = np.hstack((plot_data, plot_data)) z = np.hstack((z, z)) line = ax.scatter(plot_data[:, 0], plot_data[:, 1], s=10, c=colour_list) #,picker=True) if (graph_output_name == ""): # create the annotations box im = OffsetImage(img_data[0, :, :], zoom=5) xybox = (50., 50.) ab = AnnotationBbox(im, (0, 0), xybox=xybox, xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) # add it to the axes and make it invisible ax.add_artist(ab) ab.set_visible(False) def on_pick(event): # if the mouse is over the scatter points # find out the index within the array from the event #ind = line.contains(event)[1]["ind"] xdata, ydata = line.get_data() # get the figure size w, h = fig.get_size_inches() * fig.dpi ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.) hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. ab.xybox = (xybox[0] * ws, xybox[1] * hs) # make annotation box visible ab.set_visible(True) # place it at the position of the hovered scatter point ab.xy = (plot_data[ind[0], 0], plot_data[ind[0], 1]) # set the image corresponding to that point im.set_data(img_data[ind[0], :, :]) # fig.canvas.draw_idle() def hover(event): # if the mouse is over the scatter points if line.contains(event)[0]: # find out the index within the array from the event ind = line.contains(event)[1]["ind"] # get the figure size w, h = fig.get_size_inches() * fig.dpi ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.) hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. ab.xybox = (xybox[0] * ws, xybox[1] * hs) # make annotation box visible ab.set_visible(True) # place it at the position of the hovered scatter point ab.xy = (plot_data[ind[0], 0], plot_data[ind[0], 1]) # set the image corresponding to that point im.set_data(img_data[ind[0], :, :]) else: #if the mouse is not over a scatter point ab.set_visible(False) fig.canvas.draw_idle() # add callback for mouse moves fig.canvas.mpl_connect('motion_notify_event', hover) #fig.canvas.mpl_connect('pick_event', on_pick) plt.show() else: plt.savefig(graph_output_name)
def DrawtSNE(self): global imagesavepath, currentdate, labels, cavans, groups, labels_dict, groups3D, test_data, test_annotation from matplotlib.offsetbox import OffsetImage, AnnotationBbox global df, df_3D, Y, Y_3D print("DrawtSNE Start ") test_annotation = test_data.reshape(-1, 190, 190) # 2D tSNE plt.cla() fig, ax = plt.subplots() ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling points_with_annotation = [] for label, group in groups: name = labels_dict[label] point, = ax.plot(group.x, group.y, marker='o', linestyle='', ms=5, label=name, alpha=0.5) points_with_annotation.append([point]) plt.title('t-SNE Scattering Plot') ax.legend() cavans2D = FigureCanvas(fig) #Annotation # create the annotations box im = OffsetImage(test_annotation[0, :, :], zoom=0.25, cmap='gray') xybox = (10., 10.) ab = AnnotationBbox(im, (10, 10), xybox=xybox, xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) # add it to the axes and make it invisible ax.add_artist(ab) ab.set_visible(False) tsneprelabel = int(len(test_data) / len(labels)) def hover(event): global df, test_annotation i = 0 ispointed = np.zeros((len(groups), ), dtype=bool) for point in points_with_annotation: if point[0].contains(event)[0]: ispointed[i] = True cont, ind = point[0].contains(event) image_index = ind["ind"][0] + i * tsneprelabel # get the figure size w, h = fig.get_size_inches() * fig.dpi ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.) hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. ab.xybox = (xybox[0] * ws, xybox[1] * hs) # place it at the position of the hovered scatter point global df, test_annotation df = df ab.xy = (df['x'][image_index], df['y'][image_index]) # set the image corresponding to that point im.set_data(test_annotation[image_index, :, :]) ab.set_visible(True) else: ispointed[i] = False i = i + 1 ab.set_visible(max(ispointed)) fig.canvas.draw_idle() cid = fig.canvas.mpl_connect('motion_notify_event', hover) rows = int(self.tSNE_Layout.count()) if rows == 1: myWidget = self.tSNE_Layout.itemAt(0).widget() myWidget.deleteLater() self.tSNE_Layout.addWidget(cavans2D) print("tSNE 2D Finished") # 3D tSNE fig_3D = plt.figure() cavans3D = FigureCanvas(fig_3D) ax_3D = Axes3D(fig_3D) ax_3D.margins( 0.05) # Optional, just adds 5% padding to the autoscaling test_annotation = test_data.reshape(-1, 190, 190) for label, group in groups3D: name = labels_dict[label] ax_3D.scatter(group.x, group.y, group.z, marker='o', label=name, alpha=0.8) ax_3D.legend() ax_3D.patch.set_visible(False) ax_3D.set_axis_off() ax_3D._axis3don = False from matplotlib.offsetbox import OffsetImage, AnnotationBbox im_3D = OffsetImage(test_annotation[0, :, :], zoom=0.25, cmap='gray') xybox = (10., 10.) ab_3D = AnnotationBbox(im_3D, (10, 10), xybox=xybox, xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) # add it to the axes and make it invisible ax_3D.add_artist(ab_3D) ab_3D.set_visible(False) def onMouseMotion(event): global Y_3D distances = [] for i in range(Y_3D.shape[0]): x2, y2, _ = proj3d.proj_transform(Y_3D[i, 0], Y_3D[i, 1], Y_3D[i, 2], ax_3D.get_proj()) x3, y3 = ax_3D.transData.transform((x2, y2)) distance = np.sqrt((x3 - event.x)**2 + (y3 - event.y)**2) distances.append(distance) closestIndex = np.argmin(distances) print(closestIndex) x2, y2, _ = proj3d.proj_transform(Y_3D[closestIndex, 0], Y_3D[closestIndex, 1], Y_3D[closestIndex, 2], ax_3D.get_proj()) ab_3D.xy = (x2, y2) im_3D.set_data(test_annotation[closestIndex, :, :]) ab_3D.set_visible(True) fig_3D.canvas.draw_idle() cid3d = fig_3D.canvas.mpl_connect('motion_notify_event', onMouseMotion) # on mouse motion rows = int(self.tSNE3D_Layout.count()) if rows == 1: myWidget = self.tSNE3D_Layout.itemAt(0).widget() myWidget.deleteLater() self.tSNE3D_Layout.addWidget(cavans3D) self.movie.stop() self.movie.jumpToFrame(0) self.label_7.setText('Finished') self.label_7.setStyleSheet("color: rgb(70, 70, 70);\n" "font: 75 14pt \"MS Shell Dlg 2\";")
def visualize2DData(X, values_to_show, fig, ax, image_path, centroids, colors_scatter, datasets): """Visualize data in 2d plot with popover next to mouse position. Args: X (np.array) - array of points, of shape (numPoints, 2) fig - the figure to plot ax - the axis image_path - the images paths, of shape (N) centroids - the clusters' centroids, of shape (#C) colors_scatter - the colors to be used in scatter plot, of shape (N) datasets - the datasets names Returns: None """ cmap = plt.cm.RdYlGn line = plt.scatter(X[:, 0], X[:, 1], s=10, cmap=cmap, c=colors_scatter) plt.legend(handles=line.legend_elements()[0], labels=datasets) # create the annotations box image = plt.imread(image_path[0]) im = OffsetImage(image, zoom=0.1) xybox = (50., 50.) ab = AnnotationBbox(im, (0, 0), xybox=xybox, xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) # add it to the axes and make it invisible ax.add_artist(ab) ab.set_visible(False) # add skeleton images - to do #ab2 = AnnotationBbox(im, (0, 0), xybox=(-20,-40), xycoords='data', # boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) # add it to the axes and make it invisible #ax.add_artist(ab2) #ab2.set_visible(False) ax.scatter(centroids[:, 0], centroids[:, 1], marker='x', s=169, linewidths=3, color='w', zorder=10) def distance(point, event): """Return distance between mouse position and given data point Args: point (np.array): np.array of shape (2,), with x,y,z in data coords event (MouseEvent): mouse event (which contains mouse position in .x and .xdata) Returns: distance (np.float64): distance (in screen coords) between mouse pos and data point """ assert point.shape == ( 2, ), "distance: point.shape is wrong: %s, must be (3,)" % point.shape #x3, y3 = ax.transData.transform((x2, y2)) x2 = point[0] y2 = point[1] # event example: # motion_notify_event: xy=(374, 152) xydata=(5.013503440117894, -16.23161532314907) button=None dblclick=False inaxes=AxesSubplot(0.125,0.11;0.775x0.77) return np.sqrt((x2 - event.xdata)**2 + (y2 - event.ydata)**2) def calcClosestDatapoint(X, event): """"Calculate which data point is closest to the mouse position. Args: X (np.array) - array of points, of shape (numPoints, 3) event (MouseEvent) - mouse event (containing mouse position) Returns: smallestIndex (int) - the index (into the array of points X) of the element closest to the mouse position """ distances = [distance(X[i, 0:2], event) for i in range(X.shape[0])] return np.argmin(distances) def annotatePlot(X, index): """Create popover label in 3d chart Args: X (np.array) - array of points, of shape (numPoints, 3) index (int) - index (into points array X) of item which should be printed Returns: None """ # If we have previously displayed another label, remove it first if hasattr(annotatePlot, 'label'): annotatePlot.label.remove() # Get data point from array of points X, at position index x2 = X[index, 0] y2 = X[index, 1] annotatePlot.label = plt.annotate("Value %d" % values_to_show[index], xy=(x2, y2), xytext=(-20, -40), textcoords='offset points', ha='right', va='bottom', bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5), arrowprops=dict( arrowstyle='->', connectionstyle='arc3,rad=0')) ind = index # get the figure size w, h = fig.get_size_inches() * fig.dpi ws = (X[index, 0] > w / 2.) * -1 + (X[index, 0] <= w / 2.) hs = (X[index, 1] > h / 2.) * -1 + (X[index, 1] <= h / 2.) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. ab.xybox = (xybox[0] * ws, xybox[1] * hs) # make annotation box visible ab.set_visible(True) # place it at the position of the hovered scatter point ab.xy = (x2, y2) # set the image corresponding to that point im.set_data(plt.imread(image_path[ind])) # add skeleton images - to do # make annotation box visible #ab2.set_visible(True) # place it at the position of the hovered scatter point #ab2.xy = (x2, y2) # set the image corresponding to that point #im.set_data(plt.imread(image_path[ind])) fig.canvas.draw() def onMouseMotion(event): """Event that is triggered when mouse is moved. Shows text annotation over data point closest to mouse.""" if line.contains(event)[0]: closestIndex = calcClosestDatapoint(X, event) annotatePlot(X, closestIndex) else: # if the mouse is not over a scatter point ab.set_visible(False) # add skeleton images - to do #ab2.set_visible(False) fig.canvas.mpl_connect('motion_notify_event', onMouseMotion) # on mouse motion
def visualize3DData(X, fig, ax, image_path, C, colors_scatter, datasets): """Visualize data in 3d plot with popover next to mouse position. Args: X (np.array) - array of points, of shape (numPoints, 3) Returns: None """ #fig = plt.figure(figsize = (16,10)) #ax = fig.add_subplot(111, projection = '3d') scatter = ax.scatter(X[:, 0], X[:, 1], X[:, 2], depthshade=False, picker=True, c=colors_scatter) ax.scatter(C[:, 0], C[:, 1], C[:, 2], marker='*', c='#050505', s=1000) # create the annotations box image = plt.imread(image_path[0]) im = OffsetImage(image, zoom=0.1) xybox = (50., 50.) ab = AnnotationBbox(im, (0, 0), xybox=xybox, xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) # add it to the axes and make it invisible ax.add_artist(ab) ab.set_visible(False) def distance(point, event): """Return distance between mouse position and given data point Args: point (np.array): np.array of shape (3,), with x,y,z in data coords event (MouseEvent): mouse event (which contains mouse position in .x and .xdata) Returns: distance (np.float64): distance (in screen coords) between mouse pos and data point """ assert point.shape == ( 3, ), "distance: point.shape is wrong: %s, must be (3,)" % point.shape # Project 3d data space to 2d data space x2, y2, _ = proj3d.proj_transform(point[0], point[1], point[2], plt.gca().get_proj()) # Convert 2d data space to 2d screen space x3, y3 = ax.transData.transform((x2, y2)) return np.sqrt((x3 - event.x)**2 + (y3 - event.y)**2) def calcClosestDatapoint(X, event): """"Calculate which data point is closest to the mouse position. Args: X (np.array) - array of points, of shape (numPoints, 3) event (MouseEvent) - mouse event (containing mouse position) Returns: smallestIndex (int) - the index (into the array of points X) of the element closest to the mouse position """ distances = [distance(X[i, 0:3], event) for i in range(X.shape[0])] return np.argmin(distances) def annotatePlot(X, index): """Create popover label in 3d chart Args: X (np.array) - array of points, of shape (numPoints, 3) index (int) - index (into points array X) of item which should be printed Returns: None """ # If we have previously displayed another label, remove it first if hasattr(annotatePlot, 'label'): annotatePlot.label.remove() # Get data point from array of points X, at position index x2, y2, _ = proj3d.proj_transform(X[index, 0], X[index, 1], X[index, 2], ax.get_proj()) annotatePlot.label = plt.annotate("Value %d" % index, xy=(x2, y2), xytext=(-20, 20), textcoords='offset points', ha='right', va='bottom', bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5), arrowprops=dict( arrowstyle='->', connectionstyle='arc3,rad=0')) ind = index # get the figure size w, h = fig.get_size_inches() * fig.dpi ws = (X[index, 0] > w / 2.) * -1 + (X[index, 0] <= w / 2.) hs = (X[index, 1] > h / 2.) * -1 + (X[index, 1] <= h / 2.) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. ab.xybox = (xybox[0] * ws, xybox[1] * hs) # make annotation box visible ab.set_visible(True) # place it at the position of the hovered scatter point ab.xy = (x2, y2) # set the image corresponding to that point im.set_data(plt.imread(image_path[ind])) fig.canvas.draw() def onMouseMotion(event): """Event that is triggered when mouse is moved. Shows text annotation over data point closest to mouse.""" closestIndex = calcClosestDatapoint(X, event) annotatePlot(X, closestIndex) fig.canvas.mpl_connect('motion_notify_event', onMouseMotion) # on mouse motion plt.legend(handles=scatter.legend_elements()[0], labels=datasets, loc=2)
class PolygonHover: def __init__( self, parent, polygon_info, precision, figure, axis, selected_colour ): #takes in the polygon info, & then precision, figure, axis values self.parent = parent self.polygon_info = polygon_info self.precision = precision self.fig = figure self.ax = axis self.selected_polygon_colour = selected_colour def hover(self, event): # print("HOVER INFO", self.polygon_info) #ImportanceOfBeingErnest (2017) Possible to make labels appear when hovering over a point in matplotlib? [Online]. Available at: https://stackoverflow.com/questions/7908636/possible-to-make-labels-appear-when-hovering-over-a-point-in-matplotlib [Accessed 03 August 2020]. hovered_over = False #Using a Boolean for hovered over true or not as it is on mouse move motion event, so need to destroy the bbox when not over it if event.xdata != None: #None if the cursor is not on the figure. for polygon in self.polygon_info: for x, y in polygon['co-ordinates']: if (np.abs(x - event.xdata) < self.precision) and ( np.abs(y - event.ydata) < self.precision ): #if hovered over a point part of a polygon #Build hover label with info hovered_over = True #change to true self.ax.artists.clear() #clear all existing labels string = "Slice: {} \nId: {} \nTag: {} \nx: {} \ny: {}".format( polygon['slice'], polygon['id'], polygon['tag'], str(x)[:6], str(y)[:6]) #as coordinates can be quite long self.offsetbox = TextArea(string, minimumdescent=False) self.ab = AnnotationBbox(self.offsetbox, (0, 0), xybox=(50., 50.), xycoords='data', boxcoords="offset points", pad=0.5) #arrowprops=dict(arrowstyle='->, head_width=.5', color='white', linewidth=1, mutation_scale=.5 #Could use Arrow self.ax.add_artist(self.ab) #add box to axis #change the colour of the lines for line in polygon['lines']: for plt in self.ax.lines: if line == plt: plt.set_color("blue") # get the figure size w, h = self.fig.get_size_inches() * self.fig.dpi ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.) hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. self.ab.xybox = (50 * ws, 50 * hs) # make annotation box visible self.ab.set_visible(True) # place it at the position of the hovered scatter point self.ab.xy = (x, y) if not hovered_over: #clear existing labels self.ax.artists.clear() #change the colour of the lines self.parent.reset_polygon_cols( False) #but not selected polygon if one if self.parent.selected_polygon != None: self.parent.show_selected_plots( self.parent.selected_polygon['scatter_points'], self.ax.collections, self.selected_polygon_colour) self.parent.show_selected_plots( self.parent.selected_polygon['lines'], self.ax.lines, self.selected_polygon_colour) self.fig.canvas.draw() else: #clear existing labels self.ax.artists.clear() #change the colour of the lines self.parent.reset_polygon_cols(False) if self.parent.selected_polygon != None: self.parent.show_selected_plots( self.parent.selected_polygon['scatter_points'], self.ax.collections, self.selected_polygon_colour) self.parent.show_selected_plots( self.parent.selected_polygon['lines'], self.ax.lines, self.selected_polygon_colour) self.fig.canvas.draw() #If selected polygon colour needs changing by settings change def change_selected_colour(self, new_col): self.selected_polygon_colour = new_col #If precision value needs to be updated by settings change def update_precision_value(self, new_precision_val): self.precision = new_precision_val def update_polygon_info(self, new_polygon_info): self.polygon_info = new_polygon_info
class Discovery(object): def __init__(self, embeddings_file='embeddings.p', distance_metric='euclid', method='TSNE', embedding_size=2, overwrite_embeddings=False, n_jobs=10, dpi=300, main_plot_args={}, tsne_args={}, save_dir=join(CONFIG.metaseg_io_path, 'vis_embeddings')): """Loads the embedding files, computes the dimensionality reductions and calls the initilization of the main plot. Args: embeddings_file (str): Path to the file where all data of segments including feature embeddings is saved. distance_metric (str): Distance metric to use for nearest neighbor computation. method (str): Method to use for dimensionality reduction of nearest neighbor embeddings. For plotting the points are always reduced in dimensionality using PCA to 50 dimensions followed by t-SNE to two dimensions. embedding_size (int): Dimensionality of the feature embeddings used for nearest neighbor search. overwrite_embeddings (bool): If True, precomputed nearest neighbor and plotting embeddings from previous runs are overwritten with freshly computed ones. Otherwise precomputed embeddings are used if requested embedding_size is matching. n_jobs (int): Number of processes to use for t-SNE computation. dpi (int): Dots per inch for graphics that are saved to disk. main_plot_args (dict): Keyword arguments for the creation of the main plot. tsne_args (dict): Keyword arguments for the t-SNE algorithm. save_dir (str): Path to the directory where saved images are placed in. """ self.log = logging.getLogger('Discovery') self.embeddings_file = embeddings_file self.distance_metrics = ['euclid', 'cos'] self.dm = 0 if distance_metric not in self.distance_metrics else self.distance_metrics.index(distance_metric) self.dpi = dpi self.save_dir = save_dir os.makedirs(self.save_dir, exist_ok=True) self.cluster_methods = OrderedDict() self.cluster_methods['kmeans'] = {'main': KMeans, 'kwargs': {}} # self.cluster_methods['spectral'] = {'main': SpectralClustering, 'kwargs': {}} self.cluster_methods['agglo'] = {'main': AgglomerativeClustering, 'kwargs': {'linkage': 'ward'}} self.methods_with_ncluster_param = ['kmeans', 'spectral', 'agglo'] self.cme = 0 self.clustering = None self.n_clusters = 25 # colors: self.standard_color = (0, 0, 1, 1) self.current_color = (1, 0, 0, 1) self.nn_color = (1, 0, 1, 1) self.log.info('Loading data...') with open(self.embeddings_file, 'rb') as f: self.data = pkl.load(f) self.iou_preds = self.data['iou_pred'] self.gt = np.array(self.data['gt']).flatten() self.pred = np.array(self.data['pred']).flatten() self.gi = self.data['image_level_index'] # global indices (on image level and not on component level) self.log.info('Loaded {} segment embeddings.'.format(self.pred.shape[0])) self.nearest_neighbors = None if len(self.data['embeddings']) == 1: self.data['plot_embeddings'] = np.array([self.data['embeddings'][0][0], self.data['embeddings'][0][1]]).reshape((1, 2)) self.data['nn_embeddings'] = self.data['plot_embeddings'] else: if ('nn_embeddings' not in self.data.keys() or overwrite_embeddings or 'plot_embeddings' not in self.data.keys()) \ and embedding_size < self.data['embeddings'][0].shape[0]: self.log.info('Computing PCA...') n_comp = 50 if 50 < min(len(self.data['embeddings']), self.data['embeddings'][0].shape[0]) else min(len(self.data['embeddings']), self.data['embeddings'][0].shape[0]) embeddings = PCA( n_components=n_comp ).fit_transform(np.stack(self.data['embeddings']).reshape((-1, self.data['embeddings'][0].shape[0]))) rewrite = True else: rewrite = False if 'plot_embeddings' not in self.data.keys() or overwrite_embeddings: self.log.info('Computing t-SNE for plotting') self.data['plot_embeddings'] = TSNE(n_components=2, **tsne_args).fit_transform(embeddings) new_plot_embeddings = True else: new_plot_embeddings = False if embedding_size >= self.data['embeddings'][0].shape[0] or embedding_size is None: self.embeddings = np.stack(self.data['embeddings']).reshape((-1, self.data['embeddings'][0].shape[0])) self.log.debug( ('Requested embedding size of {} was greater or equal to data dimensionality of {}. ' 'Data has thus not been reduced in dimensionality.').format( embedding_size, self.data['embeddings'].shape[1])) elif (self.data['nn_embeddings'].shape[1] == embedding_size if 'nn_embeddings' in self.data.keys() else False) \ and not overwrite_embeddings: self.embeddings = self.data['nn_embeddings'] self.log.info(('Loaded reduced embeddings ({} dimensions) from precomputed file ' + 'for nearest neighbor search.').format(self.embeddings.shape[1])) else: if method == 'TSNE': if 'plot_embeddings' in self.data.keys() and embedding_size == 2 and new_plot_embeddings: self.embeddings = self.data['plot_embeddings'] self.log.info('Reused the precomputed manifold for plotting for nearest neighbor search.') else: self.log.info('Computing t-SNE of dimension {} for nearest neighbor search...'.format( embedding_size)) self.embeddings = TSNE( n_components=embedding_size, n_jobs=n_jobs, **tsne_args ).fit_transform(embeddings) else: self.log.info('Computing Isomap of dimension {} for nearest neighbor search...'.format(embedding_size)) self.embeddings = Isomap(n_components=embedding_size, n_jobs=n_jobs, ).fit_transform(embeddings) self.data['nn_embeddings'] = self.embeddings # Write added data into pickle file if rewrite: with open(self.embeddings_file, 'wb') as f: pkl.dump(self.data, f) self.x = self.data['plot_embeddings'][:, 0] self.y = self.data['plot_embeddings'][:, 1] self.label_mapping = dict() for d in np.unique(self.data['dataset']).flatten(): self.label_mapping[d] = getattr( importlib.import_module(datasets[d].module_name), datasets[d].class_name)( **datasets[d].kwargs, ).label_mapping train_dat = self.label_mapping[CONFIG.TRAIN_DATASET.name] = getattr( importlib.import_module(CONFIG.TRAIN_DATASET.module_name), CONFIG.TRAIN_DATASET.class_name)( **CONFIG.TRAIN_DATASET.kwargs, ) self.pred_mapping = train_dat.pred_mapping if CONFIG.TRAIN_DATASET.name not in self.label_mapping: self.label_mapping[CONFIG.TRAIN_DATASET.name] = train_dat.label_mapping self.tnsize = (50, 50) self.fig_nn = None self.n_neighbors = 49 self.current_pressed_key = None self.plot_main(**main_plot_args) def plot_main(self, **plot_args): """Initializes the main plot. Only 'legend' (bool) is currently supported as keyword argument. """ self.fig_main = plt.figure(num=1) self.fig_main.canvas.set_window_title('Embedding space') ax = self.fig_main.add_subplot(111) ax.set_axis_off() self.line_main = ax.scatter(self.x, self.y, marker="o", color=np.stack([tuple(i / 255.0 for i in self.label_mapping[ self.data['dataset'][self.gi[ind]] ][self.gt[ind]][1]) + (1.0,) for ind in range(self.x.shape[0])])) self.line_main.set_picker(True) if plot_args['legend'] if 'legend' in plot_args else False: box = ax.get_position() ax.set_position([box.x0, box.y0, box.width, box.height * 0.8]) legend_elements = [] for d in np.unique(self.data['dataset']).flatten(): cls = np.unique(self.gt[np.array(self.data['dataset'])[self.gi] == d]) cls = list({(self.label_mapping[d][cl][0], self.label_mapping[d][cl][1]) for cl in cls}) names = np.array([i[0] for i in cls]) cols = np.array([i[1] for i in cls]) legend_elements += [Patch(color=tuple(i / 255.0 for i in cols[i]) + (1.0,), label=names[i] if not names[i][-1].isdigit() else names[i][:names[i].rfind(' ')]) for i in range(names.shape[0])] ax.legend(loc='upper left', handles=legend_elements, ncol=8, bbox_to_anchor=(0, 1.2)) self.basecolors = self.line_main.get_facecolor() tmp = Image.open(self.data['image_path'][self.gi[0]]).convert('RGB').crop( self.data['box'][0]) tmp.thumbnail(self.tnsize, Image.ANTIALIAS) self.im = OffsetImage(tmp, zoom=2) self.xybox = (50., 50.) self.ab = AnnotationBbox(self.im, (0, 0), xybox=self.xybox, xycoords='data', boxcoords='offset points', pad=0.3, arrowprops=dict(arrowstyle='->')) ax.add_artist(self.ab) self.ab.set_visible(False) if plot_args['save_path'] is not None if 'save_path' in plot_args else False: plt.savefig(expanduser(plot_args['save_path']), dpi=300, bbox_inches='tight') else: self.fig_main.canvas.mpl_connect('motion_notify_event', self.hover_main) self.fig_main.canvas.mpl_connect('button_press_event', self.click_main) self.fig_main.canvas.mpl_connect('scroll_event', self.scroll) self.fig_main.canvas.mpl_connect('key_press_event', self.key_press) self.fig_main.canvas.mpl_connect('key_release_event', self.key_release) plt.show() def hover_main(self, event): """Action handler for the main plot. This function shows a thumbnail of the underlying image when a scatter point is hovered with the mouse. """ # if the mouse is over the scatter points if self.line_main.contains(event)[0]: # find out the index within the array from the event ind, *_ = self.line_main.contains(event)[1]["ind"] # get the figure size w, h = self.fig_main.get_size_inches() * self.fig_main.dpi ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.) hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. self.ab.xybox = (self.xybox[0] * ws, self.xybox[1] * hs) # make annotation box visible self.ab.set_visible(True) # place it at the position of the hovered scatter point self.ab.xy = (self.x[ind], self.y[ind]) # set the image corresponding to that point tmp = Image.open(self.data['image_path'][self.gi[ind]]).convert('RGB').crop(self.data['box'][ind]) tmp.thumbnail(self.tnsize, Image.ANTIALIAS) self.im.set_data(tmp) tmp.close() else: # if the mouse is not over a scatter point self.ab.set_visible(False) self.fig_main.canvas.draw_idle() def click_main(self, event): """Action handler for the main plot. This function shows a single or full image or displays nearest neighbors based on the button that has been pressed and which scatter point was pressed. """ if self.line_main.contains(event)[0]: ind, *_ = self.line_main.contains(event)[1]['ind'] if self.current_pressed_key == 't' and event.button == 1: self.store_thumbnail(ind) elif self.current_pressed_key == 'control' and event.button == 1: self.show_single_image(ind, save=True) elif self.current_pressed_key == 'control' and event.button == 2: self.show_full_image(ind, save=True) elif event.button == 1: # left mouse button self.show_single_image(ind) elif event.button == 2: # middle mouse button self.show_full_image(ind) elif event.button == 3: # right mouse button if not plt.fignum_exists(2): # nearest neighbor figure is not open anymore or has not been opened yet self.log.info('Loading nearest neighbors...') self.nearest_neighbors = self.get_nearest_neighbors(ind, metric=self.distance_metrics[self.dm]) thumbnails = [] for neighbor_ind in self.nearest_neighbors: thumbnails.append(Image.open(self.data['image_path'][self.gi[neighbor_ind]]).crop( self.data['box'][neighbor_ind])) columns = math.ceil(math.sqrt(self.n_neighbors)) rows = math.ceil(self.n_neighbors / columns) self.fig_nn = plt.figure(num=2, dpi=self.dpi) self.fig_nn.canvas.set_window_title('{} nearest neighbors to selected image'.format( self.n_neighbors)) for p in range(columns * rows): ax = self.fig_nn.add_subplot(rows, columns, p + 1) ax.set_axis_off() if p < len(thumbnails): ax.imshow(np.asarray(thumbnails[p])) self.fig_nn.canvas.mpl_connect('button_press_event', self.click_nn) self.fig_nn.canvas.mpl_connect('key_press_event', self.key_press) self.fig_nn.canvas.mpl_connect('key_release_event', self.key_release) self.fig_nn.canvas.mpl_connect('scroll_event', self.scroll) self.fig_nn.show() else: # nearest neighbor figure is already open. Update the figure with new nearest neighbor self.update_nearest_neighbors(ind) return self.set_color(ind, self.current_color) self.flush_colors() def click_nn(self, event): """Action handler for the nearest neighbor window. When clicking a cropped segment in the nearest neighbor window the same actions are taken as in the click handler for the main plot. """ if event.inaxes in self.fig_nn.axes: ind = self.get_ind_nn(event) if self.current_pressed_key == 't' and event.button == 1: self.store_thumbnail(self.nearest_neighbors[ind]) elif self.current_pressed_key == 'control' and event.button == 1: self.show_single_image(self.nearest_neighbors[ind], save=True) elif self.current_pressed_key == 'control' and event.button == 2: self.show_full_image(self.nearest_neighbors[ind], save=True) elif event.button == 1: # left mouse button self.show_single_image(self.nearest_neighbors[ind]) elif event.button == 2: # middle mouse button self.show_full_image(self.nearest_neighbors[ind]) elif event.button == 3: # right mouse button self.update_nearest_neighbors(self.nearest_neighbors[ind]) def key_press(self, event): """Performs different actions based on pressed keys.""" if event.key == 'm': self.dm += 1 self.dm = self.dm % len(self.distance_metrics) self.log.info('Changed distance metric to {}'.format(self.distance_metrics[self.dm])) elif event.key == '#': self.cme += 1 self.cme = self.cme % len(self.cluster_methods) self.log.info('Changed clustering method to {}'.format(list(self.cluster_methods.keys())[self.cme])) elif event.key == 'c': self.log.info('Started clustering with {}...'.format(list(self.cluster_methods.keys())[self.cme])) self.cluster(method=list(self.cluster_methods.keys())[self.cme]) self.fig_main.axes[0].get_legend().remove() self.basecolors = cm.get_cmap('viridis', (max(self.clustering) + 1))(self.clustering) self.flush_colors() self.cluster_statistics() elif event.key == 'x': if self.clustering is not None: self.cluster_statistics() elif event.key == 'g': self.color_gt() elif event.key == 'h': self.color_pred() elif event.key == 'b': self.set_color(list(range(self.basecolors.shape[0])), self.standard_color) self.flush_colors() self.current_pressed_key = event.key self.log.debug('Key \'{}\' pressed.'.format(event.key)) def key_release(self, event): """Clears the variable where the last pressed key is saved.""" self.current_pressed_key = None self.log.debug('Key \'{}\' released.'.format(event.key)) def scroll(self, event): """Increases or decreases number of nearest neighbors when scrolling on the main or nearest neighbor plot.""" if event.button == 'up': self.n_neighbors += 1 self.log.info('Increased number of nearest neighbors to {}'.format(self.n_neighbors)) elif event.button == 'down': if self.n_neighbors > 0: self.n_neighbors -= 1 self.log.info('Decreased number of nearest neighbors to {}'.format(self.n_neighbors)) def show_single_image(self, ind, save=False): """Displays the full image belonging to a segment. The segment is marked with a red bounding box.""" self.log.info('{} image...'.format('Saving' if save else 'Loading')) img_box = self.draw_box_on_image(ind) fig_tmp = plt.figure(max(3, max(plt.get_fignums()) + 1), dpi=self.dpi) ax = fig_tmp.add_subplot(111) ax.set_axis_off() ax.imshow(np.asarray(img_box), interpolation='nearest') if save: fig_tmp.subplots_adjust(bottom=0, left=0, right=1, top=1, hspace=0, wspace=0) ax.margins(0.05, 0.05) fig_tmp.gca().xaxis.set_major_locator(plt.NullLocator()) fig_tmp.gca().yaxis.set_major_locator(plt.NullLocator()) fig_tmp.savefig(join(self.save_dir, 'image_{}.jpg'.format(ind)), bbox_inches='tight', pad_inches=0.0) self.log.debug('Saved image to {}'.format(join(self.save_dir, 'image_{}.jpg'.format(ind)))) else: fig_tmp.canvas.set_window_title('Dataset: {}, Image index: {}'.format(self.data['dataset'][self.gi[ind]], self.data['image_index'][ self.gi[ind]])) fig_tmp.tight_layout(pad=0.0) fig_tmp.show() def show_full_image(self, ind, save=False): """Displays four panels of the full image belonging to a segment. Top left: Entropy heatmap of prediction. Top right: Predicted IoU of each segment. Bottom left: Source image with ground truth overlay. Bottom right: Predicted semantic segmentation. """ self.log.info('{} detailed image...'.format('Saving' if save else 'Loading')) box = self.data['box'][ind] image = np.asarray(Image.open(self.data['image_path'][self.gi[ind]]).convert('RGB')) image_index = self.data['image_index'][self.gi[ind]] iou_pred = self.data['iou_pred'][self.gi[ind]] dataset = self.data['dataset'][self.gi[ind]] model_name = self.data['model_name'][self.gi[ind]] pred, gt, image_path = probs_gt_load(image_index, input_dir=join(CONFIG.metaseg_io_path, 'input', model_name, dataset)) components = components_load(image_index, components_dir=join(CONFIG.metaseg_io_path, 'components', model_name, dataset)) e = entropy(pred) pred = pred.argmax(2) predc = np.asarray([self.pred_mapping[pred[ind_i, ind_j]][1] for ind_i in range(pred.shape[0]) for ind_j in range(pred.shape[1])]).reshape(image.shape) gtc = np.asarray([self.label_mapping[dataset][gt[ind_i, ind_j]][1] for ind_i in range(gt.shape[0]) for ind_j in range(gt.shape[1])]).reshape(image.shape) overlay_factor = [1.0, 0.5, 1.0] img_predc, img_gtc, img_entropy = [ Image.fromarray(np.uint8(arr * overlay_factor[i] + image * (1 - overlay_factor[i]))) for i, arr in enumerate([predc, gtc, cm.jet(e)[:, :, :3] * 255.0])] img_ioupred = Image.fromarray(self.visualize_segments(components, iou_pred)) images = [img_gtc, img_predc, img_entropy, img_ioupred] box_line_width = 5 left, upper = max(0, box[0] - box_line_width), max(0, box[1] - box_line_width) right, lower = min(pred.shape[1], box[2] + box_line_width), min(pred.shape[0], box[3] + box_line_width) for k in images: draw = ImageDraw.Draw(k) draw.rectangle([left, upper, right, lower], outline=(255, 0, 0), width=box_line_width) del draw for k in range(len(images)): images[k] = np.asarray(images[k]).astype('uint8') img_top = np.concatenate(images[2:], axis=1) img_bottom = np.concatenate(images[:2], axis=1) img_total = np.concatenate((img_top, img_bottom), axis=0) fig_tmp = plt.figure(max(3, max(plt.get_fignums()) + 1), dpi=self.dpi) fig_tmp.canvas.set_window_title('Dataset: {}, Image index: {}'.format(dataset, image_index)) ax = fig_tmp.add_subplot(111) ax.set_axis_off() ax.imshow(img_total, interpolation='nearest') if save: fig_tmp.subplots_adjust(bottom=0, left=0, right=1, top=1, hspace=0, wspace=0) ax.margins(0.05, 0.05) fig_tmp.gca().xaxis.set_major_locator(plt.NullLocator()) fig_tmp.gca().yaxis.set_major_locator(plt.NullLocator()) fig_tmp.savefig(join(self.save_dir, 'detailed_image_{}.jpg'.format(ind)), bbox_inches='tight', pad_inches=0.0) self.log.debug('Saved image to {}'.format(join(self.save_dir, 'detailed_image_{}.jpg'.format(ind)))) else: fig_tmp.tight_layout(pad=0.0) fig_tmp.show() def store_thumbnail(self, ind): """Stores a thumbnail of a segment if requested. Thus is not saving the whole image but only the cropped part. """ image = Image.open(self.data['image_path'][self.gi[ind]]).convert('RGB') image = image.crop(self.data['box'][ind]) name = self.label_mapping[self.data['dataset'][self.gi[ind]]][self.gt[ind]][0] if name[-1].isdigit(): name = name[:-2] name = name.replace(' ', '_') image.save(join(self.save_dir, 'thumbnail_{}_{:0>2.1f}_{:0>2.1f}.jpg'.format( name, self.x[ind], self.y[ind]))) self.log.debug('Saved thumbnail to {}'.format(join(self.save_dir, 'thumbnail_{}_{:0>2.1f}_{:0>2.1f}.jpg'.format( name, self.x[ind], self.y[ind])))) def get_nearest_neighbors(self, ind, metric='cos'): """Computes nearest neighbors to the specified index in the collection of segment crops.""" if metric == 'euclid': dists = self.lp_dist(self.embeddings[ind], self.embeddings, d=2) else: dists = self.cos_dist(self.embeddings[ind], self.embeddings) return np.argsort(dists)[1:(self.n_neighbors + 1)] def update_nearest_neighbors(self, ind): """If requesting nearest neighbors of a segment within the nearest neighbor plot window the nearest neighbors are updated according to the newly selected segment. """ self.log.info('Loading nearest neighbors...') self.nearest_neighbors = self.get_nearest_neighbors(ind, metric=self.distance_metrics[self.dm]) thumbnails = [] for neighbor_ind in self.nearest_neighbors: b = self.data['box'][neighbor_ind] thumbnails.append(plt.imread(self.data['image_path'][self.gi[neighbor_ind]])[b[1]:b[3], b[0]:b[2], :]) n = math.ceil(math.sqrt(len(self.nearest_neighbors))) if len(self.fig_nn.axes) != (n ** 2): # number of nearest neighbors has been changed # redefine number of subplots in fig_nn self.rearrange_axes(n, n) for p in range(n ** 2): if p < self.n_neighbors: self.fig_nn.axes[p].imshow(thumbnails[p]) else: self.fig_nn.axes[p].clear() self.fig_nn.axes[p].set_axis_off() self.fig_nn.canvas.draw() self.set_color(ind, self.current_color) self.flush_colors() def cluster(self, method='kmeans'): if method in self.methods_with_ncluster_param: n_clusters = self.n_cluster_prompt() if n_clusters == 'elbow' and method == 'kmeans': n_clusters = self.elbow() self.clustering = self.cluster_methods[method]['main']( n_clusters=n_clusters, **self.cluster_methods[method]['kwargs']).fit_predict(self.embeddings) def cluster_statistics(self): fig_clstats = plt.figure(max(3, max(plt.get_fignums()) + 1)) clusters = np.unique(self.clustering).flatten() n = math.ceil(math.sqrt(clusters.shape[0])) all_label_names = [] self.log.debug('Size of cluster statistics plots: {}'.format(n)) for i in range(n ** 2): ax = fig_clstats.add_subplot(n, n, i + 1) if i < clusters.shape[0]: # labels, label_counts = np.unique(self.gt[self.clustering == clusters[i]], return_counts=True) label_names = [] cols = [] for j in range(self.clustering.shape[0]): if self.clustering[j] == clusters[i]: dat = self.data['dataset'][self.gi[j]] name = self.label_mapping[dat][self.gt[j]][0] label_names.append((name, self.label_mapping[dat][self.gt[j]][1])) if i == (clusters.shape[0] - 1): missing = [all_label_names[ind] for ind in np.unique([lbl[0] for lbl in all_label_names], return_index=True)[1]] label_names += missing labels, label_inds, label_counts = np.unique([lbl[0] for lbl in label_names], return_counts=True, return_index=True) explode = np.zeros(labels.shape[0]) explode[np.argmax(label_counts)] = 0.2 cols = np.array([label_names[ind][1] for ind in label_inds]) perm = np.argsort(label_counts)[::-1] all_label_names += [label_names[i] for i in np.unique([lbl[0] for lbl in label_names], return_index=True)[1]] ax.pie(label_counts[perm], # labels=labels[perm], colors=cols[perm], explode=explode[perm], autopct=lambda perc: '{:1.1f}%'.format(perc) if perc > 20 else '', # shadow=True, startangle=90, wedgeprops=dict(edgecolor='w'), textprops=dict(color="g")) ax.axis('equal') else: ax.set_axis_off() fig_clstats.legend(fig_clstats.axes[clusters.shape[0] - 1].patches, labels[perm], loc='upper left') fig_clstats.show() def elbow(self): low = int(input('Enter the minimum number of clusters: ')) high = int(input('Enter the maximum number of clusters: ')) km = [KMeans(n_clusters=i) for i in range(low, high + 1)] km = [k.fit(self.embeddings) for k in tqdm(km, total=len(km))] score = [k.inertia_ for k in km] fig_elbow = plt.figure(max(3, max(plt.get_fignums()) + 1)) ax = fig_elbow.add_subplot(111) ax.plot(range(low, high + 1), score) fig_elbow.show() return int(input('Enter number of clusters: ')) def n_cluster_prompt(self): # inp = input('Enter the number of clusters. Typing \'elbow\' will start the elbow process.') inp = input('Enter the number of clusters: ') if inp == 'elbow': return inp else: try: inp = int(inp) except ValueError: self.log.error('Input should be an int or \'elbow\' but received {}!'.format(inp)) return 'error' if inp <= 1: raise ValueError('Number of clusters should be greater than 1!') else: return inp def rearrange_axes(self, nrows, ncols): """Helper function for the nearest neighbor plot window. If number of nearest neighbors has been changed and a new query segment has been chosen the arrangement of subplots within the window has to be changed. """ n = len(self.fig_nn.axes) if n <= (nrows * ncols): # we need to add more axes for i, ax in enumerate(self.fig_nn.axes): ax.change_geometry(nrows, ncols, i + 1) for p in range(n, nrows * ncols): ax = self.fig_nn.add_subplot(nrows, ncols, p + 1) ax.set_axis_off() else: # we need to remove some axes for p in range(n - 1, (nrows * ncols) - 1, -1): self.fig_nn.delaxes(self.fig_nn.axes[p]) for i, ax in enumerate(self.fig_nn.axes): ax.change_geometry(nrows, ncols, i + 1) def draw_box_on_image(self, ind): """Draws the red bounding of a selected segment onto the source image.""" box_line_width = 5 img_box = Image.open(self.data['image_path'][self.gi[ind]]).convert('RGB') draw = ImageDraw.Draw(img_box) left, upper, right, lower = self.data['box'][ind] left, upper = max(0, left - box_line_width), max(0, upper - box_line_width) right, lower = min(img_box.size[0], right + box_line_width), min(img_box.size[1], lower + box_line_width) draw.rectangle([left, upper, right, lower], outline=(255, 0, 0), width=box_line_width) del draw return img_box @staticmethod def visualize_segments(comp, metric): """Helper function for generation of the four panels in the detailed image function.""" r = np.asarray(metric) r = 1 - 0.5 * r g = np.asarray(metric) b = 0.3 + 0.35 * np.asarray(metric) r = np.concatenate((r, np.asarray([0, 1]))) g = np.concatenate((g, np.asarray([0, 1]))) b = np.concatenate((b, np.asarray([0, 1]))) components = np.asarray(comp.copy(), dtype='int16') components[components < 0] = len(r) - 1 components[components == 0] = len(r) img = np.zeros(components.shape + (3,)) x, y = np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1])) x = x.reshape(-1) y = y.reshape(-1) img[x, y, 0] = r[components[x, y] - 1] img[x, y, 1] = g[components[x, y] - 1] img[x, y, 2] = b[components[x, y] - 1] img = np.asarray(255 * img).astype('uint8') return img @staticmethod def lp_dist(point, all_points, d=2): """Calculates the L_p distance from a point to a collection of points. Used for retrieval.""" return ((all_points - point) ** d).sum(1) ** (1.0 / d) @staticmethod def cos_dist(point, all_points): """Calculates the cosine distance from a point to a collection of points. Used for retrieval.""" return 1 - ((point * all_points).sum(1) / (norm(point) * norm(all_points, axis=1))) @staticmethod def get_gridsize(fig): """Helper function for the nearest neighbor plot.""" return fig.axes[0].get_subplotspec().get_gridspec().get_geometry() def get_ind_nn(self, event): """Helper function for the nearest neighbor plot.""" _, ncols = self.get_gridsize(self.fig_nn) eventrow = event.inaxes.rowNum eventcol = event.inaxes.colNum return (eventrow * ncols) + eventcol def color_gt(self): """When called colors the scatter in the main plot according to the ground truth colors.""" self.basecolors = np.stack([tuple(i / 255.0 for i in self.label_mapping[self.data['dataset'][self.gi[ind]]][ self.gt[ind] ][1]) + (1.0,) for ind in range(self.basecolors.shape[0])]) legend_elements = [] for d in np.unique(self.data['dataset']).flatten(): cls = np.unique(self.gt[np.array(self.data['dataset'])[self.gi] == d]) cls = list({(self.label_mapping[d][cl][0], self.label_mapping[d][cl][1]) for cl in cls}) names = np.array([i[0] for i in cls]) cols = np.array([i[1] for i in cls]) legend_elements += [Patch(color=tuple(i / 255.0 for i in cols[i]) + (1.0,), label=names[i] if not names[i][-1].isdigit() else names[i][:names[i].rfind(' ')]) for i in range(names.shape[0])] self.fig_main.axes[0].legend(loc='upper left', handles=legend_elements, ncol=8, bbox_to_anchor=(0, 1.2)) self.flush_colors() def color_pred(self): """When called colors the scatter in the main plot according the predicted class color.""" self.basecolors = np.stack([tuple(i / 255.0 for i in self.pred_mapping[self.pred[ind]][1]) + (1.0,) for ind in range(self.basecolors.shape[0])]) legend_elements = [Patch(color=tuple(i / 255.0 for i in self.pred_mapping[cl][1]) + (1.0,), label=self.pred_mapping[cl][0]) for cl in np.unique(self.pred).flatten()] self.fig_main.axes[0].legend(loc='upper left', handles=legend_elements, ncol=8, bbox_to_anchor=(0, 1.2)) self.flush_colors() def set_color(self, ind, color): """Helper function to set a color of a segment with index ind.""" self.basecolors[ind, :] = color def change_color(self, old_color, new_color): """Helper function to change a specific color to a different one.""" self.basecolors[(self.basecolors == old_color).all(axis=1)] = new_color def flush_colors(self): """Flushes all color changes onto the main scatter plot.""" self.line_main.set_color(self.basecolors) self.fig_main.canvas.draw()
class Discovery(object): def __init__( self, embeddings_file="embeddings.p", distance_metric="euclid", method="TSNE", embedding_size=2, overwrite_embeddings=False, n_jobs=10, dpi=300, main_plot_args={}, tsne_args={}, save_dir=join(CONFIG.metaseg_io_path, "vis_embeddings"), ): """Loads the embedding files, computes the dimensionality reductions and calls the initilization of the main plot. Args: embeddings_file (str): Path to the file where all data of segments including feature embeddings is saved. distance_metric (str): Distance metric to use for nearest neighbor computation. method (str): Method to use for dimensionality reduction of nearest neighbor embeddings. For plotting the points are always reduced in dimensionality using PCA to 50 dimensions followed by t-SNE to two dimensions. embedding_size (int): Dimensionality of the feature embeddings used for nearest neighbor search. overwrite_embeddings (bool): If True, precomputed nearest neighbor and plotting embeddings from previous runs are overwritten with freshly computed ones. Otherwise precomputed embeddings are used if requested embedding_size is matching. n_jobs (int): Number of processes to use for t-SNE computation. dpi (int): Dots per inch for graphics that are saved to disk. main_plot_args (dict): Keyword arguments for the creation of the main plot. tsne_args (dict): Keyword arguments for the t-SNE algorithm. save_dir (str): Path to the directory where saved images are placed in. """ self.log = logging.getLogger("Discovery") self.embeddings_file = embeddings_file self.distance_metrics = ["euclid", "cos"] self.dm = (0 if distance_metric not in self.distance_metrics else self.distance_metrics.index(distance_metric)) self.dpi = dpi self.save_dir = save_dir os.makedirs(self.save_dir, exist_ok=True) self.cluster_methods = OrderedDict() self.cluster_methods["kmeans"] = {"main": KMeans, "kwargs": {}} self.cluster_methods["spectral"] = { "main": SpectralClustering, "kwargs": {} } self.cluster_methods["agglo"] = { "main": AgglomerativeClustering, "kwargs": { "linkage": "ward" }, } self.methods_with_ncluster_param = ["kmeans", "spectral", "agglo"] self.cme = 0 self.clustering = None self.n_clusters = 25 # colors: self.standard_color = (0, 0, 1, 1) self.current_color = (1, 0, 0, 1) self.nn_color = (1, 0, 1, 1) self.log.info("Loading data...") with open(self.embeddings_file, "rb") as f: self.data = pkl.load(f) self.iou_preds = self.data["iou_pred"] self.gt = np.array(self.data["gt"]).flatten() self.pred = np.array(self.data["pred"]).flatten() self.gi = self.data[ "image_level_index"] # global indices (on image level and not on component level) self.log.info("Loaded {} segment embeddings.".format( self.pred.shape[0])) self.nearest_neighbors = None if len(self.data["embeddings"]) == 1: self.data["plot_embeddings"] = np.array( [self.data["embeddings"][0][0], self.data["embeddings"][0][1]]).reshape((1, 2)) self.data["nn_embeddings"] = self.data["plot_embeddings"] else: if ("nn_embeddings" not in self.data.keys() or overwrite_embeddings or "plot_embeddings" not in self.data.keys() ) and embedding_size < self.data["embeddings"][0].shape[0]: self.log.info("Computing PCA...") n_comp = (50 if 50 < min( len(self.data["embeddings"]), self.data["embeddings"][0].shape[0], ) else min( len(self.data["embeddings"]), self.data["embeddings"][0].shape[0], )) embeddings = PCA(n_components=n_comp).fit_transform( np.stack(self.data["embeddings"]).reshape( (-1, self.data["embeddings"][0].shape[0]))) rewrite = True else: rewrite = False if "plot_embeddings" not in self.data.keys( ) or overwrite_embeddings: self.log.info("Computing t-SNE for plotting") self.data["plot_embeddings"] = TSNE( n_components=2, **tsne_args).fit_transform(embeddings) new_plot_embeddings = True else: new_plot_embeddings = False if (embedding_size >= self.data["embeddings"][0].shape[0] or embedding_size is None): self.embeddings = np.stack(self.data["embeddings"]).reshape( (-1, self.data["embeddings"][0].shape[0])) self.log.debug( ("Requested embedding size of {} was greater or equal " "to data dimensionality of {}. " "Data has thus not been reduced in dimensionality." ).format(embedding_size, self.data["embeddings"].shape[1])) elif (self.data["nn_embeddings"].shape[1] == embedding_size if "nn_embeddings" in self.data.keys() else False) and not overwrite_embeddings: self.embeddings = self.data["nn_embeddings"] self.log.info(("Loaded reduced embeddings " "({} dimensions) from precomputed file " + "for nearest neighbor search.").format( self.embeddings.shape[1])) elif rewrite: if method == "TSNE": if ("plot_embeddings" in self.data.keys() and embedding_size == 2 and new_plot_embeddings): self.embeddings = self.data["plot_embeddings"] self.log.info(( "Reused the precomputed manifold for plotting for " "nearest neighbor search.")) else: self.log.info(("Computing t-SNE of dimension " "{} for nearest neighbor search..." ).format(embedding_size)) self.embeddings = TSNE( n_components=embedding_size, n_jobs=n_jobs, **tsne_args).fit_transform(embeddings) else: self.log.info(("Computing Isomap of dimension " "{} for nearest neighbor search..." ).format(embedding_size)) self.embeddings = Isomap( n_components=embedding_size, n_jobs=n_jobs, ).fit_transform(embeddings) self.data["nn_embeddings"] = self.embeddings else: raise ValueError( ("Please specify a valid combination of arguments.\n" "Loading fails if 'overwrite_embeddings' is False and " "saved 'embedding_size' does not match the requested one." )) # Write added data into pickle file if rewrite: with open(self.embeddings_file, "wb") as f: pkl.dump(self.data, f) self.x = self.data["plot_embeddings"][:, 0] self.y = self.data["plot_embeddings"][:, 1] self.label_mapping = dict() for d in np.unique(self.data["dataset"]).flatten(): try: self.label_mapping[d] = getattr( importlib.import_module(datasets[d].module_name), datasets[d].class_name, )(**datasets[d].kwargs, ).label_mapping except AttributeError: self.label_mapping[d] = None train_dat = self.label_mapping[CONFIG.TRAIN_DATASET.name] = getattr( importlib.import_module(CONFIG.TRAIN_DATASET.module_name), CONFIG.TRAIN_DATASET.class_name, )(**CONFIG.TRAIN_DATASET.kwargs, ) self.pred_mapping = train_dat.pred_mapping if CONFIG.TRAIN_DATASET.name not in self.label_mapping: self.label_mapping[ CONFIG.TRAIN_DATASET.name] = train_dat.label_mapping self.tnsize = (50, 50) self.fig_nn = None self.fig_main = None self.line_main = None self.im = None self.xybox = None self.ab = None self.basecolors = np.stack( [self.standard_color for _ in range(self.x.shape[0])]) self.n_neighbors = 49 self.current_pressed_key = None self.plot_main(**main_plot_args) def plot_main(self, **plot_args): """Initializes the main plot. Only 'legend' (bool) is currently supported as keyword argument. """ self.fig_main = plt.figure(num=1) self.fig_main.canvas.set_window_title("Embedding space") ax = self.fig_main.add_subplot(111) ax.set_axis_off() self.line_main = ax.scatter(self.x, self.y, marker="o", color=self.basecolors, zorder=2) self.line_main.set_picker(True) if ((plot_args["legend"] and all(lm is not None for lm in self.label_mapping.values())) if "legend" in plot_args else False): box = ax.get_position() ax.set_position([box.x0, box.y0, box.width, box.height * 0.8]) legend_elements = [] for d in np.unique(self.data["dataset"]).flatten(): cls = np.unique( self.gt[np.array(self.data["dataset"])[self.gi] == d]) cls = list({(self.label_mapping[d][cl][0], self.label_mapping[d][cl][1]) for cl in cls}) names = np.array([i[0] for i in cls]) cols = np.array([i[1] for i in cls]) legend_elements += [ Patch( color=tuple(i / 255.0 for i in cols[i]) + (1.0, ), label=names[i] if not names[i][-1].isdigit() else names[i][:names[i].rfind(" ")], ) for i in range(names.shape[0]) ] ax.legend( loc="upper left", handles=legend_elements, ncol=8, bbox_to_anchor=(0, 1.2), ) self.basecolors = self.line_main.get_facecolor() tmp = (Image.open( self.data["image_path"][self.gi[0]]).convert("RGB").crop( self.data["box"][0])) tmp.thumbnail(self.tnsize, Image.ANTIALIAS) self.im = OffsetImage(tmp, zoom=2) self.xybox = (50.0, 50.0) self.ab = AnnotationBbox( self.im, (0, 0), xybox=self.xybox, xycoords="data", boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->"), ) ax.add_artist(self.ab) self.ab.set_visible(False) if plot_args[ "save_path"] is not None if "save_path" in plot_args else False: plt.savefig(expanduser(plot_args["save_path"]), dpi=300, bbox_inches="tight") else: self.fig_main.canvas.mpl_connect("motion_notify_event", self.hover_main) self.fig_main.canvas.mpl_connect("button_press_event", self.click_main) self.fig_main.canvas.mpl_connect("scroll_event", self.scroll) self.fig_main.canvas.mpl_connect("key_press_event", self.key_press) self.fig_main.canvas.mpl_connect("key_release_event", self.key_release) plt.show() def hover_main(self, event): """Action handler for the main plot. This function shows a thumbnail of the underlying image when a scatter point is hovered with the mouse. """ # if the mouse is over the scatter points if self.line_main.contains(event)[0]: # find out the index within the array from the event ind, *_ = self.line_main.contains(event)[1]["ind"] # get the figure size w, h = self.fig_main.get_size_inches() * self.fig_main.dpi ws = (event.x > w / 2.0) * -1 + (event.x <= w / 2.0) hs = (event.y > h / 2.0) * -1 + (event.y <= h / 2.0) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. self.ab.xybox = (self.xybox[0] * ws, self.xybox[1] * hs) # make annotation box visible self.ab.set_visible(True) # place it at the position of the hovered scatter point self.ab.xy = (self.x[ind], self.y[ind]) # set the image corresponding to that point tmp = (Image.open( self.data["image_path"][self.gi[ind]]).convert("RGB").crop( self.data["box"][ind])) tmp.thumbnail(self.tnsize, Image.ANTIALIAS) self.im.set_data(tmp) tmp.close() else: # if the mouse is not over a scatter point self.ab.set_visible(False) self.fig_main.canvas.draw_idle() def click_main(self, event): """Action handler for the main plot. This function shows a single or full image or displays nearest neighbors based on the button that has been pressed and which scatter point was pressed. """ if self.line_main.contains(event)[0]: ind, *_ = self.line_main.contains(event)[1]["ind"] if self.current_pressed_key == "t" and event.button == 1: self.store_thumbnail(ind) elif self.current_pressed_key == "control" and event.button == 1: self.show_single_image(ind, save=True) elif self.current_pressed_key == "control" and event.button == 2: self.show_full_image(ind, save=True) elif event.button == 1: # left mouse button self.show_single_image(ind) elif event.button == 2: # middle mouse button self.show_full_image(ind) elif event.button == 3: # right mouse button if not plt.fignum_exists(2): # nearest neighbor figure is not open anymore or has not been # opened yet self.log.info("Loading nearest neighbors...") self.nearest_neighbors = self.get_nearest_neighbors( ind, metric=self.distance_metrics[self.dm]) thumbnails = [] for neighbor_ind in self.nearest_neighbors: thumbnails.append( Image.open(self.data["image_path"][ self.gi[neighbor_ind]]).crop( self.data["box"][neighbor_ind])) columns = math.ceil(math.sqrt(self.n_neighbors)) rows = math.ceil(self.n_neighbors / columns) self.fig_nn = plt.figure(num=2, dpi=self.dpi) self.fig_nn.canvas.set_window_title( "{} nearest neighbors to selected image".format( self.n_neighbors)) for p in range(columns * rows): ax = self.fig_nn.add_subplot(rows, columns, p + 1) ax.set_axis_off() if p < len(thumbnails): ax.imshow(np.asarray(thumbnails[p])) self.fig_nn.canvas.mpl_connect("button_press_event", self.click_nn) self.fig_nn.canvas.mpl_connect("key_press_event", self.key_press) self.fig_nn.canvas.mpl_connect("key_release_event", self.key_release) self.fig_nn.canvas.mpl_connect("scroll_event", self.scroll) self.fig_nn.show() else: # nearest neighbor figure is already open. Update the figure with # new nearest neighbor self.update_nearest_neighbors(ind) return self.set_color(ind, self.current_color) self.flush_colors() def click_nn(self, event): """Action handler for the nearest neighbor window. When clicking a cropped segment in the nearest neighbor window the same actions are taken as in the click handler for the main plot. """ if event.inaxes in self.fig_nn.axes: ind = self.get_ind_nn(event) if self.current_pressed_key == "t" and event.button == 1: self.store_thumbnail(self.nearest_neighbors[ind]) elif self.current_pressed_key == "control" and event.button == 1: self.show_single_image(self.nearest_neighbors[ind], save=True) elif self.current_pressed_key == "control" and event.button == 2: self.show_full_image(self.nearest_neighbors[ind], save=True) elif event.button == 1: # left mouse button self.show_single_image(self.nearest_neighbors[ind]) elif event.button == 2: # middle mouse button self.show_full_image(self.nearest_neighbors[ind]) elif event.button == 3: # right mouse button self.update_nearest_neighbors(self.nearest_neighbors[ind]) def key_press(self, event): """Performs different actions based on pressed keys.""" self.log.debug("Key '{}' pressed.".format(event.key)) if event.key == "m": self.dm += 1 self.dm = self.dm % len(self.distance_metrics) self.log.info("Changed distance metric to {}".format( self.distance_metrics[self.dm])) elif event.key == "#": self.cme += 1 self.cme = self.cme % len(self.cluster_methods) self.log.info("Changed clustering method to {}".format( list(self.cluster_methods.keys())[self.cme])) elif event.key == "c": self.log.info("Started clustering with {}...".format( list(self.cluster_methods.keys())[self.cme])) self.cluster(method=list(self.cluster_methods.keys())[self.cme]) if self.fig_main.axes[0].get_legend() is not None: self.fig_main.axes[0].get_legend().remove() self.basecolors = cm.get_cmap( "viridis", (max(self.clustering) + 1))(self.clustering) self.flush_colors() elif event.key == "g": self.color_gt() elif event.key == "h": self.color_pred() elif event.key == "b": self.set_color(list(range(self.basecolors.shape[0])), self.standard_color) if self.fig_main.axes[0].get_legend() is not None: self.fig_main.axes[0].get_legend().remove() self.flush_colors() elif event.key == "d": self.show_density() self.current_pressed_key = event.key def key_release(self, event): """Clears the variable where the last pressed key is saved.""" self.current_pressed_key = None self.log.debug("Key '{}' released.".format(event.key)) def scroll(self, event): """Increases or decreases number of nearest neighbors when scrolling on the main or nearest neighbor plot.""" if event.button == "up": self.n_neighbors += 1 self.log.info("Increased number of nearest neighbors to {}".format( self.n_neighbors)) elif event.button == "down": if self.n_neighbors > 0: self.n_neighbors -= 1 self.log.info( "Decreased number of nearest neighbors to {}".format( self.n_neighbors)) def show_single_image(self, ind, save=False): """Displays the full image belonging to a segment. The segment is marked with a red bounding box.""" self.log.info("{} image...".format("Saving" if save else "Loading")) img_box = self.draw_box_on_image(ind) fig_tmp = plt.figure(max(3, max(plt.get_fignums()) + 1), dpi=self.dpi) ax = fig_tmp.add_subplot(111) ax.set_axis_off() ax.imshow(np.asarray(img_box), interpolation="nearest") if save: fig_tmp.subplots_adjust(bottom=0, left=0, right=1, top=1, hspace=0, wspace=0) ax.margins(0.05, 0.05) fig_tmp.gca().xaxis.set_major_locator(plt.NullLocator()) fig_tmp.gca().yaxis.set_major_locator(plt.NullLocator()) fig_tmp.savefig( join(self.save_dir, "image_{}.jpg".format(ind)), bbox_inches="tight", pad_inches=0.0, ) self.log.debug("Saved image to {}".format( join(self.save_dir, "image_{}.jpg".format(ind)))) else: fig_tmp.canvas.set_window_title( "Dataset: {}, Image index: {}".format( self.data["dataset"][self.gi[ind]], self.data["image_index"][self.gi[ind]], )) fig_tmp.tight_layout(pad=0.0) fig_tmp.show() def show_full_image(self, ind, save=False): """Displays four panels of the full image belonging to a segment. Top left: Entropy heatmap of prediction. Top right: Predicted IoU of each segment. Bottom left: Source image with ground truth overlay. Bottom right: Predicted semantic segmentation. """ self.log.info( "{} detailed image...".format("Saving" if save else "Loading")) box = self.data["box"][ind] image = np.asarray( Image.open(self.data["image_path"][self.gi[ind]]).convert("RGB")) image_index = self.data["image_index"][self.gi[ind]] iou_pred = self.data["iou_pred"][self.gi[ind]] dataset = self.data["dataset"][self.gi[ind]] model_name = self.data["model_name"][self.gi[ind]] pred, gt, image_path = probs_gt_load( image_index, input_dir=join(CONFIG.metaseg_io_path, "input", model_name, dataset), ) components = components_load( image_index, components_dir=join(CONFIG.metaseg_io_path, "components", model_name, dataset), ) e = entropy(pred) pred = pred.argmax(2) predc = np.asarray([ self.pred_mapping[pred[ind_i, ind_j]][1] for ind_i in range(pred.shape[0]) for ind_j in range(pred.shape[1]) ]).reshape(image.shape) overlay_factor = [1.0, 0.5, 1.0] if self.label_mapping[dataset] is not None: gtc = np.asarray([ self.label_mapping[dataset][gt[ind_i, ind_j]][1] for ind_i in range(gt.shape[0]) for ind_j in range(gt.shape[1]) ]).reshape(image.shape) else: gtc = np.zeros_like(image) overlay_factor[1] = 0.0 img_predc, img_gtc, img_entropy = [ Image.fromarray( np.uint8(arr * overlay_factor[i] + image * (1 - overlay_factor[i]))) for i, arr in enumerate([predc, gtc, cm.jet(e)[:, :, :3] * 255.0]) ] img_ioupred = Image.fromarray( self.visualize_segments(components, iou_pred)) images = [img_gtc, img_predc, img_entropy, img_ioupred] box_line_width = 5 left, upper = max(0, box[0] - box_line_width), max( 0, box[1] - box_line_width) right, lower = min(pred.shape[1], box[2] + box_line_width), min( pred.shape[0], box[3] + box_line_width) for k in images: draw = ImageDraw.Draw(k) draw.rectangle([left, upper, right, lower], outline=(255, 0, 0), width=box_line_width) del draw for k in range(len(images)): images[k] = np.asarray(images[k]).astype("uint8") img_top = np.concatenate(images[2:], axis=1) img_bottom = np.concatenate(images[:2], axis=1) img_total = np.concatenate((img_top, img_bottom), axis=0) fig_tmp = plt.figure(max(3, max(plt.get_fignums()) + 1), dpi=self.dpi) fig_tmp.canvas.set_window_title("Dataset: {}, Image index: {}".format( dataset, image_index)) ax = fig_tmp.add_subplot(111) ax.set_axis_off() ax.imshow(img_total, interpolation="nearest") if save: fig_tmp.subplots_adjust(bottom=0, left=0, right=1, top=1, hspace=0, wspace=0) ax.margins(0.05, 0.05) fig_tmp.gca().xaxis.set_major_locator(plt.NullLocator()) fig_tmp.gca().yaxis.set_major_locator(plt.NullLocator()) fig_tmp.savefig( join(self.save_dir, "detailed_image_{}.jpg".format(ind)), bbox_inches="tight", pad_inches=0.0, ) self.log.debug("Saved image to {}".format( join(self.save_dir, "detailed_image_{}.jpg".format(ind)))) else: fig_tmp.tight_layout(pad=0.0) fig_tmp.show() def store_thumbnail(self, ind): """Stores a thumbnail of a segment if requested. Thus is not saving the whole image but only the cropped part.""" image = Image.open( self.data["image_path"][self.gi[ind]]).convert("RGB") image = image.crop(self.data["box"][ind]) if self.label_mapping[self.data["dataset"][self.gi[ind]]] is None: name = "None" else: name = self.label_mapping[self.data["dataset"][self.gi[ind]]][ self.gt[ind]][0] if name[-1].isdigit(): name = name[:-2] name = name.replace(" ", "_") image.save( join( self.save_dir, "thumbnail_{}_{:0>2.1f}_{:0>2.1f}.jpg".format( name, self.x[ind], self.y[ind]), )) self.log.debug("Saved thumbnail to {}".format( join( self.save_dir, "thumbnail_{}_{:0>2.1f}_{:0>2.1f}.jpg".format( name, self.x[ind], self.y[ind]), ))) def get_nearest_neighbors(self, ind, metric="cos"): """Computes nearest neighbors to the specified index in the collection of segment crops.""" if metric == "euclid": dists = self.lp_dist(self.embeddings[ind], self.embeddings, d=2) else: dists = self.cos_dist(self.embeddings[ind], self.embeddings) return np.argsort(dists)[1:(self.n_neighbors + 1)] def update_nearest_neighbors(self, ind): """If requesting nearest neighbors of a segment within the nearest neighbor plot window the nearest neighbors are updated according to the newly selected segment. """ self.log.info("Loading nearest neighbors...") self.nearest_neighbors = self.get_nearest_neighbors( ind, metric=self.distance_metrics[self.dm]) thumbnails = [] for neighbor_ind in self.nearest_neighbors: b = self.data["box"][neighbor_ind] thumbnails.append( plt.imread(self.data["image_path"][self.gi[neighbor_ind]])[ b[1]:b[3], b[0]:b[2], :]) n = math.ceil(math.sqrt(len(self.nearest_neighbors))) if len(self.fig_nn.axes) != (n**2): # number of nearest neighbors has been changed # redefine number of subplots in fig_nn self.rearrange_axes(n, n) for p in range(n**2): if p < self.n_neighbors: self.fig_nn.axes[p].imshow(thumbnails[p]) else: self.fig_nn.axes[p].clear() self.fig_nn.axes[p].set_axis_off() self.fig_nn.canvas.draw() self.set_color(ind, self.current_color) self.flush_colors() def cluster(self, method="kmeans"): if method in self.methods_with_ncluster_param: n_clusters = self.n_cluster_prompt() if n_clusters == "elbow" and method == "kmeans": n_clusters = self.elbow() self.clustering = self.cluster_methods[method]["main"]( n_clusters=n_clusters, **self.cluster_methods[method]["kwargs"]).fit_predict( self.embeddings) def elbow(self): low = int(input("Enter the minimum number of clusters: ")) high = int(input("Enter the maximum number of clusters: ")) km = [KMeans(n_clusters=i) for i in range(low, high + 1)] km = [k.fit(self.embeddings) for k in tqdm(km, total=len(km))] score = [k.inertia_ for k in km] fig_elbow = plt.figure(max(3, max(plt.get_fignums()) + 1)) ax = fig_elbow.add_subplot(111) ax.plot(range(low, high + 1), score) fig_elbow.show() return int(input("Enter number of clusters: ")) def n_cluster_prompt(self): inp = input("Enter the number of clusters: ") if inp == "elbow": return inp else: try: inp = int(inp) except ValueError: self.log.error( "Input should be an int or 'elbow' but received {}!". format(inp)) return "error" if inp <= 1: raise ValueError( "Number of clusters should be greater than 1!") else: return inp def rearrange_axes(self, nrows, ncols): """Helper function for the nearest neighbor plot window. If number of nearest neighbors has been changed and a new query segment has been chosen the arrangement of subplots within the window has to be changed. """ n = len(self.fig_nn.axes) if n <= (nrows * ncols): # we need to add more axes for i, ax in enumerate(self.fig_nn.axes): ax.change_geometry(nrows, ncols, i + 1) for p in range(n, nrows * ncols): ax = self.fig_nn.add_subplot(nrows, ncols, p + 1) ax.set_axis_off() else: # we need to remove some axes for p in range(n - 1, (nrows * ncols) - 1, -1): self.fig_nn.delaxes(self.fig_nn.axes[p]) for i, ax in enumerate(self.fig_nn.axes): ax.change_geometry(nrows, ncols, i + 1) def draw_box_on_image(self, ind): """Draws the red bounding of a selected segment onto the source image.""" box_line_width = 5 img_box = Image.open( self.data["image_path"][self.gi[ind]]).convert("RGB") draw = ImageDraw.Draw(img_box) left, upper, right, lower = self.data["box"][ind] left, upper = max(0, left - box_line_width), max(0, upper - box_line_width) right, lower = min(img_box.size[0], right + box_line_width), min( img_box.size[1], lower + box_line_width) draw.rectangle([left, upper, right, lower], outline=(255, 0, 0), width=box_line_width) del draw return img_box @staticmethod def visualize_segments(comp, metric): """Helper function for generation of the four panels in the detailed image function.""" r = np.asarray(metric) r = 1 - 0.5 * r g = np.asarray(metric) b = 0.3 + 0.35 * np.asarray(metric) r = np.concatenate((r, np.asarray([0, 1]))) g = np.concatenate((g, np.asarray([0, 1]))) b = np.concatenate((b, np.asarray([0, 1]))) components = np.asarray(comp.copy(), dtype="int16") components[components < 0] = len(r) - 1 components[components == 0] = len(r) img = np.zeros(components.shape + (3, )) x, y = np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1])) x = x.reshape(-1) y = y.reshape(-1) img[x, y, 0] = r[components[x, y] - 1] img[x, y, 1] = g[components[x, y] - 1] img[x, y, 2] = b[components[x, y] - 1] img = np.asarray(255 * img).astype("uint8") return img @staticmethod def lp_dist(point, all_points, d=2): """Calculates the L_p distance from a point to a collection of points. Used for retrieval.""" return ((all_points - point)**d).sum(1)**(1.0 / d) @staticmethod def cos_dist(point, all_points): """Calculates the cosine distance from a point to a collection of points. Used for retrieval.""" return 1 - ((point * all_points).sum(1) / (norm(point) * norm(all_points, axis=1))) @staticmethod def get_gridsize(fig): """Helper function for the nearest neighbor plot.""" return fig.axes[0].get_subplotspec().get_gridspec().get_geometry() def get_ind_nn(self, event): """Helper function for the nearest neighbor plot.""" _, ncols = self.get_gridsize(self.fig_nn) eventrow = event.inaxes.rowNum eventcol = event.inaxes.colNum return (eventrow * ncols) + eventcol def color_gt(self): """When called colors the scatter in the main plot according to the ground truth colors.""" if all(self.label_mapping[self.data["dataset"][self.gi[ind]]] is not None for ind in range(self.basecolors.shape[0])): self.basecolors = np.stack([ tuple(i / 255.0 for i in self.label_mapping[self.data["dataset"][ self.gi[ind]]][self.gt[ind]][1]) + (1.0, ) for ind in range(self.basecolors.shape[0]) ]) legend_elements = [] for d in np.unique(self.data["dataset"]).flatten(): cls = np.unique( self.gt[np.array(self.data["dataset"])[self.gi] == d]) cls = list({(self.label_mapping[d][cl][0], self.label_mapping[d][cl][1]) for cl in cls}) names = np.array([i[0] for i in cls]) cols = np.array([i[1] for i in cls]) legend_elements += [ Patch( color=tuple(i / 255.0 for i in cols[i]) + (1.0, ), label=names[i] if not names[i][-1].isdigit() else names[i][:names[i].rfind(" ")], ) for i in range(names.shape[0]) ] self.fig_main.axes[0].legend( loc="upper left", handles=legend_elements, ncol=8, bbox_to_anchor=(0, 1.1), ) self.flush_colors() def color_pred(self): """When called colors the scatter in the main plot according to the predicted class color.""" self.basecolors = np.stack([ tuple(i / 255.0 for i in self.pred_mapping[self.pred[ind]][1]) + (1.0, ) for ind in range(self.basecolors.shape[0]) ]) legend_elements = [ Patch( color=tuple(i / 255.0 for i in self.pred_mapping[cl][1]) + (1.0, ), label=self.pred_mapping[cl][0], ) for cl in np.unique(self.pred).flatten() ] self.fig_main.axes[0].legend(loc="upper left", handles=legend_elements, ncol=8, bbox_to_anchor=(0, 1.1)) self.flush_colors() def show_density(self): embedding_kde = estimate_kernel_density(self.data["plot_embeddings"]) xmin = self.x.min() xmin = xmin * 1.3 if xmin < 0 else xmin * 0.8 xmax = self.x.max() xmax = xmax * 1.3 if xmax > 0 else xmax * 0.8 ymin = self.y.min() ymin = ymin * 1.3 if ymin < 0 else ymin * 0.8 ymax = self.y.max() ymax = ymax * 1.3 if ymax > 0 else ymax * 0.8 grid_x, grid_y = np.mgrid[xmin:xmax, ymin:ymax] grid_z = embedding_kde(np.vstack([grid_x.flatten(), grid_y.flatten()])) colmap = plt.get_cmap("Greys") colmap = colors.LinearSegmentedColormap.from_list( "trunc({n},{a:.2f},{b:.2f})".format(n=colmap.name, a=0.0, b=0.75), colmap(np.linspace(0.0, 0.75, 256)), ) grid_z[grid_z < np.quantile(grid_z, 0.55)] = np.NaN colmap.set_bad("white") self.fig_main.axes[0].pcolormesh( grid_x, grid_y, grid_z.reshape(grid_x.shape), cmap=colmap, shading="gouraud", zorder=1, ) self.flush_colors() def set_color(self, ind, color): """Helper function to set a color of a segment with index ind.""" self.basecolors[ind, :] = color def change_color(self, old_color, new_color): """Helper function to change a specific color to a different one.""" self.basecolors[(self.basecolors == old_color).all(axis=1)] = new_color def flush_colors(self): """Flushes all color changes onto the main scatter plot.""" self.line_main.set_color(self.basecolors) self.fig_main.canvas.draw()
def visualize(): fig = plt.figure() ax = plt.axes() points_with_annotation = [] themap = Basemap(projection='gall', llcrnrlon = -180, llcrnrlat = -90, urcrnrlon = 180, urcrnrlat = 90, resolution = 'l', area_thresh = 100000.0, ) themap.drawcoastlines() themap.drawcountries() themap.fillcontinents(color = 'gainsboro') themap.drawmapboundary(fill_color='steelblue') #site 1 x1, y1 = themap(0,44) point, = themap.plot(x1, y1, 'o', color='Red',markersize=10) fn = get_sample_data("/Users/rogpaxton/Galvanize/SiteMetrics/plot44.0.png", asfileobj=False) arr_lena = read_png(fn) imagebox = OffsetImage(arr_lena, zoom=0.5) xy1 = (0, 44) annotation = AnnotationBbox(imagebox, xy1, xybox=(50, 150), xycoords='data', boxcoords="offset points", pad=0.5, arrowprops=dict(arrowstyle="->", #connectionstyle="angle,angleA=0,angleB=90,rad=3") )) annotation.set_visible(False) points_with_annotation.append([point, annotation]) x2, y2 = themap(90,44) themap.plot(x2, y2, 'o', # marker shape color='Red', # marker colour markersize=6 # marker size ) #site 2 x2, y2 = themap(90,44) point, = themap.plot(x2, y2, 'o', color='Red',markersize=10) fn = get_sample_data("/Users/rogpaxton/Galvanize/SiteMetrics/plot45.0.png", asfileobj=False) arr_lena = read_png(fn) imagebox = OffsetImage(arr_lena, zoom=0.5) xy2 = (90, 44) annotation = AnnotationBbox(imagebox, xy2, xybox=(50, 150), xycoords='data', boxcoords="offset points", pad=0.5, #arrowprops=dict(arrowstyle="->", #connectionstyle="angle,angleA=0,angleB=90,rad=3") ) annotation.set_visible(False) points_with_annotation.append([point, annotation]) #site 3 x3, y3 = themap(120,-34) point, = themap.plot(x3, y3, 'o', color='Red',markersize=10) fn = get_sample_data("/Users/rogpaxton/Galvanize/SiteMetrics/plot46.0.png", asfileobj=False) arr_lena = read_png(fn) imagebox = OffsetImage(arr_lena, zoom=0.5) #image = mpimg.imread("plot44.0.png") xy3 = (120, -34) annotation = AnnotationBbox(imagebox, xy3, xybox=(50, 150), xycoords='data', boxcoords="offset points", pad=0.5, arrowprops=dict(arrowstyle="->", connectionstyle="angle,angleA=0,angleB=90,rad=3") ) annotation.set_visible(False) points_with_annotation.append([point, annotation]) #Mouseover def on_move(event): visibility_changed = False for point, annotation in points_with_annotation: ax.add_artist(annotation) should_be_visible = (point.contains(event)[0] == True) if should_be_visible != annotation.get_visible(): visibility_changed = True annotation.set_visible(should_be_visible) if visibility_changed: plt.draw() on_move_id = fig.canvas.mpl_connect('motion_notify_event', on_move) plt.show()
def visualize_clusters(x, y, imgs, colors=None): """create an interactive plot visualizing the clusters hovering over a point shows the corresponding face crop""" # create figure and plot scatter fig = plt.figure() ax = fig.add_subplot(111) ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) ax.tick_params(axis='y', which='both', left=False, right=False, labelleft=False) if colors: line, = ax.plot(x, y, ls="") ax.scatter(x, y, c=colors) else: line, = ax.plot(x, y, ls="", marker="o") # create the annotations box im = OffsetImage(imgs[0], zoom=1) xybox = (50., 50.) ab = AnnotationBbox(im, (0, 0), xybox=xybox, xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) # add it to the axes and make it invisible ax.add_artist(ab) ab.set_visible(False) def hover(event): # if the mouse is over the scatter points if line.contains(event)[0]: # find out the index within the array from the event ind, = line.contains(event)[1]["ind"] # get the figure size w, h = fig.get_size_inches() * fig.dpi ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.) hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. ab.xybox = (xybox[0] * ws, xybox[1] * hs) # make annotation box visible ab.set_visible(True) # place it at the position of the hovered scatter point ab.xy = (x[ind], y[ind]) # set the image corresponding to that point #im.set_data(imgs[ind][:,:,::-1]) im.set_data(imgs[ind]) else: #if the mouse is not over a scatter point ab.set_visible(False) fig.canvas.draw_idle() # add callback for mouse moves fig.canvas.mpl_connect('motion_notify_event', hover) plt.show()
def food_map(store_list): lons = list() # x 座標:經度 lats = list() # y 座標:緯度 # 讀取商家經緯度 for store in store_list: lons.append(store.lon) lats.append(store.lat) # 轉成矩陣 lons = np.array(lons) lats = np.array(lats) x, y = (lons, lats) # transform coordinates # 讀入地圖底圖 img = plt.imread("/Users/yuchiaching/Desktop/完稿壓縮正確尺寸/地圖+畫框.png") # 建立畫框和圖表 fig, ax = plt.subplots(figsize=(16, 10), dpi=70) # 圖表顯示地圖底圖以及設定座標軸 # 善心人士_2 # http://hk.uwenku.com/question/p-qgzudzqa-bc.html ax.imshow( img, extent=[121.52209923089963, 121.55634026412044, 25.00897, 25.02854]) # 不顯示座標軸 plt.axis('off') # 標示商家位置 line, = ax.plot(x, y, ls="", marker='o', color='#fa4a0c') # 商家資訊清單 message_list = list() # 製作資訊框 for i, store in enumerate(store_list): store.types = '、'.join(store.types) message = [ store.name, '[' + str(store.area) + '] | ' + store.types, '★ ' + str(store.avg_ranking) + ' / 5.0 | ' + 'NT$' + str(store.lowerbound) + ' ~ NT$' + str(store.upperbound) ] message = '\n'.join(message) message_list.append(message) message = TextArea(message, minimumdescent=False, textprops=dict(fontproperties=prop)) xybox = (50, 50) ab = AnnotationBbox(message, (x[i], y[i]), xybox=xybox, xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) # 把他放到圖表上 ax.add_artist(ab) # 轉成可顯示 ab.set_visible(False) # CopyPaste # 滑鼠事件 # 游標移到該點位置顯示圖片 def hover(event): # if the mouse is over the scatter points if line.contains(event)[0]: # find out the index within the array from the event ind, = line.contains(event)[1]["ind"] # get the figure size w, h = fig.get_size_inches() * fig.dpi ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.) hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. ab.xybox = (xybox[0] * ws, xybox[1] * hs) # make annotation box visible ab.set_visible(True) # place it at the position of the hovered scatter point ab.xy = (x[ind], y[ind]) # set the image corresponding to that point message.set_text(message_list[ind]) else: #if the mouse is not over a scatter point ab.set_visible(False) fig.canvas.draw_idle() fig.canvas.mpl_connect('motion_notify_event', hover) plt.show() return str("finish")
def final_plot(self): w = self.nodes_weights_w h = self.nodes_weights_h # Generate data x, y for scatter and an array of images. x = np.arange(w) y = np.arange(h) # an image array as big as the som lattice # fixed values for the shape of the image and the arrow so it will not be too small fixed_h = fixed_w = 120 images_array = np.empty((len(self.input_vectors), fixed_w, fixed_h, 3)) # associate for each index of the images_array an image that should be a input vector data for i in range(len(self.input_vectors)): # save the images in the images_array also resizing as necessary (images_array dimensions) images_array[i] = np.array(cv2.resize(cv2.imread(self.input_vectors_images_path[i]), (fixed_h, fixed_w))) y = np.zeros((w, h)) x = np.zeros((w, h)) # VERY IMPORTANT # WE WANT A SQUARE SHAPE nxn # VERY IMPORTANT for i in range(len(x)): for j in range(len(x[0])): x[i][j] = j y[i][j] = i # to obtain all the elements of the lattice ordered in one dimensional vector x = np.array(x.reshape(len(x) * len(x[0]), 1)) y = np.array(y.reshape(len(y) * len(y[0]), 1)) fig = plt.figure() # just imposing how the plot should be 1x1 ax = fig.add_subplot(111) line = Line2D(x, y, ls="", marker="") ax.add_line(line) # upper is important, [:,:,:3] the colors are done by only the first three components # ****************************************************** # NORMALIZED WEIGHTS a = self.nodes_weights[:,:,:3] normalized_weights = (255*(a - np.max(a))/-np.ptp(a)).astype(np.uint8) ax.imshow( np.invert( normalized_weights ) , origin='upper') # lines = plt.scatter( x, y, marker="+") # create the annotations box im = OffsetImage(images_array[0, :, :, :], zoom=1) #xybox = (50., 50.) # this is also for the arrow xybox = (fixed_h, fixed_w) # this is also for the arrow ab = AnnotationBbox(im, (0, 0), xybox=xybox, xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->")) # add it to the axes and make it invisible ax.add_artist(ab) ab.set_visible(False) def hover(event): if line.contains(event)[0]: # find out the index within the array from the event ind = line.contains(event)[1]["ind"] # get the figure size ww, hh = fig.get_size_inches() * fig.dpi ws = (event.x > ww / 2.) * -1 + (event.x <= ww / 2.) hs = (event.y > hh / 2.) * -1 + (event.y <= hh / 2.) # if event occurs in the top or right quadrant of the figure, # change the annotation box position relative to mouse. ind = ind[0] ab.xybox = (xybox[0] * ws, xybox[1] * hs) # this should be the lenght of the arrow #ab.xybox = (xybox[0] * ws, xybox[1] * hs) # make annotation box visible ab.set_visible(True) # place it at the position of the hovered scatter point ab.xy = (x[ind], y[ind]) # set the image corresponding to that point # ****** # TO OBTAIN THE VALUES OF THE LATTICE x_index = math.floor(int(ind) / h) y_index = math.floor(int(ind) - (h * math.floor(int(ind) / h))) # take the correct "label" index image_index = self.most_like_weight(x_index, y_index) # im.set_data( images_array[ind ,:,:][0] ) #print "image_index = " + str ( image_index ) normalized_weights_im = images_array[image_index, :, :] normalized_weights_im = (255 * (normalized_weights_im - np.max(normalized_weights_im)) / -np.ptp(normalized_weights_im)).astype(np.uint8) # VERY IMPORTANT # it has to be that way with invert and [:,:,[2,1,0] from opencv to matplotlib # VERY IMPORTANT im.set_data(np.invert( normalized_weights_im )[:,:,[2,1,0]]) # important to debug #print self.input_vectors_images_path[image_index] else: # if the mouse is not over a scatter point ab.set_visible(False) fig.canvas.draw_idle() # add callback for mouse moves fig.canvas.mpl_connect('motion_notify_event', hover) # fig.canvas.mpl_connect('button_press_event', onclick) plt.show()
x = np.arange(100) y = np.random.rand(len(x)) arr = np.zeros((len(x), 10, 10)) # создаем figure и отрисовываем scatter plot fig = plt.figure() ax = fig.add_subplot(111) line, = ax.plot(x, y, ls="", marker="o") # создаем annotation box im = OffsetImage(arr[0, :, :], zoom=5) xybox = (50., 50.) ab = AnnotationBbox(im, (0, 0), xybox=xybox, xycoords='data', boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="-")) # add it to the axes and make it invisible ax.add_artist(ab) ab.set_visible(False) def hover(event): # если курсор поверх точки на графике if line.contains(event)[0]: # ищем индекс точки ind, = line.contains(event)[1]["ind"] # размер изображения w, h = fig.get_size_inches()*fig.dpi ws = (event.x > w/2.)*-1 + (event.x <= w/2.) hs = (event.y > h/2.)*-1 + (event.y <= h/2.) # если точка наверху и справа, меняем позицию annotation box ab.xybox = (xybox[0]*ws, xybox[1]*hs) ab.set_visible(True) ab.xy = (x[ind], y[ind])
def main(args): base_colors = ["b", "g", "r", "c", "m", "y", "k"] embodiment_names = [ "two_hands_two_fingers", "tongs", "rms", # "ski_gloves", # "quick_grip", "one_hand_two_fingers", "one_hand_five_fingers", # "double_quick_grip", # "crab", # "quick_grasp", ] N = len(embodiment_names) embs, frames, pos_neg_labels = get_embs(args.embs_path, embodiment_names, args.num_traj) embeddings, rgb_frames, labels = [], [], [] for name in embodiment_names: embeddings.extend(embs[name]) rgb_frames.extend(frames[name]) labels.extend(pos_neg_labels[name]) if args.l2_normalize: embeddings = [x / (np.linalg.norm(x) + 1e-7) for x in embeddings] # figure out the min embedding length min_len = 10000000 for emb in embeddings: if len(emb) < min_len: min_len = len(emb) print(f"min seq len: {min_len}") # subsample embedding sequences to make them # the same long since TSNE expects a tensor # of fixed sequence length embs, frames = [], [] for i, (emb, rgb) in enumerate(zip(embeddings, rgb_frames)): idxs = get_idxs(0, len(emb), min_len) embs.append(emb[idxs]) frames.append(rgb[idxs]) embs = np.stack(embs) frames = np.stack(frames) # flatten embeddings (N, 128) num_vids, num_frames, num_feats = embs.shape embs_flat = embs.reshape(-1, num_feats) # dimensionality reduction if args.reducer == "umap": reducer = umap.UMAP(n_components=args.ndims, random_state=0) elif args.reducer == "tsne": reducer = TSNE(n_components=args.ndims, n_jobs=-1, random_state=0) elif args.reducer == "pca": reducer = PCA(n_components=args.ndims, random_state=0) else: raise ValueError(f"{args.reducer} is not a valid reducer.") embs_2d = reducer.fit_transform(embs_flat) # TODO: add variance calculation # subsample data for less cluttered visualization idxs = np.arange(args.num_traj * N) mask = [] for idx in idxs: mask.extend(np.arange(idx * min_len, (idx + 1) * min_len)) mask = np.asarray(mask) images = frames[idxs].reshape(-1, *frames.shape[2:]) embs = embs_2d[mask] # resize frames frames = [] for img in images: img = (img - img.min()) / (img.max() - img.min()) im = Image.fromarray((img * 255).astype(np.uint8)) im.thumbnail((IMG_HEIGHT, IMG_WIDTH), Image.ANTIALIAS) frames.append(np.asarray(im)) frames = np.stack(frames) label_names, colors = [], [] for i, name in enumerate(embodiment_names): label_names.extend([name] * args.num_traj) colors.extend([base_colors[i]] * args.num_traj) # create figure fig = plt.figure() if args.ndims == 2: ax = fig.add_subplot(111) else: ax = fig.add_subplot(111, projection="3d") lines, imgz, abz = [], [], [] for i in range(len(idxs)): if args.ndims == 3: line = ax.scatter( embs[i * min_len:(i + 1) * min_len, 0], embs[i * min_len:(i + 1) * min_len, 1], embs[i * min_len:(i + 1) * min_len, 2], marker="o" if labels[i] else "x", s=15, c=colors[i], label=label_names[i], ) else: line = ax.scatter( embs[i * min_len:(i + 1) * min_len, 0], embs[i * min_len:(i + 1) * min_len, 1], marker="o" if labels[i] else "x", s=15, c=colors[i], label=label_names[i], ) lines.append(line) # create annotation box im = OffsetImage(frames[i * min_len], zoom=5) imgz.append(im) xybox = (IMG_HEIGHT, IMG_WIDTH) ab = AnnotationBbox( im, (0, 0), xybox=xybox, xycoords="data", boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->"), ) abz.append(ab) # add annotation box to axes and make it invisible ax.add_artist(ab) ab.set_visible(False) def hover(event): # if the mouse is over the scatter points for j, line in enumerate(lines): if line.contains(event)[0]: ind = line.contains(event)[1]["ind"][0] w, h = fig.get_size_inches() * fig.dpi ws = (event.x > w / 2.0) * -1 + (event.x <= w / 2.0) hs = (event.y > h / 2.0) * -1 + (event.y <= h / 2.0) abz[j].xybox = (xybox[0] * ws, xybox[1] * hs) abz[j].set_visible(True) abz[j].xy = ( embs[j * min_len + ind, 0], embs[j * min_len + ind, 1], ) imgz[j].set_data(frames[j * min_len + ind]) else: abz[j].set_visible(False) fig.canvas.draw_idle() # add callback for mouse move fig.canvas.mpl_connect("motion_notify_event", hover) legend_without_duplicate_labels(ax) emb_names = "_".join(embodiment_names) task_name = args.embs_path.split("/")[-2] name = "{}_{}_{}_{}_{}dim.png".format( task_name, emb_names, args.num_traj, args.reducer, args.ndims, ) plt.savefig(osp.join("tmp/plots/", name), format="png", dpi=300) plt.show()