Esempio n. 1
0
def transfer(config, generator, kp_detector, checkpoint, log_dir, dataset):
    log_dir = os.path.join(log_dir, 'transfer')
    png_dir = os.path.join(log_dir, 'png')
    transfer_params = config['transfer_params']

    dataset = PairedDataset(initial_dataset=dataset,
                            number_of_pairs=transfer_params['num_pairs'])
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if checkpoint is not None:
        Logger.load_cpk(checkpoint,
                        generator=generator,
                        kp_detector=kp_detector)
    else:
        raise AttributeError(
            "Checkpoint should be specified for mode='transfer'.")

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            x = {
                key: value if not hasattr(value, 'cuda') else value.cuda()
                for key, value in x.items()
            }
            driving_video = x['driving_video']
            source_image = x['source_video'][:, :, :1, :, :]
            out = transfer_one(generator, kp_detector, source_image,
                               driving_video, transfer_params)
            img_name = "-".join([x['driving_name'][0], x['source_name'][0]])

            # Store to .png for evaluation
            out_video_batch = out['video_prediction'].data.cpu().numpy()
            out_video_batch = np.concatenate(np.transpose(
                out_video_batch, [0, 2, 3, 4, 1])[0],
                                             axis=1)
            imageio.imsave(os.path.join(png_dir, img_name + '.png'),
                           (255 * out_video_batch).astype(np.uint8))

            image = Visualizer(
                **config['visualizer_params']).visualize_transfer(
                    driving_video=driving_video,
                    source_image=source_image,
                    out=out)
            imageio.mimsave(
                os.path.join(log_dir, img_name + transfer_params['format']),
                image)
Esempio n. 2
0
def animate(config, generator, kp_detector, checkpoint, log_dir, dataset):
    log_dir = os.path.join(log_dir, 'animation')
    png_dir = os.path.join(log_dir, 'png')
    animate_params = config['animate_params']

    dataset = PairedDataset(initial_dataset=dataset, number_of_pairs=animate_params['num_pairs'])
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    if checkpoint is not None:
        Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
    else:
        raise AttributeError("Checkpoint should be specified for mode='animate'.")

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            predictions = []
            visualizations = []

            driving_video = x['driving_video']
            source_frame = x['source_video'][:, :, 0, :, :]

            kp_source = kp_detector(source_frame)
            kp_driving_initial = kp_detector(driving_video[:, :, 0])

            for frame_idx in range(driving_video.shape[2]):
                driving_frame = driving_video[:, :, frame_idx]
                kp_driving = kp_detector(driving_frame)
                kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
                                       kp_driving_initial=kp_driving_initial, **animate_params['normalization_params'])
                out = generator(source_frame, kp_source=kp_source, kp_driving=kp_norm)

                out['kp_driving'] = kp_driving
                out['kp_source'] = kp_source
                out['kp_norm'] = kp_norm

                del out['sparse_deformed']

                predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])

                visualization = Visualizer(**config['visualizer_params']).visualize(source=source_frame,
                                                                                    driving=driving_frame, out=out)
                visualization = visualization
                visualizations.append(visualization)

            predictions = np.concatenate(predictions, axis=1)
            result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
            imageio.imsave(os.path.join(png_dir, result_name + '.png'), (255 * predictions).astype(np.uint8))

            image_name = result_name + animate_params['format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
Esempio n. 3
0
def reconstruction(config, generator, kp_detector, checkpoint, log_dir,
                   dataset):
    png_dir = os.path.join(log_dir, 'reconstruction/png')
    log_dir = os.path.join(log_dir, 'reconstruction')

    if checkpoint is not None:
        Logger.load_cpk(checkpoint,
                        generator=generator,
                        kp_detector=kp_detector)
    else:
        raise AttributeError(
            "Checkpoint should be specified for mode='reconstruction'.")
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    loss_list = []
    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()
    recon_gen_dir = os.path.join(log_dir, 'recon_gen')
    os.makedirs(recon_gen_dir, exist_ok=False)  #

    for it, x in tqdm(enumerate(dataloader)):
        if config['reconstruction_params']['num_videos'] is not None:
            if it > config['reconstruction_params']['num_videos']:
                break
        with torch.no_grad():
            predictions = []
            visualizations = []
            if torch.cuda.is_available():
                x['video'] = x['video'].cuda()
            kp_source = kp_detector(x['video'][:, :, 0])

            video_gen_dir = os.path.join(recon_gen_dir, x['name'][0])
            os.makedirs(video_gen_dir, exist_ok=False)  #

            for frame_idx in range(x['video'].shape[2]):
                source = x['video'][:, :, 0]
                driving = x['video'][:, :, frame_idx]
                kp_driving = kp_detector(driving)
                out = generator(source,
                                kp_source=kp_source,
                                kp_driving=kp_driving)
                out['kp_source'] = kp_source
                out['kp_driving'] = kp_driving
                predictions.append(
                    np.transpose(out['upscaled_prediction'].data.cpu().numpy(),
                                 [0, 2, 3, 1])[0])

                visualization = Visualizer(
                    **config['visualizer_params']).visualize(source=source,
                                                             driving=driving,
                                                             out=out)
                visualizations.append(visualization)

                loss_list.append(
                    torch.abs(out['upscaled_prediction'] -
                              driving).mean().cpu().numpy())

                frame_name = str(frame_idx).zfill(7) + '.png'
                upscaled_prediction = out['upscaled_prediction'].data.cpu(
                ).numpy()
                upscaled_prediction = np.transpose(upscaled_prediction,
                                                   [0, 2, 3, 1])
                upscaled_prediction = (255 * upscaled_prediction).astype(
                    np.uint8)
                imageio.imsave(os.path.join(video_gen_dir, frame_name),
                               upscaled_prediction[0])

            predictions = np.concatenate(predictions, axis=1)
            imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'),
                           (255 * predictions).astype(np.uint8))

            image_name = x['name'][0] + config['reconstruction_params'][
                'format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)

    print(f'Reconstruction loss: {np.mean(loss_list)}')
    original_img.save(os.path.join(opt.log_dir, 'original_image.png'), "PNG")

    ############## skeletons[0]
    skeletons_kps /= (128 - 1)
    config['model_params']['KTS_params']['edges'] = skeletons['humans']
    kp_to_skl = KPToSkl(**config['model_params']['KTS_params'])
    with torch.no_grad():
        skeleton_images = kp_to_skl(skeletons_kps)
    skeleton_image = (skeleton_images.unsqueeze(-1).repeat(1, 1, 1,
                                                           3).cpu().numpy())
    skeleton_image = (skeleton_image * 255).astype(np.uint8)
    Image.fromarray(skeleton_image[0]).save(
        os.path.join(opt.log_dir, 'single_skeleton.png'), "PNG")

    ############### Kp image
    kps_np_img = np.array(kp_img)
    pair_img = np.concatenate([kps_np_img, skeleton_image[0]], axis=1)
    Image.fromarray(pair_img).save(os.path.join(opt.log_dir, 'kp_skel.png'),
                                   "PNG")

    ############### Grid
    visualizer = Visualizer()
    image_grid = []
    idx = [0 + (i * 5) for i in range(0, int(images.shape[0] / 5) + 1)]
    for i in range(len(idx) - 1):
        image_grid += [skeleton_image[idx[i]:idx[i + 1]]]
        image_grid += [images[idx[i]:idx[i + 1]]]
    column_img = visualizer.create_image_grid(*image_grid)
    Image.fromarray(column_img).save(
        os.path.join(opt.log_dir, 'column_img.png'), "PNG")
Esempio n. 5
0
        use_dmm_attention=opt.use_dmm_attention,
        use_generator_attention=opt.use_generator_attention)
    #   if not opt.cpu:
    #       generator.cuda()

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    #   if not opt.cpu:
    #       kp_detector = kp_detector.cuda()

    Logger.load_cpk(opt.checkpoint,
                    generator=generator,
                    kp_detector=kp_detector,
                    use_cpu=False)

    vis = Visualizer()

    # generator = DataParallelWithCallback(generator)
    # kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    with torch.no_grad():
        driving_video = VideoToTensor()(read_video(
            opt.driving_video, opt.image_shape + (3, )))['video']
        source_image = VideoToTensor()(read_video(
            opt.source_image, opt.image_shape + (3, )))['video'][:, :1]
        print(source_image.shape)

        driving_video = torch.from_numpy(driving_video).unsqueeze(0)
Esempio n. 6
0
def prediction(config, generator, kp_detector, checkpoint, log_dir):
    dataset = FramesDataset(is_train=True, transform=VideoToTensor(), **config['dataset_params'])
    log_dir = os.path.join(log_dir, 'prediction')
    png_dir = os.path.join(log_dir, 'png')

    if checkpoint is not None:
        Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector)
    else:
        raise AttributeError("Checkpoint should be specified for mode='prediction'.")
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    print("Extracting keypoints...")

    kp_detector.eval()
    generator.eval()

    keypoints_array = []

    prediction_params = config['prediction_params']

    for it, x in tqdm(enumerate(dataloader)):
        if prediction_params['train_size'] is not None:
            if it > prediction_params['train_size']:
                break
        with torch.no_grad():
            keypoints = []
            for i in range(x['video'].shape[2]):
                kp = kp_detector(x['video'][:, :, i:(i + 1)])
                kp = {k: v.data.cpu().numpy() for k, v in kp.items()}
                keypoints.append(kp)
            keypoints_array.append(keypoints)

    predictor = PredictionModule(num_kp=config['model_params']['common_params']['num_kp'],
                                 kp_variance=config['model_params']['common_params']['kp_variance'],
                                 **prediction_params['rnn_params']).cuda()

    num_epochs = prediction_params['num_epochs']
    lr = prediction_params['lr']
    bs = prediction_params['batch_size']
    num_frames = prediction_params['num_frames']
    init_frames = prediction_params['init_frames']

    optimizer = torch.optim.Adam(predictor.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=50)

    kp_dataset = KPDataset(keypoints_array, num_frames=num_frames)

    kp_dataloader = DataLoader(kp_dataset, batch_size=bs)

    print("Training prediction...")
    for _ in trange(num_epochs):
        loss_list = []
        for x in kp_dataloader:
            x = {k: v.cuda() for k, v in x.items()}
            gt = {k: v.clone() for k, v in x.items()}
            for k in x:
                x[k][:, init_frames:] = 0
            prediction = predictor(x)

            loss = sum([torch.abs(gt[k][:, init_frames:] - prediction[k][:, init_frames:]).mean() for k in x])

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            loss_list.append(loss.detach().data.cpu().numpy())

        loss = np.mean(loss_list)
        scheduler.step(loss)

    dataset = FramesDataset(is_train=False, transform=VideoToTensor(), **config['dataset_params'])
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    print("Make predictions...")
    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            x['video'] = x['video'][:, :, :num_frames]
            kp_init = kp_detector(x['video'])
            for k in kp_init:
                kp_init[k][:, init_frames:] = 0

            kp_source = kp_detector(x['video'][:, :, :1])

            kp_video = predictor(kp_init)
            for k in kp_video:
                kp_video[k][:, :init_frames] = kp_init[k][:, :init_frames]
            if 'var' in kp_video and prediction_params['predict_variance']:
                kp_video['var'] = kp_init['var'][:, (init_frames - 1):init_frames].repeat(1, kp_video['var'].shape[1],
                                                                                          1, 1, 1)
            out = generate(generator, appearance_image=x['video'][:, :, :1], kp_appearance=kp_source,
                           kp_video=kp_video)

            x['source'] = x['video'][:, :, :1]

            out_video_batch = out['video_prediction'].data.cpu().numpy()
            out_video_batch = np.concatenate(np.transpose(out_video_batch, [0, 2, 3, 4, 1])[0], axis=1)
            imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * out_video_batch).astype(np.uint8))

            image = Visualizer(**config['visualizer_params']).visualize_reconstruction(x, out)
            image_name = x['name'][0] + prediction_params['format']
            imageio.mimsave(os.path.join(log_dir, image_name), image)

            del x, kp_video, kp_source, out
def reconstruction(config, generator, mask_generator, checkpoint, log_dir,
                   dataset):
    png_dir = os.path.join(log_dir, 'reconstruction/png')
    log_dir = os.path.join(log_dir, 'reconstruction')

    if checkpoint is not None:
        epoch = Logger.load_cpk(checkpoint,
                                generator=generator,
                                mask_generator=mask_generator)
        print('checkpoint:' + str(epoch))
    else:
        raise AttributeError(
            "Checkpoint should be specified for mode='reconstruction'.")
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    loss_list = []
    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        mask_generator = DataParallelWithCallback(mask_generator)

    generator.eval()
    mask_generator.eval()

    recon_gen_dir = './log/recon_gen'
    os.makedirs(recon_gen_dir, exist_ok=False)

    for it, x in tqdm(enumerate(dataloader)):
        if config['reconstruction_params']['num_videos'] is not None:
            if it > config['reconstruction_params']['num_videos']:
                break
        with torch.no_grad():
            predictions = []
            visualizations = []
            if torch.cuda.is_available():
                x['video'] = x['video'].cuda()
            mask_source = mask_generator(x['video'][:, :, 0])

            video_gen_dir = recon_gen_dir + '/' + x['name'][0]
            os.makedirs(video_gen_dir, exist_ok=False)

            for frame_idx in range(x['video'].shape[2]):
                source = x['video'][:, :, 0]
                driving = x['video'][:, :, frame_idx]
                mask_driving = mask_generator(driving)
                out = generator(source,
                                driving,
                                mask_source=mask_source,
                                mask_driving=mask_driving,
                                mask_driving2=None,
                                animate=False,
                                predict_mask=False)
                out['mask_source'] = mask_source
                out['mask_driving'] = mask_driving

                predictions.append(
                    np.transpose(
                        out['second_phase_prediction'].data.cpu().numpy(),
                        [0, 2, 3, 1])[0])

                visualization = Visualizer(
                    **config['visualizer_params']).visualize(source=source,
                                                             driving=driving,
                                                             target=None,
                                                             out=out,
                                                             driving2=None)
                visualizations.append(visualization)

                loss_list.append(
                    torch.abs(out['second_phase_prediction'] -
                              driving).mean().cpu().numpy())

                frame_name = str(frame_idx).zfill(7) + '.png'
                second_phase_prediction = out[
                    'second_phase_prediction'].data.cpu().numpy()
                second_phase_prediction = np.transpose(second_phase_prediction,
                                                       [0, 2, 3, 1])
                second_phase_prediction = (255 *
                                           second_phase_prediction).astype(
                                               np.uint8)
                imageio.imsave(os.path.join(video_gen_dir, frame_name),
                               second_phase_prediction[0])

            predictions = np.concatenate(predictions, axis=1)

            image_name = x['name'][0] + config['reconstruction_params'][
                'format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)

    print("Reconstruction loss: %s" % np.mean(loss_list))
Esempio n. 8
0
def animate(config, generator, kp_detector, checkpoint, log_dir, dataset,
            kp_after_softmax):
    log_dir = os.path.join(log_dir, 'animation')
    png_dir = os.path.join(log_dir, 'png')
    animate_params = config['animate_params']
    frame_size = config['dataset_params']['frame_shape'][0]
    latent_size = int(config['model_params']['common_params']['scale_factor'] *
                      frame_size)
    dataset = PairedDataset(initial_dataset=dataset,
                            number_of_pairs=animate_params['num_pairs'])
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if checkpoint is not None:
        Logger.load_cpk(checkpoint,
                        generator=generator,
                        kp_detector=kp_detector)
    else:
        raise AttributeError(
            "Checkpoint should be specified for mode='animate'.")

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            predictions = []
            visualizations = []

            driving_video = x['driving_video']
            source_frame = x['source_video'][:, :, 0, :, :]

            kp_source = kp_detector(source_frame)
            kp_driving_initial = kp_detector(driving_video[:, :, 0])

            for frame_idx in range(driving_video.shape[2]):
                driving_frame = driving_video[:, :, frame_idx]
                kp_driving = kp_detector(driving_frame)

                if kp_after_softmax:
                    kp_norm = normalize_kp(
                        kp_source=kp_source,
                        kp_driving=kp_driving,
                        kp_driving_initial=kp_driving_initial,
                        **animate_params['normalization_params'])
                    kp_source = draw_kp([latent_size, latent_size], kp_source)
                    kp_norm = draw_kp([latent_size, latent_size], kp_norm)
                    kp_source = norm_mask(latent_size, kp_source)
                    kp_norm = norm_mask(latent_size, kp_norm)
                    out = generator(source_frame,
                                    kp_source=kp_source,
                                    kp_driving=kp_norm)
                    kp_norm_int = F.interpolate(kp_norm,
                                                size=source_frame.shape[2:],
                                                mode='bilinear',
                                                align_corners=False)
                    out['kp_norm_int'] = kp_norm_int.repeat(1, 3, 1, 1)
                else:
                    out = generator(source_frame,
                                    kp_source=kp_source,
                                    kp_driving=kp_driving)

                out['kp_driving'] = kp_driving
                out['kp_source'] = kp_source

                predictions.append(
                    np.transpose(out['low_res_prediction'].data.cpu().numpy(),
                                 [0, 2, 3, 1])[0])
                predictions.append(
                    np.transpose(out['upscaled_prediction'].data.cpu().numpy(),
                                 [0, 2, 3, 1])[0])

                visualization = Visualizer(
                    **config['visualizer_params']).visualize(
                        source=source_frame, driving=driving_frame, out=out)
                visualization = visualization
                visualizations.append(visualization)

            predictions = np.concatenate(predictions, axis=1)
            result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
            imageio.imsave(os.path.join(png_dir, result_name + '.png'),
                           (255 * predictions).astype(np.uint8))

            image_name = result_name + animate_params['format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
Esempio n. 9
0
def reconstruction(config, generator, kp_detector, checkpoint, log_dir,
                   dataset):
    png_dir = os.path.join(log_dir, 'reconstruction/png')
    log_dir = os.path.join(log_dir, 'reconstruction')

    if checkpoint is not None:
        Logger.load_cpk(checkpoint,
                        generator=generator,
                        kp_detector=kp_detector)
    else:
        raise AttributeError("Checkpoint should be specified for mode='test'.")
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if not os.path.exists(png_dir):
        os.makedirs(png_dir)

    loss_list = []
    generator = DataParallelWithCallback(generator)
    kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    cat_dict = lambda l, dim: {
        k: torch.cat([v[k] for v in l], dim=dim)
        for k in l[0]
    }
    for it, x in tqdm(enumerate(dataloader)):
        if config['reconstruction_params']['num_videos'] is not None:
            if it > config['reconstruction_params']['num_videos']:
                break
        with torch.no_grad():
            kp_appearance = kp_detector(x['video'][:, :, :1])
            d = x['video'].shape[2]
            kp_video = cat_dict(
                [kp_detector(x['video'][:, :, i:(i + 1)]) for i in range(d)],
                dim=1)

            out = generate(generator,
                           appearance_image=x['video'][:, :, :1],
                           kp_appearance=kp_appearance,
                           kp_video=kp_video)
            x['source'] = x['video'][:, :, :1]

            # Store to .png for evaluation
            out_video_batch = out['video_prediction'].data.cpu().numpy()
            out_video_batch = np.concatenate(np.transpose(
                out_video_batch, [0, 2, 3, 4, 1])[0],
                                             axis=1)
            imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'),
                           (255 * out_video_batch).astype(np.uint8))

            image = Visualizer(
                **config['visualizer_params']).visualize_reconstruction(
                    x, out)
            image_name = x['name'][0] + config['reconstruction_params'][
                'format']
            imageio.mimsave(os.path.join(log_dir, image_name), image)

            loss = reconstruction_loss(out['video_prediction'].cpu(),
                                       x['video'].cpu(), 1)
            loss_list.append(loss.data.cpu().numpy())
            del x, kp_video, kp_appearance, out, loss
    print("Reconstruction loss: %s" % np.mean(loss_list))
def animate(config, generator, region_predictor, avd_network, checkpoint,
            log_dir, dataset):
    animate_params = config['animate_params']
    log_dir = os.path.join(log_dir, 'animation')

    dataset = PairedDataset(initial_dataset=dataset,
                            number_of_pairs=animate_params['num_pairs'])
    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    if checkpoint is not None:
        Logger.load_cpk(checkpoint,
                        generator=generator,
                        region_predictor=region_predictor,
                        avd_network=avd_network)
    else:
        raise AttributeError(
            "Checkpoint should be specified for mode='animate'.")

    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    if torch.cuda.is_available():
        generator = DataParallelWithCallback(generator)
        region_predictor = DataParallelWithCallback(region_predictor)
        avd_network = DataParallelWithCallback(avd_network)

    generator.eval()
    region_predictor.eval()
    avd_network.eval()

    for it, x in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            visualizations = []

            driving_video = x['driving_video']
            source_frame = x['source_video'][:, :, 0, :, :]

            source_region_params = region_predictor(source_frame)
            driving_region_params_initial = region_predictor(
                driving_video[:, :, 0])

            for frame_idx in range(driving_video.shape[2]):
                driving_frame = driving_video[:, :, frame_idx]
                driving_region_params = region_predictor(driving_frame)
                new_region_params = get_animation_region_params(
                    source_region_params,
                    driving_region_params,
                    driving_region_params_initial,
                    mode=animate_params['mode'],
                    avd_network=avd_network)
                out = generator(source_frame,
                                source_region_params=source_region_params,
                                driving_region_params=new_region_params)

                out['driving_region_params'] = driving_region_params
                out['source_region_params'] = source_region_params
                out['new_region_params'] = new_region_params

                visualization = Visualizer(
                    **config['visualizer_params']).visualize(
                        source=source_frame, driving=driving_frame, out=out)
                visualizations.append(visualization)

            result_name = "-".join([x['driving_name'][0], x['source_name'][0]])
            image_name = result_name + animate_params['format']
            imageio.mimsave(os.path.join(log_dir, image_name), visualizations)
Esempio n. 11
0
        blocks_discriminator = config['model_params']['discriminator_params']['num_blocks']
        assert len(config['train_params']['loss_weights']['reconstruction']) == blocks_discriminator + 1

    generator = MotionTransferGenerator(**config['model_params']['generator_params'],
                                        **config['model_params']['common_params'])
    if not opt.cpu:
        generator.cuda()

    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    if not opt.cpu:
        kp_detector = kp_detector.cuda()

    Logger.load_cpk(opt.checkpoint, generator=generator, kp_detector=kp_detector, use_cpu=True)

    vis = Visualizer()

    if not opt.cpu: 
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()

    '''
    Logic: The goal of this module is to essentially loop through all of the GIFs in a directory and then 
    extract the pose points for the first frame of the GIF for each GIF. This allows for an alignment based
    on only the first frame. 
    
    TODO: Extend this to extract poses from the driving video to then
    obtain poses at each frame for alignment.