Ejemplo n.º 1
0
    def __init__(self,
                 split,
                 data_root,
                 inner_shuffle=True,
                 limit=0,
                 sampling_half_window_size_seconds=2.0,
                 unmask_fringe_width=10,
                 img_augment=True,
                 filelists_dir='filelists'):
        self.all_videos = list(
            filter(
                lambda vidname: os.path.exists(join(vidname, "audio.wav")),
                get_image_list(data_root,
                               split,
                               limit=limit,
                               filelists_dir=filelists_dir)))
        self.img_names = {
            vidname:
            sorted(glob(join(vidname, '*.png')),
                   key=lambda name: int(os.path.basename(name).split('.')[0]))
            for vidname in self.all_videos
        }

        self.orig_mels = {}
        for vidname in tqdm(self.all_videos, desc="load mels"):
            mel_path = join(vidname, "mel.npy")
            wavpath = join(vidname, "audio.wav")
            assert os.path.exists(wavpath), wavpath
            if os.path.exists(mel_path):
                try:
                    orig_mel = np.load(mel_path)
                except Exception as err:
                    print(err)
                    wav = audio.load_wav(wavpath, hparams.sample_rate)
                    orig_mel = audio.melspectrogram(wav).T
                    np.save(mel_path, orig_mel)
            else:
                wav = audio.load_wav(wavpath, hparams.sample_rate)
                orig_mel = audio.melspectrogram(wav).T
                np.save(mel_path, orig_mel)
            self.orig_mels[vidname] = orig_mel
        self.data_root = data_root
        self.inner_shuffle = inner_shuffle
        self.all_videos_p = None
        self.linear_space = np.array(range(len(self.all_videos)))
        if inner_shuffle:
            imgs_counts = [
                len(self.img_names[vidname]) for vidname in self.all_videos
            ]
            self.all_videos_p = np.array(imgs_counts) / np.sum(imgs_counts)
        self.sampling_half_window_size_seconds = sampling_half_window_size_seconds
        self.unmask_fringe_width = int(unmask_fringe_width)
        self.fringe_x1 = self.unmask_fringe_width
        self.fringe_x2 = hparams.img_size - self.unmask_fringe_width
        assert self.fringe_x2 > self.fringe_x1
        self.fringe_y2 = hparams.img_size - self.unmask_fringe_width
        assert self.fringe_y2 > hparams.img_size // 2
        self.img_augment = img_augment
Ejemplo n.º 2
0
    def __init__(self, data_root, only_true_image=True, img_size=96):
        self.all_videos = []
        for dirname in os.listdir(data_root):
            dirpath = os.path.join(data_root, dirname)
            if not os.path.isdir(dirpath):
                continue
            for vid_dirname in os.listdir(dirpath):
                video_path = os.path.join(dirpath, vid_dirname)
                wavpath = os.path.join(video_path, "audio.wav")
                if len(os.listdir(video_path)) < 3 * hp.syncnet_T + 2:
                    print("insufficient files of dir:", vid_dirname)
                    continue
                if not os.path.exists(wavpath):
                    print("skip missing audio of:", vid_dirname)
                    continue
                self.all_videos.append(video_path)

        self.img_names = {
            vidname:
            sorted(glob(os.path.join(vidname, '*.png')),
                   key=lambda name: int(os.path.basename(name).split('.')[0]))
            for vidname in self.all_videos
        }

        self.orig_mels = {}
        for vidname in tqdm(self.all_videos, desc="load mels"):
            mel_path = os.path.join(vidname, "mel.npy")
            wavpath = os.path.join(vidname, "audio.wav")
            if os.path.exists(mel_path):
                try:
                    orig_mel = np.load(mel_path)
                except Exception as err:
                    print(err)
                    wav = audio.load_wav(wavpath, hp.sample_rate)
                    orig_mel = audio.melspectrogram(wav).T
                    np.save(mel_path, orig_mel)
            else:
                wav = audio.load_wav(wavpath, hp.sample_rate)
                orig_mel = audio.melspectrogram(wav).T
                np.save(mel_path, orig_mel)
            self.orig_mels[vidname] = orig_mel
        self.data_root = data_root
        self.inner_shuffle = False
        self.sampling_half_window_size_seconds = 1e10

        # 實驗發現, 只要是wrong image, model的分辨能力都很好, 因此不需要sampling wrong image
        self.only_true_image = only_true_image
        self.img_size = img_size
Ejemplo n.º 3
0
def stream_mel_chunk(filepath, fps):
    wav = audio.load_wav(filepath, hp.sample_rate)
    mel = audio.melspectrogram(wav)
    print("mel", mel.shape)

    if np.isnan(mel.reshape(-1)).sum() > 0:
        raise ValueError(
            'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again'
        )

    mel_idx_multiplier = hp.num_mels / fps
    i = 0
    while 1:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + hp.syncnet_mel_step_size > len(mel[0]):
            yield mel[:, len(mel[0]) - hp.syncnet_mel_step_size:]
            break
        yield mel[:, start_idx:start_idx + hp.syncnet_mel_step_size]
        i += 1
Ejemplo n.º 4
0
def to_mels(audio_path, fps, num_mels=80, mel_step_size=16, sample_rate=16000):
    wav = audio.load_wav(audio_path, sample_rate)
    mel = audio.melspectrogram(wav)

    if np.isnan(mel.reshape(-1)).sum() > 0:
        raise ValueError(
            'Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again'
        )

    mel_chunks = []
    mel_idx_multiplier = num_mels / fps
    i = 0
    while 1:
        start_idx = int(i * mel_idx_multiplier)
        if start_idx + mel_step_size > len(mel[0]):
            mel_chunks.append(mel[:, len(mel[0]) - mel_step_size:])
            break
        mel_chunks.append(mel[:, start_idx:start_idx + mel_step_size])
        i += 1
    return mel_chunks
Ejemplo n.º 5
0
def main():
    assert args.data_root is not None
    data_root = args.data_root

    if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir)

    with open(args.filelist, 'r') as filelist:
        lines = filelist.readlines()

    for idx, line in enumerate(tqdm(lines)):
        audio_src, video = line.strip().split()

        audio_src = os.path.join(data_root, audio_src) + '.mp4'
        video = os.path.join(data_root, video) + '.mp4'

        command = "ffmpeg -loglevel panic -y -i '{}' -strict -2 '{}'".format(
            audio_src, '../temp/temp.wav')
        subprocess.call(command, shell=True)
        temp_audio = '../temp/temp.wav'

        wav = audio.load_wav(temp_audio, 16000)
        mel = audio.melspectrogram(wav)
        if np.isnan(mel.reshape(-1)).sum() > 0:
            continue

        mel_chunks = []
        i = 0
        while 1:
            start_idx = int(i * mel_idx_multiplier)
            if start_idx + mel_step_size > len(mel[0]):
                break
            mel_chunks.append(mel[:, start_idx:start_idx + mel_step_size])
            i += 1

        video_stream = cv2.VideoCapture(video)

        full_frames = []
        while 1:
            still_reading, frame = video_stream.read()
            if not still_reading or len(full_frames) > len(mel_chunks):
                video_stream.release()
                break
            full_frames.append(frame)

        if len(full_frames) < len(mel_chunks):
            continue

        full_frames = full_frames[:len(mel_chunks)]

        try:
            face_det_results = face_detect(full_frames.copy())
        except ValueError as e:
            continue

        batch_size = args.wav2lip_batch_size
        gen = datagen(full_frames.copy(), face_det_results, mel_chunks)

        for i, (img_batch, mel_batch, frames, coords) in enumerate(gen):
            if i == 0:
                frame_h, frame_w = full_frames[0].shape[:-1]
                out = cv2.VideoWriter('../temp/result.avi',
                                      cv2.VideoWriter_fourcc(*'DIVX'), fps,
                                      (frame_w, frame_h))

            img_batch = torch.FloatTensor(np.transpose(
                img_batch, (0, 3, 1, 2))).to(device)
            mel_batch = torch.FloatTensor(np.transpose(
                mel_batch, (0, 3, 1, 2))).to(device)

            with torch.no_grad():
                pred = model(mel_batch, img_batch)

            pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.

            for pl, f, c in zip(pred, frames, coords):
                y1, y2, x1, x2 = c
                pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1))
                f[y1:y2, x1:x2] = pl
                out.write(f)

        out.release()

        vid = os.path.join(args.results_dir, '{}.mp4'.format(idx))

        command = "ffmpeg -loglevel panic -y -i '{}' -i '{}' -strict -2 -q:v 1 '{}'".format(
            temp_audio, '../temp/result.avi', vid)
        subprocess.call(command, shell=True)
Ejemplo n.º 6
0
def main():
    if not os.path.isdir(args.results_dir): os.makedirs(args.results_dir)

    if args.mode == 'dubbed':
        files = listdir(args.data_root)
        lines = ['{} {}'.format(f, f) for f in files]

    else:
        assert args.filelist is not None
        with open(args.filelist, 'r') as filelist:
            lines = filelist.readlines()

    for idx, line in enumerate(tqdm(lines)):
        video, audio_src = line.strip().split()

        audio_src = os.path.join(args.data_root, audio_src)
        video = os.path.join(args.data_root, video)

        command = "ffmpeg -loglevel panic -y -i '{}' -strict -2 '{}'".format(
            audio_src, '../temp/temp.wav')
        subprocess.call(command, shell=True)
        temp_audio = '../temp/temp.wav'

        wav = audio.load_wav(temp_audio, 16000)
        mel = audio.melspectrogram(wav)

        if np.isnan(mel.reshape(-1)).sum() > 0:
            raise ValueError('Mel contains nan!')

        video_stream = cv2.VideoCapture(video)

        fps = video_stream.get(cv2.CAP_PROP_FPS)
        mel_idx_multiplier = 80. / fps

        full_frames = []
        while 1:
            still_reading, frame = video_stream.read()
            if not still_reading:
                video_stream.release()
                break

            if min(frame.shape[:-1]) > args.max_frame_res:
                h, w = frame.shape[:-1]
                scale_factor = min(h, w) / float(args.max_frame_res)
                h = int(h / scale_factor)
                w = int(w / scale_factor)

                frame = cv2.resize(frame, (w, h))
            full_frames.append(frame)

        mel_chunks = []
        i = 0
        while 1:
            start_idx = int(i * mel_idx_multiplier)
            if start_idx + mel_step_size > len(mel[0]):
                break
            mel_chunks.append(mel[:, start_idx:start_idx + mel_step_size])
            i += 1

        if len(full_frames) < len(mel_chunks):
            if args.mode == 'tts':
                full_frames = increase_frames(full_frames, len(mel_chunks))
            else:
                raise ValueError('#Frames, audio length mismatch')

        else:
            full_frames = full_frames[:len(mel_chunks)]

        try:
            face_det_results, full_frames = face_detect(full_frames.copy())
        except ValueError as e:
            continue

        batch_size = args.wav2lip_batch_size
        gen = datagen(full_frames.copy(), face_det_results, mel_chunks)

        for i, (img_batch, mel_batch, frames, coords) in enumerate(gen):
            if i == 0:
                frame_h, frame_w = full_frames[0].shape[:-1]

                out = cv2.VideoWriter('../temp/result.avi',
                                      cv2.VideoWriter_fourcc(*'DIVX'), fps,
                                      (frame_w, frame_h))

            img_batch = torch.FloatTensor(np.transpose(
                img_batch, (0, 3, 1, 2))).to(device)
            mel_batch = torch.FloatTensor(np.transpose(
                mel_batch, (0, 3, 1, 2))).to(device)

            with torch.no_grad():
                pred = model(mel_batch, img_batch)

            pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.

            for pl, f, c in zip(pred, frames, coords):
                y1, y2, x1, x2 = c
                pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1))
                f[y1:y2, x1:x2] = pl
                out.write(f)

        out.release()

        vid = os.path.join(args.results_dir, '{}.mp4'.format(idx))
        command = "ffmpeg -loglevel panic -y -i '{}' -i '{}' -strict -2 -q:v 1 '{}'".format(
            '../temp/temp.wav', '../temp/result.avi', vid)
        subprocess.call(command, shell=True)
Ejemplo n.º 7
0
def get_mel_chunks_count(filepath, fps):
    wav = audio.load_wav(filepath, hp.sample_rate)
    mel = audio.melspectrogram(wav)
    mel_idx_multiplier = hp.num_mels / fps
    return int((len(mel[0]) - hp.syncnet_mel_step_size) % mel_idx_multiplier)