Beispiel #1
0
class Gaussian_Density_Estimator:
    def __init__(self, kernel='gaussian', bw='silverman'):
        self.estimator = FFTKDE(kernel=kernel, bw=bw)

    def train(self, data, weights=None):
        self.estimator.fit(data, weights=weights)

    def score_samples(self, input_x=None):
        if input_x is None:
            x, y = self.estimator.evaluate()
            return x, y
        else:
            y = self.estimator.evaluate(input_x)
            return y


# import numpy as np
# import matplotlib.pyplot as plt

# data = np.random.randn(2**6)
# density_estimator = Gaussian_Density_Estimator()
# density_estimator.train(data)
# # x, y = density_estimator.score_samples()
# # print(x.shape, y.shape)

# x, y = density_estimator.score_samples(10)
# print(y)

# plt.plot(x, y); plt.tight_layout()
# plt.show()
def loadCautiousDict(filename):
    data = pd.read_csv(filename)
    paramDict = data.set_index('keys').T.to_dict('list')

    for key, value in paramDict.items():
        prob, length = ast.literal_eval(value[0])
        if len(length) > 200:
            kde = FFTKDE(kernel='gaussian', bw='ISJ').fit(length)
            kde.evaluate()
            paramDict[key] = prob, kde.bw, length

    return paramDict
def is_outlier(y, x, th=10):
    z = FFTKDE(kernel='gaussian', bw='ISJ').fit(x)
    z.evaluate()
    bin_width = (max(x) - min(x)) * z.bw / 2
    eps = _EPS * 10

    breaks1 = np.arange(min(x), max(x) + bin_width, bin_width)
    breaks2 = np.arange(
        min(x) - eps - bin_width / 2,
        max(x) + bin_width, bin_width)
    score1 = robust_scale_binned(y, x, breaks1)
    score2 = robust_scale_binned(y, x, breaks2)
    return np.abs(np.vstack((score1, score2))).min(0) > th
Beispiel #4
0
def kdepy_fftkde(data, a, b, num_bin_joint):
    """ Calculate Kernel Density Estimation (KDE) using KDEpy.FFTKDE.
    Note: KDEpy.FFTKDE can do only symmetric kernel (accept only scalar bandwidth).
    We map data to [-1, 1] domain to make bandwidth independent of parameter range and more symmetric
    and use mean of list bandwidths (different bandwidth for each dimension)
    calculated usinf Scott's rule and scipy.stats.gaussian_kde
    :param data: array of parameter samples
    :param a: list of left boundaries
    :param b: list of right boundaries
    :param num_bin_joint: number of bins (cells) per dimension in estimated posterior
    :return: estimated posterior of shape (num_bin_joint, )*dimensions
    """

    N_params = len(data[0])
    logging.info('KDEpy.FFTKDe: Gaussian KDE {} dimensions'.format(N_params))
    time1 = time()
    a = np.array(a)-1e-10
    b = np.array(b)+1e-10
    data = 2 * (data - a) / (b - a) - 1     # transform data to be [-1, 1], since gaussian is the same in all directions
    bandwidth = bw_from_kdescipy(data, 'scott')
    _, grid_ravel = grid_for_kde(-1*np.ones(N_params), np.ones(N_params), num_bin_joint)
    kde = FFTKDE(kernel='gaussian', bw=np.mean(bandwidth))
    kde.fit(data)
    Z = kde.evaluate(grid_ravel.T)
    Z = Z.reshape((num_bin_joint + 1, )*N_params)
    time2 = time()
    timer(time1, time2, "Time for kdepy_fftkde")
    return Z
Beispiel #5
0
class KDEpyFFTwithISJBandwidth:
    description = 'KDE using KDEpy.FFTKDE and ISJ bandwidth'

    def __init__(self, data, bandwidth, xlim=None):
        self._instance = FFTKDE(kernel="gaussian", bw='ISJ').fit(data)

    def pdf(self, x):
        x = x.numpy()
        return self._instance.evaluate(x)
Beispiel #6
0
def _KDE(x, y, nGS):
    """Compute a bivariate kde using KDEpy."""

    # Grid points in the x and y direction
    grid_points_x, grid_points_y = nGS + 6, 2**8

    # Stack the data for 2D input, compute the KDE
    data = np.vstack((x, y)).T
    kde = FFTKDE(bw=0.025).fit(data)
    grid, points = kde.evaluate((grid_points_x, grid_points_y))

    # Retrieve grid values, reshape output and plot boundaries
    x2, y2 = np.unique(grid[:, 0]), np.unique(grid[:, 1])
    z = points.reshape(grid_points_x, grid_points_y)

    # Compute y_pred = E[y | x] = sum_y p(y | x) * y
    y_pred = np.sum((z.T / np.sum(z, axis=1)).T * y2, axis=1)
    id = np.where(x2 < np.min(x))
    x2 = np.delete(x2, id)
    y_pred = np.delete(y_pred, id)
    id = np.where(x2 > np.max(x))
    y_pred = np.delete(y_pred, id)

    return y_pred
    def density(self, compare=None):

        fig, ax = plt.subplots(1, len(self.bands) + 1, figsize=(30, 5))
        eval_list = []
        ep = 1e-10
        for i in range(len(self.bands)):

            kde = FFTKDE('gaussian', bw=0.13)
            kde.fit(self.inputs[:, :, :, i].ravel())

            if compare != None:
                min_v, max_v = self.domain(self.inputs[:, :, :, i],
                                           compare.inputs[:, :, :, i])
                grid = np.linspace(min_v - ep, max_v + ep, 100)
            else:
                grid = np.linspace(self.inputs[:, :, :, i].min() - ep,
                                   self.inputs[:, :, :, i].max() + ep, 100)
            evaluation = kde.evaluate(grid)
            ax[i].plot(grid, evaluation, label=self.name)
            ax[i].set_title(f"{self.name} {self.bands[i]}")
            eval_list.append(evaluation)

        kde = FFTKDE('gaussian', bw=0.13)
        kde.fit(self.outputs)
        if compare != None:
            min_v, max_v = self.domain(self.outputs, compare.outputs)
            grid = np.linspace(min_v - ep, max_v + ep, 100)
        else:
            grid = np.linspace(self.outputs.min() - ep,
                               self.outputs.max() + ep, 100)

        evaluation = kde.evaluate(grid)
        ax[-1].plot(grid, evaluation, label=self.name)
        ax[-1].set_title(f"{self.name} Outputs")
        eval_list.append(evaluation)

        if compare != None:
            for i in range(len(self.bands)):

                kde = FFTKDE('gaussian', bw=0.13)
                kde.fit(compare.inputs[:, :, :, i].ravel())
                #if compare != None:
                min_v, max_v = self.domain(self.inputs[:, :, :, i],
                                           compare.inputs[:, :, :, i])
                grid = np.linspace(min_v - ep, max_v + ep, 100)

                ax[i].plot(grid, kde.evaluate(grid), label=compare.name)
                ax[i].plot(grid,
                           kde.evaluate(grid) - eval_list[i],
                           label="Difference")
                ax[i].set_title(
                    f"{self.name} {self.bands[i]} | Compare: {compare.name}")
                ax[i].plot([
                    self.inputs[:, :, :, i].min(), self.inputs[:, :, :,
                                                               i].max()
                ], [0.0, 0.0],
                           linestyle='--',
                           alpha=0.3)

            kde = FFTKDE('gaussian', bw=0.13)
            kde.fit(compare.outputs)
            #if compare != None:
            min_v, max_v = self.domain(self.outputs, compare.outputs)
            grid = np.linspace(min_v - ep, max_v + ep, 100)

            ax[-1].plot(grid, kde.evaluate(grid), label=compare.name)
            ax[-1].plot(grid,
                        kde.evaluate(grid) - eval_list[-1],
                        label="Difference")
            ax[-1].set_title(f"{self.name} Outputs | Compare: {compare.name}")
            ax[-1].plot(
                [self.outputs.min(), self.outputs.max()], [0.0, 0.0],
                linestyle='--',
                alpha=0.3)

        plt.legend()
        plt.show()
Beispiel #8
0
def bwSJ(genes_log10_gmean_step1, bw_adjust=3):
    # See https://kdepy.readthedocs.io/en/latest/bandwidth.html
    fit = FFTKDE(kernel="gaussian", bw="ISJ").fit(npy.asarray(genes_log10_gmean_step1))
    _ = fit.evaluate()
    bw = fit.bw * bw_adjust
    return npy.array([bw], dtype=float)
Beispiel #9
0
def generate_selection(sid, file, kind="degree", dir="in", dataframe=False):
    if file is None or file.size == 0:
        p = Paragraph(text="""No vertices left""")
        return p

    # big_bang = time()

    if (kind == "degree"):
        edges = False
    else:
        edges = True

    limit = 1000000

    begin = time()
    if not dataframe:
        df = read_hdf(file)
    else:
        df = file

    names = df.columns.tolist()

    # print("Reading data {}-{}: ".format(dir, kind) + str(time()-begin))

    # begin = time()
    ### BASIC DEGREE COUNTING
    if (not edges):
        if (dir == "in"):
            deg_all = (df.ne(0).sum(axis=1)).to_numpy(copy=True)
        if (dir == "out"):
            deg_all = (df[df.columns].ne(0).sum(axis=1)).to_numpy(copy=True)
    else:
        adj_matrix = df.to_numpy(
            copy=True)  # convert dataframe to numpy array for efficiency
        deg_all = adj_matrix.flatten()

    # print("Degree counting/edge weights {}-{}: ".format(dir, kind) + str(time() - begin))
    # begin = time()
    if (len(deg_all) > limit):
        deg = choice(deg_all, limit, replace=False)
        #deg = deg_all[:limit]
        # print("Random sampling: {}-{}: ".format(dir, kind) + str(time() - begin))
        # begin = time()
        append(deg, array([max(deg_all)]))
        append(deg, array([min(deg_all)]))
        # print("Appending: {}-{}: ".format(dir, kind) + str(time() - begin))
    else:
        deg = deg_all

    if (edges):
        deg = [item for item in deg if item > 0]

    # begin = time()
    # deg_all = reshape(deg_all, (-1, 1))
    deg = reshape(deg, (-1, 1))
    # print("Reshaping: {}-{}: ".format(dir, kind) + str(time() - begin))
    # maxi = max(deg_all)[0]
    # if maxi == 0:
    #     deg_plot = linspace(0, 0.5, 1000)
    # else:
    #     deg_plot = linspace(0, maxi, 1000)
    # Calculate 'pretty good' (since best takes a long time) bandwidth
    # begin = time()
    #
    # grid = GridSearchCV(KernelDensity(),
    #                     {'bandwidth': linspace(0.1, 10.0, 20)},
    #                     cv=min(len(deg), 5),
    #                     iid=False)  # 5-fold cross-validation
    # grid.fit(deg)
    # print("Bandwidth: {}-{}: ".format(dir, kind) + str(time()-begin))
    # begin = time()
    # kde = grid.best_estimator_
    kde = FFTKDE(kernel='gaussian', bw='silverman').fit(deg)
    # if (not edges):
    #     kde = KernelDensity(kernel="gaussian", bandwidth=5.3).fit(deg)
    # if (edges):
    #     kde = KernelDensity(kernel="gaussian", bandwidth=0.2).fit(deg)
    # try:
    #     print(deg_plot[0][0])
    # except IndexError:
    #     deg_plot = array([[item] for item in deg_plot])
    # log_dens = kde.score_samples(deg_plot)

    X, Y = kde.evaluate()
    X = append(X, X[-1])
    X = insert(X, 0, X[0])
    Y = append(Y, 0)
    Y = insert(Y, 0, 0)
    complete = ColumnDataSource(data=dict(x=X, y=Y))
    before = ColumnDataSource(data=dict(x=[], y=[]))
    middle = ColumnDataSource(data=dict(x=X, y=Y))
    after = ColumnDataSource(data=dict(x=[], y=[]))

    if (not edges):

        type_dependent1 = "let p = document.getElementById('between-{}-degree')".format(
            dir) + """
                if(!p){
                p = document.createElement("p")
                """ + 'p.id = "between-{}-degree"'.format(dir) + """
                document.getElementsByClassName("bk-root")[0].appendChild(p)
            }
            """
        type_dependent2 = """
        amount = result;
            let hue = 120 - amount/5;
            if (hue < 0){
                hue = 0;
            }
            colored_amount = `<span style='color: hsl(${hue},100%,43%); font-weight:bold'>` + amount + "</span>"
                   
            let lower = Math.ceil(geometry.x0);
            let upper = Math.floor(geometry.x1);
            if(lower < 0){
                lower = 0;
            }
            if(upper < lower){
                p.innerHTML = "Selected no vertices, since selection doesn't contain an integer degree.";
            }
            else{
            """ + 'p.innerHTML = "Selected " + colored_amount + " vertices with {}-degree between " + lower + " and " + upper + ".";'.format(
            dir) + '}'

    else:
        type_dependent1 = """
            let p = document.getElementById('between-weight')

            if(!p){
                p = document.createElement("p")
                p.id = "between-weight"
                document.getElementsByClassName("bk-root")[0].appendChild(p)
            }
            """

        if get_directed(sid):
            type_dependent2 = """
            amount = result;
            let hue = 120 - amount/20;
            if (hue < 0){
                hue = 0;
            }
            colored_amount = `<span style='color: hsl(${hue},100%,43%); font-weight:bold'>` + amount + "</span>"
            
            let lower = Math.ceil(geometry.x0*100)/100
            if(lower < 0){
                lower = 0;
            }
    
                p.innerHTML = "Selected " + colored_amount + " edges with weight between " + lower + " and " + Math.floor(geometry.x1*100)/100 + "."
            """

        else:
            type_dependent2 = """
                        amount = result;
                        let hue = 120 - amount/20;
                        if (hue < 0){
                            hue = 0;
                        }
                        colored_amount = `<span style='color: hsl(${hue},100%,43%); font-weight:bold'>` + amount + "</span>"

                        let lower = Math.ceil(geometry.x0*100)/100
                        if(lower < 0){
                            lower = 0;
                        }

                            p.innerHTML = "Selected approximately " + colored_amount + " edges with weight between " + lower + " and " + Math.floor(geometry.x1*100)/100 + "."
                        """

    geometry_callback = CustomJS(args=dict(complete=complete,
                                           before=before,
                                           middle=middle,
                                           after=after),
                                 code="""
    let geometry = cb_data["geometry"]
    let Xs = complete.data.x
    let Ys = complete.data.y

    let bXs = before.data.x
    let bYs = before.data.y
    bXs = []
    bYs = []
    let mXs = middle.data.x
    let mYs = middle.data.y
    mXs = []
    mYs = []
    let aXs = after.data.x
    let aYs = after.data.y
    aXs = []
    aYs = []

    for (let i = 0; i < Xs.length; i++){
    // should use binary search
    let x = Xs[i]
    let y = Ys[i]
    if(x < geometry.x0){
    bXs.push(x)
    bYs.push(y)
    }
    else if (x > geometry.x1){
    aXs.push(x)
    aYs.push(y)
    }
    else {
    mXs.push(x)
    mYs.push(y)
    }
    }

    bXs.unshift(bXs[0])
    bYs.unshift(0)
    bXs.push(bXs[bXs.length-1])
    bYs[bYs.length] = 0
    mXs.unshift(mXs[0])
    mYs.unshift(0)
    mXs.push(mXs[mXs.length-1])
    mYs[mYs.length] = 0
    aXs.unshift(aXs[0])
    aYs.unshift(0)
    aXs.push(aXs[aXs.length-1])
    aYs[aYs.length] = 0

    before.data.x = bXs
    before.data.y = bYs
    middle.data.x = mXs
    middle.data.y = mYs
    after.data.x = aXs
    after.data.y = aYs

    before.change.emit()
    middle.change.emit()
    after.change.emit()
    var amount = 0
    let data = {
        left: geometry.x0,
        right: geometry.x1,
        file: window.location.pathname.substring(8),
        """ + "type: '{}', dir: '{}'".format(kind, dir) + """
    }
    $.post("/postmethod", data, function(result){ """ + type_dependent2 +
                                 "});" + type_dependent1)

    #p = figure(plot_width=300, plot_height=300, sizing_mode='scale_width')
    p = figure(plot_width=300, plot_height=300)
    select_tool = BoxSelectTool(dimensions="width", callback=geometry_callback)
    p.add_tools(select_tool)

    p.patch("x", "y", source=before, alpha=0.3, line_width=0, color="#3189ff")
    p.patch("x", "y", source=middle, alpha=1, line_width=0, color="#3189ff")
    # p.patch("x", "y", source=middle, alpha=1, line_width=0, color="orange")
    p.patch("x", "y", source=after, alpha=0.3, line_width=0, color="#3189ff")

    if (not edges):
        p.xaxis.axis_label = "{}-degree".format(dir)
    else:
        p.xaxis.axis_label = "Edge weight"

    p.yaxis.visible = False
    p.grid.visible = False

    p.toolbar.active_drag = select_tool
    p.toolbar.autohide = True

    p.background_fill_color = None
    p.border_fill_color = None
    p.outline_line_color = None

    # print("KDE + plotting: {}-{}: ".format(dir, kind) + str(time()-begin))
    # print("Total {}-{}: ".format(dir, kind) + str(time()-big_bang))
    return p
Beispiel #10
0
    def save_iter_plot(self,iteration,filename=None,C=None,G=None,D=None,params=None,save_conf=True):
        '''
        Plots a kde plot and the confidence of the discriminator D(x). Saves the plot in `filename'. 
        '''
        if params is None:
            params = self.params.copy()
        if G is None:
            G = self.G
        if D is None:
            D = self.D
        if (C is None) and self.CGAN:
            C = self.C_test
        if filename is None:
            filename = os.path.join(self.results_path,'plot_iter_%02d.'%iteration + self.format)
        if self.CGAN:
            assert G.c_dim > 0, 'Generator must be a conditional GAN if CGAN is toggled in the dataset.'
            assert str(C.keys()) == str(self.C_test.keys()), 'The tensor specified must include all conditional parameters, in the same order as C_test.'
        else:
            # Vanilla GAN case
            C = dict()

        # Update the relevant parameters with C
        params = {**params,**C}

        #------------------------------------------------------
        # Compute inputs
        #------------------------------------------------------

        if self.CGAN:
            C_test_tensor = make_test_tensor(C,self.N_test,device=self.device)
            in_sample = input_sample(self.N_test,C=C_test_tensor,Z=self.fixed_noise.to(self.device),device=self.device)
        else:
            in_sample = input_sample(self.N_test,C=None,Z=self.fixed_noise,device=self.device)

        output = G(in_sample).detach()
        gendata = postprocess(output.view(-1),params['S0'],proc_type=self.proc_type,S_ref=torch.tensor(params['S_bar'],device=self.device,dtype=torch.float32),eps=self.eps).cpu().view(-1).numpy()

        # Instantiate function for pdf of analytical distribution, add 1e-6 to keep the fraction X_next/X_prev finite
        exact_raw = preprocess(self.sample_exact(N=self.N_test,params=params),torch.tensor(params['S0'],\
                        dtype=torch.float32).view(-1,1),proc_type=self.proc_type,S_ref=torch.tensor(params['S_bar'],device=torch.device('cpu'),dtype=torch.float32),eps=self.eps).cpu().view(-1).numpy()
        
        output = output.view(-1).cpu().numpy()
        
        # Define domain based on GAN output 
        a1 = np.min(output)-0.1*np.abs(output.min())
        b1 = np.min(exact_raw)-0.1*np.abs(exact_raw.min())
        a2 = np.max(output)+0.1*np.abs(output.max())
        b2 = np.max(exact_raw)+0.1*np.abs(exact_raw.max())
        a = np.min((a1,b1))
        b = np.max((a2,b2))
        if a == 0:
            a -= 1e-20        

        if not self.supervised:

            # Define grid for KDE to be computed on 
            x_opt = np.linspace(a,b,1000)

            # Compute exact density p* and generator density p_th

            if self.proc_type is None:
                # Use exact pdf for p* if no pre-processing is used
                p_star = self.get_exact_pdf(params)(x_opt)
            else:
                # Otherwise use kernel estimate to compute p*
                p_star = FFTKDE(kernel='gaussian',bw='silverman').fit(exact_raw).evaluate(x_opt)

            kde_th = FFTKDE(kernel='gaussian',bw='silverman').fit(output)
            p_th = kde_th.evaluate(x_opt)

            # Optimal discriminator given G
            D_opt = p_star/(p_th+p_star)

            x_D = torch.linspace(x_opt.min(),x_opt.max(),self.N_test)

            # Build the input to the discriminator 
            input_D = x_D.view(-1,1).to(self.device)
            if self.supervised:
                # If the discriminator is informed with Z, give it zeros for testing 
                input_D = torch.cat((input_D,torch.zeros(self.N_test).view(-1,1)),axis=1)
            if self.CGAN:
                input_D = torch.cat((input_D,C_test_tensor),axis=1)
   
   
        #------------------------------------------------------
        # Select amount of subplots to be shown 
        #------------------------------------------------------
        # Only plot pre-processed data if pre-processing is not None
        # Only plot discriminator confidence if vanilla GAN is used

        single_handle = False # toggle to use if the axis handle is not an array
        if (self.proc_type is None) and (self.supervised):
            fig,ax = plt.subplots(1,1,figsize=(10,10),dpi=100)
            title_string = 'Generator output'
            single_handle = True
        elif (self.proc_type is None) and (not self.supervised):
            fig,ax = plt.subplots(1,2,figsize=(20,10),dpi=100)
            title_string = 'Generator output'
        elif (self.proc_type is not None) and (self.supervised):
            fig,ax = plt.subplots(1,2,figsize=(20,10),dpi=100)
            title_string = 'Post-processed data'
        else:
            fig,ax = plt.subplots(1,3,figsize=(30,10),dpi=100)
            title_string = 'Post-processed data'

        k_ax = 0 # counter for axis index 

        #------------------------------------------------------
        # Plot 1: Post-processed data
        #------------------------------------------------------
        y = self.x
        ymin = y.min()-0.1*np.abs(y.min())
        ymax = y.max()+0.1*np.abs(y.max())

        exact_pdf = self.get_exact_pdf(params)            

        if single_handle:
            ax_plot_1 = ax
        else:
            ax_plot_1 = ax[k_ax]

        ax_plot_1.plot(y,exact_pdf(y),'--k',label='Exact pdf')
        sns.kdeplot(gendata,shade=True,ax=ax_plot_1,label='Generated data')
        ax_plot_1.set_xlabel('$S_t$')
        # fig.suptitle(f'time = {self.T}')
        ax_plot_1.legend()
        # ax_plot_1.set_xlim(xmin=ymin,xmax=ymax)
        ax_plot_1.autoscale(enable=True, axis='x', tight=True)
        ax_plot_1.autoscale(enable=True, axis='y')
        ax_plot_1.set_ylim(bottom=0)        
        ax_plot_1.set_title(title_string)

        # Also plot only the kde plot as pdf 
        f_kde,ax_kde = plt.subplots(1,1,dpi=100)
        ax_kde.plot(y,exact_pdf(y),'--k',label='Exact pdf')
        sns.kdeplot(gendata,shade=True,ax=ax_kde,label='Generated data')
        ax_kde.set_xlabel('$S_t$')
        ax_kde.legend()
        ax_kde.set_xlim(xmin=ymin,xmax=ymax)
        # ax_kde.set_title(title_string)
        f_kde.suptitle(f'Iteration {iteration}')
        f_kde.savefig(os.path.join(self.results_path,'kde_output_iter_%02d'%iteration+'.pdf'),format='pdf')
        plt.close(f_kde)

        #------------------------------------------------------
        # Plot 2: Generator output
        #------------------------------------------------------
        if self.proc_type is not None:
            k_ax += 1
            sns.kdeplot(exact_raw,linestyle='--',color='k',ax=ax[k_ax],label='Pre-processed exact')        
            sns.kdeplot(output,shade=True,ax=ax[k_ax],label='Generated data')
            ax[k_ax].set_xlabel('$R_t$')
            ax[k_ax].legend()
            # ax[k_ax].set_xlim(xmin=a,xmax=b)
            ax[k_ax].autoscale(enable=True, axis='x', tight=True)
            ax[k_ax].autoscale(enable=True, axis='y')
            ax[k_ax].set_ylim(bottom=0)            
            ax[k_ax].set_title('Generator output')

        #------------------------------------------------------
        # Plot 3: Discriminator confidence
        #------------------------------------------------------    

        if not self.supervised:
            k_ax += 1
            ax[k_ax].plot(x_D,D(input_D).view(-1,1).detach().view(-1).cpu().numpy(),label='Discriminator output')
            ax[k_ax].plot(x_opt,D_opt,'--k',label='Optimal discriminator')

            # ax[1].set_title('Discriminator confidence')
            if self.proc_type is None:
                ax[k_ax].set_xlabel('$S_t$')
            else:
                ax[k_ax].set_xlabel('$R_t$')
            ax[k_ax].legend()
            # ax[k_ax].set_xlim(xmin=a,xmax=b)

            ax[k_ax].autoscale(enable=True, axis='x', tight=True)
            ax[k_ax].autoscale(enable=True, axis='y')
            ax[k_ax].set_ylim(bottom=0)

            if save_conf:
            # Repeat plot to save discriminator confidence itself as well 
                f_conf,ax_conf = plt.subplots(1,1,dpi=100)
                ax_conf.plot(x_D,D(input_D).view(-1,1).detach().view(-1).cpu().numpy(),label='Discriminator output')
                ax_conf.plot(x_opt,D_opt,'--k',label='Optimal discriminator')

                if self.proc_type is None:
                    ax_conf.set_xlabel('$S_t$')
                else:
                    ax_conf.set_xlabel('$R_t$')
                ax_conf.legend()
                ax_conf.set_xlim(xmin=a,xmax=b)
                f_conf.suptitle(f'Iteration {iteration}')
                f_conf.savefig(os.path.join(self.results_path,'D_conf_iter_%02d'%iteration+'.pdf'),format='pdf')
                plt.close(f_conf)

        #------------------------------------------------------
        # Wrap up
        #------------------------------------------------------

        fig.suptitle(f'Iteration {iteration}')
        fig.savefig(filename,format=self.format)
        plt.close()
def SCTransform(adata,
                min_cells=5,
                gmean_eps=1,
                n_genes=2000,
                n_cells=None,
                bin_size=500,
                bw_adjust=3,
                inplace=True):
    """
    This is a port of SCTransform from the Satija lab. See the R package for original documentation.
    
    Currently, only regression against the log UMI counts are supported.
    
    The only significant modification is that negative Pearson residuals are zero'd out to preserve
    the sparsity structure of the data.
    """
    X = adata.X.copy()
    X = sp.sparse.csr_matrix(X)
    X.eliminate_zeros()
    gn = np.array(list(adata.var_names))
    cn = np.array(list(adata.obs_names))
    genes_cell_count = X.sum(0).A.flatten()
    genes = np.where(genes_cell_count >= min_cells)[0]
    genes_ix = genes.copy()

    X = X[:, genes]
    Xraw = X.copy()
    gn = gn[genes]
    genes = np.arange(X.shape[1])
    genes_cell_count = X.sum(0).A.flatten()

    genes_log_gmean = np.log10(gmean(X, axis=0, eps=gmean_eps))

    if n_cells is not None and n_cells < X.shape[0]:
        cells_step1 = np.sort(
            np.random.choice(X.shape[0], replace=False, size=n_cells))
        genes_cell_count_step1 = X[cells_step1].sum(0).A.flatten()
        genes_step1 = np.where(genes_cell_count_step1 >= min_cells)[0]
        genes_log_gmean_step1 = np.log10(
            gmean(X[cells_step1][:, genes_step1], axis=0, eps=gmean_eps))
    else:
        cells_step1 = np.arange(X.shape[0])
        genes_step1 = genes
        genes_log_gmean_step1 = genes_log_gmean

    umi = X.sum(1).A.flatten()
    log_umi = np.log10(umi)
    X2 = X.copy()
    X2.data[:] = 1
    gene = X2.sum(1).A.flatten()
    log_gene = np.log10(gene)
    umi_per_gene = umi / gene
    log_umi_per_gene = np.log10(umi_per_gene)

    cell_attrs = pd.DataFrame(index=cn,
                              data=np.vstack(
                                  (umi, log_umi, gene, log_gene, umi_per_gene,
                                   log_umi_per_gene)).T,
                              columns=[
                                  'umi', 'log_umi', 'gene', 'log_gene',
                                  'umi_per_gene', 'log_umi_per_gene'
                              ])

    data_step1 = cell_attrs.iloc[cells_step1]

    if n_genes is not None and n_genes < len(genes_step1):
        log_gmean_dens = stats.gaussian_kde(genes_log_gmean_step1,
                                            bw_method='scott')
        xlo = np.linspace(genes_log_gmean_step1.min(),
                          genes_log_gmean_step1.max(), 512)
        ylo = log_gmean_dens.evaluate(xlo)
        xolo = genes_log_gmean_step1
        sampling_prob = 1 / (np.interp(xolo, xlo, ylo) + _EPS)
        genes_step1 = np.sort(
            np.random.choice(genes_step1,
                             size=n_genes,
                             p=sampling_prob / sampling_prob.sum(),
                             replace=False))
        genes_log_gmean_step1 = np.log10(
            gmean(X[cells_step1, :][:, genes_step1], eps=gmean_eps))

    bin_ind = np.ceil(np.arange(1, genes_step1.size + 1) / bin_size)
    max_bin = max(bin_ind)

    ps = Manager().dict()

    for i in range(1, int(max_bin) + 1):
        genes_bin_regress = genes_step1[bin_ind == i]
        umi_bin = X[cells_step1, :][:, genes_bin_regress]

        mm = np.vstack((np.ones(data_step1.shape[0]),
                        data_step1['log_umi'].values.flatten())).T

        pc_chunksize = umi_bin.shape[1] // os.cpu_count() + 1
        pool = Pool(os.cpu_count(), _parallel_init,
                    [genes_bin_regress, umi_bin, gn, mm, ps])
        try:
            pool.map(_parallel_wrapper,
                     range(umi_bin.shape[1]),
                     chunksize=pc_chunksize)
        finally:
            pool.close()
            pool.join()

    ps = ps._getvalue()

    model_pars = pd.DataFrame(data=np.vstack([ps[x] for x in gn[genes_step1]]),
                              columns=['Intercept', 'log_umi', 'theta'],
                              index=gn[genes_step1])

    min_theta = 1e-7
    x = model_pars['theta'].values.copy()
    x[x < min_theta] = min_theta
    model_pars['theta'] = x
    dispersion_par = np.log10(1 + 10**genes_log_gmean_step1 /
                              model_pars['theta'].values.flatten())

    model_pars = model_pars.iloc[:, model_pars.columns != 'theta'].copy()
    model_pars['dispersion'] = dispersion_par

    outliers = np.vstack(([
        is_outlier(model_pars.values[:, i], genes_log_gmean_step1)
        for i in range(model_pars.shape[1])
    ])).sum(0) > 0

    filt = np.invert(outliers)
    model_pars = model_pars[filt]
    genes_step1 = genes_step1[filt]
    genes_log_gmean_step1 = genes_log_gmean_step1[filt]

    z = FFTKDE(kernel='gaussian', bw='ISJ').fit(genes_log_gmean_step1)
    z.evaluate()
    bw = z.bw * bw_adjust

    x_points = np.vstack(
        (genes_log_gmean,
         np.array([min(genes_log_gmean_step1)] * genes_log_gmean.size))).max(0)
    x_points = np.vstack(
        (x_points,
         np.array([max(genes_log_gmean_step1)] * genes_log_gmean.size))).min(0)

    full_model_pars = pd.DataFrame(data=np.zeros(
        (x_points.size, model_pars.shape[1])),
                                   index=gn,
                                   columns=model_pars.columns)
    for i in model_pars.columns:
        kr = statsmodels.nonparametric.kernel_regression.KernelReg(
            model_pars[i].values,
            genes_log_gmean_step1[:, None], ['c'],
            reg_type='ll',
            bw=[bw])
        full_model_pars[i] = kr.fit(data_predict=x_points)[0]

    theta = 10**genes_log_gmean / (10**full_model_pars['dispersion'].values -
                                   1)
    full_model_pars['theta'] = theta
    del full_model_pars['dispersion']

    model_pars_outliers = outliers

    regressor_data = np.vstack(
        (np.ones(cell_attrs.shape[0]), cell_attrs['log_umi'].values)).T

    d = X.data
    x, y = X.nonzero()

    mud = np.exp(full_model_pars.values[:, 0][y] +
                 full_model_pars.values[:, 1][y] *
                 cell_attrs['log_umi'].values[x])
    vard = mud + mud**2 / full_model_pars['theta'].values.flatten()[y]

    X.data[:] = (d - mud) / vard**0.5
    X.data[X.data < 0] = 0
    X.eliminate_zeros()

    clip = np.sqrt(X.shape[0] / 30)
    X.data[X.data > clip] = clip

    if inplace:
        adata.raw = adata.copy()

        d = dict(zip(np.arange(X.shape[1]), genes_ix))
        x, y = X.nonzero()
        y = np.array([d[i] for i in y])
        data = X.data
        Xnew = sp.sparse.coo_matrix((data, (x, y)), shape=adata.shape).tocsr()
        adata.X = Xnew  # TODO: add log1p of corrected umi counts to layers

        for c in full_model_pars.columns:
            adata.var[c + '_sct'] = full_model_pars[c]

        for c in cell_attrs.columns:
            adata.obs[c + '_sct'] = cell_attrs[c]

        for c in model_pars.columns:
            adata.var[c + '_step1_sct'] = model_pars[c]

        z = pd.Series(index=gn, data=np.zeros(gn.size, dtype='int'))
        z[gn[genes_step1]] = 1

        w = pd.Series(index=gn, data=np.zeros(gn.size, dtype='int'))
        w[gn] = genes_log_gmean
        adata.var['genes_step1_sct'] = z
        adata.var['log10_gmean_sct'] = w

    else:
        adata_new = AnnData(X=X)
        adata_new.var_names = pd.Index(gn)
        adata_new.obs_names = adata.obs_names
        adata_new.raw = adata.copy()

        for c in full_model_pars.columns:
            adata_new.var[c + '_sct'] = full_model_pars[c]

        for c in cell_attrs.columns:
            adata_new.obs[c + '_sct'] = cell_attrs[c]

        for c in model_pars.columns:
            adata_new.var[c + '_step1_sct'] = model_pars[c]

        z = pd.Series(index=gn, data=np.zeros(gn.size, dtype='int'))
        z[gn[genes_step1]] = 1
        adata_new.var['genes_step1_sct'] = z
        adata_new.var['log10_gmean_sct'] = genes_log_gmean
        return adata_new
    def perform_fit(self, amp, pixel_pos,  training_library, max_fitpoints=None,
                    nodes=(64, 64, 64, 64, 64, 64, 64, 64, 64)):
        """
        Fit MLP model to individual template pixels

        :param amp: ndarray
            Pixel amplitudes
        :param pixel_pos: ndarray
            Pixel XY coordinate format (N, 2)
        :param max_fitpoints: int
            Maximum number of points to include in MLP fit
        :param nodes: tuple
            Node layout of MLP
        :return: MLP
            Fitted MLP model
        """
        pixel_pos = pixel_pos.T

        # If we put a limit on this then randomly choose points
        if max_fitpoints is not None and amp.shape[0] > max_fitpoints:
            indices = np.arange(amp.shape[0])
            np.random.shuffle(indices)
            amp = amp[indices[:max_fitpoints]]
            pixel_pos = pixel_pos[indices[:max_fitpoints]]

        if self.verbose:
            print("Fitting template using", training_library, "with", amp.shape[0],
                  "total pixels")

        # We need a large number of layers to get this fit right
        if training_library == "sklearn":
            from sklearn.neural_network import MLPRegressor

            model = MLPRegressor(hidden_layer_sizes=nodes, activation="relu",
                                 max_iter=1000, tol=0,
                                 early_stopping=True, verbose=False,
                                 n_iter_no_change=10)

            model.fit(pixel_pos, amp)
        elif training_library == "kde":
            from KDEpy import FFTKDE
            from scipy.interpolate import LinearNDInterpolator

            x, y = pixel_pos.T
            data = np.vstack((x, y, amp))
            #print(data.shape)
            kde = FFTKDE(bw=0.015).fit(data.T)
            points, out = kde.evaluate((self.bins[0], self.bins[1], 200))
            points_x, points_y, points_z = points.T
            #print(points_z.shape, points, out.shape)

            av_z = np.average(points_z)
            # print(av_z, ((np.max(points_z)-np.min(points_z))/2.) + np.min(points_z))
            av_val = np.sum((out*points_z).reshape((self.bins[0], self.bins[1], 200)), axis=-1) / \
                np.sum(out.reshape((self.bins[0], self.bins[1], 200)), axis=-1)

            points_x = points_x.reshape((self.bins[0], self.bins[1], 200))[:, :, 0].ravel()
            points_y = points_y.reshape((self.bins[0], self.bins[1], 200))[:, :, 0].ravel()

            int_points = np.vstack((points_x, points_y)).T
            lin = LinearNDInterpolator(np.vstack((points_x, points_y)).T, av_val.ravel(), fill_value=0)

            return lin

        elif training_library == "KNN":
            from sklearn.neighbors import KNeighborsRegressor

            model = KNeighborsRegressor(10)
            model.fit(pixel_pos, amp)

        elif training_library == "loess":
            from loess.loess_2d import loess_2d
            from scipy.interpolate import LinearNDInterpolator
            sel = amp!=0
            model = loess_2d(pixel_pos.T[0][sel], pixel_pos.T[1][sel], amp[sel],
                             degree=3, frac=0.005)
            lin = LinearNDInterpolator(pixel_pos[sel], model[0])
            return lin

        elif training_library == "keras":
            from keras.models import Sequential
            from keras.layers import Dense
            import keras
            
            model = Sequential()
            model.add(Dense(nodes[0], activation="relu", input_shape=(2,)))

            for n in nodes[1:]:
                model.add(Dense(n, activation="relu"))

            model.add(Dense(1, activation='linear'))
            model.compile(loss='mean_absolute_error',
                          optimizer="adam", metrics=['accuracy'])
            stopping = keras.callbacks.EarlyStopping(monitor='val_loss',
                                                     min_delta=0.0,
                                                     patience=10,
                                                     verbose=2, mode='auto')
            
#            pixel_pos_neg = np.array([pixel_pos.T[0], -1 * np.abs(pixel_pos.T[1])]).T
        
#            pixel_pos = np.concatenate((pixel_pos, pixel_pos_neg))
#            amp = np.concatenate((amp, amp))
        
            model.fit(pixel_pos, amp, epochs=10000,
                      batch_size=100000,
                      callbacks=[stopping], validation_split=0.1, verbose=0)

        return model