Ejemplo n.º 1
0
def format_figure(fig: plt.Figure,
                  title: str = None,
                  x_label: str = None,
                  y_label: str = None,
                  title_y_position: float = 0.93):
    fig.set_figwidth(12)
    fig.set_figheight(8)
    fig.set_facecolor("white")
    if title is not None:
        fig.suptitle(title,
                     y=title_y_position,
                     verticalalignment='top',
                     fontsize=24)

    axes = fig.axes
    for ax in axes:
        ax.spines['top'].set_color('white')
        ax.spines['right'].set_color('white')
        ax.set_facecolor("white")
        ax.xaxis.grid(which="both", linewidth=0.5)
        ax.yaxis.grid(which="both", linewidth=0.5)
        ax.xaxis.label.set_fontsize(18)
        ax.yaxis.label.set_fontsize(18)
        ax.title.set_fontsize(20)

    if x_label is not None:
        fig.text(0.5, 0.04, x_label, ha='center', fontdict={'size': 18})

    if y_label is not None:
        fig.text(0.04,
                 0.5,
                 y_label,
                 va='center',
                 rotation='vertical',
                 fontdict={'size': 18})
Ejemplo n.º 2
0
 def __mouseRealseEvent(self, e: MouseEvent, fig: Figure, ax: Axes):
     """ 鼠标点击画布时显示点击位置 """
     # 当鼠标左键的点击位置是ax并且当前工具栏模式不是缩放或者移动时,打印点击位置并显示
     if e.inaxes is ax and e.button == MouseButton.LEFT and \
             plt.get_current_fig_manager().toolbar.mode not in {'zoom rect', 'pan/zoom'}:
         # 移除上一个点击的点
         self.__removePoint(ax, 'clicked point')
         # 绘制当前点击的点
         ax.plot(e.x, e.y, 'g*', label='clicked point')
         fig.suptitle(f'Clicked Pos : {(e.x, e.y)}')
         fig.canvas.draw()
         print(f'Clicked Pos : {(e.x, e.y)}')
Ejemplo n.º 3
0
class Example(Frame):

    def __init__(self, parent):
        Frame.__init__(self, parent)

        self.hsv_color = colorsys.rgb_to_hsv(0.0, 0.0, 1.0)
        self.hex_color = '#0000ff'
        self.color_list = ["red","blue","green","orange","purple"]
        self.parent = parent
        print "Loading model..."
        self.lux = Lux()
        fp = open(curdir+"/gauss_model.pkl"); self.gm = pickle.load(fp); fp.close()
        fp = open(curdir+"/memoized_binomial_data.pkl"); self.curve_data = pickle.load(fp); fp.close()
        fp = open(curdir+"/sampling_normalizer.pkl"); self.normalizer = pickle.load(fp); fp.close()
        print "Creating UI"
        self.initUI()
        self.update_output()
        self.replot()

    def update_output(self):
        (h, s, v) = self.hsv_color
        self.hsv_var.set("Hue: \t %2.1f \nSat:\t %2.1f \nValue:\t %2.1f" % (h*360,s*100,v*100))
        items = self.lux.full_posterior((h * 360, s * 100, v * 100))
        self.current_post = items
        desc = [ '{:<25} ({:.3f})\n'.format(items[i][0], items[i][1]) for i in range(25) ]
        self.display.config(state=NORMAL)
        self.display.delete(0, END)

        for i in range(25): self.display.insert(END, '{:<20} ({:.3f})'.format(items[i][0], items[i][1]))

        self.display.select_set(0, 0)

    def plot_lux_model(self, params,ax1,label,support,dim):
        cur_color='black'
        mu1,sh1,sc1,mu2,sh2,sc2 = params
        left_stach=gam_dist(sh1,scale=sc1); lbounds=left_stach.interval(0.99)
        right_stach=gam_dist(sh2,scale=sc2); rbounds=right_stach.interval(0.99)
        lx=np.linspace(mu1,-180); rx=np.linspace(mu2,360)
        s=3;
        ax1.plot(rx, [right_stach.sf(abs(y-mu2)) for y in rx],linewidth=s,c=cur_color);
        ax1.plot([1.01*mu1,0.99*mu2], [1.,1.], linewidth=s,c=cur_color)
        return ax1.plot(lx,[left_stach.sf(abs(y-mu1)) for y in lx],c=cur_color, linewidth=s);

    def plot_gm_model(self, params, ax, label, support):
        s=3
        x = np.linspace(support[0],support[1],360)

        return ax.plot(x,norm.pdf(x,params[0],params[1]),c='red', linewidth=s), norm.pdf([params[0]],params[0],[params[1]])[0]


    def initUI(self):

        self.parent.title("Interactive LUX visualization")
        self.pack(fill=BOTH, expand=1)

        self.color_frame = Frame(self, border=1)
        self.color_frame.pack(side=LEFT)

        probe_title_var = StringVar(); probe_title_label = Label(self.color_frame, textvariable=probe_title_var, justify=CENTER,  font = "Helvetica 16 bold italic")
        probe_title_var.set("Color Probe X"); probe_title_label.pack(side=TOP)

        self.hsv_var = StringVar()
        self.hsv_label = Label(self.color_frame, textvariable=self.hsv_var,justify=LEFT)
        h,s,v = self.hsv_color
        self.hsv_var.set("Hue: %2.1f \nSaturation: %2.1f \nValue: %2.1f" % (h*360,s*100,v*100))
        self.hsv_label.pack(side=TOP)

        self.frame = Frame(self.color_frame, border=1,
            relief=SUNKEN, width=200, height=200)

        self.frame.pack(side=TOP)
        self.frame.config(bg=self.hex_color)
        self.frame.bind("<Button-1>",self.onChoose)


        self.btn = Button(self.color_frame, text="Select Color",
            command=self.onChoose)
        self.btn.pack(side=TOP)

        posterior_title_var = StringVar(); posterior_title_label = Label(self.color_frame, textvariable=posterior_title_var, justify=CENTER, font = "Helvetica 16 bold italic")
        posterior_title_var.set("\n\nLUX's Posterior"); posterior_title_label.pack(side=TOP)

        Label(self.color_frame, text="Double click to show details \n(Wait time dependent on computer)").pack(side=TOP)



        my_font = tkFont.Font(family="Courier", size=10)
        self.display = Listbox(self.color_frame, border=1,
            relief=SUNKEN, width=30, height=25, font=my_font)
        self.display.pack(side=TOP,fill=Y,expand=1)
        self.display.bind("<Double-Button-1>",self.onSelect)
        self.display_btn = Button(self.color_frame, text="Show details", command=self.onSelect)
        self.display_btn.pack(side=TOP)

        self.update_output()



        self.fig = Figure(figsize=(10,4), dpi=100)

        self.canvas = FigureCanvasTkAgg(self.fig, master=self)
        self.canvas.show()
        self.canvas.get_tk_widget().pack(side=LEFT, fill=BOTH, expand=1)
        self.canvas._tkcanvas.pack(side='top', fill='both', expand=1)

    def replot(self):

        def gb(x,i,t):
            #t is max value, i in number of bins, x is the thing to be binned
            if x==t:
                return i-1
            elif x==0.0:
                return 0
            return int(floor(float(x)*i/t))
        hsv_title = []
        j=self.display.curselection()[0]
        name = self.current_post[j][0]
        mult = lambda x: reduce(operator.mul, x)
        g_labels = []; lux_labels=[]; all_g_params=[]
        for i in range(3):

            def align_yaxis(ax1, v1, ax2, v2):
                """adjust ax2 ylimit so that v2 in ax2 is aligned to v1 in ax1"""
                _, y1 = ax1.transData.transform((0, v1))
                _, y2 = ax2.transData.transform((0, v2))
                inv = ax2.transData.inverted()
                _, dy = inv.transform((0, 0)) - inv.transform((0, y1-y2))
                miny, maxy = ax2.get_ylim()

                ax2.set_ylim(miny, maxy+dy)

            subplot = self.fig.add_subplot(4,1,i+1)
            dim_label = ["H", "S","V"][i]
            subplot.set_ylabel(r"$P(k^{true}_{%s}|x)$" % ["H", "S","V"][i] )

            curve_data = self.curve_data[name][i]

            scale = lambda x,a=0.3,b=0.9: (b-a)*(x)+a
            p_x = lambda x: self.normalizer[i][gb(x,len(self.normalizer[i]),[360,100,100][i])]
            max_p_x = max(self.normalizer[i])
            #1 is white, 0 is black. so we want highly probable thigns to be black..


            if self.lux.get_adj(self.current_post[j][0]):
                support = [[-180,180], [0,100],[0,100]][i]
                pp = lambda x,i: x-360 if i==0 and x>180 else x
                hacky_solution = [360,100,100][i]
                w = 1.5 if i==0 else 1

                conv = lambda x: x*support[1]/len(curve_data)
                bar_colors = ["%s" % (scale(1-p_x(conv(x))/max_p_x)) for x in range(len(curve_data))]
                bar1 = subplot.bar([pp(atan2(sin((x*hacky_solution/len(curve_data))*pi/180),cos((x*hacky_solution/len(curve_data))*pi/180))*180/pi,i) for x in range(len(curve_data))],[x/max(curve_data) for x in curve_data], label="%s data" % j,ec="black",width=w,linewidth=0,color=bar_colors)
            else:
                support = [[0,360], [0,100],[0,100]][i]
                w = 1.5 if i==0 else 1
                conv = lambda x: x*support[1]/len(curve_data)
                bar_colors = ["%s" % (scale(1-p_x(conv(x))/max_p_x)) for x in range(len(curve_data))]
                bar1 = subplot.bar([x*support[1]/len(curve_data) for x in range(len(curve_data))],[x/max(curve_data) for x in curve_data], label="%s data" % name[0],ec="black",width=w,linewidth=0,color=bar_colors)
                pp = lambda x,*args: x

            point = pp(self.hsv_color[i]*[360,100,100][i],i)
            hsv_title.append(point)
            probeplot = subplot.plot([point,point], [0,1],linewidth=3,c='blue',label="Probe")

            #for j in range(5):
            lux_plot = self.plot_lux_model(self.lux.get_params(self.current_post[j][0])[i], subplot, self.current_post[j][0],support, i)
            subplot2 = subplot.twinx()
            gm_plot,gm_height = self.plot_gm_model([pp(g_param,i) for g_param in self.gm[self.current_post[j][0]][0][i]], subplot2, self.current_post[j][0], support)
            extra = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)

            subplot.legend([extra], [["Hue", "Saturation", "Value"][i]],loc=2,frameon=False)


            if i==0: legend_set=lux_plot+[extra,extra,extra]+gm_plot+[extra,extra,extra]
            lux_params = self.lux.get_params(self.current_post[j][0])[i]
            g_params = [pp(g_param,i) for g_param in self.gm[self.current_post[j][0]][0][i]]
            all_g_params.append(g_params)

            g_labels.append(r"$\mu^{%s}=$%2.2f, $\sigma^{%s}$=%2.2f" % (dim_label, g_params[0],dim_label,g_params[1]))
            #lux_labels.append(r"$\mu^{L,%s}=$%2.2f, $E[\tau^{L,%s}]$=%2.2f, $\alpha^{L,%s}$=%2.2f, $\beta^{L,%s}$=%2.2f, $\mu^{U,%s}=$%2.2f, $E[\tau^{L,%s}]$=%2.2f, $\alpha^{U,%s}$=%2.2f, $\beta^{U,%s}$=%2.2f" % (dim_label, lux_params[0],dim_label, (lux_params[0]-lux_params[1]*lux_params[2]),dim_label,lux_params[1],dim_label, lux_params[2],dim_label,lux_params[3],dim_label,(lux_params[3]+lux_params[4]*lux_params[5]),dim_label, lux_params[4],dim_label,lux_params[5]))
            lux_labels.append(r"$\mu^{L,%s}=$%2.2f, $\alpha^{L,%s}$=%2.2f, $\beta^{L,%s}$=%2.2f, $\mu^{U,%s}=$%2.2f,  $\alpha^{U,%s}$=%2.2f, $\beta^{U,%s}$=%2.2f" % (dim_label, lux_params[0],dim_label, lux_params[1],dim_label, lux_params[2],dim_label,lux_params[3],dim_label,lux_params[4],dim_label,lux_params[5]))

            subplot.set_xlim(support[0],support[1])
            subplot.set_ylim(0,1.05)
            subplot2.set_xlim(support[0],support[1])
            subplot2.set_ylabel(r"$P(x|Gaussian_{%s})$" % ["H", "S","V"][i])
            align_yaxis(subplot, 1., subplot2, gm_height)


        leg_loc =(0.9,0.2)

        datum = [x*[360,100,100][i] for i,x in enumerate(self.hsv_color)];  phi_value = self.lux.get_phi(datum,self.current_post[j][0])
        #gauss_value = mult([norm.pdf(datum[i],all_g_params[i][0],all_g_params[i][1]) for i in range(3)])
        leg=self.fig.legend(probeplot+legend_set, ["Probe X"]+ [r"$\mathbf{\phi}_{%s}(X)=\mathbf{%2.5f}$; $\mathbf{\alpha}=\mathbf{%2.4f}$" % (self.current_post[j][0],phi_value,self.lux.get_avail(self.current_post[j][0]))]+lux_labels+
                                     [r"$Normal^{Hue}_{%s}$; $prior(%s)=%2.4f$" % (self.current_post[j][0],self.current_post[j][0], self.gm[self.current_post[j][0]][2])]+[g_labels[0]+"; "+g_labels[1]+"; "+g_labels[2]]
                                     , loc=8, handletextpad=4,labelspacing=0.1)


        self.fig.suptitle("%s" % name, size=30)

        print "done replotting"

    def onChoose(self, *args):
        try:
            ((red,green,blue), hx) = tkColorChooser.askcolor()
        except:
            print "I think you hit cancel"
            return
        self.hex_color = hx
        self.hsv_color = colorsys.rgb_to_hsv(red/255.0, green/255.0, blue/255.0)
        self.frame.config(bg=hx)
        self.update_output()
        self.fig.clear()
        self.replot()
        self.canvas.draw()

    def onSelect(self, *args):
        self.fig.clear()
        self.replot()
        self.canvas.draw()
Ejemplo n.º 4
0
Archivo: demo.py Proyecto: cogniton/lux
class Example(Frame):
  
    def __init__(self, parent):
        Frame.__init__(self, parent)   
         
        self.hsv_color = colorsys.rgb_to_hsv(0.0, 0.0, 1.0)
        self.hex_color = '#0000ff'
        self.color_list = ["red","blue","green","orange","purple"]
        self.parent = parent        
        print "Getting Model"
        self.lux = lux.LUX("lux.xml")
        print "Creating UI"
        self.initUI()
        self.update_output()

    def update_output(self):
        (h, s, v) = self.hsv_color
        items = self.lux.full_posterior((h * 360, s * 100, v * 100))
        self.current_post = items
        desc = [ '{} ({:.3f})\n'.format(items[i][0], items[i][1]) for i in range(25) ]
        self.display.config(state=NORMAL)
        self.display.delete(1.0, END)
        self.display.insert(END, ''.join(desc))
        self.display.config(state=DISABLED)        
    
    def make_plotter(self, params,ax1,label,cur_color,support):
        mu1,sh1,sc1,mu2,sh2,sc2 = params
        left_stach=gam_dist(sh1,scale=sc1); lbounds=left_stach.interval(0.99)
        right_stach=gam_dist(sh2,scale=sc2); rbounds=right_stach.interval(0.99)
        lx=np.linspace(mu1,-180); rx=np.linspace(mu2,360)
        s=3; #cur_color='black'
        ax1.plot(rx, [1-right_stach.cdf(abs(y-mu2)) for y in rx],linewidth=s,c=cur_color);
        ax1.plot([1.01*mu1,0.99*mu2], [1,1], linewidth=s,c=cur_color)
        return ax1.plot(lx,[1-left_stach.cdf(abs(y-mu1)) for y in lx],c=cur_color, label=r"$\phi^{Hue}_{%s}$" % label,linewidth=s);
        
        
        
    def initUI(self):
      
        self.parent.title("Interactive LUX visualization")      
        self.pack(fill=BOTH, expand=1)
        
        self.color_frame = Frame(self, border=1)
        self.color_frame.pack(side=LEFT)
        
        
        self.frame = Frame(self.color_frame, border=1, 
            relief=SUNKEN, width=100, height=100)
        #self.frame.place(x=160, y=30)
        self.frame.pack(side=TOP)
        self.frame.config(bg=self.hex_color)
                
        
        self.btn = Button(self.color_frame, text="Select Color", 
            command=self.onChoose)
        self.btn.pack(side=TOP)
        #self.btn.place(x=30, y=30)
        

        self.display = Text(self, border=1, 
            relief=SUNKEN, width=30, height=5)
        #self.display.place(x=280, y=30)
        self.display.pack(side=LEFT,fill=Y,expand=1)
        self.update_output()
        
        self.fig = Figure(figsize=(10,4), dpi=100)
        self.replot()
        self.canvas = FigureCanvasTkAgg(self.fig, master=self)
        self.canvas.show()
        self.canvas.get_tk_widget().pack(side=LEFT, fill=BOTH, expand=1)
        self.canvas._tkcanvas.pack(side='top', fill='both', expand=1)

    def replot(self):
        hsv_title = []
        for i in range(3):
            if self.lux.get_adj(self.current_post[0][0]): 
                support = [[-180,180], [0,100],[0,100]][i]
                pp = lambda x,i: x-360 if i==0 and x>180 else x
            else: 
                support = [[0,360], [0,100],[0,100]][i]
                pp = lambda x,*args: x
            subplot = self.fig.add_subplot(3,1,i+1) 
            subplot.set_xlim(support[0],support[1])   
            subplot.set_ylabel("%s" % ["Hue", "Saturation","Value"][i])
            point = pp(self.hsv_color[i]*[360,100,100][i],i)
            hsv_title.append(point)
            probeplot = subplot.plot([point,point], [0,1],linewidth=3,c='black',label="Probe")
            legend_set = []
            for j in range(5):
                test = self.make_plotter(self.lux.get_params(self.current_post[j][0])[i], subplot, self.current_post[j][0],self.color_list[j],support)
                if type(test)==type([]): legend_set+=test
                else: legend_set.append(test)
        
        self.fig.legend(legend_set, [r"$\phi_{%s}$; $\alpha=%2.4f$" % (x[0],self.lux.get_avail(x[0])) for x in self.current_post[:5]], loc=1)
        self.fig.suptitle("HSV (%2.2f,%2.2f,%2.2f) top 5 Phi curves" % (hsv_title[0],hsv_title[1],hsv_title[2]))

    def onChoose(self):
      
        ((red,green,blue), hx) = tkColorChooser.askcolor()
        self.hex_color = hx
        self.hsv_color = colorsys.rgb_to_hsv(red/255.0, green/255.0, blue/255.0)
        self.frame.config(bg=hx)
        self.update_output()
        self.fig.clear()
        self.replot()
        self.canvas.draw()
Ejemplo n.º 5
0
def set_figure_properties(fig: plt.Figure, **kwargs) -> None:
    """
    Ease the configuration of a :class:`matplotlib.figure.Figure`.

    Parameters
    ----------
    fig : matplotlib.figure.Figure
        The figure.

    Keyword arguments
    -----------------
    fontsize : int
        Font size to use for the plot titles, and axes ticks and labels.
        Defaults to 12.
    tight_layout : bool
        `True` to fit the whole subplots into the figure area,
        `False` otherwise. Defaults to `True`.
    tight_layout_rect : Sequence[float]
        A rectangle (left, bottom, right, top) in the normalized figure
        coordinate that the whole subplots area (including labels) will
        fit into. Defaults to (0, 0, 1, 1).
    suptitle : str
        The figure title. Defaults to an empty string.
    xlabel : str
        TODO
    ylabel : str
        TODO
    figlegend_on : bool
        TODO
    figlegend_ax : int
        TODO
    figlegend_loc : `str` or `Tuple[float, float]`
        TODO
    figlegend_framealpha : float
        TODO
    figlegend_ncol : int
        TODO
    subplots_adjust_hspace : float
        TODO
    subplots_adjust_vspace : float
        TODO
    """
    fontsize = kwargs.get("fontsize", 12)
    tight_layout = kwargs.get("tight_layout", True)
    tight_layout_rect = kwargs.get("tight_layout_rect", (0, 0, 1, 1))
    suptitle = kwargs.get("suptitle", "")
    x_label = kwargs.get("x_label", "")
    x_labelpad = kwargs.get("x_labelpad", 20)
    y_label = kwargs.get("y_label", "")
    y_labelpad = kwargs.get("y_labelpad", 20)
    figlegend_on = kwargs.get("figlegend_on", False)
    figlegend_ax = kwargs.get("figlegend_ax", 0)
    figlegend_loc = kwargs.get("figlegend_loc", "lower center")
    figlegend_framealpha = kwargs.get("figlegend_framealpha", 1.0)
    figlegend_ncol = kwargs.get("figlegend_ncol", 1)
    wspace = kwargs.get("subplots_adjust_wspace", None)
    hspace = kwargs.get("subplots_adjust_hspace", None)

    rcParams["font.size"] = fontsize

    if suptitle is not None and suptitle != "":
        fig.suptitle(suptitle, fontsize=fontsize + 1)

    if x_label != "" or y_label != "":
        ax = fig.add_subplot(111)
        ax.set_frame_on(False)
        ax.set_xticks([])
        ax.set_xticklabels([], visible=False)
        ax.set_yticks([])
        ax.set_yticklabels([], visible=False)

        if x_label != "":
            ax.set_xlabel(x_label, labelpad=x_labelpad)
        if y_label != "":
            ax.set_ylabel(y_label, labelpad=y_labelpad)

    if tight_layout:
        fig.tight_layout(rect=tight_layout_rect)

    if figlegend_on:
        handles, labels = fig.get_axes()[figlegend_ax].get_legend_handles_labels()
        fig.legend(
            handles,
            labels,
            loc=figlegend_loc,
            framealpha=figlegend_framealpha,
            ncol=figlegend_ncol,
        )

    if wspace is not None:
        fig.subplots_adjust(wspace=wspace)
    if hspace is not None:
        fig.subplots_adjust(hspace=hspace)
Ejemplo n.º 6
0
def plot_joint_parameter_posteriors(fig: plt.Figure,
                                    parameter_names: [str],
                                    accepted_parameters: [[float]],
                                    predicted_vals: [float],
                                    priors: ["stats.Distribution"],
                                    weights=None) -> plt.Figure:

    if (len(parameter_names) != 2): parameter_names = ["", ""]
    for p in accepted_parameters:
        if (len(p) != 2):
            raise Exception("All `accepted_parameters` must be 2d exactly.")
    if (len(priors) != 2):
        raise Exception("Exactly 2 `priors` must be provided.")

    weights = weights if (weights) else [
        1 / len(accepted_parameters) for _ in range(len(accepted_parameters))
    ]

    acc_params_x = [p[0] for p in accepted_parameters]
    acc_params_y = [p[1] for p in accepted_parameters]

    fig = plt.figure()
    gs = GridSpec(4, 4)

    ax_scatter = fig.add_subplot(gs[1:4, 0:3])
    ax_hist_x = fig.add_subplot(gs[0, 0:3])
    ax_hist_y = fig.add_subplot(gs[1:4, 3])

    # plot scatter
    sns.kdeplot(acc_params_x,
                acc_params_y,
                cmap="Blues",
                shade=True,
                weights=weights,
                ax=ax_scatter)

    # plot accepted data
    ax_hist_x.hist(acc_params_x, density=True)
    ax_hist_y.hist(acc_params_y, orientation='horizontal', density=True)

    # plot priors used
    hist_x_xs = np.linspace(min(acc_params_x + [priors[0].ppf(.01)]),
                            max(acc_params_x + [priors[0].ppf(.99)]), 100)
    ax_hist_x.plot(hist_x_xs,
                   priors[0].pdf(hist_x_xs),
                   "k-",
                   lw=2,
                   label='Prior')
    hist_y_xs = np.linspace(min(acc_params_y + [priors[1].ppf(.01)]),
                            max(acc_params_y + [priors[1].ppf(.99)]), 100)
    ax_hist_y.plot(priors[1].pdf(hist_y_xs),
                   hist_y_xs,
                   "k-",
                   lw=2,
                   label='Prior')

    # plot smooth posterior (ie KDE)
    density_x = stats.kde.gaussian_kde(acc_params_x, weights=weights)
    ax_hist_x.plot(hist_x_xs,
                   density_x(hist_x_xs),
                   "--",
                   lw=2,
                   c="orange",
                   label="Posterior KDE")
    density_y = stats.kde.gaussian_kde(acc_params_y, weights=weights)
    ax_hist_y.plot(density_y(hist_y_xs),
                   hist_y_xs,
                   "--",
                   lw=2,
                   c="orange",
                   label="Posterior KDE")

    # plot posterior mean
    ax_hist_x.vlines(predicted_vals[0],
                     ymin=0,
                     ymax=ax_hist_x.get_ylim()[1],
                     colors="orange")
    ax_hist_y.hlines(predicted_vals[1],
                     xmin=0,
                     xmax=ax_hist_y.get_xlim()[1],
                     colors="orange")

    # make legend
    f = lambda m, l, c: ax_scatter.plot([], [], marker=m, color=c, ls=l)[0]
    handles = [
        f(None, "-", "black"),
        f(None, "-", "orange"),
        f(None, "--", "orange"),
        f("s", "", "#1f77b4")
    ]

    #     ax_scatter.legend(handles=handles,labels=["Prior","Posterior Mean","Posterior KDE","Accepted"],fontsize=18)
    ax_scatter.set_xlim((hist_x_xs[0], hist_x_xs[-1]))
    ax_scatter.set_ylim((hist_y_xs[0], hist_y_xs[-1]))

    ax_scatter.set_xlabel(parameter_names[0])
    ax_scatter.set_ylabel(parameter_names[1])

    ax_scatter.margins(0)
    ax_hist_x.margins(0)
    ax_hist_y.margins(0)

    title_ending = "" if "" in parameter_names else " for {} & {}".format(
        parameter_names[0], parameter_names[1])
    fig.suptitle("Joint and Marginal Posteriors{}.".format(title_ending))

    return fig