예제 #1
0
def ascend_txt(model, lats, sideX, sideY, perceptor, percep, gen, tokenizedtxt):
    if gen == 'biggan':
        cutn = 128
        zs = [*lats()]
        out = model(zs[0], zs[1], 1)
        
    elif gen == 'dall-e':
        cutn = 32
        zs = lats()
        out = unmap_pixels(torch.sigmoid(model(zs)[:, :3].float()))

    elif gen == 'stylegan':
        zs = lats.normu.repeat(1,18,1)
        img = model(zs)
        img = torch.nn.functional.upsample_bilinear(img, (224, 224))

        img_logits, _text_logits = perceptor(img, tokenizedtxt.cuda())

        return 1/img_logits * 100, img, zs
    
    p_s = []
    for ch in range(cutn):
        # size = int(sideX*torch.zeros(1,).normal_(mean=.8, std=.3).clip(.5, .95))
        size = int(sideX*torch.zeros(1,).normal_(mean=.39, std=.865).clip(.362, .7099))
        offsetx = torch.randint(0, sideX - size, ())
        offsety = torch.randint(0, sideY - size, ())
        apper = out[:, :, offsetx:offsetx + size, offsety:offsety + size]
        apper = pad_augs(apper)
        # apper = kornia_augs(apper, sideX=sideX)
        apper = torch.nn.functional.interpolate(apper, (224, 224), mode='nearest')
        p_s.append(apper)
    into = torch.cat(p_s, 0)
    if gen == 'biggan':
        # into = nom((into + 1) / 2)
        up_noise = 0.01649
        into = into + (up_noise)*torch.randn_like(into, requires_grad=True)
        into = nom((into + 1) / 1.8)
    elif gen == 'dall-e':
        into = nom((into + 1) / 2)
    iii = perceptor.encode_image(into)
    llls = zs #lats()

    if gen == 'dall-e':
        return [0, 10*-torch.cosine_similarity(percep, iii).view(-1, 1).T.mean(1), zs]

    lat_l = torch.abs(1 - torch.std(llls[0], dim=1)).mean() + \
            torch.abs(torch.mean(llls[0])).mean() + \
            4*torch.max(torch.square(llls[0]).mean(), lats.thrsh_lat)
    for array in llls[0]:
        mean = torch.mean(array)
        diffs = array - mean
        var = torch.mean(torch.pow(diffs, 2.0))
        std = torch.pow(var, 0.5)
        zscores = diffs / std
        skews = torch.mean(torch.pow(zscores, 3.0))
        kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0
        lat_l = lat_l + torch.abs(kurtoses) / llls[0].shape[0] + torch.abs(skews) / llls[0].shape[0]
    cls_l = ((50*torch.topk(llls[1],largest=False,dim=1,k=999)[0])**2).mean()
    return [lat_l, cls_l, -100*torch.cosine_similarity(percep, iii, dim=-1).mean(), zs]
예제 #2
0
def decode():
  z = request.get_json(force=True)
  z = np.array(z)
  z = torch.from_numpy(z).to(dev)
  z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()
  x_stats = dec(z).float()
  x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
  x_rec = T.ToPILImage(mode='RGB')(x_rec[0])
  return serve_pil_image(x_rec)
예제 #3
0
def encode_decode(size):
  with torch.no_grad(): # https://github.com/pytorch/pytorch/issues/16417#issuecomment-566654504
    data = request.files['file']
    x = preprocess(PIL.Image.open(data), int(size))
    z_logits = enc(x.to(dev))
    z = torch.argmax(z_logits, axis=1)
    z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()
    x_stats = dec(z).float()
    x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
    x_rec = T.ToPILImage(mode='RGB')(x_rec[0])
    return serve_pil_image(x_rec)
def reconstruct_with_dalle(x, encoder, decoder, do_preprocess=False):
  # takes in tensor (or optionally, a PIL image) and returns a PIL image
  if do_preprocess:
    x = preprocess(x)
  z_logits = encoder(x)
  z = torch.argmax(z_logits, axis=1)
  
  print(f"DALL-E: latent shape: {z.shape}")
  z = F.one_hot(z, num_classes=encoder.vocab_size).permute(0, 3, 1, 2).float()

  x_stats = decoder(z).float()
  x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
  x_rec = T.ToPILImage(mode='RGB')(x_rec[0])

  return x_rec
예제 #5
0
def main():
    x = preprocess(
        download_image(
            "https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg"
        ))
    orig_image = T.ToPILImage(mode="RGB")(x[0])
    orig_image.show()
    # orig_image.save("test.jpg")

    z_logits = enc(x)
    z = torch.argmax(z_logits, axis=1)
    z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()

    x_stats = dec(z).float()
    x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
    x_rec = T.ToPILImage(mode="RGB")(x_rec[0])

    x_rec.show()
예제 #6
0
    img = T.ToPILImage(mode='RGB')(x[0])
    plt.imshow(img)
    plt.show()
    import torch.nn.functional as F

    z_logits = enc(x)

    z = torch.argmax(z_logits, axis=1)
    z = F.one_hot(z, num_classes=enc.vocab_size).permute(0, 3, 1, 2).float()


    ch_array = z[0,0:100,:,:]

    for i in range(100):
        ch_sub_array = z[0,i].flatten()
        byte_ch = 0
        for ch_item in ch_sub_array:
            if ch_item > 0:
                print(ch_item)
                print( byte_ch)
                ch_print = byte_ch % 256
                print(chr(ch_print))
            byte_ch += 1

    x_stats = dec(z).float()
    x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
    x_rec = T.ToPILImage(mode='RGB')(x_rec[0])
    plt.imshow(x_rec)
    plt.show()
예제 #7
0
파일: main.py 프로젝트: yk/clip_music_video
def interpolate(templist, descs, model, audiofile):

    video_temp_list = []

    # interpole elements between each image

    for idx1, pt in enumerate(descs):

        # get the next index of the descs list,
        # if it z1_idx is out of range, break the loop
        z1_idx = idx1 + 1
        if z1_idx >= len(descs):
            break

        current_lyric = pt[1]

        # get the interval betwee 2 lines/elements in seconds `ttime`
        d1 = pt[0]
        d2 = descs[z1_idx][0]
        ttime = d2 - d1

        # if it is the very first index, load the first pt temp file
        # if not assign the previous pt file (z1) to zs variable
        if idx1 == 0:
            zs = torch.load(templist[idx1])
        else:
            zs = z1

        # compute for the number of elements to be insert between the 2 elements
        N = round(ttime * interpol)
        print(z1_idx)
        # the codes below determine if the output is list (for biggan)
        # if not insert it into a list
        if not isinstance(zs, list):
            z0 = [zs]
            z1 = [torch.load(templist[z1_idx])]
        else:
            z0 = zs
            z1 = torch.load(templist[z1_idx])

        # loop over the range of elements and generate the images
        image_temp_list = []
        for t in range(N):

            azs = []
            for r in zip(z0, z1):
                z_diff = r[1] - r[0]
                inter_zs = r[0] + sigmoid(t / (N - 1)) * z_diff
                azs.append(inter_zs)

            # Generate image
            with torch.no_grad():
                if generator == 'biggan':
                    img = model(azs[0], azs[1], 1).cpu().numpy()
                    img = img[0]
                elif generator == 'dall-e':
                    img = unmap_pixels(
                        torch.sigmoid(model(
                            azs[0])[:, :3]).cpu().float()).numpy()
                    img = img[0]
                elif generator == 'stylegan':
                    img = model(azs[0])
                image_temp = create_image(img, t, current_lyric, generator)
            image_temp_list.append(image_temp)

        video_temp = create_video.createvid(f'{current_lyric}',
                                            image_temp_list,
                                            duration=ttime / N)
        video_temp_list.append(video_temp)
    # Finally create the final output and save to output folder
    create_video.concatvids(descs, video_temp_list, audiofile, lyrics=lyrics)