def beam_search(data, target_class, time_budget=conf.TIME_BUDGET, enable_i=True, top_k=conf.TOP_K, beam_width=conf.BEAM_WIDTH, iterations_limit=conf.ITERATIONS_NUMBER, theta=conf.THETA, quality_measure=conf.QUALITY_MEASURE, diverse=True): items = extract_items(data) begin = datetime.datetime.utcnow() time_budget = datetime.timedelta(seconds=time_budget) # candidate_queue = items_to_sequences(items) candidate_queue = [[]] sorted_patterns = PrioritySet(top_k, theta=theta) nb_iteration = 0 while datetime.datetime.utcnow() - begin < time_budget and nb_iteration < iterations_limit: beam = PrioritySet() while (len(candidate_queue) != 0) and nb_iteration < iterations_limit: seed = candidate_queue.pop(0) children = compute_children(seed, items, enable_i) for child in children: if nb_iteration >= iterations_limit: break quality = compute_quality(data, child, target_class, quality_measure=quality_measure) # sorted_patterns.add_preserve_memory(child, quality, data) sorted_patterns.add(child, quality) beam.add(child, quality) nb_iteration += 1 if diverse: candidate_queue = [j for i, j in beam.get_top_k_non_redundant(data, beam_width)] else: candidate_queue = [j for i, j in beam.get_top_k(beam_width)] # print("Number iterations beam search: {}".format(nb_iteration)) return sorted_patterns.get_top_k_non_redundant(data, top_k)
def test_over_top_k(): priority = PrioritySet() priority.add(frozenset([1, 2]), 0.1) assert len(priority.get_top_k(2)) == 1