Exemple #1
0
def test_muller_potential_stats():
    # Set constants
    n_seq = 1
    num_trajs = 1
    T = 2500
    num_hotstart = 0

    # Generate data
    warnings.filterwarnings("ignore", category=DeprecationWarning)
    muller = MullerModel()
    data, trajectory, start = \
            muller.generate_dataset(n_seq, num_trajs, T)
    n_features = muller.x_dim
    n_components = muller.K

    # Fit reference model and initial MSLDS model
    refmodel = GaussianHMM(n_components=n_components,
                        covariance_type='full').fit(data)
    model = MetastableSwitchingLDS(n_components, n_features,
            n_hotstart=num_hotstart)
    model.inferrer._sequences = data
    model.means_ = refmodel.means_
    model.covars_ = refmodel.covars_
    model.transmat_ = refmodel.transmat_
    model.populations_ = refmodel.startprob_
    As = []
    for i in range(n_components):
        As.append(np.zeros((n_features, n_features)))
    model.As_ = As
    model.Qs_ = refmodel.covars_
    model.bs_ = refmodel.means_

    iteration = 0 # Remove this step once hot_start is factored out
    logprob, stats = model.inferrer.do_estep()
    rlogprob, rstats = reference_estep(refmodel, data)

    yield lambda: np.testing.assert_array_almost_equal(stats['post'],
            rstats['post'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['post[1:]'],
            rstats['post[1:]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['post[:-1]'],
            rstats['post[:-1]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs'],
            rstats['obs'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs[1:]'],
            rstats['obs[1:]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs[:-1]'],
            rstats['obs[:-1]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs*obs.T'],
            rstats['obs*obs.T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
            stats['obs*obs[t-1].T'], rstats['obs*obs[t-1].T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
            stats['obs[1:]*obs[1:].T'], rstats['obs[1:]*obs[1:].T'],
            decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
            stats['obs[:-1]*obs[:-1].T'], rstats['obs[:-1]*obs[:-1].T'],
            decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
            stats['trans'], rstats['trans'], decimal=1)
Exemple #2
0
    def __init__(self):
        self.K = 2
        self.x_dim = 1
        self.As = np.reshape(np.array([[0.6], [0.6]]),
                                (self.K, self.x_dim, self.x_dim))
        self.bs = np.reshape(np.array([[0.4], [-0.4]]),
                                (self.K, self.x_dim))
        self.Qs = np.reshape(np.array([[0.01], [0.01]]),
                                (self.K, self.x_dim, self.x_dim))
        self.Z = np.reshape(np.array([[0.995, 0.005], [0.005, 0.995]]),
                                (self.K, self.K))
        self.pi = np.reshape(np.array([0.99, 0.01]), (self.K,))
        self.mus = np.reshape(np.array([[1], [-1]]), (self.K, self.x_dim))
        self.Sigmas = np.reshape(np.array([[0.01], [0.01]]),
                                (self.K, self.x_dim, self.x_dim))

        # Generate Solver
        s = MetastableSwitchingLDS(self.K, self.x_dim)
        s.As_ = self.As
        s.bs_ = self.bs
        s.Qs_ = self.Qs
        s.transmat_ = self.Z
        s.populations_ = self.pi
        s.means_ = self.mus
        s.covars_ = self.Sigmas
        self._model = s
Exemple #3
0
def test_plusmin_stats():
    # Set constants
    num_hotstart = 3
    n_seq = 1
    T = 2000

    # Generate data
    plusmin = PlusminModel()
    data, hidden = plusmin.generate_dataset(n_seq, T)
    n_features = plusmin.x_dim
    n_components = plusmin.K

    # Fit reference model
    refmodel = GaussianHMM(n_components=n_components,
                        covariance_type='full').fit(data)
    warnings.filterwarnings("ignore", category=DeprecationWarning)

    # Fit initial MSLDS model from reference model
    model = MetastableSwitchingLDS(n_components, n_features,
                                n_hotstart=0)
    model.inferrer._sequences = data
    model.means_ = refmodel.means_
    model.covars_ = refmodel.covars_
    model.transmat_ = refmodel.transmat_
    model.populations_ = refmodel.startprob_
    model.As_ = [np.zeros((n_features, n_features)),
                    np.zeros((n_features, n_features))]
    model.Qs_ = refmodel.covars_
    model.bs_ = refmodel.means_

    iteration = 0 # Remove this step once hot_start is factored out
    logprob, stats = model.inferrer.do_estep()
    rlogprob, rstats = reference_estep(refmodel, data)

    yield lambda: np.testing.assert_array_almost_equal(stats['post'],
            rstats['post'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['post[1:]'],
            rstats['post[1:]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['post[:-1]'],
            rstats['post[:-1]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs'],
            rstats['obs'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs[1:]'],
            rstats['obs[1:]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs[:-1]'],
            rstats['obs[:-1]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs*obs.T'],
            rstats['obs*obs.T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
            stats['obs*obs[t-1].T'], rstats['obs*obs[t-1].T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
            stats['obs[1:]*obs[1:].T'], rstats['obs[1:]*obs[1:].T'],
            decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
            stats['obs[:-1]*obs[:-1].T'], rstats['obs[:-1]*obs[:-1].T'],
            decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
            stats['trans'], rstats['trans'], decimal=1)
Exemple #4
0
def test_randn_stats():
    """
    Sanity test MSLDS sufficient statistic gathering by setting
    dynamics model to 0 and testing that E-step matches that of
    HMM
    """
    warnings.filterwarnings("ignore", category=DeprecationWarning)
    # Generate reference data
    n_states = 2
    n_features = 3
    data = [np.random.randn(100, n_features),
            np.random.randn(100, n_features)]
    refmodel = GaussianHMM(n_components=n_states,
                        covariance_type='full').fit(data)

    # test all of the sufficient statistics against sklearn and pure python

    model = MetastableSwitchingLDS(n_states=n_states,
            n_features=n_features, n_hotstart=0)
    model.inferrer._sequences = data
    model.means_ = refmodel.means_
    model.covars_ = refmodel.covars_
    model.transmat_ = refmodel.transmat_
    model.populations_ = refmodel.startprob_
    # Is there a more elegant way to do this?
    model.As_ = [np.zeros((n_features, n_features)),
                    np.zeros((n_features, n_features))]
    model.Qs_ = refmodel.covars_
    model.bs_ = refmodel.means_

    iteration = 0 # Remove this step once hot_start is factored out
    logprob, stats = model.inferrer.do_estep()
    rlogprob, rstats = reference_estep(refmodel, data)

    yield lambda: np.testing.assert_array_almost_equal(stats['post'],
            rstats['post'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['post[1:]'],
            rstats['post[1:]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['post[:-1]'],
            rstats['post[:-1]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs'],
            rstats['obs'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs[1:]'],
            rstats['obs[1:]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs[:-1]'],
            rstats['obs[:-1]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs*obs.T'],
            rstats['obs*obs.T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
            stats['obs*obs[t-1].T'], rstats['obs*obs[t-1].T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
            stats['obs[1:]*obs[1:].T'], rstats['obs[1:]*obs[1:].T'],
            decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
            stats['obs[:-1]*obs[:-1].T'], rstats['obs[:-1]*obs[:-1].T'],
            decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
            stats['trans'], rstats['trans'], decimal=3)
Exemple #5
0
    def fit(self, train, n_states, train_lag_time, fold, args, outfile):
        kwargs = dict(n_states=n_states, n_features=self.n_features,
                      n_init=args.n_init, max_iters=args.max_iters,
                      n_em_iter=args.n_em_iter,
                      n_hotstart=args.n_hotstart,
                      reversible_type=args.reversible_type,
                      platform=args.platform,
                      display_solver_output=args.display_solver_output)
        print(kwargs)
        model = MetastableSwitchingLDS(**kwargs)

        start = time.time()
        model.fit(train)
        end = time.time()

        result = {
            'model': 'MetastableSwitchingLinearDynamicalSystem',
            'transmat': model.transmat_.tolist(),
            'populations': model.populations_.tolist(),
            'n_states': model.n_states,
            'split': args.split,
            'train_lag_time': train_lag_time,
            'train_time': end - start,
            'n_features': model.n_features,
            'means': model.means_.tolist(),
            'covars': model.covars_.tolist(),
            'As': model.As_.tolist(),
            'bs': model.bs_.tolist(),
            'Qs': model.Qs_.tolist(),
            'train_logprob': model.fit_logprob_[-1],
            'n_train_observations': sum(len(t) for t in train),
            'train_logprobs': model.fit_logprob_,
        }

        #result['test_logprob'] = model.score(test)
        result['test_lag_time'] = train_lag_time

        if not np.all(np.isfinite(model.transmat_)):
            print('Nonfinite numbers in transmat !!')

        json.dump(result, outfile)
        outfile.write('\n')
Exemple #6
0
def test_plusmin():
    # Set constants
    n_hotstart = 3
    n_em_iter = 3
    n_experiments = 1
    n_seq = 1
    T = 2000
    gamma = 512.

    # Generate data
    plusmin = PlusminModel()
    data, hidden = plusmin.generate_dataset(n_seq, T)
    n_features = plusmin.x_dim
    n_components = plusmin.K

    # Train MSLDS
    mslds_scores = []
    l = MetastableSwitchingLDS(n_components, n_features,
            n_hotstart=n_hotstart, n_em_iter=n_em_iter,
            n_experiments=n_experiments)
    l.fit(data, gamma=gamma)
    mslds_score = l.score(data)
    print("gamma = %f" % gamma)
    print("MSLDS Log-Likelihood = %f" %  mslds_score)
    print()

    # Fit Gaussian HMM for comparison
    g = GaussianFusionHMM(plusmin.K, plusmin.x_dim)
    g.fit(data)
    hmm_score = g.score(data)
    print("HMM Log-Likelihood = %f" %  hmm_score)
    print()

    # Plot sample from MSLDS
    sim_xs, sim_Ss = l.sample(T, init_state=0, init_obs=plusmin.mus[0])
    sim_xs = np.reshape(sim_xs, (n_seq, T, plusmin.x_dim))
    plt.close('all')
    plt.figure(1)
    plt.plot(range(T), data[0], label="Observations")
    plt.plot(range(T), sim_xs[0], label='Sampled Observations')
    plt.legend()
    plt.show()
Exemple #7
0
def test_doublewell():
    import pdb, traceback, sys
    try:
        n_components = 2
        n_features = 1
        n_em_iter = 1
        n_experiments = 1
        tol=1e-1

        data = load_doublewell(random_state=0)['trajectories']
        T = len(data[0])

        # Fit MSLDS model 
        model = MetastableSwitchingLDS(n_components, n_features,
            n_experiments=n_experiments, n_em_iter=n_em_iter)
        model.fit(data, gamma=.1, tol=tol)
        mslds_score = model.score(data)
        print("MSLDS Log-Likelihood = %f" %  mslds_score)

        # Fit Gaussian HMM for comparison
        g = GaussianFusionHMM(n_components, n_features)
        g.fit(data)
        hmm_score = g.score(data)
        print("HMM Log-Likelihood = %f" %  hmm_score)
        print()

        # Plot sample from MSLDS
        sim_xs, sim_Ss = model.sample(T, init_state=0)
        plt.close('all')
        plt.figure(1)
        plt.plot(range(T), data[0], label="Observations")
        plt.plot(range(T), sim_xs, label='Sampled Observations')
        plt.legend()
        plt.show()
    except:
        type, value, tb = sys.exc_info()
        traceback.print_exc()
        pdb.post_mortem(tb)
Exemple #8
0
def test_sufficient_statistics():
    # test all of the sufficient statistics against sklearn and pure python

    model = MetastableSwitchingLDS(n_states=N_STATES, n_features=refmodel.n_features)
    model._impl._sequences = data
    model.means_ = refmodel.means_
    model.covars_ = refmodel.covars_
    model.transmat_ = refmodel.transmat_
    model.populations_ = refmodel.startprob_
    
    logprob, stats = model._impl.do_estep()
    rlogprob, rstats = _sklearn_estep()

    yield lambda: np.testing.assert_array_almost_equal(stats['post'], rstats['post'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['post[1:]'], rstats['post[1:]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['post[:-1]'], rstats['post[:-1]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs'], rstats['obs'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs[1:]'], rstats['obs[1:]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs[:-1]'], rstats['obs[:-1]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs*obs.T'], rstats['obs*obs.T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs*obs[t-1].T'], rstats['obs*obs[t-1].T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs[1:]*obs[1:].T'], rstats['obs[1:]*obs[1:].T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['obs[:-1]*obs[:-1].T'], rstats['obs[:-1]*obs[:-1].T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(stats['trans'], rstats['trans'], decimal=3)
Exemple #9
0
def test_sufficient_statistics():
    # test all of the sufficient statistics against sklearn and pure python

    model = MetastableSwitchingLDS(n_states=N_STATES,
                                   n_features=refmodel.n_features)
    model._impl._sequences = data
    model.means_ = refmodel.means_
    model.covars_ = refmodel.covars_
    model.transmat_ = refmodel.transmat_
    model.populations_ = refmodel.startprob_

    logprob, stats = model._impl.do_estep()
    rlogprob, rstats = _sklearn_estep()

    yield lambda: np.testing.assert_array_almost_equal(
        stats['post'], rstats['post'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['post[1:]'], rstats['post[1:]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['post[:-1]'], rstats['post[:-1]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs'], rstats['obs'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs[1:]'], rstats['obs[1:]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs[:-1]'], rstats['obs[:-1]'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs*obs.T'], rstats['obs*obs.T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs*obs[t-1].T'], rstats['obs*obs[t-1].T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs[1:]*obs[1:].T'], rstats['obs[1:]*obs[1:].T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['obs[:-1]*obs[:-1].T'], rstats['obs[:-1]*obs[:-1].T'], decimal=3)
    yield lambda: np.testing.assert_array_almost_equal(
        stats['trans'], rstats['trans'], decimal=3)
Exemple #10
0
def test_alanine_dipeptide_stats():
    import pdb, traceback, sys
    warnings.filterwarnings("ignore", category=DeprecationWarning)
    try:
        b = fetch_alanine_dipeptide()
        trajs = b.trajectories
        # While debugging, restrict to first trajectory only
        trajs = [trajs[0]]
        n_seq = len(trajs)
        n_frames = trajs[0].n_frames
        n_atoms = trajs[0].n_atoms
        n_features = n_atoms * 3

        data_home = get_data_home()
        data_dir = join(data_home, TARGET_DIRECTORY_ALANINE)
        top = md.load(join(data_dir, 'ala2.pdb'))
        n_components = 2
        # Superpose m
        data = []
        for traj in trajs:
            traj.superpose(top)
            Z = traj.xyz
            Z = np.reshape(Z, (n_frames, n_features), order='F')
            data.append(Z)

        n_hotstart = 3
        # Fit reference model and initial MSLDS model
        refmodel = GaussianHMM(n_components=n_components,
                            covariance_type='full').fit(data)
        rlogprob, rstats = reference_estep(refmodel, data)

        model = MetastableSwitchingLDS(n_components, n_features,
                n_hotstart=n_hotstart)
        model.inferrer._sequences = data
        model.means_ = refmodel.means_
        model.covars_ = refmodel.covars_
        model.transmat_ = refmodel.transmat_
        model.populations_ = refmodel.startprob_
        As = []
        for i in range(n_components):
            As.append(np.zeros((n_features, n_features)))
        model.As_ = As
        Qs = []
        eps = 1e-7
        for i in range(n_components):
            Q = refmodel.covars_[i] + eps*np.eye(n_features)
            Qs.append(Q)
        model.Qs_ = Qs
        model.bs_ = refmodel.means_
        logprob, stats = model.inferrer.do_estep()

        yield lambda: np.testing.assert_array_almost_equal(stats['post'],
                rstats['post'], decimal=2)
        yield lambda: np.testing.assert_array_almost_equal(stats['post[1:]'],
                rstats['post[1:]'], decimal=2)
        yield lambda: np.testing.assert_array_almost_equal(stats['post[:-1]'],
                rstats['post[:-1]'], decimal=2)
        yield lambda: np.testing.assert_array_almost_equal(stats['obs'],
                rstats['obs'], decimal=1)
        yield lambda: np.testing.assert_array_almost_equal(stats['obs[1:]'],
                rstats['obs[1:]'], decimal=1)
        yield lambda: np.testing.assert_array_almost_equal(stats['obs[:-1]'],
                rstats['obs[:-1]'], decimal=1)
        yield lambda: np.testing.assert_array_almost_equal(stats['obs*obs.T'],
                rstats['obs*obs.T'], decimal=1)
        yield lambda: np.testing.assert_array_almost_equal(
                stats['obs*obs[t-1].T'], rstats['obs*obs[t-1].T'], decimal=1)
        yield lambda: np.testing.assert_array_almost_equal(
                stats['obs[1:]*obs[1:].T'], rstats['obs[1:]*obs[1:].T'],
                decimal=1)
        yield lambda: np.testing.assert_array_almost_equal(
                stats['obs[:-1]*obs[:-1].T'], rstats['obs[:-1]*obs[:-1].T'],
                decimal=1)
        # This test fails consistently. TODO: Figure out why.
        #yield lambda: np.testing.assert_array_almost_equal(
        #        stats['trans'], rstats['trans'], decimal=2)

    except:
        type, value, tb = sys.exc_info()
        traceback.print_exc()
        pdb.post_mortem(tb)
Exemple #11
0
def test_muller_potential():
    import pdb, traceback, sys
    try:
        # Set constants
        n_hotstart = 3
        n_em_iter = 3
        n_experiments = 1
        n_seq = 1
        num_trajs = 1
        T = 2500
        sim_T = 2500
        gamma = 200. 

        # Generate data
        warnings.filterwarnings("ignore", category=DeprecationWarning)
        muller = MullerModel()
        data, trajectory, start = \
                muller.generate_dataset(n_seq, num_trajs, T)
        n_features = muller.x_dim
        n_components = muller.K

        # Train MSLDS
        model = MetastableSwitchingLDS(n_components, n_features,
            n_hotstart=n_hotstart, n_em_iter=n_em_iter,
            n_experiments=n_experiments)
        model.fit(data, gamma=gamma)
        mslds_score = model.score(data)
        print("MSLDS Log-Likelihood = %f" %  mslds_score)

        # Fit Gaussian HMM for comparison
        g = GaussianFusionHMM(n_components, n_features)
        g.fit(data)
        hmm_score = g.score(data)
        print("HMM Log-Likelihood = %f" %  hmm_score)

        # Clear Display
        plt.cla()
        plt.plot(trajectory[start:, 0], trajectory[start:, 1], color='k')
        plt.scatter(model.means_[:, 0], model.means_[:, 1], 
                    color='r', zorder=10)
        plt.scatter(data[0][:, 0], data[0][:, 1],
                edgecolor='none', facecolor='k', zorder=1)
        Delta = 0.5
        minx = min(data[0][:, 0])
        maxx = max(data[0][:, 0])
        miny = min(data[0][:, 1])
        maxy = max(data[0][:, 1])
        sim_xs, sim_Ss = model.sample(sim_T, init_state=0,
                init_obs=model.means_[0])

        minx = min(min(sim_xs[:, 0]), minx) - Delta
        maxx = max(max(sim_xs[:, 0]), maxx) + Delta
        miny = min(min(sim_xs[:, 1]), miny) - Delta
        maxy = max(max(sim_xs[:, 1]), maxy) + Delta
        plt.scatter(sim_xs[:, 0], sim_xs[:, 1], edgecolor='none',
                   zorder=5, facecolor='g')
        plt.plot(sim_xs[:, 0], sim_xs[:, 1], zorder=5, color='g')


        MullerForce.plot(ax=plt.gca(), minx=minx, maxx=maxx,
                miny=miny, maxy=maxy)
        plt.show()
    except:
        type, value, tb = sys.exc_info()
        traceback.print_exc()
        pdb.post_mortem(tb)
Exemple #12
0
def test_alanine_dipeptide():
    import pdb, traceback, sys
    warnings.filterwarnings("ignore", 
                    category=DeprecationWarning)
    try:
        b = fetch_alanine_dipeptide()
        trajs = b.trajectories
        n_seq = len(trajs)
        n_frames = trajs[0].n_frames
        n_atoms = trajs[0].n_atoms
        n_features = n_atoms * 3
        sim_T = 1000
        data_home = get_data_home()
        data_dir = join(data_home, TARGET_DIRECTORY_ALANINE)
        top = md.load(join(data_dir, 'ala2.pdb'))
        n_components = 2
        # Superpose m
        data = []
        for traj in trajs:
            traj.superpose(top)
            Z = traj.xyz
            Z = np.reshape(Z, (len(Z), n_features), order='F')
            data.append(Z)

        # Fit MSLDS model 
        n_experiments = 1
        n_em_iter = 1
        tol = 1e-1
        model = MetastableSwitchingLDS(n_components, 
            n_features, n_experiments=n_experiments, 
            n_em_iter=n_em_iter) 
        model.fit(data, gamma=.1, tol=tol, verbose=True)
        mslds_score = model.score(data)
        print("MSLDS Log-Likelihood = %f" %  mslds_score)

        # Fit Gaussian HMM for comparison
        g = GaussianFusionHMM(n_components, n_features)
        g.fit(data)
        hmm_score = g.score(data)
        print("HMM Log-Likelihood = %f" %  hmm_score)
        print()

        # Generate a trajectory from learned model.
        sample_traj, hidden_states = model.sample(sim_T)
        states = []
        for k in range(n_components):
            states.append([])

        # Presort the data into the metastable wells
        for k in range(n_components):
            for i in range(len(trajs)):
                traj = trajs[i]
                Z = traj.xyz
                Z = np.reshape(Z, (len(Z), n_features), order='F')
                logprob = log_multivariate_normal_density(Z,
                    np.array(model.means_),
                    np.array(model.covars_), covariance_type='full')
                assignments = np.argmax(logprob, axis=1)
                #probs = np.max(logprob, axis=1)
                # pick structures that have highest log probability in state
                s = traj[assignments == k]
                states[k].append(s)

        # Pick frame from original trajectories closest to current sample
        gen_traj = None
        for t in range(sim_T):
            h = hidden_states[t]
            best_logprob = -np.inf
            best_frame = None
            for i in range(len(trajs)):
                if t > 0:
                    states[h][i].superpose(gen_traj, t-1)
                Z = states[h][i].xyz
                Z = np.reshape(Z, (len(Z), n_features), order='F')
                mean = sample_traj[t]
                logprobs = log_multivariate_normal_density(Z,
                    mean, model.Qs_[h], covariance_type='full')
                ind = np.argmax(logprobs, axis=0)
                logprob = logprobs[ind]
                if logprob > best_log_prob:
                    logprob = best_logprob
                    best_frame = states[h][i][ind]
            if t == 0:
                gen_traj = best_frame
            else:
                gen_traj = gen_traj.join(frame)
        gen_traj.save('%s.xtc' % self.out)
        gen_traj[0].save('%s.xtc.pdb' % self.out)
    except:
        type, value, tb = sys.exc_info()
        traceback.print_exc()
        pdb.post_mortem(tb)