def main(args): # constants VAE_PATH = args.vae_path DALLE_PATH = args.dalle_path RESUME = exists(DALLE_PATH) EPOCHS = args.epochs BATCH_SIZE = args.batch_size LEARNING_RATE = args.learning_rate GRAD_CLIP_NORM = args.clip_grad_norm LR_DECAY = args.lr_decay MODEL_DIM = args.dim TEXT_SEQ_LEN = args.text_seq_len DEPTH = args.depth HEADS = args.heads DIM_HEAD = args.dim_head REVERSIBLE = args.reversible LOSS_IMG_WEIGHT = args.loss_img_weight ATTN_TYPES = tuple(args.attn_types.split(',')) DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt' # initialize distributed backend # initialize distributed backend if args.sagemakermp: args.deepspeed = False using_deepspeed = False else: args.deepspeed = True distr_backend = distributed_utils.set_backend_from_args(args) distr_backend.initialize(args) if args.sagemakermp: args = smp_init(args) distributed_utils.using_backend(distributed_utils.SageMakerMPBackend) else: using_deepspeed = \ distributed_utils.using_backend(distributed_utils.DeepSpeedBackend) args.rank = int(os.environ.get('RANK')) args.world_size = int(os.environ.get('WORLD_SIZE')) args.local_rank = int(os.environ.get('LOCAL_RANK')) args.global_rank = args.rank logger.debug(f"using_deepspeed : {using_deepspeed}") logger.debug( f"args.local_rank : {args.local_rank}, args.rank : {args.rank}") # tokenizer logger.debug(f"exists(args.bpe_path) : {exists(args.bpe_path)}, args.chinese : {args.chinese}") if exists(args.bpe_path): klass = HugTokenizer if args.hug else YttmTokenizer tokenizer = klass(args.bpe_path) elif args.chinese: tokenizer = ChineseTokenizer() else: tokenizer = SimpleTokenizer() # reconstitute vae if RESUME: dalle_path = Path(DALLE_PATH) if using_deepspeed: cp_dir = cp_path_to_dir(dalle_path, 'ds') assert cp_dir.is_dir(), \ f'DeepSpeed checkpoint directory {cp_dir} not found' dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME else: assert dalle_path.exists(), 'DALL-E model file does not exist' loaded_obj = torch.load(str(dalle_path), map_location='cpu') dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj[ 'vae_params'], loaded_obj['weights'] if vae_params is not None: vae = DiscreteVAE(**vae_params) else: vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) dalle_params = dict(**dalle_params) IMAGE_SIZE = vae.image_size else: if exists(VAE_PATH): vae_path = Path(VAE_PATH) assert vae_path.exists(), 'VAE model file does not exist' assert not vae_path.is_dir(), \ ('Cannot load VAE model from directory; please use a ' 'standard *.pt checkpoint. ' 'Currently, merging a DeepSpeed-partitioned VAE into a DALLE ' 'model is not supported.') loaded_obj = torch.load(str(vae_path)) vae_params, weights = loaded_obj['hparams'], loaded_obj['weights'] vae = DiscreteVAE(**vae_params) vae.load_state_dict(weights) else: if args.rank == 0: # if distr_backend.is_root_worker(): print('using pretrained VAE for encoding images to tokens') vae_params = None logger.debug(f"************* args.taming : {args.taming}") vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) IMAGE_SIZE = vae.image_size dalle_params = dict( num_text_tokens=tokenizer.vocab_size, text_seq_len=TEXT_SEQ_LEN, dim=MODEL_DIM, depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD, reversible=REVERSIBLE, loss_img_weight=LOSS_IMG_WEIGHT, attn_types=ATTN_TYPES, ) # configure OpenAI VAE for float16s if isinstance(vae, OpenAIDiscreteVAE) and args.fp16: vae.enc.blocks.output.conv.use_float16 = True # create dataset and dataloader is_shuffle = not distributed_utils.using_backend( distributed_utils.HorovodBackend) ds = TextImageDataset( args.image_text_folder, text_len=TEXT_SEQ_LEN, image_size=IMAGE_SIZE, resize_ratio=args.resize_ratio, truncate_captions=args.truncate_captions, tokenizer=tokenizer, shuffle=is_shuffle, ) assert len(ds) > 0, 'dataset is empty' # if distr_backend.is_root_worker(): if args.rank == 0: print(f'{len(ds)} image-text pairs found for training') if not is_shuffle: data_sampler = torch.utils.data.distributed.DistributedSampler( ds, num_replicas=args.world_size, rank=args.rank) elif args.sagemakermp: args.ds = ds ds = split_dataset(args) data_sampler = None else: data_sampler = None print(f"data_sampler : {data_sampler}") # uncorrectable NVLink error was detected during the execution --> remove kwargs = {'num_workers': args.num_worker, 'pin_memory': True} dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler, **kwargs ) logger.info("Processes {}/{} ({:.0f}%) of train data".format( len(dl.sampler), len(dl.dataset), 100. * len(dl.sampler) / len(dl.dataset))) # initialize DALL-E dalle = DALLE(vae=vae, **dalle_params) if not using_deepspeed: if args.fp16: dalle = dalle.half() dalle = dalle.cuda() if RESUME and not using_deepspeed: dalle.load_state_dict(weights) # optimizer opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE) if LR_DECAY: scheduler = ReduceLROnPlateau( opt, mode="min", factor=0.5, patience=10, cooldown=10, min_lr=1e-6, verbose=True, ) # if distr_backend.is_root_worker(): if args.global_rank == 0: # experiment tracker model_config = dict(depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD) logger.debug(f"args.wandb_name : {args.wandb_name}, RESUME : {RESUME}") run = wandb.init( project=args.wandb_name, # 'dalle_train_transformer' by default resume=RESUME, config=model_config, ) # distribute distr_backend.check_batch_size(BATCH_SIZE) deepspeed_config = { 'train_batch_size': BATCH_SIZE, 'gradient_clipping': GRAD_CLIP_NORM, 'fp16': { 'enabled': args.fp16, }, } (distr_dalle, distr_opt, distr_dl, distr_scheduler) = distr_backend.distribute( args=args, model=dalle, optimizer=opt, model_parameters=get_trainable_params(dalle), training_data=ds if using_deepspeed else dl, lr_scheduler=scheduler if LR_DECAY else None, config_params=deepspeed_config, ) avoid_model_calls = using_deepspeed and args.fp16 if args.sagemakermp: args.distr_dalle = smp.DistributedModel(distr_dalle) args.scaler = smp.amp.GradScaler() args.distr_opt = smp.DistributedOptimizer(distr_opt) if RESUME and using_deepspeed: distr_dalle.load_checkpoint(str(cp_dir)) # training for epoch in range(EPOCHS): logger.debug(f"********* epoch : {epoch} **********") if data_sampler: data_sampler.set_epoch(epoch) for i, (text, images) in enumerate(distr_dl): if args.fp16: images = images.half() text, images = map(lambda t: t.cuda(), (text, images)) if args.sagemakermp: args.distr_opt.zero_grad() loss = train_step(args, text, images, return_loss=True) loss = loss.reduce_mean() else: loss = distr_dalle(text, images, return_loss=True, args=args) if using_deepspeed: distr_dalle.backward(loss) distr_dalle.step() # Gradients are automatically zeroed after the step elif args.sagemakermp: if args.amp: scaler.step(args.distr_opt) scaler.update() else: # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param if len(list(args.distr_dalle.local_parameters())) > 0: args.distr_opt.step() else: loss.backward() clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM) distr_opt.step() distr_opt.zero_grad() # Collective loss, averaged avg_loss = distr_backend.average_all(loss) log = {} # if i % 10 == 0 and distr_backend.is_root_worker(): if i % 10 == 0 and args.rank == 0: print(epoch, i, f'loss - {avg_loss.item()}') log = { **log, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item() } if i % 100 == 0: # if distr_backend.is_root_worker(): if args.rank == 0: sample_text = text[:1] token_list = sample_text.masked_select( sample_text != 0).tolist() decoded_text = tokenizer.decode(token_list) logger.debug(f"******* avoid_model_calls : {avoid_model_calls}") if not avoid_model_calls: # CUDA index errors when we don't guard this image = dalle.generate_images( text[:1], filter_thres=0.9) # topk sampling at 0.9 wandb.save(f'./dalle.pt') log = { **log, } if not avoid_model_calls: log['image'] = wandb.Image(image, caption=decoded_text) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f'{args.model_dir}/dalle.pt') # if distr_backend.is_root_worker(): if args.rank == 0: wandb.log(log) # text, images = prefetcher.next() if LR_DECAY and not using_deepspeed: # Scheduler is automatically progressed after the step when # using DeepSpeed. distr_scheduler.step(loss) # if distr_backend.is_root_worker(): if args.global_rank == 0: # save trained model to wandb as an artifact every epoch's end model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) # model_artifact.add_file('dalle.pt') run.log_artifact(model_artifact) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f'{args.model_dir}/dalle-final.pt') # if distr_backend.is_root_worker(): if args.global_rank == 0: wandb.save('./dalle-final.pt') model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) # model_artifact.add_file('dalle-final.pt') run.log_artifact(model_artifact) wandb.finish()
num_replicas=distr_backend.get_world_size(), rank=distr_backend.get_rank()) else: data_sampler = None dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=not data_sampler, drop_last=True, sampler=data_sampler) # initialize DALL-E dalle = DALLE(vae=vae, **dalle_params) if args.fp16: dalle = dalle.half() dalle = dalle.cuda() if RESUME: dalle.load_state_dict(weights) # optimizer opt = Adam(dalle.parameters(), lr=LEARNING_RATE) if LR_DECAY: scheduler = ReduceLROnPlateau( opt, mode="min", factor=0.5, patience=10,
def main(args): # constants print(f"torch.cuda.nccl.version() : {torch.cuda.nccl.version()}") VAE_PATH = args.vae_path DALLE_PATH = args.dalle_path RESUME = exists(DALLE_PATH) EPOCHS = args.epochs BATCH_SIZE = args.batch_size LEARNING_RATE = args.learning_rate GRAD_CLIP_NORM = args.clip_grad_norm LR_DECAY = args.lr_decay SAVE_EVERY_N_STEPS = args.save_every_n_steps MODEL_DIM = args.dim TEXT_SEQ_LEN = args.text_seq_len DEPTH = args.depth HEADS = args.heads DIM_HEAD = args.dim_head REVERSIBLE = args.reversible LOSS_IMG_WEIGHT = args.loss_img_weight ATTN_TYPES = tuple(args.attn_types.split(',')) DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt' DALLE_OUTPUT_FILE_NAME = args.dalle_output_file_name # initialize distributed backend args.deepspeed = True distr_backend = distributed_utils.set_backend_from_args(args) distr_backend.initialize(args) using_deepspeed = \ distributed_utils.using_backend(distributed_utils.DeepSpeedBackend) args.rank = int(os.environ.get('RANK')) args.world_size = int(os.environ.get('WORLD_SIZE')) args.local_rank = int(os.environ.get('LOCAL_RANK')) args.global_rank = args.rank logger.debug(f"using_deepspeed : {using_deepspeed}") logger.debug( f"args.local_rank : {args.local_rank}, args.rank : {args.rank}") print( f"********* torch.distributed.get_rank() : {torch.distributed.get_rank()}" ) # tokenizer logger.debug( f"exists(args.bpe_path) : {exists(args.bpe_path)}, args.chinese : {args.chinese}" ) # if args.local_rank == 0 or args.local_rank == 1: # # print(f"args.job_name : {args.job_name}") # gpu_mon_thread = aws_util.GPUMon(device_index= args.local_rank, job_name=args.job_name) # gpu_mon_thread.start() if exists(args.bpe_path): klass = HugTokenizer if args.hug else YttmTokenizer tokenizer = klass(args.bpe_path) elif args.chinese: tokenizer = ChineseTokenizer() else: tokenizer = SimpleTokenizer() # reconstitute vae if RESUME: dalle_path = Path(DALLE_PATH) if using_deepspeed: cp_dir = cp_path_to_dir(dalle_path, 'ds') assert cp_dir.is_dir(), \ f'DeepSpeed checkpoint directory {cp_dir} not found' dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME else: assert dalle_path.exists(), 'DALL-E model file does not exist' loaded_obj = torch.load(str(dalle_path), map_location='cpu') dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj[ 'vae_params'], loaded_obj['weights'] if vae_params is not None: vae = DiscreteVAE(**vae_params) else: vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) dalle_params = dict(**dalle_params) IMAGE_SIZE = vae.image_size else: if exists(VAE_PATH): vae_path = Path(VAE_PATH) assert vae_path.exists(), 'VAE model file does not exist' assert not vae_path.is_dir(), \ ('Cannot load VAE model from directory; please use a ' 'standard *.pt checkpoint. ' 'Currently, merging a DeepSpeed-partitioned VAE into a DALLE ' 'model is not supported.') loaded_obj = torch.load(str(vae_path)) vae_params, weights = loaded_obj['hparams'], loaded_obj['weights'] vae = DiscreteVAE(**vae_params) vae.load_state_dict(weights) else: if args.rank == 0: # if distr_backend.is_root_worker(): print('using pretrained VAE for encoding images to tokens') vae_params = None logger.debug(f"************* args.taming : {args.taming}") vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) IMAGE_SIZE = vae.image_size dalle_params = dict( num_text_tokens=tokenizer.vocab_size, text_seq_len=TEXT_SEQ_LEN, dim=MODEL_DIM, depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD, reversible=REVERSIBLE, loss_img_weight=LOSS_IMG_WEIGHT, attn_types=ATTN_TYPES, ) # configure OpenAI VAE for float16s if isinstance(vae, OpenAIDiscreteVAE) and args.fp16: vae.enc.blocks.output.conv.use_float16 = True # create dataset and dataloader is_shuffle = not distributed_utils.using_backend( distributed_utils.HorovodBackend) ds = TextImageDataset( args.image_text_folder, text_len=TEXT_SEQ_LEN, image_size=IMAGE_SIZE, resize_ratio=args.resize_ratio, truncate_captions=args.truncate_captions, tokenizer=tokenizer, shuffle=is_shuffle, ) assert len(ds) > 0, 'dataset is empty' # if distr_backend.is_root_worker(): if args.rank == 0: print(f'{len(ds)} image-text pairs found for training') if not is_shuffle: data_sampler = torch.utils.data.distributed.DistributedSampler( ds, num_replicas=args.world_size, rank=args.rank) else: data_sampler = None kwargs = {'num_workers': args.num_worker, 'pin_memory': True} dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler, **kwargs) logger.info("Processes {}/{} ({:.0f}%) of train data".format( len(dl.sampler), len(dl.dataset), 100. * len(dl.sampler) / len(dl.dataset))) # initialize DALL-E dalle = DALLE(vae=vae, **dalle_params) if not using_deepspeed: if args.fp16: dalle = dalle.half() dalle = dalle.cuda() if RESUME and not using_deepspeed: dalle.load_state_dict(weights) # optimizer opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE) if LR_DECAY: scheduler = ReduceLROnPlateau( opt, mode="min", factor=0.5, patience=10, cooldown=10, min_lr=1e-6, verbose=True, ) # if distr_backend.is_root_worker(): if args.rank == 0: # experiment tracker model_config = dict(depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD) # wandb_dir = '/tmp/wandb' # if not os.path.exists(wandb_dir): # os.makedirs(wandb_dir) run = wandb.init( project=args.wandb_name, # 'dalle_train_transformer' by default resume=RESUME, config=model_config, # dir=wandb_dir ) # distribute distr_backend.check_batch_size(BATCH_SIZE) deepspeed_config = { 'train_batch_size': BATCH_SIZE, 'gradient_clipping': GRAD_CLIP_NORM, 'fp16': { 'enabled': args.fp16, }, } (distr_dalle, distr_opt, distr_dl, distr_scheduler) = distr_backend.distribute( args=args, model=dalle, optimizer=opt, model_parameters=get_trainable_params(dalle), training_data=ds if using_deepspeed else dl, lr_scheduler=scheduler if LR_DECAY else None, config_params=deepspeed_config, ) avoid_model_calls = using_deepspeed and args.fp16 if RESUME and using_deepspeed: distr_dalle.load_checkpoint(str(cp_dir)) # training for epoch in range(EPOCHS): logger.debug(f"********* epoch : {epoch} **********") if data_sampler: data_sampler.set_epoch(epoch) for i, (text, images) in enumerate(distr_dl): if i % 10 == 0 and args.rank == 0: t = time.time() if args.fp16: images = images.half() text, images = map(lambda t: t.cuda(), (text, images)) loss = distr_dalle(text, images, return_loss=True) if using_deepspeed: distr_dalle.backward(loss) distr_dalle.step() # Gradients are automatically zeroed after the step else: loss.backward() clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM) distr_opt.step() distr_opt.zero_grad() # Collective loss, averaged avg_loss = distr_backend.average_all(loss) log = {} # if i % 10 == 0 and distr_backend.is_root_worker(): if i % 10 == 0 and args.rank == 0: print(epoch, i, f'loss - {avg_loss.item()}') log = { **log, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item() } if i % SAVE_EVERY_N_STEPS == 0: args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}") if i % 100 == 0: # if distr_backend.is_root_worker(): if args.rank == 0: sample_text = text[:1] token_list = sample_text.masked_select( sample_text != 0).tolist() decoded_text = tokenizer.decode(token_list) if not avoid_model_calls: # CUDA index errors when we don't guard this image = dalle.generate_images( text[:1], filter_thres=0.9) # topk sampling at 0.9 log = { **log, } if not avoid_model_calls: log['image'] = wandb.Image(image, caption=decoded_text) if i % 10 == 9 and args.rank == 0: sample_per_sec = BATCH_SIZE * 10 / (time.time() - t) log["sample_per_sec"] = sample_per_sec print(epoch, i, f'sample_per_sec - {sample_per_sec}') # if distr_backend.is_root_worker(): if args.rank == 0: wandb.log(log) if LR_DECAY and not using_deepspeed: # Scheduler is automatically progressed after the step when # using DeepSpeed. distr_scheduler.step(loss) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}") # sync_local_checkpoints_to_s3(local_path=f'{args.model_dir}', s3_path='s3://lgaivision-coco-usva/Dalle_Model/tmd/') sync_local_checkpoints_to_s3( args.model_dir, os.path.join(args.output_s3, args.job_name + "/temp")) # if distr_backend.is_root_worker(): if args.rank == 0: # save trained model to wandb as an artifact every epoch's end model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) import glob print(f"************** file : {glob.glob(args.model_dir+'/*')}") try: print(f"wandb.run.dir : {wandb.run.dir}") print( f"************** file wandb: {glob.glob(wandb.run.dir+'/*')}" ) except: pass model_artifact.add_file( f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}") run.log_artifact(model_artifact) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME resource_check(args) save_model(args, f"{args.model_dir +'/'+DALLE_OUTPUT_FILE_NAME}") if args.rank == 0: # from distutils.dir_util import copy_tree # copy_tree(f'{args.model_dir}', f'{args.model_dir_last}') sync_local_checkpoints_to_s3( args.model_dir, os.path.join(args.output_s3, args.job_name + "/temp")) resource_check(args) # if distr_backend.is_root_worker(): if args.rank == 0: wandb.save(f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}", base_path=args.model_dir) model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) model_artifact.add_file(f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}") run.log_artifact(model_artifact) wandb.finish() # if args.local_rank == 0 or args.local_rank == 1: # gpu_mon_thread.kill() distributed_utils.backend.local_barrier() print("************************ Finished *************************")