Пример #1
0
def plot_individual_energy(data, plot_trace_kwargs=None, save=True):
    """
    Individual plot of the Delta G time evolution
    :param data: np.array
    :param plot_trace_kwargs: kwargs for msmexplorer's plot_trace function
    :param save: Bool, save image or not
    :return ax, side_ax: The main and side axis produced by plot_trace
    """
    if plot_trace_kwargs is None:
        plot_trace_kwargs = {}
    # actual values
    ax, side_ax = plot_trace(data=data, window=1, alpha=0.2,
                             **plot_trace_kwargs)
    # mean and std lines in side_ax
    side_ax.axhline(y=data.mean())
    side_ax.axhline(y=data.mean() - data.std(), ls='--', lw=0.8)
    side_ax.axhline(y=data.mean() + data.std(), ls='--', lw=0.8)
    # annotate side_ax with actual values
    side_ax.annotate(f"\u03bc = {data.mean():.02f}", xy=(0.55, 0.90), xycoords='axes fraction')
    side_ax.annotate(f"\u03c3 = {data.std():.02f}", xy=(0.55, 0.85), xycoords='axes fraction')
    # moving avg in ax
    plot_trace(data=data, window=max(1, int(len(data) / 100)),
               ax=ax, alpha=1,
               **plot_trace_kwargs)
    ax.set(title=args.plot_title, xlabel='Time (ns)',
           ylabel='$\Delta G (kcal \cdot mol^{-1}$)')
    formatter = FuncFormatter(to_ns)
    ax.xaxis.set_major_formatter(formatter)
    if save:
        f = plt.gcf()
        f.savefig((args.output_file + '_individual.pdf'))
        f.savefig((args.output_file + '_individual.png'), dpi=300)
    return ax, side_ax
Пример #2
0
def plot_individual_energy(data, plot_trace_kwargs=None, save=True):
    """
    Individual plot of the Delta G time evolution
    :param data: np.array
    :param plot_trace_kwargs: kwargs for msmexplorer's plot_trace function
    :param save: Bool, save image or not
    :return ax, side_ax: The main and side axis produced by plot_trace
    """
    if plot_trace_kwargs is None:
        plot_trace_kwargs = {}
    # actual values
    ax, side_ax = plot_trace(data=data,
                             window=1,
                             alpha=0.2,
                             **plot_trace_kwargs)
    # mean and std lines in side_ax
    side_ax.axhline(y=data.mean())
    side_ax.axhline(y=data.mean() - data.std(), ls='--', lw=0.8)
    side_ax.axhline(y=data.mean() + data.std(), ls='--', lw=0.8)
    # annotate side_ax with actual values
    side_ax.annotate(f"\u03bc = {data.mean():.02f}",
                     xy=(0.55, 0.90),
                     xycoords='axes fraction')
    side_ax.annotate(f"\u03c3 = {data.std():.02f}",
                     xy=(0.55, 0.85),
                     xycoords='axes fraction')
    # moving avg in ax
    plot_trace(data=data,
               window=max(1, int(len(data) / 100)),
               ax=ax,
               alpha=1,
               **plot_trace_kwargs)
    ax.set(title=args.plot_title,
           xlabel='Time (ns)',
           ylabel='$\Delta G (kcal \cdot mol^{-1}$)')
    formatter = FuncFormatter(to_ns)
    ax.xaxis.set_major_formatter(formatter)
    if save:
        f = plt.gcf()
        f.savefig((args.output_file + '_individual.pdf'))
        f.savefig((args.output_file + '_individual.png'), dpi=300)
    return ax, side_ax
Пример #3
0
def plot_rsmd(traj_list, fout=None):
    rmsd_list = [mdtraj.rmsd(traj, traj, 0) * 10 for traj in traj_list]

    ax, side_ax = msme.plot_trace(rmsd_list[0], ylabel='RMSD (Å)', xlabel='Time (ns)',
                                  label=args.Trajectories[0][:-3],
                                  **next(palette_cycled))
    formatter = FuncFormatter(to_ns)
    ax.xaxis.set_major_formatter(formatter)
    if len(rmsd_list) > 1:
        for i, rmsd in enumerate(rmsd_list[1:]):
            msme.plot_trace(rmsd, ylabel='RMSD (Å)', xlabel='Time (ns)', ax=ax,
                            side_ax=side_ax,
                            label=args.Trajectories[i + 1][:-3],
                            **next(palette_cycled))

    if len(rmsd_list) > 5:
        ax.legend_.remove()
    sns.despine()
    f = plt.gcf()
    f.savefig(fout)
Пример #4
0
"""
Trace Plot
==========
"""
from msmbuilder.example_datasets import FsPeptide
from msmbuilder.featurizer import RMSDFeaturizer

import msmexplorer as msme

# Load Fs Peptide Data
traj = FsPeptide().get().trajectories[0]

# Calculate RMSD
featurizer = RMSDFeaturizer(reference_traj=traj[0])
rmsd = featurizer.partial_transform(traj).flatten()

# Plot Trace
msme.plot_trace(rmsd, label='traj0', xlabel='Timestep', ylabel='RMSD (nm)')
Пример #5
0
def plot_tic_array(tica_trajs,
                   nrows,
                   ncols,
                   ts=0.2,
                   subplot_kwargs={},
                   trace_kwargs={},
                   free_energy_kwargs={}):
    '''
    Plot a nrows x cols array of the tIC projections, starting from the 1st one.

    Parameters
    ----------
    tica_trajs: list, or np.ndarray
        The tica transformed trajectory(ies)
    nrows: int
        Number of rows
    ncols: int
        Number of cols
    ts: float or int (default: 0.2)
        Timestep (in microseconds) between each frame in the trajectory
    subplot_kwargs: dict, optional
        Arguments to pass to plt.subplots
    trace_kwargs: dict, optional
        Arguments to pass to msme.plot_trace
    free_energy_kwargs: dict, optional
        Arguments to pass to msme.plot_free_energy

    Returns
    -------
    ax : matplotlib axis
        matplotlib figure axis

    '''
    def convert_to_mus(x, pos):
        'function for formatting the x axis of time trace plot'
        return x * ts

    if isinstance(tica_trajs, list):
        tica_trajs = np.concatenate(tica_trajs)
    elif not isinstance(tica_trajs, np.ndarray):
        raise ValuError('tica_trajs must be of type list or np.ndarray')

    fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, **subplot_kwargs)
    for i in range(nrows):
        for j in range(ncols):
            if (nrows == 1) and (ncols == 1):
                # User just passed a 1x1 array...
                ax = ax_list
            elif nrows == 1:
                ax = ax_list[j]
            elif ncols == 1:
                ax = ax_list[i]
            else:
                ax = ax_list[i][j]
            ax.grid(False)
            if i == j:
                msme.plot_trace(tica_trajs[:, i], ax=ax, **trace_kwargs)
                formatter = FuncFormatter(convert_to_mus)
                ax.xaxis.set_major_formatter(formatter)

                ax.grid(False, axis='y')
                ax.set_xlabel('Time ($\mu$s)')
            else:
                msme.plot_free_energy(tica_trajs,
                                      obs=(j, i),
                                      ax=ax,
                                      **free_energy_kwargs)
                # Bottom row
                if i == (nrows - 1):
                    ax.set_xlabel('tIC{}'.format(j + 1))
            # First column
            if j == 0:
                ax.set_ylabel('tIC{}'.format(i + 1))
    fig.tight_layout()
    return fig