def train_smp( self, model, optimizer, lr_scheduler, start_train_path_index, start_batch_index, num_params, total_steps, args, prescaled_batch, ): model.train() dp_rank = smp.dp_rank() if not prescaled_batch else smp.rdp_rank() dp_size = smp.dp_size() if not prescaled_batch else smp.rdp_size() start = time.time() throughput = None to_save = {"loss": [], "val_loss": []} loss_metric = 0 def should_record(): # only record the ranks that in the tp group that contains global rank 0 if smp.tp_size() > 1: tp_group = smp.get_tp_group() return 0 in tp_group else: return smp.rank() == 0 # Set the same seed for computation set_seed(args.seed) sampler = torch.utils.data.DistributedSampler( self.train_dataset, shuffle=True, seed=args.seed, rank=dp_rank, num_replicas=dp_size, drop_last=True, ) train_dataloader = torch.utils.data.DataLoader( self.train_dataset, sampler=sampler, batch_size=args.per_device_train_batch_size, collate_fn=self.data_collator, num_workers=0, pin_memory=True, drop_last=True, ) total_steps = 0 for batch_idx, input_data in enumerate(train_dataloader): step_start = time.time() optimizer.zero_grad(set_to_none=True) input_ids = input_data["input_ids"] attention_mask = input_data["attention_mask"] loss_mb = self.train_step(model, optimizer, input_ids, attention_mask, args) loss = loss_mb.reduce_mean() lr_scheduler.step() total_steps += 1 total_steps += 1 time_elapsed = time.time() - start step_time = time.time() - step_start if smp.rank() == 0 and not total_steps % 10: print( f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {''} samples/sec" ) if total_steps == args.max_steps: break
def split_dataset(args): if (args.ddp or args.horovod) and smp.dp_size() > 1: partitions_dict = {f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size())} args.ds = SplitDataset(args.ds, partitions=partitions_dict) args.ds.select(f"{smp.dp_rank()}") return args.ds
def train( model, optimizer, lr_scheduler, model_config, start_train_path_index, start_batch_index, num_params, total_steps, args, ): if args.enable_memory_profiling > 0: memory_status_cpu(msg="before train step") model.train() if args.parallel_proc_data_processing: pool = ProcessPoolExecutor(1) dp_rank = smp.dp_rank() if not args.prescaled_batch else smp.rdp_rank() dp_size = smp.dp_size() if not args.prescaled_batch else smp.rdp_size() data_type = "BERT" if args.use_bert_data else "GPT" if args.use_bert_data: train_paths = sorted([ os.path.join(args.training_dir, p) for p in os.listdir(args.training_dir) if os.path.isfile(os.path.join(args.training_dir, p)) and "training" in p ]) else: if args.zipped_data > 0: file_extension = ".json.gz" else: file_extension = ".json" train_paths = sorted([ os.path.join(args.training_dir, p) for p in os.listdir(args.training_dir) if p.endswith(file_extension) ]) train_dataloader = create_pretraining_dataloader( [train_paths[start_train_path_index]], args.train_batch_size, args.max_context_width, seed=args.seed, dp_rank=dp_rank, dp_size=dp_size, shuffle=args.same_seed < 1, zipped=args.zipped_data > 0, use_last_file_only=args.fast_validation > 0, data_type=data_type, ) if args.validation_freq is not None: # load all validation examples if smp.rank() == 0: print("Creating val dataloader") if args.use_bert_data: val_paths = sorted([ os.path.join(args.test_dir, p) for p in os.listdir(args.test_dir) if os.path.isfile(os.path.join(args.test_dir, p)) and "testing" in p ]) else: if args.zipped_data > 0: file_extension = ".json.gz" else: file_extension = ".json" val_paths = sorted([ os.path.join(args.test_dir, p) for p in os.listdir(args.test_dir) if p.endswith(file_extension) ]) val_dataloader = create_pretraining_dataloader( val_paths, args.val_batch_size, args.max_context_width, seed=args.seed, dp_rank=dp_rank, dp_size=dp_size, shuffle=True, zipped=args.zipped_data > 0, use_last_file_only=args.fast_validation > 0, data_type=data_type, ) if smp.rank() == 0: print("Created val dataloader") start = time.time() throughput = None to_save = {"loss": [], "val_loss": []} loss_metric = 0 def should_record(): # only record the ranks that in the tp group that contains global rank 0 if smp.tp_size() > 1: tp_group = smp.get_tp_group() return 0 in tp_group else: return smp.rank() == 0 # Set the same seed for computation set_seed(args.seed) for index in range(start_train_path_index, args.epochs * len(train_paths)): next_train_path_index = (index + 1) % len(train_paths) curr_train_path_index = index % len(train_paths) if total_steps >= args.max_steps: break if args.parallel_proc_data_processing: dataset_future = pool.submit( create_pretraining_dataloader, [train_paths[next_train_path_index]], args.train_batch_size, args.max_context_width, seed=args.seed, dp_rank=dp_rank, dp_size=dp_size, shuffle=args.same_seed < 1, zipped=args.zipped_data > 0, use_last_file_only=args.fast_validation > 0, data_type=data_type, ) if smp.rank() == 0: if args.use_bert_data: print( f"Reading data from training path {train_dataloader.dataset.input_file}" ) else: print( f"Reading data from training path {train_dataloader.dataset.input_paths}" ) for batch_idx, input_data in enumerate(train_dataloader): if batch_idx < start_batch_index: if smp.rank() == 0: print( f"Resuming from saved batch index {start_batch_index}, skipping batch {batch_idx}..." ) if start_batch_index == len(train_dataloader): # If saving at the last batch of the file, read from the next file start_batch_index = 0 break continue else: start_batch_index = 0 if args.use_bert_data: input_ids, _, attention_mask, _, _ = input_data else: input_ids, attention_mask = input_data if total_steps >= args.max_steps: break step_start = time.time() if args.smp_version < 110: optimizer.zero_grad(set_grads_to_None=True) else: optimizer.zero_grad(set_to_none=True) if args.logits_output: train_output = train_step(model, optimizer, input_ids, attention_mask, args) loss_mb = train_output["loss"] logits_mb = train_output["logits"] if smp.tp_size() > 1: logits = torch.cat(tuple(logits_mb.outputs), dim=1) else: logits = torch.cat(tuple(logits_mb.outputs), dim=0) else: # Return value, loss_mb is a StepOutput object loss_mb = train_step(model, optimizer, input_ids, attention_mask, args) # smdistributed: Average the loss across microbatches. loss = loss_mb.reduce_mean() if not args.validation_freq: loss_metric = loss.item() if args.enable_memory_profiling > 0: memory_status_cpu("After_train_step_cpu") memory_status(msg="After_train_step") if args.clean_cache > 0: # empty the cache to avoid OOM torch.cuda.empty_cache() if args.fp16: if args.smp_version < 110: optimizer.update_master_grads() optimizer.clip_master_grads(args.grad_clip) optimizer.step() if not (args.fp16 and optimizer.overflow): lr_scheduler.step() if args.enable_memory_profiling > 0: memory_status(msg="After_opt_step") total_steps += 1 time_elapsed = time.time() - start step_time = time.time() - step_start sample_processed = input_ids.shape[0] * dp_size throughput = sample_processed / step_time if smp.rank() == 0 and not total_steps % args.logging_freq: print( f"({int(time_elapsed)}s), Batch {total_steps - 1} Loss: {loss.item()}, Speed: {throughput} samples/sec" ) # evaluate on validation if args.validation_freq and not (total_steps % args.validation_freq): cur_state = np.random.get_state() model = model.eval() val_loss, val_ppl = eval_model(model, val_dataloader, args.validation_batches, args.use_bert_data) if is_main_process(smp.rank()): print( f"({int(time.time()-start)}s) Batch {total_steps - 1} Validation loss: {val_loss}" ) print( f"({int(time.time()-start)}s) Batch {total_steps - 1} Validation perplexity: {val_ppl}" ) loss_metric = val_loss if args.logits_output: to_save["val_loss"].append(val_loss) model = model.train() if args.preserve_np_state > 0: np.random.set_state(cur_state) # checkpoint if not (total_steps % args.checkpoint_freq): base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt" out_path = os.path.join(args.checkpoint_dir, base_path) total_ckpts = total_steps // args.checkpoint_freq delete_oldest_ckpt(args, delete_on_rank0_only=args.use_fsx > 0) save( out_path, model, optimizer, lr_scheduler, model_config, num_params, total_steps, curr_train_path_index, args, partial=True, batch_idx=batch_idx + 1, ) if args.logits_output: to_save["loss"].append(loss.item()) if total_steps >= args.max_steps: if should_record() and args.logits_output: to_save["logits"] = logits.detach().cpu() output_file = f"rank_{smp.rank()}_" + args.logits_output torch.save(to_save, os.path.join(args.model_dir, output_file)) print( f"logits and loss saved at {os.path.join(args.model_dir, output_file)}" ) break del train_dataloader if args.parallel_proc_data_processing: s = time.time() train_dataloader = dataset_future.result(timeout=None) wait_time = time.time() - s if wait_time > 1: # TODO if this happens, we should try num_workers>1 in dataloader print( f"[{smp.rank()}] Waited {wait_time} for data loader to be ready. Please check if dataloader performance can be improved to avoid these waits." ) else: train_dataloader = create_pretraining_dataloader( [train_paths[next_train_path_index]], args.train_batch_size, args.max_context_width, seed=args.seed, dp_rank=dp_rank, dp_size=dp_size, shuffle=args.same_seed < 1, zipped=args.zipped_data > 0, use_last_file_only=args.fast_validation > 0, data_type=data_type, ) return total_steps, throughput, loss_metric
def world_size(self): if is_sagemaker_model_parallel_available(): return smp.dp_size() return super().world_size
def init_train(): """ Train the PyTorch model """ cat_mask = [ False, True, True, True, True, False, True, True, True, True, True, False, False, False, False, False, False, False ] train_ds = CsvDatasetSimple(args.train) test_ds = CsvDatasetSimple(args.test) batch_size = args.batch_size epochs = args.epochs learning_rate = args.learning_rate logger.info("batch_size = {}, epochs = {}, learning rate = {}".format( batch_size, epochs, learning_rate)) # smdistributed: initialize the backend smp.init() # smdistributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") # smdistributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time #dataset = datasets.MNIST("../data", train=True, download=False) dataset = train_ds # smdistributed: Shard the dataset based on data-parallel ranks if smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset = SplitDataset(dataset, partitions=partitions_dict) dataset.select(f"{smp.dp_rank()}") # smdistributed: Set drop_last=True to ensure that batch size is always divisible # by the number of microbatches train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True) model = TabularNet(n_cont=9, n_cat=9, cat_mask=cat_mask, cat_dim=[ 0, 2050, 13, 5, 366, 0, 50000, 50000, 50000, 50000, 50, 0, 0, 0, 0, 0, 0, 0 ], y_min=0., y_max=1.) logger.debug(model) optimizer = optim.Adadelta(model.parameters(), lr=4.0) # SMP: Instantiate DistributedModel object using the model. # This handles distributing the model among multiple ranks # behind the scenes # If horovod is enabled this will do an overlapping_all_reduce by # default. # smdistributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) optimizer = smp.DistributedOptimizer(optimizer) train(model, device, train_loader, optimizer) torch.save(model.state_dict(), args.model_dir + "/model.pth")
args.is_distributed = len(args.hosts) > 1 and args.backend is not None args.is_multigpus = args.num_gpus > 1 args.multigpus_distributed = (args.is_distributed or args.is_multigpus) logger.debug(f"args.image_folder : {args.image_folder}") args.world_size = 1 args.local_rank = 0 args.rank = 0 if args.model_parallel: args.world_size = smp.size() args.local_rank = smp.local_rank() # rank per host args.rank = smp.rank() args.dp_size = smp.dp_size() args.dp_rank = smp.dp_rank() logger.debug(f"args.world_size : {args.world_size}, args.local_rank : {args.local_rank}, args.rank : {args.rank}, \ args.dp_size : {args.dp_size}, args.dp_rank : {args.dp_rank}") else: # initialize deepspeed print(f"args.deepspeed : {args.deepspeed}") deepspeed_utils.init_deepspeed(args.deepspeed) # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size) ## SageMaker try: if os.environ.get('SM_CHANNEL_TRAINING') is not None: args.model_dir = os.environ.get('SM_MODEL_DIR') args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR')
def main(): if not torch.cuda.is_available(): raise ValueError("The script requires CUDA support, but CUDA not available") use_ddp = True use_horovod = False # Fix seeds in order to get the same losses across runs random.seed(1) np.random.seed(1) torch.manual_seed(1) torch.cuda.manual_seed(1) smp.init() # SM Distributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") kwargs = {"batch_size": 64} kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False}) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) # SM Distributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time if smp.local_rank() == 0: dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) smp.barrier() dataset1 = datasets.MNIST("../data", train=True, download=False, transform=transform) if (use_ddp or use_horovod) and smp.dp_size() > 1: partitions_dict = {f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size())} dataset1 = SplitDataset(dataset1, partitions=partitions_dict) dataset1.select(f"{smp.dp_rank()}") # Download and create dataloaders for train and test dataset dataset2 = datasets.MNIST("../data", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **kwargs) model = GroupedNet() # SMP handles the transfer of parameters to the right device # and the user doesn't need to call 'model.to' explicitly. # model.to(device) optimizer = optim.Adadelta(model.parameters(), lr=4.0) # SM Distributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) scaler = smp.amp.GradScaler() optimizer = smp.DistributedOptimizer(optimizer) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) for epoch in range(1, 2): train(model, scaler, device, train_loader, optimizer, epoch) test_loss = test(model, device, test_loader) scheduler.step() if smp.rank() == 0: if os.path.exists('/opt/ml/local_checkpoints'): print("-INFO- PATH DO EXIST") else: os.makedirs('/opt/ml/local_checkpoints') print("-INFO- PATH DO NOT EXIST") # Waiting the save checkpoint to be finished before run another allgather_object smp.barrier() if smp.dp_rank() == 0: model_dict = model.local_state_dict() opt_dict = optimizer.local_state_dict() smp.save( {"model_state_dict": model_dict, "optimizer_state_dict": opt_dict}, f"/opt/ml/local_checkpoints/pt_mnist_checkpoint.pt", partial=True, ) smp.barrier() if smp.local_rank() == 0: print("Start syncing") base_s3_path = os.path.dirname(os.path.dirname(os.getenv('SM_MODULE_DIR', ''))) curr_host = os.getenv('SM_CURRENT_HOST') full_s3_path = f'{base_s3_path}/checkpoints/{curr_host}/' sync_local_checkpoints_to_s3(local_path='/opt/ml/local_checkpoints', s3_path=full_s3_path) print("Finished syncing")
def main(): global timeout_sent args = parse_arguments() random.seed(args.seed + args.local_rank) np.random.seed(args.seed + args.local_rank) torch.manual_seed(args.seed + args.local_rank) torch.cuda.manual_seed(args.seed + args.local_rank) worker_init = WorkerInitObj(args.seed + args.local_rank) device, args = setup_training(args) # Prepare optimizer model, optimizer, lr_scheduler, checkpoint, global_step, criterion = prepare_model_and_optimizer( args, device) raw_train_start = None most_recent_ckpts_paths = [] average_loss = 0.0 # averaged loss every args.log_freq steps epoch = 0 training_steps = 0 test_losses = [] pool = ProcessPoolExecutor(1) # Note: We loop infinitely over epochs, termination is handled via iteration count while True: thread = None restored_data_loader = None if not args.resume_from_checkpoint or epoch > 0 or ( args.phase2 and global_step < 1) or args.init_checkpoint: files = [ os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if os.path.isfile(os.path.join(args.input_dir, f)) and 'training' in f ] files.sort() num_files = len(files) random.Random(args.seed + epoch).shuffle(files) f_start_id = 0 else: f_start_id = checkpoint['files'][0] files = checkpoint['files'][1:] args.resume_from_checkpoint = False num_files = len(files) # may not exist in all checkpoints epoch = checkpoint.get('epoch', 0) restored_dataloader = checkpoint.get('data_loader', None) shared_file_list = {} if smp.is_initialized(): dpsize = smp.dp_size() dprank = smp.dp_rank() elif torch.distributed.is_initialized(): dpsize = get_world_size() dprank = get_rank() else: dpsize = 1 dprank = 0 dparallel = dpsize > 1 if dparallel and dpsize > num_files: remainder = dpsize % num_files data_file = files[(f_start_id * dpsize + dprank + remainder * f_start_id) % num_files] else: data_file = files[(f_start_id * dpsize + dprank) % num_files] previous_file = data_file if restored_data_loader is None: train_data = pretraining_dataset(data_file, args.max_predictions_per_seq) train_sampler = RandomSampler(train_data) train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size * args.n_gpu, num_workers=4, worker_init_fn=worker_init, pin_memory=True, drop_last=True) # shared_file_list["0"] = (train_dataloader, data_file) else: train_dataloader = restored_data_loader restored_data_loader = None overflow_buf = None if args.allreduce_post_accumulation: overflow_buf = torch.cuda.IntTensor([0]) for f_id in range(f_start_id + 1, len(files)): if get_world_size() > num_files: data_file = files[(f_id * get_world_size() + get_rank() + remainder * f_id) % num_files] else: data_file = files[(f_id * get_world_size() + get_rank()) % num_files] previous_file = data_file dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args, worker_init) train_iter = tqdm(train_dataloader, desc="Iteration", disable=args.disable_progress_bar ) if is_main_process() else train_dataloader if raw_train_start is None: raw_train_start = time.time() for step, batch in enumerate(train_iter): training_steps += 1 batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch if args.do_train: from smdistributed.modelparallel.test.torch.utils import verify, dump_model model.train() if args.smp > 0: loss_mbs = smp_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step) loss = loss_mbs.reduce_mean() if smp.rank() == 0: print("Loss:", loss.item()) else: loss = train_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, optimizer, criterion, step) divisor = 1 average_loss += loss.item() if training_steps % args.gradient_accumulation_steps == 0: lr_scheduler.step() # learning rate warmup global_step = take_optimizer_step( args, optimizer, model, overflow_buf, global_step) if global_step >= args.steps_this_run or timeout_sent: train_time_raw = time.time() - raw_train_start last_num_steps = int( training_steps / args.gradient_accumulation_steps) % args.log_freq last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps average_loss = torch.tensor( average_loss, dtype=torch.float32).cuda() average_loss = average_loss / (last_num_steps * divisor) if (torch.distributed.is_initialized()): average_loss /= get_world_size() torch.distributed.all_reduce(average_loss) final_loss = loss.item() elif training_steps % ( args.log_freq * args.gradient_accumulation_steps) == 0: average_loss = 0 if global_step >= args.steps_this_run or training_steps % ( args.num_steps_per_checkpoint * args. gradient_accumulation_steps) == 0 or timeout_sent: if smp.dp_rank() == 0 and not args.skip_checkpoint: if args.resume_step < 0 or not args.phase2: output_save_file = os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step)) else: output_save_file = os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step + args.phase1_end_step)) if args.do_train: save_dict = { 'model': model.local_state_dict(), 'optimizer': optimizer.local_state_dict(), 'files': [f_id] + files, 'epoch': epoch, 'data_loader': None if global_step >= args.steps_this_run else train_dataloader } if args.fp16: save_dict['master params'] = list( amp.master_params(optimizer)) # SMP: Checkpoint mp_rank specific state smp.save(save_dict, output_save_file, partial=True) most_recent_ckpts_paths.append( output_save_file) if len(most_recent_ckpts_paths) > 3 and ( args.smp == 0 or smp.dp_rank() == 0): ckpt_to_be_removed = most_recent_ckpts_paths.pop( 0) os.remove(ckpt_to_be_removed + f"_{smp.mp_rank()}") # Exiting the training due to hitting max steps, or being sent a # timeout from the cluster scheduler if global_step >= args.steps_this_run or timeout_sent: del train_dataloader # thread.join() if smp.dp_rank() == 0 and args.save_full: output_save_file = os.path.join( args.output_dir, "ckpt_{}.pt".format(global_step)) save_dict = { 'model': model.local_state_dict(), 'optimizer': optimizer.local_state_dict(), 'files': [f_id] + files, 'epoch': epoch, 'data_loader': None if global_step >= args.steps_this_run else train_dataloader } if args.fp16: save_dict['master params'] = list( amp.master_params(optimizer)) # SMP: Save a single checkpoint containing entire model parameters smp.save(save_dict, output_save_file, partial=False) smp.barrier() if smp.local_rank() == 0: print(f"Start syncing model checkpoints to s3") base_s3_path = os.path.dirname( os.path.dirname( os.getenv('SM_MODULE_DIR', ''))) curr_host = os.getenv('SM_CURRENT_HOST') full_s3_path = f'{base_s3_path}/checkpoints/{curr_host}/' sync_local_checkpoints_to_s3( local_path=args.output_dir, s3_path=full_s3_path) print( f"Finished syncing model checkpoints to s3" ) return args, final_loss, train_time_raw, global_step else: model.eval() with torch.no_grad(): loss = test_step(args, device, input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, model, criterion, step) print(f"global_step {global_step} Test Loss:", loss) test_losses.append(loss) global_step += 1 if global_step >= args.steps_this_run: return sum(test_losses) / len(test_losses) del train_dataloader # thread.join() # Make sure pool has finished and switch train_dataloader # NOTE: Will block until complete train_dataloader, data_file = dataset_future.result(timeout=None) epoch += 1
def world_size(self): if is_smdistributed_available() and self.mp_parameters != "": return smp.dp_size() return super().world_size
def main(): parser = get_parser() args = parser.parse_args() if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") use_ddp = args.ddp > 0 # Fix seeds in order to get the same losses across runs random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) smp.init() # SM Distributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") kwargs = {"batch_size": args.batch_size} kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False}) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) if args.data_dir is None: # SM Distributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time args.data_dir = "../data" if smp.local_rank() == 0: dataset1 = datasets.MNIST(args.data_dir, train=True, download=True, transform=transform) smp.barrier() dataset1 = datasets.MNIST(args.data_dir, train=True, download=False, transform=transform) if (use_ddp) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset1 = SplitDataset(dataset1, partitions=partitions_dict) dataset1.select(f"{smp.dp_rank()}") # Download and create dataloaders for train and test dataset dataset2 = datasets.MNIST(args.data_dir, train=False, download=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **kwargs) model = GroupedNet() # SMP handles the transfer of parameters to the right device # and the user doesn't need to call 'model.to' explicitly. # model.to(device) optimizer = optim.Adadelta(model.parameters(), lr=args.lr) # SM Distributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) scaler = smp.amp.GradScaler() optimizer = smp.DistributedOptimizer(optimizer) if args.partial_checkpoint: checkpoint = smp.load(args.partial_checkpoint, partial=True) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: checkpoint = smp.load(args.full_checkpoint, partial=False) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) for epoch in range(1, args.epochs + 1): train(args, model, scaler, device, train_loader, optimizer, epoch) test_loss = test(args, model, device, test_loader) scheduler.step() if args.save_partial_model: if smp.dp_rank() == 0: model_dict = model.local_state_dict() opt_dict = optimizer.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, f"./pt_mnist_checkpoint.pt", partial=True, ) if args.save_full_model: if smp.dp_rank() == 0: model_dict = model.state_dict() opt_dict = optimizer.state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, "./pt_mnist_checkpoint.pt", partial=False, ) # Waiting the save checkpoint to be finished before run another allgather_object smp.barrier() if args.assert_losses: if use_ddp: # SM Distributed: If using data parallelism, gather all losses across different model # replicas and check if losses match. losses = smp.allgather(test_loss, smp.DP_GROUP) for l in losses: assert math.isclose(l, losses[0]) assert test_loss < 0.18 else: assert test_loss < 0.08 # For CI/CD smp.barrier() print("SMP training finished successfully")
def main(): parser = get_parser() args = parser.parse_args() if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") args.rank = -1 args.world_size = 1 if args.model_parallel: args.deepspeed = False cfg = { "microbatches": args.num_microbatches, "placement_strategy": args.placement_strategy, "pipeline": args.pipeline, "optimize": args.optimize, "partitions": args.num_partitions, "horovod": args.horovod, "ddp": args.ddp, } smp.init(cfg) torch.cuda.set_device(smp.local_rank()) args.rank = smp.dp_rank() args.world_size = smp.size() else: # initialize deepspeed print(f"args.deepspeed : {args.deepspeed}") deepspeed_utils.init_deepspeed(args.deepspeed) if deepspeed_utils.is_root_worker(): args.rank = 0 if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed + args.rank) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size) cudnn.deterministic = True if cudnn.deterministic: warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True} device = torch.device("cuda") logger.debug(f"args.image_folder : {args.image_folder}") logger.debug(f"args.rank : {args.rank}") ## SageMaker try: if os.environ.get('SM_MODEL_DIR') is not None: args.model_dir = os.environ.get('SM_MODEL_DIR') # args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR') args.image_folder = os.environ.get('SM_CHANNEL_TRAINING') except: logger.debug("not SageMaker") pass IMAGE_SIZE = args.image_size IMAGE_PATH = args.image_folder EPOCHS = args.EPOCHS BATCH_SIZE = args.BATCH_SIZE LEARNING_RATE = args.LEARNING_RATE LR_DECAY_RATE = args.LR_DECAY_RATE NUM_TOKENS = args.NUM_TOKENS NUM_LAYERS = args.NUM_LAYERS NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS EMB_DIM = args.EMB_DIM HID_DIM = args.HID_DIM KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT STARTING_TEMP = args.STARTING_TEMP TEMP_MIN = args.TEMP_MIN ANNEAL_RATE = args.ANNEAL_RATE NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE # transform = Compose( # [ # RandomResizedCrop(args.image_size, args.image_size), # OneOf( # [ # IAAAdditiveGaussianNoise(), # GaussNoise(), # ], # p=0.2 # ), # VerticalFlip(p=0.5), # OneOf( # [ # MotionBlur(p=.2), # MedianBlur(blur_limit=3, p=0.1), # Blur(blur_limit=3, p=0.1), # ], # p=0.2 # ), # OneOf( # [ # CLAHE(clip_limit=2), # IAASharpen(), # IAAEmboss(), # RandomBrightnessContrast(), # ], # p=0.3 # ), # HueSaturationValue(p=0.3), # # Normalize( # # mean=[0.485, 0.456, 0.406], # # std=[0.229, 0.224, 0.225], # # ) # ], # p=1.0 # ) transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize(IMAGE_SIZE), T.CenterCrop(IMAGE_SIZE), T.ToTensor() ]) sampler = None dl = None # data logger.debug(f"IMAGE_PATH : {IMAGE_PATH}") # ds = AlbumentationImageDataset( # IMAGE_PATH, # transform=transform, # args=args # ) ds = ImageFolder( IMAGE_PATH, transform=transform, ) if args.model_parallel and (args.ddp or args.horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } ds = SplitDataset(ds, partitions=partitions_dict) ds.select(f"{smp.dp_rank()}") dl = DataLoader(ds, BATCH_SIZE, shuffle=True, drop_last=args.model_parallel, **args.kwargs) vae_params = dict(image_size=IMAGE_SIZE, num_layers=NUM_LAYERS, num_tokens=NUM_TOKENS, codebook_dim=EMB_DIM, hidden_dim=HID_DIM, num_resnet_blocks=NUM_RESNET_BLOCKS) vae = DiscreteVAE(**vae_params, smooth_l1_loss=SMOOTH_L1_LOSS, kl_div_loss_weight=KL_LOSS_WEIGHT).to(device) # optimizer opt = Adam(vae.parameters(), lr=LEARNING_RATE) sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE) if args.model_parallel: import copy dummy_codebook = copy.deepcopy(vae.codebook) dummy_decoder = copy.deepcopy(vae.decoder) vae = smp.DistributedModel(vae) scaler = smp.amp.GradScaler() opt = smp.DistributedOptimizer(opt) if args.partial_checkpoint: args.checkpoint = smp.load(args.partial_checkpoint, partial=True) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: args.checkpoint = smp.load(args.full_checkpoint, partial=False) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) assert len(ds) > 0, 'folder does not contain any images' if (not args.model_parallel) and args.rank == 0: print(f'{len(ds)} images found for training') # weights & biases experiment tracking # import wandb model_config = dict(num_tokens=NUM_TOKENS, smooth_l1_loss=SMOOTH_L1_LOSS, num_resnet_blocks=NUM_RESNET_BLOCKS, kl_loss_weight=KL_LOSS_WEIGHT) # run = wandb.init( # project = 'dalle_train_vae', # job_type = 'train_model', # config = model_config # ) def save_model(path): if not args.rank == 0: return save_obj = {'hparams': vae_params, 'weights': vae.state_dict()} torch.save(save_obj, path) # distribute with deepspeed if not args.model_parallel: deepspeed_utils.check_batch_size(BATCH_SIZE) deepspeed_config = {'train_batch_size': BATCH_SIZE} (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute( args=args, model=vae, optimizer=opt, model_parameters=vae.parameters(), training_data=ds if args.deepspeed else dl, lr_scheduler=sched, config_params=deepspeed_config, ) try: # Rubik: Define smp.step. Return any tensors needed outside. @smp.step def train_step(vae, images, temp): # logger.debug(f"args.amp : {args.amp}") with autocast(enabled=(args.amp > 0)): loss, recons = vae(images, return_loss=True, return_recons=True, temp=temp) scaled_loss = scaler.scale(loss) if args.amp else loss vae.backward(scaled_loss) # torch.nn.utils.clip_grad_norm_(vae.parameters(), 5) return loss, recons @smp.step def get_codes_step(vae, images, k): images = images[:k] logits = vae.forward(images, return_logits=True) codebook_indices = logits.argmax(dim=1).flatten(1) return codebook_indices def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices): from functools import partial for module in dummy_codebook.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) image_embeds = dummy_codebook.forward(codebook_indices) b, n, d = image_embeds.shape h = w = int(sqrt(n)) image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h=h, w=w) for module in dummy_decoder.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) hard_recons = dummy_decoder.forward(image_embeds) return hard_recons except: pass # starting temperature global_step = 0 temp = STARTING_TEMP for epoch in range(EPOCHS): ## batch_time = util.AverageMeter('Time', ':6.3f') data_time = util.AverageMeter('Data', ':6.3f') losses = util.AverageMeter('Loss', ':.4e') top1 = util.AverageMeter('Acc@1', ':6.2f') top5 = util.AverageMeter('Acc@5', ':6.2f') progress = util.ProgressMeter( len(dl), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) vae.train() start = time.time() for i, (images, _) in enumerate(dl): images = images.to(device, non_blocking=True) opt.zero_grad() if args.model_parallel: loss, recons = train_step(vae, images, temp) # Rubik: Average the loss across microbatches. loss = loss.reduce_mean() recons = recons.reduce_mean() else: loss, recons = distr_vae(images, return_loss=True, return_recons=True, temp=temp) if (not args.model_parallel) and args.deepspeed: # Gradients are automatically zeroed after the step distr_vae.backward(loss) distr_vae.step() elif args.model_parallel: if args.amp: scaler.step(opt) scaler.update() else: # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param if len(list(vae.local_parameters())) > 0: opt.step() else: loss.backward() opt.step() logs = {} if i % 10 == 0: if args.rank == 0: # if deepspeed_utils.is_root_worker(): k = NUM_IMAGES_SAVE with torch.no_grad(): if args.model_parallel: model_dict = vae.state_dict() model_dict_updated = {} for key, val in model_dict.items(): if "decoder" in key: key = key.replace("decoder.", "") elif "codebook" in key: key = key.replace("codebook.", "") model_dict_updated[key] = val dummy_decoder.load_state_dict(model_dict_updated, strict=False) dummy_codebook.load_state_dict(model_dict_updated, strict=False) codes = get_codes_step(vae, images, k) codes = codes.reduce_mean().to(torch.long) hard_recons = hard_recons_step( dummy_decoder, dummy_codebook, codes) else: codes = vae.get_codebook_indices(images[:k]) hard_recons = vae.decode(codes) images, recons = map(lambda t: t[:k], (images, recons)) images, recons, hard_recons, codes = map( lambda t: t.detach().cpu(), (images, recons, hard_recons, codes)) images, recons, hard_recons = map( lambda t: make_grid(t.float(), nrow=int(sqrt(k)), normalize=True, range=(-1, 1)), (images, recons, hard_recons)) # logs = { # **logs, # 'sample images': wandb.Image(images, caption = 'original images'), # 'reconstructions': wandb.Image(recons, caption = 'reconstructions'), # 'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'), # 'codebook_indices': wandb.Histogram(codes), # 'temperature': temp # } if args.model_parallel: filename = f'{args.model_dir}/vae.pt' if smp.dp_rank == 0: if args.save_full_model: model_dict = vae.state_dict() opt_dict = opt.state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=False, ) else: model_dict = vae.local_state_dict() opt_dict = opt.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=True, ) smp.barrier() else: save_model(f'{args.model_dir}/vae.pt') # wandb.save(f'{args.model_dir}/vae.pt') # temperature anneal temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN) # lr decay sched.step() # Collective loss, averaged if args.model_parallel: avg_loss = loss.detach().clone() # print("args.world_size : {}".format(args.world_size)) avg_loss /= args.world_size else: avg_loss = deepspeed_utils.average_all(loss) if args.rank == 0: if i % 100 == 0: lr = sched.get_last_lr()[0] print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},') logs = { **logs, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item(), 'lr': lr } # wandb.log(logs) global_step += 1 if args.rank == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy # prec1, prec5 = util.accuracy(output, target, topk=(1, 5)) # to_python_float incurs a host<->device sync losses.update(util.to_python_float(loss), images.size(0)) # top1.update(util.to_python_float(prec1), images.size(0)) # top5.update(util.to_python_float(prec5), images.size(0)) # Waiting until finishing operations on GPU (Pytorch default: async) torch.cuda.synchronize() batch_time.update((time.time() - start) / args.log_interval) end = time.time() print( 'Epoch: [{0}][{1}/{2}] ' 'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, ' 'Train_Speed={3:.3f} ({4:.3f}), ' 'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format( epoch, i, len(dl), args.world_size * BATCH_SIZE / batch_time.val, args.world_size * BATCH_SIZE / batch_time.avg, batch_time=batch_time, loss=losses)) # if deepspeed_utils.is_root_worker(): # save trained model to wandb as an artifact every epoch's end # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae.pt') # run.log_artifact(model_artifact) if args.rank == 0: # if deepspeed_utils.is_root_worker(): # save final vae and cleanup if args.model_parallel: logger.debug('save model_parallel') else: save_model(os.path.join(args.model_dir, 'vae-final.pt')) # wandb.save(f'{args.model_dir}/vae-final.pt') # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae-final.pt') # run.log_artifact(model_artifact) # wandb.finish() if args.model_parallel: if args.assert_losses: if args.horovod or args.ddp: # SM Distributed: If using data parallelism, gather all losses across different model # replicas and check if losses match. losses = smp.allgather(loss, smp.DP_GROUP) for l in losses: print(l) assert math.isclose(l, losses[0]) assert loss < 0.18 else: assert loss < 0.08 smp.barrier() print("SMP training finished successfully")