def main(): collBenchObj = commsCollBench() ### parse arguments ### parser = argparse.ArgumentParser( description="PARAM-Comm Benchmark", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) args, leftovers = collBenchObj.readArgs(parser) collBenchObj.checkArgs(args) mpi_env_params = comms_utils.read_mpi_env_vars() if mpi_env_params["global_rank"] == 0: print("\t MPI environment: %s " % (str(mpi_env_params))) print( "\t backend: %s nw-stack: %s mode: %s args.b: %d args.e: %d args.f: %d args.z: %s args.master_ip: %s " % ( args.backend, args.nw_stack, args.mode, args.b, args.e, args.f, args.z, args.master_ip, )) element_size = torch.ones([1], dtype=args.dtype).element_size() comms_world_info = comms_utils.comms_world_info_holder( args.master_ip, args.master_port, args.num_tpu_cores, mpi_env_params) commsParams = comms_utils.commsParamsHolder(args, element_size, collBenchObj.benchTime) collBenchObj.runBench(comms_world_info, commsParams)
def main(): ### import packages ### import sys import argparse supportedCommStyle = [0, 1] # 0 : non-blocking, 1 : blocking. supportedCollectives = [ "reduce", "all_reduce", "all_to_all", "all_to_allv", ] # , "scatter", "gather", "all_gather", "broadcast", "all_to_all"] supportedNwstacks = ["pytorch-nccl", "pytorch-xla-tpu"] supported_tpu_core_valuses = [1, 8] dtypeMap = { "float32": torch.float32, "int32": torch.int32, "float16": torch.half, "float64": torch.double, } supportedDtype = list(dtypeMap.keys()) ### parse arguments ### parser = argparse.ArgumentParser( description="PARAM-Comm Benchmark", formatter_class=argparse.ArgumentDefaultsHelpFormatter) # experiment related parameters parser.add_argument( "--backend", type=str, default="nccl", help="The backend to be used in PyTorch distributed process group" ) # alternative is DLRM mode. parser.add_argument( "--mode", type=str, default="comms", help="benchmark mode") # alternative is DLRM mode or comm-compute mode parser.add_argument("--b", type=str, default="8", help="minimum size, in bytes, to start with" ) # COMMS mode, begin the sweep at. parser.add_argument("--e", type=str, default="64", help="maximum size, in bytes, to end at" ) # COMMS mode, end the sweep at. parser.add_argument("--f", type=int, default=2, help="multiplication factor between sizes" ) # COMMS mode, multiplication factor. parser.add_argument("--z", type=int, default=1, help="use blocking mode for collectives" ) # 'sync/blocking' : 1 , 'async/non-blocking' : 0 parser.add_argument( "--w", type=int, default=5, help="number of warmup iterations") # number of warmup-iterations parser.add_argument("--n", type=int, default=5, help="number of iterations") # number of iterations parser.add_argument( "--collective", type=str, default="all_reduce", help='Collective to benchmark, supports ' + str(supportedCollectives)) # collective op to benchmark parser.add_argument( "--master-ip", type=str, default="127.0.0.1", help="The master-IP to coordinate") # The master-IP to coordinate. parser.add_argument( "--master-port", type=str, default="29500", help="The master-port to coordinate") # The master-port to coordinate. parser.add_argument( "--nw-stack", type=str, default="pytorch-nccl", help="network stack to be used, supports " + str(supportedNwstacks)) # The network stack to profile. parser.add_argument( "--dtype", type=torch.dtype, default=torch.float32 ) # will be overwritten based on args.data_type and dtypeMap. parser.add_argument("--data-type", type=str, default="float32", help="the base data type, supports " + str(supportedDtype)) # The data type parser.add_argument( "--num-tpu-cores", type=int, default=1, help="number of TPU cores to be used") # number of TPU cores # For comm-compute or compute mode parser.add_argument("--kernel", type=str, default="gemm", help="compute kernel") # Compute kernel: "gemm" parser.add_argument( "--num-compute", type=int, default=100, help="one collective for every NUM_COMPUTE compute kernels" ) # Launch one coll for every n compute kernels # For GEMM parser.add_argument("--mm-dim", type=int, default=100, help="dimension size for GEMM compute kernel" ) # Matrix multiplication dim n, A[n,n] * B [n,n] # For emb lookup parser.add_argument( "--emb-dim", type=int, default=128, help="dimension size for Embedding table compute kernel" ) # Embedding table dimension parser.add_argument( "--num-embs", type=int, default=100000, help="Embedding table hash size for Embedding table compute kernel" ) # Embedding table hash size parser.add_argument("--avg-len", type=int, default=28, help="Average lookup operations per sample" ) # Average #lookup per sample parser.add_argument("--batch-size", type=int, default=512, help="number of samples reading the table concurrently" ) # #Samples reading the table concurrently parser.add_argument( "--root", type=int, default=0, help="root process for reduce benchmark" ) # root process for reduce (and gather, scatter, bcast, etc., if support in the future) # TODO: check the correctness of root, should be between 0 to [world_size -1] args, leftovers = parser.parse_known_args() args.b = comms_utils.parsesize(args.b) args.e = comms_utils.parsesize(args.e) if args.nw_stack not in supportedNwstacks: print( "\t ERROR: Specified backend: %s is not one of the supported backends: %s. Make sure the input is using the correct case." % (args.nw_stack, str(supportedNwstacks))) sys.exit( ) # WARNING: Assuming sys is always used, should find a platform-independent way to gracefully exit. if args.collective not in supportedCollectives: print( "\t ERROR: Specified collective: %s is not one of the supported collectives: %s. Make sure the input is using the correct case." % (args.collective, str(supportedCollectives))) sys.exit( ) # WARNING: Assuming sys is always used, should find a platform-independent way to gracefully exit. if args.z not in supportedCommStyle: print( "\t ERROR: Specified blocking: %d is not one of the supported commstyle: %s" % (args.z, str(supportedCommStyle))) comms_utils.gracefulExit() if args.data_type not in supportedDtype: print( "\t ERROR: Specified dtype: %d is not one of the supported commstyle: %s" % (args.data_type, str(supportedDtype))) comms_utils.gracefulExit() args.dtype = dtypeMap[args.data_type] mpi_env_params = comms_utils.read_mpi_env_vars() if mpi_env_params["global_rank"] == 0: print("\t MPI environment: %s " % (str(mpi_env_params))) print( "\t backend: %s nw-stack: %s mode: %s args.b: %d args.e: %d args.f: %d args.z: %s args.master_ip: %s " % ( args.backend, args.nw_stack, args.mode, args.b, args.e, args.f, args.z, args.master_ip, )) if args.num_tpu_cores not in supported_tpu_core_valuses: print( "\t ERROR: TPU core value: %d is not one of the supported values: %s " % (args.num_tpu_cores, supported_tpu_core_valuses)) comms_utils.gracefulExit() if args.b < 1: print("\t Starting size: %d should atleast be 1! " % (args.b)) args.b = 1 element_size = torch.ones([1], dtype=args.dtype).element_size() comms_world_info = comms_utils.comms_world_info_holder( args.master_ip, args.master_port, args.num_tpu_cores, mpi_env_params) commsParams = comms_utils.commsParamsHolder(args, element_size, benchTime) runComms(comms_world_info, commsParams)