def murty(C): """Algorithm due to Murty.""" try: Q = queue.PriorityQueue() M = C.shape[0] N = C.shape[1] cost, assign = lap(C)[0:2] Q.put((cost, list(assign), (), (), (), ())) k = 0 while not Q.empty(): S = Q.get_nowait() yield (S[0], S[1][:M]) k += 1 ni = len(S[2]) rmap = tuple(x for x in range(M) if x not in S[2]) cmap = tuple(x for x in S[1] if x not in S[3]) cmap += tuple(x for x in range(N) if x not in S[3] and x not in S[1]) removed_values = C[S[4], S[5]] C[S[4], S[5]] = LARGE C_ = C[rmap, :][:, cmap] for t in range(M - ni): removed_value = C_[t, t] C_[t, t] = LARGE cost, lassign = lap(C_[t:, t:])[0:2] if LARGE not in C_[range(t, t + len(lassign)), lassign + t]: cost += C[S[2], S[3]].sum() cost += C_[range(t), range(t)].sum() assign = [None] * M for r in range(ni): assign[S[2][r]] = S[3][r] for r in range(t): assign[rmap[r]] = cmap[r] for r in range(len(lassign)): assign[rmap[r + t]] = cmap[lassign[r] + t] nxt = (cost, assign, S[2] + tuple(rmap[x] for x in range(t)), S[3] + tuple(cmap[:t]), S[4] + (rmap[t], ), S[5] + (cmap[t], )) Q.put(nxt) C_[t, t] = removed_value C[S[4], S[5]] = removed_values except GeneratorExit: pass
def test_lap_square(): ret = lapjv.lap(cost) assert ret[0] == 17.0 assert np.all(ret[1] == [1, 2, 0, 4, 5, 3, 7, 6]) assert np.all(ret[2] == [2, 0, 1, 5, 3, 4, 7, 6]) assert cost[range(cost.shape[0]), ret[1]].sum() == ret[0] assert cost[ret[2], range(cost.shape[1])].sum() == ret[0]
def test_lap_duals(): ret = lapjv.lap(cost, u=np.zeros((cost.shape[0],), dtype=np.double)) assert ret[0] == 17.0 assert np.all(ret[1] == [1, 2, 0, 4, 5, 3, 7, 6]) assert np.all(ret[2] == [2, 0, 1, 5, 3, 4, 7, 6]) assert cost[range(cost.shape[0]), ret[1]].sum() == ret[0] assert cost[ret[2], range(cost.shape[1])].sum() == ret[0]
def test_lap_cost_eps(): # This test should just return. cost = np.genfromtxt(GzipFile('cost_eps.csv.gz'), delimiter=",") lapjv.lap(cost)
def test_lap_cost_limit(): ret = lapjv.lap(cost[:3, :3], cost_limit=4.99) assert ret[0] == 3.0 assert np.all(ret[1] == [1, 2, -1]) assert np.all(ret[2] == [-1, 0, 1])
def test_lap_extension(): ret = lapjv.lap(cost[:2, :4], extend_cost=True) assert ret[0] == 3.0 assert np.all(ret[1] == [1, 2]) assert np.all(ret[2] == [-1, 0, 1, -1])
def test_lap_non_contigous(): ret = lapjv.lap(cost[:3, :3]) assert ret[0] == 8.0 assert np.all(ret[1] == [1, 2, 0]) assert np.all(ret[2] == [2, 0, 1])
def test_lap_non_square_fail(): lapjv.lap(np.zeros((3, 2)))
def test_lap_empty(): lapjv.lap(np.ndarray([]))