コード例 #1
0
    def __init__(self, dtrajs):
        # TODO: extensive input checking!
        from pyemma.util.types import ensure_dtraj_list

        # discrete trajectories
        self._dtrajs = ensure_dtraj_list(dtrajs)

        ## basic count statistics
        # histogram
        self._hist = msmest.count_states(self._dtrajs)
        # total counts
        self._total_count = np.sum(self._hist)
        # number of states
        self._nstates = msmest.number_of_states(dtrajs)

        # not yet estimated
        self._counted_at_lag = False
コード例 #2
0
ファイル: _dtraj_stats.py プロジェクト: hackyhacker/PyERNA
    def __init__(self, dtrajs):
        from pyerna.util.types import ensure_dtraj_list

        # discrete trajectories
        self._dtrajs = ensure_dtraj_list(dtrajs)

        # TODO: extensive input checking!
        if any([np.any(d < -1) for d in self._dtrajs]):
            raise ValueError('Discrete trajectory contains elements < -1.')

        ## basic count statistics
        # histogram
        from msmtools.dtraj import count_states
        self._hist = count_states(self._dtrajs, ignore_negative=True)
        # total counts
        self._total_count = np.sum(self._hist)
        # number of states
        self._nstates = msmest.number_of_states(dtrajs)

        # not yet estimated
        self._counted_at_lag = False
コード例 #3
0
ファイル: bayesian_hmsm.py プロジェクト: greenTara/PyEMMA
    def _estimate(self, dtrajs):
        # ensure right format
        dtrajs = ensure_dtraj_list(dtrajs)

        if self.init_hmsm is None:  # estimate using maximum-likelihood superclass
            # memorize the observation state for bhmm and reset
            # TODO: more elegant solution is to set Estimator params only temporarily in estimate(X, **kwargs)
            default_connectivity = self.connectivity
            default_mincount_connectivity = self.mincount_connectivity
            default_observe_nonempty = self.observe_nonempty
            self.connectivity = None
            self.observe_nonempty = False
            self.mincount_connectivity = 0
            self.accuracy = 1e-2  # this is sufficient for an initial guess
            super(BayesianHMSM, self)._estimate(dtrajs)
            self.connectivity = default_connectivity
            self.mincount_connectivity = default_mincount_connectivity
            self.observe_nonempty = default_observe_nonempty
        else:  # if given another initialization, must copy its attributes
            # TODO: this is too tedious - need to automatize parameter+result copying between estimators.
            self.nstates = self.init_hmsm.nstates
            self.reversible = self.init_hmsm.is_reversible
            self.stationary = self.init_hmsm.stationary
            # trajectories
            self._dtrajs_full = self.init_hmsm._dtrajs_full
            self._dtrajs_lagged = self.init_hmsm._dtrajs_lagged
            self._observable_set = self.init_hmsm._observable_set
            self._dtrajs_obs = self.init_hmsm._dtrajs_obs
            # MLE estimation results
            self.likelihoods = self.init_hmsm.likelihoods  # Likelihood history
            self.likelihood = self.init_hmsm.likelihood
            self.hidden_state_probabilities = self.init_hmsm.hidden_state_probabilities  # gamma variables
            self.hidden_state_trajectories = self.init_hmsm.hidden_state_trajectories  # Viterbi path
            self.count_matrix = self.init_hmsm.count_matrix  # hidden count matrix
            self.initial_count = self.init_hmsm.initial_count  # hidden init count
            self.initial_distribution = self.init_hmsm.initial_distribution
            self._active_set = self.init_hmsm._active_set
            # update HMM Model
            self.update_model_params(
                P=self.init_hmsm.transition_matrix,
                pobs=self.init_hmsm.observation_probabilities,
                dt_model=TimeUnit(self.dt_traj).get_scaled(self.lag))

        # check if we have a valid initial model
        import msmtools.estimation as msmest
        if self.reversible and not msmest.is_connected(self.count_matrix):
            raise NotImplementedError(
                'Encountered disconnected count matrix:\n ' +
                str(self.count_matrix) +
                'with reversible Bayesian HMM sampler using lag=' +
                str(self.lag) + ' and stride=' + str(self.stride) +
                '. Consider using shorter lag, ' +
                'or shorter stride (to use more of the data), ' +
                'or using a lower value for mincount_connectivity.')

        # here we blow up the output matrix (if needed) to the FULL state space because we want to use dtrajs in the
        # Bayesian HMM sampler. This is just an initialization.
        import msmtools.estimation as msmest
        nstates_full = msmest.number_of_states(dtrajs)
        if self.nstates_obs < nstates_full:
            eps = 0.01 / nstates_full  # default output probability, in order to avoid zero columns
            # full state space output matrix. make sure there are no zero columns
            B_init = eps * _np.ones(
                (self.nstates, nstates_full), dtype=_np.float64)
            # fill active states
            B_init[:, self.observable_set] = _np.maximum(
                eps, self.observation_probabilities)
            # renormalize B to make it row-stochastic
            B_init /= B_init.sum(axis=1)[:, None]
        else:
            B_init = self.observation_probabilities

        # HMM sampler
        if self.show_progress:
            self._progress_register(self.nsamples,
                                    description='Sampling HMSMs',
                                    stage=0)

            def call_back():
                self._progress_update(1, stage=0)
        else:
            call_back = None

        from bhmm import discrete_hmm, bayesian_hmm
        hmm_mle = discrete_hmm(self.initial_distribution,
                               self.transition_matrix, B_init)

        sampled_hmm = bayesian_hmm(
            self.discrete_trajectories_lagged,
            hmm_mle,
            nsample=self.nsamples,
            reversible=self.reversible,
            stationary=self.stationary,
            p0_prior=self.p0_prior,
            transition_matrix_prior=self.transition_matrix_prior,
            store_hidden=self.store_hidden,
            call_back=call_back)

        if self.show_progress:
            self._progress_force_finish(stage=0)

        # Samples
        sample_Ps = [
            sampled_hmm.sampled_hmms[i].transition_matrix
            for i in range(self.nsamples)
        ]
        sample_pis = [
            sampled_hmm.sampled_hmms[i].stationary_distribution
            for i in range(self.nsamples)
        ]
        sample_pobs = [
            sampled_hmm.sampled_hmms[i].output_model.output_probabilities
            for i in range(self.nsamples)
        ]
        samples = []
        for i in range(
                self.nsamples):  # restrict to observable set if necessary
            Bobs = sample_pobs[i][:, self.observable_set]
            sample_pobs[i] = Bobs / Bobs.sum(axis=1)[:, None]  # renormalize
            samples.append(
                _HMSM(sample_Ps[i],
                      sample_pobs[i],
                      pi=sample_pis[i],
                      dt_model=self.dt_model))

        # store results
        self.sampled_trajs = [
            sampled_hmm.sampled_hmms[i].hidden_state_trajectories
            for i in range(self.nsamples)
        ]
        self.update_model_params(samples=samples)

        # deal with connectivity
        states_subset = None
        if self.connectivity == 'largest':
            states_subset = 'largest-strong'
        elif self.connectivity == 'populous':
            states_subset = 'populous-strong'
        # OBSERVATION SET
        if self.observe_nonempty:
            observe_subset = 'nonempty'
        else:
            observe_subset = None

        # return submodel (will return self if all None)
        return self.submodel(states=states_subset,
                             obs=observe_subset,
                             mincount_connectivity=self.mincount_connectivity)
コード例 #4
0
    def _estimate(self, dtrajs):
        import bhmm
        # ensure right format
        dtrajs = _types.ensure_dtraj_list(dtrajs)

        # CHECK LAG
        trajlengths = [_np.size(dtraj) for dtraj in dtrajs]
        if self.lag >= _np.max(trajlengths):
            raise ValueError('Illegal lag time ' + str(self.lag) + ' exceeds longest trajectory length')
        if self.lag > _np.mean(trajlengths):
            self.logger.warning('Lag time ' + str(self.lag) + ' is on the order of mean trajectory length '
                                + str(_np.mean(trajlengths)) + '. It is recommended to fit four lag times in each '
                                + 'trajectory. HMM might be inaccurate.')

        # EVALUATE STRIDE
        if self.stride == 'effective':
            # by default use lag as stride (=lag sampling), because we currently have no better theory for deciding
            # how many uncorrelated counts we can make
            self.stride = self.lag
            # get a quick estimate from the spectral radius of the nonreversible
            from pyemma.msm import estimate_markov_model
            msm_nr = estimate_markov_model(dtrajs, lag=self.lag, reversible=False, sparse=False,
                                           connectivity='largest', dt_traj=self.timestep_traj)
            # if we have more than nstates timescales in our MSM, we use the next (neglected) timescale as an
            # estimate of the decorrelation time
            if msm_nr.nstates > self.nstates:
                corrtime = max(1, msm_nr.timescales()[self.nstates-1])
                # use the smaller of these two pessimistic estimates
                self.stride = int(min(self.lag, 2*corrtime))

        # LAG AND STRIDE DATA
        dtrajs_lagged_strided = bhmm.lag_observations(dtrajs, self.lag, stride=self.stride)

        # OBSERVATION SET
        if self.observe_nonempty:
            observe_subset = 'nonempty'
        else:
            observe_subset = None

        # INIT HMM
        from bhmm import init_discrete_hmm
        from pyemma.msm.estimators import MaximumLikelihoodMSM
        if self.msm_init=='largest-strong':
            hmm_init = init_discrete_hmm(dtrajs_lagged_strided, self.nstates, lag=1,
                                         reversible=self.reversible, stationary=True, regularize=True,
                                         method='lcs-spectral', separate=self.separate)
        elif self.msm_init=='all':
            hmm_init = init_discrete_hmm(dtrajs_lagged_strided, self.nstates, lag=1,
                                         reversible=self.reversible, stationary=True, regularize=True,
                                         method='spectral', separate=self.separate)
        elif issubclass(self.msm_init.__class__, MaximumLikelihoodMSM):  # initial MSM given.
            from bhmm.init.discrete import init_discrete_hmm_spectral
            p0, P0, pobs0 = init_discrete_hmm_spectral(self.msm_init.count_matrix_full, self.nstates,
                                                       reversible=self.reversible, stationary=True,
                                                       active_set=self.msm_init.active_set,
                                                       P=self.msm_init.transition_matrix, separate=self.separate)
            hmm_init = bhmm.discrete_hmm(p0, P0, pobs0)
            observe_subset = self.msm_init.active_set  # override observe_subset.
        else:
            raise ValueError('Unknown MSM initialization option: ' + str(self.msm_init))

        # ---------------------------------------------------------------------------------------
        # Estimate discrete HMM
        # ---------------------------------------------------------------------------------------

        # run EM
        from bhmm.estimators.maximum_likelihood import MaximumLikelihoodEstimator as _MaximumLikelihoodEstimator
        hmm_est = _MaximumLikelihoodEstimator(dtrajs_lagged_strided, self.nstates, initial_model=hmm_init,
                                              output='discrete', reversible=self.reversible, stationary=self.stationary,
                                              accuracy=self.accuracy, maxit=self.maxit)
        # run
        hmm_est.fit()
        # package in discrete HMM
        self.hmm = bhmm.DiscreteHMM(hmm_est.hmm)

        # get model parameters
        self.initial_distribution = self.hmm.initial_distribution
        transition_matrix = self.hmm.transition_matrix
        observation_probabilities = self.hmm.output_probabilities

        # get estimation parameters
        self.likelihoods = hmm_est.likelihoods  # Likelihood history
        self.likelihood = self.likelihoods[-1]
        self.hidden_state_probabilities = hmm_est.hidden_state_probabilities  # gamma variables
        self.hidden_state_trajectories = hmm_est.hmm.hidden_state_trajectories  # Viterbi path
        self.count_matrix = hmm_est.count_matrix  # hidden count matrix
        self.initial_count = hmm_est.initial_count  # hidden init count
        self._active_set = _np.arange(self.nstates)

        # TODO: it can happen that we loose states due to striding. Should we lift the output probabilities afterwards?
        # parametrize self
        self._dtrajs_full = dtrajs
        self._dtrajs_lagged = dtrajs_lagged_strided
        self._nstates_obs_full = msmest.number_of_states(dtrajs)
        self._nstates_obs = msmest.number_of_states(dtrajs_lagged_strided)
        self._observable_set = _np.arange(self._nstates_obs)
        self._dtrajs_obs = dtrajs
        self.set_model_params(P=transition_matrix, pobs=observation_probabilities,
                              reversible=self.reversible, dt_model=self.timestep_traj.get_scaled(self.lag))

        # TODO: perhaps remove connectivity and just rely on .submodel()?
        # deal with connectivity
        states_subset = None
        if self.connectivity == 'largest':
            states_subset = 'largest-strong'
        elif self.connectivity == 'populous':
            states_subset = 'populous-strong'

        # return submodel (will return self if all None)
        return self.submodel(states=states_subset, obs=observe_subset, mincount_connectivity=self.mincount_connectivity)
コード例 #5
0
ファイル: bayesian_hmsm.py プロジェクト: dseeliger/PyEMMA
    def _estimate(self, dtrajs):
        """

        Return
        ------
        hmsm : :class:`EstimatedHMSM <pyemma.msm.estimators.hmsm_estimated.EstimatedHMSM>`
            Estimated Hidden Markov state model

        """
        # ensure right format
        dtrajs = ensure_dtraj_list(dtrajs)

        # if no initial MSM is given, estimate it now
        if self.init_hmsm is None:
            # estimate with store_data=True, because we need an EstimatedHMSM
            hmsm_estimator = _MaximumLikelihoodHMSM(
                lag=self.lag,
                stride=self.stride,
                nstates=self.nstates,
                reversible=self.reversible,
                connectivity=self.connectivity,
                observe_active=self.observe_active,
                dt_traj=self.dt_traj)
            init_hmsm = hmsm_estimator.estimate(
                dtrajs)  # estimate with lagged trajectories
        else:
            # check input
            assert isinstance(
                self.init_hmsm,
                _EstimatedHMSM), 'hmsm must be of type EstimatedHMSM'
            init_hmsm = self.init_hmsm
            self.nstates = init_hmsm.nstates
            self.reversible = init_hmsm.is_reversible

        # here we blow up the output matrix (if needed) to the FULL state space because we want to use dtrajs in the
        # Bayesian HMM sampler
        if self.observe_active:
            import msmtools.estimation as msmest
            nstates_full = msmest.number_of_states(dtrajs)
            # pobs = _np.zeros((init_hmsm.nstates, nstates_full))  # currently unused because that produces zero cols
            eps = 0.01 / nstates_full  # default output probability, in order to avoid zero columns
            # full state space output matrix. make sure there are no zero columns
            pobs = eps * _np.ones(
                (self.nstates, nstates_full), dtype=_np.float64)
            # fill active states
            pobs[:, init_hmsm.observable_set] = _np.maximum(
                eps, init_hmsm.observation_probabilities)
            # renormalize B to make it row-stochastic
            pobs /= pobs.sum(axis=1)[:, None]
        else:
            pobs = init_hmsm.observation_probabilities

        # HMM sampler
        if self.show_progress:
            self._progress_register(self.nsamples,
                                    description='Sampling HMSMs',
                                    stage=0)

            def call_back():
                self._progress_update(1, stage=0)
        else:
            call_back = None

        from bhmm import discrete_hmm, bayesian_hmm
        hmm_mle = discrete_hmm(init_hmsm.transition_matrix,
                               pobs,
                               stationary=True,
                               reversible=self.reversible)

        # define prior
        if self.prior == 'sparse':
            self.prior_count_matrix = _np.zeros((self.nstates, self.nstates),
                                                dtype=_np.float64)
        elif self.prior == 'uniform':
            self.prior_count_matrix = _np.ones((self.nstates, self.nstates),
                                               dtype=_np.float64)
        elif self.prior == 'mixed':
            # C0 = _np.dot(_np.diag(init_hmsm.stationary_distribution), init_hmsm.transition_matrix)
            P0 = init_hmsm.transition_matrix
            P0_offdiag = P0 - _np.diag(_np.diag(P0))
            scaling_factor = 1.0 / _np.sum(P0_offdiag, axis=1)
            self.prior_count_matrix = P0 * scaling_factor[:, None]
        else:
            raise ValueError('Unknown prior mode: ' + self.prior)

        sampled_hmm = bayesian_hmm(
            init_hmsm.discrete_trajectories_lagged,
            hmm_mle,
            nsample=self.nsamples,
            transition_matrix_prior=self.prior_count_matrix,
            call_back=call_back)

        if self.show_progress:
            self._progress_force_finish(stage=0)

        # Samples
        sample_Ps = [
            sampled_hmm.sampled_hmms[i].transition_matrix
            for i in range(self.nsamples)
        ]
        sample_pis = [
            sampled_hmm.sampled_hmms[i].stationary_distribution
            for i in range(self.nsamples)
        ]
        sample_pobs = [
            sampled_hmm.sampled_hmms[i].output_model.output_probabilities
            for i in range(self.nsamples)
        ]
        samples = []
        for i in range(
                self.nsamples):  # restrict to observable set if necessary
            Bobs = sample_pobs[i][:, init_hmsm.observable_set]
            sample_pobs[i] = Bobs / Bobs.sum(axis=1)[:, None]  # renormalize
            samples.append(
                _HMSM(sample_Ps[i],
                      sample_pobs[i],
                      pi=sample_pis[i],
                      dt_model=init_hmsm.dt_model))

        # parametrize self
        self._dtrajs_full = dtrajs
        self._observable_set = init_hmsm._observable_set
        self._dtrajs_obs = init_hmsm._dtrajs_obs
        self.set_model_params(samples=samples,
                              P=init_hmsm.transition_matrix,
                              pobs=init_hmsm.observation_probabilities,
                              dt_model=init_hmsm.dt_model)

        return self
コード例 #6
0
ファイル: bayesian_hmsm.py プロジェクト: analisahill/PyEMMA
    def _estimate(self, dtrajs):
        # ensure right format
        dtrajs = ensure_dtraj_list(dtrajs)

        if self.init_hmsm is None:  # estimate using maximum-likelihood superclass
            # memorize the observation state for bhmm and reset
            # TODO: more elegant solution is to set Estimator params only temporarily in estimate(X, **kwargs)
            default_connectivity = self.connectivity
            default_mincount_connectivity = self.mincount_connectivity
            default_observe_nonempty = self.observe_nonempty
            self.connectivity = None
            self.observe_nonempty = False
            self.mincount_connectivity = 0
            self.accuracy = 1e-2  # this is sufficient for an initial guess
            super(BayesianHMSM, self)._estimate(dtrajs)
            self.connectivity = default_connectivity
            self.mincount_connectivity = default_mincount_connectivity
            self.observe_nonempty = default_observe_nonempty
        else:  # if given another initialization, must copy its attributes
            copy_attributes = ['_nstates', '_reversible', '_pi', '_observable_set', 'likelihoods', 'likelihood',
                               'hidden_state_probabilities', 'hidden_state_trajectories', 'count_matrix',
                               'initial_count', 'initial_distribution', '_active_set']
            check_user_choices = ['lag', '_nstates']

            # check if nstates and lag are compatible
            for attr in check_user_choices:
                if not getattr(self, attr) == getattr(self.init_hmsm, attr):
                    raise UserWarning('BayesianHMSM cannot be initialized with init_hmsm with '
                                      'incompatible lag or nstates.')

            if (len(dtrajs) != len(self.init_hmsm.dtrajs_full) or
                    not all((_np.array_equal(d1, d2) for d1, d2 in zip(dtrajs, self.init_hmsm.dtrajs_full)))):
                raise NotImplementedError('Bayesian HMM estimation with init_hmsm is currently only implemented ' +
                                          'if applied to the same data.')

            # TODO: implement more elegant solution to copy-pasting effective stride evaluation from ML HMM.
            # EVALUATE STRIDE
            if self.stride == 'effective':
                # by default use lag as stride (=lag sampling), because we currently have no better theory for deciding
                # how many uncorrelated counts we can make
                self.stride = self.lag
                # get a quick estimate from the spectral radius of the nonreversible
                from pyemma.msm import estimate_markov_model
                msm_nr = estimate_markov_model(dtrajs, lag=self.lag, reversible=False, sparse=False,
                                               connectivity='largest', dt_traj=self.timestep_traj)
                # if we have more than nstates timescales in our MSM, we use the next (neglected) timescale as an
                # estimate of the decorrelation time
                if msm_nr.nstates > self.nstates:
                    corrtime = max(1, msm_nr.timescales()[self.nstates - 1])
                    # use the smaller of these two pessimistic estimates
                    self.stride = int(min(self.lag, 2 * corrtime))

            # if stride is different to init_hmsm, check if microstates in lagged-strided trajs are compatible
            if self.stride != self.init_hmsm.stride:
                dtrajs_lagged_strided = _lag_observations(dtrajs, self.lag, stride=self.stride)
                _nstates_obs = _number_of_states(dtrajs_lagged_strided, only_used=True)
                _nstates_obs_full = _number_of_states(dtrajs)

                if _np.setxor1d(_np.concatenate(dtrajs_lagged_strided),
                                 _np.concatenate(self.init_hmsm._dtrajs_lagged)).size != 0:
                    raise UserWarning('Choice of stride has excluded a different set of microstates than in ' +
                                      'init_hmsm. Set of observed microstates in time-lagged strided trajectories ' +
                                      'must match to the one used for init_hmsm estimation.')

                self._dtrajs_full = dtrajs
                self._dtrajs_lagged = dtrajs_lagged_strided
                self._nstates_obs_full = _nstates_obs_full
                self._nstates_obs = _nstates_obs
                self._observable_set = _np.arange(self._nstates_obs)
                self._dtrajs_obs = dtrajs
            else:
                copy_attributes += ['_dtrajs_full', '_dtrajs_lagged', '_nstates_obs_full',
                                    '_nstates_obs', '_observable_set', '_dtrajs_obs']

            # update self with estimates from init_hmsm
            self.__dict__.update(
                {k: i for k, i in self.init_hmsm.__dict__.items() if k in copy_attributes})

            # as mentioned in the docstring, take init_hmsm observed set observation probabilities
            self.observe_nonempty = False

            # update HMM Model
            self.update_model_params(P=self.init_hmsm.transition_matrix, pobs=self.init_hmsm.observation_probabilities,
                                     dt_model=TimeUnit(self.dt_traj).get_scaled(self.lag))

        # check if we have a valid initial model
        import msmtools.estimation as msmest
        if self.reversible and not msmest.is_connected(self.count_matrix):
            raise NotImplementedError('Encountered disconnected count matrix:\n ' + str(self.count_matrix)
                                      + 'with reversible Bayesian HMM sampler using lag=' + str(self.lag)
                                      + ' and stride=' + str(self.stride) + '. Consider using shorter lag, '
                                      + 'or shorter stride (to use more of the data), '
                                      + 'or using a lower value for mincount_connectivity.')

        # here we blow up the output matrix (if needed) to the FULL state space because we want to use dtrajs in the
        # Bayesian HMM sampler. This is just an initialization.
        nstates_full = msmest.number_of_states(dtrajs)
        if self.nstates_obs < nstates_full:
            eps = 0.01 / nstates_full  # default output probability, in order to avoid zero columns
            # full state space output matrix. make sure there are no zero columns
            B_init = eps * _np.ones((self.nstates, nstates_full), dtype=_np.float64)
            # fill active states
            B_init[:, self.observable_set] = _np.maximum(eps, self.observation_probabilities)
            # renormalize B to make it row-stochastic
            B_init /= B_init.sum(axis=1)[:, None]
        else:
            B_init = self.observation_probabilities

        # HMM sampler
        if self.show_progress:
            self._progress_register(self.nsamples, description='Sampling HMSMs', stage=0)

            def call_back():
                self._progress_update(1, stage=0)
        else:
            call_back = None

        from bhmm import discrete_hmm, bayesian_hmm

        if self.init_hmsm is not None:
            hmm_mle = self.init_hmsm.hmm
        else:
            hmm_mle = discrete_hmm(self.initial_distribution, self.transition_matrix, B_init)

        sampled_hmm = bayesian_hmm(self.discrete_trajectories_lagged, hmm_mle, nsample=self.nsamples,
                                   reversible=self.reversible, stationary=self.stationary,
                                   p0_prior=self.p0_prior, transition_matrix_prior=self.transition_matrix_prior,
                                   store_hidden=self.store_hidden, call_back=call_back)

        if self.show_progress:
            self._progress_force_finish(stage=0)

        # Samples
        sample_inp = [(m.transition_matrix, m.stationary_distribution, m.output_probabilities)
                      for m in sampled_hmm.sampled_hmms]

        samples = []
        for P, pi, pobs in sample_inp:  # restrict to observable set if necessary
            Bobs = pobs[:, self.observable_set]
            pobs = Bobs / Bobs.sum(axis=1)[:, None]  # renormalize
            samples.append(_HMSM(P, pobs, pi=pi, dt_model=self.dt_model))

        # store results
        self.sampled_trajs = [sampled_hmm.sampled_hmms[i].hidden_state_trajectories for i in range(self.nsamples)]
        self.update_model_params(samples=samples)

        # deal with connectivity
        states_subset = None
        if self.connectivity == 'largest':
            states_subset = 'largest-strong'
        elif self.connectivity == 'populous':
            states_subset = 'populous-strong'
        # OBSERVATION SET
        if self.observe_nonempty:
            observe_subset = 'nonempty'
        else:
            observe_subset = None

        # return submodel (will return self if all None)
        return self.submodel(states=states_subset, obs=observe_subset,
                             mincount_connectivity=self.mincount_connectivity)