コード例 #1
0
ファイル: TRAM_estimator.py プロジェクト: yuxuanzhuang/PyEMMA
    def _estimate(self, X):
        ttrajs, dtrajs_full, btrajs = X
        # shape and type checks
        assert len(ttrajs) == len(dtrajs_full) == len(btrajs)
        for t in ttrajs:
            _types.assert_array(t, ndim=1, kind='i')
        for d in dtrajs_full:
            _types.assert_array(d, ndim=1, kind='i')
        for b in btrajs:
            _types.assert_array(b, ndim=2, kind='f')
        # find dimensions
        nstates_full = max(_np.max(d) for d in dtrajs_full) + 1
        if self.nstates_full is None:
            self.nstates_full = nstates_full
        elif self.nstates_full < nstates_full:
            raise RuntimeError("Found more states (%d) than specified by nstates_full (%d)" % (
                nstates_full, self.nstates_full))
        self.nthermo = max(_np.max(t) for t in ttrajs) + 1
        # dimensionality checks
        for t, d, b, in zip(ttrajs, dtrajs_full, btrajs):
            assert t.shape[0] == d.shape[0] == b.shape[0]
            assert b.shape[1] == self.nthermo

        # cast types and change axis order if needed
        ttrajs = [_np.require(t, dtype=_np.intc, requirements='C') for t in ttrajs]
        dtrajs_full = [_np.require(d, dtype=_np.intc, requirements='C') for d in dtrajs_full]
        btrajs = [_np.require(b, dtype=_np.float64, requirements='C') for b in btrajs]

        # if equilibrium information is given, separate the trajectories
        if self.equilibrium is not None:
            assert len(self.equilibrium) == len(ttrajs)
            _ttrajs, _dtrajs_full, _btrajs = ttrajs, dtrajs_full, btrajs
            ttrajs = [ttraj for eq, ttraj in zip(self.equilibrium, _ttrajs) if not eq]
            dtrajs_full = [dtraj for eq, dtraj in zip(self.equilibrium, _dtrajs_full) if not eq]
            self.btrajs = [btraj for eq, btraj in zip(self.equilibrium, _btrajs) if not eq]
            equilibrium_ttrajs = [ttraj for eq, ttraj in zip(self.equilibrium, _ttrajs) if eq]
            equilibrium_dtrajs_full = [dtraj for eq, dtraj in zip(self.equilibrium, _dtrajs_full) if eq]
            self.equilibrium_btrajs = [btraj for eq, btraj in zip(self.equilibrium, _btrajs) if eq]
        else: # set dummy values
            equilibrium_ttrajs = []
            equilibrium_dtrajs_full = []
            self.equilibrium_btrajs = []
            self.btrajs = btrajs

        # find state visits and transition counts
        state_counts_full = _util.state_counts(ttrajs, dtrajs_full, nstates=self.nstates_full, nthermo=self.nthermo)
        count_matrices_full = _util.count_matrices(ttrajs, dtrajs_full,
            self.lag, sliding=self.count_mode, sparse_return=False, nstates=self.nstates_full, nthermo=self.nthermo)
        self.therm_state_counts_full = state_counts_full.sum(axis=1)

        if self.equilibrium is not None:
            self.equilibrium_state_counts_full = _util.state_counts(equilibrium_ttrajs, equilibrium_dtrajs_full,
                nstates=self.nstates_full, nthermo=self.nthermo)
        else:
            self.equilibrium_state_counts_full = _np.zeros((self.nthermo, self.nstates_full), dtype=_np.float64)

        pg = _ProgressReporter()
        stage = 'cset'
        with pg.context(stage=stage):
            self.csets, pcset = _cset.compute_csets_TRAM(
                self.connectivity, state_counts_full, count_matrices_full,
                equilibrium_state_counts=self.equilibrium_state_counts_full,
                ttrajs=ttrajs+equilibrium_ttrajs, dtrajs=dtrajs_full+equilibrium_dtrajs_full, bias_trajs=self.btrajs+self.equilibrium_btrajs,
                nn=self.nn, factor=self.connectivity_factor,
                callback=_IterationProgressIndicatorCallBack(pg, 'finding connected set', stage=stage))
            self.active_set = pcset

        # check for empty states
        for k in range(self.nthermo):
            if len(self.csets[k]) == 0:
                _warnings.warn(
                    'Thermodynamic state %d' % k \
                    + ' contains no samples after reducing to the connected set.', EmptyState)

        # deactivate samples not in the csets, states are *not* relabeled
        self.state_counts, self.count_matrices, self.dtrajs, _  = _cset.restrict_to_csets(
            self.csets,
            state_counts=state_counts_full, count_matrices=count_matrices_full,
            ttrajs=ttrajs, dtrajs=dtrajs_full)

        if self.equilibrium is not None:
            self.equilibrium_state_counts, _, self.equilibrium_dtrajs, _ =  _cset.restrict_to_csets(
                self.csets,
                state_counts=self.equilibrium_state_counts_full, ttrajs=equilibrium_ttrajs, dtrajs=equilibrium_dtrajs_full)
        else:
            self.equilibrium_state_counts = _np.zeros((self.nthermo, self.nstates_full), dtype=_np.intc) # (remember: no relabeling)
            self.equilibrium_dtrajs = []

        # self-consistency tests
        assert _np.all(self.state_counts >= _np.maximum(self.count_matrices.sum(axis=1), \
            self.count_matrices.sum(axis=2)))
        assert _np.all(_np.sum(
            [_np.bincount(d[d>=0], minlength=self.nstates_full) for d in self.dtrajs],
            axis=0) == self.state_counts.sum(axis=0))
        assert _np.all(_np.sum(
            [_np.bincount(t[d>=0], minlength=self.nthermo) for t, d in zip(ttrajs, self.dtrajs)],
            axis=0) == self.state_counts.sum(axis=1))
        if self.equilibrium is not None:
            assert _np.all(_np.sum(
                [_np.bincount(d[d >= 0], minlength=self.nstates_full) for d in self.equilibrium_dtrajs],
                axis=0) == self.equilibrium_state_counts.sum(axis=0))
            assert _np.all(_np.sum(
                [_np.bincount(t[d >= 0], minlength=self.nthermo) for t, d in zip(equilibrium_ttrajs, self.equilibrium_dtrajs)],
                axis=0) ==  self.equilibrium_state_counts.sum(axis=1))

        # check for empty states
        for k in range(self.state_counts.shape[0]):
            if self.count_matrices[k, :, :].sum() == 0 and self.equilibrium_state_counts[k, :].sum()==0:
                _warnings.warn(
                    'Thermodynamic state %d' % k \
                    + ' contains no transitions and no equilibrium data after reducing to the connected set.', EmptyState)

        if self.init == 'mbar' and self.biased_conf_energies is None:
            if self.direct_space:
                mbar = _mbar_direct
            else:
                mbar = _mbar
            stage = 'MBAR init.'
            with pg.context(stage=stage):
                self.mbar_therm_energies, self.mbar_unbiased_conf_energies, \
                    self.mbar_biased_conf_energies, _ = mbar.estimate(
                        (state_counts_full.sum(axis=1)+self.equilibrium_state_counts_full.sum(axis=1)).astype(_np.intc),
                        self.btrajs+self.equilibrium_btrajs, dtrajs_full+equilibrium_dtrajs_full,
                        maxiter=self.init_maxiter, maxerr=self.init_maxerr,
                        callback=_ConvergenceProgressIndicatorCallBack(
                            pg, stage, self.init_maxiter, self.init_maxerr),
                        n_conf_states=self.nstates_full)
            self.biased_conf_energies = self.mbar_biased_conf_energies.copy()

        # run estimator
        if self.direct_space:
            tram = _tram_direct
            trammbar = _trammbar_direct
        else:
            tram = _tram
            trammbar = _trammbar
        #import warnings
        #with warnings.catch_warnings() as cm:
        # warnings.filterwarnings('ignore', RuntimeWarning)
        stage = 'TRAM'
        with pg.context(stage=stage):
            if self.equilibrium is None:
                self.biased_conf_energies, conf_energies, self.therm_energies, self.log_lagrangian_mult, \
                    self.increments, self.loglikelihoods = tram.estimate(
                        self.count_matrices, self.state_counts, self.btrajs, self.dtrajs,
                        maxiter=self.maxiter, maxerr=self.maxerr,
                        biased_conf_energies=self.biased_conf_energies,
                        log_lagrangian_mult=self.log_lagrangian_mult,
                        save_convergence_info=self.save_convergence_info,
                        callback=_ConvergenceProgressIndicatorCallBack(
                            pg, stage, self.maxiter, self.maxerr, subcallback=self.callback),
                        N_dtram_accelerations=self.N_dtram_accelerations)
            else: # use trammbar
                self.biased_conf_energies, conf_energies, self.therm_energies, self.log_lagrangian_mult, \
                    self.increments, self.loglikelihoods = trammbar.estimate(
                        self.count_matrices, self.state_counts, self.btrajs, self.dtrajs,
                        equilibrium_therm_state_counts=self.equilibrium_state_counts.sum(axis=1).astype(_np.intc),
                        equilibrium_bias_energy_sequences=self.equilibrium_btrajs, equilibrium_state_sequences=self.equilibrium_dtrajs,
                        maxiter=self.maxiter, maxerr=self.maxerr,
                        save_convergence_info=self.save_convergence_info,
                        biased_conf_energies=self.biased_conf_energies,
                        log_lagrangian_mult=self.log_lagrangian_mult,
                        callback=_ConvergenceProgressIndicatorCallBack(
                            pg, stage, self.maxiter, self.maxerr, subcallback=self.callback),
                        N_dtram_accelerations=self.N_dtram_accelerations,
                        overcounting_factor=self.overcounting_factor)

        # compute models
        fmsms = [_np.ascontiguousarray((
            _tram.estimate_transition_matrix(
                self.log_lagrangian_mult, self.biased_conf_energies, self.count_matrices, None,
                K)[self.active_set, :])[:, self.active_set]) for K in range(self.nthermo)]

        active_sets = [_largest_connected_set(msm, directed=False) for msm in fmsms]
        fmsms = [_np.ascontiguousarray(
            (msm[lcc, :])[:, lcc]) for msm, lcc in zip(fmsms, active_sets)]

        models = []
        for i, (msm, acs) in enumerate(zip(fmsms, active_sets)):
            pi_acs = _np.exp(self.therm_energies[i] - self.biased_conf_energies[i, :])[self.active_set[acs]]
            pi_acs = pi_acs / pi_acs.sum()
            models.append(_ThermoMSM(
                msm, self.active_set[acs], self.nstates_full, pi=pi_acs,
                dt_model=self.timestep_traj.get_scaled(self.lag)))

        # set model parameters to self
        self.set_model_params(
            models=models, f_therm=self.therm_energies, f=conf_energies[self.active_set].copy())

        return self
コード例 #2
0
    def _estimate(self, X):
        ttrajs, dtrajs_full, btrajs = X
        # shape and type checks
        assert len(ttrajs) == len(dtrajs_full) == len(btrajs)
        for t in ttrajs:
            _types.assert_array(t, ndim=1, kind='i')
        for d in dtrajs_full:
            _types.assert_array(d, ndim=1, kind='i')
        for b in btrajs:
            _types.assert_array(b, ndim=2, kind='f')
        # find dimensions
        self.nstates_full = max(_np.max(d) for d in dtrajs_full) + 1
        self.nthermo = max(_np.max(t) for t in ttrajs) + 1
        # dimensionality checks
        for t, d, b, in zip(ttrajs, dtrajs_full, btrajs):
            assert t.shape[0] == d.shape[0] == b.shape[0]
            assert b.shape[1] == self.nthermo

        # cast types and change axis order if needed
        ttrajs = [
            _np.require(t, dtype=_np.intc, requirements='C') for t in ttrajs
        ]
        dtrajs_full = [
            _np.require(d, dtype=_np.intc, requirements='C')
            for d in dtrajs_full
        ]
        btrajs = [
            _np.require(b, dtype=_np.float64, requirements='C') for b in btrajs
        ]

        # find state visits and transition counts
        state_counts_full = _util.state_counts(ttrajs, dtrajs_full)
        count_matrices_full = _util.count_matrices(ttrajs,
                                                   dtrajs_full,
                                                   self.lag,
                                                   sliding=self.count_mode,
                                                   sparse_return=False,
                                                   nstates=self.nstates_full)
        self.therm_state_counts_full = state_counts_full.sum(axis=1)

        self.csets, pcset = _cset.compute_csets_TRAM(
            self.connectivity,
            state_counts_full,
            count_matrices_full,
            ttrajs=ttrajs,
            dtrajs=dtrajs_full,
            bias_trajs=btrajs,
            nn=self.nn,
            factor=self.connectivity_factor,
            callback=_IterationProgressIndicatorCallBack(
                self, 'finding connected set', 'cset'))
        self.active_set = pcset

        # check for empty states
        for k in range(self.nthermo):
            if len(self.csets[k]) == 0:
                _warnings.warn(
                    'Thermodynamic state %d' % k \
                    + ' contains no samples after reducing to the connected set.', EmptyState)

        # deactivate samples not in the csets, states are *not* relabeled
        self.state_counts, self.count_matrices, self.dtrajs, _ = _cset.restrict_to_csets(
            self.csets,
            state_counts=state_counts_full,
            count_matrices=count_matrices_full,
            ttrajs=ttrajs,
            dtrajs=dtrajs_full)

        # self-consistency tests
        assert _np.all(self.state_counts >= _np.maximum(self.count_matrices.sum(axis=1), \
            self.count_matrices.sum(axis=2)))
        assert _np.all(
            _np.sum([
                _np.bincount(d[d >= 0], minlength=self.nstates_full)
                for d in self.dtrajs
            ],
                    axis=0) == self.state_counts.sum(axis=0))
        assert _np.all(
            _np.sum([
                _np.bincount(t[d >= 0], minlength=self.nthermo)
                for t, d in zip(ttrajs, self.dtrajs)
            ],
                    axis=0) == self.state_counts.sum(axis=1))

        # check for empty states
        for k in range(self.state_counts.shape[0]):
            if self.count_matrices[k, :, :].sum() == 0:
                _warnings.warn(
                    'Thermodynamic state %d' % k \
                    + 'contains no transitions after reducing to the connected set.', EmptyState)

        if self.init == 'mbar' and self.biased_conf_energies is None:
            if self.direct_space:
                mbar = _mbar_direct
            else:
                mbar = _mbar
            self.mbar_therm_energies, self.mbar_unbiased_conf_energies, \
                self.mbar_biased_conf_energies, _ = mbar.estimate(
                    state_counts_full.sum(axis=1), btrajs, dtrajs_full,
                    maxiter=self.init_maxiter, maxerr=self.init_maxerr,
                    callback=_ConvergenceProgressIndicatorCallBack(
                        self, 'MBAR init.', self.init_maxiter, self.init_maxerr),
                    n_conf_states=self.nstates_full)
            self._progress_force_finish(stage='MBAR init.',
                                        description='MBAR init.')
            self.biased_conf_energies = self.mbar_biased_conf_energies.copy()

        # run estimator
        if self.direct_space:
            tram = _tram_direct
        else:
            tram = _tram
        #import warnings
        #with warnings.catch_warnings() as cm:
        # warnings.filterwarnings('ignore', RuntimeWarning)
        self.biased_conf_energies, conf_energies, self.therm_energies, self.log_lagrangian_mult, \
            self.increments, self.loglikelihoods = tram.estimate(
                self.count_matrices, self.state_counts, btrajs, self.dtrajs,
                maxiter=self.maxiter, maxerr=self.maxerr,
                biased_conf_energies=self.biased_conf_energies,
                log_lagrangian_mult=self.log_lagrangian_mult,
                save_convergence_info=self.save_convergence_info,
                callback=_ConvergenceProgressIndicatorCallBack(
                    self, 'TRAM', self.maxiter, self.maxerr),
                N_dtram_accelerations=self.N_dtram_accelerations)
        self._progress_force_finish(stage='TRAM', description='TRAM')
        self.btrajs = btrajs

        # compute models
        fmsms = [
            _np.ascontiguousarray((_tram.estimate_transition_matrix(
                self.log_lagrangian_mult, self.biased_conf_energies,
                self.count_matrices, None,
                K)[self.active_set, :])[:, self.active_set])
            for K in range(self.nthermo)
        ]

        self.model_active_set = [
            _largest_connected_set(msm, directed=False) for msm in fmsms
        ]
        fmsms = [
            _np.ascontiguousarray((msm[lcc, :])[:, lcc])
            for msm, lcc in zip(fmsms, self.model_active_set)
        ]
        models = [
            _MSM(msm, dt_model=self.timestep_traj.get_scaled(self.lag))
            for msm in fmsms
        ]

        # set model parameters to self
        self.set_model_params(models=models,
                              f_therm=self.therm_energies,
                              f=conf_energies[self.active_set].copy())

        return self