コード例 #1
0
ファイル: commsTraceReplay.py プロジェクト: louisfeng/param
    def prepComms(self, curComm, commsParams):
        commOp = paramToCommName(curComm["comms"])
        if commOp in ("wait", "barrier"):
            return ([], [])

        # for all_to_allv, we can shrink the size if running on smaller scale
        # this is for sanity test or debug purpose only since we don't always get to run very large scale
        if self.shrink:
            cur_world_size = self.collectiveArgs.world_size
            real_world_size = cur_world_size

            if "world_size" in curComm.keys():
                real_world_size = curComm["world_size"]
            else:
                # if the trace does not record world size, we may use a2av splits to infer it
                if commOp == "all_to_allv":
                    in_split_len = len(curComm["in_split"])
                    out_split_len = len(curComm["out_split"])
                    if in_split_len > 0:
                        real_world_size = in_split_len
                    elif out_split_len > 0:
                        real_world_size = out_split_len

            newNumElemsIn = (curComm["in_msg_size"] // real_world_size) * cur_world_size
            newNumElemsOut = (
                curComm["out_msg_size"] // real_world_size
            ) * cur_world_size

            if commOp == "all_to_allv":
                curComm["out_split"] = (
                    curComm["out_split"][:cur_world_size]
                    if ("out_split" in curComm.keys())
                    else []
                )
                curComm["in_split"] = (
                    curComm["in_split"][:cur_world_size]
                    if ("in_split" in curComm.keys())
                    else []
                )
                if len(curComm["in_split"]) > 0:
                    newNumElemsIn = sum(curComm["in_split"])
                if len(curComm["out_split"]) > 0:
                    newNumElemsOut = sum(curComm["out_split"])
            elif commOp == "all_gather":
                newNumElemsOut = newNumElemsIn * cur_world_size

            curComm["in_msg_size"] = newNumElemsIn
            curComm["out_msg_size"] = newNumElemsOut

            logger.debug(
                f"shrink message sizes to curInNumElem {curComm['in_msg_size']}, curOutNumElem {curComm['out_msg_size']}"
            )

        commsParams.dtype = self.strToTorchDtype[curComm["dtype"]]
        # allocate and return tensors
        return super().prepComm(curComm, commsParams)
コード例 #2
0
ファイル: commsTraceReplay.py プロジェクト: louisfeng/param
    def setBench(self, comms_world_info, commsParams):
        # init backend and corresponding function pointers
        if commsParams.nw_stack == "pytorch-dist":
            from pytorch_dist_backend import PyTorchDistBackend

            self.backendFuncs = PyTorchDistBackend(comms_world_info, commsParams)
        elif commsParams.nw_stack == "pytorch-xla-tpu":
            from pytorch_tpu_backend import PyTorchTPUBackend

            self.backendFuncs = PyTorchTPUBackend(comms_world_info, commsParams)
        else:
            logger.error("Unsopported NW stack! ")
            comms_utils.gracefulExit()

        self.backendFuncs.initialize_backend(
            comms_world_info.master_ip,
            comms_world_info.master_port,
            backend=commsParams.backend,
        )
        self.backendFuncs.sayHello()

        # set basic collective info
        (
            local_rank,
            global_rank,
            world_size,
            group,
            curDevice,
            curHwDevice,
        ) = comms_utils.get_rank_details(
            self.backendFuncs
        )  # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes

        self.collectiveArgs.group = group
        self.collectiveArgs.device = curDevice
        self.collectiveArgs.world_size = world_size
        self.collectiveArgs.global_rank = global_rank
        self.collectiveArgs.backendFuncs = self.backendFuncs
        # FIXME:  0 is a common case, need this info from trace for more accurate replay
        self.collectiveArgs.srcOrDst = 0
        # FIXME: assuming it's always sum for reduce/allreduce operations
        self.collectiveArgs.op = self.backendFuncs.get_reduce_op("sum")
        # FIXME: alwasy perfom blocking comms; may study non-blocking in the future
        self.collectiveArgs.asyncOp = not self.is_blocking
        self.collectiveArgs.ipTensor = None
        self.collectiveArgs.opTensor = None
        self.collectiveArgs.quant_threshold = commsParams.quant_threshold

        # set of collectives to be replayed
        if self.allowList in ("all", "default", "*"):
            self.allowList = self.backendFuncs.collectiveFunc.keys()
        else:
            self.allowList = [paramToCommName(op) for op in self.allowList.split(",")]
コード例 #3
0
ファイル: commsTraceReplay.py プロジェクト: louisfeng/param
 def initTraceStat(self):
     maxInMsgsize = 0
     maxOutMsgsize = 0
     self.num_msg = len(self.comms_trace)
     self.max_msg_cnt = self.num_msg if self.max_msg_cnt == 0 else self.max_msg_cnt
     # first pass to know the statistics and get required info.
     for curComm in self.comms_trace[: self.max_msg_cnt]:
         # record the current comm
         collName = paramToCommName(curComm["comms"])
         curBlocks = curComm["marker_stack"] if "marker_stack" in curComm else []
         if collName not in self.collLat.keys():
             self.collLat[collName] = []
             # some ops don't have sizes
             if "in_msg_size" in curComm:
                 self.collInMsgSizes[collName] = []
                 self.collInUniMsgSizes[collName] = set()
                 self.collOutMsgSizes[collName] = []
                 self.collOutUniMsgSizes[collName] = set()
         if "in_msg_size" in curComm:
             self.collInMsgSizes[collName].append(curComm["in_msg_size"])
             self.collInUniMsgSizes[collName].add(curComm["in_msg_size"])
             self.collOutMsgSizes[collName].append(curComm["out_msg_size"])
             self.collOutUniMsgSizes[collName].add(curComm["out_msg_size"])
             maxInMsgsize = max(curComm["in_msg_size"], maxInMsgsize)
             maxOutMsgsize = max(curComm["out_msg_size"], maxOutMsgsize)
         # get info sorted by code block
         for curBlock in curBlocks:
             if curBlock not in self.comms_blocks:
                 self.comms_blocks[curBlock] = []
             # only add entries if on dry run, otherwise, we'll deal with later during replay w/ more info
             if self.is_dry_run:
                 if collName not in ("wait", "barrier"):
                     self.comms_blocks[curBlock].append(
                         {
                             "comms": collName,
                             "in_msg_size": curComm["in_msg_size"],
                             "out_msg_size": curComm["out_msg_size"],
                         }
                     )
                 else:
                     self.comms_blocks[curBlock].append(
                         {
                             "comms": collName,
                         }
                     )
コード例 #4
0
ファイル: commsTraceReplay.py プロジェクト: louisfeng/param
    def benchTime(self, commsParams):
        """
        The json format is expecting to be either
        {
            "marker_stack": ["## all2all ##"]
            "comms": "all_to_allv",
            "in_msg_size": 10357149,
            "out_msg_size": 23093760,
            "in_split": [],
            "out_split": [],
            "dtype": "Int"
        },
        or w/o in/out_split
        {
            "marker_stack": ["## all2all ##"]
            "comms": "all_reduce",
            "in_msg_size": 1048576,
            "out_msg_size": 1048576,
            "dtype": "Int"
        }
        or wait/barrier
        {
            "marker_stack": ["## all2all ##"]
            "comms": "wait",
        }
        NOTE:
            - this format is subject to be changed anytime
            - the unit of all size fields is # of elements (not bytes)
        """
        # warm-up
        if self.do_warm_up:
            self.warmUpBench(commsParams)

        # sync everything before starting real runs
        self.backendFuncs.sync_barrier(self.collectiveArgs)

        if self.backendFuncs.get_global_rank() == 0:
            print(
                f"\n+ {self.max_msg_cnt} messages in the trace...replaying (if present) {list(self.allowList)}"
            )
            for coll, sizes in self.collInMsgSizes.items():
                logger.info(f"\t{coll}: {len(sizes)}")

        coll_in_batch_num = 0
        for cnt, curComm in enumerate(self.comms_trace[: self.max_msg_cnt]):
            collName = paramToCommName(curComm["comms"])
            if collName not in self.allowList:
                continue

            curBlocks = curComm["marker_stack"] if "marker_stack" in curComm else []
            curBlockStack = (
                " ".join(curBlocks) if len(curBlocks) > 0 else "Unamed/Unknown"
            )

            if self.backendFuncs.get_global_rank() == 0:
                logger.debug(
                    f"[Rank {self.collectiveArgs.global_rank:3}] Replaying \n{str(curComm)}\n"
                )
                print(f"[{cnt} / {self.max_msg_cnt}]", end="\r")

            # read fields and prepare the tensors
            (
                self.collectiveArgs.ipTensor,
                self.collectiveArgs.opTensor,
            ) = self.prepComms(curComm, commsParams)

            if self.colls_per_batch > 0 and coll_in_batch_num == 0:
                batch_begin = time.monotonic()

            (latency, global_latency) = self.runComms(collName, curBlockStack)

            # calculating batch latency (batch defined by --colls-per-batch)
            if collName == "wait" and self.colls_per_batch > 0:
                coll_in_batch_num += 1
                if coll_in_batch_num == self.colls_per_batch:
                    batch_latency = (
                        time.monotonic() - batch_begin
                    ) * 1e3  # make it millisecond
                    coll_in_batch_num = 0
                    self.batchLat.append(batch_latency)

            self.collLat[collName].append(latency)

            curComm["seqnum"] = cnt
            curComm["latency_us"] = latency
            curComm["global_latency_us"] = global_latency
            curComm["quant_us"] = self.collectiveArgs.quant_time.getTimeUS()
            curComm["dequant_us"] = self.collectiveArgs.dequant_time.getTimeUS()
            self.totalCommsLatency += latency
            # Keep a copy of trace with performance (latency) and seqnum
            self.traceWithPerf.append(curComm)

            # categorized by the marker
            for curBlock in curBlocks:
                # elem_size = self.collectiveArgs.ipTensor.element_size()
                self.comms_blocks[curBlock].append(curComm)

            if self.backendFuncs.get_global_rank() == 0:
                logger.info(
                    f"[{cnt} / {self.max_msg_cnt}] Replayed {collName} in block [{curBlockStack}]... {global_latency:.2f} us"
                )

        # make sure all ops are completed
        self.backendFuncs.sync_barrier(self.collectiveArgs)
        self.backendFuncs.clear_memory(self.collectiveArgs)
コード例 #5
0
 def test_change(self):
     testName = "all12345to___a3l1l" # weird way of typing all_to_all
     result = comms_utils.paramToCommName(testName)
     self.assertEqual("all_to_all", result)
コード例 #6
0
 def test_no_change(self):
     testName = "all_to_all"
     result = comms_utils.paramToCommName(testName)
     self.assertEqual("all_to_all", result)