def test_get_selector(self): expected_output0 = multi_task.NamedLists({ 'l1': [True, True, False], 'l2': [False, True, True], 'l3': [True, True, True] }) self.assertEqual(expected_output0, self.switch.get_selector(0)) expected_output1 = multi_task.NamedLists({ 'l1': [False, False, True], 'l2': [True, False, False], 'l3': [False, False, False] }) self.assertEqual(expected_output1, self.switch.get_selector(1))
def setUp(self): super().setUp() self.inputs = [ multi_task.NamedLists({ 'l1': [1.0, 2.0], 'l2': [5.0, 6.0], 'l3': [7.0, 8.0, 9.0] }), multi_task.NamedLists({ 'l1': [3.0], 'l2': [4.0], 'l3': [] }), ] self.switch = multi_task.SwitchNamedLists({ 'l1': [0, 0, 1], 'l2': [1, 0, 0], 'l3': [0, 0, 0] }) self.merged = multi_task.NamedLists({ 'l1': [1.0, 2.0, 3.0], 'l2': [4.0, 5.0, 6.0], 'l3': [7.0, 8.0, 9.0] })
def test_namedlists(self): container = multi_task.NamedLists({ 'first': (1, 2, 3), 'second': [], 'third': [4, 5] }) self.assertEqual(container.shape, (3, 0, 2)) self.assertSequenceEqual(list(container), [1, 2, 3, 4, 5]) self.assertSequenceEqual(container.first, [1, 2, 3]) self.assertIsInstance(container.first, list) with self.assertRaises(KeyError): _ = container.something constant = False copy = container.constant_copy(constant) self.assertEqual(copy.shape, container.shape) for level, num in zip(copy.levels, copy.shape): self.assertSequenceEqual(level, [constant] * num) copy = container.copy() self.assertSequenceEqual(copy.first, container.first) copy.first[0] = 8 self.assertNotEqual(copy.first[0], container.first[0]) self.assertSequenceEqual(copy.first[1:], container.first[1:]) values = [0, 0, 0, 2, 2] packed = container.pack(values) self.assertEqual(packed.shape, container.shape) self.assertSequenceEqual(list(packed), values) self.assertSequenceEqual(packed.first, values[:3]) self.assertSequenceEqual(packed.second, []) self.assertSequenceEqual(packed.third, values[3:]) flattened = container.flatten() self.assertIsInstance(flattened, dict) self.assertLen(flattened, 6) self.assertIn('first/0', flattened) self.assertEqual(flattened['first/0'], 1) self.assertIn('second/', flattened) self.assertIsNone(flattened['second/']) unflattened = multi_task.NamedLists.unflatten(flattened) self.assertEqual(unflattened.shape, container.shape)