예제 #1
0
def launch_command(args):
    # Sanity checks
    if args.multi_gpu and args.tpu:
        raise ValueError(
            "You can only pick one between `--multi_gpu` and `--tpu`.")

    defaults = None
    # Get the default from the config file.
    if args.config_file is not None or os.path.isfile(
            default_config_file) and not args.cpu:
        defaults = load_config_from_file(args.config_file)
        if not args.multi_gpu and not args.tpu:
            args.multi_gpu = defaults.distributed_type == DistributedType.MULTI_GPU
            args.tpu = defaults.distributed_type == DistributedType.TPU
        if args.num_processes is None and defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE:
            args.num_processes = defaults.num_processes
        if not args.fp16:
            args.fp16 = defaults.fp16
        if args.main_training_function is None and defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE:
            args.main_training_function = defaults.main_training_function
    else:
        if args.num_processes is None:
            args.num_processes = 1

    # Use the proper launcher
    if args.multi_gpu and not args.cpu:
        multi_gpu_launcher(args)
    elif args.tpu and not args.cpu:
        tpu_launcher(args)
    elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
        sagemaker_launcher(defaults, args)
    else:
        simple_launcher(args)
예제 #2
0
def launch_command(args):
    # Sanity checks
    if sum([args.multi_gpu, args.tpu, args.use_deepspeed]) > 1:
        raise ValueError(
            "You can only pick one between `--multi_gpu`, `--use_deepspeed`, `--tpu`."
        )

    defaults = None
    # Get the default from the config file.
    if args.config_file is not None or os.path.isfile(
            default_config_file) and not args.cpu:
        defaults = load_config_from_file(args.config_file)
        if not args.multi_gpu and not args.tpu and not args.use_deepspeed:
            args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
            args.multi_gpu = defaults.distributed_type == DistributedType.MULTI_GPU
            args.tpu = defaults.distributed_type == DistributedType.TPU
        if defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE:
            # Update args with the defaults
            for name, attr in defaults.__dict__.items():
                if isinstance(attr, dict):
                    for k in defaults.deepspeed_config:
                        if getattr(args, k) is None:
                            setattr(args, k, defaults.deepspeed_config[k])
                    continue

                # Those args are handled separately
                if (name not in [
                        "compute_environment", "fp16", "distributed_type"
                ] and getattr(args, name, None) is None):
                    setattr(args, name, attr)

        if not args.fp16:
            args.fp16 = defaults.fp16
    else:
        if args.num_processes is None:
            args.num_processes = 1

    # Use the proper launcher
    if args.use_deepspeed and not args.cpu:
        deepspeed_launcher(args)
    elif args.multi_gpu and not args.cpu:
        multi_gpu_launcher(args)
    elif args.tpu and not args.cpu:
        tpu_launcher(args)
    elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
        sagemaker_launcher(defaults, args)
    else:
        simple_launcher(args)