def setup_training(args):

    assert torch.cuda.is_available()

    if args.smp > 0:
        # Initialize SMP. The configuration is obtained from the parameters passed to
        # the Sagemaker PyTorch estimator.
        smp.init()

    # SMP: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda", smp.local_rank())
    args.n_gpu = 1

    # if args.local_rank == -1:
    #    device = torch.device("cuda")
    #    args.n_gpu = torch.cuda.device_count()
    #    args.allreduce_post_accumulation = False
    #    args.allreduce_post_accumulation_fp16 = False
    # else:
    #    torch.cuda.set_device(args.local_rank)
    #    device = torch.device("cuda", args.local_rank)
    #    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    #    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    #    args.n_gpu = 1

    if args.gradient_accumulation_steps == 1:
        args.allreduce_post_accumulation = False
        args.allreduce_post_accumulation_fp16 = False

    print(
        "device: {} n_gpu: {}, mp_rank: {}, rank: {}, distributed training: {}, 16-bits training: {}"
        .format(device, args.n_gpu, smp.mp_rank(), smp.rank(),
                bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    if args.train_batch_size % args.gradient_accumulation_steps != 0:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible"
            .format(args.gradient_accumulation_steps, args.train_batch_size))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    if (not args.resume_from_checkpoint and os.path.exists(args.output_dir) and
        (os.listdir(args.output_dir)
         and any([i.startswith("ckpt")
                  for i in os.listdir(args.output_dir)]))):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))

    if (not args.resume_from_checkpoint
            or not os.path.exists(args.output_dir)) and is_main_process():
        os.makedirs(args.output_dir, exist_ok=True)

    return device, args
Beispiel #2
0
        def wrapper(*args, **kwargs):
            mem_before = torch.cuda.memory_allocated(device=smp.local_rank())
            f(*args, **kwargs)
            import gc

            gc.collect()
            gc.collect()
            gc.collect()
            mem_after = torch.cuda.memory_allocated(device=smp.local_rank())
            print(
                f"rank is {smp.local_rank()}, function name is {f.__name__}, memory usage is {humanize.naturalsize(mem_after - mem_before)}"
            )
Beispiel #3
0
def measure_additional_mem_context():
    smp.barrier()
    mem_before = torch.cuda.memory_allocated(device=smp.local_rank())
    yield
    import gc

    gc.collect()
    gc.collect()
    gc.collect()
    mem_after = torch.cuda.memory_allocated(device=smp.local_rank())
    print(
        f"rank is {smp.local_rank()}, memory usage is {humanize.naturalsize(mem_after - mem_before)}"
    )
    smp.barrier()
Beispiel #4
0
def dist_setting(args):
    #     args.data_parallel = False

    print("args.data_parallel : {}".format(args.data_parallel))
    print("args.model_parallel : {}".format(args.model_parallel))
    print("args.apex : {}".format(args.apex))

    args.world_size = 1
    args.host_num = args.hosts.index(args.current_host)

    if args.data_parallel:
        args.world_size = sdp.get_world_size()
        args.rank = sdp.get_rank()  # total rank in all hosts
        args.local_rank = sdp.get_local_rank()  # rank per host
    elif args.model_parallel:
        args.world_size = smp.size()
        args.local_rank = smp.local_rank()  # rank per host
        args.rank = smp.rank()
        args.dp_size = smp.dp_size()
        args.dp_rank = smp.dp_rank()
        print(
            "smp.rank() : {}, smp.size() : {}, smp.mp_rank() : {}, smp.local_size() : {}, smp.get_mp_group() : {}, smp.get_dp_group() : {}, smp.local_rank() : {}, smp.dp_size() : {}, smp.dp_rank() : {}"
            .format(smp.rank(), smp.size(), smp.mp_rank(), smp.local_size(),
                    smp.get_mp_group(), smp.get_dp_group(), smp.local_rank(),
                    smp.dp_size(), smp.dp_rank()))
    else:
        args.world_size = len(args.hosts) * args.num_gpus
        if args.local_rank is not None:
            args.rank = args.num_gpus * args.host_num + \
                args.local_rank  # total rank in all hosts

        dist.init_process_group(backend=args.backend,
                                rank=args.rank,
                                world_size=args.world_size)
        logger.info(
            'Initialized the distributed environment: \'{}\' backend on {} nodes. '
            .format(args.backend, dist.get_world_size()) +
            'Current host rank is {}. Number of gpus: {}'.format(
                dist.get_rank(), args.num_gpus))

    print("**** [dist_setting] args.rank : {}".format(args.rank))
    print("args.world_size : {}".format(args.world_size))
    print("Use GPU: {} for training".format(args.local_rank))

    args.lr = args.lr * float(args.world_size)

    args.batch_size //= args.world_size // args.num_gpus
    args.batch_size = max(args.batch_size, 1)

    return args
Beispiel #5
0
def delete_oldest_ckpt(args, delete_on_rank0_only=False):
    to_delete = smp.rank() == 0 if delete_on_rank0_only else smp.local_rank(
    ) == 0
    if to_delete:
        re_pattern = "trained_gpt_nparams-(?P<num_params>\d+)_steps-(?P<total_steps>\d+)\.pt"

        # partial
        re_pattern += "_(?P<pp_rank>\d+)_(?P<tp_rank>\d+)"

        paths_per_step = collections.defaultdict(list)

        for p in os.listdir(args.checkpoint_dir):
            if re.match(re_pattern, p):
                step = int(re.match(re_pattern, p).group("total_steps"))
                path = os.path.join(args.checkpoint_dir, p)
                paths_per_step[step].append(path)

        if paths_per_step:
            oldest_step = sorted(paths_per_step.keys())[0]
            num_parts = len(paths_per_step[oldest_step])
            if len(paths_per_step) >= args.num_kept_checkpoints:
                # delete oldest step to save the new one
                for p in paths_per_step[oldest_step]:
                    os.remove(p)
        # else We still haven't reached maximum number of checkpoints -- no need to delete older ones
    return None
Beispiel #6
0
    def _setup_devices(self) -> "torch.device":
        logger.info("PyTorch: setting up devices")
        if self.no_cuda:
            device = torch.device("cpu")
            self._n_gpu = 0
        elif is_torch_tpu_available():
            device = xm.xla_device()
            self._n_gpu = 0
        elif is_sagemaker_mp_enabled():
            local_rank = smp.local_rank()
            device = torch.device("cuda", local_rank)
            self._n_gpu = 1
        elif is_sagemaker_dp_enabled():
            sm_dist.init_process_group()
            self.local_rank = sm_dist.get_local_rank()
            device = torch.device("cuda", self.local_rank)
            self._n_gpu = 1
        elif self.deepspeed:
            # deepspeed performs its own DDP internally, and requires the program to be started with:
            # deepspeed  ./program.py
            # rather than:
            # python -m torch.distributed.launch --nproc_per_node=2 ./program.py
            from .integrations import is_deepspeed_available

            if not is_deepspeed_available():
                raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.")
            import deepspeed

            deepspeed.init_distributed()

            # workaround for setups like notebooks where the launcher can't be used,
            # but deepspeed requires a dist env.
            # env LOCAL_RANK could be set manually by the user, or via init_distributed if mpi4py is installed
            self.local_rank = int(os.environ.get("LOCAL_RANK", "-1"))

            device = torch.device("cuda", self.local_rank)
            self._n_gpu = 1
        elif self.local_rank == -1:
            # if n_gpu is > 1 we'll use nn.DataParallel.
            # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
            # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
            # trigger an error that a device index is missing. Index 0 takes into account the
            # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
            # will use the first GPU in that env, i.e. GPU#1
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
            # the default value.
            self._n_gpu = torch.cuda.device_count()
        else:
            # Here, we'll use torch.distributed.
            # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
            torch.distributed.init_process_group(backend="nccl")
            device = torch.device("cuda", self.local_rank)
            self._n_gpu = 1

        if device.type == "cuda":
            torch.cuda.set_device(device)

        return device
Beispiel #7
0
def smp_init(args):
    args.deepspeed = False
    cfg = {
        "microbatches": args.num_microbatches,
        "placement_strategy": args.placement_strategy,
        "pipeline": args.pipeline,
        "optimize": args.optimize,
        "partitions": args.num_partitions,
        "horovod": args.horovod,
        "ddp": args.ddp,
    }

    smp.init(cfg)

    args.rank = smp.dp_rank()
    args.global_rank = smp.rank()
    args.world_size = smp.size()
    
    
    os.environ['RANK'] = str(args.rank)
    os.environ['WORLD_SIZE'] = str(args.world_size)
    os.environ['LOCAL_RANK'] = str(smp.local_rank())
    
#     ## SMP_SKIP_GRAPH_VALIDATION=1
    os.environ['SMP_SKIP_GRAPH_VALIDATION'] = "0"
    
#     args.bpe_path = "/opt/ml/code/dalle_pytorch/data/bpe_simple_vocab_16e6.txt"
    
    torch.cuda.set_device(smp.local_rank())
    args.local_rank = smp.local_rank()
    
#     if args.seed is not None:
#         random.seed(args.seed)
#         torch.manual_seed(args.seed+args.rank)
#         np.random.seed(args.seed)
#         torch.cuda.manual_seed_all(args.seed)
        
#     cudnn.deterministic = True

#     if cudnn.deterministic:
#         warnings.warn('You have chosen to seed training. '
#                       'This will turn on the CUDNN deterministic setting, '
#                       'which can slow down your training considerably! '
#                       'You may see unexpected behavior when restarting '
#                       'from checkpoints.')
    return args
 def is_world_process_zero(self) -> bool:
     """
     Whether or not this process is the global main process (when training in a distributed fashion on several
     machines, this is only going to be :obj:`True` for one process).
     """
     if self.is_model_parallel_enabled:
         return smp.rank() == 0 and smp.local_rank() == 0 and smp.mp_rank() == 0 and smp.dp_rank() == 0
     else:
         return super().is_world_process_zero()
 def local_process_index(self):
     """
     The index of the local process used.
     """
     if is_torch_tpu_available():
         return xm.get_local_ordinal()
     elif is_sagemaker_mp_enabled():
         return smp.local_rank()
     elif is_sagemaker_dp_enabled():
         return sm_dist.get_rank()
     elif self.local_rank != -1:
         return self.local_rank
     return 0
Beispiel #10
0
    def _setup_devices(self) -> "torch.device":
        logger.info("PyTorch: setting up devices")
        if torch.distributed.is_available(
        ) and torch.distributed.is_initialized() and self.local_rank == -1:
            logger.warning(
                "torch.distributed process group is initialized, but local_rank == -1. "
                "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch"
            )
        if self.no_cuda:
            device = torch.device("cpu")
            self._n_gpu = 0
        elif is_sagemaker_model_parallel_available():
            local_rank = smp.local_rank()
            device = torch.device("cuda", local_rank)
            self._n_gpu = 1
        elif is_sagemaker_dp_enabled():
            import smdistributed.dataparallel.torch.torch_smddp  # noqa: F401

            torch.distributed.init_process_group(backend="smddp")
            self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
            device = torch.device("cuda", self.local_rank)
            self._n_gpu = 1
        elif self.local_rank == -1:
            # if n_gpu is > 1 we'll use nn.DataParallel.
            # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
            # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
            # trigger an error that a device index is missing. Index 0 takes into account the
            # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
            # will use the first GPU in that env, i.e. GPU#1
            device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
            # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
            # the default value.
            self._n_gpu = torch.cuda.device_count()
        else:
            # Here, we'll use torch.distributed.
            # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
            if not torch.distributed.is_initialized():
                torch.distributed.init_process_group(backend="nccl")
            device = torch.device("cuda", self.local_rank)
            self._n_gpu = 1

        if device.type == "cuda":
            torch.cuda.set_device(device)

        return device
def memory_status(msg="", reset_max=True, sync=True):

    rank = smp.rank()
    tp_rank = smp.tp_rank()
    pp_rank = smp.pp_rank()
    rdp_rank = smp.rdp_rank()
    local_rank = smp.local_rank()

    if sync:
        torch.cuda.synchronize()

    if rdp_rank != 0:
        return

    if py3nvml != None:
        py3nvml.nvmlInit()
        handle = py3nvml.nvmlDeviceGetHandleByIndex(local_rank)
        info = py3nvml.nvmlDeviceGetMemoryInfo(handle)
        total_used = info.used / 1024**3
        total_used_str = f"Totally used GPU memory: {total_used}"
    else:
        total_used_str = ""

    alloced = torch.cuda.memory_allocated(device=local_rank)
    max_alloced = torch.cuda.max_memory_allocated(device=local_rank)
    cached = torch.cuda.memory_reserved(device=local_rank)
    max_cached = torch.cuda.max_memory_reserved(device=local_rank)

    # convert to GB for printing
    alloced /= 1024**3
    cached /= 1024**3
    max_alloced /= 1024**3
    max_cached /= 1024**3

    print(
        f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}',
        f'device={local_rank} '
        f'alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} '
        f'cache {cached:0.4f} max_cached {max_cached:0.4f} '
        f'{total_used_str}')
    if reset_max:
        torch.cuda.reset_max_memory_cached()
        torch.cuda.reset_max_memory_allocated()
    if py3nvml != None:
        py3nvml.nvmlShutdown()
Beispiel #12
0
def dist_setting(args):
    #     args.data_parallel = False
    print(f"args.data_parallel : {args.data_parallel}, args.model_parallel : {args.model_parallel}, args.apex : {args.apex}")


    args.world_size = 1
    args.host_num = args.hosts.index(args.current_host)

    if args.data_parallel:
        sdp, DDP = _sdp_import(args)
        
        args.world_size = sdp.get_world_size()
        args.rank = sdp.get_rank()  # total rank in all hosts
        args.local_rank = sdp.get_local_rank()  # rank per host
    elif args.model_parallel:
        args.world_size = smp.size()
        args.world_size = args.num_gpus * len(args.hosts)
        args.local_rank = smp.local_rank()  # rank per host
        args.rank = smp.rank()
        args.dp_size = smp.dp_size()
        args.dp_rank = smp.dp_rank()
    else:
        args.world_size = len(args.hosts) * args.num_gpus
        if args.local_rank is not None:
            args.rank = args.num_gpus * args.host_num + \
                args.local_rank  # total rank in all hosts

        dist.init_process_group(backend=args.backend,
                                rank=args.rank,
                                world_size=args.world_size)
        logger.info(
            'Initialized the distributed environment: \'{}\' backend on {} nodes. '
            .format(args.backend, dist.get_world_size()) +
            'Current host rank is {}. Number of gpus: {}'.format(
                dist.get_rank(), args.num_gpus))
    
#     if not args.model_parallel:
    args.lr = args.lr * float(args.world_size)
    args.batch_size //= args.world_size // args.num_gpus
    args.batch_size = max(args.batch_size, 1)

    return args
    def _setup_devices(self) -> "torch.device":
        logger.info("PyTorch: setting up devices")
        if self.no_cuda:
            device = torch.device("cpu")
            self._n_gpu = 0
        elif is_smdistributed_available() and self.mp_parameters != "":
            # smp.init()
            local_rank = smp.local_rank()
            device = torch.device("cuda", local_rank)
            self._n_gpu = 1
        elif is_sagemaker_distributed_available():
            import smdistributed.dataparallel.torch.distributed as dist

            dist.init_process_group()
            self.local_rank = dist.get_local_rank()
            device = torch.device("cuda", self.local_rank)
            self._n_gpu = 1
        elif self.local_rank == -1:
            # if n_gpu is > 1 we'll use nn.DataParallel.
            # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0`
            # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will
            # trigger an error that a device index is missing. Index 0 takes into account the
            # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0`
            # will use the first GPU in that env, i.e. GPU#1
            device = torch.device(
                "cuda:0" if torch.cuda.is_available() else "cpu")
            # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at
            # the default value.
            self._n_gpu = torch.cuda.device_count()
        else:
            # Here, we'll use torch.distributed.
            # Initializes the distributed backend which will take care of synchronizing nodes/GPUs
            torch.distributed.init_process_group(backend="nccl")
            device = torch.device("cuda", self.local_rank)
            self._n_gpu = 1

        if device.type == "cuda":
            torch.cuda.set_device(device)

        return device
def memory_status_cpu(msg=""):
    import gc
    global last_mem_usage
    global base_mem_usage
    rdp_rank = smp.rdp_rank()
    gc.collect()
    gc.collect()
    gc.collect()
    objects = gc.get_objects()
    tensors = [
        obj for obj in objects
        if isinstance(obj, torch.Tensor) and not obj.is_cuda
    ]
    torch_usage = 0
    for t in tensors:
        torch_usage += t.numel() * dtype_to_bit[t.dtype]
    #total_usage = psutil.virtual_memory()[3] # This will get the total usage for all processes
    current_usage = process.memory_info().data
    total_usage = current_usage - base_mem_usage
    usage_change = current_usage - last_mem_usage
    last_mem_usage = current_usage

    torch_usage /= 1024**3
    total_usage /= 1024**3
    usage_change /= 1024**3
    base_usage = base_mem_usage / 1024**3

    rank = smp.rank()
    tp_rank = smp.tp_rank()
    pp_rank = smp.pp_rank()
    rdp_rank = smp.rdp_rank()
    local_rank = smp.local_rank()
    if rdp_rank != 0:
        return

    print(
        f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}',
        f'device={local_rank} '
        f'torch cpu tensor usage {torch_usage:0.4f} cpu mem usage {total_usage:0.4f} change since last measurement {usage_change:0.4f} base cpu mem usage {base_usage:0.4f}'
    )
def main():
    global timeout_sent

    args = parse_arguments()

    random.seed(args.seed + args.local_rank)
    np.random.seed(args.seed + args.local_rank)
    torch.manual_seed(args.seed + args.local_rank)
    torch.cuda.manual_seed(args.seed + args.local_rank)
    worker_init = WorkerInitObj(args.seed + args.local_rank)

    device, args = setup_training(args)

    # Prepare optimizer
    (
        model,
        optimizer,
        lr_scheduler,
        checkpoint,
        global_step,
        criterion,
    ) = prepare_model_and_optimizer(args, device)

    raw_train_start = None
    most_recent_ckpts_paths = []
    average_loss = 0.0  # averaged loss every args.log_freq steps
    epoch = 0
    training_steps = 0
    test_losses = []

    pool = ProcessPoolExecutor(1)

    # Note: We loop infinitely over epochs, termination is handled via iteration count
    while True:
        thread = None
        restored_data_loader = None
        if (not args.resume_from_checkpoint or epoch > 0
                or (args.phase2 and global_step < 1) or args.init_checkpoint):
            files = [
                os.path.join(args.input_dir, f)
                for f in os.listdir(args.input_dir)
                if os.path.isfile(os.path.join(args.input_dir, f))
                and "training" in f
            ]
            files.sort()
            num_files = len(files)
            random.Random(args.seed + epoch).shuffle(files)
            f_start_id = 0
        else:
            f_start_id = checkpoint["files"][0]
            files = checkpoint["files"][1:]
            args.resume_from_checkpoint = False
            num_files = len(files)
            # may not exist in all checkpoints
            epoch = checkpoint.get("epoch", 0)
            restored_dataloader = checkpoint.get("data_loader", None)

        shared_file_list = {}

        if smp.is_initialized():
            dpsize = smp.dp_size()
            dprank = smp.dp_rank()
        elif torch.distributed.is_initialized():
            dpsize = get_world_size()
            dprank = get_rank()
        else:
            dpsize = 1
            dprank = 0
        dparallel = dpsize > 1
        if dparallel and dpsize > num_files:
            remainder = dpsize % num_files
            data_file = files[(f_start_id * dpsize + dprank +
                               remainder * f_start_id) % num_files]
        else:
            data_file = files[(f_start_id * dpsize + dprank) % num_files]

        previous_file = data_file

        if restored_data_loader is None:
            train_data = pretraining_dataset(data_file,
                                             args.max_predictions_per_seq)
            train_sampler = RandomSampler(train_data)
            train_dataloader = DataLoader(
                train_data,
                sampler=train_sampler,
                batch_size=args.train_batch_size * args.n_gpu,
                num_workers=4,
                worker_init_fn=worker_init,
                pin_memory=True,
                drop_last=True,
            )
            # shared_file_list["0"] = (train_dataloader, data_file)
        else:
            train_dataloader = restored_data_loader
            restored_data_loader = None

        overflow_buf = None
        if args.allreduce_post_accumulation:
            overflow_buf = torch.cuda.IntTensor([0])

        for f_id in range(f_start_id + 1, len(files)):
            if get_world_size() > num_files:
                data_file = files[(f_id * get_world_size() + get_rank() +
                                   remainder * f_id) % num_files]
            else:
                data_file = files[(f_id * get_world_size() + get_rank()) %
                                  num_files]

            previous_file = data_file

            dataset_future = pool.submit(
                create_pretraining_dataset,
                data_file,
                args.max_predictions_per_seq,
                shared_file_list,
                args,
                worker_init,
            )

            train_iter = (tqdm(train_dataloader,
                               desc="Iteration",
                               disable=args.disable_progress_bar)
                          if is_main_process() else train_dataloader)

            if raw_train_start is None:
                raw_train_start = time.time()

            for step, batch in enumerate(train_iter):
                training_steps += 1
                batch = [t.to(device) for t in batch]
                input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch
                if args.do_train:
                    from smdistributed.modelparallel.test.torch.utils import dump_model, verify

                    model.train()
                    if args.smp > 0:
                        loss_mbs = smp_step(
                            args,
                            device,
                            input_ids,
                            segment_ids,
                            input_mask,
                            masked_lm_labels,
                            next_sentence_labels,
                            model,
                            optimizer,
                            criterion,
                            step,
                        )
                        loss = loss_mbs.reduce_mean()
                        if smp.rank() == 0:
                            print("Loss:", loss.item())
                    else:
                        loss = train_step(
                            args,
                            device,
                            input_ids,
                            segment_ids,
                            input_mask,
                            masked_lm_labels,
                            next_sentence_labels,
                            model,
                            optimizer,
                            criterion,
                            step,
                        )
                    divisor = 1
                    average_loss += loss.item()

                    if training_steps % args.gradient_accumulation_steps == 0:
                        lr_scheduler.step()  # learning rate warmup
                        global_step = take_optimizer_step(
                            args, optimizer, model, overflow_buf, global_step)

                    if global_step >= args.steps_this_run or timeout_sent:
                        train_time_raw = time.time() - raw_train_start
                        last_num_steps = (int(
                            training_steps / args.gradient_accumulation_steps)
                                          % args.log_freq)
                        last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps
                        average_loss = torch.tensor(
                            average_loss, dtype=torch.float32).cuda()
                        average_loss = average_loss / (last_num_steps *
                                                       divisor)
                        if torch.distributed.is_initialized():
                            average_loss /= get_world_size()
                            torch.distributed.all_reduce(average_loss)
                        final_loss = loss.item()
                    elif training_steps % (
                            args.log_freq *
                            args.gradient_accumulation_steps) == 0:
                        average_loss = 0

                    if (global_step >= args.steps_this_run or training_steps %
                        (args.num_steps_per_checkpoint *
                         args.gradient_accumulation_steps) == 0
                            or timeout_sent):
                        if smp.dp_rank() == 0 and not args.skip_checkpoint:
                            if args.resume_step < 0 or not args.phase2:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step))
                            else:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step +
                                                        args.phase1_end_step),
                                )
                            if args.do_train:
                                save_dict = {
                                    "model":
                                    model.local_state_dict(),
                                    "optimizer":
                                    optimizer.local_state_dict(),
                                    "files": [f_id] + files,
                                    "epoch":
                                    epoch,
                                    "data_loader":
                                    None if global_step >= args.steps_this_run
                                    else train_dataloader,
                                }
                                if args.fp16:
                                    save_dict["master params"] = list(
                                        amp.master_params(optimizer))
                                # SMP: Checkpoint mp_rank specific state
                                smp.save(save_dict,
                                         output_save_file,
                                         partial=True)

                                most_recent_ckpts_paths.append(
                                    output_save_file)
                                if len(most_recent_ckpts_paths) > 3 and (
                                        args.smp == 0 or smp.dp_rank() == 0):
                                    ckpt_to_be_removed = most_recent_ckpts_paths.pop(
                                        0)
                                    os.remove(ckpt_to_be_removed +
                                              f"_{smp.mp_rank()}")

                        # Exiting the training due to hitting max steps, or being sent a
                        # timeout from the cluster scheduler
                        if global_step >= args.steps_this_run or timeout_sent:
                            del train_dataloader
                            # thread.join()
                            if smp.dp_rank() == 0 and args.save_full:
                                output_save_file = os.path.join(
                                    args.output_dir,
                                    "ckpt_{}.pt".format(global_step))
                                save_dict = {
                                    "model":
                                    model.local_state_dict(),
                                    "optimizer":
                                    optimizer.local_state_dict(),
                                    "files": [f_id] + files,
                                    "epoch":
                                    epoch,
                                    "data_loader":
                                    None if global_step >= args.steps_this_run
                                    else train_dataloader,
                                }
                                if args.fp16:
                                    save_dict["master params"] = list(
                                        amp.master_params(optimizer))
                                # SMP: Save a single checkpoint containing entire model parameters
                                smp.save(save_dict,
                                         output_save_file,
                                         partial=False)
                            smp.barrier()
                            if smp.local_rank() == 0:
                                print(f"Start syncing model checkpoints to s3")
                                base_s3_path = os.path.dirname(
                                    os.path.dirname(
                                        os.getenv("SM_MODULE_DIR", "")))
                                curr_host = os.getenv("SM_CURRENT_HOST")
                                full_s3_path = f"{base_s3_path}/checkpoints/{curr_host}/"
                                sync_local_checkpoints_to_s3(
                                    local_path=args.output_dir,
                                    s3_path=full_s3_path)
                                print(
                                    f"Finished syncing model checkpoints to s3"
                                )
                            return args, final_loss, train_time_raw, global_step
                else:
                    model.eval()
                    with torch.no_grad():
                        loss = test_step(
                            args,
                            device,
                            input_ids,
                            segment_ids,
                            input_mask,
                            masked_lm_labels,
                            next_sentence_labels,
                            model,
                            criterion,
                            step,
                        )
                        print(f"global_step {global_step} Test Loss:", loss)
                        test_losses.append(loss)
                    global_step += 1
                    if global_step >= args.steps_this_run:
                        return sum(test_losses) / len(test_losses)

            del train_dataloader
            # thread.join()
            # Make sure pool has finished and switch train_dataloader
            # NOTE: Will block until complete
            train_dataloader, data_file = dataset_future.result(timeout=None)
        epoch += 1
def prepare_model_and_optimizer(args, device):

    # Prepare model
    config = modeling.BertConfig.from_json_file(args.config_file)

    # Padding for divisibility by 8
    if config.vocab_size % 8 != 0:
        config.vocab_size += 8 - (config.vocab_size % 8)

    if args.use_sequential > 0:
        config.use_sequential = True
    else:
        config.use_sequential = False

    modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training
    model = modeling.BertForPreTraining(config)
    model.checkpoint_activations(args.checkpoint_activations)
    if args.smp > 0:
        # SMP: Use the DistributedModel container to provide the model
        # to be partitioned across different ranks. For the rest of the script,
        # the returned DistributedModel object should be used in place of
        # the model provided for DistributedModel class instantiation.
        model = smp.DistributedModel(model)

    checkpoint = None
    if not args.resume_from_checkpoint:
        global_step = 0
    else:
        if not args.init_checkpoint:
            if not args.s3_checkpoint_uri:
                raise ValueError(
                    "Need to set s3_checkpoint_uri, if init_checkpoint not set"
                )
            if smp.local_rank() == 0:
                sync_s3_checkpoints_to_local(args.output_dir,
                                             args.s3_checkpoint_uri)
            smp.barrier()
        if args.resume_step == -1 and not args.init_checkpoint:
            model_names = [
                f for f in os.listdir(args.output_dir) if ".pt" in f
            ]
            args.resume_step = max([
                int(x.split(".pt")[0].split("_")[1].strip())
                for x in model_names
            ])

        global_step = args.resume_step if not args.init_checkpoint else 0

        # SMP: Load a model that was saved with smp.save
        if not args.init_checkpoint:
            checkpoint = smp.load(
                os.path.join(args.output_dir,
                             "ckpt_{}.pt".format(global_step)),
                partial=args.partial_checkpoint,
            )
        else:
            checkpoint = smp.load(args.init_checkpoint)

        model.load_state_dict(checkpoint["model"], strict=False)

        if args.phase2 and not args.init_checkpoint:
            global_step -= args.phase1_end_step
        if is_main_process():
            print("resume step from ", args.resume_step)

    model.to(device)
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "gamma", "beta", "LayerNorm"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.01,
        },
        {
            "params":
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            "weight_decay":
            0.0,
        },
    ]

    optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate)
    if args.smp > 0:
        # SMP: Use Distributed Optimizer which allows the loading of optimizer state for a distributed model
        # Also provides APIs to obtain local optimizer state for the current mp_rank.
        optimizer = smp.DistributedOptimizer(optimizer)
    lr_scheduler = PolyWarmUpScheduler(optimizer,
                                       warmup=args.warmup_proportion,
                                       total_steps=args.max_steps)

    if args.fp16:
        if args.loss_scale == 0:
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level="O2",
                loss_scale="dynamic",
                cast_model_outputs=torch.float16,
            )
        else:
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level="O2",
                loss_scale=args.loss_scale,
                cast_model_outputs=torch.float16,
            )
        amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale

    if args.resume_from_checkpoint:
        if args.phase2 or args.init_checkpoint:
            keys = list(checkpoint["optimizer"]["state"].keys())
            # Override hyperparameters from previous checkpoint
            for key in keys:
                checkpoint["optimizer"]["state"][key]["step"] = global_step
            for iter, item in enumerate(
                    checkpoint["optimizer"]["param_groups"]):
                checkpoint["optimizer"]["param_groups"][iter][
                    "step"] = global_step
                checkpoint["optimizer"]["param_groups"][iter][
                    "t_total"] = args.max_steps
                checkpoint["optimizer"]["param_groups"][iter][
                    "warmup"] = args.warmup_proportion
                checkpoint["optimizer"]["param_groups"][iter][
                    "lr"] = args.learning_rate
        optimizer.load_state_dict(checkpoint["optimizer"])  # , strict=False)
        # Restore AMP master parameters
        if args.fp16:
            optimizer._lazy_init_maybe_master_weights()
            optimizer._amp_stash.lazy_init_called = True
            optimizer.load_state_dict(checkpoint["optimizer"])
            for param, saved_param in zip(amp.master_params(optimizer),
                                          checkpoint["master params"]):
                param.data.copy_(saved_param.data)

    # if args.local_rank != -1:
    #    if not args.allreduce_post_accumulation:
    #        model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size())
    #    else:
    #        flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) )
    # elif args.n_gpu > 1:
    #    model = torch.nn.DataParallel(model)

    criterion = BertPretrainingCriterion(config.vocab_size)

    return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
Beispiel #17
0
    def init_master_params(self):
        if self.use_smp:
            torch.cuda.set_device(smp.local_rank())
            register_optimizer_hooks(self.model)
        self.fp32paramid_from_fp16paramid = {}
        # Tensor used to determine if a nan/if has happend.
        # Any non-zero value indicates inf/nan.
        # Note that we keep this for the cases that grad scaler is none.
        # We still record nan/inf if we have a bfloat16 with a grad scaler.
        if self.grad_scaler:
            self.found_inf = torch.cuda.FloatTensor([0.0])

        # Dummy tensor needed for apex multi-apply tensor.
        # For bfloat, we don't have multi-tensor apply and for now
        # we set it to none so the multi-tensor apply gets ignored.
        if self.bf16:
            self._dummy_overflow_buf = None
        else:
            self._dummy_overflow_buf = torch.cuda.IntTensor([0])

        # In case grad scaler is not passed, define the unity scale.
        if self.grad_scaler is None:
            self._scale_one = torch.cuda.FloatTensor([1.0])

        # ======================
        # main parameter stuff
        # ======================

        # only need to create contiguous buffer for fp16 params which require grads
        contig_buffer_size = 0
        for param_group in self.optimizer.param_groups:
            for param in param_group["params"]:
                if param.requires_grad and param.type() in [
                        "torch.cuda.HalfTensor",
                        "torch.cuda.BFloat16Tensor",
                ]:
                    contig_buffer_size += param.numel()

        self.fp32_param_buffer = torch.empty(
            contig_buffer_size,
            device=torch.device("cuda", smp.local_rank()),
            dtype=torch.float32,
            requires_grad=True,
        )
        offset = 0

        # only need to create contiguous buffer for fp16 params which require grads
        contig_buffer_size = 0
        for param_group in self.optimizer.param_groups:
            for param in param_group["params"]:
                if param.requires_grad and param.type() in [
                        "torch.cuda.HalfTensor",
                        "torch.cuda.BFloat16Tensor",
                ]:
                    contig_buffer_size += param.numel()

        self.fp32_param_buffer = torch.empty(
            contig_buffer_size,
            device=torch.device("cuda", smp.local_rank()),
            dtype=torch.float32,
            requires_grad=True,
        )
        offset = 0

        # For all the groups in the original optimizer:
        for param_group in self.optimizer.param_groups:
            float16_params_this_group = []
            fp32_params_this_group = []
            fp32_from_float16_params_this_group = []
            fp32_from_fp16_paramids_this_group = []
            # For all the parameters in this group:
            for i, param in enumerate(param_group["params"]):
                if param.requires_grad:
                    # float16 params:
                    if param.type() in [
                            "torch.cuda.HalfTensor",
                            "torch.cuda.BFloat16Tensor"
                    ]:
                        float16_params_this_group.append(param)
                        # Create a copy
                        with torch.no_grad():
                            master_param_buffer = self.fp32_param_buffer.narrow(
                                0, offset, param.numel()).view_as(param)
                            master_param_buffer.copy_(param.float())
                            offset += param.numel()

                        main_param = nn.Parameter(
                            master_param_buffer,
                            requires_grad=param.requires_grad)
                        self.master_is_distributed[
                            main_param] = self.model.is_distributed_parameter(
                                param)
                        self.master_distribution_axis[id(
                            main_param)] = get_distribution_axis(param)
                        fp32_from_fp16_paramids_this_group.append(
                            id(main_param))
                        if hasattr(param, "shared"):
                            main_param.shared = param.shared

                        # Replace the optimizer params with the new fp32 copy.
                        param_group["params"][i] = main_param
                        fp32_from_float16_params_this_group.append(main_param)
                        # Reset existing state dict key to the new main param.
                        if param in self.optimizer.state:
                            self.optimizer.state[
                                main_param] = self.optimizer.state.pop(param)
                        self.fp32paramid_from_fp16paramid[id(param)] = id(
                            main_param)

                    # fp32 params.
                    elif param.type() == "torch.cuda.FloatTensor":
                        fp32_params_this_group.append(param)
                        param_group["params"][i] = param

                    else:
                        raise TypeError("Wrapped parameters must be one of "
                                        "torch.cuda.FloatTensor,  "
                                        "torch.cuda.HalfTensor, or "
                                        "torch.cuda.BFloat16Tensor. "
                                        "Received {}".format(param.type()))

            self.float16_groups.append(float16_params_this_group)
            self.fp32_from_float16_groups.append(
                fp32_from_float16_params_this_group)
            self.fp32_from_fp16_paramid_groups.append(
                fp32_from_fp16_paramids_this_group)
            self.fp32_from_fp32_groups.append(fp32_params_this_group)

        # Leverage state_dict() and load_state_dict() to
        # recast preexisting per-param state tensors
        self.optimizer.load_state_dict(self.optimizer.state_dict())
        self.master_params_created = True
Beispiel #18
0
def clip_grad_norm_fp32(parameters,
                        param_is_distributed,
                        max_norm,
                        norm_type=2):
    """Clips gradient norm of an iterable of parameters whose gradients
       are in fp32.

    This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
    added functionality to handle model parallel parameters. Note that
    the gradients are modified in place.

    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """

    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]

    # Filter parameters based on:
    #   - grad should not be none
    #   - parameter should not be shared
    #   - should not be a replica due to tensor model parallelism
    torch.cuda.set_device(smp.local_rank())
    grads = []
    grads_for_norm = []
    for param in parameters:
        grad_not_none = param.grad is not None
        is_not_shared = not hasattr(param, "shared") or not param.shared
        is_not_tp_duplicate = smp.tp_rank() == 0 or (
            param in param_is_distributed and param_is_distributed[param])
        if grad_not_none:
            grad = param.grad.detach()
            # Make sure the grads are in fp32
            assert param.grad.type() == 'torch.cuda.FloatTensor'
            grads.append(grad)
        if grad_not_none and is_not_shared and is_not_tp_duplicate:
            grads_for_norm.append(grad)

    # Norm parameters.
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    total_norm = torch.tensor(0.0, device=torch.device("cuda"))

    # Calculate norm.
    if norm_type == inf:
        if len(grads_for_norm) > 0:
            total_norm = max(grad.abs().max() for grad in grads_for_norm)
        total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
        # Take max across all model-parallel GPUs.
        torch.distributed.all_reduce(total_norm_cuda,
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=smp.get_mp_process_group())
        total_norm = total_norm_cuda[0].item()

    else:
        if norm_type == 2.0:
            dummy_overflow_buf = torch.cuda.IntTensor([0],
                                                      device=torch.device(
                                                          "cuda",
                                                          smp.local_rank()))
            # Use apex's multi-tensor applier for efficiency reasons.
            # Multi-tensor applier takes a function and a list of list
            # and performs the operation on that list all in one kernel.
            if len(grads_for_norm) > 0:
                grad_norm, _ = multi_tensor_applier(
                    amp_C.multi_tensor_l2norm,
                    dummy_overflow_buf,
                    [grads_for_norm],
                    False  # no per-parameter norm
                )
                # Since we will be summing across data parallel groups,
                # we need the pow(norm-type).
                total_norm = grad_norm**norm_type
        else:
            for grad in grads_for_norm:
                grad_norm = torch.norm(grad, norm_type)
                total_norm += grad_norm**norm_type

        # Sum across all model-parallel GPUs.
        torch.distributed.all_reduce(total_norm,
                                     op=torch.distributed.ReduceOp.SUM,
                                     group=smp.get_mp_process_group())
        total_norm = total_norm.item()**(1.0 / norm_type)

    # Scale.
    if len(grads) > 0:
        clip_coeff = max_norm / (total_norm + 1.0e-6)
        if clip_coeff < 1.0:
            dummy_overflow_buf = torch.cuda.IntTensor([0],
                                                      device=torch.device(
                                                          "cuda",
                                                          smp.local_rank()))
            multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf,
                                 [grads, grads], clip_coeff)

    return total_norm
Beispiel #19
0
    def init_master_params(self):

        if self.use_smp:
            torch.cuda.set_device(smp.local_rank())
            register_optimizer_hooks(self.model)
        self.fp32paramid_from_fp16paramid = {}

        # only need to create contiguous buffer for fp16 params which require grads
        contig_buffer_size = 0
        for param_group in self.optimizer.param_groups:
            for param in param_group["params"]:
                if param.requires_grad and param.type(
                ) == "torch.cuda.HalfTensor":
                    contig_buffer_size += param.numel()

        self.fp32_param_buffer = torch.empty(
            contig_buffer_size,
            device=torch.device("cuda", smp.local_rank()),
            dtype=torch.float32,
            requires_grad=True,
        )
        offset = 0
        for i, param_group in enumerate(self.optimizer.param_groups):
            self.maybe_print(
                "FP16_Optimizer processing param group {}:".format(i))
            fp16_params_this_group = []
            fp32_params_this_group = []
            fp32_from_fp16_params_this_group = []
            fp32_from_fp16_paramids_this_group = []
            for i, param in enumerate(param_group["params"]):
                if param.requires_grad:
                    if param.type() == "torch.cuda.HalfTensor":
                        self.maybe_print(
                            "FP16_Optimizer received torch.cuda.HalfTensor with {}"
                            .format(param.size()))
                        fp16_params_this_group.append(param)

                        with torch.no_grad():
                            master_param_buffer = self.fp32_param_buffer.narrow(
                                0, offset, param.numel()).view_as(param)
                            master_param_buffer.copy_(param.float())
                            offset += param.numel()

                        master_param = nn.Parameter(
                            master_param_buffer,
                            requires_grad=param.requires_grad)

                        self.master_is_distributed[
                            master_param] = self.model.is_distributed_parameter(
                                param)
                        self.master_distribution_axis[id(
                            master_param)] = get_distribution_axis(param)
                        param_group["params"][i] = master_param
                        fp32_from_fp16_params_this_group.append(master_param)
                        fp32_from_fp16_paramids_this_group.append(
                            id(master_param))
                        # Reset existing state dict key to the new master param.
                        # We still need to recast per-param state tensors, if any, to FP32.
                        if param in self.optimizer.state:
                            self.optimizer.state[
                                master_param] = self.optimizer.state.pop(param)
                        self.fp32paramid_from_fp16paramid[id(param)] = id(
                            master_param)
                    elif param.type() == "torch.cuda.FloatTensor":
                        self.maybe_print(
                            "FP16_Optimizer received torch.cuda.FloatTensor with {}"
                            .format(param.size()))
                        fp32_params_this_group.append(param)
                        param_group["params"][i] = param
                    else:
                        raise TypeError(
                            "Wrapped parameters must be either "
                            "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
                            "Received {}".format(param.type()))
            self.fp16_groups.append(fp16_params_this_group)
            self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
            self.fp32_from_fp16_paramid_groups.append(
                fp32_from_fp16_paramids_this_group)
            self.fp32_from_fp32_groups.append(fp32_params_this_group)

        # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
        self.optimizer.load_state_dict(self.optimizer.state_dict())
        # alternative way to cast per-param state tensors:
        # self.optimizer.load_state_dict(init_state_dict)

        self.overflow = False
        self.first_closure_call_this_step = True
        self.master_params_created = True
Beispiel #20
0
def init_train():
    """
    Train the PyTorch model
    """

    cat_mask = [
        False, True, True, True, True, False, True, True, True, True, True,
        False, False, False, False, False, False, False
    ]
    train_ds = CsvDatasetSimple(args.train)
    test_ds = CsvDatasetSimple(args.test)

    batch_size = args.batch_size
    epochs = args.epochs
    learning_rate = args.learning_rate

    logger.info("batch_size = {}, epochs = {}, learning rate = {}".format(
        batch_size, epochs, learning_rate))

    # smdistributed: initialize the backend
    smp.init()

    # smdistributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")

    # smdistributed: Download only on a single process per instance.
    # When this is not present, the file is corrupted by multiple processes trying
    # to download and extract at the same time
    #dataset = datasets.MNIST("../data", train=True, download=False)
    dataset = train_ds

    # smdistributed: Shard the dataset based on data-parallel ranks
    if smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        dataset = SplitDataset(dataset, partitions=partitions_dict)
        dataset.select(f"{smp.dp_rank()}")

    # smdistributed: Set drop_last=True to ensure that batch size is always divisible
    # by the number of microbatches
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=64,
                                               drop_last=True)

    model = TabularNet(n_cont=9,
                       n_cat=9,
                       cat_mask=cat_mask,
                       cat_dim=[
                           0, 2050, 13, 5, 366, 0, 50000, 50000, 50000, 50000,
                           50, 0, 0, 0, 0, 0, 0, 0
                       ],
                       y_min=0.,
                       y_max=1.)

    logger.debug(model)

    optimizer = optim.Adadelta(model.parameters(), lr=4.0)

    # SMP: Instantiate DistributedModel object using the model.
    # This handles distributing the model among multiple ranks
    # behind the scenes
    # If horovod is enabled this will do an overlapping_all_reduce by
    # default.

    # smdistributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    model = smp.DistributedModel(model)

    optimizer = smp.DistributedOptimizer(optimizer)

    train(model, device, train_loader, optimizer)

    torch.save(model.state_dict(), args.model_dir + "/model.pth")
Beispiel #21
0
                          'from checkpoints.')

    args.is_distributed = len(args.hosts) > 1 and args.backend is not None
    args.is_multigpus = args.num_gpus > 1
    args.multigpus_distributed = (args.is_distributed or args.is_multigpus)        
    
    logger.debug(f"args.image_folder : {args.image_folder}")
    

    args.world_size = 1
    args.local_rank = 0
    args.rank = 0
    
    if args.model_parallel:
        args.world_size = smp.size()
        args.local_rank = smp.local_rank()  # rank per host
        args.rank = smp.rank()
        args.dp_size = smp.dp_size()
        args.dp_rank = smp.dp_rank()
        logger.debug(f"args.world_size : {args.world_size}, args.local_rank : {args.local_rank}, args.rank : {args.rank}, \
                    args.dp_size : {args.dp_size}, args.dp_rank : {args.dp_rank}")
    else:
        # initialize deepspeed
        print(f"args.deepspeed : {args.deepspeed}")
        deepspeed_utils.init_deepspeed(args.deepspeed)
#     args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size)


    ## SageMaker
    try:
        if os.environ.get('SM_CHANNEL_TRAINING') is not None:
def main():

    model_args, data_args, training_args, smp_args = parse_args()
    model, tokenizer = initialize_model_and_tokenizer(model_args)

    # Get datasets
    train_dataset, eval_dataset = Preprocess.datasets(model_args, data_args,
                                                      training_args)

    if is_sagemaker_mp_enabled():
        initialize_smp(smp_args, training_args)

        torch.set_default_dtype(torch.float32)

        num_params = print_num_parameters(model)

        # smdistributed: Set the device to the GPU ID used by the current process.
        # Input tensors should be transferred to this device.
        torch.cuda.set_device(smp.local_rank())
        device = torch.device("cuda")

        if not training_args.same_seed:
            # Set seed by tp_rank to prevent weights from being the same on different tp_ranks
            set_seed(training_args.seed + smp.tp_rank())

        model = smp.DistributedModel(model,
                                     trace_device=smp_args.trace_device,
                                     gradient_as_bucket_view=True)

        torch.set_default_dtype(torch.float32)

        iter_model = model
        # Build parameter groups (weight decay and non-decay).
        while isinstance(iter_model, (DistributedDataParallel, FP16_Module)):
            iter_model = iter_model.module

        param_groups = get_param_groups_by_weight_decay(iter_model)

        if training_args.use_adamw > 0:
            optimizer = training_args.AdamW(
                param_groups,
                betas=(training_args.beta1, training_args.beta2),
                lr=training_args.lr,
                weight_decay=training_args.weight_decay,
            )
        else:
            optimizer = optim.Adam(
                param_groups,
                betas=(training_args.beta1, training_args.beta2),
                lr=training_args.lr,
                weight_decay=training_args.weight_decay,
            )

        optimizer = smp.DistributedOptimizer(optimizer)
        lr_scheduler = get_learning_rate_scheduler(optimizer, training_args)

        total_steps = 0
        start_train_path_index = 0
        start_batch_index = 0

        # Initialize Trainer instance

        trainer = SMPTrainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset if training_args.do_train else None,
            eval_dataset=eval_dataset if training_args.do_eval else None,
            tokenizer=tokenizer,
            data_collator=default_data_collator,
        )

        trainer.train_smp(
            model,
            optimizer,
            lr_scheduler,
            start_train_path_index,
            start_batch_index,
            num_params,
            total_steps,
            training_args,
            prescaled_batch=smp_args.prescaled_batch,
        )
Beispiel #23
0
def main():
    if not torch.cuda.is_available():
        raise ValueError(
            "The script requires CUDA support, but CUDA not available")
    use_ddp = True
    use_horovod = False

    # Fix seeds in order to get the same losses across runs
    random.seed(1)
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed(1)

    smp.init()

    # SM Distributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")
    kwargs = {"batch_size": 64}
    kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False})

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    # SM Distributed: Download only on a single process per instance.
    # When this is not present, the file is corrupted by multiple processes trying
    # to download and extract at the same time
    if smp.local_rank() == 0:
        dataset1 = datasets.MNIST("../data",
                                  train=True,
                                  download=True,
                                  transform=transform)
    smp.barrier()
    dataset1 = datasets.MNIST("../data",
                              train=True,
                              download=False,
                              transform=transform)

    if (use_ddp or use_horovod) and smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        dataset1 = SplitDataset(dataset1, partitions=partitions_dict)
        dataset1.select(f"{smp.dp_rank()}")

    # Download and create dataloaders for train and test dataset
    dataset2 = datasets.MNIST("../data", train=False, transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset1, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

    model = GroupedNet()

    # SMP handles the transfer of parameters to the right device
    # and the user doesn't need to call 'model.to' explicitly.
    # model.to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=4.0)

    # SM Distributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    model = smp.DistributedModel(model)
    scaler = smp.amp.GradScaler()
    optimizer = smp.DistributedOptimizer(optimizer)

    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    for epoch in range(1, 2):
        train(model, scaler, device, train_loader, optimizer, epoch)
        test_loss = test(model, device, test_loader)
        scheduler.step()

    if smp.rank() == 0:
        if os.path.exists("/opt/ml/local_checkpoints"):
            print("-INFO- PATH DO EXIST")
        else:
            os.makedirs("/opt/ml/local_checkpoints")
            print("-INFO- PATH DO NOT EXIST")

    # Waiting the save checkpoint to be finished before run another allgather_object
    smp.barrier()

    if smp.dp_rank() == 0:
        model_dict = model.local_state_dict()
        opt_dict = optimizer.local_state_dict()
        smp.save(
            {
                "model_state_dict": model_dict,
                "optimizer_state_dict": opt_dict
            },
            f"/opt/ml/local_checkpoints/pt_mnist_checkpoint.pt",
            partial=True,
        )
    smp.barrier()

    if smp.local_rank() == 0:
        print("Start syncing")
        base_s3_path = os.path.dirname(
            os.path.dirname(os.getenv("SM_MODULE_DIR", "")))
        curr_host = os.getenv("SM_CURRENT_HOST")
        full_s3_path = f"{base_s3_path}/checkpoints/{curr_host}/"
        sync_local_checkpoints_to_s3(local_path="/opt/ml/local_checkpoints",
                                     s3_path=full_s3_path)
        print("Finished syncing")
Beispiel #24
0
def main():
    args = parse_args()

    if args.shard_optimizer_state > 0 and not args.skip_full_optimizer:
        raise ValueError(
            "If shard_optimizer_state is enabled, skip_full_optimizer must also be enabled. Full optimizer saving is currently not supported under optimizer state sharding."
        )

    if args.partition_assignment != "" and args.manual_partition == 0:
        print("[Warning] partition_assignment is set, enable manual_partition")
        args.manual_partition = 1

    # any value here is overriden by the config set in notebook when launching the sagemaker job
    smp_config = {
        "ddp": True,
        "tensor_parallel_degree": args.tensor_parallel_degree,
        "pipeline_parallel_degree": args.pipeline_parallel_degree,
        "microbatches": args.microbatches,
        # if activation_checkpointing true checkpoints transformer layers below
        "checkpoint_attentions":
        False if args.activation_checkpointing else True,
        "shard_optimizer_state": args.shard_optimizer_state > 0,
        "prescaled_batch": args.prescaled_batch > 0,
        "offload_activations": args.offload_activations > 0,
        "optimize": args.optimize,
        "auto_partition": False if args.manual_partition else True,
        "default_partition": 0,
        "static_mode": args.static_mode > 0,
        "fast_mode": args.fast_mode > 0,
    }

    if args.smp_version < 110:
        smp_config["fp16_params"] = args.fp16 > 0
    else:
        smp_config["fp16"] = args.fp16 > 0
        smp_config["delayed_parameter_initialization"] = args.delayed_param > 0
        smp_config["placement_strategy"] = args.placement_strategy
        smp_config[
            "activation_loading_horizon"] = args.activation_loading_horizon
        smp_config["skip_tracing"] = args.skip_tracing > 0

    if args.active_microbatches is not None:
        smp_config["active_microbatches"] = args.active_microbatches

    smp.init(smp_config)

    if smp.rank() == 0:
        print("Arguments:", args.__dict__)
        print(f"Transformers version: {transformers.__version__}")
        print(
            f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}"
        )
        print(f"smdistributed config: {smp_config}")

    if args.save_final_full_model and smp.rank() == 0:
        print(
            f"[Warning] Note that save_final_full_model only saves the final model at the end of all steps. It does not save optimizer state. Optimizer state is only saved with partial models which are saved at checkpointing_freq during training. If you want to restart training you need partial checkpoints."
        )

    if args.partition_assignment != "":
        partition_assignment = args.partition_assignment.split(",")
        assert (
            len(partition_assignment) == smp.pp_size()
        ), f"partition_assignment must have the same size as pipeline parallel degree, but getting {len(partition_assignment)} vs {smp.pp_size()}"

    if smp.rank() == 0 or (smp.local_rank() == 0 and args.use_fsx == 0):
        for path in [args.model_dir, args.checkpoint_dir]:
            if not os.path.exists(path):
                os.makedirs(path, exist_ok=True)

    model_config = GPT2Config(
        vocab_size=args.vocab_size,
        n_positions=args.max_context_width,
        n_embd=args.hidden_width,
        n_layer=args.num_layers,
        n_head=args.num_heads,
        n_inner=None,
        activation_function="gelu_new",
        resid_pdrop=args.resid_pdrop,
        embd_pdrop=args.embd_pdrop,
        attn_pdrop=args.attn_pdrop,
        layer_norm_epsilon=1e-05,
        initializer_range=0.02,
        summary_type="cls_index",
        summary_use_proj=True,
        summary_activation=None,
        summary_proj_to_labels=True,
        summary_first_dropout=args.summary_first_pdrop,
        # gradient_checkpointing=args.gradient_checkpointing > 0,
        use_cache=False,
        bos_token_id=50256,
        eos_token_id=50256,
        return_dict=True,
    )

    # the following improves start-up time by skipping proper initialization
    # of weights in the original model. this is not a problem because DistributedModel
    # will override those weights anyway when tensor_parallel_degree > 1.
    if smp.tp_size() > 1:
        from transformers.modeling_utils import PreTrainedModel

        PreTrainedModel.init_weights = lambda x: None

    set_seed(args.seed)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before model creation")

    if args.smp_version < 110:
        if args.fp16:
            torch.set_default_dtype(torch.float16)
        with smp.tensor_parallelism(
                enabled=smp.tp_size() > 1,
                attention_in_fp32=args.attention_in_fp32 > 0):
            with smp.delay_param_initialization(
                    enabled=(smp.tp_size() > 1 and args.delayed_param > 0)):
                model = AutoModelForCausalLM.from_config(model_config)
    else:
        with smp.model_creation(
                tensor_parallelism=smp.tp_size() > 1,
                attention_in_fp32=args.attention_in_fp32 > 0,
                query_key_layer_scaling=args.query_key_layer_scaling > 0,
                fused_softmax=args.fused_softmax > 0,
                fused_bias_gelu=args.fused_bias_gelu > 0,
                dtype=torch.float16
                if args.fp16 else torch.get_default_dtype(),
        ):
            model = AutoModelForCausalLM.from_config(model_config)

    if args.smp_version < 110 and args.fp16:
        model = FP16_Module(model)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after model creation")

    num_params = sum([np.prod(p.size()) for p in model.parameters()])
    if smp.rank() == 0:
        print(f"# total parameters: {num_params}")

    # smdistributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")

    if not args.same_seed:
        # Set seed by tp_rank to prevent weights from being the same on different tp_ranks
        set_seed(args.seed + smp.tp_rank())

    # smdistributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    if args.smp_version < 110 and args.fp16:
        torch.set_default_dtype(torch.float16)
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before dist model creation")
    model = smp.DistributedModel(model, trace_device="gpu")
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after dist model creation")

    if args.smp_version < 110:
        if smp.tp_size() > 1:
            transformer_layers = model.module.module.module.transformer.seq_layers
        else:
            transformer_layers = model.module.module.module.transformer.h
    else:
        m = model.get_module()
        if smp.tp_size() > 1:
            transformer_layers = m.transformer.seq_layers
        else:
            transformer_layers = m.transformer.h

    if args.manual_partition:
        print(f"Manual partition enabled")
        if args.partition_assignment != "":
            get_num_layers = lambda x: int(partition_assignment[x])
            total_layers = sum(
                [get_num_layers(pp_rank) for pp_rank in range(smp.pp_size())])
            assert (
                total_layers == args.num_layers
            ), f"partition_assignment must have the same total transformer layers as model, but getting {total_layers} vs {args.num_layers}"
        else:
            # evenly distribute layers across all partitions
            div, rem = divmod(args.num_layers, smp.pp_size())
            get_num_layers = lambda x: (div + 1
                                        if x >= smp.pp_size() - rem else div)
        assignments = []
        # (TODO) This is required for 175B otherwise a hang for partition "8,17,17,18,18,18"
        # Need further investigation
        # for pp_rank in reversed(range(smp.pp_size())):
        for pp_rank in range(smp.pp_size()):
            nl = get_num_layers(pp_rank)
            print(f"{nl} layers assigned to partition {pp_rank}")
            assignments += [pp_rank for _ in range(nl)]

        for i, c in enumerate(transformer_layers.children()):
            smp.set_partition(c, assignments[i])
    if args.smp_version < 110:
        iter_model = model
        # Build parameter groups (weight decay and non-decay).
        while isinstance(iter_model, (DistributedDataParallel, FP16_Module)):
            iter_model = iter_model.module
    else:
        iter_model = m
    param_groups = get_param_groups_by_weight_decay(iter_model)

    if args.use_adamw > 0:
        optimizer = optim.AdamW(param_groups,
                                betas=(args.beta1, args.beta2),
                                lr=args.lr,
                                weight_decay=args.weight_decay)
    else:
        optimizer = optim.Adam(param_groups,
                               betas=(args.beta1, args.beta2),
                               lr=args.lr,
                               weight_decay=args.weight_decay)

    if args.activation_checkpointing:
        kwargs = {}
        if isinstance(transformer_layers, nn.Sequential):
            kwargs["pack_args_as_tuple"] = True
            kwargs["strategy"] = args.activation_strategy
        smp.set_activation_checkpointing(transformer_layers, **kwargs)

    if args.smp_version < 110:
        optimizer = FP16_Optimizer(
            model,
            optimizer,
            static_loss_scale=None,
            dynamic_loss_scale=True,
            use_smp=True,
            dynamic_loss_args={
                "scale_window": 1000,
                "min_scale": 1,
                "delayed_shift": 2
            },
            params_have_main_grad=False,
            shard_optimizer_state=args.shard_optimizer_state > 0,
        )

        optimizer = smp.DistributedOptimizer(optimizer)
        model.register_post_step_hook(
            lambda model, optimizer: optimizer.init_master_params())
    else:
        optimizer = smp.DistributedOptimizer(
            optimizer,
            static_loss_scale=None,
            dynamic_loss_scale=True,
            dynamic_loss_args={
                "scale_window": 1000,
                "min_scale": 1,
                "delayed_shift": 2
            },
        )
    lr_scheduler = get_learning_rate_scheduler(optimizer, args)

    if args.enable_memory_profiling > 0:
        model.register_post_partition_hook(
            lambda model, optimizer: memory_status(msg="After_partition"))

    # load after wrapping model and optimizer with smp Distributed...
    if args.load_full or args.load_partial:
        if args.load_partial and args.load_full:
            print(
                "Since both --load_partial and --load_full set, will try to load from full checkpoint."
                "If the intention is to load from partial checkpoint, please don't set --load_full"
            )
        partial = not args.load_full
        path = args.checkpoint_dir if partial else args.model_dir
        translate_from_hf = not partial
        model, optimizer, total_steps, start_train_path_index, start_batch_index = load_model_and_optimizer(
            path,
            model,
            optimizer,
            lr_scheduler,
            partial,
            args,
            translate_from_hf=translate_from_hf,
            seq_length=args.max_context_width,
            load_model=True,
            load_optimizer=args.load_partial > 0,
            num_params=num_params,
        )
    else:
        total_steps = 0
        start_train_path_index = 0
        start_batch_index = 0

    start = time.time()
    total_steps, throughput, loss = train(
        model,
        optimizer,
        lr_scheduler,
        model_config,
        start_train_path_index,
        start_batch_index,
        num_params,
        total_steps,
        args,
    )
    time_to_train = time.time() - start
    if args.ci:
        print(f"[SMP_METRIC]__GPT2__Time_to_train__{time_to_train}")
        print(f"[SMP_METRIC]__GPT2__samples/second__{throughput}")
        print(f"[SMP_METRIC]__GPT2__Loss__{loss}")
        if not args.load_partial and not args.load_full:
            assert time_to_train < args.time_to_train
            assert throughput > args.throughput
            if args.loss:
                assert loss < args.loss

    if args.save_final_full_model:
        # saves full model at the end

        base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt"
        out_path = os.path.join(args.model_dir, base_path)

        if smp.rdp_rank() == 0:
            save(
                out_path,
                model,
                optimizer,
                lr_scheduler,
                model_config,
                num_params,
                total_steps,
                -1,
                args,
                partial=False,
                translate_to_hf=smp.tp_size() > 1,
                seq_length=args.max_context_width,
            )

    smp.barrier()
    if smp.rank() == 0:
        print("SMP training finished successfully")
def train(
    model,
    optimizer,
    lr_scheduler,
    model_config,
    start_train_path_index,
    start_batch_index,
    num_params,
    total_steps,
    args,
):
    model.train()
    if args.parallel_proc_data_processing:
        pool = ProcessPoolExecutor(1)

    dp_rank = smp.dp_rank() if not args.prescaled_batch else smp.rdp_rank()
    dp_size = smp.dp_size() if not args.prescaled_batch else smp.rdp_size()
    data_type = "BERT" if args.use_bert_data else "GPT"

    if args.use_bert_data:
        train_paths = sorted([
            os.path.join(args.training_dir, p)
            for p in os.listdir(args.training_dir)
            if os.path.isfile(os.path.join(args.training_dir, p)) and "training" in p
        ])
    else:
        if args.zipped_data > 0:
            file_extension = ".json.gz"
        else:
            file_extension = ".json"
        train_paths = sorted(
            [
                os.path.join(args.training_dir, p)
                for p in os.listdir(args.training_dir)
                if p.endswith(file_extension)
            ]
        )

    train_dataloader = create_pretraining_dataloader(
        [train_paths[start_train_path_index]],
        args.train_batch_size,
        args.max_context_width,
        seed=args.seed,
        dp_rank=dp_rank,
        dp_size=dp_size,
        shuffle=args.same_seed < 1,
        zipped=args.zipped_data > 0,
        use_last_file_only=args.fast_validation > 0,
        data_type=data_type,
    )

    if args.validation_freq is not None:
        # load all validation examples
        if smp.rank() == 0:
            print("Creating val dataloader")
        if args.use_bert_data:
            val_paths = sorted([
                os.path.join(args.test_dir, p)
                for p in os.listdir(args.test_dir)
                if os.path.isfile(os.path.join(args.test_dir, p)) and "testing" in p
            ])

        else:
            if args.zipped_data > 0:
                file_extension = ".json.gz"
            else:
                file_extension = ".json"
            val_paths = sorted(
                [
                    os.path.join(args.test_dir, p)
                    for p in os.listdir(args.test_dir)
                    if p.endswith(file_extension)
                ]
            )
        val_dataloader = create_pretraining_dataloader(
            val_paths,
            args.val_batch_size,
            args.max_context_width,
            seed=args.seed,
            dp_rank=dp_rank,
            dp_size=dp_size,
            shuffle=True,
            zipped=args.zipped_data > 0,
            use_last_file_only=args.fast_validation > 0,
            data_type=data_type,
        )
        if smp.rank() == 0:
            print("Created val dataloader")

    start = time.time()
    throughput = None
    to_save = {"loss": [], "val_loss": []}
    loss_metric = 0

    def should_record():
        # only record the ranks that in the tp group that contains global rank 0
        if smp.tp_size() > 1:
            tp_group = smp.get_tp_group()
            return 0 in tp_group
        else:
            return smp.rank() == 0

    # Set the same seed for computation
    set_seed(args.seed)

    for index in range(start_train_path_index, args.epochs*len(train_paths)):
        next_train_path_index = (index + 1) % len(train_paths)
        curr_train_path_index = index % len(train_paths)

        if total_steps >= args.max_steps:
            break

        if args.parallel_proc_data_processing:
            dataset_future = pool.submit(
                create_pretraining_dataloader,
                [train_paths[next_train_path_index]],
                args.train_batch_size,
                args.max_context_width,
                seed=args.seed,
                dp_rank=dp_rank,
                dp_size=dp_size,
                shuffle=args.same_seed < 1,
                zipped=args.zipped_data > 0,
                use_last_file_only=args.fast_validation > 0,
                data_type=data_type,
            )

        if smp.rank() == 0:
            if args.use_bert_data:
                print(f"Reading data from training path {train_dataloader.dataset.input_file}")
            else:
                print(f"Reading data from training path {train_dataloader.dataset.input_paths}")

        for batch_idx, input_data in enumerate(train_dataloader):
            if batch_idx < start_batch_index:
                if smp.rank() == 0:
                    print(f"Resuming from saved batch index {start_batch_index}, skipping batch {batch_idx}...")
                continue
            else:
                start_batch_index = 0

            if args.use_bert_data:
                input_ids, _, attention_mask, _, _ = input_data
            else:
                input_ids, attention_mask = input_data

            if total_steps >= args.max_steps:
                break

            step_start = time.time()

            optimizer.zero_grad(set_grads_to_None=True)

            if args.logits_output:
                train_output = train_step(model, optimizer, input_ids, attention_mask, args)
                loss_mb = train_output["loss"]
                logits_mb = train_output["logits"]
                if smp.tp_size() > 1:
                    logits = torch.cat(tuple(logits_mb.outputs), dim=1)
                else:
                    logits = torch.cat(tuple(logits_mb.outputs), dim=0)
            else:
                # Return value, loss_mb is a StepOutput object
                loss_mb = train_step(model, optimizer, input_ids, attention_mask, args)

            # smdistributed: Average the loss across microbatches.
            loss = loss_mb.reduce_mean()
            if not args.validation_freq:
                loss_metric = loss.item()

            if args.fp16:
                if args.megatron:
                    success, _, _ = optimizer.step()
                    overflow = not success
                else:
                    optimizer.update_master_grads()
                    optimizer.clip_master_grads(args.grad_clip)
                    optimizer.step()
                    overflow = optimizer.overflow
            if not (args.fp16 and overflow):
                lr_scheduler.step()

            total_steps += 1
            time_elapsed = time.time() - start
            step_time = time.time() - step_start
            sample_processed = input_ids.shape[0] * dp_size
            throughput = sample_processed / step_time
            if smp.rank() == 0 and not total_steps % args.logging_freq:
                print(
                    f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {throughput} samples/sec"
                )

            # evaluate on validation
            if args.validation_freq and not (total_steps % args.validation_freq):
                cur_state = np.random.get_state()
                model = model.eval()
                val_loss, val_ppl = eval_model(
                    model, val_dataloader, args.validation_batches, args.use_bert_data
                )
                if is_main_process(smp.rank()):
                    print(
                        f"({int(time.time()-start)}s) Batch {total_steps - 1} Validation loss: {val_loss}"
                    )
                    print(
                        f"({int(time.time()-start)}s) Batch {total_steps - 1} Validation perplexity: {val_ppl}"
                    )
                loss_metric = val_loss
                if args.logits_output:
                    to_save["val_loss"].append(val_loss)
                model = model.train()
                if args.preserve_np_state > 0:
                    np.random.set_state(cur_state)

            # checkpoint
            if not (total_steps % args.checkpoint_freq):
                base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt"
                out_path = os.path.join(args.checkpoint_dir, base_path)
                total_ckpts = total_steps // args.checkpoint_freq

                # save_or_verify_ckptsum if this is the last checkpoint
                if (args.save_or_verify_ckptsum and total_steps >= args.max_steps) or (
                    (total_ckpts + 1) * args.checkpoint_freq
                ) > args.max_steps:
                    # Save optimizer and model tensor sums and scalars before saving
                    save_ckptsum(
                        args,
                        model,
                        optimizer,
                        filename=os.path.join(args.model_dir, "saved_partial_sum"),
                    )

                save(
                    out_path,
                    model,
                    optimizer,
                    lr_scheduler,
                    model_config,
                    num_params,
                    total_steps,
                    curr_train_path_index,
                    args,
                    partial=True,
                    batch_idx=batch_idx+1,
                )

                if smp.local_rank() == 0:
                    delete_oldest_ckpt(args)

            if args.logits_output:
                to_save["loss"].append(loss.item())

        if total_steps >= args.max_steps:
            if should_record() and args.logits_output:
                to_save["logits"] = logits.detach().cpu()
                output_file = f"rank_{smp.rank()}_" + args.logits_output
                torch.save(to_save, os.path.join(args.model_dir, output_file))
                print(f"logits and loss saved at {os.path.join(args.model_dir, output_file)}")
            break

        del train_dataloader

        if args.parallel_proc_data_processing:
            s = time.time()
            train_dataloader = dataset_future.result(timeout=None)
            wait_time = time.time() - s
            if wait_time > 1:
                # TODO if this happens, we should try num_workers>1 in dataloader
                print(
                    f"[{smp.rank()}] Waited {wait_time} for data loader to be ready. Please check if dataloader performance can be improved to avoid these waits."
                )
        else:
            train_dataloader = create_pretraining_dataloader(
                [train_paths[next_train_path_index]],
                args.train_batch_size,
                args.max_context_width,
                seed=args.seed,
                dp_rank=dp_rank,
                dp_size=dp_size,
                shuffle=args.same_seed < 1,
                zipped=args.zipped_data > 0,
                use_last_file_only=args.fast_validation > 0,
                data_type=data_type,
            )

    return total_steps, throughput, loss_metric
Beispiel #26
0
def main():
    parser = get_parser()
    args = parser.parse_args()
    if not torch.cuda.is_available():
        raise ValueError(
            "The script requires CUDA support, but CUDA not available")
    use_ddp = args.ddp > 0
    use_horovod = args.horovod > 0

    # Fix seeds in order to get the same losses across runs
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    cfg = {
        "microbatches": args.num_microbatches,
        "placement_strategy": "spread",
        "pipeline": args.pipeline,
        "optimize": "speed",
        "partitions": args.num_partitions,
        "horovod": use_horovod,
        "ddp": use_ddp,
    }

    smp.init(cfg)

    # SM Distributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())
    device = torch.device("cuda")
    kwargs = {"batch_size": args.batch_size}
    kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False})

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    # SM Distributed: Download only on a single process per instance.
    # When this is not present, the file is corrupted by multiple processes trying
    # to download and extract at the same time
    if smp.local_rank() == 0:
        dataset1 = datasets.MNIST("../data",
                                  train=True,
                                  download=True,
                                  transform=transform)
    smp.barrier()
    dataset1 = datasets.MNIST("../data",
                              train=True,
                              download=False,
                              transform=transform)

    if (use_ddp or use_horovod) and smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        dataset1 = SplitDataset(dataset1, partitions=partitions_dict)
        dataset1.select(f"{smp.dp_rank()}")

    # Download and create dataloaders for train and test dataset
    dataset2 = datasets.MNIST("../data", train=False, transform=transform)

    train_loader = torch.utils.data.DataLoader(dataset1, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

    model = GroupedNet()

    # SMP handles the transfer of parameters to the right device
    # and the user doesn't need to call 'model.to' explicitly.
    # model.to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    # SM Distributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    model = smp.DistributedModel(model)
    scaler = smp.amp.GradScaler()
    optimizer = smp.DistributedOptimizer(optimizer)

    if args.partial_checkpoint:
        checkpoint = smp.load(args.partial_checkpoint, partial=True)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    elif args.full_checkpoint:
        checkpoint = smp.load(args.full_checkpoint, partial=False)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, scaler, device, train_loader, optimizer, epoch)
        test_loss = test(args, model, device, test_loader)
        scheduler.step()

    if args.save_partial_model:
        if smp.dp_rank() == 0:
            model_dict = model.local_state_dict()
            opt_dict = optimizer.local_state_dict()
            smp.save(
                {
                    "model_state_dict": model_dict,
                    "optimizer_state_dict": opt_dict
                },
                f"./pt_mnist_checkpoint.pt",
                partial=True,
            )

    if args.save_full_model:
        if smp.dp_rank() == 0:
            model_dict = model.state_dict()
            opt_dict = optimizer.state_dict()
            smp.save(
                {
                    "model_state_dict": model_dict,
                    "optimizer_state_dict": opt_dict
                },
                "./pt_mnist_checkpoint.pt",
                partial=False,
            )

    # Waiting the save checkpoint to be finished before run another allgather_object
    smp.barrier()

    if args.assert_losses:
        if use_horovod or use_ddp:
            # SM Distributed: If using data parallelism, gather all losses across different model
            # replicas and check if losses match.

            losses = smp.allgather(test_loss, smp.DP_GROUP)
            for l in losses:
                assert math.isclose(l, losses[0])

            assert test_loss < 0.18
        else:
            assert test_loss < 0.08
Beispiel #27
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    if not torch.cuda.is_available():
        raise ValueError(
            "The script requires CUDA support, but CUDA not available")

    args.rank = -1
    args.world_size = 1

    if args.model_parallel:
        args.deepspeed = False
        cfg = {
            "microbatches": args.num_microbatches,
            "placement_strategy": args.placement_strategy,
            "pipeline": args.pipeline,
            "optimize": args.optimize,
            "partitions": args.num_partitions,
            "horovod": args.horovod,
            "ddp": args.ddp,
        }

        smp.init(cfg)
        torch.cuda.set_device(smp.local_rank())
        args.rank = smp.dp_rank()
        args.world_size = smp.size()
    else:
        # initialize deepspeed
        print(f"args.deepspeed : {args.deepspeed}")
        deepspeed_utils.init_deepspeed(args.deepspeed)
        if deepspeed_utils.is_root_worker():
            args.rank = 0

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed + args.rank)
        np.random.seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size)

    cudnn.deterministic = True

    if cudnn.deterministic:
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True}

    device = torch.device("cuda")

    logger.debug(f"args.image_folder : {args.image_folder}")
    logger.debug(f"args.rank : {args.rank}")

    ## SageMaker
    try:
        if os.environ.get('SM_MODEL_DIR') is not None:
            args.model_dir = os.environ.get('SM_MODEL_DIR')
            #             args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR')
            args.image_folder = os.environ.get('SM_CHANNEL_TRAINING')
    except:
        logger.debug("not SageMaker")
        pass

    IMAGE_SIZE = args.image_size
    IMAGE_PATH = args.image_folder

    EPOCHS = args.EPOCHS
    BATCH_SIZE = args.BATCH_SIZE
    LEARNING_RATE = args.LEARNING_RATE
    LR_DECAY_RATE = args.LR_DECAY_RATE

    NUM_TOKENS = args.NUM_TOKENS
    NUM_LAYERS = args.NUM_LAYERS
    NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS
    SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS
    EMB_DIM = args.EMB_DIM
    HID_DIM = args.HID_DIM
    KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT

    STARTING_TEMP = args.STARTING_TEMP
    TEMP_MIN = args.TEMP_MIN
    ANNEAL_RATE = args.ANNEAL_RATE

    NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE

    #     transform = Compose(
    #         [
    #             RandomResizedCrop(args.image_size, args.image_size),
    #             OneOf(
    #                 [
    #                     IAAAdditiveGaussianNoise(),
    #                     GaussNoise(),
    #                 ],
    #                 p=0.2
    #             ),
    #             VerticalFlip(p=0.5),
    #             OneOf(
    #                 [
    #                     MotionBlur(p=.2),
    #                     MedianBlur(blur_limit=3, p=0.1),
    #                     Blur(blur_limit=3, p=0.1),
    #                 ],
    #                 p=0.2
    #             ),
    #             OneOf(
    #                 [
    #                     CLAHE(clip_limit=2),
    #                     IAASharpen(),
    #                     IAAEmboss(),
    #                     RandomBrightnessContrast(),
    #                 ],
    #                 p=0.3
    #             ),
    #             HueSaturationValue(p=0.3),
    # #             Normalize(
    # #                 mean=[0.485, 0.456, 0.406],
    # #                 std=[0.229, 0.224, 0.225],
    # #             )
    #         ],
    #         p=1.0
    #     )

    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize(IMAGE_SIZE),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor()
    ])

    sampler = None
    dl = None

    # data
    logger.debug(f"IMAGE_PATH : {IMAGE_PATH}")
    #     ds = AlbumentationImageDataset(
    #         IMAGE_PATH,
    #         transform=transform,
    #         args=args
    #     )
    ds = ImageFolder(
        IMAGE_PATH,
        transform=transform,
    )

    if args.model_parallel and (args.ddp
                                or args.horovod) and smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        ds = SplitDataset(ds, partitions=partitions_dict)
        ds.select(f"{smp.dp_rank()}")

    dl = DataLoader(ds,
                    BATCH_SIZE,
                    shuffle=True,
                    drop_last=args.model_parallel,
                    **args.kwargs)

    vae_params = dict(image_size=IMAGE_SIZE,
                      num_layers=NUM_LAYERS,
                      num_tokens=NUM_TOKENS,
                      codebook_dim=EMB_DIM,
                      hidden_dim=HID_DIM,
                      num_resnet_blocks=NUM_RESNET_BLOCKS)

    vae = DiscreteVAE(**vae_params,
                      smooth_l1_loss=SMOOTH_L1_LOSS,
                      kl_div_loss_weight=KL_LOSS_WEIGHT).to(device)
    # optimizer

    opt = Adam(vae.parameters(), lr=LEARNING_RATE)
    sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE)

    if args.model_parallel:
        import copy
        dummy_codebook = copy.deepcopy(vae.codebook)
        dummy_decoder = copy.deepcopy(vae.decoder)

        vae = smp.DistributedModel(vae)
        scaler = smp.amp.GradScaler()
        opt = smp.DistributedOptimizer(opt)

        if args.partial_checkpoint:
            args.checkpoint = smp.load(args.partial_checkpoint, partial=True)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])
        elif args.full_checkpoint:
            args.checkpoint = smp.load(args.full_checkpoint, partial=False)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])

    assert len(ds) > 0, 'folder does not contain any images'

    if (not args.model_parallel) and args.rank == 0:
        print(f'{len(ds)} images found for training')

        # weights & biases experiment tracking

        #         import wandb

        model_config = dict(num_tokens=NUM_TOKENS,
                            smooth_l1_loss=SMOOTH_L1_LOSS,
                            num_resnet_blocks=NUM_RESNET_BLOCKS,
                            kl_loss_weight=KL_LOSS_WEIGHT)

#         run = wandb.init(
#             project = 'dalle_train_vae',
#             job_type = 'train_model',
#             config = model_config
#         )

    def save_model(path):
        if not args.rank == 0:
            return

        save_obj = {'hparams': vae_params, 'weights': vae.state_dict()}

        torch.save(save_obj, path)

    # distribute with deepspeed
    if not args.model_parallel:
        deepspeed_utils.check_batch_size(BATCH_SIZE)
        deepspeed_config = {'train_batch_size': BATCH_SIZE}

        (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute(
            args=args,
            model=vae,
            optimizer=opt,
            model_parameters=vae.parameters(),
            training_data=ds if args.deepspeed else dl,
            lr_scheduler=sched,
            config_params=deepspeed_config,
        )

    try:
        # Rubik: Define smp.step. Return any tensors needed outside.
        @smp.step
        def train_step(vae, images, temp):
            #             logger.debug(f"args.amp : {args.amp}")
            with autocast(enabled=(args.amp > 0)):
                loss, recons = vae(images,
                                   return_loss=True,
                                   return_recons=True,
                                   temp=temp)

            scaled_loss = scaler.scale(loss) if args.amp else loss
            vae.backward(scaled_loss)
            #             torch.nn.utils.clip_grad_norm_(vae.parameters(), 5)
            return loss, recons

        @smp.step
        def get_codes_step(vae, images, k):
            images = images[:k]
            logits = vae.forward(images, return_logits=True)
            codebook_indices = logits.argmax(dim=1).flatten(1)
            return codebook_indices

        def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices):
            from functools import partial
            for module in dummy_codebook.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            image_embeds = dummy_codebook.forward(codebook_indices)
            b, n, d = image_embeds.shape
            h = w = int(sqrt(n))

            image_embeds = rearrange(image_embeds,
                                     'b (h w) d -> b d h w',
                                     h=h,
                                     w=w)
            for module in dummy_decoder.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            hard_recons = dummy_decoder.forward(image_embeds)
            return hard_recons

    except:
        pass

    # starting temperature

    global_step = 0
    temp = STARTING_TEMP

    for epoch in range(EPOCHS):
        ##
        batch_time = util.AverageMeter('Time', ':6.3f')
        data_time = util.AverageMeter('Data', ':6.3f')
        losses = util.AverageMeter('Loss', ':.4e')
        top1 = util.AverageMeter('Acc@1', ':6.2f')
        top5 = util.AverageMeter('Acc@5', ':6.2f')
        progress = util.ProgressMeter(
            len(dl), [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        vae.train()
        start = time.time()

        for i, (images, _) in enumerate(dl):
            images = images.to(device, non_blocking=True)
            opt.zero_grad()

            if args.model_parallel:
                loss, recons = train_step(vae, images, temp)
                # Rubik: Average the loss across microbatches.
                loss = loss.reduce_mean()
                recons = recons.reduce_mean()
            else:
                loss, recons = distr_vae(images,
                                         return_loss=True,
                                         return_recons=True,
                                         temp=temp)

            if (not args.model_parallel) and args.deepspeed:
                # Gradients are automatically zeroed after the step
                distr_vae.backward(loss)
                distr_vae.step()
            elif args.model_parallel:
                if args.amp:
                    scaler.step(opt)
                    scaler.update()
                else:
                    # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param
                    if len(list(vae.local_parameters())) > 0:
                        opt.step()
            else:
                loss.backward()
                opt.step()

            logs = {}

            if i % 10 == 0:
                if args.rank == 0:
                    #                 if deepspeed_utils.is_root_worker():
                    k = NUM_IMAGES_SAVE

                    with torch.no_grad():
                        if args.model_parallel:
                            model_dict = vae.state_dict()
                            model_dict_updated = {}
                            for key, val in model_dict.items():
                                if "decoder" in key:
                                    key = key.replace("decoder.", "")
                                elif "codebook" in key:
                                    key = key.replace("codebook.", "")
                                model_dict_updated[key] = val

                            dummy_decoder.load_state_dict(model_dict_updated,
                                                          strict=False)
                            dummy_codebook.load_state_dict(model_dict_updated,
                                                           strict=False)
                            codes = get_codes_step(vae, images, k)
                            codes = codes.reduce_mean().to(torch.long)
                            hard_recons = hard_recons_step(
                                dummy_decoder, dummy_codebook, codes)
                        else:
                            codes = vae.get_codebook_indices(images[:k])
                            hard_recons = vae.decode(codes)

                    images, recons = map(lambda t: t[:k], (images, recons))
                    images, recons, hard_recons, codes = map(
                        lambda t: t.detach().cpu(),
                        (images, recons, hard_recons, codes))
                    images, recons, hard_recons = map(
                        lambda t: make_grid(t.float(),
                                            nrow=int(sqrt(k)),
                                            normalize=True,
                                            range=(-1, 1)),
                        (images, recons, hard_recons))

#                     logs = {
#                         **logs,
#                         'sample images':        wandb.Image(images, caption = 'original images'),
#                         'reconstructions':      wandb.Image(recons, caption = 'reconstructions'),
#                         'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'),
#                         'codebook_indices':     wandb.Histogram(codes),
#                         'temperature':          temp
#                     }

                if args.model_parallel:
                    filename = f'{args.model_dir}/vae.pt'
                    if smp.dp_rank == 0:
                        if args.save_full_model:
                            model_dict = vae.state_dict()
                            opt_dict = opt.state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=False,
                            )
                        else:
                            model_dict = vae.local_state_dict()
                            opt_dict = opt.local_state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=True,
                            )
                    smp.barrier()

                else:
                    save_model(f'{args.model_dir}/vae.pt')
    #                     wandb.save(f'{args.model_dir}/vae.pt')

    # temperature anneal

                temp = max(temp * math.exp(-ANNEAL_RATE * global_step),
                           TEMP_MIN)

                # lr decay

                sched.step()

            # Collective loss, averaged
            if args.model_parallel:
                avg_loss = loss.detach().clone()
                #                 print("args.world_size : {}".format(args.world_size))
                avg_loss /= args.world_size

            else:
                avg_loss = deepspeed_utils.average_all(loss)

            if args.rank == 0:
                if i % 100 == 0:
                    lr = sched.get_last_lr()[0]
                    print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},')

                    logs = {
                        **logs, 'epoch': epoch,
                        'iter': i,
                        'loss': avg_loss.item(),
                        'lr': lr
                    }

#                 wandb.log(logs)
            global_step += 1

            if args.rank == 0:
                # Every print_freq iterations, check the loss, accuracy, and speed.
                # For best performance, it doesn't make sense to print these metrics every
                # iteration, since they incur an allreduce and some host<->device syncs.

                # Measure accuracy
                #                 prec1, prec5 = util.accuracy(output, target, topk=(1, 5))

                # to_python_float incurs a host<->device sync
                losses.update(util.to_python_float(loss), images.size(0))
                #                 top1.update(util.to_python_float(prec1), images.size(0))
                #                 top5.update(util.to_python_float(prec5), images.size(0))

                # Waiting until finishing operations on GPU (Pytorch default: async)
                torch.cuda.synchronize()
                batch_time.update((time.time() - start) / args.log_interval)
                end = time.time()

                print(
                    'Epoch: [{0}][{1}/{2}] '
                    'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, '
                    'Train_Speed={3:.3f} ({4:.3f}), '
                    'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format(
                        epoch,
                        i,
                        len(dl),
                        args.world_size * BATCH_SIZE / batch_time.val,
                        args.world_size * BATCH_SIZE / batch_time.avg,
                        batch_time=batch_time,
                        loss=losses))

#         if deepspeed_utils.is_root_worker():
# save trained model to wandb as an artifact every epoch's end

#             model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#             model_artifact.add_file(f'{args.model_dir}/vae.pt')
#             run.log_artifact(model_artifact)

    if args.rank == 0:
        #     if deepspeed_utils.is_root_worker():
        # save final vae and cleanup
        if args.model_parallel:
            logger.debug('save model_parallel')
        else:
            save_model(os.path.join(args.model_dir, 'vae-final.pt'))


#         wandb.save(f'{args.model_dir}/vae-final.pt')

#         model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#         model_artifact.add_file(f'{args.model_dir}/vae-final.pt')
#         run.log_artifact(model_artifact)

#         wandb.finish()

    if args.model_parallel:
        if args.assert_losses:
            if args.horovod or args.ddp:
                # SM Distributed: If using data parallelism, gather all losses across different model
                # replicas and check if losses match.

                losses = smp.allgather(loss, smp.DP_GROUP)
                for l in losses:
                    print(l)
                    assert math.isclose(l, losses[0])

                assert loss < 0.18
            else:
                assert loss < 0.08

        smp.barrier()
        print("SMP training finished successfully")