def test_get_sizes(self): beginSize = 32 endSize = 1024 stepFactor = 2 # Start at 32, end at 1024 by increasing by a factor of 2 after each iteration. correct_list = [32, 64, 128, 256, 512, 1024] result_list = comms_utils.getSizes(beginSize, endSize, stepFactor) # Lists should have same size and items in the same order. self.assertEqual(len(correct_list), len(result_list)) self.assertTrue(correct_list == result_list)
def initCollectiveArgs(self, commsParams): # lint was complaining that benchTime was too complex! ( local_rank, global_rank, world_size, group, curDevice, ) = comms_utils.get_rank_details(self.backendFuncs) # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes. comms_utils.fixBeginSize(commsParams, world_size) # Ensuring that all-reduce and all-to-all has atleast one member per rank. self.backendFuncs.sayHello() # Informs us where each process is running. allSizes = comms_utils.getSizes( commsParams.beginSize, commsParams.endSize, commsParams.stepFactor ) # Given the begin-size, end-size, step-factor what are the message sizes to iterate on. if global_rank == 0: print( "\t global_rank: %d allSizes: %s local_rank: %d element_size: %d " % (global_rank, allSizes, local_rank, commsParams.element_size) ) print("\t global_rank: %d commsParams: %s " % (global_rank, commsParams)) #self.collectiveArgs = comms_utils.collectiveArgsHolder() self.collectiveArgs.group = group self.collectiveArgs.device = curDevice self.collectiveArgs.world_size = world_size self.collectiveArgs.numIters = commsParams.numIters self.collectiveArgs.numWarmupIters = commsParams.numWarmupIters self.collectiveArgs.global_rank = global_rank self.collectiveArgs.backendFuncs = self.backendFuncs self.collectiveArgs.srcOrDst = "" self.collectiveArgs.collective = commsParams.collective op = self.backendFuncs.get_reduce_op("sum") self.collectiveArgs.op = op self.collectiveArgs.dst = commsParams.dst if commsParams.bitwidth < 32: logging.warning(f'communication bitwidth set to {commsParams.bitwidth}') try: from internals import initialize_collectiveArgs_internal initialize_collectiveArgs_internal(self.collectiveArgs, commsParams) except ImportError: if commsParams.collective != "reduce" and commsParams.collective != "all_reduce": raise NotImplementedError("quantized communication for %s is currently unsupported." % commsParams.collective) pass computeFunc = None if commsParams.mode != "comms": # Compute mode related initialization. if commsParams.kernel == "gemm": computeFunc = self.backendFuncs.gemm mm_dim = commsParams.mm_dim in1 = np.random.rand(mm_dim, mm_dim) MMin1 = torch.FloatTensor(in1).to(curDevice) in2 = np.random.rand(mm_dim, mm_dim) MMin2 = torch.FloatTensor(in2).to(curDevice) in3 = np.random.rand(mm_dim, mm_dim) MMin3 = torch.FloatTensor(in3).to(curDevice) MMout = self.backendFuncs.alloc_empty( [mm_dim, mm_dim], commsParams.dtype, curDevice ) self.collectiveArgs.MMout = MMout self.collectiveArgs.MMin1 = MMin1 self.collectiveArgs.MMin2 = MMin2 self.collectiveArgs.MMin3 = MMin3 self.collectiveArgs.numComputePerColl = commsParams.num_compute elif commsParams.kernel == "emb_lookup": computeFunc = self.backendFuncs.emb_lookup emb_dim = commsParams.emb_dim num_embeddings = commsParams.num_embs avg_length = commsParams.avg_len batch_size = commsParams.batch_size print( f"emb_dim {emb_dim} num_embs {num_embeddings} avg_len {avg_length} bs {batch_size}" ) self.collectiveArgs.EmbWeights = self.backendFuncs.alloc_empty( [num_embeddings, emb_dim], torch.double, curDevice ) self.collectiveArgs.TableOffsets = torch.LongTensor([0, num_embeddings]).to( curDevice ) self.collectiveArgs.Indices = torch.LongTensor( np.random.randint(0, num_embeddings - 1, avg_length * batch_size) ).to(curDevice) lengths = np.ones((1, batch_size)) * avg_length flat_lengths = lengths.flatten() self.collectiveArgs.Offsets = torch.LongTensor( [0] + np.cumsum(flat_lengths).tolist() ).to(curDevice) self.collectiveArgs.LookupOut = self.backendFuncs.alloc_empty( [batch_size, emb_dim], torch.double, curDevice ) self.collectiveArgs.AvgLengths = avg_length self.collectiveArgs.numComputePerColl = commsParams.num_compute return ( local_rank, global_rank, world_size, group, curDevice, allSizes, computeFunc, )
def initializeCollectiveArgs(commsParams, backendFuncs): # lint was complaining that benchTime was too complex! ( local_rank, global_rank, world_size, group, curDevice, ) = comms_utils.get_rank_details( backendFuncs ) # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes. comms_utils.fixBeginSize( commsParams, world_size ) # Ensuring that all-reduce and all-to-all has atleast one member per rank. backendFuncs.sayHello() # Informs us where each process is running. allSizes = comms_utils.getSizes( commsParams.beginSize, commsParams.endSize, commsParams.stepFactor ) # Given the begin-size, end-size, step-factor what are the message sizes to iterate on. if global_rank == 0: print( "\t global_rank: %d allSizes: %s local_rank: %d element_size: %d " % (global_rank, allSizes, local_rank, commsParams.element_size)) print("\t global_rank: %d commsParams: %s " % (global_rank, commsParams)) collectiveArgs = comms_utils.collectiveArgsHolder() collectiveArgs.group = group collectiveArgs.device = curDevice collectiveArgs.world_size = world_size collectiveArgs.numIters = commsParams.numIters collectiveArgs.numWarmupIters = commsParams.numWarmupIters collectiveArgs.global_rank = global_rank collectiveArgs.backendFuncs = backendFuncs collectiveArgs.srcOrDst = "" collectiveArgs.collective = commsParams.collective op = backendFuncs.get_reduce_op("sum") collectiveArgs.op = op collectiveArgs.dst = commsParams.dst computeFunc = None if commsParams.mode != "comms": # Compute mode related initialization. if commsParams.kernel == "gemm": computeFunc = backendFuncs.gemm mm_dim = commsParams.mm_dim in1 = np.random.rand(mm_dim, mm_dim) MMin1 = torch.FloatTensor(in1).to(curDevice) in2 = np.random.rand(mm_dim, mm_dim) MMin2 = torch.FloatTensor(in2).to(curDevice) in3 = np.random.rand(mm_dim, mm_dim) MMin3 = torch.FloatTensor(in3).to(curDevice) MMout = backendFuncs.alloc_empty([mm_dim, mm_dim], commsParams.dtype, curDevice) collectiveArgs.MMout = MMout collectiveArgs.MMin1 = MMin1 collectiveArgs.MMin2 = MMin2 collectiveArgs.MMin3 = MMin3 collectiveArgs.numComputePerColl = commsParams.num_compute else: print("Compute kernel " + commsParams.kernel + " not supported...Abort!") comms_utils.gracefulExit() return ( collectiveArgs, local_rank, global_rank, world_size, group, curDevice, allSizes, computeFunc, )
def initCollectiveArgs(self, commsParams): # lint was complaining that benchTime was too complex! ( local_rank, global_rank, world_size, group, curDevice, curHwDevice, ) = comms_utils.get_rank_details( self.backendFuncs ) # Getting ranks from backednFuncs object, since we cannot use MPI (e.g.: TPU) to launch all the processes. self.backendFuncs.sayHello( ) # Informs us where each process is running. groups = self.backendFuncs.get_groups() num_pgs = len(groups) self.comm_size = world_size self.global_rank = global_rank comms_utils.fixBeginSize( commsParams, world_size ) # Ensuring that all-reduce and all-to-all has atleast one member per rank. allSizes = comms_utils.getSizes( commsParams.beginSize, commsParams.endSize, commsParams.stepFactor ) # Given the begin-size, end-size, step-factor what are the message sizes to iterate on. if global_rank == 0: print( f"[Rank {global_rank:>3}] allSizes: {allSizes} local_rank: {local_rank} element_size: {commsParams.element_size}" ) self.collectiveArgs.group = group self.collectiveArgs.groups = groups self.collectiveArgs.num_pgs = num_pgs self.collectiveArgs.device = curDevice self.collectiveArgs.world_size = world_size self.collectiveArgs.numIters = commsParams.numIters self.collectiveArgs.numWarmupIters = commsParams.numWarmupIters self.collectiveArgs.global_rank = global_rank self.collectiveArgs.backendFuncs = self.backendFuncs self.collectiveArgs.collective = commsParams.collective op = self.backendFuncs.get_reduce_op("sum") self.collectiveArgs.op = op self.collectiveArgs.srcOrDst = commsParams.srcOrDst self.collectiveArgs.src_ranks = commsParams.src_ranks self.collectiveArgs.dst_ranks = commsParams.dst_ranks self.collectiveArgs.pair = commsParams.pair self.collectiveArgs.collective_pair = commsParams.collective_pair self.collectiveArgs.pt2pt = commsParams.pt2pt self.collectiveArgs.window = commsParams.window self.collectiveArgs.asyncOp = False if commsParams.blockingFlag == 1 else True if commsParams.bitwidth < 32: comms_utils.initQuantCommCtx(self.collectiveArgs, commsParams) if self.collectiveArgs.collective == "pt2pt": self.checkPt2PtRanks() else: self.checkCollectiveRanks() computeFunc = self.backendFuncs.noop if (commsParams.mode != "comms" ): # Compute mode related initialization if not in comms-only mode if commsParams.kernel == "gemm": computeFunc = self.backendFuncs.gemm mm_dim = commsParams.mm_dim in1 = np.random.rand(mm_dim, mm_dim) MMin1 = torch.FloatTensor(in1).to(curDevice) in2 = np.random.rand(mm_dim, mm_dim) MMin2 = torch.FloatTensor(in2).to(curDevice) in3 = np.random.rand(mm_dim, mm_dim) MMin3 = torch.FloatTensor(in3).to(curDevice) MMout = self.backendFuncs.alloc_empty([mm_dim, mm_dim], commsParams.dtype, curDevice) self.collectiveArgs.MMout = MMout self.collectiveArgs.MMin1 = MMin1 self.collectiveArgs.MMin2 = MMin2 self.collectiveArgs.MMin3 = MMin3 self.collectiveArgs.numComputePerColl = commsParams.num_compute elif commsParams.kernel == "emb_lookup": computeFunc = self.backendFuncs.emb_lookup emb_dim = commsParams.emb_dim num_embeddings = commsParams.num_embs avg_length = commsParams.avg_len batch_size = commsParams.batch_size print( f"emb_dim {emb_dim} num_embs {num_embeddings} avg_len {avg_length} bs {batch_size}" ) self.collectiveArgs.EmbWeights = self.backendFuncs.alloc_empty( [num_embeddings, emb_dim], torch.double, curDevice) self.collectiveArgs.TableOffsets = torch.LongTensor( [0, num_embeddings]).to(curDevice) self.collectiveArgs.Indices = torch.LongTensor( np.random.randint(0, num_embeddings - 1, avg_length * batch_size)).to(curDevice) lengths = np.ones((1, batch_size)) * avg_length flat_lengths = lengths.flatten() self.collectiveArgs.Offsets = torch.LongTensor( [0] + np.cumsum(flat_lengths).tolist()).to(curDevice) self.collectiveArgs.LookupOut = self.backendFuncs.alloc_empty( [batch_size, emb_dim], torch.double, curDevice) self.collectiveArgs.AvgLengths = avg_length self.collectiveArgs.numComputePerColl = commsParams.num_compute return ( local_rank, global_rank, world_size, group, curDevice, curHwDevice, allSizes, computeFunc, )