def nn_chain_core(distances):
    # Important hack: nodes are removed lazily, as the cost of removal
    # from the skip list is too important
    n = distances.shape[0]
    # XXX: Being smart with iterators is probably creating more lines than it
    # is saving
    if sparse.issparse(distances):
        distance_iter = ((d.indices, d.data)
                         for d in sparse.csr_matrix(distances))
        vmax = 1 + distances.data.max()
    else:
        def distance_iter_():
            for i, col in enumerate(distances):
                indices = np.arange(n)
                indices = indices[indices != i]
                yield indices, col[indices]
        distance_iter = distance_iter_()
        vmax = 1 + distances.max()
    # They will be at most 2*n nodes in the hierarchical clustering
    # algorithm. We can preallocate and use arrays
    active = np.zeros(2*n, dtype=np.bool)
    active[:n] = 1
    chain = list()
    children = list()
    # XXX: should this be a list, or a dict
    distance_dict = np.empty((2*n, ), dtype=object)
    for index, (indices, data) in enumerate(distance_iter):
        # Need to be able to have an actual connectivity at some point:
        # col being a sparse matrix
        # Should catter for empty cols?
        # Probably not: they should never occur
        this_distance = IndexableSkipList(expected_size=2*indices.size)
        this_distance.multiple_insert(indices, data)
        distance_dict[index] = this_distance
    print 'Distance matrix ready'

    for this_n in xrange(n, 2*n - 1):
        print chain
        print active[chain].astype(np.int)
        if len(chain) < 4:
            # Pick any 2 active elements to complete the chain
            # The last element is active: it just got added
            a = this_n - 1
            b = this_n - 2
            while not active[b]:
                b -= 1
            chain = [a, ]
        else:
            a = chain[-4]
            b = chain[-3]
            chain = chain[:-3]
        while True:
            distance_a = distance_dict[a]
            c, min_value = distance_a.argmin()
            while not active[c]:
                # Remove previously merged node lazily
                distance_a._get_node(c, remove=1, default=0)
                c, min_value = distance_a.argmin()
            if not active[b]:
                distance_a._get_node(b, remove=1, default=0)
            elif min_value == distance_a._get_node(b, default=vmax):
                c = b
            a, b = c, a
            chain.append(a)
            if len(chain) > 2 and a == chain[-3]:
                break
        children.append((a, b, distance_a[a]))
        # Remove the corresponding skip_lists from the distance dictionary
        new_distances = distance_dict[a]
        distance_a = distance_dict[b]
        distance_dict[a] = None
        active[a] = False
        distance_dict[b] = None
        active[b] = False
        # Augment the distance matrix:
        indices, values = distance_a.items()
        for other_index, other_value in zip(indices, values):
            new_distances[other_index] = max(
                        new_distances._get_node(other_index, default=0),
                        other_value)
        indices, values = new_distances.items()
        for index in indices:
            if not active[index]:
                new_distances._get_node(index, remove=1, default=0)
        distance_dict[this_n] = new_distances
        active[this_n] = True
        #for distance_list in distance_dict[active]:
        #    distance_list._get_node(a, remove=1, default=0)
        #    distance_list._get_node(b, remove=1, default=0)
        indices, values = new_distances.items()
        for distance_list, value in zip(distance_dict[indices], values):
            distance_list[this_n] = value
    return children
from nose.tools import assert_equal

import numpy as np

from skip_list import IndexableSkipList

#def test_skip_list():
if 1:
    N = 5
    indices = np.arange(N)**2
    values = np.random.randint(1000, size=N).astype(np.float)
    # Test trivial insertion
    slist = IndexableSkipList()
    slist.multiple_insert(indices, values)
    assert_equal(list(slist.iteritems()), zip(indices, values))
    assert_equal(len(slist), N)
    # Test trivial insertion
    for i, v in zip(indices, values):
        slist[i + 1] = v - 1
        assert_equal(slist[i + 1], v - 1)

    l = list(slist.iteritems())
    # Test removal of elements
    slist.pop(l[-2][0])
    l.remove(l[-2])
    assert_equal(list(slist.iteritems()), l)
    slist.pop(l[1][0])
    l.remove(l[1])
    assert_equal(list(slist.iteritems()), l)
    slist.pop(l[1][0])
    l.remove(l[1])