Exemplo n.º 1
0
def main():
    args = get_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"

    clip_model, _ = clip.load(args.model)
    print(f"Using model {args.model}")

    input_text = args.input_text
    print(f"Generating from '{input_text}'")

    out_name_list = []
    out_name_list.append(txt_clean(input_text))
    out_name = '-'.join(out_name_list)
    out_name += '-%s' % args.model if 'RN' in args.model.upper() else ''

    tempdir = os.path.join(args.out_dir, out_name)
    os.makedirs(tempdir, exist_ok=True)

    tokenized_text = clip.tokenize([input_text]).to(device).detach().clone()
    text_logits = clip_model.encode_text(tokenized_text)

    num_channels = 3
    spectrum_size = [args.batch_size, num_channels, *args.size]
    fft_img, img_freqs = get_fft_img(
        spectrum_size,
        std=0.01,
        return_img_freqs=True,
    )

    fft_img = fft_img.to(device)
    fft_img.requires_grad = True

    scale = get_scale_from_img_freqs(
        img_freqs=img_freqs,
        decay_power=args.decay,
    )

    scale = scale.to(device)

    shift = None
    if args.noise > 0:
        img_size = img_freqs.shape
        noise_size = (1, 1, *img_size, 1)
        shift = self.noise * torch.randn(noise_size, ).to(self.device)

    optimizer = torch.optim.Adam(
        [fft_img],
        args.lrate,
    )

    sign = -1

    pbar = ProgressBar(args.num_steps // args.save_freq)

    num_steps = args.num_steps
    num_crops = 200
    crop_size = 224

    for step in range(num_steps):
        loss = 0

        initial_img = fft_to_rgb(
            fft_img=fft_img,
            scale=scale,
            img_size=args.size,
            shift=shift,
            contrast=1.0,
            decorrelate=True,
            device=device,
        )

        crop_img_out = random_crop(
            initial_img,
            num_crops,
            crop_size,
            normalize=True,
        )
        img_logits = clip_model.encode_image(crop_img_out).to(device)
        tokenized_text = clip.tokenize([input_text]).to(device)
        text_logits = clip_model.encode_text(tokenized_text)

        loss += -torch.cosine_similarity(
            text_logits,
            img_logits,
            dim=-1,
        ).mean()

        torch.cuda.empty_cache()

        # if self.prog is True:
        #     lr_cur = lr + (step / self.steps) * (init_lr - lr)
        #     for g in self.optimizer.param_groups:
        #         g['lr'] = lr_cur

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % args.save_freq == 0:
            with torch.no_grad():
                img = fft_to_rgb(
                    fft_img=fft_img,
                    scale=scale,
                    img_size=args.size,
                    shift=shift,
                    contrast=1.0,
                    decorrelate=True,
                    device=device,
                )
                img = img.cpu().numpy()

            img_out_path = os.path.join(tempdir,
                                        '%04d.jpg' % (step // args.save_freq))
            checkout(
                img[0],
                img_out_path,
            )

            if pbar is not None:
                pbar.upd()

    os.system('ffmpeg -v warning -y -i %s\%%04d.jpg "%s.mp4"' %
              (tempdir, os.path.join(args.out_dir, out_name)))
    shutil.copy(
        img_list(tempdir)[-1],
        os.path.join(out_dir, '%s-%d.jpg' % (out_name, num_steps)))

    if args.save_pt is True:
        torch.save(fft_img, '%s.pt' % os.path.join(out_dir, out_name))