def test_largest_connected_set(self): """Directed""" lcc = largest_connected_set(self.C) self.assertTrue(np.all(self.lcc_directed == np.sort(lcc))) """Undirected""" lcc = largest_connected_set(self.C, directed=False) self.assertTrue(np.all(self.lcc_undirected == np.sort(lcc)))
def setUp(self): """Store state of the rng""" self.state = np.random.mtrand.get_state() """Reseed the rng to enforce 'deterministic' behavior""" np.random.mtrand.seed(42) """Meta-stable birth-death chain""" b = 2 q = np.zeros(7) p = np.zeros(7) q[1:] = 0.5 p[0:-1] = 0.5 q[2] = 1.0 - 10**(-b) q[4] = 10**(-b) p[2] = 10**(-b) p[4] = 1.0 - 10**(-b) bdc = BirthDeathChain(q, p) self.dtraj = bdc.msm.simulate(10000, start=0) self.tau = 1 """Estimate MSM""" self.C_MSM = count_matrix(self.dtraj, self.tau, sliding=True) self.lcc_MSM = largest_connected_set(self.C_MSM) self.Ccc_MSM = largest_connected_submatrix(self.C_MSM, lcc=self.lcc_MSM) self.mle_rev_max_err = 1E-8 self.P_MSM = transition_matrix(self.Ccc_MSM, reversible=True, maxerr=self.mle_rev_max_err) self.mu_MSM = stationary_distribution(self.P_MSM) self.k = 3 self.ts = timescales(self.P_MSM, k=self.k, tau=self.tau)
def _prepare_input_revpi(self, C, pi): """Max. state index visited by trajectories""" nC = C.shape[0] # Max. state index of the stationary vector array npi = pi.shape[0] # pi has to be defined on all states visited by the trajectories if nC > npi: raise ValueError( 'There are visited states for which no stationary probability is given' ) # Reduce pi to the visited set pi_visited = pi[0:nC] # Find visited states with positive stationary probabilities""" pos = _np.where(pi_visited > 0.0)[0] # Reduce C to positive probability states""" C_pos = largest_connected_submatrix(C, lcc=pos) if C_pos.sum() == 0.0: errstr = """The set of states with positive stationary probabilities is not visited by the trajectories. A MSM reversible with respect to the given stationary vector can not be estimated""" raise ValueError(errstr) # Compute largest connected set of C_pos, undirected connectivity""" lcc = largest_connected_set(C_pos, directed=False) return pos[lcc]
def equilibrium_transition_matrix(Xi, omega, sigma, reversible=True, return_lcc=True): """ Compute equilibrium transition matrix from OOM components: Parameters ---------- Xi : ndarray(M, N, M) matrix of set-observable operators omega: ndarray(M,) information state vector of OOM sigma : ndarray(M,) evaluator of OOM reversible : bool, optional, default=True symmetrize corrected count matrix in order to obtain a reversible transition matrix. return_lcc: bool, optional, default=True return indices of largest connected set. Returns ------- Tt_Eq : ndarray(N, N) equilibrium transition matrix lcc : ndarray(M,) the largest connected set of the transition matrix. """ import deeptime.markov.tools.estimation as me # Compute equilibrium transition matrix: Ct_Eq = np.einsum('j,jkl,lmn,n->km', omega, Xi, Xi, sigma) # Remove negative entries: Ct_Eq[Ct_Eq < 0.0] = 0.0 # Compute transition matrix after symmetrization: pi_r = np.sum(Ct_Eq, axis=1) if reversible: pi_c = np.sum(Ct_Eq, axis=0) pi_sym = pi_r + pi_c # Avoid zero row-sums. States with zero row-sums will be eliminated by active set update. ind0 = np.where(pi_sym == 0.0)[0] pi_sym[ind0] = 1.0 Tt_Eq = (Ct_Eq + Ct_Eq.T) / pi_sym[:, None] else: # Avoid zero row-sums. States with zero row-sums will be eliminated by active set update. ind0 = np.where(pi_r == 0.0)[0] pi_r[ind0] = 1.0 Tt_Eq = Ct_Eq / pi_r[:, None] # Perform active set update: lcc = me.largest_connected_set(Tt_Eq) Tt_Eq = me.largest_connected_submatrix(Tt_Eq, lcc=lcc) if return_lcc: return Tt_Eq, lcc else: return Tt_Eq
def test_birth_death_chain(fixed_seed, sparse): """Meta-stable birth-death chain""" b = 2 q = np.zeros(7) p = np.zeros(7) q[1:] = 0.5 p[0:-1] = 0.5 q[2] = 1.0 - 10**(-b) q[4] = 10**(-b) p[2] = 10**(-b) p[4] = 1.0 - 10**(-b) bdc = deeptime.data.birth_death_chain(q, p) dtraj = bdc.msm.simulate(10000, start=0) tau = 1 reference_count_matrix = msmest.count_matrix(dtraj, tau, sliding=True) reference_largest_connected_component = msmest.largest_connected_set( reference_count_matrix) reference_lcs = msmest.largest_connected_submatrix( reference_count_matrix, lcc=reference_largest_connected_component) reference_msm = msmest.transition_matrix(reference_lcs, reversible=True, maxerr=1e-8) reference_statdist = msmana.stationary_distribution(reference_msm) k = 3 reference_timescales = msmana.timescales(reference_msm, k=k, tau=tau) msm = estimate_markov_model(dtraj, tau, sparse=sparse) assert_equal(tau, msm.count_model.lagtime) assert_array_equal(reference_largest_connected_component, msm.count_model.connected_sets()[0]) assert_(scipy.sparse.issparse(msm.count_model.count_matrix) == sparse) assert_(scipy.sparse.issparse(msm.transition_matrix) == sparse) if sparse: count_matrix = msm.count_model.count_matrix.toarray() transition_matrix = msm.transition_matrix.toarray() else: count_matrix = msm.count_model.count_matrix transition_matrix = msm.transition_matrix assert_array_almost_equal(reference_lcs.toarray(), count_matrix) assert_array_almost_equal(reference_count_matrix.toarray(), count_matrix) assert_array_almost_equal(reference_msm.toarray(), transition_matrix) assert_array_almost_equal(reference_statdist, msm.stationary_distribution) assert_array_almost_equal(reference_timescales[1:], msm.timescales(k - 1))
def __init__(self, complete: bool = True): self.complete = complete data = np.load(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'resources', 'TestData_OOM_MSM.npz')) if complete: self.dtrajs = [data['arr_%d' % k] for k in range(1000)] else: excluded = [ 21, 25, 30, 40, 66, 72, 74, 91, 116, 158, 171, 175, 201, 239, 246, 280, 300, 301, 310, 318, 322, 323, 339, 352, 365, 368, 407, 412, 444, 475, 486, 494, 510, 529, 560, 617, 623, 637, 676, 689, 728, 731, 778, 780, 811, 828, 838, 845, 851, 859, 868, 874, 895, 933, 935, 938, 958, 961, 968, 974, 984, 990, 999 ] self.dtrajs = [data['arr_%d' % k] for k in np.setdiff1d(np.arange(1000), excluded)] # Number of states: self.N = 5 # Lag time: self.tau = 5 self.dtrajs_lag = [traj[:-self.tau] for traj in self.dtrajs] # Rank: if complete: self.rank = 3 else: self.rank = 2 # Build models: self.msmrev = OOMReweightedMSM(lagtime=self.tau, rank_mode='bootstrap_trajs').fit(self.dtrajs) self.msmrev_sparse = OOMReweightedMSM(lagtime=self.tau, sparse=True, rank_mode='bootstrap_trajs') \ .fit(self.dtrajs) self.msm = OOMReweightedMSM(lagtime=self.tau, reversible=False, rank_mode='bootstrap_trajs').fit(self.dtrajs) self.msm_sparse = OOMReweightedMSM(lagtime=self.tau, reversible=False, sparse=True, rank_mode='bootstrap_trajs').fit(self.dtrajs) self.estimators = [self.msmrev, self.msm, self.msmrev_sparse, self.msm_sparse] self.msms = [est.fetch_model() for est in self.estimators] # Reference count matrices at lag time tau and 2*tau: if complete: self.C2t = data['C2t'] else: self.C2t = data['C2t_s'] self.Ct = np.sum(self.C2t, axis=1) if complete: self.Ct_active = self.Ct self.C2t_active = self.C2t self.active_faction = 1. else: lcc = msmest.largest_connected_set(self.Ct) self.Ct_active = msmest.largest_connected_submatrix(self.Ct, lcc=lcc) self.C2t_active = self.C2t[:4, :4, :4] self.active_fraction = np.sum(self.Ct_active) / np.sum(self.Ct) # Compute OOM-components: self.Xi, self.omega, self.sigma, self.l = oom_transformations(self.Ct_active, self.C2t_active, self.rank) # Compute corrected transition matrix: Tt_rev = compute_transition_matrix(self.Xi, self.omega, self.sigma, reversible=True) Tt = compute_transition_matrix(self.Xi, self.omega, self.sigma, reversible=False) # Build reference models: self.rmsmrev = MarkovStateModel(Tt_rev) self.rmsm = MarkovStateModel(Tt) # Active count fraction: self.hist = count_states(self.dtrajs) self.active_hist = self.hist[:-1] if not complete else self.hist self.active_count_frac = float(np.sum(self.active_hist)) / np.sum(self.hist) if not complete else 1. self.active_state_frac = 0.8 if not complete else 1. # Commitor and MFPT: a = np.array([0, 1]) b = np.array([4]) if complete else np.array([3]) self.comm_forward = self.rmsm.committor_forward(a, b) self.comm_forward_rev = self.rmsmrev.committor_forward(a, b) self.comm_backward = self.rmsm.committor_backward(a, b) self.comm_backward_rev = self.rmsmrev.committor_backward(a, b) self.mfpt = self.tau * self.rmsm.mfpt(a, b) self.mfpt_rev = self.tau * self.rmsmrev.mfpt(a, b) # PCCA: pcca = self.rmsmrev.pcca(3 if complete else 2) self.pcca_ass = pcca.assignments self.pcca_dist = pcca.metastable_distributions self.pcca_mem = pcca.memberships self.pcca_sets = pcca.sets # Experimental quantities: a = np.array([1, 2, 3, 4, 5]) b = np.array([1, -1, 0, -2, 4]) p0 = np.array([0.5, 0.2, 0.2, 0.1, 0.0]) if not complete: a = a[:-1] b = b[:-1] p0 = p0[:-1] pi = self.rmsm.stationary_distribution pi_rev = self.rmsmrev.stationary_distribution _, _, L_rev = ma.rdl_decomposition(Tt_rev) self.exp = np.dot(self.rmsm.stationary_distribution, a) self.exp_rev = np.dot(self.rmsmrev.stationary_distribution, a) self.corr_rev = np.zeros(10) self.rel = np.zeros(10) self.rel_rev = np.zeros(10) for k in range(10): Ck_rev = np.dot(np.diag(pi_rev), np.linalg.matrix_power(Tt_rev, k)) self.corr_rev[k] = np.dot(a.T, np.dot(Ck_rev, b)) self.rel[k] = np.dot(p0.T, np.dot(np.linalg.matrix_power(Tt, k), a)) self.rel_rev[k] = np.dot(p0.T, np.dot(np.linalg.matrix_power(Tt_rev, k), a)) self.fing_cor = np.dot(a.T, L_rev.T) * np.dot(b.T, L_rev.T) self.fing_rel = np.dot(a.T, L_rev.T) * np.dot((p0 / pi_rev).T, L_rev.T)
def _compute_csets( connectivity, state_counts, count_matrices, ttrajs, dtrajs, bias_trajs, nn, equilibrium_state_counts=None, factor=1.0, callback=None): n_therm_states, n_conf_states = state_counts.shape if equilibrium_state_counts is not None: all_state_counts = state_counts + equilibrium_state_counts else: all_state_counts = state_counts if connectivity is None: cset_projected = _np.where(all_state_counts.sum(axis=0) > 0)[0] csets = [ _np.where(all_state_counts[k, :] > 0)[0] for k in range(n_therm_states) ] return csets, cset_projected elif connectivity == 'summed_count_matrix': # assume that two thermodynamic states overlap when there are samples from both # ensembles in some Markov state C_sum = count_matrices.sum(axis=0) if equilibrium_state_counts is not None: eq_states = _np.where(equilibrium_state_counts.sum(axis=0) > 0)[0] C_sum[eq_states, eq_states[:, _np.newaxis]] = 1 cset_projected = largest_connected_set(C_sum, directed=True) csets = [] for k in range(n_therm_states): cset = _np.intersect1d(_np.where(all_state_counts[k, :] > 0), cset_projected) csets.append(cset) return csets, cset_projected elif connectivity == 'reversible_pathways' or connectivity == 'largest': C_proxy = _np.zeros((n_conf_states, n_conf_states), dtype=int) for C in count_matrices: for comp in connected_sets(C, directed=True): C_proxy[comp[0:-1], comp[1:]] = 1 # add chain of states if equilibrium_state_counts is not None: eq_states = _np.where(equilibrium_state_counts.sum(axis=0) > 0)[0] C_proxy[eq_states, eq_states[:, _np.newaxis]] = 1 cset_projected = largest_connected_set(C_proxy, directed=False) csets = [] for k in range(n_therm_states): cset = _np.intersect1d(_np.where(all_state_counts[k, :] > 0), cset_projected) csets.append(cset) return csets, cset_projected elif connectivity in ['neighbors', 'post_hoc_RE', 'BAR_variance']: dim = n_therm_states * n_conf_states if connectivity == 'post_hoc_RE' or connectivity == 'BAR_variance': if connectivity == 'post_hoc_RE': overlap = _util._overlap_post_hoc_RE else: overlap = _overlap_BAR_variance i_s = [] j_s = [] for i in range(n_conf_states): # can take a very long time, allow to report progress via callback if callback is not None: callback(maxiter=n_conf_states, iteration_step=i) therm_states = _np.where(all_state_counts[:, i] > 0)[0] # therm states that have samples # prepare list of indices for all thermodynamic states traj_indices = {} frame_indices = {} for k in therm_states: frame_indices[k] = [_np.where( _np.logical_and(d == i, t == k))[0] for t, d in zip(ttrajs, dtrajs)] traj_indices[k] = [j for j, fi in enumerate(frame_indices[k]) if len(fi) > 0] for k in therm_states: for l in therm_states: if k!=l: kl = _np.array([k, l]) a = _np.concatenate([ bias_trajs[j][:, kl][frame_indices[k][j], :] for j in traj_indices[k]]) b = _np.concatenate([ bias_trajs[j][:, kl][frame_indices[l][j], :] for j in traj_indices[l]]) if overlap(a, b, factor=factor): x = i + k * n_conf_states y = i + l * n_conf_states i_s.append(x) j_s.append(y) else: # assume overlap between nn neighboring umbrellas assert nn is not None, 'With connectivity="neighbors", nn can\'t be None.' assert nn >= 1 and nn <= n_therm_states - 1 i_s = [] j_s = [] # connectivity between thermodynamic states for l in range(1, nn + 1): if callback is not None: callback(maxiter=nn, iteration_step=l) for k in range(n_therm_states - l): w = _np.where(_np.logical_and( all_state_counts[k, :] > 0, all_state_counts[k + l, :] > 0))[0] a = w + k * n_conf_states b = w + (k + l) * n_conf_states i_s += list(a) j_s += list(b) # connectivity between conformational states: # just copy it from the count matrices for k in range(n_therm_states): for comp in connected_sets(count_matrices[k, :, :], directed=True): # add chain that links all states in the component i_s += list(comp[0:-1] + k * n_conf_states) j_s += list(comp[1:] + k * n_conf_states) # If there is global equilibrium data, assume full connectivity # between all visited conformational states within the same thermodynamic state. if equilibrium_state_counts is not None: for k in range(n_therm_states): vertices = _np.where(equilibrium_state_counts[k, :]>0)[0] # add bidirectional chain that links all states chain = (vertices[0:-1], vertices[1:]) i_s += list(chain[0] + k * n_conf_states) j_s += list(chain[1] + k * n_conf_states) data = _np.ones(len(i_s), dtype=int) A = _sp.sparse.coo_matrix((data, (i_s, j_s)), shape=(dim, dim)) cset = largest_connected_set(A, directed=False) # group by thermodynamic state cset = _np.unravel_index(cset, (n_therm_states, n_conf_states), order='C') csets = [[] for k in range(n_therm_states)] for k,i in zip(*cset): csets[k].append(i) csets = [_np.array(c,dtype=int) for c in csets] projected_cset = _np.unique(_np.concatenate(csets)) return csets, projected_cset else: raise Exception( 'Unknown value "%s" of connectivity. Should be one of: \ summed_count_matrix, strong_in_every_ensemble, neighbors, \ post_hoc_RE or BAR_variance.' % connectivity)
def _estimate(self, trajs): # check input assert isinstance(trajs, (tuple, list)) assert len(trajs) == 2 ttrajs = trajs[0] dtrajs = trajs[1] # validate input for ttraj, dtraj in zip(ttrajs, dtrajs): _types.assert_array(ttraj, ndim=1, kind='numeric') _types.assert_array(dtraj, ndim=1, kind='numeric') assert _np.shape(ttraj)[0] == _np.shape(dtraj)[0] # harvest transition counts self.count_matrices_full = _util.count_matrices( ttrajs, dtrajs, self.lag, sliding=self.count_mode, sparse_return=False, nstates=self.nstates_full) # harvest state counts (for WHAM) self.state_counts_full = _util.state_counts(ttrajs, dtrajs, nthermo=self.nthermo, nstates=self.nstates_full) # restrict to connected set C_sum = self.count_matrices_full.sum(axis=0) # TODO: use improved cset _, cset = _cset.compute_csets_dTRAM(self.connectivity, self.count_matrices_full) self.active_set = cset # correct counts self.count_matrices = self.count_matrices_full[:, cset[:, _np.newaxis], cset] self.count_matrices = _np.require(self.count_matrices, dtype=_np.intc, requirements=['C', 'A']) # correct bias matrix self.bias_energies = self.bias_energies_full[:, cset] self.bias_energies = _np.require(self.bias_energies, dtype=_np.float64, requirements=['C', 'A']) # correct state counts self.state_counts = self.state_counts_full[:, cset] self.state_counts = _np.require(self.state_counts, dtype=_np.intc, requirements=['C', 'A']) # run initialisation pg = _ProgressReporter() if self.init is not None and self.init == 'wham': stage = 'WHAM init.' with pg.context(stage=stage): self.therm_energies, self.conf_energies, _increments, _loglikelihoods = \ _wham.estimate( self.state_counts, self.bias_energies, maxiter=self.init_maxiter, maxerr=self.init_maxerr, save_convergence_info=0, therm_energies=self.therm_energies, conf_energies=self.conf_energies, callback=_ConvergenceProgressIndicatorCallBack( pg, stage, self.init_maxiter, self.init_maxerr)) # run estimator stage = 'DTRAM' with pg.context(stage=stage): self.therm_energies, self.conf_energies, self.log_lagrangian_mult, \ self.increments, self.loglikelihoods = _dtram.estimate( self.count_matrices, self.bias_energies, maxiter=self.maxiter, maxerr=self.maxerr, log_lagrangian_mult=self.log_lagrangian_mult, conf_energies=self.conf_energies, save_convergence_info=self.save_convergence_info, callback=_ConvergenceProgressIndicatorCallBack( pg, stage, self.maxiter, self.maxerr)) # compute models fmsms = [ _dtram.estimate_transition_matrix( self.log_lagrangian_mult, self.bias_energies, self.conf_energies, self.count_matrices, _np.zeros(shape=self.conf_energies.shape, dtype=_np.float64), K) for K in range(self.nthermo) ] active_sets = [ largest_connected_set(msm, directed=False) for msm in fmsms ] fmsms = [ _np.ascontiguousarray((msm[lcc, :])[:, lcc]) for msm, lcc in zip(fmsms, active_sets) ] models = [] for i, (msm, acs) in enumerate(zip(fmsms, active_sets)): pi_acs = _np.exp(self.therm_energies[i] - self.bias_energies[i, :] - self.conf_energies)[acs] pi_acs = pi_acs / pi_acs.sum() models.append( _ThermoMSM(msm, self.active_set[acs], self.nstates_full, pi=pi_acs, dt_model=self.timestep_traj.get_scaled(self.lag))) # set model parameters to self self.set_model_params(models=models, f_therm=self.therm_energies, f=self.conf_energies) # done return self
def cktest_resource(): """Reseed the rng to enforce 'deterministic' behavior""" rnd_state = np.random.mtrand.get_state() np.random.mtrand.seed(42) """Meta-stable birth-death chain""" b = 2 q = np.zeros(7) p = np.zeros(7) q[1:] = 0.5 p[0:-1] = 0.5 q[2] = 1.0 - 10**(-b) q[4] = 10**(-b) p[2] = 10**(-b) p[4] = 1.0 - 10**(-b) bdc = BirthDeathChain(q, p) dtraj = bdc.msm.simulate(10000, start=0) tau = 1 """Estimate MSM""" MSM = estimate_markov_model(dtraj, tau) P_MSM = MSM.transition_matrix mu_MSM = MSM.stationary_distribution """Meta-stable sets""" A = [0, 1, 2] B = [4, 5, 6] w_MSM = np.zeros((2, mu_MSM.shape[0])) w_MSM[0, A] = mu_MSM[A] / mu_MSM[A].sum() w_MSM[1, B] = mu_MSM[B] / mu_MSM[B].sum() K = 10 P_MSM_dense = P_MSM p_MSM = np.zeros((K, 2)) w_MSM_k = 1.0 * w_MSM for k in range(1, K): w_MSM_k = np.dot(w_MSM_k, P_MSM_dense) p_MSM[k, 0] = w_MSM_k[0, A].sum() p_MSM[k, 1] = w_MSM_k[1, B].sum() """Assume that sets are equal, A(\tau)=A(k \tau) for all k""" w_MD = 1.0 * w_MSM p_MD = np.zeros((K, 2)) eps_MD = np.zeros((K, 2)) p_MSM[0, :] = 1.0 p_MD[0, :] = 1.0 eps_MD[0, :] = 0.0 for k in range(1, K): """Build MSM at lagtime k*tau""" C_MD = count_matrix(dtraj, k * tau, sliding=True) / (k * tau) lcc_MD = largest_connected_set(C_MD) Ccc_MD = largest_connected_submatrix(C_MD, lcc=lcc_MD) c_MD = Ccc_MD.sum(axis=1) P_MD = transition_matrix(Ccc_MD).toarray() w_MD_k = np.dot(w_MD, P_MD) """Set A""" prob_MD = w_MD_k[0, A].sum() c = c_MD[A].sum() p_MD[k, 0] = prob_MD eps_MD[k, 0] = np.sqrt(k * (prob_MD - prob_MD**2) / c) """Set B""" prob_MD = w_MD_k[1, B].sum() c = c_MD[B].sum() p_MD[k, 1] = prob_MD eps_MD[k, 1] = np.sqrt(k * (prob_MD - prob_MD**2) / c) """Input""" yield MSM, p_MSM, p_MD np.random.mtrand.set_state(rnd_state)
def _estimate(self, X): ttrajs, dtrajs_full, btrajs = X # shape and type checks assert len(ttrajs) == len(dtrajs_full) == len(btrajs) for t in ttrajs: _types.assert_array(t, ndim=1, kind='i') for d in dtrajs_full: _types.assert_array(d, ndim=1, kind='i') for b in btrajs: _types.assert_array(b, ndim=2, kind='f') # find dimensions nstates_full = max(_np.max(d) for d in dtrajs_full) + 1 if self.nstates_full is None: self.nstates_full = nstates_full elif self.nstates_full < nstates_full: raise RuntimeError("Found more states (%d) than specified by nstates_full (%d)" % ( nstates_full, self.nstates_full)) self.nthermo = max(_np.max(t) for t in ttrajs) + 1 # dimensionality checks for t, d, b, in zip(ttrajs, dtrajs_full, btrajs): assert t.shape[0] == d.shape[0] == b.shape[0] assert b.shape[1] == self.nthermo # cast types and change axis order if needed ttrajs = [_np.require(t, dtype=_np.intc, requirements='C') for t in ttrajs] dtrajs_full = [_np.require(d, dtype=_np.intc, requirements='C') for d in dtrajs_full] btrajs = [_np.require(b, dtype=_np.float64, requirements='C') for b in btrajs] # if equilibrium information is given, separate the trajectories if self.equilibrium is not None: assert len(self.equilibrium) == len(ttrajs) _ttrajs, _dtrajs_full, _btrajs = ttrajs, dtrajs_full, btrajs ttrajs = [ttraj for eq, ttraj in zip(self.equilibrium, _ttrajs) if not eq] dtrajs_full = [dtraj for eq, dtraj in zip(self.equilibrium, _dtrajs_full) if not eq] self.btrajs = [btraj for eq, btraj in zip(self.equilibrium, _btrajs) if not eq] equilibrium_ttrajs = [ttraj for eq, ttraj in zip(self.equilibrium, _ttrajs) if eq] equilibrium_dtrajs_full = [dtraj for eq, dtraj in zip(self.equilibrium, _dtrajs_full) if eq] self.equilibrium_btrajs = [btraj for eq, btraj in zip(self.equilibrium, _btrajs) if eq] else: # set dummy values equilibrium_ttrajs = [] equilibrium_dtrajs_full = [] self.equilibrium_btrajs = [] self.btrajs = btrajs # find state visits and transition counts state_counts_full = _util.state_counts(ttrajs, dtrajs_full, nstates=self.nstates_full, nthermo=self.nthermo) count_matrices_full = _util.count_matrices(ttrajs, dtrajs_full, self.lag, sliding=self.count_mode, sparse_return=False, nstates=self.nstates_full, nthermo=self.nthermo) self.therm_state_counts_full = state_counts_full.sum(axis=1) if self.equilibrium is not None: self.equilibrium_state_counts_full = _util.state_counts(equilibrium_ttrajs, equilibrium_dtrajs_full, nstates=self.nstates_full, nthermo=self.nthermo) else: self.equilibrium_state_counts_full = _np.zeros((self.nthermo, self.nstates_full), dtype=_np.float64) pg = _ProgressReporter() stage = 'cset' with pg.context(stage=stage): self.csets, pcset = _cset.compute_csets_TRAM( self.connectivity, state_counts_full, count_matrices_full, equilibrium_state_counts=self.equilibrium_state_counts_full, ttrajs=ttrajs+equilibrium_ttrajs, dtrajs=dtrajs_full+equilibrium_dtrajs_full, bias_trajs=self.btrajs+self.equilibrium_btrajs, nn=self.nn, factor=self.connectivity_factor, callback=_IterationProgressIndicatorCallBack(pg, 'finding connected set', stage=stage)) self.active_set = pcset # check for empty states for k in range(self.nthermo): if len(self.csets[k]) == 0: _warnings.warn( 'Thermodynamic state %d' % k \ + ' contains no samples after reducing to the connected set.', EmptyState) # deactivate samples not in the csets, states are *not* relabeled self.state_counts, self.count_matrices, self.dtrajs, _ = _cset.restrict_to_csets( self.csets, state_counts=state_counts_full, count_matrices=count_matrices_full, ttrajs=ttrajs, dtrajs=dtrajs_full) if self.equilibrium is not None: self.equilibrium_state_counts, _, self.equilibrium_dtrajs, _ = _cset.restrict_to_csets( self.csets, state_counts=self.equilibrium_state_counts_full, ttrajs=equilibrium_ttrajs, dtrajs=equilibrium_dtrajs_full) else: self.equilibrium_state_counts = _np.zeros((self.nthermo, self.nstates_full), dtype=_np.intc) # (remember: no relabeling) self.equilibrium_dtrajs = [] # self-consistency tests assert _np.all(self.state_counts >= _np.maximum(self.count_matrices.sum(axis=1), \ self.count_matrices.sum(axis=2))) assert _np.all(_np.sum( [_np.bincount(d[d>=0], minlength=self.nstates_full) for d in self.dtrajs], axis=0) == self.state_counts.sum(axis=0)) assert _np.all(_np.sum( [_np.bincount(t[d>=0], minlength=self.nthermo) for t, d in zip(ttrajs, self.dtrajs)], axis=0) == self.state_counts.sum(axis=1)) if self.equilibrium is not None: assert _np.all(_np.sum( [_np.bincount(d[d >= 0], minlength=self.nstates_full) for d in self.equilibrium_dtrajs], axis=0) == self.equilibrium_state_counts.sum(axis=0)) assert _np.all(_np.sum( [_np.bincount(t[d >= 0], minlength=self.nthermo) for t, d in zip(equilibrium_ttrajs, self.equilibrium_dtrajs)], axis=0) == self.equilibrium_state_counts.sum(axis=1)) # check for empty states for k in range(self.state_counts.shape[0]): if self.count_matrices[k, :, :].sum() == 0 and self.equilibrium_state_counts[k, :].sum()==0: _warnings.warn( 'Thermodynamic state %d' % k \ + ' contains no transitions and no equilibrium data after reducing to the connected set.', EmptyState) if self.init == 'mbar' and self.biased_conf_energies is None: if self.direct_space: mbar = _mbar_direct else: mbar = _mbar stage = 'MBAR init.' with pg.context(stage=stage): self.mbar_therm_energies, self.mbar_unbiased_conf_energies, \ self.mbar_biased_conf_energies, _ = mbar.estimate( (state_counts_full.sum(axis=1)+self.equilibrium_state_counts_full.sum(axis=1)).astype(_np.intc), self.btrajs+self.equilibrium_btrajs, dtrajs_full+equilibrium_dtrajs_full, maxiter=self.init_maxiter, maxerr=self.init_maxerr, callback=_ConvergenceProgressIndicatorCallBack( pg, stage, self.init_maxiter, self.init_maxerr), n_conf_states=self.nstates_full) self.biased_conf_energies = self.mbar_biased_conf_energies.copy() # run estimator if self.direct_space: tram = _tram_direct trammbar = _trammbar_direct else: tram = _tram trammbar = _trammbar #import warnings #with warnings.catch_warnings() as cm: # warnings.filterwarnings('ignore', RuntimeWarning) stage = 'TRAM' with pg.context(stage=stage): if self.equilibrium is None: self.biased_conf_energies, conf_energies, self.therm_energies, self.log_lagrangian_mult, \ self.increments, self.loglikelihoods = tram.estimate( self.count_matrices, self.state_counts, self.btrajs, self.dtrajs, maxiter=self.maxiter, maxerr=self.maxerr, biased_conf_energies=self.biased_conf_energies, log_lagrangian_mult=self.log_lagrangian_mult, save_convergence_info=self.save_convergence_info, callback=_ConvergenceProgressIndicatorCallBack( pg, stage, self.maxiter, self.maxerr, subcallback=self.callback), N_dtram_accelerations=self.N_dtram_accelerations) else: # use trammbar self.biased_conf_energies, conf_energies, self.therm_energies, self.log_lagrangian_mult, \ self.increments, self.loglikelihoods = trammbar.estimate( self.count_matrices, self.state_counts, self.btrajs, self.dtrajs, equilibrium_therm_state_counts=self.equilibrium_state_counts.sum(axis=1).astype(_np.intc), equilibrium_bias_energy_sequences=self.equilibrium_btrajs, equilibrium_state_sequences=self.equilibrium_dtrajs, maxiter=self.maxiter, maxerr=self.maxerr, save_convergence_info=self.save_convergence_info, biased_conf_energies=self.biased_conf_energies, log_lagrangian_mult=self.log_lagrangian_mult, callback=_ConvergenceProgressIndicatorCallBack( pg, stage, self.maxiter, self.maxerr, subcallback=self.callback), N_dtram_accelerations=self.N_dtram_accelerations, overcounting_factor=self.overcounting_factor) # compute models fmsms = [_np.ascontiguousarray(( _tram.estimate_transition_matrix( self.log_lagrangian_mult, self.biased_conf_energies, self.count_matrices, None, K)[self.active_set, :])[:, self.active_set]) for K in range(self.nthermo)] active_sets = [largest_connected_set(msm, directed=False) for msm in fmsms] fmsms = [_np.ascontiguousarray( (msm[lcc, :])[:, lcc]) for msm, lcc in zip(fmsms, active_sets)] models = [] for i, (msm, acs) in enumerate(zip(fmsms, active_sets)): pi_acs = _np.exp(self.therm_energies[i] - self.biased_conf_energies[i, :])[self.active_set[acs]] pi_acs = pi_acs / pi_acs.sum() models.append(_ThermoMSM( msm, self.active_set[acs], self.nstates_full, pi=pi_acs, dt_model=self.timestep_traj.get_scaled(self.lag))) # set model parameters to self self.set_model_params( models=models, f_therm=self.therm_energies, f=conf_energies[self.active_set].copy()) return self