def setup_training(args): assert torch.cuda.is_available() if args.smp > 0: # Initialize SMP. The configuration is obtained from the parameters passed to # the Sagemaker PyTorch estimator. smp.init() # SMP: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda", smp.local_rank()) args.n_gpu = 1 # if args.local_rank == -1: # device = torch.device("cuda") # args.n_gpu = torch.cuda.device_count() # args.allreduce_post_accumulation = False # args.allreduce_post_accumulation_fp16 = False # else: # torch.cuda.set_device(args.local_rank) # device = torch.device("cuda", args.local_rank) # # Initializes the distributed backend which will take care of sychronizing nodes/GPUs # torch.distributed.init_process_group(backend='nccl', init_method='env://') # args.n_gpu = 1 if args.gradient_accumulation_steps == 1: args.allreduce_post_accumulation = False args.allreduce_post_accumulation_fp16 = False print( "device: {} n_gpu: {}, mp_rank: {}, rank: {}, distributed training: {}, 16-bits training: {}" .format(device, args.n_gpu, smp.mp_rank(), smp.rank(), bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) if args.train_batch_size % args.gradient_accumulation_steps != 0: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible" .format(args.gradient_accumulation_steps, args.train_batch_size)) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps if (not args.resume_from_checkpoint and os.path.exists(args.output_dir) and (os.listdir(args.output_dir) and any([i.startswith("ckpt") for i in os.listdir(args.output_dir)]))): raise ValueError( "Output directory ({}) already exists and is not empty.".format( args.output_dir)) if (not args.resume_from_checkpoint or not os.path.exists(args.output_dir)) and is_main_process(): os.makedirs(args.output_dir, exist_ok=True) return device, args
def wrapper(*args, **kwargs): mem_before = torch.cuda.memory_allocated(device=smp.local_rank()) f(*args, **kwargs) import gc gc.collect() gc.collect() gc.collect() mem_after = torch.cuda.memory_allocated(device=smp.local_rank()) print( f"rank is {smp.local_rank()}, function name is {f.__name__}, memory usage is {humanize.naturalsize(mem_after - mem_before)}" )
def measure_additional_mem_context(): smp.barrier() mem_before = torch.cuda.memory_allocated(device=smp.local_rank()) yield import gc gc.collect() gc.collect() gc.collect() mem_after = torch.cuda.memory_allocated(device=smp.local_rank()) print( f"rank is {smp.local_rank()}, memory usage is {humanize.naturalsize(mem_after - mem_before)}" ) smp.barrier()
def dist_setting(args): # args.data_parallel = False print("args.data_parallel : {}".format(args.data_parallel)) print("args.model_parallel : {}".format(args.model_parallel)) print("args.apex : {}".format(args.apex)) args.world_size = 1 args.host_num = args.hosts.index(args.current_host) if args.data_parallel: args.world_size = sdp.get_world_size() args.rank = sdp.get_rank() # total rank in all hosts args.local_rank = sdp.get_local_rank() # rank per host elif args.model_parallel: args.world_size = smp.size() args.local_rank = smp.local_rank() # rank per host args.rank = smp.rank() args.dp_size = smp.dp_size() args.dp_rank = smp.dp_rank() print( "smp.rank() : {}, smp.size() : {}, smp.mp_rank() : {}, smp.local_size() : {}, smp.get_mp_group() : {}, smp.get_dp_group() : {}, smp.local_rank() : {}, smp.dp_size() : {}, smp.dp_rank() : {}" .format(smp.rank(), smp.size(), smp.mp_rank(), smp.local_size(), smp.get_mp_group(), smp.get_dp_group(), smp.local_rank(), smp.dp_size(), smp.dp_rank())) else: args.world_size = len(args.hosts) * args.num_gpus if args.local_rank is not None: args.rank = args.num_gpus * args.host_num + \ args.local_rank # total rank in all hosts dist.init_process_group(backend=args.backend, rank=args.rank, world_size=args.world_size) logger.info( 'Initialized the distributed environment: \'{}\' backend on {} nodes. ' .format(args.backend, dist.get_world_size()) + 'Current host rank is {}. Number of gpus: {}'.format( dist.get_rank(), args.num_gpus)) print("**** [dist_setting] args.rank : {}".format(args.rank)) print("args.world_size : {}".format(args.world_size)) print("Use GPU: {} for training".format(args.local_rank)) args.lr = args.lr * float(args.world_size) args.batch_size //= args.world_size // args.num_gpus args.batch_size = max(args.batch_size, 1) return args
def delete_oldest_ckpt(args, delete_on_rank0_only=False): to_delete = smp.rank() == 0 if delete_on_rank0_only else smp.local_rank( ) == 0 if to_delete: re_pattern = "trained_gpt_nparams-(?P<num_params>\d+)_steps-(?P<total_steps>\d+)\.pt" # partial re_pattern += "_(?P<pp_rank>\d+)_(?P<tp_rank>\d+)" paths_per_step = collections.defaultdict(list) for p in os.listdir(args.checkpoint_dir): if re.match(re_pattern, p): step = int(re.match(re_pattern, p).group("total_steps")) path = os.path.join(args.checkpoint_dir, p) paths_per_step[step].append(path) if paths_per_step: oldest_step = sorted(paths_per_step.keys())[0] num_parts = len(paths_per_step[oldest_step]) if len(paths_per_step) >= args.num_kept_checkpoints: # delete oldest step to save the new one for p in paths_per_step[oldest_step]: os.remove(p) # else We still haven't reached maximum number of checkpoints -- no need to delete older ones return None
def _setup_devices(self) -> "torch.device": logger.info("PyTorch: setting up devices") if self.no_cuda: device = torch.device("cpu") self._n_gpu = 0 elif is_torch_tpu_available(): device = xm.xla_device() self._n_gpu = 0 elif is_sagemaker_mp_enabled(): local_rank = smp.local_rank() device = torch.device("cuda", local_rank) self._n_gpu = 1 elif is_sagemaker_dp_enabled(): sm_dist.init_process_group() self.local_rank = sm_dist.get_local_rank() device = torch.device("cuda", self.local_rank) self._n_gpu = 1 elif self.deepspeed: # deepspeed performs its own DDP internally, and requires the program to be started with: # deepspeed ./program.py # rather than: # python -m torch.distributed.launch --nproc_per_node=2 ./program.py from .integrations import is_deepspeed_available if not is_deepspeed_available(): raise ImportError("--deepspeed requires deepspeed: `pip install deepspeed`.") import deepspeed deepspeed.init_distributed() # workaround for setups like notebooks where the launcher can't be used, # but deepspeed requires a dist env. # env LOCAL_RANK could be set manually by the user, or via init_distributed if mpi4py is installed self.local_rank = int(os.environ.get("LOCAL_RANK", "-1")) device = torch.device("cuda", self.local_rank) self._n_gpu = 1 elif self.local_rank == -1: # if n_gpu is > 1 we'll use nn.DataParallel. # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will # trigger an error that a device index is missing. Index 0 takes into account the # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` # will use the first GPU in that env, i.e. GPU#1 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at # the default value. self._n_gpu = torch.cuda.device_count() else: # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs torch.distributed.init_process_group(backend="nccl") device = torch.device("cuda", self.local_rank) self._n_gpu = 1 if device.type == "cuda": torch.cuda.set_device(device) return device
def smp_init(args): args.deepspeed = False cfg = { "microbatches": args.num_microbatches, "placement_strategy": args.placement_strategy, "pipeline": args.pipeline, "optimize": args.optimize, "partitions": args.num_partitions, "horovod": args.horovod, "ddp": args.ddp, } smp.init(cfg) args.rank = smp.dp_rank() args.global_rank = smp.rank() args.world_size = smp.size() os.environ['RANK'] = str(args.rank) os.environ['WORLD_SIZE'] = str(args.world_size) os.environ['LOCAL_RANK'] = str(smp.local_rank()) # ## SMP_SKIP_GRAPH_VALIDATION=1 os.environ['SMP_SKIP_GRAPH_VALIDATION'] = "0" # args.bpe_path = "/opt/ml/code/dalle_pytorch/data/bpe_simple_vocab_16e6.txt" torch.cuda.set_device(smp.local_rank()) args.local_rank = smp.local_rank() # if args.seed is not None: # random.seed(args.seed) # torch.manual_seed(args.seed+args.rank) # np.random.seed(args.seed) # torch.cuda.manual_seed_all(args.seed) # cudnn.deterministic = True # if cudnn.deterministic: # warnings.warn('You have chosen to seed training. ' # 'This will turn on the CUDNN deterministic setting, ' # 'which can slow down your training considerably! ' # 'You may see unexpected behavior when restarting ' # 'from checkpoints.') return args
def is_world_process_zero(self) -> bool: """ Whether or not this process is the global main process (when training in a distributed fashion on several machines, this is only going to be :obj:`True` for one process). """ if self.is_model_parallel_enabled: return smp.rank() == 0 and smp.local_rank() == 0 and smp.mp_rank() == 0 and smp.dp_rank() == 0 else: return super().is_world_process_zero()
def local_process_index(self): """ The index of the local process used. """ if is_torch_tpu_available(): return xm.get_local_ordinal() elif is_sagemaker_mp_enabled(): return smp.local_rank() elif is_sagemaker_dp_enabled(): return sm_dist.get_rank() elif self.local_rank != -1: return self.local_rank return 0
def _setup_devices(self) -> "torch.device": logger.info("PyTorch: setting up devices") if torch.distributed.is_available( ) and torch.distributed.is_initialized() and self.local_rank == -1: logger.warning( "torch.distributed process group is initialized, but local_rank == -1. " "In order to use Torch DDP, launch your script with `python -m torch.distributed.launch" ) if self.no_cuda: device = torch.device("cpu") self._n_gpu = 0 elif is_sagemaker_model_parallel_available(): local_rank = smp.local_rank() device = torch.device("cuda", local_rank) self._n_gpu = 1 elif is_sagemaker_dp_enabled(): import smdistributed.dataparallel.torch.torch_smddp # noqa: F401 torch.distributed.init_process_group(backend="smddp") self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK")) device = torch.device("cuda", self.local_rank) self._n_gpu = 1 elif self.local_rank == -1: # if n_gpu is > 1 we'll use nn.DataParallel. # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will # trigger an error that a device index is missing. Index 0 takes into account the # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` # will use the first GPU in that env, i.e. GPU#1 device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at # the default value. self._n_gpu = torch.cuda.device_count() else: # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") device = torch.device("cuda", self.local_rank) self._n_gpu = 1 if device.type == "cuda": torch.cuda.set_device(device) return device
def memory_status(msg="", reset_max=True, sync=True): rank = smp.rank() tp_rank = smp.tp_rank() pp_rank = smp.pp_rank() rdp_rank = smp.rdp_rank() local_rank = smp.local_rank() if sync: torch.cuda.synchronize() if rdp_rank != 0: return if py3nvml != None: py3nvml.nvmlInit() handle = py3nvml.nvmlDeviceGetHandleByIndex(local_rank) info = py3nvml.nvmlDeviceGetMemoryInfo(handle) total_used = info.used / 1024**3 total_used_str = f"Totally used GPU memory: {total_used}" else: total_used_str = "" alloced = torch.cuda.memory_allocated(device=local_rank) max_alloced = torch.cuda.max_memory_allocated(device=local_rank) cached = torch.cuda.memory_reserved(device=local_rank) max_cached = torch.cuda.max_memory_reserved(device=local_rank) # convert to GB for printing alloced /= 1024**3 cached /= 1024**3 max_alloced /= 1024**3 max_cached /= 1024**3 print( f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}', f'device={local_rank} ' f'alloc {alloced:0.4f} max_alloced {max_alloced:0.4f} ' f'cache {cached:0.4f} max_cached {max_cached:0.4f} ' f'{total_used_str}') if reset_max: torch.cuda.reset_max_memory_cached() torch.cuda.reset_max_memory_allocated() if py3nvml != None: py3nvml.nvmlShutdown()
def dist_setting(args): # args.data_parallel = False print(f"args.data_parallel : {args.data_parallel}, args.model_parallel : {args.model_parallel}, args.apex : {args.apex}") args.world_size = 1 args.host_num = args.hosts.index(args.current_host) if args.data_parallel: sdp, DDP = _sdp_import(args) args.world_size = sdp.get_world_size() args.rank = sdp.get_rank() # total rank in all hosts args.local_rank = sdp.get_local_rank() # rank per host elif args.model_parallel: args.world_size = smp.size() args.world_size = args.num_gpus * len(args.hosts) args.local_rank = smp.local_rank() # rank per host args.rank = smp.rank() args.dp_size = smp.dp_size() args.dp_rank = smp.dp_rank() else: args.world_size = len(args.hosts) * args.num_gpus if args.local_rank is not None: args.rank = args.num_gpus * args.host_num + \ args.local_rank # total rank in all hosts dist.init_process_group(backend=args.backend, rank=args.rank, world_size=args.world_size) logger.info( 'Initialized the distributed environment: \'{}\' backend on {} nodes. ' .format(args.backend, dist.get_world_size()) + 'Current host rank is {}. Number of gpus: {}'.format( dist.get_rank(), args.num_gpus)) # if not args.model_parallel: args.lr = args.lr * float(args.world_size) args.batch_size //= args.world_size // args.num_gpus args.batch_size = max(args.batch_size, 1) return args
def _setup_devices(self) -> "torch.device": logger.info("PyTorch: setting up devices") if self.no_cuda: device = torch.device("cpu") self._n_gpu = 0 elif is_smdistributed_available() and self.mp_parameters != "": # smp.init() local_rank = smp.local_rank() device = torch.device("cuda", local_rank) self._n_gpu = 1 elif is_sagemaker_distributed_available(): import smdistributed.dataparallel.torch.distributed as dist dist.init_process_group() self.local_rank = dist.get_local_rank() device = torch.device("cuda", self.local_rank) self._n_gpu = 1 elif self.local_rank == -1: # if n_gpu is > 1 we'll use nn.DataParallel. # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will # trigger an error that a device index is missing. Index 0 takes into account the # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` # will use the first GPU in that env, i.e. GPU#1 device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at # the default value. self._n_gpu = torch.cuda.device_count() else: # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs torch.distributed.init_process_group(backend="nccl") device = torch.device("cuda", self.local_rank) self._n_gpu = 1 if device.type == "cuda": torch.cuda.set_device(device) return device
def memory_status_cpu(msg=""): import gc global last_mem_usage global base_mem_usage rdp_rank = smp.rdp_rank() gc.collect() gc.collect() gc.collect() objects = gc.get_objects() tensors = [ obj for obj in objects if isinstance(obj, torch.Tensor) and not obj.is_cuda ] torch_usage = 0 for t in tensors: torch_usage += t.numel() * dtype_to_bit[t.dtype] #total_usage = psutil.virtual_memory()[3] # This will get the total usage for all processes current_usage = process.memory_info().data total_usage = current_usage - base_mem_usage usage_change = current_usage - last_mem_usage last_mem_usage = current_usage torch_usage /= 1024**3 total_usage /= 1024**3 usage_change /= 1024**3 base_usage = base_mem_usage / 1024**3 rank = smp.rank() tp_rank = smp.tp_rank() pp_rank = smp.pp_rank() rdp_rank = smp.rdp_rank() local_rank = smp.local_rank() if rdp_rank != 0: return print( f'[{msg}] rank {rank} tp_rank {tp_rank} pp_rank {pp_rank} TORCH {torch.__version__}', f'device={local_rank} ' f'torch cpu tensor usage {torch_usage:0.4f} cpu mem usage {total_usage:0.4f} change since last measurement {usage_change:0.4f} base cpu mem usage {base_usage:0.4f}' )
def main(): global timeout_sent args = parse_arguments() random.seed(args.seed + args.local_rank) np.random.seed(args.seed + args.local_rank) torch.manual_seed(args.seed + args.local_rank) torch.cuda.manual_seed(args.seed + args.local_rank) worker_init = WorkerInitObj(args.seed + args.local_rank) device, args = setup_training(args) # Prepare optimizer ( model, optimizer, lr_scheduler, checkpoint, global_step, criterion, ) = prepare_model_and_optimizer(args, device) raw_train_start = None most_recent_ckpts_paths = [] average_loss = 0.0 # averaged loss every args.log_freq steps epoch = 0 training_steps = 0 test_losses = [] pool = ProcessPoolExecutor(1) # Note: We loop infinitely over epochs, termination is handled via iteration count while True: thread = None restored_data_loader = None if (not args.resume_from_checkpoint or epoch > 0 or (args.phase2 and global_step < 1) or args.init_checkpoint): files = [ os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in f ] files.sort() num_files = len(files) random.Random(args.seed + epoch).shuffle(files) f_start_id = 0 else: f_start_id = checkpoint["files"][0] files = checkpoint["files"][1:] args.resume_from_checkpoint = False num_files = len(files) # may not exist in all checkpoints epoch = checkpoint.get("epoch", 0) restored_dataloader = checkpoint.get("data_loader", None) shared_file_list = {} if smp.is_initialized(): dpsize = smp.dp_size() dprank = smp.dp_rank() elif torch.distributed.is_initialized(): dpsize = get_world_size() dprank = get_rank() else: dpsize = 1 dprank = 0 dparallel = dpsize > 1 if dparallel and dpsize > num_files: remainder = dpsize % num_files data_file = files[(f_start_id * dpsize + dprank + remainder * f_start_id) % num_files] else: data_file = files[(f_start_id * dpsize + dprank) % num_files] previous_file = data_file if restored_data_loader is None: train_data = pretraining_dataset(data_file, args.max_predictions_per_seq) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader( train_data, sampler=train_sampler, batch_size=args.train_batch_size * args.n_gpu, num_workers=4, worker_init_fn=worker_init, pin_memory=True, drop_last=True, ) # shared_file_list["0"] = (train_dataloader, data_file) else: train_dataloader = restored_data_loader restored_data_loader = None overflow_buf = None if args.allreduce_post_accumulation: overflow_buf = torch.cuda.IntTensor([0]) for f_id in range(f_start_id + 1, len(files)): if get_world_size() > num_files: data_file = files[(f_id * get_world_size() + get_rank() + remainder * f_id) % num_files] else: data_file = files[(f_id * get_world_size() + get_rank()) % num_files] previous_file = data_file dataset_future = pool.submit( create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init, ) train_iter = (tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process() else train_dataloader) if raw_train_start is None: raw_train_start = time.time() for step, batch in enumerate(train_iter): training_steps += 1 batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch if args.do_train: from smdistributed.modelparallel.test.torch.utils import dump_model, verify model.train() if args.smp > 0: loss_mbs = smp_step( args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step, ) loss = loss_mbs.reduce_mean() if smp.rank() == 0: print("Loss:", loss.item()) else: loss = train_step( args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step, ) divisor = 1 average_loss += loss.item() if training_steps % args.gradient_accumulation_steps == 0: lr_scheduler.step() # learning rate warmup global_step = take_optimizer_step( args, optimizer, model, overflow_buf, global_step) if global_step >= args.steps_this_run or timeout_sent: train_time_raw = time.time() - raw_train_start last_num_steps = (int( training_steps / args.gradient_accumulation_steps) % args.log_freq) last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps average_loss = torch.tensor( average_loss, dtype=torch.float32).cuda() average_loss = average_loss / (last_num_steps * divisor) if torch.distributed.is_initialized(): average_loss /= get_world_size() torch.distributed.all_reduce(average_loss) final_loss = loss.item() elif training_steps % ( args.log_freq * args.gradient_accumulation_steps) == 0: average_loss = 0 if (global_step >= args.steps_this_run or training_steps % (args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0 or timeout_sent): if smp.dp_rank() == 0 and not args.skip_checkpoint: if args.resume_step < 0 or not args.phase2: output_save_file = os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step)) else: output_save_file = os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step + args.phase1_end_step), ) if args.do_train: save_dict = { "model": model.local_state_dict(), "optimizer": optimizer.local_state_dict(), "files": [f_id] + files, "epoch": epoch, "data_loader": None if global_step >= args.steps_this_run else train_dataloader, } if args.fp16: save_dict["master params"] = list( amp.master_params(optimizer)) # SMP: Checkpoint mp_rank specific state smp.save(save_dict, output_save_file, partial=True) most_recent_ckpts_paths.append( output_save_file) if len(most_recent_ckpts_paths) > 3 and ( args.smp == 0 or smp.dp_rank() == 0): ckpt_to_be_removed = most_recent_ckpts_paths.pop( 0) os.remove(ckpt_to_be_removed + f"_{smp.mp_rank()}") # Exiting the training due to hitting max steps, or being sent a # timeout from the cluster scheduler if global_step >= args.steps_this_run or timeout_sent: del train_dataloader # thread.join() if smp.dp_rank() == 0 and args.save_full: output_save_file = os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step)) save_dict = { "model": model.local_state_dict(), "optimizer": optimizer.local_state_dict(), "files": [f_id] + files, "epoch": epoch, "data_loader": None if global_step >= args.steps_this_run else train_dataloader, } if args.fp16: save_dict["master params"] = list( amp.master_params(optimizer)) # SMP: Save a single checkpoint containing entire model parameters smp.save(save_dict, output_save_file, partial=False) smp.barrier() if smp.local_rank() == 0: print(f"Start syncing model checkpoints to s3") base_s3_path = os.path.dirname( os.path.dirname( os.getenv("SM_MODULE_DIR", ""))) curr_host = os.getenv("SM_CURRENT_HOST") full_s3_path = f"{base_s3_path}/checkpoints/{curr_host}/" sync_local_checkpoints_to_s3( local_path=args.output_dir, s3_path=full_s3_path) print( f"Finished syncing model checkpoints to s3" ) return args, final_loss, train_time_raw, global_step else: model.eval() with torch.no_grad(): loss = test_step( args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, criterion, step, ) print(f"global_step {global_step} Test Loss:", loss) test_losses.append(loss) global_step += 1 if global_step >= args.steps_this_run: return sum(test_losses) / len(test_losses) del train_dataloader # thread.join() # Make sure pool has finished and switch train_dataloader # NOTE: Will block until complete train_dataloader, data_file = dataset_future.result(timeout=None) epoch += 1
def prepare_model_and_optimizer(args, device): # Prepare model config = modeling.BertConfig.from_json_file(args.config_file) # Padding for divisibility by 8 if config.vocab_size % 8 != 0: config.vocab_size += 8 - (config.vocab_size % 8) if args.use_sequential > 0: config.use_sequential = True else: config.use_sequential = False modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training model = modeling.BertForPreTraining(config) model.checkpoint_activations(args.checkpoint_activations) if args.smp > 0: # SMP: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) checkpoint = None if not args.resume_from_checkpoint: global_step = 0 else: if not args.init_checkpoint: if not args.s3_checkpoint_uri: raise ValueError( "Need to set s3_checkpoint_uri, if init_checkpoint not set" ) if smp.local_rank() == 0: sync_s3_checkpoints_to_local(args.output_dir, args.s3_checkpoint_uri) smp.barrier() if args.resume_step == -1 and not args.init_checkpoint: model_names = [ f for f in os.listdir(args.output_dir) if ".pt" in f ] args.resume_step = max([ int(x.split(".pt")[0].split("_")[1].strip()) for x in model_names ]) global_step = args.resume_step if not args.init_checkpoint else 0 # SMP: Load a model that was saved with smp.save if not args.init_checkpoint: checkpoint = smp.load( os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), partial=args.partial_checkpoint, ) else: checkpoint = smp.load(args.init_checkpoint) model.load_state_dict(checkpoint["model"], strict=False) if args.phase2 and not args.init_checkpoint: global_step -= args.phase1_end_step if is_main_process(): print("resume step from ", args.resume_step) model.to(device) param_optimizer = list(model.named_parameters()) no_decay = ["bias", "gamma", "beta", "LayerNorm"] optimizer_grouped_parameters = [ { "params": [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], "weight_decay": 0.01, }, { "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate) if args.smp > 0: # SMP: Use Distributed Optimizer which allows the loading of optimizer state for a distributed model # Also provides APIs to obtain local optimizer state for the current mp_rank. optimizer = smp.DistributedOptimizer(optimizer) lr_scheduler = PolyWarmUpScheduler(optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps) if args.fp16: if args.loss_scale == 0: model, optimizer = amp.initialize( model, optimizer, opt_level="O2", loss_scale="dynamic", cast_model_outputs=torch.float16, ) else: model, optimizer = amp.initialize( model, optimizer, opt_level="O2", loss_scale=args.loss_scale, cast_model_outputs=torch.float16, ) amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale if args.resume_from_checkpoint: if args.phase2 or args.init_checkpoint: keys = list(checkpoint["optimizer"]["state"].keys()) # Override hyperparameters from previous checkpoint for key in keys: checkpoint["optimizer"]["state"][key]["step"] = global_step for iter, item in enumerate( checkpoint["optimizer"]["param_groups"]): checkpoint["optimizer"]["param_groups"][iter][ "step"] = global_step checkpoint["optimizer"]["param_groups"][iter][ "t_total"] = args.max_steps checkpoint["optimizer"]["param_groups"][iter][ "warmup"] = args.warmup_proportion checkpoint["optimizer"]["param_groups"][iter][ "lr"] = args.learning_rate optimizer.load_state_dict(checkpoint["optimizer"]) # , strict=False) # Restore AMP master parameters if args.fp16: optimizer._lazy_init_maybe_master_weights() optimizer._amp_stash.lazy_init_called = True optimizer.load_state_dict(checkpoint["optimizer"]) for param, saved_param in zip(amp.master_params(optimizer), checkpoint["master params"]): param.data.copy_(saved_param.data) # if args.local_rank != -1: # if not args.allreduce_post_accumulation: # model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size()) # else: # flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) ) # elif args.n_gpu > 1: # model = torch.nn.DataParallel(model) criterion = BertPretrainingCriterion(config.vocab_size) return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
def init_master_params(self): if self.use_smp: torch.cuda.set_device(smp.local_rank()) register_optimizer_hooks(self.model) self.fp32paramid_from_fp16paramid = {} # Tensor used to determine if a nan/if has happend. # Any non-zero value indicates inf/nan. # Note that we keep this for the cases that grad scaler is none. # We still record nan/inf if we have a bfloat16 with a grad scaler. if self.grad_scaler: self.found_inf = torch.cuda.FloatTensor([0.0]) # Dummy tensor needed for apex multi-apply tensor. # For bfloat, we don't have multi-tensor apply and for now # we set it to none so the multi-tensor apply gets ignored. if self.bf16: self._dummy_overflow_buf = None else: self._dummy_overflow_buf = torch.cuda.IntTensor([0]) # In case grad scaler is not passed, define the unity scale. if self.grad_scaler is None: self._scale_one = torch.cuda.FloatTensor([1.0]) # ====================== # main parameter stuff # ====================== # only need to create contiguous buffer for fp16 params which require grads contig_buffer_size = 0 for param_group in self.optimizer.param_groups: for param in param_group["params"]: if param.requires_grad and param.type() in [ "torch.cuda.HalfTensor", "torch.cuda.BFloat16Tensor", ]: contig_buffer_size += param.numel() self.fp32_param_buffer = torch.empty( contig_buffer_size, device=torch.device("cuda", smp.local_rank()), dtype=torch.float32, requires_grad=True, ) offset = 0 # only need to create contiguous buffer for fp16 params which require grads contig_buffer_size = 0 for param_group in self.optimizer.param_groups: for param in param_group["params"]: if param.requires_grad and param.type() in [ "torch.cuda.HalfTensor", "torch.cuda.BFloat16Tensor", ]: contig_buffer_size += param.numel() self.fp32_param_buffer = torch.empty( contig_buffer_size, device=torch.device("cuda", smp.local_rank()), dtype=torch.float32, requires_grad=True, ) offset = 0 # For all the groups in the original optimizer: for param_group in self.optimizer.param_groups: float16_params_this_group = [] fp32_params_this_group = [] fp32_from_float16_params_this_group = [] fp32_from_fp16_paramids_this_group = [] # For all the parameters in this group: for i, param in enumerate(param_group["params"]): if param.requires_grad: # float16 params: if param.type() in [ "torch.cuda.HalfTensor", "torch.cuda.BFloat16Tensor" ]: float16_params_this_group.append(param) # Create a copy with torch.no_grad(): master_param_buffer = self.fp32_param_buffer.narrow( 0, offset, param.numel()).view_as(param) master_param_buffer.copy_(param.float()) offset += param.numel() main_param = nn.Parameter( master_param_buffer, requires_grad=param.requires_grad) self.master_is_distributed[ main_param] = self.model.is_distributed_parameter( param) self.master_distribution_axis[id( main_param)] = get_distribution_axis(param) fp32_from_fp16_paramids_this_group.append( id(main_param)) if hasattr(param, "shared"): main_param.shared = param.shared # Replace the optimizer params with the new fp32 copy. param_group["params"][i] = main_param fp32_from_float16_params_this_group.append(main_param) # Reset existing state dict key to the new main param. if param in self.optimizer.state: self.optimizer.state[ main_param] = self.optimizer.state.pop(param) self.fp32paramid_from_fp16paramid[id(param)] = id( main_param) # fp32 params. elif param.type() == "torch.cuda.FloatTensor": fp32_params_this_group.append(param) param_group["params"][i] = param else: raise TypeError("Wrapped parameters must be one of " "torch.cuda.FloatTensor, " "torch.cuda.HalfTensor, or " "torch.cuda.BFloat16Tensor. " "Received {}".format(param.type())) self.float16_groups.append(float16_params_this_group) self.fp32_from_float16_groups.append( fp32_from_float16_params_this_group) self.fp32_from_fp16_paramid_groups.append( fp32_from_fp16_paramids_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group) # Leverage state_dict() and load_state_dict() to # recast preexisting per-param state tensors self.optimizer.load_state_dict(self.optimizer.state_dict()) self.master_params_created = True
def clip_grad_norm_fp32(parameters, param_is_distributed, max_norm, norm_type=2): """Clips gradient norm of an iterable of parameters whose gradients are in fp32. This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and added functionality to handle model parallel parameters. Note that the gradients are modified in place. Arguments: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized max_norm (float or int): max norm of the gradients norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] # Filter parameters based on: # - grad should not be none # - parameter should not be shared # - should not be a replica due to tensor model parallelism torch.cuda.set_device(smp.local_rank()) grads = [] grads_for_norm = [] for param in parameters: grad_not_none = param.grad is not None is_not_shared = not hasattr(param, "shared") or not param.shared is_not_tp_duplicate = smp.tp_rank() == 0 or ( param in param_is_distributed and param_is_distributed[param]) if grad_not_none: grad = param.grad.detach() # Make sure the grads are in fp32 assert param.grad.type() == 'torch.cuda.FloatTensor' grads.append(grad) if grad_not_none and is_not_shared and is_not_tp_duplicate: grads_for_norm.append(grad) # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) total_norm = torch.tensor(0.0, device=torch.device("cuda")) # Calculate norm. if norm_type == inf: if len(grads_for_norm) > 0: total_norm = max(grad.abs().max() for grad in grads_for_norm) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=smp.get_mp_process_group()) total_norm = total_norm_cuda[0].item() else: if norm_type == 2.0: dummy_overflow_buf = torch.cuda.IntTensor([0], device=torch.device( "cuda", smp.local_rank())) # Use apex's multi-tensor applier for efficiency reasons. # Multi-tensor applier takes a function and a list of list # and performs the operation on that list all in one kernel. if len(grads_for_norm) > 0: grad_norm, _ = multi_tensor_applier( amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads_for_norm], False # no per-parameter norm ) # Since we will be summing across data parallel groups, # we need the pow(norm-type). total_norm = grad_norm**norm_type else: for grad in grads_for_norm: grad_norm = torch.norm(grad, norm_type) total_norm += grad_norm**norm_type # Sum across all model-parallel GPUs. torch.distributed.all_reduce(total_norm, op=torch.distributed.ReduceOp.SUM, group=smp.get_mp_process_group()) total_norm = total_norm.item()**(1.0 / norm_type) # Scale. if len(grads) > 0: clip_coeff = max_norm / (total_norm + 1.0e-6) if clip_coeff < 1.0: dummy_overflow_buf = torch.cuda.IntTensor([0], device=torch.device( "cuda", smp.local_rank())) multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) return total_norm
def init_master_params(self): if self.use_smp: torch.cuda.set_device(smp.local_rank()) register_optimizer_hooks(self.model) self.fp32paramid_from_fp16paramid = {} # only need to create contiguous buffer for fp16 params which require grads contig_buffer_size = 0 for param_group in self.optimizer.param_groups: for param in param_group["params"]: if param.requires_grad and param.type( ) == "torch.cuda.HalfTensor": contig_buffer_size += param.numel() self.fp32_param_buffer = torch.empty( contig_buffer_size, device=torch.device("cuda", smp.local_rank()), dtype=torch.float32, requires_grad=True, ) offset = 0 for i, param_group in enumerate(self.optimizer.param_groups): self.maybe_print( "FP16_Optimizer processing param group {}:".format(i)) fp16_params_this_group = [] fp32_params_this_group = [] fp32_from_fp16_params_this_group = [] fp32_from_fp16_paramids_this_group = [] for i, param in enumerate(param_group["params"]): if param.requires_grad: if param.type() == "torch.cuda.HalfTensor": self.maybe_print( "FP16_Optimizer received torch.cuda.HalfTensor with {}" .format(param.size())) fp16_params_this_group.append(param) with torch.no_grad(): master_param_buffer = self.fp32_param_buffer.narrow( 0, offset, param.numel()).view_as(param) master_param_buffer.copy_(param.float()) offset += param.numel() master_param = nn.Parameter( master_param_buffer, requires_grad=param.requires_grad) self.master_is_distributed[ master_param] = self.model.is_distributed_parameter( param) self.master_distribution_axis[id( master_param)] = get_distribution_axis(param) param_group["params"][i] = master_param fp32_from_fp16_params_this_group.append(master_param) fp32_from_fp16_paramids_this_group.append( id(master_param)) # Reset existing state dict key to the new master param. # We still need to recast per-param state tensors, if any, to FP32. if param in self.optimizer.state: self.optimizer.state[ master_param] = self.optimizer.state.pop(param) self.fp32paramid_from_fp16paramid[id(param)] = id( master_param) elif param.type() == "torch.cuda.FloatTensor": self.maybe_print( "FP16_Optimizer received torch.cuda.FloatTensor with {}" .format(param.size())) fp32_params_this_group.append(param) param_group["params"][i] = param else: raise TypeError( "Wrapped parameters must be either " "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " "Received {}".format(param.type())) self.fp16_groups.append(fp16_params_this_group) self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) self.fp32_from_fp16_paramid_groups.append( fp32_from_fp16_paramids_this_group) self.fp32_from_fp32_groups.append(fp32_params_this_group) # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors self.optimizer.load_state_dict(self.optimizer.state_dict()) # alternative way to cast per-param state tensors: # self.optimizer.load_state_dict(init_state_dict) self.overflow = False self.first_closure_call_this_step = True self.master_params_created = True
def init_train(): """ Train the PyTorch model """ cat_mask = [ False, True, True, True, True, False, True, True, True, True, True, False, False, False, False, False, False, False ] train_ds = CsvDatasetSimple(args.train) test_ds = CsvDatasetSimple(args.test) batch_size = args.batch_size epochs = args.epochs learning_rate = args.learning_rate logger.info("batch_size = {}, epochs = {}, learning rate = {}".format( batch_size, epochs, learning_rate)) # smdistributed: initialize the backend smp.init() # smdistributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") # smdistributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time #dataset = datasets.MNIST("../data", train=True, download=False) dataset = train_ds # smdistributed: Shard the dataset based on data-parallel ranks if smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset = SplitDataset(dataset, partitions=partitions_dict) dataset.select(f"{smp.dp_rank()}") # smdistributed: Set drop_last=True to ensure that batch size is always divisible # by the number of microbatches train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True) model = TabularNet(n_cont=9, n_cat=9, cat_mask=cat_mask, cat_dim=[ 0, 2050, 13, 5, 366, 0, 50000, 50000, 50000, 50000, 50, 0, 0, 0, 0, 0, 0, 0 ], y_min=0., y_max=1.) logger.debug(model) optimizer = optim.Adadelta(model.parameters(), lr=4.0) # SMP: Instantiate DistributedModel object using the model. # This handles distributing the model among multiple ranks # behind the scenes # If horovod is enabled this will do an overlapping_all_reduce by # default. # smdistributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) optimizer = smp.DistributedOptimizer(optimizer) train(model, device, train_loader, optimizer) torch.save(model.state_dict(), args.model_dir + "/model.pth")
'from checkpoints.') args.is_distributed = len(args.hosts) > 1 and args.backend is not None args.is_multigpus = args.num_gpus > 1 args.multigpus_distributed = (args.is_distributed or args.is_multigpus) logger.debug(f"args.image_folder : {args.image_folder}") args.world_size = 1 args.local_rank = 0 args.rank = 0 if args.model_parallel: args.world_size = smp.size() args.local_rank = smp.local_rank() # rank per host args.rank = smp.rank() args.dp_size = smp.dp_size() args.dp_rank = smp.dp_rank() logger.debug(f"args.world_size : {args.world_size}, args.local_rank : {args.local_rank}, args.rank : {args.rank}, \ args.dp_size : {args.dp_size}, args.dp_rank : {args.dp_rank}") else: # initialize deepspeed print(f"args.deepspeed : {args.deepspeed}") deepspeed_utils.init_deepspeed(args.deepspeed) # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size) ## SageMaker try: if os.environ.get('SM_CHANNEL_TRAINING') is not None:
def main(): model_args, data_args, training_args, smp_args = parse_args() model, tokenizer = initialize_model_and_tokenizer(model_args) # Get datasets train_dataset, eval_dataset = Preprocess.datasets(model_args, data_args, training_args) if is_sagemaker_mp_enabled(): initialize_smp(smp_args, training_args) torch.set_default_dtype(torch.float32) num_params = print_num_parameters(model) # smdistributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") if not training_args.same_seed: # Set seed by tp_rank to prevent weights from being the same on different tp_ranks set_seed(training_args.seed + smp.tp_rank()) model = smp.DistributedModel(model, trace_device=smp_args.trace_device, gradient_as_bucket_view=True) torch.set_default_dtype(torch.float32) iter_model = model # Build parameter groups (weight decay and non-decay). while isinstance(iter_model, (DistributedDataParallel, FP16_Module)): iter_model = iter_model.module param_groups = get_param_groups_by_weight_decay(iter_model) if training_args.use_adamw > 0: optimizer = training_args.AdamW( param_groups, betas=(training_args.beta1, training_args.beta2), lr=training_args.lr, weight_decay=training_args.weight_decay, ) else: optimizer = optim.Adam( param_groups, betas=(training_args.beta1, training_args.beta2), lr=training_args.lr, weight_decay=training_args.weight_decay, ) optimizer = smp.DistributedOptimizer(optimizer) lr_scheduler = get_learning_rate_scheduler(optimizer, training_args) total_steps = 0 start_train_path_index = 0 start_batch_index = 0 # Initialize Trainer instance trainer = SMPTrainer( model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=default_data_collator, ) trainer.train_smp( model, optimizer, lr_scheduler, start_train_path_index, start_batch_index, num_params, total_steps, training_args, prescaled_batch=smp_args.prescaled_batch, )
def main(): if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") use_ddp = True use_horovod = False # Fix seeds in order to get the same losses across runs random.seed(1) np.random.seed(1) torch.manual_seed(1) torch.cuda.manual_seed(1) smp.init() # SM Distributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") kwargs = {"batch_size": 64} kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False}) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) # SM Distributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time if smp.local_rank() == 0: dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) smp.barrier() dataset1 = datasets.MNIST("../data", train=True, download=False, transform=transform) if (use_ddp or use_horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset1 = SplitDataset(dataset1, partitions=partitions_dict) dataset1.select(f"{smp.dp_rank()}") # Download and create dataloaders for train and test dataset dataset2 = datasets.MNIST("../data", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **kwargs) model = GroupedNet() # SMP handles the transfer of parameters to the right device # and the user doesn't need to call 'model.to' explicitly. # model.to(device) optimizer = optim.Adadelta(model.parameters(), lr=4.0) # SM Distributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) scaler = smp.amp.GradScaler() optimizer = smp.DistributedOptimizer(optimizer) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) for epoch in range(1, 2): train(model, scaler, device, train_loader, optimizer, epoch) test_loss = test(model, device, test_loader) scheduler.step() if smp.rank() == 0: if os.path.exists("/opt/ml/local_checkpoints"): print("-INFO- PATH DO EXIST") else: os.makedirs("/opt/ml/local_checkpoints") print("-INFO- PATH DO NOT EXIST") # Waiting the save checkpoint to be finished before run another allgather_object smp.barrier() if smp.dp_rank() == 0: model_dict = model.local_state_dict() opt_dict = optimizer.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, f"/opt/ml/local_checkpoints/pt_mnist_checkpoint.pt", partial=True, ) smp.barrier() if smp.local_rank() == 0: print("Start syncing") base_s3_path = os.path.dirname( os.path.dirname(os.getenv("SM_MODULE_DIR", ""))) curr_host = os.getenv("SM_CURRENT_HOST") full_s3_path = f"{base_s3_path}/checkpoints/{curr_host}/" sync_local_checkpoints_to_s3(local_path="/opt/ml/local_checkpoints", s3_path=full_s3_path) print("Finished syncing")
def main(): args = parse_args() if args.shard_optimizer_state > 0 and not args.skip_full_optimizer: raise ValueError( "If shard_optimizer_state is enabled, skip_full_optimizer must also be enabled. Full optimizer saving is currently not supported under optimizer state sharding." ) if args.partition_assignment != "" and args.manual_partition == 0: print("[Warning] partition_assignment is set, enable manual_partition") args.manual_partition = 1 # any value here is overriden by the config set in notebook when launching the sagemaker job smp_config = { "ddp": True, "tensor_parallel_degree": args.tensor_parallel_degree, "pipeline_parallel_degree": args.pipeline_parallel_degree, "microbatches": args.microbatches, # if activation_checkpointing true checkpoints transformer layers below "checkpoint_attentions": False if args.activation_checkpointing else True, "shard_optimizer_state": args.shard_optimizer_state > 0, "prescaled_batch": args.prescaled_batch > 0, "offload_activations": args.offload_activations > 0, "optimize": args.optimize, "auto_partition": False if args.manual_partition else True, "default_partition": 0, "static_mode": args.static_mode > 0, "fast_mode": args.fast_mode > 0, } if args.smp_version < 110: smp_config["fp16_params"] = args.fp16 > 0 else: smp_config["fp16"] = args.fp16 > 0 smp_config["delayed_parameter_initialization"] = args.delayed_param > 0 smp_config["placement_strategy"] = args.placement_strategy smp_config[ "activation_loading_horizon"] = args.activation_loading_horizon smp_config["skip_tracing"] = args.skip_tracing > 0 if args.active_microbatches is not None: smp_config["active_microbatches"] = args.active_microbatches smp.init(smp_config) if smp.rank() == 0: print("Arguments:", args.__dict__) print(f"Transformers version: {transformers.__version__}") print( f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}" ) print(f"smdistributed config: {smp_config}") if args.save_final_full_model and smp.rank() == 0: print( f"[Warning] Note that save_final_full_model only saves the final model at the end of all steps. It does not save optimizer state. Optimizer state is only saved with partial models which are saved at checkpointing_freq during training. If you want to restart training you need partial checkpoints." ) if args.partition_assignment != "": partition_assignment = args.partition_assignment.split(",") assert ( len(partition_assignment) == smp.pp_size() ), f"partition_assignment must have the same size as pipeline parallel degree, but getting {len(partition_assignment)} vs {smp.pp_size()}" if smp.rank() == 0 or (smp.local_rank() == 0 and args.use_fsx == 0): for path in [args.model_dir, args.checkpoint_dir]: if not os.path.exists(path): os.makedirs(path, exist_ok=True) model_config = GPT2Config( vocab_size=args.vocab_size, n_positions=args.max_context_width, n_embd=args.hidden_width, n_layer=args.num_layers, n_head=args.num_heads, n_inner=None, activation_function="gelu_new", resid_pdrop=args.resid_pdrop, embd_pdrop=args.embd_pdrop, attn_pdrop=args.attn_pdrop, layer_norm_epsilon=1e-05, initializer_range=0.02, summary_type="cls_index", summary_use_proj=True, summary_activation=None, summary_proj_to_labels=True, summary_first_dropout=args.summary_first_pdrop, # gradient_checkpointing=args.gradient_checkpointing > 0, use_cache=False, bos_token_id=50256, eos_token_id=50256, return_dict=True, ) # the following improves start-up time by skipping proper initialization # of weights in the original model. this is not a problem because DistributedModel # will override those weights anyway when tensor_parallel_degree > 1. if smp.tp_size() > 1: from transformers.modeling_utils import PreTrainedModel PreTrainedModel.init_weights = lambda x: None set_seed(args.seed) if args.enable_memory_profiling > 0: memory_status_cpu(msg="before model creation") if args.smp_version < 110: if args.fp16: torch.set_default_dtype(torch.float16) with smp.tensor_parallelism( enabled=smp.tp_size() > 1, attention_in_fp32=args.attention_in_fp32 > 0): with smp.delay_param_initialization( enabled=(smp.tp_size() > 1 and args.delayed_param > 0)): model = AutoModelForCausalLM.from_config(model_config) else: with smp.model_creation( tensor_parallelism=smp.tp_size() > 1, attention_in_fp32=args.attention_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0, fused_softmax=args.fused_softmax > 0, fused_bias_gelu=args.fused_bias_gelu > 0, dtype=torch.float16 if args.fp16 else torch.get_default_dtype(), ): model = AutoModelForCausalLM.from_config(model_config) if args.smp_version < 110 and args.fp16: model = FP16_Module(model) if args.enable_memory_profiling > 0: memory_status_cpu(msg="after model creation") num_params = sum([np.prod(p.size()) for p in model.parameters()]) if smp.rank() == 0: print(f"# total parameters: {num_params}") # smdistributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") if not args.same_seed: # Set seed by tp_rank to prevent weights from being the same on different tp_ranks set_seed(args.seed + smp.tp_rank()) # smdistributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. if args.smp_version < 110 and args.fp16: torch.set_default_dtype(torch.float16) if args.enable_memory_profiling > 0: memory_status_cpu(msg="before dist model creation") model = smp.DistributedModel(model, trace_device="gpu") if args.enable_memory_profiling > 0: memory_status_cpu(msg="after dist model creation") if args.smp_version < 110: if smp.tp_size() > 1: transformer_layers = model.module.module.module.transformer.seq_layers else: transformer_layers = model.module.module.module.transformer.h else: m = model.get_module() if smp.tp_size() > 1: transformer_layers = m.transformer.seq_layers else: transformer_layers = m.transformer.h if args.manual_partition: print(f"Manual partition enabled") if args.partition_assignment != "": get_num_layers = lambda x: int(partition_assignment[x]) total_layers = sum( [get_num_layers(pp_rank) for pp_rank in range(smp.pp_size())]) assert ( total_layers == args.num_layers ), f"partition_assignment must have the same total transformer layers as model, but getting {total_layers} vs {args.num_layers}" else: # evenly distribute layers across all partitions div, rem = divmod(args.num_layers, smp.pp_size()) get_num_layers = lambda x: (div + 1 if x >= smp.pp_size() - rem else div) assignments = [] # (TODO) This is required for 175B otherwise a hang for partition "8,17,17,18,18,18" # Need further investigation # for pp_rank in reversed(range(smp.pp_size())): for pp_rank in range(smp.pp_size()): nl = get_num_layers(pp_rank) print(f"{nl} layers assigned to partition {pp_rank}") assignments += [pp_rank for _ in range(nl)] for i, c in enumerate(transformer_layers.children()): smp.set_partition(c, assignments[i]) if args.smp_version < 110: iter_model = model # Build parameter groups (weight decay and non-decay). while isinstance(iter_model, (DistributedDataParallel, FP16_Module)): iter_model = iter_model.module else: iter_model = m param_groups = get_param_groups_by_weight_decay(iter_model) if args.use_adamw > 0: optimizer = optim.AdamW(param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay) else: optimizer = optim.Adam(param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay) if args.activation_checkpointing: kwargs = {} if isinstance(transformer_layers, nn.Sequential): kwargs["pack_args_as_tuple"] = True kwargs["strategy"] = args.activation_strategy smp.set_activation_checkpointing(transformer_layers, **kwargs) if args.smp_version < 110: optimizer = FP16_Optimizer( model, optimizer, static_loss_scale=None, dynamic_loss_scale=True, use_smp=True, dynamic_loss_args={ "scale_window": 1000, "min_scale": 1, "delayed_shift": 2 }, params_have_main_grad=False, shard_optimizer_state=args.shard_optimizer_state > 0, ) optimizer = smp.DistributedOptimizer(optimizer) model.register_post_step_hook( lambda model, optimizer: optimizer.init_master_params()) else: optimizer = smp.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={ "scale_window": 1000, "min_scale": 1, "delayed_shift": 2 }, ) lr_scheduler = get_learning_rate_scheduler(optimizer, args) if args.enable_memory_profiling > 0: model.register_post_partition_hook( lambda model, optimizer: memory_status(msg="After_partition")) # load after wrapping model and optimizer with smp Distributed... if args.load_full or args.load_partial: if args.load_partial and args.load_full: print( "Since both --load_partial and --load_full set, will try to load from full checkpoint." "If the intention is to load from partial checkpoint, please don't set --load_full" ) partial = not args.load_full path = args.checkpoint_dir if partial else args.model_dir translate_from_hf = not partial model, optimizer, total_steps, start_train_path_index, start_batch_index = load_model_and_optimizer( path, model, optimizer, lr_scheduler, partial, args, translate_from_hf=translate_from_hf, seq_length=args.max_context_width, load_model=True, load_optimizer=args.load_partial > 0, num_params=num_params, ) else: total_steps = 0 start_train_path_index = 0 start_batch_index = 0 start = time.time() total_steps, throughput, loss = train( model, optimizer, lr_scheduler, model_config, start_train_path_index, start_batch_index, num_params, total_steps, args, ) time_to_train = time.time() - start if args.ci: print(f"[SMP_METRIC]__GPT2__Time_to_train__{time_to_train}") print(f"[SMP_METRIC]__GPT2__samples/second__{throughput}") print(f"[SMP_METRIC]__GPT2__Loss__{loss}") if not args.load_partial and not args.load_full: assert time_to_train < args.time_to_train assert throughput > args.throughput if args.loss: assert loss < args.loss if args.save_final_full_model: # saves full model at the end base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt" out_path = os.path.join(args.model_dir, base_path) if smp.rdp_rank() == 0: save( out_path, model, optimizer, lr_scheduler, model_config, num_params, total_steps, -1, args, partial=False, translate_to_hf=smp.tp_size() > 1, seq_length=args.max_context_width, ) smp.barrier() if smp.rank() == 0: print("SMP training finished successfully")
def train( model, optimizer, lr_scheduler, model_config, start_train_path_index, start_batch_index, num_params, total_steps, args, ): model.train() if args.parallel_proc_data_processing: pool = ProcessPoolExecutor(1) dp_rank = smp.dp_rank() if not args.prescaled_batch else smp.rdp_rank() dp_size = smp.dp_size() if not args.prescaled_batch else smp.rdp_size() data_type = "BERT" if args.use_bert_data else "GPT" if args.use_bert_data: train_paths = sorted([ os.path.join(args.training_dir, p) for p in os.listdir(args.training_dir) if os.path.isfile(os.path.join(args.training_dir, p)) and "training" in p ]) else: if args.zipped_data > 0: file_extension = ".json.gz" else: file_extension = ".json" train_paths = sorted( [ os.path.join(args.training_dir, p) for p in os.listdir(args.training_dir) if p.endswith(file_extension) ] ) train_dataloader = create_pretraining_dataloader( [train_paths[start_train_path_index]], args.train_batch_size, args.max_context_width, seed=args.seed, dp_rank=dp_rank, dp_size=dp_size, shuffle=args.same_seed < 1, zipped=args.zipped_data > 0, use_last_file_only=args.fast_validation > 0, data_type=data_type, ) if args.validation_freq is not None: # load all validation examples if smp.rank() == 0: print("Creating val dataloader") if args.use_bert_data: val_paths = sorted([ os.path.join(args.test_dir, p) for p in os.listdir(args.test_dir) if os.path.isfile(os.path.join(args.test_dir, p)) and "testing" in p ]) else: if args.zipped_data > 0: file_extension = ".json.gz" else: file_extension = ".json" val_paths = sorted( [ os.path.join(args.test_dir, p) for p in os.listdir(args.test_dir) if p.endswith(file_extension) ] ) val_dataloader = create_pretraining_dataloader( val_paths, args.val_batch_size, args.max_context_width, seed=args.seed, dp_rank=dp_rank, dp_size=dp_size, shuffle=True, zipped=args.zipped_data > 0, use_last_file_only=args.fast_validation > 0, data_type=data_type, ) if smp.rank() == 0: print("Created val dataloader") start = time.time() throughput = None to_save = {"loss": [], "val_loss": []} loss_metric = 0 def should_record(): # only record the ranks that in the tp group that contains global rank 0 if smp.tp_size() > 1: tp_group = smp.get_tp_group() return 0 in tp_group else: return smp.rank() == 0 # Set the same seed for computation set_seed(args.seed) for index in range(start_train_path_index, args.epochs*len(train_paths)): next_train_path_index = (index + 1) % len(train_paths) curr_train_path_index = index % len(train_paths) if total_steps >= args.max_steps: break if args.parallel_proc_data_processing: dataset_future = pool.submit( create_pretraining_dataloader, [train_paths[next_train_path_index]], args.train_batch_size, args.max_context_width, seed=args.seed, dp_rank=dp_rank, dp_size=dp_size, shuffle=args.same_seed < 1, zipped=args.zipped_data > 0, use_last_file_only=args.fast_validation > 0, data_type=data_type, ) if smp.rank() == 0: if args.use_bert_data: print(f"Reading data from training path {train_dataloader.dataset.input_file}") else: print(f"Reading data from training path {train_dataloader.dataset.input_paths}") for batch_idx, input_data in enumerate(train_dataloader): if batch_idx < start_batch_index: if smp.rank() == 0: print(f"Resuming from saved batch index {start_batch_index}, skipping batch {batch_idx}...") continue else: start_batch_index = 0 if args.use_bert_data: input_ids, _, attention_mask, _, _ = input_data else: input_ids, attention_mask = input_data if total_steps >= args.max_steps: break step_start = time.time() optimizer.zero_grad(set_grads_to_None=True) if args.logits_output: train_output = train_step(model, optimizer, input_ids, attention_mask, args) loss_mb = train_output["loss"] logits_mb = train_output["logits"] if smp.tp_size() > 1: logits = torch.cat(tuple(logits_mb.outputs), dim=1) else: logits = torch.cat(tuple(logits_mb.outputs), dim=0) else: # Return value, loss_mb is a StepOutput object loss_mb = train_step(model, optimizer, input_ids, attention_mask, args) # smdistributed: Average the loss across microbatches. loss = loss_mb.reduce_mean() if not args.validation_freq: loss_metric = loss.item() if args.fp16: if args.megatron: success, _, _ = optimizer.step() overflow = not success else: optimizer.update_master_grads() optimizer.clip_master_grads(args.grad_clip) optimizer.step() overflow = optimizer.overflow if not (args.fp16 and overflow): lr_scheduler.step() total_steps += 1 time_elapsed = time.time() - start step_time = time.time() - step_start sample_processed = input_ids.shape[0] * dp_size throughput = sample_processed / step_time if smp.rank() == 0 and not total_steps % args.logging_freq: print( f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {throughput} samples/sec" ) # evaluate on validation if args.validation_freq and not (total_steps % args.validation_freq): cur_state = np.random.get_state() model = model.eval() val_loss, val_ppl = eval_model( model, val_dataloader, args.validation_batches, args.use_bert_data ) if is_main_process(smp.rank()): print( f"({int(time.time()-start)}s) Batch {total_steps - 1} Validation loss: {val_loss}" ) print( f"({int(time.time()-start)}s) Batch {total_steps - 1} Validation perplexity: {val_ppl}" ) loss_metric = val_loss if args.logits_output: to_save["val_loss"].append(val_loss) model = model.train() if args.preserve_np_state > 0: np.random.set_state(cur_state) # checkpoint if not (total_steps % args.checkpoint_freq): base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt" out_path = os.path.join(args.checkpoint_dir, base_path) total_ckpts = total_steps // args.checkpoint_freq # save_or_verify_ckptsum if this is the last checkpoint if (args.save_or_verify_ckptsum and total_steps >= args.max_steps) or ( (total_ckpts + 1) * args.checkpoint_freq ) > args.max_steps: # Save optimizer and model tensor sums and scalars before saving save_ckptsum( args, model, optimizer, filename=os.path.join(args.model_dir, "saved_partial_sum"), ) save( out_path, model, optimizer, lr_scheduler, model_config, num_params, total_steps, curr_train_path_index, args, partial=True, batch_idx=batch_idx+1, ) if smp.local_rank() == 0: delete_oldest_ckpt(args) if args.logits_output: to_save["loss"].append(loss.item()) if total_steps >= args.max_steps: if should_record() and args.logits_output: to_save["logits"] = logits.detach().cpu() output_file = f"rank_{smp.rank()}_" + args.logits_output torch.save(to_save, os.path.join(args.model_dir, output_file)) print(f"logits and loss saved at {os.path.join(args.model_dir, output_file)}") break del train_dataloader if args.parallel_proc_data_processing: s = time.time() train_dataloader = dataset_future.result(timeout=None) wait_time = time.time() - s if wait_time > 1: # TODO if this happens, we should try num_workers>1 in dataloader print( f"[{smp.rank()}] Waited {wait_time} for data loader to be ready. Please check if dataloader performance can be improved to avoid these waits." ) else: train_dataloader = create_pretraining_dataloader( [train_paths[next_train_path_index]], args.train_batch_size, args.max_context_width, seed=args.seed, dp_rank=dp_rank, dp_size=dp_size, shuffle=args.same_seed < 1, zipped=args.zipped_data > 0, use_last_file_only=args.fast_validation > 0, data_type=data_type, ) return total_steps, throughput, loss_metric
def main(): parser = get_parser() args = parser.parse_args() if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") use_ddp = args.ddp > 0 use_horovod = args.horovod > 0 # Fix seeds in order to get the same losses across runs random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) cfg = { "microbatches": args.num_microbatches, "placement_strategy": "spread", "pipeline": args.pipeline, "optimize": "speed", "partitions": args.num_partitions, "horovod": use_horovod, "ddp": use_ddp, } smp.init(cfg) # SM Distributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") kwargs = {"batch_size": args.batch_size} kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False}) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) # SM Distributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time if smp.local_rank() == 0: dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) smp.barrier() dataset1 = datasets.MNIST("../data", train=True, download=False, transform=transform) if (use_ddp or use_horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset1 = SplitDataset(dataset1, partitions=partitions_dict) dataset1.select(f"{smp.dp_rank()}") # Download and create dataloaders for train and test dataset dataset2 = datasets.MNIST("../data", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **kwargs) model = GroupedNet() # SMP handles the transfer of parameters to the right device # and the user doesn't need to call 'model.to' explicitly. # model.to(device) optimizer = optim.Adadelta(model.parameters(), lr=args.lr) # SM Distributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) scaler = smp.amp.GradScaler() optimizer = smp.DistributedOptimizer(optimizer) if args.partial_checkpoint: checkpoint = smp.load(args.partial_checkpoint, partial=True) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: checkpoint = smp.load(args.full_checkpoint, partial=False) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) for epoch in range(1, args.epochs + 1): train(args, model, scaler, device, train_loader, optimizer, epoch) test_loss = test(args, model, device, test_loader) scheduler.step() if args.save_partial_model: if smp.dp_rank() == 0: model_dict = model.local_state_dict() opt_dict = optimizer.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, f"./pt_mnist_checkpoint.pt", partial=True, ) if args.save_full_model: if smp.dp_rank() == 0: model_dict = model.state_dict() opt_dict = optimizer.state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, "./pt_mnist_checkpoint.pt", partial=False, ) # Waiting the save checkpoint to be finished before run another allgather_object smp.barrier() if args.assert_losses: if use_horovod or use_ddp: # SM Distributed: If using data parallelism, gather all losses across different model # replicas and check if losses match. losses = smp.allgather(test_loss, smp.DP_GROUP) for l in losses: assert math.isclose(l, losses[0]) assert test_loss < 0.18 else: assert test_loss < 0.08
def main(): parser = get_parser() args = parser.parse_args() if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") args.rank = -1 args.world_size = 1 if args.model_parallel: args.deepspeed = False cfg = { "microbatches": args.num_microbatches, "placement_strategy": args.placement_strategy, "pipeline": args.pipeline, "optimize": args.optimize, "partitions": args.num_partitions, "horovod": args.horovod, "ddp": args.ddp, } smp.init(cfg) torch.cuda.set_device(smp.local_rank()) args.rank = smp.dp_rank() args.world_size = smp.size() else: # initialize deepspeed print(f"args.deepspeed : {args.deepspeed}") deepspeed_utils.init_deepspeed(args.deepspeed) if deepspeed_utils.is_root_worker(): args.rank = 0 if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed + args.rank) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size) cudnn.deterministic = True if cudnn.deterministic: warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True} device = torch.device("cuda") logger.debug(f"args.image_folder : {args.image_folder}") logger.debug(f"args.rank : {args.rank}") ## SageMaker try: if os.environ.get('SM_MODEL_DIR') is not None: args.model_dir = os.environ.get('SM_MODEL_DIR') # args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR') args.image_folder = os.environ.get('SM_CHANNEL_TRAINING') except: logger.debug("not SageMaker") pass IMAGE_SIZE = args.image_size IMAGE_PATH = args.image_folder EPOCHS = args.EPOCHS BATCH_SIZE = args.BATCH_SIZE LEARNING_RATE = args.LEARNING_RATE LR_DECAY_RATE = args.LR_DECAY_RATE NUM_TOKENS = args.NUM_TOKENS NUM_LAYERS = args.NUM_LAYERS NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS EMB_DIM = args.EMB_DIM HID_DIM = args.HID_DIM KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT STARTING_TEMP = args.STARTING_TEMP TEMP_MIN = args.TEMP_MIN ANNEAL_RATE = args.ANNEAL_RATE NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE # transform = Compose( # [ # RandomResizedCrop(args.image_size, args.image_size), # OneOf( # [ # IAAAdditiveGaussianNoise(), # GaussNoise(), # ], # p=0.2 # ), # VerticalFlip(p=0.5), # OneOf( # [ # MotionBlur(p=.2), # MedianBlur(blur_limit=3, p=0.1), # Blur(blur_limit=3, p=0.1), # ], # p=0.2 # ), # OneOf( # [ # CLAHE(clip_limit=2), # IAASharpen(), # IAAEmboss(), # RandomBrightnessContrast(), # ], # p=0.3 # ), # HueSaturationValue(p=0.3), # # Normalize( # # mean=[0.485, 0.456, 0.406], # # std=[0.229, 0.224, 0.225], # # ) # ], # p=1.0 # ) transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize(IMAGE_SIZE), T.CenterCrop(IMAGE_SIZE), T.ToTensor() ]) sampler = None dl = None # data logger.debug(f"IMAGE_PATH : {IMAGE_PATH}") # ds = AlbumentationImageDataset( # IMAGE_PATH, # transform=transform, # args=args # ) ds = ImageFolder( IMAGE_PATH, transform=transform, ) if args.model_parallel and (args.ddp or args.horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } ds = SplitDataset(ds, partitions=partitions_dict) ds.select(f"{smp.dp_rank()}") dl = DataLoader(ds, BATCH_SIZE, shuffle=True, drop_last=args.model_parallel, **args.kwargs) vae_params = dict(image_size=IMAGE_SIZE, num_layers=NUM_LAYERS, num_tokens=NUM_TOKENS, codebook_dim=EMB_DIM, hidden_dim=HID_DIM, num_resnet_blocks=NUM_RESNET_BLOCKS) vae = DiscreteVAE(**vae_params, smooth_l1_loss=SMOOTH_L1_LOSS, kl_div_loss_weight=KL_LOSS_WEIGHT).to(device) # optimizer opt = Adam(vae.parameters(), lr=LEARNING_RATE) sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE) if args.model_parallel: import copy dummy_codebook = copy.deepcopy(vae.codebook) dummy_decoder = copy.deepcopy(vae.decoder) vae = smp.DistributedModel(vae) scaler = smp.amp.GradScaler() opt = smp.DistributedOptimizer(opt) if args.partial_checkpoint: args.checkpoint = smp.load(args.partial_checkpoint, partial=True) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: args.checkpoint = smp.load(args.full_checkpoint, partial=False) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) assert len(ds) > 0, 'folder does not contain any images' if (not args.model_parallel) and args.rank == 0: print(f'{len(ds)} images found for training') # weights & biases experiment tracking # import wandb model_config = dict(num_tokens=NUM_TOKENS, smooth_l1_loss=SMOOTH_L1_LOSS, num_resnet_blocks=NUM_RESNET_BLOCKS, kl_loss_weight=KL_LOSS_WEIGHT) # run = wandb.init( # project = 'dalle_train_vae', # job_type = 'train_model', # config = model_config # ) def save_model(path): if not args.rank == 0: return save_obj = {'hparams': vae_params, 'weights': vae.state_dict()} torch.save(save_obj, path) # distribute with deepspeed if not args.model_parallel: deepspeed_utils.check_batch_size(BATCH_SIZE) deepspeed_config = {'train_batch_size': BATCH_SIZE} (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute( args=args, model=vae, optimizer=opt, model_parameters=vae.parameters(), training_data=ds if args.deepspeed else dl, lr_scheduler=sched, config_params=deepspeed_config, ) try: # Rubik: Define smp.step. Return any tensors needed outside. @smp.step def train_step(vae, images, temp): # logger.debug(f"args.amp : {args.amp}") with autocast(enabled=(args.amp > 0)): loss, recons = vae(images, return_loss=True, return_recons=True, temp=temp) scaled_loss = scaler.scale(loss) if args.amp else loss vae.backward(scaled_loss) # torch.nn.utils.clip_grad_norm_(vae.parameters(), 5) return loss, recons @smp.step def get_codes_step(vae, images, k): images = images[:k] logits = vae.forward(images, return_logits=True) codebook_indices = logits.argmax(dim=1).flatten(1) return codebook_indices def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices): from functools import partial for module in dummy_codebook.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) image_embeds = dummy_codebook.forward(codebook_indices) b, n, d = image_embeds.shape h = w = int(sqrt(n)) image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h=h, w=w) for module in dummy_decoder.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) hard_recons = dummy_decoder.forward(image_embeds) return hard_recons except: pass # starting temperature global_step = 0 temp = STARTING_TEMP for epoch in range(EPOCHS): ## batch_time = util.AverageMeter('Time', ':6.3f') data_time = util.AverageMeter('Data', ':6.3f') losses = util.AverageMeter('Loss', ':.4e') top1 = util.AverageMeter('Acc@1', ':6.2f') top5 = util.AverageMeter('Acc@5', ':6.2f') progress = util.ProgressMeter( len(dl), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) vae.train() start = time.time() for i, (images, _) in enumerate(dl): images = images.to(device, non_blocking=True) opt.zero_grad() if args.model_parallel: loss, recons = train_step(vae, images, temp) # Rubik: Average the loss across microbatches. loss = loss.reduce_mean() recons = recons.reduce_mean() else: loss, recons = distr_vae(images, return_loss=True, return_recons=True, temp=temp) if (not args.model_parallel) and args.deepspeed: # Gradients are automatically zeroed after the step distr_vae.backward(loss) distr_vae.step() elif args.model_parallel: if args.amp: scaler.step(opt) scaler.update() else: # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param if len(list(vae.local_parameters())) > 0: opt.step() else: loss.backward() opt.step() logs = {} if i % 10 == 0: if args.rank == 0: # if deepspeed_utils.is_root_worker(): k = NUM_IMAGES_SAVE with torch.no_grad(): if args.model_parallel: model_dict = vae.state_dict() model_dict_updated = {} for key, val in model_dict.items(): if "decoder" in key: key = key.replace("decoder.", "") elif "codebook" in key: key = key.replace("codebook.", "") model_dict_updated[key] = val dummy_decoder.load_state_dict(model_dict_updated, strict=False) dummy_codebook.load_state_dict(model_dict_updated, strict=False) codes = get_codes_step(vae, images, k) codes = codes.reduce_mean().to(torch.long) hard_recons = hard_recons_step( dummy_decoder, dummy_codebook, codes) else: codes = vae.get_codebook_indices(images[:k]) hard_recons = vae.decode(codes) images, recons = map(lambda t: t[:k], (images, recons)) images, recons, hard_recons, codes = map( lambda t: t.detach().cpu(), (images, recons, hard_recons, codes)) images, recons, hard_recons = map( lambda t: make_grid(t.float(), nrow=int(sqrt(k)), normalize=True, range=(-1, 1)), (images, recons, hard_recons)) # logs = { # **logs, # 'sample images': wandb.Image(images, caption = 'original images'), # 'reconstructions': wandb.Image(recons, caption = 'reconstructions'), # 'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'), # 'codebook_indices': wandb.Histogram(codes), # 'temperature': temp # } if args.model_parallel: filename = f'{args.model_dir}/vae.pt' if smp.dp_rank == 0: if args.save_full_model: model_dict = vae.state_dict() opt_dict = opt.state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=False, ) else: model_dict = vae.local_state_dict() opt_dict = opt.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=True, ) smp.barrier() else: save_model(f'{args.model_dir}/vae.pt') # wandb.save(f'{args.model_dir}/vae.pt') # temperature anneal temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN) # lr decay sched.step() # Collective loss, averaged if args.model_parallel: avg_loss = loss.detach().clone() # print("args.world_size : {}".format(args.world_size)) avg_loss /= args.world_size else: avg_loss = deepspeed_utils.average_all(loss) if args.rank == 0: if i % 100 == 0: lr = sched.get_last_lr()[0] print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},') logs = { **logs, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item(), 'lr': lr } # wandb.log(logs) global_step += 1 if args.rank == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy # prec1, prec5 = util.accuracy(output, target, topk=(1, 5)) # to_python_float incurs a host<->device sync losses.update(util.to_python_float(loss), images.size(0)) # top1.update(util.to_python_float(prec1), images.size(0)) # top5.update(util.to_python_float(prec5), images.size(0)) # Waiting until finishing operations on GPU (Pytorch default: async) torch.cuda.synchronize() batch_time.update((time.time() - start) / args.log_interval) end = time.time() print( 'Epoch: [{0}][{1}/{2}] ' 'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, ' 'Train_Speed={3:.3f} ({4:.3f}), ' 'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format( epoch, i, len(dl), args.world_size * BATCH_SIZE / batch_time.val, args.world_size * BATCH_SIZE / batch_time.avg, batch_time=batch_time, loss=losses)) # if deepspeed_utils.is_root_worker(): # save trained model to wandb as an artifact every epoch's end # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae.pt') # run.log_artifact(model_artifact) if args.rank == 0: # if deepspeed_utils.is_root_worker(): # save final vae and cleanup if args.model_parallel: logger.debug('save model_parallel') else: save_model(os.path.join(args.model_dir, 'vae-final.pt')) # wandb.save(f'{args.model_dir}/vae-final.pt') # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae-final.pt') # run.log_artifact(model_artifact) # wandb.finish() if args.model_parallel: if args.assert_losses: if args.horovod or args.ddp: # SM Distributed: If using data parallelism, gather all losses across different model # replicas and check if losses match. losses = smp.allgather(loss, smp.DP_GROUP) for l in losses: print(l) assert math.isclose(l, losses[0]) assert loss < 0.18 else: assert loss < 0.08 smp.barrier() print("SMP training finished successfully")