def ascend_txt(model, lats, sideX, sideY, perceptor, percep, gen, tokenizedtxt): if gen == 'biggan': cutn = 128 zs = [*lats()] out = model(zs[0], zs[1], 1) elif gen == 'dall-e': cutn = 32 zs = lats() out = unmap_pixels(torch.sigmoid(model(zs)[:, :3].float())) elif gen == 'stylegan': zs = lats.normu.repeat(1,18,1) img = model(zs) img = torch.nn.functional.upsample_bilinear(img, (224, 224)) img_logits, _text_logits = perceptor(img, tokenizedtxt.cuda()) return 1/img_logits * 100, img, zs p_s = [] for ch in range(cutn): # size = int(sideX*torch.zeros(1,).normal_(mean=.8, std=.3).clip(.5, .95)) size = int(sideX*torch.zeros(1,).normal_(mean=.39, std=.865).clip(.362, .7099)) offsetx = torch.randint(0, sideX - size, ()) offsety = torch.randint(0, sideY - size, ()) apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size] apper = pad_augs(apper) # apper = kornia_augs(apper, sideX=sideX) apper = torch.nn.functional.interpolate(apper, (224, 224), mode='nearest') p_s.append(apper) into = torch.cat(p_s, 0) if gen == 'biggan': # into = nom((into + 1) / 2) up_noise = 0.01649 into = into + (up_noise)*torch.randn_like(into, requires_grad=True) into = nom((into + 1) / 1.8) elif gen == 'dall-e': into = nom((into + 1) / 2) iii = perceptor.encode_image(into) llls = zs #lats() if gen == 'dall-e': return [0, 10*-torch.cosine_similarity(percep, iii).view(-1, 1).T.mean(1), zs] lat_l = torch.abs(1 - torch.std(llls[0], dim=1)).mean() + \ torch.abs(torch.mean(llls[0])).mean() + \ 4*torch.max(torch.square(llls[0]).mean(), lats.thrsh_lat) for array in llls[0]: mean = torch.mean(array) diffs = array - mean var = torch.mean(torch.pow(diffs, 2.0)) std = torch.pow(var, 0.5) zscores = diffs / std skews = torch.mean(torch.pow(zscores, 3.0)) kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0 lat_l = lat_l + torch.abs(kurtoses) / llls[0].shape[0] + torch.abs(skews) / llls[0].shape[0] cls_l = ((50*torch.topk(llls[1],largest=False,dim=1,k=999)[0])**2).mean() return [lat_l, cls_l, -100*torch.cosine_similarity(percep, iii, dim=-1).mean(), zs]
def decode(): z = request.get_json(force=True) z = np.array(z) z = torch.from_numpy(z).to(dev) z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float() x_stats = dec(z).float() x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3])) x_rec = T.ToPILImage(mode='RGB')(x_rec[0]) return serve_pil_image(x_rec)
def encode_decode(size): with torch.no_grad(): # https://github.com/pytorch/pytorch/issues/16417#issuecomment-566654504 data = request.files['file'] x = preprocess(PIL.Image.open(data), int(size)) z_logits = enc(x.to(dev)) z = torch.argmax(z_logits, axis=1) z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float() x_stats = dec(z).float() x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3])) x_rec = T.ToPILImage(mode='RGB')(x_rec[0]) return serve_pil_image(x_rec)
def reconstruct_with_dalle(x, encoder, decoder, do_preprocess=False): # takes in tensor (or optionally, a PIL image) and returns a PIL image if do_preprocess: x = preprocess(x) z_logits = encoder(x) z = torch.argmax(z_logits, axis=1) print(f"DALL-E: latent shape: {z.shape}") z = F.one_hot(z, num_classes=encoder.vocab_size).permute(0, 3, 1, 2).float() x_stats = decoder(z).float() x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3])) x_rec = T.ToPILImage(mode='RGB')(x_rec[0]) return x_rec
def main(): x = preprocess( download_image( "https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg" )) orig_image = T.ToPILImage(mode="RGB")(x[0]) orig_image.show() # orig_image.save("test.jpg") z_logits = enc(x) z = torch.argmax(z_logits, axis=1) z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float() x_stats = dec(z).float() x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3])) x_rec = T.ToPILImage(mode="RGB")(x_rec[0]) x_rec.show()
img = T.ToPILImage(mode='RGB')(x[0]) plt.imshow(img) plt.show() import torch.nn.functional as F z_logits = enc(x) z = torch.argmax(z_logits, axis=1) z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float() ch_array = z[0,0:100,:,:] for i in range(100): ch_sub_array = z[0,i].flatten() byte_ch = 0 for ch_item in ch_sub_array: if ch_item > 0: print(ch_item) print( byte_ch) ch_print = byte_ch % 256 print(chr(ch_print)) byte_ch += 1 x_stats = dec(z).float() x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3])) x_rec = T.ToPILImage(mode='RGB')(x_rec[0]) plt.imshow(x_rec) plt.show()
def interpolate(templist, descs, model, audiofile): video_temp_list = [] # interpole elements between each image for idx1, pt in enumerate(descs): # get the next index of the descs list, # if it z1_idx is out of range, break the loop z1_idx = idx1 + 1 if z1_idx >= len(descs): break current_lyric = pt[1] # get the interval betwee 2 lines/elements in seconds `ttime` d1 = pt[0] d2 = descs[z1_idx][0] ttime = d2 - d1 # if it is the very first index, load the first pt temp file # if not assign the previous pt file (z1) to zs variable if idx1 == 0: zs = torch.load(templist[idx1]) else: zs = z1 # compute for the number of elements to be insert between the 2 elements N = round(ttime * interpol) print(z1_idx) # the codes below determine if the output is list (for biggan) # if not insert it into a list if not isinstance(zs, list): z0 = [zs] z1 = [torch.load(templist[z1_idx])] else: z0 = zs z1 = torch.load(templist[z1_idx]) # loop over the range of elements and generate the images image_temp_list = [] for t in range(N): azs = [] for r in zip(z0, z1): z_diff = r[1] - r[0] inter_zs = r[0] + sigmoid(t / (N - 1)) * z_diff azs.append(inter_zs) # Generate image with torch.no_grad(): if generator == 'biggan': img = model(azs[0], azs[1], 1).cpu().numpy() img = img[0] elif generator == 'dall-e': img = unmap_pixels( torch.sigmoid(model( azs[0])[:, :3]).cpu().float()).numpy() img = img[0] elif generator == 'stylegan': img = model(azs[0]) image_temp = create_image(img, t, current_lyric, generator) image_temp_list.append(image_temp) video_temp = create_video.createvid(f'{current_lyric}', image_temp_list, duration=ttime / N) video_temp_list.append(video_temp) # Finally create the final output and save to output folder create_video.concatvids(descs, video_temp_list, audiofile, lyrics=lyrics)