コード例 #1
0
    def plot_cg_TICA_contours(self, msm, macrostate, n_states, lag, file_name,
                              name):
        """
        DEPRECATED. Plots distributions of states along 2 ICs instead of doing coarse masking. 
        Becomes cumbersome for systems with more than 6 macrostates.

        Parameters
        ----------
        msm : TYPE
            DESCRIPTION.
        macrostate : TYPE
            DESCRIPTION.
        n_states : TYPE
            DESCRIPTION.
        lag : TYPE
            DESCRIPTION.
        file_name : TYPE
            DESCRIPTION.
        name : TYPE
            DESCRIPTION.

        Returns
        -------
        None.

        """

        rows, columns = tools_plots.plot_layout(macrostate)
        pcca_plot, axes = plt.subplots(rows,
                                       columns,
                                       figsize=(8, 6),
                                       sharex=True,
                                       sharey=True)

        #loop through individual states
        for idx, ax in axes.flat:
            #                 _=pyemma.plots.plot_contour(*tica_concat[:, :2].T, met_dist[dtrajs_concatenated], ax=ax, cmap='rainbow')
            _ = pyemma.plots.plot_free_energy(
                *self.tica_concat[:, :2].T,
                msm.metastable_distributions[idx][self.dtrajs_concatenated],
                legacy=False,
                ax=ax,
                method='nearest')
            #ax.scatter(*clusters.clustercenters.T)
            ax.set_xlabel('IC 1')
            ax.set_ylabel('IC 2')
            ax.set_title(f'Macrostate {idx+1}')

        pcca_plot.suptitle(
            f'PCCA+ of {self.ft_name}: {n_states} -> {macrostate} macrostates @ {lag*self.timestep}\nTICA of {name}',
            ha='center',
            weight='bold',
            fontsize=12)
        pcca_plot.tight_layout()
        pcca_plot.savefig(f'{self.results}/{file_name}.png',
                          dpi=600,
                          bbox_inches="tight")
        plt.show()
コード例 #2
0
def plot_MFPT(self,
              mfpt_df,
              scheme,
              feature,
              parameters,
              error=0.2,
              regions=None,
              labels=None):
    """Function to plot heatmap of MFPTS between all states."""

    # Taken from matplotlib documentation. Make images respond to changes in the norm of other images (e.g. via the
    # "edit axis, curves and images parameters" GUI on Qt), but be careful not to recurse infinitely!
    def update(changed_image):
        for im in images:
            if (changed_image.get_cmap() != im.get_cmap()
                    or changed_image.get_clim() != im.get_clim()):
                im.set_cmap(changed_image.get_cmap())
                im.set_clim(changed_image.get_clim())

    images = []

    # =============================================================================
    #
    #
    #             mfpt_cg_table=pd.DataFrame([[msm.mfpt(msm.metastable_sets[i], msm.metastable_sets[j]) for j in range(macrostate)] for i in range(macrostate)],
    #                                        index=index_cg, columns=[c for c in np.arange(1, macrostate+1)])
    #             #mfpt_cg_c=pd.concat([mfpt_cg_c, mfpt_cg_table], axis=0, sort=True)
    #             print(mfpt_cg_table)
    #             #sns.heatmap(mfpt_cg_table, linewidths=0.1, cmap='rainbow', cbar_kws={'label': r'rates (log $s^{-1}$)'})
    #             #mfpt_cg_table.plot()
    # =============================================================================
    # =============================================================================
    #         #rates_cg_table=mfpt_cg_table.apply(lambda x: (1 / x) / 1e-12)
    #         rates_cg_table.replace([np.inf], np.nan, inplace=True)
    #         rates_cg_table.fillna(value=0, inplace=True)
    #
    #         #rates_cg_table=pd.concat([rates_cg_table, pi_cg], axis=1)
    #         #rates_cg_c=pd.concat([rates_cg_c, rates_cg_table], axis=0, sort=True)
    #         print(rates_cg_table)
    #         rates_cg_table.plot()
    # =============================================================================

    rows, columns = tools_plots.plot_layout(parameters)
    fig, axes = plt.subplots(rows,
                             columns,
                             constrained_layout=True,
                             figsize=(9, 6))
    fig.suptitle(
        f'Discretization: {scheme}\nFeature: {feature} (error tolerance {error:.1%})',
        fontsize=14)

    cmap_plot = plt.cm.get_cmap("gist_rainbow")
    #cmap_plot.set_under(color='red')

    #cmap_plot.set_over(color='yellow')
    cmap_plot.set_bad(color='white')

    vmins, vmaxs = [], []

    for plot, parameter in zip(axes.flat, parameters):

        means = self.mfpt_filter(mfpt_df, scheme, feature, parameter, error)

        try:
            if scheme == 'combinatorial':
                contour_plot = plot.pcolormesh(means,
                                               edgecolors='k',
                                               linewidths=1,
                                               cmap=cmap_plot)
                label_names = tools.sampledStateLabels(
                    regions, sampled_states=means.index.values, labels=labels)
                positions = np.arange(0.5, len(label_names) + 0.5)
                plot.set_xticks(positions)
                plot.set_xticklabels(label_names, fontsize=7, rotation=70)
                plot.set_yticks(positions)
                plot.set_yticklabels(label_names, fontsize=7)
            else:
                contour_plot = plot.pcolormesh(means, cmap=cmap_plot)
                ticks = plot.get_xticks() + 0.5
                plot.set_xticks(ticks)
                plot.set_xticklabels((ticks + 0.5).astype(int))
                plot.set_yticks(ticks[:-1])
                plot.set_yticklabels((ticks + 0.5).astype(int)[:-1])

            plot.set_facecolor('white')
            plot.set_title(parameter)  #, fontsize=8)
            plot.set_xlabel('From state', fontsize=10)
            plot.set_ylabel('To state', fontsize=10)
            images.append(contour_plot)

        except:
            print('No values to plot')

    # Find the min and max of all colors for use in setting the color scale.
    vmins = []
    vmaxs = []
    for image in images:
        array = image.get_array()
        try:
            vmin_i = np.min(array[np.nonzero(array)])
        except:
            vmin_i = 1
        try:
            vmax_i = np.max(array[np.nonzero(array)])
        except:
            vmax_i = 1e12
        vmins.append(vmin_i)
        vmaxs.append(vmax_i)

    vmin = min(vmins)
    vmax = max(vmaxs)
    #vmax = max(image.get_array().max() for image in images)

    norm = ml.colors.LogNorm(vmin=vmin, vmax=vmax)

    for im in images:
        im.set_norm(norm)

    print(f'limits: {vmin:e}, {vmax:e}')

    cbar = fig.colorbar(images[-1], ax=axes)
    cbar.set_label(label=r'MFPT (ps)', size='large')
    cbar.ax.tick_params(labelsize=12)
    for im in images:
        im.callbacksSM.connect('changed', update)

    return images
コード例 #3
0
ファイル: Featurize.py プロジェクト: hcarv-itb/SimFound_v2
    def plot(self,
             input_df=None,
             method=None,
             feature_name=None,
             level=2,
             subplots_=True,
             stats=True):
        """
        General function to print *Featurized* pandas dataFrame, either stored under a Featurize instance or external input.
        Takes as input the level referring to the Project *hierarchy* ontology. 
        Alternatively, it will extract data based on multi-index definitions of provided dataFrame.
        Generates statistical plots of the *level*.
        Generates subplots of sublevel.

        Parameters
        ----------
        input_df : TYPE, optional
            DESCRIPTION. The default is None.
        method : TYPE, optional
            DESCRIPTION. The default is 'RMSD'.
        feature_name : TYPE, optional
            DESCRIPTION. The default is None.
        level : TYPE, optional
            DESCRIPTION. The default is 2.
        subplots_ : TYPE, optional
            DESCRIPTION. The default is True.
        stats : TYPE, optional
            DESCRIPTION. The default is True.
                
        
        Example
        -------
        When Project.hierarchy=['protein', 'ligand', 'parameter']
        Using level=2 will plot *ligand* stats and sublevel *parameter* plots.
        

        Returns
        -------
        None.
        

        """

        try:
            input_df = self.features[f'{feature_name}']
            print(f'Feature {feature_name} found.')

        except KeyError:

            input_df = input_df
            print(f'Feature {feature_name} not found. Using input DataFrame.')

        try:
            method = input_df.name
            function, kind, _, _, _ = self.getMethod(method)
        except:
            function, kind, _, _, _ = self.getMethod(method)

        levels = input_df.columns.levels  #might need to subselect here
        units = levels[-1].to_list()[0]

        #Set stats
        level_iterables = input_df.columns.get_level_values(
            f'l{level}').unique()
        sublevel_iterables = input_df.columns.get_level_values(
            f'l{level+1}').unique()

        #plot level
        rows, columns = tools_plots.plot_layout(level_iterables)
        fig_, axes_ = plt.subplots(rows,
                                   columns,
                                   sharex=True,
                                   sharey=True,
                                   constrained_layout=True,
                                   figsize=(9, 6))

        try:
            flat = axes_.flat
        except AttributeError:
            flat = axes_

            flat = [flat]

        for sup_it, ax_ in zip(level_iterables, flat):

            sup_df = input_df.loc[:,
                                  input_df.columns.
                                  get_level_values(f'l{level}') == sup_it]

            #Plot sublevel
            rows_, columns_ = tools_plots.plot_layout(sublevel_iterables)
            fig_it, axes_it = plt.subplots(rows_,
                                           columns_,
                                           sharex=True,
                                           sharey=True,
                                           constrained_layout=True,
                                           figsize=(9, 6))

            try:
                axes_it.flat
            except:

                axes_it = np.asarray(axes_it)

            for iterable, ax_it in zip(sublevel_iterables, axes_it.flat):

                #level +1 to access elements below in hierarchy
                df_it = sup_df.loc[:,
                                   sup_df.columns.
                                   get_level_values(f'l{level+1}') == iterable]
                title_ = f'{method}: {feature_name}'
                #if kind == 'global':
                #     df_it=df_it.min(axis=1)
                #     title_ = f'{feature_name}(min)'
                df_it.plot(kind='line',
                           subplots=subplots_,
                           sharey=True,
                           figsize=(9, 6),
                           legend=False,
                           sort_columns=True,
                           linewidth='1',
                           ax=ax_it)

                #{df_it.index.values[0]} to {df_it.index.values[-1]} {df_it.index.name}'
                #print(units)
                ax_it.set_xlabel(df_it.index.name)
                ax_it.set_ylabel(f'{method} ({units})')

                #print(df_it)

                #Print sublevel into level
                if kind == 'by_time':

                    mean = df_it.mean()
                    mean.plot()

                elif kind == 'by_element':

                    mean = df_it.mean(axis=1)
                    std = df_it.std(axis=1).values
                    ax_.plot(mean.index.values, mean)
                    ax_.fill_between(mean.index.values,
                                     mean + std,
                                     mean - std,
                                     alpha=0.3)

                    mean, std_lower, std_upper = mean, mean.values - std, mean.values + std

                    #print(std_lower, std_upper)
                    if not (np.isnan(std_lower).any()
                            and np.isnan(std_upper).any()):
                        ax_.fill_between(mean, std_lower, std_upper, alpha=0.8)

                    ax_.set_xlabel(input_df.index.name)
                    ax_.set_ylabel(f'{method} ({units})')

                elif kind == 'global':

                    sns.distplot(df_it.to_numpy().flatten(),
                                 ax=ax_,
                                 hist=False,
                                 kde=True,
                                 kde_kws={
                                     'shade': True,
                                     'linewidth': 2
                                 })
                    #ax_.set_xscale('log')
                    ax_.set_yscale('log')

                    #ax_.hist(sup_df.to_numpy
                ax_.set_xlabel(f'{method} ({levels[-1].to_list()[0]})')
                ax_.set_title(f'{method}: {sup_it}')

            fig_it.suptitle(title_)
            fig_it.show()
            fig_it.savefig(os.path.abspath(
                f'{self.results}/{method}_{feature_name}_sub_l{level+1}-{sup_it}.png'
            ),
                           bbox_inches="tight",
                           dpi=600)

        fig_.legend(sublevel_iterables)
        #fig_.show()
        fig_.savefig(os.path.abspath(
            f'{self.results}/{method}_{feature_name}_l{level}_stats.png'),
                     bbox_inches="tight",
                     dpi=600)

        return plt.show()
コード例 #4
0
    def plot(self, 
             input_df=None,
             level=2, 
             subplots_=True):
        
        dfs= []        
        for k, v in self.discretized.items():
            if re.search('shellProfile', k):
                print(f'Discretization {k} found.')
                dfs.append((k,v, 'ShellProfile'))
                
            if re.search('combinatorial', k):
                print(print(f'Discretization {k} found.'))
                dfs.append((k,v, 'Combinatorial'))
                
        if not len(dfs):
            
            df=[('external', input_df)]
            print('Feature not found. Using input DataFrame.')
            
            
        
        
        for name_df in dfs:
            (name, df, kind) = name_df

            
            level_iterables=df.columns.levels #Exclude last, the values of states
            rows, columns=tools_plots.plot_layout(level_iterables[level])
            fig_, axes_it =plt.subplots(rows, columns, sharex=True, sharey=False, constrained_layout=True, figsize=(12,9))
            
            try:
                axes_it.flat
            except:
                axes_it=np.asarray(axes_it)
        
            for iterable, ax_it in zip(level_iterables[level], axes_it.flat): 
                
                df_it=df.loc[:,df.columns.get_level_values(f'l{level+1}') == iterable] #level +1 due to index starting at l1
                if kind == 'ShellProfile':
                    df_it.plot(kind='line', 
                               ax=ax_it, 
                               subplots=subplots_,  
                               title=f'{name} @{iterable}', 
                               figsize=(9,6), 
                               legend=False,
                               sort_columns=True,
                               linewidth=1,
                               loglog=True,
                               xlabel=f'{level_iterables[-1].to_list()[0]}',
                               ylabel='counts')
                elif kind == 'Combinatorial':
                    df_it.plot(x=self.data.index.name, y=np.arange(0,len(df_it.columns)), 
                               kind='scatter',
                               ax=ax_it,
                               subplots=subplots_, 
                               sharex=True, 
                               sharey=True, 
                               #layout=(5,5), # (int(len(df_it)/2), int(len(df_it)/2)),
                               title=f'{name} @{iterable}',
                               xlabel=f'Trajectory time ({self.data.index.name})',
                               ylabel='State Index',
                               figsize=(9,6))

                plt.savefig(f'{self.results}/discretized_{kind}_{name}.png', bbox_inches="tight", dpi=600)
        
        return plt.show()
コード例 #5
0
        def dG_plot():
            
            dG_fits=pd.DataFrame()
            dG_fits.name=r'$\Delta$G'
            

            
            
            ord_scalars=[]
            for o in ordered_concentrations:
                for scalar, iterables in scalars.items():
                    #print(scalar, iterables)
                    if iterables[0] == o:
                        ord_scalars.append((scalar, iterables))
            #print(ord_scalars)
            
            rows, columns=tools_plots.plot_layout(ord_scalars)
            fig_dG,axes_dG=plt.subplots(rows, columns, sharex=True, sharey=True, constrained_layout=True, figsize=(6,6))
            
            
            for (scalar, iterables), ax_dG in zip(ord_scalars, axes_dG.flat): #scalars.items()
                
                legends_dG = []
                for iterable in iterables:
                    
                    #Important stuff going on here.
                    iterable_df = tools.Functions.get_descriptors(input_df, level, iterable, describe=describe, quantiles=quantiles)
                    N_enz, N_error_p, N_error_m  = self.get_sim_counts(iterable_df, iterable, quantiles, describe)
                    N_opt, N_enz_fit, N_t, bulk_value, fitted_bulk, factor, unit = self.get_fittings(iterable, 
                                                                                                     N_enz, 
                                                                                                     ranges, 
                                                                                                     bulk_range, 
                                                                                                     resolution, 
                                                                                                     name_fit)
                    #Calculate dG (kT)
                    (a, a_p, a_m, b) = [np.log(i) for i in [N_enz, N_error_p, N_error_m, N_opt]]                

                    dG=pd.DataFrame({f'$\Delta$G {iterable}': np.negative(a - b)}) 
                    dG_err_m=np.negative(a_m.subtract(b, axis='rows'))
                    dG_err_p=np.negative(a_p.subtract(b, axis='rows'))                            
                    
                    theoretic_df=pd.DataFrame({f'{name_original} {iterable}':N_t, 
                                               f'{name_fit} {iterable}':N_opt}, 
                                               index=N_enz.index.values)                
                    dG_fits=pd.concat([dG_fits, dG, dG_err_m, dG_err_p, N_enz, theoretic_df], axis=1)    
                    
                    if describe == 'mean':
                        
                        ax_dG.plot(ranges, dG, color='green')
                        ax_dG.fill_between(ranges, dG_err_m, dG_err_p, alpha=0.5, color='green')
                        ax_dG.set_ylim(-4, 4)
                        legends_dG.append(['shells'])
                        
# =============================================================================
#                         ax_dG.vlines([2.25,4.5,8,10,25], 
#                                      -4, 
#                                      4, 
#                                      linestyle='dashdot', 
#                                      colors=['darkorange', 'black', 'dimgray', 'silver', 'lightgray'], #
#                                      label=labels)
# =============================================================================
                        
                        
                        #legends_dG.append(labels)
                        locs=(0.85, 0.3)
                
                    elif describe == 'quantile':
                        
                        ax_dG.plot(ranges, dG, color='orange')
                        ax_dG.set_ylim(-9, 4)
                        
                        legends_dG.append(iterable)
                        legends_dG=[f'{iterable}-Q0.5']
                        
                        for idx, m in enumerate(dG_err_m, 1):
                            ax_dG.plot(ranges, dG_err_m[m], alpha=1-(0.15*idx)) #, color='green')
                            legends_dG.append(m)
                        for idx, p in enumerate(dG_err_p, 1):
                            ax_dG.plot(ranges, dG_err_p[p], alpha=1-(0.15*idx)) #, color='red') 
                            legends_dG.append(p)
                            
                        #locs=(0.79, 0.12)
                    legends_dG.append(iterable)
                
                #ax_dG.grid()    
                #ax_dG.legend() #legends_dG) #, loc=locs)
                ax_dG.axhline(y=0, ls='--', color='black')
                ax_dG.set_xlim(1,bulk_range[-1]+10)
                ax_dG.set_title(iterable, fontsize=10)
                ax_dG.set_xscale('log')
                
                if len(shells):
                    ax_dG.vlines(shells, color='grey', linewidth=1)
              
            #TODO: Change this for even number of iterables, otherwise data is wiped.    
            if len(axes_dG.flat) > 7 and (len(axes_dG.flat % 2)) != 0:
                axes_dG.flat[-1].axis("off")
                locs=(0.75, 0.4)
            
                
            handles, labels = ax_dG.get_legend_handles_labels()
            fig_dG.legend(handles, labels, bbox_to_anchor=Discretize.def_locs)
            #fig_dG.legend(labels, loc=locs) #legends_dG, loc=locs)
            #fig_dG.subplots_adjust(wspace=0, hspace=0)
            fig_dG.text(0.5, -0.04, r'$\itd$$_{NAC}$ ($\AA$)', ha='center', va='center', fontsize=12)
            fig_dG.text(-0.04, 0.5, r'$\Delta$G ($\it{k}$$_B$T)', ha='center', va='center', rotation='vertical', fontsize=12)
            fig_dG.suptitle(f'Feature: {feature_name}\n{describe}')
            
            fig_dG.savefig(f'{self.results}/binding_profile_{describe}.png', dpi=600, bbox_inches="tight")
            fig_dG.show()
            dG_fits.to_csv(f'{self.results}/dGProfile_-{describe}.csv')
            #print(dG_fits)
            
            return dG_fits, fig_dG
コード例 #6
0
        def bulk_fitting():
            
            rows, columns=tools_plots.plot_layout(iterables)
            fig_fits,axes_fit=plt.subplots(rows, columns, sharex=True, sharey=True, constrained_layout=True, figsize=(6,6))
            
            try:
                subplots=axes_fit.flat
            except AttributeError:
                subplots=axes_fit
                subplots=[subplots]
            
            
            legends = [name_original, name_fit]
            
            for iterable, ax_fit in zip(iterables, subplots):
                
                #TODO: Store and send to dG to avoid double calc. Dict.
                #Important stuff going on here.
                
                iterable_df = tools.Functions.get_descriptors(input_df, 
                                                              level, 
                                                              iterable, 
                                                              describe=describe, 
                                                              quantiles=quantiles)
                N_enz, N_error_p, N_error_m  = self.get_sim_counts(iterable_df, iterable, quantiles, describe)            
                N_opt, N_enz_fit, N_t, bulk_value, fitted_bulk, factor, unit = self.get_fittings(iterable, 
                                                                                                 N_enz, 
                                                                                                 ranges, 
                                                                                                 bulk_range, 
                                                                                                 resolution, 
                                                                                                 name_fit)
    
                ax_fit.plot(ranges, N_t, label=r'N$_{(i)}$reference (initial)', color='red', ls='--')
                ax_fit.plot(ranges, N_opt, label=r'N$_{(i)}$reference (fit)', color='black')
    

                if describe == 'mean':
                       
                    ax_fit.plot(ranges, N_enz, label='N$_{(i)}$enzyme', color='green')
                    ax_fit.fill_betweenx(N_enz_fit, bulk_range[0], bulk_range[-1], label='Bulk', color='grey', alpha=0.8)
                    ax_fit.set_ylim(1e-4, 100)               
                    ax_fit.fill_between(ranges, N_error_m, N_error_p, color='green', alpha=0.3)
                    locs=(0.9, 0.13)
                
                elif describe == 'quantile':
                        
                    ax_fit.plot(ranges, N_enz, label='N$_{(i)}$enzyme', color='orange')
                    ax_fit.set_ylim(1e-3, 200)
                    
                    for idx, m in enumerate(N_error_m, 1):
                        ax_fit.plot(ranges, N_error_m[m], alpha=1-(0.15*idx), label=m, color='green')
                        legends.append(m)
                    for idx, p in enumerate(N_error_p, 1):
                        ax_fit.plot(ranges, N_error_p[p], label=p, alpha=1-(0.15*idx), color='red') 
                        legends.append(p)
                    locs=(0.79, 0.12)
                
                
                #ax_fit.grid()        
                ax_fit.set_yscale('log')
                ax_fit.set_xscale('log')
                ax_fit.set_xlim(1,bulk_range[-1] + 10)
                ax_fit.set_title(f'{iterable} ({np.round(fitted_bulk, decimals=1)} {unit})', fontsize=12)
    
            if describe == 'mean':
                legends.append('N$_{(i)}$enzyme')
                legends.append('Bulk')
            if describe == 'quantile':
                legends.append(N_enz.name)
            
            if not axes_fit.flat[-1].lines: 
                axes_fit.flat[-1].set_visible(False)
            
            if len(axes_fit.flat) > 7 and (len(axes_fit.flat % 2)) != 0:
                axes_fit.flat[-1].axis("off")
                locs=(0.9, 0.5)
                
            handles, labels = ax_fit.get_legend_handles_labels()
            fig_fits.subplots_adjust(wspace=0) #wspace=0, hspace=0
            fig_fits.legend(handles, labels, bbox_to_anchor=Discretize.def_locs) #legends, loc=locs)
            fig_fits.text(0.5, -0.04, r'Shell $\iti$ ($\AA$)', ha='center', va='center', fontsize=12)
            fig_fits.text(-0.04, 0.5, r'$\itN$', ha='center', va='center', rotation='vertical', fontsize=12)
            fig_fits.suptitle(f'Feature: {feature_name}\n{describe}')
            fig_fits.tight_layout()
            fig_fits.show()
            fig_fits.savefig(f'{self.results}/{feature_name}_{describe}_fittings.png', dpi=600, bbox_inches="tight")