def test_split_equal(self): """Test splitEqual""" # # Create the fiber to be split # c = [0, 1, 9, 10, 12, 31, 41] p = [0, 10, 20, 100, 120, 310, 410] f = Fiber(c, p) # # Create list of reference fibers after the split # css = [[0, 1], [9, 10], [12, 31], [41]] pss = [[0, 10], [20, 100], [120, 310], [410]] split_ref = [] for (cs, ps) in zip(css, pss): split_ref.append(Fiber(cs, ps)) # # Do the split # size = 2 split = f.splitEqual(size) # # Check the split # for i, (sc, sp) in enumerate(split): self.assertEqual(sc, css[i][0]) self.assertEqual(sp, split_ref[i])
def test_split_equal_empty(self): """Test splitEqual on empty fiber""" empty = Fiber() split = empty.splitEqual(3) # After we split, we need to make sure that we have actually added # another level to the empty fiber self.assertIsInstance(split.getDefault(), Fiber)
def test_split_equal_below(self): """Test splitEqualBelow""" c0 = [0, 1, 9, 10, 12, 31, 41] p0 = [ 0, 10, 20, 100, 120, 310, 410 ] f0 = Fiber(c0, p0) c1 = [1, 2, 10, 11, 13, 32, 42] p1 = [ 1, 11, 21, 101, 121, 311, 411 ] f1 = Fiber(c1, p1) c = [2, 4] f = Fiber(c, [f0, f1]) f.splitEqualBelow(4, depth=0) f0_split = f0.splitEqual(4) f1_split = f1.splitEqual(4) f_ref = Fiber(c, [f0_split, f1_split]) self.assertEqual(f, f_ref)
def test_split_equal_then_flatten(self): """Test that flattenRanks can undo splitEqual""" # # Create the fiber to be split # c = [0, 1, 9, 10, 12, 31, 41] p = [0, 10, 20, 100, 120, 310, 410] f = Fiber(c, p) # # Do the split # size = 2 split = f.splitEqual(size) # # Check the split # self.assertEqual(split.flattenRanks(style="absolute"), f)
def test_split_equal_partioned(self): """Test splitEqual(2, partitions=2)""" # # Create the fiber to be split # c = [0, 1, 9, 10, 12, 31, 41] p = [0, 10, 20, 100, 120, 310, 410] f = Fiber(c, p) # # Create list of reference fibers after the split # a_coords = [0, 12] a1 = Fiber([0, 1], [0, 10]) a2 = Fiber([12, 31], [120, 310]) a = Fiber(coords=a_coords, payloads=[a1, a2]) b_coords = [9, 41] b1 = Fiber([9, 10], [20, 100]) b2 = Fiber([41], [410]) b = Fiber(coords=b_coords, payloads=[b1, b2]) split_ref = Fiber(payloads=[a, b]) # # Do the split # size = 2 split = f.splitEqual(size, partitions=2) # # Check the split # self.assertEqual(split, split_ref)
# splits = [0, 12, 31] print(f"NonUniform coordinate split (splits at {splits})\n") fibers = f.splitNonUniform(splits) for c, s in fibers: s.print() # # Equal position-based split # size = 2 print(f"Equal position split (groups of {size})\n") fibers = f.splitEqual(size) for c, s in fibers: s.print() sizes = [1, 2, 4] print(f"NonEqual position split (splits of sizes {sizes})\n") fibers = f.splitUnEqual(sizes) for c, s in fibers: s.print() # # Create multiple partitions #