Ejemplo n.º 1
0
def create_toggle_partial_tree_scorer(model_list, 
                                      data,
                                      tree,
                                      num_extra_nodes=0
                                      ):
    '''Simple constructor for a TogglePartialTreeScorer and LikeCalcEnvironment
    to support it.
    
    Infers:
        data_type (num_states) and asrv from the models in model_list
        num_leaves, num_state_code_arrays, num_partials for len(data)
        num_patterns from len(data[0]) 
        
        
        
        Assumes that you will want a enough prob_matrix for every edge in a rooted,
            binary tree to have two sets of matrices for each model/rate-category 
            combination.
        Assumes that you want only two eigen solution per model.

        Assumes that you want only two partials for internal nodes solution per
            model and `num_extra_partials` in addition to this (thus if you want
            one "extra" node that can be swapped in and out of the tree you 
            should use num_extra_partials=1, and then manually load up the 
            _LCE_xxx attributes for that node.

        Assumes that you want one rescaling array for every 6 edges (every 4 leaves)
    '''
    from pytbeaglehon.like_calc_environ import LikeCalcEnvironment
    asrv_list = []
    num_model_rate_cats = 0
    num_leaves = len(data)
    num_patterns = len(data[0])
    num_models = len(model_list) #TODO more generic for mixtures!
    for model in model_list:
        a = model.asrv
        if a is None:
            num_model_rate_cats += 1
        else:
            num_model_rate_cats += a.num_categories
    num_internals = (num_leaves - 1)
    num_nodes = num_internals + num_leaves
    LCE = LikeCalcEnvironment(model_list=model_list,
                               num_patterns=num_patterns,
                               num_leaves=num_leaves,
                               num_state_code_arrays=num_leaves,
                               num_partials=(num_internals + num_extra_nodes)*2*num_model_rate_cats,
                               num_prob_matrices=(num_nodes + num_extra_nodes)*2*num_model_rate_cats,
                               num_eigen_storage_structs=2*num_models,
                               num_rescalings_multipliers= 2*(1 + num_leaves//4))
    for n, row in enumerate(data):
        LCE.set_state_code_array(n, row)
    scorer = TogglePartialTreeScorer(LCE, tree)
    return scorer
 def _incarnate(self):
     'Assures that there is a LikeCalcEnvironment associated with this object'
     if (self._cmodel is None) or (self._model_index is None) or (self._calc_env is None):
         if self.num_states is None:
             raise ValueError("DiscStateContTimeModel.num_states must be set before calculations can be performed")
         from pytbeaglehon.like_calc_environ import LikeCalcEnvironment
         self._calc_env = LikeCalcEnvironment(num_leaves=2,
                                              num_patterns=1,
                                              num_states=self.num_states,
                                              num_state_code_arrays=2,
                                              num_partials=1,
                                              model_list=[self],
                                              num_prob_matrices=1,
                                              num_eigen_storage_structs=1,
                                              num_rescalings_multipliers=0,
                                              resource_index=-1)
         self._owns_calc_env = True
class DiscStateContTimeModel(object):
    def __init__(self, **kwargs):
        self._cmodel = kwargs.get('cmodel')
        self._char_type = kwargs.get('char_type')
        self._model_index = kwargs.get('model_index')
        self._calc_env = kwargs.get('calc_env')
        self.asrv = kwargs.get('asrv')
        self._owns_calc_env = False
        self._changed_params = set()
        self._q_mat = None
        self._q_mat_hash = None
        self._prev_asrv_hash = None
        self._asrv_hash = None
        self._total_state_hash = None
        self._last_asrv_rates_hash = None
        self._prev_state_hash = None
        self._num_states = None # only used if _char_type is None
        self._eigen_soln_wrapper = None
        self._state_freq_hash = None
        param_list = kwargs.get('param_list')
        if param_list is not None:
            for p in param_list:
                _LOG.debug("Model %s registering self as listener of parameter %s" % (str(self), str(p)))
                p.add_listener(self.param_changed)
        self._changed_params.add(None) # having a non-empty set assures that the q_mat will be recognized as dirty
    
    
    def prob_matrices(self, edge_length, eigen_soln_caching=None, prob_mat_caching=None, as_wrappers=False):
        """returns probability matrices for all rate categories for the model given 
        the edge length `edge_length` 
        """
        prob_wrapper_list = self.calc_prob_matrices(edge_length, eigen_soln_caching=eigen_soln_caching, prob_mat_caching=prob_mat_caching)
        return self._fetch_prob_matrices(prob_wrapper_list)

    def __str__(self):
        return 'DiscStateContTimeModel for %s with asrv=%s at %d' % (str(self.char_type), str(self.asrv), id(self))

    def calc_prob_matrices(self, edge_length, eigen_soln_caching=None, prob_mat_caching=None):
        '''Returns a list containing a transition probability matrix for each rate category.'''
        self._incarnate()
        return self._calc_prob_matrices(edge_length, eigen_soln_caching=eigen_soln_caching, prob_mat_caching=prob_mat_caching)
        

    def param_changed(self, p):
        '''Adds `p` to this list of changed_parameters.'''
        _LOG.debug("Model %s learned that %s changed." % (str(self), str(p)))
        if self._total_state_hash is not None:
            self._prev_state_hash = self._total_state_hash
        self._total_state_hash = None
        self._state_freq_hash = None
        self._changed_params.add(p)

    def get_char_type(self):
        return self._char_type
    char_type = property(get_char_type)

    def get_num_states(self):
        if self._char_type is None:
            return self._num_states
        return self._char_type.num_states
    num_states = property(get_num_states)

    def q_mat_is_dirty(self):
        return (self._changed_params != _EMPTY_SET)

    def asrv_is_dirty(self):
        if self.asrv is None:
            return False
        self._asrv_hash = self.asrv.state_hash
        return self._asrv_hash != self._prev_asrv_hash

    def calc_q_mat(self):
        raise NotImplementedError()

    def get_q_mat(self):
        if self.q_mat_is_dirty():
            self.calc_q_mat()
            self._changed_params.clear()
        return self._q_mat

    def set_q_mat(self, v):
        self._changed_params.clear()
        self._total_state_hash = None
        if v is None:
            self._changed_params.add(None) # having a non-empty set assures that the q_mat will be recognized as dirty    
            self._q_mat = None
        else:
            self._q_mat = tuple([tuple([float(i) for i in row]) for row in v])
    q_mat = property(get_q_mat, set_q_mat)
    
    def get_q_mat_hash(self):
        if self.q_mat_is_dirty():
            qm = self.q_mat
            _LOG.debug('generating hash for %s' % repr(qm))
            self._q_mat_hash = hash(qm)
        return self._q_mat_hash
    q_mat_hash = property(get_q_mat_hash)

    def get_state_hash(self):
        if  (self._total_state_hash is None) or  self.q_mat_is_dirty() or self.asrv_is_dirty():
            self._total_state_hash = hash((id(self), self.q_mat_hash, self._asrv_hash))
            if self._prev_asrv_hash != self._asrv_hash:
                self._prev_asrv_hash = self._asrv_hash
        return self._total_state_hash
    state_hash = property(get_state_hash)
    
    def convert_eigen_soln_caching(self, in_eigen_soln_caching):
        if in_eigen_soln_caching is None:
            return (CachingFacets.DO_NOT_SAVE,)
        return (in_eigen_soln_caching,)

    def convert_prob_mat_caching(self, in_prob_mat_caching):
        if in_prob_mat_caching is None:
            return (CachingFacets.DO_NOT_SAVE,)
        return (in_prob_mat_caching,)

    def _calc_prob_matrices(self, edge_length, eigen_soln_caching=None, prob_mat_caching=None):
        """Returns and an index and state list object which records where the probability matrices
        are stored in the LikeCalcEnvironment.
        """
        es_wrapper = self.get_eigen_soln(eigen_soln_caching=eigen_soln_caching)
        return self._calc_env.calc_prob_from_eigen(edge_length, 
                                         self.asrv,
                                         eigen_soln=es_wrapper,
                                         prob_mat_caching=self.convert_prob_mat_caching(prob_mat_caching))

    def get_eigen_soln(self, eigen_soln_caching=None):
        state_id = self.state_hash
        es_wrapper = self._calc_env.calc_eigen_soln(model=self, 
                                             model_state_hash=state_id,
                                             eigen_soln_caching=self.convert_eigen_soln_caching(eigen_soln_caching))
        self._eigen_soln_wrapper = es_wrapper
        return self._eigen_soln_wrapper
    eigen_soln = property(get_eigen_soln)
        
    def _fetch_prob_matrices(self, prob_wrapper_list):
        return self._calc_env.get_prob_matrices(prob_wrapper_list)


    def _incarnate(self):
        'Assures that there is a LikeCalcEnvironment associated with this object'
        if (self._cmodel is None) or (self._model_index is None) or (self._calc_env is None):
            if self.num_states is None:
                raise ValueError("DiscStateContTimeModel.num_states must be set before calculations can be performed")
            from pytbeaglehon.like_calc_environ import LikeCalcEnvironment
            self._calc_env = LikeCalcEnvironment(num_leaves=2,
                                                 num_patterns=1,
                                                 num_states=self.num_states,
                                                 num_state_code_arrays=2,
                                                 num_partials=1,
                                                 model_list=[self],
                                                 num_prob_matrices=1,
                                                 num_eigen_storage_structs=1,
                                                 num_rescalings_multipliers=0,
                                                 resource_index=-1)
            self._owns_calc_env = True
    def _reassign_environ(self, calc_env, model_index, cmodel, asrv=None):
        '''Associates the model instance with a new LikeCalcEnvironment

        `model_index` is the new index of this model in that environment
        `cmodel` is a reference to the C object that represents the model.
        '''
        if self._owns_calc_env:
            self._calc_env.release_resources()
            self._owns_calc_env = False
        self._calc_env = calc_env
        self._model_index = model_index
        self._cmodel = cmodel
        self._asrv = asrv
    def get_num_eigen_solutions(self):
        return 1

    def get_num_categories(self):
        a = self.asrv
        if a is None:
            return 1
        return a.num_categories
    num_rate_categories = property(get_num_categories)

    def get_all_submodels(self):
        return [self]
    submodels = property(get_all_submodels)        

    def get_num_prob_models(self):
        return 1
    num_prob_models = property(get_num_prob_models)
    def get_cmodel(self):
        return self._cmodel
    cmodel = property(get_cmodel)
    
    def transmit_category_weights(self):
        "Intended for interal use only -- passes category weights to likelihood calculator for integration of likelihood"
        es_wrapper = self._eigen_soln_wrapper
        if es_wrapper is None:
            raise ValueError("eigen solution must be calculated transmit_category_weights can be called")
        asrv = self.asrv
        if asrv is None:
            from pytbeaglehon.like_calc_environ import NONE_HASH
            w = (1.0, )
            wph = NONE_HASH
        else:
            w = asrv.probabilities
            wph = asrv.get_prob_hash()
        es_wrapper.transmit_category_weights(w, wph)

    def transmit_state_freq(self):
        "Intended for interal use only -- passes equilibrium state frequencies to likelihood calculator for integration of likelihood"
        es_wrapper = self._eigen_soln_wrapper
        if es_wrapper is None:
            raise ValueError("eigen solution must be calculated transmit_state_freq can be called")
        sf = self.state_freq
        sfh = self.state_freq_hash
        es_wrapper.transmit_state_freq(sf, sfh)

    def get_state_freq_hash(self):
        if self._state_freq_hash is None:
            assert(self._state_freq is not None)
            self._state_freq_hash = repr(self._state_freq)
        return self._state_freq_hash
    state_freq_hash = property(get_state_freq_hash)