def test_equal_policy(self): testBench = commsTraceReplayBench() testBench.collectiveArgs.device = "cpu" testBench.collectiveArgs.world_size = 2 testBench.rebalance_policy = "equal" testComm = commsArgs() testComm.comms = "all_to_allv" testComm.inMsgSize = 5 testComm.outMsgSize = 3 testComm.inSplit = [3, 2] testComm.outSplit = [1, 2] ipTensor = torch.tensor([16], dtype=torch.int) # Mock a second rank to have inMsgSize 11 testBench.backendFuncs = MockBackendFunction() testBench.backendFuncs.mock_collective = mock.MagicMock(side_effect=(lambda collectiveArgs: setattr(collectiveArgs, "ipTensor", ipTensor))) testBench.rebalanceSplit(testComm) # Mock all_reduce wil return 16, so inMsgSize, outMsgSize should be equal to 8 since we are assuming world_size = 2. # inSplit and outSplit should be [4, 4] print(f"ipTensor after: {testBench.collectiveArgs.ipTensor}") self.assertEqual(8, testComm.inMsgSize) self.assertEqual(8, testComm.outMsgSize) self.assertEqual([4,4], testComm.inSplit) self.assertEqual([4,4], testComm.outSplit)
def createCommsArgs(**kwargs) -> commsArgs: """ Test utility to create comms args from a dict of values. """ curComm = commsArgs() for key, value in kwargs.items(): setattr(curComm, key, value) return curComm
def test_non_blocking_run(self): testBench = commsTraceReplayBench() testBench.is_blocking = False testBench.backendFuncs = MockBackendFunction() collName = "all_gather" curComm = commsArgs(req=0) (latency, global_latency) = testBench.runComms(collName, curComm, "test_stack") self.assertIsNotNone(latency) self.assertIsNotNone(global_latency) self.assertEqual(latency, global_latency)
def test_tensor_shrink_allgather(self): testBench = commsTraceReplayBench() testBench.backendFuncs = MockBackendFunction() commsParams = commsParamsTest() commsParams.dcheck = 1 commsParams.device = "cpu" curComm = commsArgs(comms="all_gather", dtype="Int", inMsgSize=4, outMsgSize=4, worldSize=4) testBench.shrink = True testBench.collectiveArgs.world_size = 1 (iptensor, optensor) = testBench.prepComms(curComm, commsParams) # tensor length should shrink to world size self.assertEqual(1, len(iptensor)) self.assertEqual(1, len(optensor))
def test_tensor_shrink_alltoallv(self): testBench = commsTraceReplayBench() testBench.backendFuncs = MockBackendFunction() commsParams = commsParamsTest() commsParams.dcheck = 1 commsParams.device = "cpu" curComm = commsArgs(comms="all_to_allv", dtype="Int", inMsgSize=4, outMsgSize=4, inSplit=[1, 1, 1, 1], outSplit=[1, 1, 1, 1], worldSize=4) testBench.shrink = True testBench.collectiveArgs.world_size = 1 (iptensor, optensor) = testBench.prepComms(curComm, commsParams) # tensor length should shrink to world size self.assertEqual(1, len(iptensor)) self.assertEqual(1, len(optensor)) # both input and output tensors should be equal to 1 for all_to_allv self.assertEqual(1, iptensor[0]) self.assertEqual(1, optensor[0])
def test_tensor_no_shrink(self): testBench = commsTraceReplayBench() testBench.backendFuncs = MockBackendFunction() commsParams = commsParamsTest() commsParams.dcheck = 1 commsParams.device = "cpu" curComm = commsArgs(comms="recv", dtype="Int", inMsgSize=1, outMsgSize=1) testBench.shrink = False testBench.collectiveArgs.world_size = 1 (iptensor, optensor) = testBench.prepComms(curComm, commsParams) # tensor length needs to match world_size self.assertEqual(1, len(iptensor)) self.assertEqual(1, len(optensor)) # both input and output tensors should be equal to 1 self.assertEqual(1, iptensor[0]) self.assertEqual(1, optensor[0])
def test_no_tensor(self): # wait and barrier require no tensors testBench = commsTraceReplayBench() testBench.backendFuncs = MockBackendFunction() commsParams = commsParamsTest() commsParams.dcheck = 1 commsParams.device = "cpu" curComm = commsArgs() curComm.comms = "wait" (iptensor, optensor) = testBench.prepComms(curComm, None) self.assertEqual(0, len(iptensor)) self.assertEqual(0, len(optensor)) curComm.comms = "barrier" (iptensor, optensor) = testBench.prepComms(curComm, None) self.assertEqual(0, len(iptensor)) self.assertEqual(0, len(optensor))
def test_unsupported_policy(self): testBench = commsTraceReplayBench() testBench.rebalance_policy = "unsupported" # any str that isn't in supported is considered unsupported testComm = commsArgs() testComm.comms = "all_to_allv" testComm.inMsgSize = 5 testComm.outMsgSize = 3 testComm.worldSize = 2 testComm.inSplit = [3, 2] testComm.outSplit = [1, 2] testBench.rebalanceSplit(testComm) # should be no change self.assertEqual(5, testComm.inMsgSize) self.assertEqual(3, testComm.outMsgSize) self.assertEqual([3,2], testComm.inSplit) self.assertEqual([1,2], testComm.outSplit)