def process(num): global params_tmp, opt_state, params, image_f, optimizer if a.interpol is True: # linear topics interpolation txt_encs = get_encs(key_txt_encs, num) styl_encs = get_encs(key_styl_encs, num) img_encs = get_encs(key_img_encs, num) else: # change by cut txt_encs = [key_txt_encs[min(num, len(key_txt_encs)-1)][0]] * steps if len(key_txt_encs) > 0 else [] styl_encs = [key_styl_encs[min(num, len(key_styl_encs)-1)][0]] * steps if len(key_styl_encs) > 0 else [] img_encs = [key_img_encs[min(num, len(key_img_encs)-1)][0]] * steps if len(key_img_encs) > 0 else [] if a.verbose is True: if len(texts) > 0: print(' ref text: ', texts[min(num, len(texts)-1)][:80]) if len(styles) > 0: print(' ref style: ', styles[min(num, len(styles)-1)][:80]) if len(images) > 0: print(' ref image: ', basename(images[min(num, len(images)-1)])[:80]) pbar = ProgressBar(steps) for ii in range(steps): glob_step = num * steps + ii # save/transform txt_enc = txt_encs[ii % len(txt_encs)].unsqueeze(0) if len(txt_encs) > 0 else None styl_enc = styl_encs[ii % len(styl_encs)].unsqueeze(0) if len(styl_encs) > 0 else None img_enc = img_encs[ii % len(img_encs)].unsqueeze(0) if len(img_encs) > 0 else None # MOTION: transform frame, reload params scale = m_scale[glob_step] if a.anima else 1 + a.scale shift = m_shift[glob_step] if a.anima else [0, a.shift] angle = m_angle[glob_step][0] if a.anima else a.angle shear = m_shear[glob_step][0] if a.anima else a.shear if a.gen == 'RGB': if a.depth > 0: params_tmp = depth_transform(params_tmp, depth_infer, depth_mask, a.size, a.depth, scale, shift, a.colors, a.depth_dir, glob_step) params_tmp = frame_transform(params_tmp, a.size, angle, shift, scale, shear) params, image_f, _ = pixel_image([1, 3, *a.size], resume=params_tmp) img_tmp = None else: # FFT if old_torch(): # 1.7.1 img_tmp = torch.irfft(params_tmp, 2, normalized=True, signal_sizes=a.size) if a.depth > 0: img_tmp = depth_transform(img_tmp, depth_infer, depth_mask, a.size, a.depth, scale, shift, a.colors, a.depth_dir, glob_step) img_tmp = frame_transform(img_tmp, a.size, angle, shift, scale, shear) params_tmp = torch.rfft(img_tmp, 2, normalized=True) else: # 1.8+ if type(params_tmp) is not torch.complex64: params_tmp = torch.view_as_complex(params_tmp) img_tmp = torch.fft.irfftn(params_tmp, s=a.size, norm='ortho') if a.depth > 0: img_tmp = depth_transform(img_tmp, depth_infer, depth_mask, a.size, a.depth, scale, shift, a.colors, a.depth_dir, glob_step) img_tmp = frame_transform(img_tmp, a.size, angle, shift, scale, shear) params_tmp = torch.fft.rfftn(img_tmp, s=a.size, dim=[2,3], norm='ortho') params_tmp = torch.view_as_real(params_tmp) params, image_f, _ = fft_image([1, 3, *a.size], sd=1, resume=params_tmp) if a.optimizer.lower() == 'adamw': optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01) elif a.optimizer.lower() == 'adamw_custom': optimizer = torch.optim.AdamW(params, a.lrate, weight_decay=0.01, betas=(.0,.999), amsgrad=True) elif a.optimizer.lower() == 'adam': optimizer = torch.optim.Adam(params, a.lrate) else: # adam_custom optimizer = torch.optim.Adam(params, a.lrate, betas=(.0,.999)) image_f = to_valid_rgb(image_f, colors = a.colors) del img_tmp if a.smooth is True and num + ii > 0: optimizer.load_state_dict(opt_state) ### optimization for ss in range(a.opt_step): loss = 0 noise = a.noise * (torch.rand(1, 1, a.size[0], a.size[1]//2+1, 1)-0.5).cuda() if a.noise>0 else 0. img_out = image_f(noise, fixcontrast=a.fixcontrast) img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0] out_enc = model_clip.encode_image(img_sliced) if a.gen == 'RGB': # empirical hack loss += abs(img_out.mean((2,3)) - 0.45).mean() # fix brightness loss += abs(img_out.std((2,3)) - 0.17).mean() # fix contrast if txt_enc is not None: loss -= a.invert * sim_func(txt_enc, out_enc, a.sim) if styl_enc is not None: loss -= a.weight2 * sim_func(styl_enc, out_enc, a.sim) if img_enc is not None: loss -= a.weight_img * sim_func(img_enc, out_enc, a.sim) if a.in_txt0 is not None: # subtract text for anti_txt_enc in anti_txt_encs: loss += 0.3 * sim_func(anti_txt_enc, out_enc, a.sim) if a.sharp != 0: # scharr|sobel|naive loss -= a.sharp * derivat(img_out, mode='naive') if a.enforce != 0: img_sliced = slice_imgs([image_f(noise, fixcontrast=a.fixcontrast)], a.samples, a.modsize, trform_f, a.align, a.macro)[0] out_enc2 = model_clip.encode_image(img_sliced) loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim) del out_enc2; torch.cuda.empty_cache() if a.expand > 0: global prev_enc if ii > 0: loss += a.expand * sim_func(prev_enc, out_enc, a.sim) prev_enc = out_enc.detach().clone() del img_out, img_sliced, out_enc; torch.cuda.empty_cache() optimizer.zero_grad() loss.backward() optimizer.step() ### save params & frame params_tmp = params[0].detach().clone() if a.smooth is True: opt_state = optimizer.state_dict() with torch.no_grad(): img_t = image_f(contrast=a.contrast, fixcontrast=a.fixcontrast)[0].permute(1,2,0) img_np = torch.clip(img_t*255, 0, 255).cpu().numpy().astype(np.uint8) imsave(os.path.join(tempdir, '%06d.jpg' % glob_step), img_np, quality=95) if a.verbose is True: cvshow(img_np) del img_t, img_np pbar.upd() params_tmp = params[0].detach().clone()
def enc_image(img_file): img_t = torch.from_numpy(img_read(img_file)/255.).unsqueeze(0).permute(0,3,1,2).cuda()[:,:3,:,:] in_sliced = slice_imgs([img_t], a.samples, a.modsize, transforms.normalize(), a.align)[0] emb = model_clip.encode_image(in_sliced) return emb.detach().clone()
def main(): a = get_args() shape = [1, 3, *a.size] if a.dwt is True: params, image_f, sz = dwt_image(shape, a.wave, 0.3, a.colors, a.resume) else: params, image_f, sz = fft_image(shape, 0.07, a.decay, a.resume) if sz is not None: a.size = sz image_f = to_valid_rgb(image_f, colors=a.colors) if a.prog is True: lr1 = a.lrate * 2 lr0 = lr1 * 0.01 else: lr0 = a.lrate if a.optimizer.lower() == 'adamw': optimizer = torch.optim.AdamW(params, lr0, weight_decay=0.01) elif a.optimizer.lower() == 'adamw_custom': optimizer = torch.optim.AdamW(params, lr0, weight_decay=0.01, betas=(.0, .999), amsgrad=True) elif a.optimizer.lower() == 'adam': optimizer = torch.optim.Adam(params, lr0) else: # adam_custom optimizer = torch.optim.Adam(params, lr0, betas=(.0, .999)) sign = 1. if a.invert is True else -1. # Load CLIP models model_clip, _ = clip.load(a.model, jit=old_torch()) try: a.modsize = model_clip.visual.input_resolution except: a.modsize = 288 if a.model == 'RN50x4' else 384 if a.model == 'RN50x16' else 224 if a.verbose is True: print(' using model', a.model) xmem = { 'ViT-B/16': 0.25, 'RN50': 0.5, 'RN50x4': 0.16, 'RN50x16': 0.06, 'RN101': 0.33 } if a.model in xmem.keys(): a.samples = int(a.samples * xmem[a.model]) if a.multilang is True: model_lang = SentenceTransformer( 'clip-ViT-B-32-multilingual-v1').cuda() def enc_text(txt): if a.multilang is True: emb = model_lang.encode([txt], convert_to_tensor=True, show_progress_bar=False) else: emb = model_clip.encode_text(clip.tokenize(txt).cuda()) return emb.detach().clone() if a.enforce != 0: a.samples = int(a.samples * 0.5) if a.sync > 0: a.samples = int(a.samples * 0.5) if 'elastic' in a.transform: trform_f = transforms.transforms_elastic a.samples = int(a.samples * 0.95) elif 'custom' in a.transform: trform_f = transforms.transforms_custom a.samples = int(a.samples * 0.95) elif 'fast' in a.transform: trform_f = transforms.transforms_fast a.samples = int(a.samples * 0.95) else: trform_f = transforms.normalize() out_name = [] if a.in_txt is not None: if a.verbose is True: print(' topic text: ', a.in_txt) if a.translate: translator = Translator() a.in_txt = translator.translate(a.in_txt, dest='en').text if a.verbose is True: print(' translated to:', a.in_txt) txt_enc = enc_text(a.in_txt) out_name.append(txt_clean(a.in_txt).lower()[:40]) if a.notext > 0: txt_plot = torch.from_numpy(plot_text(a.in_txt, a.modsize) / 255.).unsqueeze(0).permute(0, 3, 1, 2).cuda() txt_plot_enc = model_clip.encode_image(txt_plot).detach().clone() if a.in_txt2 is not None: if a.verbose is True: print(' style text:', a.in_txt2) a.samples = int(a.samples * 0.75) if a.translate: translator = Translator() a.in_txt2 = translator.translate(a.in_txt2, dest='en').text if a.verbose is True: print(' translated to:', a.in_txt2) txt_enc2 = enc_text(a.in_txt2) out_name.append(txt_clean(a.in_txt2).lower()[:40]) if a.in_txt0 is not None: if a.verbose is True: print(' subtract text:', a.in_txt0) a.samples = int(a.samples * 0.75) if a.translate: translator = Translator() a.in_txt0 = translator.translate(a.in_txt0, dest='en').text if a.verbose is True: print(' translated to:', a.in_txt0) txt_enc0 = enc_text(a.in_txt0) out_name.append('off-' + txt_clean(a.in_txt0).lower()[:40]) if a.multilang is True: del model_lang if a.in_img is not None and os.path.isfile(a.in_img): if a.verbose is True: print(' ref image:', basename(a.in_img)) img_in = torch.from_numpy( img_read(a.in_img) / 255.).unsqueeze(0).permute(0, 3, 1, 2).cuda() img_in = img_in[:, :3, :, :] # fix rgb channels in_sliced = slice_imgs([img_in], a.samples, a.modsize, transforms.normalize(), a.align)[0] img_enc = model_clip.encode_image(in_sliced).detach().clone() if a.sync > 0: sim_loss = lpips.LPIPS(net='vgg', verbose=False).cuda() sim_size = [s // 2 for s in a.size] img_in = F.interpolate(img_in, sim_size, mode='bicubic', align_corners=True).float() else: del img_in del in_sliced torch.cuda.empty_cache() out_name.append(basename(a.in_img).replace(' ', '_')) if a.verbose is True: print(' samples:', a.samples) out_name = '-'.join(out_name) out_name += '-%s' % a.model if 'RN' in a.model.upper() else '' tempdir = os.path.join(a.out_dir, out_name) os.makedirs(tempdir, exist_ok=True) prev_enc = 0 def train(i): loss = 0 noise = a.noise * torch.rand(1, 1, *params[0].shape[2:4], 1).cuda() if a.noise > 0 else None img_out = image_f(noise) img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0] out_enc = model_clip.encode_image(img_sliced) if a.in_txt is not None: # input text loss += sign * sim_func(txt_enc, out_enc, a.sim) if a.notext > 0: loss -= sign * a.notext * sim_func(txt_plot_enc, out_enc, a.sim) if a.in_txt2 is not None: # input text - style loss += sign * a.weight2 * sim_func(txt_enc2, out_enc, a.sim) if a.in_txt0 is not None: # subtract text loss += -sign * 0.3 * sim_func(txt_enc0, out_enc, a.sim) if a.in_img is not None and os.path.isfile(a.in_img): # input image loss += sign * 0.5 * sim_func(img_enc, out_enc, a.sim) if a.sync > 0 and a.in_img is not None and os.path.isfile( a.in_img): # image composition prog_sync = (a.steps // a.opt_step - i) / (a.steps // a.opt_step) loss += prog_sync * a.sync * sim_loss(F.interpolate( img_out, sim_size, mode='bicubic', align_corners=True).float(), img_in, normalize=True).squeeze() if a.sharp != 0 and a.dwt is not True: # scharr|sobel|default loss -= a.sharp * derivat(img_out, mode='naiv') # loss -= a.sharp * derivat(img_sliced, mode='scharr') if a.enforce != 0: img_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize, trform_f, a.align, a.macro)[0] out_enc2 = model_clip.encode_image(img_sliced) loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim) del out_enc2 torch.cuda.empty_cache() if a.expand > 0: global prev_enc if i > 0: loss += a.expand * sim_func(out_enc, prev_enc, a.sim) prev_enc = out_enc.detach() # .clone() del img_out, img_sliced, out_enc torch.cuda.empty_cache() assert not isinstance(loss, int), ' Loss not defined, check the inputs' if a.prog is True: lr_cur = lr0 + (i / a.steps) * (lr1 - lr0) for g in optimizer.param_groups: g['lr'] = lr_cur optimizer.zero_grad() loss.backward() optimizer.step() if i % a.opt_step == 0: with torch.no_grad(): img = image_f(contrast=a.contrast).cpu().numpy()[0] # empirical tone mapping if (a.sync > 0 and a.in_img is not None): img = img**1.3 elif a.sharp != 0: img = img**(1 + a.sharp / 2.) checkout(img, os.path.join(tempdir, '%04d.jpg' % (i // a.opt_step)), verbose=a.verbose) pbar.upd() pbar = ProgressBar(a.steps // a.opt_step) for i in range(a.steps): train(i) os.system('ffmpeg -v warning -y -i %s/\%%04d.jpg "%s.mp4"' % (tempdir, os.path.join(a.out_dir, out_name))) shutil.copy( img_list(tempdir)[-1], os.path.join(a.out_dir, '%s-%d.jpg' % (out_name, a.steps))) if a.save_pt is True: torch.save(params, '%s.pt' % os.path.join(a.out_dir, out_name))
def train(i): loss = 0 noise = a.noise * torch.rand(1, 1, *params[0].shape[2:4], 1).cuda() if a.noise > 0 else None img_out = image_f(noise) img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, a.macro)[0] out_enc = model_clip.encode_image(img_sliced) if a.in_txt is not None: # input text loss += sign * sim_func(txt_enc, out_enc, a.sim) if a.notext > 0: loss -= sign * a.notext * sim_func(txt_plot_enc, out_enc, a.sim) if a.in_txt2 is not None: # input text - style loss += sign * a.weight2 * sim_func(txt_enc2, out_enc, a.sim) if a.in_txt0 is not None: # subtract text loss += -sign * 0.3 * sim_func(txt_enc0, out_enc, a.sim) if a.in_img is not None and os.path.isfile(a.in_img): # input image loss += sign * 0.5 * sim_func(img_enc, out_enc, a.sim) if a.sync > 0 and a.in_img is not None and os.path.isfile( a.in_img): # image composition prog_sync = (a.steps // a.opt_step - i) / (a.steps // a.opt_step) loss += prog_sync * a.sync * sim_loss(F.interpolate( img_out, sim_size, mode='bicubic', align_corners=True).float(), img_in, normalize=True).squeeze() if a.sharp != 0 and a.dwt is not True: # scharr|sobel|default loss -= a.sharp * derivat(img_out, mode='naiv') # loss -= a.sharp * derivat(img_sliced, mode='scharr') if a.enforce != 0: img_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize, trform_f, a.align, a.macro)[0] out_enc2 = model_clip.encode_image(img_sliced) loss -= a.enforce * sim_func(out_enc, out_enc2, a.sim) del out_enc2 torch.cuda.empty_cache() if a.expand > 0: global prev_enc if i > 0: loss += a.expand * sim_func(out_enc, prev_enc, a.sim) prev_enc = out_enc.detach() # .clone() del img_out, img_sliced, out_enc torch.cuda.empty_cache() assert not isinstance(loss, int), ' Loss not defined, check the inputs' if a.prog is True: lr_cur = lr0 + (i / a.steps) * (lr1 - lr0) for g in optimizer.param_groups: g['lr'] = lr_cur optimizer.zero_grad() loss.backward() optimizer.step() if i % a.opt_step == 0: with torch.no_grad(): img = image_f(contrast=a.contrast).cpu().numpy()[0] # empirical tone mapping if (a.sync > 0 and a.in_img is not None): img = img**1.3 elif a.sharp != 0: img = img**(1 + a.sharp / 2.) checkout(img, os.path.join(tempdir, '%04d.jpg' % (i // a.opt_step)), verbose=a.verbose) pbar.upd()
def process(txt, num): sd = 0.01 if a.keep > 0: sd = a.keep + (1 - a.keep) * sd params, image_f, _ = fft_image([1, 3, *a.size], resume='init.pt', sd=sd, decay_power=a.decay) image_f = to_valid_rgb(image_f, colors=a.colors) if a.prog is True: lr1 = a.lrate * 2 lr0 = a.lrate * 0.1 else: lr0 = a.lrate optimizer = torch.optim.AdamW(params, lr0, weight_decay=0.01, amsgrad=True) if a.verbose is True: print(' topic: ', txt) if a.translate: translator = Translator() txt = translator.translate(txt, dest='en').text if a.verbose is True: print(' translated to:', txt) txt_enc = enc_text(txt) if a.notext > 0: txt_plot = torch.from_numpy(plot_text(txt, a.modsize) / 255.).unsqueeze(0).permute(0, 3, 1, 2).cuda() txt_plot_enc = model_clip.encode_image(txt_plot).detach().clone() else: txt_plot_enc = None out_name = '%03d-%s' % (num + 1, txt_clean(txt)) out_name += '-%s' % a.model if 'RN' in a.model.upper() else '' tempdir = os.path.join(workdir, out_name) os.makedirs(tempdir, exist_ok=True) pbar = ProgressBar(a.steps // a.fstep) for i in range(a.steps): loss = 0 noise = a.noise * torch.randn(1, 1, *params[0].shape[2:4], 1).cuda() if a.noise > 0 else None img_out = image_f(noise) img_sliced = slice_imgs([img_out], a.samples, a.modsize, trform_f, a.align, macro=a.macro)[0] out_enc = model_clip.encode_image(img_sliced) loss -= torch.cosine_similarity(txt_enc, out_enc, dim=-1).mean() if a.in_txt2 is not None: # input text - style loss -= 0.5 * torch.cosine_similarity( txt_enc2, out_enc, dim=-1).mean() if a.in_txt0 is not None: # subtract text loss += 0.5 * torch.cosine_similarity( txt_enc0, out_enc, dim=-1).mean() if a.notext > 0: loss += a.notext * torch.cosine_similarity( txt_plot_enc, out_enc, dim=-1).mean() if a.sharp != 0: # mode = scharr|sobel|default loss -= a.sharp * derivat(img_out, mode='sobel') # loss -= a.sharp * derivat(img_sliced, mode='scharr') if a.enforce != 0: img_sliced = slice_imgs([image_f(noise)], a.samples, a.modsize, trform_f, a.align, macro=a.macro)[0] out_enc2 = model_clip.encode_image(img_sliced) loss -= a.enforce * torch.cosine_similarity( out_enc, out_enc2, dim=-1).mean() del out_enc2 torch.cuda.empty_cache() if a.expand > 0: global prev_enc if i > 0: loss += a.expand * torch.cosine_similarity( out_enc, prev_enc, dim=-1).mean() prev_enc = out_enc.detach().clone() del img_out, img_sliced, out_enc torch.cuda.empty_cache() if a.prog is True: lr_cur = lr0 + (i / a.steps) * (lr1 - lr0) for g in optimizer.param_groups: g['lr'] = lr_cur optimizer.zero_grad() loss.backward() optimizer.step() if i % a.fstep == 0: with torch.no_grad(): img = image_f(contrast=a.contrast).cpu().numpy()[0] if a.sharp != 0: img = img**(1 + a.sharp / 2.) # empirical tone mapping checkout(img, os.path.join(tempdir, '%04d.jpg' % (i // a.fstep)), verbose=a.verbose) pbar.upd() del img if a.keep > 0: global params_start, params_ema params_ema = ema(params_ema, params[0].detach().clone(), num + 1) torch.save((1 - a.keep) * params_start + a.keep * params_ema, 'init.pt') torch.save(params[0], '%s.pt' % os.path.join(workdir, out_name)) shutil.copy( img_list(tempdir)[-1], os.path.join(workdir, '%s-%d.jpg' % (out_name, a.steps))) os.system('ffmpeg -v warning -y -i %s\%%04d.jpg "%s.mp4"' % (tempdir, os.path.join(workdir, out_name)))