def _wrap_model(self, model, training=True): if self.is_model_parallel_enabled: # Wrapping the base model twice in a DistributedModel will raise an error. if isinstance(self.model_wrapped, smp.model.DistributedModel): return self.model_wrapped return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) else: return super()._wrap_model(model)
def _wrap_model(self, model, training=True): if self.is_model_parallel_enabled: # Wrapping the base model twice in a DistributedModel will raise an error. if isinstance(self.model_wrapped, smp.model.DistributedModel): return self.model_wrapped return smp.DistributedModel(model) else: return super()._wrap_model(model)
def smp_init(model, optimizer, args): model = smp.DistributedModel(model) args.scaler = smp.amp.GradScaler() optimizer = smp.DistributedOptimizer(optimizer) if args.partial_checkpoint: args.checkpoint = smp.load(args.partial_checkpoint, partial=True) model.load_state_dict(args.checkpoint["model_state_dict"]) optimizer.load_state_dict(args.checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: args.checkpoint = smp.load(args.full_checkpoint, partial=False) model.load_state_dict(args.checkpoint["model_state_dict"]) optimizer.load_state_dict(args.checkpoint["optimizer_state_dict"]) return model, optimizer, args
def main(): parser = get_parser() args = parser.parse_args() if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") use_ddp = args.ddp > 0 use_horovod = args.horovod > 0 # Fix seeds in order to get the same losses across runs random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) cfg = { "microbatches": args.num_microbatches, "placement_strategy": "spread", "pipeline": args.pipeline, "optimize": "speed", "partitions": args.num_partitions, "horovod": use_horovod, "ddp": use_ddp, } smp.init(cfg) # SM Distributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") kwargs = {"batch_size": args.batch_size} kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False}) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) # SM Distributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time if smp.local_rank() == 0: dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) smp.barrier() dataset1 = datasets.MNIST("../data", train=True, download=False, transform=transform) if (use_ddp or use_horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset1 = SplitDataset(dataset1, partitions=partitions_dict) dataset1.select(f"{smp.dp_rank()}") # Download and create dataloaders for train and test dataset dataset2 = datasets.MNIST("../data", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **kwargs) model = GroupedNet() # SMP handles the transfer of parameters to the right device # and the user doesn't need to call 'model.to' explicitly. # model.to(device) optimizer = optim.Adadelta(model.parameters(), lr=args.lr) # SM Distributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) scaler = smp.amp.GradScaler() optimizer = smp.DistributedOptimizer(optimizer) if args.partial_checkpoint: checkpoint = smp.load(args.partial_checkpoint, partial=True) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: checkpoint = smp.load(args.full_checkpoint, partial=False) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) for epoch in range(1, args.epochs + 1): train(args, model, scaler, device, train_loader, optimizer, epoch) test_loss = test(args, model, device, test_loader) scheduler.step() if args.save_partial_model: if smp.dp_rank() == 0: model_dict = model.local_state_dict() opt_dict = optimizer.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, f"./pt_mnist_checkpoint.pt", partial=True, ) if args.save_full_model: if smp.dp_rank() == 0: model_dict = model.state_dict() opt_dict = optimizer.state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, "./pt_mnist_checkpoint.pt", partial=False, ) # Waiting the save checkpoint to be finished before run another allgather_object smp.barrier() if args.assert_losses: if use_horovod or use_ddp: # SM Distributed: If using data parallelism, gather all losses across different model # replicas and check if losses match. losses = smp.allgather(test_loss, smp.DP_GROUP) for l in losses: assert math.isclose(l, losses[0]) assert test_loss < 0.18 else: assert test_loss < 0.08
def main(): if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") use_ddp = True use_horovod = False # Fix seeds in order to get the same losses across runs random.seed(1) np.random.seed(1) torch.manual_seed(1) torch.cuda.manual_seed(1) smp.init() # SM Distributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") kwargs = {"batch_size": 64} kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": False}) transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081, ))]) # SM Distributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time if smp.local_rank() == 0: dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) smp.barrier() dataset1 = datasets.MNIST("../data", train=True, download=False, transform=transform) if (use_ddp or use_horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset1 = SplitDataset(dataset1, partitions=partitions_dict) dataset1.select(f"{smp.dp_rank()}") # Download and create dataloaders for train and test dataset dataset2 = datasets.MNIST("../data", train=False, transform=transform) train_loader = torch.utils.data.DataLoader(dataset1, **kwargs) test_loader = torch.utils.data.DataLoader(dataset2, **kwargs) model = GroupedNet() # SMP handles the transfer of parameters to the right device # and the user doesn't need to call 'model.to' explicitly. # model.to(device) optimizer = optim.Adadelta(model.parameters(), lr=4.0) # SM Distributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) scaler = smp.amp.GradScaler() optimizer = smp.DistributedOptimizer(optimizer) scheduler = StepLR(optimizer, step_size=1, gamma=0.7) for epoch in range(1, 2): train(model, scaler, device, train_loader, optimizer, epoch) test_loss = test(model, device, test_loader) scheduler.step() if smp.rank() == 0: if os.path.exists("/opt/ml/local_checkpoints"): print("-INFO- PATH DO EXIST") else: os.makedirs("/opt/ml/local_checkpoints") print("-INFO- PATH DO NOT EXIST") # Waiting the save checkpoint to be finished before run another allgather_object smp.barrier() if smp.dp_rank() == 0: model_dict = model.local_state_dict() opt_dict = optimizer.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, f"/opt/ml/local_checkpoints/pt_mnist_checkpoint.pt", partial=True, ) smp.barrier() if smp.local_rank() == 0: print("Start syncing") base_s3_path = os.path.dirname( os.path.dirname(os.getenv("SM_MODULE_DIR", ""))) curr_host = os.getenv("SM_CURRENT_HOST") full_s3_path = f"{base_s3_path}/checkpoints/{curr_host}/" sync_local_checkpoints_to_s3(local_path="/opt/ml/local_checkpoints", s3_path=full_s3_path) print("Finished syncing")
def prepare_model_and_optimizer(args, device): # Prepare model config = modeling.BertConfig.from_json_file(args.config_file) # Padding for divisibility by 8 if config.vocab_size % 8 != 0: config.vocab_size += 8 - (config.vocab_size % 8) if args.use_sequential > 0: config.use_sequential = True else: config.use_sequential = False modeling.ACT2FN["bias_gelu"] = modeling.bias_gelu_training model = modeling.BertForPreTraining(config) model.checkpoint_activations(args.checkpoint_activations) if args.smp > 0: # SMP: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) checkpoint = None if not args.resume_from_checkpoint: global_step = 0 else: if not args.init_checkpoint: if not args.s3_checkpoint_uri: raise ValueError( "Need to set s3_checkpoint_uri, if init_checkpoint not set" ) if smp.local_rank() == 0: sync_s3_checkpoints_to_local(args.output_dir, args.s3_checkpoint_uri) smp.barrier() if args.resume_step == -1 and not args.init_checkpoint: model_names = [ f for f in os.listdir(args.output_dir) if ".pt" in f ] args.resume_step = max([ int(x.split(".pt")[0].split("_")[1].strip()) for x in model_names ]) global_step = args.resume_step if not args.init_checkpoint else 0 # SMP: Load a model that was saved with smp.save if not args.init_checkpoint: checkpoint = smp.load( os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), partial=args.partial_checkpoint, ) else: checkpoint = smp.load(args.init_checkpoint) model.load_state_dict(checkpoint["model"], strict=False) if args.phase2 and not args.init_checkpoint: global_step -= args.phase1_end_step if is_main_process(): print("resume step from ", args.resume_step) model.to(device) param_optimizer = list(model.named_parameters()) no_decay = ["bias", "gamma", "beta", "LayerNorm"] optimizer_grouped_parameters = [ { "params": [ p for n, p in param_optimizer if not any(nd in n for nd in no_decay) ], "weight_decay": 0.01, }, { "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] optimizer = FusedLAMB(optimizer_grouped_parameters, lr=args.learning_rate) if args.smp > 0: # SMP: Use Distributed Optimizer which allows the loading of optimizer state for a distributed model # Also provides APIs to obtain local optimizer state for the current mp_rank. optimizer = smp.DistributedOptimizer(optimizer) lr_scheduler = PolyWarmUpScheduler(optimizer, warmup=args.warmup_proportion, total_steps=args.max_steps) if args.fp16: if args.loss_scale == 0: model, optimizer = amp.initialize( model, optimizer, opt_level="O2", loss_scale="dynamic", cast_model_outputs=torch.float16, ) else: model, optimizer = amp.initialize( model, optimizer, opt_level="O2", loss_scale=args.loss_scale, cast_model_outputs=torch.float16, ) amp._amp_state.loss_scalers[0]._loss_scale = args.init_loss_scale if args.resume_from_checkpoint: if args.phase2 or args.init_checkpoint: keys = list(checkpoint["optimizer"]["state"].keys()) # Override hyperparameters from previous checkpoint for key in keys: checkpoint["optimizer"]["state"][key]["step"] = global_step for iter, item in enumerate( checkpoint["optimizer"]["param_groups"]): checkpoint["optimizer"]["param_groups"][iter][ "step"] = global_step checkpoint["optimizer"]["param_groups"][iter][ "t_total"] = args.max_steps checkpoint["optimizer"]["param_groups"][iter][ "warmup"] = args.warmup_proportion checkpoint["optimizer"]["param_groups"][iter][ "lr"] = args.learning_rate optimizer.load_state_dict(checkpoint["optimizer"]) # , strict=False) # Restore AMP master parameters if args.fp16: optimizer._lazy_init_maybe_master_weights() optimizer._amp_stash.lazy_init_called = True optimizer.load_state_dict(checkpoint["optimizer"]) for param, saved_param in zip(amp.master_params(optimizer), checkpoint["master params"]): param.data.copy_(saved_param.data) # if args.local_rank != -1: # if not args.allreduce_post_accumulation: # model = DDP(model, message_size=250000000, gradient_predivide_factor=get_world_size()) # else: # flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) ) # elif args.n_gpu > 1: # model = torch.nn.DataParallel(model) criterion = BertPretrainingCriterion(config.vocab_size) return model, optimizer, lr_scheduler, checkpoint, global_step, criterion
def main(args): # constants VAE_PATH = args.vae_path DALLE_PATH = args.dalle_path RESUME = exists(DALLE_PATH) EPOCHS = args.epochs BATCH_SIZE = args.batch_size LEARNING_RATE = args.learning_rate GRAD_CLIP_NORM = args.clip_grad_norm LR_DECAY = args.lr_decay MODEL_DIM = args.dim TEXT_SEQ_LEN = args.text_seq_len DEPTH = args.depth HEADS = args.heads DIM_HEAD = args.dim_head REVERSIBLE = args.reversible LOSS_IMG_WEIGHT = args.loss_img_weight ATTN_TYPES = tuple(args.attn_types.split(',')) DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt' # initialize distributed backend # initialize distributed backend if args.sagemakermp: args.deepspeed = False using_deepspeed = False else: args.deepspeed = True distr_backend = distributed_utils.set_backend_from_args(args) distr_backend.initialize(args) if args.sagemakermp: args = smp_init(args) distributed_utils.using_backend(distributed_utils.SageMakerMPBackend) else: using_deepspeed = \ distributed_utils.using_backend(distributed_utils.DeepSpeedBackend) args.rank = int(os.environ.get('RANK')) args.world_size = int(os.environ.get('WORLD_SIZE')) args.local_rank = int(os.environ.get('LOCAL_RANK')) args.global_rank = args.rank logger.debug(f"using_deepspeed : {using_deepspeed}") logger.debug( f"args.local_rank : {args.local_rank}, args.rank : {args.rank}") # tokenizer logger.debug(f"exists(args.bpe_path) : {exists(args.bpe_path)}, args.chinese : {args.chinese}") if exists(args.bpe_path): klass = HugTokenizer if args.hug else YttmTokenizer tokenizer = klass(args.bpe_path) elif args.chinese: tokenizer = ChineseTokenizer() else: tokenizer = SimpleTokenizer() # reconstitute vae if RESUME: dalle_path = Path(DALLE_PATH) if using_deepspeed: cp_dir = cp_path_to_dir(dalle_path, 'ds') assert cp_dir.is_dir(), \ f'DeepSpeed checkpoint directory {cp_dir} not found' dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME else: assert dalle_path.exists(), 'DALL-E model file does not exist' loaded_obj = torch.load(str(dalle_path), map_location='cpu') dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj[ 'vae_params'], loaded_obj['weights'] if vae_params is not None: vae = DiscreteVAE(**vae_params) else: vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) dalle_params = dict(**dalle_params) IMAGE_SIZE = vae.image_size else: if exists(VAE_PATH): vae_path = Path(VAE_PATH) assert vae_path.exists(), 'VAE model file does not exist' assert not vae_path.is_dir(), \ ('Cannot load VAE model from directory; please use a ' 'standard *.pt checkpoint. ' 'Currently, merging a DeepSpeed-partitioned VAE into a DALLE ' 'model is not supported.') loaded_obj = torch.load(str(vae_path)) vae_params, weights = loaded_obj['hparams'], loaded_obj['weights'] vae = DiscreteVAE(**vae_params) vae.load_state_dict(weights) else: if args.rank == 0: # if distr_backend.is_root_worker(): print('using pretrained VAE for encoding images to tokens') vae_params = None logger.debug(f"************* args.taming : {args.taming}") vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) IMAGE_SIZE = vae.image_size dalle_params = dict( num_text_tokens=tokenizer.vocab_size, text_seq_len=TEXT_SEQ_LEN, dim=MODEL_DIM, depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD, reversible=REVERSIBLE, loss_img_weight=LOSS_IMG_WEIGHT, attn_types=ATTN_TYPES, ) # configure OpenAI VAE for float16s if isinstance(vae, OpenAIDiscreteVAE) and args.fp16: vae.enc.blocks.output.conv.use_float16 = True # create dataset and dataloader is_shuffle = not distributed_utils.using_backend( distributed_utils.HorovodBackend) ds = TextImageDataset( args.image_text_folder, text_len=TEXT_SEQ_LEN, image_size=IMAGE_SIZE, resize_ratio=args.resize_ratio, truncate_captions=args.truncate_captions, tokenizer=tokenizer, shuffle=is_shuffle, ) assert len(ds) > 0, 'dataset is empty' # if distr_backend.is_root_worker(): if args.rank == 0: print(f'{len(ds)} image-text pairs found for training') if not is_shuffle: data_sampler = torch.utils.data.distributed.DistributedSampler( ds, num_replicas=args.world_size, rank=args.rank) elif args.sagemakermp: args.ds = ds ds = split_dataset(args) data_sampler = None else: data_sampler = None print(f"data_sampler : {data_sampler}") # uncorrectable NVLink error was detected during the execution --> remove kwargs = {'num_workers': args.num_worker, 'pin_memory': True} dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler, **kwargs ) logger.info("Processes {}/{} ({:.0f}%) of train data".format( len(dl.sampler), len(dl.dataset), 100. * len(dl.sampler) / len(dl.dataset))) # initialize DALL-E dalle = DALLE(vae=vae, **dalle_params) if not using_deepspeed: if args.fp16: dalle = dalle.half() dalle = dalle.cuda() if RESUME and not using_deepspeed: dalle.load_state_dict(weights) # optimizer opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE) if LR_DECAY: scheduler = ReduceLROnPlateau( opt, mode="min", factor=0.5, patience=10, cooldown=10, min_lr=1e-6, verbose=True, ) # if distr_backend.is_root_worker(): if args.global_rank == 0: # experiment tracker model_config = dict(depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD) logger.debug(f"args.wandb_name : {args.wandb_name}, RESUME : {RESUME}") run = wandb.init( project=args.wandb_name, # 'dalle_train_transformer' by default resume=RESUME, config=model_config, ) # distribute distr_backend.check_batch_size(BATCH_SIZE) deepspeed_config = { 'train_batch_size': BATCH_SIZE, 'gradient_clipping': GRAD_CLIP_NORM, 'fp16': { 'enabled': args.fp16, }, } (distr_dalle, distr_opt, distr_dl, distr_scheduler) = distr_backend.distribute( args=args, model=dalle, optimizer=opt, model_parameters=get_trainable_params(dalle), training_data=ds if using_deepspeed else dl, lr_scheduler=scheduler if LR_DECAY else None, config_params=deepspeed_config, ) avoid_model_calls = using_deepspeed and args.fp16 if args.sagemakermp: args.distr_dalle = smp.DistributedModel(distr_dalle) args.scaler = smp.amp.GradScaler() args.distr_opt = smp.DistributedOptimizer(distr_opt) if RESUME and using_deepspeed: distr_dalle.load_checkpoint(str(cp_dir)) # training for epoch in range(EPOCHS): logger.debug(f"********* epoch : {epoch} **********") if data_sampler: data_sampler.set_epoch(epoch) for i, (text, images) in enumerate(distr_dl): if args.fp16: images = images.half() text, images = map(lambda t: t.cuda(), (text, images)) if args.sagemakermp: args.distr_opt.zero_grad() loss = train_step(args, text, images, return_loss=True) loss = loss.reduce_mean() else: loss = distr_dalle(text, images, return_loss=True, args=args) if using_deepspeed: distr_dalle.backward(loss) distr_dalle.step() # Gradients are automatically zeroed after the step elif args.sagemakermp: if args.amp: scaler.step(args.distr_opt) scaler.update() else: # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param if len(list(args.distr_dalle.local_parameters())) > 0: args.distr_opt.step() else: loss.backward() clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM) distr_opt.step() distr_opt.zero_grad() # Collective loss, averaged avg_loss = distr_backend.average_all(loss) log = {} # if i % 10 == 0 and distr_backend.is_root_worker(): if i % 10 == 0 and args.rank == 0: print(epoch, i, f'loss - {avg_loss.item()}') log = { **log, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item() } if i % 100 == 0: # if distr_backend.is_root_worker(): if args.rank == 0: sample_text = text[:1] token_list = sample_text.masked_select( sample_text != 0).tolist() decoded_text = tokenizer.decode(token_list) logger.debug(f"******* avoid_model_calls : {avoid_model_calls}") if not avoid_model_calls: # CUDA index errors when we don't guard this image = dalle.generate_images( text[:1], filter_thres=0.9) # topk sampling at 0.9 wandb.save(f'./dalle.pt') log = { **log, } if not avoid_model_calls: log['image'] = wandb.Image(image, caption=decoded_text) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f'{args.model_dir}/dalle.pt') # if distr_backend.is_root_worker(): if args.rank == 0: wandb.log(log) # text, images = prefetcher.next() if LR_DECAY and not using_deepspeed: # Scheduler is automatically progressed after the step when # using DeepSpeed. distr_scheduler.step(loss) # if distr_backend.is_root_worker(): if args.global_rank == 0: # save trained model to wandb as an artifact every epoch's end model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) # model_artifact.add_file('dalle.pt') run.log_artifact(model_artifact) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f'{args.model_dir}/dalle-final.pt') # if distr_backend.is_root_worker(): if args.global_rank == 0: wandb.save('./dalle-final.pt') model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) # model_artifact.add_file('dalle-final.pt') run.log_artifact(model_artifact) wandb.finish()
def main(): args = parse_args() if args.shard_optimizer_state > 0 and not args.skip_full_optimizer: raise ValueError( "If shard_optimizer_state is enabled, skip_full_optimizer must also be enabled. Full optimizer saving is currently not supported under optimizer state sharding." ) if args.partition_assignment != "" and args.manual_partition == 0: print("[Warning] partition_assignment is set, enable manual_partition") args.manual_partition = 1 # any value here is overriden by the config set in notebook when launching the sagemaker job smp_config = { "ddp": True, "tensor_parallel_degree": args.tensor_parallel_degree, "pipeline_parallel_degree": args.pipeline_parallel_degree, "microbatches": args.microbatches, # if activation_checkpointing true checkpoints transformer layers below "checkpoint_attentions": False if args.activation_checkpointing else True, "shard_optimizer_state": args.shard_optimizer_state > 0, "prescaled_batch": args.prescaled_batch > 0, "offload_activations": args.offload_activations > 0, "optimize": args.optimize, "auto_partition": False if args.manual_partition else True, "default_partition": 0, "static_mode": args.static_mode > 0, "fast_mode": args.fast_mode > 0, } if args.smp_version < 110: smp_config["fp16_params"] = args.fp16 > 0 else: smp_config["fp16"] = args.fp16 > 0 smp_config["delayed_parameter_initialization"] = args.delayed_param > 0 smp_config["placement_strategy"] = args.placement_strategy smp_config[ "activation_loading_horizon"] = args.activation_loading_horizon smp_config["skip_tracing"] = args.skip_tracing > 0 if args.active_microbatches is not None: smp_config["active_microbatches"] = args.active_microbatches smp.init(smp_config) if smp.rank() == 0: print("Arguments:", args.__dict__) print(f"Transformers version: {transformers.__version__}") print( f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}" ) print(f"smdistributed config: {smp_config}") if args.save_final_full_model and smp.rank() == 0: print( f"[Warning] Note that save_final_full_model only saves the final model at the end of all steps. It does not save optimizer state. Optimizer state is only saved with partial models which are saved at checkpointing_freq during training. If you want to restart training you need partial checkpoints." ) if args.partition_assignment != "": partition_assignment = args.partition_assignment.split(",") assert ( len(partition_assignment) == smp.pp_size() ), f"partition_assignment must have the same size as pipeline parallel degree, but getting {len(partition_assignment)} vs {smp.pp_size()}" if smp.rank() == 0 or (smp.local_rank() == 0 and args.use_fsx == 0): for path in [args.model_dir, args.checkpoint_dir]: if not os.path.exists(path): os.makedirs(path, exist_ok=True) model_config = GPT2Config( vocab_size=args.vocab_size, n_positions=args.max_context_width, n_embd=args.hidden_width, n_layer=args.num_layers, n_head=args.num_heads, n_inner=None, activation_function="gelu_new", resid_pdrop=args.resid_pdrop, embd_pdrop=args.embd_pdrop, attn_pdrop=args.attn_pdrop, layer_norm_epsilon=1e-05, initializer_range=0.02, summary_type="cls_index", summary_use_proj=True, summary_activation=None, summary_proj_to_labels=True, summary_first_dropout=args.summary_first_pdrop, # gradient_checkpointing=args.gradient_checkpointing > 0, use_cache=False, bos_token_id=50256, eos_token_id=50256, return_dict=True, ) # the following improves start-up time by skipping proper initialization # of weights in the original model. this is not a problem because DistributedModel # will override those weights anyway when tensor_parallel_degree > 1. if smp.tp_size() > 1: from transformers.modeling_utils import PreTrainedModel PreTrainedModel.init_weights = lambda x: None set_seed(args.seed) if args.enable_memory_profiling > 0: memory_status_cpu(msg="before model creation") if args.smp_version < 110: if args.fp16: torch.set_default_dtype(torch.float16) with smp.tensor_parallelism( enabled=smp.tp_size() > 1, attention_in_fp32=args.attention_in_fp32 > 0): with smp.delay_param_initialization( enabled=(smp.tp_size() > 1 and args.delayed_param > 0)): model = AutoModelForCausalLM.from_config(model_config) else: with smp.model_creation( tensor_parallelism=smp.tp_size() > 1, attention_in_fp32=args.attention_in_fp32 > 0, query_key_layer_scaling=args.query_key_layer_scaling > 0, fused_softmax=args.fused_softmax > 0, fused_bias_gelu=args.fused_bias_gelu > 0, dtype=torch.float16 if args.fp16 else torch.get_default_dtype(), ): model = AutoModelForCausalLM.from_config(model_config) if args.smp_version < 110 and args.fp16: model = FP16_Module(model) if args.enable_memory_profiling > 0: memory_status_cpu(msg="after model creation") num_params = sum([np.prod(p.size()) for p in model.parameters()]) if smp.rank() == 0: print(f"# total parameters: {num_params}") # smdistributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") if not args.same_seed: # Set seed by tp_rank to prevent weights from being the same on different tp_ranks set_seed(args.seed + smp.tp_rank()) # smdistributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. if args.smp_version < 110 and args.fp16: torch.set_default_dtype(torch.float16) if args.enable_memory_profiling > 0: memory_status_cpu(msg="before dist model creation") model = smp.DistributedModel(model, trace_device="gpu") if args.enable_memory_profiling > 0: memory_status_cpu(msg="after dist model creation") if args.smp_version < 110: if smp.tp_size() > 1: transformer_layers = model.module.module.module.transformer.seq_layers else: transformer_layers = model.module.module.module.transformer.h else: m = model.get_module() if smp.tp_size() > 1: transformer_layers = m.transformer.seq_layers else: transformer_layers = m.transformer.h if args.manual_partition: print(f"Manual partition enabled") if args.partition_assignment != "": get_num_layers = lambda x: int(partition_assignment[x]) total_layers = sum( [get_num_layers(pp_rank) for pp_rank in range(smp.pp_size())]) assert ( total_layers == args.num_layers ), f"partition_assignment must have the same total transformer layers as model, but getting {total_layers} vs {args.num_layers}" else: # evenly distribute layers across all partitions div, rem = divmod(args.num_layers, smp.pp_size()) get_num_layers = lambda x: (div + 1 if x >= smp.pp_size() - rem else div) assignments = [] # (TODO) This is required for 175B otherwise a hang for partition "8,17,17,18,18,18" # Need further investigation # for pp_rank in reversed(range(smp.pp_size())): for pp_rank in range(smp.pp_size()): nl = get_num_layers(pp_rank) print(f"{nl} layers assigned to partition {pp_rank}") assignments += [pp_rank for _ in range(nl)] for i, c in enumerate(transformer_layers.children()): smp.set_partition(c, assignments[i]) if args.smp_version < 110: iter_model = model # Build parameter groups (weight decay and non-decay). while isinstance(iter_model, (DistributedDataParallel, FP16_Module)): iter_model = iter_model.module else: iter_model = m param_groups = get_param_groups_by_weight_decay(iter_model) if args.use_adamw > 0: optimizer = optim.AdamW(param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay) else: optimizer = optim.Adam(param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay) if args.activation_checkpointing: kwargs = {} if isinstance(transformer_layers, nn.Sequential): kwargs["pack_args_as_tuple"] = True kwargs["strategy"] = args.activation_strategy smp.set_activation_checkpointing(transformer_layers, **kwargs) if args.smp_version < 110: optimizer = FP16_Optimizer( model, optimizer, static_loss_scale=None, dynamic_loss_scale=True, use_smp=True, dynamic_loss_args={ "scale_window": 1000, "min_scale": 1, "delayed_shift": 2 }, params_have_main_grad=False, shard_optimizer_state=args.shard_optimizer_state > 0, ) optimizer = smp.DistributedOptimizer(optimizer) model.register_post_step_hook( lambda model, optimizer: optimizer.init_master_params()) else: optimizer = smp.DistributedOptimizer( optimizer, static_loss_scale=None, dynamic_loss_scale=True, dynamic_loss_args={ "scale_window": 1000, "min_scale": 1, "delayed_shift": 2 }, ) lr_scheduler = get_learning_rate_scheduler(optimizer, args) if args.enable_memory_profiling > 0: model.register_post_partition_hook( lambda model, optimizer: memory_status(msg="After_partition")) # load after wrapping model and optimizer with smp Distributed... if args.load_full or args.load_partial: if args.load_partial and args.load_full: print( "Since both --load_partial and --load_full set, will try to load from full checkpoint." "If the intention is to load from partial checkpoint, please don't set --load_full" ) partial = not args.load_full path = args.checkpoint_dir if partial else args.model_dir translate_from_hf = not partial model, optimizer, total_steps, start_train_path_index, start_batch_index = load_model_and_optimizer( path, model, optimizer, lr_scheduler, partial, args, translate_from_hf=translate_from_hf, seq_length=args.max_context_width, load_model=True, load_optimizer=args.load_partial > 0, num_params=num_params, ) else: total_steps = 0 start_train_path_index = 0 start_batch_index = 0 start = time.time() total_steps, throughput, loss = train( model, optimizer, lr_scheduler, model_config, start_train_path_index, start_batch_index, num_params, total_steps, args, ) time_to_train = time.time() - start if args.ci: print(f"[SMP_METRIC]__GPT2__Time_to_train__{time_to_train}") print(f"[SMP_METRIC]__GPT2__samples/second__{throughput}") print(f"[SMP_METRIC]__GPT2__Loss__{loss}") if not args.load_partial and not args.load_full: assert time_to_train < args.time_to_train assert throughput > args.throughput if args.loss: assert loss < args.loss if args.save_final_full_model: # saves full model at the end base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt" out_path = os.path.join(args.model_dir, base_path) if smp.rdp_rank() == 0: save( out_path, model, optimizer, lr_scheduler, model_config, num_params, total_steps, -1, args, partial=False, translate_to_hf=smp.tp_size() > 1, seq_length=args.max_context_width, ) smp.barrier() if smp.rank() == 0: print("SMP training finished successfully")
get_hard_recons = vae.decode opt = Adam(vae.parameters(), lr = LEARNING_RATE) sched = ExponentialLR(optimizer = opt, gamma = LR_DECAY_RATE) logger.debug(f"args.local_rank : {args.local_rank}") if args.local_rank is not None: torch.cuda.set_device(args.local_rank) else: torch.cuda.set_device(0) if args.multigpus_distributed: vae.cuda(args.local_rank) if args.model_parallel: vae = smp.DistributedModel(vae) args.scaler = smp.amp.GradScaler() opt = smp.DistributedOptimizer(opt) if args.partial_checkpoint: args.checkpoint = smp.load(args.partial_checkpoint, partial=True) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: args.checkpoint = smp.load(args.full_checkpoint, partial=False) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) else: vae = vae.cuda() else: vae = vae.cuda()
def init_train(): """ Train the PyTorch model """ cat_mask = [ False, True, True, True, True, False, True, True, True, True, True, False, False, False, False, False, False, False ] train_ds = CsvDatasetSimple(args.train) test_ds = CsvDatasetSimple(args.test) batch_size = args.batch_size epochs = args.epochs learning_rate = args.learning_rate logger.info("batch_size = {}, epochs = {}, learning rate = {}".format( batch_size, epochs, learning_rate)) # smdistributed: initialize the backend smp.init() # smdistributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") # smdistributed: Download only on a single process per instance. # When this is not present, the file is corrupted by multiple processes trying # to download and extract at the same time #dataset = datasets.MNIST("../data", train=True, download=False) dataset = train_ds # smdistributed: Shard the dataset based on data-parallel ranks if smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } dataset = SplitDataset(dataset, partitions=partitions_dict) dataset.select(f"{smp.dp_rank()}") # smdistributed: Set drop_last=True to ensure that batch size is always divisible # by the number of microbatches train_loader = torch.utils.data.DataLoader(dataset, batch_size=64, drop_last=True) model = TabularNet(n_cont=9, n_cat=9, cat_mask=cat_mask, cat_dim=[ 0, 2050, 13, 5, 366, 0, 50000, 50000, 50000, 50000, 50, 0, 0, 0, 0, 0, 0, 0 ], y_min=0., y_max=1.) logger.debug(model) optimizer = optim.Adadelta(model.parameters(), lr=4.0) # SMP: Instantiate DistributedModel object using the model. # This handles distributing the model among multiple ranks # behind the scenes # If horovod is enabled this will do an overlapping_all_reduce by # default. # smdistributed: Use the DistributedModel container to provide the model # to be partitioned across different ranks. For the rest of the script, # the returned DistributedModel object should be used in place of # the model provided for DistributedModel class instantiation. model = smp.DistributedModel(model) optimizer = smp.DistributedOptimizer(optimizer) train(model, device, train_loader, optimizer) torch.save(model.state_dict(), args.model_dir + "/model.pth")
def main(): model_args, data_args, training_args, smp_args = parse_args() model, tokenizer = initialize_model_and_tokenizer(model_args) # Get datasets train_dataset, eval_dataset = Preprocess.datasets(model_args, data_args, training_args) if is_sagemaker_mp_enabled(): initialize_smp(smp_args, training_args) torch.set_default_dtype(torch.float32) num_params = print_num_parameters(model) # smdistributed: Set the device to the GPU ID used by the current process. # Input tensors should be transferred to this device. torch.cuda.set_device(smp.local_rank()) device = torch.device("cuda") if not training_args.same_seed: # Set seed by tp_rank to prevent weights from being the same on different tp_ranks set_seed(training_args.seed + smp.tp_rank()) model = smp.DistributedModel(model, trace_device=smp_args.trace_device, gradient_as_bucket_view=True) torch.set_default_dtype(torch.float32) iter_model = model # Build parameter groups (weight decay and non-decay). while isinstance(iter_model, (DistributedDataParallel, FP16_Module)): iter_model = iter_model.module param_groups = get_param_groups_by_weight_decay(iter_model) if training_args.use_adamw > 0: optimizer = training_args.AdamW( param_groups, betas=(training_args.beta1, training_args.beta2), lr=training_args.lr, weight_decay=training_args.weight_decay, ) else: optimizer = optim.Adam( param_groups, betas=(training_args.beta1, training_args.beta2), lr=training_args.lr, weight_decay=training_args.weight_decay, ) optimizer = smp.DistributedOptimizer(optimizer) lr_scheduler = get_learning_rate_scheduler(optimizer, training_args) total_steps = 0 start_train_path_index = 0 start_batch_index = 0 # Initialize Trainer instance trainer = SMPTrainer( model=model, args=training_args, train_dataset=train_dataset if training_args.do_train else None, eval_dataset=eval_dataset if training_args.do_eval else None, tokenizer=tokenizer, data_collator=default_data_collator, ) trainer.train_smp( model, optimizer, lr_scheduler, start_train_path_index, start_batch_index, num_params, total_steps, training_args, prescaled_batch=smp_args.prescaled_batch, )
def main(): parser = get_parser() args = parser.parse_args() if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") args.rank = -1 args.world_size = 1 if args.model_parallel: args.deepspeed = False cfg = { "microbatches": args.num_microbatches, "placement_strategy": args.placement_strategy, "pipeline": args.pipeline, "optimize": args.optimize, "partitions": args.num_partitions, "horovod": args.horovod, "ddp": args.ddp, } smp.init(cfg) torch.cuda.set_device(smp.local_rank()) args.rank = smp.dp_rank() args.world_size = smp.size() else: # initialize deepspeed print(f"args.deepspeed : {args.deepspeed}") deepspeed_utils.init_deepspeed(args.deepspeed) if deepspeed_utils.is_root_worker(): args.rank = 0 if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed + args.rank) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size) cudnn.deterministic = True if cudnn.deterministic: warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True} device = torch.device("cuda") logger.debug(f"args.image_folder : {args.image_folder}") logger.debug(f"args.rank : {args.rank}") ## SageMaker try: if os.environ.get('SM_MODEL_DIR') is not None: args.model_dir = os.environ.get('SM_MODEL_DIR') # args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR') args.image_folder = os.environ.get('SM_CHANNEL_TRAINING') except: logger.debug("not SageMaker") pass IMAGE_SIZE = args.image_size IMAGE_PATH = args.image_folder EPOCHS = args.EPOCHS BATCH_SIZE = args.BATCH_SIZE LEARNING_RATE = args.LEARNING_RATE LR_DECAY_RATE = args.LR_DECAY_RATE NUM_TOKENS = args.NUM_TOKENS NUM_LAYERS = args.NUM_LAYERS NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS EMB_DIM = args.EMB_DIM HID_DIM = args.HID_DIM KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT STARTING_TEMP = args.STARTING_TEMP TEMP_MIN = args.TEMP_MIN ANNEAL_RATE = args.ANNEAL_RATE NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE # transform = Compose( # [ # RandomResizedCrop(args.image_size, args.image_size), # OneOf( # [ # IAAAdditiveGaussianNoise(), # GaussNoise(), # ], # p=0.2 # ), # VerticalFlip(p=0.5), # OneOf( # [ # MotionBlur(p=.2), # MedianBlur(blur_limit=3, p=0.1), # Blur(blur_limit=3, p=0.1), # ], # p=0.2 # ), # OneOf( # [ # CLAHE(clip_limit=2), # IAASharpen(), # IAAEmboss(), # RandomBrightnessContrast(), # ], # p=0.3 # ), # HueSaturationValue(p=0.3), # # Normalize( # # mean=[0.485, 0.456, 0.406], # # std=[0.229, 0.224, 0.225], # # ) # ], # p=1.0 # ) transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize(IMAGE_SIZE), T.CenterCrop(IMAGE_SIZE), T.ToTensor() ]) sampler = None dl = None # data logger.debug(f"IMAGE_PATH : {IMAGE_PATH}") # ds = AlbumentationImageDataset( # IMAGE_PATH, # transform=transform, # args=args # ) ds = ImageFolder( IMAGE_PATH, transform=transform, ) if args.model_parallel and (args.ddp or args.horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } ds = SplitDataset(ds, partitions=partitions_dict) ds.select(f"{smp.dp_rank()}") dl = DataLoader(ds, BATCH_SIZE, shuffle=True, drop_last=args.model_parallel, **args.kwargs) vae_params = dict(image_size=IMAGE_SIZE, num_layers=NUM_LAYERS, num_tokens=NUM_TOKENS, codebook_dim=EMB_DIM, hidden_dim=HID_DIM, num_resnet_blocks=NUM_RESNET_BLOCKS) vae = DiscreteVAE(**vae_params, smooth_l1_loss=SMOOTH_L1_LOSS, kl_div_loss_weight=KL_LOSS_WEIGHT).to(device) # optimizer opt = Adam(vae.parameters(), lr=LEARNING_RATE) sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE) if args.model_parallel: import copy dummy_codebook = copy.deepcopy(vae.codebook) dummy_decoder = copy.deepcopy(vae.decoder) vae = smp.DistributedModel(vae) scaler = smp.amp.GradScaler() opt = smp.DistributedOptimizer(opt) if args.partial_checkpoint: args.checkpoint = smp.load(args.partial_checkpoint, partial=True) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: args.checkpoint = smp.load(args.full_checkpoint, partial=False) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) assert len(ds) > 0, 'folder does not contain any images' if (not args.model_parallel) and args.rank == 0: print(f'{len(ds)} images found for training') # weights & biases experiment tracking # import wandb model_config = dict(num_tokens=NUM_TOKENS, smooth_l1_loss=SMOOTH_L1_LOSS, num_resnet_blocks=NUM_RESNET_BLOCKS, kl_loss_weight=KL_LOSS_WEIGHT) # run = wandb.init( # project = 'dalle_train_vae', # job_type = 'train_model', # config = model_config # ) def save_model(path): if not args.rank == 0: return save_obj = {'hparams': vae_params, 'weights': vae.state_dict()} torch.save(save_obj, path) # distribute with deepspeed if not args.model_parallel: deepspeed_utils.check_batch_size(BATCH_SIZE) deepspeed_config = {'train_batch_size': BATCH_SIZE} (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute( args=args, model=vae, optimizer=opt, model_parameters=vae.parameters(), training_data=ds if args.deepspeed else dl, lr_scheduler=sched, config_params=deepspeed_config, ) try: # Rubik: Define smp.step. Return any tensors needed outside. @smp.step def train_step(vae, images, temp): # logger.debug(f"args.amp : {args.amp}") with autocast(enabled=(args.amp > 0)): loss, recons = vae(images, return_loss=True, return_recons=True, temp=temp) scaled_loss = scaler.scale(loss) if args.amp else loss vae.backward(scaled_loss) # torch.nn.utils.clip_grad_norm_(vae.parameters(), 5) return loss, recons @smp.step def get_codes_step(vae, images, k): images = images[:k] logits = vae.forward(images, return_logits=True) codebook_indices = logits.argmax(dim=1).flatten(1) return codebook_indices def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices): from functools import partial for module in dummy_codebook.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) image_embeds = dummy_codebook.forward(codebook_indices) b, n, d = image_embeds.shape h = w = int(sqrt(n)) image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h=h, w=w) for module in dummy_decoder.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) hard_recons = dummy_decoder.forward(image_embeds) return hard_recons except: pass # starting temperature global_step = 0 temp = STARTING_TEMP for epoch in range(EPOCHS): ## batch_time = util.AverageMeter('Time', ':6.3f') data_time = util.AverageMeter('Data', ':6.3f') losses = util.AverageMeter('Loss', ':.4e') top1 = util.AverageMeter('Acc@1', ':6.2f') top5 = util.AverageMeter('Acc@5', ':6.2f') progress = util.ProgressMeter( len(dl), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) vae.train() start = time.time() for i, (images, _) in enumerate(dl): images = images.to(device, non_blocking=True) opt.zero_grad() if args.model_parallel: loss, recons = train_step(vae, images, temp) # Rubik: Average the loss across microbatches. loss = loss.reduce_mean() recons = recons.reduce_mean() else: loss, recons = distr_vae(images, return_loss=True, return_recons=True, temp=temp) if (not args.model_parallel) and args.deepspeed: # Gradients are automatically zeroed after the step distr_vae.backward(loss) distr_vae.step() elif args.model_parallel: if args.amp: scaler.step(opt) scaler.update() else: # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param if len(list(vae.local_parameters())) > 0: opt.step() else: loss.backward() opt.step() logs = {} if i % 10 == 0: if args.rank == 0: # if deepspeed_utils.is_root_worker(): k = NUM_IMAGES_SAVE with torch.no_grad(): if args.model_parallel: model_dict = vae.state_dict() model_dict_updated = {} for key, val in model_dict.items(): if "decoder" in key: key = key.replace("decoder.", "") elif "codebook" in key: key = key.replace("codebook.", "") model_dict_updated[key] = val dummy_decoder.load_state_dict(model_dict_updated, strict=False) dummy_codebook.load_state_dict(model_dict_updated, strict=False) codes = get_codes_step(vae, images, k) codes = codes.reduce_mean().to(torch.long) hard_recons = hard_recons_step( dummy_decoder, dummy_codebook, codes) else: codes = vae.get_codebook_indices(images[:k]) hard_recons = vae.decode(codes) images, recons = map(lambda t: t[:k], (images, recons)) images, recons, hard_recons, codes = map( lambda t: t.detach().cpu(), (images, recons, hard_recons, codes)) images, recons, hard_recons = map( lambda t: make_grid(t.float(), nrow=int(sqrt(k)), normalize=True, range=(-1, 1)), (images, recons, hard_recons)) # logs = { # **logs, # 'sample images': wandb.Image(images, caption = 'original images'), # 'reconstructions': wandb.Image(recons, caption = 'reconstructions'), # 'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'), # 'codebook_indices': wandb.Histogram(codes), # 'temperature': temp # } if args.model_parallel: filename = f'{args.model_dir}/vae.pt' if smp.dp_rank == 0: if args.save_full_model: model_dict = vae.state_dict() opt_dict = opt.state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=False, ) else: model_dict = vae.local_state_dict() opt_dict = opt.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=True, ) smp.barrier() else: save_model(f'{args.model_dir}/vae.pt') # wandb.save(f'{args.model_dir}/vae.pt') # temperature anneal temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN) # lr decay sched.step() # Collective loss, averaged if args.model_parallel: avg_loss = loss.detach().clone() # print("args.world_size : {}".format(args.world_size)) avg_loss /= args.world_size else: avg_loss = deepspeed_utils.average_all(loss) if args.rank == 0: if i % 100 == 0: lr = sched.get_last_lr()[0] print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},') logs = { **logs, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item(), 'lr': lr } # wandb.log(logs) global_step += 1 if args.rank == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy # prec1, prec5 = util.accuracy(output, target, topk=(1, 5)) # to_python_float incurs a host<->device sync losses.update(util.to_python_float(loss), images.size(0)) # top1.update(util.to_python_float(prec1), images.size(0)) # top5.update(util.to_python_float(prec5), images.size(0)) # Waiting until finishing operations on GPU (Pytorch default: async) torch.cuda.synchronize() batch_time.update((time.time() - start) / args.log_interval) end = time.time() print( 'Epoch: [{0}][{1}/{2}] ' 'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, ' 'Train_Speed={3:.3f} ({4:.3f}), ' 'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format( epoch, i, len(dl), args.world_size * BATCH_SIZE / batch_time.val, args.world_size * BATCH_SIZE / batch_time.avg, batch_time=batch_time, loss=losses)) # if deepspeed_utils.is_root_worker(): # save trained model to wandb as an artifact every epoch's end # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae.pt') # run.log_artifact(model_artifact) if args.rank == 0: # if deepspeed_utils.is_root_worker(): # save final vae and cleanup if args.model_parallel: logger.debug('save model_parallel') else: save_model(os.path.join(args.model_dir, 'vae-final.pt')) # wandb.save(f'{args.model_dir}/vae-final.pt') # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae-final.pt') # run.log_artifact(model_artifact) # wandb.finish() if args.model_parallel: if args.assert_losses: if args.horovod or args.ddp: # SM Distributed: If using data parallelism, gather all losses across different model # replicas and check if losses match. losses = smp.allgather(loss, smp.DP_GROUP) for l in losses: print(l) assert math.isclose(l, losses[0]) assert loss < 0.18 else: assert loss < 0.08 smp.barrier() print("SMP training finished successfully")