def translate_using_latent(nets, args, x_src, y_trg_list, z_trg_list, psi, filename): n_images = 100 x_src.stop_gradient = True N, C, H, W = x_src.shape latent_dim = z_trg_list[0].shape[1] x_concat = [x_src] masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None for i, y_trg in enumerate(y_trg_list): z_many = porch.randn(n_images, latent_dim) # y_many = porch.LongTensor(10000).fill_(y_trg[0]) y_many = np.empty([n_images]) y_many.fill(y_trg[0].numpy()[0]) y_many = to_variable(y_many) s_many = nets.mapping_network(z_many, y_many) s_avg = porch.mean(s_many, dim=0, keepdim=True) s_avg = s_avg.repeat(N, 1) for z_trg in z_trg_list: s_trg = nets.mapping_network(z_trg, y_trg) s_trg = porch.lerp(s_avg, s_trg, psi) x_fake = nets.generator(x_src, s_trg, masks=masks) x_concat += [x_fake] x_concat = porch.cat(x_concat, dim=0) save_image(x_concat, N, filename)
def video_latent(nets, args, x_src, y_list, z_list, psi, fname): x_src.stop_gradient = True latent_dim = z_list[0].size(1) s_list = [] for i, y_trg in enumerate(y_list): z_many = porch.randn(10000, latent_dim) y_many = porch.LongTensor(10000).fill_(y_trg[0]) s_many = nets.mapping_network(z_many, y_many) s_avg = porch.mean(s_many, dim=0, keepdim=True) s_avg = s_avg.repeat(x_src.size(0), 1) for z_trg in z_list: s_trg = nets.mapping_network(z_trg, y_trg) s_trg = porch.lerp(s_avg, s_trg, psi) s_list.append(s_trg) s_prev = None video = [] # fetch reference images for idx_ref, s_next in enumerate(tqdm(s_list, 'video_latent', len(s_list))): if s_prev is None: s_prev = s_next continue if idx_ref % len(z_list) == 0: s_prev = s_next continue frames = interpolate(nets, args, x_src, s_prev, s_next).cpu() video.append(frames) s_prev = s_next for _ in range(10): video.append(frames[-1:]) video = tensor2ndarray255(porch.cat(video)) save_video(fname, video)
def interpolate(nets, args, x_src, s_prev, s_next): ''' returns T x C x H x W ''' B = x_src.shape[0] frames = [] masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None alphas = get_alphas() for alpha in alphas: s_ref = porch.lerp(s_prev, s_next, alpha) x_fake = nets.generator(x_src, s_ref, masks=masks) entries = porch.cat([x_src, x_fake], dim=2) frame = porchvision.utils.make_grid(entries, nrow=B, padding=0, pad_value=-1).unsqueeze(0) frames.append(frame) frames = porch.cat(frames) return frames
def moving_average(model, model_test, beta=0.999): for param, param_test in zip(model.parameters(), model_test.parameters()): porch.copy(porch.lerp(param, param_test, beta), param_test)