Example #1
0
 def test_all_reduce(self):
     commsParams = commsParamsTest()
     commsParams.collective = "all_reduce"
     commsParams.beginSize = 0
     commsParams.element_size = 2
     world_size = 16
     comms_utils.fixBeginSize(commsParams, world_size)
     # For reduce collectives, beginSize should >= element_size
     self.assertEqual(2, commsParams.beginSize)
Example #2
0
 def test_all_to_all(self):
     commsParams = commsParamsTest()
     commsParams.collective = "all_to_all"
     commsParams.beginSize = 0
     commsParams.element_size = 2
     commsParams.bitwidth = 32
     world_size = 16
     comms_utils.fixBeginSize(commsParams, world_size)
     # beginSize / element_size < world_size, so the new begin size should be element_size * world_size
     self.assertEqual(32, commsParams.beginSize)
Example #3
0
 def test_all_to_all_quantized(self):
     commsParams = commsParamsTest()
     commsParams.collective = "all_to_all"
     commsParams.beginSize = 0
     commsParams.element_size = 2
     commsParams.bitwidth = 31 # Bitwidth less than 32 triggers quantization
     commsParams.quant_a2a_embedding_dim = 2
     world_size = 16
     comms_utils.fixBeginSize(commsParams, world_size)
     # (beginSize / element_size / world_size) < quant_a2a_embedding_dim, so the new begin size should be element_size * world_size * quant_a2a_embedding_dim
     self.assertEqual(64, commsParams.beginSize)
Example #4
0
    def initCollectiveArgs(self, commsParams):
        # lint was complaining that benchTime was too complex!
        (
            local_rank,
            global_rank,
            world_size,
            group,
            curDevice,
        ) = comms_utils.get_rank_details(self.backendFuncs)  # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes.

        comms_utils.fixBeginSize(commsParams, world_size)  # Ensuring that all-reduce and all-to-all has atleast one member per rank.
        self.backendFuncs.sayHello()  # Informs us where each process is running.
        allSizes = comms_utils.getSizes(
            commsParams.beginSize, commsParams.endSize, commsParams.stepFactor
        )  # Given the begin-size, end-size, step-factor what are the message sizes to iterate on.

        if global_rank == 0:
            print(
                "\t global_rank: %d allSizes: %s local_rank: %d element_size: %d "
                % (global_rank, allSizes, local_rank, commsParams.element_size)
            )
            print("\t global_rank: %d commsParams: %s " % (global_rank, commsParams))

        #self.collectiveArgs = comms_utils.collectiveArgsHolder()
        self.collectiveArgs.group = group
        self.collectiveArgs.device = curDevice
        self.collectiveArgs.world_size = world_size
        self.collectiveArgs.numIters = commsParams.numIters
        self.collectiveArgs.numWarmupIters = commsParams.numWarmupIters
        self.collectiveArgs.global_rank = global_rank
        self.collectiveArgs.backendFuncs = self.backendFuncs
        self.collectiveArgs.srcOrDst = ""
        self.collectiveArgs.collective = commsParams.collective
        op = self.backendFuncs.get_reduce_op("sum")
        self.collectiveArgs.op = op
        self.collectiveArgs.dst = commsParams.dst

        if commsParams.bitwidth < 32:
            logging.warning(f'communication bitwidth set to {commsParams.bitwidth}')
            try:
                from internals import initialize_collectiveArgs_internal
                initialize_collectiveArgs_internal(self.collectiveArgs, commsParams)
            except ImportError:
                if commsParams.collective != "reduce" and  commsParams.collective != "all_reduce":
                    raise NotImplementedError("quantized communication for %s is currently unsupported." % commsParams.collective)
                pass

        computeFunc = None
        if commsParams.mode != "comms":  # Compute mode related initialization.
            if commsParams.kernel == "gemm":
                computeFunc = self.backendFuncs.gemm

                mm_dim = commsParams.mm_dim
                in1 = np.random.rand(mm_dim, mm_dim)
                MMin1 = torch.FloatTensor(in1).to(curDevice)
                in2 = np.random.rand(mm_dim, mm_dim)
                MMin2 = torch.FloatTensor(in2).to(curDevice)
                in3 = np.random.rand(mm_dim, mm_dim)
                MMin3 = torch.FloatTensor(in3).to(curDevice)
                MMout = self.backendFuncs.alloc_empty(
                    [mm_dim, mm_dim], commsParams.dtype, curDevice
                )
                self.collectiveArgs.MMout = MMout
                self.collectiveArgs.MMin1 = MMin1
                self.collectiveArgs.MMin2 = MMin2
                self.collectiveArgs.MMin3 = MMin3
                self.collectiveArgs.numComputePerColl = commsParams.num_compute
            elif commsParams.kernel == "emb_lookup":
                computeFunc = self.backendFuncs.emb_lookup

                emb_dim = commsParams.emb_dim
                num_embeddings = commsParams.num_embs
                avg_length = commsParams.avg_len
                batch_size = commsParams.batch_size
                print(
                    f"emb_dim {emb_dim} num_embs {num_embeddings} avg_len {avg_length} bs {batch_size}"
                )
                self.collectiveArgs.EmbWeights = self.backendFuncs.alloc_empty(
                    [num_embeddings, emb_dim], torch.double, curDevice
                )
                self.collectiveArgs.TableOffsets = torch.LongTensor([0, num_embeddings]).to(
                    curDevice
                )
                self.collectiveArgs.Indices = torch.LongTensor(
                    np.random.randint(0, num_embeddings - 1, avg_length * batch_size)
                ).to(curDevice)
                lengths = np.ones((1, batch_size)) * avg_length
                flat_lengths = lengths.flatten()
                self.collectiveArgs.Offsets = torch.LongTensor(
                    [0] + np.cumsum(flat_lengths).tolist()
                ).to(curDevice)
                self.collectiveArgs.LookupOut = self.backendFuncs.alloc_empty(
                    [batch_size, emb_dim], torch.double, curDevice
                )
                self.collectiveArgs.AvgLengths = avg_length
                self.collectiveArgs.numComputePerColl = commsParams.num_compute

        return (
            local_rank,
            global_rank,
            world_size,
            group,
            curDevice,
            allSizes,
            computeFunc,
        )
Example #5
0
def initializeCollectiveArgs(commsParams, backendFuncs):
    # lint was complaining that benchTime was too complex!
    (
        local_rank,
        global_rank,
        world_size,
        group,
        curDevice,
    ) = comms_utils.get_rank_details(
        backendFuncs
    )  # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes.

    comms_utils.fixBeginSize(
        commsParams, world_size
    )  # Ensuring that all-reduce and all-to-all has atleast one member per rank.
    backendFuncs.sayHello()  # Informs us where each process is running.
    allSizes = comms_utils.getSizes(
        commsParams.beginSize, commsParams.endSize, commsParams.stepFactor
    )  # Given the begin-size, end-size, step-factor what are the message sizes to iterate on.

    if global_rank == 0:
        print(
            "\t global_rank: %d allSizes: %s local_rank: %d element_size: %d "
            % (global_rank, allSizes, local_rank, commsParams.element_size))
        print("\t global_rank: %d commsParams: %s " %
              (global_rank, commsParams))

    collectiveArgs = comms_utils.collectiveArgsHolder()
    collectiveArgs.group = group
    collectiveArgs.device = curDevice
    collectiveArgs.world_size = world_size
    collectiveArgs.numIters = commsParams.numIters
    collectiveArgs.numWarmupIters = commsParams.numWarmupIters
    collectiveArgs.global_rank = global_rank
    collectiveArgs.backendFuncs = backendFuncs
    collectiveArgs.srcOrDst = ""
    collectiveArgs.collective = commsParams.collective
    op = backendFuncs.get_reduce_op("sum")
    collectiveArgs.op = op
    collectiveArgs.dst = commsParams.dst

    computeFunc = None
    if commsParams.mode != "comms":  # Compute mode related initialization.
        if commsParams.kernel == "gemm":
            computeFunc = backendFuncs.gemm

            mm_dim = commsParams.mm_dim
            in1 = np.random.rand(mm_dim, mm_dim)
            MMin1 = torch.FloatTensor(in1).to(curDevice)
            in2 = np.random.rand(mm_dim, mm_dim)
            MMin2 = torch.FloatTensor(in2).to(curDevice)
            in3 = np.random.rand(mm_dim, mm_dim)
            MMin3 = torch.FloatTensor(in3).to(curDevice)
            MMout = backendFuncs.alloc_empty([mm_dim, mm_dim],
                                             commsParams.dtype, curDevice)
            collectiveArgs.MMout = MMout
            collectiveArgs.MMin1 = MMin1
            collectiveArgs.MMin2 = MMin2
            collectiveArgs.MMin3 = MMin3
            collectiveArgs.numComputePerColl = commsParams.num_compute
        else:
            print("Compute kernel " + commsParams.kernel +
                  " not supported...Abort!")
            comms_utils.gracefulExit()

    return (
        collectiveArgs,
        local_rank,
        global_rank,
        world_size,
        group,
        curDevice,
        allSizes,
        computeFunc,
    )
Example #6
0
    def initCollectiveArgs(self, commsParams):
        # lint was complaining that benchTime was too complex!
        (
            local_rank,
            global_rank,
            world_size,
            group,
            curDevice,
            curHwDevice,
        ) = comms_utils.get_rank_details(
            self.backendFuncs
        )  # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes.
        self.backendFuncs.sayHello(
        )  # Informs us where each process is running.
        groups = self.backendFuncs.get_groups()
        num_pgs = len(groups)

        self.comm_size = world_size
        self.global_rank = global_rank

        comms_utils.fixBeginSize(
            commsParams, world_size
        )  # Ensuring that all-reduce and all-to-all has atleast one member per rank.
        allSizes = comms_utils.getSizes(
            commsParams.beginSize, commsParams.endSize, commsParams.stepFactor
        )  # Given the begin-size, end-size, step-factor what are the message sizes to iterate on.

        if global_rank == 0:
            print(
                f"[Rank {global_rank:>3}] allSizes: {allSizes} local_rank: {local_rank} element_size: {commsParams.element_size}"
            )

        self.collectiveArgs.group = group
        self.collectiveArgs.groups = groups
        self.collectiveArgs.num_pgs = num_pgs
        self.collectiveArgs.device = curDevice
        self.collectiveArgs.world_size = world_size
        self.collectiveArgs.numIters = commsParams.numIters
        self.collectiveArgs.numWarmupIters = commsParams.numWarmupIters
        self.collectiveArgs.global_rank = global_rank
        self.collectiveArgs.backendFuncs = self.backendFuncs
        self.collectiveArgs.collective = commsParams.collective
        op = self.backendFuncs.get_reduce_op("sum")
        self.collectiveArgs.op = op
        self.collectiveArgs.srcOrDst = commsParams.srcOrDst
        self.collectiveArgs.src_ranks = commsParams.src_ranks
        self.collectiveArgs.dst_ranks = commsParams.dst_ranks
        self.collectiveArgs.pair = commsParams.pair
        self.collectiveArgs.collective_pair = commsParams.collective_pair
        self.collectiveArgs.pt2pt = commsParams.pt2pt
        self.collectiveArgs.window = commsParams.window
        self.collectiveArgs.asyncOp = False if commsParams.blockingFlag == 1 else True

        if commsParams.bitwidth < 32:
            comms_utils.initQuantCommCtx(self.collectiveArgs, commsParams)

        if self.collectiveArgs.collective == "pt2pt":
            self.checkPt2PtRanks()
        else:
            self.checkCollectiveRanks()

        computeFunc = self.backendFuncs.noop
        if (commsParams.mode != "comms"
            ):  # Compute mode related initialization if not in comms-only mode
            if commsParams.kernel == "gemm":
                computeFunc = self.backendFuncs.gemm

                mm_dim = commsParams.mm_dim
                in1 = np.random.rand(mm_dim, mm_dim)
                MMin1 = torch.FloatTensor(in1).to(curDevice)
                in2 = np.random.rand(mm_dim, mm_dim)
                MMin2 = torch.FloatTensor(in2).to(curDevice)
                in3 = np.random.rand(mm_dim, mm_dim)
                MMin3 = torch.FloatTensor(in3).to(curDevice)
                MMout = self.backendFuncs.alloc_empty([mm_dim, mm_dim],
                                                      commsParams.dtype,
                                                      curDevice)
                self.collectiveArgs.MMout = MMout
                self.collectiveArgs.MMin1 = MMin1
                self.collectiveArgs.MMin2 = MMin2
                self.collectiveArgs.MMin3 = MMin3
                self.collectiveArgs.numComputePerColl = commsParams.num_compute
            elif commsParams.kernel == "emb_lookup":
                computeFunc = self.backendFuncs.emb_lookup

                emb_dim = commsParams.emb_dim
                num_embeddings = commsParams.num_embs
                avg_length = commsParams.avg_len
                batch_size = commsParams.batch_size
                print(
                    f"emb_dim {emb_dim} num_embs {num_embeddings} avg_len {avg_length} bs {batch_size}"
                )
                self.collectiveArgs.EmbWeights = self.backendFuncs.alloc_empty(
                    [num_embeddings, emb_dim], torch.double, curDevice)
                self.collectiveArgs.TableOffsets = torch.LongTensor(
                    [0, num_embeddings]).to(curDevice)
                self.collectiveArgs.Indices = torch.LongTensor(
                    np.random.randint(0, num_embeddings - 1,
                                      avg_length * batch_size)).to(curDevice)
                lengths = np.ones((1, batch_size)) * avg_length
                flat_lengths = lengths.flatten()
                self.collectiveArgs.Offsets = torch.LongTensor(
                    [0] + np.cumsum(flat_lengths).tolist()).to(curDevice)
                self.collectiveArgs.LookupOut = self.backendFuncs.alloc_empty(
                    [batch_size, emb_dim], torch.double, curDevice)
                self.collectiveArgs.AvgLengths = avg_length
                self.collectiveArgs.numComputePerColl = commsParams.num_compute

        return (
            local_rank,
            global_rank,
            world_size,
            group,
            curDevice,
            curHwDevice,
            allSizes,
            computeFunc,
        )