def _get_train_sampler(self): if self.is_model_parallel_enabled: if self.args.group_by_length: return DistributedLengthGroupedSampler( self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank()) else: return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank()) else: return super()._get_train_sampler()
def dist_setting(args): # args.data_parallel = False print("args.data_parallel : {}".format(args.data_parallel)) print("args.model_parallel : {}".format(args.model_parallel)) print("args.apex : {}".format(args.apex)) args.world_size = 1 args.host_num = args.hosts.index(args.current_host) if args.data_parallel: args.world_size = sdp.get_world_size() args.rank = sdp.get_rank() # total rank in all hosts args.local_rank = sdp.get_local_rank() # rank per host elif args.model_parallel: args.world_size = smp.size() args.local_rank = smp.local_rank() # rank per host args.rank = smp.rank() args.dp_size = smp.dp_size() args.dp_rank = smp.dp_rank() print( "smp.rank() : {}, smp.size() : {}, smp.mp_rank() : {}, smp.local_size() : {}, smp.get_mp_group() : {}, smp.get_dp_group() : {}, smp.local_rank() : {}, smp.dp_size() : {}, smp.dp_rank() : {}" .format(smp.rank(), smp.size(), smp.mp_rank(), smp.local_size(), smp.get_mp_group(), smp.get_dp_group(), smp.local_rank(), smp.dp_size(), smp.dp_rank())) else: args.world_size = len(args.hosts) * args.num_gpus if args.local_rank is not None: args.rank = args.num_gpus * args.host_num + \ args.local_rank # total rank in all hosts dist.init_process_group(backend=args.backend, rank=args.rank, world_size=args.world_size) logger.info( 'Initialized the distributed environment: \'{}\' backend on {} nodes. ' .format(args.backend, dist.get_world_size()) + 'Current host rank is {}. Number of gpus: {}'.format( dist.get_rank(), args.num_gpus)) print("**** [dist_setting] args.rank : {}".format(args.rank)) print("args.world_size : {}".format(args.world_size)) print("Use GPU: {} for training".format(args.local_rank)) args.lr = args.lr * float(args.world_size) args.batch_size //= args.world_size // args.num_gpus args.batch_size = max(args.batch_size, 1) return args
def _save_smp(self, output_dir: Optional[str] = None): if smp.dp_rank() != 0: return output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info("Saving model checkpoint to %s", output_dir) # Calling the state_dict needs to be done on the wrapped model state_dict = self.model_wrapped.state_dict() # Rest of the save is done for the main process only if self.is_world_process_zero(): model = self.model if not isinstance(model, PreTrainedModel): model = unwrap_model(model) if isinstance(model, PreTrainedModel): model.save_pretrained(output_dir, state_dict=state_dict) else: logger.info( "Trainer.model is not a `PreTrainedModel`, only saving its state dict." ) torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) # Good practice: save your training arguments together with the trained model torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
def _get_eval_sampler( self, eval_dataset: Dataset ) -> Optional[torch.utils.data.sampler.Sampler]: if self.is_model_parallel_enabled: return SequentialDistributedSampler(eval_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank()) else: return super()._get_eval_sampler(eval_dataset)
def is_world_process_zero(self) -> bool: """ Whether or not this process is the global main process (when training in a distributed fashion on several machines, this is only going to be :obj:`True` for one process). """ if self.is_model_parallel_enabled: return smp.rank() == 0 and smp.local_rank() == 0 and smp.mp_rank( ) == 0 and smp.dp_rank() == 0 else: return super().is_world_process_zero()
def _get_train_sampler(self): if self.is_model_parallel_enabled: if self.args.group_by_length: return DistributedLengthGroupedSampler( self.train_dataset, self.args.train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank()) elif not self.args.dataloader_drop_last: return DistributedSamplerWithLoop( self.train_dataset, self.args.per_device_train_batch_size, num_replicas=smp.dp_size(), rank=smp.dp_rank(), ) else: return DistributedSampler(self.train_dataset, num_replicas=smp.dp_size(), rank=smp.dp_rank()) else: return super()._get_train_sampler()
def process_index(self): """ The number of processes used in parallel. """ if is_torch_tpu_available(): return xm.get_ordinal() elif is_sagemaker_mp_enabled(): return smp.dp_rank() elif is_sagemaker_dp_enabled(): return sm_dist.get_rank() elif self.local_rank != -1: return torch.distributed.get_rank() return 0
def process_index(self): """ The index of the current process used. """ if is_torch_tpu_available(): return xm.get_ordinal() elif is_sagemaker_mp_enabled(): return smp.dp_rank() elif is_sagemaker_dp_enabled(): return sm_dist.get_rank() elif self.local_rank != -1: return torch.distributed.get_rank() return 0
def _save_checkpoint(self, model, trial, metrics=None): if self.is_model_parallel_enabled: if smp.dp_rank() != 0: return checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self.args.output_dir self.store_flos() output_dir = os.path.join(run_dir, checkpoint_folder) self.save_model(output_dir) # Consolidate the state dict on all processed of dp_rank 0 opt_state_dict = self.optimizer.state_dict() # Save it and the scheduler on the main process if self.is_world_process_zero(): torch.save(opt_state_dict, os.path.join(output_dir, "optimizer.pt")) with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) reissue_pt_warnings(caught_warnings) # Determine the new best metric / best model checkpoint if metrics is not None and self.args.metric_for_best_model is not None: metric_to_check = self.args.metric_for_best_model if not metric_to_check.startswith("eval_"): metric_to_check = f"eval_{metric_to_check}" metric_value = metrics[metric_to_check] operator = np.greater if self.args.greater_is_better else np.less if (self.state.best_metric is None or self.state.best_model_checkpoint is None or operator(metric_value, self.state.best_metric)): self.state.best_metric = metric_value self.state.best_model_checkpoint = output_dir # Save the Trainer state if self.is_world_process_zero(): self.state.save_to_json( os.path.join(output_dir, "trainer_state.json")) # Maybe delete some older checkpoints. if self.is_world_process_zero(): self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) else: super()._save_checkpoint(self, model, trial, metrics=metrics)
def smp_init(args): args.deepspeed = False cfg = { "microbatches": args.num_microbatches, "placement_strategy": args.placement_strategy, "pipeline": args.pipeline, "optimize": args.optimize, "partitions": args.num_partitions, "horovod": args.horovod, "ddp": args.ddp, } smp.init(cfg) args.rank = smp.dp_rank() args.global_rank = smp.rank() args.world_size = smp.size() os.environ['RANK'] = str(args.rank) os.environ['WORLD_SIZE'] = str(args.world_size) os.environ['LOCAL_RANK'] = str(smp.local_rank()) # ## SMP_SKIP_GRAPH_VALIDATION=1 os.environ['SMP_SKIP_GRAPH_VALIDATION'] = "0" # args.bpe_path = "/opt/ml/code/dalle_pytorch/data/bpe_simple_vocab_16e6.txt" torch.cuda.set_device(smp.local_rank()) args.local_rank = smp.local_rank() # if args.seed is not None: # random.seed(args.seed) # torch.manual_seed(args.seed+args.rank) # np.random.seed(args.seed) # torch.cuda.manual_seed_all(args.seed) # cudnn.deterministic = True # if cudnn.deterministic: # warnings.warn('You have chosen to seed training. ' # 'This will turn on the CUDNN deterministic setting, ' # 'which can slow down your training considerably! ' # 'You may see unexpected behavior when restarting ' # 'from checkpoints.') return args
def dist_setting(args): # args.data_parallel = False print(f"args.data_parallel : {args.data_parallel}, args.model_parallel : {args.model_parallel}, args.apex : {args.apex}") args.world_size = 1 args.host_num = args.hosts.index(args.current_host) if args.data_parallel: sdp, DDP = _sdp_import(args) args.world_size = sdp.get_world_size() args.rank = sdp.get_rank() # total rank in all hosts args.local_rank = sdp.get_local_rank() # rank per host elif args.model_parallel: args.world_size = smp.size() args.world_size = args.num_gpus * len(args.hosts) args.local_rank = smp.local_rank() # rank per host args.rank = smp.rank() args.dp_size = smp.dp_size() args.dp_rank = smp.dp_rank() else: args.world_size = len(args.hosts) * args.num_gpus if args.local_rank is not None: args.rank = args.num_gpus * args.host_num + \ args.local_rank # total rank in all hosts dist.init_process_group(backend=args.backend, rank=args.rank, world_size=args.world_size) logger.info( 'Initialized the distributed environment: \'{}\' backend on {} nodes. ' .format(args.backend, dist.get_world_size()) + 'Current host rank is {}. Number of gpus: {}'.format( dist.get_rank(), args.num_gpus)) # if not args.model_parallel: args.lr = args.lr * float(args.world_size) args.batch_size //= args.world_size // args.num_gpus args.batch_size = max(args.batch_size, 1) return args
def main(): global timeout_sent args = parse_arguments() random.seed(args.seed + args.local_rank) np.random.seed(args.seed + args.local_rank) torch.manual_seed(args.seed + args.local_rank) torch.cuda.manual_seed(args.seed + args.local_rank) worker_init = WorkerInitObj(args.seed + args.local_rank) device, args = setup_training(args) # Prepare optimizer ( model, optimizer, lr_scheduler, checkpoint, global_step, criterion, ) = prepare_model_and_optimizer(args, device) raw_train_start = None most_recent_ckpts_paths = [] average_loss = 0.0 # averaged loss every args.log_freq steps epoch = 0 training_steps = 0 test_losses = [] pool = ProcessPoolExecutor(1) # Note: We loop infinitely over epochs, termination is handled via iteration count while True: thread = None restored_data_loader = None if (not args.resume_from_checkpoint or epoch > 0 or (args.phase2 and global_step < 1) or args.init_checkpoint): files = [ os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in f ] files.sort() num_files = len(files) random.Random(args.seed + epoch).shuffle(files) f_start_id = 0 else: f_start_id = checkpoint["files"][0] files = checkpoint["files"][1:] args.resume_from_checkpoint = False num_files = len(files) # may not exist in all checkpoints epoch = checkpoint.get("epoch", 0) restored_dataloader = checkpoint.get("data_loader", None) shared_file_list = {} if smp.is_initialized(): dpsize = smp.dp_size() dprank = smp.dp_rank() elif torch.distributed.is_initialized(): dpsize = get_world_size() dprank = get_rank() else: dpsize = 1 dprank = 0 dparallel = dpsize > 1 if dparallel and dpsize > num_files: remainder = dpsize % num_files data_file = files[(f_start_id * dpsize + dprank + remainder * f_start_id) % num_files] else: data_file = files[(f_start_id * dpsize + dprank) % num_files] previous_file = data_file if restored_data_loader is None: train_data = pretraining_dataset(data_file, args.max_predictions_per_seq) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader( train_data, sampler=train_sampler, batch_size=args.train_batch_size * args.n_gpu, num_workers=4, worker_init_fn=worker_init, pin_memory=True, drop_last=True, ) # shared_file_list["0"] = (train_dataloader, data_file) else: train_dataloader = restored_data_loader restored_data_loader = None overflow_buf = None if args.allreduce_post_accumulation: overflow_buf = torch.cuda.IntTensor([0]) for f_id in range(f_start_id + 1, len(files)): if get_world_size() > num_files: data_file = files[(f_id * get_world_size() + get_rank() + remainder * f_id) % num_files] else: data_file = files[(f_id * get_world_size() + get_rank()) % num_files] previous_file = data_file dataset_future = pool.submit( create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init, ) train_iter = (tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process() else train_dataloader) if raw_train_start is None: raw_train_start = time.time() for step, batch in enumerate(train_iter): training_steps += 1 batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch if args.do_train: from smdistributed.modelparallel.test.torch.utils import dump_model, verify model.train() if args.smp > 0: loss_mbs = smp_step( args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step, ) loss = loss_mbs.reduce_mean() if smp.rank() == 0: print("Loss:", loss.item()) else: loss = train_step( args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step, ) divisor = 1 average_loss += loss.item() if training_steps % args.gradient_accumulation_steps == 0: lr_scheduler.step() # learning rate warmup global_step = take_optimizer_step( args, optimizer, model, overflow_buf, global_step) if global_step >= args.steps_this_run or timeout_sent: train_time_raw = time.time() - raw_train_start last_num_steps = (int( training_steps / args.gradient_accumulation_steps) % args.log_freq) last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps average_loss = torch.tensor( average_loss, dtype=torch.float32).cuda() average_loss = average_loss / (last_num_steps * divisor) if torch.distributed.is_initialized(): average_loss /= get_world_size() torch.distributed.all_reduce(average_loss) final_loss = loss.item() elif training_steps % ( args.log_freq * args.gradient_accumulation_steps) == 0: average_loss = 0 if (global_step >= args.steps_this_run or training_steps % (args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0 or timeout_sent): if smp.dp_rank() == 0 and not args.skip_checkpoint: if args.resume_step < 0 or not args.phase2: output_save_file = os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step)) else: output_save_file = os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step + args.phase1_end_step), ) if args.do_train: save_dict = { "model": model.local_state_dict(), "optimizer": optimizer.local_state_dict(), "files": [f_id] + files, "epoch": epoch, "data_loader": None if global_step >= args.steps_this_run else train_dataloader, } if args.fp16: save_dict["master params"] = list( amp.master_params(optimizer)) # SMP: Checkpoint mp_rank specific state smp.save(save_dict, output_save_file, partial=True) most_recent_ckpts_paths.append( output_save_file) if len(most_recent_ckpts_paths) > 3 and ( args.smp == 0 or smp.dp_rank() == 0): ckpt_to_be_removed = most_recent_ckpts_paths.pop( 0) os.remove(ckpt_to_be_removed + f"_{smp.mp_rank()}") # Exiting the training due to hitting max steps, or being sent a # timeout from the cluster scheduler if global_step >= args.steps_this_run or timeout_sent: del train_dataloader # thread.join() if smp.dp_rank() == 0 and args.save_full: output_save_file = os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step)) save_dict = { "model": model.local_state_dict(), "optimizer": optimizer.local_state_dict(), "files": [f_id] + files, "epoch": epoch, "data_loader": None if global_step >= args.steps_this_run else train_dataloader, } if args.fp16: save_dict["master params"] = list( amp.master_params(optimizer)) # SMP: Save a single checkpoint containing entire model parameters smp.save(save_dict, output_save_file, partial=False) smp.barrier() if smp.local_rank() == 0: print(f"Start syncing model checkpoints to s3") base_s3_path = os.path.dirname( os.path.dirname( os.getenv("SM_MODULE_DIR", ""))) curr_host = os.getenv("SM_CURRENT_HOST") full_s3_path = f"{base_s3_path}/checkpoints/{curr_host}/" sync_local_checkpoints_to_s3( local_path=args.output_dir, s3_path=full_s3_path) print( f"Finished syncing model checkpoints to s3" ) return args, final_loss, train_time_raw, global_step else: model.eval() with torch.no_grad(): loss = test_step( args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, criterion, step, ) print(f"global_step {global_step} Test Loss:", loss) test_losses.append(loss) global_step += 1 if global_step >= args.steps_this_run: return sum(test_losses) / len(test_losses) del train_dataloader # thread.join() # Make sure pool has finished and switch train_dataloader # NOTE: Will block until complete train_dataloader, data_file = dataset_future.result(timeout=None) epoch += 1
def main(): if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") use_ddp = True use_horovod = False # Fix seeds in order to get the same losses across runs random.seed(1) np.random.seed(1) torch.manual_seed(1) torch.cuda.manual_seed(1) smp.init() # SM Distributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") kwargs = {"batch_size": 64} kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False}) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) # SM Distributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time if smp.local_rank() == 0: dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) smp.barrier() dataset1 = datasets.MNIST("../data", train=True, download=False, transform=transform) if (use_ddp or use_horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset1 = SplitDataset(dataset1, partitions=partitions_dict) dataset1.select(f"{smp.dp_rank()}") # Download and create dataloaders for train and test dataset dataset2 = datasets.MNIST("../data", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **kwargs) model = GroupedNet() # SMP handles the transfer of parameters to the right device # and the user doesn't need to call 'model.to' explicitly. # model.to(device) optimizer = optim.Adadelta(model.parameters(), lr=4.0) # SM Distributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) scaler = smp.amp.GradScaler() optimizer = smp.DistributedOptimizer(optimizer) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) for epoch in range(1, 2): train(model, scaler, device, train_loader, optimizer, epoch) test_loss = test(model, device, test_loader) scheduler.step() if smp.rank() == 0: if os.path.exists("/opt/ml/local_checkpoints"): print("-INFO- PATH DO EXIST") else: os.makedirs("/opt/ml/local_checkpoints") print("-INFO- PATH DO NOT EXIST") # Waiting the save checkpoint to be finished before run another allgather_object smp.barrier() if smp.dp_rank() == 0: model_dict = model.local_state_dict() opt_dict = optimizer.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, f"/opt/ml/local_checkpoints/pt_mnist_checkpoint.pt", partial=True, ) smp.barrier() if smp.local_rank() == 0: print("Start syncing") base_s3_path = os.path.dirname( os.path.dirname(os.getenv("SM_MODULE_DIR", ""))) curr_host = os.getenv("SM_CURRENT_HOST") full_s3_path = f"{base_s3_path}/checkpoints/{curr_host}/" sync_local_checkpoints_to_s3(local_path="/opt/ml/local_checkpoints", s3_path=full_s3_path) print("Finished syncing")
def train_smp( self, model, optimizer, lr_scheduler, start_train_path_index, start_batch_index, num_params, total_steps, args, prescaled_batch, ): model.train() dp_rank = smp.dp_rank() if not prescaled_batch else smp.rdp_rank() dp_size = smp.dp_size() if not prescaled_batch else smp.rdp_size() start = time.time() throughput = None to_save = {"loss": [], "val_loss": []} loss_metric = 0 def should_record(): # only record the ranks that in the tp group that contains global rank 0 if smp.tp_size() > 1: tp_group = smp.get_tp_group() return 0 in tp_group else: return smp.rank() == 0 # Set the same seed for computation set_seed(args.seed) sampler = torch.utils.data.DistributedSampler( self.train_dataset, shuffle=True, seed=args.seed, rank=dp_rank, num_replicas=dp_size, drop_last=True, ) train_dataloader = torch.utils.data.DataLoader( self.train_dataset, sampler=sampler, batch_size=args.per_device_train_batch_size, collate_fn=self.data_collator, num_workers=0, pin_memory=True, drop_last=True, ) total_steps = 0 for batch_idx, input_data in enumerate(train_dataloader): step_start = time.time() optimizer.zero_grad(set_to_none=True) input_ids = input_data["input_ids"] attention_mask = input_data["attention_mask"] loss_mb = self.train_step(model, optimizer, input_ids, attention_mask, args) loss = loss_mb.reduce_mean() lr_scheduler.step() total_steps += 1 total_steps += 1 time_elapsed = time.time() - start step_time = time.time() - step_start if smp.rank() == 0 and not total_steps % 10: print( f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {''} samples/sec" ) if total_steps == args.max_steps: break
def train( model, optimizer, lr_scheduler, model_config, start_train_path_index, start_batch_index, num_params, total_steps, args, ): if args.enable_memory_profiling > 0: memory_status_cpu(msg="before train step") model.train() if args.parallel_proc_data_processing: pool = ProcessPoolExecutor(1) dp_rank = smp.dp_rank() if not args.prescaled_batch else smp.rdp_rank() dp_size = smp.dp_size() if not args.prescaled_batch else smp.rdp_size() data_type = "BERT" if args.use_bert_data else "GPT" if args.use_bert_data: train_paths = sorted([ os.path.join(args.training_dir, p) for p in os.listdir(args.training_dir) if os.path.isfile(os.path.join(args.training_dir, p)) and "training" in p ]) else: if args.zipped_data > 0: file_extension = ".json.gz" else: file_extension = ".json" train_paths = sorted([ os.path.join(args.training_dir, p) for p in os.listdir(args.training_dir) if p.endswith(file_extension) ]) train_dataloader = create_pretraining_dataloader( [train_paths[start_train_path_index]], args.train_batch_size, args.max_context_width, seed=args.seed, dp_rank=dp_rank, dp_size=dp_size, shuffle=args.same_seed < 1, zipped=args.zipped_data > 0, use_last_file_only=args.fast_validation > 0, data_type=data_type, ) if args.validation_freq is not None: # load all validation examples if smp.rank() == 0: print("Creating val dataloader") if args.use_bert_data: val_paths = sorted([ os.path.join(args.test_dir, p) for p in os.listdir(args.test_dir) if os.path.isfile(os.path.join(args.test_dir, p)) and "testing" in p ]) else: if args.zipped_data > 0: file_extension = ".json.gz" else: file_extension = ".json" val_paths = sorted([ os.path.join(args.test_dir, p) for p in os.listdir(args.test_dir) if p.endswith(file_extension) ]) val_dataloader = create_pretraining_dataloader( val_paths, args.val_batch_size, args.max_context_width, seed=args.seed, dp_rank=dp_rank, dp_size=dp_size, shuffle=True, zipped=args.zipped_data > 0, use_last_file_only=args.fast_validation > 0, data_type=data_type, ) if smp.rank() == 0: print("Created val dataloader") start = time.time() throughput = None to_save = {"loss": [], "val_loss": []} loss_metric = 0 def should_record(): # only record the ranks that in the tp group that contains global rank 0 if smp.tp_size() > 1: tp_group = smp.get_tp_group() return 0 in tp_group else: return smp.rank() == 0 # Set the same seed for computation set_seed(args.seed) for index in range(start_train_path_index, args.epochs * len(train_paths)): next_train_path_index = (index + 1) % len(train_paths) curr_train_path_index = index % len(train_paths) if total_steps >= args.max_steps: break if args.parallel_proc_data_processing: dataset_future = pool.submit( create_pretraining_dataloader, [train_paths[next_train_path_index]], args.train_batch_size, args.max_context_width, seed=args.seed, dp_rank=dp_rank, dp_size=dp_size, shuffle=args.same_seed < 1, zipped=args.zipped_data > 0, use_last_file_only=args.fast_validation > 0, data_type=data_type, ) if smp.rank() == 0: if args.use_bert_data: print( f"Reading data from training path {train_dataloader.dataset.input_file}" ) else: print( f"Reading data from training path {train_dataloader.dataset.input_paths}" ) for batch_idx, input_data in enumerate(train_dataloader): if batch_idx < start_batch_index: if smp.rank() == 0: print( f"Resuming from saved batch index {start_batch_index}, skipping batch {batch_idx}..." ) if start_batch_index == len(train_dataloader): # If saving at the last batch of the file, read from the next file start_batch_index = 0 break continue else: start_batch_index = 0 if args.use_bert_data: input_ids, _, attention_mask, _, _ = input_data else: input_ids, attention_mask = input_data if total_steps >= args.max_steps: break step_start = time.time() if args.smp_version < 110: optimizer.zero_grad(set_grads_to_None=True) else: optimizer.zero_grad(set_to_none=True) if args.logits_output: train_output = train_step(model, optimizer, input_ids, attention_mask, args) loss_mb = train_output["loss"] logits_mb = train_output["logits"] if smp.tp_size() > 1: logits = torch.cat(tuple(logits_mb.outputs), dim=1) else: logits = torch.cat(tuple(logits_mb.outputs), dim=0) else: # Return value, loss_mb is a StepOutput object loss_mb = train_step(model, optimizer, input_ids, attention_mask, args) # smdistributed: Average the loss across microbatches. loss = loss_mb.reduce_mean() if not args.validation_freq: loss_metric = loss.item() if args.enable_memory_profiling > 0: memory_status_cpu("After_train_step_cpu") memory_status(msg="After_train_step") if args.clean_cache > 0: # empty the cache to avoid OOM torch.cuda.empty_cache() if args.fp16: if args.smp_version < 110: optimizer.update_master_grads() optimizer.clip_master_grads(args.grad_clip) optimizer.step() if not (args.fp16 and optimizer.overflow): lr_scheduler.step() if args.enable_memory_profiling > 0: memory_status(msg="After_opt_step") total_steps += 1 time_elapsed = time.time() - start step_time = time.time() - step_start sample_processed = input_ids.shape[0] * dp_size throughput = sample_processed / step_time if smp.rank() == 0 and not total_steps % args.logging_freq: print( f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {throughput} samples/sec" ) # evaluate on validation if args.validation_freq and not (total_steps % args.validation_freq): cur_state = np.random.get_state() model = model.eval() val_loss, val_ppl = eval_model(model, val_dataloader, args.validation_batches, args.use_bert_data) if is_main_process(smp.rank()): print( f"({int(time.time()-start)}s) Batch {total_steps - 1} Validation loss: {val_loss}" ) print( f"({int(time.time()-start)}s) Batch {total_steps - 1} Validation perplexity: {val_ppl}" ) loss_metric = val_loss if args.logits_output: to_save["val_loss"].append(val_loss) model = model.train() if args.preserve_np_state > 0: np.random.set_state(cur_state) # checkpoint if not (total_steps % args.checkpoint_freq): base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt" out_path = os.path.join(args.checkpoint_dir, base_path) total_ckpts = total_steps // args.checkpoint_freq delete_oldest_ckpt(args, delete_on_rank0_only=args.use_fsx > 0) save( out_path, model, optimizer, lr_scheduler, model_config, num_params, total_steps, curr_train_path_index, args, partial=True, batch_idx=batch_idx + 1, ) if args.logits_output: to_save["loss"].append(loss.item()) if total_steps >= args.max_steps: if should_record() and args.logits_output: to_save["logits"] = logits.detach().cpu() output_file = f"rank_{smp.rank()}_" + args.logits_output torch.save(to_save, os.path.join(args.model_dir, output_file)) print( f"logits and loss saved at {os.path.join(args.model_dir, output_file)}" ) break del train_dataloader if args.parallel_proc_data_processing: s = time.time() train_dataloader = dataset_future.result(timeout=None) wait_time = time.time() - s if wait_time > 1: # TODO if this happens, we should try num_workers>1 in dataloader print( f"[{smp.rank()}] Waited {wait_time} for data loader to be ready. Please check if dataloader performance can be improved to avoid these waits." ) else: train_dataloader = create_pretraining_dataloader( [train_paths[next_train_path_index]], args.train_batch_size, args.max_context_width, seed=args.seed, dp_rank=dp_rank, dp_size=dp_size, shuffle=args.same_seed < 1, zipped=args.zipped_data > 0, use_last_file_only=args.fast_validation > 0, data_type=data_type, ) return total_steps, throughput, loss_metric
def main(): global timeout_sent args = parse_arguments() random.seed(args.seed + args.local_rank) np.random.seed(args.seed + args.local_rank) torch.manual_seed(args.seed + args.local_rank) torch.cuda.manual_seed(args.seed + args.local_rank) worker_init = WorkerInitObj(args.seed + args.local_rank) device, args = setup_training(args) if is_main_process(): dllogger.log(step="PARAMETER", data={"Config": [str(args)]}) # Prepare optimizer model, optimizer, lr_scheduler, checkpoint, global_step, criterion = prepare_model_and_optimizer(args, device) if is_main_process(): dllogger.log(step="PARAMETER", data={"SEED": args.seed}) raw_train_start = None if is_main_process(): dllogger.log(step="PARAMETER", data={"train_start": True}) dllogger.log(step="PARAMETER", data={"batch_size_per_gpu": args.train_batch_size}) dllogger.log(step="PARAMETER", data={"learning_rate": args.learning_rate}) most_recent_ckpts_paths = [] average_loss = 0.0 # averaged loss every args.log_freq steps epoch = 0 training_steps = 0 test_losses = [] pool = ProcessPoolExecutor(1) # Note: We loop infinitely over epochs, termination is handled via iteration count while True: thread = None restored_data_loader = None if not args.resume_from_checkpoint or epoch > 0 or (args.phase2 and global_step < 1) or args.init_checkpoint: files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if os.path.isfile(os.path.join(args.input_dir, f)) and 'training' in f] files.sort() num_files = len(files) random.Random(args.seed + epoch).shuffle(files) f_start_id = 0 else: f_start_id = checkpoint['files'][0] files = checkpoint['files'][1:] args.resume_from_checkpoint = False num_files = len(files) # may not exist in all checkpoints epoch = checkpoint.get('epoch', 0) restored_dataloader = checkpoint.get('data_loader', None) shared_file_list = {} if torch.distributed.is_initialized() and get_world_size() > num_files: remainder = get_world_size() % num_files data_file = files[(f_start_id*get_world_size()+get_rank() + remainder*f_start_id)%num_files] else: data_file = files[(f_start_id*get_world_size()+get_rank())%num_files] previous_file = data_file if restored_data_loader is None: train_data = pretraining_dataset(data_file, args.max_predictions_per_seq) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size * args.n_gpu, num_workers=4, worker_init_fn=worker_init, pin_memory=True) # shared_file_list["0"] = (train_dataloader, data_file) else: train_dataloader = restored_data_loader restored_data_loader = None overflow_buf = None if args.allreduce_post_accumulation: overflow_buf = torch.cuda.IntTensor([0]) for f_id in range(f_start_id + 1 , len(files)): if get_world_size() > num_files: data_file = files[(f_id*get_world_size()+get_rank() + remainder*f_id)%num_files] else: data_file = files[(f_id*get_world_size()+get_rank())%num_files] previous_file = data_file dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init) train_iter = tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar) if is_main_process() else train_dataloader if raw_train_start is None: raw_train_start = time.time() for step, batch in enumerate(train_iter): training_steps += 1 batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch if args.do_train: from smdistributed.modelparallel.test.torch.utils import verify, dump_model model.train() if args.smp > 0: loss_mbs = smp_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step) loss = loss_mbs.reduce_mean() if smp.rank() == 0: print("Loss:", loss.item()) else: loss = train_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step) divisor=1 average_loss += loss.item() if training_steps % args.gradient_accumulation_steps == 0: lr_scheduler.step() # learning rate warmup global_step = take_optimizer_step(args, optimizer, model, overflow_buf, global_step) if global_step >= args.steps_this_run or timeout_sent: train_time_raw = time.time() - raw_train_start last_num_steps = int(training_steps / args.gradient_accumulation_steps) % args.log_freq last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps average_loss = torch.tensor(average_loss, dtype=torch.float32).cuda() average_loss = average_loss / (last_num_steps * divisor) if (torch.distributed.is_initialized()): average_loss /= get_world_size() torch.distributed.all_reduce(average_loss) final_loss = loss.item() if is_main_process(): dllogger.log(step=(epoch, global_step, ), data={"final_loss": final_loss}) elif training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0: if is_main_process(): dllogger.log(step=(epoch, global_step, ), data={"average_loss": average_loss / (args.log_freq * divisor), "step_loss": loss.item() * args.gradient_accumulation_steps / divisor, "learning_rate": optimizer.param_groups[0]['lr']}) average_loss = 0 if global_step >= args.steps_this_run or training_steps % ( args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0 or timeout_sent: if smp.dp_rank() == 0 and not args.skip_checkpoint: # Save a trained model dllogger.log(step="PARAMETER", data={"checkpoint_step": global_step}) # model_to_save = model.module if hasattr(model, # 'module') else model # Only save the model it-self if args.resume_step < 0 or not args.phase2: output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)) else: output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step + args.phase1_end_step)) if args.do_train: save_dict = { 'model': model.local_state_dict(), 'optimizer': optimizer.local_state_dict(), 'files': [f_id] + files, 'epoch': epoch, 'data_loader': None if global_step >= args.steps_this_run else train_dataloader} if args.fp16: save_dict['master params'] = list(amp.master_params(optimizer)) # SMP: Checkpoint mp_rank specific state smp.save(save_dict, output_save_file, partial=True) most_recent_ckpts_paths.append(output_save_file) if len(most_recent_ckpts_paths) > 3 and (args.smp == 0 or smp.dp_rank() == 0): ckpt_to_be_removed = most_recent_ckpts_paths.pop(0) os.remove(ckpt_to_be_removed+f"_{smp.mp_rank()}") # Exiting the training due to hitting max steps, or being sent a # timeout from the cluster scheduler if global_step >= args.steps_this_run or timeout_sent: del train_dataloader # thread.join() if smp.dp_rank() == 0 and args.save_full: output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)) save_dict = { 'model': model.local_state_dict(), 'optimizer': optimizer.local_state_dict(), 'files': [f_id] + files, 'epoch': epoch, 'data_loader': None if global_step >= args.steps_this_run else train_dataloader} if args.fp16: save_dict['master params'] = list(amp.master_params(optimizer)) # SMP: Save a single checkpoint containing entire model parameters smp.save(save_dict, output_save_file, partial=False) smp.barrier() return args, final_loss, train_time_raw, global_step else: model.eval() with torch.no_grad(): loss = test_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, criterion, step) print(f"global_step {global_step} Test Loss:", loss) test_losses.append(loss) global_step += 1 if global_step >= args.steps_this_run: return sum(test_losses) / len(test_losses) del train_dataloader # thread.join() # Make sure pool has finished and switch train_dataloader # NOTE: Will block until complete train_dataloader, data_file = dataset_future.result(timeout=None) epoch += 1
args.is_multigpus = args.num_gpus > 1 args.multigpus_distributed = (args.is_distributed or args.is_multigpus) logger.debug(f"args.image_folder : {args.image_folder}") args.world_size = 1 args.local_rank = 0 args.rank = 0 if args.model_parallel: args.world_size = smp.size() args.local_rank = smp.local_rank() # rank per host args.rank = smp.rank() args.dp_size = smp.dp_size() args.dp_rank = smp.dp_rank() logger.debug(f"args.world_size : {args.world_size}, args.local_rank : {args.local_rank}, args.rank : {args.rank}, \ args.dp_size : {args.dp_size}, args.dp_rank : {args.dp_rank}") else: # initialize deepspeed print(f"args.deepspeed : {args.deepspeed}") deepspeed_utils.init_deepspeed(args.deepspeed) # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size) ## SageMaker try: if os.environ.get('SM_CHANNEL_TRAINING') is not None: args.model_dir = os.environ.get('SM_MODEL_DIR') args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR') args.image_folder = os.environ.get('SM_CHANNEL_TRAINING')
def main(): parser = get_parser() args = parser.parse_args() if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") use_ddp = args.ddp > 0 use_horovod = args.horovod > 0 # Fix seeds in order to get the same losses across runs random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) cfg = { "microbatches": args.num_microbatches, "placement_strategy": "spread", "pipeline": args.pipeline, "optimize": "speed", "partitions": args.num_partitions, "horovod": use_horovod, "ddp": use_ddp, } smp.init(cfg) # SM Distributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") kwargs = {"batch_size": args.batch_size} kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False}) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) # SM Distributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time if smp.local_rank() == 0: dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) smp.barrier() dataset1 = datasets.MNIST("../data", train=True, download=False, transform=transform) if (use_ddp or use_horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset1 = SplitDataset(dataset1, partitions=partitions_dict) dataset1.select(f"{smp.dp_rank()}") # Download and create dataloaders for train and test dataset dataset2 = datasets.MNIST("../data", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **kwargs) model = GroupedNet() # SMP handles the transfer of parameters to the right device # and the user doesn't need to call 'model.to' explicitly. # model.to(device) optimizer = optim.Adadelta(model.parameters(), lr=args.lr) # SM Distributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) scaler = smp.amp.GradScaler() optimizer = smp.DistributedOptimizer(optimizer) if args.partial_checkpoint: checkpoint = smp.load(args.partial_checkpoint, partial=True) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: checkpoint = smp.load(args.full_checkpoint, partial=False) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) for epoch in range(1, args.epochs + 1): train(args, model, scaler, device, train_loader, optimizer, epoch) test_loss = test(args, model, device, test_loader) scheduler.step() if args.save_partial_model: if smp.dp_rank() == 0: model_dict = model.local_state_dict() opt_dict = optimizer.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, f"./pt_mnist_checkpoint.pt", partial=True, ) if args.save_full_model: if smp.dp_rank() == 0: model_dict = model.state_dict() opt_dict = optimizer.state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, "./pt_mnist_checkpoint.pt", partial=False, ) # Waiting the save checkpoint to be finished before run another allgather_object smp.barrier() if args.assert_losses: if use_horovod or use_ddp: # SM Distributed: If using data parallelism, gather all losses across different model # replicas and check if losses match. losses = smp.allgather(test_loss, smp.DP_GROUP) for l in losses: assert math.isclose(l, losses[0]) assert test_loss < 0.18 else: assert test_loss < 0.08
def main(): parser = get_parser() args = parser.parse_args() if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") args.rank = -1 args.world_size = 1 if args.model_parallel: args.deepspeed = False cfg = { "microbatches": args.num_microbatches, "placement_strategy": args.placement_strategy, "pipeline": args.pipeline, "optimize": args.optimize, "partitions": args.num_partitions, "horovod": args.horovod, "ddp": args.ddp, } smp.init(cfg) torch.cuda.set_device(smp.local_rank()) args.rank = smp.dp_rank() args.world_size = smp.size() else: # initialize deepspeed print(f"args.deepspeed : {args.deepspeed}") deepspeed_utils.init_deepspeed(args.deepspeed) if deepspeed_utils.is_root_worker(): args.rank = 0 if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed + args.rank) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size) cudnn.deterministic = True if cudnn.deterministic: warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True} device = torch.device("cuda") logger.debug(f"args.image_folder : {args.image_folder}") logger.debug(f"args.rank : {args.rank}") ## SageMaker try: if os.environ.get('SM_MODEL_DIR') is not None: args.model_dir = os.environ.get('SM_MODEL_DIR') # args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR') args.image_folder = os.environ.get('SM_CHANNEL_TRAINING') except: logger.debug("not SageMaker") pass IMAGE_SIZE = args.image_size IMAGE_PATH = args.image_folder EPOCHS = args.EPOCHS BATCH_SIZE = args.BATCH_SIZE LEARNING_RATE = args.LEARNING_RATE LR_DECAY_RATE = args.LR_DECAY_RATE NUM_TOKENS = args.NUM_TOKENS NUM_LAYERS = args.NUM_LAYERS NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS EMB_DIM = args.EMB_DIM HID_DIM = args.HID_DIM KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT STARTING_TEMP = args.STARTING_TEMP TEMP_MIN = args.TEMP_MIN ANNEAL_RATE = args.ANNEAL_RATE NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE # transform = Compose( # [ # RandomResizedCrop(args.image_size, args.image_size), # OneOf( # [ # IAAAdditiveGaussianNoise(), # GaussNoise(), # ], # p=0.2 # ), # VerticalFlip(p=0.5), # OneOf( # [ # MotionBlur(p=.2), # MedianBlur(blur_limit=3, p=0.1), # Blur(blur_limit=3, p=0.1), # ], # p=0.2 # ), # OneOf( # [ # CLAHE(clip_limit=2), # IAASharpen(), # IAAEmboss(), # RandomBrightnessContrast(), # ], # p=0.3 # ), # HueSaturationValue(p=0.3), # # Normalize( # # mean=[0.485, 0.456, 0.406], # # std=[0.229, 0.224, 0.225], # # ) # ], # p=1.0 # ) transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize(IMAGE_SIZE), T.CenterCrop(IMAGE_SIZE), T.ToTensor() ]) sampler = None dl = None # data logger.debug(f"IMAGE_PATH : {IMAGE_PATH}") # ds = AlbumentationImageDataset( # IMAGE_PATH, # transform=transform, # args=args # ) ds = ImageFolder( IMAGE_PATH, transform=transform, ) if args.model_parallel and (args.ddp or args.horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } ds = SplitDataset(ds, partitions=partitions_dict) ds.select(f"{smp.dp_rank()}") dl = DataLoader(ds, BATCH_SIZE, shuffle=True, drop_last=args.model_parallel, **args.kwargs) vae_params = dict(image_size=IMAGE_SIZE, num_layers=NUM_LAYERS, num_tokens=NUM_TOKENS, codebook_dim=EMB_DIM, hidden_dim=HID_DIM, num_resnet_blocks=NUM_RESNET_BLOCKS) vae = DiscreteVAE(**vae_params, smooth_l1_loss=SMOOTH_L1_LOSS, kl_div_loss_weight=KL_LOSS_WEIGHT).to(device) # optimizer opt = Adam(vae.parameters(), lr=LEARNING_RATE) sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE) if args.model_parallel: import copy dummy_codebook = copy.deepcopy(vae.codebook) dummy_decoder = copy.deepcopy(vae.decoder) vae = smp.DistributedModel(vae) scaler = smp.amp.GradScaler() opt = smp.DistributedOptimizer(opt) if args.partial_checkpoint: args.checkpoint = smp.load(args.partial_checkpoint, partial=True) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: args.checkpoint = smp.load(args.full_checkpoint, partial=False) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) assert len(ds) > 0, 'folder does not contain any images' if (not args.model_parallel) and args.rank == 0: print(f'{len(ds)} images found for training') # weights & biases experiment tracking # import wandb model_config = dict(num_tokens=NUM_TOKENS, smooth_l1_loss=SMOOTH_L1_LOSS, num_resnet_blocks=NUM_RESNET_BLOCKS, kl_loss_weight=KL_LOSS_WEIGHT) # run = wandb.init( # project = 'dalle_train_vae', # job_type = 'train_model', # config = model_config # ) def save_model(path): if not args.rank == 0: return save_obj = {'hparams': vae_params, 'weights': vae.state_dict()} torch.save(save_obj, path) # distribute with deepspeed if not args.model_parallel: deepspeed_utils.check_batch_size(BATCH_SIZE) deepspeed_config = {'train_batch_size': BATCH_SIZE} (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute( args=args, model=vae, optimizer=opt, model_parameters=vae.parameters(), training_data=ds if args.deepspeed else dl, lr_scheduler=sched, config_params=deepspeed_config, ) try: # Rubik: Define smp.step. Return any tensors needed outside. @smp.step def train_step(vae, images, temp): # logger.debug(f"args.amp : {args.amp}") with autocast(enabled=(args.amp > 0)): loss, recons = vae(images, return_loss=True, return_recons=True, temp=temp) scaled_loss = scaler.scale(loss) if args.amp else loss vae.backward(scaled_loss) # torch.nn.utils.clip_grad_norm_(vae.parameters(), 5) return loss, recons @smp.step def get_codes_step(vae, images, k): images = images[:k] logits = vae.forward(images, return_logits=True) codebook_indices = logits.argmax(dim=1).flatten(1) return codebook_indices def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices): from functools import partial for module in dummy_codebook.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) image_embeds = dummy_codebook.forward(codebook_indices) b, n, d = image_embeds.shape h = w = int(sqrt(n)) image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h=h, w=w) for module in dummy_decoder.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) hard_recons = dummy_decoder.forward(image_embeds) return hard_recons except: pass # starting temperature global_step = 0 temp = STARTING_TEMP for epoch in range(EPOCHS): ## batch_time = util.AverageMeter('Time', ':6.3f') data_time = util.AverageMeter('Data', ':6.3f') losses = util.AverageMeter('Loss', ':.4e') top1 = util.AverageMeter('Acc@1', ':6.2f') top5 = util.AverageMeter('Acc@5', ':6.2f') progress = util.ProgressMeter( len(dl), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) vae.train() start = time.time() for i, (images, _) in enumerate(dl): images = images.to(device, non_blocking=True) opt.zero_grad() if args.model_parallel: loss, recons = train_step(vae, images, temp) # Rubik: Average the loss across microbatches. loss = loss.reduce_mean() recons = recons.reduce_mean() else: loss, recons = distr_vae(images, return_loss=True, return_recons=True, temp=temp) if (not args.model_parallel) and args.deepspeed: # Gradients are automatically zeroed after the step distr_vae.backward(loss) distr_vae.step() elif args.model_parallel: if args.amp: scaler.step(opt) scaler.update() else: # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param if len(list(vae.local_parameters())) > 0: opt.step() else: loss.backward() opt.step() logs = {} if i % 10 == 0: if args.rank == 0: # if deepspeed_utils.is_root_worker(): k = NUM_IMAGES_SAVE with torch.no_grad(): if args.model_parallel: model_dict = vae.state_dict() model_dict_updated = {} for key, val in model_dict.items(): if "decoder" in key: key = key.replace("decoder.", "") elif "codebook" in key: key = key.replace("codebook.", "") model_dict_updated[key] = val dummy_decoder.load_state_dict(model_dict_updated, strict=False) dummy_codebook.load_state_dict(model_dict_updated, strict=False) codes = get_codes_step(vae, images, k) codes = codes.reduce_mean().to(torch.long) hard_recons = hard_recons_step( dummy_decoder, dummy_codebook, codes) else: codes = vae.get_codebook_indices(images[:k]) hard_recons = vae.decode(codes) images, recons = map(lambda t: t[:k], (images, recons)) images, recons, hard_recons, codes = map( lambda t: t.detach().cpu(), (images, recons, hard_recons, codes)) images, recons, hard_recons = map( lambda t: make_grid(t.float(), nrow=int(sqrt(k)), normalize=True, range=(-1, 1)), (images, recons, hard_recons)) # logs = { # **logs, # 'sample images': wandb.Image(images, caption = 'original images'), # 'reconstructions': wandb.Image(recons, caption = 'reconstructions'), # 'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'), # 'codebook_indices': wandb.Histogram(codes), # 'temperature': temp # } if args.model_parallel: filename = f'{args.model_dir}/vae.pt' if smp.dp_rank == 0: if args.save_full_model: model_dict = vae.state_dict() opt_dict = opt.state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=False, ) else: model_dict = vae.local_state_dict() opt_dict = opt.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=True, ) smp.barrier() else: save_model(f'{args.model_dir}/vae.pt') # wandb.save(f'{args.model_dir}/vae.pt') # temperature anneal temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN) # lr decay sched.step() # Collective loss, averaged if args.model_parallel: avg_loss = loss.detach().clone() # print("args.world_size : {}".format(args.world_size)) avg_loss /= args.world_size else: avg_loss = deepspeed_utils.average_all(loss) if args.rank == 0: if i % 100 == 0: lr = sched.get_last_lr()[0] print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},') logs = { **logs, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item(), 'lr': lr } # wandb.log(logs) global_step += 1 if args.rank == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy # prec1, prec5 = util.accuracy(output, target, topk=(1, 5)) # to_python_float incurs a host<->device sync losses.update(util.to_python_float(loss), images.size(0)) # top1.update(util.to_python_float(prec1), images.size(0)) # top5.update(util.to_python_float(prec5), images.size(0)) # Waiting until finishing operations on GPU (Pytorch default: async) torch.cuda.synchronize() batch_time.update((time.time() - start) / args.log_interval) end = time.time() print( 'Epoch: [{0}][{1}/{2}] ' 'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, ' 'Train_Speed={3:.3f} ({4:.3f}), ' 'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format( epoch, i, len(dl), args.world_size * BATCH_SIZE / batch_time.val, args.world_size * BATCH_SIZE / batch_time.avg, batch_time=batch_time, loss=losses)) # if deepspeed_utils.is_root_worker(): # save trained model to wandb as an artifact every epoch's end # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae.pt') # run.log_artifact(model_artifact) if args.rank == 0: # if deepspeed_utils.is_root_worker(): # save final vae and cleanup if args.model_parallel: logger.debug('save model_parallel') else: save_model(os.path.join(args.model_dir, 'vae-final.pt')) # wandb.save(f'{args.model_dir}/vae-final.pt') # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae-final.pt') # run.log_artifact(model_artifact) # wandb.finish() if args.model_parallel: if args.assert_losses: if args.horovod or args.ddp: # SM Distributed: If using data parallelism, gather all losses across different model # replicas and check if losses match. losses = smp.allgather(loss, smp.DP_GROUP) for l in losses: print(l) assert math.isclose(l, losses[0]) assert loss < 0.18 else: assert loss < 0.08 smp.barrier() print("SMP training finished successfully")