def generate_images(captions): print(sum(captions, start=[])) all_words = list(sorted(frozenset(sum(captions, start=[])))) word_tokens = dict(zip(all_words, range(1, len(all_words) + 1))) caption_tokens = [[word_tokens[w] for w in c] for c in captions] logging.info(f"{all_words =}") logging.info(f"{word_tokens =}") logging.info(f"{caption_tokens =}") longest_caption = max(len(c) for c in captions) captions_array = np.zeros((len(caption_tokens), longest_caption), dtype=np.int64) for i in range(len(caption_tokens)): captions_array[i, :len(caption_tokens[i])] = caption_tokens[i] # captions_array = torch.from_numpy(captions_array).cuda() captions_array = torch.from_numpy(captions_array) captions_mask = captions_array != 0 logging.info(f"{captions_array = }") dalle = DALLE( dim=1024, vae= vae, # automatically infer (1) image sequence length and (2) number of image tokens num_text_tokens=len(word_tokens) + 1, # vocab size for text text_seq_len=longest_caption, # text sequence length depth=12, # should aim to be 64 heads=16, # attention heads dim_head=64, # attention head dimension attn_dropout=0.1, # attention dropout ff_dropout=0.1, # feedforward dropout ) # .cuda generated_image_codes = [] with torch.no_grad(): for i in range(0, len(captions), 128): generated = generate_image_code( dalle, captions_array[i:i + 128, ...], mask=captions_mask[i:i + 128, ...], ) generated_image_codes.append(generated) generated_image_codes = torch.cat(generated_image_codes, axis=0) with torch.no_grad(): generated_images = vae.decode(generated_image_codes) logging.info(f"{generated_images = }")
def get_dalle(vae, vocab, args): dalle = DALLE(dim=args.codebook_dims, vae=vae, num_text_tokens=len(vocab) + 1, text_seq_len=len(vocab), depth=16, heads=8, dim_head=64, attn_dropout=0.1, ff_dropout=0.1, reversible=True) if args.dalle is not None and os.path.isfile(args.dalle): print(f"loading state dict from {args.dalle}") dalle.load_state_dict(torch.load(args.dalle)) dalle.to(args.device) vae.to(args.device) return dalle
data_sampler = torch.utils.data.distributed.DistributedSampler( ds, num_replicas=distr_backend.get_world_size(), rank=distr_backend.get_rank()) else: data_sampler = None dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=not data_sampler, drop_last=True, sampler=data_sampler) # initialize DALL-E dalle = DALLE(vae=vae, **dalle_params) if args.fp16: dalle = dalle.half() dalle = dalle.cuda() if RESUME: dalle.load_state_dict(weights) # optimizer opt = Adam(dalle.parameters(), lr=LEARNING_RATE) if LR_DECAY: scheduler = ReduceLROnPlateau( opt, mode="min",
assert dalle_path.exists(), 'trained DALL-E must exist' load_obj = torch.load(str(dalle_path)) dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights') dalle_params.pop('vae', None) # cleanup later if vae_params is not None: vae = DiscreteVAE(**vae_params) elif not args.taming: vae = OpenAIDiscreteVAE() else: vae = VQGanVAE1024() dalle = DALLE(vae = vae, **dalle_params).cuda() dalle.load_state_dict(weights) # generate images image_size = vae.image_size texts = args.text.split('|') for text in tqdm(texts): text = tokenizer.tokenize([args.text], dalle.text_seq_len).cuda() text = repeat(text, '() n -> b n', b = args.num_images) outputs = []
hidden_dim=64, # hidden dimension num_resnet_blocks=1, # number of resnet blocks temperature= 0.9, # gumbel softmax temperature, the lower this is, the harder the discretization straight_through= False, # straight-through for gumbel softmax. unclear if it is better one way or the other ).cuda() vae.load_state_dict(torch.load("Vae-small.pth")) dalle = DALLE( dim=1024, vae= vae, # automatically infer (1) image sequence length and (2) number of image tokens num_text_tokens=NUM_TOKENS, # vocab size for text text_seq_len=TEXTSEQLEN, # text sequence length depth=12, # should aim to be 64 heads=16, # attention heads dim_head=64, # attention head dimension attn_dropout=0.1, # attention dropout ff_dropout=0.1 # feedforward dropout ).cuda() dalle.load_state_dict(torch.load("dalle-small.pth")) """ text = torch.randint(0, NUM_TOKENS, (BATCH_SIZE, TEXTSEQLEN)) images = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE) mask = torch.ones_like(text).bool() """ tokenDset = token_dataset('./coco/merged-smallsample.txt')
# create dataset and dataloader ds = TextImageDataset( args.image_text_folder, text_len = TEXT_SEQ_LEN, image_size = IMAGE_SIZE ) assert len(ds) > 0, 'dataset is empty' print(f'{len(ds)} image-text pairs found for training') dl = DataLoader(ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = True) # initialize DALL-E dalle = DALLE(**dalle_params).cuda() if exists(args.dalle_path): dalle.load_state_dict(weights) # optimizer opt = Adam(dalle.parameters(), lr = LEARNING_RATE) # experiment tracker import wandb wandb.config.depth = DEPTH wandb.config.heads = HEADS wandb.config.dim_head = DIM_HEAD
EPOCHS * DATASET_SIZE)) loss = vae(img, return_recon_loss=True) VAEloss.append(loss.cpu().detach().numpy()) loss.backward() optimizerVAE.step() np.savetxt("vaeloss.csv", np.asarray(VAEloss), delimiter=",") torch.save(vae.state_dict(), "Vae-small.pth") dalle = DALLE( dim=1024, vae= vae, # automatically infer (1) image sequence length and (2) number of image tokens num_text_tokens=NUM_TOKENS, # vocab size for text text_seq_len=TEXTSEQLEN, # text sequence length depth=12, # should aim to be 64 heads=16, # attention heads dim_head=64, # attention head dimension attn_dropout=0.1, # attention dropout ff_dropout=0.1 # feedforward dropout ).cuda() optimizerDALLE = torch.optim.Adam(dalle.parameters(), lr=learning_rate) DALLEloss = [] for epoch in range(EPOCHS): for i in range(DATASET_SIZE): #print(i,":",tokenDset.getRand(i),img.size()) optimizerDALLE.zero_grad() img, strs = cap[i] #print(img.size())
data_sampler = None if ENABLE_WEBDATASET: # WebLoader for WebDataset and DeepSpeed compatibility dl = wds.WebLoader(ds, batch_size=None, shuffle=False) # optionally add num_workers=2 (n) argument number_of_batches = DATASET_SIZE // (BATCH_SIZE * distr_backend.get_world_size()) dl = dl.repeat(2).slice(number_of_batches) dl.length = number_of_batches else: # Regular DataLoader for image-text-folder datasets dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler) # initialize DALL-E dalle = DALLE(vae=vae, **dalle_params) if not using_deepspeed: if args.fp16: dalle = dalle.half() dalle = dalle.cuda() if RESUME and not using_deepspeed: dalle.load_state_dict(weights) # optimizer opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE) if RESUME and opt_state: opt.load_state_dict(opt_state) if LR_DECAY:
def main(args): # constants VAE_PATH = args.vae_path DALLE_PATH = args.dalle_path RESUME = exists(DALLE_PATH) EPOCHS = args.epochs BATCH_SIZE = args.batch_size LEARNING_RATE = args.learning_rate GRAD_CLIP_NORM = args.clip_grad_norm LR_DECAY = args.lr_decay MODEL_DIM = args.dim TEXT_SEQ_LEN = args.text_seq_len DEPTH = args.depth HEADS = args.heads DIM_HEAD = args.dim_head REVERSIBLE = args.reversible LOSS_IMG_WEIGHT = args.loss_img_weight ATTN_TYPES = tuple(args.attn_types.split(',')) DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt' # initialize distributed backend # initialize distributed backend if args.sagemakermp: args.deepspeed = False using_deepspeed = False else: args.deepspeed = True distr_backend = distributed_utils.set_backend_from_args(args) distr_backend.initialize(args) if args.sagemakermp: args = smp_init(args) distributed_utils.using_backend(distributed_utils.SageMakerMPBackend) else: using_deepspeed = \ distributed_utils.using_backend(distributed_utils.DeepSpeedBackend) args.rank = int(os.environ.get('RANK')) args.world_size = int(os.environ.get('WORLD_SIZE')) args.local_rank = int(os.environ.get('LOCAL_RANK')) args.global_rank = args.rank logger.debug(f"using_deepspeed : {using_deepspeed}") logger.debug( f"args.local_rank : {args.local_rank}, args.rank : {args.rank}") # tokenizer logger.debug(f"exists(args.bpe_path) : {exists(args.bpe_path)}, args.chinese : {args.chinese}") if exists(args.bpe_path): klass = HugTokenizer if args.hug else YttmTokenizer tokenizer = klass(args.bpe_path) elif args.chinese: tokenizer = ChineseTokenizer() else: tokenizer = SimpleTokenizer() # reconstitute vae if RESUME: dalle_path = Path(DALLE_PATH) if using_deepspeed: cp_dir = cp_path_to_dir(dalle_path, 'ds') assert cp_dir.is_dir(), \ f'DeepSpeed checkpoint directory {cp_dir} not found' dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME else: assert dalle_path.exists(), 'DALL-E model file does not exist' loaded_obj = torch.load(str(dalle_path), map_location='cpu') dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj[ 'vae_params'], loaded_obj['weights'] if vae_params is not None: vae = DiscreteVAE(**vae_params) else: vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) dalle_params = dict(**dalle_params) IMAGE_SIZE = vae.image_size else: if exists(VAE_PATH): vae_path = Path(VAE_PATH) assert vae_path.exists(), 'VAE model file does not exist' assert not vae_path.is_dir(), \ ('Cannot load VAE model from directory; please use a ' 'standard *.pt checkpoint. ' 'Currently, merging a DeepSpeed-partitioned VAE into a DALLE ' 'model is not supported.') loaded_obj = torch.load(str(vae_path)) vae_params, weights = loaded_obj['hparams'], loaded_obj['weights'] vae = DiscreteVAE(**vae_params) vae.load_state_dict(weights) else: if args.rank == 0: # if distr_backend.is_root_worker(): print('using pretrained VAE for encoding images to tokens') vae_params = None logger.debug(f"************* args.taming : {args.taming}") vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) IMAGE_SIZE = vae.image_size dalle_params = dict( num_text_tokens=tokenizer.vocab_size, text_seq_len=TEXT_SEQ_LEN, dim=MODEL_DIM, depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD, reversible=REVERSIBLE, loss_img_weight=LOSS_IMG_WEIGHT, attn_types=ATTN_TYPES, ) # configure OpenAI VAE for float16s if isinstance(vae, OpenAIDiscreteVAE) and args.fp16: vae.enc.blocks.output.conv.use_float16 = True # create dataset and dataloader is_shuffle = not distributed_utils.using_backend( distributed_utils.HorovodBackend) ds = TextImageDataset( args.image_text_folder, text_len=TEXT_SEQ_LEN, image_size=IMAGE_SIZE, resize_ratio=args.resize_ratio, truncate_captions=args.truncate_captions, tokenizer=tokenizer, shuffle=is_shuffle, ) assert len(ds) > 0, 'dataset is empty' # if distr_backend.is_root_worker(): if args.rank == 0: print(f'{len(ds)} image-text pairs found for training') if not is_shuffle: data_sampler = torch.utils.data.distributed.DistributedSampler( ds, num_replicas=args.world_size, rank=args.rank) elif args.sagemakermp: args.ds = ds ds = split_dataset(args) data_sampler = None else: data_sampler = None print(f"data_sampler : {data_sampler}") # uncorrectable NVLink error was detected during the execution --> remove kwargs = {'num_workers': args.num_worker, 'pin_memory': True} dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler, **kwargs ) logger.info("Processes {}/{} ({:.0f}%) of train data".format( len(dl.sampler), len(dl.dataset), 100. * len(dl.sampler) / len(dl.dataset))) # initialize DALL-E dalle = DALLE(vae=vae, **dalle_params) if not using_deepspeed: if args.fp16: dalle = dalle.half() dalle = dalle.cuda() if RESUME and not using_deepspeed: dalle.load_state_dict(weights) # optimizer opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE) if LR_DECAY: scheduler = ReduceLROnPlateau( opt, mode="min", factor=0.5, patience=10, cooldown=10, min_lr=1e-6, verbose=True, ) # if distr_backend.is_root_worker(): if args.global_rank == 0: # experiment tracker model_config = dict(depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD) logger.debug(f"args.wandb_name : {args.wandb_name}, RESUME : {RESUME}") run = wandb.init( project=args.wandb_name, # 'dalle_train_transformer' by default resume=RESUME, config=model_config, ) # distribute distr_backend.check_batch_size(BATCH_SIZE) deepspeed_config = { 'train_batch_size': BATCH_SIZE, 'gradient_clipping': GRAD_CLIP_NORM, 'fp16': { 'enabled': args.fp16, }, } (distr_dalle, distr_opt, distr_dl, distr_scheduler) = distr_backend.distribute( args=args, model=dalle, optimizer=opt, model_parameters=get_trainable_params(dalle), training_data=ds if using_deepspeed else dl, lr_scheduler=scheduler if LR_DECAY else None, config_params=deepspeed_config, ) avoid_model_calls = using_deepspeed and args.fp16 if args.sagemakermp: args.distr_dalle = smp.DistributedModel(distr_dalle) args.scaler = smp.amp.GradScaler() args.distr_opt = smp.DistributedOptimizer(distr_opt) if RESUME and using_deepspeed: distr_dalle.load_checkpoint(str(cp_dir)) # training for epoch in range(EPOCHS): logger.debug(f"********* epoch : {epoch} **********") if data_sampler: data_sampler.set_epoch(epoch) for i, (text, images) in enumerate(distr_dl): if args.fp16: images = images.half() text, images = map(lambda t: t.cuda(), (text, images)) if args.sagemakermp: args.distr_opt.zero_grad() loss = train_step(args, text, images, return_loss=True) loss = loss.reduce_mean() else: loss = distr_dalle(text, images, return_loss=True, args=args) if using_deepspeed: distr_dalle.backward(loss) distr_dalle.step() # Gradients are automatically zeroed after the step elif args.sagemakermp: if args.amp: scaler.step(args.distr_opt) scaler.update() else: # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param if len(list(args.distr_dalle.local_parameters())) > 0: args.distr_opt.step() else: loss.backward() clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM) distr_opt.step() distr_opt.zero_grad() # Collective loss, averaged avg_loss = distr_backend.average_all(loss) log = {} # if i % 10 == 0 and distr_backend.is_root_worker(): if i % 10 == 0 and args.rank == 0: print(epoch, i, f'loss - {avg_loss.item()}') log = { **log, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item() } if i % 100 == 0: # if distr_backend.is_root_worker(): if args.rank == 0: sample_text = text[:1] token_list = sample_text.masked_select( sample_text != 0).tolist() decoded_text = tokenizer.decode(token_list) logger.debug(f"******* avoid_model_calls : {avoid_model_calls}") if not avoid_model_calls: # CUDA index errors when we don't guard this image = dalle.generate_images( text[:1], filter_thres=0.9) # topk sampling at 0.9 wandb.save(f'./dalle.pt') log = { **log, } if not avoid_model_calls: log['image'] = wandb.Image(image, caption=decoded_text) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f'{args.model_dir}/dalle.pt') # if distr_backend.is_root_worker(): if args.rank == 0: wandb.log(log) # text, images = prefetcher.next() if LR_DECAY and not using_deepspeed: # Scheduler is automatically progressed after the step when # using DeepSpeed. distr_scheduler.step(loss) # if distr_backend.is_root_worker(): if args.global_rank == 0: # save trained model to wandb as an artifact every epoch's end model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) # model_artifact.add_file('dalle.pt') run.log_artifact(model_artifact) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f'{args.model_dir}/dalle-final.pt') # if distr_backend.is_root_worker(): if args.global_rank == 0: wandb.save('./dalle-final.pt') model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) # model_artifact.add_file('dalle-final.pt') run.log_artifact(model_artifact) wandb.finish()
codebook_dim=256, hidden_dim=128, temperature=0.9) # load pretrained vae vae_dict = torch.load("./models/" + vaename + "-" + str(load_epoch) + ".pth") vae.load_state_dict(vae_dict) vae.to(device) dalle = DALLE( dim=256, #512, vae= vae, # automatically infer (1) image sequence length and (2) number of image tokens num_text_tokens=10000, # vocab size for text text_seq_len=256, # text sequence length depth=6, # should be 64 heads=8, # attention heads dim_head=64, # attention head dimension attn_dropout=0.1, # attention dropout ff_dropout=0.1 # feedforward dropout ) # load pretrained dalle if continuing training dalle_dict = torch.load(loadfn) dalle.load_state_dict(dalle_dict) dalle.to(device) # get image and text data
def main(args): # constants print(f"torch.cuda.nccl.version() : {torch.cuda.nccl.version()}") VAE_PATH = args.vae_path DALLE_PATH = args.dalle_path RESUME = exists(DALLE_PATH) EPOCHS = args.epochs BATCH_SIZE = args.batch_size LEARNING_RATE = args.learning_rate GRAD_CLIP_NORM = args.clip_grad_norm LR_DECAY = args.lr_decay SAVE_EVERY_N_STEPS = args.save_every_n_steps MODEL_DIM = args.dim TEXT_SEQ_LEN = args.text_seq_len DEPTH = args.depth HEADS = args.heads DIM_HEAD = args.dim_head REVERSIBLE = args.reversible LOSS_IMG_WEIGHT = args.loss_img_weight ATTN_TYPES = tuple(args.attn_types.split(',')) DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt' DALLE_OUTPUT_FILE_NAME = args.dalle_output_file_name # initialize distributed backend args.deepspeed = True distr_backend = distributed_utils.set_backend_from_args(args) distr_backend.initialize(args) using_deepspeed = \ distributed_utils.using_backend(distributed_utils.DeepSpeedBackend) args.rank = int(os.environ.get('RANK')) args.world_size = int(os.environ.get('WORLD_SIZE')) args.local_rank = int(os.environ.get('LOCAL_RANK')) args.global_rank = args.rank logger.debug(f"using_deepspeed : {using_deepspeed}") logger.debug( f"args.local_rank : {args.local_rank}, args.rank : {args.rank}") print( f"********* torch.distributed.get_rank() : {torch.distributed.get_rank()}" ) # tokenizer logger.debug( f"exists(args.bpe_path) : {exists(args.bpe_path)}, args.chinese : {args.chinese}" ) # if args.local_rank == 0 or args.local_rank == 1: # # print(f"args.job_name : {args.job_name}") # gpu_mon_thread = aws_util.GPUMon(device_index= args.local_rank, job_name=args.job_name) # gpu_mon_thread.start() if exists(args.bpe_path): klass = HugTokenizer if args.hug else YttmTokenizer tokenizer = klass(args.bpe_path) elif args.chinese: tokenizer = ChineseTokenizer() else: tokenizer = SimpleTokenizer() # reconstitute vae if RESUME: dalle_path = Path(DALLE_PATH) if using_deepspeed: cp_dir = cp_path_to_dir(dalle_path, 'ds') assert cp_dir.is_dir(), \ f'DeepSpeed checkpoint directory {cp_dir} not found' dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME else: assert dalle_path.exists(), 'DALL-E model file does not exist' loaded_obj = torch.load(str(dalle_path), map_location='cpu') dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj[ 'vae_params'], loaded_obj['weights'] if vae_params is not None: vae = DiscreteVAE(**vae_params) else: vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) dalle_params = dict(**dalle_params) IMAGE_SIZE = vae.image_size else: if exists(VAE_PATH): vae_path = Path(VAE_PATH) assert vae_path.exists(), 'VAE model file does not exist' assert not vae_path.is_dir(), \ ('Cannot load VAE model from directory; please use a ' 'standard *.pt checkpoint. ' 'Currently, merging a DeepSpeed-partitioned VAE into a DALLE ' 'model is not supported.') loaded_obj = torch.load(str(vae_path)) vae_params, weights = loaded_obj['hparams'], loaded_obj['weights'] vae = DiscreteVAE(**vae_params) vae.load_state_dict(weights) else: if args.rank == 0: # if distr_backend.is_root_worker(): print('using pretrained VAE for encoding images to tokens') vae_params = None logger.debug(f"************* args.taming : {args.taming}") vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024 vae = vae_klass(args) IMAGE_SIZE = vae.image_size dalle_params = dict( num_text_tokens=tokenizer.vocab_size, text_seq_len=TEXT_SEQ_LEN, dim=MODEL_DIM, depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD, reversible=REVERSIBLE, loss_img_weight=LOSS_IMG_WEIGHT, attn_types=ATTN_TYPES, ) # configure OpenAI VAE for float16s if isinstance(vae, OpenAIDiscreteVAE) and args.fp16: vae.enc.blocks.output.conv.use_float16 = True # create dataset and dataloader is_shuffle = not distributed_utils.using_backend( distributed_utils.HorovodBackend) ds = TextImageDataset( args.image_text_folder, text_len=TEXT_SEQ_LEN, image_size=IMAGE_SIZE, resize_ratio=args.resize_ratio, truncate_captions=args.truncate_captions, tokenizer=tokenizer, shuffle=is_shuffle, ) assert len(ds) > 0, 'dataset is empty' # if distr_backend.is_root_worker(): if args.rank == 0: print(f'{len(ds)} image-text pairs found for training') if not is_shuffle: data_sampler = torch.utils.data.distributed.DistributedSampler( ds, num_replicas=args.world_size, rank=args.rank) else: data_sampler = None kwargs = {'num_workers': args.num_worker, 'pin_memory': True} dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler, **kwargs) logger.info("Processes {}/{} ({:.0f}%) of train data".format( len(dl.sampler), len(dl.dataset), 100. * len(dl.sampler) / len(dl.dataset))) # initialize DALL-E dalle = DALLE(vae=vae, **dalle_params) if not using_deepspeed: if args.fp16: dalle = dalle.half() dalle = dalle.cuda() if RESUME and not using_deepspeed: dalle.load_state_dict(weights) # optimizer opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE) if LR_DECAY: scheduler = ReduceLROnPlateau( opt, mode="min", factor=0.5, patience=10, cooldown=10, min_lr=1e-6, verbose=True, ) # if distr_backend.is_root_worker(): if args.rank == 0: # experiment tracker model_config = dict(depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD) # wandb_dir = '/tmp/wandb' # if not os.path.exists(wandb_dir): # os.makedirs(wandb_dir) run = wandb.init( project=args.wandb_name, # 'dalle_train_transformer' by default resume=RESUME, config=model_config, # dir=wandb_dir ) # distribute distr_backend.check_batch_size(BATCH_SIZE) deepspeed_config = { 'train_batch_size': BATCH_SIZE, 'gradient_clipping': GRAD_CLIP_NORM, 'fp16': { 'enabled': args.fp16, }, } (distr_dalle, distr_opt, distr_dl, distr_scheduler) = distr_backend.distribute( args=args, model=dalle, optimizer=opt, model_parameters=get_trainable_params(dalle), training_data=ds if using_deepspeed else dl, lr_scheduler=scheduler if LR_DECAY else None, config_params=deepspeed_config, ) avoid_model_calls = using_deepspeed and args.fp16 if RESUME and using_deepspeed: distr_dalle.load_checkpoint(str(cp_dir)) # training for epoch in range(EPOCHS): logger.debug(f"********* epoch : {epoch} **********") if data_sampler: data_sampler.set_epoch(epoch) for i, (text, images) in enumerate(distr_dl): if i % 10 == 0 and args.rank == 0: t = time.time() if args.fp16: images = images.half() text, images = map(lambda t: t.cuda(), (text, images)) loss = distr_dalle(text, images, return_loss=True) if using_deepspeed: distr_dalle.backward(loss) distr_dalle.step() # Gradients are automatically zeroed after the step else: loss.backward() clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM) distr_opt.step() distr_opt.zero_grad() # Collective loss, averaged avg_loss = distr_backend.average_all(loss) log = {} # if i % 10 == 0 and distr_backend.is_root_worker(): if i % 10 == 0 and args.rank == 0: print(epoch, i, f'loss - {avg_loss.item()}') log = { **log, 'epoch': epoch, 'iter': i, 'loss': avg_loss.item() } if i % SAVE_EVERY_N_STEPS == 0: args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}") if i % 100 == 0: # if distr_backend.is_root_worker(): if args.rank == 0: sample_text = text[:1] token_list = sample_text.masked_select( sample_text != 0).tolist() decoded_text = tokenizer.decode(token_list) if not avoid_model_calls: # CUDA index errors when we don't guard this image = dalle.generate_images( text[:1], filter_thres=0.9) # topk sampling at 0.9 log = { **log, } if not avoid_model_calls: log['image'] = wandb.Image(image, caption=decoded_text) if i % 10 == 9 and args.rank == 0: sample_per_sec = BATCH_SIZE * 10 / (time.time() - t) log["sample_per_sec"] = sample_per_sec print(epoch, i, f'sample_per_sec - {sample_per_sec}') # if distr_backend.is_root_worker(): if args.rank == 0: wandb.log(log) if LR_DECAY and not using_deepspeed: # Scheduler is automatically progressed after the step when # using DeepSpeed. distr_scheduler.step(loss) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME save_model(args, f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}") # sync_local_checkpoints_to_s3(local_path=f'{args.model_dir}', s3_path='s3://lgaivision-coco-usva/Dalle_Model/tmd/') sync_local_checkpoints_to_s3( args.model_dir, os.path.join(args.output_s3, args.job_name + "/temp")) # if distr_backend.is_root_worker(): if args.rank == 0: # save trained model to wandb as an artifact every epoch's end model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) import glob print(f"************** file : {glob.glob(args.model_dir+'/*')}") try: print(f"wandb.run.dir : {wandb.run.dir}") print( f"************** file wandb: {glob.glob(wandb.run.dir+'/*')}" ) except: pass model_artifact.add_file( f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}") run.log_artifact(model_artifact) args.distr_dalle = distr_dalle args.dalle_params = dalle_params args.vae_params = vae_params args.using_deepspeed = using_deepspeed args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME resource_check(args) save_model(args, f"{args.model_dir +'/'+DALLE_OUTPUT_FILE_NAME}") if args.rank == 0: # from distutils.dir_util import copy_tree # copy_tree(f'{args.model_dir}', f'{args.model_dir_last}') sync_local_checkpoints_to_s3( args.model_dir, os.path.join(args.output_s3, args.job_name + "/temp")) resource_check(args) # if distr_backend.is_root_worker(): if args.rank == 0: wandb.save(f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}", base_path=args.model_dir) model_artifact = wandb.Artifact('trained-dalle', type='model', metadata=dict(model_config)) model_artifact.add_file(f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}") run.log_artifact(model_artifact) wandb.finish() # if args.local_rank == 0 or args.local_rank == 1: # gpu_mon_thread.kill() distributed_utils.backend.local_barrier() print("************************ Finished *************************")
compose = T.Compose([T.Resize(IMAGE_SIZE), T.CenterCrop(IMAGE_SIZE), T.ToTensor(),]) def collate_fn(batch): return tuple(zip(*batch)) ds = CocoCaptions(root=IMAGE_PATH, annFile=ANNO_PATH, transform=compose) dl = DataLoader(ds, BATCH_SIZE, shuffle=True, num_workers=0, collate_fn=collate_fn) assert len(ds) > 0, 'dataset is empty' print(f'{len(ds)} image-text pairs found for training') # initialize DALL-E dalle = DALLE(**dalle_params) if RESUME: dalle.load_state_dict(weights) dalle = torch.nn.DataParallel(dalle).cuda() # optimizer opt = Adam(dalle.parameters(), lr = LEARNING_RATE) # experiment tracker import wandb wandb.config.depth = DEPTH wandb.config.heads = HEADS wandb.config.dim_head = DIM_HEAD
assert dalle_path.exists(), 'trained DALL-E must exist' load_obj = torch.load(str(dalle_path)) dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop( 'vae_params'), load_obj.pop('weights') dalle_params.pop('vae', None) # cleanup later if args.taming: vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path) elif vae_params is not None: vae = DiscreteVAE(**vae_params) else: vae = OpenAIDiscreteVAE() dalle = DALLE(vae=vae, **dalle_params).cuda() dalle.load_state_dict(weights) # generate images image_size = vae.image_size texts = args.text.split('|') for j, text in tqdm(enumerate(texts)): if args.gentxt: text_tokens, gen_texts = dalle.generate_texts(tokenizer, text=text, filter_thres=args.top_k) text = gen_texts[0]