def test_constructor_1D(self): """Test constructor 1D""" ans = Fiber([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [7, 1, 8, 3, 8, 4, 6, 3, 7, 5]) attrs = [] attrs.append([[0], 1, None, [10]]) attrs.append([[-1], 1, None, [10]]) attrs.append([[0], 1, None, [20]]) attrs.append([[-2], 1, None, [20]]) fs = [] fs.append(Fiber(self.input["c1"], self.input["p1"])) fs.append(Fiber(self.input["c1"], self.input["p1"], default=-1)) fs.append(Fiber(self.input["c1"], self.input["p1"], shape=[20])) fs.append(Fiber(self.input["c1"], self.input["p1"], shape=[20], default=-2)) for test, f in enumerate(fs): with self.subTest(test=test): f_attr = self.attributes(f) for n, (c, p) in enumerate(f): self.assertEqual(c, self.input["c1"][n]) self.assertEqual(p, self.input["p1"][n]) self.assertEqual(f, ans) self.assertEqual(f_attr, attrs[test])
def test_union_2x_1d(self): """Test union 2-way for 1d fibers""" ans = Fiber([0, 2, 3, 4, 6], [('AB', Payload(1), Payload(2)), ('AB', Payload(3), Payload(4)), ('B', Payload(0), Payload(5)), ('A', Payload(5), Payload(0)), ('A', Payload(7), Payload(0))]) a_m = self.input["a1_m"] b_m = self.input["b1_m"] z_m1 = a_m | b_m z_m2 = Fiber.union(a_m, b_m) for test, z_m in enumerate([z_m1, z_m2]): with self.subTest(test=test): # Check for right answer self.assertEqual(z_m, ans) # Check that payloads are of correct type self.assertIsInstance(z_m[0].payload.value[1], Payload) self.assertIsInstance(z_m[2].payload.value[1], Payload) self.assertIsInstance(z_m[3].payload.value[2], Payload) # Check that default was set properly z_m_default = z_m.getDefault() self.assertEqual(z_m_default, Payload(('', 0, 0))) self.assertIsInstance(z_m_default, Payload) # Check final shape is correct z_m_shape = z_m.getShape() self.assertEqual(z_m_shape, [7])
def test_split_unequal(self): """Test splitUnequal""" # # Create the fiber to be split # c = [0, 1, 9, 10, 12, 31, 41] p = [0, 10, 20, 100, 120, 310, 410] f = Fiber(c, p) # # Create list of reference fibers after the split # css = [[0], [1, 9], [10, 12, 31, 41]] pss = [[0], [10, 20], [100, 120, 310, 410]] split_ref = [] for (cs, ps) in zip(css, pss): split_ref.append(Fiber(cs, ps)) # # Do the split # sizes = [1, 2, 4] split = f.splitUnEqual(sizes) # # Check the split # for i, (sc, sp) in enumerate(split): self.assertEqual(sc, css[i][0]) self.assertEqual(sp, split_ref[i])
def test_split_nonuniform2(self): """Test splitNonUniform - not starting at coordinate 0""" # # Create the fiber to be split # c = [0, 1, 9, 10, 12, 31, 41] p = [0, 10, 20, 100, 120, 310, 410] f = Fiber(c, p) # # Create list of reference fibers after the split # css = [[9, 10], [12], [31, 41]] pss = [[20, 100], [120], [310, 410]] split_ref = [] for (cs, ps) in zip(css, pss): split_ref.append(Fiber(cs, ps)) # # Do the split # splits = [8, 12, 31] split = f.splitNonUniform(splits) # # Check the split # for i, (sc, sp) in enumerate(split): self.assertEqual(sc, splits[i]) self.assertEqual(sp, split_ref[i])
def test_print_3D_flattened(self): """Test str format 3D flattened""" c0 = [2, 4, 6, 8] p0 = [3, 5, 7, 9] f0 = Fiber(c0, p0) c1 = [3, 5, 7] p1 = [4, 6, 8] f1 = Fiber(c1, p1) c = [(0, 2), (1, 5)] a = Fiber(c, [f0, f1]) ss = f"{a:n*}" ss_ref = "F/[( (0, 2) -> F/[(2 -> <3>) \n" + \ " (4 -> <5>) \n" + \ " (6 -> <7>) \n" + \ " (8 -> <9>) ])\n" + \ " ( (1, 5) -> F/[(3 -> <4>) \n" + \ " (5 -> <6>) \n" + \ " (7 -> <8>) ])" self.assertEqual(ss, ss_ref) sr = f"{a!r}" sr_ref = "Fiber([(0, 2), (1, 5)], [Fiber([2, 4, 6, 8], [3, 5, 7, 9]), Fiber([3, 5, 7], [4, 6, 8])])" self.assertEqual(sr, sr_ref)
def test_getitem_2D(self): b0 = Fiber([1, 4, 7], [2, 5, 8]) b1 = Fiber([2, 4, 6], [3, 5, 7]) a0 = Fiber([2, 4], [b0, b1]) self.assertEqual(a0[1][1], 5)
def test_split_uniform_empty(self): """Test splitUniform on empty fiber""" empty = Fiber() split = empty.splitUniform(5) # After we split, we need to make sure that we have actually added # another level to the empty fiber self.assertIsInstance(split.getDefault(), Fiber)
def setUp(self): self.input = {} self.input["c1"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] self.input["p1"] = [7, 1, 8, 3, 8, 4, 6, 3, 7, 5] self.input["c2"] = [ 0, 1, 2] self.input["p2"] = [ Fiber([2], [4]), Fiber([1], [4]), Fiber([2], [2])]
def test_fromRandom_1D_dense(self): """Test random 1d sparse""" ans = Fiber([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [7, 1, 8, 3, 8, 4, 6, 3, 7, 5]) attr = [[0], 1, None, [10]] f = Fiber.fromRandom([10], [1.0], 9, 10) f_attr = self.attributes(f) self.assertEqual(f, ans) self.assertEqual(f_attr, attr)
def test_fromRandom_1D_sparse(self): """Test random 1d sparse""" ans = Fiber([3, 6, 8, 9, 12, 16, 19, 20, 28, 30, 32, 38, 40, 43, 46, 47, 48, 49], [8, 9, 6, 3, 5, 4, 1, 4, 6, 4, 1, 6, 2, 6, 5, 9, 2, 5]) attr = [[0], 1, None, [50]] f = Fiber.fromRandom([50], [0.3], 9, 10) f_attr = self.attributes(f) self.assertEqual(f, ans) self.assertEqual(f_attr, attr)
def test_fromRandom_2D_sparse(self): """Test random 1d sparse""" ans = Fiber([0, 1, 2], [Fiber([2], [4]), Fiber([1], [4]), Fiber([2], [2])]) attr = [[Fiber, 0], 2, None, [3, 3]] f = Fiber.fromRandom([3, 3], [1.0, 0.3], 4, 10) f_attr = self.attributes(f) self.assertEqual(f, ans) self.assertEqual(f_attr, attr)
def test_add_fiber(self): """Test __add__ fiber""" f_in = Fiber.fromUncompressed([1, 2, 3, 0, 0, 6]) g_in = Fiber([6, 8], [20, 22]) fg_ref = Fiber([0, 1, 2, 5, 6, 8], [1, 2, 3, 6, 20, 22]) with self.subTest("f+g"): fg_out = f_in + g_in self.assertEqual(fg_ref, fg_out) with self.subTest("f+=g"): # f_in gets clobbered! f_in += g_in self.assertEqual(fg_ref, f_in)
def test_mul_int(self): """Test __mul__ integers""" f_in = Fiber.fromUncompressed([1, 2, 3, 0, 0, 6]) f_ref = Fiber([0, 1, 2, 5], [2, 4, 6, 12]) with self.subTest("f_in * 2"): f_out = f_in * 2 self.assertEqual(f_ref, f_out) with self.subTest("2*f_in"): f_out = 2 * f_in self.assertEqual(f_ref, f_out) with self.subTest("f_in *=2"): # f_in gets clobbered! f_in *= 2 self.assertEqual(f_ref, f_in)
def test_add_int(self): """Test __add__ integers""" f_in = Fiber.fromUncompressed([1, 2, 3, 0, 0, 6]) f_ref = Fiber.fromUncompressed([3, 4, 5, 2, 2, 8]) with self.subTest("f_in + 2"): f_out = f_in + 2 self.assertEqual(f_ref, f_out) with self.subTest("2 + f_in"): f_out = 2 + f_in self.assertEqual(f_ref, f_out) with self.subTest("f_in += 2"): # f_in gets clobbered! f_in += 2 self.assertEqual(f_ref, f_in)
def test_add_payload(self): """Test __add__ payload""" f_in = Fiber.fromUncompressed([1, 2, 3, 0, 0, 6]) f_ref = Fiber.fromUncompressed([3, 4, 5, 2, 2, 8]) two = Payload(2) with self.subTest("f_in + 2"): f_out = f_in + two self.assertEqual(f_ref, f_out) with self.subTest("2 + f_in"): f_out = two + f_in self.assertEqual(f_ref, f_out) with self.subTest("f_in += 2"): # f_in gets clobbered! f_in += two self.assertEqual(f_ref, f_in)
def test_mul_payload(self): """Test __mul__ payload""" f_in = Fiber.fromUncompressed([1, 2, 3, 0, 0, 6]) f_ref = Fiber([0, 1, 2, 5], [2, 4, 6, 12]) two = Payload(2) with self.subTest("f_in * 2"): f_out = f_in * two self.assertEqual(f_ref, f_out) with self.subTest("2*f_in"): f_out = two * f_in self.assertEqual(f_ref, f_out) with self.subTest("f_in *=2"): # f_in gets clobbered! f_in *= two self.assertEqual(f_ref, f_in)
def test_split_nonuniform_then_flatten(self): """Test that flattenRanks can undo splitNonUniform""" # # Create the fiber to be split # c = [0, 1, 9, 10, 12, 31, 41] p = [0, 10, 20, 100, 120, 310, 410] f = Fiber(c, p) # # Do the split # splits = [0, 12, 31] split = f.splitNonUniform(splits) # # Check the split # self.assertEqual(split.flattenRanks(style="absolute"), f)
def test_split_unequal_then_flatten(self): """Test that flattenRanks can undo splitUnequal""" # # Create the fiber to be split # c = [0, 1, 9, 10, 12, 31, 41] p = [0, 10, 20, 100, 120, 310, 410] f = Fiber(c, p) # # Do the split # sizes = [1, 2, 4] split = f.splitUnEqual(sizes) # # Check the split # self.assertEqual(split.flattenRanks(style="absolute"), f)
def test_split_uniform_relative_then_flatten(self): """Test that flattenRanks can undo splitUniform (relative)""" # # Create the fiber to be split # c = [0, 1, 9, 10, 12, 31, 41] p = [0, 10, 20, 100, 120, 310, 410] f = Fiber(c, p) # # Do the split # coords = 10 split = f.splitUniform(coords, relativeCoords=True) # # Check the split # self.assertEqual(split.flattenRanks(style="relative"), f)
def test_split_uniform_then_flatten(self): """Test that flattenRanks() can undo splitUniform""" # # Create the fiber to be split # c = [0, 1, 9, 10, 12, 31, 41] p = [0, 10, 20, 100, 120, 310, 410] f = Fiber(c, p) # # Do the split # coords = 10 split = f.splitUniform(coords) # # Check that flattening after splitting gives us the same answer # self.assertEqual(split.flattenRanks(style="absolute"), f)
def test_union_2x_1d2d(self): """Test union 2-way for 1d/2d fibers""" ans = Fiber([0, 2, 4, 6], [('AB', 1, Fiber([0, 2, 3], [2, 4, 5])), ('AB', 3, Fiber([0, 1, 2], [3, 4, 6])), ('AB', 5, Fiber([0, 1, 2, 3], [1, 2, 3, 4])), ('A', 7, Fiber([], []))]) a_m = self.input["a1_m"] b_m = self.input["b2_m"] z_m1 = a_m | b_m z_m2 = Fiber.union(a_m, b_m) for test, z_m in enumerate([z_m1, z_m2]): with self.subTest(test=test): # Check for right answer self.assertEqual(z_m, ans) # Check that payloads are of correct type self.assertIsInstance(z_m[0].payload.value[1], Payload) self.assertIsInstance(z_m[0].payload.value[2], Fiber) self.assertIsInstance(z_m[2].payload.value[1], Payload) self.assertIsInstance(z_m[3].payload.value[2], Fiber) # Check that default was set properly z_m_default = z_m.getDefault() self.assertEqual(z_m_default, Payload(('', 0, Fiber))) self.assertIsInstance(z_m_default, Payload) # Check final shape is correct (note it is 1-D) z_m_shape = z_m.getShape() self.assertEqual(z_m_shape, [7])
def test_getPayloadRef_2d(self): """Test getPayloadRef of a 2-D tensor""" t = Tensor.fromYAMLfile("./data/test_tensor-1.yaml") with self.subTest(test="Existing element"): p23_ref = 203 p23 = t.getPayloadRef(2, 3) self.assertEqual(p23_ref, p23) # Make sure change is seen p23_new_ref = 310 p23 <<= p23_new_ref p23_new = t.getPayload(2, 3) self.assertEqual(p23_new_ref, p23_new) with self.subTest(test="Non-existing element"): p31_ref = 0 p31 = t.getPayloadRef(3, 1) self.assertEqual(p31_ref, p31) # Make sure change is seen p31_new_ref = 100 p31 <<= p31_new_ref p31_new = t.getPayload(3, 1) self.assertEqual(p31_new_ref, p31_new) with self.subTest(test="Element of non-existing fiber"): p51_ref = 0 p51 = t.getPayloadRef(5, 1) self.assertEqual(p51_ref, p51) # Make sure change is NOT seen p51_new_ref = 100 p51 <<= p51_new_ref p51_new = t.getPayload(5, 1) self.assertEqual(p51_new_ref, p51_new) with self.subTest(test="Existing fiber"): p4_ref = Fiber([0, 2], [400, 402]) p4 = t.getPayloadRef(4) self.assertEqual(p4_ref, p4)
def test_flatten_below(self): """Test {,un}flattenRanksBelow""" c0 = [0, 1, 9, 10, 12, 31, 41] p0 = [ 0, 10, 20, 100, 120, 310, 410 ] f0 = Fiber(c0, p0) c1 = [1, 2, 10, 11, 13, 32, 42] p1 = [ 1, 11, 21, 101, 121, 311, 411 ] f1 = Fiber(c1, p1) c = [2, 4] f = Fiber(c, [f0, f1]) # This just creates another level... f.splitUnEqualBelow([3, 3, 1], depth=0) f_ref = deepcopy(f) # Flattening and unflattening should do nothing f.flattenRanksBelow() f.unflattenRanksBelow() self.assertEqual(f, f_ref)
def test_split_uniform_relative(self): """Test splitUniform""" # # Create the fiber to be split # c = [0, 1, 9, 10, 12, 31, 41] p = [0, 10, 20, 100, 120, 310, 410] f = Fiber(c, p) # # Create list of reference fibers after the split # split_ref_coords = [0, 10, 30, 40] css = [[0, 1, 9], [0, 2], [1], [1]] pss = [[0, 10, 20], [100, 120], [310], [410]] split_ref_payloads = [] for (cs, ps) in zip(css, pss): split_ref_payloads.append(Fiber(cs, ps)) # # Do the split # coords = 10 split = f.splitUniform(coords, relativeCoords=True) # # Check the split # for i, (sc, sp) in enumerate(split): self.assertEqual(sc, split_ref_coords[i]) self.assertEqual(sp, split_ref_payloads[i])
def test_split_equal_partioned(self): """Test splitEqual(2, partitions=2)""" # # Create the fiber to be split # c = [0, 1, 9, 10, 12, 31, 41] p = [0, 10, 20, 100, 120, 310, 410] f = Fiber(c, p) # # Create list of reference fibers after the split # a_coords = [0, 12] a1 = Fiber([0, 1], [0, 10]) a2 = Fiber([12, 31], [120, 310]) a = Fiber(coords=a_coords, payloads=[a1, a2]) b_coords = [9, 41] b1 = Fiber([9, 10], [20, 100]) b2 = Fiber([41], [410]) b = Fiber(coords=b_coords, payloads=[b1, b2]) split_ref = Fiber(payloads=[a, b]) # # Do the split # size = 2 split = f.splitEqual(size, partitions=2) # # Check the split # self.assertEqual(split, split_ref)
def test_fromUncompressed_1D(self): """Test construction of a tensor from nested lists""" tensor_ref = Tensor.fromYAMLfile("./data/test_tensor-1.yaml") # Manual copy of test_tensor-1.yaml # 0 1 2 3 # t = [100, 101, 0, 102] fiber = Fiber([0, 1, 3], [100, 101, 102]) tensor = Tensor.fromUncompressed(["M"], t) self.assertEqual(tensor.getRoot(), fiber)
def test_print_2D_flattened(self): """Test str format 2D flattened""" c = [(2, 3), (2, 4), (3, 1), (8, 2)] p = [3, 5, 7, 9] a = Fiber(c, p) ss = f"{a:n*}" ss_ref = "F/[((2, 3) -> <3>) \n ((2, 4) -> <5>) \n ((3, 1) -> <7>) \n ((8, 2) -> <9>) ]" self.assertEqual(ss, ss_ref) sr = f"{a!r}" sr_ref = "Fiber([(2, 3), (2, 4), (3, 1), (8, 2)], [3, 5, 7, 9])" self.assertEqual(sr, sr_ref)
def test_print_1D(self): """Test str format 1D""" c = [2, 4, 6, 8] p = [3, 5, 7, 9] a = Fiber(c, p) ss = f"{a:n*}" ss_ref = "F/[(2 -> <3>) \n (4 -> <5>) \n (6 -> <7>) \n (8 -> <9>) ]" self.assertEqual(ss, ss_ref) sr = f"{a!r}" sr_ref = "Fiber([2, 4, 6, 8], [3, 5, 7, 9])" self.assertEqual(sr, sr_ref)
def attributes(f): """Get all attributes of a fiber""" defaults = [] ff = f while isinstance(ff, Fiber): defaults.append(ff.getDefault()) ff = (ff.payloads or [Fiber([],[])])[0] attributes = [ defaults, f.getDepth(), f.getOwner(), f.getShape() ] return attributes
def test_flattenRanks_f02(self): """ Test flattenRanks - f02 """ t0 = Tensor.fromYAMLfile("./data/tensor_3d-0.yaml") t1 = Tensor.fromYAMLfile("./data/tensor_3d-1.yaml") t2 = Tensor.fromFiber(["A", "B", "C", "D"], Fiber([1, 4], [t0.getRoot(), t1.getRoot()]), name="t2") f13 = t2.flattenRanks(depth=1, levels=2) u13 = f13.unflattenRanks(depth=1, levels=2) self.assertEqual(u13, t2) f04 = t2.flattenRanks(depth=0, levels=3) u04 = f04.unflattenRanks(depth=0, levels=3) self.assertEqual(u04, t2)