def test_equal_policy(self):
        testBench = commsTraceReplayBench()
        testBench.collectiveArgs.device = "cpu"
        testBench.collectiveArgs.world_size = 2
        testBench.rebalance_policy = "equal"

        testComm = commsArgs()
        testComm.comms = "all_to_allv"
        testComm.inMsgSize = 5
        testComm.outMsgSize = 3
        testComm.inSplit = [3, 2]
        testComm.outSplit = [1, 2]

        ipTensor = torch.tensor([16], dtype=torch.int) # Mock a second rank to have inMsgSize 11
        testBench.backendFuncs = MockBackendFunction()
        testBench.backendFuncs.mock_collective = mock.MagicMock(side_effect=(lambda collectiveArgs: setattr(collectiveArgs, "ipTensor", ipTensor)))

        testBench.rebalanceSplit(testComm)
        # Mock all_reduce wil return 16, so inMsgSize, outMsgSize should be equal to 8 since we are assuming world_size = 2.
        # inSplit and outSplit should be [4, 4]
        print(f"ipTensor after: {testBench.collectiveArgs.ipTensor}")
        self.assertEqual(8, testComm.inMsgSize)
        self.assertEqual(8, testComm.outMsgSize)
        self.assertEqual([4,4], testComm.inSplit)
        self.assertEqual([4,4], testComm.outSplit)
    def test_dry_run(self):
        test_trace = [
                        createCommsArgs(comms="test", inMsgSize=1,
                         outMsgSize=1, markerStack=["test_stack"]),
                        createCommsArgs(comms="all_gather", inMsgSize=2,
                         outMsgSize=2),
                        createCommsArgs(comms="wait", markerStack=["test_stack"])
                     ]
        testBench = commsTraceReplayBench()
        testBench.comms_trace = test_trace
        testBench.is_dry_run = True
        testBench.initTraceStat()
        # Only 2 messages had msg sizes
        self.assertEqual(2, len(testBench.collInMsgSizes))
        self.assertEqual(2, len(testBench.collOutMsgSizes))
        # The sum of the sizes of all all_gather msgs is 2 for in and out
        self.assertEqual(2, sum(testBench.collInMsgSizes["all_gather"]))
        self.assertEqual(2, sum(testBench.collOutMsgSizes["all_gather"]))
        # Dry run records comm blocks. We have two colls in test_stack
        self.assertEqual(2, len(testBench.comms_blocks["test_stack"]))
        # check values of comm_blocks
        self.assertEqual("test", testBench.comms_blocks["test_stack"][0]["comms"]) # first comm in "test_stack" is test
        self.assertEqual(1, testBench.comms_blocks["test_stack"][0]["in_msg_size"])
        self.assertEqual(1, testBench.comms_blocks["test_stack"][0]["out_msg_size"])

        self.assertEqual("wait", testBench.comms_blocks["test_stack"][1]["comms"]) # second comm in "test_stack" is wait
 def test_non_blocking_run(self):
     testBench = commsTraceReplayBench()
     testBench.is_blocking = False
     testBench.backendFuncs = MockBackendFunction()
     collName = "all_gather"
     curComm = commsArgs(req=0)
     (latency, global_latency) = testBench.runComms(collName, curComm, "test_stack")
     self.assertIsNotNone(latency)
     self.assertIsNotNone(global_latency)
     self.assertEqual(latency, global_latency)
 def test_tensor_shrink_allgather(self):
     testBench = commsTraceReplayBench()
     testBench.backendFuncs = MockBackendFunction()
     commsParams = commsParamsTest()
     commsParams.dcheck = 1
     commsParams.device = "cpu"
     curComm = commsArgs(comms="all_gather", dtype="Int", inMsgSize=4, outMsgSize=4, worldSize=4)
     testBench.shrink = True
     testBench.collectiveArgs.world_size = 1
     (iptensor, optensor) = testBench.prepComms(curComm, commsParams)
     # tensor length should shrink to world size
     self.assertEqual(1, len(iptensor))
     self.assertEqual(1, len(optensor))
 def test_warm_up_bench(self):
     test_trace = [
                     createCommsArgs(comms="test", inMsgSize=1,
                      outMsgSize=1, markerStack=["test_stack"]),
                     createCommsArgs(comms="all_gather", inMsgSize=2,
                      outmsgSize=2),
                     createCommsArgs(comms="wait", markerStack=["test_stack"])
                  ]
     testBench = commsTraceReplayBench()
     testBench.backendFuncs = MockBackendFunction()
     testBench.comms_trace = test_trace
     commsParams = commsParamsTest()
     testBench.warmUpBench(commsParams)
     self.assertTrue(True) # just check to see if warmUpBench ran without failure
 def test_init_bench(self):
     testBench = commsTraceReplayBench()
     commsParams = commsParamsTest()
     args = testArgs()
     args.use_timestamp = True
     args.num_msg = 1000
     args.auto_shrink = False
     args.no_warm_up = False
     testBench.initBench(commsParams, args)
     # check if parameters are being set
     self.assertEqual(True, args.use_timestamp, testBench.use_timestamp)
     self.assertEqual(1000, args.num_msg, testBench.max_msg_cnt)
     self.assertEqual(False, args.auto_shrink, testBench.shrink)
     self.assertEqual(False, args.no_warm_up, not testBench.do_warm_up)
 def test_tensor_shrink_alltoallv(self):
     testBench = commsTraceReplayBench()
     testBench.backendFuncs = MockBackendFunction()
     commsParams = commsParamsTest()
     commsParams.dcheck = 1
     commsParams.device = "cpu"
     curComm = commsArgs(comms="all_to_allv", dtype="Int", inMsgSize=4, outMsgSize=4, inSplit=[1, 1, 1, 1], outSplit=[1, 1, 1, 1], worldSize=4)
     testBench.shrink = True
     testBench.collectiveArgs.world_size = 1
     (iptensor, optensor) = testBench.prepComms(curComm, commsParams)
     # tensor length should shrink to world size
     self.assertEqual(1, len(iptensor))
     self.assertEqual(1, len(optensor))
     # both input and output tensors should be equal to 1 for all_to_allv
     self.assertEqual(1, iptensor[0])
     self.assertEqual(1, optensor[0])
 def test_tensor_no_shrink(self):
     testBench = commsTraceReplayBench()
     testBench.backendFuncs = MockBackendFunction()
     commsParams = commsParamsTest()
     commsParams.dcheck = 1
     commsParams.device = "cpu"
     curComm = commsArgs(comms="recv", dtype="Int", inMsgSize=1, outMsgSize=1)
     testBench.shrink = False
     testBench.collectiveArgs.world_size = 1
     (iptensor, optensor) = testBench.prepComms(curComm, commsParams)
     # tensor length needs to match world_size
     self.assertEqual(1, len(iptensor))
     self.assertEqual(1, len(optensor))
     # both input and output tensors should be equal to 1
     self.assertEqual(1, iptensor[0])
     self.assertEqual(1, optensor[0])
 def test_no_tensor(self):
     # wait and barrier require no tensors
     testBench = commsTraceReplayBench()
     testBench.backendFuncs = MockBackendFunction()
     commsParams = commsParamsTest()
     commsParams.dcheck = 1
     commsParams.device = "cpu"
     curComm = commsArgs()
     curComm.comms = "wait"
     (iptensor, optensor) = testBench.prepComms(curComm, None)
     self.assertEqual(0, len(iptensor))
     self.assertEqual(0, len(optensor))
     curComm.comms = "barrier"
     (iptensor, optensor) = testBench.prepComms(curComm, None)
     self.assertEqual(0, len(iptensor))
     self.assertEqual(0, len(optensor))
    def test_unsupported_policy(self):
        testBench = commsTraceReplayBench()
        testBench.rebalance_policy = "unsupported" # any str that isn't in supported is considered unsupported

        testComm = commsArgs()
        testComm.comms = "all_to_allv"
        testComm.inMsgSize = 5
        testComm.outMsgSize = 3
        testComm.worldSize = 2
        testComm.inSplit = [3, 2]
        testComm.outSplit = [1, 2]

        testBench.rebalanceSplit(testComm)

        # should be no change
        self.assertEqual(5, testComm.inMsgSize)
        self.assertEqual(3, testComm.outMsgSize)
        self.assertEqual([3,2], testComm.inSplit)
        self.assertEqual([1,2], testComm.outSplit)
 def test_not_dry_run(self):
     test_trace = [
                     createCommsArgs(comms="test", inMsgSize=1,
                      outMsgSize=1, markerStack=["test_stack"]),
                     createCommsArgs(comms="all_gather", inMsgSize=2,
                      outMsgSize=2),
                     createCommsArgs(comms="wait", markerStack=["test_stack"])
                  ]
     testBench = commsTraceReplayBench()
     testBench.comms_trace = test_trace
     testBench.initTraceStat()
     # Only 2 messages had msg sizes
     self.assertEqual(2, len(testBench.collInMsgSizes))
     self.assertEqual(2, len(testBench.collOutMsgSizes))
     # The sum of the sizes of all all_gather msgs is 2 for in and out
     self.assertEqual(2, sum(testBench.collInMsgSizes["all_gather"]))
     self.assertEqual(2, sum(testBench.collOutMsgSizes["all_gather"]))
     # Not dry run does not record comm blocks.
     self.assertEqual(0, len(testBench.comms_blocks["test_stack"]))