Exemple #1
0
def generate_images(captions):
    print(sum(captions, start=[]))
    all_words = list(sorted(frozenset(sum(captions, start=[]))))
    word_tokens = dict(zip(all_words, range(1, len(all_words) + 1)))
    caption_tokens = [[word_tokens[w] for w in c] for c in captions]

    logging.info(f"{all_words =}")
    logging.info(f"{word_tokens =}")
    logging.info(f"{caption_tokens =}")

    longest_caption = max(len(c) for c in captions)
    captions_array = np.zeros((len(caption_tokens), longest_caption),
                              dtype=np.int64)
    for i in range(len(caption_tokens)):
        captions_array[i, :len(caption_tokens[i])] = caption_tokens[i]

    # captions_array = torch.from_numpy(captions_array).cuda()
    captions_array = torch.from_numpy(captions_array)
    captions_mask = captions_array != 0

    logging.info(f"{captions_array = }")

    dalle = DALLE(
        dim=1024,
        vae=
        vae,  # automatically infer (1) image sequence length and (2) number of image tokens
        num_text_tokens=len(word_tokens) + 1,  # vocab size for text
        text_seq_len=longest_caption,  # text sequence length
        depth=12,  # should aim to be 64
        heads=16,  # attention heads
        dim_head=64,  # attention head dimension
        attn_dropout=0.1,  # attention dropout
        ff_dropout=0.1,  # feedforward dropout
    )
    # .cuda

    generated_image_codes = []
    with torch.no_grad():
        for i in range(0, len(captions), 128):
            generated = generate_image_code(
                dalle,
                captions_array[i:i + 128, ...],
                mask=captions_mask[i:i + 128, ...],
            )
            generated_image_codes.append(generated)

    generated_image_codes = torch.cat(generated_image_codes, axis=0)

    with torch.no_grad():
        generated_images = vae.decode(generated_image_codes)
        logging.info(f"{generated_images = }")
Exemple #2
0
def get_dalle(vae, vocab, args):
    dalle = DALLE(dim=args.codebook_dims,
                  vae=vae,
                  num_text_tokens=len(vocab) + 1,
                  text_seq_len=len(vocab),
                  depth=16,
                  heads=8,
                  dim_head=64,
                  attn_dropout=0.1,
                  ff_dropout=0.1,
                  reversible=True)

    if args.dalle is not None and os.path.isfile(args.dalle):
        print(f"loading state dict from {args.dalle}")
        dalle.load_state_dict(torch.load(args.dalle))

    dalle.to(args.device)
    vae.to(args.device)

    return dalle
Exemple #3
0
    data_sampler = torch.utils.data.distributed.DistributedSampler(
        ds,
        num_replicas=distr_backend.get_world_size(),
        rank=distr_backend.get_rank())
else:
    data_sampler = None

dl = DataLoader(ds,
                batch_size=BATCH_SIZE,
                shuffle=not data_sampler,
                drop_last=True,
                sampler=data_sampler)

# initialize DALL-E

dalle = DALLE(vae=vae, **dalle_params)
if args.fp16:
    dalle = dalle.half()
dalle = dalle.cuda()

if RESUME:
    dalle.load_state_dict(weights)

# optimizer

opt = Adam(dalle.parameters(), lr=LEARNING_RATE)

if LR_DECAY:
    scheduler = ReduceLROnPlateau(
        opt,
        mode="min",
Exemple #4
0
assert dalle_path.exists(), 'trained DALL-E must exist'

load_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop('vae_params'), load_obj.pop('weights')

dalle_params.pop('vae', None) # cleanup later

if vae_params is not None:
    vae = DiscreteVAE(**vae_params)
elif not args.taming:
    vae = OpenAIDiscreteVAE()
else:
    vae = VQGanVAE1024()


dalle = DALLE(vae = vae, **dalle_params).cuda()

dalle.load_state_dict(weights)

# generate images

image_size = vae.image_size

texts = args.text.split('|')

for text in tqdm(texts):
    text = tokenizer.tokenize([args.text], dalle.text_seq_len).cuda()

    text = repeat(text, '() n -> b n', b = args.num_images)

    outputs = []
Exemple #5
0
    hidden_dim=64,  # hidden dimension
    num_resnet_blocks=1,  # number of resnet blocks
    temperature=
    0.9,  # gumbel softmax temperature, the lower this is, the harder the discretization
    straight_through=
    False,  # straight-through for gumbel softmax. unclear if it is better one way or the other
).cuda()

vae.load_state_dict(torch.load("Vae-small.pth"))

dalle = DALLE(
    dim=1024,
    vae=
    vae,  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens=NUM_TOKENS,  # vocab size for text
    text_seq_len=TEXTSEQLEN,  # text sequence length
    depth=12,  # should aim to be 64
    heads=16,  # attention heads
    dim_head=64,  # attention head dimension
    attn_dropout=0.1,  # attention dropout
    ff_dropout=0.1  # feedforward dropout
).cuda()

dalle.load_state_dict(torch.load("dalle-small.pth"))
"""
text = torch.randint(0, NUM_TOKENS, (BATCH_SIZE, TEXTSEQLEN))
images = torch.randn(BATCH_SIZE, 3, IMAGE_SIZE, IMAGE_SIZE)
mask = torch.ones_like(text).bool()
"""

tokenDset = token_dataset('./coco/merged-smallsample.txt')
Exemple #6
0
# create dataset and dataloader

ds = TextImageDataset(
    args.image_text_folder,
    text_len = TEXT_SEQ_LEN,
    image_size = IMAGE_SIZE
)

assert len(ds) > 0, 'dataset is empty'
print(f'{len(ds)} image-text pairs found for training')

dl = DataLoader(ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = True)

# initialize DALL-E

dalle = DALLE(**dalle_params).cuda()

if exists(args.dalle_path):
    dalle.load_state_dict(weights)

# optimizer

opt = Adam(dalle.parameters(), lr = LEARNING_RATE)

# experiment tracker

import wandb

wandb.config.depth = DEPTH
wandb.config.heads = HEADS
wandb.config.dim_head = DIM_HEAD
                                             EPOCHS * DATASET_SIZE))
        loss = vae(img, return_recon_loss=True)
        VAEloss.append(loss.cpu().detach().numpy())
        loss.backward()
        optimizerVAE.step()

np.savetxt("vaeloss.csv", np.asarray(VAEloss), delimiter=",")

torch.save(vae.state_dict(), "Vae-small.pth")

dalle = DALLE(
    dim=1024,
    vae=
    vae,  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens=NUM_TOKENS,  # vocab size for text
    text_seq_len=TEXTSEQLEN,  # text sequence length
    depth=12,  # should aim to be 64
    heads=16,  # attention heads
    dim_head=64,  # attention head dimension
    attn_dropout=0.1,  # attention dropout
    ff_dropout=0.1  # feedforward dropout
).cuda()

optimizerDALLE = torch.optim.Adam(dalle.parameters(), lr=learning_rate)
DALLEloss = []

for epoch in range(EPOCHS):
    for i in range(DATASET_SIZE):
        #print(i,":",tokenDset.getRand(i),img.size())
        optimizerDALLE.zero_grad()
        img, strs = cap[i]
        #print(img.size())
Exemple #8
0
    data_sampler = None

if ENABLE_WEBDATASET:
    # WebLoader for WebDataset and DeepSpeed compatibility
    dl = wds.WebLoader(ds, batch_size=None, shuffle=False) # optionally add num_workers=2 (n) argument
    number_of_batches = DATASET_SIZE // (BATCH_SIZE * distr_backend.get_world_size())
    dl = dl.repeat(2).slice(number_of_batches)
    dl.length = number_of_batches
else:
    # Regular DataLoader for image-text-folder datasets
    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=is_shuffle, drop_last=True, sampler=data_sampler)


# initialize DALL-E

dalle = DALLE(vae=vae, **dalle_params)
if not using_deepspeed:
    if args.fp16:
        dalle = dalle.half()
    dalle = dalle.cuda()

if RESUME and not using_deepspeed:
    dalle.load_state_dict(weights)

# optimizer

opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)
if RESUME and opt_state:
    opt.load_state_dict(opt_state)

if LR_DECAY:
Exemple #9
0
def main(args):
    # constants

    VAE_PATH = args.vae_path
    DALLE_PATH = args.dalle_path
    RESUME = exists(DALLE_PATH)

    EPOCHS = args.epochs
    BATCH_SIZE = args.batch_size

    LEARNING_RATE = args.learning_rate
    GRAD_CLIP_NORM = args.clip_grad_norm
    LR_DECAY = args.lr_decay

    MODEL_DIM = args.dim
    TEXT_SEQ_LEN = args.text_seq_len
    DEPTH = args.depth
    HEADS = args.heads
    DIM_HEAD = args.dim_head
    REVERSIBLE = args.reversible
    LOSS_IMG_WEIGHT = args.loss_img_weight

    ATTN_TYPES = tuple(args.attn_types.split(','))

    DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt'

    # initialize distributed backend

    # initialize distributed backend
    if args.sagemakermp:
        args.deepspeed = False
        using_deepspeed = False
    else:
        args.deepspeed = True
    
    distr_backend = distributed_utils.set_backend_from_args(args)
    distr_backend.initialize(args)

    if args.sagemakermp:
        args = smp_init(args)
        distributed_utils.using_backend(distributed_utils.SageMakerMPBackend)
    else:
        using_deepspeed = \
            distributed_utils.using_backend(distributed_utils.DeepSpeedBackend) 
        args.rank = int(os.environ.get('RANK'))
        args.world_size = int(os.environ.get('WORLD_SIZE'))
        args.local_rank = int(os.environ.get('LOCAL_RANK'))
        args.global_rank = args.rank
    
    logger.debug(f"using_deepspeed : {using_deepspeed}")
    
    logger.debug(
        f"args.local_rank : {args.local_rank}, args.rank : {args.rank}")
    

    # tokenizer
    logger.debug(f"exists(args.bpe_path) : {exists(args.bpe_path)}, args.chinese : {args.chinese}")
    if exists(args.bpe_path):
        klass = HugTokenizer if args.hug else YttmTokenizer
        tokenizer = klass(args.bpe_path)
    elif args.chinese:
        tokenizer = ChineseTokenizer()
    else:
        tokenizer = SimpleTokenizer()

    # reconstitute vae

    if RESUME:
        dalle_path = Path(DALLE_PATH)
        if using_deepspeed:
            cp_dir = cp_path_to_dir(dalle_path, 'ds')
            assert cp_dir.is_dir(), \
                f'DeepSpeed checkpoint directory {cp_dir} not found'
            dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME
        else:
            assert dalle_path.exists(), 'DALL-E model file does not exist'
        loaded_obj = torch.load(str(dalle_path), map_location='cpu')

        dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj[
            'vae_params'], loaded_obj['weights']

        if vae_params is not None:
            vae = DiscreteVAE(**vae_params)
        else:
            vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024
            vae = vae_klass(args)

        dalle_params = dict(**dalle_params)
        IMAGE_SIZE = vae.image_size
    else:
        if exists(VAE_PATH):
            vae_path = Path(VAE_PATH)
            assert vae_path.exists(), 'VAE model file does not exist'
            assert not vae_path.is_dir(), \
                ('Cannot load VAE model from directory; please use a '
                'standard *.pt checkpoint. '
                'Currently, merging a DeepSpeed-partitioned VAE into a DALLE '
                'model is not supported.')

            loaded_obj = torch.load(str(vae_path))

            vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

            vae = DiscreteVAE(**vae_params)
            vae.load_state_dict(weights)
        else:
            if args.rank == 0:
                # if distr_backend.is_root_worker():
                print('using pretrained VAE for encoding images to tokens')
            vae_params = None
            logger.debug(f"************* args.taming : {args.taming}")
            vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024
            vae = vae_klass(args)

        IMAGE_SIZE = vae.image_size

        dalle_params = dict(
            num_text_tokens=tokenizer.vocab_size,
            text_seq_len=TEXT_SEQ_LEN,
            dim=MODEL_DIM,
            depth=DEPTH,
            heads=HEADS,
            dim_head=DIM_HEAD,
            reversible=REVERSIBLE,
            loss_img_weight=LOSS_IMG_WEIGHT,
            attn_types=ATTN_TYPES,
        )

    # configure OpenAI VAE for float16s

    if isinstance(vae, OpenAIDiscreteVAE) and args.fp16:
        vae.enc.blocks.output.conv.use_float16 = True

    # create dataset and dataloader

    is_shuffle = not distributed_utils.using_backend(
        distributed_utils.HorovodBackend)

    ds = TextImageDataset(
        args.image_text_folder,
        text_len=TEXT_SEQ_LEN,
        image_size=IMAGE_SIZE,
        resize_ratio=args.resize_ratio,
        truncate_captions=args.truncate_captions,
        tokenizer=tokenizer,
        shuffle=is_shuffle,
    )

    assert len(ds) > 0, 'dataset is empty'
    # if distr_backend.is_root_worker():
    if args.rank == 0:
        print(f'{len(ds)} image-text pairs found for training')

    if not is_shuffle:
        data_sampler = torch.utils.data.distributed.DistributedSampler(
            ds, num_replicas=args.world_size, rank=args.rank)
    elif args.sagemakermp:
        args.ds = ds
        ds = split_dataset(args)
        data_sampler = None
    else:
        data_sampler = None
        
        
    print(f"data_sampler : {data_sampler}")

    # uncorrectable NVLink error was detected during the execution  --> remove
    kwargs = {'num_workers': args.num_worker, 'pin_memory': True}
    dl = DataLoader(ds,
                    batch_size=BATCH_SIZE,
                    shuffle=is_shuffle,
                    drop_last=True,
                    sampler=data_sampler,
                    **kwargs
                   )
    
    logger.info("Processes {}/{} ({:.0f}%) of train data".format(
    len(dl.sampler), len(dl.dataset),
    100. * len(dl.sampler) / len(dl.dataset)))

    # initialize DALL-E
    dalle = DALLE(vae=vae, **dalle_params)
    if not using_deepspeed:
        if args.fp16:
            dalle = dalle.half()
        dalle = dalle.cuda()

    if RESUME and not using_deepspeed:
        dalle.load_state_dict(weights)

    # optimizer
    opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)
    

    if LR_DECAY:
        scheduler = ReduceLROnPlateau(
            opt,
            mode="min",
            factor=0.5,
            patience=10,
            cooldown=10,
            min_lr=1e-6,
            verbose=True,
        )
    # if distr_backend.is_root_worker():
    if args.global_rank == 0:
        # experiment tracker

        model_config = dict(depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD)
        
        
        logger.debug(f"args.wandb_name : {args.wandb_name}, RESUME : {RESUME}")
        
        run = wandb.init(
            project=args.wandb_name,  # 'dalle_train_transformer' by default
            resume=RESUME,
            config=model_config,
        )

    # distribute
    distr_backend.check_batch_size(BATCH_SIZE)
    deepspeed_config = {
        'train_batch_size': BATCH_SIZE,
        'gradient_clipping': GRAD_CLIP_NORM,
        'fp16': {
            'enabled': args.fp16,
        },
    }
    (distr_dalle, distr_opt, distr_dl,
     distr_scheduler) = distr_backend.distribute(
         args=args,
         model=dalle,
         optimizer=opt,
         model_parameters=get_trainable_params(dalle),
         training_data=ds if using_deepspeed else dl,
         lr_scheduler=scheduler if LR_DECAY else None,
         config_params=deepspeed_config,
     )
    avoid_model_calls = using_deepspeed and args.fp16

    if args.sagemakermp:
        args.distr_dalle = smp.DistributedModel(distr_dalle)
        args.scaler = smp.amp.GradScaler()
        args.distr_opt = smp.DistributedOptimizer(distr_opt)
    
    if RESUME and using_deepspeed:
        distr_dalle.load_checkpoint(str(cp_dir))
        


    # training

    for epoch in range(EPOCHS):
        logger.debug(f"********* epoch : {epoch} **********")
        if data_sampler:
            data_sampler.set_epoch(epoch)
            
        for i, (text, images) in enumerate(distr_dl):
            if args.fp16:
                images = images.half()
                
            text, images = map(lambda t: t.cuda(), (text, images))
            
            if args.sagemakermp:
                args.distr_opt.zero_grad()
                
                loss = train_step(args, text, images, return_loss=True)
                loss = loss.reduce_mean()

            else:  
                loss = distr_dalle(text, images, return_loss=True, args=args)

            if using_deepspeed:
                distr_dalle.backward(loss)
                distr_dalle.step()
                # Gradients are automatically zeroed after the step
            elif args.sagemakermp:
                if args.amp:
                    scaler.step(args.distr_opt)
                    scaler.update()
                else:
                    # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param
                    if len(list(args.distr_dalle.local_parameters())) > 0:
                        args.distr_opt.step()
            else:
                loss.backward()
                clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM)
                distr_opt.step()
                distr_opt.zero_grad()

            # Collective loss, averaged
            avg_loss = distr_backend.average_all(loss)

            log = {}

            # if i % 10 == 0 and distr_backend.is_root_worker():
            if i % 10 == 0 and args.rank == 0:
                print(epoch, i, f'loss - {avg_loss.item()}')

                log = {
                    **log, 'epoch': epoch,
                    'iter': i,
                    'loss': avg_loss.item()
                }

            if i % 100 == 0:
                # if distr_backend.is_root_worker():
                if args.rank == 0:
                    sample_text = text[:1]
                    token_list = sample_text.masked_select(
                        sample_text != 0).tolist()
                    decoded_text = tokenizer.decode(token_list)
                    
                    logger.debug(f"******* avoid_model_calls : {avoid_model_calls}")
                    if not avoid_model_calls:
                        # CUDA index errors when we don't guard this
                        image = dalle.generate_images(
                            text[:1], filter_thres=0.9)  # topk sampling at 0.9

                    wandb.save(f'./dalle.pt')

                    log = {
                        **log,
                    }
                    if not avoid_model_calls:
                        log['image'] = wandb.Image(image, caption=decoded_text)

                args.distr_dalle = distr_dalle
                args.dalle_params = dalle_params
                args.vae_params = vae_params
                args.using_deepspeed = using_deepspeed
                args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME

                save_model(args, f'{args.model_dir}/dalle.pt')

            # if distr_backend.is_root_worker():
            if args.rank == 0:
                wandb.log(log)
                
#             text, images = prefetcher.next()

        if LR_DECAY and not using_deepspeed:
            # Scheduler is automatically progressed after the step when
            # using DeepSpeed.
            distr_scheduler.step(loss)

        # if distr_backend.is_root_worker():
        if args.global_rank == 0:
            # save trained model to wandb as an artifact every epoch's end

            model_artifact = wandb.Artifact('trained-dalle',
                                            type='model',
                                            metadata=dict(model_config))
            # model_artifact.add_file('dalle.pt')
            run.log_artifact(model_artifact)

    args.distr_dalle = distr_dalle
    args.dalle_params = dalle_params
    args.vae_params = vae_params
    args.using_deepspeed = using_deepspeed
    args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME

    save_model(args, f'{args.model_dir}/dalle-final.pt')

    # if distr_backend.is_root_worker():
    if args.global_rank == 0:
        wandb.save('./dalle-final.pt')
        model_artifact = wandb.Artifact('trained-dalle',
                                        type='model',
                                        metadata=dict(model_config))
        # model_artifact.add_file('dalle-final.pt')
        run.log_artifact(model_artifact)

        wandb.finish()
                  codebook_dim=256,
                  hidden_dim=128,
                  temperature=0.9)

# load pretrained vae

vae_dict = torch.load("./models/" + vaename + "-" + str(load_epoch) + ".pth")
vae.load_state_dict(vae_dict)
vae.to(device)

dalle = DALLE(
    dim=256,  #512,
    vae=
    vae,  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens=10000,  # vocab size for text
    text_seq_len=256,  # text sequence length
    depth=6,  # should be 64
    heads=8,  # attention heads
    dim_head=64,  # attention head dimension
    attn_dropout=0.1,  # attention dropout
    ff_dropout=0.1  # feedforward dropout
)

# load pretrained dalle if continuing training

dalle_dict = torch.load(loadfn)
dalle.load_state_dict(dalle_dict)

dalle.to(device)

# get image and text data
Exemple #11
0
def main(args):
    # constants
    print(f"torch.cuda.nccl.version() : {torch.cuda.nccl.version()}")
    VAE_PATH = args.vae_path
    DALLE_PATH = args.dalle_path
    RESUME = exists(DALLE_PATH)

    EPOCHS = args.epochs
    BATCH_SIZE = args.batch_size

    LEARNING_RATE = args.learning_rate
    GRAD_CLIP_NORM = args.clip_grad_norm
    LR_DECAY = args.lr_decay
    SAVE_EVERY_N_STEPS = args.save_every_n_steps
    MODEL_DIM = args.dim
    TEXT_SEQ_LEN = args.text_seq_len
    DEPTH = args.depth
    HEADS = args.heads
    DIM_HEAD = args.dim_head
    REVERSIBLE = args.reversible
    LOSS_IMG_WEIGHT = args.loss_img_weight

    ATTN_TYPES = tuple(args.attn_types.split(','))

    DEEPSPEED_CP_AUX_FILENAME = 'auxiliary.pt'
    DALLE_OUTPUT_FILE_NAME = args.dalle_output_file_name

    # initialize distributed backend

    args.deepspeed = True

    distr_backend = distributed_utils.set_backend_from_args(args)
    distr_backend.initialize(args)

    using_deepspeed = \
        distributed_utils.using_backend(distributed_utils.DeepSpeedBackend)
    args.rank = int(os.environ.get('RANK'))
    args.world_size = int(os.environ.get('WORLD_SIZE'))
    args.local_rank = int(os.environ.get('LOCAL_RANK'))
    args.global_rank = args.rank

    logger.debug(f"using_deepspeed : {using_deepspeed}")

    logger.debug(
        f"args.local_rank : {args.local_rank}, args.rank : {args.rank}")
    print(
        f"********* torch.distributed.get_rank() : {torch.distributed.get_rank()}"
    )
    # tokenizer
    logger.debug(
        f"exists(args.bpe_path) : {exists(args.bpe_path)}, args.chinese : {args.chinese}"
    )

    #     if args.local_rank == 0 or args.local_rank == 1:
    # #         print(f"args.job_name : {args.job_name}")
    #         gpu_mon_thread = aws_util.GPUMon(device_index= args.local_rank, job_name=args.job_name)
    #         gpu_mon_thread.start()

    if exists(args.bpe_path):
        klass = HugTokenizer if args.hug else YttmTokenizer
        tokenizer = klass(args.bpe_path)
    elif args.chinese:
        tokenizer = ChineseTokenizer()
    else:
        tokenizer = SimpleTokenizer()

    # reconstitute vae

    if RESUME:
        dalle_path = Path(DALLE_PATH)
        if using_deepspeed:
            cp_dir = cp_path_to_dir(dalle_path, 'ds')
            assert cp_dir.is_dir(), \
                f'DeepSpeed checkpoint directory {cp_dir} not found'
            dalle_path = cp_dir / DEEPSPEED_CP_AUX_FILENAME
        else:
            assert dalle_path.exists(), 'DALL-E model file does not exist'
        loaded_obj = torch.load(str(dalle_path), map_location='cpu')

        dalle_params, vae_params, weights = loaded_obj['hparams'], loaded_obj[
            'vae_params'], loaded_obj['weights']

        if vae_params is not None:
            vae = DiscreteVAE(**vae_params)
        else:
            vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024
            vae = vae_klass(args)

        dalle_params = dict(**dalle_params)
        IMAGE_SIZE = vae.image_size
    else:
        if exists(VAE_PATH):
            vae_path = Path(VAE_PATH)
            assert vae_path.exists(), 'VAE model file does not exist'
            assert not vae_path.is_dir(), \
                ('Cannot load VAE model from directory; please use a '
                 'standard *.pt checkpoint. '
                 'Currently, merging a DeepSpeed-partitioned VAE into a DALLE '
                 'model is not supported.')

            loaded_obj = torch.load(str(vae_path))

            vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

            vae = DiscreteVAE(**vae_params)
            vae.load_state_dict(weights)
        else:
            if args.rank == 0:
                # if distr_backend.is_root_worker():
                print('using pretrained VAE for encoding images to tokens')
            vae_params = None
            logger.debug(f"************* args.taming : {args.taming}")
            vae_klass = OpenAIDiscreteVAE if not args.taming else VQGanVAE1024
            vae = vae_klass(args)

        IMAGE_SIZE = vae.image_size

        dalle_params = dict(
            num_text_tokens=tokenizer.vocab_size,
            text_seq_len=TEXT_SEQ_LEN,
            dim=MODEL_DIM,
            depth=DEPTH,
            heads=HEADS,
            dim_head=DIM_HEAD,
            reversible=REVERSIBLE,
            loss_img_weight=LOSS_IMG_WEIGHT,
            attn_types=ATTN_TYPES,
        )

    # configure OpenAI VAE for float16s

    if isinstance(vae, OpenAIDiscreteVAE) and args.fp16:
        vae.enc.blocks.output.conv.use_float16 = True

    # create dataset and dataloader

    is_shuffle = not distributed_utils.using_backend(
        distributed_utils.HorovodBackend)

    ds = TextImageDataset(
        args.image_text_folder,
        text_len=TEXT_SEQ_LEN,
        image_size=IMAGE_SIZE,
        resize_ratio=args.resize_ratio,
        truncate_captions=args.truncate_captions,
        tokenizer=tokenizer,
        shuffle=is_shuffle,
    )

    assert len(ds) > 0, 'dataset is empty'
    # if distr_backend.is_root_worker():
    if args.rank == 0:
        print(f'{len(ds)} image-text pairs found for training')

    if not is_shuffle:
        data_sampler = torch.utils.data.distributed.DistributedSampler(
            ds, num_replicas=args.world_size, rank=args.rank)
    else:
        data_sampler = None

    kwargs = {'num_workers': args.num_worker, 'pin_memory': True}
    dl = DataLoader(ds,
                    batch_size=BATCH_SIZE,
                    shuffle=is_shuffle,
                    drop_last=True,
                    sampler=data_sampler,
                    **kwargs)

    logger.info("Processes {}/{} ({:.0f}%) of train data".format(
        len(dl.sampler), len(dl.dataset),
        100. * len(dl.sampler) / len(dl.dataset)))

    # initialize DALL-E

    dalle = DALLE(vae=vae, **dalle_params)
    if not using_deepspeed:
        if args.fp16:
            dalle = dalle.half()
        dalle = dalle.cuda()

    if RESUME and not using_deepspeed:
        dalle.load_state_dict(weights)

    # optimizer

    opt = Adam(get_trainable_params(dalle), lr=LEARNING_RATE)

    if LR_DECAY:
        scheduler = ReduceLROnPlateau(
            opt,
            mode="min",
            factor=0.5,
            patience=10,
            cooldown=10,
            min_lr=1e-6,
            verbose=True,
        )

    # if distr_backend.is_root_worker():
    if args.rank == 0:
        # experiment tracker

        model_config = dict(depth=DEPTH, heads=HEADS, dim_head=DIM_HEAD)

        #         wandb_dir = '/tmp/wandb'
        #         if not os.path.exists(wandb_dir):
        #             os.makedirs(wandb_dir)

        run = wandb.init(
            project=args.wandb_name,  # 'dalle_train_transformer' by default
            resume=RESUME,
            config=model_config,
            #             dir=wandb_dir
        )

    # distribute

    distr_backend.check_batch_size(BATCH_SIZE)
    deepspeed_config = {
        'train_batch_size': BATCH_SIZE,
        'gradient_clipping': GRAD_CLIP_NORM,
        'fp16': {
            'enabled': args.fp16,
        },
    }

    (distr_dalle, distr_opt, distr_dl,
     distr_scheduler) = distr_backend.distribute(
         args=args,
         model=dalle,
         optimizer=opt,
         model_parameters=get_trainable_params(dalle),
         training_data=ds if using_deepspeed else dl,
         lr_scheduler=scheduler if LR_DECAY else None,
         config_params=deepspeed_config,
     )
    avoid_model_calls = using_deepspeed and args.fp16

    if RESUME and using_deepspeed:
        distr_dalle.load_checkpoint(str(cp_dir))

    # training

    for epoch in range(EPOCHS):
        logger.debug(f"********* epoch : {epoch} **********")
        if data_sampler:
            data_sampler.set_epoch(epoch)
        for i, (text, images) in enumerate(distr_dl):

            if i % 10 == 0 and args.rank == 0:
                t = time.time()
            if args.fp16:
                images = images.half()
            text, images = map(lambda t: t.cuda(), (text, images))

            loss = distr_dalle(text, images, return_loss=True)

            if using_deepspeed:
                distr_dalle.backward(loss)
                distr_dalle.step()
                # Gradients are automatically zeroed after the step
            else:
                loss.backward()
                clip_grad_norm_(distr_dalle.parameters(), GRAD_CLIP_NORM)
                distr_opt.step()
                distr_opt.zero_grad()

            # Collective loss, averaged
            avg_loss = distr_backend.average_all(loss)

            log = {}

            # if i % 10 == 0 and distr_backend.is_root_worker():
            if i % 10 == 0 and args.rank == 0:
                print(epoch, i, f'loss - {avg_loss.item()}')

                log = {
                    **log, 'epoch': epoch,
                    'iter': i,
                    'loss': avg_loss.item()
                }
            if i % SAVE_EVERY_N_STEPS == 0:
                args.distr_dalle = distr_dalle
                args.dalle_params = dalle_params
                args.vae_params = vae_params
                args.using_deepspeed = using_deepspeed
                args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME
                save_model(args,
                           f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}")

            if i % 100 == 0:
                # if distr_backend.is_root_worker():
                if args.rank == 0:
                    sample_text = text[:1]
                    token_list = sample_text.masked_select(
                        sample_text != 0).tolist()
                    decoded_text = tokenizer.decode(token_list)

                    if not avoid_model_calls:
                        # CUDA index errors when we don't guard this
                        image = dalle.generate_images(
                            text[:1], filter_thres=0.9)  # topk sampling at 0.9

                    log = {
                        **log,
                    }
                    if not avoid_model_calls:
                        log['image'] = wandb.Image(image, caption=decoded_text)
            if i % 10 == 9 and args.rank == 0:
                sample_per_sec = BATCH_SIZE * 10 / (time.time() - t)
                log["sample_per_sec"] = sample_per_sec
                print(epoch, i, f'sample_per_sec - {sample_per_sec}')
            # if distr_backend.is_root_worker():
            if args.rank == 0:
                wandb.log(log)

        if LR_DECAY and not using_deepspeed:
            # Scheduler is automatically progressed after the step when
            # using DeepSpeed.
            distr_scheduler.step(loss)

        args.distr_dalle = distr_dalle
        args.dalle_params = dalle_params
        args.vae_params = vae_params
        args.using_deepspeed = using_deepspeed
        args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME

        save_model(args, f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}")
        #                 sync_local_checkpoints_to_s3(local_path=f'{args.model_dir}', s3_path='s3://lgaivision-coco-usva/Dalle_Model/tmd/')
        sync_local_checkpoints_to_s3(
            args.model_dir,
            os.path.join(args.output_s3, args.job_name + "/temp"))

        # if distr_backend.is_root_worker():
        if args.rank == 0:
            # save trained model to wandb as an artifact every epoch's end

            model_artifact = wandb.Artifact('trained-dalle',
                                            type='model',
                                            metadata=dict(model_config))
            import glob
            print(f"************** file : {glob.glob(args.model_dir+'/*')}")
            try:
                print(f"wandb.run.dir : {wandb.run.dir}")
                print(
                    f"************** file wandb: {glob.glob(wandb.run.dir+'/*')}"
                )
            except:
                pass

            model_artifact.add_file(
                f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}")
            run.log_artifact(model_artifact)

    args.distr_dalle = distr_dalle
    args.dalle_params = dalle_params
    args.vae_params = vae_params
    args.using_deepspeed = using_deepspeed
    args.DEEPSPEED_CP_AUX_FILENAME = DEEPSPEED_CP_AUX_FILENAME
    resource_check(args)
    save_model(args, f"{args.model_dir +'/'+DALLE_OUTPUT_FILE_NAME}")
    if args.rank == 0:
        #         from distutils.dir_util import copy_tree

        #         copy_tree(f'{args.model_dir}', f'{args.model_dir_last}')
        sync_local_checkpoints_to_s3(
            args.model_dir,
            os.path.join(args.output_s3, args.job_name + "/temp"))
        resource_check(args)

    # if distr_backend.is_root_worker():
    if args.rank == 0:
        wandb.save(f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}",
                   base_path=args.model_dir)

        model_artifact = wandb.Artifact('trained-dalle',
                                        type='model',
                                        metadata=dict(model_config))
        model_artifact.add_file(f"{args.model_dir+'/'+DALLE_OUTPUT_FILE_NAME}")
        run.log_artifact(model_artifact)

        wandb.finish()


#     if args.local_rank == 0 or args.local_rank == 1:
#         gpu_mon_thread.kill()

    distributed_utils.backend.local_barrier()
    print("************************ Finished *************************")
Exemple #12
0
    data_sampler = torch.utils.data.distributed.DistributedSampler(
        ds,
        num_replicas=distr_backend.get_world_size(),
        rank=distr_backend.get_rank())
else:
    data_sampler = None

dl = DataLoader(ds,
                batch_size=BATCH_SIZE,
                shuffle=not data_sampler,
                drop_last=True,
                sampler=data_sampler)

# initialize DALL-E

dalle = DALLE(vae=vae, **dalle_params)
if args.fp16:
    dalle = dalle.half()
dalle = dalle.cuda()

if RESUME:
    dalle.load_state_dict(weights)

# optimizer

opt = Adam(dalle.parameters(), lr=LEARNING_RATE)

if LR_DECAY:
    scheduler = ReduceLROnPlateau(
        opt,
        mode="min",
Exemple #13
0
compose = T.Compose([T.Resize(IMAGE_SIZE),
                     T.CenterCrop(IMAGE_SIZE),
                     T.ToTensor(),])

def collate_fn(batch):
    return tuple(zip(*batch))

ds = CocoCaptions(root=IMAGE_PATH, annFile=ANNO_PATH, transform=compose)
dl = DataLoader(ds, BATCH_SIZE, shuffle=True, num_workers=0, collate_fn=collate_fn)

assert len(ds) > 0, 'dataset is empty'
print(f'{len(ds)} image-text pairs found for training')

# initialize DALL-E
dalle = DALLE(**dalle_params)

if RESUME:
    dalle.load_state_dict(weights)


dalle = torch.nn.DataParallel(dalle).cuda()

# optimizer
opt = Adam(dalle.parameters(), lr = LEARNING_RATE)

# experiment tracker
import wandb
wandb.config.depth = DEPTH
wandb.config.heads = HEADS
wandb.config.dim_head = DIM_HEAD
Exemple #14
0
assert dalle_path.exists(), 'trained DALL-E must exist'

load_obj = torch.load(str(dalle_path))
dalle_params, vae_params, weights = load_obj.pop('hparams'), load_obj.pop(
    'vae_params'), load_obj.pop('weights')

dalle_params.pop('vae', None)  # cleanup later

if args.taming:
    vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
elif vae_params is not None:
    vae = DiscreteVAE(**vae_params)
else:
    vae = OpenAIDiscreteVAE()

dalle = DALLE(vae=vae, **dalle_params).cuda()

dalle.load_state_dict(weights)

# generate images

image_size = vae.image_size

texts = args.text.split('|')

for j, text in tqdm(enumerate(texts)):
    if args.gentxt:
        text_tokens, gen_texts = dalle.generate_texts(tokenizer,
                                                      text=text,
                                                      filter_thres=args.top_k)
        text = gen_texts[0]