def prepare_input(self, target_image): if len(target_image.shape) == 3: target_image = target_image.unsqueeze(0) if target_image.shape[2] > 256: target_image = utils.downsample_256(target_image) self.target_image = target_image print(self.target_image.shape)
def forward(self, img, evaluation=False): # Encode if evaluation: self.e.eval() latent_offset = self.e(img) if evaluation: self.e.train() # Add mean (we only want to compute offset to mean latent) latent = latent_offset + self.latent_avg # Decode img_gen, _ = self.g([latent], input_is_latent=True, noise=self.g.noises) # Downsample to 256 x 256 img_gen = utils.downsample_256(img_gen) # from torchvision.utils import make_grid # img_gen = make_grid(img_gen.detach().cpu(), normalize=True, range=(-1, 1)) # transforms.ToPILImage()(img_gen).show() # 1 / 0 # Compute perceptual loss loss = self.criterion(img_gen, img).mean() return loss, img_gen
def test_model(self, val_loader): # Generate image with torch.no_grad(): # Generate random image z = torch.randn(self.args.batch_size, 512, device=self.device) img, _ = self.g([z], truncation=0.9, truncation_latent=self.latent_avg) img = utils.downsample_256(img) # Forward _, img_gen = self.forward(img, evaluation=True) img_tensor = torch.cat((img, img_gen.clamp(-1., 1.)), dim=0) save_image(img_tensor, f'{self.args.save_dir}test_model_train.png', normalize=True, range=(-1, 1), nrow=min(8, self.args.batch_size)) # Test on validation data _, img_val, img_gen_val = self.eval(val_loader) save_tensor = torch.cat((img_val, img_gen_val.clamp(-1., 1.)), dim=0) save_image(save_tensor, f'{self.args.save_dir}test_model_val.png', normalize=True, range=(-1, 1), nrow=min(8, self.args.batch_size))
def image_from_latent(latentfile, eafa_model): latent = torch.load(latentfile).unsqueeze(0).cuda() with torch.no_grad(): img = eafa_model.g([latent], input_is_latent=True, noise=eafa_model.g.noises)[0].cpu() img = downsample_256(img) img = make_grid(img, normalize=True, range=(-1, 1)) return img
def eval(self, data_loader, sample_name): # Unpack batch batch = next(iter(data_loader)) audio, input_latent, aux_input, target_latent, target_img, _ = self.unpack_data( batch) n_display = min(4, self.args.batch_size) audio = audio[:n_display] target_latent = target_latent[:n_display] target_img = target_img[:n_display] input_latent = input_latent[:n_display] aux_input = aux_input[:n_display] with torch.no_grad(): # Forward pred = self.forward(audio, input_latent, aux_input) input_img, _ = self.g([input_latent], input_is_latent=True, noise=self.g.noises) input_img = utils.downsample_256(input_img) pred, _ = self.g([pred], input_is_latent=True, noise=self.g.noises) pred = utils.downsample_256(pred) target_img, _ = self.g([target_latent], input_is_latent=True, noise=self.g.noises) target_img = utils.downsample_256(target_img) # Normalize images to display input_img = make_grid(input_img, normalize=True, range=(-1, 1)) pred = make_grid(pred, normalize=True, range=(-1, 1)) target_img = make_grid(target_img, normalize=True, range=(-1, 1)) diff = (target_img - pred) * 5 img_tensor = torch.stack((pred, target_img, diff, input_img), dim=0) save_image(img_tensor, f'{self.args.save_dir}sample/{sample_name}', nrow=1)
def __iter__(self): # Sample random z z = torch.randn(self.batch_size, 512, device=self.device) # Generate image with torch.no_grad(): img, _ = self.g([z], truncation=0.9, truncation_latent=self.g.latent_avg) # Resize from 1024 to 256 if self.downsample: img = downsample_256(img) yield {'x': img}
def get_loss(self, pred, target_latent, target_image, validate=False): latent_mse = F.mse_loss(pred[:, 4:8], target_latent[:, 4:8], reduction='none') latent_mse *= self.mse_mask latent_mse = latent_mse.mean() # Reconstruct image pred_img = self.g([pred], input_is_latent=True, noise=self.g.noises)[0] pred_img = utils.downsample_256(pred_img) # Visualize # from torchvision import transforms # sample_pred = make_grid(pred_img[0].cpu(), normalize=True, range=(-1, 1)) # sample_target = make_grid(target_image[0].cpu(), normalize=True, range=(-1, 1)) # sample_pred_masked = sample_pred * self.image_mask.cpu() # sample_target_masked = sample_target * self.image_mask.cpu() # print(self.image_mask.min(), self.image_mask.max()) # print(sample_pred.shape, self.image_mask.shape, sample_pred_masked.shape) # print(sample_target.shape, self.image_mask.shape, sample_target_masked.shape) # transforms.ToPILImage('RGB')(sample_pred).show() # transforms.ToPILImage('RGB')(sample_target).show() # transforms.ToPILImage('RGB')(sample_pred_masked).show() # transforms.ToPILImage('RGB')(sample_target_masked).show() # 1 / 0 # Image loss if self.args.image_loss_type == 'lpips': l1_loss = self.lpips(pred_img * self.image_mask, target_image * self.image_mask).mean() elif self.args.image_loss_type == 'l1': l1_loss = F.l1_loss(pred_img, target_image, reduction='none') l1_loss *= self.image_mask l1_loss = l1_loss.sum() / self.image_mask.sum() else: raise NotImplementedError loss = self.args.latent_loss_weight * latent_mse + \ self.args.photometric_loss_weight * l1_loss # print(f"Loss {loss.item():.4f}, latent_mse {latent_mse.item() * self.args.latent_loss_weight:.4f}, image_l1 {l1_loss.item() * self.args.photometric_loss_weight:.4f}") return {'loss': loss, 'latent_mse': latent_mse, 'image_l1': l1_loss}
def step(self): # Hyperparameters t = self.cur_step / self.num_steps # Add noise to dlatents noise_strength = self.latent_std * self.initial_noise_factor * \ max(0.0, 1.0 - t / self.noise_ramp_length) ** 2 latent_noise = (torch.randn_like(self.latent_in) * noise_strength).to( self.device) self.latent_expr = self.latent_in + latent_noise # Update learning rate self.update_lr(t) # Train self.img_gen = self.g_ema([self.latent_expr.unsqueeze(0)], input_is_latent=True, noise=self.g_ema.noises)[0] # Downsample to 256 x 256 self.img_gen = utils.downsample_256(self.img_gen) # Compute perceptual loss self.loss = self.lpips(self.img_gen, self.target_image).sum() # Additional MSE loss if self.mse_strength: self.loss += F.mse_loss(self.img_gen, self.target_image) * self.mse_strength # Noise regularization # reg_loss = self.noise_regularization() # self.loss += reg_loss * self.regularize_noise_weight # Update params self.opt.zero_grad() self.loss.backward() self.opt.step()
def encode_frames(root_path): if root_path[-1] != '/': root_path += '/' videos = sorted(glob(root_path + '*/')) videos = [sorted(glob(v + '*.png')) for v in videos] all_frames = [item for sublist in videos for item in sublist] assert len(all_frames) > 0 print(len(all_frames)) # Select device device = 'cuda' if torch.cuda.is_available() else 'cpu' # Load encoder from my_models.models import resnetEncoder e = resnetEncoder(net=18).eval().to(device) # checkpoint = torch.load("PATH_HERE", map_location=device) checkpoint = torch.load("/mnt/sdb1/meissen/Networks/GRID_new.pt", map_location=device) if type(checkpoint) == dict: e.load_state_dict(checkpoint['model']) else: e.load_state_dict(checkpoint) # Get latent avg from my_models.style_gan_2 import PretrainedGenerator1024 g = PretrainedGenerator1024().eval() latent_avg = g.latent_avg.view(1, -1).repeat(18, 1) # transforms t = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) for frame in tqdm(all_frames): save_path = frame.split('.')[0] + '.latent.pt' # print(save_path) if os.path.exists(save_path): continue # Load image img = t(Image.open(frame)).unsqueeze(0).to(device) # Encoder image with torch.no_grad(): latent_offset = e(img)[0].cpu() latent = latent_offset + latent_avg # Visualize from torchvision.utils import make_grid from utils.utils import downsample_256 print(save_path, latent.shape) img_gen = g.to(device)([latent.unsqueeze(0).to(device)], input_is_latent=True, noise=g.noises)[0].cpu() img_gen = downsample_256(img_gen) img_gen = make_grid(torch.cat((img_gen, img.cpu()), dim=0), normalize=True, range=(-1, 1)) img_gen = transforms.ToPILImage('RGB')(img_gen) img_gen.show() 1 / 0 # Save torch.save(latent, save_path)
def test_video(self, test_latent_path, test_sentence_path, audio_file_path, frames=-1, mode=""): print(f"Testing {test_sentence_path}\n{audio_file_path}\n") self.audio_encoder.eval() if test_sentence_path[-1] != '/': test_sentence_path += '/' test_latent = torch.load(test_latent_path).unsqueeze(0).to(self.device) aux_input = test_latent[:, 4:8] sentence_name = test_sentence_path.split('/')[-2] # Load audio features audio_type = 'deepspeech' if self.args.audio_type == 'deepspeech-synced' else self.args.audio_type audio_paths = sorted(glob(test_sentence_path + f'*.{audio_type}.npy'))[:frames] audios = torch.stack([ torch.tensor(np.load(p), dtype=torch.float32) for p in audio_paths ]).to(self.device) # Pad audio features pad = self.args.T // 2 audios = F.pad(audios, (0, 0, 0, 0, pad, pad - 1), 'constant', 0.) audios = audios.unfold(0, self.args.T, 1).permute(0, 3, 1, 2) target_latent_paths = sorted(glob(test_sentence_path + '*.latent.pt'))[:frames] target_latents = torch.stack( [torch.load(p) for p in target_latent_paths]).to(self.device) pbar = tqdm(total=len(target_latents)) video = [] # Generate for i, (audio, target_latent) in enumerate(zip(audios, target_latents)): audio = audio.unsqueeze(0) target_latent = target_latent.unsqueeze(0) with torch.no_grad(): input_latent = test_latent.clone() latent = self.forward(audio, input_latent, aux_input) # Generate images pred = self.g([latent], input_is_latent=True, noise=self.g.noises)[0] # target_img = self.g([target_latent], input_is_latent=True, noise=self.g.noises)[0] # Downsample pred = utils.downsample_256(pred) # target_img = utils.downsample_256(target_img) pbar.update() # Normalize pred = make_grid(pred.cpu(), normalize=True, range=(-1, 1)) # target_img = make_grid(target_img.cpu(), normalize=True, range=(-1, 1)) # diff = (target_img - pred) * 5 # save_tensor = torch.stack((pred, target_img, diff), dim=0) # video.append(make_grid(save_tensor)) video.append(make_grid(pred)) # Save frames as video video = torch.stack(video, dim=0) video_name = f"{self.args.save_dir}results/{mode}{sentence_name}" os.makedirs(f"{self.args.save_dir}results", exist_ok=True) utils.write_video(f'{video_name}.mp4', video, fps=25) # Add audio p = Popen([ 'ffmpeg', '-y', '-i', f'{video_name}.mp4', '-i', audio_file_path, '-codec', 'copy', '-shortest', f'{video_name}.mov' ], stdout=PIPE, stderr=PIPE) output, error = p.communicate() if p.returncode != 0: print( "Adding audio from %s to video %s failed with error\n%d %s %s" % (audio_file_path, f'{video_name}.mp4', p.returncode, output, error)) os.system(f"rm {video_name}.mp4") self.audio_encoder.train()
def __call__(self, test_latent, test_sentence_path, direction=None, use_landmark_input=False, audio_multiplier=2.0, audio_truncation=0.8, direction_multiplier=1.0, max_sec=None): # Load test latent if type(test_latent) is str: test_latent = torch.load(test_latent).unsqueeze(0).to(self.device) else: test_latent = test_latent.unsqueeze(0).to(self.device) if test_latent.shape[1] == 1: test_latent = test_latent.repeat(1, 18, 1) # Visualize # img = self.g([test_latent], input_is_latent=True, noise=self.g.noises)[0] # img = utils.downsample_256(img) # img = make_grid(img.cpu(), normalize=True, range=(-1, 1)) # from torchvision import transforms # transforms.ToPILImage('RGB')(img).show() # 1 / 0 # Auxiliary input aux_input = test_latent[:, 4:8] # Load audio features audio_paths = sorted(glob(test_sentence_path + f'*.{self.audio_type}.npy')) audios = torch.stack([torch.tensor(np.load(p), dtype=torch.float32) for p in audio_paths]).to(self.device) if max_sec is not None: max_frames = 25 * max_sec audios = audios[:max_frames] # Pad audio features pad = self.T // 2 audios = F.pad(audios, (0, 0, 0, 0, pad, pad - 1), 'constant', 0.) audios = audios.unfold(0, self.T, 1).permute(0, 3, 1, 2) # Load direction if provided if direction is not None: # Load test latent if type(direction) is str: ext = direction.split('.')[-1] if ext == 'npy': direction = torch.tensor( np.load(direction), dtype=torch.float32).unsqueeze(0).to(self.device) elif ext == 'pt': direction = torch.load(direction).unsqueeze(0).to(self.device) else: raise RuntimeError else: direction = direction.unsqueeze(0).to(self.device) video = [] # Generate for i, audio in enumerate(audios): audio = audio.unsqueeze(0) with torch.no_grad(): input_latent = test_latent.clone() latent = self.forward(audio, input_latent, aux_input, direction, audio_multiplier=audio_multiplier, audio_truncation=audio_truncation, direction_multiplier=direction_multiplier) # Generate images pred = self.g([latent], input_is_latent=True, noise=self.g.noises)[0] # Downsample pred = utils.downsample_256(pred) # Normalize pred = make_grid(pred.cpu(), normalize=True, range=(-1, 1)) video.append(pred) return torch.stack(video)