def un_rgb(image, colors=1.): color_correlation_svd_sqrt = torch.tensor([[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]]) color_correlation_svd_sqrt /= torch.tensor([colors, 1., 1.]) # saturate, empirical max_norm_svd_sqrt = color_correlation_svd_sqrt.norm(dim=0).max() color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt colcorr_t = color_correlation_normalized.T.cuda() colcorr_t_inv = torch.linalg.inv(colcorr_t) if not isinstance(image, torch.Tensor): # numpy int array [0..255] image = torch.Tensor(image).cuda().permute(2, 0, 1).unsqueeze(0) / 255. # image = inv_sigmoid(image) image = normalize()(image) # experimental return torch.einsum('nchw,cd->ndhw', image, colcorr_t_inv) # edit by katherine crowson
def enc_image(img_file): img_t = torch.from_numpy(img_read(img_file)/255.).unsqueeze(0).permute(0,3,1,2).cuda()[:,:3,:,:] in_sliced = slice_imgs([img_t], a.samples, a.modsize, transforms.normalize(), a.align)[0] emb = model_clip.encode_image(in_sliced) return emb.detach().clone()
def main(): a = get_args() # Load CLIP models model_clip, _ = clip.load(a.model, jit=old_torch()) try: a.modsize = model_clip.visual.input_resolution except: a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 224 if a.verbose is True: print(' using model', a.model) xmem = {'ViT-B/16':0.25, 'RN50':0.5, 'RN50x4':0.16, 'RN50x16':0.06, 'RN101':0.33} if a.model in xmem.keys(): a.samples = int(a.samples * xmem[a.model]) if a.translate: translator = Translator() if a.enforce != 0: a.samples = int(a.samples * 0.5) if 'elastic' in a.transform: trform_f = transforms.transforms_elastic a.samples = int(a.samples * 0.95) elif 'custom' in a.transform: trform_f = transforms.transforms_custom a.samples = int(a.samples * 0.95) elif 'fast' in a.transform: trform_f = transforms.transforms_fast a.samples = int(a.samples * 0.95) else: trform_f = transforms.normalize() def enc_text(txt): if a.translate: txt = translator.translate(txt, dest='en').text emb = model_clip.encode_text(clip.tokenize(txt).cuda()[:77]) return emb.detach().clone() def enc_image(img_file): img_t = torch.from_numpy(img_read(img_file)/255.).unsqueeze(0).permute(0,3,1,2).cuda()[:,:3,:,:] in_sliced = slice_imgs([img_t], a.samples, a.modsize, transforms.normalize(), a.align)[0] emb = model_clip.encode_image(in_sliced) return emb.detach().clone() # Encode inputs count = 0 texts = [] styles = [] images = [] if a.in_txt is not None: if os.path.isfile(a.in_txt): with open(a.in_txt, 'r', encoding="utf-8") as f: texts = f.readlines() texts = [tt.strip() for tt in texts if len(tt.strip()) > 0 and tt[0] != '#'] else: texts = [a.in_txt] if a.in_txt_pre is not None: texts = [' '.join([a.in_txt_pre, tt]).strip() for tt in texts] if a.in_txt_post is not None: texts = [' '.join([tt, a.in_txt_post]).strip() for tt in texts] key_txt_encs = [enc_text(txt) for txt in texts] count = max(count, len(key_txt_encs)) if a.in_txt2 is not None: if os.path.isfile(a.in_txt2): with open(a.in_txt2, 'r', encoding="utf-8") as f: styles = f.readlines() styles = [tt.strip() for tt in styles if len(tt.strip()) > 0 and tt[0] != '#'] else: styles = [a.in_txt2] key_styl_encs = [enc_text(style) for style in styles] count = max(count, len(key_styl_encs)) if a.in_img is not None and os.path.exists(a.in_img): images = file_list(a.in_img) if os.path.isdir(a.in_img) else [a.in_img] key_img_encs = [enc_image(image) for image in images] count = max(count, len(key_img_encs)) assert count > 0, "No inputs found!" if a.in_txt0 is not None: if a.verbose is True: print(' subtract text:', a.in_txt0) if a.translate: a.in_txt0 = translator.translate(a.in_txt0, dest='en').text anti_txt_encs = [enc_text(txt) for txt in a.in_txt0.split('.')] if a.verbose is True: print(' samples:', a.samples) global params_tmp shape = [1, 3, *a.size] if a.gen == 'RGB': params_tmp, _, sz = pixel_image(shape, a.resume) params_tmp = params_tmp[0].cuda().detach() else: params_tmp, sz = resume_fft(a.resume, shape, decay=1.5, sd=1) if sz is not None: a.size = sz if a.depth != 0: depth_infer, depth_mask = depth.init_adabins(size=a.size, model_path=a.depth_model, mask_path=a.depth_mask, tridepth=a.tridepth) if a.depth_dir is not None: os.makedirs(a.depth_dir, exist_ok=True) print(' depth dir:', a.depth_dir) steps = a.steps glob_steps = count * steps if glob_steps == a.fstep: a.fstep = glob_steps // 2 # otherwise no motion workname = basename(a.in_txt) if a.in_txt is not None else basename(a.in_img) workname = txt_clean(workname) workdir = os.path.join(a.out_dir, workname + '-%s' % a.gen.lower()) if a.rem is not None: workdir += '-%s' % a.rem if 'RN' in a.model.upper(): workdir += '-%s' % a.model tempdir = os.path.join(workdir, 'ttt') os.makedirs(tempdir, exist_ok=True) save_cfg(a, workdir) if a.in_txt is not None and os.path.isfile(a.in_txt): shutil.copy(a.in_txt, os.path.join(workdir, os.path.basename(a.in_txt))) if a.in_txt2 is not None and os.path.isfile(a.in_txt2): shutil.copy(a.in_txt2, os.path.join(workdir, os.path.basename(a.in_txt2))) midp = 0.5 if a.anima: if a.gen == 'RGB': # zoom in m_scale = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[-0.3], verbose=False) m_scale = 1 + (m_scale + 0.3) * a.scale else: m_scale = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[0.6], verbose=False) m_scale = 1 - (m_scale-0.6) * a.scale m_shift = latent_anima([2], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[midp,midp], verbose=False) m_angle = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[midp], verbose=False) m_shear = latent_anima([1], glob_steps, a.fstep, uniform=True, cubic=True, start_lat=[midp], verbose=False) m_shift = (midp-m_shift) * a.shift * abs(m_scale-1) / a.scale m_angle = (midp-m_angle) * a.angle * abs(m_scale-1) / a.scale m_shear = (midp-m_shear) * a.shear * abs(m_scale-1) / a.scale def get_encs(encs, num): cnt = len(encs) if cnt == 0: return [] enc_1 = encs[min(num, cnt-1)] enc_2 = encs[min(num+1, cnt-1)] return slerp(enc_1, enc_2, steps) prev_enc = 0 def process(num): global params_tmp, opt_state, params, image_f, optimizer if a.interpol is True: # linear topics interpolation txt_encs = get_encs(key_txt_encs, num) styl_encs = get_encs(key_styl_encs, num) img_encs = get_encs(key_img_encs, num) else: # change by cut txt_encs = [key_txt_encs[min(num, len(key_txt_encs)-1)][0]] * steps if len(key_txt_encs) > 0 else [] styl_encs = [key_styl_encs[min(num, len(key_styl_encs)-1)][0]] * steps if len(key_styl_encs) > 0 else [] img_encs = [key_img_encs[min(num, len(key_img_encs)-1)][0]] * steps if len(key_img_encs) > 0 else [] if a.verbose is True: if len(texts) > 0: print(' ref text: ', texts[min(num, len(texts)-1)][:80]) if len(styles) > 0: print(' ref style: ', styles[min(num, len(styles)-1)][:80]) if len(images) > 0: print(' ref image: ', basename(images[min(num, len(images)-1)])[:80]) pbar = ProgressBar(steps) for ii in range(steps): glob_step = num * steps + ii # save/transform txt_enc = txt_encs[ii % len(txt_encs)].unsqueeze(0) if len(txt_encs) > 0 else None styl_enc = styl_encs[ii % len(styl_encs)].unsqueeze(0) if len(styl_encs) > 0 else None img_enc = img_encs[ii % len(img_encs)].unsqueeze(0) if len(img_encs) > 0 else None # MOTION: transform frame, reload params scale = m_scale[glob_step] if a.anima else 1 + a.scale shift = m_shift[glob_step] if a.anima else [0, a.shift] angle = m_angle[glob_step][0] if a.anima else a.angle shear = m_shear[glob_step][0] if a.anima else a.shear if a.gen == 'RGB': if a.depth > 0: params_tmp = depth_transform(params_tmp, depth_infer, depth_mask, a.size, a.depth, scale, shift, a.colors, a.depth_dir, glob_step) params_tmp = frame_transform(params_tmp, a.size, angle, shift, scale, shear) params, image_f, _ = pixel_image([1, 3, *a.size], resume=params_tmp) img_tmp = None else: # FFT if old_torch(): # 1.7.1 img_tmp = torch.irfft(params_tmp, 2, normalized=True, signal_sizes=a.size) if a.depth > 0: img_tmp = depth_transform(img_tmp, depth_infer, depth_mask, a.size, a.depth, scale, shift, a.colors, a.depth_dir, glob_step) img_tmp = frame_transform(img_tmp, a.size, angle, shift, scale, shear) params_tmp = torch.rfft(img_tmp, 2, normalized=True) else: # 1.8+ if type(params_tmp) is not torch.complex64: params_tmp = torch.view_as_complex(params_tmp) img_tmp = torch.fft.irfftn(params_tmp, s=a.size, norm='ortho') if a.depth > 0: img_tmp = depth_transform(img_tmp, depth_infer, depth_mask, a.size, a.depth, scale, shift, a.colors, a.depth_dir, glob_step) img_tmp = frame_transform(img_tmp, a.size, angle, shift, scale, shear) params_tmp = torch.fft.rfftn(img_tmp, s=a.size, dim=[2,3], norm='ortho') params_tmp = torch.view_as_real(params_tmp) params, image_f, _ = fft_image([1, 3, *a.size], sd=1, resume=params_tmp) if a.optimizer.lower() == 'adamw': optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01) elif a.optimizer.lower() == 'adamw_custom': optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01, betas=(.0,.999), amsgrad=True) elif a.optimizer.lower() == 'adam': optimizer = torch.optim.Adam(params, a.lrate) else: # adam_custom optimizer = torch.optim.Adam(params, a.lrate, betas=(.0,.999)) image_f = to_valid_rgb(image_f, colors = a.colors) del img_tmp if a.smooth is True and num + ii > 0: optimizer.load_state_dict(opt_state) ### optimization for ss in range(a.opt_step): loss = 0 noise = a.noise * (torch.rand(1, 1, a.size[0], a.size[1]//2+1, 1)-0.5).cuda() if a.noise>0 else 0. img_out = image_f(noise, fixcontrast=a.fixcontrast) img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0] out_enc = model_clip.encode_image(img_sliced) if a.gen == 'RGB': # empirical hack loss += abs(img_out.mean((2,3)) - 0.45).mean() # fix brightness loss += abs(img_out.std((2,3)) - 0.17).mean() # fix contrast if txt_enc is not None: loss -= a.invert * sim_func(txt_enc, out_enc, a.sim) if styl_enc is not None: loss -= a.weight2 * sim_func(styl_enc, out_enc, a.sim) if img_enc is not None: loss -= a.weight_img * sim_func(img_enc, out_enc, a.sim) if a.in_txt0 is not None: # subtract text for anti_txt_enc in anti_txt_encs: loss += 0.3 * sim_func(anti_txt_enc, out_enc, a.sim) if a.sharp != 0: # scharr|sobel|naive loss -= a.sharp * derivat(img_out, mode='naive') if a.enforce != 0: img_sliced = slice_imgs([image_f(noise, fixcontrast=a.fixcontrast)], a.samples, a.modsize, trform_f, a.align, a.macro)[0] out_enc2 = model_clip.encode_image(img_sliced) loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim) del out_enc2; torch.cuda.empty_cache() if a.expand > 0: global prev_enc if ii > 0: loss += a.expand * sim_func(prev_enc, out_enc, a.sim) prev_enc = out_enc.detach().clone() del img_out, img_sliced, out_enc; torch.cuda.empty_cache() optimizer.zero_grad() loss.backward() optimizer.step() ### save params & frame params_tmp = params[0].detach().clone() if a.smooth is True: opt_state = optimizer.state_dict() with torch.no_grad(): img_t = image_f(contrast=a.contrast, fixcontrast=a.fixcontrast)[0].permute(1,2,0) img_np = torch.clip(img_t*255, 0, 255).cpu().numpy().astype(np.uint8) imsave(os.path.join(tempdir, '%06d.jpg' % glob_step), img_np, quality=95) if a.verbose is True: cvshow(img_np) del img_t, img_np pbar.upd() params_tmp = params[0].detach().clone() glob_start = time.time() try: for i in range(count): process(i) except KeyboardInterrupt: pass os.system('ffmpeg -v warning -y -i %s/\%%06d.jpg "%s.mp4"' % (tempdir, os.path.join(workdir, workname)))
def main(): a = get_args() shape = [1, 3, *a.size] if a.dwt is True: params, image_f, sz = dwt_image(shape, a.wave, 0.3, a.colors, a.resume) else: params, image_f, sz = fft_image(shape, 0.07, a.decay, a.resume) if sz is not None: a.size = sz image_f = to_valid_rgb(image_f, colors=a.colors) if a.prog is True: lr1 = a.lrate * 2 lr0 = lr1 * 0.01 else: lr0 = a.lrate if a.optimizer.lower() == 'adamw': optimizer = torch.optim.AdamW(params, lr0, weight_decay=0.01) elif a.optimizer.lower() == 'adamw_custom': optimizer = torch.optim.AdamW(params, lr0, weight_decay=0.01, betas=(.0, .999), amsgrad=True) elif a.optimizer.lower() == 'adam': optimizer = torch.optim.Adam(params, lr0) else: # adam_custom optimizer = torch.optim.Adam(params, lr0, betas=(.0, .999)) sign = 1. if a.invert is True else -1. # Load CLIP models model_clip, _ = clip.load(a.model, jit=old_torch()) try: a.modsize = model_clip.visual.input_resolution except: a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 224 if a.verbose is True: print(' using model', a.model) xmem = { 'ViT-B/16': 0.25, 'RN50': 0.5, 'RN50x4': 0.16, 'RN50x16': 0.06, 'RN101': 0.33 } if a.model in xmem.keys(): a.samples = int(a.samples * xmem[a.model]) if a.multilang is True: model_lang = SentenceTransformer( 'clip-ViT-B-32-multilingual-v1').cuda() def enc_text(txt): if a.multilang is True: emb = model_lang.encode([txt], convert_to_tensor=True, show_progress_bar=False) else: emb = model_clip.encode_text(clip.tokenize(txt).cuda()) return emb.detach().clone() if a.enforce != 0: a.samples = int(a.samples * 0.5) if a.sync > 0: a.samples = int(a.samples * 0.5) if 'elastic' in a.transform: trform_f = transforms.transforms_elastic a.samples = int(a.samples * 0.95) elif 'custom' in a.transform: trform_f = transforms.transforms_custom a.samples = int(a.samples * 0.95) elif 'fast' in a.transform: trform_f = transforms.transforms_fast a.samples = int(a.samples * 0.95) else: trform_f = transforms.normalize() out_name = [] if a.in_txt is not None: if a.verbose is True: print(' topic text: ', a.in_txt) if a.translate: translator = Translator() a.in_txt = translator.translate(a.in_txt, dest='en').text if a.verbose is True: print(' translated to:', a.in_txt) txt_enc = enc_text(a.in_txt) out_name.append(txt_clean(a.in_txt).lower()[:40]) if a.notext > 0: txt_plot = torch.from_numpy(plot_text(a.in_txt, a.modsize) / 255.).unsqueeze(0).permute(0, 3, 1, 2).cuda() txt_plot_enc = model_clip.encode_image(txt_plot).detach().clone() if a.in_txt2 is not None: if a.verbose is True: print(' style text:', a.in_txt2) a.samples = int(a.samples * 0.75) if a.translate: translator = Translator() a.in_txt2 = translator.translate(a.in_txt2, dest='en').text if a.verbose is True: print(' translated to:', a.in_txt2) txt_enc2 = enc_text(a.in_txt2) out_name.append(txt_clean(a.in_txt2).lower()[:40]) if a.in_txt0 is not None: if a.verbose is True: print(' subtract text:', a.in_txt0) a.samples = int(a.samples * 0.75) if a.translate: translator = Translator() a.in_txt0 = translator.translate(a.in_txt0, dest='en').text if a.verbose is True: print(' translated to:', a.in_txt0) txt_enc0 = enc_text(a.in_txt0) out_name.append('off-' + txt_clean(a.in_txt0).lower()[:40]) if a.multilang is True: del model_lang if a.in_img is not None and os.path.isfile(a.in_img): if a.verbose is True: print(' ref image:', basename(a.in_img)) img_in = torch.from_numpy( img_read(a.in_img) / 255.).unsqueeze(0).permute(0, 3, 1, 2).cuda() img_in = img_in[:, :3, :, :] # fix rgb channels in_sliced = slice_imgs([img_in], a.samples, a.modsize, transforms.normalize(), a.align)[0] img_enc = model_clip.encode_image(in_sliced).detach().clone() if a.sync > 0: sim_loss = lpips.LPIPS(net='vgg', verbose=False).cuda() sim_size = [s // 2 for s in a.size] img_in = F.interpolate(img_in, sim_size, mode='bicubic', align_corners=True).float() else: del img_in del in_sliced torch.cuda.empty_cache() out_name.append(basename(a.in_img).replace(' ', '_')) if a.verbose is True: print(' samples:', a.samples) out_name = '-'.join(out_name) out_name += '-%s' % a.model if 'RN' in a.model.upper() else '' tempdir = os.path.join(a.out_dir, out_name) os.makedirs(tempdir, exist_ok=True) prev_enc = 0 def train(i): loss = 0 noise = a.noise * torch.rand(1, 1, *params[0].shape[2:4], 1).cuda() if a.noise > 0 else None img_out = image_f(noise) img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0] out_enc = model_clip.encode_image(img_sliced) if a.in_txt is not None: # input text loss += sign * sim_func(txt_enc, out_enc, a.sim) if a.notext > 0: loss -= sign * a.notext * sim_func(txt_plot_enc, out_enc, a.sim) if a.in_txt2 is not None: # input text - style loss += sign * a.weight2 * sim_func(txt_enc2, out_enc, a.sim) if a.in_txt0 is not None: # subtract text loss += -sign * 0.3 * sim_func(txt_enc0, out_enc, a.sim) if a.in_img is not None and os.path.isfile(a.in_img): # input image loss += sign * 0.5 * sim_func(img_enc, out_enc, a.sim) if a.sync > 0 and a.in_img is not None and os.path.isfile( a.in_img): # image composition prog_sync = (a.steps // a.opt_step - i) / (a.steps // a.opt_step) loss += prog_sync * a.sync * sim_loss(F.interpolate( img_out, sim_size, mode='bicubic', align_corners=True).float(), img_in, normalize=True).squeeze() if a.sharp != 0 and a.dwt is not True: # scharr|sobel|default loss -= a.sharp * derivat(img_out, mode='naiv') # loss -= a.sharp * derivat(img_sliced, mode='scharr') if a.enforce != 0: img_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize, trform_f, a.align, a.macro)[0] out_enc2 = model_clip.encode_image(img_sliced) loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim) del out_enc2 torch.cuda.empty_cache() if a.expand > 0: global prev_enc if i > 0: loss += a.expand * sim_func(out_enc, prev_enc, a.sim) prev_enc = out_enc.detach() # .clone() del img_out, img_sliced, out_enc torch.cuda.empty_cache() assert not isinstance(loss, int), ' Loss not defined, check the inputs' if a.prog is True: lr_cur = lr0 + (i / a.steps) * (lr1 - lr0) for g in optimizer.param_groups: g['lr'] = lr_cur optimizer.zero_grad() loss.backward() optimizer.step() if i % a.opt_step == 0: with torch.no_grad(): img = image_f(contrast=a.contrast).cpu().numpy()[0] # empirical tone mapping if (a.sync > 0 and a.in_img is not None): img = img**1.3 elif a.sharp != 0: img = img**(1 + a.sharp / 2.) checkout(img, os.path.join(tempdir, '%04d.jpg' % (i // a.opt_step)), verbose=a.verbose) pbar.upd() pbar = ProgressBar(a.steps // a.opt_step) for i in range(a.steps): train(i) os.system('ffmpeg -v warning -y -i %s/\%%04d.jpg "%s.mp4"' % (tempdir, os.path.join(a.out_dir, out_name))) shutil.copy( img_list(tempdir)[-1], os.path.join(a.out_dir, '%s-%d.jpg' % (out_name, a.steps))) if a.save_pt is True: torch.save(params, '%s.pt' % os.path.join(a.out_dir, out_name))
def main(): a = get_args() # Load CLIP models model_clip, _ = clip.load(a.model, jit=old_torch()) try: a.modsize = model_clip.visual.input_resolution except: a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 224 if a.verbose is True: print(' using model', a.model) xmem = { 'ViT-B/16': 0.25, 'RN50': 0.5, 'RN50x4': 0.16, 'RN50x16': 0.06, 'RN101': 0.33 } if a.model in xmem.keys(): a.samples = int(a.samples * xmem[a.model]) workdir = os.path.join(a.out_dir, basename(a.in_txt)) workdir += '-%s' % a.model if 'RN' in a.model.upper() else '' os.makedirs(workdir, exist_ok=True) def enc_text(txt): if a.multilang is True: model_lang = SentenceTransformer( 'clip-ViT-B-32-multilingual-v1').cuda() emb = model_lang.encode([txt], convert_to_tensor=True, show_progress_bar=False) del model_lang else: emb = model_clip.encode_text(clip.tokenize(txt).cuda()) return emb.detach().clone() if a.enforce != 0: a.samples = int(a.samples * 0.5) if 'elastic' in a.transform: trform_f = transforms.transforms_elastic a.samples = int(a.samples * 0.95) elif 'custom' in a.transform: trform_f = transforms.transforms_custom a.samples = int(a.samples * 0.95) elif 'fast' in a.transform: trform_f = transforms.transforms_fast a.samples = int(a.samples * 0.95) else: trform_f = transforms.normalize() if a.in_txt2 is not None: if a.verbose is True: print(' style:', basename(a.in_txt2)) # a.samples = int(a.samples * 0.75) if a.translate: translator = Translator() a.in_txt2 = translator.translate(a.in_txt2, dest='en').text if a.verbose is True: print(' translated to:', a.in_txt2) txt_enc2 = enc_text(a.in_txt2) if a.in_txt0 is not None: if a.verbose is True: print(' subtract text:', basename(a.in_txt0)) if a.translate: translator = Translator() a.in_txt0 = translator.translate(a.in_txt0, dest='en').text if a.verbose is True: print(' translated to:', a.in_txt0) txt_enc0 = enc_text(a.in_txt0) # make init global params_start, params_ema params_shape = [1, 3, a.size[0], a.size[1] // 2 + 1, 2] params_start = torch.randn(*params_shape).cuda() # random init params_ema = 0. if a.resume is not None and os.path.isfile(a.resume): if a.verbose is True: print(' resuming from', a.resume) params_start = load_params(a.resume).cuda() if a.keep > 0: params_ema = params_start[0].detach().clone() else: a.resume = 'init.pt' torch.save(params_start, 'init.pt') # final init shutil.copy(a.resume, os.path.join(workdir, '000-%s.pt' % basename(a.resume))) prev_enc = 0 def process(txt, num): sd = 0.01 if a.keep > 0: sd = a.keep + (1 - a.keep) * sd params, image_f, _ = fft_image([1, 3, *a.size], resume='init.pt', sd=sd, decay_power=a.decay) image_f = to_valid_rgb(image_f, colors=a.colors) if a.prog is True: lr1 = a.lrate * 2 lr0 = a.lrate * 0.1 else: lr0 = a.lrate optimizer = torch.optim.AdamW(params, lr0, weight_decay=0.01, amsgrad=True) if a.verbose is True: print(' topic: ', txt) if a.translate: translator = Translator() txt = translator.translate(txt, dest='en').text if a.verbose is True: print(' translated to:', txt) txt_enc = enc_text(txt) if a.notext > 0: txt_plot = torch.from_numpy(plot_text(txt, a.modsize) / 255.).unsqueeze(0).permute(0, 3, 1, 2).cuda() txt_plot_enc = model_clip.encode_image(txt_plot).detach().clone() else: txt_plot_enc = None out_name = '%03d-%s' % (num + 1, txt_clean(txt)) out_name += '-%s' % a.model if 'RN' in a.model.upper() else '' tempdir = os.path.join(workdir, out_name) os.makedirs(tempdir, exist_ok=True) pbar = ProgressBar(a.steps // a.fstep) for i in range(a.steps): loss = 0 noise = a.noise * torch.randn(1, 1, *params[0].shape[2:4], 1).cuda() if a.noise > 0 else None img_out = image_f(noise) img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, macro=a.macro)[0] out_enc = model_clip.encode_image(img_sliced) loss -= torch.cosine_similarity(txt_enc, out_enc, dim=-1).mean() if a.in_txt2 is not None: # input text - style loss -= 0.5 * torch.cosine_similarity( txt_enc2, out_enc, dim=-1).mean() if a.in_txt0 is not None: # subtract text loss += 0.5 * torch.cosine_similarity( txt_enc0, out_enc, dim=-1).mean() if a.notext > 0: loss += a.notext * torch.cosine_similarity( txt_plot_enc, out_enc, dim=-1).mean() if a.sharp != 0: # mode = scharr|sobel|default loss -= a.sharp * derivat(img_out, mode='sobel') # loss -= a.sharp * derivat(img_sliced, mode='scharr') if a.enforce != 0: img_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize, trform_f, a.align, macro=a.macro)[0] out_enc2 = model_clip.encode_image(img_sliced) loss -= a.enforce * torch.cosine_similarity( out_enc, out_enc2, dim=-1).mean() del out_enc2 torch.cuda.empty_cache() if a.expand > 0: global prev_enc if i > 0: loss += a.expand * torch.cosine_similarity( out_enc, prev_enc, dim=-1).mean() prev_enc = out_enc.detach().clone() del img_out, img_sliced, out_enc torch.cuda.empty_cache() if a.prog is True: lr_cur = lr0 + (i / a.steps) * (lr1 - lr0) for g in optimizer.param_groups: g['lr'] = lr_cur optimizer.zero_grad() loss.backward() optimizer.step() if i % a.fstep == 0: with torch.no_grad(): img = image_f(contrast=a.contrast).cpu().numpy()[0] if a.sharp != 0: img = img**(1 + a.sharp / 2.) # empirical tone mapping checkout(img, os.path.join(tempdir, '%04d.jpg' % (i // a.fstep)), verbose=a.verbose) pbar.upd() del img if a.keep > 0: global params_start, params_ema params_ema = ema(params_ema, params[0].detach().clone(), num + 1) torch.save((1 - a.keep) * params_start + a.keep * params_ema, 'init.pt') torch.save(params[0], '%s.pt' % os.path.join(workdir, out_name)) shutil.copy( img_list(tempdir)[-1], os.path.join(workdir, '%s-%d.jpg' % (out_name, a.steps))) os.system('ffmpeg -v warning -y -i %s\%%04d.jpg "%s.mp4"' % (tempdir, os.path.join(workdir, out_name))) with open(a.in_txt, 'r', encoding="utf-8") as f: texts = f.readlines() texts = [ tt.strip() for tt in texts if len(tt.strip()) > 0 and tt[0] != '#' ] if a.verbose is True: print(' total lines:', len(texts)) print(' samples:', a.samples) for i, txt in enumerate(texts): process(txt, i) vsteps = int(a.length * 25 / len(texts)) # 25 fps tempdir = os.path.join(workdir, '_final') os.makedirs(tempdir, exist_ok=True) def read_pt(file): return torch.load(file).cuda() if a.verbose is True: print(' rendering complete piece') ptfiles = file_list(workdir, 'pt') pbar = ProgressBar(vsteps * len(ptfiles)) for px in range(len(ptfiles)): params1 = read_pt(ptfiles[px]) params2 = read_pt(ptfiles[(px + 1) % len(ptfiles)]) params, image_f, _ = fft_image([1, 3, *a.size], resume=params1, sd=1., decay_power=a.decay) image_f = to_valid_rgb(image_f, colors=a.colors) for i in range(vsteps): with torch.no_grad(): img = image_f( (params2 - params1) * math.sin(1.5708 * i / vsteps)**2)[0].permute(1, 2, 0) img = torch.clip(img * 255, 0, 255).cpu().numpy().astype(np.uint8) imsave(os.path.join(tempdir, '%05d.jpg' % (px * vsteps + i)), img) if a.verbose is True: cvshow(img) pbar.upd() os.system('ffmpeg -v warning -y -i %s/\%%05d.jpg "%s.mp4"' % (tempdir, os.path.join(a.out_dir, basename(a.in_txt)))) if a.keep > 0: os.remove('init.pt')