コード例 #1
0
def msd(final_parameters, res):

    nboot = 200
    frac = 0.4
    ntraj = 100
    nsteps = 4806
    dt = 0.5
    endshow=2000

    trajectory_generator = GenARData(params=final_parameters)
    trajectory_generator.gen_trajectory(nsteps, ntraj, bound_dimensions=[0])

    # Calculate MSD and plot
    msd = ts.msd(trajectory_generator.traj, 1)
    error = ts.bootstrap_msd(msd, nboot, confidence=68)

    t = np.arange(endshow)*dt
    plt.plot(t, msd.mean(axis=1)[:endshow], lw=2, color='xkcd:blue')
    plt.fill_between(t, msd.mean(axis=1)[:endshow] + error[0, :endshow], msd.mean(axis=1)[:endshow] - error[1, :endshow], alpha=0.3, color='xkcd:blue')

    MD_MSD = file_rw.load_object('trajectories/%s_msd.pl' % res)
    plt.title(names[res], fontsize=18)
    plt.plot(t, MD_MSD.MSD_average[:endshow], color='black', lw=2)
    plt.fill_between(t, MD_MSD.MSD_average[:endshow] + MD_MSD.limits[0, :endshow], MD_MSD.MSD_average[:endshow] - MD_MSD.limits[1, :endshow], alpha=0.3, color='black')
    plt.tick_params(labelsize=14)
    plt.xlabel('Time (ns)', fontsize=14)
    plt.ylabel('Mean Squared Displacement (nm$^2$)', fontsize=14)

    plt.show()
コード例 #2
0
def plot_realization(final_parameters, res):

    MD_MSD = file_rw.load_object('trajectories/%s_msd.pl' % res)

    nsteps = MD_MSD.MSD.shape[0]

    trajectory_generator = GenARData(params=final_parameters)
    trajectory_generator.gen_trajectory(nsteps, 1, bound_dimensions=[0])

    fig, ax = plt.subplots(2, 1, figsize=(12, 5))

    ax[0].plot(trajectory_generator.traj[:, 0, 1], lw=2)
    ax[1].plot(trajectory_generator.traj[:, 0, 0], lw=2)

    ax[0].set_xlabel('Step number', fontsize=14)
    ax[0].set_ylabel('z coordinate', fontsize=14)
    ax[0].tick_params(labelsize=14)

    ax[1].set_xlabel('Step number', fontsize=14)
    ax[1].set_ylabel('r coordinate', fontsize=14)
    ax[1].tick_params(labelsize=14)

    plt.show()
コード例 #3
0
def test_cluster(final_parameters,
                 dt_A=None,
                 dt_sigma=None,
                 nclusters_A=None,
                 nclusters_sigma=None,
                 show=True):

    ihmmr = final_parameters['ihmmr']

    A = None
    sigma = None
    mu = None

    for t in range(24):

        estimated_states = ihmmr[t].z[0, :]
        found_states = list(np.unique(estimated_states))

        a = np.zeros([2, 2, len(found_states)
                      ])  # should probably an include a dimension for AR order
        s = np.zeros([2, 2, len(found_states)])
        m = np.zeros([2, len(found_states)])

        for i, state in enumerate(found_states):

            Amean = ihmmr[t].converged_params['A'][:, 0, ..., i].mean(axis=0)
            sigmamean = ihmmr[t].converged_params['sigma'][:, ...,
                                                           i].mean(axis=0)

            # we want to cluster on unconditional mean
            mucond = ihmmr[t].converged_params['mu'][..., i].mean(
                axis=0)  # conditional mean
            mumean = np.linalg.inv(np.eye(2) -
                                   Amean) @ mucond  # unconditional mean

            a[..., i] = Amean
            s[..., i] = sigmamean
            m[:, i] = mumean

        if A is None:
            A = a
            sigma = s
            mu = m
        else:
            A = np.concatenate((A, a), axis=-1)
            sigma = np.concatenate((sigma, s), axis=-1)
            mu = np.concatenate((mu, m), axis=-1)

    print(A.shape)
    sig_params = {'sigma': sigma}
    A_params = {'A': A}

    # default is diags
    eigs = False
    diags = True
    if cluster_vars == 'eigs':
        eigs = True
        diags = False

    fig, ax = plt.subplots(1, 2)

    sig_cluster = Cluster(sig_params,
                          eigs=eigs,
                          diags=diags,
                          algorithm=algorithm,
                          distance_threshold=dt_sigma,
                          nclusters=nclusters_sigma)
    A_cluster = Cluster(A_params,
                        eigs=eigs,
                        diags=diags,
                        algorithm=algorithm,
                        distance_threshold=dt_A,
                        nclusters=nclusters_A)

    sig_cluster.fit()
    A_cluster.fit()

    sigma_clusters = np.zeros([np.unique(sig_cluster.labels).size, 2])
    for j, k in enumerate(np.unique(sig_cluster.labels)):
        #print('Cluster %d' % k)
        ndx = np.where(sig_cluster.labels == k)[0]
        diagonals = np.zeros([2, len(ndx)])
        for i, n in enumerate(ndx):
            diagonals[:, i] = np.diag(sigma[..., n])
        sigma_clusters[j, :] = diagonals.mean(axis=1)

    A_clusters = np.zeros([np.unique(A_cluster.labels).size, 2])
    for j, k in enumerate(np.unique(A_cluster.labels)):
        #print('Cluster %d' % k)
        ndx = np.where(A_cluster.labels == k)[0]
        diagonals = np.zeros([2, len(ndx)])
        for i, n in enumerate(ndx):
            diagonals[:, i] = np.diag(A[..., n])
        A_clusters[j, :] = diagonals.mean(axis=1)

    ax[0].scatter(A_clusters[:, 0], A_clusters[:, 1])
    ax[1].scatter(sigma_clusters[:, 0], sigma_clusters[:, 1])

    nA_clusters = np.unique(A_cluster.labels).size
    nsig_clusters = np.unique(sig_cluster.labels).size
    print('Found %d sigma clusters' % nsig_clusters)
    print('Found %d A clusters' % nA_clusters)

    cluster_matrix = np.zeros([nA_clusters, nsig_clusters])

    new_clusters = np.zeros([A.shape[-1]])

    for state in range(A.shape[-1]):
        new_clusters[state] = A_cluster.labels[
            state] * nsig_clusters + sig_cluster.labels[state]

    print('Found %d total clusters' % np.unique(new_clusters).size)

    if show:
        plt.show()
        shift = 3

        for ndx, c in enumerate(np.unique(new_clusters)):

            print('Cluster %d' % c)

            fig, ax = plt.subplots(1, 2, figsize=(12, 7), sharey=True)
            fig2, ax2 = plt.subplots(1, 2, figsize=(12, 7), sharey=True)

            ndx = np.where(new_clusters == c)[0]

            print(len(ndx))

            Adiags = np.array([np.diag(A[..., a]) for a in ndx])
            sigdiags = np.array([np.diag(sigma[..., s]) for s in ndx])

            ax2[0].scatter(Adiags[:, 0], Adiags[:, 1])
            ax2[1].scatter(sigdiags[:, 0], sigdiags[:, 1])

            for i, n in enumerate(
                    np.random.choice(ndx, size=min(4, len(ndx)),
                                     replace=False)):
                print(np.diag(A[..., n]), np.diag(sigma[..., n]))

                parameters = {
                    'pi_init': [1],
                    'T': np.array([[1]]),
                    'mu': np.zeros(2),
                    'A': A[..., n][np.newaxis, np.newaxis, ...],
                    'sigma': sigma[..., n][np.newaxis, ...]
                }

                trajectory_generator = GenARData(params=parameters)
                trajectory_generator.gen_trajectory(1000,
                                                    1,
                                                    state_no=0,
                                                    progress=False)

                t = trajectory_generator.traj[:, 0, :]
                t -= t.mean(axis=0)

                ax[0].plot(t[:, 1] + i * shift, lw=2)
                ax[1].plot(t[:, 0] + i * shift, lw=2)
                #ax[2].text(0, s*shift, '%.1f %%' % (100*fraction[s]), fontsize=20, horizontalalignment='center')

            ax[0].set_xlabel('Step Number', fontsize=14)
            ax[1].set_xlabel('Step Number', fontsize=14)
            ax[0].set_title('$z$ direction', fontsize=16)
            ax[1].set_title('$r$ direction', fontsize=16)
            ax[0].tick_params(labelsize=14)
            ax[1].tick_params(labelsize=14)
            plt.tight_layout()

            plt.show()
コード例 #4
0
def cluster_behavior(params, percent, n=2):
    """
    percent: only show states that are in this percent of total trajectories
    """

    z = params['z']
    ihmmr = params['ihmmr']
    mu = params['mu']

    if res == 'MET':
        clustered_sequence = ihmmr[n].clustered_state_sequence[0, :]
        nclusters = np.unique(clustered_sequence).size
    else:
        nclusters = np.unique(z).size

    state_counts = dict()

    for n in range(24):

        unique_states = np.unique(ihmmr[n].clustered_state_sequence[0, :])
        #print(unique_states)
        for u in unique_states:
            if u in state_counts.keys():
                state_counts[u] += 1
            else:
                state_counts[u] = 1

    nstates = max(state_counts.keys()) + 1
    state_counts = np.array([state_counts[i] for i in range(nstates)])
    fraction = state_counts / 24

    prevelant_states = np.where(fraction > (percent / 100))[0]

    # For methanol
    cmap = plt.cm.jet
    if res == 'MET':
        prevelant_states = np.concatenate((prevelant_states, [14]))

        shown_colors = np.array(
            [cmap(i) for i in np.linspace(50, 225, nclusters).astype(int)])
        colors = np.array([
            cmap(i)
            for i in np.linspace(50, 225,
                                 clustered_sequence.max() + 1).astype(int)
        ])
        colors[np.unique(clustered_sequence)] = shown_colors
        colors[14] = colors[28]  # hack

    else:
        colors = np.array(
            [cmap(i) for i in np.linspace(50, 225, nclusters + 1).astype(int)])

    print('Prevelant States:', prevelant_states)

    #count = np.zeros(np.unique(z).size)

    #for i in range(z.shape[0]):
    #    for j in range(count.size):
    #        count[j] += len(np.where(z[i, :] == j)[0])

    #count /= count.sum()

    #sorted_ndx = np.argsort(count)[::-1]

    #stop = np.where(np.cumsum(count[sorted_ndx]) > (top_percent / 100))[0][0]
    #prevelant_states = sorted_ndx[:(stop + 1)]

    shift = 0.75

    fig, ax = plt.subplots(1,
                           3,
                           figsize=(10, 10),
                           sharey=False,
                           gridspec_kw={'width_ratios': [1, 1, 0.15]})
    #fig, ax = plt.subplots(1, 2, figsize=(7, 7), sharey=True)

    trajectory_generator = GenARData(params=final_parameters)

    A = final_parameters['A']
    sigma = final_parameters['sigma']
    T = final_parameters['T']

    fig1, Tax = plt.subplots()
    fig2, Aax = plt.subplots()
    #sigax = Aax.twinx()

    bin_width = 0.2

    #for i, p in enumerate(np.unique(z)):
    #    if mu[p, 0] > 2:
    #        print(i, p)
    #        print(0.5 * dwell(T[p, p]), mu[p, 0], state_counts[i])
    #exit()

    if res == 'MET':
        new_order = [0, 2, 4, 5, 1, 3, 6]
        prevelant_states = prevelant_states[new_order]

    for i, s in enumerate(prevelant_states):

        Adiag = np.diag(A[s, 0, ...])
        sigdiag = np.diag(sigma[s, ...])

        print(np.diag(A[s, 0, ...]), np.diag(sigma[s, ...]))
        trajectory_generator.gen_trajectory(100, 1, state_no=s, progress=False)

        t = trajectory_generator.traj[:, 0, :]
        t -= t.mean(axis=0)

        Tax.bar(i, 0.5 * dwell(T[s, s]), color=colors[s], edgecolor='black')

        Aax.scatter(sigdiag[0],
                    Adiag[0],
                    color=colors[s],
                    edgecolor='black',
                    s=100)
        Aax.scatter(sigdiag[1],
                    Adiag[1],
                    color=colors[s],
                    edgecolor='black',
                    s=100,
                    marker='^')

        #Aax.bar(i - 1.5 * (bin_width), Adiag[0], bin_width, color=colors[s], edgecolor='black', alpha=1)
        #Aax.bar(i - 0.5 * (bin_width), Adiag[1], bin_width, color=colors[s], edgecolor='black', alpha=1)
        #sigax.bar(i + 0.5 * (bin_width), sigdiag[0], bin_width, color=colors[s], edgecolor='black', alpha=0.5)
        #sigax.bar(i + 1.5 * (bin_width), sigdiag[1], bin_width, color=colors[s], edgecolor='black', alpha=0.5)

        ax[0].plot(t[:, 1] + i * shift, lw=2, color=colors[s])
        ax[1].plot(t[:, 0] + i * shift, lw=2, color=colors[s])
        ax[2].text(0,
                   i * shift,
                   '%.1f %%' % (100 * fraction[s]),
                   fontsize=16,
                   horizontalalignment='center')

    ax[0].set_yticks([i * shift for i in range(len(prevelant_states))])
    #ax[0].set_yticklabels(['%.1f %%' % (100*fraction[s]) for s in prevelant_states])

    ax[0].set_yticklabels(
        ['%d' % (s + 1) for s in range(len(prevelant_states))])

    ax[1].set_yticks([i * shift for i in range(len(prevelant_states))])
    ax[1].set_yticklabels(['%.1f' % mu[p, 0] for p in prevelant_states])

    ax[0].set_xlabel('Step Number', fontsize=18)
    ax[0].set_ylabel('State Number', fontsize=18)
    ax[1].set_xlabel('Step Number', fontsize=18)
    ax[0].set_title('$z$ direction', fontsize=18)
    ax[1].set_title('$r$ direction', fontsize=18)
    ax[0].tick_params(labelsize=16)
    ax[1].tick_params(labelsize=16)
    ax[1].set_ylabel('Cluster radial mean', fontsize=18)
    ax[2].axis('off')
    ax[2].set_yticks([i * shift for i in range(len(prevelant_states))])
    ax[2].set_title('Percentage\nPrevalence', fontsize=18)

    ax[2].set_xlim(0, 1)

    ax[0].set_ylim(-shift, shift * len(prevelant_states))
    ax[1].set_ylim(-shift, shift * len(prevelant_states))
    ax[2].set_ylim(-shift, shift * len(prevelant_states))

    circle = mlines.Line2D([], [],
                           color='white',
                           markeredgecolor='black',
                           marker='o',
                           linestyle=None,
                           label='radial dimension',
                           markersize=12)
    square = mlines.Line2D([], [],
                           color='white',
                           markeredgecolor='black',
                           marker='^',
                           linestyle=None,
                           label='axial dimension',
                           markersize=12)

    Tax.set_ylabel('Average dwell time (ns)', fontsize=18)
    Tax.set_xticklabels([])

    Aax.set_ylabel('A diagonals', fontsize=18)
    Aax.set_xlabel('$\Sigma$ diagonals', fontsize=18)
    #Aax.set_xticklabels([])
    Aax.tick_params(labelsize=16)
    Tax.tick_params(labelsize=16)
    Aax.legend(handles=[circle, square], fontsize=16)
    #sigax.tick_params(labelsize=14)
    Tax.set_xticks([i for i in range(len(prevelant_states))])
    Tax.set_xticklabels([i + 1 for i in range(len(prevelant_states))])
    Tax.set_xlabel('State Number', fontsize=16)

    fig1.tight_layout()
    fig2.tight_layout()
    fig.tight_layout()

    fig.savefig(
        '/home/ben/github/LLC_Membranes/Ben_Manuscripts/hdphmm/figures/common_states_%s.pdf'
        % res)
    fig1.savefig(
        '/home/ben/github/LLC_Membranes/Ben_Manuscripts/hdphmm/figures/dwell_times_%s.pdf'
        % res)
    fig2.savefig(
        '/home/ben/github/LLC_Membranes/Ben_Manuscripts/hdphmm/figures/A_sigma_scatter_%s.pdf'
        % res)

    plt.show()
コード例 #5
0
def view_clusters(final_parameters, shift=3, show_states='all'):

    all_params = final_parameters['all_state_params']
    A = all_params['A']
    sigma = all_params['sigma']
    new_clusters = all_params['state_labels']

    centroids = np.zeros([2, 3, np.unique(new_clusters).size
                          ])  # [A/sig, 2 dimensions + mass, cluster]

    if show_states == 'all':
        show_states = np.arange(np.unique(new_clusters).size)

    for i, c in enumerate(np.unique(new_clusters)):

        ndx = np.where(new_clusters == c)[0]

        Adiags = np.array([np.diag(A[..., a]) for a in ndx])
        sigdiags = np.array([np.diag(sigma[..., s]) for s in ndx])

        centroids[0, :, i] = np.concatenate((Adiags.mean(axis=0), [len(ndx)]))
        centroids[1, :, i] = np.concatenate(
            (sigdiags.mean(axis=0), [len(ndx)]))

        if i in show_states:

            print('Cluster %d' % c)
            print(len(ndx))

            fig, ax = plt.subplots(2, 2, figsize=(12, 7))  #, sharey=True)

            ax[1, 0].scatter(Adiags[:, 0], Adiags[:, 1])
            ax[1, 1].scatter(sigdiags[:, 0], sigdiags[:, 1])

            for j, n in enumerate(
                    np.random.choice(ndx, size=min(4, len(ndx)),
                                     replace=False)):
                print(np.diag(A[..., n]), np.diag(sigma[..., n]))
                parameters = {
                    'pi_init': [1],
                    'T': np.array([[1]]),
                    'mu': np.zeros(2),
                    'A': A[..., n][np.newaxis, np.newaxis, ...],
                    'sigma': sigma[..., n][np.newaxis, ...]
                }

                trajectory_generator = GenARData(params=parameters)
                trajectory_generator.gen_trajectory(1000,
                                                    1,
                                                    state_no=0,
                                                    progress=False)

                t = trajectory_generator.traj[:, 0, :]
                t -= t.mean(axis=0)

                ax[0, 0].plot(t[:, 1] + j * shift, lw=2)
                ax[0, 1].plot(t[:, 0] + j * shift, lw=2)

            ax[0, 0].set_title('$z$ direction', fontsize=16)
            ax[0, 1].set_title('$r$ direction', fontsize=16)

            ax[0, 0].set_xlabel('Step Number', fontsize=14)

            ax[0, 0].set_ylabel('Shifted $z$ coordinate', fontsize=14)
            ax[0, 1].set_xlabel('Step Number', fontsize=14)
            ax[0, 1].set_ylabel('Shifted $r$ coordinate', fontsize=14)

            ax[1, 0].set_xlabel('A(0, 0)', fontsize=14)
            ax[1, 0].set_ylabel('A(1, 1)', fontsize=14)
            ax[1, 1].set_xlabel('$\Sigma$(0, 0)', fontsize=14)
            ax[1, 1].set_ylabel('$\Sigma$(1, 1)', fontsize=14)

            ax[0, 0].tick_params(labelsize=14)
            ax[0, 1].tick_params(labelsize=14)
            ax[0, 0].tick_params(labelsize=14)
            ax[0, 1].tick_params(labelsize=14)

            plt.tight_layout()

            plt.show()

    # plot centroids
    bubble_scale = 0.02
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    cmap = plt.cm.jet
    colors = np.array([
        cmap(i)
        for i in np.random.choice(np.arange(cmap.N), size=centroids.shape[-1])
    ])
    for i in range(centroids.shape[-1]):

        if centroids[0, 2, i] > 5:

            ax[0].scatter(centroids[0, 0, i],
                          centroids[0, 1, i],
                          color=colors[i])
            ax[1].scatter(centroids[1, 0, i],
                          centroids[1, 1, i],
                          color=colors[i])
            ax[0].add_patch(
                Circle(centroids[0, :, i],
                       np.sqrt(centroids[0, 2, i] / np.pi) * bubble_scale,
                       color=colors[i],
                       alpha=0.3,
                       lw=2))
            ax[1].add_patch(
                Circle(centroids[1, :, i],
                       np.sqrt(centroids[0, 2, i] / np.pi) * bubble_scale *
                       0.4,
                       color=colors[i],
                       alpha=0.3,
                       lw=2))

    ax[0].set_xlabel('A(0, 0)', fontsize=14)
    ax[0].set_ylabel('A(1, 1)', fontsize=14)
    ax[1].set_xlabel('$\Sigma$(0, 0)', fontsize=14)
    ax[1].set_ylabel('$\Sigma$(1, 1)', fontsize=14)

    ax[0].set_xlim(-.2, 1)
    ax[0].set_ylim(-.2, 1)

    ax[1].set_xlim(-0.1, 0.3)
    ax[1].set_ylim(-0.1, 0.3)

    ax[0].set_aspect(1)
    ax[1].set_aspect(1)

    plt.tight_layout()
    plt.show()
コード例 #6
0
def ihmm(res, traj_no, ntraj, hyperparams, plot=False, niter=100):

    print('Trajectory %d' % traj_no)
    difference = False  # take first order difference of solute trajectories
    observation_model = 'AR'  # assume an autoregressive model (that's the only model implemented)
    order = 1  # autoregressive order
    max_states = 100  # More is usually better
    dim = [0, 1, 2]  # dimensions of trajectory to keep
    prior = 'MNIW-N'  # MNIW-N (includes means) or MNIW (forces means to zero)
    link = False  # link trajectories and add phantom state
    keep_xy = True
    save_every = 1

    # You can define a dictionary with some spline paramters
    spline_params = {
        'npts_spline': 10,
        'save': True,
        'savename': 'spline_hdphmm.pl'
    }

    com_savename = 'trajectories/com_xy_radial_%s.pl' % res

    com = 'trajectories/com_xy_radial_%s.pl' % res  # center of mass trajectories. If it exists, we can skip loading the MD trajectory and just load this
    gro = 'berendsen.gro'

    ihmm = hdphmm.InfiniteHMM(com,
                              traj_no=traj_no,
                              load_com=True,
                              difference=difference,
                              observation_model=observation_model,
                              order=order,
                              max_states=max_states,
                              dim=dim,
                              spline_params=spline_params,
                              prior=prior,
                              hyperparams=hyperparams,
                              keep_xy=keep_xy,
                              com_savename=com_savename,
                              gro=gro,
                              radial=True,
                              save_com=True,
                              save_every=save_every)

    ihmm.inference(niter)

    #ihmm.summarize_results(traj_no=0)

    ihmm._get_params(quiet=True)

    radial = np.zeros([ihmm.com.shape[0], 1, 2])
    radial[:, 0, 0] = np.linalg.norm(ihmm.com[:, 0, :2], axis=1)
    radial[:, 0, 1] = ihmm.com[:, 0, 2]

    ihmmr = hdphmm.InfiniteHMM((radial, ihmm.dt),
                               traj_no=[0],
                               load_com=False,
                               difference=False,
                               order=1,
                               max_states=100,
                               dim=[0, 1],
                               spline_params=spline_params,
                               prior='MNIW-N',
                               hyperparams=None,
                               save_com=False,
                               state_sequence=ihmm.z)

    ihmmr.inference(niter)
    #ihmmr.summarize_results(traj_no=0)
    ihmmr._get_params(traj_no=0)

    estimated_states = ihmmr.z[0, :]

    found_states = list(np.unique(estimated_states))

    # for rare cases where there is a unique state found at the end of the trajectory
    for i, f in enumerate(found_states):

        ndx = np.where(ihmmr.z[0, :] == f)[0]

        if len(ndx) == 1:
            if ndx[0] >= ihmmr.nT - 2:
                del found_states[i]

    ihmmr.found_states = found_states

    A = np.zeros([len(found_states), 1, 2,
                  2])  # should probably an include a dimension for AR order
    sigma = np.zeros([len(found_states), 2, 2])
    mu = np.zeros([len(found_states), 2])

    for i in range(len(found_states)):

        A[i, 0, ...] = ihmmr.converged_params['A'][:, 0, ..., i].mean(axis=0)
        sigma[i, ...] = ihmmr.converged_params['sigma'][:, ..., i].mean(axis=0)

        # we want to cluster on unconditional mean
        mucond = ihmmr.converged_params['mu'][..., i].mean(
            axis=0)  # conditional mea
        mumean = np.linalg.inv(np.eye(2) -
                               A[i, 0, ...]) @ mucond  # unconditional mean
        mu[i, :] = mumean

    nstates = len(ihmmr.found_states)

    ndx_dict = {ihmmr.found_states[i]: i for i in range(nstates)}

    count_matrix = np.zeros([nstates, nstates])

    for frame in range(
            1, ihmmr.nT -
            1):  # start at frame 1. May need to truncate more as equilibration
        try:
            transitioned_from = [ndx_dict[i] for i in ihmmr.z[:, frame - 1]]
            transitioned_to = [ndx_dict[i] for i in ihmmr.z[:, frame]]
            for pair in zip(transitioned_from, transitioned_to):
                count_matrix[pair[0], pair[1]] += 1
        except KeyError:
            pass

    # The following is very similar to ihmm3.pi_z. The difference is due to the dirichlet process.
    transition_matrix = (count_matrix.T / count_matrix.sum(axis=1)).T

    # Initial distribution of states
    init_state = ihmmr.z[:, 0]
    pi_init = np.zeros([nstates])
    for i, c in enumerate(ihmmr.found_states):
        pi_init[i] = np.where(init_state == c)[0].size

    pi_init /= pi_init.sum()

    final_parameters = {
        'A': A,
        'sigma': sigma,
        'mu': mu,
        'T': transition_matrix,
        'pi_init': pi_init
    }

    MD_MSD = file_rw.load_object('trajectories/%s_msd.pl' % res)

    nboot = 200
    frac = 0.4
    nsteps = MD_MSD.MSD_average.shape[0]  #4806
    dt = 0.5
    endshow = 2000  #int(nsteps*frac)

    trajectory_generator = GenARData(params=final_parameters)
    trajectory_generator.gen_trajectory(nsteps, ntraj, bound_dimensions=[0])

    return trajectory_generator