예제 #1
0
    def _split_subtree(tree: SortedList,
                       max_length: int) -> Tuple[SortedList, SortedList]:
        """
		Split a tree by its median element into two smaller trees

		:param tree: The tree to split
		:param max_length: The size of the largest possible element in the trie in bits
		:return: The tree with the smaller elements,
				 and the tree with the larger elements
		"""
        median = tree.bisect_right(
            YFastTrie._calculate_representative(tree[len(tree) // 2],
                                                max_length))
        return SortedList(tree.islice(stop=median)), SortedList(
            tree.islice(start=median))
예제 #2
0
def test_merge_subtrees(values):
	values = SortedList(values)
	split = randint(1, len(values) - 1)
	left_tree = SortedList(values.islice(stop=split))
	right_tree = SortedList(values.islice(start=split))
	new_left, new_right = YFastTrie._merge_subtrees(left_tree, right_tree, 2 * max_trie_entry_size)

	if len(values) <= 2 * max_trie_entry_size:
		assert new_right is None
		assert isinstance(new_left, SortedList)
		assert len(new_left) == len(values)
	else:
		assert isinstance(new_left, SortedList)
		assert isinstance(new_right, SortedList)
		assert len(new_left) + len(new_right) == len(values)
		assert YFastTrie._calculate_representative(max(new_left), max_trie_entry_size) < min(new_right)
예제 #3
0
class Leaderboard:
    def __init__(self):
        # Time Complexity: O(1) (init)
        # Space Complexity: O(N) (overall)

        self.scores = {}
        self.sorted_scores = SortedList()

    def addScore(self, playerId: int, score: int) -> None:
        # Time Complexity: O(log N)
        # Space Complexity: O(1)

        if playerId in self.scores:
            self.sorted_scores.remove(self.scores[playerId])
            score += self.scores[playerId]

        self.scores[playerId] = score
        self.sorted_scores.add(score)

    def top(self, K: int) -> int:
        # Time Complexity: O(K) or O(N) (depending on implementation)
        # Space Complexity: O(1)

        start = max(0, len(self.scores) - K)
        return sum(self.sorted_scores.islice(start))

    def reset(self, playerId: int) -> None:
        # Time Complexity: O(log N)
        # Space Complexity: O(1)

        self.addScore(playerId, -self.scores.get(playerId, 0))
예제 #4
0
class RangeList:
    """Add and query non-overlapping intervals. Intervals are semi-closed, e.g.
    the interval [1, 3) contains the points {1, 2}.
    """
    def __init__(self, init=None):
        self.data = SortedList(init, key=lambda x: x[0])

    def add(self, start, end):
        left = self.data.bisect_right((start, 0))
        if left > 0:
            if self.data[left - 1][1] >= start:
                start = self.data[left - 1][0]
                left -= 1

        right = self.data.bisect_right((end, 0))
        if right > 0:
            if self.data[right - 1][1] >= end:
                end = self.data[right - 1][1]

        for _ in range(right - left):
            self.data.pop(left)

        self.data.add((start, end))

    def list(self):
        return list(self.data)

    def iter(self):
        return self.data.islice(start=0)
예제 #5
0
def test_islice():
    sl = SortedList(load=7)

    assert [] == list(sl.islice())

    values = list(range(53))
    sl.update(values)

    for start in range(53):
        for stop in range(53):
            assert list(sl.islice(start, stop)) == values[start:stop]

    for start in range(53):
        for stop in range(53):
            assert list(sl.islice(start, stop,
                                  reverse=True)) == values[start:stop][::-1]

    for start in range(53):
        assert list(sl.islice(start=start)) == values[start:]
        assert list(sl.islice(start=start,
                              reverse=True)) == values[start:][::-1]

    for stop in range(53):
        assert list(sl.islice(stop=stop)) == values[:stop]
        assert list(sl.islice(stop=stop, reverse=True)) == values[:stop][::-1]
예제 #6
0
def test_merge_large_subtrees(values):
	values = SortedList(values)
	split = len(values) // 2
	left_tree = SortedList(values.islice(stop=split))
	right_tree = SortedList(values.islice(start=split))
	new_left, new_right = YFastTrie._merge_subtrees(left_tree, right_tree, 2 * max_trie_entry_size)
	assert isinstance(new_left, SortedList)
	assert isinstance(new_right, SortedList)
	assert len(new_left) + len(new_right) == len(values)
	assert YFastTrie._calculate_representative(max(new_left), max_trie_entry_size) < min(new_right)

	split += 1
	left_tree = SortedList(values.islice(stop=split))
	right_tree = SortedList(values.islice(start=split))
	new_left, new_right = YFastTrie._merge_subtrees(left_tree, right_tree, 2 * max_trie_entry_size)
	assert isinstance(new_left, SortedList)
	assert isinstance(new_right, SortedList)
	assert len(new_left) + len(new_right) == len(values)
	assert YFastTrie._calculate_representative(max(new_left), max_trie_entry_size) < min(new_right)
예제 #7
0
class ZSet:
    def __init__(self):
        self.mem2score = {}
        self.scores = SortedList()

    def __contains__(self, val):
        return val in self.mem2score

    def __setitem__(self, val, score):
        self.add(val, score)

    def __getitem__(self, key):
        return self.mem2score[key]

    def __len__(self):
        return len(self.mem2score)

    def __iter__(self):
        def f():
            for score, val in self.scores:
                yield val
        return f()

    def __str__(self):
        ans = []
        return

    def get(self, key, default=None):
        return self.mem2score.get(key, default)

    def add(self, val, score):
        s_prev = self.mem2score.get(val, None)
        if s_prev:
            if s_prev == score:
                return False
            self.scores.remove((s_prev, val))
        self.mem2score[val] = score
        self.scores.add((score, val))
        return True

    def discard(self, key):
        try:
            score = self.mem2score.pop(key)
        except KeyError:
            return
        self.scores.remove((score, key))

    def items(self):
        return self.mem2score.items()

    def rank(self, member):
        return self.scores.index((self.mem2score[member], member))

    def islice_score(self, start, stop, reverse=False):
        return self.scores.islice(start, stop, reverse)
예제 #8
0
class MovieRentingSystem(object):
    def __init__(self, n, entries):
        """
        :type n: int
        :type entries: List[List[int]]
        """
        self.movies = defaultdict(SortedList)
        self.shops = defaultdict(dict)
        self.renting = SortedList([])
        for shop, movie, price in entries:
            self.movies[movie].add((price, shop))
            self.shops[shop][movie] = price

    def search(self, movie):
        """
        :type movie: int
        :rtype: List[int]
        """
        return [i[1] for i in list(self.movies[movie].islice(stop=5))]

    def rent(self, shop, movie):
        """
        :type shop: int
        :type movie: int
        :rtype: None
        """
        price = self.shops[shop][movie]
        self.movies[movie].discard((price, shop))
        self.renting.add((price, shop, movie))

    def drop(self, shop, movie):
        """
        :type shop: int
        :type movie: int
        :rtype: None
        """
        price = self.shops[shop][movie]
        self.movies[movie].add((price, shop))
        self.renting.discard((price, shop, movie))

    def report(self):
        """
        :rtype: List[List[int]]
        """
        return [[x, y] for _, x, y in self.renting.islice(stop=5)]
예제 #9
0
def generateTestData(queryPositivePercentage, sampleSize, maxUniverse):
    uncompressedList = []
    queryList = SortedList([])
    negativeList = []
    samplePercentage = sampleSize / maxUniverse * 100
    random.seed(datetime.now())
    for x in range(1, maxUniverse):
        if (random.randint(1, 100) <= samplePercentage):
            uncompressedList.append(x)
            if (random.randint(1, 100) <= queryPositivePercentage):
                queryList.append(x)
            elif (len(negativeList) > 0):
                # you add some random element from the negative list
                randomIndex = random.randrange(len(negativeList))
                queryList.add(negativeList.pop(randomIndex))
        else:
            negativeList.append(x)
    return (uncompressedList, list(queryList.islice(0, len(queryList))))
예제 #10
0
  def get_top_recs(self, user_id, sim_users, n):
    '''
    Given user_id with its top k similar users, return top n
    chefs with highest predicted ratings. Does not filter out
    chefs that the user has already rated.

    @param int user_id user to get n chef recommendations for
    @param int sim_users the list of similar users for user_id
    @param int n number of chef recommendations to return
    @returns list top_chefs top n predicted rated chefs for user_id
    '''
    
    summation = None
    for s_user in sim_users:
      if summation == None:
        summation = self.sparse_ratings_matrix[s_user]
      else:
        summation += self.sparse_ratings_matrix[s_user]

    length = len(sim_users)
    summation = summation/length

    rated_chefs = summation[0].nonzero()[1] # list of chef ids with non-zero ratings

    sorted_chefs = SortedList([])
    for rated_chef_id in rated_chefs:
      sorted_chefs.add(KeyValPair(rated_chef_id, summation[0, rated_chef_id]))

    # give predicted rating, too, maybe
    top_chefs = list(map(lambda x: x.key, list(sorted_chefs.islice(0, n, reverse=True))))
    return top_chefs


# start = time.time()
# x = ChefRecommendationEngine()
# end = time.time()
# print(end - start)
# print(x.recs)

# start = time.time()
# x.receive_new_rating(32, 22, 2)
# end = time.time()
# print(end - start)
# print(x.recs)
def test_islice():
    sl = SortedList(load=7)

    assert [] == list(sl.islice())

    values = list(range(53))
    sl.update(values)

    for start in range(53):
        for stop in range(53):
            assert list(sl.islice(start, stop)) == values[start:stop]

    for start in range(53):
        for stop in range(53):
            assert list(sl.islice(start, stop, reverse=True)) == values[start:stop][::-1]

    for start in range(53):
        assert list(sl.islice(start=start)) == values[start:]
        assert list(sl.islice(start=start, reverse=True)) == values[start:][::-1]

    for stop in range(53):
        assert list(sl.islice(stop=stop)) == values[:stop]
        assert list(sl.islice(stop=stop, reverse=True)) == values[:stop][::-1]
예제 #12
0
    class ManagerNode(object):
        def __init__(self, node, parent, position):
            """
            :param node: the behavior that is represented by this ManagerNode
            :type node: py_trees.behaviour.Behaviour
            :param parent: the parent of the behavior that is represented by this ManagerNode
            :type parent: py_trees.behaviour.Behaviour
            :param position: the position of the node in the list of children of the parent
            :type position: int
            """
            self.node = node
            self.parent = parent
            self.position = position
            self.disabled_children = SortedList()
            self.enabled_children = SortedList()

        def __lt__(self, other):
            return self.position < other.position

        def __gt__(self, other):
            return self.position > other.position

        def __eq__(self, other):
            return self.node == other.node and self.parent == other.parent

        def disable_child(self, manager_node):
            """
            marks the given manager node as disabled in the internal tree representation and removes it to the behavior tree
            :param manager_node:
            :type manager_node: TreeManager.ManagerNode
            :return:
            """
            self.enabled_children.remove(manager_node)
            self.disabled_children.add(manager_node)
            if isinstance(self.node, PluginBehavior):
                self.node.remove_plugin(manager_node.node.name)
            else:
                self.node.remove_child(manager_node.node)

        def enable_child(self, manager_node):
            """
            marks the given manager node as enabled in the internal tree representation and adds it to the behavior tree
            :param manager_node:
            :type manager_node: TreeManager.ManagerNode
            :return:
            """
            self.disabled_children.remove(manager_node)
            self.enabled_children.add(manager_node)
            if isinstance(self.node, PluginBehavior):
                self.node.add_plugin(manager_node.node)
            else:
                idx = self.enabled_children.index(manager_node)
                self.node.insert_child(manager_node.node, idx)

        def add_child(self, manager_node):
            """
            adds the given manager node to the internal tree map and the corresponding behavior to the behavior tree
            :param manager_node:
            :type manager_node: TreeManager.ManagerNode
            :return:
            """
            if isinstance(self.node, PluginBehavior):
                self.enabled_children.add(manager_node)
                self.node.add_plugin(manager_node.node)
            else:
                if manager_node.position < 0:
                    manager_node.position = 0
                    if self.enabled_children:
                        manager_node.position = max(
                            manager_node.position,
                            self.enabled_children[-1].position + 1)
                    if self.disabled_children:
                        manager_node.position = max(
                            manager_node.position,
                            self.disabled_children[-1].position + 1)
                    idx = manager_node.position
                else:
                    idx = self.disabled_children.bisect_left(manager_node)
                    for c in self.disabled_children.islice(start=idx):
                        c.position += 1
                    idx = self.enabled_children.bisect_left(manager_node)
                    for c in self.enabled_children.islice(start=idx):
                        c.position += 1
                self.node.insert_child(manager_node.node, idx)
                self.enabled_children.add(manager_node)

        def remove_child(self, manager_node):
            """
            removes the given manager_node from the internal tree map and the corresponding behavior from the behavior tree
            :param manager_node:
            :type manager_node: TreeManager.ManagerNode
            :return:
            """
            if isinstance(self.node, PluginBehavior):
                if manager_node in self.enabled_children:
                    self.enabled_children.remove(manager_node)
                    self.node.remove_plugin(manager_node.node.name)
                elif manager_node in self.disabled_children:
                    self.disabled_children.remove(manager_node)
                else:
                    raise RuntimeError(
                        'could not remove node from parent. this probably means that the tree is inconsistent'
                    )
            else:
                if manager_node in self.enabled_children:
                    self.enabled_children.remove(manager_node)
                    self.node.remove_child(manager_node.node)
                elif manager_node in self.disabled_children:
                    self.disabled_children.remove(manager_node)
                else:
                    raise RuntimeError(
                        'could not remove node. this probably means that the tree is inconsistent'
                    )
                idx = self.disabled_children.bisect_right(manager_node)
                for c in self.disabled_children.islice(start=idx):
                    c.position -= 1
                idx = self.enabled_children.bisect_right(manager_node)
                for c in self.enabled_children.islice(start=idx):
                    c.position -= 1
예제 #13
0
def process_swaps_for_eon(operator_eon_number):
    checkpoint_created = RootCommitment.objects.filter(
        eon_number=operator_eon_number).exists()

    notification_queue = []

    with transaction.atomic():
        default_time = timezone.now() - datetime.timedelta(days=365000)
        # Match swaps
        last_unprocessed_swap_time = timezone.make_aware(
            datetime.datetime.fromtimestamp(
                cache.get_or_set('last_unprocessed_swap_time',
                                 default_time.timestamp())))

        unprocessed_swaps = Transfer.objects \
            .filter(
                time__gte=last_unprocessed_swap_time,
                processed=False,
                complete=False,
                voided=False,
                cancelled=False,
                swap=True,
                eon_number=operator_eon_number,
                sender_active_state__operator_signature__isnull=False,
                recipient_active_state__operator_signature__isnull=False) \
            .select_for_update() \
            .order_by('time')

        order_books_cache = {}

        for swap in unprocessed_swaps:
            last_unprocessed_swap_time = max(
                last_unprocessed_swap_time,
                swap.time + datetime.timedelta(milliseconds=1))
            matched_successfully = False
            with transaction.atomic(), swap.lock(
                    auto_renewal=True), swap.wallet.lock(
                        auto_renewal=True), swap.recipient.lock(
                            auto_renewal=True):
                swap_wallet_view_context = WalletTransferContext(
                    wallet=swap.wallet, transfer=swap)
                swap_recipient_view_context = WalletTransferContext(
                    wallet=swap.recipient, transfer=swap)

                if swap_expired(swap, operator_eon_number, checkpoint_created):
                    logger.info('Retiring swap')
                    swap.retire_swap()
                    continue
                if should_void_swap(swap, swap_wallet_view_context,
                                    swap_recipient_view_context,
                                    operator_eon_number, checkpoint_created):
                    logger.info('Voiding swap.')
                    swap.close(voided=True)
                    continue
                elif swap.is_fulfilled_swap():
                    logger.info('Skipping finalized swap.')
                    continue

                opposite_order_book_name = '{}-{}'.format(
                    swap.recipient.token.short_name,
                    swap.wallet.token.short_name)

                # If this is a sell order then the opposite orderbook is for buys, which should be sorted
                # in decremental order by price such that the first element in the list is the highest priced
                opposite_comparison_function = price_comparison_function(
                    inverse=swap.sell_order, reverse=swap.sell_order)
                if opposite_order_book_name not in order_books_cache:
                    print("FETCHED")
                    opposite_swaps = Transfer.objects\
                        .filter(
                            id__lte=swap.id,
                            wallet__token=swap.recipient.token,
                            recipient__token=swap.wallet.token,
                            processed=False,
                            complete=False,
                            voided=False,
                            cancelled=False,
                            swap=True,
                            eon_number=operator_eon_number,
                            sender_active_state__operator_signature__isnull=False,
                            recipient_active_state__operator_signature__isnull=False)\
                        .select_for_update()

                    order_books_cache[opposite_order_book_name] = SortedList(
                        opposite_swaps,
                        key=cmp_to_key(opposite_comparison_function))
                else:
                    print("CACHED")

                opposite_order_book = SortedList(
                    order_books_cache[opposite_order_book_name],
                    key=cmp_to_key(opposite_comparison_function))
                opposite_orders_consumed = 0

                if len(opposite_order_book) == 0:
                    print("EMPTY")

                opposite_swaps = Transfer.objects \
                    .filter(
                        id__lte=swap.id,
                        wallet__token=swap.recipient.token,
                        recipient__token=swap.wallet.token,
                        processed=False,
                        complete=False,
                        voided=False,
                        cancelled=False,
                        swap=True,
                        eon_number=operator_eon_number,
                        sender_active_state__operator_signature__isnull=False,
                        recipient_active_state__operator_signature__isnull=False) \
                    .select_for_update()
                assert (len(opposite_order_book) == opposite_swaps.count())

                for opposite in opposite_order_book:
                    # BUY Price: amount / amount_swapped
                    # SELL Price: amount_swapped / amount

                    if swap.sell_order:
                        logger.info('SELL FOR {} VS BUY AT {}'.format(
                            swap.amount_swapped / swap.amount,
                            opposite.amount / opposite.amount_swapped))

                    else:
                        logger.info('BUY AT {} VS SELL FOR {}'.format(
                            swap.amount / swap.amount_swapped,
                            opposite.amount_swapped / opposite.amount))

                    # The invariant is that the buy order price is greater than or equal to the sell order price
                    invariant = swap.amount * \
                        opposite.amount >= opposite.amount_swapped * swap.amount_swapped

                    if not invariant:
                        break

                    with opposite.lock(
                            auto_renewal=True), opposite.wallet.lock(
                                auto_renewal=True), opposite.recipient.lock(
                                    auto_renewal=True):
                        opposite_wallet_view_context = WalletTransferContext(
                            wallet=opposite.wallet, transfer=opposite)
                        opposite_recipient_view_context = WalletTransferContext(
                            wallet=opposite.recipient, transfer=opposite)

                        if swap_expired(opposite, operator_eon_number,
                                        checkpoint_created):
                            opposite.retire_swap()
                            opposite_orders_consumed += 1
                            continue
                        if should_void_swap(opposite,
                                            opposite_wallet_view_context,
                                            opposite_recipient_view_context,
                                            operator_eon_number,
                                            checkpoint_created):
                            opposite.close(voided=True)
                            opposite_orders_consumed += 1
                            continue
                        elif opposite.is_fulfilled_swap():
                            opposite_orders_consumed += 1
                            continue

                        matched_successfully = match_limit_to_limit(
                            swap, opposite)

                        if opposite.is_fulfilled_swap():
                            opposite_orders_consumed += 1
                            try:
                                opposite.sign_swap_fulfillment(
                                    settings.HUB_OWNER_ACCOUNT_ADDRESS,
                                    settings.HUB_OWNER_ACCOUNT_KEY)
                            except LookupError as e:
                                logger.error(e)

                        if swap.is_fulfilled_swap():
                            try:
                                swap.sign_swap_fulfillment(
                                    settings.HUB_OWNER_ACCOUNT_ADDRESS,
                                    settings.HUB_OWNER_ACCOUNT_KEY)
                            except LookupError as e:
                                logger.error(e)

                        if swap.is_fulfilled_swap():
                            break

                order_books_cache[
                    opposite_order_book_name] = opposite_order_book.islice(
                        opposite_orders_consumed)

                swap_order_book_name = '{}-{}'.format(
                    swap.wallet.token.short_name,
                    swap.recipient.token.short_name)
                if not swap.is_fulfilled_swap(
                ) and swap_order_book_name in order_books_cache:
                    swap_comparison_function = price_comparison_function(
                        inverse=not swap.sell_order,
                        reverse=not swap.sell_order)
                    swap_order_book = SortedList(
                        order_books_cache[swap_order_book_name],
                        key=cmp_to_key(swap_comparison_function))
                    swap_order_book.add(swap)
                    order_books_cache[swap_order_book_name] = swap_order_book

            if matched_successfully:
                notification_queue.append((swap.id, opposite.id))

        cache.set('last_unprocessed_swap_time',
                  last_unprocessed_swap_time.timestamp())

    for swap_id, opposite_id in notification_queue:
        operator_celery.send_task('auditor.tasks.on_swap_matching',
                                  args=[swap_id, opposite_id])
예제 #14
0
def band_search(alpha, A, y, A_fix, s_start, N1_max, bandwidth, adaptive, a, b,
                algorithm):
    """Performs the forward band search fo FBMP2 from a given s-vector

    Args:
        alpha (float): hyperparameter alpha = sigma / sigma_x
        A (numpy.ndarray): the feature matrix [A_{m,n}]
        y (numpy.ndarray): target vector [y_m]
        A_fix (numpy.ndarray): fixed feature matrix [A_fix_{m,n_fix}]
        s_start (sequence of :obj:`int`): starting s-vector
        N1_max (int): highest total number of active feature to be investigated
        bandwidth (int): If `adaptive`==False, this is the number of extensions
            in each layer of the search. If `adaptive`==True, it is the minimal
            number of times any one feature's s'_n must be found 0 and 1 among
            the set of nodes (s') targeted by the extensions.
        adaptive (bool): If True, the set of extensions is adaptively chosen,
            such that it includes both 0 and 1 for all s'_n values at least
            `bandwidth` number of times.
        a (float): "a" parameter of inverse-gamma prior of sigma_x^2 (>= 0)
        b (float): "b" parameter of inverse-gamma prior of sigma_x^2 (>= 0)
        algorithm (str): "M-space" or "N1-space", which selects
            whether NodeM and ExtensionM, or NodeN1 and ExtensionN1 classes
            are used

    Returns:
        dict: dict containing

            "alpha" (float): input alpha value

            "s" (:obj:`list` of :obj:`bytes`): discovered s vectors

            "logL" (:obj:`list` of :obj:`float`): discovered ln(L(s)) values

    """
    M, N = A.shape
    s_start = np.array(s_start, dtype=np.uint8)
    if algorithm not in {'M-space', 'N1-space'}:
        raise ValueError('argument `algorithm` must be either '
                         '"M-space" or "N1-space"')

    if adaptive:
        # In the adaptive mode, we do not know how many of the possible forward
        # extensions will we need from each discovery. We need to keep all.
        discovery_bandwidth = N
    else:
        # When bandwidth is truly fixed, it's enough to keep the best
        # `bandwidth` from each discovery
        discovery_bandwidth = bandwidth

    # Initialize starting node
    if algorithm == 'M-space':
        starting_node = initialize_nodeM(alpha, A, y, A_fix, s_start)
    elif algorithm == 'N1-space':
        starting_node = initialize_nodeN1(alpha, A, y, A_fix, s_start)
        AAdiag = np.einsum('mn,mn->n', A, A)
        Ay = A.T.dot(y)
        alphasq = alpha**2
    else:
        raise ValueError('argument `algorithm` must be either '
                         '"M-space" or "N1-space"')

    # initialize result containers
    nodes_to_extend = [starting_node]
    results = {
        bytes(s_start): compute_logL(starting_node.G, starting_node.H, M, a, b)
    }

    # perform layer-by-layer discovery + extension iterations
    for N1 in range(sum(s_start), min(N + 1, N1_max), 1):
        best_forward_extensions = SortedList()  # always sorted

        # We discover the neighbors
        # and keep track of all discovered s, ln(L(s)),
        # and all necessary forward extensions.
        #
        # Crucially, `s_record_set` is also updated in each iteration.
        # This enables avoiding re-computing already discovered nodes
        for node in nodes_to_extend:
            if algorithm == 'M-space':
                new_logLs, new_forward_extensions = \
                    node.discover_neighbors(A, y,
                                            discovery_bandwidth,
                                            a, b)
            elif algorithm == 'N1-space':
                new_logLs, new_forward_extensions = \
                    node.discover_neighbors(A, AAdiag, Ay, alphasq,
                                            bandwidth, a, b)
            else:
                raise ValueError('argument `algorithm` must be either '
                                 '"M-space" or "N1-space"')

            s = node.s
            s_neighbors = list(
                map(bytes, (np.repeat(s[None, :], len(s), axis=0) != np.eye(
                    len(s), dtype=np.uint8)).astype(np.uint8)))
            results.update(dict(zip(s_neighbors, new_logLs)))
            best_forward_extensions.update(new_forward_extensions)
            if not adaptive:
                # trim list to contain only the best `bandwidth` extensions
                best_forward_extensions = SortedList(
                    best_forward_extensions.islice(-bandwidth, None))

        if adaptive:
            # The adaptive logic works the following way:
            #   - For each feature n, we keep two counters, counting
            #     how many s-vectors we've included in the set of nodes to
            #     extend that have s_n = 0 (and 1, respectively)
            #   - We iterate through the possible extensions in descending
            #     order with respect to their log_L values
            #   - We select an extension only if it moves at least one of the
            #     counters forward that is below the goal count,
            #     i.e. the `bandwidth`

            c0 = np.zeros(N, dtype=int)  # counter for (s_n == 0)
            c1 = np.zeros(N, dtype=int)  # counter for (s_n == 1)
            selected_forward_extensions = []
            for ext in best_forward_extensions[::-1]:

                # check if this extension moves any of the low counter up
                s_new = ext.s.copy()
                s_new[ext.n] = 1 - s_new[ext.n]
                if (1 - s_new)[c0 < bandwidth].any() or \
                        s_new[c1 < bandwidth].any():
                    c0 = c0 + 1 - s_new  # record the inactive features
                    c1 = c1 + s_new  # record the active features
                    selected_forward_extensions.append(ext)

            best_forward_extensions = selected_forward_extensions

        # Compute the target nodes of each selected extension.
        # This enables the next iteration to start.
        if algorithm == 'M-space':
            nodes_to_extend = [
                ext.get_target_node(A) for ext in best_forward_extensions
            ]
        elif algorithm == 'N1-space':
            nodes_to_extend = [
                ext.get_target_node(A, y) for ext in best_forward_extensions
            ]
        else:
            raise ValueError('argument `algorithm` must be either '
                             '"M-space" or "N1-space"')

    s_list, logL_list = zip(*results.items())
    return {'alpha': alpha, 's': s_list, 'logL': logL_list}