Beispiel #1
0
class _DefaultROIManager(ROIManager, QObject):
    def __init__(self, parent: QObject = None):
        super().__init__(parent=parent)
        self._cache = LRUCache(maxsize=2048)  # Store this many ROIs at once

    @staticmethod
    def _getCacheKey(roiFile: pwsdt.RoiFile):
        return os.path.split(roiFile.filePath)[0], roiFile.name, roiFile.number

    def removeRoi(self, roiFile: pwsdt.RoiFile):
        self._cache.pop(self._getCacheKey(roiFile))
        roiFile.delete()
        self.roiRemoved.emit(roiFile)

    def updateRoi(self, roiFile: pwsdt.RoiFile, roi: pwsdt.Roi):
        roiFile.update(roi)
        self._cache[self._getCacheKey(roiFile)] = roiFile
        self.roiUpdated.emit(roiFile)

    def createRoi(self, acq: pwsdt.Acquisition, roi: pwsdt.Roi, roiName: str, roiNumber: int, overwrite: bool = False) -> pwsdt.RoiFile:
        """

        Args:
            acq: The acquisition to save the ROI to
            roi: The ROI to save.
            roiName: The name to save the ROI as.
            roiNumber: The number to save the ROI as.
            overwrite: Whether to overwrite existing ROIs with conflicting name/number combo.

        Returns:
            A reference to the created ROIFile

        Raises:
            OSError: If `overwrite` is false and an ROIFile for this name and number already exists.

        """
        try:
            roiFile = acq.saveRoi(roiName, roiNumber, roi, overwrite=overwrite)
        except OSError as e:
            raise e
        self._cache[self._getCacheKey(roiFile)] = roiFile
        self.roiCreated.emit(roiFile, overwrite)
        return roiFile

    @cachedmethod(lambda self: self._cache, key=lambda acq, roiName, roiNum: (acq.filePath, roiName, roiNum))  # Cache results
    def getROI(self, acq: pwsdt.Acquisition, roiName: str, roiNum: int) -> pwsdt.RoiFile:
        return acq.loadRoi(roiName, roiNum)

    def close(self):
        self._cache.clear()
Beispiel #2
0
class ResidualNet(NeuralNet):
    def __init__(self, nb_filters, nb_res_blocks):
        self.nn = AlphaZeroNetwork(nb_filters, nb_res_blocks)
        if args.cuda:
            self.nn.cuda()
        # if args.half_precision:
        #     self.nn.half()
        #     for module in self.nn.modules():
        #         if isinstance(module, nn.BatchNorm2d):
        #             module.float()
        self.cache = LRUCache(maxsize=500000)

    def clear_cache(self):
        self.cache.clear()
        print("position evaluation cache cleared")

    def train(self, examples):
        print(len(examples))
        # optimizer = optim.Adam(self.nn.parameters(), lr=args.lr)
        optimizer = optim.SGD(self.nn.parameters(),
                              lr=args.lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
        for epoch in range(args.epochs):
            print('EPOCH ::: ' + str(epoch + 1))
            self.nn.train()
            running_pi_loss = 0
            running_v_loss = 0
            data_time = AverageMeter()
            batch_time = AverageMeter()
            pi_losses = AverageMeter()
            v_losses = AverageMeter()
            end = time.time()

            bar = Bar('Training Net', max=int(len(examples) / args.batch_size))
            batch_idx = 0

            while batch_idx < int(len(examples) / args.batch_size):
                sample_ids = np.random.randint(len(examples),
                                               size=args.batch_size)
                boards, pis, vs = list(zip(*[examples[i] for i in sample_ids]))
                boards = torch.Tensor(boards)
                target_pis = torch.Tensor(pis)
                target_vs = torch.Tensor(vs)

                # predict
                if args.cuda:
                    boards, target_pis, target_vs = boards.contiguous().cuda(
                    ), target_pis.contiguous().cuda(), target_vs.contiguous(
                    ).cuda()

                # measure data loading time
                data_time.update(time.time() - end)

                # Compute output
                p_vector, v = self.nn(boards)
                policy_loss = self.loss_pi(target_pis, p_vector)
                value_loss = self.loss_v(target_vs, v)
                total_loss = policy_loss + value_loss

                # record losses
                pi_losses.update(policy_loss.item())
                v_losses.update(total_loss.item())
                running_v_loss += value_loss.item()
                running_pi_loss += policy_loss.item()

                # compute gradient and do SGD step
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()
                batch_idx += 1

                # plot progress
                print(
                    '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss_pi: {lpi:.4f} | Loss_v: {lv:.3f}'
                    .format(
                        batch=batch_idx,
                        size=int(len(examples) / args.batch_size),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        lpi=pi_losses.avg,
                        lv=v_losses.avg,
                    ))
                bar.next()
            bar.finish()
        self.clear_cache()
        infos = {
            "value_loss": running_v_loss / float(len(examples)),
            "policy_loss": running_pi_loss / float(len(examples))
        }
        return infos

    @cachedmethod(lambda self: self.cache, key=lambda board: board.tostring())
    def predict(self, board):
        start = time.time()
        with torch.no_grad():
            board = torch.Tensor([board])
            if args.cuda:
                board = board.cuda()
            self.nn.eval()
            p, v = self.nn(board)
        # print(f'PREDICTION TIME TAKEN : {time.time() - start}')

        return torch.exp(p).data.cpu().numpy()[0], v.data.cpu().numpy()[0]

    def loss_pi(self, targets, outputs):
        return -torch.sum(targets * outputs) / targets.size()[0]

    def loss_v(self, targets, outputs):
        return torch.sum((targets - outputs.view(-1))**2) / targets.size()[0]

    def save_checkpoint(self,
                        folder='checkpoint',
                        filename='checkpoint.pth.tar'):
        filepath = os.path.join(folder, filename)
        if not os.path.exists(folder):
            print("Checkpoint Directory does not exist! Making directory {}".
                  format(folder))
            os.mkdir(folder)
        else:
            print("Checkpoint Directory exists! ")
        torch.save({
            'state_dict': self.nn.state_dict(),
        }, filepath)

    def load_checkpoint(self,
                        folder='checkpoint',
                        filename='checkpoint.pth.tar'):
        # https://github.com/pytorch/examples/blob/master/imagenet/main.py#L98
        filepath = os.path.join(folder, filename)
        if not os.path.exists(filepath):
            raise ("No model in path {}".format(filepath))
        map_location = None if args.cuda else 'cpu'
        checkpoint = torch.load(filepath, map_location=map_location)
        self.nn.load_state_dict(checkpoint['state_dict'])
Beispiel #3
0
class TextDoc(object):
    """
    Class that tokenizes, tags, and parses a text document, and provides an easy
    interface to information extraction, alternative document representations,
    and statistical measures of the text.

    Args:
        text (str)
        spacy_pipeline (``spacy.<lang>.<Lang>()``, optional)
        lang (str, optional)
        metadata (dict, optional)
        max_cachesize (int, optional)
    """
    def __init__(self, text, spacy_pipeline=None, lang='auto',
                 metadata=None, max_cachesize=5):
        self.metadata = {} if metadata is None else metadata
        self.lang = text_utils.detect_language(text) if lang == 'auto' else lang
        if spacy_pipeline is None:
            self.spacy_pipeline = data.load_spacy_pipeline(lang=self.lang)
        else:
            # check for match between text and supplied spacy pipeline language
            if spacy_pipeline.lang != self.lang:
                msg = 'TextDoc.lang {} != spacy_pipeline.lang {}'.format(
                    self.lang, spacy_pipeline.lang)
                raise ValueError(msg)
            else:
                self.spacy_pipeline = spacy_pipeline
        self.spacy_vocab = self.spacy_pipeline.vocab
        self.spacy_stringstore = self.spacy_vocab.strings
        self.spacy_doc = self.spacy_pipeline(text)
        self._term_counts = Counter()
        self._cache = LRUCache(maxsize=max_cachesize)

    def __repr__(self):
        return 'TextDoc({} tokens: {})'.format(
            len(self.spacy_doc), repr(self.text[:50].replace('\n',' ').strip() + '...'))

    def __len__(self):
        return self.n_tokens

    def __getitem__(self, index):
        return self.spacy_doc[index]

    def __iter__(self):
        for tok in self.spacy_doc:
            yield tok

    @property
    def tokens(self):
        """Yield the document's tokens as tokenized by spacy; same as ``__iter__``."""
        for tok in self.spacy_doc:
            yield tok

    @property
    def sents(self):
        """Yield the document's sentences as segmented by spacy."""
        for sent in self.spacy_doc.sents:
            yield sent

    def merge(self, spans):
        """
        Merge spans *in-place* within doc so that each takes up a single token.
        Note: All cached methods on this doc will be cleared.

        Args:
            spans (iterable(``spacy.Span``)): for example, the results from
                :func:`extract.named_entities() <textacy.extract.named_entities>`
                or :func:`extract.pos_regex_matches() <textacy.extract.pos_regex_matches>`
        """
        with LOCK:
            self._cache.clear()
        spacy_utils.merge_spans(spans)

    ###############
    # DOC AS TEXT #

    @property
    def text(self):
        """Return the document's raw text."""
        return self.spacy_doc.text_with_ws

    @property
    def tokenized_text(self):
        """Return text as an ordered, nested list of tokens per sentence."""
        return [[token.text for token in sent]
                for sent in self.spacy_doc.sents]

    @property
    def pos_tagged_text(self):
        """Return text as an ordered, nested list of (token, POS) pairs per sentence."""
        return [[(token.text, token.pos_) for token in sent]
                for sent in self.spacy_doc.sents]

    #######################
    # DOC REPRESENTATIONS #

    def as_bag_of_terms(self, weighting='tf', normalized=True, binary=False,
                        idf=None, lemmatize='auto',
                        ngram_range=(1, 1),
                        include_nes=False, include_nps=False, include_kts=False):
        """
        Represent doc as a "bag of terms", an unordered set of (term id, term weight)
        pairs, where term weight may be by TF or TF*IDF.

        Args:
            weighting (str {'tf', 'tfidf'}, optional): weighting of term weights,
                either term frequency ('tf') or tf * inverse doc frequency ('tfidf')
            idf (dict, optional): if `weighting` = 'tfidf', idf's must be supplied
                externally, such as from a `TextCorpus` object
            lemmatize (bool or 'auto', optional): if True, lemmatize all terms
                when getting their frequencies
            ngram_range (tuple(int), optional): (min n, max n) values for n-grams
                to include in terms list; default (1, 1) only includes unigrams
            include_nes (bool, optional): if True, include named entities in terms list
            include_nps (bool, optional): if True, include noun phrases in terms list
            include_kts (bool, optional): if True, include key terms in terms list
            normalized (bool, optional): if True, normalize term freqs by the
                total number of unique terms
            binary (bool optional): if True, set all (non-zero) term freqs equal to 1

        Returns:
            :class:`collections.Counter <collections.Counter>`: mapping of term ids
                to corresponding term weights
        """
        term_weights = self.term_counts(
            lemmatize=lemmatize, ngram_range=ngram_range, include_nes=include_nes,
            include_nps=include_nps, include_kts=include_kts)

        if binary is True:
            term_weights = Counter({key: 1 for key in term_weights.keys()})
        elif normalized is True:
            # n_terms = sum(term_freqs.values())
            n_tokens = self.n_tokens
            term_weights = Counter({key: val / n_tokens
                                    for key, val in term_weights.items()})

        if weighting == 'tfidf' and idf:
            term_weights = Counter({key: val * idf[key]
                                    for key, val in term_weights.items()})

        return term_weights

    def as_bag_of_concepts(self):
        raise NotImplementedError()

    def as_semantic_network(self):
        raise NotImplementedError()

    ##########################
    # INFORMATION EXTRACTION #

    @cachedmethod(attrgetter('_cache'), key=partial(hashkey, 'words'))
    def words(self, **kwargs):
        """
        Extract an ordered list of words from a spacy-parsed doc, optionally
        filtering words by part-of-speech (etc.) and frequency.

        .. seealso:: :func:`extract.words() <textacy.extract.words>` for all function kwargs.
        """
        return list(extract.words(self.spacy_doc, **kwargs))

    @cachedmethod(attrgetter('_cache'), key=partial(hashkey, 'ngrams'))
    def ngrams(self, n, **kwargs):
        """
        Extract an ordered list of n-grams (``n`` consecutive words) from doc,
        optionally filtering n-grams by the types and parts-of-speech of the
        constituent words.

        Args:
            n (int): number of tokens to include in n-grams;
                1 => unigrams, 2 => bigrams

        .. seealso:: :func:`extract.ngrams() <textacy.extract.ngrams>` for all function kwargs.
        """
        return list(extract.ngrams(self.spacy_doc, n, **kwargs))

    @cachedmethod(attrgetter('_cache'), key=partial(hashkey, 'named_entities'))
    def named_entities(self, **kwargs):
        """
        Extract an ordered list of named entities (PERSON, ORG, LOC, etc.) from
        doc, optionally filtering by the entity types and frequencies.

        .. seealso:: :func:`extract.named_entities() <textacy.extract.named_entities>`
        for all function kwargs.
        """
        return list(extract.named_entities(self.spacy_doc, **kwargs))

    @cachedmethod(attrgetter('_cache'), key=partial(hashkey, 'noun_chunks'))
    def noun_chunks(self, **kwargs):
        """
        Extract an ordered list of noun phrases from doc, optionally
        filtering by frequency and dropping leading determiners.

        .. seealso:: :func:`extract.noun_chunks() <textacy.extract.noun_chunks>`
        for all function kwargs.
        """
        return list(extract.noun_chunks(self.spacy_doc, **kwargs))

    @cachedmethod(attrgetter('_cache'), key=partial(hashkey, 'pos_regex_matches'))
    def pos_regex_matches(self, pattern):
        """
        Extract sequences of consecutive tokens from a spacy-parsed doc whose
        part-of-speech tags match the specified regex pattern.

        Args:
            pattern (str): Pattern of consecutive POS tags whose corresponding words
                are to be extracted, inspired by the regex patterns used in NLTK's
                ``nltk.chunk.regexp``. Tags are uppercase, from the universal tag set;
                delimited by < and >, which are basically converted to parentheses
                with spaces as needed to correctly extract matching word sequences;
                white space in the input doesn't matter.

                Examples (see :obj:`POS_REGEX_PATTERNS <textacy.regexes_etc.POS_REGEX_PATTERNS>`):

                * noun phrase: r'<DET>? (<NOUN>+ <ADP|CONJ>)* <NOUN>+'
                * compound nouns: r'<NOUN>+'
                * verb phrase: r'<VERB>?<ADV>*<VERB>+'
                * prepositional phrase: r'<PREP> <DET>? (<NOUN>+<ADP>)* <NOUN>+'
        """
        return list(extract.pos_regex_matches(self.spacy_doc, pattern))

    @cachedmethod(attrgetter('_cache'), key=partial(hashkey, 'subject_verb_object_triples'))
    def subject_verb_object_triples(self):
        """
        Extract an *un*ordered list of distinct subject-verb-object (SVO) triples
        from doc.
        """
        return list(extract.subject_verb_object_triples(self.spacy_doc))

    @cachedmethod(attrgetter('_cache'), key=partial(hashkey, 'acronyms_and_definitions'))
    def acronyms_and_definitions(self, **kwargs):
        """
        Extract a collection of acronyms and their most likely definitions,
        if available, from doc. If multiple definitions are found for a given acronym,
        only the most frequently occurring definition is returned.

        .. seealso:: :func:`extract.acronyms_and_definitions() <textacy.extract.acronyms_and_definitions>`
        for all function kwargs.
        """
        return extract.acronyms_and_definitions(self.spacy_doc, **kwargs)

    @cachedmethod(attrgetter('_cache'), key=partial(hashkey, 'semistructured_statements'))
    def semistructured_statements(self, entity, **kwargs):
        """
        Extract "semi-structured statements" from doc, each as a (entity, cue, fragment)
        triple. This is similar to subject-verb-object triples.

        Args:
            entity (str): a noun or noun phrase of some sort (e.g. "President Obama",
                "global warming", "Python")

        .. seealso:: :func:`extract.semistructured_statements() <textacy.extract.semistructured_statements>`
        for all function kwargs.
        """
        return list(extract.semistructured_statements(
            self.spacy_doc, entity, **kwargs))

    @cachedmethod(attrgetter('_cache'), key=partial(hashkey, 'direct_quotations'))
    def direct_quotations(self):
        """
        Baseline, not-great attempt at direction quotation extraction (no indirect
        or mixed quotations) using rules and patterns. English only.
        """
        return list(extract.direct_quotations(self.spacy_doc))

    @cachedmethod(attrgetter('_cache'), key=partial(hashkey, 'key_terms'))
    def key_terms(self, algorithm='sgrank', n=10):
        """
        Extract key terms from a document using `algorithm`.

        Args:
            algorithm (str {'sgrank', 'textrank', 'singlerank'}, optional): name
                of algorithm to use for key term extraction
            n (int or float, optional): if int, number of top-ranked terms to return
                as keyterms; if float, must be in the open interval (0.0, 1.0),
                representing the fraction of top-ranked terms to return as keyterms

        Raises:
            ValueError: if ``algorithm`` not in {'sgrank', 'textrank', 'singlerank'}
        """
        if algorithm == 'sgrank':
            return keyterms.sgrank(self.spacy_doc, window_width=1500, n_keyterms=n)
        elif algorithm == 'textrank':
            return keyterms.textrank(self.spacy_doc, n_keyterms=n)
        elif algorithm == 'singlerank':
            return keyterms.singlerank(self.spacy_doc, n_keyterms=n)
        else:
            raise ValueError('algorithm {} not a valid option'.format(algorithm))

    ##############
    # STATISTICS #

    @cachedmethod(attrgetter('_cache'), key=partial(hashkey, 'term_counts'))
    def term_counts(self, lemmatize='auto', ngram_range=(1, 1),
                    include_nes=False, include_nps=False, include_kts=False):
        """
        Get the number of occurrences ("counts") of each unique term in doc;
        terms may be words, n-grams, named entities, noun phrases, and key terms.

        Args:
            lemmatize (bool or 'auto', optional): if True, lemmatize all terms
                when getting their frequencies; if 'auto', lemmatize all terms
                that aren't proper nouns or acronyms
            ngram_range (tuple(int), optional): (min n, max n) values for n-grams
                to include in terms list; default (1, 1) only includes unigrams
            include_nes (bool, optional): if True, include named entities in terms list
            include_nps (bool, optional): if True, include noun phrases in terms list
            include_kts (bool, optional): if True, include key terms in terms list

        Returns:
            :class:`collections.Counter() <collections.Counter>`: mapping of unique
                term ids to corresponding term counts
        """
        if lemmatize == 'auto':
            get_id = lambda x: self.spacy_stringstore[spacy_utils.normalized_str(x)]
        elif lemmatize is True:
            get_id = lambda x: self.spacy_stringstore[x.lemma_]
        else:
            get_id = lambda x: self.spacy_stringstore[x.text]

        for n in range(ngram_range[0], ngram_range[1] + 1):
            if n == 1:
                self._term_counts = self._term_counts | Counter(
                    get_id(word) for word in self.words())
            else:
                self._term_counts = self._term_counts | Counter(
                    get_id(ngram) for ngram in self.ngrams(n))
        if include_nes is True:
            self._term_counts = self._term_counts | Counter(
                get_id(ne) for ne in self.named_entities())
        if include_nps is True:
            self._term_counts = self._term_counts | Counter(
                get_id(np) for np in self.noun_chunks())
        if include_kts is True:
            # HACK: key terms are currently returned as strings
            # TODO: cache key terms, and return them as spacy spans
            get_id = lambda x: self.spacy_stringstore[x]
            self._term_counts = self._term_counts | Counter(
                get_id(kt) for kt, _ in self.key_terms())

        return self._term_counts

    def term_count(self, term):
        """
        Get the number of occurrences ("count") of term in doc.

        Args:
            term (str or ``spacy.Token`` or ``spacy.Span``)

        Returns:
            int
        """
        # figure out what object we're dealing with here; convert as necessary
        if isinstance(term, str):
            term_text = term
            term_id = self.spacy_stringstore[term_text]
            term_len = term_text.count(' ') + 1
        elif isinstance(term, spacy_token):
            term_text = spacy_utils.normalized_str(term)
            term_id = self.spacy_stringstore[term_text]
            term_len = 1
        elif isinstance(term, spacy_span):
            term_text = spacy_utils.normalized_str(term)
            term_id = self.spacy_stringstore[term_text]
            term_len = len(term)

        term_count_ = self._term_counts[term_id]
        if term_count_ > 0:
            return term_count_
        # have we not already counted the appropriate `n` n-grams?
        if not any(self.spacy_stringstore[t].count(' ') == term_len
                   for t in self._term_counts):
            get_id = lambda x: self.spacy_stringstore[spacy_utils.normalized_str(x)]
            if term_len == 1:
                self._term_counts += Counter(get_id(w) for w in self.words())
            else:
                self._term_counts += Counter(get_id(ng) for ng in self.ngrams(term_len))
            term_count_ = self._term_counts[term_id]
            if term_count_ > 0:
                return term_count_
        # last resort: try a regular expression
        return sum(1 for _ in re.finditer(re.escape(term_text), self.text))

    @property
    def n_tokens(self):
        """The number of tokens in the document -- including punctuation."""
        return len(self.spacy_doc)

    def n_words(self, filter_stops=False, filter_punct=True, filter_nums=False):
        """
        The number of words in the document, with optional filtering of stop words,
        punctuation (on by default), and numbers.
        """
        return len(self.words(filter_stops=filter_stops,
                              filter_punct=filter_punct,
                              filter_nums=filter_nums))

    @property
    def n_sents(self):
        """The number of sentences in the document."""
        return sum(1 for _ in self.spacy_doc.sents)

    def n_paragraphs(self, pattern=r'\n\n+'):
        """The number of paragraphs in the document, as delimited by ``pattern``."""
        return sum(1 for _ in re.finditer(pattern, self.text)) + 1

    @cachedmethod(attrgetter('_cache'), key=partial(hashkey, 'readability_stats'))
    def readability_stats(self):
        return text_stats.readability_stats(self)
Beispiel #4
0
class EP(object):
    r"""Generic EP implementation.

    Let :math:`\mathrm Q \mathrm S \mathrm Q^{\intercal}` be the economic
    eigendecomposition of the genetic covariance matrix.
    Let :math:`\mathrm U\mathrm S\mathrm V^{\intercal}` be the singular value
    decomposition of the user-provided covariates :math:`\mathrm M`. We define

    .. math::

        \mathrm K = v ((1-\delta)\mathrm Q \mathrm S \mathrm Q^{\intercal} +
                    \delta \mathrm I)

    as the covariance of the prior distribution. As such,
    :math:`v` and :math:`\delta` refer to :py:attr:`_v` and :py:attr:`_delta`
    class attributes, respectively. We also use the following variables for
    convenience:

    .. math::
        :nowrap:

        \begin{eqnarray}
            \sigma_b^2          & = & v (1-\delta) \\
            \sigma_{\epsilon}^2 & = & v \delta
        \end{eqnarray}

    The covariate effect-sizes is given by :math:`\boldsymbol\beta`, which
    implies

    .. math::

        \mathbf m = \mathrm M \boldsymbol\beta

    The prior is thus defined as

    .. math::

        \mathcal N(\mathbf z ~|~ \mathbf m; \mathrm K)

    and the marginal likelihood is given by

    .. math::

        p(\mathbf y) = \int \prod_i p(y_i | g(\mathrm E[y_i | z_i])=z_i)
            \mathcal N(\mathbf z ~|~ \mathbf m, \mathrm K) \mathrm d\mathbf z

    However, the singular value decomposition of the covariates allows us to
    automatically remove dependence between covariates, which would create
    infinitly number of :math:`\boldsymbol\beta` that lead to global optima.
    Let us define

    .. math::

        \tilde{\boldsymbol\beta} = \mathrm S^{1/2} \mathrm V^{\intercal}
                                    \boldsymbol\beta

    as the covariate effect-sizes we will effectively work with during the
    optimization process. Let us also define the

    .. math::

        \tilde{\mathrm M} = \mathrm U \mathrm S^{1/2}

    as the redundance-free covariates. Naturally,

    .. math::

        \mathbf m = \tilde{\mathrm M} \tilde{\boldsymbol\beta}

    In summary, we will optimize :math:`\tilde{\boldsymbol{\beta}}`, even
    though the user will be able to retrieve the corresponding
    :math:`\boldsymbol{\beta}`.


    Let

    .. math::

        \mathrm{KL}[p(y_i|z_i) p_{-}(z_i|y_i)_{\text{EP}} ~|~
            p(y_i|z_i)_{\text{EP}} p_{-}(z_i|y_i)_{\text{EP}}]

    be the KL divergence we want to minimize at each EP iteration.
    The left-hand side can be described as
    :math:`\hat c_i \mathcal N(z_i | \hat \mu_i; \hat \sigma_i^2)`


    Args:
        M (array_like): :math:`\mathrm M` covariates.
        Q (array_like): :math:`\mathrm Q` of the economic
                        eigendecomposition.
        S (array_like): :math:`\mathrm S` of the economic
                        eigendecomposition.
        overdispersion (bool): `True` for :math:`\sigma_{\epsilon}^2 \ge 0`,
                `False` for :math:`\sigma_{\epsilon}^2=0`.
        QSQt (array_like): :math:`\mathrm Q \mathrm S
                        \mathrm Q^{\intercal}` in case this has already
                        been computed. Defaults to `None`.


    Attributes:
        _v (float): Total variance :math:`v` from the prior distribution.
        _delta (float): Fraction of the total variance due to the identity
                        matrix :math:`\mathrm I`.
        _loghz (array_like): This is :math:`\log(\hat c)` for each site.
        _hmu (array_like): This is :math:`\hat \mu` for each site.
        _hvar (array_like): This is :math:`\hat \sigma^2` for each site.

    """
    def __init__(self, M, Q, S, overdispersion):
        self._cache_SQt = LRUCache(maxsize=1)
        self._cache_m = LRUCache(maxsize=1)
        self._cache_K = LRUCache(maxsize=1)
        self._cache_diagK = LRUCache(maxsize=1)
        self._cache_update = LRUCache(maxsize=1)
        self._cache_lml_components = LRUCache(maxsize=1)
        self._cache_L = LRUCache(maxsize=1)
        self._cache_A = LRUCache(maxsize=1)
        self._cache_C = LRUCache(maxsize=1)
        self._cache_BiQt = LRUCache(maxsize=1)
        self._cache_QBiQtAm = LRUCache(maxsize=1)
        self._cache_QBiQtCteta = LRUCache(maxsize=1)

        self._logger = logging.getLogger(__name__)

        if not is_all_finite(Q) or not is_all_finite(isfinite(S)):
            raise ValueError("There are non-finite numbers in the provided" +
                             " eigen decomposition.")

        if S.min() <= 0:
            raise ValueError("The provided covariance matrix is not" +
                             " positive-definite because the minimum" +
                             " eigvalue is %f." % S.min())

        make_sure_reasonable_conditioning(S)

        self._S = S
        self._Q = Q
        self.__QSQt = None

        nsamples = M.shape[0]
        self._previous_sitelik_tau = zeros(nsamples)
        self._previous_sitelik_eta = zeros(nsamples)

        self._sitelik_tau = zeros(nsamples)
        self._sitelik_eta = zeros(nsamples)

        self._cav_tau = zeros(nsamples)
        self._cav_eta = zeros(nsamples)

        self._joint_tau = zeros(nsamples)
        self._joint_eta = zeros(nsamples)

        self._v = None
        self._delta = 0
        self._overdispersion = overdispersion
        self._tM = None
        self.__tbeta = None
        self._covariate_setup(M)

        self._loghz = empty(nsamples)
        self._hmu = empty(nsamples)
        self._hvar = empty(nsamples)
        self._ep_params_initialized = False

    def _copy_to(self, ep):
        ep._cache_SQt = LRUCache(maxsize=1)
        ep._cache_m = LRUCache(maxsize=1)
        ep._cache_K = LRUCache(maxsize=1)
        ep._cache_diagK = LRUCache(maxsize=1)
        ep._cache_update = LRUCache(maxsize=1)
        ep._cache_lml_components = LRUCache(maxsize=1)
        ep._cache_L = LRUCache(maxsize=1)
        ep._cache_A = LRUCache(maxsize=1)
        ep._cache_C = LRUCache(maxsize=1)
        ep._cache_BiQt = LRUCache(maxsize=1)
        ep._cache_QBiQtAm = LRUCache(maxsize=1)
        ep._cache_QBiQtCteta = LRUCache(maxsize=1)

        ep._logger = logging.getLogger(__name__)

        ep._S = self._S
        ep._Q = self._Q
        ep.__QSQt = self.__QSQt

        ep._previous_sitelik_tau = self._previous_sitelik_tau.copy()
        ep._previous_sitelik_eta = self._previous_sitelik_eta.copy()

        ep._sitelik_tau = self._sitelik_tau.copy()
        ep._sitelik_eta = self._sitelik_eta.copy()

        ep._cav_tau = self._cav_tau.copy()
        ep._cav_eta = self._cav_eta.copy()

        ep._joint_tau = self._joint_tau.copy()
        ep._joint_eta = self._joint_eta.copy()

        ep._v = self._v
        ep._delta = self._delta
        ep._overdispersion = self._overdispersion
        ep._tM = self._tM
        ep.__tbeta = self.__tbeta.copy()
        ep._M = self._M
        ep._svd_U = self._svd_U
        ep._svd_S12 = self._svd_S12
        ep._svd_V = self._svd_V

        ep._loghz = self._loghz.copy()
        ep._hmu = self._hmu.copy()
        ep._hvar = self._hvar.copy()
        ep._ep_params_initialized = self._ep_params_initialized

    def _covariate_setup(self, M):
        self._M = M
        SVD = economic_svd(M)
        self._svd_U = SVD[0]
        self._svd_S12 = sqrt(SVD[1])
        self._svd_V = SVD[2]
        self._tM = ddot(self._svd_U, self._svd_S12, left=False)
        if self.__tbeta is not None:
            self.__tbeta = resize(self.__tbeta, self._tM.shape[1])

    def _init_ep_params(self):
        self._logger.debug("EP parameters initialization.")

        if self._ep_params_initialized:
            self._joint_update()
        else:
            self._joint_initialize()
            self._sitelik_initialize()
            self._ep_params_initialized = True

    def fixed_ep(self):
        w1, w2, w3, _, _, w6, w7 = self._lml_components()

        lml_const = w1 + w2 + w3 + w6 + w7

        beta_nom = self._optimal_beta_nom()

        return FixedEP(lml_const, self._A(), self._C(), self._L(), self._Q,
                       self._QBiQtCteta(), self._sitelik_eta, beta_nom)

    def _joint_initialize(self):
        r"""Initialize the mean and covariance of the posterior.

        Given that :math:`\tilde{\mathrm T}` is a matrix of zeros before the
        first EP iteration, we have

        .. math::
            :nowrap:

            \begin{eqnarray}
                \Sigma         & = & \mathrm K \\
                \boldsymbol\mu & = & \mathrm K^{-1} \mathbf m
            \end{eqnarray}
        """
        self._joint_tau[:] = 1 / self._diagK()
        self._joint_eta[:] = self.m()
        self._joint_eta[:] *= self._joint_tau

    def _sitelik_initialize(self):
        self._sitelik_tau[:] = 0.
        self._sitelik_eta[:] = 0.

    @cachedmethod(attrgetter('_cache_K'))
    def K(self):
        r"""Covariance matrix of the prior.

        Returns:
            :math:`\sigma_b^2 \mathrm Q_0 \mathrm S_0 \mathrm Q_0^{\intercal} + \sigma_{\epsilon}^2 \mathrm I`.
        """
        return sum2diag(self.sigma2_b * self._QSQt(), self.sigma2_epsilon)

    def _Kdot(self, x):
        Q = self._Q
        S = self._S
        out = dot(Q.T, x)
        out *= S
        out = dot(Q, out)
        out *= (1 - self.delta)
        out += self.delta * x
        out *= self.v
        return out

    @cachedmethod(attrgetter('_cache_diagK'))
    def _diagK(self):
        r"""Returns the diagonal of :math:`\mathrm K`."""
        return self.sigma2_b * self._diagQSQt() + self.sigma2_epsilon

    def _diagQSQt(self):
        return self._QSQt().diagonal()

    @cachedmethod(attrgetter('_cache_m'))
    def m(self):
        r"""Mean vector of the prior.

        Returns:
            :math:`\mathrm M \boldsymbol\beta`.
        """
        return dot(self._tM, self._tbeta)

    @property
    def covariates_variance(self):
        r"""Variance explained by the covariates.

        It is defined as

        .. math::

            \sigma_a^2 = \sum_{s=1}^p \left\{ \sum_{i=1}^n \left(
                \mathrm M_{i,s}\beta_s - \sum_{j=1}^n
                \frac{\mathrm M_{j,s}\beta_s}{n} \right)^2 \Big/ n
            \right\}

        where :math:`p` is the number of covariates and :math:`n` is the number
        of individuals. One can show that it amounts to
        :math:`\sum_s \beta_s^2` whenever the columns of :math:`\mathrm M`
        are normalized to have mean and standard deviation equal to zero and
        one, respectively.
        """
        return fsum(variance(self.M * self.beta, axis=0))

    @property
    def sigma2_b(self):
        r"""Returns :math:`v (1-\delta)`."""
        return self.v * (1 - self.delta)

    @property
    def sigma2_epsilon(self):
        r"""Returns :math:`v \delta`."""
        return self.v * self.delta

    @property
    def delta(self):
        r"""Returns :math:`\delta`."""
        return self._delta

    @delta.setter
    def delta(self, v):
        r"""Set :math:`\delta`."""
        self._cache_K.clear()
        self._cache_diagK.clear()
        self._cache_update.clear()
        self._cache_lml_components.clear()
        self._cache_L.clear()
        self._cache_A.clear()
        self._cache_C.clear()
        self._cache_BiQt.clear()
        self._cache_QBiQtAm.clear()
        self._cache_QBiQtCteta.clear()
        if not (0 <= v <= 1):
            raise ValueError("delta should not be %f." % v)
        self._delta = v

    @property
    def v(self):
        r"""Returns :math:`v`."""
        return self._v

    @v.setter
    def v(self, v):
        r"""Set :math:`v`."""
        self._cache_K.clear()
        self._cache_diagK.clear()
        self._cache_update.clear()
        self._cache_lml_components.clear()
        self._cache_L.clear()
        self._cache_A.clear()
        self._cache_C.clear()
        self._cache_BiQt.clear()
        self._cache_QBiQtAm.clear()
        self._cache_QBiQtCteta.clear()
        if v < 0:
            raise ValueError("v should not be %f." % v)
        self._v = max(v, epsilon.small)

    @property
    def _tbeta(self):
        return self.__tbeta

    @_tbeta.setter
    def _tbeta(self, value):
        self._cache_lml_components.clear()
        self._cache_QBiQtAm.clear()
        self._cache_m.clear()
        self._cache_update.clear()
        if not is_all_finite(value):
            raise ValueError("tbeta should not be %s." % str(value))
        if self.__tbeta is None:
            self.__tbeta = asarray(value, float).copy()
        else:
            self.__tbeta[:] = value

    @property
    def beta(self):
        r"""Returns :math:`\boldsymbol\beta`."""
        return solve(self._svd_V.T, self._tbeta / self._svd_S12)

    @beta.setter
    def beta(self, value):
        if not is_all_finite(value):
            raise ValueError("beta should not be %s." % str(value))
        self._tbeta = self._svd_S12 * dot(self._svd_V.T, value)

    @property
    def M(self):
        r"""Returns :math:`\mathrm M`."""
        return self._M

    @M.setter
    def M(self, value):
        self._covariate_setup(value)
        self._cache_m.clear()
        self._cache_QBiQtAm.clear()
        self._cache_update.clear()
        self._cache_lml_components.clear()

    @cachedmethod(attrgetter('_cache_lml_components'))
    def _lml_components(self):
        self._update()

        S = self._S
        m = self.m()
        ttau = self._sitelik_tau
        teta = self._sitelik_eta
        ctau = self._cav_tau
        ceta = self._cav_eta
        tctau = ttau + ctau
        A = self._A()
        C = self._C()
        L = self._L()
        Am = A * m

        QBiQtCteta = self._QBiQtCteta()
        QBiQtAm = self._QBiQtAm()

        gS = self.sigma2_b * S
        eC = self.sigma2_epsilon * C

        w1 = -sum(log(diagonal(L))) + (-sum(log(gS)) / 2 + log(A).sum() / 2)

        w2 = eC * teta
        w2 += C * QBiQtCteta
        w2 -= teta / tctau
        w2 = dot(teta, w2) / 2

        w3 = dot(ceta, (ttau * ceta - 2 * teta * ctau) / (ctau * tctau)) / 2

        w4 = dot(m * C, teta) - dot(Am, QBiQtCteta)

        w5 = -dot(Am, m) / 2 + dot(Am, QBiQtAm) / 2

        w6 = -sum(log(ttau)) + sum(log(tctau)) - sum(log(ctau))
        w6 /= 2

        w7 = sum(self._loghz)

        return (w1, w2, w3, w4, w5, w6, w7)

    def lml(self, fast=False):
        if fast:
            return self._normal_lml()
        else:
            v = fsum(self._lml_components())
            if not isfinite(v):
                raise ValueError("LML should not be %f." % v)
            return fsum(self._lml_components())

    def _normal_lml(self):
        self._update()

        m = self.m()
        ttau = self._sitelik_tau
        teta = self._sitelik_eta

        # NEW PHENOTYPE
        y = teta.copy()

        # NEW MEAN
        m = ttau * m

        # NEW COVARIANCE
        K = self.K()
        K = ddot(ttau, ddot(K, ttau, left=False), left=True)
        sum2diag(K, ttau, out=K)
        (Q, S0) = economic_qs(K)
        Q0, Q1 = Q

        from ...lmm import FastLMM
        from numpy import newaxis

        fastlmm = FastLMM(y, Q0, Q1, S0, covariates=m[:, newaxis])
        fastlmm.learn(progress=False)
        return fastlmm.lml()

    def _gradient_over_v(self):
        self._update()

        A = self._A()
        Q = self._Q
        S = self._S
        C = self._C()
        m = self.m()
        delta = self.delta
        v = self.v
        teta = self._sitelik_eta

        AQ = ddot(A, Q, left=True)
        SQt = ddot(S, Q.T, left=True)

        Am = A * m
        Em = Am - A * self._QBiQtAm()

        Cteta = C * teta
        Eu = Cteta - A * self._QBiQtCteta()

        u = Em - Eu

        uBiQtAK0, uBiQtAK1 = self._uBiQtAK()

        out = dot(u, self._Kdot(u))
        out /= v
        out -= (1 - delta) * trace2(AQ, SQt)
        out -= delta * A.sum()
        out += (1 - delta) * trace2(AQ, uBiQtAK0)
        out += delta * trace2(AQ, uBiQtAK1)
        out /= 2
        return out

    def _gradient_over_delta(self):
        self._update()

        v = self.v
        delta = self.delta
        Q = self._Q
        S = self._S

        A = self._A()
        C = self._C()
        m = self.m()
        teta = self._sitelik_eta

        Am = A * m
        Em = Am - A * self._QBiQtAm()

        Cteta = C * teta
        Eu = Cteta - A * self._QBiQtCteta()

        u = Em - Eu

        AQ = ddot(A, Q, left=True)
        SQt = ddot(S, Q.T, left=True)

        BiQt = self._BiQt()

        uBiQtAK0, uBiQtAK1 = self._uBiQtAK()

        out = -trace2(AQ, uBiQtAK0)
        out -= (delta / (1 - delta)) * trace2(AQ, uBiQtAK1)
        out += trace2(AQ, ddot(BiQt, A, left=False)) * \
            ((delta / (1 - delta)) + 1)
        out += (1 + delta / (1 - delta)) * dot(u, u)
        out += trace2(AQ, SQt) + (delta / (1 - delta)) * A.sum()
        out -= (1 + delta / (1 - delta)) * A.sum()

        out *= v

        out -= dot(u, self._Kdot(u)) / (1 - delta)

        return out / 2

    def _gradient_over_both(self):
        self._update()

        v = self.v
        delta = self.delta
        Q = self._Q
        S = self._S
        A = self._A()
        AQ = ddot(A, Q, left=True)
        SQt = ddot(S, Q.T, left=True)
        BiQt = self._BiQt()
        uBiQtAK0, uBiQtAK1 = self._uBiQtAK()

        C = self._C()
        m = self.m()
        teta = self._sitelik_eta
        Q = self._Q
        As = A.sum()

        Am = A * m
        Em = Am - A * self._QBiQtAm()

        Cteta = C * teta
        Eu = Cteta - A * self._QBiQtCteta()

        u = Em - Eu
        uKu = dot(u, self._Kdot(u))
        tr1 = trace2(AQ, uBiQtAK0)
        tr2 = trace2(AQ, uBiQtAK1)

        dv = uKu / v
        dv -= (1 - delta) * trace2(AQ, SQt)
        dv -= delta * As
        dv += (1 - delta) * tr1
        dv += delta * tr2
        dv /= 2

        dd = delta / (1 - delta)
        ddelta = -tr1
        ddelta -= dd * tr2
        ddelta += trace2(AQ, ddot(BiQt, A, left=False)) * (dd + 1)
        ddelta += (dd + 1) * dot(u, u)
        ddelta += trace2(AQ, SQt)
        ddelta -= As
        ddelta *= v
        ddelta -= uKu / (1 - delta)
        ddelta /= 2

        v = asarray([dv, ddelta])

        if not is_all_finite(v):
            raise ValueError("LML gradient should not be %s." % str(v))

        return v

    @cachedmethod(attrgetter('_cache_update'))
    def _update(self):
        self._init_ep_params()

        self._logger.debug('EP loop has started.')

        pttau = self._previous_sitelik_tau
        pteta = self._previous_sitelik_eta

        ttau = self._sitelik_tau
        teta = self._sitelik_eta

        jtau = self._joint_tau
        jeta = self._joint_eta

        ctau = self._cav_tau
        ceta = self._cav_eta

        i = 0
        while i < MAX_EP_ITER:
            pttau[:] = ttau
            pteta[:] = teta

            ctau[:] = jtau - ttau
            ceta[:] = jeta - teta
            self._tilted_params()

            if not all(isfinite(self._hvar)) or any(self._hvar == 0.):
                raise Exception('Error: not all(isfinite(hsig2))' +
                                ' or any(hsig2 == 0.).')

            self._sitelik_update()
            self._cache_lml_components.clear()
            self._cache_L.clear()
            self._cache_A.clear()
            self._cache_C.clear()
            self._cache_BiQt.clear()
            self._cache_QBiQtAm.clear()
            self._cache_QBiQtCteta.clear()

            self._joint_update()

            tdiff = abs(pttau - ttau)
            ediff = abs(pteta - teta)
            aerr = tdiff.max() + ediff.max()

            if pttau.min() <= 0. or (0. in pteta):
                rerr = inf
            else:
                rtdiff = tdiff / abs(pttau)
                rediff = ediff / abs(pteta)
                rerr = rtdiff.max() + rediff.max()

            i += 1
            if aerr < 2 * EP_EPS or rerr < 2 * EP_EPS:
                break

        if i + 1 == MAX_EP_ITER:
            self._logger.warning('Maximum number of EP iterations has' +
                                 ' been attained.')

        self._logger.debug('EP loop has performed %d iterations.', i)

    def _joint_update(self):
        A = self._A()
        C = self._C()
        m = self.m()
        Q = self._Q
        v = self.v
        delta = self.delta
        teta = self._sitelik_eta
        jtau = self._joint_tau
        jeta = self._joint_eta
        Kteta = self._Kdot(teta)

        BiQt = self._BiQt()
        uBiQtAK0, uBiQtAK1 = self._uBiQtAK()

        jtau[:] = -dotd(Q, uBiQtAK0)
        jtau *= 1 - delta
        jtau -= delta * dotd(Q, uBiQtAK1)
        jtau *= v
        jtau += self._diagK()

        jtau[:] = 1 / jtau

        dot(Q, dot(BiQt, -A * Kteta), out=jeta)
        jeta += Kteta
        jeta += m
        jeta -= self._QBiQtAm()
        jeta *= jtau
        jtau /= C

    def _sitelik_update(self):
        hmu = self._hmu
        hvar = self._hvar
        tau = self._cav_tau
        eta = self._cav_eta
        self._sitelik_tau[:] = clip(1.0 / hvar - tau, epsilon.tiny,
                                    1 / epsilon.small)
        self._sitelik_eta[:] = hmu / hvar - eta

    def _optimal_beta_nom(self):
        A = self._A()
        C = self._C()
        teta = self._sitelik_eta
        Cteta = C * teta
        v = Cteta - A * self._QBiQtCteta()
        if not is_all_finite(v):
            raise ValueError("beta_nom should not be %s." % str(v))
        return v

    def _optimal_tbeta_denom(self):
        L = self._L()
        Q = self._Q
        AM = ddot(self._A(), self._tM, left=True)
        QBiQtAM = dot(Q, cho_solve(L, dot(Q.T, AM)))
        v = dot(self._tM.T, AM) - dot(AM.T, QBiQtAM)
        if not is_all_finite(v):
            raise ValueError("tbeta_denom should not be %s." % str(v))
        return v

    def _optimal_tbeta(self):
        self._update()

        if all(abs(self._M) < 1e-15):
            return zeros_like(self._tbeta)

        u = dot(self._tM.T, self._optimal_beta_nom())
        Z = self._optimal_tbeta_denom()

        try:
            with errstate(all='raise'):
                self._tbeta = solve(Z, u)

        except (LinAlgError, FloatingPointError):
            self._logger.warning('Failed to compute the optimal beta.' +
                                 ' Zeroing it.')
            self.__tbeta[:] = 0.

        return self.__tbeta

    def _optimize_beta(self):
        ptbeta = empty_like(self._tbeta)

        step = inf
        i = 0
        alpha = 1.0
        maxiter = 30
        while step > epsilon.small and i < maxiter:
            ptbeta[:] = self._tbeta
            self._optimal_tbeta()
            self._tbeta = clip(alpha * (self._tbeta - ptbeta) + ptbeta, -10,
                               +10)
            nstep = sum((self._tbeta - ptbeta)**2)

            if nstep > step:
                alpha /= 10
            step = nstep

            i += 1

        if i == maxiter:
            self._logger.warning('Maximum number of beta iterations has' +
                                 ' been attained.')

    @property
    def bounds(self):
        bounds = dict(v=(1e-3, 1 / epsilon.large),
                      delta=(0, 1 - epsilon.small))
        return bounds

    def _start_optimizer(self):
        bound = self.bounds
        x0 = [self.v]
        bounds = [bound['v']]

        if self._overdispersion:
            klass = FunCostOverdispersion
            x0 += [self.delta]
            bounds += [bound['delta']]
        else:
            klass = FunCost

        return (klass, x0, bounds)

    def _finish_optimizer(self, x):
        self.v = x[0]
        if self._overdispersion:
            self.delta = x[1]

        self._optimize_beta()

    def learn(self, progress=True):
        self._logger.debug("Start of optimization.")
        progress = tqdm(desc='EP', disable=not progress)

        (klass, x0, bounds) = self._start_optimizer()

        start = time()
        with progress as pbar:
            func = klass(self, pbar)
            x = fmin_tnc(func, x0, bounds=bounds, **_magic_numbers)[0]

        self._finish_optimizer(x)

        msg = "End of optimization (%.3f seconds, %d function calls)."
        self._logger.debug(msg, time() - start, func.nfev)

    @cachedmethod(attrgetter('_cache_A'))
    def _A(self):
        r"""Returns :math:`\mathcal A = \tilde{\mathrm T} \mathcal C^{-1}`."""
        ttau = self._sitelik_tau
        s2 = self.sigma2_epsilon
        return ttau / (ttau * s2 + 1)

    @cachedmethod(attrgetter('_cache_C'))
    def _C(self):
        r"""Returns :math:`\mathcal C = \sigma_{\epsilon}^2 \tilde{\mathrm T} +
            \mathrm I`."""
        ttau = self._sitelik_tau
        s2 = self.sigma2_epsilon
        return 1 / (ttau * s2 + 1)

    @cachedmethod(attrgetter('_cache_SQt'))
    def _SQt(self):
        r"""Returns :math:`\mathrm S \mathrm Q^\intercal`."""
        return ddot(self._S, self._Q.T, left=True)

    def _QSQt(self):
        r"""Returns :math:`\mathrm Q \mathrm S \mathrm Q^\intercal`."""
        if self.__QSQt is None:
            Q = self._Q
            self.__QSQt = dot(Q, self._SQt())
        return self.__QSQt

    @cachedmethod(attrgetter('_cache_BiQt'))
    def _BiQt(self):
        Q = self._Q
        return cho_solve(self._L(), Q.T)

    @cachedmethod(attrgetter('_cache_L'))
    def _L(self):
        r"""Returns the Cholesky factorization of :math:`\mathcal B`.

        .. math::

            \mathcal B = \mathrm Q^{\intercal}\mathcal A\mathrm Q
                (\sigma_b^2 \mathrm S)^{-1}
        """
        Q = self._Q
        A = self._A()
        B = dot(Q.T, ddot(A, Q, left=True))
        sum2diag(B, 1. / (self.sigma2_b * self._S), out=B)
        return cho_factor(B, lower=True)[0]

    @cachedmethod(attrgetter('_cache_QBiQtCteta'))
    def _QBiQtCteta(self):
        Q = self._Q
        L = self._L()
        C = self._C()
        teta = self._sitelik_eta
        return dot(Q, cho_solve(L, dot(Q.T, C * teta)))

    @cachedmethod(attrgetter('_cache_QBiQtAm'))
    def _QBiQtAm(self):
        Q = self._Q
        L = self._L()
        A = self._A()
        m = self.m()

        return dot(Q, cho_solve(L, dot(Q.T, A * m)))

    def _uBiQtAK(self):
        BiQt = self._BiQt()
        S = self._S
        Q = self._Q
        BiQtA = ddot(BiQt, self._A(), left=False)
        BiQtAQS = dot(BiQtA, Q)
        ddot(BiQtAQS, S, left=False, out=BiQtAQS)

        return dot(BiQtAQS, Q.T), BiQtA

    def covariance(self):
        K = self.K()
        return sum2diag(K, 1 / self._sitelik_tau)

    def get_normal_likelihood_trick(self):
        # Covariance: nK = K + \tilde\Sigma = K + 1/self._sitelik_tau
        # via (K + 1/self._sitelik_tau)^{-1} = A1 - A1QB1^-1QTA1
        # Mean: \mathbf m
        # New phenotype: \tilde\mu
        #
        # I.e.: \tilde\mu \sim N(\mathbf m, K + \tilde\Sigma)
        #
        #
        # We transform the above Normal in an equivalent but more robust
        # one: \tilde\y \sim N(\tilde\m, \tilde\nK + \Sigma^{-1})
        #
        # \tilde\y = \tilde\Sigma^{-1} \tilde\mu
        # \tilde\m = \tilde\Sigma^{-1} \tilde\m
        # \tilde\nK = \tilde\Sigma^{-1} \nK \tilde\Sigma^{-1}

        m = self.m()
        ttau = self._sitelik_tau
        teta = self._sitelik_eta

        # NEW PHENOTYPE
        y = teta.copy()

        # NEW MEAN
        m = ttau * m

        # NEW COVARIANCE
        K = self.K()
        K = ddot(ttau, ddot(K, ttau, left=False), left=True)
        sum2diag(K, ttau, out=K)
        (Q, S0) = economic_qs(K)
        Q0, Q1 = Q

        from ...lmm import FastLMM
        from numpy import newaxis

        fastlmm = FastLMM(y, Q0, Q1, S0, covariates=m[:, newaxis])
        fastlmm.learn(progress=False)
        return fastlmm.get_normal_likelihood_trick()

    def _paolo(self):
        tK = self.covariance()
        tmu = self._sitelik_eta / self._sitelik_tau
        tS = 1 / self._sitelik_tau
        sigg2 = self.sigma2_b
        sige2 = self.sigma2_epsilon
        return dict(tK=tK, tmu=tmu, tS=tS, sigg2=sigg2, sige2=sige2)
Beispiel #5
0
class CodeStorageMongo(REIL.CodeStorageMem):

    # mongodb host
    DEF_HOST = '127.0.0.1'
    DEF_PORT = 27017

    # defult database name
    DEF_DB = 'openreil'

    # index for instructions collection
    INDEX = [('addr', pymongo.ASCENDING), ('inum', pymongo.ASCENDING)]

    CACHE_SIZE = 1024

    def __init__(self, arch, collection, db = None, host = None, port = None):
        
        self.arch = arch
        self.db_name = self.DEF_DB if db is None else db
        self.collection_name = collection

        self.host = self.DEF_HOST if host is None else host
        self.port = self.DEF_PORT if port is None else port

        # instructions cache
        self.cache = LRUCache(maxsize = self.CACHE_SIZE)

        # connect to the server
        self.client = pymongo.Connection(self.host, self.port)
        self.db = self.client[self.db_name]

        # get collection        
        self.collection = self.db[self.collection_name]
        self.collection.ensure_index(self.INDEX)     

    def __iter__(self):

        for item in self.collection.find().sort(self.INDEX): 

            yield REIL.Insn(self._insn_from_item(item))

    def _insn_to_item(self, insn):

        insn = REIL.Insn(insn)        

        def _arg_in(arg):

            if arg.type == REIL.A_NONE:  

                return ()
            
            elif arg.type == REIL.A_CONST:

                return ( arg.type, arg.size, _U64IN(arg.val) )
            
            else:

                return ( arg.type, arg.size, arg.name )

        if insn.has_attr(REIL.IATTR_BIN):

            # JSON doesn't support binary data
            insn.set_attr(REIL.IATTR_BIN, base64.b64encode(insn.get_attr(REIL.IATTR_BIN)))

        # JSON doesn't support numeric keys
        attr = [ (key, val) for key, val in insn.attr.items() ]

        return {

            'addr': _U64IN(insn.addr), 'size': insn.size, 'inum': insn.inum, 'op': insn.op, \
            'a': _arg_in(insn.a), 'b': _arg_in(insn.b), 'c': _arg_in(insn.c), \
            'attr': attr
        }

    def _insn_from_item(self, item):

        attr, attr_dict = item['attr'], {}

        def _arg_out(arg):

            if len(arg) == 0: 

                return ()

            elif REIL.Arg_type(arg) == REIL.A_CONST:

                arg = ( REIL.Arg_type(arg), REIL.Arg_size(arg), _U64OUT(REIL.Arg_val(arg)) )

            return arg                

        for key, val in attr:

            attr_dict[key] = val

        if attr_dict.has_key(REIL.IATTR_BIN):

            # get instruction binary data from base64
            attr_dict[REIL.IATTR_BIN] = base64.b64decode(attr_dict[REIL.IATTR_BIN])

        return ( 

            ( _U64OUT(item['addr']), item['size'] ), item['inum'], item['op'], \
            ( _arg_out(item['a']), _arg_out(item['b']), _arg_out(item['c']) ), \
            attr_dict
        ) 

    def _get_key(self, ir_addr):

        return { 'addr': ir_addr[0], 'inum': ir_addr[1] }

    def _find(self, ir_addr):

        return self.collection.find_one(self._get_key(ir_addr))

    def _get_insn(self, ir_addr): 

        # get item from cache
        try: return self.cache[ir_addr]
        except KeyError: pass

        # get item from collection
        insn = self._find(ir_addr)
        if insn is not None: 

            insn = self._insn_from_item(insn)

            # update cache
            self.cache[ir_addr] = insn

            return insn

        else:

            raise REIL.StorageError(*ir_addr)

    def _del_insn(self, ir_addr):

        insn = self._find(ir_addr)
        if insn is not None: 

            # remove item from collection
            self.collection.remove(self._get_key(ir_addr))

            # remove item from cache
            try: del self.cache[ir_addr]
            except KeyError: pass

        else:

            raise REIL.StorageError(*ir_addr)

    def _put_insn(self, insn):

        ir_addr = REIL.Insn_ir_addr(insn)

        if self._find(ir_addr) is not None:

            # update existing item
            self.collection.update(self._get_key(ir_addr), self._insn_to_item(insn))            

        else:

            # add a new item
            self.collection.insert(self._insn_to_item(insn))

        # update cache
        self.cache[ir_addr] = insn

    def size(self): 

        return self.collection.find().count()

    def clear(self): 

        self.cache.clear()

        # remove all items of collection
        return self.collection.remove()
class MultipleRandomWalk(object):
    cache_misses = 0
    total_steps = 0

    def __init__(self,
                 alpha=0.1,
                 n=100000,
                 n_p=500,
                 n_v=4,
                 n_jobs=0,
                 use_boosting=False,
                 dynamic_allocation=False,
                 pixie_weighting=False,
                 cache_maxsize=2048):
        self._alpha = alpha
        self._n = n
        self._n_p = n_p
        self._n_v = n_v
        self._n_jobs = n_jobs
        self._use_boosting = use_boosting
        self._dynamic_allocation = dynamic_allocation
        self._pixie_weighting = pixie_weighting
        self._cache_maxsize = cache_maxsize
        self._playlist_model = None
        self._cache = Manager().dict()
        self._query_pid = None
        self._edge_mask = None
        self.cache = LRUCache(maxsize=cache_maxsize)

    def run_random_walk(self,
                        query_tracks,
                        playlist_model=None,
                        pid=None,
                        counts_per_node=False):
        # Clear the cache, we have a new model
        self._edge_mask = None
        self.cache.clear()
        self.total_steps = 0
        self.cache_misses = 0
        self._query_pid = pid

        self._playlist_model = playlist_model
        if self._n_jobs <= 0:
            return self.pixie_random_walk_multiple(
                query_tracks, counts_per_node=counts_per_node)
        else:
            return self.pixie_random_walk_multiple_parallel(
                query_tracks, counts_per_node=counts_per_node)

    @cachetools.cachedmethod(operator.attrgetter('cache'))
    def get_neighborhood(self, node):
        self.cache_misses += 1
        return shared.graph.neighbors_fast(node, self._edge_mask)

    @cachetools.cachedmethod(operator.attrgetter('cache'))
    def get_transition_probabilities(self, playlist, query_pid):
        self.cache_misses += 1
        neighbors = shared.graph.neighbors(playlist, only_enabled=True)
        if len(neighbors) > 0:
            track_features = shared.graph.features[neighbors - 1000000]
            scores = self._playlist_model.score(track_features)
            return scores, neighbors
        else:
            return None, None

    def get_transition_probabilities_nocache(self, playlist, query_pid):
        self.cache_misses += 1
        neighbors = shared.graph.neighbors(playlist, only_enabled=True)
        if len(neighbors) > 0:
            track_features = shared.graph.features[neighbors - 1000000]
            scores = self._playlist_model.score(track_features)
            return scores, neighbors
        else:
            return None, None

    def pixie_random_walk(self, q, edge_mask, n=None):
        self._edge_mask = edge_mask
        if n is None:
            n = self._n
        visit_count = Counter()
        total_steps = 0
        n_high_visited = 0

        while total_steps < n and n_high_visited < self._n_p:
            # Restart to the query track
            curr_track = q
            sample_steps = np.random.geometric(self._alpha)
            for i in range(sample_steps):
                curr_playlist = shared.graph.sample_neighbor_fast(
                    curr_track, edge_mask)
                # Reached a dead end
                if curr_playlist is None:
                    if total_steps == 0:
                        # The query node does not have any connections
                        return Counter(), False
                    else:
                        break

                if self._playlist_model is None:
                    curr_track = shared.graph.sample_neighbor_fast(
                        curr_playlist, edge_mask)
                else:
                    p, neighbors = self.get_transition_probabilities(
                        curr_playlist, self._query_pid)
                    if neighbors is not None and len(neighbors) > 0:
                        curr_track = random.choices(neighbors, weights=p)[0]
                    else:
                        curr_track = None
                # Reached a dead end
                if curr_track is None:
                    break

                visit_count[curr_track] += 1

                if visit_count[curr_track] == self._n_v:
                    n_high_visited += 1

                total_steps += 1
                self.total_steps += 2
                if total_steps >= self._n or n_high_visited >= self._n_p:
                    break
        return visit_count, n_high_visited >= self._n_p

    def pixie_random_walk_multiple(self, query_tracks, counts_per_node=False):
        scaling_factors = [
            scaling_factor(q, self._pixie_weighting) for q in query_tracks
        ]
        summed_scaling_factors = np.sum(scaling_factors)
        visit_counts = []
        boosted_visit_counts = {}
        total_random_walks = len(query_tracks)
        early_stopped_count = 0.0

        for q, s in zip(query_tracks, scaling_factors):
            n_q = self._n * s / summed_scaling_factors if self._dynamic_allocation else int(
                self._n / len(query_tracks))
            visit_q, early_stopped = self.pixie_random_walk(
                q, shared.graph.edge_mask, n=n_q)
            if early_stopped:
                early_stopped_count += 1.0
            visit_counts.append(visit_q)

        print("misses = {}, hits = {}, total steps = {}".format(
            self.cache_misses, self.total_steps - self.cache_misses,
            self.total_steps))

        if counts_per_node:
            return visit_counts, early_stopped_count, total_random_walks

        for v in visit_counts:
            for track, visit_count in v.items():
                if self._use_boosting:
                    if track in boosted_visit_counts:
                        boosted_visit_counts[track] += np.sqrt(visit_count)
                    else:
                        boosted_visit_counts[track] = np.sqrt(visit_count)
                else:
                    if track in boosted_visit_counts:
                        boosted_visit_counts[track] += visit_count
                    else:
                        boosted_visit_counts[track] = visit_count

        if self._use_boosting:
            boosted_visit_counts = {
                k: v**2
                for k, v in boosted_visit_counts.items()
            }

        return boosted_visit_counts, early_stopped_count, total_random_walks

    def worker_random_walk(self, q, summed_scaling_factors, n, pid):
        self.cache.clear()
        if self._dynamic_allocation:
            n_q = n * scaling_factor(
                q, self._pixie_weighting) / summed_scaling_factors
        else:
            n_q = n
        edge_mask = np.frombuffer(shared.arr_enabled, dtype='bool')
        return self.pixie_random_walk(q, edge_mask, n=n_q)

    def pixie_random_walk_multiple_parallel(self,
                                            query_tracks,
                                            counts_per_node=False):
        scaling_factors = [
            scaling_factor(q, self._pixie_weighting) for q in query_tracks
        ]
        summed_scaling_factors = np.sum(scaling_factors)
        boosted_visit_counts = {}
        total_random_walks = len(query_tracks)

        if self._dynamic_allocation:
            n_q = self._n
        else:
            n_q = int(self._n / len(query_tracks))

        self.cache.clear()
        results = shared.pool.map(
            functools.partial(self.worker_random_walk,
                              summed_scaling_factors=summed_scaling_factors,
                              n=n_q,
                              pid=self._query_pid), query_tracks)

        early_stopped_count = np.sum([int(e) for _, e in results])
        if counts_per_node:
            return [v for v, _ in results
                    ], early_stopped_count, total_random_walks

        for v, early_stopped in results:
            for track, visit_count in v.items():
                if self._use_boosting:
                    if track in boosted_visit_counts:
                        boosted_visit_counts[track] += np.sqrt(visit_count)
                    else:
                        boosted_visit_counts[track] = np.sqrt(visit_count)
                else:
                    if track in boosted_visit_counts:
                        boosted_visit_counts[track] += visit_count
                    else:
                        boosted_visit_counts[track] = visit_count

        if self._use_boosting:
            boosted_visit_counts = {
                k: v**2
                for k, v in boosted_visit_counts.items()
            }

        return boosted_visit_counts, early_stopped_count, total_random_walks