Ejemplo n.º 1
0
def make_interactive_plot(x, y, im_arr, x_type=None, y_type=None):
    # create figure and plot scatter
    fig = plt.figure()
    ax = fig.add_subplot(111)
    line, = ax.plot(x, y, ls="", marker="o")
    plt.xlabel(x_type)
    plt.ylabel(y_type)

    # create the annotations box
    im = OffsetImage(im_arr[0], zoom=5)
    xybox = (50., 50.)
    ab = AnnotationBbox(im, (0, 0),
                        xybox=xybox,
                        xycoords='data',
                        boxcoords="offset points",
                        pad=0.3,
                        arrowprops=dict(arrowstyle="->"))
    # add it to the axes and make it invisible
    ax.add_artist(ab)
    ab.set_visible(False)

    # add callback for mouse moves
    fig.canvas.mpl_connect(
        'motion_notify_event',
        lambda event: my_hover(event, im_arr, line, fig, xybox, ab, im, x, y))
    plt.show()
Ejemplo n.º 2
0
 def __fill_retracement_text_annotations__(self, index: int, ret: float,
                                           value: float):
     text = '{:=1.3f} - {:.2f}'.format(ret, value)
     text_area = TextArea(text, minimumdescent=True, textprops=dict(size=7))
     annotation_box = AnnotationBbox(text_area, (index, value))
     annotation_box.set_visible(False)
     self.__fib_retracement_rectangle_text_dic[ret] = annotation_box
Ejemplo n.º 3
0
def interactive_plot(xdata,ydata,images,supp_values=None,fig=None):
    """
    Displays a plot of xdata, ydata and displays the image corresponding to 
    this point when hovering over data.
    Parameters:
        xdata: array, sequence of scalar
        ydata: array, metric values
        image: 3D array, of dimensions (n_images,width,height), containing the images
        from which the values ydata are calculated.
        supp_values: list of arrays, values of a different metric
    """
    if fig is None:
        fig = plt.figure()
    ax = fig.add_subplot(111)
    line, = ax.plot(xdata,ydata, ls="-", marker="o")
    if supp_values is not None: 
        for ys in supp_values:
            ax.plot(xdata,ys)
            
    # create the annotations box
    if type(images)==list:
        images=np.asarray(images)
    im = OffsetImage(images[0,:,:], zoom=200/(images.shape[1]+images.shape[0]))
    xybox=(50., 50.)
    ab = AnnotationBbox(im, (0,0), xybox=xybox, xycoords='data',
            boxcoords="offset points",  pad=0.3,  arrowprops=dict(arrowstyle="->"))
    # add it to the axes and make it invisible
    ax.add_artist(ab)
    ab.set_visible(False)
    
    def hover(event):
        try:
        # if the mouse is over the scatter points
            if line.contains(event)[0]:
                # find out the index within the array from the event
                ind, = line.contains(event)[1]["ind"]
                # get the figure size
                w,h = fig.get_size_inches()*fig.dpi
                ws = (event.x > w/2.)*-1 + (event.x <= w/2.) 
                hs = (event.y > h/2.)*-1 + (event.y <= h/2.)
                # if event occurs in the top or right quadrant of the figure,
                # change the annotation box position relative to mouse.
                ab.xybox = (xybox[0]*ws, xybox[1]*hs)
                # make annotation box visible
                ab.set_visible(True)
                # place it at the position of the hovered scatter point
                ab.xy =(xdata[ind], ydata[ind])
                # set the image corresponding to that point
                im.set_data(images[ind,:,:])
            else:
                #if the mouse is not over a scatter point
                ab.set_visible(False)
        except Exception as e:
            print("Hovering error")
            print(e)
        fig.canvas.draw_idle()
    
    # add callback for mouse moves
    fig.canvas.mpl_connect('motion_notify_event', hover)
    return fig,ax
Ejemplo n.º 4
0
def plot_embedding_images(embedd, labels, paths, info, savefile):
  fig = plt.figure(figsize=(15,10))
  embedding = reduce_dimensionality(embedd)
  ax = plt.gca()
  ax.set_xlim(-15, 15)
  ax.set_ylim(-10, 10)
  colors = [indexcolors[int(i) % len(indexcolors)] for i in labels.squeeze()]
  sc = ax.scatter(embedding[:,0],embedding[:,1], s=12, c= colors )
  legend_texts = [ x[0] for x in sorted(info.items(), key=lambda kv: kv[1])]
  patches=[]
  for i,label in enumerate(legend_texts):
    patches.append(mpatches.Patch(color=indexcolors[i], label=label))
  plt.legend(handles=patches)
  #plt.xlabel('Dim 1', fontsize=12)
  #plt.ylabel('Dim 2', fontsize=12)
  #plt.grid(True)

  for i,thumb in enumerate(paths):
        #print(thumb)
        img = PILImage.open(thumb)
        # img.thumbnail((16, 12), PILImage.ANTIALIAS)
        img.thumbnail((24, 18), PILImage.ANTIALIAS)
        img = OffsetImage(img, zoom=1)
        ab = AnnotationBbox(img, (embedding[i,0]+0.3, embedding[i,1]+0.3), xycoords='data', frameon=False)
        ax.add_artist(ab)
        ab.set_visible(True)

  # plt.show()
  plt.savefig(savefile)
Ejemplo n.º 5
0
 def __fill_retracement_spikes_text_annotations__(self, ret: str,
                                                  position_in_wave: int):
     if position_in_wave >= self.xy.shape[
             0]:  # we don't have this component for unfinished waves
         return
     index = self.xy[position_in_wave, 0]
     value = self.xy[position_in_wave, 1]
     position = self.fib_wave.comp_position_list[position_in_wave]
     is_position_retracement = position_in_wave % 2 == 0
     prefix = 'Retr.' if is_position_retracement else 'Reg.'
     value_adjusted = value + (value * 0.01 if is_position_retracement else
                               value * -0.01)
     reg_ret_value = self.fib_wave.comp_reg_ret_percent_list[
         position_in_wave - 1]
     if position_in_wave == 1:
         reg_ret_value = round(
             self.xy[position_in_wave, 1] -
             self.xy[position_in_wave - 1, 1], 2)
         text = 'P_{}={:.2f}\n{}: {:.2f}'.format(position, value, prefix,
                                                 reg_ret_value)
     else:
         text = 'P_{}={:.2f}\n{}: {:=3.1f}%'.format(position, value, prefix,
                                                    reg_ret_value)
     text_props = dict(color='crimson',
                       backgroundcolor=self.color_bg,
                       size=7)
     text_area = TextArea(text, minimumdescent=True, textprops=text_props)
     annotation_box = AnnotationBbox(text_area, (index, value_adjusted))
     annotation_box.set_visible(False)
     self.__fib_retracement_spikes_text_dic[ret] = annotation_box
Ejemplo n.º 6
0
def annot_and_hover(x_pos, y_pos, fig, ax, bars, slider_val, df):
    '''
    This is a function to set up the annotation and hover so as to show the text on the top of the bar
    x_pos : x coordinate that the pointer is pointed to
    y_pos : y coordinate that the pointer is pointed to
    fig : figure of the plot
    ax : axis of the plot
    bars : bars of the plot

    '''

    img = df['image'][0]
    img_show = OffsetImage(img, zoom=0.05)
    annot = AnnotationBbox(img_show,
                           xy=(0, 0),
                           xybox=(x_pos, y_pos),
                           xycoords='data',
                           boxcoords="offset points",
                           pad=0.3,
                           arrowprops=dict(arrowstyle="->"))

    ax.add_artist(annot)
    annot.set_visible(False)

    # this functino work with hover func to update the annotation value and pos
    def update_annot(bar):
        x = bar.get_x() + bar.get_width(
        ) / 2  # to ensure the pointer is pointed to the center
        print('{}'.format(x))
        y = bar.get_y() + bar.get_height()
        annot.xy = (x, y)
        num = int(slider_val.val + x)
        img = df['image'][num]
        img_show.set_data(img)

    # set up the hover event
    def hover(event):
        vis = annot.get_visible()
        print('slider value is{}'.format(slider_val.val))
        print('x is {}'.format(event.x))
        if event.inaxes == ax:
            for bar in bars:
                cont, ind = bar.contains(
                    event
                )  # it will return cont when mouse is over the container
                if cont:
                    update_annot(bar)
                    annot.set_visible(True)
                    fig.canvas.draw_idle()
                    return  # stop the function when mouse it's over any bar
        if vis:
            annot.set_visible(False)
            fig.canvas.draw_idle()

    # connect the event and call back hover functino with the background fig.convas
    fig.canvas.mpl_connect("button_press_event", hover)
Ejemplo n.º 7
0
def generate_figures(arr, x, y):

    # create figure and plot scatter
    fig = plt.figure(2)
    fig.suptitle("Samples association for each neuron")
    ax = fig.add_subplot(111)
    line, = ax.plot(x, y, ls="", marker="o")
    cell = 0
    for i in range(NN_SIZE):
        for j in range(NN_SIZE):
            ax.plot(i, j, marker="D", ls="", color="red")
            plt.text(j, NN_SIZE - i - 1, cell)
            cell += 1
    # create the annotations box
    im = OffsetImage(arr[0, :, :], zoom=5)
    xybox = (50., 50.)
    ab = AnnotationBbox(im, (0, 0),
                        xybox=xybox,
                        xycoords='data',
                        boxcoords="offset points",
                        pad=0.3,
                        arrowprops=dict(arrowstyle="->"))
    # add it to the axes and make it invisible
    ax.add_artist(ab)
    ab.set_visible(False)

    def hover(event):
        # if the mouse is over the scatter points
        if line.contains(event)[0]:
            # find out the index within the array from the event
            ind = line.contains(event)[1]["ind"]
            ind = np.array(ind[0])
            # get the figure size
            w, h = fig.get_size_inches() * fig.dpi
            ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.)
            hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.)
            # if event occurs in the top or right quadrant of the figure,
            # change the annotation box position relative to mouse.
            ab.xybox = (xybox[0] * ws, xybox[1] * hs)
            # make annotation box visible
            ab.set_visible(True)
            # place it at the position of the hovered scatter point
            ab.xy = (x[ind], y[ind])
            # set the image corresponding to that point
            im.set_data(arr[ind, :, :])
        else:
            #if the mouse is not over a scatter point
            ab.set_visible(False)
        fig.canvas.draw_idle()

    # add callback for mouse moves
    fig.canvas.mpl_connect('motion_notify_event', hover)
    plt.grid(True)
    plt.show()
Ejemplo n.º 8
0
    def createLabelSquare(self):

        ta = TextArea("Test 1", minimumdescent=False)
        ab = AnnotationBbox(ta,
                            xy=(0, 0),
                            xybox=(1.02, 1),
                            xycoords=("data", "axes fraction"),
                            boxcoords=("data", "axes fraction"),
                            box_alignment=(0, 0),
                            frameon=True)
        ab.set_visible(False)

        return ta, ab
Ejemplo n.º 9
0
 def onpick3(event):
   if event.button==1:
     cont, ind = sc.contains(event)
     for thumb in ind['ind']:
       print(paths[thumb])
       img = PILImage.open(paths[thumb])
       img.thumbnail((128, 96), PILImage.ANTIALIAS)
       img = OffsetImage(img, zoom=1)
       ab = AnnotationBbox(img, (embedding[thumb,0]+0.2, embedding[thumb,1]+0.2), xycoords='data', frameon=False)
       ax.add_artist(ab)
       ab.set_visible(True)
       t_list.append(ab)
   else:
     for annot in t_list:
       annot.set_visible(False)
   event.canvas.draw()
Ejemplo n.º 10
0
def add_pie(treeplot, node, values, colors=None, size=16, norm=True,
        xoff=0, yoff=0,
        halign=0.5, valign=0.5,
        xycoords='data', boxcoords=('offset points'), vis=True):
    """
    Draw a pie chart

    Args:
    node (Node): A single Node object or node label
    values (list): A list of floats.
    colors (list): A list of strings to pull colors from. Optional.
    size (float): Diameter of the pie chart
    norm (bool): Whether or not to normalize the values so they
      add up to 360
    xoff, yoff (float): X and Y offset. Optional, defaults to 0
    halign, valign (float): Horizontal and vertical alignment within
      box. Optional, defaults to 0.5

    """
    x, y = xy(treeplot, node)
    da = DrawingArea(size, size); r = size*0.5; center = (r,r)
    x0 = 0
    S = 360.0
    if norm: S = 360.0/sum(values)
    if not colors:
        c = _tango
        colors = [ next(c) for v in values ]
    for i, v in enumerate(values):
        theta = v*S
        if v: da.add_artist(Wedge(center, r, x0, x0+theta,
                                  fc=colors[i], ec='none'))
        x0 += theta
    box = AnnotationBbox(da, (x,y), pad=0, frameon=False,
                         xybox=(xoff, yoff),
                         xycoords=xycoords,
                         box_alignment=(halign,valign),
                         boxcoords=boxcoords)
    treeplot.add_artist(box)
    box.set_visible(vis)
    treeplot.figure.canvas.draw_idle()
    return box
    def plot_chemical_space(self, embedding, legend_elements, colors=None):

        if colors is None:
            colors = [0 for _ in range(len(self._pdbs))]
        elif colors is not None:
            colors = [sns.color_palette()[x] for x in colors]

        annotations = self._pdbs
        fig, ax = plt.subplots()
        images_path = self._im_path
        images = np.array([
            files for files in os.listdir(images_path) if files.endswith("png")
        ])
        image_path = sorted(np.asarray(images))
        annot1 = ax.annotate("",
                             xy=(5, 5),
                             xytext=(1000, 1000),
                             xycoords="figure pixels",
                             bbox=dict(boxstyle="round", fc="w"),
                             arrowprops=dict(arrowstyle="->"))
        image = plt.imread(os.path.join(images_path, image_path[0]))
        im = OffsetImage(image, zoom=0.6)
        annot = AnnotationBbox(im,
                               xy=(0, 0),
                               xybox=(1.2, 0.5),
                               boxcoords="offset points")
        ax.add_artist(annot)
        annot.set_visible(False)
        annot1.set_visible(False)
        sc = plt.scatter(embedding[:, 0],
                         embedding[:, 1],
                         c=colors,
                         s=200,
                         alpha=1,
                         edgecolors="k")
        if legend_elements is not None:
            ax.legend(handles=legend_elements)
        ax.set_xlabel("PCA 1", fontsize=26, labelpad=25)
        ax.set_ylabel("PCA 2", fontsize=26, labelpad=30)

        def update_annot(ind):

            pos = sc.get_offsets()[ind["ind"][0]]
            annot.xy = pos
            annot1.xy = pos
            annot1.set_text(annotations[int(ind["ind"][0])])
            ax.set_title(annotations[int(ind["ind"][0])])
            im.set_data(
                plt.imread(os.path.join(images_path,
                                        image_path[int(ind["ind"][0])]),
                           format="png"))

        def hover(event):

            vis = annot.get_visible()
            if event.inaxes == ax:
                cont, ind = sc.contains(event)
                if cont:
                    update_annot(ind)
                    annot.set_visible(True)
                    annot1.set_visible(False)
                    fig.canvas.draw_idle()
                else:
                    if vis:
                        annot.set_visible(False)
                        annot1.set_visible(False)
                        fig.canvas.draw_idle()

        fig.canvas.mpl_connect("motion_notify_event", hover)
        plt.show()
def main():

    parser = argparse.ArgumentParser(
        description=
        "Classify zones using complexity criteria (unsupervised learning)")

    parser.add_argument('--data',
                        type=str,
                        help='required data file',
                        required=True)
    parser.add_argument('--clusters',
                        type=int,
                        help='number of expected clusters',
                        default=2)
    parser.add_argument('--output',
                        type=str,
                        help='output folder name',
                        required=True)

    args = parser.parse_args()

    p_data = args.data
    p_clusters = args.clusters
    p_output = args.output

    x_values = []
    images_path = []
    images = []
    zones = []
    scenes = []

    with open(p_data, 'r') as f:
        for line in f.readlines():
            data = line.split(';')
            del data[-1]

            scene = data[0]
            if scene not in scenes:
                scenes.append(scene)

            images_path.append(data[1])
            zones.append(int(data[2]))

            img_arr = segmentation.divide_in_blocks(Image.open(data[1]),
                                                    (200, 200))[int(data[2])]
            images.append(np.array(img_arr))

            x = []
            for v in data[3:]:
                x.append(float(v))

            x_values.append(x)

    print(scenes)
    # plt.show()
    # TODO : save kmean model
    kmeans = KMeans(init='k-means++', n_clusters=p_clusters, n_init=10)
    labels = kmeans.fit(x_values).labels_

    unique, counts = np.unique(labels, return_counts=True)
    print(dict(zip(unique, counts)))

    pca = PCA(n_components=2)
    x_data = pca.fit_transform(x_values)

    # Need to create as global variable so our callback(on_plot_hover) can access
    fig, ax = plt.subplots()
    fig.set_figheight(20)
    fig.set_figwidth(40)

    ax.tick_params(axis='both', which='major', labelsize=20)

    sc = plt.scatter(x_data[:, 0], x_data[:, 1], c=labels, linewidths=10)

    # annot = ax.annotate("", xy=(0,0), xytext=(20,20), textcoords="offset points",
    #                     bbox=dict(boxstyle="round", fc="w"),
    #                     arrowprops=dict(arrowstyle="->"))
    imagebox = OffsetImage(images[0], zoom=1.2)
    imagebox.image.axes = ax

    annot = AnnotationBbox(imagebox,
                           xy=(0, 0),
                           xybox=(-150., 150.),
                           xycoords='data',
                           boxcoords="offset points",
                           pad=0.8,
                           arrowprops=dict(arrowstyle="->"))

    annot.set_visible(False)
    ax.add_artist(annot)

    def update_annot(ind):

        imagebox = OffsetImage([images[n] for n in ind["ind"]][0], zoom=1.2)
        imagebox.image.axes = ax

        pos = sc.get_offsets()[ind["ind"][0]]

        setattr(annot, 'offsetbox', imagebox)
        annot.xy = pos
        # text = "{}, {}".format(" ".join(list(map(str,ind["ind"]))),
        #                     " ".join([images_path[n] for n in ind["ind"]]))
        # annot.text(text)
        # #annot.get_bbox_patch().set_facecolor(cmap(norm(c[ind["ind"][0]])))
        # annot.get_bbox_patch().set_alpha(0.4)

    def hover(event):
        vis = annot.get_visible()
        if event.inaxes == ax:
            cont, ind = sc.contains(event)
            if cont:
                update_annot(ind)
                annot.set_visible(True)
                fig.canvas.draw_idle()
            else:
                if vis:
                    annot.set_visible(False)
                    fig.canvas.draw_idle()

    fig.canvas.mpl_connect("motion_notify_event", hover)
    #plt.show()

    if not os.path.exists(model_output_folder):
        os.makedirs(model_output_folder)

    model_path = os.path.join(model_output_folder, p_output + '.joblib')
    print('Model saved into {0}'.format(model_path))
    joblib.dump(kmeans, model_path)
Ejemplo n.º 13
0
class ZoomPlot():

    def __init__(self, pnts):
        self.fig = plt.figure(figsize=(15,9))
        self.ax = self.fig.add_subplot(111)

        self.days = pnts['days']
        self.lons = pnts['lons']
        self.lats = pnts['lats']
        self.dirs = pnts['dirs']
        self.coms = pnts['coms']

        self.bnds = self.bnds_strt = [-58, 80, -180, 180]
        self.resolution = 'c'

        # add callback for mouse clicks
        self.fig.canvas.mpl_connect('button_press_event', self.onclick)

        self.plot_map()

    def plot_map(self):
        self.map = Basemap(projection='merc',llcrnrlat=self.bnds[0],urcrnrlat=self.bnds[1],
                      llcrnrlon=self.bnds[2],urcrnrlon=self.bnds[3],resolution=self.resolution)

        self.map.drawcoastlines()
        self.map.drawmapboundary(fill_color='cornflowerblue')
        self.map.fillcontinents(color='lightgreen', lake_color='aqua')
        self.map.drawcountries()
        self.map.drawstates()

        self.plot_points()

        self.fig.canvas.draw()

        self.zoomcall = self.ax.callbacks.connect('ylim_changed', self.onzoom)

    def onzoom(self, axes):
        #print('zoom triggered')
        self.ax.patches.clear()
        self.ax.collections.clear()
        self.ax.callbacks.disconnect(self.zoomcall)

        x1, y1 = self.map(self.ax.get_xlim()[0], self.ax.get_ylim()[0], inverse = True)
        x2, y2 = self.map(self.ax.get_xlim()[1], self.ax.get_ylim()[1], inverse = True)
        self.bnds = [y1, y2, x1, x2]

        # reset zoom to home (workaround for unidentified error when you press the home button)
        if any([a/b > 1 for a,b in zip(self.bnds,self.bnds_strt)]):
            self.bnds = self.bnds_strt # reset map boundaryies
            self.ax.lines.clear() # reset points
            self.ab.set_visible(False) # hide picture if visible

        # change map resolution based on zoom level
        zoom_set = max(abs(self.bnds[0]-self.bnds[1]),abs(self.bnds[2]-self.bnds[3]))
        if zoom_set < 30 and zoom_set >= 3:
            self.resolution = 'l'
            #print('   --- low resolution')
        elif zoom_set < 3:
            self.resolution = 'i'
            #print('   --- intermeditate resolution')
        else:
            self.resolution = 'c'
            #print('   --- coarse resolution')

        self.plot_map()

    def plot_points(self):
        self.x, self.y = self.map(self.lons, self.lats)
        self.line, = self.map.plot(self.x, self.y, color='darkmagenta', linestyle='none', marker='o', markeredgecolor='gold')

        # create the annotations box
        self.pic = mpimg.imread('pics\\profpic.png') # just to set up variables, will change later
        self.im = OffsetImage(self.pic)
        self.xybox = (50., 50.)
        self.ab = AnnotationBbox(self.im, (0,0), xybox=self.xybox, xycoords='data',
                boxcoords="offset points",  pad=0.3,  arrowprops=dict(arrowstyle="->"))
        # add it to the axes and make it invisible
        self.ax.add_artist(self.ab)
        self.ab.set_visible(False)

    def onclick(self, event): # if you click on a data point
        if self.line.contains(event)[0]:
            # find out the index within the array from the event
            try:
                ind, = self.line.contains(event)[1]["ind"]
            except ValueError:
                self.ax.text(0.5, 0.5, 'Please zoom in!', fontsize=24,

                             ha='center', va='center', transform=self.ax.transAxes, weight='bold')
                self.fig.canvas.draw()
                time.sleep(0.7)
                self.ax.texts.clear()
            else:
                # get the figure size
                w,h = self.fig.get_size_inches()*self.fig.dpi
                ws = (event.x > w/2.)*-1 + (event.x <= w/2.)
                hs = (event.y > h/2.)*-1 + (event.y <= h/2.)
                # if event occurs in the top or right quadrant of the figure,
                # change the annotation box position relative to mouse.
                self.ab.xybox = (self.xybox[0]*ws, self.xybox[1]*hs)
                # make annotation box visible
                self.ab.set_visible(True)
                # place it at the position of the hovered scatter point
                self.ab.xy = (self.x[ind], self.y[ind])
                # set the image corresponding to that point
                dir = self.dirs[ind]
                self.im.set_data(mpimg.imread(dir))
                # change zoom of the image to deal with different file sizes
                picsize = max(mpimg.imread(dir).shape)
                self.im.set_zoom(0.4*(1000/picsize)) # optimum: zoom = 0.4, picsize = 1000
        else:
            #if you didn't click on a data point
            self.ab.set_visible(False)
        self.fig.canvas.draw_idle()
def plot_interpolation(sess,
                       ae,
                       data,
                       first_image_id,
                       second_image_id,
                       file_list,
                       graph_output_name=""):

    y, z = sess.run([ae['y'], ae['z']], feed_dict={ae['x']: np.asarray(data)})
    z = np.squeeze(z)
    y = np.squeeze(y)

    if (len(z.shape) == 1):
        z_size = 1
        z = np.expand_dims(z, axis=1)
        data_temp = np.hstack((z, z))
    else:
        z_size = z.shape[1]

    img_size = 64
    num_examples = 100
    decimalPlaces = 3
    currInd = 0
    # first z
    z_first = z[first_image_id, :]
    #second z
    z_second = z[second_image_id, :]

    z_list = z_first + np.expand_dims(np.linspace(0, 1, num_examples),
                                      axis=1) * (z_second - z_first)
    x_list_first = sess.run(ae['y'],
                            feed_dict={
                                ae['z']:
                                np.reshape(z_list,
                                           (z_list.shape[0], 1, 1, z_size))
                            })

    n_iters = 10
    z_list_iter = z_list
    x_list = x_list_first
    for ind_i in range(0, n_iters):
        if (z_size == 1):
            x_list_temp = sess.run(
                ae['y'],
                feed_dict={
                    ae['z']:
                    np.reshape(z_list_iter[:, 0],
                               (z_list_iter.shape[0], 1, 1, z_size))
                })
        else:
            x_list_temp = sess.run(
                ae['y'],
                feed_dict={
                    ae['z']:
                    np.reshape(z_list_iter,
                               (z_list_iter.shape[0], 1, 1, z_size))
                })
        x_list = np.zeros((x_list_temp.shape[0], img_size * img_size))
        for ind_i in range(0, x_list.shape[0]):
            x_list[ind_i, :] = np.reshape(x_list_temp[ind_i, :, :],
                                          (1, img_size * img_size))
        for ind_i in range(0, x_list.shape[0]):
            x_list[ind_i, :] = x_list[ind_i, :] - np.min(
                x_list[ind_i, :].flatten())
            x_list[ind_i, :] = x_list[ind_i, :] / (np.max(
                x_list[ind_i, :].flatten()))
        z_list_iter = sess.run(ae['z'],
                               feed_dict={ae['x']: np.asarray(x_list)})

    if (len(z_list_iter.shape) == 1):
        z_list_iter = np.expand_dims(z_list_iter, axis=1)
    else:
        z_list_iter = np.squeeze(z_list_iter)
    # plot result
    fig = plt.figure()
    ax = fig.add_subplot(111)

    #set up colours
    first_colour = 1
    second_colour = 5
    third_colour = 10
    colour_list = np.vstack( (first_colour*np.ones((z.shape[0],1)) , \
     second_colour*np.ones((z_list.shape[0],1)),  third_colour*np.ones((z_list_iter.shape[0],1))))

    plot_data = np.vstack(
        (z[:, 0:z_size], z_list[:, 0:z_size], z_list_iter[:, 0:z_size]))
    img_data = np.vstack(
        (y, np.reshape(x_list_first, (x_list.shape[0], img_size, img_size)),
         np.reshape(x_list, (x_list.shape[0], img_size, img_size))))
    if (z_size == 1):
        plot_data = np.hstack((plot_data, plot_data))
        z = np.hstack((z, z))

    line = ax.scatter(plot_data[:, 0], plot_data[:, 1], s=10,
                      c=colour_list)  #,picker=True)
    if (graph_output_name == ""):
        # create the annotations box
        im = OffsetImage(img_data[0, :, :], zoom=5)
        xybox = (50., 50.)
        ab = AnnotationBbox(im, (0, 0),
                            xybox=xybox,
                            xycoords='data',
                            boxcoords="offset points",
                            pad=0.3,
                            arrowprops=dict(arrowstyle="->"))
        # add it to the axes and make it invisible
        ax.add_artist(ab)
        ab.set_visible(False)

        def on_pick(event):
            # if the mouse is over the scatter points

            # find out the index within the array from the event
            #ind = line.contains(event)[1]["ind"]
            xdata, ydata = line.get_data()

            # get the figure size
            w, h = fig.get_size_inches() * fig.dpi
            ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.)
            hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.)
            # if event occurs in the top or right quadrant of the figure,
            # change the annotation box position relative to mouse.
            ab.xybox = (xybox[0] * ws, xybox[1] * hs)
            # make annotation box visible
            ab.set_visible(True)
            # place it at the position of the hovered scatter point
            ab.xy = (plot_data[ind[0], 0], plot_data[ind[0], 1])
            # set the image corresponding to that point
            im.set_data(img_data[ind[0], :, :])

        # 	fig.canvas.draw_idle()
        def hover(event):
            # if the mouse is over the scatter points
            if line.contains(event)[0]:
                # find out the index within the array from the event
                ind = line.contains(event)[1]["ind"]
                # get the figure size
                w, h = fig.get_size_inches() * fig.dpi
                ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.)
                hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.)
                # if event occurs in the top or right quadrant of the figure,
                # change the annotation box position relative to mouse.
                ab.xybox = (xybox[0] * ws, xybox[1] * hs)
                # make annotation box visible
                ab.set_visible(True)
                # place it at the position of the hovered scatter point
                ab.xy = (plot_data[ind[0], 0], plot_data[ind[0], 1])
                # set the image corresponding to that point
                im.set_data(img_data[ind[0], :, :])
            else:
                #if the mouse is not over a scatter point
                ab.set_visible(False)
            fig.canvas.draw_idle()

        # add callback for mouse moves
        fig.canvas.mpl_connect('motion_notify_event', hover)
        #fig.canvas.mpl_connect('pick_event', on_pick)
        plt.show()
    else:
        plt.savefig(graph_output_name)
Ejemplo n.º 15
0
    def DrawtSNE(self):
        global imagesavepath, currentdate, labels, cavans, groups, labels_dict, groups3D, test_data, test_annotation
        from matplotlib.offsetbox import OffsetImage, AnnotationBbox
        global df, df_3D, Y, Y_3D
        print("DrawtSNE Start ")
        test_annotation = test_data.reshape(-1, 190, 190)
        # 2D tSNE
        plt.cla()
        fig, ax = plt.subplots()
        ax.margins(0.05)  # Optional, just adds 5% padding to the autoscaling
        points_with_annotation = []
        for label, group in groups:
            name = labels_dict[label]
            point, = ax.plot(group.x,
                             group.y,
                             marker='o',
                             linestyle='',
                             ms=5,
                             label=name,
                             alpha=0.5)
            points_with_annotation.append([point])
        plt.title('t-SNE Scattering Plot')
        ax.legend()
        cavans2D = FigureCanvas(fig)
        #Annotation

        # create the annotations box
        im = OffsetImage(test_annotation[0, :, :], zoom=0.25, cmap='gray')
        xybox = (10., 10.)
        ab = AnnotationBbox(im, (10, 10),
                            xybox=xybox,
                            xycoords='data',
                            boxcoords="offset points",
                            pad=0.3,
                            arrowprops=dict(arrowstyle="->"))

        # add it to the axes and make it invisible
        ax.add_artist(ab)
        ab.set_visible(False)

        tsneprelabel = int(len(test_data) / len(labels))

        def hover(event):
            global df, test_annotation
            i = 0
            ispointed = np.zeros((len(groups), ), dtype=bool)
            for point in points_with_annotation:
                if point[0].contains(event)[0]:
                    ispointed[i] = True
                    cont, ind = point[0].contains(event)
                    image_index = ind["ind"][0] + i * tsneprelabel
                    # get the figure size
                    w, h = fig.get_size_inches() * fig.dpi
                    ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.)
                    hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.)
                    # if event occurs in the top or right quadrant of the figure,
                    # change the annotation box position relative to mouse.
                    ab.xybox = (xybox[0] * ws, xybox[1] * hs)
                    # place it at the position of the hovered scatter point
                    global df, test_annotation
                    df = df
                    ab.xy = (df['x'][image_index], df['y'][image_index])
                    # set the image corresponding to that point
                    im.set_data(test_annotation[image_index, :, :])
                    ab.set_visible(True)
                else:
                    ispointed[i] = False
                i = i + 1
            ab.set_visible(max(ispointed))
            fig.canvas.draw_idle()

        cid = fig.canvas.mpl_connect('motion_notify_event', hover)
        rows = int(self.tSNE_Layout.count())
        if rows == 1:
            myWidget = self.tSNE_Layout.itemAt(0).widget()
            myWidget.deleteLater()
        self.tSNE_Layout.addWidget(cavans2D)

        print("tSNE 2D Finished")
        # 3D tSNE
        fig_3D = plt.figure()
        cavans3D = FigureCanvas(fig_3D)
        ax_3D = Axes3D(fig_3D)
        ax_3D.margins(
            0.05)  # Optional, just adds 5% padding to the autoscaling
        test_annotation = test_data.reshape(-1, 190, 190)
        for label, group in groups3D:
            name = labels_dict[label]
            ax_3D.scatter(group.x,
                          group.y,
                          group.z,
                          marker='o',
                          label=name,
                          alpha=0.8)
        ax_3D.legend()
        ax_3D.patch.set_visible(False)
        ax_3D.set_axis_off()
        ax_3D._axis3don = False
        from matplotlib.offsetbox import OffsetImage, AnnotationBbox
        im_3D = OffsetImage(test_annotation[0, :, :], zoom=0.25, cmap='gray')
        xybox = (10., 10.)
        ab_3D = AnnotationBbox(im_3D, (10, 10),
                               xybox=xybox,
                               xycoords='data',
                               boxcoords="offset points",
                               pad=0.3,
                               arrowprops=dict(arrowstyle="->"))
        # add it to the axes and make it invisible
        ax_3D.add_artist(ab_3D)
        ab_3D.set_visible(False)

        def onMouseMotion(event):
            global Y_3D
            distances = []
            for i in range(Y_3D.shape[0]):
                x2, y2, _ = proj3d.proj_transform(Y_3D[i, 0], Y_3D[i, 1],
                                                  Y_3D[i, 2], ax_3D.get_proj())
                x3, y3 = ax_3D.transData.transform((x2, y2))
                distance = np.sqrt((x3 - event.x)**2 + (y3 - event.y)**2)
                distances.append(distance)
            closestIndex = np.argmin(distances)
            print(closestIndex)
            x2, y2, _ = proj3d.proj_transform(Y_3D[closestIndex,
                                                   0], Y_3D[closestIndex, 1],
                                              Y_3D[closestIndex, 2],
                                              ax_3D.get_proj())
            ab_3D.xy = (x2, y2)
            im_3D.set_data(test_annotation[closestIndex, :, :])
            ab_3D.set_visible(True)
            fig_3D.canvas.draw_idle()

        cid3d = fig_3D.canvas.mpl_connect('motion_notify_event',
                                          onMouseMotion)  # on mouse motion

        rows = int(self.tSNE3D_Layout.count())
        if rows == 1:
            myWidget = self.tSNE3D_Layout.itemAt(0).widget()
            myWidget.deleteLater()
        self.tSNE3D_Layout.addWidget(cavans3D)
        self.movie.stop()
        self.movie.jumpToFrame(0)
        self.label_7.setText('Finished')
        self.label_7.setStyleSheet("color: rgb(70, 70, 70);\n"
                                   "font: 75 14pt \"MS Shell Dlg 2\";")
Ejemplo n.º 16
0
def visualize2DData(X, values_to_show, fig, ax, image_path, centroids,
                    colors_scatter, datasets):
    """Visualize data in 2d plot with popover next to mouse position.

    Args:
        X (np.array) - array of points, of shape (numPoints, 2)
        fig - the figure to plot
        ax - the axis
        image_path - the images paths, of shape (N)
        centroids - the clusters' centroids, of shape (#C)
        colors_scatter - the colors to be used in scatter plot, of shape (N)
        datasets - the datasets names
    Returns:
        None
    """

    cmap = plt.cm.RdYlGn
    line = plt.scatter(X[:, 0], X[:, 1], s=10, cmap=cmap, c=colors_scatter)
    plt.legend(handles=line.legend_elements()[0], labels=datasets)

    # create the annotations box
    image = plt.imread(image_path[0])
    im = OffsetImage(image, zoom=0.1)
    xybox = (50., 50.)
    ab = AnnotationBbox(im, (0, 0),
                        xybox=xybox,
                        xycoords='data',
                        boxcoords="offset points",
                        pad=0.3,
                        arrowprops=dict(arrowstyle="->"))

    # add it to the axes and make it invisible
    ax.add_artist(ab)
    ab.set_visible(False)

    # add skeleton images - to do
    #ab2 = AnnotationBbox(im, (0, 0), xybox=(-20,-40), xycoords='data',
    #                    boxcoords="offset points", pad=0.3, arrowprops=dict(arrowstyle="->"))

    # add it to the axes and make it invisible
    #ax.add_artist(ab2)
    #ab2.set_visible(False)

    ax.scatter(centroids[:, 0],
               centroids[:, 1],
               marker='x',
               s=169,
               linewidths=3,
               color='w',
               zorder=10)

    def distance(point, event):
        """Return distance between mouse position and given data point

        Args:
            point (np.array): np.array of shape (2,), with x,y,z in data coords
            event (MouseEvent): mouse event (which contains mouse position in .x and .xdata)
        Returns:
            distance (np.float64): distance (in screen coords) between mouse pos and data point
        """
        assert point.shape == (
            2,
        ), "distance: point.shape is wrong: %s, must be (3,)" % point.shape

        #x3, y3 = ax.transData.transform((x2, y2))

        x2 = point[0]
        y2 = point[1]

        # event example:
        # motion_notify_event: xy=(374, 152) xydata=(5.013503440117894, -16.23161532314907) button=None dblclick=False inaxes=AxesSubplot(0.125,0.11;0.775x0.77)

        return np.sqrt((x2 - event.xdata)**2 + (y2 - event.ydata)**2)

    def calcClosestDatapoint(X, event):
        """"Calculate which data point is closest to the mouse position.

        Args:
            X (np.array) - array of points, of shape (numPoints, 3)
            event (MouseEvent) - mouse event (containing mouse position)
        Returns:
            smallestIndex (int) - the index (into the array of points X) of the element closest to the mouse position
        """
        distances = [distance(X[i, 0:2], event) for i in range(X.shape[0])]
        return np.argmin(distances)

    def annotatePlot(X, index):
        """Create popover label in 3d chart

        Args:
            X (np.array) - array of points, of shape (numPoints, 3)
            index (int) - index (into points array X) of item which should be printed
        Returns:
            None
        """
        # If we have previously displayed another label, remove it first
        if hasattr(annotatePlot, 'label'):
            annotatePlot.label.remove()
        # Get data point from array of points X, at position index
        x2 = X[index, 0]
        y2 = X[index, 1]

        annotatePlot.label = plt.annotate("Value %d" % values_to_show[index],
                                          xy=(x2, y2),
                                          xytext=(-20, -40),
                                          textcoords='offset points',
                                          ha='right',
                                          va='bottom',
                                          bbox=dict(boxstyle='round,pad=0.5',
                                                    fc='yellow',
                                                    alpha=0.5),
                                          arrowprops=dict(
                                              arrowstyle='->',
                                              connectionstyle='arc3,rad=0'))

        ind = index
        # get the figure size
        w, h = fig.get_size_inches() * fig.dpi
        ws = (X[index, 0] > w / 2.) * -1 + (X[index, 0] <= w / 2.)
        hs = (X[index, 1] > h / 2.) * -1 + (X[index, 1] <= h / 2.)
        # if event occurs in the top or right quadrant of the figure,
        # change the annotation box position relative to mouse.
        ab.xybox = (xybox[0] * ws, xybox[1] * hs)
        # make annotation box visible
        ab.set_visible(True)
        # place it at the position of the hovered scatter point
        ab.xy = (x2, y2)
        # set the image corresponding to that point
        im.set_data(plt.imread(image_path[ind]))

        # add skeleton images - to do
        # make annotation box visible
        #ab2.set_visible(True)
        # place it at the position of the hovered scatter point
        #ab2.xy = (x2, y2)
        # set the image corresponding to that point
        #im.set_data(plt.imread(image_path[ind]))

        fig.canvas.draw()

    def onMouseMotion(event):
        """Event that is triggered when mouse is moved. Shows text annotation over data point closest to mouse."""
        if line.contains(event)[0]:
            closestIndex = calcClosestDatapoint(X, event)
            annotatePlot(X, closestIndex)
        else:
            # if the mouse is not over a scatter point
            ab.set_visible(False)

            # add skeleton images - to do
            #ab2.set_visible(False)

    fig.canvas.mpl_connect('motion_notify_event',
                           onMouseMotion)  # on mouse motion
Ejemplo n.º 17
0
def visualize3DData(X, fig, ax, image_path, C, colors_scatter, datasets):
    """Visualize data in 3d plot with popover next to mouse position.

    Args:
        X (np.array) - array of points, of shape (numPoints, 3)
    Returns:
        None
    """
    #fig = plt.figure(figsize = (16,10))
    #ax = fig.add_subplot(111, projection = '3d')
    scatter = ax.scatter(X[:, 0],
                         X[:, 1],
                         X[:, 2],
                         depthshade=False,
                         picker=True,
                         c=colors_scatter)
    ax.scatter(C[:, 0], C[:, 1], C[:, 2], marker='*', c='#050505', s=1000)

    # create the annotations box
    image = plt.imread(image_path[0])
    im = OffsetImage(image, zoom=0.1)
    xybox = (50., 50.)
    ab = AnnotationBbox(im, (0, 0),
                        xybox=xybox,
                        xycoords='data',
                        boxcoords="offset points",
                        pad=0.3,
                        arrowprops=dict(arrowstyle="->"))
    # add it to the axes and make it invisible
    ax.add_artist(ab)
    ab.set_visible(False)

    def distance(point, event):
        """Return distance between mouse position and given data point

        Args:
            point (np.array): np.array of shape (3,), with x,y,z in data coords
            event (MouseEvent): mouse event (which contains mouse position in .x and .xdata)
        Returns:
            distance (np.float64): distance (in screen coords) between mouse pos and data point
        """
        assert point.shape == (
            3,
        ), "distance: point.shape is wrong: %s, must be (3,)" % point.shape

        # Project 3d data space to 2d data space
        x2, y2, _ = proj3d.proj_transform(point[0], point[1], point[2],
                                          plt.gca().get_proj())
        # Convert 2d data space to 2d screen space
        x3, y3 = ax.transData.transform((x2, y2))

        return np.sqrt((x3 - event.x)**2 + (y3 - event.y)**2)

    def calcClosestDatapoint(X, event):
        """"Calculate which data point is closest to the mouse position.

        Args:
            X (np.array) - array of points, of shape (numPoints, 3)
            event (MouseEvent) - mouse event (containing mouse position)
        Returns:
            smallestIndex (int) - the index (into the array of points X) of the element closest to the mouse position
        """
        distances = [distance(X[i, 0:3], event) for i in range(X.shape[0])]
        return np.argmin(distances)

    def annotatePlot(X, index):
        """Create popover label in 3d chart

        Args:
            X (np.array) - array of points, of shape (numPoints, 3)
            index (int) - index (into points array X) of item which should be printed
        Returns:
            None
        """
        # If we have previously displayed another label, remove it first
        if hasattr(annotatePlot, 'label'):
            annotatePlot.label.remove()
        # Get data point from array of points X, at position index
        x2, y2, _ = proj3d.proj_transform(X[index, 0], X[index, 1], X[index,
                                                                      2],
                                          ax.get_proj())
        annotatePlot.label = plt.annotate("Value %d" % index,
                                          xy=(x2, y2),
                                          xytext=(-20, 20),
                                          textcoords='offset points',
                                          ha='right',
                                          va='bottom',
                                          bbox=dict(boxstyle='round,pad=0.5',
                                                    fc='yellow',
                                                    alpha=0.5),
                                          arrowprops=dict(
                                              arrowstyle='->',
                                              connectionstyle='arc3,rad=0'))

        ind = index
        # get the figure size
        w, h = fig.get_size_inches() * fig.dpi
        ws = (X[index, 0] > w / 2.) * -1 + (X[index, 0] <= w / 2.)
        hs = (X[index, 1] > h / 2.) * -1 + (X[index, 1] <= h / 2.)
        # if event occurs in the top or right quadrant of the figure,
        # change the annotation box position relative to mouse.
        ab.xybox = (xybox[0] * ws, xybox[1] * hs)
        # make annotation box visible
        ab.set_visible(True)
        # place it at the position of the hovered scatter point
        ab.xy = (x2, y2)
        # set the image corresponding to that point
        im.set_data(plt.imread(image_path[ind]))

        fig.canvas.draw()

    def onMouseMotion(event):
        """Event that is triggered when mouse is moved. Shows text annotation over data point closest to mouse."""
        closestIndex = calcClosestDatapoint(X, event)
        annotatePlot(X, closestIndex)

    fig.canvas.mpl_connect('motion_notify_event',
                           onMouseMotion)  # on mouse motion
    plt.legend(handles=scatter.legend_elements()[0], labels=datasets, loc=2)
Ejemplo n.º 18
0
class PolygonHover:
    def __init__(
        self, parent, polygon_info, precision, figure, axis, selected_colour
    ):  #takes in the polygon info, & then precision, figure, axis values
        self.parent = parent
        self.polygon_info = polygon_info
        self.precision = precision
        self.fig = figure
        self.ax = axis
        self.selected_polygon_colour = selected_colour

    def hover(self, event):
        # print("HOVER INFO", self.polygon_info)
        #ImportanceOfBeingErnest (2017) Possible to make labels appear when hovering over a point in matplotlib? [Online]. Available at: https://stackoverflow.com/questions/7908636/possible-to-make-labels-appear-when-hovering-over-a-point-in-matplotlib [Accessed 03 August 2020].
        hovered_over = False  #Using a Boolean for hovered over true or not as it is on mouse move motion event, so need to destroy the bbox when not over it
        if event.xdata != None:  #None if the cursor is not on the figure.
            for polygon in self.polygon_info:
                for x, y in polygon['co-ordinates']:
                    if (np.abs(x - event.xdata) < self.precision) and (
                            np.abs(y - event.ydata) < self.precision
                    ):  #if hovered over a point part of a polygon

                        #Build hover label with info
                        hovered_over = True  #change to true
                        self.ax.artists.clear()  #clear all existing labels
                        string = "Slice: {} \nId: {} \nTag: {} \nx: {} \ny: {}".format(
                            polygon['slice'], polygon['id'], polygon['tag'],
                            str(x)[:6],
                            str(y)[:6])  #as coordinates can be quite long

                        self.offsetbox = TextArea(string, minimumdescent=False)
                        self.ab = AnnotationBbox(self.offsetbox, (0, 0),
                                                 xybox=(50., 50.),
                                                 xycoords='data',
                                                 boxcoords="offset points",
                                                 pad=0.5)
                        #arrowprops=dict(arrowstyle='->, head_width=.5', color='white', linewidth=1, mutation_scale=.5 #Could use Arrow
                        self.ax.add_artist(self.ab)  #add box to axis

                        #change the colour of the lines
                        for line in polygon['lines']:
                            for plt in self.ax.lines:
                                if line == plt:
                                    plt.set_color("blue")

                        # get the figure size
                        w, h = self.fig.get_size_inches() * self.fig.dpi
                        ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.)
                        hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.)
                        # if event occurs in the top or right quadrant of the figure,
                        # change the annotation box position relative to mouse.
                        self.ab.xybox = (50 * ws, 50 * hs)
                        # make annotation box visible
                        self.ab.set_visible(True)
                        # place it at the position of the hovered scatter point
                        self.ab.xy = (x, y)

            if not hovered_over:
                #clear existing labels
                self.ax.artists.clear()
                #change the colour of the lines
                self.parent.reset_polygon_cols(
                    False)  #but not selected polygon if one
                if self.parent.selected_polygon != None:
                    self.parent.show_selected_plots(
                        self.parent.selected_polygon['scatter_points'],
                        self.ax.collections, self.selected_polygon_colour)
                    self.parent.show_selected_plots(
                        self.parent.selected_polygon['lines'], self.ax.lines,
                        self.selected_polygon_colour)

            self.fig.canvas.draw()
        else:
            #clear existing labels
            self.ax.artists.clear()
            #change the colour of the lines
            self.parent.reset_polygon_cols(False)
            if self.parent.selected_polygon != None:
                self.parent.show_selected_plots(
                    self.parent.selected_polygon['scatter_points'],
                    self.ax.collections, self.selected_polygon_colour)
                self.parent.show_selected_plots(
                    self.parent.selected_polygon['lines'], self.ax.lines,
                    self.selected_polygon_colour)

            self.fig.canvas.draw()

    #If selected polygon colour needs changing by settings change
    def change_selected_colour(self, new_col):
        self.selected_polygon_colour = new_col

    #If precision value needs to be updated by settings change
    def update_precision_value(self, new_precision_val):
        self.precision = new_precision_val

    def update_polygon_info(self, new_polygon_info):
        self.polygon_info = new_polygon_info
Ejemplo n.º 19
0
class Discovery(object):
    def __init__(self,
                 embeddings_file='embeddings.p',
                 distance_metric='euclid',
                 method='TSNE',
                 embedding_size=2,
                 overwrite_embeddings=False,
                 n_jobs=10,
                 dpi=300,
                 main_plot_args={},
                 tsne_args={},
                 save_dir=join(CONFIG.metaseg_io_path, 'vis_embeddings')):
        """Loads the embedding files, computes the dimensionality reductions and calls the initilization
        of the main plot.

        Args:
            embeddings_file (str): Path to the file where all data of segments including feature embeddings is saved.
            distance_metric (str): Distance metric to use for nearest neighbor computation.
            method (str): Method to use for dimensionality reduction of nearest neighbor embeddings. For plotting the
                points are always reduced in dimensionality using PCA to 50 dimensions followed by t-SNE to two
                dimensions.
            embedding_size (int): Dimensionality of the feature embeddings used for nearest neighbor search.
            overwrite_embeddings (bool): If True, precomputed nearest neighbor and plotting embeddings from previous
                runs are overwritten with freshly computed ones. Otherwise precomputed embeddings are used if requested
                embedding_size is matching.
            n_jobs (int): Number of processes to use for t-SNE computation.
            dpi (int): Dots per inch for graphics that are saved to disk.
            main_plot_args (dict): Keyword arguments for the creation of the main plot.
            tsne_args (dict): Keyword arguments for the t-SNE algorithm.
            save_dir (str): Path to the directory where saved images are placed in.
        """
        self.log = logging.getLogger('Discovery')
        self.embeddings_file = embeddings_file
        self.distance_metrics = ['euclid', 'cos']
        self.dm = 0 if distance_metric not in self.distance_metrics else self.distance_metrics.index(distance_metric)

        self.dpi = dpi
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)

        self.cluster_methods = OrderedDict()
        self.cluster_methods['kmeans'] = {'main': KMeans, 'kwargs': {}}
        # self.cluster_methods['spectral'] = {'main': SpectralClustering, 'kwargs': {}}
        self.cluster_methods['agglo'] = {'main': AgglomerativeClustering, 'kwargs': {'linkage': 'ward'}}

        self.methods_with_ncluster_param = ['kmeans', 'spectral', 'agglo']
        self.cme = 0
        self.clustering = None
        self.n_clusters = 25

        # colors:
        self.standard_color = (0, 0, 1, 1)
        self.current_color = (1, 0, 0, 1)
        self.nn_color = (1, 0, 1, 1)

        self.log.info('Loading data...')
        with open(self.embeddings_file, 'rb') as f:
            self.data = pkl.load(f)

        self.iou_preds = self.data['iou_pred']
        self.gt = np.array(self.data['gt']).flatten()
        self.pred = np.array(self.data['pred']).flatten()
        self.gi = self.data['image_level_index']  # global indices (on image level and not on component level)

        self.log.info('Loaded {} segment embeddings.'.format(self.pred.shape[0]))

        self.nearest_neighbors = None

        if len(self.data['embeddings']) == 1:
            self.data['plot_embeddings'] = np.array([self.data['embeddings'][0][0],
                                                     self.data['embeddings'][0][1]]).reshape((1, 2))
            self.data['nn_embeddings'] = self.data['plot_embeddings']
        else:
            if ('nn_embeddings' not in self.data.keys()
                or overwrite_embeddings
                or 'plot_embeddings' not in self.data.keys()) \
                    and embedding_size < self.data['embeddings'][0].shape[0]:
                self.log.info('Computing PCA...')
                n_comp = 50 if 50 < min(len(self.data['embeddings']),
                                        self.data['embeddings'][0].shape[0]) else min(len(self.data['embeddings']),
                                                                                      self.data['embeddings'][0].shape[0])
                embeddings = PCA(
                    n_components=n_comp
                ).fit_transform(np.stack(self.data['embeddings']).reshape((-1, self.data['embeddings'][0].shape[0])))
                rewrite = True
            else:
                rewrite = False

            if 'plot_embeddings' not in self.data.keys() or overwrite_embeddings:
                self.log.info('Computing t-SNE for plotting')
                self.data['plot_embeddings'] = TSNE(n_components=2, **tsne_args).fit_transform(embeddings)
                new_plot_embeddings = True
            else:
                new_plot_embeddings = False

            if embedding_size >= self.data['embeddings'][0].shape[0] or embedding_size is None:
                self.embeddings = np.stack(self.data['embeddings']).reshape((-1, self.data['embeddings'][0].shape[0]))
                self.log.debug(
                    ('Requested embedding size of {} was greater or equal to data dimensionality of {}. '
                     'Data has thus not been reduced in dimensionality.').format(
                        embedding_size,
                        self.data['embeddings'].shape[1]))
            elif (self.data['nn_embeddings'].shape[1] == embedding_size if 'nn_embeddings' in self.data.keys() else False) \
                    and not overwrite_embeddings:
                self.embeddings = self.data['nn_embeddings']
                self.log.info(('Loaded reduced embeddings ({} dimensions) from precomputed file '
                               + 'for nearest neighbor search.').format(self.embeddings.shape[1]))
            else:
                if method == 'TSNE':
                    if 'plot_embeddings' in self.data.keys() and embedding_size == 2 and new_plot_embeddings:
                        self.embeddings = self.data['plot_embeddings']
                        self.log.info('Reused the precomputed manifold for plotting for nearest neighbor search.')
                    else:
                        self.log.info('Computing t-SNE of dimension {} for nearest neighbor search...'.format(
                            embedding_size))
                        self.embeddings = TSNE(
                            n_components=embedding_size,
                            n_jobs=n_jobs,
                            **tsne_args
                        ).fit_transform(embeddings)
                else:
                    self.log.info('Computing Isomap of dimension {} for nearest neighbor search...'.format(embedding_size))
                    self.embeddings = Isomap(n_components=embedding_size,
                                             n_jobs=n_jobs,
                                             ).fit_transform(embeddings)
                self.data['nn_embeddings'] = self.embeddings

            # Write added data into pickle file
            if rewrite:
                with open(self.embeddings_file, 'wb') as f:
                    pkl.dump(self.data, f)

        self.x = self.data['plot_embeddings'][:, 0]
        self.y = self.data['plot_embeddings'][:, 1]

        self.label_mapping = dict()
        for d in np.unique(self.data['dataset']).flatten():
            self.label_mapping[d] = getattr(
                importlib.import_module(datasets[d].module_name),
                datasets[d].class_name)(
                **datasets[d].kwargs,
            ).label_mapping
        train_dat = self.label_mapping[CONFIG.TRAIN_DATASET.name] = getattr(
            importlib.import_module(CONFIG.TRAIN_DATASET.module_name),
            CONFIG.TRAIN_DATASET.class_name)(
            **CONFIG.TRAIN_DATASET.kwargs,
        )
        self.pred_mapping = train_dat.pred_mapping
        if CONFIG.TRAIN_DATASET.name not in self.label_mapping:
            self.label_mapping[CONFIG.TRAIN_DATASET.name] = train_dat.label_mapping

        self.tnsize = (50, 50)
        self.fig_nn = None
        self.n_neighbors = 49
        self.current_pressed_key = None

        self.plot_main(**main_plot_args)

    def plot_main(self, **plot_args):
        """Initializes the main plot.

        Only 'legend' (bool) is currently supported as keyword argument.
        """
        self.fig_main = plt.figure(num=1)
        self.fig_main.canvas.set_window_title('Embedding space')
        ax = self.fig_main.add_subplot(111)
        ax.set_axis_off()
        self.line_main = ax.scatter(self.x, self.y,
                                    marker="o",
                                    color=np.stack([tuple(i / 255.0
                                                          for i in self.label_mapping[
                                                              self.data['dataset'][self.gi[ind]]
                                                          ][self.gt[ind]][1])
                                                    + (1.0,) for ind in range(self.x.shape[0])]))
        self.line_main.set_picker(True)

        if plot_args['legend'] if 'legend' in plot_args else False:
            box = ax.get_position()
            ax.set_position([box.x0, box.y0, box.width, box.height * 0.8])
            legend_elements = []
            for d in np.unique(self.data['dataset']).flatten():
                cls = np.unique(self.gt[np.array(self.data['dataset'])[self.gi] == d])
                cls = list({(self.label_mapping[d][cl][0], self.label_mapping[d][cl][1]) for cl in cls})
                names = np.array([i[0] for i in cls])
                cols = np.array([i[1] for i in cls])
                legend_elements += [Patch(color=tuple(i / 255.0 for i in cols[i]) + (1.0,),
                                          label=names[i]
                                          if not names[i][-1].isdigit()
                                          else names[i][:names[i].rfind(' ')])
                                    for i in range(names.shape[0])]
            ax.legend(loc='upper left', handles=legend_elements, ncol=8, bbox_to_anchor=(0, 1.2))
        self.basecolors = self.line_main.get_facecolor()

        tmp = Image.open(self.data['image_path'][self.gi[0]]).convert('RGB').crop(
            self.data['box'][0])
        tmp.thumbnail(self.tnsize, Image.ANTIALIAS)
        self.im = OffsetImage(tmp, zoom=2)
        self.xybox = (50., 50.)
        self.ab = AnnotationBbox(self.im, (0, 0), xybox=self.xybox, xycoords='data',
                                 boxcoords='offset points', pad=0.3, arrowprops=dict(arrowstyle='->'))
        ax.add_artist(self.ab)
        self.ab.set_visible(False)

        if plot_args['save_path'] is not None if 'save_path' in plot_args else False:
            plt.savefig(expanduser(plot_args['save_path']), dpi=300, bbox_inches='tight')

        else:
            self.fig_main.canvas.mpl_connect('motion_notify_event', self.hover_main)
            self.fig_main.canvas.mpl_connect('button_press_event', self.click_main)
            self.fig_main.canvas.mpl_connect('scroll_event', self.scroll)
            self.fig_main.canvas.mpl_connect('key_press_event', self.key_press)
            self.fig_main.canvas.mpl_connect('key_release_event', self.key_release)
            plt.show()

    def hover_main(self, event):
        """Action handler for the main plot.

        This function shows a thumbnail of the underlying image when a scatter point is hovered with the mouse.
        """
        # if the mouse is over the scatter points
        if self.line_main.contains(event)[0]:
            # find out the index within the array from the event
            ind, *_ = self.line_main.contains(event)[1]["ind"]

            # get the figure size
            w, h = self.fig_main.get_size_inches() * self.fig_main.dpi
            ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.)
            hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.)

            # if event occurs in the top or right quadrant of the figure,
            # change the annotation box position relative to mouse.
            self.ab.xybox = (self.xybox[0] * ws, self.xybox[1] * hs)

            # make annotation box visible
            self.ab.set_visible(True)

            # place it at the position of the hovered scatter point
            self.ab.xy = (self.x[ind], self.y[ind])

            # set the image corresponding to that point
            tmp = Image.open(self.data['image_path'][self.gi[ind]]).convert('RGB').crop(self.data['box'][ind])
            tmp.thumbnail(self.tnsize, Image.ANTIALIAS)
            self.im.set_data(tmp)
            tmp.close()
        else:
            # if the mouse is not over a scatter point
            self.ab.set_visible(False)
        self.fig_main.canvas.draw_idle()

    def click_main(self, event):
        """Action handler for the main plot.

        This function shows a single or full image or displays nearest neighbors based on the button that has been
        pressed and which scatter point was pressed.
        """
        if self.line_main.contains(event)[0]:
            ind, *_ = self.line_main.contains(event)[1]['ind']

            if self.current_pressed_key == 't' and event.button == 1:
                self.store_thumbnail(ind)
            elif self.current_pressed_key == 'control' and event.button == 1:
                self.show_single_image(ind, save=True)
            elif self.current_pressed_key == 'control' and event.button == 2:
                self.show_full_image(ind, save=True)
            elif event.button == 1:  # left mouse button
                self.show_single_image(ind)
            elif event.button == 2:  # middle mouse button
                self.show_full_image(ind)
            elif event.button == 3:  # right mouse button
                if not plt.fignum_exists(2):
                    # nearest neighbor figure is not open anymore or has not been opened yet
                    self.log.info('Loading nearest neighbors...')
                    self.nearest_neighbors = self.get_nearest_neighbors(ind, metric=self.distance_metrics[self.dm])
                    thumbnails = []
                    for neighbor_ind in self.nearest_neighbors:
                        thumbnails.append(Image.open(self.data['image_path'][self.gi[neighbor_ind]]).crop(
                            self.data['box'][neighbor_ind]))
                    columns = math.ceil(math.sqrt(self.n_neighbors))
                    rows = math.ceil(self.n_neighbors / columns)

                    self.fig_nn = plt.figure(num=2, dpi=self.dpi)
                    self.fig_nn.canvas.set_window_title('{} nearest neighbors to selected image'.format(
                        self.n_neighbors))
                    for p in range(columns * rows):
                        ax = self.fig_nn.add_subplot(rows, columns, p + 1)
                        ax.set_axis_off()
                        if p < len(thumbnails):
                            ax.imshow(np.asarray(thumbnails[p]))
                    self.fig_nn.canvas.mpl_connect('button_press_event', self.click_nn)
                    self.fig_nn.canvas.mpl_connect('key_press_event', self.key_press)
                    self.fig_nn.canvas.mpl_connect('key_release_event', self.key_release)
                    self.fig_nn.canvas.mpl_connect('scroll_event', self.scroll)
                    self.fig_nn.show()
                else:
                    # nearest neighbor figure is already open. Update the figure with new nearest neighbor
                    self.update_nearest_neighbors(ind)
                    return

            self.set_color(ind, self.current_color)
            self.flush_colors()

    def click_nn(self, event):
        """Action handler for the nearest neighbor window.

        When clicking a cropped segment in the nearest neighbor window the same actions are taken as in the click
        handler for the main plot.
        """
        if event.inaxes in self.fig_nn.axes:
            ind = self.get_ind_nn(event)

            if self.current_pressed_key == 't' and event.button == 1:
                self.store_thumbnail(self.nearest_neighbors[ind])
            elif self.current_pressed_key == 'control' and event.button == 1:
                self.show_single_image(self.nearest_neighbors[ind], save=True)
            elif self.current_pressed_key == 'control' and event.button == 2:
                self.show_full_image(self.nearest_neighbors[ind], save=True)
            elif event.button == 1:  # left mouse button
                self.show_single_image(self.nearest_neighbors[ind])
            elif event.button == 2:  # middle mouse button
                self.show_full_image(self.nearest_neighbors[ind])
            elif event.button == 3:  # right mouse button
                self.update_nearest_neighbors(self.nearest_neighbors[ind])

    def key_press(self, event):
        """Performs different actions based on pressed keys."""
        if event.key == 'm':
            self.dm += 1
            self.dm = self.dm % len(self.distance_metrics)
            self.log.info('Changed distance metric to {}'.format(self.distance_metrics[self.dm]))
        elif event.key == '#':
            self.cme += 1
            self.cme = self.cme % len(self.cluster_methods)
            self.log.info('Changed clustering method to {}'.format(list(self.cluster_methods.keys())[self.cme]))
        elif event.key == 'c':
            self.log.info('Started clustering with {}...'.format(list(self.cluster_methods.keys())[self.cme]))
            self.cluster(method=list(self.cluster_methods.keys())[self.cme])
            self.fig_main.axes[0].get_legend().remove()
            self.basecolors = cm.get_cmap('viridis', (max(self.clustering) + 1))(self.clustering)
            self.flush_colors()
            self.cluster_statistics()
        elif event.key == 'x':
            if self.clustering is not None:
                self.cluster_statistics()
        elif event.key == 'g':
            self.color_gt()
        elif event.key == 'h':
            self.color_pred()
        elif event.key == 'b':
            self.set_color(list(range(self.basecolors.shape[0])), self.standard_color)
            self.flush_colors()

        self.current_pressed_key = event.key
        self.log.debug('Key \'{}\' pressed.'.format(event.key))

    def key_release(self, event):
        """Clears the variable where the last pressed key is saved."""
        self.current_pressed_key = None
        self.log.debug('Key \'{}\' released.'.format(event.key))

    def scroll(self, event):
        """Increases or decreases number of nearest neighbors when scrolling on the main or nearest neighbor plot."""
        if event.button == 'up':
            self.n_neighbors += 1
            self.log.info('Increased number of nearest neighbors to {}'.format(self.n_neighbors))
        elif event.button == 'down':
            if self.n_neighbors > 0:
                self.n_neighbors -= 1
                self.log.info('Decreased number of nearest neighbors to {}'.format(self.n_neighbors))

    def show_single_image(self, ind, save=False):
        """Displays the full image belonging to a segment. The segment is marked with a red bounding box."""
        self.log.info('{} image...'.format('Saving' if save else 'Loading'))
        img_box = self.draw_box_on_image(ind)
        fig_tmp = plt.figure(max(3, max(plt.get_fignums()) + 1), dpi=self.dpi)
        ax = fig_tmp.add_subplot(111)
        ax.set_axis_off()
        ax.imshow(np.asarray(img_box), interpolation='nearest')
        if save:
            fig_tmp.subplots_adjust(bottom=0, left=0, right=1, top=1, hspace=0, wspace=0)
            ax.margins(0.05, 0.05)
            fig_tmp.gca().xaxis.set_major_locator(plt.NullLocator())
            fig_tmp.gca().yaxis.set_major_locator(plt.NullLocator())
            fig_tmp.savefig(join(self.save_dir, 'image_{}.jpg'.format(ind)),
                            bbox_inches='tight', pad_inches=0.0)
            self.log.debug('Saved image to {}'.format(join(self.save_dir, 'image_{}.jpg'.format(ind))))
        else:
            fig_tmp.canvas.set_window_title('Dataset: {}, Image index: {}'.format(self.data['dataset'][self.gi[ind]],
                                                                                  self.data['image_index'][
                                                                                      self.gi[ind]]))
            fig_tmp.tight_layout(pad=0.0)
            fig_tmp.show()

    def show_full_image(self, ind, save=False):
        """Displays four panels of the full image belonging to a segment.

        Top left: Entropy heatmap of prediction.
        Top right: Predicted IoU of each segment.
        Bottom left: Source image with ground truth overlay.
        Bottom right: Predicted semantic segmentation.
        """
        self.log.info('{} detailed image...'.format('Saving' if save else 'Loading'))
        box = self.data['box'][ind]
        image = np.asarray(Image.open(self.data['image_path'][self.gi[ind]]).convert('RGB'))
        image_index = self.data['image_index'][self.gi[ind]]
        iou_pred = self.data['iou_pred'][self.gi[ind]]
        dataset = self.data['dataset'][self.gi[ind]]
        model_name = self.data['model_name'][self.gi[ind]]

        pred, gt, image_path = probs_gt_load(image_index,
                                             input_dir=join(CONFIG.metaseg_io_path, 'input', model_name, dataset))
        components = components_load(image_index,
                                     components_dir=join(CONFIG.metaseg_io_path, 'components', model_name, dataset))
        e = entropy(pred)
        pred = pred.argmax(2)
        predc = np.asarray([self.pred_mapping[pred[ind_i, ind_j]][1]
                            for ind_i in range(pred.shape[0])
                            for ind_j in range(pred.shape[1])]).reshape(image.shape)
        gtc = np.asarray([self.label_mapping[dataset][gt[ind_i, ind_j]][1]
                          for ind_i in range(gt.shape[0])
                          for ind_j in range(gt.shape[1])]).reshape(image.shape)

        overlay_factor = [1.0, 0.5, 1.0]
        img_predc, img_gtc, img_entropy = [
            Image.fromarray(np.uint8(arr * overlay_factor[i] + image * (1 - overlay_factor[i])))
            for i, arr in enumerate([predc,
                                     gtc,
                                     cm.jet(e)[:, :, :3] * 255.0])]

        img_ioupred = Image.fromarray(self.visualize_segments(components, iou_pred))

        images = [img_gtc, img_predc, img_entropy, img_ioupred]

        box_line_width = 5
        left, upper = max(0, box[0] - box_line_width), max(0, box[1] - box_line_width)
        right, lower = min(pred.shape[1], box[2] + box_line_width), min(pred.shape[0], box[3] + box_line_width)

        for k in images:
            draw = ImageDraw.Draw(k)
            draw.rectangle([left, upper, right, lower], outline=(255, 0, 0), width=box_line_width)
            del draw

        for k in range(len(images)):
            images[k] = np.asarray(images[k]).astype('uint8')

        img_top = np.concatenate(images[2:], axis=1)
        img_bottom = np.concatenate(images[:2], axis=1)

        img_total = np.concatenate((img_top, img_bottom), axis=0)
        fig_tmp = plt.figure(max(3, max(plt.get_fignums()) + 1), dpi=self.dpi)
        fig_tmp.canvas.set_window_title('Dataset: {}, Image index: {}'.format(dataset,
                                                                              image_index))
        ax = fig_tmp.add_subplot(111)
        ax.set_axis_off()
        ax.imshow(img_total, interpolation='nearest')

        if save:
            fig_tmp.subplots_adjust(bottom=0, left=0, right=1, top=1, hspace=0, wspace=0)
            ax.margins(0.05, 0.05)
            fig_tmp.gca().xaxis.set_major_locator(plt.NullLocator())
            fig_tmp.gca().yaxis.set_major_locator(plt.NullLocator())
            fig_tmp.savefig(join(self.save_dir, 'detailed_image_{}.jpg'.format(ind)),
                            bbox_inches='tight', pad_inches=0.0)
            self.log.debug('Saved image to {}'.format(join(self.save_dir, 'detailed_image_{}.jpg'.format(ind))))
        else:
            fig_tmp.tight_layout(pad=0.0)
            fig_tmp.show()

    def store_thumbnail(self, ind):
        """Stores a thumbnail of a segment if requested. Thus is not saving the whole image but only the cropped part.
        """
        image = Image.open(self.data['image_path'][self.gi[ind]]).convert('RGB')
        image = image.crop(self.data['box'][ind])

        name = self.label_mapping[self.data['dataset'][self.gi[ind]]][self.gt[ind]][0]
        if name[-1].isdigit():
            name = name[:-2]

        name = name.replace(' ', '_')

        image.save(join(self.save_dir, 'thumbnail_{}_{:0>2.1f}_{:0>2.1f}.jpg'.format(
            name,
            self.x[ind],
            self.y[ind])))
        self.log.debug('Saved thumbnail to {}'.format(join(self.save_dir, 'thumbnail_{}_{:0>2.1f}_{:0>2.1f}.jpg'.format(
            name,
            self.x[ind],
            self.y[ind]))))

    def get_nearest_neighbors(self, ind, metric='cos'):
        """Computes nearest neighbors to the specified index in the collection of segment crops."""
        if metric == 'euclid':
            dists = self.lp_dist(self.embeddings[ind], self.embeddings, d=2)
        else:
            dists = self.cos_dist(self.embeddings[ind], self.embeddings)
        return np.argsort(dists)[1:(self.n_neighbors + 1)]

    def update_nearest_neighbors(self, ind):
        """If requesting nearest neighbors of a segment within the nearest neighbor plot window the nearest neighbors
            are updated according to the newly selected segment.
        """
        self.log.info('Loading nearest neighbors...')
        self.nearest_neighbors = self.get_nearest_neighbors(ind, metric=self.distance_metrics[self.dm])
        thumbnails = []
        for neighbor_ind in self.nearest_neighbors:
            b = self.data['box'][neighbor_ind]
            thumbnails.append(plt.imread(self.data['image_path'][self.gi[neighbor_ind]])[b[1]:b[3], b[0]:b[2], :])
        n = math.ceil(math.sqrt(len(self.nearest_neighbors)))
        if len(self.fig_nn.axes) != (n ** 2):
            # number of nearest neighbors has been changed
            # redefine number of subplots in fig_nn
            self.rearrange_axes(n, n)

        for p in range(n ** 2):
            if p < self.n_neighbors:
                self.fig_nn.axes[p].imshow(thumbnails[p])
            else:
                self.fig_nn.axes[p].clear()
                self.fig_nn.axes[p].set_axis_off()

        self.fig_nn.canvas.draw()
        self.set_color(ind, self.current_color)
        self.flush_colors()

    def cluster(self, method='kmeans'):
        if method in self.methods_with_ncluster_param:
            n_clusters = self.n_cluster_prompt()
            if n_clusters == 'elbow' and method == 'kmeans':
                n_clusters = self.elbow()

            self.clustering = self.cluster_methods[method]['main'](
                n_clusters=n_clusters,
                **self.cluster_methods[method]['kwargs']).fit_predict(self.embeddings)

    def cluster_statistics(self):
        fig_clstats = plt.figure(max(3, max(plt.get_fignums()) + 1))
        clusters = np.unique(self.clustering).flatten()
        n = math.ceil(math.sqrt(clusters.shape[0]))
        all_label_names = []
        self.log.debug('Size of cluster statistics plots: {}'.format(n))
        for i in range(n ** 2):
            ax = fig_clstats.add_subplot(n, n, i + 1)
            if i < clusters.shape[0]:
                # labels, label_counts = np.unique(self.gt[self.clustering == clusters[i]], return_counts=True)

                label_names = []
                cols = []
                for j in range(self.clustering.shape[0]):
                    if self.clustering[j] == clusters[i]:
                        dat = self.data['dataset'][self.gi[j]]
                        name = self.label_mapping[dat][self.gt[j]][0]
                        label_names.append((name, self.label_mapping[dat][self.gt[j]][1]))

                if i == (clusters.shape[0] - 1):
                    missing = [all_label_names[ind] for ind in np.unique([lbl[0] for lbl in all_label_names],
                                                                         return_index=True)[1]]
                    label_names += missing

                labels, label_inds, label_counts = np.unique([lbl[0] for lbl in label_names],
                                                             return_counts=True,
                                                             return_index=True)
                explode = np.zeros(labels.shape[0])
                explode[np.argmax(label_counts)] = 0.2
                cols = np.array([label_names[ind][1] for ind in label_inds])
                perm = np.argsort(label_counts)[::-1]
                all_label_names += [label_names[i] for i in np.unique([lbl[0] for lbl in label_names],
                                                                      return_index=True)[1]]

                ax.pie(label_counts[perm],
                       # labels=labels[perm],
                       colors=cols[perm],
                       explode=explode[perm],
                       autopct=lambda perc: '{:1.1f}%'.format(perc) if perc > 20 else '',
                       # shadow=True,
                       startangle=90,
                       wedgeprops=dict(edgecolor='w'),
                       textprops=dict(color="g"))
                ax.axis('equal')
            else:
                ax.set_axis_off()

        fig_clstats.legend(fig_clstats.axes[clusters.shape[0] - 1].patches, labels[perm], loc='upper left')
        fig_clstats.show()

    def elbow(self):
        low = int(input('Enter the minimum number of clusters: '))
        high = int(input('Enter the maximum number of clusters: '))
        km = [KMeans(n_clusters=i) for i in range(low, high + 1)]

        km = [k.fit(self.embeddings) for k in tqdm(km, total=len(km))]
        score = [k.inertia_ for k in km]

        fig_elbow = plt.figure(max(3, max(plt.get_fignums()) + 1))
        ax = fig_elbow.add_subplot(111)
        ax.plot(range(low, high + 1), score)
        fig_elbow.show()
        return int(input('Enter number of clusters: '))

    def n_cluster_prompt(self):
        # inp = input('Enter the number of clusters. Typing \'elbow\' will start the elbow process.')
        inp = input('Enter the number of clusters: ')
        if inp == 'elbow':
            return inp
        else:
            try:
                inp = int(inp)
            except ValueError:
                self.log.error('Input should be an int or \'elbow\' but received {}!'.format(inp))
                return 'error'
            if inp <= 1:
                raise ValueError('Number of clusters should be greater than 1!')
            else:
                return inp

    def rearrange_axes(self, nrows, ncols):
        """Helper function for the nearest neighbor plot window. If number of nearest neighbors has been changed and a
            new query segment has been chosen the arrangement of subplots within the window has to be changed.
        """
        n = len(self.fig_nn.axes)
        if n <= (nrows * ncols):
            # we need to add more axes
            for i, ax in enumerate(self.fig_nn.axes):
                ax.change_geometry(nrows, ncols, i + 1)
            for p in range(n, nrows * ncols):
                ax = self.fig_nn.add_subplot(nrows, ncols, p + 1)
                ax.set_axis_off()
        else:
            # we need to remove some axes
            for p in range(n - 1, (nrows * ncols) - 1, -1):
                self.fig_nn.delaxes(self.fig_nn.axes[p])
            for i, ax in enumerate(self.fig_nn.axes):
                ax.change_geometry(nrows, ncols, i + 1)

    def draw_box_on_image(self, ind):
        """Draws the red bounding of a selected segment onto the source image."""
        box_line_width = 5
        img_box = Image.open(self.data['image_path'][self.gi[ind]]).convert('RGB')
        draw = ImageDraw.Draw(img_box)
        left, upper, right, lower = self.data['box'][ind]
        left, upper = max(0, left - box_line_width), max(0, upper - box_line_width)
        right, lower = min(img_box.size[0], right + box_line_width), min(img_box.size[1], lower + box_line_width)
        draw.rectangle([left, upper, right, lower], outline=(255, 0, 0), width=box_line_width)
        del draw
        return img_box

    @staticmethod
    def visualize_segments(comp, metric):
        """Helper function for generation of the four panels in the detailed image function."""
        r = np.asarray(metric)
        r = 1 - 0.5 * r
        g = np.asarray(metric)
        b = 0.3 + 0.35 * np.asarray(metric)

        r = np.concatenate((r, np.asarray([0, 1])))
        g = np.concatenate((g, np.asarray([0, 1])))
        b = np.concatenate((b, np.asarray([0, 1])))

        components = np.asarray(comp.copy(), dtype='int16')
        components[components < 0] = len(r) - 1
        components[components == 0] = len(r)

        img = np.zeros(components.shape + (3,))
        x, y = np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]))
        x = x.reshape(-1)
        y = y.reshape(-1)

        img[x, y, 0] = r[components[x, y] - 1]
        img[x, y, 1] = g[components[x, y] - 1]
        img[x, y, 2] = b[components[x, y] - 1]

        img = np.asarray(255 * img).astype('uint8')

        return img

    @staticmethod
    def lp_dist(point, all_points, d=2):
        """Calculates the L_p distance from a point to a collection of points. Used for retrieval."""
        return ((all_points - point) ** d).sum(1) ** (1.0 / d)

    @staticmethod
    def cos_dist(point, all_points):
        """Calculates the cosine distance from a point to a collection of points. Used for retrieval."""
        return 1 - ((point * all_points).sum(1) / (norm(point) * norm(all_points, axis=1)))

    @staticmethod
    def get_gridsize(fig):
        """Helper function for the nearest neighbor plot."""
        return fig.axes[0].get_subplotspec().get_gridspec().get_geometry()

    def get_ind_nn(self, event):
        """Helper function for the nearest neighbor plot."""
        _, ncols = self.get_gridsize(self.fig_nn)
        eventrow = event.inaxes.rowNum
        eventcol = event.inaxes.colNum
        return (eventrow * ncols) + eventcol

    def color_gt(self):
        """When called colors the scatter in the main plot according to the ground truth colors."""
        self.basecolors = np.stack([tuple(i / 255.0
                                          for i in self.label_mapping[self.data['dataset'][self.gi[ind]]][
                                              self.gt[ind]
                                          ][1])
                                    + (1.0,) for ind in range(self.basecolors.shape[0])])
        legend_elements = []
        for d in np.unique(self.data['dataset']).flatten():
            cls = np.unique(self.gt[np.array(self.data['dataset'])[self.gi] == d])
            cls = list({(self.label_mapping[d][cl][0], self.label_mapping[d][cl][1]) for cl in cls})
            names = np.array([i[0] for i in cls])
            cols = np.array([i[1] for i in cls])
            legend_elements += [Patch(color=tuple(i / 255.0 for i in cols[i]) + (1.0,),
                                      label=names[i]
                                      if not names[i][-1].isdigit()
                                      else names[i][:names[i].rfind(' ')])
                                for i in range(names.shape[0])]
        self.fig_main.axes[0].legend(loc='upper left', handles=legend_elements, ncol=8, bbox_to_anchor=(0, 1.2))
        self.flush_colors()

    def color_pred(self):
        """When called colors the scatter in the main plot according the predicted class color."""
        self.basecolors = np.stack([tuple(i / 255.0
                                          for i in self.pred_mapping[self.pred[ind]][1])
                                    + (1.0,)
                                    for ind in range(self.basecolors.shape[0])])
        legend_elements = [Patch(color=tuple(i / 255.0
                                             for i in self.pred_mapping[cl][1]) + (1.0,),
                                 label=self.pred_mapping[cl][0])
                           for cl in np.unique(self.pred).flatten()]
        self.fig_main.axes[0].legend(loc='upper left', handles=legend_elements, ncol=8, bbox_to_anchor=(0, 1.2))
        self.flush_colors()

    def set_color(self, ind, color):
        """Helper function to set a color of a segment with index ind."""
        self.basecolors[ind, :] = color

    def change_color(self, old_color, new_color):
        """Helper function to change a specific color to a different one."""
        self.basecolors[(self.basecolors == old_color).all(axis=1)] = new_color

    def flush_colors(self):
        """Flushes all color changes onto the main scatter plot."""
        self.line_main.set_color(self.basecolors)
        self.fig_main.canvas.draw()
Ejemplo n.º 20
0
class Discovery(object):
    def __init__(
            self,
            embeddings_file="embeddings.p",
            distance_metric="euclid",
            method="TSNE",
            embedding_size=2,
            overwrite_embeddings=False,
            n_jobs=10,
            dpi=300,
            main_plot_args={},
            tsne_args={},
            save_dir=join(CONFIG.metaseg_io_path, "vis_embeddings"),
    ):
        """Loads the embedding files, computes the dimensionality reductions and calls
        the initilization of the main plot.

        Args:
            embeddings_file (str): Path to the file where all data of segments
                including feature embeddings is saved.
            distance_metric (str): Distance metric to use for nearest neighbor
                computation.
            method (str): Method to use for dimensionality reduction of nearest
                neighbor embeddings. For plotting the points are always reduced in
                dimensionality using PCA to 50 dimensions followed by t-SNE to two
                dimensions.
            embedding_size (int): Dimensionality of the feature embeddings used for
                nearest neighbor search.
            overwrite_embeddings (bool): If True, precomputed nearest neighbor and
                plotting embeddings from previous runs are overwritten with freshly
                computed ones. Otherwise precomputed embeddings are used if requested
                embedding_size is matching.
            n_jobs (int): Number of processes to use for t-SNE computation.
            dpi (int): Dots per inch for graphics that are saved to disk.
            main_plot_args (dict): Keyword arguments for the creation of the main plot.
            tsne_args (dict): Keyword arguments for the t-SNE algorithm.
            save_dir (str): Path to the directory where saved images are placed in.
        """
        self.log = logging.getLogger("Discovery")
        self.embeddings_file = embeddings_file
        self.distance_metrics = ["euclid", "cos"]
        self.dm = (0 if distance_metric not in self.distance_metrics else
                   self.distance_metrics.index(distance_metric))

        self.dpi = dpi
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)

        self.cluster_methods = OrderedDict()
        self.cluster_methods["kmeans"] = {"main": KMeans, "kwargs": {}}
        self.cluster_methods["spectral"] = {
            "main": SpectralClustering,
            "kwargs": {}
        }
        self.cluster_methods["agglo"] = {
            "main": AgglomerativeClustering,
            "kwargs": {
                "linkage": "ward"
            },
        }

        self.methods_with_ncluster_param = ["kmeans", "spectral", "agglo"]
        self.cme = 0
        self.clustering = None
        self.n_clusters = 25

        # colors:
        self.standard_color = (0, 0, 1, 1)
        self.current_color = (1, 0, 0, 1)
        self.nn_color = (1, 0, 1, 1)

        self.log.info("Loading data...")
        with open(self.embeddings_file, "rb") as f:
            self.data = pkl.load(f)

        self.iou_preds = self.data["iou_pred"]
        self.gt = np.array(self.data["gt"]).flatten()
        self.pred = np.array(self.data["pred"]).flatten()
        self.gi = self.data[
            "image_level_index"]  # global indices (on image level and not on component level)

        self.log.info("Loaded {} segment embeddings.".format(
            self.pred.shape[0]))

        self.nearest_neighbors = None

        if len(self.data["embeddings"]) == 1:
            self.data["plot_embeddings"] = np.array(
                [self.data["embeddings"][0][0],
                 self.data["embeddings"][0][1]]).reshape((1, 2))
            self.data["nn_embeddings"] = self.data["plot_embeddings"]
        else:
            if ("nn_embeddings" not in self.data.keys() or overwrite_embeddings
                    or "plot_embeddings" not in self.data.keys()
                ) and embedding_size < self.data["embeddings"][0].shape[0]:
                self.log.info("Computing PCA...")
                n_comp = (50 if 50 < min(
                    len(self.data["embeddings"]),
                    self.data["embeddings"][0].shape[0],
                ) else min(
                    len(self.data["embeddings"]),
                    self.data["embeddings"][0].shape[0],
                ))
                embeddings = PCA(n_components=n_comp).fit_transform(
                    np.stack(self.data["embeddings"]).reshape(
                        (-1, self.data["embeddings"][0].shape[0])))
                rewrite = True
            else:
                rewrite = False

            if "plot_embeddings" not in self.data.keys(
            ) or overwrite_embeddings:
                self.log.info("Computing t-SNE for plotting")
                self.data["plot_embeddings"] = TSNE(
                    n_components=2, **tsne_args).fit_transform(embeddings)
                new_plot_embeddings = True
            else:
                new_plot_embeddings = False

            if (embedding_size >= self.data["embeddings"][0].shape[0]
                    or embedding_size is None):
                self.embeddings = np.stack(self.data["embeddings"]).reshape(
                    (-1, self.data["embeddings"][0].shape[0]))
                self.log.debug(
                    ("Requested embedding size of {} was greater or equal "
                     "to data dimensionality of {}. "
                     "Data has thus not been reduced in dimensionality."
                     ).format(embedding_size,
                              self.data["embeddings"].shape[1]))
            elif (self.data["nn_embeddings"].shape[1] == embedding_size
                  if "nn_embeddings" in self.data.keys() else
                  False) and not overwrite_embeddings:
                self.embeddings = self.data["nn_embeddings"]
                self.log.info(("Loaded reduced embeddings "
                               "({} dimensions) from precomputed file " +
                               "for nearest neighbor search.").format(
                                   self.embeddings.shape[1]))
            elif rewrite:
                if method == "TSNE":
                    if ("plot_embeddings" in self.data.keys()
                            and embedding_size == 2 and new_plot_embeddings):
                        self.embeddings = self.data["plot_embeddings"]
                        self.log.info((
                            "Reused the precomputed manifold for plotting for "
                            "nearest neighbor search."))
                    else:
                        self.log.info(("Computing t-SNE of dimension "
                                       "{} for nearest neighbor search..."
                                       ).format(embedding_size))
                        self.embeddings = TSNE(
                            n_components=embedding_size,
                            n_jobs=n_jobs,
                            **tsne_args).fit_transform(embeddings)
                else:
                    self.log.info(("Computing Isomap of dimension "
                                   "{} for nearest neighbor search..."
                                   ).format(embedding_size))
                    self.embeddings = Isomap(
                        n_components=embedding_size,
                        n_jobs=n_jobs,
                    ).fit_transform(embeddings)
                self.data["nn_embeddings"] = self.embeddings
            else:
                raise ValueError(
                    ("Please specify a valid combination of arguments.\n"
                     "Loading fails if 'overwrite_embeddings' is False and "
                     "saved 'embedding_size' does not match the requested one."
                     ))

            # Write added data into pickle file
            if rewrite:
                with open(self.embeddings_file, "wb") as f:
                    pkl.dump(self.data, f)

        self.x = self.data["plot_embeddings"][:, 0]
        self.y = self.data["plot_embeddings"][:, 1]

        self.label_mapping = dict()
        for d in np.unique(self.data["dataset"]).flatten():
            try:
                self.label_mapping[d] = getattr(
                    importlib.import_module(datasets[d].module_name),
                    datasets[d].class_name,
                )(**datasets[d].kwargs, ).label_mapping
            except AttributeError:
                self.label_mapping[d] = None

        train_dat = self.label_mapping[CONFIG.TRAIN_DATASET.name] = getattr(
            importlib.import_module(CONFIG.TRAIN_DATASET.module_name),
            CONFIG.TRAIN_DATASET.class_name,
        )(**CONFIG.TRAIN_DATASET.kwargs, )
        self.pred_mapping = train_dat.pred_mapping
        if CONFIG.TRAIN_DATASET.name not in self.label_mapping:
            self.label_mapping[
                CONFIG.TRAIN_DATASET.name] = train_dat.label_mapping

        self.tnsize = (50, 50)
        self.fig_nn = None
        self.fig_main = None
        self.line_main = None
        self.im = None
        self.xybox = None
        self.ab = None
        self.basecolors = np.stack(
            [self.standard_color for _ in range(self.x.shape[0])])
        self.n_neighbors = 49
        self.current_pressed_key = None

        self.plot_main(**main_plot_args)

    def plot_main(self, **plot_args):
        """Initializes the main plot.

        Only 'legend' (bool) is currently supported as keyword argument.
        """
        self.fig_main = plt.figure(num=1)
        self.fig_main.canvas.set_window_title("Embedding space")
        ax = self.fig_main.add_subplot(111)
        ax.set_axis_off()
        self.line_main = ax.scatter(self.x,
                                    self.y,
                                    marker="o",
                                    color=self.basecolors,
                                    zorder=2)
        self.line_main.set_picker(True)

        if ((plot_args["legend"]
             and all(lm is not None for lm in self.label_mapping.values()))
                if "legend" in plot_args else False):
            box = ax.get_position()
            ax.set_position([box.x0, box.y0, box.width, box.height * 0.8])
            legend_elements = []
            for d in np.unique(self.data["dataset"]).flatten():
                cls = np.unique(
                    self.gt[np.array(self.data["dataset"])[self.gi] == d])
                cls = list({(self.label_mapping[d][cl][0],
                             self.label_mapping[d][cl][1])
                            for cl in cls})
                names = np.array([i[0] for i in cls])
                cols = np.array([i[1] for i in cls])
                legend_elements += [
                    Patch(
                        color=tuple(i / 255.0 for i in cols[i]) + (1.0, ),
                        label=names[i] if not names[i][-1].isdigit() else
                        names[i][:names[i].rfind(" ")],
                    ) for i in range(names.shape[0])
                ]
            ax.legend(
                loc="upper left",
                handles=legend_elements,
                ncol=8,
                bbox_to_anchor=(0, 1.2),
            )
        self.basecolors = self.line_main.get_facecolor()

        tmp = (Image.open(
            self.data["image_path"][self.gi[0]]).convert("RGB").crop(
                self.data["box"][0]))
        tmp.thumbnail(self.tnsize, Image.ANTIALIAS)
        self.im = OffsetImage(tmp, zoom=2)
        self.xybox = (50.0, 50.0)
        self.ab = AnnotationBbox(
            self.im,
            (0, 0),
            xybox=self.xybox,
            xycoords="data",
            boxcoords="offset points",
            pad=0.3,
            arrowprops=dict(arrowstyle="->"),
        )
        ax.add_artist(self.ab)
        self.ab.set_visible(False)

        if plot_args[
                "save_path"] is not None if "save_path" in plot_args else False:
            plt.savefig(expanduser(plot_args["save_path"]),
                        dpi=300,
                        bbox_inches="tight")

        else:
            self.fig_main.canvas.mpl_connect("motion_notify_event",
                                             self.hover_main)
            self.fig_main.canvas.mpl_connect("button_press_event",
                                             self.click_main)
            self.fig_main.canvas.mpl_connect("scroll_event", self.scroll)
            self.fig_main.canvas.mpl_connect("key_press_event", self.key_press)
            self.fig_main.canvas.mpl_connect("key_release_event",
                                             self.key_release)
            plt.show()

    def hover_main(self, event):
        """Action handler for the main plot.

        This function shows a thumbnail of the underlying image when a scatter point
        is hovered with the mouse.
        """
        # if the mouse is over the scatter points
        if self.line_main.contains(event)[0]:
            # find out the index within the array from the event
            ind, *_ = self.line_main.contains(event)[1]["ind"]

            # get the figure size
            w, h = self.fig_main.get_size_inches() * self.fig_main.dpi
            ws = (event.x > w / 2.0) * -1 + (event.x <= w / 2.0)
            hs = (event.y > h / 2.0) * -1 + (event.y <= h / 2.0)

            # if event occurs in the top or right quadrant of the figure,
            # change the annotation box position relative to mouse.
            self.ab.xybox = (self.xybox[0] * ws, self.xybox[1] * hs)

            # make annotation box visible
            self.ab.set_visible(True)

            # place it at the position of the hovered scatter point
            self.ab.xy = (self.x[ind], self.y[ind])

            # set the image corresponding to that point
            tmp = (Image.open(
                self.data["image_path"][self.gi[ind]]).convert("RGB").crop(
                    self.data["box"][ind]))
            tmp.thumbnail(self.tnsize, Image.ANTIALIAS)
            self.im.set_data(tmp)
            tmp.close()
        else:
            # if the mouse is not over a scatter point
            self.ab.set_visible(False)
        self.fig_main.canvas.draw_idle()

    def click_main(self, event):
        """Action handler for the main plot.

        This function shows a single or full image or displays nearest neighbors based
        on the button that has been pressed and which scatter point was pressed.
        """
        if self.line_main.contains(event)[0]:
            ind, *_ = self.line_main.contains(event)[1]["ind"]

            if self.current_pressed_key == "t" and event.button == 1:
                self.store_thumbnail(ind)
            elif self.current_pressed_key == "control" and event.button == 1:
                self.show_single_image(ind, save=True)
            elif self.current_pressed_key == "control" and event.button == 2:
                self.show_full_image(ind, save=True)
            elif event.button == 1:  # left mouse button
                self.show_single_image(ind)
            elif event.button == 2:  # middle mouse button
                self.show_full_image(ind)
            elif event.button == 3:  # right mouse button
                if not plt.fignum_exists(2):
                    # nearest neighbor figure is not open anymore or has not been
                    # opened yet
                    self.log.info("Loading nearest neighbors...")
                    self.nearest_neighbors = self.get_nearest_neighbors(
                        ind, metric=self.distance_metrics[self.dm])
                    thumbnails = []
                    for neighbor_ind in self.nearest_neighbors:
                        thumbnails.append(
                            Image.open(self.data["image_path"][
                                self.gi[neighbor_ind]]).crop(
                                    self.data["box"][neighbor_ind]))
                    columns = math.ceil(math.sqrt(self.n_neighbors))
                    rows = math.ceil(self.n_neighbors / columns)

                    self.fig_nn = plt.figure(num=2, dpi=self.dpi)
                    self.fig_nn.canvas.set_window_title(
                        "{} nearest neighbors to selected image".format(
                            self.n_neighbors))
                    for p in range(columns * rows):
                        ax = self.fig_nn.add_subplot(rows, columns, p + 1)
                        ax.set_axis_off()
                        if p < len(thumbnails):
                            ax.imshow(np.asarray(thumbnails[p]))
                    self.fig_nn.canvas.mpl_connect("button_press_event",
                                                   self.click_nn)
                    self.fig_nn.canvas.mpl_connect("key_press_event",
                                                   self.key_press)
                    self.fig_nn.canvas.mpl_connect("key_release_event",
                                                   self.key_release)
                    self.fig_nn.canvas.mpl_connect("scroll_event", self.scroll)
                    self.fig_nn.show()
                else:
                    # nearest neighbor figure is already open. Update the figure with
                    # new nearest neighbor
                    self.update_nearest_neighbors(ind)
                    return

            self.set_color(ind, self.current_color)
            self.flush_colors()

    def click_nn(self, event):
        """Action handler for the nearest neighbor window.

        When clicking a cropped segment in the nearest neighbor window the same actions
        are taken as in the click handler for the main plot.
        """
        if event.inaxes in self.fig_nn.axes:
            ind = self.get_ind_nn(event)

            if self.current_pressed_key == "t" and event.button == 1:
                self.store_thumbnail(self.nearest_neighbors[ind])
            elif self.current_pressed_key == "control" and event.button == 1:
                self.show_single_image(self.nearest_neighbors[ind], save=True)
            elif self.current_pressed_key == "control" and event.button == 2:
                self.show_full_image(self.nearest_neighbors[ind], save=True)
            elif event.button == 1:  # left mouse button
                self.show_single_image(self.nearest_neighbors[ind])
            elif event.button == 2:  # middle mouse button
                self.show_full_image(self.nearest_neighbors[ind])
            elif event.button == 3:  # right mouse button
                self.update_nearest_neighbors(self.nearest_neighbors[ind])

    def key_press(self, event):
        """Performs different actions based on pressed keys."""
        self.log.debug("Key '{}' pressed.".format(event.key))
        if event.key == "m":
            self.dm += 1
            self.dm = self.dm % len(self.distance_metrics)
            self.log.info("Changed distance metric to {}".format(
                self.distance_metrics[self.dm]))
        elif event.key == "#":
            self.cme += 1
            self.cme = self.cme % len(self.cluster_methods)
            self.log.info("Changed clustering method to {}".format(
                list(self.cluster_methods.keys())[self.cme]))
        elif event.key == "c":
            self.log.info("Started clustering with {}...".format(
                list(self.cluster_methods.keys())[self.cme]))
            self.cluster(method=list(self.cluster_methods.keys())[self.cme])
            if self.fig_main.axes[0].get_legend() is not None:
                self.fig_main.axes[0].get_legend().remove()
            self.basecolors = cm.get_cmap(
                "viridis", (max(self.clustering) + 1))(self.clustering)
            self.flush_colors()
        elif event.key == "g":
            self.color_gt()
        elif event.key == "h":
            self.color_pred()
        elif event.key == "b":
            self.set_color(list(range(self.basecolors.shape[0])),
                           self.standard_color)
            if self.fig_main.axes[0].get_legend() is not None:
                self.fig_main.axes[0].get_legend().remove()
            self.flush_colors()
        elif event.key == "d":
            self.show_density()

        self.current_pressed_key = event.key

    def key_release(self, event):
        """Clears the variable where the last pressed key is saved."""
        self.current_pressed_key = None
        self.log.debug("Key '{}' released.".format(event.key))

    def scroll(self, event):
        """Increases or decreases number of nearest neighbors when scrolling on
        the main or nearest neighbor plot."""
        if event.button == "up":
            self.n_neighbors += 1
            self.log.info("Increased number of nearest neighbors to {}".format(
                self.n_neighbors))
        elif event.button == "down":
            if self.n_neighbors > 0:
                self.n_neighbors -= 1
                self.log.info(
                    "Decreased number of nearest neighbors to {}".format(
                        self.n_neighbors))

    def show_single_image(self, ind, save=False):
        """Displays the full image belonging to a segment. The segment is marked with
        a red bounding box."""
        self.log.info("{} image...".format("Saving" if save else "Loading"))
        img_box = self.draw_box_on_image(ind)
        fig_tmp = plt.figure(max(3, max(plt.get_fignums()) + 1), dpi=self.dpi)
        ax = fig_tmp.add_subplot(111)
        ax.set_axis_off()
        ax.imshow(np.asarray(img_box), interpolation="nearest")
        if save:
            fig_tmp.subplots_adjust(bottom=0,
                                    left=0,
                                    right=1,
                                    top=1,
                                    hspace=0,
                                    wspace=0)
            ax.margins(0.05, 0.05)
            fig_tmp.gca().xaxis.set_major_locator(plt.NullLocator())
            fig_tmp.gca().yaxis.set_major_locator(plt.NullLocator())
            fig_tmp.savefig(
                join(self.save_dir, "image_{}.jpg".format(ind)),
                bbox_inches="tight",
                pad_inches=0.0,
            )
            self.log.debug("Saved image to {}".format(
                join(self.save_dir, "image_{}.jpg".format(ind))))
        else:
            fig_tmp.canvas.set_window_title(
                "Dataset: {}, Image index: {}".format(
                    self.data["dataset"][self.gi[ind]],
                    self.data["image_index"][self.gi[ind]],
                ))
            fig_tmp.tight_layout(pad=0.0)
            fig_tmp.show()

    def show_full_image(self, ind, save=False):
        """Displays four panels of the full image belonging to a segment.

        Top left: Entropy heatmap of prediction.
        Top right: Predicted IoU of each segment.
        Bottom left: Source image with ground truth overlay.
        Bottom right: Predicted semantic segmentation.
        """
        self.log.info(
            "{} detailed image...".format("Saving" if save else "Loading"))
        box = self.data["box"][ind]
        image = np.asarray(
            Image.open(self.data["image_path"][self.gi[ind]]).convert("RGB"))
        image_index = self.data["image_index"][self.gi[ind]]
        iou_pred = self.data["iou_pred"][self.gi[ind]]
        dataset = self.data["dataset"][self.gi[ind]]
        model_name = self.data["model_name"][self.gi[ind]]

        pred, gt, image_path = probs_gt_load(
            image_index,
            input_dir=join(CONFIG.metaseg_io_path, "input", model_name,
                           dataset),
        )
        components = components_load(
            image_index,
            components_dir=join(CONFIG.metaseg_io_path, "components",
                                model_name, dataset),
        )
        e = entropy(pred)
        pred = pred.argmax(2)
        predc = np.asarray([
            self.pred_mapping[pred[ind_i, ind_j]][1]
            for ind_i in range(pred.shape[0]) for ind_j in range(pred.shape[1])
        ]).reshape(image.shape)
        overlay_factor = [1.0, 0.5, 1.0]

        if self.label_mapping[dataset] is not None:
            gtc = np.asarray([
                self.label_mapping[dataset][gt[ind_i, ind_j]][1]
                for ind_i in range(gt.shape[0]) for ind_j in range(gt.shape[1])
            ]).reshape(image.shape)
        else:
            gtc = np.zeros_like(image)
            overlay_factor[1] = 0.0

        img_predc, img_gtc, img_entropy = [
            Image.fromarray(
                np.uint8(arr * overlay_factor[i] + image *
                         (1 - overlay_factor[i])))
            for i, arr in enumerate([predc, gtc,
                                     cm.jet(e)[:, :, :3] * 255.0])
        ]

        img_ioupred = Image.fromarray(
            self.visualize_segments(components, iou_pred))

        images = [img_gtc, img_predc, img_entropy, img_ioupred]

        box_line_width = 5
        left, upper = max(0, box[0] - box_line_width), max(
            0, box[1] - box_line_width)
        right, lower = min(pred.shape[1], box[2] + box_line_width), min(
            pred.shape[0], box[3] + box_line_width)

        for k in images:
            draw = ImageDraw.Draw(k)
            draw.rectangle([left, upper, right, lower],
                           outline=(255, 0, 0),
                           width=box_line_width)
            del draw

        for k in range(len(images)):
            images[k] = np.asarray(images[k]).astype("uint8")

        img_top = np.concatenate(images[2:], axis=1)
        img_bottom = np.concatenate(images[:2], axis=1)

        img_total = np.concatenate((img_top, img_bottom), axis=0)
        fig_tmp = plt.figure(max(3, max(plt.get_fignums()) + 1), dpi=self.dpi)
        fig_tmp.canvas.set_window_title("Dataset: {}, Image index: {}".format(
            dataset, image_index))
        ax = fig_tmp.add_subplot(111)
        ax.set_axis_off()
        ax.imshow(img_total, interpolation="nearest")

        if save:
            fig_tmp.subplots_adjust(bottom=0,
                                    left=0,
                                    right=1,
                                    top=1,
                                    hspace=0,
                                    wspace=0)
            ax.margins(0.05, 0.05)
            fig_tmp.gca().xaxis.set_major_locator(plt.NullLocator())
            fig_tmp.gca().yaxis.set_major_locator(plt.NullLocator())
            fig_tmp.savefig(
                join(self.save_dir, "detailed_image_{}.jpg".format(ind)),
                bbox_inches="tight",
                pad_inches=0.0,
            )
            self.log.debug("Saved image to {}".format(
                join(self.save_dir, "detailed_image_{}.jpg".format(ind))))
        else:
            fig_tmp.tight_layout(pad=0.0)
            fig_tmp.show()

    def store_thumbnail(self, ind):
        """Stores a thumbnail of a segment if requested. Thus is not saving the whole
        image but only the cropped part."""
        image = Image.open(
            self.data["image_path"][self.gi[ind]]).convert("RGB")
        image = image.crop(self.data["box"][ind])

        if self.label_mapping[self.data["dataset"][self.gi[ind]]] is None:
            name = "None"
        else:
            name = self.label_mapping[self.data["dataset"][self.gi[ind]]][
                self.gt[ind]][0]
        if name[-1].isdigit():
            name = name[:-2]

        name = name.replace(" ", "_")

        image.save(
            join(
                self.save_dir,
                "thumbnail_{}_{:0>2.1f}_{:0>2.1f}.jpg".format(
                    name, self.x[ind], self.y[ind]),
            ))
        self.log.debug("Saved thumbnail to {}".format(
            join(
                self.save_dir,
                "thumbnail_{}_{:0>2.1f}_{:0>2.1f}.jpg".format(
                    name, self.x[ind], self.y[ind]),
            )))

    def get_nearest_neighbors(self, ind, metric="cos"):
        """Computes nearest neighbors to the specified index in the collection of
        segment crops."""
        if metric == "euclid":
            dists = self.lp_dist(self.embeddings[ind], self.embeddings, d=2)
        else:
            dists = self.cos_dist(self.embeddings[ind], self.embeddings)
        return np.argsort(dists)[1:(self.n_neighbors + 1)]

    def update_nearest_neighbors(self, ind):
        """If requesting nearest neighbors of a segment within the nearest neighbor
        plot window the nearest neighbors are updated according to the newly
        selected segment.
        """
        self.log.info("Loading nearest neighbors...")
        self.nearest_neighbors = self.get_nearest_neighbors(
            ind, metric=self.distance_metrics[self.dm])
        thumbnails = []
        for neighbor_ind in self.nearest_neighbors:
            b = self.data["box"][neighbor_ind]
            thumbnails.append(
                plt.imread(self.data["image_path"][self.gi[neighbor_ind]])[
                    b[1]:b[3], b[0]:b[2], :])
        n = math.ceil(math.sqrt(len(self.nearest_neighbors)))
        if len(self.fig_nn.axes) != (n**2):
            # number of nearest neighbors has been changed
            # redefine number of subplots in fig_nn
            self.rearrange_axes(n, n)

        for p in range(n**2):
            if p < self.n_neighbors:
                self.fig_nn.axes[p].imshow(thumbnails[p])
            else:
                self.fig_nn.axes[p].clear()
                self.fig_nn.axes[p].set_axis_off()

        self.fig_nn.canvas.draw()
        self.set_color(ind, self.current_color)
        self.flush_colors()

    def cluster(self, method="kmeans"):
        if method in self.methods_with_ncluster_param:
            n_clusters = self.n_cluster_prompt()
            if n_clusters == "elbow" and method == "kmeans":
                n_clusters = self.elbow()

            self.clustering = self.cluster_methods[method]["main"](
                n_clusters=n_clusters,
                **self.cluster_methods[method]["kwargs"]).fit_predict(
                    self.embeddings)

    def elbow(self):
        low = int(input("Enter the minimum number of clusters: "))
        high = int(input("Enter the maximum number of clusters: "))
        km = [KMeans(n_clusters=i) for i in range(low, high + 1)]

        km = [k.fit(self.embeddings) for k in tqdm(km, total=len(km))]
        score = [k.inertia_ for k in km]

        fig_elbow = plt.figure(max(3, max(plt.get_fignums()) + 1))
        ax = fig_elbow.add_subplot(111)
        ax.plot(range(low, high + 1), score)
        fig_elbow.show()
        return int(input("Enter number of clusters: "))

    def n_cluster_prompt(self):
        inp = input("Enter the number of clusters: ")
        if inp == "elbow":
            return inp
        else:
            try:
                inp = int(inp)
            except ValueError:
                self.log.error(
                    "Input should be an int or 'elbow' but received {}!".
                    format(inp))
                return "error"
            if inp <= 1:
                raise ValueError(
                    "Number of clusters should be greater than 1!")
            else:
                return inp

    def rearrange_axes(self, nrows, ncols):
        """Helper function for the nearest neighbor plot window. If number of nearest
        neighbors has been changed and a new query segment has been chosen the
        arrangement of subplots within the window has to be changed.
        """
        n = len(self.fig_nn.axes)
        if n <= (nrows * ncols):
            # we need to add more axes
            for i, ax in enumerate(self.fig_nn.axes):
                ax.change_geometry(nrows, ncols, i + 1)
            for p in range(n, nrows * ncols):
                ax = self.fig_nn.add_subplot(nrows, ncols, p + 1)
                ax.set_axis_off()
        else:
            # we need to remove some axes
            for p in range(n - 1, (nrows * ncols) - 1, -1):
                self.fig_nn.delaxes(self.fig_nn.axes[p])
            for i, ax in enumerate(self.fig_nn.axes):
                ax.change_geometry(nrows, ncols, i + 1)

    def draw_box_on_image(self, ind):
        """Draws the red bounding of a selected segment onto the source image."""
        box_line_width = 5
        img_box = Image.open(
            self.data["image_path"][self.gi[ind]]).convert("RGB")
        draw = ImageDraw.Draw(img_box)
        left, upper, right, lower = self.data["box"][ind]
        left, upper = max(0,
                          left - box_line_width), max(0,
                                                      upper - box_line_width)
        right, lower = min(img_box.size[0], right + box_line_width), min(
            img_box.size[1], lower + box_line_width)
        draw.rectangle([left, upper, right, lower],
                       outline=(255, 0, 0),
                       width=box_line_width)
        del draw
        return img_box

    @staticmethod
    def visualize_segments(comp, metric):
        """Helper function for generation of the four panels in the detailed
        image function."""
        r = np.asarray(metric)
        r = 1 - 0.5 * r
        g = np.asarray(metric)
        b = 0.3 + 0.35 * np.asarray(metric)

        r = np.concatenate((r, np.asarray([0, 1])))
        g = np.concatenate((g, np.asarray([0, 1])))
        b = np.concatenate((b, np.asarray([0, 1])))

        components = np.asarray(comp.copy(), dtype="int16")
        components[components < 0] = len(r) - 1
        components[components == 0] = len(r)

        img = np.zeros(components.shape + (3, ))
        x, y = np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]))
        x = x.reshape(-1)
        y = y.reshape(-1)

        img[x, y, 0] = r[components[x, y] - 1]
        img[x, y, 1] = g[components[x, y] - 1]
        img[x, y, 2] = b[components[x, y] - 1]

        img = np.asarray(255 * img).astype("uint8")

        return img

    @staticmethod
    def lp_dist(point, all_points, d=2):
        """Calculates the L_p distance from a point to a collection of points.
        Used for retrieval."""
        return ((all_points - point)**d).sum(1)**(1.0 / d)

    @staticmethod
    def cos_dist(point, all_points):
        """Calculates the cosine distance from a point to a collection of points.
        Used for retrieval."""
        return 1 - ((point * all_points).sum(1) /
                    (norm(point) * norm(all_points, axis=1)))

    @staticmethod
    def get_gridsize(fig):
        """Helper function for the nearest neighbor plot."""
        return fig.axes[0].get_subplotspec().get_gridspec().get_geometry()

    def get_ind_nn(self, event):
        """Helper function for the nearest neighbor plot."""
        _, ncols = self.get_gridsize(self.fig_nn)
        eventrow = event.inaxes.rowNum
        eventcol = event.inaxes.colNum
        return (eventrow * ncols) + eventcol

    def color_gt(self):
        """When called colors the scatter in the main plot according to the ground
        truth colors."""
        if all(self.label_mapping[self.data["dataset"][self.gi[ind]]]
               is not None for ind in range(self.basecolors.shape[0])):
            self.basecolors = np.stack([
                tuple(i / 255.0
                      for i in self.label_mapping[self.data["dataset"][
                          self.gi[ind]]][self.gt[ind]][1]) + (1.0, )
                for ind in range(self.basecolors.shape[0])
            ])
            legend_elements = []
            for d in np.unique(self.data["dataset"]).flatten():
                cls = np.unique(
                    self.gt[np.array(self.data["dataset"])[self.gi] == d])
                cls = list({(self.label_mapping[d][cl][0],
                             self.label_mapping[d][cl][1])
                            for cl in cls})
                names = np.array([i[0] for i in cls])
                cols = np.array([i[1] for i in cls])
                legend_elements += [
                    Patch(
                        color=tuple(i / 255.0 for i in cols[i]) + (1.0, ),
                        label=names[i] if not names[i][-1].isdigit() else
                        names[i][:names[i].rfind(" ")],
                    ) for i in range(names.shape[0])
                ]
            self.fig_main.axes[0].legend(
                loc="upper left",
                handles=legend_elements,
                ncol=8,
                bbox_to_anchor=(0, 1.1),
            )
            self.flush_colors()

    def color_pred(self):
        """When called colors the scatter in the main plot according to the predicted
        class color."""
        self.basecolors = np.stack([
            tuple(i / 255.0
                  for i in self.pred_mapping[self.pred[ind]][1]) + (1.0, )
            for ind in range(self.basecolors.shape[0])
        ])
        legend_elements = [
            Patch(
                color=tuple(i / 255.0
                            for i in self.pred_mapping[cl][1]) + (1.0, ),
                label=self.pred_mapping[cl][0],
            ) for cl in np.unique(self.pred).flatten()
        ]
        self.fig_main.axes[0].legend(loc="upper left",
                                     handles=legend_elements,
                                     ncol=8,
                                     bbox_to_anchor=(0, 1.1))
        self.flush_colors()

    def show_density(self):
        embedding_kde = estimate_kernel_density(self.data["plot_embeddings"])
        xmin = self.x.min()
        xmin = xmin * 1.3 if xmin < 0 else xmin * 0.8
        xmax = self.x.max()
        xmax = xmax * 1.3 if xmax > 0 else xmax * 0.8

        ymin = self.y.min()
        ymin = ymin * 1.3 if ymin < 0 else ymin * 0.8
        ymax = self.y.max()
        ymax = ymax * 1.3 if ymax > 0 else ymax * 0.8

        grid_x, grid_y = np.mgrid[xmin:xmax, ymin:ymax]
        grid_z = embedding_kde(np.vstack([grid_x.flatten(), grid_y.flatten()]))
        colmap = plt.get_cmap("Greys")
        colmap = colors.LinearSegmentedColormap.from_list(
            "trunc({n},{a:.2f},{b:.2f})".format(n=colmap.name, a=0.0, b=0.75),
            colmap(np.linspace(0.0, 0.75, 256)),
        )
        grid_z[grid_z < np.quantile(grid_z, 0.55)] = np.NaN
        colmap.set_bad("white")
        self.fig_main.axes[0].pcolormesh(
            grid_x,
            grid_y,
            grid_z.reshape(grid_x.shape),
            cmap=colmap,
            shading="gouraud",
            zorder=1,
        )
        self.flush_colors()

    def set_color(self, ind, color):
        """Helper function to set a color of a segment with index ind."""
        self.basecolors[ind, :] = color

    def change_color(self, old_color, new_color):
        """Helper function to change a specific color to a different one."""
        self.basecolors[(self.basecolors == old_color).all(axis=1)] = new_color

    def flush_colors(self):
        """Flushes all color changes onto the main scatter plot."""
        self.line_main.set_color(self.basecolors)
        self.fig_main.canvas.draw()
Ejemplo n.º 21
0
def visualize():
    fig = plt.figure()

    ax = plt.axes()

    points_with_annotation = []

    themap = Basemap(projection='gall',
                llcrnrlon = -180,
                llcrnrlat = -90,
                urcrnrlon = 180,
                urcrnrlat = 90,
                resolution = 'l',
                area_thresh = 100000.0,
                )

    themap.drawcoastlines()
    themap.drawcountries()
    themap.fillcontinents(color = 'gainsboro')
    themap.drawmapboundary(fill_color='steelblue')

    #site 1
    x1, y1 = themap(0,44)
    point, = themap.plot(x1, y1, 'o', color='Red',markersize=10)
    fn = get_sample_data("/Users/rogpaxton/Galvanize/SiteMetrics/plot44.0.png", asfileobj=False)
    arr_lena = read_png(fn)
    imagebox = OffsetImage(arr_lena, zoom=0.5)
    xy1 = (0, 44)

    annotation = AnnotationBbox(imagebox, xy1,
                        xybox=(50, 150),
                        xycoords='data',
                        boxcoords="offset points",
                        pad=0.5,
                        arrowprops=dict(arrowstyle="->",
                                        #connectionstyle="angle,angleA=0,angleB=90,rad=3")
                        ))

    annotation.set_visible(False)

    points_with_annotation.append([point, annotation])

    x2, y2 = themap(90,44)
    themap.plot(x2, y2,
            'o',                    # marker shape
            color='Red',         # marker colour
            markersize=6            # marker size
            )

    #site 2
    x2, y2 = themap(90,44)
    point, = themap.plot(x2, y2, 'o', color='Red',markersize=10)
    fn = get_sample_data("/Users/rogpaxton/Galvanize/SiteMetrics/plot45.0.png", asfileobj=False)
    arr_lena = read_png(fn)
    imagebox = OffsetImage(arr_lena, zoom=0.5)
    xy2 = (90, 44)

    annotation = AnnotationBbox(imagebox, xy2,
                        xybox=(50, 150),
                        xycoords='data',
                        boxcoords="offset points",
                        pad=0.5,
                        #arrowprops=dict(arrowstyle="->",
                                        #connectionstyle="angle,angleA=0,angleB=90,rad=3")
                        )

    annotation.set_visible(False)

    points_with_annotation.append([point, annotation])

    #site 3
    x3, y3 = themap(120,-34)
    point, = themap.plot(x3, y3, 'o', color='Red',markersize=10)
    fn = get_sample_data("/Users/rogpaxton/Galvanize/SiteMetrics/plot46.0.png", asfileobj=False)
    arr_lena = read_png(fn)
    imagebox = OffsetImage(arr_lena, zoom=0.5)
    #image = mpimg.imread("plot44.0.png")
    xy3 = (120, -34)

    annotation = AnnotationBbox(imagebox, xy3,
                        xybox=(50, 150),
                        xycoords='data',
                        boxcoords="offset points",
                        pad=0.5,
                        arrowprops=dict(arrowstyle="->",
                                        connectionstyle="angle,angleA=0,angleB=90,rad=3")
                        )

    annotation.set_visible(False)

    points_with_annotation.append([point, annotation])

    #Mouseover
    
    def on_move(event):
        visibility_changed = False
        for point, annotation in points_with_annotation:
            ax.add_artist(annotation)
            should_be_visible = (point.contains(event)[0] == True)

            if should_be_visible != annotation.get_visible():
                visibility_changed = True
                annotation.set_visible(should_be_visible)

        if visibility_changed:

            plt.draw()


    on_move_id = fig.canvas.mpl_connect('motion_notify_event', on_move)


    plt.show()
Ejemplo n.º 22
0
def visualize_clusters(x, y, imgs, colors=None):
    """create an interactive plot visualizing the clusters
       hovering over a point shows the corresponding face crop"""
    # create figure and plot scatter
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.tick_params(axis='x',
                   which='both',
                   bottom=False,
                   top=False,
                   labelbottom=False)
    ax.tick_params(axis='y',
                   which='both',
                   left=False,
                   right=False,
                   labelleft=False)
    if colors:
        line, = ax.plot(x, y, ls="")
        ax.scatter(x, y, c=colors)
    else:
        line, = ax.plot(x, y, ls="", marker="o")

    # create the annotations box
    im = OffsetImage(imgs[0], zoom=1)
    xybox = (50., 50.)
    ab = AnnotationBbox(im, (0, 0),
                        xybox=xybox,
                        xycoords='data',
                        boxcoords="offset points",
                        pad=0.3,
                        arrowprops=dict(arrowstyle="->"))
    # add it to the axes and make it invisible
    ax.add_artist(ab)
    ab.set_visible(False)

    def hover(event):
        # if the mouse is over the scatter points
        if line.contains(event)[0]:
            # find out the index within the array from the event
            ind, = line.contains(event)[1]["ind"]
            # get the figure size
            w, h = fig.get_size_inches() * fig.dpi
            ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.)
            hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.)
            # if event occurs in the top or right quadrant of the figure,
            # change the annotation box position relative to mouse.
            ab.xybox = (xybox[0] * ws, xybox[1] * hs)
            # make annotation box visible
            ab.set_visible(True)
            # place it at the position of the hovered scatter point
            ab.xy = (x[ind], y[ind])
            # set the image corresponding to that point
            #im.set_data(imgs[ind][:,:,::-1])
            im.set_data(imgs[ind])
        else:
            #if the mouse is not over a scatter point
            ab.set_visible(False)
        fig.canvas.draw_idle()

    # add callback for mouse moves
    fig.canvas.mpl_connect('motion_notify_event', hover)
    plt.show()
Ejemplo n.º 23
0
def food_map(store_list):

    lons = list()  # x 座標:經度
    lats = list()  # y 座標:緯度

    # 讀取商家經緯度
    for store in store_list:
        lons.append(store.lon)
        lats.append(store.lat)

    # 轉成矩陣
    lons = np.array(lons)
    lats = np.array(lats)

    x, y = (lons, lats)  # transform coordinates

    # 讀入地圖底圖
    img = plt.imread("/Users/yuchiaching/Desktop/完稿壓縮正確尺寸/地圖+畫框.png")
    # 建立畫框和圖表
    fig, ax = plt.subplots(figsize=(16, 10), dpi=70)
    # 圖表顯示地圖底圖以及設定座標軸
    # 善心人士_2
    # http://hk.uwenku.com/question/p-qgzudzqa-bc.html
    ax.imshow(
        img,
        extent=[121.52209923089963, 121.55634026412044, 25.00897, 25.02854])
    # 不顯示座標軸
    plt.axis('off')

    # 標示商家位置
    line, = ax.plot(x, y, ls="", marker='o', color='#fa4a0c')

    # 商家資訊清單
    message_list = list()

    # 製作資訊框
    for i, store in enumerate(store_list):
        store.types = '、'.join(store.types)

        message = [
            store.name, '[' + str(store.area) + '] | ' + store.types,
            '★ ' + str(store.avg_ranking) + ' / 5.0 | ' + 'NT$' +
            str(store.lowerbound) + ' ~ NT$' + str(store.upperbound)
        ]

        message = '\n'.join(message)

        message_list.append(message)

        message = TextArea(message,
                           minimumdescent=False,
                           textprops=dict(fontproperties=prop))

        xybox = (50, 50)
        ab = AnnotationBbox(message, (x[i], y[i]),
                            xybox=xybox,
                            xycoords='data',
                            boxcoords="offset points",
                            pad=0.3,
                            arrowprops=dict(arrowstyle="->"))

        # 把他放到圖表上
        ax.add_artist(ab)
        # 轉成可顯示
        ab.set_visible(False)

    # CopyPaste
    # 滑鼠事件
    # 游標移到該點位置顯示圖片
    def hover(event):
        # if the mouse is over the scatter points
        if line.contains(event)[0]:
            # find out the index within the array from the event
            ind, = line.contains(event)[1]["ind"]
            # get the figure size
            w, h = fig.get_size_inches() * fig.dpi
            ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.)
            hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.)
            # if event occurs in the top or right quadrant of the figure,
            # change the annotation box position relative to mouse.
            ab.xybox = (xybox[0] * ws, xybox[1] * hs)
            # make annotation box visible
            ab.set_visible(True)
            # place it at the position of the hovered scatter point
            ab.xy = (x[ind], y[ind])
            # set the image corresponding to that point
            message.set_text(message_list[ind])
        else:
            #if the mouse is not over a scatter point
            ab.set_visible(False)
        fig.canvas.draw_idle()

    fig.canvas.mpl_connect('motion_notify_event', hover)

    plt.show()

    return str("finish")
Ejemplo n.º 24
0
    def final_plot(self):
        w = self.nodes_weights_w
        h = self.nodes_weights_h

        # Generate data x, y for scatter and an array of images.
        x = np.arange(w)
        y = np.arange(h)
        # an image array as big as the som lattice
        # fixed values for the shape of the image and the arrow so it will not be too small
        fixed_h = fixed_w = 120

        images_array = np.empty((len(self.input_vectors), fixed_w, fixed_h, 3))
        # associate for each index of the images_array an image that should be a input vector data
        for i in range(len(self.input_vectors)):
            # save the images in the images_array also resizing as necessary (images_array dimensions)
            images_array[i] = np.array(cv2.resize(cv2.imread(self.input_vectors_images_path[i]), (fixed_h, fixed_w)))

        y = np.zeros((w, h))
        x = np.zeros((w, h))
        # VERY IMPORTANT
        # WE WANT A SQUARE SHAPE nxn
        # VERY IMPORTANT
        for i in range(len(x)):
            for j in range(len(x[0])):
                x[i][j] = j
                y[i][j] = i
        # to obtain all the elements of the lattice ordered in one dimensional vector
        x = np.array(x.reshape(len(x) * len(x[0]), 1))
        y = np.array(y.reshape(len(y) * len(y[0]), 1))

        fig = plt.figure()
        # just imposing how the plot should be 1x1
        ax = fig.add_subplot(111)
        line = Line2D(x, y, ls="", marker="")
        ax.add_line(line)

        # upper is important, [:,:,:3] the colors are done by only the first three components
        # ******************************************************
        # NORMALIZED WEIGHTS
        a = self.nodes_weights[:,:,:3]
        normalized_weights =  (255*(a - np.max(a))/-np.ptp(a)).astype(np.uint8)
        ax.imshow( np.invert( normalized_weights ) , origin='upper')

        # lines = plt.scatter( x, y,  marker="+")
        # create the annotations box
        im = OffsetImage(images_array[0, :, :, :], zoom=1)
        #xybox = (50., 50.) # this is also for the arrow
        xybox = (fixed_h, fixed_w)  # this is also for the arrow
        ab = AnnotationBbox(im, (0, 0), xybox=xybox, xycoords='data', boxcoords="offset points", pad=0.3,
                            arrowprops=dict(arrowstyle="->"))
        # add it to the axes and make it invisible
        ax.add_artist(ab)
        ab.set_visible(False)

        def hover(event):

            if line.contains(event)[0]:
                # find out the index within the array from the event
                ind = line.contains(event)[1]["ind"]
                # get the figure size
                ww, hh = fig.get_size_inches() * fig.dpi
                ws = (event.x > ww / 2.) * -1 + (event.x <= ww / 2.)
                hs = (event.y > hh / 2.) * -1 + (event.y <= hh / 2.)
                # if event occurs in the top or right quadrant of the figure,
                # change the annotation box position relative to mouse.
                ind = ind[0]
                ab.xybox = (xybox[0] * ws, xybox[1] * hs) # this should be the lenght of the arrow
                #ab.xybox = (xybox[0] * ws, xybox[1] * hs)
                # make annotation box visible
                ab.set_visible(True)
                # place it at the position of the hovered scatter point
                ab.xy = (x[ind], y[ind])
                # set the image corresponding to that point
                # ******
                # TO OBTAIN THE VALUES OF THE LATTICE
                x_index = math.floor(int(ind) / h)
                y_index = math.floor(int(ind) - (h * math.floor(int(ind) / h)))
                # take the correct "label" index
                image_index = self.most_like_weight(x_index, y_index)
                # im.set_data(  images_array[ind ,:,:][0] )
                #print "image_index = " + str ( image_index )
                normalized_weights_im = images_array[image_index, :, :]
                normalized_weights_im = (255 * (normalized_weights_im - np.max(normalized_weights_im)) / -np.ptp(normalized_weights_im)).astype(np.uint8)

                # VERY IMPORTANT
                # it has to be that way with invert and [:,:,[2,1,0] from opencv to matplotlib
                # VERY IMPORTANT
                im.set_data(np.invert( normalized_weights_im )[:,:,[2,1,0]])
                # important to debug
                #print self.input_vectors_images_path[image_index]
            else:
                # if the mouse is not over a scatter point
                ab.set_visible(False)

            fig.canvas.draw_idle()

        # add callback for mouse moves
        fig.canvas.mpl_connect('motion_notify_event', hover)
        # fig.canvas.mpl_connect('button_press_event', onclick)
        plt.show()
x = np.arange(100)
y = np.random.rand(len(x))
arr = np.zeros((len(x), 10, 10))

# создаем figure и отрисовываем scatter plot
fig = plt.figure()
ax = fig.add_subplot(111)
line, = ax.plot(x, y, ls="", marker="o")

# создаем annotation box
im = OffsetImage(arr[0, :, :], zoom=5)
xybox = (50., 50.)
ab = AnnotationBbox(im, (0, 0), xybox=xybox, xycoords='data', boxcoords="offset points",  pad=0.3,  arrowprops=dict(arrowstyle="-"))
# add it to the axes and make it invisible
ax.add_artist(ab)
ab.set_visible(False)


def hover(event):
    # если курсор поверх точки на графике
    if line.contains(event)[0]:
        # ищем индекс точки
        ind, = line.contains(event)[1]["ind"]
        # размер изображения
        w, h = fig.get_size_inches()*fig.dpi
        ws = (event.x > w/2.)*-1 + (event.x <= w/2.)
        hs = (event.y > h/2.)*-1 + (event.y <= h/2.)
        # если точка наверху и справа, меняем позицию annotation box
        ab.xybox = (xybox[0]*ws, xybox[1]*hs)
        ab.set_visible(True)
        ab.xy = (x[ind], y[ind])
Ejemplo n.º 26
0
def main(args):
    base_colors = ["b", "g", "r", "c", "m", "y", "k"]
    embodiment_names = [
        "two_hands_two_fingers",
        "tongs",
        "rms",
        # "ski_gloves",
        # "quick_grip",
        "one_hand_two_fingers",
        "one_hand_five_fingers",
        # "double_quick_grip",
        # "crab",
        # "quick_grasp",
    ]
    N = len(embodiment_names)

    embs, frames, pos_neg_labels = get_embs(args.embs_path, embodiment_names,
                                            args.num_traj)

    embeddings, rgb_frames, labels = [], [], []
    for name in embodiment_names:
        embeddings.extend(embs[name])
        rgb_frames.extend(frames[name])
        labels.extend(pos_neg_labels[name])
    if args.l2_normalize:
        embeddings = [x / (np.linalg.norm(x) + 1e-7) for x in embeddings]

    # figure out the min embedding length
    min_len = 10000000
    for emb in embeddings:
        if len(emb) < min_len:
            min_len = len(emb)
    print(f"min seq len: {min_len}")

    # subsample embedding sequences to make them
    # the same long since TSNE expects a tensor
    # of fixed sequence length
    embs, frames = [], []
    for i, (emb, rgb) in enumerate(zip(embeddings, rgb_frames)):
        idxs = get_idxs(0, len(emb), min_len)
        embs.append(emb[idxs])
        frames.append(rgb[idxs])
    embs = np.stack(embs)
    frames = np.stack(frames)

    # flatten embeddings (N, 128)
    num_vids, num_frames, num_feats = embs.shape
    embs_flat = embs.reshape(-1, num_feats)

    # dimensionality reduction
    if args.reducer == "umap":
        reducer = umap.UMAP(n_components=args.ndims, random_state=0)
    elif args.reducer == "tsne":
        reducer = TSNE(n_components=args.ndims, n_jobs=-1, random_state=0)
    elif args.reducer == "pca":
        reducer = PCA(n_components=args.ndims, random_state=0)
    else:
        raise ValueError(f"{args.reducer} is not a valid reducer.")

    embs_2d = reducer.fit_transform(embs_flat)

    # TODO: add variance calculation

    # subsample data for less cluttered visualization
    idxs = np.arange(args.num_traj * N)
    mask = []
    for idx in idxs:
        mask.extend(np.arange(idx * min_len, (idx + 1) * min_len))
    mask = np.asarray(mask)
    images = frames[idxs].reshape(-1, *frames.shape[2:])
    embs = embs_2d[mask]

    # resize frames
    frames = []
    for img in images:
        img = (img - img.min()) / (img.max() - img.min())
        im = Image.fromarray((img * 255).astype(np.uint8))
        im.thumbnail((IMG_HEIGHT, IMG_WIDTH), Image.ANTIALIAS)
        frames.append(np.asarray(im))
    frames = np.stack(frames)

    label_names, colors = [], []
    for i, name in enumerate(embodiment_names):
        label_names.extend([name] * args.num_traj)
        colors.extend([base_colors[i]] * args.num_traj)

    # create figure
    fig = plt.figure()
    if args.ndims == 2:
        ax = fig.add_subplot(111)
    else:
        ax = fig.add_subplot(111, projection="3d")

    lines, imgz, abz = [], [], []
    for i in range(len(idxs)):
        if args.ndims == 3:
            line = ax.scatter(
                embs[i * min_len:(i + 1) * min_len, 0],
                embs[i * min_len:(i + 1) * min_len, 1],
                embs[i * min_len:(i + 1) * min_len, 2],
                marker="o" if labels[i] else "x",
                s=15,
                c=colors[i],
                label=label_names[i],
            )
        else:
            line = ax.scatter(
                embs[i * min_len:(i + 1) * min_len, 0],
                embs[i * min_len:(i + 1) * min_len, 1],
                marker="o" if labels[i] else "x",
                s=15,
                c=colors[i],
                label=label_names[i],
            )
        lines.append(line)

        # create annotation box
        im = OffsetImage(frames[i * min_len], zoom=5)
        imgz.append(im)
        xybox = (IMG_HEIGHT, IMG_WIDTH)
        ab = AnnotationBbox(
            im,
            (0, 0),
            xybox=xybox,
            xycoords="data",
            boxcoords="offset points",
            pad=0.3,
            arrowprops=dict(arrowstyle="->"),
        )
        abz.append(ab)

        # add annotation box to axes and make it invisible
        ax.add_artist(ab)
        ab.set_visible(False)

        def hover(event):
            # if the mouse is over the scatter points
            for j, line in enumerate(lines):
                if line.contains(event)[0]:
                    ind = line.contains(event)[1]["ind"][0]
                    w, h = fig.get_size_inches() * fig.dpi
                    ws = (event.x > w / 2.0) * -1 + (event.x <= w / 2.0)
                    hs = (event.y > h / 2.0) * -1 + (event.y <= h / 2.0)
                    abz[j].xybox = (xybox[0] * ws, xybox[1] * hs)
                    abz[j].set_visible(True)
                    abz[j].xy = (
                        embs[j * min_len + ind, 0],
                        embs[j * min_len + ind, 1],
                    )
                    imgz[j].set_data(frames[j * min_len + ind])
                else:
                    abz[j].set_visible(False)
                fig.canvas.draw_idle()

        # add callback for mouse move
        fig.canvas.mpl_connect("motion_notify_event", hover)
    legend_without_duplicate_labels(ax)
    emb_names = "_".join(embodiment_names)
    task_name = args.embs_path.split("/")[-2]
    name = "{}_{}_{}_{}_{}dim.png".format(
        task_name,
        emb_names,
        args.num_traj,
        args.reducer,
        args.ndims,
    )
    plt.savefig(osp.join("tmp/plots/", name), format="png", dpi=300)
    plt.show()