Example #1
0
 def test_mock_backend(self):
     mockBackend = MockBackendFunction()
     mockTuple = (mockBackend.local_rank,
                  mockBackend.global_rank,
                  mockBackend.world_size,
                  mockBackend.group,
                  mockBackend.device,
                  mockBackend.device)
     self.assertEqual(comms_utils.get_rank_details(mockBackend), mockTuple)
Example #2
0
    def setBench(self, comms_world_info, commsParams):
        # init backend and corresponding function pointers
        if commsParams.nw_stack == "pytorch-dist":
            from pytorch_dist_backend import PyTorchDistBackend

            self.backendFuncs = PyTorchDistBackend(comms_world_info, commsParams)
        elif commsParams.nw_stack == "pytorch-xla-tpu":
            from pytorch_tpu_backend import PyTorchTPUBackend

            self.backendFuncs = PyTorchTPUBackend(comms_world_info, commsParams)
        else:
            logger.error("Unsopported NW stack! ")
            comms_utils.gracefulExit()

        self.backendFuncs.initialize_backend(
            comms_world_info.master_ip,
            comms_world_info.master_port,
            backend=commsParams.backend,
        )
        self.backendFuncs.sayHello()

        # set basic collective info
        (
            local_rank,
            global_rank,
            world_size,
            group,
            curDevice,
            curHwDevice,
        ) = comms_utils.get_rank_details(
            self.backendFuncs
        )  # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes

        self.collectiveArgs.group = group
        self.collectiveArgs.device = curDevice
        self.collectiveArgs.world_size = world_size
        self.collectiveArgs.global_rank = global_rank
        self.collectiveArgs.backendFuncs = self.backendFuncs
        # FIXME:  0 is a common case, need this info from trace for more accurate replay
        self.collectiveArgs.srcOrDst = 0
        # FIXME: assuming it's always sum for reduce/allreduce operations
        self.collectiveArgs.op = self.backendFuncs.get_reduce_op("sum")
        # FIXME: alwasy perfom blocking comms; may study non-blocking in the future
        self.collectiveArgs.asyncOp = not self.is_blocking
        self.collectiveArgs.ipTensor = None
        self.collectiveArgs.opTensor = None
        self.collectiveArgs.quant_threshold = commsParams.quant_threshold

        # set of collectives to be replayed
        if self.allowList in ("all", "default", "*"):
            self.allowList = self.backendFuncs.collectiveFunc.keys()
        else:
            self.allowList = [paramToCommName(op) for op in self.allowList.split(",")]
Example #3
0
    def initCollectiveArgs(self, commsParams):
        # lint was complaining that benchTime was too complex!
        (
            local_rank,
            global_rank,
            world_size,
            group,
            curDevice,
        ) = comms_utils.get_rank_details(self.backendFuncs)  # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes.

        comms_utils.fixBeginSize(commsParams, world_size)  # Ensuring that all-reduce and all-to-all has atleast one member per rank.
        self.backendFuncs.sayHello()  # Informs us where each process is running.
        allSizes = comms_utils.getSizes(
            commsParams.beginSize, commsParams.endSize, commsParams.stepFactor
        )  # Given the begin-size, end-size, step-factor what are the message sizes to iterate on.

        if global_rank == 0:
            print(
                "\t global_rank: %d allSizes: %s local_rank: %d element_size: %d "
                % (global_rank, allSizes, local_rank, commsParams.element_size)
            )
            print("\t global_rank: %d commsParams: %s " % (global_rank, commsParams))

        #self.collectiveArgs = comms_utils.collectiveArgsHolder()
        self.collectiveArgs.group = group
        self.collectiveArgs.device = curDevice
        self.collectiveArgs.world_size = world_size
        self.collectiveArgs.numIters = commsParams.numIters
        self.collectiveArgs.numWarmupIters = commsParams.numWarmupIters
        self.collectiveArgs.global_rank = global_rank
        self.collectiveArgs.backendFuncs = self.backendFuncs
        self.collectiveArgs.srcOrDst = ""
        self.collectiveArgs.collective = commsParams.collective
        op = self.backendFuncs.get_reduce_op("sum")
        self.collectiveArgs.op = op
        self.collectiveArgs.dst = commsParams.dst

        if commsParams.bitwidth < 32:
            logging.warning(f'communication bitwidth set to {commsParams.bitwidth}')
            try:
                from internals import initialize_collectiveArgs_internal
                initialize_collectiveArgs_internal(self.collectiveArgs, commsParams)
            except ImportError:
                if commsParams.collective != "reduce" and  commsParams.collective != "all_reduce":
                    raise NotImplementedError("quantized communication for %s is currently unsupported." % commsParams.collective)
                pass

        computeFunc = None
        if commsParams.mode != "comms":  # Compute mode related initialization.
            if commsParams.kernel == "gemm":
                computeFunc = self.backendFuncs.gemm

                mm_dim = commsParams.mm_dim
                in1 = np.random.rand(mm_dim, mm_dim)
                MMin1 = torch.FloatTensor(in1).to(curDevice)
                in2 = np.random.rand(mm_dim, mm_dim)
                MMin2 = torch.FloatTensor(in2).to(curDevice)
                in3 = np.random.rand(mm_dim, mm_dim)
                MMin3 = torch.FloatTensor(in3).to(curDevice)
                MMout = self.backendFuncs.alloc_empty(
                    [mm_dim, mm_dim], commsParams.dtype, curDevice
                )
                self.collectiveArgs.MMout = MMout
                self.collectiveArgs.MMin1 = MMin1
                self.collectiveArgs.MMin2 = MMin2
                self.collectiveArgs.MMin3 = MMin3
                self.collectiveArgs.numComputePerColl = commsParams.num_compute
            elif commsParams.kernel == "emb_lookup":
                computeFunc = self.backendFuncs.emb_lookup

                emb_dim = commsParams.emb_dim
                num_embeddings = commsParams.num_embs
                avg_length = commsParams.avg_len
                batch_size = commsParams.batch_size
                print(
                    f"emb_dim {emb_dim} num_embs {num_embeddings} avg_len {avg_length} bs {batch_size}"
                )
                self.collectiveArgs.EmbWeights = self.backendFuncs.alloc_empty(
                    [num_embeddings, emb_dim], torch.double, curDevice
                )
                self.collectiveArgs.TableOffsets = torch.LongTensor([0, num_embeddings]).to(
                    curDevice
                )
                self.collectiveArgs.Indices = torch.LongTensor(
                    np.random.randint(0, num_embeddings - 1, avg_length * batch_size)
                ).to(curDevice)
                lengths = np.ones((1, batch_size)) * avg_length
                flat_lengths = lengths.flatten()
                self.collectiveArgs.Offsets = torch.LongTensor(
                    [0] + np.cumsum(flat_lengths).tolist()
                ).to(curDevice)
                self.collectiveArgs.LookupOut = self.backendFuncs.alloc_empty(
                    [batch_size, emb_dim], torch.double, curDevice
                )
                self.collectiveArgs.AvgLengths = avg_length
                self.collectiveArgs.numComputePerColl = commsParams.num_compute

        return (
            local_rank,
            global_rank,
            world_size,
            group,
            curDevice,
            allSizes,
            computeFunc,
        )
Example #4
0
def initializeCollectiveArgs(commsParams, backendFuncs):
    # lint was complaining that benchTime was too complex!
    (
        local_rank,
        global_rank,
        world_size,
        group,
        curDevice,
    ) = comms_utils.get_rank_details(
        backendFuncs
    )  # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes.

    comms_utils.fixBeginSize(
        commsParams, world_size
    )  # Ensuring that all-reduce and all-to-all has atleast one member per rank.
    backendFuncs.sayHello()  # Informs us where each process is running.
    allSizes = comms_utils.getSizes(
        commsParams.beginSize, commsParams.endSize, commsParams.stepFactor
    )  # Given the begin-size, end-size, step-factor what are the message sizes to iterate on.

    if global_rank == 0:
        print(
            "\t global_rank: %d allSizes: %s local_rank: %d element_size: %d "
            % (global_rank, allSizes, local_rank, commsParams.element_size))
        print("\t global_rank: %d commsParams: %s " %
              (global_rank, commsParams))

    collectiveArgs = comms_utils.collectiveArgsHolder()
    collectiveArgs.group = group
    collectiveArgs.device = curDevice
    collectiveArgs.world_size = world_size
    collectiveArgs.numIters = commsParams.numIters
    collectiveArgs.numWarmupIters = commsParams.numWarmupIters
    collectiveArgs.global_rank = global_rank
    collectiveArgs.backendFuncs = backendFuncs
    collectiveArgs.srcOrDst = ""
    collectiveArgs.collective = commsParams.collective
    op = backendFuncs.get_reduce_op("sum")
    collectiveArgs.op = op
    collectiveArgs.dst = commsParams.dst

    computeFunc = None
    if commsParams.mode != "comms":  # Compute mode related initialization.
        if commsParams.kernel == "gemm":
            computeFunc = backendFuncs.gemm

            mm_dim = commsParams.mm_dim
            in1 = np.random.rand(mm_dim, mm_dim)
            MMin1 = torch.FloatTensor(in1).to(curDevice)
            in2 = np.random.rand(mm_dim, mm_dim)
            MMin2 = torch.FloatTensor(in2).to(curDevice)
            in3 = np.random.rand(mm_dim, mm_dim)
            MMin3 = torch.FloatTensor(in3).to(curDevice)
            MMout = backendFuncs.alloc_empty([mm_dim, mm_dim],
                                             commsParams.dtype, curDevice)
            collectiveArgs.MMout = MMout
            collectiveArgs.MMin1 = MMin1
            collectiveArgs.MMin2 = MMin2
            collectiveArgs.MMin3 = MMin3
            collectiveArgs.numComputePerColl = commsParams.num_compute
        else:
            print("Compute kernel " + commsParams.kernel +
                  " not supported...Abort!")
            comms_utils.gracefulExit()

    return (
        collectiveArgs,
        local_rank,
        global_rank,
        world_size,
        group,
        curDevice,
        allSizes,
        computeFunc,
    )
Example #5
0
    def initCollectiveArgs(self, commsParams):
        # lint was complaining that benchTime was too complex!
        (
            local_rank,
            global_rank,
            world_size,
            group,
            curDevice,
            curHwDevice,
        ) = comms_utils.get_rank_details(
            self.backendFuncs
        )  # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes.
        self.backendFuncs.sayHello(
        )  # Informs us where each process is running.
        groups = self.backendFuncs.get_groups()
        num_pgs = len(groups)

        self.comm_size = world_size
        self.global_rank = global_rank

        comms_utils.fixBeginSize(
            commsParams, world_size
        )  # Ensuring that all-reduce and all-to-all has atleast one member per rank.
        allSizes = comms_utils.getSizes(
            commsParams.beginSize, commsParams.endSize, commsParams.stepFactor
        )  # Given the begin-size, end-size, step-factor what are the message sizes to iterate on.

        if global_rank == 0:
            print(
                f"[Rank {global_rank:>3}] allSizes: {allSizes} local_rank: {local_rank} element_size: {commsParams.element_size}"
            )

        self.collectiveArgs.group = group
        self.collectiveArgs.groups = groups
        self.collectiveArgs.num_pgs = num_pgs
        self.collectiveArgs.device = curDevice
        self.collectiveArgs.world_size = world_size
        self.collectiveArgs.numIters = commsParams.numIters
        self.collectiveArgs.numWarmupIters = commsParams.numWarmupIters
        self.collectiveArgs.global_rank = global_rank
        self.collectiveArgs.backendFuncs = self.backendFuncs
        self.collectiveArgs.collective = commsParams.collective
        op = self.backendFuncs.get_reduce_op("sum")
        self.collectiveArgs.op = op
        self.collectiveArgs.srcOrDst = commsParams.srcOrDst
        self.collectiveArgs.src_ranks = commsParams.src_ranks
        self.collectiveArgs.dst_ranks = commsParams.dst_ranks
        self.collectiveArgs.pair = commsParams.pair
        self.collectiveArgs.collective_pair = commsParams.collective_pair
        self.collectiveArgs.pt2pt = commsParams.pt2pt
        self.collectiveArgs.window = commsParams.window
        self.collectiveArgs.asyncOp = False if commsParams.blockingFlag == 1 else True

        if commsParams.bitwidth < 32:
            comms_utils.initQuantCommCtx(self.collectiveArgs, commsParams)

        if self.collectiveArgs.collective == "pt2pt":
            self.checkPt2PtRanks()
        else:
            self.checkCollectiveRanks()

        computeFunc = self.backendFuncs.noop
        if (commsParams.mode != "comms"
            ):  # Compute mode related initialization if not in comms-only mode
            if commsParams.kernel == "gemm":
                computeFunc = self.backendFuncs.gemm

                mm_dim = commsParams.mm_dim
                in1 = np.random.rand(mm_dim, mm_dim)
                MMin1 = torch.FloatTensor(in1).to(curDevice)
                in2 = np.random.rand(mm_dim, mm_dim)
                MMin2 = torch.FloatTensor(in2).to(curDevice)
                in3 = np.random.rand(mm_dim, mm_dim)
                MMin3 = torch.FloatTensor(in3).to(curDevice)
                MMout = self.backendFuncs.alloc_empty([mm_dim, mm_dim],
                                                      commsParams.dtype,
                                                      curDevice)
                self.collectiveArgs.MMout = MMout
                self.collectiveArgs.MMin1 = MMin1
                self.collectiveArgs.MMin2 = MMin2
                self.collectiveArgs.MMin3 = MMin3
                self.collectiveArgs.numComputePerColl = commsParams.num_compute
            elif commsParams.kernel == "emb_lookup":
                computeFunc = self.backendFuncs.emb_lookup

                emb_dim = commsParams.emb_dim
                num_embeddings = commsParams.num_embs
                avg_length = commsParams.avg_len
                batch_size = commsParams.batch_size
                print(
                    f"emb_dim {emb_dim} num_embs {num_embeddings} avg_len {avg_length} bs {batch_size}"
                )
                self.collectiveArgs.EmbWeights = self.backendFuncs.alloc_empty(
                    [num_embeddings, emb_dim], torch.double, curDevice)
                self.collectiveArgs.TableOffsets = torch.LongTensor(
                    [0, num_embeddings]).to(curDevice)
                self.collectiveArgs.Indices = torch.LongTensor(
                    np.random.randint(0, num_embeddings - 1,
                                      avg_length * batch_size)).to(curDevice)
                lengths = np.ones((1, batch_size)) * avg_length
                flat_lengths = lengths.flatten()
                self.collectiveArgs.Offsets = torch.LongTensor(
                    [0] + np.cumsum(flat_lengths).tolist()).to(curDevice)
                self.collectiveArgs.LookupOut = self.backendFuncs.alloc_empty(
                    [batch_size, emb_dim], torch.double, curDevice)
                self.collectiveArgs.AvgLengths = avg_length
                self.collectiveArgs.numComputePerColl = commsParams.num_compute

        return (
            local_rank,
            global_rank,
            world_size,
            group,
            curDevice,
            curHwDevice,
            allSizes,
            computeFunc,
        )