Example #1
0
class HDHProcess:
    def __init__(self,
                 num_patterns,
                 alpha_0,
                 mu_0,
                 vocabulary,
                 omega=1,
                 doc_length=20,
                 doc_min_length=5,
                 words_per_pattern=10,
                 random_state=None):
        """
        Parameters
        ----------
        num_patterns : int
            The maximum number of patterns that will be shared across
            the users.

        alpha_0 : tuple
            The parameter that is used when sampling the time kernel weights
            of each pattern. The distribution that is being used is a Gamma.
            This tuple should be of the form (shape, scale).

        mu_0 : tuple
            The parameter of the Gamma distribution that is used to sample
            each user's \mu (activity level). This tuple should be of the
            form (shape, scale).

        vocabulary : list
            The list of available words to use when generating documents.

        omega : float, default is 1
            The decay parameter for the decay of the exponential decay kernel.

        doc_length : int, default is 20
            The maximum number of words per document.

        doc_min_length : int, default is 5
            The minimum number of words per document.

        words_per_pattern: int, default is 10
            The number of words that will have a non-zero probability to appear
            in each pattern.

        random_state: int or RandomState object, default is None
            The random number generator.
        """
        self.prng = check_random_state(random_state)
        self.doc_prng = RandomState(self.prng.randint(200000000))
        self.time_kernel_prng = RandomState(self.prng.randint(2000000000))
        self.pattern_param_prng = RandomState(self.prng.randint(2000000000))

        self.num_patterns = num_patterns
        self.alpha_0 = alpha_0
        self.vocabulary = vocabulary
        self.mu_0 = mu_0
        self.document_length = doc_length
        self.document_min_length = doc_min_length
        self.omega = omega
        self.words_per_pattern = words_per_pattern
        self.pattern_params = self.sample_pattern_params()
        self.time_kernels = self.sample_time_kernels()
        self.pattern_popularity = self.sample_pattern_popularity()

        # Initialize all the counters etc.
        self.reset()

    def reset(self):
        """Removes all the events and users already sampled.


        Note
        ----
        It does not reseed the random number generator. It also retains the
        already sampled pattern parameters (word distributions and alphas)
        """
        self.mu_per_user = {}
        self.num_users = 0
        self.time_history = []
        self.time_history_per_user = {}
        self.table_history_per_user = {}
        self.dish_on_table_per_user = {}
        self.dish_counters = defaultdict(int)
        self.last_event_user_pattern = defaultdict(lambda: defaultdict(int))
        self.total_tables = 0
        self.first_observed_time = {}
        self.user_table_cache = defaultdict(dict)
        self.table_history_per_user = defaultdict(list)
        self.time_history_per_user = defaultdict(list)
        self.document_history_per_user = defaultdict(list)
        self.dish_on_table_per_user = defaultdict(dict)
        self.cache_per_user = defaultdict(dict)
        self.total_tables_per_user = defaultdict(int)
        self.events = []
        self.per_pattern_word_counts = defaultdict(lambda: defaultdict(int))
        self.per_pattern_word_count_total = defaultdict(int)

    def sample_pattern_params(self):
        """Returns the word distributions for each pattern.


        Returns
        -------
        parameters : list
            A list of word distributions, one for each pattern.
        """
        sampled_params = {}
        V = len(self.vocabulary)
        for pattern in range(self.num_patterns):
            custom_theta = [0] * V
            words_in_pattern = self.prng.choice(V,
                                                size=self.words_per_pattern,
                                                replace=False)
            for word in words_in_pattern:
                custom_theta[word] = 100. / self.words_per_pattern
            sampled_params[pattern] = \
                self.pattern_param_prng.dirichlet(custom_theta)
        return sampled_params

    def sample_time_kernels(self):
        """Returns the time decay parameter of each pattern.


        Returns
        -------
        alphas : list
            A list of time decay parameters, one for each pattern.
        """
        alphas = {
            pattern: self.time_kernel_prng.gamma(self.alpha_0[0],
                                                 self.alpha_0[1])
            for pattern in range(self.num_patterns)
        }
        return alphas

    def sample_pattern_popularity(self):
        """Returns a popularity distribution over the patterns.


        Returns
        -------
        pattern_popularities : list
            A list with the popularity distribution of each pattern.
        """
        pattern_popularity = {}
        pi = self.pattern_param_prng.dirichlet(
            [1 for pattern in range(self.num_patterns)])
        for pattern_i, pi_i in enumerate(pi):
            pattern_popularity[pattern_i] = pi_i
        return pattern_popularity

    def sample_mu(self):
        """Samples a value from the prior of the base intensity mu.


        Returns
        -------
        mu_u : float
            The base intensity of a user, sampled from the prior.
        """
        return self.prng.gamma(self.mu_0[0], self.mu_0[1])

    def sample_next_time(self, pattern, user):
        """Samples the time of the next event of a pattern for a given user.


        Parameters
        ----------
        pattern : int
            The pattern index that we want to sample the next event for.

        user : int
            The index of the user that we want to sample for.


        Returns
        -------
        timestamp : float
        """
        U = self.prng.rand
        mu_u = self.mu_per_user[user]
        lambda_u_pattern = mu_u * self.pattern_popularity[pattern]

        if user in self.total_tables_per_user:
            # We have seen the user before
            num_tables = self.total_tables_per_user[user]
            pattern_tables = [
                table for table in range(num_tables)
                if self.dish_on_table_per_user[user][table] == pattern
            ]
        else:
            pattern_tables = []
        alpha = self.time_kernels[pattern]
        if not pattern_tables:
            lambda_star = lambda_u_pattern
            s = -1 / lambda_star * np.log(U())
            return s
        else:
            # Add the \alpha of the most recent table (previous event) in the
            # user-pattern intensity
            s = self.last_event_user_pattern[user][pattern]
            pattern_intensity = 0
            for table in pattern_tables:
                t_last, sum_kernels = self.user_table_cache[user][table]
                update_value = self.kernel(s, t_last)
                # update_value should be 1, so practically we are just adding
                # \alpha to the intensity dt after the event
                table_intensity = alpha * sum_kernels * update_value
                table_intensity += alpha * update_value
                pattern_intensity += table_intensity
            lambda_star = lambda_u_pattern + pattern_intensity

            # New event
            accepted = False
            while not accepted:
                s = s - 1 / lambda_star * np.log(U())
                # Rejection test
                pattern_intensity = 0
                for table in pattern_tables:
                    t_last, sum_kernels = self.user_table_cache[user][table]
                    update_value = self.kernel(s, t_last)
                    # update_value should be 1, so practically we are just adding
                    # \alpha to the intensity dt after the event
                    table_intensity = alpha * sum_kernels * update_value
                    table_intensity += alpha * update_value
                    pattern_intensity += table_intensity
                lambda_s = lambda_u_pattern + pattern_intensity
                if U() < lambda_s / lambda_star:
                    return s
                else:
                    lambda_star = lambda_s

    def sample_user_events(self,
                           min_num_events=100,
                           max_num_events=None,
                           t_max=None):
        """Samples events for a user.


        Parameters
        ----------
        min_num_events : int, default is 100
            The minimum number of events to sample.

        max_num_events : int, default is None
            If not None, this is the maximum number of events to sample.

        t_max : float, default is None
            The time limit until which to sample events.


        Returns
        -------
        events : list
            A list of the form [(t_i, doc_i, user_i, meta_i), ...] sorted by
            increasing time that has all the events of the sampled users.
            Above, doc_i is the document and meta_i is any sort of metadata
            that we want for doc_i, e.g. question_id. The generator will return
            an empty list for meta_i.
        """
        user = len(self.mu_per_user)
        mu_u = self.sample_mu()
        self.mu_per_user[user] = mu_u

        # Populate the list with the first event for each pattern
        next_time_per_pattern = [
            self.sample_next_time(pattern, user)
            for pattern in range(self.num_patterns)
        ]
        next_time_per_pattern = asfortranarray(next_time_per_pattern)

        iteration = 0
        over_tmax = False
        while iteration < min_num_events or not over_tmax:
            if max_num_events is not None and iteration > max_num_events:
                break
            z_n = next_time_per_pattern.argmin()
            t_n = next_time_per_pattern[z_n]
            if t_max is not None and t_n > t_max:
                over_tmax = True
                break
            num_tables_user = self.total_tables_per_user[user] \
                if user in self.total_tables_per_user else 0
            tables = range(num_tables_user)
            tables = [
                table for table in tables
                if self.dish_on_table_per_user[user][table] == z_n
            ]
            intensities = []

            alpha = self.time_kernels[z_n]
            for table in tables:
                t_last, sum_kernels = self.user_table_cache[user][table]
                update_value = self.kernel(t_n, t_last)
                table_intensity = alpha * sum_kernels * update_value
                table_intensity += alpha * update_value
                intensities.append(table_intensity)
            intensities.append(mu_u * self.pattern_popularity[z_n])
            log_intensities = [
                ln(inten_i) if inten_i > 0 else -float('inf')
                for inten_i in intensities
            ]

            normalizing_log_intensity = logsumexp(log_intensities)
            intensities = [
                exp(log_intensity - normalizing_log_intensity)
                for log_intensity in log_intensities
            ]
            k = weighted_choice(intensities, self.prng)

            if k < len(tables):
                # Assign to already existing table
                table = tables[k]
                # update cache for that table
                t_last, sum_kernels = self.user_table_cache[user][table]
                update_value = self.kernel(t_n, t_last)
                sum_kernels += 1
                sum_kernels *= update_value
                self.user_table_cache[user][table] = (t_n, sum_kernels)
            else:
                table = num_tables_user
                self.total_tables += 1
                self.total_tables_per_user[user] += 1
                # Since this is a new table, initialize the cache accordingly
                self.user_table_cache[user][table] = (t_n, 0)
                self.dish_on_table_per_user[user][table] = z_n

            if z_n not in self.first_observed_time or\
                    t_n < self.first_observed_time[z_n]:
                self.first_observed_time[z_n] = t_n
            self.dish_counters[z_n] += 1

            doc_n = self.sample_document(z_n)
            self._update_word_counters(doc_n.split(), z_n)
            self.document_history_per_user[user]\
                .append(doc_n)
            self.table_history_per_user[user].append(table)
            self.time_history_per_user[user].append(t_n)
            self.last_event_user_pattern[user][z_n] = t_n

            # Resample time for that pattern
            next_time_per_pattern[z_n] = self.sample_next_time(z_n, user)
            z_n = next_time_per_pattern.argmin()
            t_n = next_time_per_pattern[z_n]
            iteration += 1

        events = [(self.time_history_per_user[user][i],
                   self.document_history_per_user[user][i], user, [])
                  for i in range(len(self.time_history_per_user[user]))]

        # Update the full history of events with the ones generated for the
        # current user and re-order everything so that the events are
        # ordered by their timestamp
        self.events += events
        self.events = sorted(self.events, key=lambda x: x[0])
        self.num_users += 1
        return events

    def kernel(self, t_i, t_j):
        """Returns the kernel function for t_i and t_j.


        Parameters
        ----------
        t_i : float
            Timestamp representing `now`.

        t_j : float
            Timestamp representaing `past`.


        Returns
        -------
        float
        """
        return exp(-self.omega * (t_i - t_j))

    def sample_document(self, pattern):
        """Sample a random document from a specific pattern.


        Parameters
        ----------
        pattern : int
            The pattern from which to sample the content.


        Returns
        -------
        str
            A space separeted string that contains all the sampled words.
        """
        length = self.doc_prng.randint(self.document_length) + \
            self.document_min_length
        words = self.doc_prng.multinomial(length, self.pattern_params[pattern])
        return ' '.join([
            self.vocabulary[i] for i, repeats in enumerate(words)
            for j in range(repeats)
        ])

    def _update_word_counters(self, doc, pattern):
        """Updates the word counters of the process for the particular document


        Parameters
        ----------
        doc : list
            A list with the words in the document.

        pattern : int
            The index of the latent pattern that this document belongs to.
        """
        for word in doc:
            self.per_pattern_word_counts[pattern][word] += 1
            self.per_pattern_word_count_total[pattern] += 1
        return

    def pattern_content_str(self,
                            patterns=None,
                            show_words=-1,
                            min_word_occurence=5):
        """Return the content information for the patterns of the process.


        Parameters
        ----------
        patterns : list, default is None
            If this list is provided, only information about the patterns in
            the list will be returned.

        show_words : int, default is -1
            The maximum number of words to show for each pattern. Notice that
            the words are sorted according to their occurence count.

        min_word_occurence : int, default is 5
            Only show words that show up at least `min_word_occurence` number
            of times in the documents of the respective pattern.


        Returns
        -------
        str
            A string with all the content information
        """
        if patterns is None:
            patterns = self.per_pattern_word_counts.keys()
        text = [
            '___Pattern %d___ \n%s\n%s' % (pattern, '\n'.join([
                '%s : %d' % (k, v) for i, (k, v) in enumerate(
                    sorted(self.per_pattern_word_counts[pattern].iteritems(),
                           key=lambda x: (x[1], x[0]),
                           reverse=True)) if v >= min_word_occurence and
                (show_words == -1 or (show_words > 0 and i < show_words))
            ]), ' '.join([
                k for i, (k, v) in enumerate(
                    sorted(self.per_pattern_word_counts[pattern].iteritems(),
                           key=lambda x: (x[1], x[0]),
                           reverse=True)) if v < min_word_occurence and
                (show_words == -1 or (show_words > 0 and i < show_words))
            ])) for pattern in self.per_pattern_word_counts
            if pattern in patterns
        ]
        return '\n\n'.join(text)

    def user_patterns_set(self, user):
        """Return the patterns that a specific user adopted.


        Parameters
        ----------
        user : int
            The index of a user.


        Returns
        -------
        set
            The set of the patterns that the user adopted.
        """
        pattern_list = [
            self.dish_on_table_per_user[user][table]
            for table in self.table_history_per_user[user]
        ]
        return list(set(pattern_list))

    def user_pattern_history_str(self,
                                 user=None,
                                 patterns=None,
                                 show_time=True,
                                 t_min=0):
        """Returns a representation of the history of a user's actions and the
        pattern that they correspond to.


        Parameters
        ----------
        user : int, default is None
            An index to the user we want to inspect. If None, the function
            runs over all the users.

        patterns : list, default is None
            If not None, limit the actions returned to the ones that belong in
            the provided patterns.

        show_time : bool, default is True
            Control wether the timestamp will appear in the representation or
            not.

        t_min : float, default is 0
            The timestamp after which we only consider actions.


        Returns
        -------
        str
        """
        if patterns is not None and type(patterns) is not set:
            patterns = set(patterns)

        return '\n'.join([
            '%spattern=%2d task=%3d (u=%d)  %s' %
            ('%5.3g ' % t if show_time else '', dish, table, u, doc)
            for u in range(self.num_users)
            for ((t, doc),
                 (table, dish
                  )) in zip([(t, d)
                             for t, d in zip(self.time_history_per_user[u],
                                             self.document_history_per_user[u])
                             ], [(table, self.dish_on_table_per_user[u][table])
                                 for table in self.table_history_per_user[u]])
            if (user is None or user == u) and (
                patterns is None or dish in patterns) and t >= t_min
        ])

    def _plot_user(self,
                   user,
                   fig,
                   num_samples,
                   T_max,
                   task_detail,
                   seed=None,
                   patterns=None,
                   colormap=None,
                   T_min=0,
                   paper=True):
        """Helper function that plots.
        """
        tics = np.arange(T_min, T_max, (T_max - T_min) / num_samples)
        tables = sorted(set(self.table_history_per_user[user]))
        active_tables = set()
        dish_set = set()

        times = [t for t in self.time_history_per_user[user]]
        start_event = min([
            i for i, t in enumerate(self.time_history_per_user[user])
            if t >= T_min
        ])
        table_cache = {}  # partial sum for each table
        dish_cache = {}  # partial sum for each dish
        table_intensities = [[] for _ in range(len(tables))]
        dish_intensities = [[] for _ in range(len(self.time_kernels))]
        i = 0  # index of tics
        j = start_event  # index of events

        while i < len(tics):
            if j >= len(times) or tics[i] < times[j]:
                # We need to measure the intensity
                dish_intensities, table_intensities = \
                    self._measure_intensities(tics[i], dish_cache=dish_cache,
                                              table_cache=table_cache,
                                              tables=tables,
                                              user=user,
                                              dish_intensities=dish_intensities,
                                              table_intensities=table_intensities)
                i += 1
            else:
                # We need to record an event and update our caches
                dish_cache, table_cache, active_tables, dish_set = \
                    self._update_cache(times[j],
                                       dish_cache=dish_cache,
                                       table_cache=table_cache,
                                       event_table=self.table_history_per_user[user][j],
                                       tables=tables,
                                       user=user,
                                       active_tables=active_tables,
                                       dish_set=dish_set)
                j += 1
        if patterns is not None:
            dish_set = set([d for d in patterns if d in dish_set])
        dish_dict = {dish: i for i, dish in enumerate(dish_set)}

        if colormap is None:
            prng = RandomState(seed)
            num_dishes = len(dish_set)
            colormap = qualitative_cmap(n_colors=num_dishes)
            colormap = prng.permutation(colormap)

        if not task_detail:
            for dish in dish_set:
                fig.plot(tics,
                         dish_intensities[dish],
                         color=colormap[dish_dict[dish]],
                         linestyle='-',
                         label="Pattern " + str(dish),
                         linewidth=3.)
        else:
            for table in active_tables:
                dish = self.dish_on_table_per_user[user][table]
                if patterns is not None and dish not in patterns:
                    continue
                fig.plot(tics,
                         table_intensities[table],
                         color=colormap[dish_dict[dish]],
                         linewidth=3.)
            for dish in dish_set:
                if patterns is not None and dish not in patterns:
                    continue
                fig.plot([], [],
                         color=colormap[dish_dict[dish]],
                         linestyle='-',
                         label="Pattern " + str(dish),
                         linewidth=5.)
        return fig

    def _measure_intensities(self, t, dish_cache, table_cache, tables, user,
                             dish_intensities, table_intensities):
        """Measures the intensities of all tables and dishes for the plot func.


        Parameters
        ----------
        t : float

        dish_cache : dict

        table_cache : dict

        tables : list

        user : int

        dish_intensities : dict

        table_intensities : dict


        Returns
        -------
        dish_intensities : dict

        table_intensities : dict
        """
        updated_dish = [False] * len(self.time_kernels)
        for table in tables:
            dish = self.dish_on_table_per_user[user][table]
            lambda_uz = self.mu_per_user[user] * self.pattern_popularity[dish]
            alpha = self.time_kernels[dish]
            if dish in dish_cache:
                t_last_dish, sum_kernels_dish = dish_cache[dish]
                update_value_dish = self.kernel(t, t_last_dish)
                dish_intensity = lambda_uz + alpha * sum_kernels_dish *\
                    update_value_dish
                dish_intensity += alpha * update_value_dish
            else:
                dish_intensity = lambda_uz
            if table in table_cache:
                # table already exists
                t_last_table, sum_kernels_table = table_cache[table]
                update_value_table = self.kernel(t, t_last_table)
                table_intensity = alpha * sum_kernels_table *\
                    update_value_table
                table_intensity += alpha * update_value_table
            else:
                # table does not exist yet
                table_intensity = 0
            table_intensities[table].append(table_intensity)
            if not updated_dish[dish]:
                # make sure to update dish only once for all the tables
                dish_intensities[dish].append(dish_intensity)
                updated_dish[dish] = True
        return dish_intensities, table_intensities

    def _update_cache(self, t, dish_cache, table_cache, event_table, tables,
                      user, active_tables, dish_set):
        """Updates the caches for the plot function when an event is recorded.


        Parameters
        ----------
        t : float

        dish_cache : dict

        table_cache : dict

        event_table : int

        tables : list

        user : int

        active_tables : set

        dish_set : set


        Returns
        -------
        dish_cache : dict

        table_cache : dict

        active_tables : set

        dish_set : set
        """
        dish = self.dish_on_table_per_user[user][event_table]
        active_tables.add(event_table)
        dish_set.add(dish)
        if event_table not in table_cache:
            table_cache[event_table] = (t, 0)
        else:
            t_last, sum_kernels = table_cache[event_table]
            update_value = self.kernel(t, t_last)
            sum_kernels += 1
            sum_kernels *= update_value
            table_cache[event_table] = (t, sum_kernels)
        if dish not in dish_cache:
            dish_cache[dish] = (t, 0)
        else:
            t_last, sum_kernels = dish_cache[dish]
            update_value = self.kernel(t, t_last)
            sum_kernels += 1
            sum_kernels *= update_value
            dish_cache[dish] = (t, sum_kernels)
        return dish_cache, table_cache, active_tables, dish_set

    def plot(self,
             num_samples=500,
             T_min=0,
             T_max=None,
             start_date=None,
             users=None,
             user_limit=50,
             patterns=None,
             task_detail=False,
             save_to_file=False,
             filename="user_timelines",
             intensity_threshold=None,
             paper=True,
             colors=None,
             fig_width=20,
             fig_height_per_user=5,
             time_unit='months',
             label_every=3,
             seed=None):
        """Plots the intensity of a set of users for a set of patterns over a
        time period.

        In this plot, each user is a separate subplot and for each user the
        plot shows her event_rate for each separate pattern that she has been
        active at.


        Parameters
        ----------
        num_samples : int, default is 500
            The granularity level of the intensity line. Smaller number of
            samples results in faster plotting, while larger numbers give
            much more detailed result.

        T_min : float, default is 0
            The minimum timestamp that the plot shows, in seconds.

        T_max : float, default is None
            If not None, this is the maximum timestamp that the plot considers,
            in seconds.

        start_date : datetime, default is None
            If provided, this is the actual datetime that corresponds to
            time 0. This is required if `paper` is True.

        users : list, default is None
            If provided, this list contains the id's of the users that will be
            plotted. Actually, only the first `user_limit` of them will be
            shown.

        user_limit : int, default is 50
            The maximum number of users to plot.

        patterns : list, default is None
            The list of patterns that will be shown in the final plot. If None,
            all of the patterns will be plotted.

        task_detail : bool, default is False
            If True, thee plot has one line per task. Otherwise, we only plot
            the commulative intensity of all tasks under the same pattern.

        save_to_file : bool, default is False
            If True, the plot will be saved to a `pdf` and a `png` file.

        filename : str, default is 'user_timelines'
            The name of the output file that will be used when saving the plot.

        intensity_threshold : float, default is None
            If provided, this is the maximum intensity value that will be
            plotted, i.e. the y_max that will be the cut-off threshold for the
            y-axis.

        paper : bool, default is True
            If True, the plot result will be the same as the figures that are
            in the published paper.

        colors : list, default is None
            A list of colors that will be used for the plot. Each color will
            correspond to a single pattern, and will be shared across all the
            users.

        fig_width : int, default is 20
            The width of the figure that will be returned.

        fig_height_per_user : int, default is 5
            The height of each separate user-plot of the final figure. If
            multiplied by the number of users, this determines the total height
            of the figure. Notice that due to a matplotlib constraint(?) the
            total height of  the figure cannot be over 70.

        time_unit : str, default is 'months'
            Controls wether the time units is measured in days (in
            which case it should be set to 'days') or months.

        label_every : int, default is 3
            The frequency of the labels that show in the x-axis.

        seed : int, default is None
            A seed to the random number generator used to assign colors to
            patterns.


        Returns
        -------
        fig : matplotlib.Figure object
        """
        prng = RandomState(seed)
        num_users = len(self.dish_on_table_per_user)
        if users is None:
            users = range(num_users)
        num_users_to_plot = min(len(users), user_limit)
        users = users[:num_users_to_plot]
        if T_max is None:
            T_max = max(
                [self.time_history_per_user[user][-1] for user in users])
        fig = plt.figure(figsize=(fig_width,
                                  min(fig_height_per_user *
                                      num_users_to_plot, 70)))

        num_patterns_global = len(self.time_kernels)
        colormap = qualitative_cmap(n_colors=num_patterns_global)
        colormap = prng.permutation(colormap)
        if colors is not None:
            colormap = matplotlib.colors.ListedColormap(colors)
        if paper:
            sns.set_style('white')
            sns.despine(bottom=True, top=False, right=False, left=False)

        user_plt_axes = []
        max_intensity = -float('inf')
        for i, user in enumerate(users):
            if user not in self.time_history_per_user \
                    or not self.time_history_per_user[user] \
                    or self.time_history_per_user[user][-1] < T_min:
                # no events generated for this user during the time window
                # we are interested in
                continue
            if patterns is not None:
                user_patterns = set([
                    self.dish_on_table_per_user[user][table]
                    for table in self.table_history_per_user[user]
                ])
                if not any([pattern in patterns for pattern in user_patterns]):
                    # user did not generate events in the patterns of interest
                    continue
            user_plt = plt.subplot(num_users_to_plot, 1, i + 1)
            user_plt = self._plot_user(user,
                                       user_plt,
                                       num_samples,
                                       T_max,
                                       task_detail=task_detail,
                                       seed=seed,
                                       patterns=patterns,
                                       colormap=colormap,
                                       T_min=T_min,
                                       paper=paper)
            user_plt.set_xlim((T_min, T_max))
            if paper:
                if start_date is None:
                    raise ValueError(
                        'For paper-level quality plots, the actual datetime for t=0 must be provided as `start_date`'
                    )
                if start_date.microsecond > 500000:
                    start_date = start_date.replace(microsecond=0) \
                        + datetime.timedelta(seconds=1)
                else:
                    start_date = start_date.replace(microsecond=0)
                if time_unit == 'days':
                    t_min_seconds = T_min * 86400
                    t_max_seconds = T_max * 86400
                    t1 = start_date + datetime.timedelta(0, t_min_seconds)
                    t2 = start_date + datetime.timedelta(0, t_max_seconds)
                    ticks = monthly_ticks_for_days(t1, t2)
                    labels = monthly_labels(t1, t2, every=label_every)
                elif time_unit == 'months':
                    t1 = month_add(start_date, T_min)
                    t2 = month_add(start_date, T_max)
                    ticks = monthly_ticks_for_months(t1, t2)
                    labels = monthly_labels(t1, t2, every=label_every)
                labels[-1] = ''
                user_plt.set_xlim((ticks[0], ticks[-1]))
                user_plt.yaxis.set_ticks([])
                user_plt.xaxis.set_ticks(ticks)
                plt.setp(user_plt.xaxis.get_majorticklabels(), rotation=-0)
                user_plt.tick_params('x',
                                     length=10,
                                     which='major',
                                     direction='out',
                                     top=False,
                                     bottom=True)
                user_plt.xaxis.set_ticklabels(labels, fontsize=30)
                user_plt.tick_params(axis='x', which='major', pad=10)
                user_plt.get_xaxis(
                ).majorTicks[1].label1.set_horizontalalignment('left')
                for tick in user_plt.xaxis.get_major_ticks():
                    tick.label1.set_horizontalalignment('left')

            current_intensity = user_plt.get_ylim()[1]
            if current_intensity > max_intensity:
                max_intensity = current_intensity
            user_plt_axes.append(user_plt)
            if not paper:
                plt.title('User %2d' % user)
                plt.legend()
        if intensity_threshold is None:
            intensity_threshold = max_intensity
        for ax in user_plt_axes:
            ax.set_ylim((-0.2, intensity_threshold))
        if paper and save_to_file:
            print("Create image %s.png" % (filename))
            plt.savefig(filename + '.png', transparent=True)
            plt.savefig(filename + '.pdf', transparent=True)
            sns.set_style(None)
        return fig

    def user_patterns(self, user):
        """Returns a list with the patterns that a user has adopted.

        Parameters
        ----------
        user : int
        """
        pattern_list = [
            self.dish_on_table_per_user[user][table]
            for table in self.table_history_per_user[user]
        ]
        return list(set(pattern_list))

    def show_annotated_events(self,
                              user=None,
                              patterns=None,
                              show_time=True,
                              T_min=0,
                              T_max=None):
        """Returns a string where each event is annotated with the inferred
        pattern.


        Parameters
        ----------
        user : int, default is None
            If given, the events returned are limited to the selected user

        patterns : list, default is None
            If not None, an event is return only if it belongs to one of the
            selected patterns

        show_time : bool, default is True
            Controls whether the time of the event will be shown

        T_min : float, default is 0
            Controls the minimum timestamp after which the events will be shown

        T_max : float, default is None
            If given, T_max controls the maximum timestamp shown


        Returns
        -------
        str
        """
        if patterns is not None and type(patterns) is not set:
            patterns = set(patterns)

        if show_time:
            return '\n'.join([
                '%5.3g pattern=%3d task=%3d (u=%d)  %s' %
                (t, dish, table, u, doc) for u in range(self.num_users)
                for ((t, doc), (table, dish))
                in zip([(t, d)
                        for t, d in zip(self.time_history_per_user[u],
                                        self.document_history_per_user[u])],
                       [(table, self.dish_on_table_per_user[u][table])
                        for table in self.table_history_per_user[u]])
                if (user is None or user == u) and (
                    patterns is None or dish in patterns) and t >= T_min and (
                        T_max is None or (T_max is not None and t <= T_max))
            ])
        else:
            return '\n'.join([
                'pattern=%3d task=%3d (u=%d)  %s' % (dish, table, u, doc)
                for u in range(self.num_users) for ((t, doc), (table, dish))
                in zip([(t, d)
                        for t, d in zip(self.time_history_per_user[u],
                                        self.document_history_per_user[u])],
                       [(table, self.dish_on_table_per_user[u][table])
                        for table in self.table_history_per_user[u]])
                if (user is None or user == u) and (
                    patterns is None or dish in patterns) and t >= T_min and (
                        T_max is None or (T_max is not None and t <= T_max))
            ])

    def show_pattern_content(self, patterns=None, words=0, detail_threshold=5):
        """Shows the content distrubution of the inferred patterns.


        Parameters
        ----------
        patterns : list, default is None
            If not None, only the content of the selected patterns will be
            shown

        words : int, default is 0
            A positive number that control how many words will be shown.
            The words are being shown sorted by their likelihood, starting
            with the most probable.

        detail_threshold : int, default is 5
            A positive number that sets the lower bound in the number of times
            that a word appeared in a pattern so that its count is shown.


        Returns
        -------
        str
        """
        if patterns is None:
            patterns = self.per_pattern_word_count.keys()
        text = [
            '___Pattern %d___ \n%s\n%s' % (pattern, '\n'.join([
                '%s : %d' % (k, v) for i, (k, v) in enumerate(
                    sorted(self.per_pattern_word_counts[pattern].iteritems(),
                           key=lambda x: (x[1], x[0]),
                           reverse=True))
                if v >= detail_threshold and (words == 0 or i < words)
            ]), ' '.join([
                k for i, (k, v) in enumerate(
                    sorted(self.per_pattern_word_counts[pattern].iteritems(),
                           key=lambda x: (x[1], x[0]),
                           reverse=True))
                if v < detail_threshold and (words == 0 or i < words)
            ])) for pattern in self.per_pattern_word_counts
            if pattern in patterns
        ]
        return '\n\n'.join(text)
Example #2
0
def randomu(seed, di=None, binomial=None, double=False, gamma=False,
            normal=False, poisson=False):
    """
    Replicates the randomu function avaiable within IDL
    (Interactive Data Language, EXELISvis).
    Returns an array of uniformly distributed random numbers of the
    specified dimensions.
    The randomu function returns one or more pseudo-random numbers
    with one or more of the following distributions:
    Uniform (default)
    Gaussian
    binomial
    gamma
    poisson

    :param seed:
        If seed is not of type mtrand.RandomState, then a new state is
        initialised. Othersise seed will be used to generate the random
        values.

    :param di:
        A list specifying the dimensions of the resulting array. If di
        is a scalar then randomu returns a scalar.
        Dimensions are D1, D2, D3...D8 (x,y,z,lambda...).
        The list will be inverted to suit Python's inverted dimensions
        i.e. (D3,D2,D1).

    :param binomial:
        Set this keyword to a list of length 2, [n,p], to generate
        random deviates from a binomial distribution. If an event
        occurs with probablility p, with n trials, then the number of
        times it occurs has a binomial distribution.

    :param double:
        If set to True, then randomu will return a double precision
        random numbers.

    :param gamma:
        Set this keyword to an integer order i > 0 to generate random
        deviates from a gamm distribution.

    :param Long:
        If set to True, then randomu will return integer uniform
        random deviates in the range [0...2^31-1], using the Mersenne
        Twister algorithm. All other keywords will be ignored.

    :param normal:
        If set to True, then random deviates will be generated from a
        normal distribution.

    :param poisson:
        Set this keyword to the mean number of events occurring during
        a unit of time. The poisson keword returns a random deviate
        drawn from a poisson distribution with that mean.

    :param ULong:
        If set to True, then randomu will return unsigned integer
        uniform deviates in the range [0..2^32-1], using the Mersenne
        Twister algorithm. All other keywords will be ignored.

    :return:
        A NumPy array of uniformly distributed random numbers of the
        specified dimensions.

    Example:
        >>> seed = None
        >>> x, sd = randomu(seed, [10,10])
        >>> x, sd = randomu(seed, [100,100], binomial=[10,0.5])
        >>> x, sd = randomu(seed, [100,100], gamma=2)
        >>> # 200x by 100y array of normally distributed values
        >>> x, sd = randomu(seed, [200,100], normal=True)
        >>> # 1000 deviates from a poisson distribution with a mean of 1.5
        >>> x, sd = randomu(seed, [1000], poisson=1.5)
        >>> # Return a scalar from a uniform distribution
        >>> x, sd = randomu(seed)

    :author:
        Josh Sixsmith, [email protected], [email protected]

    :copyright:
        Copyright (c) 2014, Josh Sixsmith
        All rights reserved.

        Redistribution and use in source and binary forms, with or without
        modification, are permitted provided that the following conditions are met:

        1. Redistributions of source code must retain the above copyright notice, this
           list of conditions and the following disclaimer.
        2. Redistributions in binary form must reproduce the above copyright notice,
           this list of conditions and the following disclaimer in the documentation
           and/or other materials provided with the distribution.

        THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
        ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
        WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
        DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
        ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
        (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
        LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
        ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
        (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
        SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

        The views and conclusions contained in the software and documentation are those
        of the authors and should not be interpreted as representing official policies,
        either expressed or implied, of the FreeBSD Project.
    """

    # Initialise the data type
    if double:
        dtype = 'float64'
    else:
        dtype = 'float32'

    # Check the seed
    # http://stackoverflow.com/questions/5836335/consistenly-create-same-random-numpy-array
    if type(seed) != mtrand.RandomState:
        seed = RandomState()

    if di is not None:
        if type(di) is not list:
            raise TypeError("Dimensions must be a list or None.")
        if len(di) > 8:
            raise ValueError("Error. More than 8 dimensions specified.")
        # Invert the dimensions list
        dims = di[::-1]
    else:
        dims = 1

    # Python has issues with overflow:
    # OverflowError: Python int too large to convert to C long
    # Occurs with Long and ULong
    #if Long:
    #    res = seed.random_integers(0, 2**31-1, dims)
    #    if di is None:
    #        res = res[0]
    #    return res, seed

    #if ULong:
    #    res = seed.random_integers(0, 2**32-1, dims)
    #    if di is None:
    #        res = res[0]
    #    return res, seed

    # Check for other keywords
    distributions = 0
    kwds = [binomial, gamma, normal, poisson]
    for kwd in kwds:
        if kwd:
            distributions += 1

    if distributions > 1:
        print("Conflicting keywords.")
        return

    if binomial:
        if len(binomial) != 2:
            msg = "Error. binomial must contain [n,p] trials & probability."
            raise ValueError(msg)

        n = binomial[0]
        p = binomial[1]

        res = seed.binomial(n, p, dims)

    elif gamma:
        res = seed.gamma(gamma, size=dims)

    elif normal:
        res = seed.normal(0, 1, dims)

    elif poisson:
        res = seed.poisson(poisson, dims)

    else:
        res = seed.uniform(size=dims)

    res = res.astype(dtype)

    if di is None:
        res = res[0]

    return res, seed
Example #3
0
class Particle(object):
    def __init__(self,
                 vocabulary_length,
                 num_users,
                 time_kernels=None,
                 alpha_0=(2, 2),
                 mu_0=1,
                 theta_0=None,
                 seed=None,
                 logweight=0,
                 update_kernels=False,
                 uid=0,
                 omega=1,
                 beta=1,
                 keep_alpha_history=False,
                 mu_rate=0.6):
        self.vocabulary_length = vocabulary_length
        self.vocabulary = None
        self.seed = seed
        self.prng = RandomState(self.seed)
        self.first_observed_time = {}
        self.first_observed_user_time = {}
        self.per_topic_word_counts = {}
        self.per_topic_word_count_total = {}
        self.time_kernels = {}
        self.alpha_0 = alpha_0
        self.mu_0 = mu_0
        self.theta_0 = array(theta_0)
        self._lntheta = _ln(theta_0[0])
        self.logweight = logweight
        self.update_kernels = update_kernels
        self.uid = uid
        self.num_events = 0
        self.topic_previous_event = None
        # The following are for speed optimization purposes
        # A struture to save the total intensity of a topic
        # up to the most recent event t_i of that topic.
        # It will be used to measure the total intensity at
        # any time after t_i
        self._Qn = None
        self.omega = omega
        self.beta = beta
        self.num_users = num_users
        self.keep_alpha_history = keep_alpha_history

        self.user_table_cache = {}
        self.dish_on_table_per_user = {}
        self.dish_on_table_todelete = {}
        self.dish_counters = {}
        self._max_dish = -1
        self.total_tables = 0

        self.table_history_with_user = []
        self.time_previous_user_event = []
        self.total_tables_per_user = []
        self.dish_cache = {}
        self.time_kernel_prior = {}
        self.time_history_per_user = {}
        self.doc_history_per_user = {}
        self.question_history_per_user = {}
        self.table_history_per_user = {}
        self.alpha_history = {}
        self.alpha_distribution_history = {}
        self.mu_rate = mu_rate
        self.mu_per_user = {}
        self.time_elapsed = 0
        self.active_tables_per_user = {}

    def reseed(self, seed=None, uid=None):
        self.seed = seed
        self.prng = RandomState(self.seed)
        if uid is None:
            self.uid = self.prng.randint(maxint)
        else:
            self.uid = uid

    def reset_weight(self):
        self.logweight = 0

    def copy(self):
        new_p = Particle(num_users=self.num_users,
                         vocabulary_length=self.vocabulary_length,
                         seed=self.seed,
                         mu_rate=self.mu_rate,
                         theta_0=self.theta_0,
                         omega=self.omega,
                         beta=self.beta,
                         mu_0=self.mu_0,
                         uid=self.uid,
                         logweight=self.logweight,
                         update_kernels=self.update_kernels,
                         keep_alpha_history=self.keep_alpha_history)
        new_p.alpha_0 = copy(self.alpha_0)
        new_p.num_events = self.num_events
        new_p.topic_previous_event = self.topic_previous_event
        new_p.total_tables = self.total_tables
        new_p._max_dish = self._max_dish

        new_p.time_previous_user_event = copy(self.time_previous_user_event)
        new_p.total_tables_per_user = copy(self.total_tables_per_user)
        new_p.first_observed_time = copy(self.first_observed_time)
        new_p.first_observed_user_time = copy(self.first_observed_user_time)
        new_p.table_history_with_user = copy(self.table_history_with_user)

        new_p.dish_cache = copy_dict(self.dish_cache)
        new_p.dish_counters = copy_dict(self.dish_counters)
        new_p.dish_on_table_per_user = \
            copy_dict(self.dish_on_table_per_user)

        new_p.dish_on_table_per_user = {}
        new_p.dish_on_table_todelete = {}
        for u in self.dish_on_table_per_user:
            new_p.dish_on_table_per_user[u] = {}
            new_p.dish_on_table_todelete[u] = {}
            self.dish_on_table_todelete[u] = {}

            for t in self.dish_on_table_per_user[u]:
                if t in self.active_tables_per_user[u]:
                    new_p.dish_on_table_per_user[u][t] = \
                        self.dish_on_table_per_user[u][t]
                else:
                    dish = self.dish_on_table_per_user[u][t]
                    self.dish_on_table_todelete[u][t] = dish
                    new_p.dish_on_table_todelete[u][t] = dish
                    if t in self.user_table_cache[u]:
                        del self.user_table_cache[u][t]

        new_p.per_topic_word_counts = copy_dict(self.per_topic_word_counts)
        new_p.per_topic_word_count_total = copy_dict(
            self.per_topic_word_count_total)
        new_p.time_kernels = copy_dict(self.time_kernels)
        new_p.time_kernel_prior = copy_dict(self.time_kernel_prior)
        new_p.user_table_cache = copy_dict(self.user_table_cache)
        if self.keep_alpha_history:
            new_p.alpha_history = copy_dict(self.alpha_history)
            new_p.alpha_distribution_history = \
                copy_dict(self.alpha_distribution_history)
        new_p.mu_per_user = copy_dict(self.mu_per_user)
        new_p.active_tables_per_user = copy_dict(self.active_tables_per_user)
        return new_p

    def update(self, event):
        """Parses an event and updates the particle


        Parameters
        ----------
        event : tuple
            The event is a 4-tuple of the form (user, time, content, metadata)
        """
        # u_n : user of the n-th event
        # t_n : time of the n-th event
        # d_n : text of the n-th event
        # q_n : any metadata for the n-th event, e.g. the question id
        t_n, d_n, u_n, q_n = event
        d_n = d_n.split()

        if self.num_events == 0:
            self.time_previous_user_event = [0 for i in range(self.num_users)]
            self.total_tables_per_user = [0 for i in range(self.num_users)]
            self.mu_per_user = {
                i: self.sample_mu()
                for i in range(self.num_users)
            }
            self.active_tables_per_user = {
                i: set()
                for i in range(self.num_users)
            }
        if self.num_events >= 1 and u_n in self.time_previous_user_event and \
                self.time_previous_user_event[u_n] > 0:
            log_likelihood_tn = self.time_event_log_likelihood(t_n, u_n)
        else:
            log_likelihood_tn = 0

        tables_before = self.total_tables_per_user[u_n]
        b_n, z_n, opened_table, log_likelihood_dn = self.sample_table(
            t_n, d_n, u_n)
        if self.total_tables_per_user[
                u_n] > tables_before and tables_before > 0:
            # opened a new table
            old_mu = self.mu_per_user[u_n]
            tables_num = tables_before + 1
            user_alive_time = t_n - self.first_observed_user_time[u_n]
            new_mu = (self.mu_rate * old_mu +
                      (1 - self.mu_rate) * tables_num / user_alive_time)
            self.mu_per_user[u_n] = new_mu

        if z_n not in self.time_kernels:
            self.time_kernels[z_n] = self.sample_time_kernel()
            self.first_observed_time[z_n] = t_n
            self.dish_cache[z_n] = (t_n, 0, 1, 1, 1)
            self._max_dish = z_n
        else:
            if self.update_kernels:
                self.update_time_kernel(t_n, z_n)
        if self.update_kernels and self.keep_alpha_history:
            if z_n not in self.alpha_history:
                self.alpha_history[z_n] = []
                self.alpha_distribution_history[z_n] = []
            self.alpha_history[z_n].append(self.time_kernels[z_n])
            self.alpha_distribution_history[z_n].append(
                self.time_kernel_prior[z_n])
        if self.num_events >= 1:
            self.logweight += log_likelihood_tn
            self.logweight += self._Qn
        self.num_events += 1
        self._update_word_counters(d_n, z_n)

        self.time_previous_user_event[u_n] = t_n
        self.topic_previous_event = z_n
        self.user_previous_event = u_n
        self.table_previous_event = b_n
        self.active_tables_per_user[u_n].add(b_n)
        if z_n not in self.dish_counters:
            self.dish_counters[z_n] = 1
        elif opened_table:
            self.dish_counters[z_n] += 1
        if u_n not in self.first_observed_user_time:
            self.first_observed_user_time[u_n] = t_n
        return b_n, z_n

    def sample_table(self, t_n, d_n, u_n):
        """Samples table b_n and topic z_n together for the event n.


        Parameters
        ----------
        t_n : float
            The time of the event.

        d_n : list
            The document for the event.

        u_n : int
            The user id.


        Returns
        -------
        table : int

        dish : int
        """
        if self.total_tables_per_user[u_n] == 0:
            # This is going to be the user's first table
            self.dish_on_table_per_user[u_n] = {}
            self.user_table_cache[u_n] = {}
            self.time_previous_user_event[u_n] = 0

        tables = range(self.total_tables_per_user[u_n])
        num_dishes = len(self.dish_counters)
        intensities = []
        dn_word_counts = Counter(d_n)
        count_dn = len(d_n)
        # Precompute the doc_log_likelihood for each of the dishes
        dish_log_likelihood = []
        for dish in self.dish_counters:
            dll = self.document_log_likelihood(dn_word_counts, count_dn, dish)
            dish_log_likelihood.append(dll)

        table_intensity_threshold = 1e-8  # below this, the table is inactive

        # Provide one option for each of the already open tables
        mu = self.mu_per_user[u_n]
        total_table_int = mu
        dish_log_likelihood_array = []
        for table in tables:
            if table in self.active_tables_per_user[u_n]:
                dish = self.dish_on_table_per_user[u_n][table]
                alpha = self.time_kernels[dish]
                t_last, sum_kernels = self.user_table_cache[u_n][table]
                update_value = self.kernel(t_n, t_last)
                table_intensity = alpha * sum_kernels * update_value
                table_intensity += alpha * update_value
                total_table_int += table_intensity
                if table_intensity < table_intensity_threshold:
                    self.active_tables_per_user[u_n].remove(table)
                dish_log_likelihood_array.append(dish_log_likelihood[dish])
                intensities.append(table_intensity)
            else:
                dish_log_likelihood_array.append(0)
                intensities.append(0)
        log_intensities = [
            ln(inten_i / total_table_int) +
            dish_log_likelihood_array[i] if inten_i > 0 else -float('inf')
            for i, inten_i in enumerate(intensities)
        ]

        # Provide one option for new table with already existing dish
        for dish in self.dish_counters:
            dish_intensity = (mu / total_table_int) * \
                             self.dish_counters[dish] / (self.total_tables + self.beta)
            dish_intensity = ln(dish_intensity)
            dish_intensity += dish_log_likelihood[dish]
            log_intensities.append(dish_intensity)

        # Provide a last option for new table with new dish
        new_dish_intensity = mu * self.beta / \
                             (total_table_int * (self.total_tables + self.beta))
        new_dish_intensity = ln(new_dish_intensity)
        new_dish_log_likelihood = self.document_log_likelihood(
            dn_word_counts, count_dn, num_dishes)
        new_dish_intensity += new_dish_log_likelihood
        log_intensities.append(new_dish_intensity)

        normalizing_log_intensity = logsumexp(log_intensities)
        intensities = [
            exp(log_intensity - normalizing_log_intensity)
            for log_intensity in log_intensities
        ]
        self._Qn = normalizing_log_intensity
        k = weighted_choice(intensities, self.prng)
        opened_table = False
        if k in tables:
            # Assign to one of the already existing tables
            table = k
            dish = self.dish_on_table_per_user[u_n][table]
            # update cache for that table
            t_last, sum_kernels = self.user_table_cache[u_n][table]
            update_value = self.kernel(t_n, t_last)
            sum_kernels += 1
            sum_kernels *= update_value
            self.user_table_cache[u_n][table] = (t_n, sum_kernels)
        else:
            k = k - len(tables)
            table = len(tables)
            self.total_tables += 1
            self.total_tables_per_user[u_n] += 1
            dish = k
            # Since this is a new table, initialize the cache accordingly
            self.user_table_cache[u_n][table] = (t_n, 0)
            self.dish_on_table_per_user[u_n][table] = dish
            opened_table = True
            if dish not in self.time_kernel_prior:
                self.time_kernel_prior[dish] = self.alpha_0
                dll = self.document_log_likelihood(dn_word_counts, count_dn,
                                                   dish)
                dish_log_likelihood.append(dll)

        self.table_history_with_user.append((u_n, table))
        self.time_previous_user_event[u_n] = t_n
        return table, dish, opened_table, dish_log_likelihood[dish]

    def kernel(self, t_i, t_j):
        """Returns the kernel function for t_i and t_j.


        Parameters
        ----------
        t_i : float
            The later timestamp

        t_j : float
            The earlier timestamp


        Returns
        -------
        float
        """
        return exp(-self.omega * (t_i - t_j))

    def update_time_kernel(self, t_n, z_n):
        """Updates the parameter of the time kernel of the chosen pattern
        """
        v_1, v_2 = self.time_kernel_prior[z_n]
        t_last, sum_kernels, event_count, intensity, prod = self.dish_cache[
            z_n]
        update_value = self.kernel(t_n, t_last)

        sum_kernels += 1
        sum_kernels *= update_value
        prod = sum_kernels
        sum_integrals = event_count - sum_kernels
        sum_integrals /= self.omega

        self.time_kernel_prior[z_n] = self.alpha_0[
            0] + event_count - self.dish_counters[z_n], self.alpha_0[1] + (
                sum_integrals)

        prior = self.time_kernel_prior[z_n]
        self.time_kernels[z_n] = self.sample_time_kernel(prior)

        self.dish_cache[
            z_n] = t_n, sum_kernels, event_count + 1, intensity, prod

    def sample_time_kernel(self, alpha_0=None):
        if alpha_0 is None:
            alpha_0 = self.alpha_0
        return self.prng.gamma(alpha_0[0], 1. / alpha_0[1])

    def sample_mu(self):
        """Samples a value from the prior of the base intensity mu.


        Returns
        -------
        mu_u : float
            The base intensity of a user, sampled from the prior.
        """
        return self.prng.gamma(self.mu_0[0], self.mu_0[1])

    def document_log_likelihood(self, dn_word_counts, count_dn, z_n):
        """Returns the log likelihood of document d_n to belong to cluster z_n.

        Note: Assumes a Gamma prior on the word distribution.
        """
        theta = self.theta_0[0]
        V = self.vocabulary_length
        if z_n not in self.per_topic_word_count_total:
            count_zn_no_dn = 0
        else:
            count_zn_no_dn = self.per_topic_word_count_total[z_n]
        # TODO: The code below works only for uniform theta_0. We should
        # put the theta that corresponds to `word`. Here we assume that
        # all the elements of theta_0 are equal
        gamma_numerator = _gammaln(count_zn_no_dn + V * theta)
        gamma_denominator = _gammaln(count_zn_no_dn + count_dn + V * theta)
        is_old_topic = z_n <= self._max_dish
        unique_words = len(dn_word_counts) == count_dn
        topic_words = None
        if is_old_topic:
            topic_words = self.per_topic_word_counts[z_n]

        if unique_words:
            rest = [
                _ln(topic_words[word] + theta)
                if is_old_topic and word in topic_words else self._lntheta
                for word in dn_word_counts
            ]
        else:
            rest = [
                _gammaln(topic_words[word] + dn_word_counts[word] + theta) -
                _gammaln(topic_words[word] + theta)
                if is_old_topic and word in topic_words else
                _gammaln(dn_word_counts[word] + theta) - _gammaln(theta)
                for word in dn_word_counts
            ]
        return gamma_numerator - gamma_denominator + sum(rest)

    def document_history_log_likelihood(self):
        """Computes the log likelihood for the whole history of documents,
        using the inferred parameters.
        """
        doc_log_likelihood = 0
        for user in self.doc_history_per_user:
            for doc, table in zip(self.doc_history_per_user[user],
                                  self.table_history_per_user[user]):
                dish = self.dish_on_table_per_user[user][table]
                doc_word_counts = Counter(doc.split())
                count_doc = len(doc.split())
                doc_log_likelihood += self.document_log_likelihood(
                    doc_word_counts, count_doc, dish)
        return doc_log_likelihood

    def time_event_log_likelihood(self, t_n, u_n):
        mu = self.mu_per_user[u_n]
        integral = (t_n - self.time_previous_user_event[u_n]) * mu
        intensity = mu
        for table in self.user_table_cache[u_n]:
            t_last, sum_timedeltas = self.user_table_cache[u_n][table]
            update_value = self.kernel(t_n, t_last)
            topic_sum = (sum_timedeltas + 1) - \
                        (sum_timedeltas + 1) * update_value
            dish = self.dish_on_table_per_user[u_n][table]
            topic_sum *= self.time_kernels[dish]
            integral += topic_sum
            intensity += (sum_timedeltas + 1) \
                         * self.time_kernels[dish] * update_value
        return ln(intensity) - integral

    def _update_word_counters(self, d_n, z_n):
        if z_n not in self.per_topic_word_counts:
            self.per_topic_word_counts[z_n] = {}
        if z_n not in self.per_topic_word_count_total:
            self.per_topic_word_count_total[z_n] = 0
        for word in d_n:
            if word not in self.per_topic_word_counts[z_n]:
                self.per_topic_word_counts[z_n][word] = 0
            self.per_topic_word_counts[z_n][word] += 1
            self.per_topic_word_count_total[z_n] += 1
        return

    def to_process(self, events):
        """Exports the particle as a HDHProcess object.

        Use the exported object to plot the user timelines.

        Returns
        -------
        HDHProcess
        """
        process = HDHProcess(num_patterns=len(self.time_kernels),
                             mu_0=self.mu_0,
                             alpha_0=self.alpha_0,
                             vocabulary=self.vocabulary)
        process.mu_per_user = self.mu_per_user
        process.table_history_per_user = self.table_history_per_user
        process.time_history_per_user = self.time_history_per_user
        process.dish_on_table_per_user = self.dish_on_table_per_user
        process.time_kernels = self.time_kernels
        process.first_observed_time = self.first_observed_time
        process.omega = self.omega
        process.num_users = self.num_users
        process.document_history_per_user = self.doc_history_per_user
        process.per_pattern_word_counts = self.per_topic_word_counts
        process.per_pattern_word_count_total = self.per_topic_word_count_total
        process.events = events
        process.dish_counters = self.dish_counters
        process.total_tables = self.total_tables

        return process

    def get_intensity(self, t_n, u_n, z_n):
        pi_z = self.dish_counters[z_n] / self.total_tables
        mu = self.mu_per_user[u_n]
        alpha = self.time_kernels[z_n]
        intensity = pi_z * mu
        for table in self.user_table_cache[u_n]:
            dish = self.dish_on_table_per_user[u_n][table]
            if dish == z_n:
                t_last, sum_timedeltas = self.user_table_cache[u_n][table]
                update_value = self.kernel(t_n, t_last)
                table_intensity = alpha * sum_timedeltas * update_value
                table_intensity += alpha * update_value
                intensity += table_intensity
        return intensity
Example #4
0
from pymoreg.model import MGNREnsemble

sns.set(color_codes=True)
SAVE = False

n_variables = 15
n_features = 10
seeds = list(range(101, 200))
rng = RandomState(1802)
variables = list(range(n_variables))
n_samples = 300

# Data generation parameters
gen_mean = np.zeros(n_variables)
gen_var = rng.gamma(shape=1, size=n_variables)
# gen_weight = 2

# Generate some data form a GN
graph = random_mbc(n_features, n_variables - n_features, rng=rng, fan_in=5)
beta = np.multiply(graph.A.T, rng.normal(0, 2, size=graph.shape))

sample_seed = rng.randint(0, 2 ** 32 - 1)
data = sample_from_gn(graph, gen_mean, gen_var, beta, n_samples, sample_seed)
test = sample_from_gn(graph, gen_mean, gen_var, beta, n_samples, sample_seed)

X, Y = data[:, :n_features], data[:, n_features:]
X_test, Y_test = test[:, :n_features], data[:, n_features:]

graph_score = BGe(data)(graph)