def get_vae(args): vae = DiscreteVAE(image_size=args.size, num_layers=args.vae_layers, num_tokens=8192, codebook_dim=args.codebook_dims, num_resnet_blocks=9, hidden_dim=128, temperature=args.temperature) if args.vae is not None and os.path.isfile(args.vae): print(f"loading state dict from {args.vae}") vae.load_state_dict(torch.load(args.vae)) vae.to(args.device) return vae
from torchvision.utils import save_image from torch.utils.data import DataLoader from dalle_pytorch import DiscreteVAE imgSize = 256 load_epoch = 280 vae = DiscreteVAE(image_size=imgSize, num_layers=3, channels=3, num_tokens=2048, codebook_dim=1024, hidden_dim=128) vae_dict = torch.load("./models/dvae-" + str(load_epoch) + ".pth") vae.load_state_dict(vae_dict) vae.cuda() batchSize = 12 n_epochs = 500 log_interval = 20 #images = torch.randn(4, 3, 256, 256) t = transforms.Compose([ transforms.Resize(imgSize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) #(0.267, 0.233, 0.234)) ])
vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass() dalle_params = dict(**dalle_params) IMAGE_SIZE = vae.image_size else: if exists(VAE_PATH): vae_path = Path(VAE_PATH) assert vae_path.exists(), 'VAE model file does not exist' loaded_obj = torch.load(str(vae_path)) vae_params, weights = loaded_obj['hparams'], loaded_obj['weights'] vae = DiscreteVAE(**vae_params) vae.load_state_dict(weights) else: if distr_backend.is_root_worker(): print('using pretrained VAE for encoding images to tokens') vae_params = None vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass() IMAGE_SIZE = vae.image_size dalle_params = dict(num_text_tokens=tokenizer.vocab_size, text_seq_len=TEXT_SEQ_LEN, dim=MODEL_DIM, depth=DEPTH, heads=HEADS,
vae = DiscreteVAE( image_size=IMAGE_SIZE, num_layers= 3, # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map) num_tokens= 8192, # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects codebook_dim=512, # codebook dimension hidden_dim=64, # hidden dimension num_resnet_blocks=1, # number of resnet blocks temperature= 0.9, # gumbel softmax temperature, the lower this is, the harder the discretization straight_through= False, # straight-through for gumbel softmax. unclear if it is better one way or the other ).cuda() vae.load_state_dict(torch.load("Vae-small.pth")) dalle = DALLE( dim=1024, vae= vae, # automatically infer (1) image sequence length and (2) number of image tokens num_text_tokens=NUM_TOKENS, # vocab size for text text_seq_len=TEXTSEQLEN, # text sequence length depth=12, # should aim to be 64 heads=16, # attention heads dim_head=64, # attention head dimension attn_dropout=0.1, # attention dropout ff_dropout=0.1 # feedforward dropout ).cuda() dalle.load_state_dict(torch.load("dalle-small.pth"))
def main(args): # constants VAE_PATH = args.vae_path DALLE_PATH = args.dalle_path RESUME = exists(DALLE_PATH) EPOCHS = args.epochs BATCH_SIZE = args.batch_size LEARNING_RATE = args.learning_rate GRAD_CLIP_NORM = args.clip_grad_norm LR_DECAY = args.lr_decay MODEL_DIM = args.dim TEXT_SEQ_LEN = args.text_seq_len DEPTH = args.depth HEADS = args.heads DIM_HEAD = args.dim_head REVERSIBLE = args.reversible LOSS_IMG_WEIGHT = args.loss_img_weight ATTN_TYPES = tuple(args.attn_types.split(',')) DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt' # initialize distributed backend # initialize distributed backend if args.sagemakermp: args.deepspeed = False using_deepspeed = False else: args.deepspeed = True distr_backend = distributed_utils.set_backend_from_args(args) distr_backend.initialize(args) if args.sagemakermp: args = smp_init(args) distributed_utils.using_backend(distributed_utils.SageMakerMPBackend) else: using_deepspeed = \ distributed_utils.using_backend(distributed_utils.DeepSpeedBackend) args.rank = int(os.environ.get('RANK')) args.world_size = int(os.environ.get('WORLD_SIZE')) args.local_rank = int(os.environ.get('LOCAL_RANK')) args.global_rank = args.rank logger.debug(f"using_deepspeed : {using_deepspeed}") logger.debug( f"args.local_rank : {args.local_rank}, args.rank : {args.rank}") # tokenizer logger.debug(f"exists(args.bpe_path) : {exists(args.bpe_path)}, args.chinese : {args.chinese}") if exists(args.bpe_path): klass = HugTokenizer if args.hug else YttmTokenizer tokenizer = klass(args.bpe_path) elif args.chinese: tokenizer = ChineseTokenizer() else: tokenizer = SimpleTokenizer() # reconstitute vae if RESUME: dalle_path = Path(DALLE_PATH) if using_deepspeed: cp_dir = cp_path_to_dir(dalle_path, 'ds') assert cp_dir.is_dir(), \ f'DeepSpeed checkpoint directory {cp_dir} not found' dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME else: assert dalle_path.exists(), 'DALL-E model file does not exist' loaded_obj = torch.load(str(dalle_path), map_location='cpu') dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj[ 'vae_params'], loaded_obj['weights'] if vae_params is not None: vae = DiscreteVAE(**vae_params) else: vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) dalle_params = dict(**dalle_params) IMAGE_SIZE = vae.image_size else: if exists(VAE_PATH): vae_path = Path(VAE_PATH) assert vae_path.exists(), 'VAE model file does not exist' assert not vae_path.is_dir(), \ ('Cannot load VAE model from directory; please use a ' 'standard *.pt checkpoint. ' 'Currently, merging a DeepSpeed-partitioned VAE into a DALLE ' 'model is not supported.') loaded_obj = torch.load(str(vae_path)) vae_params, weights = loaded_obj['hparams'], loaded_obj['weights'] vae = DiscreteVAE(**vae_params) vae.load_state_dict(weights) else: if args.rank == 0: # if distr_backend.is_root_worker(): print('using pretrained VAE for encoding images to tokens') vae_params = None logger.debug(f"************* args.taming : {args.taming}") vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) IMAGE_SIZE = vae.image_size dalle_params = dict( num_text_tokens=tokenizer.vocab_size, text_seq_len=TEXT_SEQ_LEN, dim=MODEL_DIM, depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD, reversible=REVERSIBLE, loss_img_weight=LOSS_IMG_WEIGHT, attn_types=ATTN_TYPES, ) # configure OpenAI VAE for float16s if isinstance(vae, OpenAIDiscreteVAE) and args.fp16: vae.enc.blocks.output.conv.use_float16 = True # create dataset and dataloader is_shuffle = not distributed_utils.using_backend( distributed_utils.HorovodBackend) ds = TextImageDataset( args.image_text_folder, text_len=TEXT_SEQ_LEN, image_size=IMAGE_SIZE, resize_ratio=args.resize_ratio, truncate_captions=args.truncate_captions, tokenizer=tokenizer, shuffle=is_shuffle, ) assert len(ds) > 0, 'dataset is empty' # if distr_backend.is_root_worker(): if args.rank == 0: print(f'{len(ds)} image-text pairs found for training') if not is_shuffle: data_sampler = torch.utils.data.distributed.DistributedSampler( ds, num_replicas=args.world_size, rank=args.rank) elif args.sagemakermp: args.ds = ds ds = split_dataset(args) data_sampler = None else: data_sampler = None print(f"data_sampler : {data_sampler}") # uncorrectable NVLink error was detected during the execution --> remove kwargs = {'num_workers': args.num_worker, 'pin_memory': True} dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler, **kwargs ) logger.info("Processes {}/{} ({:.0f}%) of train data".format( len(dl.sampler), len(dl.dataset), 100. * len(dl.sampler) / len(dl.dataset))) # initialize DALL-E dalle = DALLE(vae=vae, **dalle_params) if not using_deepspeed: if args.fp16: dalle = dalle.half() dalle = dalle.cuda() if RESUME and not using_deepspeed: dalle.load_state_dict(weights) # optimizer opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE) if LR_DECAY: scheduler = ReduceLROnPlateau( opt, mode="min", factor=0.5, patience=10, cooldown=10, min_lr=1e-6, verbose=True, ) # if distr_backend.is_root_worker(): if args.global_rank == 0: # experiment tracker model_config = dict(depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD) logger.debug(f"args.wandb_name : {args.wandb_name}, RESUME : {RESUME}") run = wandb.init( project=args.wandb_name, # 'dalle_train_transformer' by default resume=RESUME, config=model_config, ) # distribute distr_backend.check_batch_size(BATCH_SIZE) deepspeed_config = { 'train_batch_size': BATCH_SIZE, 'gradient_clipping': GRAD_CLIP_NORM, 'fp16': { 'enabled': args.fp16, }, } (distr_dalle, distr_opt, distr_dl, distr_scheduler) = distr_backend.distribute( args=args, model=dalle, optimizer=opt, model_parameters=get_trainable_params(dalle), training_data=ds if using_deepspeed else dl, lr_scheduler=scheduler if LR_DECAY else None, config_params=deepspeed_config, ) avoid_model_calls = using_deepspeed and args.fp16 if args.sagemakermp: args.distr_dalle = smp.DistributedModel(distr_dalle) args.scaler = smp.amp.GradScaler() args.distr_opt = smp.DistributedOptimizer(distr_opt) if RESUME and using_deepspeed: distr_dalle.load_checkpoint(str(cp_dir)) # training for epoch in range(EPOCHS): logger.debug(f"********* epoch : {epoch} **********") if data_sampler: data_sampler.set_epoch(epoch) for i, (text, images) in enumerate(distr_dl): if args.fp16: images = images.half() text, images = map(lambda t: t.cuda(), (text, images)) if args.sagemakermp: args.distr_opt.zero_grad() loss = train_step(args, text, images, return_loss=True) loss = loss.reduce_mean() else: loss = distr_dalle(text, images, return_loss=True, args=args) if using_deepspeed: distr_dalle.backward(loss) distr_dalle.step() # Gradients are automatically zeroed after the step elif args.sagemakermp: if args.amp: scaler.step(args.distr_opt) scaler.update() else: # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param if len(list(args.distr_dalle.local_parameters())) > 0: args.distr_opt.step() else: loss.backward() clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM) distr_opt.step() distr_opt.zero_grad() # Collective loss, averaged avg_loss = distr_backend.average_all(loss) log = {} # if i % 10 == 0 and distr_backend.is_root_worker(): if i % 10 == 0 and args.rank == 0: print(epoch, i, f'loss - {avg_loss.item()}') log = { **log, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item() } if i % 100 == 0: # if distr_backend.is_root_worker(): if args.rank == 0: sample_text = text[:1] token_list = sample_text.masked_select( sample_text != 0).tolist() decoded_text = tokenizer.decode(token_list) logger.debug(f"******* avoid_model_calls : {avoid_model_calls}") if not avoid_model_calls: # CUDA index errors when we don't guard this image = dalle.generate_images( text[:1], filter_thres=0.9) # topk sampling at 0.9 wandb.save(f'./dalle.pt') log = { **log, } if not avoid_model_calls: log['image'] = wandb.Image(image, caption=decoded_text) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f'{args.model_dir}/dalle.pt') # if distr_backend.is_root_worker(): if args.rank == 0: wandb.log(log) # text, images = prefetcher.next() if LR_DECAY and not using_deepspeed: # Scheduler is automatically progressed after the step when # using DeepSpeed. distr_scheduler.step(loss) # if distr_backend.is_root_worker(): if args.global_rank == 0: # save trained model to wandb as an artifact every epoch's end model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) # model_artifact.add_file('dalle.pt') run.log_artifact(model_artifact) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f'{args.model_dir}/dalle-final.pt') # if distr_backend.is_root_worker(): if args.global_rank == 0: wandb.save('./dalle-final.pt') model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) # model_artifact.add_file('dalle-final.pt') run.log_artifact(model_artifact) wandb.finish()
def main(args): # constants print(f"torch.cuda.nccl.version() : {torch.cuda.nccl.version()}") VAE_PATH = args.vae_path DALLE_PATH = args.dalle_path RESUME = exists(DALLE_PATH) EPOCHS = args.epochs BATCH_SIZE = args.batch_size LEARNING_RATE = args.learning_rate GRAD_CLIP_NORM = args.clip_grad_norm LR_DECAY = args.lr_decay SAVE_EVERY_N_STEPS = args.save_every_n_steps MODEL_DIM = args.dim TEXT_SEQ_LEN = args.text_seq_len DEPTH = args.depth HEADS = args.heads DIM_HEAD = args.dim_head REVERSIBLE = args.reversible LOSS_IMG_WEIGHT = args.loss_img_weight ATTN_TYPES = tuple(args.attn_types.split(',')) DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt' DALLE_OUTPUT_FILE_NAME = args.dalle_output_file_name # initialize distributed backend args.deepspeed = True distr_backend = distributed_utils.set_backend_from_args(args) distr_backend.initialize(args) using_deepspeed = \ distributed_utils.using_backend(distributed_utils.DeepSpeedBackend) args.rank = int(os.environ.get('RANK')) args.world_size = int(os.environ.get('WORLD_SIZE')) args.local_rank = int(os.environ.get('LOCAL_RANK')) args.global_rank = args.rank logger.debug(f"using_deepspeed : {using_deepspeed}") logger.debug( f"args.local_rank : {args.local_rank}, args.rank : {args.rank}") print( f"********* torch.distributed.get_rank() : {torch.distributed.get_rank()}" ) # tokenizer logger.debug( f"exists(args.bpe_path) : {exists(args.bpe_path)}, args.chinese : {args.chinese}" ) # if args.local_rank == 0 or args.local_rank == 1: # # print(f"args.job_name : {args.job_name}") # gpu_mon_thread = aws_util.GPUMon(device_index= args.local_rank, job_name=args.job_name) # gpu_mon_thread.start() if exists(args.bpe_path): klass = HugTokenizer if args.hug else YttmTokenizer tokenizer = klass(args.bpe_path) elif args.chinese: tokenizer = ChineseTokenizer() else: tokenizer = SimpleTokenizer() # reconstitute vae if RESUME: dalle_path = Path(DALLE_PATH) if using_deepspeed: cp_dir = cp_path_to_dir(dalle_path, 'ds') assert cp_dir.is_dir(), \ f'DeepSpeed checkpoint directory {cp_dir} not found' dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME else: assert dalle_path.exists(), 'DALL-E model file does not exist' loaded_obj = torch.load(str(dalle_path), map_location='cpu') dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj[ 'vae_params'], loaded_obj['weights'] if vae_params is not None: vae = DiscreteVAE(**vae_params) else: vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) dalle_params = dict(**dalle_params) IMAGE_SIZE = vae.image_size else: if exists(VAE_PATH): vae_path = Path(VAE_PATH) assert vae_path.exists(), 'VAE model file does not exist' assert not vae_path.is_dir(), \ ('Cannot load VAE model from directory; please use a ' 'standard *.pt checkpoint. ' 'Currently, merging a DeepSpeed-partitioned VAE into a DALLE ' 'model is not supported.') loaded_obj = torch.load(str(vae_path)) vae_params, weights = loaded_obj['hparams'], loaded_obj['weights'] vae = DiscreteVAE(**vae_params) vae.load_state_dict(weights) else: if args.rank == 0: # if distr_backend.is_root_worker(): print('using pretrained VAE for encoding images to tokens') vae_params = None logger.debug(f"************* args.taming : {args.taming}") vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) IMAGE_SIZE = vae.image_size dalle_params = dict( num_text_tokens=tokenizer.vocab_size, text_seq_len=TEXT_SEQ_LEN, dim=MODEL_DIM, depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD, reversible=REVERSIBLE, loss_img_weight=LOSS_IMG_WEIGHT, attn_types=ATTN_TYPES, ) # configure OpenAI VAE for float16s if isinstance(vae, OpenAIDiscreteVAE) and args.fp16: vae.enc.blocks.output.conv.use_float16 = True # create dataset and dataloader is_shuffle = not distributed_utils.using_backend( distributed_utils.HorovodBackend) ds = TextImageDataset( args.image_text_folder, text_len=TEXT_SEQ_LEN, image_size=IMAGE_SIZE, resize_ratio=args.resize_ratio, truncate_captions=args.truncate_captions, tokenizer=tokenizer, shuffle=is_shuffle, ) assert len(ds) > 0, 'dataset is empty' # if distr_backend.is_root_worker(): if args.rank == 0: print(f'{len(ds)} image-text pairs found for training') if not is_shuffle: data_sampler = torch.utils.data.distributed.DistributedSampler( ds, num_replicas=args.world_size, rank=args.rank) else: data_sampler = None kwargs = {'num_workers': args.num_worker, 'pin_memory': True} dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler, **kwargs) logger.info("Processes {}/{} ({:.0f}%) of train data".format( len(dl.sampler), len(dl.dataset), 100. * len(dl.sampler) / len(dl.dataset))) # initialize DALL-E dalle = DALLE(vae=vae, **dalle_params) if not using_deepspeed: if args.fp16: dalle = dalle.half() dalle = dalle.cuda() if RESUME and not using_deepspeed: dalle.load_state_dict(weights) # optimizer opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE) if LR_DECAY: scheduler = ReduceLROnPlateau( opt, mode="min", factor=0.5, patience=10, cooldown=10, min_lr=1e-6, verbose=True, ) # if distr_backend.is_root_worker(): if args.rank == 0: # experiment tracker model_config = dict(depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD) # wandb_dir = '/tmp/wandb' # if not os.path.exists(wandb_dir): # os.makedirs(wandb_dir) run = wandb.init( project=args.wandb_name, # 'dalle_train_transformer' by default resume=RESUME, config=model_config, # dir=wandb_dir ) # distribute distr_backend.check_batch_size(BATCH_SIZE) deepspeed_config = { 'train_batch_size': BATCH_SIZE, 'gradient_clipping': GRAD_CLIP_NORM, 'fp16': { 'enabled': args.fp16, }, } (distr_dalle, distr_opt, distr_dl, distr_scheduler) = distr_backend.distribute( args=args, model=dalle, optimizer=opt, model_parameters=get_trainable_params(dalle), training_data=ds if using_deepspeed else dl, lr_scheduler=scheduler if LR_DECAY else None, config_params=deepspeed_config, ) avoid_model_calls = using_deepspeed and args.fp16 if RESUME and using_deepspeed: distr_dalle.load_checkpoint(str(cp_dir)) # training for epoch in range(EPOCHS): logger.debug(f"********* epoch : {epoch} **********") if data_sampler: data_sampler.set_epoch(epoch) for i, (text, images) in enumerate(distr_dl): if i % 10 == 0 and args.rank == 0: t = time.time() if args.fp16: images = images.half() text, images = map(lambda t: t.cuda(), (text, images)) loss = distr_dalle(text, images, return_loss=True) if using_deepspeed: distr_dalle.backward(loss) distr_dalle.step() # Gradients are automatically zeroed after the step else: loss.backward() clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM) distr_opt.step() distr_opt.zero_grad() # Collective loss, averaged avg_loss = distr_backend.average_all(loss) log = {} # if i % 10 == 0 and distr_backend.is_root_worker(): if i % 10 == 0 and args.rank == 0: print(epoch, i, f'loss - {avg_loss.item()}') log = { **log, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item() } if i % SAVE_EVERY_N_STEPS == 0: args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}") if i % 100 == 0: # if distr_backend.is_root_worker(): if args.rank == 0: sample_text = text[:1] token_list = sample_text.masked_select( sample_text != 0).tolist() decoded_text = tokenizer.decode(token_list) if not avoid_model_calls: # CUDA index errors when we don't guard this image = dalle.generate_images( text[:1], filter_thres=0.9) # topk sampling at 0.9 log = { **log, } if not avoid_model_calls: log['image'] = wandb.Image(image, caption=decoded_text) if i % 10 == 9 and args.rank == 0: sample_per_sec = BATCH_SIZE * 10 / (time.time() - t) log["sample_per_sec"] = sample_per_sec print(epoch, i, f'sample_per_sec - {sample_per_sec}') # if distr_backend.is_root_worker(): if args.rank == 0: wandb.log(log) if LR_DECAY and not using_deepspeed: # Scheduler is automatically progressed after the step when # using DeepSpeed. distr_scheduler.step(loss) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}") # sync_local_checkpoints_to_s3(local_path=f'{args.model_dir}', s3_path='s3://lgaivision-coco-usva/Dalle_Model/tmd/') sync_local_checkpoints_to_s3( args.model_dir, os.path.join(args.output_s3, args.job_name + "/temp")) # if distr_backend.is_root_worker(): if args.rank == 0: # save trained model to wandb as an artifact every epoch's end model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) import glob print(f"************** file : {glob.glob(args.model_dir+'/*')}") try: print(f"wandb.run.dir : {wandb.run.dir}") print( f"************** file wandb: {glob.glob(wandb.run.dir+'/*')}" ) except: pass model_artifact.add_file( f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}") run.log_artifact(model_artifact) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME resource_check(args) save_model(args, f"{args.model_dir +'/'+DALLE_OUTPUT_FILE_NAME}") if args.rank == 0: # from distutils.dir_util import copy_tree # copy_tree(f'{args.model_dir}', f'{args.model_dir_last}') sync_local_checkpoints_to_s3( args.model_dir, os.path.join(args.output_s3, args.job_name + "/temp")) resource_check(args) # if distr_backend.is_root_worker(): if args.rank == 0: wandb.save(f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}", base_path=args.model_dir) model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) model_artifact.add_file(f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}") run.log_artifact(model_artifact) wandb.finish() # if args.local_rank == 0 or args.local_rank == 1: # gpu_mon_thread.kill() distributed_utils.backend.local_barrier() print("************************ Finished *************************")
logger.debug(f"args.local_rank : {args.local_rank}") if args.local_rank is not None: torch.cuda.set_device(args.local_rank) else: torch.cuda.set_device(0) if args.multigpus_distributed: vae.cuda(args.local_rank) if args.model_parallel: vae = smp.DistributedModel(vae) args.scaler = smp.amp.GradScaler() opt = smp.DistributedOptimizer(opt) if args.partial_checkpoint: args.checkpoint = smp.load(args.partial_checkpoint, partial=True) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: args.checkpoint = smp.load(args.full_checkpoint, partial=False) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) else: vae = vae.cuda() else: vae = vae.cuda() assert len(ds) > 0, 'folder does not contain any images' if (not args.model_parallel) and deepspeed_utils.is_root_worker(): print(f'{len(ds)} images found for training')
def main(): parser = get_parser() args = parser.parse_args() if not torch.cuda.is_available(): raise ValueError( "The script requires CUDA support, but CUDA not available") args.rank = -1 args.world_size = 1 if args.model_parallel: args.deepspeed = False cfg = { "microbatches": args.num_microbatches, "placement_strategy": args.placement_strategy, "pipeline": args.pipeline, "optimize": args.optimize, "partitions": args.num_partitions, "horovod": args.horovod, "ddp": args.ddp, } smp.init(cfg) torch.cuda.set_device(smp.local_rank()) args.rank = smp.dp_rank() args.world_size = smp.size() else: # initialize deepspeed print(f"args.deepspeed : {args.deepspeed}") deepspeed_utils.init_deepspeed(args.deepspeed) if deepspeed_utils.is_root_worker(): args.rank = 0 if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed + args.rank) np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size) cudnn.deterministic = True if cudnn.deterministic: warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True} device = torch.device("cuda") logger.debug(f"args.image_folder : {args.image_folder}") logger.debug(f"args.rank : {args.rank}") ## SageMaker try: if os.environ.get('SM_MODEL_DIR') is not None: args.model_dir = os.environ.get('SM_MODEL_DIR') # args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR') args.image_folder = os.environ.get('SM_CHANNEL_TRAINING') except: logger.debug("not SageMaker") pass IMAGE_SIZE = args.image_size IMAGE_PATH = args.image_folder EPOCHS = args.EPOCHS BATCH_SIZE = args.BATCH_SIZE LEARNING_RATE = args.LEARNING_RATE LR_DECAY_RATE = args.LR_DECAY_RATE NUM_TOKENS = args.NUM_TOKENS NUM_LAYERS = args.NUM_LAYERS NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS EMB_DIM = args.EMB_DIM HID_DIM = args.HID_DIM KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT STARTING_TEMP = args.STARTING_TEMP TEMP_MIN = args.TEMP_MIN ANNEAL_RATE = args.ANNEAL_RATE NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE # transform = Compose( # [ # RandomResizedCrop(args.image_size, args.image_size), # OneOf( # [ # IAAAdditiveGaussianNoise(), # GaussNoise(), # ], # p=0.2 # ), # VerticalFlip(p=0.5), # OneOf( # [ # MotionBlur(p=.2), # MedianBlur(blur_limit=3, p=0.1), # Blur(blur_limit=3, p=0.1), # ], # p=0.2 # ), # OneOf( # [ # CLAHE(clip_limit=2), # IAASharpen(), # IAAEmboss(), # RandomBrightnessContrast(), # ], # p=0.3 # ), # HueSaturationValue(p=0.3), # # Normalize( # # mean=[0.485, 0.456, 0.406], # # std=[0.229, 0.224, 0.225], # # ) # ], # p=1.0 # ) transform = T.Compose([ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), T.Resize(IMAGE_SIZE), T.CenterCrop(IMAGE_SIZE), T.ToTensor() ]) sampler = None dl = None # data logger.debug(f"IMAGE_PATH : {IMAGE_PATH}") # ds = AlbumentationImageDataset( # IMAGE_PATH, # transform=transform, # args=args # ) ds = ImageFolder( IMAGE_PATH, transform=transform, ) if args.model_parallel and (args.ddp or args.horovod) and smp.dp_size() > 1: partitions_dict = { f"{i}": 1 / smp.dp_size() for i in range(smp.dp_size()) } ds = SplitDataset(ds, partitions=partitions_dict) ds.select(f"{smp.dp_rank()}") dl = DataLoader(ds, BATCH_SIZE, shuffle=True, drop_last=args.model_parallel, **args.kwargs) vae_params = dict(image_size=IMAGE_SIZE, num_layers=NUM_LAYERS, num_tokens=NUM_TOKENS, codebook_dim=EMB_DIM, hidden_dim=HID_DIM, num_resnet_blocks=NUM_RESNET_BLOCKS) vae = DiscreteVAE(**vae_params, smooth_l1_loss=SMOOTH_L1_LOSS, kl_div_loss_weight=KL_LOSS_WEIGHT).to(device) # optimizer opt = Adam(vae.parameters(), lr=LEARNING_RATE) sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE) if args.model_parallel: import copy dummy_codebook = copy.deepcopy(vae.codebook) dummy_decoder = copy.deepcopy(vae.decoder) vae = smp.DistributedModel(vae) scaler = smp.amp.GradScaler() opt = smp.DistributedOptimizer(opt) if args.partial_checkpoint: args.checkpoint = smp.load(args.partial_checkpoint, partial=True) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) elif args.full_checkpoint: args.checkpoint = smp.load(args.full_checkpoint, partial=False) vae.load_state_dict(args.checkpoint["model_state_dict"]) opt.load_state_dict(args.checkpoint["optimizer_state_dict"]) assert len(ds) > 0, 'folder does not contain any images' if (not args.model_parallel) and args.rank == 0: print(f'{len(ds)} images found for training') # weights & biases experiment tracking # import wandb model_config = dict(num_tokens=NUM_TOKENS, smooth_l1_loss=SMOOTH_L1_LOSS, num_resnet_blocks=NUM_RESNET_BLOCKS, kl_loss_weight=KL_LOSS_WEIGHT) # run = wandb.init( # project = 'dalle_train_vae', # job_type = 'train_model', # config = model_config # ) def save_model(path): if not args.rank == 0: return save_obj = {'hparams': vae_params, 'weights': vae.state_dict()} torch.save(save_obj, path) # distribute with deepspeed if not args.model_parallel: deepspeed_utils.check_batch_size(BATCH_SIZE) deepspeed_config = {'train_batch_size': BATCH_SIZE} (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute( args=args, model=vae, optimizer=opt, model_parameters=vae.parameters(), training_data=ds if args.deepspeed else dl, lr_scheduler=sched, config_params=deepspeed_config, ) try: # Rubik: Define smp.step. Return any tensors needed outside. @smp.step def train_step(vae, images, temp): # logger.debug(f"args.amp : {args.amp}") with autocast(enabled=(args.amp > 0)): loss, recons = vae(images, return_loss=True, return_recons=True, temp=temp) scaled_loss = scaler.scale(loss) if args.amp else loss vae.backward(scaled_loss) # torch.nn.utils.clip_grad_norm_(vae.parameters(), 5) return loss, recons @smp.step def get_codes_step(vae, images, k): images = images[:k] logits = vae.forward(images, return_logits=True) codebook_indices = logits.argmax(dim=1).flatten(1) return codebook_indices def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices): from functools import partial for module in dummy_codebook.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) image_embeds = dummy_codebook.forward(codebook_indices) b, n, d = image_embeds.shape h = w = int(sqrt(n)) image_embeds = rearrange(image_embeds, 'b (h w) d -> b d h w', h=h, w=w) for module in dummy_decoder.modules(): method = smp_state.patch_manager.get_original_method( "forward", type(module)) module.forward = partial(method, module) hard_recons = dummy_decoder.forward(image_embeds) return hard_recons except: pass # starting temperature global_step = 0 temp = STARTING_TEMP for epoch in range(EPOCHS): ## batch_time = util.AverageMeter('Time', ':6.3f') data_time = util.AverageMeter('Data', ':6.3f') losses = util.AverageMeter('Loss', ':.4e') top1 = util.AverageMeter('Acc@1', ':6.2f') top5 = util.AverageMeter('Acc@5', ':6.2f') progress = util.ProgressMeter( len(dl), [batch_time, data_time, losses, top1, top5], prefix="Epoch: [{}]".format(epoch)) vae.train() start = time.time() for i, (images, _) in enumerate(dl): images = images.to(device, non_blocking=True) opt.zero_grad() if args.model_parallel: loss, recons = train_step(vae, images, temp) # Rubik: Average the loss across microbatches. loss = loss.reduce_mean() recons = recons.reduce_mean() else: loss, recons = distr_vae(images, return_loss=True, return_recons=True, temp=temp) if (not args.model_parallel) and args.deepspeed: # Gradients are automatically zeroed after the step distr_vae.backward(loss) distr_vae.step() elif args.model_parallel: if args.amp: scaler.step(opt) scaler.update() else: # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param if len(list(vae.local_parameters())) > 0: opt.step() else: loss.backward() opt.step() logs = {} if i % 10 == 0: if args.rank == 0: # if deepspeed_utils.is_root_worker(): k = NUM_IMAGES_SAVE with torch.no_grad(): if args.model_parallel: model_dict = vae.state_dict() model_dict_updated = {} for key, val in model_dict.items(): if "decoder" in key: key = key.replace("decoder.", "") elif "codebook" in key: key = key.replace("codebook.", "") model_dict_updated[key] = val dummy_decoder.load_state_dict(model_dict_updated, strict=False) dummy_codebook.load_state_dict(model_dict_updated, strict=False) codes = get_codes_step(vae, images, k) codes = codes.reduce_mean().to(torch.long) hard_recons = hard_recons_step( dummy_decoder, dummy_codebook, codes) else: codes = vae.get_codebook_indices(images[:k]) hard_recons = vae.decode(codes) images, recons = map(lambda t: t[:k], (images, recons)) images, recons, hard_recons, codes = map( lambda t: t.detach().cpu(), (images, recons, hard_recons, codes)) images, recons, hard_recons = map( lambda t: make_grid(t.float(), nrow=int(sqrt(k)), normalize=True, range=(-1, 1)), (images, recons, hard_recons)) # logs = { # **logs, # 'sample images': wandb.Image(images, caption = 'original images'), # 'reconstructions': wandb.Image(recons, caption = 'reconstructions'), # 'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'), # 'codebook_indices': wandb.Histogram(codes), # 'temperature': temp # } if args.model_parallel: filename = f'{args.model_dir}/vae.pt' if smp.dp_rank == 0: if args.save_full_model: model_dict = vae.state_dict() opt_dict = opt.state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=False, ) else: model_dict = vae.local_state_dict() opt_dict = opt.local_state_dict() smp.save( { "model_state_dict": model_dict, "optimizer_state_dict": opt_dict }, filename, partial=True, ) smp.barrier() else: save_model(f'{args.model_dir}/vae.pt') # wandb.save(f'{args.model_dir}/vae.pt') # temperature anneal temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN) # lr decay sched.step() # Collective loss, averaged if args.model_parallel: avg_loss = loss.detach().clone() # print("args.world_size : {}".format(args.world_size)) avg_loss /= args.world_size else: avg_loss = deepspeed_utils.average_all(loss) if args.rank == 0: if i % 100 == 0: lr = sched.get_last_lr()[0] print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},') logs = { **logs, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item(), 'lr': lr } # wandb.log(logs) global_step += 1 if args.rank == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy # prec1, prec5 = util.accuracy(output, target, topk=(1, 5)) # to_python_float incurs a host<->device sync losses.update(util.to_python_float(loss), images.size(0)) # top1.update(util.to_python_float(prec1), images.size(0)) # top5.update(util.to_python_float(prec5), images.size(0)) # Waiting until finishing operations on GPU (Pytorch default: async) torch.cuda.synchronize() batch_time.update((time.time() - start) / args.log_interval) end = time.time() print( 'Epoch: [{0}][{1}/{2}] ' 'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, ' 'Train_Speed={3:.3f} ({4:.3f}), ' 'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format( epoch, i, len(dl), args.world_size * BATCH_SIZE / batch_time.val, args.world_size * BATCH_SIZE / batch_time.avg, batch_time=batch_time, loss=losses)) # if deepspeed_utils.is_root_worker(): # save trained model to wandb as an artifact every epoch's end # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae.pt') # run.log_artifact(model_artifact) if args.rank == 0: # if deepspeed_utils.is_root_worker(): # save final vae and cleanup if args.model_parallel: logger.debug('save model_parallel') else: save_model(os.path.join(args.model_dir, 'vae-final.pt')) # wandb.save(f'{args.model_dir}/vae-final.pt') # model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config)) # model_artifact.add_file(f'{args.model_dir}/vae-final.pt') # run.log_artifact(model_artifact) # wandb.finish() if args.model_parallel: if args.assert_losses: if args.horovod or args.ddp: # SM Distributed: If using data parallelism, gather all losses across different model # replicas and check if losses match. losses = smp.allgather(loss, smp.DP_GROUP) for l in losses: print(l) assert math.isclose(l, losses[0]) assert loss < 0.18 else: assert loss < 0.08 smp.barrier() print("SMP training finished successfully")