示例#1
0
    def forward(self, pyd):
        pyd = create_pyramid(pyd, self.n_local_enhancers)

        if len(pyd) == 1:
            return self.base(pyd[-1])

        x = pyd[-1]
        x = self.base.in_conv(x)
        x = self.base.inner(x)

        for n in range(1, len(pyd)):
            enhancer = getattr(self, 'enhancer%d' % n)
            x = enhancer.extract_features(pyd[self.n_local_enhancers - n], x)
            if n == self.n_local_enhancers:
                x = enhancer.out_conv(x)

        return x
示例#2
0
    def forward(self, pyd):
        pyd = create_pyramid(pyd, self.n_local_enhancers)

        # Call global at the coarsest level
        if len(pyd) == 1:
            return self.base(pyd[-1])

        x = pyd[-1]
        x = self.base.in_conv(x)
        x = self.base.inner(x)

        # Apply enhancer for each level
        for n in range(1, len(pyd)):
            enhancer = getattr(self, 'enhancer%d' % n)
            # x = enhancer(pyd[self.n_local_enhancers - n], x)
            x = enhancer.extract_features(pyd[self.n_local_enhancers - n], x)
            if n == self.n_local_enhancers:
                x = enhancer.out_conv(x)

        return x
示例#3
0
    def forward(self, pyd):
        pyd = create_pyramid(pyd, self.n_local_enhancers)

        if len(pyd) == 1:
            return self.base(pyd[-1])

        x = pyd[-1]
        x = self.base.in_conv(x)
        x = self.base.inner(x)

        for n in range(1, len(pyd)):
            enhancer = getattr(self, 'enhancer%d' % n)
            x = enhancer.extract_features(pyd[self.n_local_enhancers - n], x)
            if n == self.n_local_enhancers:
                output = []
                for i in range(len(self.out_nc)):
                    out_conv = getattr(enhancer, 'out_conv%d' % (i + 1))
                    output.append(out_conv(x))

                return tuple(output)
示例#4
0
    def forward(self, pyd):
        pyd[1] = torch.nn.functional.interpolate(pyd[0], scale_factor=0.5, mode='area')

        pyd = create_pyramid(pyd, self.n_local_enhancers)

        # Call global at the coarsest level
        if len(pyd) == 1:
            return self.base(pyd[-1])

        x = pyd[-1]
        x = self.base.in_conv(x)
        x = self.base.inner(x)

        # Apply enhancer for each level
        for n in range(1, len(pyd)):
            enhancer = getattr(self, 'enhancer%d' % n)
            # x = enhancer(pyd[self.n_local_enhancers - n], x)
            x = enhancer.extract_features(pyd[self.n_local_enhancers - n], x)
            if n == self.n_local_enhancers:
                x = enhancer.out_conv(x)

        return x
示例#5
0
文件: swap.py 项目: KSRawal/fsgan
    def __call__(self,
                 source_path,
                 target_path,
                 output_path=None,
                 select_source='longest',
                 select_target='longest',
                 finetune=None):
        is_vid = os.path.splitext(source_path)[1] == '.mp4'
        finetune = self.finetune_enabled and is_vid if finetune is None else finetune and is_vid

        # Validation
        assert os.path.isfile(
            source_path), 'Source path "%s" does not exist' % source_path
        assert os.path.isfile(
            target_path), 'Target path "%s" does not exist' % target_path

        # Cache input
        source_cache_dir, source_seq_file_path, _ = self.cache(source_path)
        target_cache_dir, target_seq_file_path, _ = self.cache(target_path)

        # Load sequences from file
        with open(source_seq_file_path, "rb") as fp:  # Unpickling
            source_seq_list = pickle.load(fp)
        with open(target_seq_file_path, "rb") as fp:  # Unpickling
            target_seq_list = pickle.load(fp)

        # Select source and target sequence
        source_seq = select_seq(source_seq_list, select_source)
        target_seq = select_seq(target_seq_list, select_target)

        # Set source and target sequence videos paths
        src_path_no_ext, src_ext = os.path.splitext(source_path)
        src_vid_seq_name = os.path.basename(
            src_path_no_ext) + '_seq%02d%s' % (source_seq.id, src_ext)
        src_vid_seq_path = os.path.join(source_cache_dir, src_vid_seq_name)
        tgt_path_no_ext, tgt_ext = os.path.splitext(target_path)
        tgt_vid_seq_name = os.path.basename(
            tgt_path_no_ext) + '_seq%02d%s' % (target_seq.id, tgt_ext)
        tgt_vid_seq_path = os.path.join(target_cache_dir, tgt_vid_seq_name)

        # Set output path
        if output_path is not None:
            if os.path.isdir(output_path):
                output_filename = f'{os.path.basename(src_path_no_ext)}_{os.path.basename(tgt_path_no_ext)}.mp4'
                output_path = os.path.join(output_path, output_filename)

        # Initialize appearance map
        src_transform = img_lms_pose_transforms.Compose(
            [Rotate(), Pyramids(2),
             ToTensor(), Normalize()])
        tgt_transform = img_lms_pose_transforms.Compose(
            [ToTensor(), Normalize()])
        appearance_map = AppearanceMapDataset(
            src_vid_seq_path, tgt_vid_seq_path, src_transform, tgt_transform,
            self.landmarks_postfix, self.pose_postfix,
            self.segmentation_postfix, self.min_radius)
        appearance_map_loader = DataLoader(appearance_map,
                                           batch_size=self.batch_size,
                                           num_workers=1,
                                           pin_memory=True,
                                           drop_last=False,
                                           shuffle=False)

        # Initialize video writer
        self.video_renderer.init(target_path,
                                 target_seq,
                                 output_path,
                                 _appearance_map=appearance_map)

        # Finetune reenactment model on source sequences
        if finetune:
            self.finetune(src_vid_seq_path, self.finetune_save)

        print(
            f'=> Face swapping: "{src_vid_seq_name}" -> "{tgt_vid_seq_name}"...'
        )

        # For each batch of frames in the target video
        for i, (src_frame, src_landmarks, src_poses, bw, tgt_frame, tgt_landmarks, tgt_pose, tgt_mask) \
                in enumerate(tqdm(appearance_map_loader, unit='batches', file=sys.stdout)):
            # Prepare input
            for p in range(len(src_frame)):
                src_frame[p] = src_frame[p].to(self.device)
            tgt_frame = tgt_frame.to(self.device)
            tgt_landmarks = tgt_landmarks.to(self.device)
            # tgt_mask = tgt_mask.unsqueeze(1).to(self.device)
            tgt_mask = tgt_mask.unsqueeze(1).int().to(self.device).bool(
            )  # TODO: check if the boolean tensor bug is fixed
            bw = bw.to(self.device)
            bw_indices = torch.nonzero(torch.any(bw > 0, dim=0),
                                       as_tuple=True)[0]
            bw = bw[:, bw_indices]

            # For each source frame perform reenactment
            reenactment_triplet = []
            for j in bw_indices:
                input = []
                for p in range(len(src_frame)):
                    context = self.landmarks_decoders[p](tgt_landmarks)
                    input.append(
                        torch.cat((src_frame[p][:, j], context), dim=1))

                # Reenactment
                reenactment_triplet.append(self.Gr(input).unsqueeze(1))
            reenactment_tensor = torch.cat(reenactment_triplet, dim=1)

            # Barycentric interpolation of reenacted frames
            reenactment_tensor = (reenactment_tensor *
                                  bw.view(*bw.shape, 1, 1, 1)).sum(dim=1)

            # Compute reenactment segmentation
            reenactment_seg = self.S(reenactment_tensor)
            reenactment_background_mask_tensor = (reenactment_seg.argmax(1) !=
                                                  1).unsqueeze(1)

            # Remove the background of the aligned face
            reenactment_tensor.masked_fill_(reenactment_background_mask_tensor,
                                            -1.0)

            # Soften target mask
            soft_tgt_mask, eroded_tgt_mask = self.smooth_mask(tgt_mask)

            # Complete face
            inpainting_input_tensor = torch.cat(
                (reenactment_tensor, eroded_tgt_mask.float()), dim=1)
            inpainting_input_tensor_pyd = create_pyramid(
                inpainting_input_tensor, 2)
            completion_tensor = self.Gc(inpainting_input_tensor_pyd)

            # Blend faces
            transfer_tensor = transfer_mask(completion_tensor, tgt_frame,
                                            eroded_tgt_mask)
            blend_input_tensor = torch.cat(
                (transfer_tensor, tgt_frame, eroded_tgt_mask.float()), dim=1)
            blend_input_tensor_pyd = create_pyramid(blend_input_tensor, 2)
            blend_tensor = self.Gb(blend_input_tensor_pyd)

            result_tensor = blend_tensor * soft_tgt_mask + tgt_frame * (
                1 - soft_tgt_mask)

            # Write output
            if self.verbose == 0:
                self.video_renderer.write(result_tensor)
            elif self.verbose == 1:
                curr_src_frames = [
                    src_frame[0][:, i] for i in range(src_frame[0].shape[1])
                ]
                self.video_renderer.write(*curr_src_frames, result_tensor,
                                          tgt_frame)
            else:
                curr_src_frames = [
                    src_frame[0][:, i] for i in range(src_frame[0].shape[1])
                ]
                tgt_seg_blend = blend_seg_label(tgt_frame,
                                                tgt_mask.squeeze(1),
                                                alpha=0.2)
                soft_tgt_mask = soft_tgt_mask.mul(2.).sub(1.).repeat(
                    1, 3, 1, 1)
                self.video_renderer.write(*curr_src_frames, result_tensor,
                                          tgt_frame, reenactment_tensor,
                                          completion_tensor, transfer_tensor,
                                          soft_tgt_mask, tgt_seg_blend,
                                          tgt_pose)

        # Load original reenactment weights
        if finetune:
            if self.gpus and len(self.gpus) > 1:
                self.Gr.module.load_state_dict(self.reenactment_state_dict)
            else:
                self.Gr.load_state_dict(self.reenactment_state_dict)

        # Finalize video and wait for the video writer to finish writing
        self.video_renderer.finalize()
        self.video_renderer.wait_until_finished()
示例#6
0
    def proces_epoch(dataset_loader, train=True):
        stage = 'TRAINING' if train else 'VALIDATION'
        total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch
        pbar = tqdm(dataset_loader, unit='batches')

        # Set networks training mode
        Gc.train(train)
        D.train(train)
        Gr.train(False)
        S.train(False)
        L.train(False)

        # Reset logger
        logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format(
            stage, res, res, epoch + 1, res_epochs,
            scheduler_G.get_lr()[0]))

        # For each batch in the training data
        for i, (img, target) in enumerate(pbar):
            # Prepare input
            with torch.no_grad():
                # For each view images
                for j in range(len(img)):
                    # For each pyramid image: push to device
                    for p in range(len(img[j])):
                        img[j][p] = img[j][p].to(device)

                # Compute context
                context = L(img[1][0].sub(context_mean).div(context_std))
                context = landmarks_utils.filter_landmarks(context)

                # Normalize each of the pyramid images
                for j in range(len(img)):
                    for p in range(len(img[j])):
                        img[j][p].sub_(img_mean).div_(img_std)

                # # Compute segmentation
                # seg = []
                # for j in range(len(img)):
                #     curr_seg = S(img[j][0])
                #     if curr_seg.shape[2:] != (res, res):
                #         curr_seg = F.interpolate(curr_seg, (res, res), mode='bicubic', align_corners=False)
                #     seg.append(curr_seg)

                # Compute segmentation
                target_seg = S(img[1][0])
                if target_seg.shape[2:] != (res, res):
                    target_seg = F.interpolate(target_seg, (res, res),
                                               mode='bicubic',
                                               align_corners=False)

                # Concatenate pyramid images with context to derive the final input
                input = []
                for p in range(len(img[0]) - 1, -1, -1):
                    context = F.interpolate(context,
                                            size=img[0][p].shape[2:],
                                            mode='bicubic',
                                            align_corners=False)
                    input.insert(0, torch.cat((img[0][p], context), dim=1))

                # Reenactment
                reenactment_img = Gr(input)
                reenactment_seg = S(reenactment_img)
                if reenactment_img.shape[2:] != (res, res):
                    reenactment_img = F.interpolate(reenactment_img,
                                                    (res, res),
                                                    mode='bilinear',
                                                    align_corners=False)
                    reenactment_seg = F.interpolate(reenactment_seg,
                                                    (res, res),
                                                    mode='bilinear',
                                                    align_corners=False)

                # Remove unnecessary pyramids
                for j in range(len(img)):
                    img[j] = img[j][-ri - 1:]

                # Source face
                reenactment_face_mask = reenactment_seg.argmax(1) == 1
                inpainting_mask = seg_utils.random_hair_inpainting_mask_tensor(
                    reenactment_face_mask).to(device)
                reenactment_face_mask = reenactment_face_mask * (
                    inpainting_mask == 0)
                reenactment_img_with_hole = reenactment_img.masked_fill(
                    ~reenactment_face_mask.unsqueeze(1), background_value)

                # Target face
                target_face_mask = (target_seg.argmax(1) == 1).unsqueeze(1)
                inpainting_target = img[1][0]
                inpainting_target.masked_fill_(~target_face_mask,
                                               background_value)

                # Inpainting input
                inpainting_input = torch.cat(
                    (reenactment_img_with_hole, target_face_mask.float()),
                    dim=1)
                inpainting_input_pyd = img_utils.create_pyramid(
                    inpainting_input, len(img[0]))

            # Face inpainting
            inpainting_pred = Gc(inpainting_input_pyd)

            # Fake Detection and Loss
            inpainting_pred_pyd = img_utils.create_pyramid(
                inpainting_pred, len(img[0]))
            pred_fake_pool = D([x.detach() for x in inpainting_pred_pyd])
            loss_D_fake = criterion_gan(pred_fake_pool, False)

            # Real Detection and Loss
            inpainting_target_pyd = img_utils.create_pyramid(
                inpainting_target, len(img[0]))
            pred_real = D(inpainting_target_pyd)
            loss_D_real = criterion_gan(pred_real, True)

            loss_D_total = (loss_D_fake + loss_D_real) * 0.5

            # GAN loss (Fake Passability Loss)
            pred_fake = D(inpainting_pred_pyd)
            loss_G_GAN = criterion_gan(pred_fake, True)

            # Reconstruction
            loss_pixelwise = criterion_pixelwise(inpainting_pred,
                                                 inpainting_target)
            loss_id = criterion_id(inpainting_pred, inpainting_target)
            loss_attr = criterion_attr(inpainting_pred, inpainting_target)
            loss_rec = pix_weight * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr

            loss_G_total = rec_weight * loss_rec + gan_weight * loss_G_GAN

            if train:
                # Update generator weights
                optimizer_G.zero_grad()
                loss_G_total.backward()
                optimizer_G.step()

                # Update discriminator weights
                optimizer_D.zero_grad()
                loss_D_total.backward()
                optimizer_D.step()

            logger.update('losses',
                          pixelwise=loss_pixelwise,
                          id=loss_id,
                          attr=loss_attr,
                          rec=loss_rec,
                          g_gan=loss_G_GAN,
                          d_gan=loss_D_total)
            total_iter += dataset_loader.batch_size

            # Batch logs
            pbar.set_description(str(logger))
            if train and i % log_freq == 0:
                logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter)

        # Epoch logs
        logger.log_scalars_avg(
            '%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch)
        if not train:
            # Log images
            grid = img_utils.make_grid(img[0][0], reenactment_img,
                                       reenactment_img_with_hole,
                                       inpainting_pred, inpainting_target)
            logger.log_image('%dx%d/vis' % (res, res), grid, epoch)

        return logger.log_dict['losses']['rec'].avg
示例#7
0
    def proces_epoch(dataset_loader, train=True):
        stage = 'TRAINING' if train else 'VALIDATION'
        total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch
        pbar = tqdm(dataset_loader, unit='batches')

        # Set networks training mode
        G.train(train)
        D.train(train)
        S.train(False)
        L.train(False)

        # Reset logger
        logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format(
            stage, res, res, epoch + 1, res_epochs, scheduler_G.get_lr()[0]))

        # For each batch in the training data
        for i, (img, target) in enumerate(pbar):
            # Prepare input
            with torch.no_grad():
                # For each view images
                for j in range(len(img)):
                    # For each pyramid image: push to device
                    for p in range(len(img[j])):
                        img[j][p] = img[j][p].to(device)

                # Compute context
                context = L(img[1][0].sub(context_mean).div(context_std))
                context = landmarks_utils.filter_context(context)

                # Normalize each of the pyramid images
                for j in range(len(img)):
                    for p in range(len(img[j])):
                        img[j][p].sub_(img_mean).div_(img_std)

                # Compute segmentation
                seg = S(img[1][0])
                if seg.shape[2:] != (res, res):
                    seg = F.interpolate(seg, (res, res), mode='bicubic', align_corners=False)
                # seg = img_utils.create_pyramid(seg, len(img[0]))[-ri - 1:]

                # Remove unnecessary pyramids
                for j in range(len(img)):
                    img[j] = img[j][-ri - 1:]

                # Concatenate pyramid images with context to derive the final input
                input = []
                for p in range(len(img[0]) - 1, -1, -1):
                    context = F.interpolate(context, size=img[0][p].shape[2:], mode='bicubic', align_corners=False)
                    input.append(torch.cat((img[0][p], context), dim=1))
                input = input[::-1]

            # Reenactment
            img_pred, seg_pred = G(input)

            # Fake Detection and Loss
            img_pred_pyd = img_utils.create_pyramid(img_pred, len(img[0]))
            pred_fake_pool = D([x.detach() for x in img_pred_pyd])
            loss_D_fake = criterion_gan(pred_fake_pool, False)

            # Real Detection and Loss
            pred_real = D(img[1])
            loss_D_real = criterion_gan(pred_real, True)

            loss_D_total = (loss_D_fake + loss_D_real) * 0.5

            # GAN loss (Fake Passability Loss)
            pred_fake = D(img_pred_pyd)
            loss_G_GAN = criterion_gan(pred_fake, True)

            # Reconstruction and segmentation loss
            loss_pixelwise = criterion_pixelwise(img_pred, img[1][0])
            loss_id = criterion_id(img_pred, img[1][0])
            loss_attr = criterion_attr(img_pred, img[1][0])
            loss_rec = 0.1 * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr
            loss_seg = criterion_pixelwise(seg_pred, seg)

            loss_G_total = rec_weight * loss_rec + seg_weight * loss_seg + gan_weight * loss_G_GAN

            if train:
                # Update generator weights
                optimizer_G.zero_grad()
                loss_G_total.backward()
                optimizer_G.step()

                # Update discriminator weights
                optimizer_D.zero_grad()
                loss_D_total.backward()
                optimizer_D.step()

            logger.update('losses', pixelwise=loss_pixelwise, id=loss_id, attr=loss_attr, rec=loss_rec, seg=loss_seg,
                          g_gan=loss_G_GAN, d_gan=loss_D_total)
            total_iter += dataset_loader.batch_size

            # Batch logs
            pbar.set_description(str(logger))
            if train and i % log_freq == 0:
                logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter)

        # Epoch logs
        logger.log_scalars_avg('%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch)
        if not train:
            # Log images
            blend_seg_pred = seg_utils.blend_seg_pred(img[1][0], seg_pred)
            blend_seg = seg_utils.blend_seg_pred(img[1][0], seg)
            grid = img_utils.make_grid(img[0][0], img_pred, img[1][0], blend_seg_pred, blend_seg)
            logger.log_image('%dx%d/vis' % (res, res), grid, epoch)

        return logger.log_dict['losses']['rec'].avg
示例#8
0
def main(
        source_path,
        target_path,
        arch='res_unet_split.MultiScaleResUNet(in_nc=71,out_nc=(3,3),flat_layers=(2,0,2,3),ngf=128)',
        model_path='../weights/ijbc_msrunet_256_1_2_reenactment_stepwise_v1.pth',
        pose_model_path='../weights/hopenet_robust_alpha1.pth',
        pil_transforms1=('landmark_transforms.FaceAlignCrop(bbox_scale=1.2)',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)'),
        pil_transforms2=('landmark_transforms.FaceAlignCrop(bbox_scale=1.2)',
                         'landmark_transforms.Resize(256)',
                         'landmark_transforms.Pyramids(2)'),
        tensor_transforms1=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        tensor_transforms2=(
            'landmark_transforms.ToTensor()',
            'transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])'),
        output_path=None,
        crop_size=256,
        display=False):
    torch.set_grad_enabled(False)

    # Initialize models
    fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._3D,
                                      flip_input=True)
    device, gpus = utils.set_device()
    G = obj_factory(arch).to(device)
    checkpoint = torch.load(model_path)
    G.load_state_dict(checkpoint['state_dict'])
    G.train(False)

    # Initialize pose
    Gp = Hopenet().to(device)
    checkpoint = torch.load(pose_model_path)
    Gp.load_state_dict(checkpoint['state_dict'])
    Gp.train(False)

    # Initialize landmarks to heatmaps
    landmarks2heatmaps = [
        LandmarkHeatmap(kernel_size=13, size=(256, 256)).to(device),
        LandmarkHeatmap(kernel_size=7, size=(128, 128)).to(device)
    ]

    # Initialize transformations
    pil_transforms1 = obj_factory(
        pil_transforms1) if pil_transforms1 is not None else []
    pil_transforms2 = obj_factory(
        pil_transforms2) if pil_transforms2 is not None else []
    tensor_transforms1 = obj_factory(
        tensor_transforms1) if tensor_transforms1 is not None else []
    tensor_transforms2 = obj_factory(
        tensor_transforms2) if tensor_transforms2 is not None else []
    img_transforms1 = landmark_transforms.ComposePyramids(pil_transforms1 +
                                                          tensor_transforms1)
    img_transforms2 = landmark_transforms.ComposePyramids(pil_transforms2 +
                                                          tensor_transforms2)

    # Process source image
    source_bgr = cv2.imread(source_path)
    source_rgb = source_bgr[:, :, ::-1]
    source_landmarks, source_bbox = process_image(fa, source_rgb, crop_size)
    if source_bbox is None:
        raise RuntimeError("Couldn't detect a face in source image: " +
                           source_path)
    source_tensor, source_landmarks, source_bbox = img_transforms1(
        source_rgb, source_landmarks, source_bbox)
    source_cropped_bgr = tensor2bgr(
        source_tensor[0] if isinstance(source_tensor, list) else source_tensor)
    for i in range(len(source_tensor)):
        source_tensor[i] = source_tensor[i].unsqueeze(0).to(device)

    # Extract landmarks, bounding boxes, euler angles, and 3D landmarks from target video
    frame_indices, landmarks, bboxes, eulers, landmarks_3d = \
        extract_landmarks_bboxes_euler_3d_from_video(target_path, Gp, fa, device=device)
    if frame_indices.size == 0:
        raise RuntimeError('No faces were detected in the target video: ' +
                           target_path)

    # Open target video file
    cap = cv2.VideoCapture(target_path)
    if not cap.isOpened():
        raise RuntimeError('Failed to read target video: ' + target_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    # Initialize output video file
    if output_path is not None:
        if os.path.isdir(output_path):
            output_filename = os.path.splitext(os.path.basename(source_path))[0] + '_' + \
                              os.path.splitext(os.path.basename(target_path))[0] + '.mp4'
            output_path = os.path.join(output_path, output_filename)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out_vid = cv2.VideoWriter(
            output_path, fourcc, fps,
            (source_cropped_bgr.shape[1] * 3, source_cropped_bgr.shape[0]))
    else:
        out_vid = None

    # For each frame in the target video
    valid_frame_ind = 0
    for i in tqdm(range(total_frames)):
        ret, target_bgr = cap.read()
        if target_bgr is None:
            continue
        if i not in frame_indices:
            continue
        target_rgb = target_bgr[:, :, ::-1]
        target_tensor, target_landmarks, target_bbox = img_transforms2(
            target_rgb, landmarks_3d[valid_frame_ind], bboxes[valid_frame_ind])
        target_euler = eulers[valid_frame_ind]
        valid_frame_ind += 1

        # TODO: Calculate the number of required reenactment iterations
        reenactment_iterations = 2

        # Generate landmarks sequence
        target_landmarks_sequence = []
        for ri in range(1, reenactment_iterations):
            interp_landmarks = []
            for j in range(len(source_tensor)):
                alpha = float(ri) / reenactment_iterations
                curr_interp_landmarks_np = interpolate_points(
                    source_landmarks[j].cpu().numpy(),
                    target_landmarks[j].cpu().numpy(),
                    alpha=alpha)
                interp_landmarks.append(
                    torch.from_numpy(curr_interp_landmarks_np))
            target_landmarks_sequence.append(interp_landmarks)
        target_landmarks_sequence.append(target_landmarks)

        # Iterative reenactment
        out_img_tensor = source_tensor
        for curr_target_landmarks in target_landmarks_sequence:
            out_img_tensor = create_pyramid(out_img_tensor, 2)
            input_tensor = []
            for j in range(len(out_img_tensor)):
                curr_target_landmarks[j] = curr_target_landmarks[j].unsqueeze(
                    0).to(device)
                curr_target_landmarks[j] = landmarks2heatmaps[j](
                    curr_target_landmarks[j])
                input_tensor.append(
                    torch.cat((out_img_tensor[j], curr_target_landmarks[j]),
                              dim=1))
            out_img_tensor, out_seg_tensor = G(input_tensor)

        # Convert back to numpy images
        out_img_bgr = tensor2bgr(out_img_tensor)
        frame_cropped_bgr = tensor2bgr(target_tensor[0])

        # Render
        # for point in np.round(frame_landmarks).astype(int):
        #     cv2.circle(frame_cropped_bgr, (point[0], point[1]), 2, (0, 0, 255), -1)
        render_img = np.concatenate(
            (source_cropped_bgr, out_img_bgr, frame_cropped_bgr), axis=1)
        if out_vid is not None:
            out_vid.write(render_img)
        if out_vid is None or display:
            cv2.imshow('render_img', render_img)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
    def proces_epoch(dataset_loader, train=True):
        stage = 'TRAINING' if train else 'VALIDATION'
        total_iter = len(dataset_loader) * dataset_loader.batch_size * epoch
        pbar = tqdm(dataset_loader, unit='batches')

        # Set networks training mode
        G.train(train)
        D.train(train)

        # Reset logger
        logger.reset(prefix='{} {}X{}: Epoch: {} / {}; LR: {:.0e}; '.format(
            stage, res, res, epoch + 1, res_epochs,  optimizer_G.param_groups[0]['lr']))

        # For each batch in the training data
        for i, (img, landmarks, target) in enumerate(pbar):
            # Prepare input
            with torch.no_grad():
                # For each view images and landmarks
                landmarks[1] = landmarks[1].to(device)
                for j in range(len(img)):
                    # landmarks[j] = landmarks[j].to(device)

                    # For each pyramid image: push to device
                    for p in range(len(img[j])):
                        img[j][p] = img[j][p].to(device)

                # Remove unnecessary pyramids
                for j in range(len(img)):
                    img[j] = img[j][-ri - 1:]

                # Concatenate pyramid images with context to derive the final input
                input = []
                for p in range(len(img[0])):
                    context = res_landmarks_decoders[p](landmarks[1])
                    input.append(torch.cat((img[0][p], context), dim=1))

            # Reenactment
            img_pred = G(input)

            # Fake Detection and Loss
            img_pred_pyd = img_utils.create_pyramid(img_pred, len(img[0]))
            pred_fake_pool = D([x.detach() for x in img_pred_pyd])
            loss_D_fake = criterion_gan(pred_fake_pool, False)

            # Real Detection and Loss
            pred_real = D(img[1])
            loss_D_real = criterion_gan(pred_real, True)

            loss_D_total = (loss_D_fake + loss_D_real) * 0.5

            # GAN loss (Fake Passability Loss)
            pred_fake = D(img_pred_pyd)
            loss_G_GAN = criterion_gan(pred_fake, True)

            # Reconstruction and segmentation loss
            loss_pixelwise = criterion_pixelwise(img_pred, img[1][0])
            loss_id = criterion_id(img_pred, img[1][0])
            loss_attr = criterion_attr(img_pred, img[1][0])
            loss_rec = 0.1 * loss_pixelwise + 0.5 * loss_id + 0.5 * loss_attr

            loss_G_total = rec_weight * loss_rec + gan_weight * loss_G_GAN

            if train:
                # Update generator weights
                optimizer_G.zero_grad()
                loss_G_total.backward()
                optimizer_G.step()

                # Update discriminator weights
                optimizer_D.zero_grad()
                loss_D_total.backward()
                optimizer_D.step()

            logger.update('losses', pixelwise=loss_pixelwise, id=loss_id, attr=loss_attr, rec=loss_rec,
                          g_gan=loss_G_GAN, d_gan=loss_D_total)
            total_iter += dataset_loader.batch_size

            # Batch logs
            pbar.set_description(str(logger))
            if train and i % log_freq == 0:
                logger.log_scalars_val('%dx%d/batch' % (res, res), total_iter)

        # Epoch logs
        logger.log_scalars_avg('%dx%d/epoch/%s' % (res, res, 'train' if train else 'val'), epoch)
        if not train:
            # Log images
            grid = img_utils.make_grid(img[0][0], img_pred, img[1][0])
            logger.log_image('%dx%d/vis' % (res, res), grid, epoch)

        return logger.log_dict['losses']['rec'].avg