Beispiel #1
0
    def __getitem__(self, ind):
        key = self.keys[ind]
        text_file = self.text_files[key]
        image_file = self.image_files[key]

        image = Image.open(image_file)
        descriptions = text_file.read_text().split('\n')
        descriptions = list(filter(lambda t: len(t) > 0, descriptions))
        description = choice(descriptions)

        tokenized_text = tokenize(description, self.text_len, truncate_text=args.truncate_captions).squeeze(0)

        image_tensor = self.image_tranform(image)
        return tokenized_text, image_tensor
Beispiel #2
0
    def __getitem__(self, ind):
        key = self.keys[ind]
        text_file = self.text_files[key]
        image_file = self.image_files[key]

        image = Image.open(image_file)
        descriptions = text_file.read_text().split('\n')
        descriptions = list(filter(lambda t: len(t) > 0, descriptions))
        description = choice(descriptions)

        tokenized_text = tokenize(
            description, self.text_len,
            truncate_text=args.truncate_captions).squeeze(0)

        image_tensor = self.image_tranform(image)
        save_image(image_tensor,
                   os.path.join('../dataset/COCO_256',
                                str(image_file).split('/')[-1]),
                   normalize=True)
        return tokenized_text, image_tensor
Beispiel #3
0
if vae_params is not None:
    vae = DiscreteVAE(**vae_params)
elif not args.taming:
    vae = OpenAIDiscreteVAE()
else:
    vae = VQGanVAE1024()

dalle = DALLE(vae=vae, **dalle_params).cuda()

dalle.load_state_dict(weights)

# generate images

image_size = vae.image_size

text = tokenize([args.text], dalle.text_seq_len).cuda()

text = repeat(text, '() n -> b n', b=args.num_images)
# create masks
mask = text != 0

outputs = []

for text_chunk, mask in tqdm(zip(text.split(args.batch_size),
                                 mask.split(args.batch_size)),
                             desc='generating images'):
    output = dalle.generate_images(text_chunk,
                                   mask=mask,
                                   filter_thres=args.top_k)
    outputs.append(output)
Beispiel #4
0
wandb.config.depth = DEPTH
wandb.config.heads = HEADS
wandb.config.dim_head = DIM_HEAD

wandb.init(project = 'dalle_train_transformer_coco', resume = RESUME)

# training
for epoch in range(epoch_start, EPOCHS):
    for i, (images, text) in enumerate(dl):
        images = torch.stack(images)
        text_list = []
        for descriptions in text:
            descriptions = list(filter(lambda t: len(t) > 0, descriptions))
            description= choice(descriptions)
            text_list.append(description)
        text = tokenize(text_list).squeeze(0)
        mask = text != 0
        text, images, mask = map(lambda t: t.cuda(), (text, images, mask))
        loss = dalle(text, images, mask = mask, return_loss = True)
        loss = torch.sum(loss)
        loss.backward()
#         clip_grad_norm_(dalle.parameters(), GRAD_CLIP_NORM)

        opt.step()
        opt.zero_grad()

        log = {}

        if i % 10 == 0:
            print(epoch, i, f'loss - {loss.item()}')