Ejemplo n.º 1
0
    def checkArgs(self, args):
        super().checkArgs(args)
        if args.collective not in supportedCollectives:
            print(
                "\t ERROR: Specified collective: %s is not one of the supported collectives: %s. Make sure the input is using the correct case."
                % (args.collective, str(supportedCollectives))
            )
            comms_utils.gracefulExit()

        if args.z not in supportedCommStyle:
            print(
                "\t ERROR: Specified blocking: %d is not one of the supported commstyle: %s"
                % (args.z, str(supportedCommStyle))
            )
            comms_utils.gracefulExit()

        args.b = comms_utils.parsesize(args.b)
        args.e = comms_utils.parsesize(args.e)
        args.dtype = self.dtypeMap[args.data_type]

        if args.b < 1:
            print("\t Starting size: %d should atleast be 1! " % (args.b))
            args.b = 1

        if args.e < args.b:
            print(
                "\t ERROR: In COMMS-mode, the begin-size: %d is larger than the end-size: %d "
                % (args.b, args.e)
            )

        if args.device == "cpu" and args.backend == "nccl":
            raise ValueError("NCCL is not supported for device type CPU")
Ejemplo n.º 2
0
    def checkArgs(self, args):
        super().checkArgs(args)

        if args.pt2pt is not None:
            args.collective = "pt2pt"
            if args.pt2pt not in pt2ptPatterns:
                print(
                    "\t ERROR: Specified pt2pt pattern: %d is not one of the supported pt2pt patterns: %s"
                    % (args.pt2pt, str(pt2ptPatterns)))
                comms_utils.gracefulExit()

        args.b = comms_utils.parsesize(args.b)
        args.e = comms_utils.parsesize(args.e)
        args.dtype = self.dtypeMap[args.data_type]

        if args.b < 1:
            print("\t Starting size: %d should atleast be 1! " % (args.b))
            args.b = 1

        if args.e < args.b:
            print(
                "\t ERROR: In COMMS-mode, the begin-size: %d is larger than the end-size: %d "
                % (args.b, args.e))

        if args.device == "cpu" and args.backend == "nccl":
            raise ValueError("NCCL is not supported for device type CPU")

        if args.c == 1 and args.z == 0:
            logging.warning(
                "Data validation is not supported for non-blocking mode...disable validation check and proceed..."
            )
            args.c = 0
Ejemplo n.º 3
0
    def checkArgs(self, args):
        super().checkArgs(args)

        if ((not self.use_remote_trace) and
            (path.exists(self.trace_file) is False or path.isfile(self.trace_file) is False)
        ):
            raise ValueError(
                f"Trace file {self.trace_file} not exist or not a file! Please specifiy the correct path using --trace-path"
            )
            comms_utils.gracefulExit()
Ejemplo n.º 4
0
    def setBench(self, comms_world_info, commsParams):
        # init backend and corresponding function pointers
        if commsParams.nw_stack == "pytorch-dist":
            from pytorch_dist_backend import PyTorchDistBackend

            self.backendFuncs = PyTorchDistBackend(comms_world_info, commsParams)
        elif commsParams.nw_stack == "pytorch-xla-tpu":
            from pytorch_tpu_backend import PyTorchTPUBackend

            self.backendFuncs = PyTorchTPUBackend(comms_world_info, commsParams)
        else:
            logger.error("Unsopported NW stack! ")
            comms_utils.gracefulExit()

        self.backendFuncs.initialize_backend(
            comms_world_info.master_ip,
            comms_world_info.master_port,
            backend=commsParams.backend,
        )
        self.backendFuncs.sayHello()

        # set basic collective info
        (
            local_rank,
            global_rank,
            world_size,
            group,
            curDevice,
            curHwDevice,
        ) = comms_utils.get_rank_details(
            self.backendFuncs
        )  # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes

        self.collectiveArgs.group = group
        self.collectiveArgs.device = curDevice
        self.collectiveArgs.world_size = world_size
        self.collectiveArgs.global_rank = global_rank
        self.collectiveArgs.backendFuncs = self.backendFuncs
        # FIXME:  0 is a common case, need this info from trace for more accurate replay
        self.collectiveArgs.srcOrDst = 0
        # FIXME: assuming it's always sum for reduce/allreduce operations
        self.collectiveArgs.op = self.backendFuncs.get_reduce_op("sum")
        # FIXME: alwasy perfom blocking comms; may study non-blocking in the future
        self.collectiveArgs.asyncOp = not self.is_blocking
        self.collectiveArgs.ipTensor = None
        self.collectiveArgs.opTensor = None
        self.collectiveArgs.quant_threshold = commsParams.quant_threshold

        # set of collectives to be replayed
        if self.allowList in ("all", "default", "*"):
            self.allowList = self.backendFuncs.collectiveFunc.keys()
        else:
            self.allowList = [paramToCommName(op) for op in self.allowList.split(",")]
Ejemplo n.º 5
0
    def checkArgs(self, args):
        super().checkArgs(args)

        if args.pt2pt is not None:
            args.collective = "pt2pt"
            if args.pt2pt not in pt2ptPatterns:
                logger.error(
                    f"Specified pt2pt pattern: {args.pt2pt} is not one of the supported pt2pt patterns: {str(pt2ptPatterns)}"
                )
                comms_utils.gracefulExit()

        args.b = comms_utils.parsesize(args.b)
        args.e = comms_utils.parsesize(args.e)
        args.dtype = self.dtypeMap[args.data_type]

        if args.b < 1:
            logger.warning(
                f"Starting size (--b {args.b}) should be greater than 1 byte...fix and continue"
            )
            args.b = 1

        if args.e < args.b:
            logger.warning(
                f"the begin-size (--b {args.b}) is larger than the end-size (--e {args.e})"
            )

        if args.device == "cpu" and args.backend == "nccl":
            raise ValueError(
                f"NCCL is not supported for device type {args.device}")

        if args.c == 1 and args.z == 0 and args.collective in (
                "all_reduce", "reduce", "reduce_scatter"):
            logger.warning(
                f"Data validation is not supported for {args.collective} in non-blocking mode, disabled and continue"
            )
            args.c = 0

        # run a few sanity checks
        if args.bitwidth < 32:
            if args.device != "cuda":
                logger.error(
                    f"collective quantization may not be fully supported for {args.device}"
                )
            comms_utils.checkQuantArgs(
                args.collective,
                args.dtype,
                args.b,
                args.quant_a2a_embedding_dim,
                args.z,
            )
Ejemplo n.º 6
0
    def runBench(self, comms_world_info, commsParams):
        # Init the desired backend
        if commsParams.nw_stack == "pytorch-nccl":
            from pytorch_dist_backend import PyTorchDistBackend

            backendObj = PyTorchDistBackend(comms_world_info, commsParams)
        elif commsParams.nw_stack == "pytorch-xla-tpu":
            from pytorch_tpu_backend import PyTorchTPUBackend

            backendObj = PyTorchTPUBackend(comms_world_info, commsParams)
        else:
            print("\t Error: Unsopported NW stack! ")
            comms_utils.gracefulExit()

        self.backendFuncs = backendObj
        backendObj.benchmark_comms()
        return
Ejemplo n.º 7
0
def runComms(comms_world_info, commsParams):
    # Run sanity checks.
    if commsParams.endSize < commsParams.beginSize:
        print(
            "\t ERROR: In COMMS-mode, the begin-size: %d is larger than the end-size: %d "
            % (commsParams.beginSize, commsParams.endSize))

    # Run-loop
    if commsParams.nw_stack == "pytorch-nccl":
        from pytorch_nccl_backend import PyTorchNCCLBackend

        backendObj = PyTorchNCCLBackend(comms_world_info, commsParams)
    elif commsParams.nw_stack == "pytorch-xla-tpu":
        from tpu_backend import PyTorchTPUBackend

        backendObj = PyTorchTPUBackend(comms_world_info, commsParams)
    else:
        print("\t Error: Unsopported NW stack! ")
        comms_utils.gracefulExit()

    backendObj.benchmark_comms()
    return
Ejemplo n.º 8
0
    def runBench(self, comms_world_info, commsParams):
        # Init the desired backend
        if commsParams.nw_stack == "pytorch-dist":
            from pytorch_dist_backend import PyTorchDistBackend

            backendObj = PyTorchDistBackend(comms_world_info, commsParams)
        elif commsParams.nw_stack == "pytorch-xla-tpu":
            from pytorch_tpu_backend import PyTorchTPUBackend

            backendObj = PyTorchTPUBackend(comms_world_info, commsParams)
        else:
            print("\t Error: Unsopported NW stack! ")
            comms_utils.gracefulExit()

        self.backendFuncs = backendObj
        try:
            backendObj.benchmark_comms()
        except ValueError as ve:
            if commsParams.backend == "ucc":
                logging.critical("PyTorch UCC not implemented? {}".format(
                    repr(ve)))
            raise
Ejemplo n.º 9
0
    def checkPt2PtRanks(self):
        # set default values
        if not self.collectiveArgs.src_ranks:
            self.collectiveArgs.src_ranks = [0]
        if not self.collectiveArgs.dst_ranks:
            self.collectiveArgs.dst_ranks = [1]

        # sanity check
        if self.collectiveArgs.pt2pt == "one2one":
            if (len(self.collectiveArgs.src_ranks) > 1
                    or len(self.collectiveArgs.dst_ranks) > 1):
                if self.global_rank == 0:
                    logger.error(
                        "One2one Pt2Pt requires only a single rank is specified in src_ranks and dst_ranks! "
                    )
                comms_utils.gracefulExit()
        elif self.collectiveArgs.pt2pt == "pairwise":
            # pairwise pt2pt requires identical number of ranks in src_ranks and dst_ranks.
            if len(self.collectiveArgs.src_ranks) != len(
                    self.collectiveArgs.dst_ranks):
                if self.global_rank == 0:
                    logger.error(
                        "Pairwise Pt2Pt requires identical number of members in src_ranks and dst_ranks! "
                    )
                comms_utils.gracefulExit()
            # pairwise pt2pt does not allow same rank to exist in both groups
            if bool(
                    set(self.collectiveArgs.src_ranks).intersection(
                        self.collectiveArgs.dst_ranks)):
                if self.global_rank == 0:
                    logger.error(
                        "Pairwise Pt2Pt requires distinct members in src_ranks and dst_ranks! "
                    )
                comms_utils.gracefulExit()

        if self.global_rank == 0:
            print(
                f"\t collective={self.collectiveArgs.collective}\t{self.collectiveArgs.pt2pt}, src_ranks={self.collectiveArgs.src_ranks}, dst_ranks={self.collectiveArgs.dst_ranks}"
            )
Ejemplo n.º 10
0
def initializeCollectiveArgs(commsParams, backendFuncs):
    # lint was complaining that benchTime was too complex!
    (
        local_rank,
        global_rank,
        world_size,
        group,
        curDevice,
    ) = comms_utils.get_rank_details(
        backendFuncs
    )  # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes.

    comms_utils.fixBeginSize(
        commsParams, world_size
    )  # Ensuring that all-reduce and all-to-all has atleast one member per rank.
    backendFuncs.sayHello()  # Informs us where each process is running.
    allSizes = comms_utils.getSizes(
        commsParams.beginSize, commsParams.endSize, commsParams.stepFactor
    )  # Given the begin-size, end-size, step-factor what are the message sizes to iterate on.

    if global_rank == 0:
        print(
            "\t global_rank: %d allSizes: %s local_rank: %d element_size: %d "
            % (global_rank, allSizes, local_rank, commsParams.element_size))
        print("\t global_rank: %d commsParams: %s " %
              (global_rank, commsParams))

    collectiveArgs = comms_utils.collectiveArgsHolder()
    collectiveArgs.group = group
    collectiveArgs.device = curDevice
    collectiveArgs.world_size = world_size
    collectiveArgs.numIters = commsParams.numIters
    collectiveArgs.numWarmupIters = commsParams.numWarmupIters
    collectiveArgs.global_rank = global_rank
    collectiveArgs.backendFuncs = backendFuncs
    collectiveArgs.srcOrDst = ""
    collectiveArgs.collective = commsParams.collective
    op = backendFuncs.get_reduce_op("sum")
    collectiveArgs.op = op
    collectiveArgs.dst = commsParams.dst

    computeFunc = None
    if commsParams.mode != "comms":  # Compute mode related initialization.
        if commsParams.kernel == "gemm":
            computeFunc = backendFuncs.gemm

            mm_dim = commsParams.mm_dim
            in1 = np.random.rand(mm_dim, mm_dim)
            MMin1 = torch.FloatTensor(in1).to(curDevice)
            in2 = np.random.rand(mm_dim, mm_dim)
            MMin2 = torch.FloatTensor(in2).to(curDevice)
            in3 = np.random.rand(mm_dim, mm_dim)
            MMin3 = torch.FloatTensor(in3).to(curDevice)
            MMout = backendFuncs.alloc_empty([mm_dim, mm_dim],
                                             commsParams.dtype, curDevice)
            collectiveArgs.MMout = MMout
            collectiveArgs.MMin1 = MMin1
            collectiveArgs.MMin2 = MMin2
            collectiveArgs.MMin3 = MMin3
            collectiveArgs.numComputePerColl = commsParams.num_compute
        else:
            print("Compute kernel " + commsParams.kernel +
                  " not supported...Abort!")
            comms_utils.gracefulExit()

    return (
        collectiveArgs,
        local_rank,
        global_rank,
        world_size,
        group,
        curDevice,
        allSizes,
        computeFunc,
    )
Ejemplo n.º 11
0
def main():
    ### import packages ###
    import sys
    import argparse

    supportedCommStyle = [0, 1]  # 0 : non-blocking, 1 : blocking.
    supportedCollectives = [
        "reduce",
        "all_reduce",
        "all_to_all",
        "all_to_allv",
    ]  # , "scatter", "gather", "all_gather", "broadcast", "all_to_all"]
    supportedNwstacks = ["pytorch-nccl", "pytorch-xla-tpu"]
    supported_tpu_core_valuses = [1, 8]
    dtypeMap = {
        "float32": torch.float32,
        "int32": torch.int32,
        "float16": torch.half,
        "float64": torch.double,
    }
    supportedDtype = list(dtypeMap.keys())

    ### parse arguments ###
    parser = argparse.ArgumentParser(
        description="PARAM-Comm Benchmark",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # experiment related parameters
    parser.add_argument(
        "--backend",
        type=str,
        default="nccl",
        help="The backend to be used in PyTorch distributed process group"
    )  # alternative is DLRM mode.
    parser.add_argument(
        "--mode", type=str, default="comms",
        help="benchmark mode")  # alternative is DLRM mode or comm-compute mode
    parser.add_argument("--b",
                        type=str,
                        default="8",
                        help="minimum size, in bytes, to start with"
                        )  # COMMS mode, begin the sweep at.
    parser.add_argument("--e",
                        type=str,
                        default="64",
                        help="maximum size, in bytes, to end at"
                        )  # COMMS mode, end the sweep at.
    parser.add_argument("--f",
                        type=int,
                        default=2,
                        help="multiplication factor between sizes"
                        )  # COMMS mode, multiplication factor.
    parser.add_argument("--z",
                        type=int,
                        default=1,
                        help="use blocking mode for collectives"
                        )  # 'sync/blocking' : 1 , 'async/non-blocking' : 0

    parser.add_argument(
        "--w", type=int, default=5,
        help="number of warmup iterations")  # number of warmup-iterations
    parser.add_argument("--n",
                        type=int,
                        default=5,
                        help="number of iterations")  # number of iterations
    parser.add_argument(
        "--collective",
        type=str,
        default="all_reduce",
        help='Collective to benchmark, supports ' +
        str(supportedCollectives))  # collective op to benchmark
    parser.add_argument(
        "--master-ip",
        type=str,
        default="127.0.0.1",
        help="The master-IP to coordinate")  # The master-IP to coordinate.
    parser.add_argument(
        "--master-port",
        type=str,
        default="29500",
        help="The master-port to coordinate")  # The master-port to coordinate.
    parser.add_argument(
        "--nw-stack",
        type=str,
        default="pytorch-nccl",
        help="network stack to be used, supports " +
        str(supportedNwstacks))  # The network stack to profile.
    parser.add_argument(
        "--dtype", type=torch.dtype, default=torch.float32
    )  # will be overwritten based on args.data_type and dtypeMap.
    parser.add_argument("--data-type",
                        type=str,
                        default="float32",
                        help="the base data type, supports " +
                        str(supportedDtype))  # The data type

    parser.add_argument(
        "--num-tpu-cores",
        type=int,
        default=1,
        help="number of TPU cores to be used")  # number of TPU cores

    # For comm-compute or compute mode
    parser.add_argument("--kernel",
                        type=str,
                        default="gemm",
                        help="compute kernel")  # Compute kernel: "gemm"
    parser.add_argument(
        "--num-compute",
        type=int,
        default=100,
        help="one collective for every NUM_COMPUTE compute kernels"
    )  # Launch one coll for every n compute kernels
    # For GEMM
    parser.add_argument("--mm-dim",
                        type=int,
                        default=100,
                        help="dimension size for GEMM compute kernel"
                        )  # Matrix multiplication dim n, A[n,n] * B [n,n]
    # For emb lookup
    parser.add_argument(
        "--emb-dim",
        type=int,
        default=128,
        help="dimension size for Embedding table compute kernel"
    )  # Embedding table dimension
    parser.add_argument(
        "--num-embs",
        type=int,
        default=100000,
        help="Embedding table hash size for Embedding table compute kernel"
    )  # Embedding table hash size
    parser.add_argument("--avg-len",
                        type=int,
                        default=28,
                        help="Average lookup operations per sample"
                        )  # Average #lookup per sample
    parser.add_argument("--batch-size",
                        type=int,
                        default=512,
                        help="number of samples reading the table concurrently"
                        )  # #Samples reading the table concurrently
    parser.add_argument(
        "--root",
        type=int,
        default=0,
        help="root process for reduce benchmark"
    )  # root process for reduce (and gather, scatter, bcast, etc., if support in the future)
    # TODO: check the correctness of root, should be between 0 to [world_size -1]

    args, leftovers = parser.parse_known_args()
    args.b = comms_utils.parsesize(args.b)
    args.e = comms_utils.parsesize(args.e)

    if args.nw_stack not in supportedNwstacks:
        print(
            "\t ERROR: Specified backend: %s is not one of the supported backends: %s. Make sure the input is using the correct case."
            % (args.nw_stack, str(supportedNwstacks)))
        sys.exit(
        )  # WARNING: Assuming sys is always used, should find a platform-independent way to gracefully exit.

    if args.collective not in supportedCollectives:
        print(
            "\t ERROR: Specified collective: %s is not one of the supported collectives: %s. Make sure the input is using the correct case."
            % (args.collective, str(supportedCollectives)))
        sys.exit(
        )  # WARNING: Assuming sys is always used, should find a platform-independent way to gracefully exit.

    if args.z not in supportedCommStyle:
        print(
            "\t ERROR: Specified blocking: %d is not one of the supported commstyle: %s"
            % (args.z, str(supportedCommStyle)))
        comms_utils.gracefulExit()

    if args.data_type not in supportedDtype:
        print(
            "\t ERROR: Specified dtype: %d is not one of the supported commstyle: %s"
            % (args.data_type, str(supportedDtype)))
        comms_utils.gracefulExit()

    args.dtype = dtypeMap[args.data_type]

    mpi_env_params = comms_utils.read_mpi_env_vars()
    if mpi_env_params["global_rank"] == 0:
        print("\t MPI environment: %s " % (str(mpi_env_params)))
        print(
            "\t backend: %s nw-stack: %s mode: %s args.b: %d args.e: %d args.f: %d args.z: %s args.master_ip: %s "
            % (
                args.backend,
                args.nw_stack,
                args.mode,
                args.b,
                args.e,
                args.f,
                args.z,
                args.master_ip,
            ))

    if args.num_tpu_cores not in supported_tpu_core_valuses:
        print(
            "\t ERROR: TPU core value: %d is not one of the supported values: %s "
            % (args.num_tpu_cores, supported_tpu_core_valuses))
        comms_utils.gracefulExit()

    if args.b < 1:
        print("\t Starting size: %d should atleast be 1! " % (args.b))
        args.b = 1

    element_size = torch.ones([1], dtype=args.dtype).element_size()
    comms_world_info = comms_utils.comms_world_info_holder(
        args.master_ip, args.master_port, args.num_tpu_cores, mpi_env_params)
    commsParams = comms_utils.commsParamsHolder(args, element_size, benchTime)
    runComms(comms_world_info, commsParams)
Ejemplo n.º 12
0
    def benchTime(self, index, commsParams, backendFuncs):
        # Get NW stack specific parameters
        (
            local_rank,
            global_rank,
            world_size,
            group,
            curDevice,
            curHwDevice,
            allSizes,
            computeFunc,
        ) = self.initCollectiveArgs(commsParams)

        results = {}
        timeElapsedList = []
        for curSize in allSizes:
            # Allocating memory.
            numElements = int(curSize // commsParams.element_size)
            scaleFactor = numElements * numElements
            if commsParams.collective == "all_to_all":
                # numElements = int(numElements // world_size)  # assuming that world_size won't be zero!
                scaleFactor = 1
            ipTensor = backendFuncs.alloc_random([numElements], curDevice,
                                                 commsParams.dtype,
                                                 scaleFactor)

            opTensor = ipTensor
            # ignoring all_gather, scatter-gather, for now # FUTURE-TODO- make interface accept scatter and gather list.
            asyncOp = True
            collectiveFunc = None

            if (commsParams.blockingFlag == 1
                ):  # if blockingFlag is 1, it means asyncOp should be false.
                asyncOp = False

            if commsParams.mode != "compute":  # comms specific initializations
                if commsParams.collective == "all_reduce":
                    collectiveFunc = backendFuncs.all_reduce

                elif commsParams.collective == "all_to_all":
                    opTensor = backendFuncs.alloc_empty([numElements],
                                                        commsParams.dtype,
                                                        curDevice)
                    collectiveFunc = backendFuncs.all_to_all

                elif commsParams.collective == "all_to_allv":
                    opTensor = backendFuncs.alloc_empty([numElements],
                                                        commsParams.dtype,
                                                        curDevice)
                    self.collectiveArgs.ipTensor_split = [
                        int(numElements // world_size)
                        for i in range(world_size)
                    ]
                    self.collectiveArgs.opTensor_split = [
                        int(numElements // world_size)
                        for i in range(world_size)
                    ]
                    collectiveFunc = backendFuncs.all_to_allv

                elif commsParams.collective == "reduce":
                    collectiveFunc = backendFuncs.reduce
                else:
                    print("This should not happen")
                    gracefulExit()

            # Setup the arguments.
            self.collectiveArgs.ipTensor = ipTensor
            self.collectiveArgs.opTensor = opTensor
            self.collectiveArgs.asyncOp = asyncOp
            self.collectiveArgs.dataSize = curSize
            self.collectiveArgs.numElements = numElements
            self.collectiveArgs.waitObj = []

            # self.collectiveArgs has all the information on the experiment.
            timeElapsedNS, algBW, busBW, memSize, x = self.runColl(
                comm_fn=collectiveFunc, compute_fn=computeFunc)

            results[curSize] = {}
            results[curSize]["timeUS"] = timeElapsedNS / 1e3
            timeElapsedList.append(
                results[curSize]["timeUS"]
            )  # assuming that order is known at each rank, so it's OK to not identify it by message-size
            results[curSize]["algBW"] = algBW
            results[curSize]["busBW"] = busBW
            results[curSize]["memSize"] = memSize
            if (commsParams.collective
                    == "all_to_all") or (commsParams.collective
                                         == "all_to_allv"):
                results[curSize]["num_elements"] = int(numElements //
                                                       world_size)
            else:
                results[curSize]["num_elements"] = int(numElements)
            results[curSize]["x"] = x

            del ipTensor
            del opTensor
            backendFuncs.clear_memory()
            self.backendFuncs.barrier(self.collectiveArgs, "curSize")

        # Push the list to device, then do an all-gather.
        timeElapsedTensor = torch.tensor(timeElapsedList, device=curDevice)
        if not commsParams.backend == "xla":
            tensorList = [
                torch.ones_like(timeElapsedTensor) for _ in range(world_size)
            ]
            self.collectiveArgs.tensorList = tensorList

        self.collectiveArgs.ipTensor = timeElapsedTensor
        self.collectiveArgs.asyncOp = False
        self.collectiveArgs.dataSize = (timeElapsedTensor.nelement() *
                                        timeElapsedTensor.element_size())
        self.collectiveArgs.numElements = timeElapsedTensor.nelement()

        if self.collectiveArgs.reducescatter_allgather_qcomm is not None:
            try:
                logging.warning("Removing installed quantization handlers.")
                from internals import remove_quantization_handlers

                remove_quantization_handlers(self.collectiveArgs)
            except ImportError:
                pass
            finally:
                assert self.collectiveArgs.reducescatter_allgather_qcomm is None

        self.collectiveArgs.waitObj.append(
            backendFuncs.all_gather(self.collectiveArgs, retFlag=True))
        backendFuncs.complete_accel_ops(self.collectiveArgs)

        if global_rank == 0:
            if commsParams.backend == "xla":
                self.reportBenchTime(commsParams, allSizes,
                                     self.collectiveArgs.opTensor, results)
            else:
                self.reportBenchTime(commsParams, allSizes,
                                     self.collectiveArgs.tensorList, results)

        # wait rank 0 reports results to avoid other ranks mess up the output
        self.backendFuncs.barrier(self.collectiveArgs, "benchtime")