def prepComms(self, curComm, commsParams): commOp = paramToCommName(curComm["comms"]) if commOp in ("wait", "barrier"): return ([], []) # for all_to_allv, we can shrink the size if running on smaller scale # this is for sanity test or debug purpose only since we don't always get to run very large scale if self.shrink: cur_world_size = self.collectiveArgs.world_size real_world_size = cur_world_size if "world_size" in curComm.keys(): real_world_size = curComm["world_size"] else: # if the trace does not record world size, we may use a2av splits to infer it if commOp == "all_to_allv": in_split_len = len(curComm["in_split"]) out_split_len = len(curComm["out_split"]) if in_split_len > 0: real_world_size = in_split_len elif out_split_len > 0: real_world_size = out_split_len newNumElemsIn = (curComm["in_msg_size"] // real_world_size) * cur_world_size newNumElemsOut = ( curComm["out_msg_size"] // real_world_size ) * cur_world_size if commOp == "all_to_allv": curComm["out_split"] = ( curComm["out_split"][:cur_world_size] if ("out_split" in curComm.keys()) else [] ) curComm["in_split"] = ( curComm["in_split"][:cur_world_size] if ("in_split" in curComm.keys()) else [] ) if len(curComm["in_split"]) > 0: newNumElemsIn = sum(curComm["in_split"]) if len(curComm["out_split"]) > 0: newNumElemsOut = sum(curComm["out_split"]) elif commOp == "all_gather": newNumElemsOut = newNumElemsIn * cur_world_size curComm["in_msg_size"] = newNumElemsIn curComm["out_msg_size"] = newNumElemsOut logger.debug( f"shrink message sizes to curInNumElem {curComm['in_msg_size']}, curOutNumElem {curComm['out_msg_size']}" ) commsParams.dtype = self.strToTorchDtype[curComm["dtype"]] # allocate and return tensors return super().prepComm(curComm, commsParams)
def setBench(self, comms_world_info, commsParams): # init backend and corresponding function pointers if commsParams.nw_stack == "pytorch-dist": from pytorch_dist_backend import PyTorchDistBackend self.backendFuncs = PyTorchDistBackend(comms_world_info, commsParams) elif commsParams.nw_stack == "pytorch-xla-tpu": from pytorch_tpu_backend import PyTorchTPUBackend self.backendFuncs = PyTorchTPUBackend(comms_world_info, commsParams) else: logger.error("Unsopported NW stack! ") comms_utils.gracefulExit() self.backendFuncs.initialize_backend( comms_world_info.master_ip, comms_world_info.master_port, backend=commsParams.backend, ) self.backendFuncs.sayHello() # set basic collective info ( local_rank, global_rank, world_size, group, curDevice, curHwDevice, ) = comms_utils.get_rank_details( self.backendFuncs ) # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes self.collectiveArgs.group = group self.collectiveArgs.device = curDevice self.collectiveArgs.world_size = world_size self.collectiveArgs.global_rank = global_rank self.collectiveArgs.backendFuncs = self.backendFuncs # FIXME: 0 is a common case, need this info from trace for more accurate replay self.collectiveArgs.srcOrDst = 0 # FIXME: assuming it's always sum for reduce/allreduce operations self.collectiveArgs.op = self.backendFuncs.get_reduce_op("sum") # FIXME: alwasy perfom blocking comms; may study non-blocking in the future self.collectiveArgs.asyncOp = not self.is_blocking self.collectiveArgs.ipTensor = None self.collectiveArgs.opTensor = None self.collectiveArgs.quant_threshold = commsParams.quant_threshold # set of collectives to be replayed if self.allowList in ("all", "default", "*"): self.allowList = self.backendFuncs.collectiveFunc.keys() else: self.allowList = [paramToCommName(op) for op in self.allowList.split(",")]
def initTraceStat(self): maxInMsgsize = 0 maxOutMsgsize = 0 self.num_msg = len(self.comms_trace) self.max_msg_cnt = self.num_msg if self.max_msg_cnt == 0 else self.max_msg_cnt # first pass to know the statistics and get required info. for curComm in self.comms_trace[: self.max_msg_cnt]: # record the current comm collName = paramToCommName(curComm["comms"]) curBlocks = curComm["marker_stack"] if "marker_stack" in curComm else [] if collName not in self.collLat.keys(): self.collLat[collName] = [] # some ops don't have sizes if "in_msg_size" in curComm: self.collInMsgSizes[collName] = [] self.collInUniMsgSizes[collName] = set() self.collOutMsgSizes[collName] = [] self.collOutUniMsgSizes[collName] = set() if "in_msg_size" in curComm: self.collInMsgSizes[collName].append(curComm["in_msg_size"]) self.collInUniMsgSizes[collName].add(curComm["in_msg_size"]) self.collOutMsgSizes[collName].append(curComm["out_msg_size"]) self.collOutUniMsgSizes[collName].add(curComm["out_msg_size"]) maxInMsgsize = max(curComm["in_msg_size"], maxInMsgsize) maxOutMsgsize = max(curComm["out_msg_size"], maxOutMsgsize) # get info sorted by code block for curBlock in curBlocks: if curBlock not in self.comms_blocks: self.comms_blocks[curBlock] = [] # only add entries if on dry run, otherwise, we'll deal with later during replay w/ more info if self.is_dry_run: if collName not in ("wait", "barrier"): self.comms_blocks[curBlock].append( { "comms": collName, "in_msg_size": curComm["in_msg_size"], "out_msg_size": curComm["out_msg_size"], } ) else: self.comms_blocks[curBlock].append( { "comms": collName, } )
def benchTime(self, commsParams): """ The json format is expecting to be either { "marker_stack": ["## all2all ##"] "comms": "all_to_allv", "in_msg_size": 10357149, "out_msg_size": 23093760, "in_split": [], "out_split": [], "dtype": "Int" }, or w/o in/out_split { "marker_stack": ["## all2all ##"] "comms": "all_reduce", "in_msg_size": 1048576, "out_msg_size": 1048576, "dtype": "Int" } or wait/barrier { "marker_stack": ["## all2all ##"] "comms": "wait", } NOTE: - this format is subject to be changed anytime - the unit of all size fields is # of elements (not bytes) """ # warm-up if self.do_warm_up: self.warmUpBench(commsParams) # sync everything before starting real runs self.backendFuncs.sync_barrier(self.collectiveArgs) if self.backendFuncs.get_global_rank() == 0: print( f"\n+ {self.max_msg_cnt} messages in the trace...replaying (if present) {list(self.allowList)}" ) for coll, sizes in self.collInMsgSizes.items(): logger.info(f"\t{coll}: {len(sizes)}") coll_in_batch_num = 0 for cnt, curComm in enumerate(self.comms_trace[: self.max_msg_cnt]): collName = paramToCommName(curComm["comms"]) if collName not in self.allowList: continue curBlocks = curComm["marker_stack"] if "marker_stack" in curComm else [] curBlockStack = ( " ".join(curBlocks) if len(curBlocks) > 0 else "Unamed/Unknown" ) if self.backendFuncs.get_global_rank() == 0: logger.debug( f"[Rank {self.collectiveArgs.global_rank:3}] Replaying \n{str(curComm)}\n" ) print(f"[{cnt} / {self.max_msg_cnt}]", end="\r") # read fields and prepare the tensors ( self.collectiveArgs.ipTensor, self.collectiveArgs.opTensor, ) = self.prepComms(curComm, commsParams) if self.colls_per_batch > 0 and coll_in_batch_num == 0: batch_begin = time.monotonic() (latency, global_latency) = self.runComms(collName, curBlockStack) # calculating batch latency (batch defined by --colls-per-batch) if collName == "wait" and self.colls_per_batch > 0: coll_in_batch_num += 1 if coll_in_batch_num == self.colls_per_batch: batch_latency = ( time.monotonic() - batch_begin ) * 1e3 # make it millisecond coll_in_batch_num = 0 self.batchLat.append(batch_latency) self.collLat[collName].append(latency) curComm["seqnum"] = cnt curComm["latency_us"] = latency curComm["global_latency_us"] = global_latency curComm["quant_us"] = self.collectiveArgs.quant_time.getTimeUS() curComm["dequant_us"] = self.collectiveArgs.dequant_time.getTimeUS() self.totalCommsLatency += latency # Keep a copy of trace with performance (latency) and seqnum self.traceWithPerf.append(curComm) # categorized by the marker for curBlock in curBlocks: # elem_size = self.collectiveArgs.ipTensor.element_size() self.comms_blocks[curBlock].append(curComm) if self.backendFuncs.get_global_rank() == 0: logger.info( f"[{cnt} / {self.max_msg_cnt}] Replayed {collName} in block [{curBlockStack}]... {global_latency:.2f} us" ) # make sure all ops are completed self.backendFuncs.sync_barrier(self.collectiveArgs) self.backendFuncs.clear_memory(self.collectiveArgs)
def test_change(self): testName = "all12345to___a3l1l" # weird way of typing all_to_all result = comms_utils.paramToCommName(testName) self.assertEqual("all_to_all", result)
def test_no_change(self): testName = "all_to_all" result = comms_utils.paramToCommName(testName) self.assertEqual("all_to_all", result)