def test_equal_policy(self):
        testBench = commsTraceReplayBench()
        testBench.collectiveArgs.device = "cpu"
        testBench.collectiveArgs.world_size = 2
        testBench.rebalance_policy = "equal"

        testComm = commsArgs()
        testComm.comms = "all_to_allv"
        testComm.inMsgSize = 5
        testComm.outMsgSize = 3
        testComm.inSplit = [3, 2]
        testComm.outSplit = [1, 2]

        ipTensor = torch.tensor([16], dtype=torch.int) # Mock a second rank to have inMsgSize 11
        testBench.backendFuncs = MockBackendFunction()
        testBench.backendFuncs.mock_collective = mock.MagicMock(side_effect=(lambda collectiveArgs: setattr(collectiveArgs, "ipTensor", ipTensor)))

        testBench.rebalanceSplit(testComm)
        # Mock all_reduce wil return 16, so inMsgSize, outMsgSize should be equal to 8 since we are assuming world_size = 2.
        # inSplit and outSplit should be [4, 4]
        print(f"ipTensor after: {testBench.collectiveArgs.ipTensor}")
        self.assertEqual(8, testComm.inMsgSize)
        self.assertEqual(8, testComm.outMsgSize)
        self.assertEqual([4,4], testComm.inSplit)
        self.assertEqual([4,4], testComm.outSplit)
Exemplo n.º 2
0
def createCommsArgs(**kwargs) -> commsArgs:
    """
    Test utility to create comms args from a dict of values.
    """
    curComm = commsArgs()
    for key, value in kwargs.items():
        setattr(curComm, key, value)

    return curComm
 def test_non_blocking_run(self):
     testBench = commsTraceReplayBench()
     testBench.is_blocking = False
     testBench.backendFuncs = MockBackendFunction()
     collName = "all_gather"
     curComm = commsArgs(req=0)
     (latency, global_latency) = testBench.runComms(collName, curComm, "test_stack")
     self.assertIsNotNone(latency)
     self.assertIsNotNone(global_latency)
     self.assertEqual(latency, global_latency)
 def test_tensor_shrink_allgather(self):
     testBench = commsTraceReplayBench()
     testBench.backendFuncs = MockBackendFunction()
     commsParams = commsParamsTest()
     commsParams.dcheck = 1
     commsParams.device = "cpu"
     curComm = commsArgs(comms="all_gather", dtype="Int", inMsgSize=4, outMsgSize=4, worldSize=4)
     testBench.shrink = True
     testBench.collectiveArgs.world_size = 1
     (iptensor, optensor) = testBench.prepComms(curComm, commsParams)
     # tensor length should shrink to world size
     self.assertEqual(1, len(iptensor))
     self.assertEqual(1, len(optensor))
 def test_tensor_shrink_alltoallv(self):
     testBench = commsTraceReplayBench()
     testBench.backendFuncs = MockBackendFunction()
     commsParams = commsParamsTest()
     commsParams.dcheck = 1
     commsParams.device = "cpu"
     curComm = commsArgs(comms="all_to_allv", dtype="Int", inMsgSize=4, outMsgSize=4, inSplit=[1, 1, 1, 1], outSplit=[1, 1, 1, 1], worldSize=4)
     testBench.shrink = True
     testBench.collectiveArgs.world_size = 1
     (iptensor, optensor) = testBench.prepComms(curComm, commsParams)
     # tensor length should shrink to world size
     self.assertEqual(1, len(iptensor))
     self.assertEqual(1, len(optensor))
     # both input and output tensors should be equal to 1 for all_to_allv
     self.assertEqual(1, iptensor[0])
     self.assertEqual(1, optensor[0])
 def test_tensor_no_shrink(self):
     testBench = commsTraceReplayBench()
     testBench.backendFuncs = MockBackendFunction()
     commsParams = commsParamsTest()
     commsParams.dcheck = 1
     commsParams.device = "cpu"
     curComm = commsArgs(comms="recv", dtype="Int", inMsgSize=1, outMsgSize=1)
     testBench.shrink = False
     testBench.collectiveArgs.world_size = 1
     (iptensor, optensor) = testBench.prepComms(curComm, commsParams)
     # tensor length needs to match world_size
     self.assertEqual(1, len(iptensor))
     self.assertEqual(1, len(optensor))
     # both input and output tensors should be equal to 1
     self.assertEqual(1, iptensor[0])
     self.assertEqual(1, optensor[0])
 def test_no_tensor(self):
     # wait and barrier require no tensors
     testBench = commsTraceReplayBench()
     testBench.backendFuncs = MockBackendFunction()
     commsParams = commsParamsTest()
     commsParams.dcheck = 1
     commsParams.device = "cpu"
     curComm = commsArgs()
     curComm.comms = "wait"
     (iptensor, optensor) = testBench.prepComms(curComm, None)
     self.assertEqual(0, len(iptensor))
     self.assertEqual(0, len(optensor))
     curComm.comms = "barrier"
     (iptensor, optensor) = testBench.prepComms(curComm, None)
     self.assertEqual(0, len(iptensor))
     self.assertEqual(0, len(optensor))
    def test_unsupported_policy(self):
        testBench = commsTraceReplayBench()
        testBench.rebalance_policy = "unsupported" # any str that isn't in supported is considered unsupported

        testComm = commsArgs()
        testComm.comms = "all_to_allv"
        testComm.inMsgSize = 5
        testComm.outMsgSize = 3
        testComm.worldSize = 2
        testComm.inSplit = [3, 2]
        testComm.outSplit = [1, 2]

        testBench.rebalanceSplit(testComm)

        # should be no change
        self.assertEqual(5, testComm.inMsgSize)
        self.assertEqual(3, testComm.outMsgSize)
        self.assertEqual([3,2], testComm.inSplit)
        self.assertEqual([1,2], testComm.outSplit)