Ejemplo n.º 1
0
    def checkArgs(self, args):
        super().checkArgs(args)
        if args.collective not in supportedCollectives:
            print(
                "\t ERROR: Specified collective: %s is not one of the supported collectives: %s. Make sure the input is using the correct case."
                % (args.collective, str(supportedCollectives))
            )
            comms_utils.gracefulExit()

        if args.z not in supportedCommStyle:
            print(
                "\t ERROR: Specified blocking: %d is not one of the supported commstyle: %s"
                % (args.z, str(supportedCommStyle))
            )
            comms_utils.gracefulExit()

        args.b = comms_utils.parsesize(args.b)
        args.e = comms_utils.parsesize(args.e)
        args.dtype = self.dtypeMap[args.data_type]

        if args.b < 1:
            print("\t Starting size: %d should atleast be 1! " % (args.b))
            args.b = 1

        if args.e < args.b:
            print(
                "\t ERROR: In COMMS-mode, the begin-size: %d is larger than the end-size: %d "
                % (args.b, args.e)
            )

        if args.device == "cpu" and args.backend == "nccl":
            raise ValueError("NCCL is not supported for device type CPU")
Ejemplo n.º 2
0
    def checkArgs(self, args):
        super().checkArgs(args)

        if args.pt2pt is not None:
            args.collective = "pt2pt"
            if args.pt2pt not in pt2ptPatterns:
                print(
                    "\t ERROR: Specified pt2pt pattern: %d is not one of the supported pt2pt patterns: %s"
                    % (args.pt2pt, str(pt2ptPatterns)))
                comms_utils.gracefulExit()

        args.b = comms_utils.parsesize(args.b)
        args.e = comms_utils.parsesize(args.e)
        args.dtype = self.dtypeMap[args.data_type]

        if args.b < 1:
            print("\t Starting size: %d should atleast be 1! " % (args.b))
            args.b = 1

        if args.e < args.b:
            print(
                "\t ERROR: In COMMS-mode, the begin-size: %d is larger than the end-size: %d "
                % (args.b, args.e))

        if args.device == "cpu" and args.backend == "nccl":
            raise ValueError("NCCL is not supported for device type CPU")

        if args.c == 1 and args.z == 0:
            logging.warning(
                "Data validation is not supported for non-blocking mode...disable validation check and proceed..."
            )
            args.c = 0
Ejemplo n.º 3
0
    def checkArgs(self, args):
        super().checkArgs(args)

        if args.pt2pt is not None:
            args.collective = "pt2pt"
            if args.pt2pt not in pt2ptPatterns:
                logger.error(
                    f"Specified pt2pt pattern: {args.pt2pt} is not one of the supported pt2pt patterns: {str(pt2ptPatterns)}"
                )
                comms_utils.gracefulExit()

        args.b = comms_utils.parsesize(args.b)
        args.e = comms_utils.parsesize(args.e)
        args.dtype = self.dtypeMap[args.data_type]

        if args.b < 1:
            logger.warning(
                f"Starting size (--b {args.b}) should be greater than 1 byte...fix and continue"
            )
            args.b = 1

        if args.e < args.b:
            logger.warning(
                f"the begin-size (--b {args.b}) is larger than the end-size (--e {args.e})"
            )

        if args.device == "cpu" and args.backend == "nccl":
            raise ValueError(
                f"NCCL is not supported for device type {args.device}")

        if args.c == 1 and args.z == 0 and args.collective in (
                "all_reduce", "reduce", "reduce_scatter"):
            logger.warning(
                f"Data validation is not supported for {args.collective} in non-blocking mode, disabled and continue"
            )
            args.c = 0

        # run a few sanity checks
        if args.bitwidth < 32:
            if args.device != "cuda":
                logger.error(
                    f"collective quantization may not be fully supported for {args.device}"
                )
            comms_utils.checkQuantArgs(
                args.collective,
                args.dtype,
                args.b,
                args.quant_a2a_embedding_dim,
                args.z,
            )
Ejemplo n.º 4
0
def main():
    ### import packages ###
    import sys
    import argparse

    supportedCommStyle = [0, 1]  # 0 : non-blocking, 1 : blocking.
    supportedCollectives = [
        "reduce",
        "all_reduce",
        "all_to_all",
        "all_to_allv",
    ]  # , "scatter", "gather", "all_gather", "broadcast", "all_to_all"]
    supportedNwstacks = ["pytorch-nccl", "pytorch-xla-tpu"]
    supported_tpu_core_valuses = [1, 8]
    dtypeMap = {
        "float32": torch.float32,
        "int32": torch.int32,
        "float16": torch.half,
        "float64": torch.double,
    }
    supportedDtype = list(dtypeMap.keys())

    ### parse arguments ###
    parser = argparse.ArgumentParser(
        description="PARAM-Comm Benchmark",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # experiment related parameters
    parser.add_argument(
        "--backend",
        type=str,
        default="nccl",
        help="The backend to be used in PyTorch distributed process group"
    )  # alternative is DLRM mode.
    parser.add_argument(
        "--mode", type=str, default="comms",
        help="benchmark mode")  # alternative is DLRM mode or comm-compute mode
    parser.add_argument("--b",
                        type=str,
                        default="8",
                        help="minimum size, in bytes, to start with"
                        )  # COMMS mode, begin the sweep at.
    parser.add_argument("--e",
                        type=str,
                        default="64",
                        help="maximum size, in bytes, to end at"
                        )  # COMMS mode, end the sweep at.
    parser.add_argument("--f",
                        type=int,
                        default=2,
                        help="multiplication factor between sizes"
                        )  # COMMS mode, multiplication factor.
    parser.add_argument("--z",
                        type=int,
                        default=1,
                        help="use blocking mode for collectives"
                        )  # 'sync/blocking' : 1 , 'async/non-blocking' : 0

    parser.add_argument(
        "--w", type=int, default=5,
        help="number of warmup iterations")  # number of warmup-iterations
    parser.add_argument("--n",
                        type=int,
                        default=5,
                        help="number of iterations")  # number of iterations
    parser.add_argument(
        "--collective",
        type=str,
        default="all_reduce",
        help='Collective to benchmark, supports ' +
        str(supportedCollectives))  # collective op to benchmark
    parser.add_argument(
        "--master-ip",
        type=str,
        default="127.0.0.1",
        help="The master-IP to coordinate")  # The master-IP to coordinate.
    parser.add_argument(
        "--master-port",
        type=str,
        default="29500",
        help="The master-port to coordinate")  # The master-port to coordinate.
    parser.add_argument(
        "--nw-stack",
        type=str,
        default="pytorch-nccl",
        help="network stack to be used, supports " +
        str(supportedNwstacks))  # The network stack to profile.
    parser.add_argument(
        "--dtype", type=torch.dtype, default=torch.float32
    )  # will be overwritten based on args.data_type and dtypeMap.
    parser.add_argument("--data-type",
                        type=str,
                        default="float32",
                        help="the base data type, supports " +
                        str(supportedDtype))  # The data type

    parser.add_argument(
        "--num-tpu-cores",
        type=int,
        default=1,
        help="number of TPU cores to be used")  # number of TPU cores

    # For comm-compute or compute mode
    parser.add_argument("--kernel",
                        type=str,
                        default="gemm",
                        help="compute kernel")  # Compute kernel: "gemm"
    parser.add_argument(
        "--num-compute",
        type=int,
        default=100,
        help="one collective for every NUM_COMPUTE compute kernels"
    )  # Launch one coll for every n compute kernels
    # For GEMM
    parser.add_argument("--mm-dim",
                        type=int,
                        default=100,
                        help="dimension size for GEMM compute kernel"
                        )  # Matrix multiplication dim n, A[n,n] * B [n,n]
    # For emb lookup
    parser.add_argument(
        "--emb-dim",
        type=int,
        default=128,
        help="dimension size for Embedding table compute kernel"
    )  # Embedding table dimension
    parser.add_argument(
        "--num-embs",
        type=int,
        default=100000,
        help="Embedding table hash size for Embedding table compute kernel"
    )  # Embedding table hash size
    parser.add_argument("--avg-len",
                        type=int,
                        default=28,
                        help="Average lookup operations per sample"
                        )  # Average #lookup per sample
    parser.add_argument("--batch-size",
                        type=int,
                        default=512,
                        help="number of samples reading the table concurrently"
                        )  # #Samples reading the table concurrently
    parser.add_argument(
        "--root",
        type=int,
        default=0,
        help="root process for reduce benchmark"
    )  # root process for reduce (and gather, scatter, bcast, etc., if support in the future)
    # TODO: check the correctness of root, should be between 0 to [world_size -1]

    args, leftovers = parser.parse_known_args()
    args.b = comms_utils.parsesize(args.b)
    args.e = comms_utils.parsesize(args.e)

    if args.nw_stack not in supportedNwstacks:
        print(
            "\t ERROR: Specified backend: %s is not one of the supported backends: %s. Make sure the input is using the correct case."
            % (args.nw_stack, str(supportedNwstacks)))
        sys.exit(
        )  # WARNING: Assuming sys is always used, should find a platform-independent way to gracefully exit.

    if args.collective not in supportedCollectives:
        print(
            "\t ERROR: Specified collective: %s is not one of the supported collectives: %s. Make sure the input is using the correct case."
            % (args.collective, str(supportedCollectives)))
        sys.exit(
        )  # WARNING: Assuming sys is always used, should find a platform-independent way to gracefully exit.

    if args.z not in supportedCommStyle:
        print(
            "\t ERROR: Specified blocking: %d is not one of the supported commstyle: %s"
            % (args.z, str(supportedCommStyle)))
        comms_utils.gracefulExit()

    if args.data_type not in supportedDtype:
        print(
            "\t ERROR: Specified dtype: %d is not one of the supported commstyle: %s"
            % (args.data_type, str(supportedDtype)))
        comms_utils.gracefulExit()

    args.dtype = dtypeMap[args.data_type]

    mpi_env_params = comms_utils.read_mpi_env_vars()
    if mpi_env_params["global_rank"] == 0:
        print("\t MPI environment: %s " % (str(mpi_env_params)))
        print(
            "\t backend: %s nw-stack: %s mode: %s args.b: %d args.e: %d args.f: %d args.z: %s args.master_ip: %s "
            % (
                args.backend,
                args.nw_stack,
                args.mode,
                args.b,
                args.e,
                args.f,
                args.z,
                args.master_ip,
            ))

    if args.num_tpu_cores not in supported_tpu_core_valuses:
        print(
            "\t ERROR: TPU core value: %d is not one of the supported values: %s "
            % (args.num_tpu_cores, supported_tpu_core_valuses))
        comms_utils.gracefulExit()

    if args.b < 1:
        print("\t Starting size: %d should atleast be 1! " % (args.b))
        args.b = 1

    element_size = torch.ones([1], dtype=args.dtype).element_size()
    comms_world_info = comms_utils.comms_world_info_holder(
        args.master_ip, args.master_port, args.num_tpu_cores, mpi_env_params)
    commsParams = comms_utils.commsParamsHolder(args, element_size, benchTime)
    runComms(comms_world_info, commsParams)
Ejemplo n.º 5
0
 def test_single_size(self):
     sizeStr = "1024"
     size = comms_utils.parsesize(sizeStr)
     self.assertEqual(1024, size)
Ejemplo n.º 6
0
 def test_kb_size(self):
     sizeStr = "5KB"
     size = comms_utils.parsesize(sizeStr)
     self.assertEqual(5120, size)
Ejemplo n.º 7
0
 def test_mb_size(self):
     sizeStr = "3MB"
     size = comms_utils.parsesize(sizeStr)
     self.assertEqual(3145728, size)
Ejemplo n.º 8
0
 def test_gb_size(self):
     sizeStr = "2GB"
     size = comms_utils.parsesize(sizeStr)
     # size is in bytes
     self.assertEqual(2147483648, size)