Example #1
0
def _subsample(counts, n, replace=False, seed=0):
    """Randomly subsample from a vector of counts.

    Parameters
    ----------
    counts : 1-D array_like
        Vector of counts.
    n : int
        Number of element to subsample (<= the total number of counts).
    replace : bool, optional
        Subsample with or without replacement.
    seed : int, optional
        Random seed.

    Returns
    -------
    subcounts : 1-D ndarray
        Subsampled vector of counts

    Raises
    ------
    ValueError, TypeError
    """

    if n < 0:
        raise ValueError("'n' must be > 0 ")

    counts = np.asarray(counts)

    if counts.ndim != 1:
        raise ValueError("counts must be an 1-D array_like object")

    counts = counts.astype(int, casting='safe')
    counts_sum = counts.sum()

    if n > counts_sum:
        raise ValueError("'n' must be <= the total number of counts")

    prng = RandomState(seed)

    if replace:
        p = counts / counts_sum
        subcounts = prng.multinomial(n, p)
    else:
        nonzero = np.flatnonzero(counts)
        expanded = np.concatenate([np.repeat(i, counts[i]) for i in nonzero])
        permuted = prng.permutation(expanded)[:n]
        subcounts = np.bincount(permuted, minlength=counts.size)

    return subcounts
Example #2
0
def _subsample(counts, n, replace=False, seed=0):
    """Randomly subsample from a vector of counts.

    Parameters
    ----------
    counts : 1-D array_like
        Vector of counts.
    n : int
        Number of element to subsample (<= the total number of counts).
    replace : bool, optional
        Subsample with or without replacement.
    seed : int, optional
        Random seed.

    Returns
    -------
    subcounts : 1-D ndarray
        Subsampled vector of counts

    Raises
    ------
    ValueError, TypeError
    """

    if n < 0:
        raise ValueError("'n' must be > 0 ")

    counts = np.asarray(counts)

    if counts.ndim != 1:
        raise ValueError("counts must be an 1-D array_like object")

    counts = counts.astype(int, casting='safe')
    counts_sum = counts.sum()

    if n > counts_sum:
        raise ValueError("'n' must be <= the total number of counts")

    prng = RandomState(seed)

    if replace:
        p = counts / counts_sum
        subcounts = prng.multinomial(n, p)
    else:
        nonzero = np.flatnonzero(counts)
        expanded = np.concatenate([np.repeat(i, counts[i]) for i in nonzero])
        permuted = prng.permutation(expanded)[:n]
        subcounts = np.bincount(permuted, minlength=counts.size)

    return subcounts
Example #3
0
def _collect_pathways_by_clusters(cluster_distribution: NDArray) -> List[List[PathwayList]]:
    pathways_by_clusters = []
    for i, cluster_size in enumerate(cluster_distribution):
        sources, transition_probabilities_by_sources = obtain_sources_and_transition_probabilities(i, START_LOCATION,
                                                                                                   FINISH_LOCATION)
        cluster_pathways = [[START_LOCATION] for _ in range(cluster_size)]
        current_sources = set(sources[:])
        while not _check_all_pathways_is_over(current_sources):
            for source in current_sources:
                if source == FINISH_LOCATION:
                    continue
                transitions_probabilities = transition_probabilities_by_sources[source]
                selected_pathways = [pathway for pathway in cluster_pathways if pathway[-1] == source]
                random = RandomState()
                target_distribution = random.multinomial(len(selected_pathways), transitions_probabilities, size=1)[0]
                for source_index, target_size in enumerate(target_distribution):
                    for _ in range(target_size):
                        pathway = selected_pathways.pop()
                        pathway.append(sources[source_index])
            current_sources = {pathway[-1] for pathway in cluster_pathways}
        pathways_by_clusters.append(cluster_pathways)
    return pathways_by_clusters
Example #4
0
File: hlda.py Project: wangjs/hlda
class NCRPNode(object):
    
    # class variable to keep track of total nodes created so far
    total_nodes = 0
    last_node_id = 0
    
    def __init__(self, num_levels, vocab, parent=None, level=0, 
                 random_state=None):

        self.node_id = NCRPNode.last_node_id
        NCRPNode.last_node_id += 1
        
        self.customers = 0
        self.parent = parent
        self.children = []
        self.level = level        
        self.total_words = 0
        self.num_levels = num_levels

        self.vocab = np.array(vocab)
        self.word_counts = np.zeros(len(vocab))
                
        if random_state is None:
            self.random_state = RandomState()
        else:
            self.random_state = random_state  
            
    def __repr__(self):
        parent_id = None
        if self.parent is not None:
            parent_id = self.parent.node_id
        return 'Node=%d level=%d customers=%d total_words=%d parent=%s' % (self.node_id, 
            self.level, self.customers, self.total_words, parent_id)    
                        
    def add_child(self):
        ''' Adds a child to the next level of this node '''
        node = NCRPNode(self.num_levels, self.vocab, parent=self, level=self.level+1)
        self.children.append(node)
        NCRPNode.total_nodes += 1
        return node

    def is_leaf(self):
        ''' Check if this node is a leaf node '''
        return self.level == self.num_levels-1
    
    def get_new_leaf(self):
        ''' Keeps adding nodes along the path until a leaf node is generated'''
        node = self
        for l in range(self.level, self.num_levels-1):
            node = node.add_child()
        return node
            
    def drop_path(self):
        ''' Removes a document from a path starting from this node '''
        node = self
        node.customers -= 1
        if node.customers == 0:
            node.parent.remove(node)
        for level in range(1, self.num_levels): # skip the root
            node = node.parent
            node.customers -= 1
            if node.customers == 0:
                node.parent.remove(node)

    def remove(self, node):
        ''' Removes a child node '''
        self.children.remove(node)
        NCRPNode.total_nodes -= 1
        
    def add_path(self):
        ''' Adds a document to a path starting from this node '''
        node = self
        node.customers += 1
        for level in range(1, self.num_levels):
            node = node.parent
            node.customers += 1

    def select(self, gamma):
        ''' Selects an existing child or create a new one according to the CRP '''
        
        weights = np.zeros(len(self.children)+1)
        weights[0] = float(gamma) / (gamma+self.customers)
        i = 1
        for child in self.children:
            weights[i] = float(child.customers) / (gamma + self.customers)
            i += 1

        choice = self.random_state.multinomial(1, weights).argmax()
        if choice == 0:
            return self.add_child()
        else:
            return self.children[choice-1]   
                
    def get_top_words(self, n_words, with_weight):  
        ''' Get the top n words in this node '''   

        pos = np.argsort(self.word_counts)[::-1]
        sorted_vocab = self.vocab[pos]
        sorted_vocab = sorted_vocab[:n_words]
        sorted_weights = self.word_counts[pos]
        sorted_weights = sorted_weights[:n_words]
        
        output = ''
        for word, weight in zip(sorted_vocab, sorted_weights):
            if with_weight:
                output += '%s (%d), ' % (word, weight)
            else:
                output += '%s, ' % word                
        return output            
Example #5
0
File: hlda.py Project: wangjs/hlda
class HierarchicalLDA(object):
    
    def __init__(self, corpus, vocab, 
                 alpha=10.0, gamma=1.0, eta=0.1, 
                 seed=0, verbose=True, num_levels=3):
        
        NCRPNode.total_nodes = 0
        NCRPNode.last_node_id = 0        
        
        self.corpus = corpus
        self.vocab = vocab
        self.alpha = alpha  # smoothing on doc-topic distributions
        self.gamma = gamma  # "imaginary" customers at the next, as yet unused table
        self.eta = eta      # smoothing on topic-word distributions

        self.seed = seed
        self.random_state = RandomState(seed)        
        self.verbose = verbose

        self.num_levels = num_levels
        self.num_documents = len(corpus)
        self.num_types = len(vocab)
        self.eta_sum = eta * self.num_types

        # if self.verbose:        
        #     for d in range(len(self.corpus)):
        #         doc = self.corpus[d]
        #         words = ' '.join([self.vocab[n] for n in doc])
        #         print 'doc_%d = %s' % (d, words)  
        
        # initialise a single path
        path = np.zeros(self.num_levels, dtype=np.object)
        
        # initialize and fill the topic pointer arrays for 
        # every document. Set everything to the single path that 
        # we added earlier.
        self.root_node = NCRPNode(self.num_levels, self.vocab)
        self.document_leaves = {}                                   # currently selected path (ie leaf node) through the NCRP tree
        self.levels = np.zeros(self.num_documents, dtype=np.object) # indexed < doc, token >
        for d in range(len(self.corpus)):
            
            # populate nodes into the path of this document
            doc = self.corpus[d]
            doc_len = len(doc)
            path[0] = self.root_node
            self.root_node.customers += 1 # always add to the root node first
            for level in range(1, self.num_levels):
                # at each level, a node is selected by its parent node based on the CRP prior
                parent_node = path[level-1]
                level_node = parent_node.select(self.gamma)
                level_node.customers += 1
                path[level] = level_node
                
            # set the leaf node for this document                 
            leaf_node = path[self.num_levels-1]
            self.document_leaves[d] = leaf_node
                        
            # randomly assign each word in the document to a level (node) along the path
            self.levels[d] = np.zeros(doc_len, dtype=np.int)
            for n in range(doc_len):
                w = doc[n]                
                random_level = self.random_state.randint(self.num_levels)
                random_node = path[random_level]
                random_node.word_counts[w] += 1
                random_node.total_words += 1
                self.levels[d][n] = random_level                

    def estimate(self, num_samples, display_topics=50, n_words=5, with_weights=True):
        
        print 'HierarchicalLDA sampling'
        for s in range(num_samples):
            
            sys.stdout.write('.')
            
            for d in range(len(self.corpus)):
                self.sample_path(d)
            
            for d in range(len(self.corpus)):
                self.sample_topics(d)
                
            if (s > 0) and ((s+1) % display_topics == 0):
                print
                self.print_nodes(n_words, with_weights)

    def sample_path(self, d):
        
        # define a path starting from the leaf node of this doc
        path = np.zeros(self.num_levels, dtype=np.object)
        node = self.document_leaves[d]
        for level in range(self.num_levels-1, -1, -1): # e.g. [3, 2, 1, 0] for num_levels = 4
            path[level] = node
            node = node.parent
            
        # remove this document from the path, deleting empty nodes if necessary
        self.document_leaves[d].drop_path()
        
        ############################################################
        # calculates the prior p(c_d | c_{-d}) in eq. (4)
        ############################################################

        node_weights = {}
        self.calculate_ncrp_prior(node_weights, self.root_node, 0.0)
        
        ############################################################
        # calculates the likelihood p(w_d | c, w_{-d}, z) in eq. (4)
        ############################################################

        level_word_counts = {}
        for level in range(self.num_levels):
            level_word_counts[level] = {}        
        doc_levels = self.levels[d]
        doc = self.corpus[d]
        
        # remove doc from path
        for n in range(len(doc)): # for each word in the doc
            
            # count the word at each level
            level = doc_levels[n]
            w = doc[n]
            if w not in level_word_counts[level]:
                level_word_counts[level][w] = 1
            else:
                level_word_counts[level][w] += 1

            # remove word count from the node at that level
            level_node = path[level]
            level_node.word_counts[w] -= 1
            level_node.total_words -= 1
            assert level_node.word_counts[w] >= 0
            assert level_node.total_words >= 0

        self.calculate_doc_likelihood(node_weights, level_word_counts)

        ############################################################
        # pick a new path
        ############################################################

        nodes = np.array(list(node_weights.keys()))
        weights = np.array([node_weights[node] for node in nodes])
        weights = np.exp(weights - np.max(weights)) # normalise so the largest weight is 1
        weights = weights / np.sum(weights)

        choice = self.random_state.multinomial(1, weights).argmax()
        node = nodes[choice]
        
        # if we picked an internal node, we need to add a new path to the leaf
        if not node.is_leaf():
            node = node.get_new_leaf()

        # add the doc back to the path
        node.add_path()                     # add a customer to the path
        self.document_leaves[d] = node      # store the leaf node for this doc

        # add the words
        for level in range(self.num_levels-1, -1, -1): # e.g. [3, 2, 1, 0] for num_levels = 4
            word_counts = level_word_counts[level]
            for w in word_counts:
                node.word_counts[w] += word_counts[w]
                node.total_words += word_counts[w]
            node = node.parent        
        
    def calculate_ncrp_prior(self, node_weights, node, weight):
        ''' Calculates the prior on the path according to the nested CRP '''

        for child in node.children:
            child_weight = log( float(child.customers) / (node.customers + self.gamma) )
            self.calculate_ncrp_prior(node_weights, child, weight + child_weight)
        
        node_weights[node] = weight + log( self.gamma / (node.customers + self.gamma))

    def calculate_doc_likelihood(self, node_weights, level_word_counts):

        # calculate the weight for a new path at a given level
        new_topic_weights = np.zeros(self.num_levels)
        for level in range(1, self.num_levels):  # skip the root

            word_counts = level_word_counts[level]
            total_tokens = 0

            for w in word_counts:
                count = word_counts[w]
                for i in range(count):  # why ?????????
                    new_topic_weights[level] += log((self.eta + i) / (self.eta_sum + total_tokens))
                    total_tokens += 1

        self.calculate_word_likelihood(node_weights, self.root_node, 0.0, level_word_counts, new_topic_weights, 0)

    def calculate_word_likelihood(self, node_weights, node, weight, level_word_counts, new_topic_weights, level):
                
        # first calculate the likelihood of the words at this level, given this topic
        node_weight = 0.0
        word_counts = level_word_counts[level]
        total_words = 0
        
        for w in word_counts:
            count = word_counts[w]
            for i in range(count): # why ?????????
                node_weight += log( (self.eta + node.word_counts[w] + i) / 
                                    (self.eta_sum + node.total_words + total_words) )
                total_words += 1
                
        # propagate that weight to the child nodes
        for child in node.children:
            self.calculate_word_likelihood(node_weights, child, weight + node_weight, 
                                           level_word_counts, new_topic_weights, level+1)
            
        # finally if this is an internal node, add the weight of a new path
        level += 1
        while level < self.num_levels:
            node_weight += new_topic_weights[level]
            level += 1
            
        node_weights[node] += node_weight
        
    def sample_topics(self, d):

        doc = self.corpus[d]
        
        # initialise level counts
        doc_levels = self.levels[d]
        level_counts = np.zeros(self.num_levels, dtype=np.int)
        for c in doc_levels:
            level_counts[c] += 1

        # get the leaf node and populate the path
        path = np.zeros(self.num_levels, dtype=np.object)
        node = self.document_leaves[d]
        for level in range(self.num_levels-1, -1, -1): # e.g. [3, 2, 1, 0] for num_levels = 4
            path[level] = node
            node = node.parent

        # sample a new level for each word
        level_weights = np.zeros(self.num_levels)            
        for n in range(len(doc)):

            w = doc[n]            
            word_level = doc_levels[n]

            # remove from model
            level_counts[word_level] -= 1
            node = path[word_level]
            node.word_counts[w] -= 1
            node.total_words -= 1

            # pick new level
            for level in range(self.num_levels):
                level_weights[level] = (self.alpha + level_counts[level]) *                     \
                    (self.eta + path[level].word_counts[w]) /                                   \
                    (self.eta_sum + path[level].total_words)
            level_weights = level_weights / np.sum(level_weights)
            level = self.random_state.multinomial(1, level_weights).argmax()
            
            # put the word back into the model
            doc_levels[n] = level
            level_counts[level] += 1
            node = path[level]
            node.word_counts[w] += 1
            node.total_words += 1
        
    def print_nodes(self, n_words, with_weights):
        self.print_node(self.root_node, 0, n_words, with_weights)
        
    def print_node(self, node, indent, n_words, with_weights):
        out = '    ' * indent
        out += 'topic %d (level=%d, total_words=%d, documents=%d): ' % (node.node_id, node.level, node.total_words, node.customers)
        out += node.get_top_words(n_words, with_weights)
        print out        
        for child in node.children:
            self.print_node(child, indent+1, n_words, with_weights)        
Example #6
0
class HDHProcess:
    def __init__(self,
                 num_patterns,
                 alpha_0,
                 mu_0,
                 vocabulary,
                 omega=1,
                 doc_length=20,
                 doc_min_length=5,
                 words_per_pattern=10,
                 random_state=None):
        """
        Parameters
        ----------
        num_patterns : int
            The maximum number of patterns that will be shared across
            the users.

        alpha_0 : tuple
            The parameter that is used when sampling the time kernel weights
            of each pattern. The distribution that is being used is a Gamma.
            This tuple should be of the form (shape, scale).

        mu_0 : tuple
            The parameter of the Gamma distribution that is used to sample
            each user's \mu (activity level). This tuple should be of the
            form (shape, scale).

        vocabulary : list
            The list of available words to use when generating documents.

        omega : float, default is 1
            The decay parameter for the decay of the exponential decay kernel.

        doc_length : int, default is 20
            The maximum number of words per document.

        doc_min_length : int, default is 5
            The minimum number of words per document.

        words_per_pattern: int, default is 10
            The number of words that will have a non-zero probability to appear
            in each pattern.

        random_state: int or RandomState object, default is None
            The random number generator.
        """
        self.prng = check_random_state(random_state)
        self.doc_prng = RandomState(self.prng.randint(200000000))
        self.time_kernel_prng = RandomState(self.prng.randint(2000000000))
        self.pattern_param_prng = RandomState(self.prng.randint(2000000000))

        self.num_patterns = num_patterns
        self.alpha_0 = alpha_0
        self.vocabulary = vocabulary
        self.mu_0 = mu_0
        self.document_length = doc_length
        self.document_min_length = doc_min_length
        self.omega = omega
        self.words_per_pattern = words_per_pattern
        self.pattern_params = self.sample_pattern_params()
        self.time_kernels = self.sample_time_kernels()
        self.pattern_popularity = self.sample_pattern_popularity()

        # Initialize all the counters etc.
        self.reset()

    def reset(self):
        """Removes all the events and users already sampled.


        Note
        ----
        It does not reseed the random number generator. It also retains the
        already sampled pattern parameters (word distributions and alphas)
        """
        self.mu_per_user = {}
        self.num_users = 0
        self.time_history = []
        self.time_history_per_user = {}
        self.table_history_per_user = {}
        self.dish_on_table_per_user = {}
        self.dish_counters = defaultdict(int)
        self.last_event_user_pattern = defaultdict(lambda: defaultdict(int))
        self.total_tables = 0
        self.first_observed_time = {}
        self.user_table_cache = defaultdict(dict)
        self.table_history_per_user = defaultdict(list)
        self.time_history_per_user = defaultdict(list)
        self.document_history_per_user = defaultdict(list)
        self.dish_on_table_per_user = defaultdict(dict)
        self.cache_per_user = defaultdict(dict)
        self.total_tables_per_user = defaultdict(int)
        self.events = []
        self.per_pattern_word_counts = defaultdict(lambda: defaultdict(int))
        self.per_pattern_word_count_total = defaultdict(int)

    def sample_pattern_params(self):
        """Returns the word distributions for each pattern.


        Returns
        -------
        parameters : list
            A list of word distributions, one for each pattern.
        """
        sampled_params = {}
        V = len(self.vocabulary)
        for pattern in range(self.num_patterns):
            custom_theta = [0] * V
            words_in_pattern = self.prng.choice(V,
                                                size=self.words_per_pattern,
                                                replace=False)
            for word in words_in_pattern:
                custom_theta[word] = 100. / self.words_per_pattern
            sampled_params[pattern] = \
                self.pattern_param_prng.dirichlet(custom_theta)
        return sampled_params

    def sample_time_kernels(self):
        """Returns the time decay parameter of each pattern.


        Returns
        -------
        alphas : list
            A list of time decay parameters, one for each pattern.
        """
        alphas = {
            pattern: self.time_kernel_prng.gamma(self.alpha_0[0],
                                                 self.alpha_0[1])
            for pattern in range(self.num_patterns)
        }
        return alphas

    def sample_pattern_popularity(self):
        """Returns a popularity distribution over the patterns.


        Returns
        -------
        pattern_popularities : list
            A list with the popularity distribution of each pattern.
        """
        pattern_popularity = {}
        pi = self.pattern_param_prng.dirichlet(
            [1 for pattern in range(self.num_patterns)])
        for pattern_i, pi_i in enumerate(pi):
            pattern_popularity[pattern_i] = pi_i
        return pattern_popularity

    def sample_mu(self):
        """Samples a value from the prior of the base intensity mu.


        Returns
        -------
        mu_u : float
            The base intensity of a user, sampled from the prior.
        """
        return self.prng.gamma(self.mu_0[0], self.mu_0[1])

    def sample_next_time(self, pattern, user):
        """Samples the time of the next event of a pattern for a given user.


        Parameters
        ----------
        pattern : int
            The pattern index that we want to sample the next event for.

        user : int
            The index of the user that we want to sample for.


        Returns
        -------
        timestamp : float
        """
        U = self.prng.rand
        mu_u = self.mu_per_user[user]
        lambda_u_pattern = mu_u * self.pattern_popularity[pattern]

        if user in self.total_tables_per_user:
            # We have seen the user before
            num_tables = self.total_tables_per_user[user]
            pattern_tables = [
                table for table in range(num_tables)
                if self.dish_on_table_per_user[user][table] == pattern
            ]
        else:
            pattern_tables = []
        alpha = self.time_kernels[pattern]
        if not pattern_tables:
            lambda_star = lambda_u_pattern
            s = -1 / lambda_star * np.log(U())
            return s
        else:
            # Add the \alpha of the most recent table (previous event) in the
            # user-pattern intensity
            s = self.last_event_user_pattern[user][pattern]
            pattern_intensity = 0
            for table in pattern_tables:
                t_last, sum_kernels = self.user_table_cache[user][table]
                update_value = self.kernel(s, t_last)
                # update_value should be 1, so practically we are just adding
                # \alpha to the intensity dt after the event
                table_intensity = alpha * sum_kernels * update_value
                table_intensity += alpha * update_value
                pattern_intensity += table_intensity
            lambda_star = lambda_u_pattern + pattern_intensity

            # New event
            accepted = False
            while not accepted:
                s = s - 1 / lambda_star * np.log(U())
                # Rejection test
                pattern_intensity = 0
                for table in pattern_tables:
                    t_last, sum_kernels = self.user_table_cache[user][table]
                    update_value = self.kernel(s, t_last)
                    # update_value should be 1, so practically we are just adding
                    # \alpha to the intensity dt after the event
                    table_intensity = alpha * sum_kernels * update_value
                    table_intensity += alpha * update_value
                    pattern_intensity += table_intensity
                lambda_s = lambda_u_pattern + pattern_intensity
                if U() < lambda_s / lambda_star:
                    return s
                else:
                    lambda_star = lambda_s

    def sample_user_events(self,
                           min_num_events=100,
                           max_num_events=None,
                           t_max=None):
        """Samples events for a user.


        Parameters
        ----------
        min_num_events : int, default is 100
            The minimum number of events to sample.

        max_num_events : int, default is None
            If not None, this is the maximum number of events to sample.

        t_max : float, default is None
            The time limit until which to sample events.


        Returns
        -------
        events : list
            A list of the form [(t_i, doc_i, user_i, meta_i), ...] sorted by
            increasing time that has all the events of the sampled users.
            Above, doc_i is the document and meta_i is any sort of metadata
            that we want for doc_i, e.g. question_id. The generator will return
            an empty list for meta_i.
        """
        user = len(self.mu_per_user)
        mu_u = self.sample_mu()
        self.mu_per_user[user] = mu_u

        # Populate the list with the first event for each pattern
        next_time_per_pattern = [
            self.sample_next_time(pattern, user)
            for pattern in range(self.num_patterns)
        ]
        next_time_per_pattern = asfortranarray(next_time_per_pattern)

        iteration = 0
        over_tmax = False
        while iteration < min_num_events or not over_tmax:
            if max_num_events is not None and iteration > max_num_events:
                break
            z_n = next_time_per_pattern.argmin()
            t_n = next_time_per_pattern[z_n]
            if t_max is not None and t_n > t_max:
                over_tmax = True
                break
            num_tables_user = self.total_tables_per_user[user] \
                if user in self.total_tables_per_user else 0
            tables = range(num_tables_user)
            tables = [
                table for table in tables
                if self.dish_on_table_per_user[user][table] == z_n
            ]
            intensities = []

            alpha = self.time_kernels[z_n]
            for table in tables:
                t_last, sum_kernels = self.user_table_cache[user][table]
                update_value = self.kernel(t_n, t_last)
                table_intensity = alpha * sum_kernels * update_value
                table_intensity += alpha * update_value
                intensities.append(table_intensity)
            intensities.append(mu_u * self.pattern_popularity[z_n])
            log_intensities = [
                ln(inten_i) if inten_i > 0 else -float('inf')
                for inten_i in intensities
            ]

            normalizing_log_intensity = logsumexp(log_intensities)
            intensities = [
                exp(log_intensity - normalizing_log_intensity)
                for log_intensity in log_intensities
            ]
            k = weighted_choice(intensities, self.prng)

            if k < len(tables):
                # Assign to already existing table
                table = tables[k]
                # update cache for that table
                t_last, sum_kernels = self.user_table_cache[user][table]
                update_value = self.kernel(t_n, t_last)
                sum_kernels += 1
                sum_kernels *= update_value
                self.user_table_cache[user][table] = (t_n, sum_kernels)
            else:
                table = num_tables_user
                self.total_tables += 1
                self.total_tables_per_user[user] += 1
                # Since this is a new table, initialize the cache accordingly
                self.user_table_cache[user][table] = (t_n, 0)
                self.dish_on_table_per_user[user][table] = z_n

            if z_n not in self.first_observed_time or\
                    t_n < self.first_observed_time[z_n]:
                self.first_observed_time[z_n] = t_n
            self.dish_counters[z_n] += 1

            doc_n = self.sample_document(z_n)
            self._update_word_counters(doc_n.split(), z_n)
            self.document_history_per_user[user]\
                .append(doc_n)
            self.table_history_per_user[user].append(table)
            self.time_history_per_user[user].append(t_n)
            self.last_event_user_pattern[user][z_n] = t_n

            # Resample time for that pattern
            next_time_per_pattern[z_n] = self.sample_next_time(z_n, user)
            z_n = next_time_per_pattern.argmin()
            t_n = next_time_per_pattern[z_n]
            iteration += 1

        events = [(self.time_history_per_user[user][i],
                   self.document_history_per_user[user][i], user, [])
                  for i in range(len(self.time_history_per_user[user]))]

        # Update the full history of events with the ones generated for the
        # current user and re-order everything so that the events are
        # ordered by their timestamp
        self.events += events
        self.events = sorted(self.events, key=lambda x: x[0])
        self.num_users += 1
        return events

    def kernel(self, t_i, t_j):
        """Returns the kernel function for t_i and t_j.


        Parameters
        ----------
        t_i : float
            Timestamp representing `now`.

        t_j : float
            Timestamp representaing `past`.


        Returns
        -------
        float
        """
        return exp(-self.omega * (t_i - t_j))

    def sample_document(self, pattern):
        """Sample a random document from a specific pattern.


        Parameters
        ----------
        pattern : int
            The pattern from which to sample the content.


        Returns
        -------
        str
            A space separeted string that contains all the sampled words.
        """
        length = self.doc_prng.randint(self.document_length) + \
            self.document_min_length
        words = self.doc_prng.multinomial(length, self.pattern_params[pattern])
        return ' '.join([
            self.vocabulary[i] for i, repeats in enumerate(words)
            for j in range(repeats)
        ])

    def _update_word_counters(self, doc, pattern):
        """Updates the word counters of the process for the particular document


        Parameters
        ----------
        doc : list
            A list with the words in the document.

        pattern : int
            The index of the latent pattern that this document belongs to.
        """
        for word in doc:
            self.per_pattern_word_counts[pattern][word] += 1
            self.per_pattern_word_count_total[pattern] += 1
        return

    def pattern_content_str(self,
                            patterns=None,
                            show_words=-1,
                            min_word_occurence=5):
        """Return the content information for the patterns of the process.


        Parameters
        ----------
        patterns : list, default is None
            If this list is provided, only information about the patterns in
            the list will be returned.

        show_words : int, default is -1
            The maximum number of words to show for each pattern. Notice that
            the words are sorted according to their occurence count.

        min_word_occurence : int, default is 5
            Only show words that show up at least `min_word_occurence` number
            of times in the documents of the respective pattern.


        Returns
        -------
        str
            A string with all the content information
        """
        if patterns is None:
            patterns = self.per_pattern_word_counts.keys()
        text = [
            '___Pattern %d___ \n%s\n%s' % (pattern, '\n'.join([
                '%s : %d' % (k, v) for i, (k, v) in enumerate(
                    sorted(self.per_pattern_word_counts[pattern].iteritems(),
                           key=lambda x: (x[1], x[0]),
                           reverse=True)) if v >= min_word_occurence and
                (show_words == -1 or (show_words > 0 and i < show_words))
            ]), ' '.join([
                k for i, (k, v) in enumerate(
                    sorted(self.per_pattern_word_counts[pattern].iteritems(),
                           key=lambda x: (x[1], x[0]),
                           reverse=True)) if v < min_word_occurence and
                (show_words == -1 or (show_words > 0 and i < show_words))
            ])) for pattern in self.per_pattern_word_counts
            if pattern in patterns
        ]
        return '\n\n'.join(text)

    def user_patterns_set(self, user):
        """Return the patterns that a specific user adopted.


        Parameters
        ----------
        user : int
            The index of a user.


        Returns
        -------
        set
            The set of the patterns that the user adopted.
        """
        pattern_list = [
            self.dish_on_table_per_user[user][table]
            for table in self.table_history_per_user[user]
        ]
        return list(set(pattern_list))

    def user_pattern_history_str(self,
                                 user=None,
                                 patterns=None,
                                 show_time=True,
                                 t_min=0):
        """Returns a representation of the history of a user's actions and the
        pattern that they correspond to.


        Parameters
        ----------
        user : int, default is None
            An index to the user we want to inspect. If None, the function
            runs over all the users.

        patterns : list, default is None
            If not None, limit the actions returned to the ones that belong in
            the provided patterns.

        show_time : bool, default is True
            Control wether the timestamp will appear in the representation or
            not.

        t_min : float, default is 0
            The timestamp after which we only consider actions.


        Returns
        -------
        str
        """
        if patterns is not None and type(patterns) is not set:
            patterns = set(patterns)

        return '\n'.join([
            '%spattern=%2d task=%3d (u=%d)  %s' %
            ('%5.3g ' % t if show_time else '', dish, table, u, doc)
            for u in range(self.num_users)
            for ((t, doc),
                 (table, dish
                  )) in zip([(t, d)
                             for t, d in zip(self.time_history_per_user[u],
                                             self.document_history_per_user[u])
                             ], [(table, self.dish_on_table_per_user[u][table])
                                 for table in self.table_history_per_user[u]])
            if (user is None or user == u) and (
                patterns is None or dish in patterns) and t >= t_min
        ])

    def _plot_user(self,
                   user,
                   fig,
                   num_samples,
                   T_max,
                   task_detail,
                   seed=None,
                   patterns=None,
                   colormap=None,
                   T_min=0,
                   paper=True):
        """Helper function that plots.
        """
        tics = np.arange(T_min, T_max, (T_max - T_min) / num_samples)
        tables = sorted(set(self.table_history_per_user[user]))
        active_tables = set()
        dish_set = set()

        times = [t for t in self.time_history_per_user[user]]
        start_event = min([
            i for i, t in enumerate(self.time_history_per_user[user])
            if t >= T_min
        ])
        table_cache = {}  # partial sum for each table
        dish_cache = {}  # partial sum for each dish
        table_intensities = [[] for _ in range(len(tables))]
        dish_intensities = [[] for _ in range(len(self.time_kernels))]
        i = 0  # index of tics
        j = start_event  # index of events

        while i < len(tics):
            if j >= len(times) or tics[i] < times[j]:
                # We need to measure the intensity
                dish_intensities, table_intensities = \
                    self._measure_intensities(tics[i], dish_cache=dish_cache,
                                              table_cache=table_cache,
                                              tables=tables,
                                              user=user,
                                              dish_intensities=dish_intensities,
                                              table_intensities=table_intensities)
                i += 1
            else:
                # We need to record an event and update our caches
                dish_cache, table_cache, active_tables, dish_set = \
                    self._update_cache(times[j],
                                       dish_cache=dish_cache,
                                       table_cache=table_cache,
                                       event_table=self.table_history_per_user[user][j],
                                       tables=tables,
                                       user=user,
                                       active_tables=active_tables,
                                       dish_set=dish_set)
                j += 1
        if patterns is not None:
            dish_set = set([d for d in patterns if d in dish_set])
        dish_dict = {dish: i for i, dish in enumerate(dish_set)}

        if colormap is None:
            prng = RandomState(seed)
            num_dishes = len(dish_set)
            colormap = qualitative_cmap(n_colors=num_dishes)
            colormap = prng.permutation(colormap)

        if not task_detail:
            for dish in dish_set:
                fig.plot(tics,
                         dish_intensities[dish],
                         color=colormap[dish_dict[dish]],
                         linestyle='-',
                         label="Pattern " + str(dish),
                         linewidth=3.)
        else:
            for table in active_tables:
                dish = self.dish_on_table_per_user[user][table]
                if patterns is not None and dish not in patterns:
                    continue
                fig.plot(tics,
                         table_intensities[table],
                         color=colormap[dish_dict[dish]],
                         linewidth=3.)
            for dish in dish_set:
                if patterns is not None and dish not in patterns:
                    continue
                fig.plot([], [],
                         color=colormap[dish_dict[dish]],
                         linestyle='-',
                         label="Pattern " + str(dish),
                         linewidth=5.)
        return fig

    def _measure_intensities(self, t, dish_cache, table_cache, tables, user,
                             dish_intensities, table_intensities):
        """Measures the intensities of all tables and dishes for the plot func.


        Parameters
        ----------
        t : float

        dish_cache : dict

        table_cache : dict

        tables : list

        user : int

        dish_intensities : dict

        table_intensities : dict


        Returns
        -------
        dish_intensities : dict

        table_intensities : dict
        """
        updated_dish = [False] * len(self.time_kernels)
        for table in tables:
            dish = self.dish_on_table_per_user[user][table]
            lambda_uz = self.mu_per_user[user] * self.pattern_popularity[dish]
            alpha = self.time_kernels[dish]
            if dish in dish_cache:
                t_last_dish, sum_kernels_dish = dish_cache[dish]
                update_value_dish = self.kernel(t, t_last_dish)
                dish_intensity = lambda_uz + alpha * sum_kernels_dish *\
                    update_value_dish
                dish_intensity += alpha * update_value_dish
            else:
                dish_intensity = lambda_uz
            if table in table_cache:
                # table already exists
                t_last_table, sum_kernels_table = table_cache[table]
                update_value_table = self.kernel(t, t_last_table)
                table_intensity = alpha * sum_kernels_table *\
                    update_value_table
                table_intensity += alpha * update_value_table
            else:
                # table does not exist yet
                table_intensity = 0
            table_intensities[table].append(table_intensity)
            if not updated_dish[dish]:
                # make sure to update dish only once for all the tables
                dish_intensities[dish].append(dish_intensity)
                updated_dish[dish] = True
        return dish_intensities, table_intensities

    def _update_cache(self, t, dish_cache, table_cache, event_table, tables,
                      user, active_tables, dish_set):
        """Updates the caches for the plot function when an event is recorded.


        Parameters
        ----------
        t : float

        dish_cache : dict

        table_cache : dict

        event_table : int

        tables : list

        user : int

        active_tables : set

        dish_set : set


        Returns
        -------
        dish_cache : dict

        table_cache : dict

        active_tables : set

        dish_set : set
        """
        dish = self.dish_on_table_per_user[user][event_table]
        active_tables.add(event_table)
        dish_set.add(dish)
        if event_table not in table_cache:
            table_cache[event_table] = (t, 0)
        else:
            t_last, sum_kernels = table_cache[event_table]
            update_value = self.kernel(t, t_last)
            sum_kernels += 1
            sum_kernels *= update_value
            table_cache[event_table] = (t, sum_kernels)
        if dish not in dish_cache:
            dish_cache[dish] = (t, 0)
        else:
            t_last, sum_kernels = dish_cache[dish]
            update_value = self.kernel(t, t_last)
            sum_kernels += 1
            sum_kernels *= update_value
            dish_cache[dish] = (t, sum_kernels)
        return dish_cache, table_cache, active_tables, dish_set

    def plot(self,
             num_samples=500,
             T_min=0,
             T_max=None,
             start_date=None,
             users=None,
             user_limit=50,
             patterns=None,
             task_detail=False,
             save_to_file=False,
             filename="user_timelines",
             intensity_threshold=None,
             paper=True,
             colors=None,
             fig_width=20,
             fig_height_per_user=5,
             time_unit='months',
             label_every=3,
             seed=None):
        """Plots the intensity of a set of users for a set of patterns over a
        time period.

        In this plot, each user is a separate subplot and for each user the
        plot shows her event_rate for each separate pattern that she has been
        active at.


        Parameters
        ----------
        num_samples : int, default is 500
            The granularity level of the intensity line. Smaller number of
            samples results in faster plotting, while larger numbers give
            much more detailed result.

        T_min : float, default is 0
            The minimum timestamp that the plot shows, in seconds.

        T_max : float, default is None
            If not None, this is the maximum timestamp that the plot considers,
            in seconds.

        start_date : datetime, default is None
            If provided, this is the actual datetime that corresponds to
            time 0. This is required if `paper` is True.

        users : list, default is None
            If provided, this list contains the id's of the users that will be
            plotted. Actually, only the first `user_limit` of them will be
            shown.

        user_limit : int, default is 50
            The maximum number of users to plot.

        patterns : list, default is None
            The list of patterns that will be shown in the final plot. If None,
            all of the patterns will be plotted.

        task_detail : bool, default is False
            If True, thee plot has one line per task. Otherwise, we only plot
            the commulative intensity of all tasks under the same pattern.

        save_to_file : bool, default is False
            If True, the plot will be saved to a `pdf` and a `png` file.

        filename : str, default is 'user_timelines'
            The name of the output file that will be used when saving the plot.

        intensity_threshold : float, default is None
            If provided, this is the maximum intensity value that will be
            plotted, i.e. the y_max that will be the cut-off threshold for the
            y-axis.

        paper : bool, default is True
            If True, the plot result will be the same as the figures that are
            in the published paper.

        colors : list, default is None
            A list of colors that will be used for the plot. Each color will
            correspond to a single pattern, and will be shared across all the
            users.

        fig_width : int, default is 20
            The width of the figure that will be returned.

        fig_height_per_user : int, default is 5
            The height of each separate user-plot of the final figure. If
            multiplied by the number of users, this determines the total height
            of the figure. Notice that due to a matplotlib constraint(?) the
            total height of  the figure cannot be over 70.

        time_unit : str, default is 'months'
            Controls wether the time units is measured in days (in
            which case it should be set to 'days') or months.

        label_every : int, default is 3
            The frequency of the labels that show in the x-axis.

        seed : int, default is None
            A seed to the random number generator used to assign colors to
            patterns.


        Returns
        -------
        fig : matplotlib.Figure object
        """
        prng = RandomState(seed)
        num_users = len(self.dish_on_table_per_user)
        if users is None:
            users = range(num_users)
        num_users_to_plot = min(len(users), user_limit)
        users = users[:num_users_to_plot]
        if T_max is None:
            T_max = max(
                [self.time_history_per_user[user][-1] for user in users])
        fig = plt.figure(figsize=(fig_width,
                                  min(fig_height_per_user *
                                      num_users_to_plot, 70)))

        num_patterns_global = len(self.time_kernels)
        colormap = qualitative_cmap(n_colors=num_patterns_global)
        colormap = prng.permutation(colormap)
        if colors is not None:
            colormap = matplotlib.colors.ListedColormap(colors)
        if paper:
            sns.set_style('white')
            sns.despine(bottom=True, top=False, right=False, left=False)

        user_plt_axes = []
        max_intensity = -float('inf')
        for i, user in enumerate(users):
            if user not in self.time_history_per_user \
                    or not self.time_history_per_user[user] \
                    or self.time_history_per_user[user][-1] < T_min:
                # no events generated for this user during the time window
                # we are interested in
                continue
            if patterns is not None:
                user_patterns = set([
                    self.dish_on_table_per_user[user][table]
                    for table in self.table_history_per_user[user]
                ])
                if not any([pattern in patterns for pattern in user_patterns]):
                    # user did not generate events in the patterns of interest
                    continue
            user_plt = plt.subplot(num_users_to_plot, 1, i + 1)
            user_plt = self._plot_user(user,
                                       user_plt,
                                       num_samples,
                                       T_max,
                                       task_detail=task_detail,
                                       seed=seed,
                                       patterns=patterns,
                                       colormap=colormap,
                                       T_min=T_min,
                                       paper=paper)
            user_plt.set_xlim((T_min, T_max))
            if paper:
                if start_date is None:
                    raise ValueError(
                        'For paper-level quality plots, the actual datetime for t=0 must be provided as `start_date`'
                    )
                if start_date.microsecond > 500000:
                    start_date = start_date.replace(microsecond=0) \
                        + datetime.timedelta(seconds=1)
                else:
                    start_date = start_date.replace(microsecond=0)
                if time_unit == 'days':
                    t_min_seconds = T_min * 86400
                    t_max_seconds = T_max * 86400
                    t1 = start_date + datetime.timedelta(0, t_min_seconds)
                    t2 = start_date + datetime.timedelta(0, t_max_seconds)
                    ticks = monthly_ticks_for_days(t1, t2)
                    labels = monthly_labels(t1, t2, every=label_every)
                elif time_unit == 'months':
                    t1 = month_add(start_date, T_min)
                    t2 = month_add(start_date, T_max)
                    ticks = monthly_ticks_for_months(t1, t2)
                    labels = monthly_labels(t1, t2, every=label_every)
                labels[-1] = ''
                user_plt.set_xlim((ticks[0], ticks[-1]))
                user_plt.yaxis.set_ticks([])
                user_plt.xaxis.set_ticks(ticks)
                plt.setp(user_plt.xaxis.get_majorticklabels(), rotation=-0)
                user_plt.tick_params('x',
                                     length=10,
                                     which='major',
                                     direction='out',
                                     top=False,
                                     bottom=True)
                user_plt.xaxis.set_ticklabels(labels, fontsize=30)
                user_plt.tick_params(axis='x', which='major', pad=10)
                user_plt.get_xaxis(
                ).majorTicks[1].label1.set_horizontalalignment('left')
                for tick in user_plt.xaxis.get_major_ticks():
                    tick.label1.set_horizontalalignment('left')

            current_intensity = user_plt.get_ylim()[1]
            if current_intensity > max_intensity:
                max_intensity = current_intensity
            user_plt_axes.append(user_plt)
            if not paper:
                plt.title('User %2d' % user)
                plt.legend()
        if intensity_threshold is None:
            intensity_threshold = max_intensity
        for ax in user_plt_axes:
            ax.set_ylim((-0.2, intensity_threshold))
        if paper and save_to_file:
            print("Create image %s.png" % (filename))
            plt.savefig(filename + '.png', transparent=True)
            plt.savefig(filename + '.pdf', transparent=True)
            sns.set_style(None)
        return fig

    def user_patterns(self, user):
        """Returns a list with the patterns that a user has adopted.

        Parameters
        ----------
        user : int
        """
        pattern_list = [
            self.dish_on_table_per_user[user][table]
            for table in self.table_history_per_user[user]
        ]
        return list(set(pattern_list))

    def show_annotated_events(self,
                              user=None,
                              patterns=None,
                              show_time=True,
                              T_min=0,
                              T_max=None):
        """Returns a string where each event is annotated with the inferred
        pattern.


        Parameters
        ----------
        user : int, default is None
            If given, the events returned are limited to the selected user

        patterns : list, default is None
            If not None, an event is return only if it belongs to one of the
            selected patterns

        show_time : bool, default is True
            Controls whether the time of the event will be shown

        T_min : float, default is 0
            Controls the minimum timestamp after which the events will be shown

        T_max : float, default is None
            If given, T_max controls the maximum timestamp shown


        Returns
        -------
        str
        """
        if patterns is not None and type(patterns) is not set:
            patterns = set(patterns)

        if show_time:
            return '\n'.join([
                '%5.3g pattern=%3d task=%3d (u=%d)  %s' %
                (t, dish, table, u, doc) for u in range(self.num_users)
                for ((t, doc), (table, dish))
                in zip([(t, d)
                        for t, d in zip(self.time_history_per_user[u],
                                        self.document_history_per_user[u])],
                       [(table, self.dish_on_table_per_user[u][table])
                        for table in self.table_history_per_user[u]])
                if (user is None or user == u) and (
                    patterns is None or dish in patterns) and t >= T_min and (
                        T_max is None or (T_max is not None and t <= T_max))
            ])
        else:
            return '\n'.join([
                'pattern=%3d task=%3d (u=%d)  %s' % (dish, table, u, doc)
                for u in range(self.num_users) for ((t, doc), (table, dish))
                in zip([(t, d)
                        for t, d in zip(self.time_history_per_user[u],
                                        self.document_history_per_user[u])],
                       [(table, self.dish_on_table_per_user[u][table])
                        for table in self.table_history_per_user[u]])
                if (user is None or user == u) and (
                    patterns is None or dish in patterns) and t >= T_min and (
                        T_max is None or (T_max is not None and t <= T_max))
            ])

    def show_pattern_content(self, patterns=None, words=0, detail_threshold=5):
        """Shows the content distrubution of the inferred patterns.


        Parameters
        ----------
        patterns : list, default is None
            If not None, only the content of the selected patterns will be
            shown

        words : int, default is 0
            A positive number that control how many words will be shown.
            The words are being shown sorted by their likelihood, starting
            with the most probable.

        detail_threshold : int, default is 5
            A positive number that sets the lower bound in the number of times
            that a word appeared in a pattern so that its count is shown.


        Returns
        -------
        str
        """
        if patterns is None:
            patterns = self.per_pattern_word_count.keys()
        text = [
            '___Pattern %d___ \n%s\n%s' % (pattern, '\n'.join([
                '%s : %d' % (k, v) for i, (k, v) in enumerate(
                    sorted(self.per_pattern_word_counts[pattern].iteritems(),
                           key=lambda x: (x[1], x[0]),
                           reverse=True))
                if v >= detail_threshold and (words == 0 or i < words)
            ]), ' '.join([
                k for i, (k, v) in enumerate(
                    sorted(self.per_pattern_word_counts[pattern].iteritems(),
                           key=lambda x: (x[1], x[0]),
                           reverse=True))
                if v < detail_threshold and (words == 0 or i < words)
            ])) for pattern in self.per_pattern_word_counts
            if pattern in patterns
        ]
        return '\n\n'.join(text)
    def anytime_explain(self, instance, callback=None, update_func=None, update_prediction=None):
        data_rows, no_atr = self.data.X.shape
        class_value = self.model(instance)[0]
        prng = RandomState(self.seed)

        self.init_arrays(no_atr)
        attr_values = self.get_atr_column(instance)

        batch_mx_size = self.batch_size * no_atr
        z_sq = abs(st.norm.ppf(self.p_val/2))**2

        tiled_inst = self.tile_instance(instance)
        inst1 = copy.deepcopy(tiled_inst)
        inst2 = copy.deepcopy(tiled_inst)

        worst_case = self.max_iter*no_atr
        time_point = time.time()
        update_table = False

        domain = Domain([ContinuousVariable("Score"),
                         ContinuousVariable("Error")],
                        metas=[StringVariable(name="Feature"), StringVariable(name = "Value")])

        if update_prediction is not None:
            update_prediction(class_value)

        def create_res_table():
            nonzero = self.steps != 0
            expl_scaled = (self.expl[nonzero]/self.steps[nonzero]).reshape(1, -1)
            # creating return array
            ips = np.hstack((expl_scaled.T, np.sqrt(
                z_sq * self.var[nonzero] / self.steps[nonzero]).reshape(-1, 1)))
            table = Table.from_numpy(domain, ips,
                                     metas=np.hstack((np.asarray(self.atr_names)[nonzero[0]].reshape(-1, 1),
                                                        attr_values[nonzero[0]].reshape(-1,1))))
            return table

        while not(all(self.iterations_reached[0, :] > self.max_iter)):
            prog = 1 - np.sum(self.max_iter - self.iterations_reached)/worst_case
            if (callback(int(prog*100))):
                break
            if not(any(self.iterations_reached[0, :] > self.max_iter)):
                a = np.argmax(prng.multinomial(
                    1, pvals=(self.var[0, :]/(np.sum(self.var[0, :])))))
            else:
                a = np.argmin(self.iterations_reached[0, :])

            perm = (prng.random_sample(batch_mx_size).reshape(
                self.batch_size, no_atr)) > 0.5
            rand_data = self.data.X[prng.randint(0,
                                                 data_rows, size=self.batch_size), :]
            inst1.X = np.copy(tiled_inst.X)
            inst1.X[perm] = rand_data[perm]
            inst2.X = np.copy(inst1.X)

            inst1.X[:, a] = tiled_inst.X[:, a]
            inst2.X[:, a] = rand_data[:, a]
            f1 = self._get_predictions(inst1, class_value)
            f2 = self._get_predictions(inst2, class_value)

            diff = np.sum(f1 - f2)
            self.expl[0, a] += diff

            # update variance
            self.steps[0, a] += self.batch_size
            self.iterations_reached[0, a] += self.batch_size
            d = diff - self.mu[0, a]
            self.mu[0, a] += d / self.steps[0, a]
            self.M2[0, a] += d * (diff - self.mu[0, a])
            self.var[0, a] = self.M2[0, a] / (self.steps[0, a] - 1)

            if time.time() - time_point > 1:
                update_table = True
                time_point = time.time()

            if update_table:
                update_table = False
                update_func(create_res_table())

            # exclude from sampling if necessary
            needed_iter = z_sq * self.var[0, a] / (self.error**2)
            if (needed_iter <= self.steps[0, a]) and (self.steps[0, a] >= self.min_iter) or (self.steps[0, a] > self.max_iter):
                self.iterations_reached[0, a] = self.max_iter + 1

        return class_value, create_res_table()
Example #8
0
def main(outdir):
    rng = RandomState(MT19937(SeedSequence(config.seed)))

    num_employees = 50000

    num_orders = 1000000

    num_jobsites = 2800
    num_areas = 180
    num_qualifications = 214
    num_qualigroups = 13
    num_shifts = 4
    num_days = 2708

    start_day = datetime(2013, 8, 1)

    print("create sliding window of active employees")
    active_employees = np.zeros((num_employees, num_days)).astype(bool)

    left = 0
    right = 100
    upkeep = 400
    change = (.95, 1 - .95)
    for irow, row in enumerate(active_employees):
        active_employees[irow, left:right] = 1
        left = left + rng.choice([0, 1], p=change)
        right = left + upkeep + rng.choice([0, 1], p=change)

    print("create base distributions for areas, qualis and shifts")
    areas = rng.dirichlet(np.ones(num_areas) * .1)

    jobsites = rng.dirichlet(np.ones(num_jobsites) * .1)

    area_of_jobsite = np.empty(num_jobsites)
    for ijobsite, jobsite in enumerate(jobsites):
        area_of_jobsite[ijobsite] = rng.choice(np.arange(num_areas), p=areas)

    qualigroups = rng.dirichlet(np.ones(num_qualigroups) * .1)

    qualis = rng.dirichlet(np.ones(num_qualifications) * .1)

    qualigroup_of_quali = np.empty(num_qualifications)
    for iquali, quali in enumerate(qualis):
        qualigroup_of_quali[iquali] = rng.choice(np.arange(num_qualigroups),
                                                 p=qualigroups)

    shifts = rng.dirichlet(np.ones(num_shifts))

    orders = []
    for _ in tqdm(range(num_orders), desc="create orders"):
        shift = rng.choice(range(num_shifts), p=shifts)

        jobsite = rng.choice(range(num_jobsites), p=jobsites)
        area = area_of_jobsite[jobsite]

        quali = rng.choice(range(num_qualifications), p=qualis)
        qualigroup = qualigroup_of_quali[quali]

        day = rng.randint(0, num_days)

        orders.append({
            "Schicht": shift,
            "Einsatzort": jobsite,
            "PLZ": area,
            "Qualifikation": quali,
            "Qualifikationgruppe": qualigroup,
            "Tag": day,
        })

    employee_qualifications = rng.multinomial(
        1, qualis, size=(num_employees)).astype(bool)
    employee_jobsites = rng.multinomial(1, jobsites,
                                        size=(num_employees)).astype(bool)

    orders = pd.DataFrame(orders)
    offers = []

    ps = np.ones(6) / np.arange(1, 7)
    ps /= ps.sum()

    for _, order in tqdm(orders.iterrows(),
                         desc="create offers",
                         total=len(orders)):

        match_active = active_employees[:, int(order.Tag)]
        match_quali = employee_qualifications[:, int(order.Qualifikation)]
        match_jobsite = employee_jobsites[:, int(order.Einsatzort)]

        match, = (match_active & match_quali & match_jobsite).nonzero()

        offers.append(match[:6].tolist())
        if len(offers[-1]) == 0:

            offers[-1] = rng.choice(match_active.nonzero()[0],
                                    np.random.choice(range(1, 7),
                                                     p=ps)).tolist()

    berlin_holidays = holidays.DE(prov="BE")

    orders["Mitarbeiter ID"] = offers
    print("add day meta data")
    orders["Tag"] = orders["Tag"].apply(lambda day: start_day + timedelta(day))
    orders["Wochentag"] = orders["Tag"].apply(lambda day: day.strftime("%a"))
    orders["Feiertag"] = orders["Tag"].apply(
        lambda day: day in berlin_holidays)

    orders = orders[[
        "Einsatzort", "PLZ", "Qualifikation", "Qualifikationgruppe", "Schicht",
        "Tag", "Wochentag", "Feiertag", "Mitarbeiter ID"
    ]]
    orders = orders.sort_values("Tag")

    train, test = train_test_split(orders)

    train.to_csv(os.path.join(outdir, "train.tsv"), index=False, sep="\t")
    test.to_csv(os.path.join(outdir, "test_truth.tsv"), index=False, sep="\t")
    test[[
        "Einsatzort", "PLZ", "Qualifikation", "Qualifikationgruppe", "Schicht",
        "Tag", "Wochentag", "Feiertag"
    ]].to_csv(os.path.join(outdir, "test_publish.tsv"), index=False, sep="\t")
Example #9
0
def calc_cluster_distribution(population_size: int, random_seed=47) -> NDArray:
    cluster_probabilities = _obtain_cluster_probabilities(
        CLUSTER_INFO_FILE_NAME)
    random = RandomState(random_seed)
    return random.multinomial(population_size, cluster_probabilities,
                              size=1)[0]
Example #10
0
def _subsample_nonzero(counts, ns, replace=False, seed=0):
    """Randomly subsample from a vector of counts and returns the number of
    nonzero values for each number of element to subsample specified.

    Parameters
    ----------
    counts : 1-D array_like of integers
        Vector of counts.
    ns : 1-D array_like of integers
        List of numbers of element to subsample.
    replace : bool, optional
        Subsample with or without replacement.
    seed : int, optional
        Random seed.

    Returns
    -------
    nonzero : 1-D ndarray
        Number of nonzero values for each value of ns.

    Raises
    ------
    ValueError, TypeError
    """

    counts = np.asarray(counts)
    ns = np.asarray(ns)

    if counts.ndim != 1:
        raise ValueError("'counts' must be an 1-D array_like object")

    if (ns < 0).sum() > 0:
        raise ValueError("values in 'ns' must be > 0 ")

    counts = counts.astype(int, casting='safe')
    ns = ns.astype(int, casting='safe')

    counts_sum = counts.sum()

    prng = RandomState(seed)
    nonzero = []

    if replace:
        p = counts / counts_sum
        for n in ns:
            if n > counts_sum:
                nonzero.append(np.nan)
            else:
                subcounts = prng.multinomial(n, p)
                nonzero.append(np.count_nonzero(subcounts))
    else:
        nz = np.flatnonzero(counts)
        expanded = np.concatenate([np.repeat(i, counts[i]) for i in nz])
        permuted = prng.permutation(expanded)
        for n in ns:
            if n > counts_sum:
                nonzero.append(np.nan)
            else:
                subcounts = np.bincount(permuted[:n], minlength=counts.size)
                nonzero.append(np.count_nonzero(subcounts))

    return np.array(nonzero)
Example #11
0
def _subsample_nonzero(counts, ns, replace=False, seed=0):
    """Randomly subsample from a vector of counts and returns the number of
    nonzero values for each number of element to subsample specified.

    Parameters
    ----------
    counts : 1-D array_like of integers
        Vector of counts.
    ns : 1-D array_like of integers
        List of numbers of element to subsample.
    replace : bool, optional
        Subsample with or without replacement.
    seed : int, optional
        Random seed.

    Returns
    -------
    nonzero : 1-D ndarray
        Number of nonzero values for each value of ns.

    Raises
    ------
    ValueError, TypeError
    """

    counts = np.asarray(counts)
    ns = np.asarray(ns)

    if counts.ndim != 1:
        raise ValueError("'counts' must be an 1-D array_like object")

    if (ns < 0).sum() > 0:
        raise ValueError("values in 'ns' must be > 0 ")

    counts = counts.astype(int, casting='safe')
    ns = ns.astype(int, casting='safe')

    counts_sum = counts.sum()

    prng = RandomState(seed)
    nonzero = []

    if replace:
        p = counts / counts_sum
        for n in ns:
            if n > counts_sum:
                nonzero.append(np.nan)
            else:
                subcounts = prng.multinomial(n, p)
                nonzero.append(np.count_nonzero(subcounts))
    else:
        nz = np.flatnonzero(counts)
        expanded = np.concatenate([np.repeat(i, counts[i]) for i in nz])
        permuted = prng.permutation(expanded)
        for n in ns:
            if n > counts_sum:
                nonzero.append(np.nan)
            else:
                subcounts = np.bincount(permuted[:n], minlength=counts.size)
                nonzero.append(np.count_nonzero(subcounts))

    return np.array(nonzero)
    def anytime_explain(self, instance, callback=None, update_func=None, update_prediction=None):
        data_rows, no_atr = self.data.X.shape
        class_value = self.model(instance)[0]
        prng = RandomState(self.seed)

        self.init_arrays(no_atr)
        attr_values = self.get_atr_column(instance)

        batch_mx_size = self.batch_size * no_atr
        z_sq = abs(st.norm.ppf(self.p_val/2))**2

        tiled_inst = self.tile_instance(instance)
        inst1 = copy.deepcopy(tiled_inst)
        inst2 = copy.deepcopy(tiled_inst)

        worst_case = self.max_iter*no_atr
        time_point = time.time()
        update_table = False

        domain = Domain([ContinuousVariable("Score"),
                         ContinuousVariable("Error")],
                        metas=[StringVariable(name="Feature"), StringVariable(name="Value")])

        if update_prediction is not None:
            update_prediction(class_value)

        def create_res_table():
            nonzero = self.steps != 0
            expl_scaled = (self.expl[nonzero] /
                           self.steps[nonzero]).reshape(1, -1)
            """ creating return array"""
            ips = np.hstack((expl_scaled.T, np.sqrt(
                z_sq * self.var[nonzero] / self.steps[nonzero]).reshape(-1, 1)))
            table = Table.from_numpy(domain, ips,
                                     metas=np.hstack((np.asarray(self.atr_names)[nonzero[0]].reshape(-1, 1),
                                                      attr_values[nonzero[0]].reshape(-1, 1))))
            return table

        while not(all(self.iterations_reached[0, :] > self.max_iter)):
            prog = 1 - np.sum(self.max_iter -
                              self.iterations_reached)/worst_case
            if (callback(int(prog*100))):
                break
            if not(any(self.iterations_reached[0, :] > self.max_iter)):
                a = np.argmax(prng.multinomial(
                    1, pvals=(self.var[0, :]/(np.sum(self.var[0, :])))))
            else:
                a = np.argmin(self.iterations_reached[0, :])

            perm = (prng.random_sample(batch_mx_size).reshape(
                self.batch_size, no_atr)) > 0.5
            rand_data = self.data.X[prng.randint(0,
                                                 data_rows, size=self.batch_size), :]
            inst1.X = np.copy(tiled_inst.X)
            inst1.X[perm] = rand_data[perm]
            inst2.X = np.copy(inst1.X)

            inst1.X[:, a] = tiled_inst.X[:, a]
            inst2.X[:, a] = rand_data[:, a]
            f1 = self._get_predictions(inst1, class_value)
            f2 = self._get_predictions(inst2, class_value)

            diff = np.sum(f1 - f2)
            self.expl[0, a] += diff

            """update variance"""
            self.steps[0, a] += self.batch_size
            self.iterations_reached[0, a] += self.batch_size
            d = diff - self.mu[0, a]
            self.mu[0, a] += d / self.steps[0, a]
            self.M2[0, a] += d * (diff - self.mu[0, a])
            self.var[0, a] = self.M2[0, a] / (self.steps[0, a] - 1)

            if time.time() - time_point > 1:
                update_table = True
                time_point = time.time()

            if update_table:
                update_table = False
                update_func(create_res_table())

            # exclude from sampling if necessary
            needed_iter = z_sq * self.var[0, a] / (self.error**2)
            if (needed_iter <= self.steps[0, a]) and (self.steps[0, a] >= self.min_iter) or (self.steps[0, a] > self.max_iter):
                self.iterations_reached[0, a] = self.max_iter + 1

        return class_value, create_res_table()
Example #13
0
class ChildGenerator(object):
    def __init__(self, parents: List[str]):
        self.__unique_genes = [parent[1:-1] for parent in parents]
        self.__gene_lengths = self.__get_gene_lengths()
        self.__unique_lengths, self.__length_repetitions = self.__get_length_repetitions(
        )
        self.__random = RandomState()
        self.__length_probabilities = self.__determine_length_probabilities()
        self.__code_probabilities_by_position = self.__determine_code_probabilities_by_position(
        )
        self.__forbidden_lengths = set()

    def generate(self) -> Union[str, None]:
        length = self.__determine_length()
        while length in self.__forbidden_lengths:
            length = self.__determine_length()
        child = ""
        previous_letter = ""
        for i in range(length):
            curr_letter, next_letter = self.__produce_letter(
                i, i == (length - 1))
            while curr_letter == previous_letter:
                curr_letter, next_letter = self.__produce_letter(
                    i, i == (length - 1))
                curr_codes, curr_probabilities, curr_random = self.__code_probabilities_by_position[
                    i]
                if len(
                        curr_codes
                ) == 2 and previous_letter in curr_codes and next_letter in curr_codes:
                    self.__forbidden_lengths.add(length)
                    return None
            child += curr_letter
            previous_letter = curr_letter
        return "X" + child + "Y"

    def __produce_letter(self, position: int,
                         is_last: bool) -> Tuple[Letter, Letter]:
        curr_codes, curr_probabilities, curr_random = self.__code_probabilities_by_position[
            position]
        index = list(
            curr_random.multinomial(1, curr_probabilities, size=1)[0]).index(1)
        if not is_last:
            next_codes, next_probabilities, next_random = self.__code_probabilities_by_position[
                position + 1]
            if len(next_probabilities) == 1:
                next_code = next_codes[0]
                i = curr_codes.index(
                    next_code) if next_code in curr_codes else None
                if i is not None:
                    p = curr_probabilities[i] / (len(curr_probabilities) - 1)
                    new_curr_codes = curr_codes[:]
                    new_curr_probabilities = [
                        curr_probability + p
                        for curr_probability in curr_probabilities
                    ]
                    del new_curr_probabilities[i]
                    del new_curr_codes[i]
                    index = list(
                        curr_random.multinomial(1,
                                                new_curr_probabilities,
                                                size=1)[0]).index(1)
                    return new_curr_codes[index], next_code
        return curr_codes[index], ""

    def __get_gene_lengths(self):
        return [len(gene) for gene in self.__unique_genes]

    def __get_length_repetitions(self):
        return maths.count_repetitions(sorted(self.__gene_lengths))

    def __determine_length_probabilities(self):
        return [
            repetition / len(self.__gene_lengths)
            for repetition in self.__length_repetitions
        ]

    def __determine_length(self):
        index = list(
            self.__random.multinomial(1, self.__length_probabilities,
                                      size=1)[0]).index(1)
        return self.__unique_lengths[index]

    def __determine_code_probabilities_by_position(
            self) -> List[Tuple[List[Letter], List[float], RandomState]]:
        max_length = max(self.__unique_lengths)
        all_code_probabilities = []
        for position in range(max_length):
            codes = []
            for gene in self.__unique_genes:
                if len(gene) - 1 < position:
                    continue
                code = gene[position]
                codes.append(code)
            unique_codes, code_repetitions = maths.count_repetitions(
                sorted(codes))
            code_probabilities = [
                repetition / len(codes) for repetition in code_repetitions
            ]
            all_code_probabilities.append(
                tuple((unique_codes, code_probabilities, RandomState())))
        return all_code_probabilities
Example #14
0
class NCRPNode(object):
    
    total_nodes = 0
    last_node_id = 0

    def __init__(self, vocab, parent=None, level=0,
                 random_state=None):

        self.node_id = NCRPNode.last_node_id
        NCRPNode.last_node_id += 1

        self.customers = 0
        self.parent = parent
        self.children = []
        self.level = level
        self.total_words = 0
        self.vocab = np.array(vocab)
        self.word_counts = np.zeros(len(vocab))

        if random_state is None:
            self.random_state = RandomState()
        else:
            self.random_state = random_state

    def __repr__(self):
        parent_id = None
        if self.parent is not None:
            parent_id = self.parent.node_id
        return 'Node=%d level=%d customers=%d total_words=%d parent=%s' % (self.node_id,
            self.level, self.customers, self.total_words, parent_id)

    def add_child(self):
        ''' Adds a child to the next level of this node '''
        node = NCRPNode(self.vocab, parent=self, level=self.level+1)
        self.children.append(node)
        NCRPNode.total_nodes += 1
        return node

    def is_leaf(self,num_levels):
        ''' Check if this node is a leaf node '''
        return self.level == num_levels-1

    def get_new_leaf(self,num_levels):
        ''' Keeps adding nodes along the path until a leaf node is generated'''
        node = self
        for l in range(self.level, num_levels):
            node = node.add_child()
        return node

    def drop_path(self,num_levels):
        ''' Removes a document from a path starting from this node '''
        node = self
        node.customers -= 1
        if node.customers == 0:
            node.parent.remove(node)
        for level in range(1, num_levels): # skip the root
            node = node.parent
            node.customers -= 1
            if node.customers == 0:
                node.parent.remove(node)

    def remove(self, node):
        self.children.remove(node)
        NCRPNode.total_nodes -= 1

    def add_path(self,num_levels):
        node = self
        node.customers += 1
        for level in range(1, num_levels):
            node = node.parent
            node.customers += 1

    def select(self, gamma):
        weights = np.zeros(len(self.children)+1)
        weights[0] = float(gamma) / (gamma+self.customers)
        i = 1
        
        for child in self.children:
            weights[i] = float(child.customers) / (gamma + self.customers)
            i += 1
            
        choice = self.random_state.multinomial(1, weights).argmax()
       
        if choice == 0:
            return self.add_child()
        else:
            return self.children[choice-1]

    
    def get_node_words(self, temp):
        
        output = ''
        for i in range(len(self.word_counts)):
            if self.word_counts[i] != 0:
                if self.vocab[i] not in temp:
                    output += '%s, ' % self.vocab[i]
        return output        
Example #15
0
class HierarchicalLDA(object):
    def __init__(self, corpus, vocab, features, VAF, n_samples, seed,
                 crp_alpha, gamma, eta, mcmc_passes, stats_interval, max_k,
                 thinning, burnIn, mu, kappa, nu, sigma, ddcrp_alpha):

        NCRPNode.total_nodes = 0
        NCRPNode.last_node_id = 0

        self.num_samples = n_samples

        self.corpus = corpus
        self.vocab = vocab
        self.crp_alpha = crp_alpha  # smoothing on doc-topic distributions
        self.gamma = gamma  # "imaginary" customers at the next, as yet unused table
        self.eta = eta  # smoothing on topic-word distributions

        self.mcmc_passes = mcmc_passes
        self.stats_interval = stats_interval
        self.max_k = max_k
        self.thinning = thinning
        self.burnIn = burnIn
        self.VAF = VAF

        self.mu = mu
        self.kappa = kappa
        self.nu = nu
        self.sigma = sigma
        self.ddcrp_alpha = ddcrp_alpha
        self.features = features

        self.clusterPath = './clust-trace.csv'
        self.seed = seed
        self.random_state = RandomState(seed)
        self.num_documents = len(corpus)
        self.num_types = len(vocab)
        self.eta_sum = eta * self.num_types
        self.root_node = NCRPNode(self.vocab)
        self.document_leaves = {
        }  # currently selected path (ie leaf node) through the NCRP tree
        self.levels = np.zeros(self.num_documents,
                               dtype=np.object)  # indexed < doc, token >
        self.documents = []

        for d in range(len(self.corpus)):

            doc = NCRPDocument(doc_id=d,
                               num_levels=3,
                               root_node=self.root_node,
                               vocab=vocab,
                               words=self.corpus[d])
            self.documents.append(doc)

        for d in range(len(self.documents)):
            document = self.documents[d]
            feature = self.features[d]
            level_assignment = self.clone_detection(document, feature)
            document.num_levels = np.max(level_assignment)

        for d in range(len(self.documents)):

            # populate nodes into the path of this document
            doc = self.documents[d].words
            doc_num_levels = self.documents[d].num_levels + 1
            path = np.zeros(doc_num_levels, dtype=np.object)
            doc_len = len(doc)
            path[0] = self.root_node
            self.root_node.customers += 1  # always add to the root node first

            for level in range(1, doc_num_levels):
                # at each level, a node is selected by its parent node based on the CRP prior
                parent_node = path[level - 1]
                level_node = parent_node.select(self.gamma)
                level_node.level = level
                level_node.customers += 1
                path[level] = level_node

            # set the leaf node for this document
            leaf_node = path[doc_num_levels - 1]

            self.document_leaves[d] = leaf_node

            # randomly assign each word in the document to a level (node) along the path
            self.levels[d] = np.zeros(doc_len, dtype=np.int)
            for n in range(doc_len):
                w = doc[n]
                random_level = self.random_state.randint(doc_num_levels)
                random_node = path[random_level]
                random_node.word_counts[w] += 1
                random_node.total_words += 1
                self.levels[d][n] = random_level

    def estimate(self):

        print('HierarchicalLDA sampling\n')
        for s in range(self.num_samples):
            sys.stdout.write('.')
            for cd in range(len(self.documents)):
                self.sample_path(cd)

            for zd in range(len(self.documents)):
                self.sample_level(zd)

            #if (s > 0) and ((s+1) % display_topics == 0):
            print(" %d" % (s + 1))
            self.print_nodes()
            print

    def sample_path(self, d):

        # define a path starting from the leaf node of this doc
        document = self.documents[d]
        doc_num_levels = document.num_levels + 1
        path = np.zeros(doc_num_levels, dtype=np.object)
        node = self.document_leaves[d]

        for level in range(doc_num_levels - 1, -1,
                           -1):  # e.g. [3, 2, 1, 0] for num_levels = 4
            path[level] = node
            node = node.parent

        # remove this document from the path, deleting empty nodes if necessary
        self.document_leaves[d].drop_path(doc_num_levels)

        ############################################################
        # calculates the prior p(c_d | c_{-d}) in eq. (4)
        ############################################################

        node_weights = {}
        self.calculate_ncrp_prior(node_weights, self.root_node, 0.0)

        ############################################################
        # calculates the likelihood p(w_d | c, w_{-d}, z) in eq. (4)
        ############################################################

        level_word_counts = {}
        for level in range(doc_num_levels):
            level_word_counts[level] = {}
        doc_levels = self.levels[d]
        words = document.words

        # remove doc from path
        for n in range(len(words)):  # for each word in the doc

            # count the word at each level
            level = doc_levels[n]
            w = words[n]

            if w not in level_word_counts[level]:
                level_word_counts[level][w] = 1
            else:
                level_word_counts[level][w] += 1

            # remove word count from the node at that level
            level_node = path[level]
            level_node.word_counts[w] -= 1
            level_node.total_words -= 1
            assert level_node.word_counts[w] >= 0
            assert level_node.total_words >= 0

        self.calculate_doc_likelihood(node_weights, level_word_counts,
                                      doc_num_levels)

        ############################################################
        # pick a new path
        ############################################################

        nodes = np.array(list(node_weights.keys()))
        weights = np.array([node_weights[node] for node in nodes])
        weights = np.exp(
            weights - np.max(weights))  # normalise so the largest weight is 1
        weights = weights / np.sum(weights)

        choice = self.random_state.multinomial(1, weights).argmax()
        node = nodes[choice]

        # if we picked an internal node, we need to add a new path to the leaf

        if node.level > doc_num_levels - 1:
            for i in range(doc_num_levels - 1, node.level):
                temp_node = node.parent
                node = temp_node

        if node.level < doc_num_levels - 1:
            for l in range(node.level, doc_num_levels):
                node = node.add_child()

        # add the doc back to the path
        node.add_path(doc_num_levels)  # add a customer to the path
        self.document_leaves[d] = node  # store the leaf node for this doc

        # add the words
        for level in range(doc_num_levels - 1, -1,
                           -1):  # e.g. [3, 2, 1, 0] for num_levels = 4
            word_counts = level_word_counts[level]
            for w in word_counts:
                node.word_counts[w] += word_counts[w]
                node.total_words += word_counts[w]
            node = node.parent

    def calculate_ncrp_prior(self, node_weights, node, weight):
        ''' Calculates the prior on the path according to the nested CRP '''

        for child in node.children:
            child_weight = log(
                float(child.customers) / (node.customers + self.gamma))
            self.calculate_ncrp_prior(node_weights, child,
                                      weight + child_weight)

        node_weights[node] = weight + log(self.gamma /
                                          (node.customers + self.gamma))

    def calculate_doc_likelihood(self, node_weights, level_word_counts,
                                 doc_num_levels):

        # calculate the weight for a new path at a given level
        new_topic_weights = np.zeros(doc_num_levels)
        for level in range(1, doc_num_levels):  # skip the root

            word_counts = level_word_counts[level]
            total_tokens = 0

            for w in word_counts:
                count = word_counts[w]
                for i in range(count):
                    new_topic_weights[level] += log(
                        (self.eta + i) / (self.eta_sum + total_tokens))
                    total_tokens += 1

        self.calculate_word_likelihood(node_weights, self.root_node, 0.0,
                                       level_word_counts, new_topic_weights, 0,
                                       doc_num_levels)

    def calculate_word_likelihood(self, node_weights, node, weight,
                                  level_word_counts, new_topic_weights, level,
                                  doc_num_levels):

        # first calculate the likelihood of the words at this level, given this topic
        node_weight = 0.0
        word_counts = level_word_counts[level]
        total_words = 0

        for w in word_counts:
            count = word_counts[w]
            for i in range(count):
                node_weight += log(
                    (self.eta + node.word_counts[w] + i) /
                    (self.eta_sum + node.total_words + total_words))
                total_words += 1

        # propagate that weight to the child nodes
        for child in node.children:
            if level + 1 < doc_num_levels:
                self.calculate_word_likelihood(node_weights, child,
                                               weight + node_weight,
                                               level_word_counts,
                                               new_topic_weights, level + 1,
                                               doc_num_levels)

        # finally if this is an internal node, add the weight of a new path
        level += 1
        while level < doc_num_levels:
            node_weight += new_topic_weights[level]
            level += 1

        node_weights[node] += node_weight

    def sample_level(self, d):

        document = self.documents[d]
        feature = self.features[d]
        doc_num_levels = document.num_levels + 1

        level_assignment = self.clone_detection(document, feature)
        #sort level assignment

        # initialise level counts
        doc_levels = self.levels[d]
        level_counts = np.zeros(doc_num_levels, dtype=np.int)
        for c in doc_levels:
            level_counts[c] += 1

        # get the leaf node and populate the path
        path = np.zeros(doc_num_levels, dtype=np.object)
        node = self.document_leaves[d]
        for level in range(doc_num_levels - 1, -1,
                           -1):  # e.g. [3, 2, 1, 0] for num_levels = 4
            path[level] = node
            node = node.parent

        # put the word back into the model
        level_weights = np.zeros(doc_num_levels)
        words = document.words
        for n in range(len(words)):

            w = words[n]
            word_level = doc_levels[n]

            # remove from model
            level_counts[word_level] -= 1
            node = path[word_level]
            node.word_counts[w] -= 1
            node.total_words -= 1

            level = level_assignment[n]

            doc_levels[n] = level
            level_counts[level] += 1

            node = path[level]

            node.word_counts[w] += 1
            node.total_words += 1

    def clone_detection(self, document, feature):

        adj_list = self.get_adj_list_by_co_occurence_frequency(document)
        feature = np.array(list(feature))

        dimensionality_of_data = len(self.VAF)
        np.random.seed(seed=2)
        mu_bar = np.zeros((dimensionality_of_data, ))
        lambda_bar = np.random.rand(
            dimensionality_of_data,
            dimensionality_of_data) + np.eye(dimensionality_of_data)
        niw = Priors.NIW(mu0=mu_bar,
                         kappa0=self.kappa,
                         nu0=self.nu,
                         lambda0=lambda_bar)

        crp = ddCRP.ddCRP(alpha=self.ddcrp_alpha,
                          model=niw,
                          mcmc_passes=self.mcmc_passes,
                          stats_interval=self.stats_interval,
                          parcelPath=self.clusterPath,
                          ward=False,
                          n_clusters=3)
        crp.fit(feature, adj_list)

        pc = PointClustering(thinning=self.thinning,
                             burnIn=self.burnIn,
                             clust_trace_filepath=self.clusterPath,
                             method='avg',
                             max_k=self.max_k)
        result_clustering = pc.estimatePointClustering()

        #sort according to the VAF

        return result_clustering

    def get_adj_list_by_co_occurence_frequency(self, document):

        adj_list = {}.fromkeys(np.arange(len(self.VAF)))
        this_doc_words = document.words

        words_documents = []
        for cnt in range(len(self.documents)):
            words = self.documents[cnt].words
            row = np.zeros(len(self.vocab))
            for w in range(len(words)):
                row[words[w]] = 1
            words_documents.append(row)

        words_frequencies = []
        for k in range(len(self.vocab)):
            words_frequencies.append(np.zeros(len(self.vocab)))

        counter = 0
        for cnt in range(len(words_documents)):
            row = np.zeros(len(self.vocab))
            for i in range(len(this_doc_words)):
                for j in range(i + 1, len(this_doc_words)):
                    occur = words_documents[cnt][i] + words_documents[cnt][j]
                    if occur == 2:
                        counter += 1
                        words_frequencies[i][j] += 1
                        words_frequencies[j][i] += 1
        words_frequencies = [
            x / len(self.documents) for x in words_frequencies
        ]

        mean_co_occurrence = sum(sum(words_frequencies)) / counter

        curr_adj = []
        for i in range(len(this_doc_words)):
            curr_adj = []
            for j in range(len(this_doc_words)):
                if words_frequencies[i][j] > mean_co_occurrence:
                    curr_adj.append(j)
            curr_adj = list(dict.fromkeys(curr_adj))
            adj_list[i] = list(np.array(curr_adj))

        return adj_list

    def print_nodes(self):
        temp = []
        self.print_node(self.root_node, 0, temp)

    def print_node(self, node, indent, temp):
        output = node.get_node_words(temp)
        out = '    ' * indent
        if indent != 0 and output != '':
            out += 'topic=%d level=%d (documents=%d): ' % (
                node.node_id, node.level, node.customers)
            out += output
            temp.extend(output.split(', '))
            print(out)
        for child in node.children:
            self.print_node(child, indent + 1, temp)