def score_actions(stats: SearchStats, exploration_coef: float, noise_scale: float, noise_alpha: float, rng: RandomState) -> np.ndarray: """Score actions (child nodes) according to UCT heuristic (AZ flavor)""" prior = stats.prior_prob # P(s, a) if noise_scale: # explore random actions at root node noise = rng.dirichlet(np.full(len(prior), noise_alpha)) prior = ((1.0 - noise_scale) * prior + noise_scale * noise).astype( np.float32) visit_gap = np.sqrt(np.sum(stats.num_visits)) / (1.0 + stats.num_visits) action_uct_boost = exploration_coef * prior * visit_gap # U(s, a) action_value = stats.total_value / stats.num_visits.clip(min=1) # Q(s, a) score = action_value + action_uct_boost # Q(s, a) + U(s, a) return score
class TestDirichletExpectation(unittest.TestCase): """Test log_dirichlet_expectation""" def setUp(self): self.rand = RandomState(0) def test_dirichlet_expectation_with_sampling(self): alpha = np.ones((10)) samples = int(1e5) expectation = log_dirichlet_expectation(alpha) sample_mean = np.mean(np.log(self.rand.dirichlet(alpha, samples)), 0) assert_almost_equal(expectation, sample_mean, decimal=2) def test_2d_dirichlet_expectation(self): alpha = self.rand.choice(range(1, 10), 20).reshape(2, 10) exp1 = log_dirichlet_expectation(alpha) exp2 = (psi(alpha) - psi(np.sum(alpha, 1))[:, np.newaxis]) assert_almost_equal(exp1, exp2)
class HDHProcess: def __init__(self, num_patterns, alpha_0, mu_0, vocabulary, omega=1, doc_length=20, doc_min_length=5, words_per_pattern=10, random_state=None): """ Parameters ---------- num_patterns : int The maximum number of patterns that will be shared across the users. alpha_0 : tuple The parameter that is used when sampling the time kernel weights of each pattern. The distribution that is being used is a Gamma. This tuple should be of the form (shape, scale). mu_0 : tuple The parameter of the Gamma distribution that is used to sample each user's \mu (activity level). This tuple should be of the form (shape, scale). vocabulary : list The list of available words to use when generating documents. omega : float, default is 1 The decay parameter for the decay of the exponential decay kernel. doc_length : int, default is 20 The maximum number of words per document. doc_min_length : int, default is 5 The minimum number of words per document. words_per_pattern: int, default is 10 The number of words that will have a non-zero probability to appear in each pattern. random_state: int or RandomState object, default is None The random number generator. """ self.prng = check_random_state(random_state) self.doc_prng = RandomState(self.prng.randint(200000000)) self.time_kernel_prng = RandomState(self.prng.randint(2000000000)) self.pattern_param_prng = RandomState(self.prng.randint(2000000000)) self.num_patterns = num_patterns self.alpha_0 = alpha_0 self.vocabulary = vocabulary self.mu_0 = mu_0 self.document_length = doc_length self.document_min_length = doc_min_length self.omega = omega self.words_per_pattern = words_per_pattern self.pattern_params = self.sample_pattern_params() self.time_kernels = self.sample_time_kernels() self.pattern_popularity = self.sample_pattern_popularity() # Initialize all the counters etc. self.reset() def reset(self): """Removes all the events and users already sampled. Note ---- It does not reseed the random number generator. It also retains the already sampled pattern parameters (word distributions and alphas) """ self.mu_per_user = {} self.num_users = 0 self.time_history = [] self.time_history_per_user = {} self.table_history_per_user = {} self.dish_on_table_per_user = {} self.dish_counters = defaultdict(int) self.last_event_user_pattern = defaultdict(lambda: defaultdict(int)) self.total_tables = 0 self.first_observed_time = {} self.user_table_cache = defaultdict(dict) self.table_history_per_user = defaultdict(list) self.time_history_per_user = defaultdict(list) self.document_history_per_user = defaultdict(list) self.dish_on_table_per_user = defaultdict(dict) self.cache_per_user = defaultdict(dict) self.total_tables_per_user = defaultdict(int) self.events = [] self.per_pattern_word_counts = defaultdict(lambda: defaultdict(int)) self.per_pattern_word_count_total = defaultdict(int) def sample_pattern_params(self): """Returns the word distributions for each pattern. Returns ------- parameters : list A list of word distributions, one for each pattern. """ sampled_params = {} V = len(self.vocabulary) for pattern in range(self.num_patterns): custom_theta = [0] * V words_in_pattern = self.prng.choice(V, size=self.words_per_pattern, replace=False) for word in words_in_pattern: custom_theta[word] = 100. / self.words_per_pattern sampled_params[pattern] = \ self.pattern_param_prng.dirichlet(custom_theta) return sampled_params def sample_time_kernels(self): """Returns the time decay parameter of each pattern. Returns ------- alphas : list A list of time decay parameters, one for each pattern. """ alphas = { pattern: self.time_kernel_prng.gamma(self.alpha_0[0], self.alpha_0[1]) for pattern in range(self.num_patterns) } return alphas def sample_pattern_popularity(self): """Returns a popularity distribution over the patterns. Returns ------- pattern_popularities : list A list with the popularity distribution of each pattern. """ pattern_popularity = {} pi = self.pattern_param_prng.dirichlet( [1 for pattern in range(self.num_patterns)]) for pattern_i, pi_i in enumerate(pi): pattern_popularity[pattern_i] = pi_i return pattern_popularity def sample_mu(self): """Samples a value from the prior of the base intensity mu. Returns ------- mu_u : float The base intensity of a user, sampled from the prior. """ return self.prng.gamma(self.mu_0[0], self.mu_0[1]) def sample_next_time(self, pattern, user): """Samples the time of the next event of a pattern for a given user. Parameters ---------- pattern : int The pattern index that we want to sample the next event for. user : int The index of the user that we want to sample for. Returns ------- timestamp : float """ U = self.prng.rand mu_u = self.mu_per_user[user] lambda_u_pattern = mu_u * self.pattern_popularity[pattern] if user in self.total_tables_per_user: # We have seen the user before num_tables = self.total_tables_per_user[user] pattern_tables = [ table for table in range(num_tables) if self.dish_on_table_per_user[user][table] == pattern ] else: pattern_tables = [] alpha = self.time_kernels[pattern] if not pattern_tables: lambda_star = lambda_u_pattern s = -1 / lambda_star * np.log(U()) return s else: # Add the \alpha of the most recent table (previous event) in the # user-pattern intensity s = self.last_event_user_pattern[user][pattern] pattern_intensity = 0 for table in pattern_tables: t_last, sum_kernels = self.user_table_cache[user][table] update_value = self.kernel(s, t_last) # update_value should be 1, so practically we are just adding # \alpha to the intensity dt after the event table_intensity = alpha * sum_kernels * update_value table_intensity += alpha * update_value pattern_intensity += table_intensity lambda_star = lambda_u_pattern + pattern_intensity # New event accepted = False while not accepted: s = s - 1 / lambda_star * np.log(U()) # Rejection test pattern_intensity = 0 for table in pattern_tables: t_last, sum_kernels = self.user_table_cache[user][table] update_value = self.kernel(s, t_last) # update_value should be 1, so practically we are just adding # \alpha to the intensity dt after the event table_intensity = alpha * sum_kernels * update_value table_intensity += alpha * update_value pattern_intensity += table_intensity lambda_s = lambda_u_pattern + pattern_intensity if U() < lambda_s / lambda_star: return s else: lambda_star = lambda_s def sample_user_events(self, min_num_events=100, max_num_events=None, t_max=None): """Samples events for a user. Parameters ---------- min_num_events : int, default is 100 The minimum number of events to sample. max_num_events : int, default is None If not None, this is the maximum number of events to sample. t_max : float, default is None The time limit until which to sample events. Returns ------- events : list A list of the form [(t_i, doc_i, user_i, meta_i), ...] sorted by increasing time that has all the events of the sampled users. Above, doc_i is the document and meta_i is any sort of metadata that we want for doc_i, e.g. question_id. The generator will return an empty list for meta_i. """ user = len(self.mu_per_user) mu_u = self.sample_mu() self.mu_per_user[user] = mu_u # Populate the list with the first event for each pattern next_time_per_pattern = [ self.sample_next_time(pattern, user) for pattern in range(self.num_patterns) ] next_time_per_pattern = asfortranarray(next_time_per_pattern) iteration = 0 over_tmax = False while iteration < min_num_events or not over_tmax: if max_num_events is not None and iteration > max_num_events: break z_n = next_time_per_pattern.argmin() t_n = next_time_per_pattern[z_n] if t_max is not None and t_n > t_max: over_tmax = True break num_tables_user = self.total_tables_per_user[user] \ if user in self.total_tables_per_user else 0 tables = range(num_tables_user) tables = [ table for table in tables if self.dish_on_table_per_user[user][table] == z_n ] intensities = [] alpha = self.time_kernels[z_n] for table in tables: t_last, sum_kernels = self.user_table_cache[user][table] update_value = self.kernel(t_n, t_last) table_intensity = alpha * sum_kernels * update_value table_intensity += alpha * update_value intensities.append(table_intensity) intensities.append(mu_u * self.pattern_popularity[z_n]) log_intensities = [ ln(inten_i) if inten_i > 0 else -float('inf') for inten_i in intensities ] normalizing_log_intensity = logsumexp(log_intensities) intensities = [ exp(log_intensity - normalizing_log_intensity) for log_intensity in log_intensities ] k = weighted_choice(intensities, self.prng) if k < len(tables): # Assign to already existing table table = tables[k] # update cache for that table t_last, sum_kernels = self.user_table_cache[user][table] update_value = self.kernel(t_n, t_last) sum_kernels += 1 sum_kernels *= update_value self.user_table_cache[user][table] = (t_n, sum_kernels) else: table = num_tables_user self.total_tables += 1 self.total_tables_per_user[user] += 1 # Since this is a new table, initialize the cache accordingly self.user_table_cache[user][table] = (t_n, 0) self.dish_on_table_per_user[user][table] = z_n if z_n not in self.first_observed_time or\ t_n < self.first_observed_time[z_n]: self.first_observed_time[z_n] = t_n self.dish_counters[z_n] += 1 doc_n = self.sample_document(z_n) self._update_word_counters(doc_n.split(), z_n) self.document_history_per_user[user]\ .append(doc_n) self.table_history_per_user[user].append(table) self.time_history_per_user[user].append(t_n) self.last_event_user_pattern[user][z_n] = t_n # Resample time for that pattern next_time_per_pattern[z_n] = self.sample_next_time(z_n, user) z_n = next_time_per_pattern.argmin() t_n = next_time_per_pattern[z_n] iteration += 1 events = [(self.time_history_per_user[user][i], self.document_history_per_user[user][i], user, []) for i in range(len(self.time_history_per_user[user]))] # Update the full history of events with the ones generated for the # current user and re-order everything so that the events are # ordered by their timestamp self.events += events self.events = sorted(self.events, key=lambda x: x[0]) self.num_users += 1 return events def kernel(self, t_i, t_j): """Returns the kernel function for t_i and t_j. Parameters ---------- t_i : float Timestamp representing `now`. t_j : float Timestamp representaing `past`. Returns ------- float """ return exp(-self.omega * (t_i - t_j)) def sample_document(self, pattern): """Sample a random document from a specific pattern. Parameters ---------- pattern : int The pattern from which to sample the content. Returns ------- str A space separeted string that contains all the sampled words. """ length = self.doc_prng.randint(self.document_length) + \ self.document_min_length words = self.doc_prng.multinomial(length, self.pattern_params[pattern]) return ' '.join([ self.vocabulary[i] for i, repeats in enumerate(words) for j in range(repeats) ]) def _update_word_counters(self, doc, pattern): """Updates the word counters of the process for the particular document Parameters ---------- doc : list A list with the words in the document. pattern : int The index of the latent pattern that this document belongs to. """ for word in doc: self.per_pattern_word_counts[pattern][word] += 1 self.per_pattern_word_count_total[pattern] += 1 return def pattern_content_str(self, patterns=None, show_words=-1, min_word_occurence=5): """Return the content information for the patterns of the process. Parameters ---------- patterns : list, default is None If this list is provided, only information about the patterns in the list will be returned. show_words : int, default is -1 The maximum number of words to show for each pattern. Notice that the words are sorted according to their occurence count. min_word_occurence : int, default is 5 Only show words that show up at least `min_word_occurence` number of times in the documents of the respective pattern. Returns ------- str A string with all the content information """ if patterns is None: patterns = self.per_pattern_word_counts.keys() text = [ '___Pattern %d___ \n%s\n%s' % (pattern, '\n'.join([ '%s : %d' % (k, v) for i, (k, v) in enumerate( sorted(self.per_pattern_word_counts[pattern].iteritems(), key=lambda x: (x[1], x[0]), reverse=True)) if v >= min_word_occurence and (show_words == -1 or (show_words > 0 and i < show_words)) ]), ' '.join([ k for i, (k, v) in enumerate( sorted(self.per_pattern_word_counts[pattern].iteritems(), key=lambda x: (x[1], x[0]), reverse=True)) if v < min_word_occurence and (show_words == -1 or (show_words > 0 and i < show_words)) ])) for pattern in self.per_pattern_word_counts if pattern in patterns ] return '\n\n'.join(text) def user_patterns_set(self, user): """Return the patterns that a specific user adopted. Parameters ---------- user : int The index of a user. Returns ------- set The set of the patterns that the user adopted. """ pattern_list = [ self.dish_on_table_per_user[user][table] for table in self.table_history_per_user[user] ] return list(set(pattern_list)) def user_pattern_history_str(self, user=None, patterns=None, show_time=True, t_min=0): """Returns a representation of the history of a user's actions and the pattern that they correspond to. Parameters ---------- user : int, default is None An index to the user we want to inspect. If None, the function runs over all the users. patterns : list, default is None If not None, limit the actions returned to the ones that belong in the provided patterns. show_time : bool, default is True Control wether the timestamp will appear in the representation or not. t_min : float, default is 0 The timestamp after which we only consider actions. Returns ------- str """ if patterns is not None and type(patterns) is not set: patterns = set(patterns) return '\n'.join([ '%spattern=%2d task=%3d (u=%d) %s' % ('%5.3g ' % t if show_time else '', dish, table, u, doc) for u in range(self.num_users) for ((t, doc), (table, dish )) in zip([(t, d) for t, d in zip(self.time_history_per_user[u], self.document_history_per_user[u]) ], [(table, self.dish_on_table_per_user[u][table]) for table in self.table_history_per_user[u]]) if (user is None or user == u) and ( patterns is None or dish in patterns) and t >= t_min ]) def _plot_user(self, user, fig, num_samples, T_max, task_detail, seed=None, patterns=None, colormap=None, T_min=0, paper=True): """Helper function that plots. """ tics = np.arange(T_min, T_max, (T_max - T_min) / num_samples) tables = sorted(set(self.table_history_per_user[user])) active_tables = set() dish_set = set() times = [t for t in self.time_history_per_user[user]] start_event = min([ i for i, t in enumerate(self.time_history_per_user[user]) if t >= T_min ]) table_cache = {} # partial sum for each table dish_cache = {} # partial sum for each dish table_intensities = [[] for _ in range(len(tables))] dish_intensities = [[] for _ in range(len(self.time_kernels))] i = 0 # index of tics j = start_event # index of events while i < len(tics): if j >= len(times) or tics[i] < times[j]: # We need to measure the intensity dish_intensities, table_intensities = \ self._measure_intensities(tics[i], dish_cache=dish_cache, table_cache=table_cache, tables=tables, user=user, dish_intensities=dish_intensities, table_intensities=table_intensities) i += 1 else: # We need to record an event and update our caches dish_cache, table_cache, active_tables, dish_set = \ self._update_cache(times[j], dish_cache=dish_cache, table_cache=table_cache, event_table=self.table_history_per_user[user][j], tables=tables, user=user, active_tables=active_tables, dish_set=dish_set) j += 1 if patterns is not None: dish_set = set([d for d in patterns if d in dish_set]) dish_dict = {dish: i for i, dish in enumerate(dish_set)} if colormap is None: prng = RandomState(seed) num_dishes = len(dish_set) colormap = qualitative_cmap(n_colors=num_dishes) colormap = prng.permutation(colormap) if not task_detail: for dish in dish_set: fig.plot(tics, dish_intensities[dish], color=colormap[dish_dict[dish]], linestyle='-', label="Pattern " + str(dish), linewidth=3.) else: for table in active_tables: dish = self.dish_on_table_per_user[user][table] if patterns is not None and dish not in patterns: continue fig.plot(tics, table_intensities[table], color=colormap[dish_dict[dish]], linewidth=3.) for dish in dish_set: if patterns is not None and dish not in patterns: continue fig.plot([], [], color=colormap[dish_dict[dish]], linestyle='-', label="Pattern " + str(dish), linewidth=5.) return fig def _measure_intensities(self, t, dish_cache, table_cache, tables, user, dish_intensities, table_intensities): """Measures the intensities of all tables and dishes for the plot func. Parameters ---------- t : float dish_cache : dict table_cache : dict tables : list user : int dish_intensities : dict table_intensities : dict Returns ------- dish_intensities : dict table_intensities : dict """ updated_dish = [False] * len(self.time_kernels) for table in tables: dish = self.dish_on_table_per_user[user][table] lambda_uz = self.mu_per_user[user] * self.pattern_popularity[dish] alpha = self.time_kernels[dish] if dish in dish_cache: t_last_dish, sum_kernels_dish = dish_cache[dish] update_value_dish = self.kernel(t, t_last_dish) dish_intensity = lambda_uz + alpha * sum_kernels_dish *\ update_value_dish dish_intensity += alpha * update_value_dish else: dish_intensity = lambda_uz if table in table_cache: # table already exists t_last_table, sum_kernels_table = table_cache[table] update_value_table = self.kernel(t, t_last_table) table_intensity = alpha * sum_kernels_table *\ update_value_table table_intensity += alpha * update_value_table else: # table does not exist yet table_intensity = 0 table_intensities[table].append(table_intensity) if not updated_dish[dish]: # make sure to update dish only once for all the tables dish_intensities[dish].append(dish_intensity) updated_dish[dish] = True return dish_intensities, table_intensities def _update_cache(self, t, dish_cache, table_cache, event_table, tables, user, active_tables, dish_set): """Updates the caches for the plot function when an event is recorded. Parameters ---------- t : float dish_cache : dict table_cache : dict event_table : int tables : list user : int active_tables : set dish_set : set Returns ------- dish_cache : dict table_cache : dict active_tables : set dish_set : set """ dish = self.dish_on_table_per_user[user][event_table] active_tables.add(event_table) dish_set.add(dish) if event_table not in table_cache: table_cache[event_table] = (t, 0) else: t_last, sum_kernels = table_cache[event_table] update_value = self.kernel(t, t_last) sum_kernels += 1 sum_kernels *= update_value table_cache[event_table] = (t, sum_kernels) if dish not in dish_cache: dish_cache[dish] = (t, 0) else: t_last, sum_kernels = dish_cache[dish] update_value = self.kernel(t, t_last) sum_kernels += 1 sum_kernels *= update_value dish_cache[dish] = (t, sum_kernels) return dish_cache, table_cache, active_tables, dish_set def plot(self, num_samples=500, T_min=0, T_max=None, start_date=None, users=None, user_limit=50, patterns=None, task_detail=False, save_to_file=False, filename="user_timelines", intensity_threshold=None, paper=True, colors=None, fig_width=20, fig_height_per_user=5, time_unit='months', label_every=3, seed=None): """Plots the intensity of a set of users for a set of patterns over a time period. In this plot, each user is a separate subplot and for each user the plot shows her event_rate for each separate pattern that she has been active at. Parameters ---------- num_samples : int, default is 500 The granularity level of the intensity line. Smaller number of samples results in faster plotting, while larger numbers give much more detailed result. T_min : float, default is 0 The minimum timestamp that the plot shows, in seconds. T_max : float, default is None If not None, this is the maximum timestamp that the plot considers, in seconds. start_date : datetime, default is None If provided, this is the actual datetime that corresponds to time 0. This is required if `paper` is True. users : list, default is None If provided, this list contains the id's of the users that will be plotted. Actually, only the first `user_limit` of them will be shown. user_limit : int, default is 50 The maximum number of users to plot. patterns : list, default is None The list of patterns that will be shown in the final plot. If None, all of the patterns will be plotted. task_detail : bool, default is False If True, thee plot has one line per task. Otherwise, we only plot the commulative intensity of all tasks under the same pattern. save_to_file : bool, default is False If True, the plot will be saved to a `pdf` and a `png` file. filename : str, default is 'user_timelines' The name of the output file that will be used when saving the plot. intensity_threshold : float, default is None If provided, this is the maximum intensity value that will be plotted, i.e. the y_max that will be the cut-off threshold for the y-axis. paper : bool, default is True If True, the plot result will be the same as the figures that are in the published paper. colors : list, default is None A list of colors that will be used for the plot. Each color will correspond to a single pattern, and will be shared across all the users. fig_width : int, default is 20 The width of the figure that will be returned. fig_height_per_user : int, default is 5 The height of each separate user-plot of the final figure. If multiplied by the number of users, this determines the total height of the figure. Notice that due to a matplotlib constraint(?) the total height of the figure cannot be over 70. time_unit : str, default is 'months' Controls wether the time units is measured in days (in which case it should be set to 'days') or months. label_every : int, default is 3 The frequency of the labels that show in the x-axis. seed : int, default is None A seed to the random number generator used to assign colors to patterns. Returns ------- fig : matplotlib.Figure object """ prng = RandomState(seed) num_users = len(self.dish_on_table_per_user) if users is None: users = range(num_users) num_users_to_plot = min(len(users), user_limit) users = users[:num_users_to_plot] if T_max is None: T_max = max( [self.time_history_per_user[user][-1] for user in users]) fig = plt.figure(figsize=(fig_width, min(fig_height_per_user * num_users_to_plot, 70))) num_patterns_global = len(self.time_kernels) colormap = qualitative_cmap(n_colors=num_patterns_global) colormap = prng.permutation(colormap) if colors is not None: colormap = matplotlib.colors.ListedColormap(colors) if paper: sns.set_style('white') sns.despine(bottom=True, top=False, right=False, left=False) user_plt_axes = [] max_intensity = -float('inf') for i, user in enumerate(users): if user not in self.time_history_per_user \ or not self.time_history_per_user[user] \ or self.time_history_per_user[user][-1] < T_min: # no events generated for this user during the time window # we are interested in continue if patterns is not None: user_patterns = set([ self.dish_on_table_per_user[user][table] for table in self.table_history_per_user[user] ]) if not any([pattern in patterns for pattern in user_patterns]): # user did not generate events in the patterns of interest continue user_plt = plt.subplot(num_users_to_plot, 1, i + 1) user_plt = self._plot_user(user, user_plt, num_samples, T_max, task_detail=task_detail, seed=seed, patterns=patterns, colormap=colormap, T_min=T_min, paper=paper) user_plt.set_xlim((T_min, T_max)) if paper: if start_date is None: raise ValueError( 'For paper-level quality plots, the actual datetime for t=0 must be provided as `start_date`' ) if start_date.microsecond > 500000: start_date = start_date.replace(microsecond=0) \ + datetime.timedelta(seconds=1) else: start_date = start_date.replace(microsecond=0) if time_unit == 'days': t_min_seconds = T_min * 86400 t_max_seconds = T_max * 86400 t1 = start_date + datetime.timedelta(0, t_min_seconds) t2 = start_date + datetime.timedelta(0, t_max_seconds) ticks = monthly_ticks_for_days(t1, t2) labels = monthly_labels(t1, t2, every=label_every) elif time_unit == 'months': t1 = month_add(start_date, T_min) t2 = month_add(start_date, T_max) ticks = monthly_ticks_for_months(t1, t2) labels = monthly_labels(t1, t2, every=label_every) labels[-1] = '' user_plt.set_xlim((ticks[0], ticks[-1])) user_plt.yaxis.set_ticks([]) user_plt.xaxis.set_ticks(ticks) plt.setp(user_plt.xaxis.get_majorticklabels(), rotation=-0) user_plt.tick_params('x', length=10, which='major', direction='out', top=False, bottom=True) user_plt.xaxis.set_ticklabels(labels, fontsize=30) user_plt.tick_params(axis='x', which='major', pad=10) user_plt.get_xaxis( ).majorTicks[1].label1.set_horizontalalignment('left') for tick in user_plt.xaxis.get_major_ticks(): tick.label1.set_horizontalalignment('left') current_intensity = user_plt.get_ylim()[1] if current_intensity > max_intensity: max_intensity = current_intensity user_plt_axes.append(user_plt) if not paper: plt.title('User %2d' % user) plt.legend() if intensity_threshold is None: intensity_threshold = max_intensity for ax in user_plt_axes: ax.set_ylim((-0.2, intensity_threshold)) if paper and save_to_file: print("Create image %s.png" % (filename)) plt.savefig(filename + '.png', transparent=True) plt.savefig(filename + '.pdf', transparent=True) sns.set_style(None) return fig def user_patterns(self, user): """Returns a list with the patterns that a user has adopted. Parameters ---------- user : int """ pattern_list = [ self.dish_on_table_per_user[user][table] for table in self.table_history_per_user[user] ] return list(set(pattern_list)) def show_annotated_events(self, user=None, patterns=None, show_time=True, T_min=0, T_max=None): """Returns a string where each event is annotated with the inferred pattern. Parameters ---------- user : int, default is None If given, the events returned are limited to the selected user patterns : list, default is None If not None, an event is return only if it belongs to one of the selected patterns show_time : bool, default is True Controls whether the time of the event will be shown T_min : float, default is 0 Controls the minimum timestamp after which the events will be shown T_max : float, default is None If given, T_max controls the maximum timestamp shown Returns ------- str """ if patterns is not None and type(patterns) is not set: patterns = set(patterns) if show_time: return '\n'.join([ '%5.3g pattern=%3d task=%3d (u=%d) %s' % (t, dish, table, u, doc) for u in range(self.num_users) for ((t, doc), (table, dish)) in zip([(t, d) for t, d in zip(self.time_history_per_user[u], self.document_history_per_user[u])], [(table, self.dish_on_table_per_user[u][table]) for table in self.table_history_per_user[u]]) if (user is None or user == u) and ( patterns is None or dish in patterns) and t >= T_min and ( T_max is None or (T_max is not None and t <= T_max)) ]) else: return '\n'.join([ 'pattern=%3d task=%3d (u=%d) %s' % (dish, table, u, doc) for u in range(self.num_users) for ((t, doc), (table, dish)) in zip([(t, d) for t, d in zip(self.time_history_per_user[u], self.document_history_per_user[u])], [(table, self.dish_on_table_per_user[u][table]) for table in self.table_history_per_user[u]]) if (user is None or user == u) and ( patterns is None or dish in patterns) and t >= T_min and ( T_max is None or (T_max is not None and t <= T_max)) ]) def show_pattern_content(self, patterns=None, words=0, detail_threshold=5): """Shows the content distrubution of the inferred patterns. Parameters ---------- patterns : list, default is None If not None, only the content of the selected patterns will be shown words : int, default is 0 A positive number that control how many words will be shown. The words are being shown sorted by their likelihood, starting with the most probable. detail_threshold : int, default is 5 A positive number that sets the lower bound in the number of times that a word appeared in a pattern so that its count is shown. Returns ------- str """ if patterns is None: patterns = self.per_pattern_word_count.keys() text = [ '___Pattern %d___ \n%s\n%s' % (pattern, '\n'.join([ '%s : %d' % (k, v) for i, (k, v) in enumerate( sorted(self.per_pattern_word_counts[pattern].iteritems(), key=lambda x: (x[1], x[0]), reverse=True)) if v >= detail_threshold and (words == 0 or i < words) ]), ' '.join([ k for i, (k, v) in enumerate( sorted(self.per_pattern_word_counts[pattern].iteritems(), key=lambda x: (x[1], x[0]), reverse=True)) if v < detail_threshold and (words == 0 or i < words) ])) for pattern in self.per_pattern_word_counts if pattern in patterns ] return '\n\n'.join(text)
def main(outdir): rng = RandomState(MT19937(SeedSequence(config.seed))) num_employees = 50000 num_orders = 1000000 num_jobsites = 2800 num_areas = 180 num_qualifications = 214 num_qualigroups = 13 num_shifts = 4 num_days = 2708 start_day = datetime(2013, 8, 1) print("create sliding window of active employees") active_employees = np.zeros((num_employees, num_days)).astype(bool) left = 0 right = 100 upkeep = 400 change = (.95, 1 - .95) for irow, row in enumerate(active_employees): active_employees[irow, left:right] = 1 left = left + rng.choice([0, 1], p=change) right = left + upkeep + rng.choice([0, 1], p=change) print("create base distributions for areas, qualis and shifts") areas = rng.dirichlet(np.ones(num_areas) * .1) jobsites = rng.dirichlet(np.ones(num_jobsites) * .1) area_of_jobsite = np.empty(num_jobsites) for ijobsite, jobsite in enumerate(jobsites): area_of_jobsite[ijobsite] = rng.choice(np.arange(num_areas), p=areas) qualigroups = rng.dirichlet(np.ones(num_qualigroups) * .1) qualis = rng.dirichlet(np.ones(num_qualifications) * .1) qualigroup_of_quali = np.empty(num_qualifications) for iquali, quali in enumerate(qualis): qualigroup_of_quali[iquali] = rng.choice(np.arange(num_qualigroups), p=qualigroups) shifts = rng.dirichlet(np.ones(num_shifts)) orders = [] for _ in tqdm(range(num_orders), desc="create orders"): shift = rng.choice(range(num_shifts), p=shifts) jobsite = rng.choice(range(num_jobsites), p=jobsites) area = area_of_jobsite[jobsite] quali = rng.choice(range(num_qualifications), p=qualis) qualigroup = qualigroup_of_quali[quali] day = rng.randint(0, num_days) orders.append({ "Schicht": shift, "Einsatzort": jobsite, "PLZ": area, "Qualifikation": quali, "Qualifikationgruppe": qualigroup, "Tag": day, }) employee_qualifications = rng.multinomial( 1, qualis, size=(num_employees)).astype(bool) employee_jobsites = rng.multinomial(1, jobsites, size=(num_employees)).astype(bool) orders = pd.DataFrame(orders) offers = [] ps = np.ones(6) / np.arange(1, 7) ps /= ps.sum() for _, order in tqdm(orders.iterrows(), desc="create offers", total=len(orders)): match_active = active_employees[:, int(order.Tag)] match_quali = employee_qualifications[:, int(order.Qualifikation)] match_jobsite = employee_jobsites[:, int(order.Einsatzort)] match, = (match_active & match_quali & match_jobsite).nonzero() offers.append(match[:6].tolist()) if len(offers[-1]) == 0: offers[-1] = rng.choice(match_active.nonzero()[0], np.random.choice(range(1, 7), p=ps)).tolist() berlin_holidays = holidays.DE(prov="BE") orders["Mitarbeiter ID"] = offers print("add day meta data") orders["Tag"] = orders["Tag"].apply(lambda day: start_day + timedelta(day)) orders["Wochentag"] = orders["Tag"].apply(lambda day: day.strftime("%a")) orders["Feiertag"] = orders["Tag"].apply( lambda day: day in berlin_holidays) orders = orders[[ "Einsatzort", "PLZ", "Qualifikation", "Qualifikationgruppe", "Schicht", "Tag", "Wochentag", "Feiertag", "Mitarbeiter ID" ]] orders = orders.sort_values("Tag") train, test = train_test_split(orders) train.to_csv(os.path.join(outdir, "train.tsv"), index=False, sep="\t") test.to_csv(os.path.join(outdir, "test_truth.tsv"), index=False, sep="\t") test[[ "Einsatzort", "PLZ", "Qualifikation", "Qualifikationgruppe", "Schicht", "Tag", "Wochentag", "Feiertag" ]].to_csv(os.path.join(outdir, "test_publish.tsv"), index=False, sep="\t")