def checkArgs(self, args): super().checkArgs(args) if args.collective not in supportedCollectives: print( "\t ERROR: Specified collective: %s is not one of the supported collectives: %s. Make sure the input is using the correct case." % (args.collective, str(supportedCollectives)) ) comms_utils.gracefulExit() if args.z not in supportedCommStyle: print( "\t ERROR: Specified blocking: %d is not one of the supported commstyle: %s" % (args.z, str(supportedCommStyle)) ) comms_utils.gracefulExit() args.b = comms_utils.parsesize(args.b) args.e = comms_utils.parsesize(args.e) args.dtype = self.dtypeMap[args.data_type] if args.b < 1: print("\t Starting size: %d should atleast be 1! " % (args.b)) args.b = 1 if args.e < args.b: print( "\t ERROR: In COMMS-mode, the begin-size: %d is larger than the end-size: %d " % (args.b, args.e) ) if args.device == "cpu" and args.backend == "nccl": raise ValueError("NCCL is not supported for device type CPU")
def checkArgs(self, args): super().checkArgs(args) if args.pt2pt is not None: args.collective = "pt2pt" if args.pt2pt not in pt2ptPatterns: print( "\t ERROR: Specified pt2pt pattern: %d is not one of the supported pt2pt patterns: %s" % (args.pt2pt, str(pt2ptPatterns))) comms_utils.gracefulExit() args.b = comms_utils.parsesize(args.b) args.e = comms_utils.parsesize(args.e) args.dtype = self.dtypeMap[args.data_type] if args.b < 1: print("\t Starting size: %d should atleast be 1! " % (args.b)) args.b = 1 if args.e < args.b: print( "\t ERROR: In COMMS-mode, the begin-size: %d is larger than the end-size: %d " % (args.b, args.e)) if args.device == "cpu" and args.backend == "nccl": raise ValueError("NCCL is not supported for device type CPU") if args.c == 1 and args.z == 0: logging.warning( "Data validation is not supported for non-blocking mode...disable validation check and proceed..." ) args.c = 0
def checkArgs(self, args): super().checkArgs(args) if ((not self.use_remote_trace) and (path.exists(self.trace_file) is False or path.isfile(self.trace_file) is False) ): raise ValueError( f"Trace file {self.trace_file} not exist or not a file! Please specifiy the correct path using --trace-path" ) comms_utils.gracefulExit()
def setBench(self, comms_world_info, commsParams): # init backend and corresponding function pointers if commsParams.nw_stack == "pytorch-dist": from pytorch_dist_backend import PyTorchDistBackend self.backendFuncs = PyTorchDistBackend(comms_world_info, commsParams) elif commsParams.nw_stack == "pytorch-xla-tpu": from pytorch_tpu_backend import PyTorchTPUBackend self.backendFuncs = PyTorchTPUBackend(comms_world_info, commsParams) else: logger.error("Unsopported NW stack! ") comms_utils.gracefulExit() self.backendFuncs.initialize_backend( comms_world_info.master_ip, comms_world_info.master_port, backend=commsParams.backend, ) self.backendFuncs.sayHello() # set basic collective info ( local_rank, global_rank, world_size, group, curDevice, curHwDevice, ) = comms_utils.get_rank_details( self.backendFuncs ) # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes self.collectiveArgs.group = group self.collectiveArgs.device = curDevice self.collectiveArgs.world_size = world_size self.collectiveArgs.global_rank = global_rank self.collectiveArgs.backendFuncs = self.backendFuncs # FIXME: 0 is a common case, need this info from trace for more accurate replay self.collectiveArgs.srcOrDst = 0 # FIXME: assuming it's always sum for reduce/allreduce operations self.collectiveArgs.op = self.backendFuncs.get_reduce_op("sum") # FIXME: alwasy perfom blocking comms; may study non-blocking in the future self.collectiveArgs.asyncOp = not self.is_blocking self.collectiveArgs.ipTensor = None self.collectiveArgs.opTensor = None self.collectiveArgs.quant_threshold = commsParams.quant_threshold # set of collectives to be replayed if self.allowList in ("all", "default", "*"): self.allowList = self.backendFuncs.collectiveFunc.keys() else: self.allowList = [paramToCommName(op) for op in self.allowList.split(",")]
def checkArgs(self, args): super().checkArgs(args) if args.pt2pt is not None: args.collective = "pt2pt" if args.pt2pt not in pt2ptPatterns: logger.error( f"Specified pt2pt pattern: {args.pt2pt} is not one of the supported pt2pt patterns: {str(pt2ptPatterns)}" ) comms_utils.gracefulExit() args.b = comms_utils.parsesize(args.b) args.e = comms_utils.parsesize(args.e) args.dtype = self.dtypeMap[args.data_type] if args.b < 1: logger.warning( f"Starting size (--b {args.b}) should be greater than 1 byte...fix and continue" ) args.b = 1 if args.e < args.b: logger.warning( f"the begin-size (--b {args.b}) is larger than the end-size (--e {args.e})" ) if args.device == "cpu" and args.backend == "nccl": raise ValueError( f"NCCL is not supported for device type {args.device}") if args.c == 1 and args.z == 0 and args.collective in ( "all_reduce", "reduce", "reduce_scatter"): logger.warning( f"Data validation is not supported for {args.collective} in non-blocking mode, disabled and continue" ) args.c = 0 # run a few sanity checks if args.bitwidth < 32: if args.device != "cuda": logger.error( f"collective quantization may not be fully supported for {args.device}" ) comms_utils.checkQuantArgs( args.collective, args.dtype, args.b, args.quant_a2a_embedding_dim, args.z, )
def runBench(self, comms_world_info, commsParams): # Init the desired backend if commsParams.nw_stack == "pytorch-nccl": from pytorch_dist_backend import PyTorchDistBackend backendObj = PyTorchDistBackend(comms_world_info, commsParams) elif commsParams.nw_stack == "pytorch-xla-tpu": from pytorch_tpu_backend import PyTorchTPUBackend backendObj = PyTorchTPUBackend(comms_world_info, commsParams) else: print("\t Error: Unsopported NW stack! ") comms_utils.gracefulExit() self.backendFuncs = backendObj backendObj.benchmark_comms() return
def runComms(comms_world_info, commsParams): # Run sanity checks. if commsParams.endSize < commsParams.beginSize: print( "\t ERROR: In COMMS-mode, the begin-size: %d is larger than the end-size: %d " % (commsParams.beginSize, commsParams.endSize)) # Run-loop if commsParams.nw_stack == "pytorch-nccl": from pytorch_nccl_backend import PyTorchNCCLBackend backendObj = PyTorchNCCLBackend(comms_world_info, commsParams) elif commsParams.nw_stack == "pytorch-xla-tpu": from tpu_backend import PyTorchTPUBackend backendObj = PyTorchTPUBackend(comms_world_info, commsParams) else: print("\t Error: Unsopported NW stack! ") comms_utils.gracefulExit() backendObj.benchmark_comms() return
def runBench(self, comms_world_info, commsParams): # Init the desired backend if commsParams.nw_stack == "pytorch-dist": from pytorch_dist_backend import PyTorchDistBackend backendObj = PyTorchDistBackend(comms_world_info, commsParams) elif commsParams.nw_stack == "pytorch-xla-tpu": from pytorch_tpu_backend import PyTorchTPUBackend backendObj = PyTorchTPUBackend(comms_world_info, commsParams) else: print("\t Error: Unsopported NW stack! ") comms_utils.gracefulExit() self.backendFuncs = backendObj try: backendObj.benchmark_comms() except ValueError as ve: if commsParams.backend == "ucc": logging.critical("PyTorch UCC not implemented? {}".format( repr(ve))) raise
def checkPt2PtRanks(self): # set default values if not self.collectiveArgs.src_ranks: self.collectiveArgs.src_ranks = [0] if not self.collectiveArgs.dst_ranks: self.collectiveArgs.dst_ranks = [1] # sanity check if self.collectiveArgs.pt2pt == "one2one": if (len(self.collectiveArgs.src_ranks) > 1 or len(self.collectiveArgs.dst_ranks) > 1): if self.global_rank == 0: logger.error( "One2one Pt2Pt requires only a single rank is specified in src_ranks and dst_ranks! " ) comms_utils.gracefulExit() elif self.collectiveArgs.pt2pt == "pairwise": # pairwise pt2pt requires identical number of ranks in src_ranks and dst_ranks. if len(self.collectiveArgs.src_ranks) != len( self.collectiveArgs.dst_ranks): if self.global_rank == 0: logger.error( "Pairwise Pt2Pt requires identical number of members in src_ranks and dst_ranks! " ) comms_utils.gracefulExit() # pairwise pt2pt does not allow same rank to exist in both groups if bool( set(self.collectiveArgs.src_ranks).intersection( self.collectiveArgs.dst_ranks)): if self.global_rank == 0: logger.error( "Pairwise Pt2Pt requires distinct members in src_ranks and dst_ranks! " ) comms_utils.gracefulExit() if self.global_rank == 0: print( f"\t collective={self.collectiveArgs.collective}\t{self.collectiveArgs.pt2pt}, src_ranks={self.collectiveArgs.src_ranks}, dst_ranks={self.collectiveArgs.dst_ranks}" )
def initializeCollectiveArgs(commsParams, backendFuncs): # lint was complaining that benchTime was too complex! ( local_rank, global_rank, world_size, group, curDevice, ) = comms_utils.get_rank_details( backendFuncs ) # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes. comms_utils.fixBeginSize( commsParams, world_size ) # Ensuring that all-reduce and all-to-all has atleast one member per rank. backendFuncs.sayHello() # Informs us where each process is running. allSizes = comms_utils.getSizes( commsParams.beginSize, commsParams.endSize, commsParams.stepFactor ) # Given the begin-size, end-size, step-factor what are the message sizes to iterate on. if global_rank == 0: print( "\t global_rank: %d allSizes: %s local_rank: %d element_size: %d " % (global_rank, allSizes, local_rank, commsParams.element_size)) print("\t global_rank: %d commsParams: %s " % (global_rank, commsParams)) collectiveArgs = comms_utils.collectiveArgsHolder() collectiveArgs.group = group collectiveArgs.device = curDevice collectiveArgs.world_size = world_size collectiveArgs.numIters = commsParams.numIters collectiveArgs.numWarmupIters = commsParams.numWarmupIters collectiveArgs.global_rank = global_rank collectiveArgs.backendFuncs = backendFuncs collectiveArgs.srcOrDst = "" collectiveArgs.collective = commsParams.collective op = backendFuncs.get_reduce_op("sum") collectiveArgs.op = op collectiveArgs.dst = commsParams.dst computeFunc = None if commsParams.mode != "comms": # Compute mode related initialization. if commsParams.kernel == "gemm": computeFunc = backendFuncs.gemm mm_dim = commsParams.mm_dim in1 = np.random.rand(mm_dim, mm_dim) MMin1 = torch.FloatTensor(in1).to(curDevice) in2 = np.random.rand(mm_dim, mm_dim) MMin2 = torch.FloatTensor(in2).to(curDevice) in3 = np.random.rand(mm_dim, mm_dim) MMin3 = torch.FloatTensor(in3).to(curDevice) MMout = backendFuncs.alloc_empty([mm_dim, mm_dim], commsParams.dtype, curDevice) collectiveArgs.MMout = MMout collectiveArgs.MMin1 = MMin1 collectiveArgs.MMin2 = MMin2 collectiveArgs.MMin3 = MMin3 collectiveArgs.numComputePerColl = commsParams.num_compute else: print("Compute kernel " + commsParams.kernel + " not supported...Abort!") comms_utils.gracefulExit() return ( collectiveArgs, local_rank, global_rank, world_size, group, curDevice, allSizes, computeFunc, )
def main(): ### import packages ### import sys import argparse supportedCommStyle = [0, 1] # 0 : non-blocking, 1 : blocking. supportedCollectives = [ "reduce", "all_reduce", "all_to_all", "all_to_allv", ] # , "scatter", "gather", "all_gather", "broadcast", "all_to_all"] supportedNwstacks = ["pytorch-nccl", "pytorch-xla-tpu"] supported_tpu_core_valuses = [1, 8] dtypeMap = { "float32": torch.float32, "int32": torch.int32, "float16": torch.half, "float64": torch.double, } supportedDtype = list(dtypeMap.keys()) ### parse arguments ### parser = argparse.ArgumentParser( description="PARAM-Comm Benchmark", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # experiment related parameters parser.add_argument( "--backend", type=str, default="nccl", help="The backend to be used in PyTorch distributed process group" ) # alternative is DLRM mode. parser.add_argument( "--mode", type=str, default="comms", help="benchmark mode") # alternative is DLRM mode or comm-compute mode parser.add_argument("--b", type=str, default="8", help="minimum size, in bytes, to start with" ) # COMMS mode, begin the sweep at. parser.add_argument("--e", type=str, default="64", help="maximum size, in bytes, to end at" ) # COMMS mode, end the sweep at. parser.add_argument("--f", type=int, default=2, help="multiplication factor between sizes" ) # COMMS mode, multiplication factor. parser.add_argument("--z", type=int, default=1, help="use blocking mode for collectives" ) # 'sync/blocking' : 1 , 'async/non-blocking' : 0 parser.add_argument( "--w", type=int, default=5, help="number of warmup iterations") # number of warmup-iterations parser.add_argument("--n", type=int, default=5, help="number of iterations") # number of iterations parser.add_argument( "--collective", type=str, default="all_reduce", help='Collective to benchmark, supports ' + str(supportedCollectives)) # collective op to benchmark parser.add_argument( "--master-ip", type=str, default="127.0.0.1", help="The master-IP to coordinate") # The master-IP to coordinate. parser.add_argument( "--master-port", type=str, default="29500", help="The master-port to coordinate") # The master-port to coordinate. parser.add_argument( "--nw-stack", type=str, default="pytorch-nccl", help="network stack to be used, supports " + str(supportedNwstacks)) # The network stack to profile. parser.add_argument( "--dtype", type=torch.dtype, default=torch.float32 ) # will be overwritten based on args.data_type and dtypeMap. parser.add_argument("--data-type", type=str, default="float32", help="the base data type, supports " + str(supportedDtype)) # The data type parser.add_argument( "--num-tpu-cores", type=int, default=1, help="number of TPU cores to be used") # number of TPU cores # For comm-compute or compute mode parser.add_argument("--kernel", type=str, default="gemm", help="compute kernel") # Compute kernel: "gemm" parser.add_argument( "--num-compute", type=int, default=100, help="one collective for every NUM_COMPUTE compute kernels" ) # Launch one coll for every n compute kernels # For GEMM parser.add_argument("--mm-dim", type=int, default=100, help="dimension size for GEMM compute kernel" ) # Matrix multiplication dim n, A[n,n] * B [n,n] # For emb lookup parser.add_argument( "--emb-dim", type=int, default=128, help="dimension size for Embedding table compute kernel" ) # Embedding table dimension parser.add_argument( "--num-embs", type=int, default=100000, help="Embedding table hash size for Embedding table compute kernel" ) # Embedding table hash size parser.add_argument("--avg-len", type=int, default=28, help="Average lookup operations per sample" ) # Average #lookup per sample parser.add_argument("--batch-size", type=int, default=512, help="number of samples reading the table concurrently" ) # #Samples reading the table concurrently parser.add_argument( "--root", type=int, default=0, help="root process for reduce benchmark" ) # root process for reduce (and gather, scatter, bcast, etc., if support in the future) # TODO: check the correctness of root, should be between 0 to [world_size -1] args, leftovers = parser.parse_known_args() args.b = comms_utils.parsesize(args.b) args.e = comms_utils.parsesize(args.e) if args.nw_stack not in supportedNwstacks: print( "\t ERROR: Specified backend: %s is not one of the supported backends: %s. Make sure the input is using the correct case." % (args.nw_stack, str(supportedNwstacks))) sys.exit( ) # WARNING: Assuming sys is always used, should find a platform-independent way to gracefully exit. if args.collective not in supportedCollectives: print( "\t ERROR: Specified collective: %s is not one of the supported collectives: %s. Make sure the input is using the correct case." % (args.collective, str(supportedCollectives))) sys.exit( ) # WARNING: Assuming sys is always used, should find a platform-independent way to gracefully exit. if args.z not in supportedCommStyle: print( "\t ERROR: Specified blocking: %d is not one of the supported commstyle: %s" % (args.z, str(supportedCommStyle))) comms_utils.gracefulExit() if args.data_type not in supportedDtype: print( "\t ERROR: Specified dtype: %d is not one of the supported commstyle: %s" % (args.data_type, str(supportedDtype))) comms_utils.gracefulExit() args.dtype = dtypeMap[args.data_type] mpi_env_params = comms_utils.read_mpi_env_vars() if mpi_env_params["global_rank"] == 0: print("\t MPI environment: %s " % (str(mpi_env_params))) print( "\t backend: %s nw-stack: %s mode: %s args.b: %d args.e: %d args.f: %d args.z: %s args.master_ip: %s " % ( args.backend, args.nw_stack, args.mode, args.b, args.e, args.f, args.z, args.master_ip, )) if args.num_tpu_cores not in supported_tpu_core_valuses: print( "\t ERROR: TPU core value: %d is not one of the supported values: %s " % (args.num_tpu_cores, supported_tpu_core_valuses)) comms_utils.gracefulExit() if args.b < 1: print("\t Starting size: %d should atleast be 1! " % (args.b)) args.b = 1 element_size = torch.ones([1], dtype=args.dtype).element_size() comms_world_info = comms_utils.comms_world_info_holder( args.master_ip, args.master_port, args.num_tpu_cores, mpi_env_params) commsParams = comms_utils.commsParamsHolder(args, element_size, benchTime) runComms(comms_world_info, commsParams)
def benchTime(self, index, commsParams, backendFuncs): # Get NW stack specific parameters ( local_rank, global_rank, world_size, group, curDevice, curHwDevice, allSizes, computeFunc, ) = self.initCollectiveArgs(commsParams) results = {} timeElapsedList = [] for curSize in allSizes: # Allocating memory. numElements = int(curSize // commsParams.element_size) scaleFactor = numElements * numElements if commsParams.collective == "all_to_all": # numElements = int(numElements // world_size) # assuming that world_size won't be zero! scaleFactor = 1 ipTensor = backendFuncs.alloc_random([numElements], curDevice, commsParams.dtype, scaleFactor) opTensor = ipTensor # ignoring all_gather, scatter-gather, for now # FUTURE-TODO- make interface accept scatter and gather list. asyncOp = True collectiveFunc = None if (commsParams.blockingFlag == 1 ): # if blockingFlag is 1, it means asyncOp should be false. asyncOp = False if commsParams.mode != "compute": # comms specific initializations if commsParams.collective == "all_reduce": collectiveFunc = backendFuncs.all_reduce elif commsParams.collective == "all_to_all": opTensor = backendFuncs.alloc_empty([numElements], commsParams.dtype, curDevice) collectiveFunc = backendFuncs.all_to_all elif commsParams.collective == "all_to_allv": opTensor = backendFuncs.alloc_empty([numElements], commsParams.dtype, curDevice) self.collectiveArgs.ipTensor_split = [ int(numElements // world_size) for i in range(world_size) ] self.collectiveArgs.opTensor_split = [ int(numElements // world_size) for i in range(world_size) ] collectiveFunc = backendFuncs.all_to_allv elif commsParams.collective == "reduce": collectiveFunc = backendFuncs.reduce else: print("This should not happen") gracefulExit() # Setup the arguments. self.collectiveArgs.ipTensor = ipTensor self.collectiveArgs.opTensor = opTensor self.collectiveArgs.asyncOp = asyncOp self.collectiveArgs.dataSize = curSize self.collectiveArgs.numElements = numElements self.collectiveArgs.waitObj = [] # self.collectiveArgs has all the information on the experiment. timeElapsedNS, algBW, busBW, memSize, x = self.runColl( comm_fn=collectiveFunc, compute_fn=computeFunc) results[curSize] = {} results[curSize]["timeUS"] = timeElapsedNS / 1e3 timeElapsedList.append( results[curSize]["timeUS"] ) # assuming that order is known at each rank, so it's OK to not identify it by message-size results[curSize]["algBW"] = algBW results[curSize]["busBW"] = busBW results[curSize]["memSize"] = memSize if (commsParams.collective == "all_to_all") or (commsParams.collective == "all_to_allv"): results[curSize]["num_elements"] = int(numElements // world_size) else: results[curSize]["num_elements"] = int(numElements) results[curSize]["x"] = x del ipTensor del opTensor backendFuncs.clear_memory() self.backendFuncs.barrier(self.collectiveArgs, "curSize") # Push the list to device, then do an all-gather. timeElapsedTensor = torch.tensor(timeElapsedList, device=curDevice) if not commsParams.backend == "xla": tensorList = [ torch.ones_like(timeElapsedTensor) for _ in range(world_size) ] self.collectiveArgs.tensorList = tensorList self.collectiveArgs.ipTensor = timeElapsedTensor self.collectiveArgs.asyncOp = False self.collectiveArgs.dataSize = (timeElapsedTensor.nelement() * timeElapsedTensor.element_size()) self.collectiveArgs.numElements = timeElapsedTensor.nelement() if self.collectiveArgs.reducescatter_allgather_qcomm is not None: try: logging.warning("Removing installed quantization handlers.") from internals import remove_quantization_handlers remove_quantization_handlers(self.collectiveArgs) except ImportError: pass finally: assert self.collectiveArgs.reducescatter_allgather_qcomm is None self.collectiveArgs.waitObj.append( backendFuncs.all_gather(self.collectiveArgs, retFlag=True)) backendFuncs.complete_accel_ops(self.collectiveArgs) if global_rank == 0: if commsParams.backend == "xla": self.reportBenchTime(commsParams, allSizes, self.collectiveArgs.opTensor, results) else: self.reportBenchTime(commsParams, allSizes, self.collectiveArgs.tensorList, results) # wait rank 0 reports results to avoid other ranks mess up the output self.backendFuncs.barrier(self.collectiveArgs, "benchtime")