Beispiel #1
0
def main():
    collBenchObj = commsCollBench()

    ### parse arguments ###
    parser = argparse.ArgumentParser(
        description="PARAM-Comm Benchmark",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    args, leftovers = collBenchObj.readArgs(parser)

    collBenchObj.checkArgs(args)

    mpi_env_params = comms_utils.read_mpi_env_vars()
    if mpi_env_params["global_rank"] == 0:
        print("\t MPI environment: %s " % (str(mpi_env_params)))
        print(
            "\t backend: %s nw-stack: %s mode: %s args.b: %d args.e: %d args.f: %d args.z: %s args.master_ip: %s "
            % (
                args.backend,
                args.nw_stack,
                args.mode,
                args.b,
                args.e,
                args.f,
                args.z,
                args.master_ip,
            ))

    element_size = torch.ones([1], dtype=args.dtype).element_size()
    comms_world_info = comms_utils.comms_world_info_holder(
        args.master_ip, args.master_port, args.num_tpu_cores, mpi_env_params)

    commsParams = comms_utils.commsParamsHolder(args, element_size,
                                                collBenchObj.benchTime)
    collBenchObj.runBench(comms_world_info, commsParams)
def main():

    mpi_env_params = comms_utils.read_mpi_env_vars()

    traceBench = commsTraceReplayBench()
    parser = argparse.ArgumentParser(
        description="PARAM-Comms Trace Replay Mode",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    args = traceBench.readArgs(parser)
    traceBench.setTraceFile(args, mpi_env_params)
    traceBench.checkArgs(args)

    time.sleep(1)
    comms_world_info = comms_utils.comms_world_info_holder(
        args.master_ip, args.master_port, args.num_tpu_cores, mpi_env_params
    )
    commsParams = comms_utils.commsParamsHolderBase(args)
    traceBench.initBench(comms_world_info, commsParams, args)
    traceBench.runBench(comms_world_info, commsParams)
Beispiel #3
0
def main():
    ### import packages ###
    import sys
    import argparse

    supportedCommStyle = [0, 1]  # 0 : non-blocking, 1 : blocking.
    supportedCollectives = [
        "reduce",
        "all_reduce",
        "all_to_all",
        "all_to_allv",
    ]  # , "scatter", "gather", "all_gather", "broadcast", "all_to_all"]
    supportedNwstacks = ["pytorch-nccl", "pytorch-xla-tpu"]
    supported_tpu_core_valuses = [1, 8]
    dtypeMap = {
        "float32": torch.float32,
        "int32": torch.int32,
        "float16": torch.half,
        "float64": torch.double,
    }
    supportedDtype = list(dtypeMap.keys())

    ### parse arguments ###
    parser = argparse.ArgumentParser(
        description="PARAM-Comm Benchmark",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # experiment related parameters
    parser.add_argument(
        "--backend",
        type=str,
        default="nccl",
        help="The backend to be used in PyTorch distributed process group"
    )  # alternative is DLRM mode.
    parser.add_argument(
        "--mode", type=str, default="comms",
        help="benchmark mode")  # alternative is DLRM mode or comm-compute mode
    parser.add_argument("--b",
                        type=str,
                        default="8",
                        help="minimum size, in bytes, to start with"
                        )  # COMMS mode, begin the sweep at.
    parser.add_argument("--e",
                        type=str,
                        default="64",
                        help="maximum size, in bytes, to end at"
                        )  # COMMS mode, end the sweep at.
    parser.add_argument("--f",
                        type=int,
                        default=2,
                        help="multiplication factor between sizes"
                        )  # COMMS mode, multiplication factor.
    parser.add_argument("--z",
                        type=int,
                        default=1,
                        help="use blocking mode for collectives"
                        )  # 'sync/blocking' : 1 , 'async/non-blocking' : 0

    parser.add_argument(
        "--w", type=int, default=5,
        help="number of warmup iterations")  # number of warmup-iterations
    parser.add_argument("--n",
                        type=int,
                        default=5,
                        help="number of iterations")  # number of iterations
    parser.add_argument(
        "--collective",
        type=str,
        default="all_reduce",
        help='Collective to benchmark, supports ' +
        str(supportedCollectives))  # collective op to benchmark
    parser.add_argument(
        "--master-ip",
        type=str,
        default="127.0.0.1",
        help="The master-IP to coordinate")  # The master-IP to coordinate.
    parser.add_argument(
        "--master-port",
        type=str,
        default="29500",
        help="The master-port to coordinate")  # The master-port to coordinate.
    parser.add_argument(
        "--nw-stack",
        type=str,
        default="pytorch-nccl",
        help="network stack to be used, supports " +
        str(supportedNwstacks))  # The network stack to profile.
    parser.add_argument(
        "--dtype", type=torch.dtype, default=torch.float32
    )  # will be overwritten based on args.data_type and dtypeMap.
    parser.add_argument("--data-type",
                        type=str,
                        default="float32",
                        help="the base data type, supports " +
                        str(supportedDtype))  # The data type

    parser.add_argument(
        "--num-tpu-cores",
        type=int,
        default=1,
        help="number of TPU cores to be used")  # number of TPU cores

    # For comm-compute or compute mode
    parser.add_argument("--kernel",
                        type=str,
                        default="gemm",
                        help="compute kernel")  # Compute kernel: "gemm"
    parser.add_argument(
        "--num-compute",
        type=int,
        default=100,
        help="one collective for every NUM_COMPUTE compute kernels"
    )  # Launch one coll for every n compute kernels
    # For GEMM
    parser.add_argument("--mm-dim",
                        type=int,
                        default=100,
                        help="dimension size for GEMM compute kernel"
                        )  # Matrix multiplication dim n, A[n,n] * B [n,n]
    # For emb lookup
    parser.add_argument(
        "--emb-dim",
        type=int,
        default=128,
        help="dimension size for Embedding table compute kernel"
    )  # Embedding table dimension
    parser.add_argument(
        "--num-embs",
        type=int,
        default=100000,
        help="Embedding table hash size for Embedding table compute kernel"
    )  # Embedding table hash size
    parser.add_argument("--avg-len",
                        type=int,
                        default=28,
                        help="Average lookup operations per sample"
                        )  # Average #lookup per sample
    parser.add_argument("--batch-size",
                        type=int,
                        default=512,
                        help="number of samples reading the table concurrently"
                        )  # #Samples reading the table concurrently
    parser.add_argument(
        "--root",
        type=int,
        default=0,
        help="root process for reduce benchmark"
    )  # root process for reduce (and gather, scatter, bcast, etc., if support in the future)
    # TODO: check the correctness of root, should be between 0 to [world_size -1]

    args, leftovers = parser.parse_known_args()
    args.b = comms_utils.parsesize(args.b)
    args.e = comms_utils.parsesize(args.e)

    if args.nw_stack not in supportedNwstacks:
        print(
            "\t ERROR: Specified backend: %s is not one of the supported backends: %s. Make sure the input is using the correct case."
            % (args.nw_stack, str(supportedNwstacks)))
        sys.exit(
        )  # WARNING: Assuming sys is always used, should find a platform-independent way to gracefully exit.

    if args.collective not in supportedCollectives:
        print(
            "\t ERROR: Specified collective: %s is not one of the supported collectives: %s. Make sure the input is using the correct case."
            % (args.collective, str(supportedCollectives)))
        sys.exit(
        )  # WARNING: Assuming sys is always used, should find a platform-independent way to gracefully exit.

    if args.z not in supportedCommStyle:
        print(
            "\t ERROR: Specified blocking: %d is not one of the supported commstyle: %s"
            % (args.z, str(supportedCommStyle)))
        comms_utils.gracefulExit()

    if args.data_type not in supportedDtype:
        print(
            "\t ERROR: Specified dtype: %d is not one of the supported commstyle: %s"
            % (args.data_type, str(supportedDtype)))
        comms_utils.gracefulExit()

    args.dtype = dtypeMap[args.data_type]

    mpi_env_params = comms_utils.read_mpi_env_vars()
    if mpi_env_params["global_rank"] == 0:
        print("\t MPI environment: %s " % (str(mpi_env_params)))
        print(
            "\t backend: %s nw-stack: %s mode: %s args.b: %d args.e: %d args.f: %d args.z: %s args.master_ip: %s "
            % (
                args.backend,
                args.nw_stack,
                args.mode,
                args.b,
                args.e,
                args.f,
                args.z,
                args.master_ip,
            ))

    if args.num_tpu_cores not in supported_tpu_core_valuses:
        print(
            "\t ERROR: TPU core value: %d is not one of the supported values: %s "
            % (args.num_tpu_cores, supported_tpu_core_valuses))
        comms_utils.gracefulExit()

    if args.b < 1:
        print("\t Starting size: %d should atleast be 1! " % (args.b))
        args.b = 1

    element_size = torch.ones([1], dtype=args.dtype).element_size()
    comms_world_info = comms_utils.comms_world_info_holder(
        args.master_ip, args.master_port, args.num_tpu_cores, mpi_env_params)
    commsParams = comms_utils.commsParamsHolder(args, element_size, benchTime)
    runComms(comms_world_info, commsParams)