示例#1
0
def frame_transform(img, size, angle, shift, scale, shear):
    if old_torch(): # 1.7.1
        img = T.functional.affine(img, angle, tuple(shift), scale, shear, fillcolor=0, resample=PIL.Image.BILINEAR)
        img = T.functional.center_crop(img, size)
        img = pad_up_to(img, size)
    else: # 1.8+
        img = T.functional.affine(img, angle, tuple(shift), scale, shear, fill=0, interpolation=T.InterpolationMode.BILINEAR)
        img = T.functional.center_crop(img, size) # on 1.8+ also pads
    return img
示例#2
0
 def inner(shift=None, contrast=1., *noargs, **nokwargs):
     scaled_spectrum_t = scale * spectrum_real_imag_t
     if shift is not None:
         scaled_spectrum_t += scale * shift
     if old_torch():
         image = torch.irfft(scaled_spectrum_t,
                             2,
                             normalized=True,
                             signal_sizes=(h, w))
     else:
         if type(scaled_spectrum_t) is not torch.complex64:
             scaled_spectrum_t = torch.view_as_complex(scaled_spectrum_t)
         image = torch.fft.irfftn(scaled_spectrum_t, s=(h, w), norm='ortho')
     image = image * contrast / image.std()  # keep contrast, empirical
     return image
示例#3
0
def img2fft(img_in, decay=1., colors=1.):
    image_t = un_rgb(img_in, colors=colors)
    h, w = image_t.shape[2], image_t.shape[3]

    with torch.no_grad():
        if old_torch():
            spectrum = torch.rfft(image_t, 2, normalized=True)  # 1.7
        else:
            spectrum = torch.fft.rfftn(image_t,
                                       s=(h, w),
                                       dim=[2, 3],
                                       norm='ortho')  # 1.8
            spectrum = torch.view_as_real(spectrum)
        spectrum = un_spectrum(spectrum, decay_power=decay)
        spectrum *= 500000.  # [sic!!!]
    return spectrum
示例#4
0
    def process(num):
        global params_tmp, opt_state, params, image_f, optimizer

        if a.interpol is True: # linear topics interpolation
            txt_encs  = get_encs(key_txt_encs,  num)
            styl_encs = get_encs(key_styl_encs, num)
            img_encs  = get_encs(key_img_encs,  num)
        else: # change by cut
            txt_encs  = [key_txt_encs[min(num,  len(key_txt_encs)-1)][0]]  * steps if len(key_txt_encs)  > 0 else []
            styl_encs = [key_styl_encs[min(num, len(key_styl_encs)-1)][0]] * steps if len(key_styl_encs) > 0 else []
            img_encs  = [key_img_encs[min(num,  len(key_img_encs)-1)][0]]  * steps if len(key_img_encs)  > 0 else []
        
        if a.verbose is True: 
            if len(texts)  > 0: print(' ref text: ',  texts[min(num, len(texts)-1)][:80])
            if len(styles) > 0: print(' ref style: ', styles[min(num, len(styles)-1)][:80])
            if len(images) > 0: print(' ref image: ', basename(images[min(num, len(images)-1)])[:80])
        
        pbar = ProgressBar(steps)
        for ii in range(steps):
            glob_step = num * steps + ii # save/transform
            
            txt_enc  = txt_encs[ii % len(txt_encs)].unsqueeze(0)   if len(txt_encs)  > 0 else None
            styl_enc = styl_encs[ii % len(styl_encs)].unsqueeze(0) if len(styl_encs) > 0 else None
            img_enc  = img_encs[ii % len(img_encs)].unsqueeze(0)   if len(img_encs)  > 0 else None

            # MOTION: transform frame, reload params

            scale = m_scale[glob_step]    if a.anima else 1 + a.scale
            shift = m_shift[glob_step]    if a.anima else [0, a.shift]
            angle = m_angle[glob_step][0] if a.anima else a.angle
            shear = m_shear[glob_step][0] if a.anima else a.shear

            if a.gen == 'RGB':
                if a.depth > 0:
                    params_tmp = depth_transform(params_tmp, depth_infer, depth_mask, a.size, a.depth, scale, shift, a.colors, a.depth_dir, glob_step)
                params_tmp = frame_transform(params_tmp, a.size, angle, shift, scale, shear)
                params, image_f, _ = pixel_image([1, 3, *a.size], resume=params_tmp)
                img_tmp = None

            else: # FFT
                if old_torch(): # 1.7.1
                    img_tmp = torch.irfft(params_tmp, 2, normalized=True, signal_sizes=a.size)
                    if a.depth > 0:
                        img_tmp = depth_transform(img_tmp, depth_infer, depth_mask, a.size, a.depth, scale, shift, a.colors, a.depth_dir, glob_step)
                    img_tmp = frame_transform(img_tmp, a.size, angle, shift, scale, shear)
                    params_tmp = torch.rfft(img_tmp, 2, normalized=True)
                else: # 1.8+
                    if type(params_tmp) is not torch.complex64:
                        params_tmp = torch.view_as_complex(params_tmp)
                    img_tmp = torch.fft.irfftn(params_tmp, s=a.size, norm='ortho')
                    if a.depth > 0:
                        img_tmp = depth_transform(img_tmp, depth_infer, depth_mask, a.size, a.depth, scale, shift, a.colors, a.depth_dir, glob_step)
                    img_tmp = frame_transform(img_tmp, a.size, angle, shift, scale, shear)
                    params_tmp = torch.fft.rfftn(img_tmp, s=a.size, dim=[2,3], norm='ortho')
                    params_tmp = torch.view_as_real(params_tmp)
                params, image_f, _ = fft_image([1, 3, *a.size], sd=1, resume=params_tmp)

            if a.optimizer.lower() == 'adamw':
                optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01)
            elif a.optimizer.lower() == 'adamw_custom':
                optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01, betas=(.0,.999), amsgrad=True)
            elif a.optimizer.lower() == 'adam':
                optimizer = torch.optim.Adam(params, a.lrate)
            else: # adam_custom
                optimizer = torch.optim.Adam(params, a.lrate, betas=(.0,.999))
            image_f = to_valid_rgb(image_f, colors = a.colors)
            del img_tmp

            if a.smooth is True and num + ii > 0:
                optimizer.load_state_dict(opt_state)

            ### optimization
            for ss in range(a.opt_step):
                loss = 0

                noise = a.noise * (torch.rand(1, 1, a.size[0], a.size[1]//2+1, 1)-0.5).cuda() if a.noise>0 else 0.
                img_out = image_f(noise, fixcontrast=a.fixcontrast)
                
                img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0]
                out_enc = model_clip.encode_image(img_sliced)

                if a.gen == 'RGB': # empirical hack
                    loss += abs(img_out.mean((2,3)) - 0.45).mean() # fix brightness
                    loss += abs(img_out.std((2,3)) - 0.17).mean() # fix contrast

                if txt_enc is not None:
                    loss -= a.invert * sim_func(txt_enc, out_enc, a.sim)
                if styl_enc is not None:
                    loss -= a.weight2 * sim_func(styl_enc, out_enc, a.sim)
                if img_enc is not None:
                    loss -= a.weight_img * sim_func(img_enc, out_enc, a.sim)
                if a.in_txt0 is not None: # subtract text
                    for anti_txt_enc in anti_txt_encs:
                        loss += 0.3 * sim_func(anti_txt_enc, out_enc, a.sim)
                if a.sharp != 0: # scharr|sobel|naive
                    loss -= a.sharp * derivat(img_out, mode='naive')
                if a.enforce != 0:
                    img_sliced = slice_imgs([image_f(noise, fixcontrast=a.fixcontrast)], a.samples, a.modsize, trform_f, a.align, a.macro)[0]
                    out_enc2 = model_clip.encode_image(img_sliced)
                    loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim)
                    del out_enc2; torch.cuda.empty_cache()
                if a.expand > 0:
                    global prev_enc
                    if ii > 0:
                        loss += a.expand * sim_func(prev_enc, out_enc, a.sim)
                    prev_enc = out_enc.detach().clone()
                del img_out, img_sliced, out_enc; torch.cuda.empty_cache()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            ### save params & frame

            params_tmp = params[0].detach().clone()
            if a.smooth is True:
                opt_state = optimizer.state_dict()

            with torch.no_grad():
                img_t = image_f(contrast=a.contrast, fixcontrast=a.fixcontrast)[0].permute(1,2,0)
                img_np = torch.clip(img_t*255, 0, 255).cpu().numpy().astype(np.uint8)
            imsave(os.path.join(tempdir, '%06d.jpg' % glob_step), img_np, quality=95)
            if a.verbose is True: cvshow(img_np)
            del img_t, img_np
            pbar.upd()

        params_tmp = params[0].detach().clone()
示例#5
0
def main():
    a = get_args()
    
    # Load CLIP models
    model_clip, _ = clip.load(a.model, jit=old_torch())
    try:
        a.modsize = model_clip.visual.input_resolution 
    except:
        a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 224
    if a.verbose is True: print(' using model', a.model)
    xmem = {'ViT-B/16':0.25, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN101':0.33}
    if a.model in xmem.keys():
        a.samples = int(a.samples * xmem[a.model])

    if a.translate:
        translator = Translator()

    if a.enforce != 0:
        a.samples = int(a.samples * 0.5)

    if 'elastic' in a.transform:
        trform_f = transforms.transforms_elastic
        a.samples = int(a.samples * 0.95)
    elif 'custom' in a.transform:
        trform_f = transforms.transforms_custom
        a.samples = int(a.samples * 0.95)
    elif 'fast' in a.transform:
        trform_f = transforms.transforms_fast
        a.samples = int(a.samples * 0.95)
    else:
        trform_f = transforms.normalize()

    def enc_text(txt):
        if a.translate:
            txt = translator.translate(txt, dest='en').text
        emb = model_clip.encode_text(clip.tokenize(txt).cuda()[:77])
        return emb.detach().clone()

    def enc_image(img_file):
        img_t = torch.from_numpy(img_read(img_file)/255.).unsqueeze(0).permute(0,3,1,2).cuda()[:,:3,:,:]
        in_sliced = slice_imgs([img_t], a.samples, a.modsize, transforms.normalize(), a.align)[0]
        emb = model_clip.encode_image(in_sliced)
        return emb.detach().clone()

    # Encode inputs
    count = 0
    texts = []
    styles = []
    images = []
    
    if a.in_txt is not None:
        if os.path.isfile(a.in_txt):
            with open(a.in_txt, 'r', encoding="utf-8") as f:
                texts = f.readlines()
                texts = [tt.strip() for tt in texts if len(tt.strip()) > 0 and tt[0] != '#']
        else:
            texts = [a.in_txt]
    if a.in_txt_pre is not None:
        texts = [' '.join([a.in_txt_pre, tt]).strip() for tt in texts]
    if a.in_txt_post is not None:
        texts = [' '.join([tt, a.in_txt_post]).strip() for tt in texts]
    key_txt_encs = [enc_text(txt) for txt in texts]
    count = max(count, len(key_txt_encs))

    if a.in_txt2 is not None:
        if os.path.isfile(a.in_txt2):
            with open(a.in_txt2, 'r', encoding="utf-8") as f:
                styles = f.readlines()
                styles = [tt.strip() for tt in styles if len(tt.strip()) > 0 and tt[0] != '#']
        else:
            styles = [a.in_txt2]
    key_styl_encs = [enc_text(style) for style in styles]
    count = max(count, len(key_styl_encs))

    if a.in_img is not None and os.path.exists(a.in_img):
        images = file_list(a.in_img) if os.path.isdir(a.in_img) else [a.in_img]
    key_img_encs = [enc_image(image) for image in images]
    count = max(count, len(key_img_encs))
    
    assert count > 0, "No inputs found!"
    
    if a.in_txt0 is not None:
        if a.verbose is True: print(' subtract text:', a.in_txt0)
        if a.translate:
            a.in_txt0 = translator.translate(a.in_txt0, dest='en').text
        anti_txt_encs = [enc_text(txt) for txt in a.in_txt0.split('.')]

    if a.verbose is True: print(' samples:', a.samples)

    global params_tmp
    shape = [1, 3, *a.size]

    if a.gen == 'RGB':
        params_tmp, _, sz = pixel_image(shape, a.resume)
        params_tmp = params_tmp[0].cuda().detach()
    else:
        params_tmp, sz = resume_fft(a.resume, shape, decay=1.5, sd=1)
    if sz is not None: a.size = sz

    if a.depth != 0:
        depth_infer, depth_mask = depth.init_adabins(size=a.size, model_path=a.depth_model, mask_path=a.depth_mask, tridepth=a.tridepth)
        if a.depth_dir is not None:
            os.makedirs(a.depth_dir, exist_ok=True)
            print(' depth dir:', a.depth_dir)

    steps = a.steps
    glob_steps = count * steps
    if glob_steps == a.fstep: a.fstep = glob_steps // 2 # otherwise no motion

    workname = basename(a.in_txt) if a.in_txt is not None else basename(a.in_img)
    workname = txt_clean(workname)
    workdir = os.path.join(a.out_dir, workname + '-%s' % a.gen.lower())
    if a.rem is not None:        workdir += '-%s' % a.rem
    if 'RN' in a.model.upper():  workdir += '-%s' % a.model
    tempdir = os.path.join(workdir, 'ttt')
    os.makedirs(tempdir, exist_ok=True)
    save_cfg(a, workdir)
    if a.in_txt is not None and os.path.isfile(a.in_txt):
        shutil.copy(a.in_txt, os.path.join(workdir, os.path.basename(a.in_txt)))
    if a.in_txt2 is not None and os.path.isfile(a.in_txt2):
        shutil.copy(a.in_txt2, os.path.join(workdir, os.path.basename(a.in_txt2)))

    midp = 0.5
    if a.anima:
        if a.gen == 'RGB': # zoom in
            m_scale = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[-0.3], verbose=False)
            m_scale = 1 + (m_scale + 0.3) * a.scale
        else:
            m_scale = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[0.6],  verbose=False)
            m_scale = 1 - (m_scale-0.6) * a.scale
        m_shift = latent_anima([2], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[midp,midp], verbose=False)
        m_angle = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[midp],    verbose=False)
        m_shear = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[midp],    verbose=False)
        m_shift = (midp-m_shift) * a.shift * abs(m_scale-1) / a.scale
        m_angle = (midp-m_angle) * a.angle * abs(m_scale-1) / a.scale
        m_shear = (midp-m_shear) * a.shear * abs(m_scale-1) / a.scale
    
    def get_encs(encs, num):
        cnt = len(encs)
        if cnt == 0: return []
        enc_1 = encs[min(num,   cnt-1)]
        enc_2 = encs[min(num+1, cnt-1)]
        return slerp(enc_1, enc_2, steps)

    prev_enc = 0
    def process(num):
        global params_tmp, opt_state, params, image_f, optimizer

        if a.interpol is True: # linear topics interpolation
            txt_encs  = get_encs(key_txt_encs,  num)
            styl_encs = get_encs(key_styl_encs, num)
            img_encs  = get_encs(key_img_encs,  num)
        else: # change by cut
            txt_encs  = [key_txt_encs[min(num,  len(key_txt_encs)-1)][0]]  * steps if len(key_txt_encs)  > 0 else []
            styl_encs = [key_styl_encs[min(num, len(key_styl_encs)-1)][0]] * steps if len(key_styl_encs) > 0 else []
            img_encs  = [key_img_encs[min(num,  len(key_img_encs)-1)][0]]  * steps if len(key_img_encs)  > 0 else []
        
        if a.verbose is True: 
            if len(texts)  > 0: print(' ref text: ',  texts[min(num, len(texts)-1)][:80])
            if len(styles) > 0: print(' ref style: ', styles[min(num, len(styles)-1)][:80])
            if len(images) > 0: print(' ref image: ', basename(images[min(num, len(images)-1)])[:80])
        
        pbar = ProgressBar(steps)
        for ii in range(steps):
            glob_step = num * steps + ii # save/transform
            
            txt_enc  = txt_encs[ii % len(txt_encs)].unsqueeze(0)   if len(txt_encs)  > 0 else None
            styl_enc = styl_encs[ii % len(styl_encs)].unsqueeze(0) if len(styl_encs) > 0 else None
            img_enc  = img_encs[ii % len(img_encs)].unsqueeze(0)   if len(img_encs)  > 0 else None

            # MOTION: transform frame, reload params

            scale = m_scale[glob_step]    if a.anima else 1 + a.scale
            shift = m_shift[glob_step]    if a.anima else [0, a.shift]
            angle = m_angle[glob_step][0] if a.anima else a.angle
            shear = m_shear[glob_step][0] if a.anima else a.shear

            if a.gen == 'RGB':
                if a.depth > 0:
                    params_tmp = depth_transform(params_tmp, depth_infer, depth_mask, a.size, a.depth, scale, shift, a.colors, a.depth_dir, glob_step)
                params_tmp = frame_transform(params_tmp, a.size, angle, shift, scale, shear)
                params, image_f, _ = pixel_image([1, 3, *a.size], resume=params_tmp)
                img_tmp = None

            else: # FFT
                if old_torch(): # 1.7.1
                    img_tmp = torch.irfft(params_tmp, 2, normalized=True, signal_sizes=a.size)
                    if a.depth > 0:
                        img_tmp = depth_transform(img_tmp, depth_infer, depth_mask, a.size, a.depth, scale, shift, a.colors, a.depth_dir, glob_step)
                    img_tmp = frame_transform(img_tmp, a.size, angle, shift, scale, shear)
                    params_tmp = torch.rfft(img_tmp, 2, normalized=True)
                else: # 1.8+
                    if type(params_tmp) is not torch.complex64:
                        params_tmp = torch.view_as_complex(params_tmp)
                    img_tmp = torch.fft.irfftn(params_tmp, s=a.size, norm='ortho')
                    if a.depth > 0:
                        img_tmp = depth_transform(img_tmp, depth_infer, depth_mask, a.size, a.depth, scale, shift, a.colors, a.depth_dir, glob_step)
                    img_tmp = frame_transform(img_tmp, a.size, angle, shift, scale, shear)
                    params_tmp = torch.fft.rfftn(img_tmp, s=a.size, dim=[2,3], norm='ortho')
                    params_tmp = torch.view_as_real(params_tmp)
                params, image_f, _ = fft_image([1, 3, *a.size], sd=1, resume=params_tmp)

            if a.optimizer.lower() == 'adamw':
                optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01)
            elif a.optimizer.lower() == 'adamw_custom':
                optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01, betas=(.0,.999), amsgrad=True)
            elif a.optimizer.lower() == 'adam':
                optimizer = torch.optim.Adam(params, a.lrate)
            else: # adam_custom
                optimizer = torch.optim.Adam(params, a.lrate, betas=(.0,.999))
            image_f = to_valid_rgb(image_f, colors = a.colors)
            del img_tmp

            if a.smooth is True and num + ii > 0:
                optimizer.load_state_dict(opt_state)

            ### optimization
            for ss in range(a.opt_step):
                loss = 0

                noise = a.noise * (torch.rand(1, 1, a.size[0], a.size[1]//2+1, 1)-0.5).cuda() if a.noise>0 else 0.
                img_out = image_f(noise, fixcontrast=a.fixcontrast)
                
                img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0]
                out_enc = model_clip.encode_image(img_sliced)

                if a.gen == 'RGB': # empirical hack
                    loss += abs(img_out.mean((2,3)) - 0.45).mean() # fix brightness
                    loss += abs(img_out.std((2,3)) - 0.17).mean() # fix contrast

                if txt_enc is not None:
                    loss -= a.invert * sim_func(txt_enc, out_enc, a.sim)
                if styl_enc is not None:
                    loss -= a.weight2 * sim_func(styl_enc, out_enc, a.sim)
                if img_enc is not None:
                    loss -= a.weight_img * sim_func(img_enc, out_enc, a.sim)
                if a.in_txt0 is not None: # subtract text
                    for anti_txt_enc in anti_txt_encs:
                        loss += 0.3 * sim_func(anti_txt_enc, out_enc, a.sim)
                if a.sharp != 0: # scharr|sobel|naive
                    loss -= a.sharp * derivat(img_out, mode='naive')
                if a.enforce != 0:
                    img_sliced = slice_imgs([image_f(noise, fixcontrast=a.fixcontrast)], a.samples, a.modsize, trform_f, a.align, a.macro)[0]
                    out_enc2 = model_clip.encode_image(img_sliced)
                    loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim)
                    del out_enc2; torch.cuda.empty_cache()
                if a.expand > 0:
                    global prev_enc
                    if ii > 0:
                        loss += a.expand * sim_func(prev_enc, out_enc, a.sim)
                    prev_enc = out_enc.detach().clone()
                del img_out, img_sliced, out_enc; torch.cuda.empty_cache()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            ### save params & frame

            params_tmp = params[0].detach().clone()
            if a.smooth is True:
                opt_state = optimizer.state_dict()

            with torch.no_grad():
                img_t = image_f(contrast=a.contrast, fixcontrast=a.fixcontrast)[0].permute(1,2,0)
                img_np = torch.clip(img_t*255, 0, 255).cpu().numpy().astype(np.uint8)
            imsave(os.path.join(tempdir, '%06d.jpg' % glob_step), img_np, quality=95)
            if a.verbose is True: cvshow(img_np)
            del img_t, img_np
            pbar.upd()

        params_tmp = params[0].detach().clone()
        
    glob_start = time.time()
    try:
        for i in range(count):
            process(i)
    except KeyboardInterrupt:
        pass

    os.system('ffmpeg -v warning -y -i %s/\%%06d.jpg "%s.mp4"' % (tempdir, os.path.join(workdir, workname)))
示例#6
0
def main():
    a = get_args()

    shape = [1, 3, *a.size]
    if a.dwt is True:
        params, image_f, sz = dwt_image(shape, a.wave, 0.3, a.colors, a.resume)
    else:
        params, image_f, sz = fft_image(shape, 0.07, a.decay, a.resume)
    if sz is not None: a.size = sz
    image_f = to_valid_rgb(image_f, colors=a.colors)

    if a.prog is True:
        lr1 = a.lrate * 2
        lr0 = lr1 * 0.01
    else:
        lr0 = a.lrate
    if a.optimizer.lower() == 'adamw':
        optimizer = torch.optim.AdamW(params, lr0, weight_decay=0.01)
    elif a.optimizer.lower() == 'adamw_custom':
        optimizer = torch.optim.AdamW(params,
                                      lr0,
                                      weight_decay=0.01,
                                      betas=(.0, .999),
                                      amsgrad=True)
    elif a.optimizer.lower() == 'adam':
        optimizer = torch.optim.Adam(params, lr0)
    else:  # adam_custom
        optimizer = torch.optim.Adam(params, lr0, betas=(.0, .999))
    sign = 1. if a.invert is True else -1.

    # Load CLIP models
    model_clip, _ = clip.load(a.model, jit=old_torch())
    try:
        a.modsize = model_clip.visual.input_resolution
    except:
        a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 224
    if a.verbose is True: print(' using model', a.model)
    xmem = {
        'ViT-B/16': 0.25,
        'RN50': 0.5,
        'RN50x4': 0.16,
        'RN50x16': 0.06,
        'RN101': 0.33
    }
    if a.model in xmem.keys():
        a.samples = int(a.samples * xmem[a.model])

    if a.multilang is True:
        model_lang = SentenceTransformer(
            'clip-ViT-B-32-multilingual-v1').cuda()

    def enc_text(txt):
        if a.multilang is True:
            emb = model_lang.encode([txt],
                                    convert_to_tensor=True,
                                    show_progress_bar=False)
        else:
            emb = model_clip.encode_text(clip.tokenize(txt).cuda())
        return emb.detach().clone()

    if a.enforce != 0:
        a.samples = int(a.samples * 0.5)
    if a.sync > 0:
        a.samples = int(a.samples * 0.5)

    if 'elastic' in a.transform:
        trform_f = transforms.transforms_elastic
        a.samples = int(a.samples * 0.95)
    elif 'custom' in a.transform:
        trform_f = transforms.transforms_custom
        a.samples = int(a.samples * 0.95)
    elif 'fast' in a.transform:
        trform_f = transforms.transforms_fast
        a.samples = int(a.samples * 0.95)
    else:
        trform_f = transforms.normalize()

    out_name = []
    if a.in_txt is not None:
        if a.verbose is True: print(' topic text: ', a.in_txt)
        if a.translate:
            translator = Translator()
            a.in_txt = translator.translate(a.in_txt, dest='en').text
            if a.verbose is True: print(' translated to:', a.in_txt)
        txt_enc = enc_text(a.in_txt)
        out_name.append(txt_clean(a.in_txt).lower()[:40])

        if a.notext > 0:
            txt_plot = torch.from_numpy(plot_text(a.in_txt, a.modsize) /
                                        255.).unsqueeze(0).permute(0, 3, 1,
                                                                   2).cuda()
            txt_plot_enc = model_clip.encode_image(txt_plot).detach().clone()

    if a.in_txt2 is not None:
        if a.verbose is True: print(' style text:', a.in_txt2)
        a.samples = int(a.samples * 0.75)
        if a.translate:
            translator = Translator()
            a.in_txt2 = translator.translate(a.in_txt2, dest='en').text
            if a.verbose is True: print(' translated to:', a.in_txt2)
        txt_enc2 = enc_text(a.in_txt2)
        out_name.append(txt_clean(a.in_txt2).lower()[:40])

    if a.in_txt0 is not None:
        if a.verbose is True: print(' subtract text:', a.in_txt0)
        a.samples = int(a.samples * 0.75)
        if a.translate:
            translator = Translator()
            a.in_txt0 = translator.translate(a.in_txt0, dest='en').text
            if a.verbose is True: print(' translated to:', a.in_txt0)
        txt_enc0 = enc_text(a.in_txt0)
        out_name.append('off-' + txt_clean(a.in_txt0).lower()[:40])

    if a.multilang is True: del model_lang

    if a.in_img is not None and os.path.isfile(a.in_img):
        if a.verbose is True: print(' ref image:', basename(a.in_img))
        img_in = torch.from_numpy(
            img_read(a.in_img) / 255.).unsqueeze(0).permute(0, 3, 1, 2).cuda()
        img_in = img_in[:, :3, :, :]  # fix rgb channels
        in_sliced = slice_imgs([img_in], a.samples, a.modsize,
                               transforms.normalize(), a.align)[0]
        img_enc = model_clip.encode_image(in_sliced).detach().clone()
        if a.sync > 0:
            sim_loss = lpips.LPIPS(net='vgg', verbose=False).cuda()
            sim_size = [s // 2 for s in a.size]
            img_in = F.interpolate(img_in,
                                   sim_size,
                                   mode='bicubic',
                                   align_corners=True).float()
        else:
            del img_in
        del in_sliced
        torch.cuda.empty_cache()
        out_name.append(basename(a.in_img).replace(' ', '_'))

    if a.verbose is True: print(' samples:', a.samples)
    out_name = '-'.join(out_name)
    out_name += '-%s' % a.model if 'RN' in a.model.upper() else ''
    tempdir = os.path.join(a.out_dir, out_name)
    os.makedirs(tempdir, exist_ok=True)

    prev_enc = 0

    def train(i):
        loss = 0

        noise = a.noise * torch.rand(1, 1, *params[0].shape[2:4],
                                     1).cuda() if a.noise > 0 else None
        img_out = image_f(noise)
        img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f,
                                a.align, a.macro)[0]
        out_enc = model_clip.encode_image(img_sliced)

        if a.in_txt is not None:  # input text
            loss += sign * sim_func(txt_enc, out_enc, a.sim)
            if a.notext > 0:
                loss -= sign * a.notext * sim_func(txt_plot_enc, out_enc,
                                                   a.sim)
        if a.in_txt2 is not None:  # input text - style
            loss += sign * a.weight2 * sim_func(txt_enc2, out_enc, a.sim)
        if a.in_txt0 is not None:  # subtract text
            loss += -sign * 0.3 * sim_func(txt_enc0, out_enc, a.sim)
        if a.in_img is not None and os.path.isfile(a.in_img):  # input image
            loss += sign * 0.5 * sim_func(img_enc, out_enc, a.sim)
        if a.sync > 0 and a.in_img is not None and os.path.isfile(
                a.in_img):  # image composition
            prog_sync = (a.steps // a.opt_step - i) / (a.steps // a.opt_step)
            loss += prog_sync * a.sync * sim_loss(F.interpolate(
                img_out, sim_size, mode='bicubic', align_corners=True).float(),
                                                  img_in,
                                                  normalize=True).squeeze()
        if a.sharp != 0 and a.dwt is not True:  # scharr|sobel|default
            loss -= a.sharp * derivat(img_out, mode='naiv')
            # loss -= a.sharp * derivat(img_sliced, mode='scharr')
        if a.enforce != 0:
            img_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize,
                                    trform_f, a.align, a.macro)[0]
            out_enc2 = model_clip.encode_image(img_sliced)
            loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim)
            del out_enc2
            torch.cuda.empty_cache()
        if a.expand > 0:
            global prev_enc
            if i > 0:
                loss += a.expand * sim_func(out_enc, prev_enc, a.sim)
            prev_enc = out_enc.detach()  # .clone()

        del img_out, img_sliced, out_enc
        torch.cuda.empty_cache()
        assert not isinstance(loss, int), ' Loss not defined, check the inputs'

        if a.prog is True:
            lr_cur = lr0 + (i / a.steps) * (lr1 - lr0)
            for g in optimizer.param_groups:
                g['lr'] = lr_cur

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % a.opt_step == 0:
            with torch.no_grad():
                img = image_f(contrast=a.contrast).cpu().numpy()[0]
            # empirical tone mapping
            if (a.sync > 0 and a.in_img is not None):
                img = img**1.3
            elif a.sharp != 0:
                img = img**(1 + a.sharp / 2.)
            checkout(img,
                     os.path.join(tempdir, '%04d.jpg' % (i // a.opt_step)),
                     verbose=a.verbose)
            pbar.upd()

    pbar = ProgressBar(a.steps // a.opt_step)
    for i in range(a.steps):
        train(i)

    os.system('ffmpeg -v warning -y -i %s/\%%04d.jpg "%s.mp4"' %
              (tempdir, os.path.join(a.out_dir, out_name)))
    shutil.copy(
        img_list(tempdir)[-1],
        os.path.join(a.out_dir, '%s-%d.jpg' % (out_name, a.steps)))
    if a.save_pt is True:
        torch.save(params, '%s.pt' % os.path.join(a.out_dir, out_name))
示例#7
0
def main():
    a = get_args()

    # Load CLIP models
    model_clip, _ = clip.load(a.model, jit=old_torch())
    try:
        a.modsize = model_clip.visual.input_resolution
    except:
        a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 224
    if a.verbose is True: print(' using model', a.model)
    xmem = {
        'ViT-B/16': 0.25,
        'RN50': 0.5,
        'RN50x4': 0.16,
        'RN50x16': 0.06,
        'RN101': 0.33
    }
    if a.model in xmem.keys():
        a.samples = int(a.samples * xmem[a.model])
    workdir = os.path.join(a.out_dir, basename(a.in_txt))
    workdir += '-%s' % a.model if 'RN' in a.model.upper() else ''
    os.makedirs(workdir, exist_ok=True)

    def enc_text(txt):
        if a.multilang is True:
            model_lang = SentenceTransformer(
                'clip-ViT-B-32-multilingual-v1').cuda()
            emb = model_lang.encode([txt],
                                    convert_to_tensor=True,
                                    show_progress_bar=False)
            del model_lang
        else:
            emb = model_clip.encode_text(clip.tokenize(txt).cuda())
        return emb.detach().clone()

    if a.enforce != 0:
        a.samples = int(a.samples * 0.5)

    if 'elastic' in a.transform:
        trform_f = transforms.transforms_elastic
        a.samples = int(a.samples * 0.95)
    elif 'custom' in a.transform:
        trform_f = transforms.transforms_custom
        a.samples = int(a.samples * 0.95)
    elif 'fast' in a.transform:
        trform_f = transforms.transforms_fast
        a.samples = int(a.samples * 0.95)
    else:
        trform_f = transforms.normalize()

    if a.in_txt2 is not None:
        if a.verbose is True: print(' style:', basename(a.in_txt2))
        # a.samples = int(a.samples * 0.75)
        if a.translate:
            translator = Translator()
            a.in_txt2 = translator.translate(a.in_txt2, dest='en').text
            if a.verbose is True: print(' translated to:', a.in_txt2)
        txt_enc2 = enc_text(a.in_txt2)

    if a.in_txt0 is not None:
        if a.verbose is True: print(' subtract text:', basename(a.in_txt0))
        if a.translate:
            translator = Translator()
            a.in_txt0 = translator.translate(a.in_txt0, dest='en').text
            if a.verbose is True: print(' translated to:', a.in_txt0)
        txt_enc0 = enc_text(a.in_txt0)

    # make init
    global params_start, params_ema
    params_shape = [1, 3, a.size[0], a.size[1] // 2 + 1, 2]
    params_start = torch.randn(*params_shape).cuda()  # random init
    params_ema = 0.
    if a.resume is not None and os.path.isfile(a.resume):
        if a.verbose is True: print(' resuming from', a.resume)
        params_start = load_params(a.resume).cuda()
        if a.keep > 0:
            params_ema = params_start[0].detach().clone()
    else:
        a.resume = 'init.pt'

    torch.save(params_start, 'init.pt')  # final init
    shutil.copy(a.resume,
                os.path.join(workdir, '000-%s.pt' % basename(a.resume)))

    prev_enc = 0

    def process(txt, num):

        sd = 0.01
        if a.keep > 0: sd = a.keep + (1 - a.keep) * sd
        params, image_f, _ = fft_image([1, 3, *a.size],
                                       resume='init.pt',
                                       sd=sd,
                                       decay_power=a.decay)
        image_f = to_valid_rgb(image_f, colors=a.colors)

        if a.prog is True:
            lr1 = a.lrate * 2
            lr0 = a.lrate * 0.1
        else:
            lr0 = a.lrate
        optimizer = torch.optim.AdamW(params,
                                      lr0,
                                      weight_decay=0.01,
                                      amsgrad=True)

        if a.verbose is True: print(' topic: ', txt)
        if a.translate:
            translator = Translator()
            txt = translator.translate(txt, dest='en').text
            if a.verbose is True: print(' translated to:', txt)
        txt_enc = enc_text(txt)
        if a.notext > 0:
            txt_plot = torch.from_numpy(plot_text(txt, a.modsize) /
                                        255.).unsqueeze(0).permute(0, 3, 1,
                                                                   2).cuda()
            txt_plot_enc = model_clip.encode_image(txt_plot).detach().clone()
        else:
            txt_plot_enc = None

        out_name = '%03d-%s' % (num + 1, txt_clean(txt))
        out_name += '-%s' % a.model if 'RN' in a.model.upper() else ''
        tempdir = os.path.join(workdir, out_name)
        os.makedirs(tempdir, exist_ok=True)

        pbar = ProgressBar(a.steps // a.fstep)
        for i in range(a.steps):
            loss = 0
            noise = a.noise * torch.randn(1, 1, *params[0].shape[2:4],
                                          1).cuda() if a.noise > 0 else None
            img_out = image_f(noise)
            img_sliced = slice_imgs([img_out],
                                    a.samples,
                                    a.modsize,
                                    trform_f,
                                    a.align,
                                    macro=a.macro)[0]
            out_enc = model_clip.encode_image(img_sliced)

            loss -= torch.cosine_similarity(txt_enc, out_enc, dim=-1).mean()
            if a.in_txt2 is not None:  # input text - style
                loss -= 0.5 * torch.cosine_similarity(
                    txt_enc2, out_enc, dim=-1).mean()
            if a.in_txt0 is not None:  # subtract text
                loss += 0.5 * torch.cosine_similarity(
                    txt_enc0, out_enc, dim=-1).mean()
            if a.notext > 0:
                loss += a.notext * torch.cosine_similarity(
                    txt_plot_enc, out_enc, dim=-1).mean()
            if a.sharp != 0:  # mode = scharr|sobel|default
                loss -= a.sharp * derivat(img_out, mode='sobel')
                # loss -= a.sharp * derivat(img_sliced, mode='scharr')
            if a.enforce != 0:
                img_sliced = slice_imgs([image_f(noise)],
                                        a.samples,
                                        a.modsize,
                                        trform_f,
                                        a.align,
                                        macro=a.macro)[0]
                out_enc2 = model_clip.encode_image(img_sliced)
                loss -= a.enforce * torch.cosine_similarity(
                    out_enc, out_enc2, dim=-1).mean()
                del out_enc2
                torch.cuda.empty_cache()
            if a.expand > 0:
                global prev_enc
                if i > 0:
                    loss += a.expand * torch.cosine_similarity(
                        out_enc, prev_enc, dim=-1).mean()
                prev_enc = out_enc.detach().clone()
            del img_out, img_sliced, out_enc
            torch.cuda.empty_cache()

            if a.prog is True:
                lr_cur = lr0 + (i / a.steps) * (lr1 - lr0)
                for g in optimizer.param_groups:
                    g['lr'] = lr_cur

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % a.fstep == 0:
                with torch.no_grad():
                    img = image_f(contrast=a.contrast).cpu().numpy()[0]
                if a.sharp != 0:
                    img = img**(1 + a.sharp / 2.)  # empirical tone mapping
                checkout(img,
                         os.path.join(tempdir, '%04d.jpg' % (i // a.fstep)),
                         verbose=a.verbose)
                pbar.upd()
                del img

        if a.keep > 0:
            global params_start, params_ema
            params_ema = ema(params_ema, params[0].detach().clone(), num + 1)
            torch.save((1 - a.keep) * params_start + a.keep * params_ema,
                       'init.pt')

        torch.save(params[0], '%s.pt' % os.path.join(workdir, out_name))
        shutil.copy(
            img_list(tempdir)[-1],
            os.path.join(workdir, '%s-%d.jpg' % (out_name, a.steps)))
        os.system('ffmpeg -v warning -y -i %s\%%04d.jpg "%s.mp4"' %
                  (tempdir, os.path.join(workdir, out_name)))

    with open(a.in_txt, 'r', encoding="utf-8") as f:
        texts = f.readlines()
        texts = [
            tt.strip() for tt in texts if len(tt.strip()) > 0 and tt[0] != '#'
        ]
    if a.verbose is True:
        print(' total lines:', len(texts))
        print(' samples:', a.samples)

    for i, txt in enumerate(texts):
        process(txt, i)

    vsteps = int(a.length * 25 / len(texts))  # 25 fps
    tempdir = os.path.join(workdir, '_final')
    os.makedirs(tempdir, exist_ok=True)

    def read_pt(file):
        return torch.load(file).cuda()

    if a.verbose is True: print(' rendering complete piece')
    ptfiles = file_list(workdir, 'pt')
    pbar = ProgressBar(vsteps * len(ptfiles))
    for px in range(len(ptfiles)):
        params1 = read_pt(ptfiles[px])
        params2 = read_pt(ptfiles[(px + 1) % len(ptfiles)])

        params, image_f, _ = fft_image([1, 3, *a.size],
                                       resume=params1,
                                       sd=1.,
                                       decay_power=a.decay)
        image_f = to_valid_rgb(image_f, colors=a.colors)

        for i in range(vsteps):
            with torch.no_grad():
                img = image_f(
                    (params2 - params1) *
                    math.sin(1.5708 * i / vsteps)**2)[0].permute(1, 2, 0)
                img = torch.clip(img * 255, 0,
                                 255).cpu().numpy().astype(np.uint8)
            imsave(os.path.join(tempdir, '%05d.jpg' % (px * vsteps + i)), img)
            if a.verbose is True: cvshow(img)
            pbar.upd()

    os.system('ffmpeg -v warning -y -i %s/\%%05d.jpg "%s.mp4"' %
              (tempdir, os.path.join(a.out_dir, basename(a.in_txt))))
    if a.keep > 0: os.remove('init.pt')