Beispiel #1
0
class ZoomPlot():

    def __init__(self, pnts):
        self.fig = plt.figure(figsize=(15,9))
        self.ax = self.fig.add_subplot(111)

        self.days = pnts['days']
        self.lons = pnts['lons']
        self.lats = pnts['lats']
        self.dirs = pnts['dirs']
        self.coms = pnts['coms']

        self.bnds = self.bnds_strt = [-58, 80, -180, 180]
        self.resolution = 'c'

        # add callback for mouse clicks
        self.fig.canvas.mpl_connect('button_press_event', self.onclick)

        self.plot_map()

    def plot_map(self):
        self.map = Basemap(projection='merc',llcrnrlat=self.bnds[0],urcrnrlat=self.bnds[1],
                      llcrnrlon=self.bnds[2],urcrnrlon=self.bnds[3],resolution=self.resolution)

        self.map.drawcoastlines()
        self.map.drawmapboundary(fill_color='cornflowerblue')
        self.map.fillcontinents(color='lightgreen', lake_color='aqua')
        self.map.drawcountries()
        self.map.drawstates()

        self.plot_points()

        self.fig.canvas.draw()

        self.zoomcall = self.ax.callbacks.connect('ylim_changed', self.onzoom)

    def onzoom(self, axes):
        #print('zoom triggered')
        self.ax.patches.clear()
        self.ax.collections.clear()
        self.ax.callbacks.disconnect(self.zoomcall)

        x1, y1 = self.map(self.ax.get_xlim()[0], self.ax.get_ylim()[0], inverse = True)
        x2, y2 = self.map(self.ax.get_xlim()[1], self.ax.get_ylim()[1], inverse = True)
        self.bnds = [y1, y2, x1, x2]

        # reset zoom to home (workaround for unidentified error when you press the home button)
        if any([a/b > 1 for a,b in zip(self.bnds,self.bnds_strt)]):
            self.bnds = self.bnds_strt # reset map boundaryies
            self.ax.lines.clear() # reset points
            self.ab.set_visible(False) # hide picture if visible

        # change map resolution based on zoom level
        zoom_set = max(abs(self.bnds[0]-self.bnds[1]),abs(self.bnds[2]-self.bnds[3]))
        if zoom_set < 30 and zoom_set >= 3:
            self.resolution = 'l'
            #print('   --- low resolution')
        elif zoom_set < 3:
            self.resolution = 'i'
            #print('   --- intermeditate resolution')
        else:
            self.resolution = 'c'
            #print('   --- coarse resolution')

        self.plot_map()

    def plot_points(self):
        self.x, self.y = self.map(self.lons, self.lats)
        self.line, = self.map.plot(self.x, self.y, color='darkmagenta', linestyle='none', marker='o', markeredgecolor='gold')

        # create the annotations box
        self.pic = mpimg.imread('pics\\profpic.png') # just to set up variables, will change later
        self.im = OffsetImage(self.pic)
        self.xybox = (50., 50.)
        self.ab = AnnotationBbox(self.im, (0,0), xybox=self.xybox, xycoords='data',
                boxcoords="offset points",  pad=0.3,  arrowprops=dict(arrowstyle="->"))
        # add it to the axes and make it invisible
        self.ax.add_artist(self.ab)
        self.ab.set_visible(False)

    def onclick(self, event): # if you click on a data point
        if self.line.contains(event)[0]:
            # find out the index within the array from the event
            try:
                ind, = self.line.contains(event)[1]["ind"]
            except ValueError:
                self.ax.text(0.5, 0.5, 'Please zoom in!', fontsize=24,

                             ha='center', va='center', transform=self.ax.transAxes, weight='bold')
                self.fig.canvas.draw()
                time.sleep(0.7)
                self.ax.texts.clear()
            else:
                # get the figure size
                w,h = self.fig.get_size_inches()*self.fig.dpi
                ws = (event.x > w/2.)*-1 + (event.x <= w/2.)
                hs = (event.y > h/2.)*-1 + (event.y <= h/2.)
                # if event occurs in the top or right quadrant of the figure,
                # change the annotation box position relative to mouse.
                self.ab.xybox = (self.xybox[0]*ws, self.xybox[1]*hs)
                # make annotation box visible
                self.ab.set_visible(True)
                # place it at the position of the hovered scatter point
                self.ab.xy = (self.x[ind], self.y[ind])
                # set the image corresponding to that point
                dir = self.dirs[ind]
                self.im.set_data(mpimg.imread(dir))
                # change zoom of the image to deal with different file sizes
                picsize = max(mpimg.imread(dir).shape)
                self.im.set_zoom(0.4*(1000/picsize)) # optimum: zoom = 0.4, picsize = 1000
        else:
            #if you didn't click on a data point
            self.ab.set_visible(False)
        self.fig.canvas.draw_idle()
Beispiel #2
0
class Discovery(object):
    def __init__(
            self,
            embeddings_file="embeddings.p",
            distance_metric="euclid",
            method="TSNE",
            embedding_size=2,
            overwrite_embeddings=False,
            n_jobs=10,
            dpi=300,
            main_plot_args={},
            tsne_args={},
            save_dir=join(CONFIG.metaseg_io_path, "vis_embeddings"),
    ):
        """Loads the embedding files, computes the dimensionality reductions and calls
        the initilization of the main plot.

        Args:
            embeddings_file (str): Path to the file where all data of segments
                including feature embeddings is saved.
            distance_metric (str): Distance metric to use for nearest neighbor
                computation.
            method (str): Method to use for dimensionality reduction of nearest
                neighbor embeddings. For plotting the points are always reduced in
                dimensionality using PCA to 50 dimensions followed by t-SNE to two
                dimensions.
            embedding_size (int): Dimensionality of the feature embeddings used for
                nearest neighbor search.
            overwrite_embeddings (bool): If True, precomputed nearest neighbor and
                plotting embeddings from previous runs are overwritten with freshly
                computed ones. Otherwise precomputed embeddings are used if requested
                embedding_size is matching.
            n_jobs (int): Number of processes to use for t-SNE computation.
            dpi (int): Dots per inch for graphics that are saved to disk.
            main_plot_args (dict): Keyword arguments for the creation of the main plot.
            tsne_args (dict): Keyword arguments for the t-SNE algorithm.
            save_dir (str): Path to the directory where saved images are placed in.
        """
        self.log = logging.getLogger("Discovery")
        self.embeddings_file = embeddings_file
        self.distance_metrics = ["euclid", "cos"]
        self.dm = (0 if distance_metric not in self.distance_metrics else
                   self.distance_metrics.index(distance_metric))

        self.dpi = dpi
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)

        self.cluster_methods = OrderedDict()
        self.cluster_methods["kmeans"] = {"main": KMeans, "kwargs": {}}
        self.cluster_methods["spectral"] = {
            "main": SpectralClustering,
            "kwargs": {}
        }
        self.cluster_methods["agglo"] = {
            "main": AgglomerativeClustering,
            "kwargs": {
                "linkage": "ward"
            },
        }

        self.methods_with_ncluster_param = ["kmeans", "spectral", "agglo"]
        self.cme = 0
        self.clustering = None
        self.n_clusters = 25

        # colors:
        self.standard_color = (0, 0, 1, 1)
        self.current_color = (1, 0, 0, 1)
        self.nn_color = (1, 0, 1, 1)

        self.log.info("Loading data...")
        with open(self.embeddings_file, "rb") as f:
            self.data = pkl.load(f)

        self.iou_preds = self.data["iou_pred"]
        self.gt = np.array(self.data["gt"]).flatten()
        self.pred = np.array(self.data["pred"]).flatten()
        self.gi = self.data[
            "image_level_index"]  # global indices (on image level and not on component level)

        self.log.info("Loaded {} segment embeddings.".format(
            self.pred.shape[0]))

        self.nearest_neighbors = None

        if len(self.data["embeddings"]) == 1:
            self.data["plot_embeddings"] = np.array(
                [self.data["embeddings"][0][0],
                 self.data["embeddings"][0][1]]).reshape((1, 2))
            self.data["nn_embeddings"] = self.data["plot_embeddings"]
        else:
            if ("nn_embeddings" not in self.data.keys() or overwrite_embeddings
                    or "plot_embeddings" not in self.data.keys()
                ) and embedding_size < self.data["embeddings"][0].shape[0]:
                self.log.info("Computing PCA...")
                n_comp = (50 if 50 < min(
                    len(self.data["embeddings"]),
                    self.data["embeddings"][0].shape[0],
                ) else min(
                    len(self.data["embeddings"]),
                    self.data["embeddings"][0].shape[0],
                ))
                embeddings = PCA(n_components=n_comp).fit_transform(
                    np.stack(self.data["embeddings"]).reshape(
                        (-1, self.data["embeddings"][0].shape[0])))
                rewrite = True
            else:
                rewrite = False

            if "plot_embeddings" not in self.data.keys(
            ) or overwrite_embeddings:
                self.log.info("Computing t-SNE for plotting")
                self.data["plot_embeddings"] = TSNE(
                    n_components=2, **tsne_args).fit_transform(embeddings)
                new_plot_embeddings = True
            else:
                new_plot_embeddings = False

            if (embedding_size >= self.data["embeddings"][0].shape[0]
                    or embedding_size is None):
                self.embeddings = np.stack(self.data["embeddings"]).reshape(
                    (-1, self.data["embeddings"][0].shape[0]))
                self.log.debug(
                    ("Requested embedding size of {} was greater or equal "
                     "to data dimensionality of {}. "
                     "Data has thus not been reduced in dimensionality."
                     ).format(embedding_size,
                              self.data["embeddings"].shape[1]))
            elif (self.data["nn_embeddings"].shape[1] == embedding_size
                  if "nn_embeddings" in self.data.keys() else
                  False) and not overwrite_embeddings:
                self.embeddings = self.data["nn_embeddings"]
                self.log.info(("Loaded reduced embeddings "
                               "({} dimensions) from precomputed file " +
                               "for nearest neighbor search.").format(
                                   self.embeddings.shape[1]))
            elif rewrite:
                if method == "TSNE":
                    if ("plot_embeddings" in self.data.keys()
                            and embedding_size == 2 and new_plot_embeddings):
                        self.embeddings = self.data["plot_embeddings"]
                        self.log.info((
                            "Reused the precomputed manifold for plotting for "
                            "nearest neighbor search."))
                    else:
                        self.log.info(("Computing t-SNE of dimension "
                                       "{} for nearest neighbor search..."
                                       ).format(embedding_size))
                        self.embeddings = TSNE(
                            n_components=embedding_size,
                            n_jobs=n_jobs,
                            **tsne_args).fit_transform(embeddings)
                else:
                    self.log.info(("Computing Isomap of dimension "
                                   "{} for nearest neighbor search..."
                                   ).format(embedding_size))
                    self.embeddings = Isomap(
                        n_components=embedding_size,
                        n_jobs=n_jobs,
                    ).fit_transform(embeddings)
                self.data["nn_embeddings"] = self.embeddings
            else:
                raise ValueError(
                    ("Please specify a valid combination of arguments.\n"
                     "Loading fails if 'overwrite_embeddings' is False and "
                     "saved 'embedding_size' does not match the requested one."
                     ))

            # Write added data into pickle file
            if rewrite:
                with open(self.embeddings_file, "wb") as f:
                    pkl.dump(self.data, f)

        self.x = self.data["plot_embeddings"][:, 0]
        self.y = self.data["plot_embeddings"][:, 1]

        self.label_mapping = dict()
        for d in np.unique(self.data["dataset"]).flatten():
            try:
                self.label_mapping[d] = getattr(
                    importlib.import_module(datasets[d].module_name),
                    datasets[d].class_name,
                )(**datasets[d].kwargs, ).label_mapping
            except AttributeError:
                self.label_mapping[d] = None

        train_dat = self.label_mapping[CONFIG.TRAIN_DATASET.name] = getattr(
            importlib.import_module(CONFIG.TRAIN_DATASET.module_name),
            CONFIG.TRAIN_DATASET.class_name,
        )(**CONFIG.TRAIN_DATASET.kwargs, )
        self.pred_mapping = train_dat.pred_mapping
        if CONFIG.TRAIN_DATASET.name not in self.label_mapping:
            self.label_mapping[
                CONFIG.TRAIN_DATASET.name] = train_dat.label_mapping

        self.tnsize = (50, 50)
        self.fig_nn = None
        self.fig_main = None
        self.line_main = None
        self.im = None
        self.xybox = None
        self.ab = None
        self.basecolors = np.stack(
            [self.standard_color for _ in range(self.x.shape[0])])
        self.n_neighbors = 49
        self.current_pressed_key = None

        self.plot_main(**main_plot_args)

    def plot_main(self, **plot_args):
        """Initializes the main plot.

        Only 'legend' (bool) is currently supported as keyword argument.
        """
        self.fig_main = plt.figure(num=1)
        self.fig_main.canvas.set_window_title("Embedding space")
        ax = self.fig_main.add_subplot(111)
        ax.set_axis_off()
        self.line_main = ax.scatter(self.x,
                                    self.y,
                                    marker="o",
                                    color=self.basecolors,
                                    zorder=2)
        self.line_main.set_picker(True)

        if ((plot_args["legend"]
             and all(lm is not None for lm in self.label_mapping.values()))
                if "legend" in plot_args else False):
            box = ax.get_position()
            ax.set_position([box.x0, box.y0, box.width, box.height * 0.8])
            legend_elements = []
            for d in np.unique(self.data["dataset"]).flatten():
                cls = np.unique(
                    self.gt[np.array(self.data["dataset"])[self.gi] == d])
                cls = list({(self.label_mapping[d][cl][0],
                             self.label_mapping[d][cl][1])
                            for cl in cls})
                names = np.array([i[0] for i in cls])
                cols = np.array([i[1] for i in cls])
                legend_elements += [
                    Patch(
                        color=tuple(i / 255.0 for i in cols[i]) + (1.0, ),
                        label=names[i] if not names[i][-1].isdigit() else
                        names[i][:names[i].rfind(" ")],
                    ) for i in range(names.shape[0])
                ]
            ax.legend(
                loc="upper left",
                handles=legend_elements,
                ncol=8,
                bbox_to_anchor=(0, 1.2),
            )
        self.basecolors = self.line_main.get_facecolor()

        tmp = (Image.open(
            self.data["image_path"][self.gi[0]]).convert("RGB").crop(
                self.data["box"][0]))
        tmp.thumbnail(self.tnsize, Image.ANTIALIAS)
        self.im = OffsetImage(tmp, zoom=2)
        self.xybox = (50.0, 50.0)
        self.ab = AnnotationBbox(
            self.im,
            (0, 0),
            xybox=self.xybox,
            xycoords="data",
            boxcoords="offset points",
            pad=0.3,
            arrowprops=dict(arrowstyle="->"),
        )
        ax.add_artist(self.ab)
        self.ab.set_visible(False)

        if plot_args[
                "save_path"] is not None if "save_path" in plot_args else False:
            plt.savefig(expanduser(plot_args["save_path"]),
                        dpi=300,
                        bbox_inches="tight")

        else:
            self.fig_main.canvas.mpl_connect("motion_notify_event",
                                             self.hover_main)
            self.fig_main.canvas.mpl_connect("button_press_event",
                                             self.click_main)
            self.fig_main.canvas.mpl_connect("scroll_event", self.scroll)
            self.fig_main.canvas.mpl_connect("key_press_event", self.key_press)
            self.fig_main.canvas.mpl_connect("key_release_event",
                                             self.key_release)
            plt.show()

    def hover_main(self, event):
        """Action handler for the main plot.

        This function shows a thumbnail of the underlying image when a scatter point
        is hovered with the mouse.
        """
        # if the mouse is over the scatter points
        if self.line_main.contains(event)[0]:
            # find out the index within the array from the event
            ind, *_ = self.line_main.contains(event)[1]["ind"]

            # get the figure size
            w, h = self.fig_main.get_size_inches() * self.fig_main.dpi
            ws = (event.x > w / 2.0) * -1 + (event.x <= w / 2.0)
            hs = (event.y > h / 2.0) * -1 + (event.y <= h / 2.0)

            # if event occurs in the top or right quadrant of the figure,
            # change the annotation box position relative to mouse.
            self.ab.xybox = (self.xybox[0] * ws, self.xybox[1] * hs)

            # make annotation box visible
            self.ab.set_visible(True)

            # place it at the position of the hovered scatter point
            self.ab.xy = (self.x[ind], self.y[ind])

            # set the image corresponding to that point
            tmp = (Image.open(
                self.data["image_path"][self.gi[ind]]).convert("RGB").crop(
                    self.data["box"][ind]))
            tmp.thumbnail(self.tnsize, Image.ANTIALIAS)
            self.im.set_data(tmp)
            tmp.close()
        else:
            # if the mouse is not over a scatter point
            self.ab.set_visible(False)
        self.fig_main.canvas.draw_idle()

    def click_main(self, event):
        """Action handler for the main plot.

        This function shows a single or full image or displays nearest neighbors based
        on the button that has been pressed and which scatter point was pressed.
        """
        if self.line_main.contains(event)[0]:
            ind, *_ = self.line_main.contains(event)[1]["ind"]

            if self.current_pressed_key == "t" and event.button == 1:
                self.store_thumbnail(ind)
            elif self.current_pressed_key == "control" and event.button == 1:
                self.show_single_image(ind, save=True)
            elif self.current_pressed_key == "control" and event.button == 2:
                self.show_full_image(ind, save=True)
            elif event.button == 1:  # left mouse button
                self.show_single_image(ind)
            elif event.button == 2:  # middle mouse button
                self.show_full_image(ind)
            elif event.button == 3:  # right mouse button
                if not plt.fignum_exists(2):
                    # nearest neighbor figure is not open anymore or has not been
                    # opened yet
                    self.log.info("Loading nearest neighbors...")
                    self.nearest_neighbors = self.get_nearest_neighbors(
                        ind, metric=self.distance_metrics[self.dm])
                    thumbnails = []
                    for neighbor_ind in self.nearest_neighbors:
                        thumbnails.append(
                            Image.open(self.data["image_path"][
                                self.gi[neighbor_ind]]).crop(
                                    self.data["box"][neighbor_ind]))
                    columns = math.ceil(math.sqrt(self.n_neighbors))
                    rows = math.ceil(self.n_neighbors / columns)

                    self.fig_nn = plt.figure(num=2, dpi=self.dpi)
                    self.fig_nn.canvas.set_window_title(
                        "{} nearest neighbors to selected image".format(
                            self.n_neighbors))
                    for p in range(columns * rows):
                        ax = self.fig_nn.add_subplot(rows, columns, p + 1)
                        ax.set_axis_off()
                        if p < len(thumbnails):
                            ax.imshow(np.asarray(thumbnails[p]))
                    self.fig_nn.canvas.mpl_connect("button_press_event",
                                                   self.click_nn)
                    self.fig_nn.canvas.mpl_connect("key_press_event",
                                                   self.key_press)
                    self.fig_nn.canvas.mpl_connect("key_release_event",
                                                   self.key_release)
                    self.fig_nn.canvas.mpl_connect("scroll_event", self.scroll)
                    self.fig_nn.show()
                else:
                    # nearest neighbor figure is already open. Update the figure with
                    # new nearest neighbor
                    self.update_nearest_neighbors(ind)
                    return

            self.set_color(ind, self.current_color)
            self.flush_colors()

    def click_nn(self, event):
        """Action handler for the nearest neighbor window.

        When clicking a cropped segment in the nearest neighbor window the same actions
        are taken as in the click handler for the main plot.
        """
        if event.inaxes in self.fig_nn.axes:
            ind = self.get_ind_nn(event)

            if self.current_pressed_key == "t" and event.button == 1:
                self.store_thumbnail(self.nearest_neighbors[ind])
            elif self.current_pressed_key == "control" and event.button == 1:
                self.show_single_image(self.nearest_neighbors[ind], save=True)
            elif self.current_pressed_key == "control" and event.button == 2:
                self.show_full_image(self.nearest_neighbors[ind], save=True)
            elif event.button == 1:  # left mouse button
                self.show_single_image(self.nearest_neighbors[ind])
            elif event.button == 2:  # middle mouse button
                self.show_full_image(self.nearest_neighbors[ind])
            elif event.button == 3:  # right mouse button
                self.update_nearest_neighbors(self.nearest_neighbors[ind])

    def key_press(self, event):
        """Performs different actions based on pressed keys."""
        self.log.debug("Key '{}' pressed.".format(event.key))
        if event.key == "m":
            self.dm += 1
            self.dm = self.dm % len(self.distance_metrics)
            self.log.info("Changed distance metric to {}".format(
                self.distance_metrics[self.dm]))
        elif event.key == "#":
            self.cme += 1
            self.cme = self.cme % len(self.cluster_methods)
            self.log.info("Changed clustering method to {}".format(
                list(self.cluster_methods.keys())[self.cme]))
        elif event.key == "c":
            self.log.info("Started clustering with {}...".format(
                list(self.cluster_methods.keys())[self.cme]))
            self.cluster(method=list(self.cluster_methods.keys())[self.cme])
            if self.fig_main.axes[0].get_legend() is not None:
                self.fig_main.axes[0].get_legend().remove()
            self.basecolors = cm.get_cmap(
                "viridis", (max(self.clustering) + 1))(self.clustering)
            self.flush_colors()
        elif event.key == "g":
            self.color_gt()
        elif event.key == "h":
            self.color_pred()
        elif event.key == "b":
            self.set_color(list(range(self.basecolors.shape[0])),
                           self.standard_color)
            if self.fig_main.axes[0].get_legend() is not None:
                self.fig_main.axes[0].get_legend().remove()
            self.flush_colors()
        elif event.key == "d":
            self.show_density()

        self.current_pressed_key = event.key

    def key_release(self, event):
        """Clears the variable where the last pressed key is saved."""
        self.current_pressed_key = None
        self.log.debug("Key '{}' released.".format(event.key))

    def scroll(self, event):
        """Increases or decreases number of nearest neighbors when scrolling on
        the main or nearest neighbor plot."""
        if event.button == "up":
            self.n_neighbors += 1
            self.log.info("Increased number of nearest neighbors to {}".format(
                self.n_neighbors))
        elif event.button == "down":
            if self.n_neighbors > 0:
                self.n_neighbors -= 1
                self.log.info(
                    "Decreased number of nearest neighbors to {}".format(
                        self.n_neighbors))

    def show_single_image(self, ind, save=False):
        """Displays the full image belonging to a segment. The segment is marked with
        a red bounding box."""
        self.log.info("{} image...".format("Saving" if save else "Loading"))
        img_box = self.draw_box_on_image(ind)
        fig_tmp = plt.figure(max(3, max(plt.get_fignums()) + 1), dpi=self.dpi)
        ax = fig_tmp.add_subplot(111)
        ax.set_axis_off()
        ax.imshow(np.asarray(img_box), interpolation="nearest")
        if save:
            fig_tmp.subplots_adjust(bottom=0,
                                    left=0,
                                    right=1,
                                    top=1,
                                    hspace=0,
                                    wspace=0)
            ax.margins(0.05, 0.05)
            fig_tmp.gca().xaxis.set_major_locator(plt.NullLocator())
            fig_tmp.gca().yaxis.set_major_locator(plt.NullLocator())
            fig_tmp.savefig(
                join(self.save_dir, "image_{}.jpg".format(ind)),
                bbox_inches="tight",
                pad_inches=0.0,
            )
            self.log.debug("Saved image to {}".format(
                join(self.save_dir, "image_{}.jpg".format(ind))))
        else:
            fig_tmp.canvas.set_window_title(
                "Dataset: {}, Image index: {}".format(
                    self.data["dataset"][self.gi[ind]],
                    self.data["image_index"][self.gi[ind]],
                ))
            fig_tmp.tight_layout(pad=0.0)
            fig_tmp.show()

    def show_full_image(self, ind, save=False):
        """Displays four panels of the full image belonging to a segment.

        Top left: Entropy heatmap of prediction.
        Top right: Predicted IoU of each segment.
        Bottom left: Source image with ground truth overlay.
        Bottom right: Predicted semantic segmentation.
        """
        self.log.info(
            "{} detailed image...".format("Saving" if save else "Loading"))
        box = self.data["box"][ind]
        image = np.asarray(
            Image.open(self.data["image_path"][self.gi[ind]]).convert("RGB"))
        image_index = self.data["image_index"][self.gi[ind]]
        iou_pred = self.data["iou_pred"][self.gi[ind]]
        dataset = self.data["dataset"][self.gi[ind]]
        model_name = self.data["model_name"][self.gi[ind]]

        pred, gt, image_path = probs_gt_load(
            image_index,
            input_dir=join(CONFIG.metaseg_io_path, "input", model_name,
                           dataset),
        )
        components = components_load(
            image_index,
            components_dir=join(CONFIG.metaseg_io_path, "components",
                                model_name, dataset),
        )
        e = entropy(pred)
        pred = pred.argmax(2)
        predc = np.asarray([
            self.pred_mapping[pred[ind_i, ind_j]][1]
            for ind_i in range(pred.shape[0]) for ind_j in range(pred.shape[1])
        ]).reshape(image.shape)
        overlay_factor = [1.0, 0.5, 1.0]

        if self.label_mapping[dataset] is not None:
            gtc = np.asarray([
                self.label_mapping[dataset][gt[ind_i, ind_j]][1]
                for ind_i in range(gt.shape[0]) for ind_j in range(gt.shape[1])
            ]).reshape(image.shape)
        else:
            gtc = np.zeros_like(image)
            overlay_factor[1] = 0.0

        img_predc, img_gtc, img_entropy = [
            Image.fromarray(
                np.uint8(arr * overlay_factor[i] + image *
                         (1 - overlay_factor[i])))
            for i, arr in enumerate([predc, gtc,
                                     cm.jet(e)[:, :, :3] * 255.0])
        ]

        img_ioupred = Image.fromarray(
            self.visualize_segments(components, iou_pred))

        images = [img_gtc, img_predc, img_entropy, img_ioupred]

        box_line_width = 5
        left, upper = max(0, box[0] - box_line_width), max(
            0, box[1] - box_line_width)
        right, lower = min(pred.shape[1], box[2] + box_line_width), min(
            pred.shape[0], box[3] + box_line_width)

        for k in images:
            draw = ImageDraw.Draw(k)
            draw.rectangle([left, upper, right, lower],
                           outline=(255, 0, 0),
                           width=box_line_width)
            del draw

        for k in range(len(images)):
            images[k] = np.asarray(images[k]).astype("uint8")

        img_top = np.concatenate(images[2:], axis=1)
        img_bottom = np.concatenate(images[:2], axis=1)

        img_total = np.concatenate((img_top, img_bottom), axis=0)
        fig_tmp = plt.figure(max(3, max(plt.get_fignums()) + 1), dpi=self.dpi)
        fig_tmp.canvas.set_window_title("Dataset: {}, Image index: {}".format(
            dataset, image_index))
        ax = fig_tmp.add_subplot(111)
        ax.set_axis_off()
        ax.imshow(img_total, interpolation="nearest")

        if save:
            fig_tmp.subplots_adjust(bottom=0,
                                    left=0,
                                    right=1,
                                    top=1,
                                    hspace=0,
                                    wspace=0)
            ax.margins(0.05, 0.05)
            fig_tmp.gca().xaxis.set_major_locator(plt.NullLocator())
            fig_tmp.gca().yaxis.set_major_locator(plt.NullLocator())
            fig_tmp.savefig(
                join(self.save_dir, "detailed_image_{}.jpg".format(ind)),
                bbox_inches="tight",
                pad_inches=0.0,
            )
            self.log.debug("Saved image to {}".format(
                join(self.save_dir, "detailed_image_{}.jpg".format(ind))))
        else:
            fig_tmp.tight_layout(pad=0.0)
            fig_tmp.show()

    def store_thumbnail(self, ind):
        """Stores a thumbnail of a segment if requested. Thus is not saving the whole
        image but only the cropped part."""
        image = Image.open(
            self.data["image_path"][self.gi[ind]]).convert("RGB")
        image = image.crop(self.data["box"][ind])

        if self.label_mapping[self.data["dataset"][self.gi[ind]]] is None:
            name = "None"
        else:
            name = self.label_mapping[self.data["dataset"][self.gi[ind]]][
                self.gt[ind]][0]
        if name[-1].isdigit():
            name = name[:-2]

        name = name.replace(" ", "_")

        image.save(
            join(
                self.save_dir,
                "thumbnail_{}_{:0>2.1f}_{:0>2.1f}.jpg".format(
                    name, self.x[ind], self.y[ind]),
            ))
        self.log.debug("Saved thumbnail to {}".format(
            join(
                self.save_dir,
                "thumbnail_{}_{:0>2.1f}_{:0>2.1f}.jpg".format(
                    name, self.x[ind], self.y[ind]),
            )))

    def get_nearest_neighbors(self, ind, metric="cos"):
        """Computes nearest neighbors to the specified index in the collection of
        segment crops."""
        if metric == "euclid":
            dists = self.lp_dist(self.embeddings[ind], self.embeddings, d=2)
        else:
            dists = self.cos_dist(self.embeddings[ind], self.embeddings)
        return np.argsort(dists)[1:(self.n_neighbors + 1)]

    def update_nearest_neighbors(self, ind):
        """If requesting nearest neighbors of a segment within the nearest neighbor
        plot window the nearest neighbors are updated according to the newly
        selected segment.
        """
        self.log.info("Loading nearest neighbors...")
        self.nearest_neighbors = self.get_nearest_neighbors(
            ind, metric=self.distance_metrics[self.dm])
        thumbnails = []
        for neighbor_ind in self.nearest_neighbors:
            b = self.data["box"][neighbor_ind]
            thumbnails.append(
                plt.imread(self.data["image_path"][self.gi[neighbor_ind]])[
                    b[1]:b[3], b[0]:b[2], :])
        n = math.ceil(math.sqrt(len(self.nearest_neighbors)))
        if len(self.fig_nn.axes) != (n**2):
            # number of nearest neighbors has been changed
            # redefine number of subplots in fig_nn
            self.rearrange_axes(n, n)

        for p in range(n**2):
            if p < self.n_neighbors:
                self.fig_nn.axes[p].imshow(thumbnails[p])
            else:
                self.fig_nn.axes[p].clear()
                self.fig_nn.axes[p].set_axis_off()

        self.fig_nn.canvas.draw()
        self.set_color(ind, self.current_color)
        self.flush_colors()

    def cluster(self, method="kmeans"):
        if method in self.methods_with_ncluster_param:
            n_clusters = self.n_cluster_prompt()
            if n_clusters == "elbow" and method == "kmeans":
                n_clusters = self.elbow()

            self.clustering = self.cluster_methods[method]["main"](
                n_clusters=n_clusters,
                **self.cluster_methods[method]["kwargs"]).fit_predict(
                    self.embeddings)

    def elbow(self):
        low = int(input("Enter the minimum number of clusters: "))
        high = int(input("Enter the maximum number of clusters: "))
        km = [KMeans(n_clusters=i) for i in range(low, high + 1)]

        km = [k.fit(self.embeddings) for k in tqdm(km, total=len(km))]
        score = [k.inertia_ for k in km]

        fig_elbow = plt.figure(max(3, max(plt.get_fignums()) + 1))
        ax = fig_elbow.add_subplot(111)
        ax.plot(range(low, high + 1), score)
        fig_elbow.show()
        return int(input("Enter number of clusters: "))

    def n_cluster_prompt(self):
        inp = input("Enter the number of clusters: ")
        if inp == "elbow":
            return inp
        else:
            try:
                inp = int(inp)
            except ValueError:
                self.log.error(
                    "Input should be an int or 'elbow' but received {}!".
                    format(inp))
                return "error"
            if inp <= 1:
                raise ValueError(
                    "Number of clusters should be greater than 1!")
            else:
                return inp

    def rearrange_axes(self, nrows, ncols):
        """Helper function for the nearest neighbor plot window. If number of nearest
        neighbors has been changed and a new query segment has been chosen the
        arrangement of subplots within the window has to be changed.
        """
        n = len(self.fig_nn.axes)
        if n <= (nrows * ncols):
            # we need to add more axes
            for i, ax in enumerate(self.fig_nn.axes):
                ax.change_geometry(nrows, ncols, i + 1)
            for p in range(n, nrows * ncols):
                ax = self.fig_nn.add_subplot(nrows, ncols, p + 1)
                ax.set_axis_off()
        else:
            # we need to remove some axes
            for p in range(n - 1, (nrows * ncols) - 1, -1):
                self.fig_nn.delaxes(self.fig_nn.axes[p])
            for i, ax in enumerate(self.fig_nn.axes):
                ax.change_geometry(nrows, ncols, i + 1)

    def draw_box_on_image(self, ind):
        """Draws the red bounding of a selected segment onto the source image."""
        box_line_width = 5
        img_box = Image.open(
            self.data["image_path"][self.gi[ind]]).convert("RGB")
        draw = ImageDraw.Draw(img_box)
        left, upper, right, lower = self.data["box"][ind]
        left, upper = max(0,
                          left - box_line_width), max(0,
                                                      upper - box_line_width)
        right, lower = min(img_box.size[0], right + box_line_width), min(
            img_box.size[1], lower + box_line_width)
        draw.rectangle([left, upper, right, lower],
                       outline=(255, 0, 0),
                       width=box_line_width)
        del draw
        return img_box

    @staticmethod
    def visualize_segments(comp, metric):
        """Helper function for generation of the four panels in the detailed
        image function."""
        r = np.asarray(metric)
        r = 1 - 0.5 * r
        g = np.asarray(metric)
        b = 0.3 + 0.35 * np.asarray(metric)

        r = np.concatenate((r, np.asarray([0, 1])))
        g = np.concatenate((g, np.asarray([0, 1])))
        b = np.concatenate((b, np.asarray([0, 1])))

        components = np.asarray(comp.copy(), dtype="int16")
        components[components < 0] = len(r) - 1
        components[components == 0] = len(r)

        img = np.zeros(components.shape + (3, ))
        x, y = np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]))
        x = x.reshape(-1)
        y = y.reshape(-1)

        img[x, y, 0] = r[components[x, y] - 1]
        img[x, y, 1] = g[components[x, y] - 1]
        img[x, y, 2] = b[components[x, y] - 1]

        img = np.asarray(255 * img).astype("uint8")

        return img

    @staticmethod
    def lp_dist(point, all_points, d=2):
        """Calculates the L_p distance from a point to a collection of points.
        Used for retrieval."""
        return ((all_points - point)**d).sum(1)**(1.0 / d)

    @staticmethod
    def cos_dist(point, all_points):
        """Calculates the cosine distance from a point to a collection of points.
        Used for retrieval."""
        return 1 - ((point * all_points).sum(1) /
                    (norm(point) * norm(all_points, axis=1)))

    @staticmethod
    def get_gridsize(fig):
        """Helper function for the nearest neighbor plot."""
        return fig.axes[0].get_subplotspec().get_gridspec().get_geometry()

    def get_ind_nn(self, event):
        """Helper function for the nearest neighbor plot."""
        _, ncols = self.get_gridsize(self.fig_nn)
        eventrow = event.inaxes.rowNum
        eventcol = event.inaxes.colNum
        return (eventrow * ncols) + eventcol

    def color_gt(self):
        """When called colors the scatter in the main plot according to the ground
        truth colors."""
        if all(self.label_mapping[self.data["dataset"][self.gi[ind]]]
               is not None for ind in range(self.basecolors.shape[0])):
            self.basecolors = np.stack([
                tuple(i / 255.0
                      for i in self.label_mapping[self.data["dataset"][
                          self.gi[ind]]][self.gt[ind]][1]) + (1.0, )
                for ind in range(self.basecolors.shape[0])
            ])
            legend_elements = []
            for d in np.unique(self.data["dataset"]).flatten():
                cls = np.unique(
                    self.gt[np.array(self.data["dataset"])[self.gi] == d])
                cls = list({(self.label_mapping[d][cl][0],
                             self.label_mapping[d][cl][1])
                            for cl in cls})
                names = np.array([i[0] for i in cls])
                cols = np.array([i[1] for i in cls])
                legend_elements += [
                    Patch(
                        color=tuple(i / 255.0 for i in cols[i]) + (1.0, ),
                        label=names[i] if not names[i][-1].isdigit() else
                        names[i][:names[i].rfind(" ")],
                    ) for i in range(names.shape[0])
                ]
            self.fig_main.axes[0].legend(
                loc="upper left",
                handles=legend_elements,
                ncol=8,
                bbox_to_anchor=(0, 1.1),
            )
            self.flush_colors()

    def color_pred(self):
        """When called colors the scatter in the main plot according to the predicted
        class color."""
        self.basecolors = np.stack([
            tuple(i / 255.0
                  for i in self.pred_mapping[self.pred[ind]][1]) + (1.0, )
            for ind in range(self.basecolors.shape[0])
        ])
        legend_elements = [
            Patch(
                color=tuple(i / 255.0
                            for i in self.pred_mapping[cl][1]) + (1.0, ),
                label=self.pred_mapping[cl][0],
            ) for cl in np.unique(self.pred).flatten()
        ]
        self.fig_main.axes[0].legend(loc="upper left",
                                     handles=legend_elements,
                                     ncol=8,
                                     bbox_to_anchor=(0, 1.1))
        self.flush_colors()

    def show_density(self):
        embedding_kde = estimate_kernel_density(self.data["plot_embeddings"])
        xmin = self.x.min()
        xmin = xmin * 1.3 if xmin < 0 else xmin * 0.8
        xmax = self.x.max()
        xmax = xmax * 1.3 if xmax > 0 else xmax * 0.8

        ymin = self.y.min()
        ymin = ymin * 1.3 if ymin < 0 else ymin * 0.8
        ymax = self.y.max()
        ymax = ymax * 1.3 if ymax > 0 else ymax * 0.8

        grid_x, grid_y = np.mgrid[xmin:xmax, ymin:ymax]
        grid_z = embedding_kde(np.vstack([grid_x.flatten(), grid_y.flatten()]))
        colmap = plt.get_cmap("Greys")
        colmap = colors.LinearSegmentedColormap.from_list(
            "trunc({n},{a:.2f},{b:.2f})".format(n=colmap.name, a=0.0, b=0.75),
            colmap(np.linspace(0.0, 0.75, 256)),
        )
        grid_z[grid_z < np.quantile(grid_z, 0.55)] = np.NaN
        colmap.set_bad("white")
        self.fig_main.axes[0].pcolormesh(
            grid_x,
            grid_y,
            grid_z.reshape(grid_x.shape),
            cmap=colmap,
            shading="gouraud",
            zorder=1,
        )
        self.flush_colors()

    def set_color(self, ind, color):
        """Helper function to set a color of a segment with index ind."""
        self.basecolors[ind, :] = color

    def change_color(self, old_color, new_color):
        """Helper function to change a specific color to a different one."""
        self.basecolors[(self.basecolors == old_color).all(axis=1)] = new_color

    def flush_colors(self):
        """Flushes all color changes onto the main scatter plot."""
        self.line_main.set_color(self.basecolors)
        self.fig_main.canvas.draw()
class Discovery(object):
    def __init__(self,
                 embeddings_file='embeddings.p',
                 distance_metric='euclid',
                 method='TSNE',
                 embedding_size=2,
                 overwrite_embeddings=False,
                 n_jobs=10,
                 dpi=300,
                 main_plot_args={},
                 tsne_args={},
                 save_dir=join(CONFIG.metaseg_io_path, 'vis_embeddings')):
        """Loads the embedding files, computes the dimensionality reductions and calls the initilization
        of the main plot.

        Args:
            embeddings_file (str): Path to the file where all data of segments including feature embeddings is saved.
            distance_metric (str): Distance metric to use for nearest neighbor computation.
            method (str): Method to use for dimensionality reduction of nearest neighbor embeddings. For plotting the
                points are always reduced in dimensionality using PCA to 50 dimensions followed by t-SNE to two
                dimensions.
            embedding_size (int): Dimensionality of the feature embeddings used for nearest neighbor search.
            overwrite_embeddings (bool): If True, precomputed nearest neighbor and plotting embeddings from previous
                runs are overwritten with freshly computed ones. Otherwise precomputed embeddings are used if requested
                embedding_size is matching.
            n_jobs (int): Number of processes to use for t-SNE computation.
            dpi (int): Dots per inch for graphics that are saved to disk.
            main_plot_args (dict): Keyword arguments for the creation of the main plot.
            tsne_args (dict): Keyword arguments for the t-SNE algorithm.
            save_dir (str): Path to the directory where saved images are placed in.
        """
        self.log = logging.getLogger('Discovery')
        self.embeddings_file = embeddings_file
        self.distance_metrics = ['euclid', 'cos']
        self.dm = 0 if distance_metric not in self.distance_metrics else self.distance_metrics.index(distance_metric)

        self.dpi = dpi
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)

        self.cluster_methods = OrderedDict()
        self.cluster_methods['kmeans'] = {'main': KMeans, 'kwargs': {}}
        # self.cluster_methods['spectral'] = {'main': SpectralClustering, 'kwargs': {}}
        self.cluster_methods['agglo'] = {'main': AgglomerativeClustering, 'kwargs': {'linkage': 'ward'}}

        self.methods_with_ncluster_param = ['kmeans', 'spectral', 'agglo']
        self.cme = 0
        self.clustering = None
        self.n_clusters = 25

        # colors:
        self.standard_color = (0, 0, 1, 1)
        self.current_color = (1, 0, 0, 1)
        self.nn_color = (1, 0, 1, 1)

        self.log.info('Loading data...')
        with open(self.embeddings_file, 'rb') as f:
            self.data = pkl.load(f)

        self.iou_preds = self.data['iou_pred']
        self.gt = np.array(self.data['gt']).flatten()
        self.pred = np.array(self.data['pred']).flatten()
        self.gi = self.data['image_level_index']  # global indices (on image level and not on component level)

        self.log.info('Loaded {} segment embeddings.'.format(self.pred.shape[0]))

        self.nearest_neighbors = None

        if len(self.data['embeddings']) == 1:
            self.data['plot_embeddings'] = np.array([self.data['embeddings'][0][0],
                                                     self.data['embeddings'][0][1]]).reshape((1, 2))
            self.data['nn_embeddings'] = self.data['plot_embeddings']
        else:
            if ('nn_embeddings' not in self.data.keys()
                or overwrite_embeddings
                or 'plot_embeddings' not in self.data.keys()) \
                    and embedding_size < self.data['embeddings'][0].shape[0]:
                self.log.info('Computing PCA...')
                n_comp = 50 if 50 < min(len(self.data['embeddings']),
                                        self.data['embeddings'][0].shape[0]) else min(len(self.data['embeddings']),
                                                                                      self.data['embeddings'][0].shape[0])
                embeddings = PCA(
                    n_components=n_comp
                ).fit_transform(np.stack(self.data['embeddings']).reshape((-1, self.data['embeddings'][0].shape[0])))
                rewrite = True
            else:
                rewrite = False

            if 'plot_embeddings' not in self.data.keys() or overwrite_embeddings:
                self.log.info('Computing t-SNE for plotting')
                self.data['plot_embeddings'] = TSNE(n_components=2, **tsne_args).fit_transform(embeddings)
                new_plot_embeddings = True
            else:
                new_plot_embeddings = False

            if embedding_size >= self.data['embeddings'][0].shape[0] or embedding_size is None:
                self.embeddings = np.stack(self.data['embeddings']).reshape((-1, self.data['embeddings'][0].shape[0]))
                self.log.debug(
                    ('Requested embedding size of {} was greater or equal to data dimensionality of {}. '
                     'Data has thus not been reduced in dimensionality.').format(
                        embedding_size,
                        self.data['embeddings'].shape[1]))
            elif (self.data['nn_embeddings'].shape[1] == embedding_size if 'nn_embeddings' in self.data.keys() else False) \
                    and not overwrite_embeddings:
                self.embeddings = self.data['nn_embeddings']
                self.log.info(('Loaded reduced embeddings ({} dimensions) from precomputed file '
                               + 'for nearest neighbor search.').format(self.embeddings.shape[1]))
            else:
                if method == 'TSNE':
                    if 'plot_embeddings' in self.data.keys() and embedding_size == 2 and new_plot_embeddings:
                        self.embeddings = self.data['plot_embeddings']
                        self.log.info('Reused the precomputed manifold for plotting for nearest neighbor search.')
                    else:
                        self.log.info('Computing t-SNE of dimension {} for nearest neighbor search...'.format(
                            embedding_size))
                        self.embeddings = TSNE(
                            n_components=embedding_size,
                            n_jobs=n_jobs,
                            **tsne_args
                        ).fit_transform(embeddings)
                else:
                    self.log.info('Computing Isomap of dimension {} for nearest neighbor search...'.format(embedding_size))
                    self.embeddings = Isomap(n_components=embedding_size,
                                             n_jobs=n_jobs,
                                             ).fit_transform(embeddings)
                self.data['nn_embeddings'] = self.embeddings

            # Write added data into pickle file
            if rewrite:
                with open(self.embeddings_file, 'wb') as f:
                    pkl.dump(self.data, f)

        self.x = self.data['plot_embeddings'][:, 0]
        self.y = self.data['plot_embeddings'][:, 1]

        self.label_mapping = dict()
        for d in np.unique(self.data['dataset']).flatten():
            self.label_mapping[d] = getattr(
                importlib.import_module(datasets[d].module_name),
                datasets[d].class_name)(
                **datasets[d].kwargs,
            ).label_mapping
        train_dat = self.label_mapping[CONFIG.TRAIN_DATASET.name] = getattr(
            importlib.import_module(CONFIG.TRAIN_DATASET.module_name),
            CONFIG.TRAIN_DATASET.class_name)(
            **CONFIG.TRAIN_DATASET.kwargs,
        )
        self.pred_mapping = train_dat.pred_mapping
        if CONFIG.TRAIN_DATASET.name not in self.label_mapping:
            self.label_mapping[CONFIG.TRAIN_DATASET.name] = train_dat.label_mapping

        self.tnsize = (50, 50)
        self.fig_nn = None
        self.n_neighbors = 49
        self.current_pressed_key = None

        self.plot_main(**main_plot_args)

    def plot_main(self, **plot_args):
        """Initializes the main plot.

        Only 'legend' (bool) is currently supported as keyword argument.
        """
        self.fig_main = plt.figure(num=1)
        self.fig_main.canvas.set_window_title('Embedding space')
        ax = self.fig_main.add_subplot(111)
        ax.set_axis_off()
        self.line_main = ax.scatter(self.x, self.y,
                                    marker="o",
                                    color=np.stack([tuple(i / 255.0
                                                          for i in self.label_mapping[
                                                              self.data['dataset'][self.gi[ind]]
                                                          ][self.gt[ind]][1])
                                                    + (1.0,) for ind in range(self.x.shape[0])]))
        self.line_main.set_picker(True)

        if plot_args['legend'] if 'legend' in plot_args else False:
            box = ax.get_position()
            ax.set_position([box.x0, box.y0, box.width, box.height * 0.8])
            legend_elements = []
            for d in np.unique(self.data['dataset']).flatten():
                cls = np.unique(self.gt[np.array(self.data['dataset'])[self.gi] == d])
                cls = list({(self.label_mapping[d][cl][0], self.label_mapping[d][cl][1]) for cl in cls})
                names = np.array([i[0] for i in cls])
                cols = np.array([i[1] for i in cls])
                legend_elements += [Patch(color=tuple(i / 255.0 for i in cols[i]) + (1.0,),
                                          label=names[i]
                                          if not names[i][-1].isdigit()
                                          else names[i][:names[i].rfind(' ')])
                                    for i in range(names.shape[0])]
            ax.legend(loc='upper left', handles=legend_elements, ncol=8, bbox_to_anchor=(0, 1.2))
        self.basecolors = self.line_main.get_facecolor()

        tmp = Image.open(self.data['image_path'][self.gi[0]]).convert('RGB').crop(
            self.data['box'][0])
        tmp.thumbnail(self.tnsize, Image.ANTIALIAS)
        self.im = OffsetImage(tmp, zoom=2)
        self.xybox = (50., 50.)
        self.ab = AnnotationBbox(self.im, (0, 0), xybox=self.xybox, xycoords='data',
                                 boxcoords='offset points', pad=0.3, arrowprops=dict(arrowstyle='->'))
        ax.add_artist(self.ab)
        self.ab.set_visible(False)

        if plot_args['save_path'] is not None if 'save_path' in plot_args else False:
            plt.savefig(expanduser(plot_args['save_path']), dpi=300, bbox_inches='tight')

        else:
            self.fig_main.canvas.mpl_connect('motion_notify_event', self.hover_main)
            self.fig_main.canvas.mpl_connect('button_press_event', self.click_main)
            self.fig_main.canvas.mpl_connect('scroll_event', self.scroll)
            self.fig_main.canvas.mpl_connect('key_press_event', self.key_press)
            self.fig_main.canvas.mpl_connect('key_release_event', self.key_release)
            plt.show()

    def hover_main(self, event):
        """Action handler for the main plot.

        This function shows a thumbnail of the underlying image when a scatter point is hovered with the mouse.
        """
        # if the mouse is over the scatter points
        if self.line_main.contains(event)[0]:
            # find out the index within the array from the event
            ind, *_ = self.line_main.contains(event)[1]["ind"]

            # get the figure size
            w, h = self.fig_main.get_size_inches() * self.fig_main.dpi
            ws = (event.x > w / 2.) * -1 + (event.x <= w / 2.)
            hs = (event.y > h / 2.) * -1 + (event.y <= h / 2.)

            # if event occurs in the top or right quadrant of the figure,
            # change the annotation box position relative to mouse.
            self.ab.xybox = (self.xybox[0] * ws, self.xybox[1] * hs)

            # make annotation box visible
            self.ab.set_visible(True)

            # place it at the position of the hovered scatter point
            self.ab.xy = (self.x[ind], self.y[ind])

            # set the image corresponding to that point
            tmp = Image.open(self.data['image_path'][self.gi[ind]]).convert('RGB').crop(self.data['box'][ind])
            tmp.thumbnail(self.tnsize, Image.ANTIALIAS)
            self.im.set_data(tmp)
            tmp.close()
        else:
            # if the mouse is not over a scatter point
            self.ab.set_visible(False)
        self.fig_main.canvas.draw_idle()

    def click_main(self, event):
        """Action handler for the main plot.

        This function shows a single or full image or displays nearest neighbors based on the button that has been
        pressed and which scatter point was pressed.
        """
        if self.line_main.contains(event)[0]:
            ind, *_ = self.line_main.contains(event)[1]['ind']

            if self.current_pressed_key == 't' and event.button == 1:
                self.store_thumbnail(ind)
            elif self.current_pressed_key == 'control' and event.button == 1:
                self.show_single_image(ind, save=True)
            elif self.current_pressed_key == 'control' and event.button == 2:
                self.show_full_image(ind, save=True)
            elif event.button == 1:  # left mouse button
                self.show_single_image(ind)
            elif event.button == 2:  # middle mouse button
                self.show_full_image(ind)
            elif event.button == 3:  # right mouse button
                if not plt.fignum_exists(2):
                    # nearest neighbor figure is not open anymore or has not been opened yet
                    self.log.info('Loading nearest neighbors...')
                    self.nearest_neighbors = self.get_nearest_neighbors(ind, metric=self.distance_metrics[self.dm])
                    thumbnails = []
                    for neighbor_ind in self.nearest_neighbors:
                        thumbnails.append(Image.open(self.data['image_path'][self.gi[neighbor_ind]]).crop(
                            self.data['box'][neighbor_ind]))
                    columns = math.ceil(math.sqrt(self.n_neighbors))
                    rows = math.ceil(self.n_neighbors / columns)

                    self.fig_nn = plt.figure(num=2, dpi=self.dpi)
                    self.fig_nn.canvas.set_window_title('{} nearest neighbors to selected image'.format(
                        self.n_neighbors))
                    for p in range(columns * rows):
                        ax = self.fig_nn.add_subplot(rows, columns, p + 1)
                        ax.set_axis_off()
                        if p < len(thumbnails):
                            ax.imshow(np.asarray(thumbnails[p]))
                    self.fig_nn.canvas.mpl_connect('button_press_event', self.click_nn)
                    self.fig_nn.canvas.mpl_connect('key_press_event', self.key_press)
                    self.fig_nn.canvas.mpl_connect('key_release_event', self.key_release)
                    self.fig_nn.canvas.mpl_connect('scroll_event', self.scroll)
                    self.fig_nn.show()
                else:
                    # nearest neighbor figure is already open. Update the figure with new nearest neighbor
                    self.update_nearest_neighbors(ind)
                    return

            self.set_color(ind, self.current_color)
            self.flush_colors()

    def click_nn(self, event):
        """Action handler for the nearest neighbor window.

        When clicking a cropped segment in the nearest neighbor window the same actions are taken as in the click
        handler for the main plot.
        """
        if event.inaxes in self.fig_nn.axes:
            ind = self.get_ind_nn(event)

            if self.current_pressed_key == 't' and event.button == 1:
                self.store_thumbnail(self.nearest_neighbors[ind])
            elif self.current_pressed_key == 'control' and event.button == 1:
                self.show_single_image(self.nearest_neighbors[ind], save=True)
            elif self.current_pressed_key == 'control' and event.button == 2:
                self.show_full_image(self.nearest_neighbors[ind], save=True)
            elif event.button == 1:  # left mouse button
                self.show_single_image(self.nearest_neighbors[ind])
            elif event.button == 2:  # middle mouse button
                self.show_full_image(self.nearest_neighbors[ind])
            elif event.button == 3:  # right mouse button
                self.update_nearest_neighbors(self.nearest_neighbors[ind])

    def key_press(self, event):
        """Performs different actions based on pressed keys."""
        if event.key == 'm':
            self.dm += 1
            self.dm = self.dm % len(self.distance_metrics)
            self.log.info('Changed distance metric to {}'.format(self.distance_metrics[self.dm]))
        elif event.key == '#':
            self.cme += 1
            self.cme = self.cme % len(self.cluster_methods)
            self.log.info('Changed clustering method to {}'.format(list(self.cluster_methods.keys())[self.cme]))
        elif event.key == 'c':
            self.log.info('Started clustering with {}...'.format(list(self.cluster_methods.keys())[self.cme]))
            self.cluster(method=list(self.cluster_methods.keys())[self.cme])
            self.fig_main.axes[0].get_legend().remove()
            self.basecolors = cm.get_cmap('viridis', (max(self.clustering) + 1))(self.clustering)
            self.flush_colors()
            self.cluster_statistics()
        elif event.key == 'x':
            if self.clustering is not None:
                self.cluster_statistics()
        elif event.key == 'g':
            self.color_gt()
        elif event.key == 'h':
            self.color_pred()
        elif event.key == 'b':
            self.set_color(list(range(self.basecolors.shape[0])), self.standard_color)
            self.flush_colors()

        self.current_pressed_key = event.key
        self.log.debug('Key \'{}\' pressed.'.format(event.key))

    def key_release(self, event):
        """Clears the variable where the last pressed key is saved."""
        self.current_pressed_key = None
        self.log.debug('Key \'{}\' released.'.format(event.key))

    def scroll(self, event):
        """Increases or decreases number of nearest neighbors when scrolling on the main or nearest neighbor plot."""
        if event.button == 'up':
            self.n_neighbors += 1
            self.log.info('Increased number of nearest neighbors to {}'.format(self.n_neighbors))
        elif event.button == 'down':
            if self.n_neighbors > 0:
                self.n_neighbors -= 1
                self.log.info('Decreased number of nearest neighbors to {}'.format(self.n_neighbors))

    def show_single_image(self, ind, save=False):
        """Displays the full image belonging to a segment. The segment is marked with a red bounding box."""
        self.log.info('{} image...'.format('Saving' if save else 'Loading'))
        img_box = self.draw_box_on_image(ind)
        fig_tmp = plt.figure(max(3, max(plt.get_fignums()) + 1), dpi=self.dpi)
        ax = fig_tmp.add_subplot(111)
        ax.set_axis_off()
        ax.imshow(np.asarray(img_box), interpolation='nearest')
        if save:
            fig_tmp.subplots_adjust(bottom=0, left=0, right=1, top=1, hspace=0, wspace=0)
            ax.margins(0.05, 0.05)
            fig_tmp.gca().xaxis.set_major_locator(plt.NullLocator())
            fig_tmp.gca().yaxis.set_major_locator(plt.NullLocator())
            fig_tmp.savefig(join(self.save_dir, 'image_{}.jpg'.format(ind)),
                            bbox_inches='tight', pad_inches=0.0)
            self.log.debug('Saved image to {}'.format(join(self.save_dir, 'image_{}.jpg'.format(ind))))
        else:
            fig_tmp.canvas.set_window_title('Dataset: {}, Image index: {}'.format(self.data['dataset'][self.gi[ind]],
                                                                                  self.data['image_index'][
                                                                                      self.gi[ind]]))
            fig_tmp.tight_layout(pad=0.0)
            fig_tmp.show()

    def show_full_image(self, ind, save=False):
        """Displays four panels of the full image belonging to a segment.

        Top left: Entropy heatmap of prediction.
        Top right: Predicted IoU of each segment.
        Bottom left: Source image with ground truth overlay.
        Bottom right: Predicted semantic segmentation.
        """
        self.log.info('{} detailed image...'.format('Saving' if save else 'Loading'))
        box = self.data['box'][ind]
        image = np.asarray(Image.open(self.data['image_path'][self.gi[ind]]).convert('RGB'))
        image_index = self.data['image_index'][self.gi[ind]]
        iou_pred = self.data['iou_pred'][self.gi[ind]]
        dataset = self.data['dataset'][self.gi[ind]]
        model_name = self.data['model_name'][self.gi[ind]]

        pred, gt, image_path = probs_gt_load(image_index,
                                             input_dir=join(CONFIG.metaseg_io_path, 'input', model_name, dataset))
        components = components_load(image_index,
                                     components_dir=join(CONFIG.metaseg_io_path, 'components', model_name, dataset))
        e = entropy(pred)
        pred = pred.argmax(2)
        predc = np.asarray([self.pred_mapping[pred[ind_i, ind_j]][1]
                            for ind_i in range(pred.shape[0])
                            for ind_j in range(pred.shape[1])]).reshape(image.shape)
        gtc = np.asarray([self.label_mapping[dataset][gt[ind_i, ind_j]][1]
                          for ind_i in range(gt.shape[0])
                          for ind_j in range(gt.shape[1])]).reshape(image.shape)

        overlay_factor = [1.0, 0.5, 1.0]
        img_predc, img_gtc, img_entropy = [
            Image.fromarray(np.uint8(arr * overlay_factor[i] + image * (1 - overlay_factor[i])))
            for i, arr in enumerate([predc,
                                     gtc,
                                     cm.jet(e)[:, :, :3] * 255.0])]

        img_ioupred = Image.fromarray(self.visualize_segments(components, iou_pred))

        images = [img_gtc, img_predc, img_entropy, img_ioupred]

        box_line_width = 5
        left, upper = max(0, box[0] - box_line_width), max(0, box[1] - box_line_width)
        right, lower = min(pred.shape[1], box[2] + box_line_width), min(pred.shape[0], box[3] + box_line_width)

        for k in images:
            draw = ImageDraw.Draw(k)
            draw.rectangle([left, upper, right, lower], outline=(255, 0, 0), width=box_line_width)
            del draw

        for k in range(len(images)):
            images[k] = np.asarray(images[k]).astype('uint8')

        img_top = np.concatenate(images[2:], axis=1)
        img_bottom = np.concatenate(images[:2], axis=1)

        img_total = np.concatenate((img_top, img_bottom), axis=0)
        fig_tmp = plt.figure(max(3, max(plt.get_fignums()) + 1), dpi=self.dpi)
        fig_tmp.canvas.set_window_title('Dataset: {}, Image index: {}'.format(dataset,
                                                                              image_index))
        ax = fig_tmp.add_subplot(111)
        ax.set_axis_off()
        ax.imshow(img_total, interpolation='nearest')

        if save:
            fig_tmp.subplots_adjust(bottom=0, left=0, right=1, top=1, hspace=0, wspace=0)
            ax.margins(0.05, 0.05)
            fig_tmp.gca().xaxis.set_major_locator(plt.NullLocator())
            fig_tmp.gca().yaxis.set_major_locator(plt.NullLocator())
            fig_tmp.savefig(join(self.save_dir, 'detailed_image_{}.jpg'.format(ind)),
                            bbox_inches='tight', pad_inches=0.0)
            self.log.debug('Saved image to {}'.format(join(self.save_dir, 'detailed_image_{}.jpg'.format(ind))))
        else:
            fig_tmp.tight_layout(pad=0.0)
            fig_tmp.show()

    def store_thumbnail(self, ind):
        """Stores a thumbnail of a segment if requested. Thus is not saving the whole image but only the cropped part.
        """
        image = Image.open(self.data['image_path'][self.gi[ind]]).convert('RGB')
        image = image.crop(self.data['box'][ind])

        name = self.label_mapping[self.data['dataset'][self.gi[ind]]][self.gt[ind]][0]
        if name[-1].isdigit():
            name = name[:-2]

        name = name.replace(' ', '_')

        image.save(join(self.save_dir, 'thumbnail_{}_{:0>2.1f}_{:0>2.1f}.jpg'.format(
            name,
            self.x[ind],
            self.y[ind])))
        self.log.debug('Saved thumbnail to {}'.format(join(self.save_dir, 'thumbnail_{}_{:0>2.1f}_{:0>2.1f}.jpg'.format(
            name,
            self.x[ind],
            self.y[ind]))))

    def get_nearest_neighbors(self, ind, metric='cos'):
        """Computes nearest neighbors to the specified index in the collection of segment crops."""
        if metric == 'euclid':
            dists = self.lp_dist(self.embeddings[ind], self.embeddings, d=2)
        else:
            dists = self.cos_dist(self.embeddings[ind], self.embeddings)
        return np.argsort(dists)[1:(self.n_neighbors + 1)]

    def update_nearest_neighbors(self, ind):
        """If requesting nearest neighbors of a segment within the nearest neighbor plot window the nearest neighbors
            are updated according to the newly selected segment.
        """
        self.log.info('Loading nearest neighbors...')
        self.nearest_neighbors = self.get_nearest_neighbors(ind, metric=self.distance_metrics[self.dm])
        thumbnails = []
        for neighbor_ind in self.nearest_neighbors:
            b = self.data['box'][neighbor_ind]
            thumbnails.append(plt.imread(self.data['image_path'][self.gi[neighbor_ind]])[b[1]:b[3], b[0]:b[2], :])
        n = math.ceil(math.sqrt(len(self.nearest_neighbors)))
        if len(self.fig_nn.axes) != (n ** 2):
            # number of nearest neighbors has been changed
            # redefine number of subplots in fig_nn
            self.rearrange_axes(n, n)

        for p in range(n ** 2):
            if p < self.n_neighbors:
                self.fig_nn.axes[p].imshow(thumbnails[p])
            else:
                self.fig_nn.axes[p].clear()
                self.fig_nn.axes[p].set_axis_off()

        self.fig_nn.canvas.draw()
        self.set_color(ind, self.current_color)
        self.flush_colors()

    def cluster(self, method='kmeans'):
        if method in self.methods_with_ncluster_param:
            n_clusters = self.n_cluster_prompt()
            if n_clusters == 'elbow' and method == 'kmeans':
                n_clusters = self.elbow()

            self.clustering = self.cluster_methods[method]['main'](
                n_clusters=n_clusters,
                **self.cluster_methods[method]['kwargs']).fit_predict(self.embeddings)

    def cluster_statistics(self):
        fig_clstats = plt.figure(max(3, max(plt.get_fignums()) + 1))
        clusters = np.unique(self.clustering).flatten()
        n = math.ceil(math.sqrt(clusters.shape[0]))
        all_label_names = []
        self.log.debug('Size of cluster statistics plots: {}'.format(n))
        for i in range(n ** 2):
            ax = fig_clstats.add_subplot(n, n, i + 1)
            if i < clusters.shape[0]:
                # labels, label_counts = np.unique(self.gt[self.clustering == clusters[i]], return_counts=True)

                label_names = []
                cols = []
                for j in range(self.clustering.shape[0]):
                    if self.clustering[j] == clusters[i]:
                        dat = self.data['dataset'][self.gi[j]]
                        name = self.label_mapping[dat][self.gt[j]][0]
                        label_names.append((name, self.label_mapping[dat][self.gt[j]][1]))

                if i == (clusters.shape[0] - 1):
                    missing = [all_label_names[ind] for ind in np.unique([lbl[0] for lbl in all_label_names],
                                                                         return_index=True)[1]]
                    label_names += missing

                labels, label_inds, label_counts = np.unique([lbl[0] for lbl in label_names],
                                                             return_counts=True,
                                                             return_index=True)
                explode = np.zeros(labels.shape[0])
                explode[np.argmax(label_counts)] = 0.2
                cols = np.array([label_names[ind][1] for ind in label_inds])
                perm = np.argsort(label_counts)[::-1]
                all_label_names += [label_names[i] for i in np.unique([lbl[0] for lbl in label_names],
                                                                      return_index=True)[1]]

                ax.pie(label_counts[perm],
                       # labels=labels[perm],
                       colors=cols[perm],
                       explode=explode[perm],
                       autopct=lambda perc: '{:1.1f}%'.format(perc) if perc > 20 else '',
                       # shadow=True,
                       startangle=90,
                       wedgeprops=dict(edgecolor='w'),
                       textprops=dict(color="g"))
                ax.axis('equal')
            else:
                ax.set_axis_off()

        fig_clstats.legend(fig_clstats.axes[clusters.shape[0] - 1].patches, labels[perm], loc='upper left')
        fig_clstats.show()

    def elbow(self):
        low = int(input('Enter the minimum number of clusters: '))
        high = int(input('Enter the maximum number of clusters: '))
        km = [KMeans(n_clusters=i) for i in range(low, high + 1)]

        km = [k.fit(self.embeddings) for k in tqdm(km, total=len(km))]
        score = [k.inertia_ for k in km]

        fig_elbow = plt.figure(max(3, max(plt.get_fignums()) + 1))
        ax = fig_elbow.add_subplot(111)
        ax.plot(range(low, high + 1), score)
        fig_elbow.show()
        return int(input('Enter number of clusters: '))

    def n_cluster_prompt(self):
        # inp = input('Enter the number of clusters. Typing \'elbow\' will start the elbow process.')
        inp = input('Enter the number of clusters: ')
        if inp == 'elbow':
            return inp
        else:
            try:
                inp = int(inp)
            except ValueError:
                self.log.error('Input should be an int or \'elbow\' but received {}!'.format(inp))
                return 'error'
            if inp <= 1:
                raise ValueError('Number of clusters should be greater than 1!')
            else:
                return inp

    def rearrange_axes(self, nrows, ncols):
        """Helper function for the nearest neighbor plot window. If number of nearest neighbors has been changed and a
            new query segment has been chosen the arrangement of subplots within the window has to be changed.
        """
        n = len(self.fig_nn.axes)
        if n <= (nrows * ncols):
            # we need to add more axes
            for i, ax in enumerate(self.fig_nn.axes):
                ax.change_geometry(nrows, ncols, i + 1)
            for p in range(n, nrows * ncols):
                ax = self.fig_nn.add_subplot(nrows, ncols, p + 1)
                ax.set_axis_off()
        else:
            # we need to remove some axes
            for p in range(n - 1, (nrows * ncols) - 1, -1):
                self.fig_nn.delaxes(self.fig_nn.axes[p])
            for i, ax in enumerate(self.fig_nn.axes):
                ax.change_geometry(nrows, ncols, i + 1)

    def draw_box_on_image(self, ind):
        """Draws the red bounding of a selected segment onto the source image."""
        box_line_width = 5
        img_box = Image.open(self.data['image_path'][self.gi[ind]]).convert('RGB')
        draw = ImageDraw.Draw(img_box)
        left, upper, right, lower = self.data['box'][ind]
        left, upper = max(0, left - box_line_width), max(0, upper - box_line_width)
        right, lower = min(img_box.size[0], right + box_line_width), min(img_box.size[1], lower + box_line_width)
        draw.rectangle([left, upper, right, lower], outline=(255, 0, 0), width=box_line_width)
        del draw
        return img_box

    @staticmethod
    def visualize_segments(comp, metric):
        """Helper function for generation of the four panels in the detailed image function."""
        r = np.asarray(metric)
        r = 1 - 0.5 * r
        g = np.asarray(metric)
        b = 0.3 + 0.35 * np.asarray(metric)

        r = np.concatenate((r, np.asarray([0, 1])))
        g = np.concatenate((g, np.asarray([0, 1])))
        b = np.concatenate((b, np.asarray([0, 1])))

        components = np.asarray(comp.copy(), dtype='int16')
        components[components < 0] = len(r) - 1
        components[components == 0] = len(r)

        img = np.zeros(components.shape + (3,))
        x, y = np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]))
        x = x.reshape(-1)
        y = y.reshape(-1)

        img[x, y, 0] = r[components[x, y] - 1]
        img[x, y, 1] = g[components[x, y] - 1]
        img[x, y, 2] = b[components[x, y] - 1]

        img = np.asarray(255 * img).astype('uint8')

        return img

    @staticmethod
    def lp_dist(point, all_points, d=2):
        """Calculates the L_p distance from a point to a collection of points. Used for retrieval."""
        return ((all_points - point) ** d).sum(1) ** (1.0 / d)

    @staticmethod
    def cos_dist(point, all_points):
        """Calculates the cosine distance from a point to a collection of points. Used for retrieval."""
        return 1 - ((point * all_points).sum(1) / (norm(point) * norm(all_points, axis=1)))

    @staticmethod
    def get_gridsize(fig):
        """Helper function for the nearest neighbor plot."""
        return fig.axes[0].get_subplotspec().get_gridspec().get_geometry()

    def get_ind_nn(self, event):
        """Helper function for the nearest neighbor plot."""
        _, ncols = self.get_gridsize(self.fig_nn)
        eventrow = event.inaxes.rowNum
        eventcol = event.inaxes.colNum
        return (eventrow * ncols) + eventcol

    def color_gt(self):
        """When called colors the scatter in the main plot according to the ground truth colors."""
        self.basecolors = np.stack([tuple(i / 255.0
                                          for i in self.label_mapping[self.data['dataset'][self.gi[ind]]][
                                              self.gt[ind]
                                          ][1])
                                    + (1.0,) for ind in range(self.basecolors.shape[0])])
        legend_elements = []
        for d in np.unique(self.data['dataset']).flatten():
            cls = np.unique(self.gt[np.array(self.data['dataset'])[self.gi] == d])
            cls = list({(self.label_mapping[d][cl][0], self.label_mapping[d][cl][1]) for cl in cls})
            names = np.array([i[0] for i in cls])
            cols = np.array([i[1] for i in cls])
            legend_elements += [Patch(color=tuple(i / 255.0 for i in cols[i]) + (1.0,),
                                      label=names[i]
                                      if not names[i][-1].isdigit()
                                      else names[i][:names[i].rfind(' ')])
                                for i in range(names.shape[0])]
        self.fig_main.axes[0].legend(loc='upper left', handles=legend_elements, ncol=8, bbox_to_anchor=(0, 1.2))
        self.flush_colors()

    def color_pred(self):
        """When called colors the scatter in the main plot according the predicted class color."""
        self.basecolors = np.stack([tuple(i / 255.0
                                          for i in self.pred_mapping[self.pred[ind]][1])
                                    + (1.0,)
                                    for ind in range(self.basecolors.shape[0])])
        legend_elements = [Patch(color=tuple(i / 255.0
                                             for i in self.pred_mapping[cl][1]) + (1.0,),
                                 label=self.pred_mapping[cl][0])
                           for cl in np.unique(self.pred).flatten()]
        self.fig_main.axes[0].legend(loc='upper left', handles=legend_elements, ncol=8, bbox_to_anchor=(0, 1.2))
        self.flush_colors()

    def set_color(self, ind, color):
        """Helper function to set a color of a segment with index ind."""
        self.basecolors[ind, :] = color

    def change_color(self, old_color, new_color):
        """Helper function to change a specific color to a different one."""
        self.basecolors[(self.basecolors == old_color).all(axis=1)] = new_color

    def flush_colors(self):
        """Flushes all color changes onto the main scatter plot."""
        self.line_main.set_color(self.basecolors)
        self.fig_main.canvas.draw()