Ejemplo n.º 1
0
    def __getitem_along_with_fa__(self, item):
        """
        Online get landmark alignment (deprecated)
        (can only run under num_works=0)
        """
        fls_filename = self.fls_filenames[item]

        # load mp4 file
        # ================= raw VOX version ================================
        mp4_filename = fls_filename[:-4].split('_x_')
        mp4_id = mp4_filename[0].split('_')[-1]
        mp4_vname = mp4_filename[1]
        mp4_vid = mp4_filename[2]
        video_dir = os.path.join(self.mp4_dir, mp4_id, mp4_vname, mp4_vid + '.mp4')
        # print('============================\nvideo_dir : ' + video_dir, item)
        # ======================================================================

        video = cv2.VideoCapture(video_dir)
        if (video.isOpened() == False):
            print('Unable to open video file')
            exit(0)
        length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))

        # save video and landmark in parallel
        frames = []
        random_frame_indices = np.random.permutation(length-2)[0:self.num_random_frames]

        for j in range(length):
            ret, img = video.read()

            if(j in random_frame_indices):
                # online landmark
                img_video = cv2.resize(img, (256, 256))
                img = img_video.transpose((2, 0, 1)) / 255.0
                inputs = torch.tensor(img, dtype=torch.float32, requires_grad=False).unsqueeze(0).to(self.device)
                with torch.no_grad():
                    outputs, boundary_channels = self.model(inputs)
                pred_heatmap = outputs[-1][:, :-1, :, :][0].detach().cpu()
                pred_landmarks, _ = get_preds_fromhm(pred_heatmap.unsqueeze(0))
                pred_landmarks = pred_landmarks.squeeze().numpy() * 4

                img_fl = np.ones(shape=(256, 256, 3)) * 255
                img_fl = vis_landmark_on_img98(img_fl * 255.0, pred_landmarks)  # 98x2

                frame = np.concatenate((img_fl, img_video), axis=2)
                frames.append(frame)

        frames = np.stack(frames, axis=0).astype(np.float32)/255.0  # N x 256 x 256 x 6

        image_in = np.concatenate([frames[0:-1, :, :, 0:3], frames[1:, :, :, 3:6]], axis=3)
        image_out = frames[0:-1, :, :, 3:6]

        image_in, image_out = np.swapaxes(image_in, 1, 3), np.swapaxes(image_out, 1, 3)
        return image_in, image_out
    def test(self):
        if (self.opt_parser.use_vox_dataset == 'raw'):
            if (self.opt_parser.add_audio_in):
                from src.dataset.image_translation.image_translation_dataset import \
                    image_translation_raw98_with_audio_test_dataset as image_translation_test_dataset
            else:
                from src.dataset.image_translation.image_translation_dataset import image_translation_raw98_test_dataset as image_translation_test_dataset
        else:
            from src.dataset.image_translation.image_translation_dataset import image_translation_preprocessed98_test_dataset as image_translation_test_dataset
        self.dataset = image_translation_test_dataset(
            num_frames=self.opt_parser.num_frames)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=1,
            shuffle=True,
            num_workers=self.opt_parser.num_workers)

        self.G.eval()
        for i, batch in enumerate(self.dataloader):
            print(i, 50)
            if (i > 50):
                break

            if (self.opt_parser.add_audio_in):
                image_in, image_out, audio_in = batch
                audio_in = audio_in.reshape(-1, 1, 256, 256).to(device)
            else:
                image_in, image_out = batch

            # # online landmark (AwingNet)
            with torch.no_grad():
                image_in, image_out = \
                    image_in.reshape(-1, 3, 256, 256).to(
                        device), image_out.reshape(-1, 3, 256, 256).to(device)

                pred_landmarks = []
                for j in range(image_in.shape[0] // 16):
                    inputs = image_out[j * 16:j * 16 + 16]
                    outputs, boundary_channels = self.fa_model(inputs)
                    pred_heatmap = outputs[-1][:, :-1, :, :].detach().cpu()
                    pred_landmark, _ = get_preds_fromhm(pred_heatmap)
                    pred_landmarks.append(pred_landmark.numpy() * 4)
                pred_landmarks = np.concatenate(pred_landmarks, axis=0)

            # draw landmark on while bg
            img_fls = []
            for pred_fl in pred_landmarks:
                img_fl = np.ones(shape=(256, 256, 3)) * 255.0
                img_fl = vis_landmark_on_img98(img_fl, pred_fl)  # 98x2
                img_fls.append(img_fl.transpose((2, 0, 1)))
            img_fls = np.stack(img_fls, axis=0).astype(np.float32) / 255.0
            image_fls_in = torch.tensor(img_fls,
                                        requires_grad=False).to(device)

            if (self.opt_parser.add_audio_in):
                # print(image_fls_in.shape, image_in.shape, audio_in.shape)
                image_in = torch.cat([
                    image_fls_in, image_in[0:image_fls_in.shape[0]],
                    audio_in[0:image_fls_in.shape[0]]
                ],
                                     dim=1)
            else:
                image_in = torch.cat(
                    [image_fls_in, image_in[0:image_fls_in.shape[0]]], dim=1)

            # normal 68 test dataset
            # image_in, image_out = image_in.reshape(-1, 6, 256, 256), image_out.reshape(-1, 3, 256, 256)

            # random single frame
            # cv2.imwrite('random_img_{}.jpg'.format(i), np.swapaxes(image_out[5].numpy(),0, 2)*255.0)

            image_in, image_out = image_in.to(device), image_out.to(device)

            writer = cv2.VideoWriter('tmp_{:04d}.mp4'.format(i),
                                     cv2.VideoWriter_fourcc(*'mjpg'), 25,
                                     (256 * 4, 256))

            for j in range(image_in.shape[0] // 16):
                g_out = self.G(image_in[j * 16:j * 16 + 16])
                g_out = torch.tanh(g_out)

                # norm 68 pts
                # g_out = np.swapaxes(g_out.cpu().detach().numpy(), 1, 3)
                # ref_out = np.swapaxes(image_out[j*16:j*16+16].cpu().detach().numpy(), 1, 3)
                # ref_in = np.swapaxes(image_in[j*16:j*16+16, 3:6, :, :].cpu().detach().numpy(), 1, 3)
                # fls_in = np.swapaxes(image_in[j * 16:j * 16 + 16, 0:3, :, :].cpu().detach().numpy(), 1, 3)
                g_out = g_out.cpu().detach().numpy().transpose((0, 2, 3, 1))
                g_out[g_out < 0] = 0
                ref_out = image_out[j * 16:j * 16 +
                                    16].cpu().detach().numpy().transpose(
                                        (0, 2, 3, 1))
                ref_in = image_in[j * 16:j * 16 + 16,
                                  3:6, :, :].cpu().detach().numpy().transpose(
                                      (0, 2, 3, 1))
                fls_in = image_in[j * 16:j * 16 + 16,
                                  0:3, :, :].cpu().detach().numpy().transpose(
                                      (0, 2, 3, 1))

                for k in range(g_out.shape[0]):
                    frame = np.concatenate(
                        (ref_in[k], g_out[k], fls_in[k], ref_out[k]),
                        axis=1) * 255.0
                    writer.write(frame.astype(np.uint8))

            writer.release()

            os.system(
                'ffmpeg -y -i tmp_{:04d}.mp4 -pix_fmt yuv420p random_{:04d}.mp4'
                .format(i, i))
            os.system('rm tmp_{:04d}.mp4'.format(i))
    def __train_pass__(self, epoch, is_training=True):
        epoch += self.ckpt['epoch']
        st_epoch = time.time()
        if (is_training):
            self.G.train()
            status = 'TRAIN'
        else:
            self.G.eval()
            status = 'EVAL'

        g_time = 0.0
        for i, batch in enumerate(self.dataloader):
            if (i >= len(self.dataloader) - 2):
                break
            st_batch = time.time()

            if (self.opt_parser.comb_fan_awing):
                image_in, image_out, fan_pred_landmarks = batch
                fan_pred_landmarks = fan_pred_landmarks.reshape(
                    -1, 68, 3).detach().cpu().numpy()
            elif (self.opt_parser.add_audio_in):
                image_in, image_out, audio_in = batch
                audio_in = audio_in.reshape(-1, 1, 256, 256).to(device)
            else:
                image_in, image_out = batch

            with torch.no_grad():
                # # online landmark (AwingNet)
                image_in, image_out = \
                    image_in.reshape(-1, 3, 256, 256).to(
                        device), image_out.reshape(-1, 3, 256, 256).to(device)
                inputs = image_out
                outputs, boundary_channels = self.fa_model(inputs)
                pred_heatmap = outputs[-1][:, :-1, :, :].detach().cpu()
                pred_landmarks, _ = get_preds_fromhm(pred_heatmap)
                pred_landmarks = pred_landmarks.numpy() * 4

                # online landmark (FAN) -> replace jaw + eye brow in AwingNet
                if (self.opt_parser.comb_fan_awing):
                    fl_jaw_eyebrow = fan_pred_landmarks[:, 0:27, 0:2]
                    fl_rest = pred_landmarks[:, 51:, :]
                    pred_landmarks = np.concatenate([fl_jaw_eyebrow, fl_rest],
                                                    axis=1).astype(np.int)

            # draw landmark on while bg
            img_fls = []
            for pred_fl in pred_landmarks:
                img_fl = np.ones(shape=(256, 256, 3)) * 255.0
                if (self.opt_parser.comb_fan_awing):
                    img_fl = vis_landmark_on_img74(img_fl, pred_fl)  # 74x2
                else:
                    img_fl = vis_landmark_on_img98(img_fl, pred_fl)  # 98x2
                img_fls.append(img_fl.transpose((2, 0, 1)))
            img_fls = np.stack(img_fls, axis=0).astype(np.float32) / 255.0
            image_fls_in = torch.tensor(img_fls,
                                        requires_grad=False).to(device)
            if (self.opt_parser.add_audio_in):
                # print(image_fls_in.shape, image_in.shape, audio_in.shape)
                image_in = torch.cat([image_fls_in, image_in, audio_in], dim=1)
            else:
                image_in = torch.cat([image_fls_in, image_in], dim=1)

            # image_in, image_out = \
            #     image_in.reshape(-1, 6, 256, 256).to(device), image_out.reshape(-1, 3, 256, 256).to(device)

            # image2image net fp
            g_out = self.G(image_in)
            g_out = torch.tanh(g_out)

            loss_l1 = self.criterionL1(g_out, image_out)
            loss_vgg, loss_style = self.criterionVGG(g_out,
                                                     image_out,
                                                     style=True)

            loss_vgg, loss_style = torch.mean(loss_vgg), torch.mean(loss_style)

            loss = loss_l1 + loss_vgg + loss_style
            if (is_training):
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # log
            if (self.opt_parser.write):
                self.writer.add_scalar('loss',
                                       loss.cpu().detach().numpy(), self.count)
                self.writer.add_scalar('loss_l1',
                                       loss_l1.cpu().detach().numpy(),
                                       self.count)
                self.writer.add_scalar('loss_vgg',
                                       loss_vgg.cpu().detach().numpy(),
                                       self.count)
                self.count += 1

            # save image to track training process
            if (i % self.opt_parser.jpg_freq == 0):
                vis_in = np.concatenate([
                    image_in[0, 3:6].cpu().detach().numpy().transpose(
                        (1, 2, 0)),
                    image_in[0, 0:3].cpu().detach().numpy().transpose(
                        (1, 2, 0))
                ],
                                        axis=1)
                vis_out = np.concatenate([
                    image_out[0].cpu().detach().numpy().transpose(
                        (1, 2, 0)), g_out[0].cpu().detach().numpy().transpose(
                            (1, 2, 0))
                ],
                                         axis=1)
                vis = np.concatenate([vis_in, vis_out], axis=0)
                try:
                    os.makedirs(
                        os.path.join(self.opt_parser.jpg_dir,
                                     self.opt_parser.name))
                except:
                    pass
                cv2.imwrite(
                    os.path.join(self.opt_parser.jpg_dir, self.opt_parser.name,
                                 'e{:03d}_b{:04d}.jpg'.format(epoch, i)),
                    vis * 255.0)
            # save ckpt
            if (i % self.opt_parser.ckpt_last_freq == 0):
                self.__save_model__('last', epoch)
                # os.system('!zip -r "/content/MakeItTalk/drive/MyDrive/MakeItTalk/last.zip" "PreprocessedVox_imagetranslation/ckpt/tmp/ckpt_last.pth"')

            print(
                "Epoch {}, Batch {}/{}, loss {:.4f}, l1 {:.4f}, vggloss {:.4f}, styleloss {:.4f} time {:.4f}"
                .format(epoch, i,
                        len(self.dataset) // self.opt_parser.batch_size,
                        loss.cpu().detach().numpy(),
                        loss_l1.cpu().detach().numpy(),
                        loss_vgg.cpu().detach().numpy(),
                        loss_style.cpu().detach().numpy(),
                        time.time() - st_batch))

            g_time += time.time() - st_batch

            if (self.opt_parser.test_speed):
                if (i >= 100):
                    break

        print('Epoch time usage:',
              time.time() - st_epoch, 'I/O time usage:',
              time.time() - st_epoch - g_time, '\n=========================')
        if (self.opt_parser.test_speed):
            exit(0)
        if (epoch % self.opt_parser.ckpt_epoch_freq == 0):
            self.__save_model__('{:02d}'.format(epoch), epoch)
Ejemplo n.º 4
0
            # z = 1.8 / 2
            # new_half_w = int(max(r-l, t-b) * z)
            # c = [min(max(0+new_half_w, (l+r)//2), h-1-new_half_w), min(max(0+new_half_w, (b+t)//2), h-1-new_half_w)]
            # c = [int(item) for item in c]
            # print(c)
            # frame0 = cv2.resize(frame[c[0]-new_half_w:c[0]+new_half_w, c[1]-new_half_w:c[1]+new_half_w], (256, 256))
            # ===========================================================================================================

            frame0 = cv2.resize(frame, (256, 256))
            frame = frame0.copy().transpose(
                (2, 0, 1)).astype(np.float32) / 255.0
            inputs = torch.tensor(frame,
                                  requires_grad=False).unsqueeze(0).to(device)
            outputs, boundary_channels = fa_model(inputs)
            pred_heatmap = outputs[-1][:, :-1, :, :].detach().cpu()
            pred_landmarks, _ = get_preds_fromhm(pred_heatmap)
            pred_landmarks = pred_landmarks[0].numpy() * 4
            # pred_landmarks[:, 0], pred_landmarks[:, 1] = pred_landmarks[:, 1] * 1 , pred_landmarks[:, 0] * 1
            frame0 = vis_landmark_on_img98(frame0, pred_landmarks).astype(
                np.uint8)  # 98x2

        else:
            ''' FAN '''
            pred_landmarks = predictor.get_landmarks(frame)[0]
            frame0 = vis_landmark_on_img(frame, pred_landmarks,
                                         linewidth=2)  # 68x2
            frame0 = cv2.resize(frame0, (256, 256))
            pred_landmarks[:,
                           0], pred_landmarks[:,
                                              1] = pred_landmarks[:,
                                                                  1] * 256. / h, pred_landmarks[:,