Пример #1
0
def translate_using_latent(nets, args, x_src, y_trg_list, z_trg_list, psi,
                           filename):
    n_images = 100
    x_src.stop_gradient = True
    N, C, H, W = x_src.shape
    latent_dim = z_trg_list[0].shape[1]
    x_concat = [x_src]
    masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None

    for i, y_trg in enumerate(y_trg_list):
        z_many = porch.randn(n_images, latent_dim)
        # y_many = porch.LongTensor(10000).fill_(y_trg[0])
        y_many = np.empty([n_images])
        y_many.fill(y_trg[0].numpy()[0])
        y_many = to_variable(y_many)
        s_many = nets.mapping_network(z_many, y_many)
        s_avg = porch.mean(s_many, dim=0, keepdim=True)
        s_avg = s_avg.repeat(N, 1)

        for z_trg in z_trg_list:
            s_trg = nets.mapping_network(z_trg, y_trg)
            s_trg = porch.lerp(s_avg, s_trg, psi)
            x_fake = nets.generator(x_src, s_trg, masks=masks)
            x_concat += [x_fake]

    x_concat = porch.cat(x_concat, dim=0)
    save_image(x_concat, N, filename)
Пример #2
0
def video_latent(nets, args, x_src, y_list, z_list, psi, fname):
    x_src.stop_gradient = True

    latent_dim = z_list[0].size(1)
    s_list = []
    for i, y_trg in enumerate(y_list):
        z_many = porch.randn(10000, latent_dim)
        y_many = porch.LongTensor(10000).fill_(y_trg[0])
        s_many = nets.mapping_network(z_many, y_many)
        s_avg = porch.mean(s_many, dim=0, keepdim=True)
        s_avg = s_avg.repeat(x_src.size(0), 1)

        for z_trg in z_list:
            s_trg = nets.mapping_network(z_trg, y_trg)
            s_trg = porch.lerp(s_avg, s_trg, psi)
            s_list.append(s_trg)

    s_prev = None
    video = []
    # fetch reference images
    for idx_ref, s_next in enumerate(tqdm(s_list, 'video_latent',
                                          len(s_list))):
        if s_prev is None:
            s_prev = s_next
            continue
        if idx_ref % len(z_list) == 0:
            s_prev = s_next
            continue
        frames = interpolate(nets, args, x_src, s_prev, s_next).cpu()
        video.append(frames)
        s_prev = s_next
    for _ in range(10):
        video.append(frames[-1:])
    video = tensor2ndarray255(porch.cat(video))
    save_video(fname, video)
Пример #3
0
def interpolate(nets, args, x_src, s_prev, s_next):
    ''' returns T x C x H x W '''
    B = x_src.shape[0]
    frames = []
    masks = nets.fan.get_heatmap(x_src) if args.w_hpf > 0 else None
    alphas = get_alphas()

    for alpha in alphas:
        s_ref = porch.lerp(s_prev, s_next, alpha)
        x_fake = nets.generator(x_src, s_ref, masks=masks)
        entries = porch.cat([x_src, x_fake], dim=2)
        frame = porchvision.utils.make_grid(entries,
                                            nrow=B,
                                            padding=0,
                                            pad_value=-1).unsqueeze(0)
        frames.append(frame)
    frames = porch.cat(frames)
    return frames
Пример #4
0
def moving_average(model, model_test, beta=0.999):
    for param, param_test in zip(model.parameters(), model_test.parameters()):
        porch.copy(porch.lerp(param, param_test, beta), param_test)