예제 #1
0
    def __init__(self,
                 path2data='../dataset/groot/data/speech2gesture_data',
                 path2outdata='../dataset/groot/data',
                 speaker='all',
                 preprocess_methods=['data']):
        super(Skeleton2D, self).__init__(path2data=path2data)
        self.path2data = path2data
        self.df = pd.read_csv(Path(self.path2data) / 'cmu_intervals_df.csv',
                              dtype=object)
        self.df.loc[:, 'delta_time'] = self.df['delta_time'].apply(float)
        self.df.loc[:, 'interval_id'] = self.df['interval_id'].apply(str)

        self.path2outdata = path2outdata
        self.speaker = speaker
        self.preprocess_methods = preprocess_methods

        self.missing = MissingData(self.path2outdata)
예제 #2
0
파일: text.py 프로젝트: chahuja/mix-stage
    def __init__(self,
                 path2data='../dataset/groot/data',
                 path2outdata='../dataset/groot/data',
                 speaker='all',
                 preprocess_methods=['w2v'],
                 text_aligned=0):
        super(Text, self).__init__(path2data=path2data)
        self.path2data = path2data
        self.df = pd.read_csv(Path(self.path2data) / 'cmu_intervals_df.csv',
                              dtype=object)
        self.df.loc[:, 'delta_time'] = self.df['delta_time'].apply(float)
        self.df.loc[:, 'interval_id'] = self.df['interval_id'].apply(str)

        self.path2outdata = path2outdata
        self.speaker = speaker
        self.preprocess_methods = preprocess_methods

        self.missing = MissingData(self.path2data)

        ## list of word2-vec models
        self.w2v_models = []
        self.text_aligned = text_aligned
예제 #3
0
class Skeleton2D(Modality):
    def __init__(self,
                 path2data='../dataset/groot/data/speech2gesture_data',
                 path2outdata='../dataset/groot/data',
                 speaker='all',
                 preprocess_methods=['data']):
        super(Skeleton2D, self).__init__(path2data=path2data)
        self.path2data = path2data
        self.df = pd.read_csv(Path(self.path2data) / 'cmu_intervals_df.csv',
                              dtype=object)
        self.df.loc[:, 'delta_time'] = self.df['delta_time'].apply(float)
        self.df.loc[:, 'interval_id'] = self.df['interval_id'].apply(str)

        self.path2outdata = path2outdata
        self.speaker = speaker
        self.preprocess_methods = preprocess_methods

        self.missing = MissingData(self.path2outdata)

    def preprocess(self):
        if self.speaker[0] != 'all':
            speakers = self.speaker
        else:
            speakers = self.speakers

        for speaker in tqdm(speakers, desc='speakers', leave=False):
            tqdm.write('Speaker: {}'.format(speaker))
            df_speaker = self.get_df_subset('speaker', speaker)
            interval_ids = df_speaker['interval_id'].unique()
            interval_ids = np.array(
                list(set(interval_ids) - self.missing.load_intervals()))

            # for interval_id in tqdm(interval_ids, desc='intervals'):
            #   self.save_intervals(interval_id, speaker)
            # pdb.set_trace()

            missing_data_list = Parallel(n_jobs=-1)(
                delayed(self.save_intervals)(interval_id, speaker)
                for interval_id in tqdm(interval_ids))
            self.missing.save_intervals(missing_data_list)

    def save_intervals(self, interval_id, speaker):
        ## process keypoints for each interval
        if self.preprocess_methods == 'data':
            process_interval = self.process_interval
        elif self.preprocess_methods == 'normalize':
            process_interval = self.normalize
        elif self.preprocess_methods == 'confidence':
            process_interval = self.confidence
        else:
            raise 'preprocess_methods = {} not found'.format(
                self.preprocess_methods)

        keypoints = process_interval(interval_id)
        if keypoints is None:
            return interval_id

        ## save keypoints
        filename = Path(self.path2outdata
                        ) / 'processed' / speaker / '{}.h5'.format(interval_id)
        key = self.add_key(self.h5_key, self.preprocess_methods)
        try:
            self.append(filename, key, keypoints)
        except:
            #      pdb.set_trace()
            return interval_id
        return None

    def normalize(self, interval_id):
        ## get filename from interval_id
        speaker = self.get_df_subset('interval_id',
                                     interval_id).iloc[0].speaker
        filename = Path(self.path2outdata
                        ) / 'processed' / speaker / '{}.h5'.format(interval_id)

        ## Reference shoulder length
        ref_len = 167

        ## load keypoints
        try:
            data, h5 = self.load(filename, 'pose/data')
            data = data[()]
            h5.close()
        except:
            warnings.warn(
                'pose/data not found in filename {}'.format(filename))
            return None

        ## exception
        if len(data.shape) == 3:
            return None
        ## normalize
        ratio = ref_len / (
            (data.reshape(data.shape[0], 2, -1)[..., 1]**2).sum(1)**0.5)
        keypoints = ratio.reshape(-1, 1) * data
        keypoints[:, [0, 52]] = data[:, [0, 52]]

        return keypoints

    def berk_confidence(self, interval_id):
        file_list = self.get_filelist(interval_id)
        if file_list is None:
            return None

        augment_filename = lambda x: replace_Nth_parent(
            x[:-4] + '_pose.yml', by='keypoints_all', N=2)
        file_list = [augment_filename(filename) for filename in file_list]
        keypoints_list = [loadPose(filename) for filename in file_list]

        try:
            keypoints = np.stack(keypoints_list, axis=0)
        except:
            warnings.warn(
                '[BERK_CONFIDENCE] interval_id: {}'.format(interval_id))
            pdb.set_trace()
            return None
        keypoints = keypoints[..., -1]

        return np.concatenate([keypoints] * 2, axis=1)  ## (Time, Joints)

    def get_speaker(self, interval_id):
        return self.df[self.df['interval_id'] == interval_id].speaker.iloc[0]

    def cmu_confidence(self, interval_id):
        filename = Path(
            self.path2outdata) / 'raw_keypoints' / self.get_speaker(
                interval_id) / '{}.h5'.format(interval_id)
        try:
            data, h5 = self.load(filename.as_posix(), 'pose/data')
            data = data[()]
            h5.close()
        except:
            warnings.warn('interval {} not found'.format(interval_id))
            h5.close()

        keypoints = data[:, -1, :]
        return np.concatenate([keypoints] * 2, axis=1)  ## (Time, Joints)

    def confidence(self, interval_id):
        if interval_id[0] == 'c':
            return self.cmu_confidence(interval_id)
        else:
            return self.berk_confidence(interval_id)

    def process_interval(self, interval_id):
        file_list = self.get_filelist(interval_id)
        if file_list is None:
            return None

        keypoints_list = [np.loadtxt(filename) for filename in file_list]

        keypoints = np.stack(keypoints_list, axis=0)
        keypoints = self.process_keypoints(keypoints)

        return keypoints

    def process_keypoints(self, keypoints, inv=False):
        if not inv:
            keypoints_new = keypoints - keypoints[..., self.root:self.root + 1]
            keypoints_new[..., self.root] = keypoints[..., self.root]
            keypoints_new = keypoints_new.reshape(keypoints_new.shape[0], -1)
        else:
            keypoints = keypoints.reshape(keypoints.shape[0], 2, -1)
            keypoints_new = keypoints + keypoints[..., self.root:self.root + 1]
            keypoints_new[..., self.root] = keypoints[..., self.root]
        return keypoints_new

    def get_filelist(self, interval_id):
        df = self.df[self.df['interval_id'] == interval_id]
        start_time = df['start_time'].values[0].split(' ')[-1][1:]
        end_time = df['end_time'].values[0].split(' ')[-1][1:]
        speaker = df['speaker'].values[0]
        video_fn = df['video_fn'].values[0].split('.')[
            0]  ## the folder names end at the first period of the video_fn
        video_fn = Path('_'.join(
            video_fn.split(' ')))  ## the folder names have `_` instead of ` `
        path2keypoints = '{}/{}/keypoints_simple/{}/'.format(
            self.path2data, speaker, video_fn)
        file_df = pd.DataFrame(data=os.listdir(path2keypoints),
                               columns=['files_temp'])
        file_df['files'] = file_df['files_temp'].apply(
            lambda x: (Path(path2keypoints) / x).as_posix())
        file_df['start_time'] = file_df['files_temp'].apply(
            self.get_time_from_file)
        file_df = file_df.sort_values(by='start_time').reset_index()

        try:
            start_id = file_df[file_df['start_time'] == start_time].index[0]
            end_id = file_df[file_df['start_time'] == end_time].index[0]
        except:
            return None
        if not (self.are_keypoints_complete(file_df, start_id, end_id)):
            #self.missing.append_interval(interval_id)
            warnings.warn('interval_id: {} not found.'.format(interval_id))
            return None
        return file_df.iloc[start_id:end_id + 1]['files'].values

    def are_keypoints_complete(self, file_df, start_id, end_id):
        # frames = (end_id + 1) - start_id
        # diff = (datetime.strptime(end_time, '%H:%M:%S.%f') - datetime.strptime(start_time, '%H:%M:%S.%f')).total_seconds()
        # diff_frames = (self.fs * diff) - frames
        flag = (
            ((file_df.iloc[start_id + 1:end_id + 1].start_time.apply(
                pd.to_timedelta).reset_index() -
              file_df.iloc[start_id:end_id].start_time.apply(pd.to_timedelta).
              reset_index())['start_time'].apply(lambda x: x.total_seconds()) -
             1 / self.fs('pose/data')).apply(abs) > 0.00008).any()
        if flag:
            return False
        # if abs(diff_frames) >= 2:
        #   return False

        return True

    def get_time_from_file(self, x):
        x_cap = ':'.join('.'.join(
            x.split('.')[:-1]).split('_')[-3:]).split('.')
        if len(
                x_cap
        ) == 1:  ## sometimes the filnames do not have miliseconds as it is all zeros
            x_cap = '.'.join(x_cap + ['000000'])
        else:
            x_cap = '.'.join(x_cap)
        return x_cap

    @property
    def parents(self):
        return [
            -1, 0, 1, 2, 0, 4, 5, 0, 7, 7, 6, 10, 11, 12, 13, 10, 15, 16, 17,
            10, 19, 20, 21, 10, 23, 24, 25, 10, 27, 28, 29, 3, 31, 32, 33, 34,
            31, 36, 37, 38, 31, 40, 41, 42, 31, 44, 45, 46, 31, 48, 49, 50
        ]

    @property
    def joint_subset(self):
        ## choose only the relevant skeleton key-points (removed nose and eyes)
        return np.r_[range(7), range(10, len(self.parents))]

    @property
    def root(self):
        return 0

    @property
    def joint_names(self):
        return [
            'Neck', 'RShoulder', 'RElbow', 'RWrist', 'LShoulder', 'LElbow',
            'LWrist', 'Nose', 'REye', 'LEye', 'LHandRoot', 'LHandThumb1',
            'LHandThumb2', 'LHandThumb3', 'LHandThumb4', 'LHandIndex1',
            'LHandIndex2', 'LHandIndex3', 'LHandIndex4', 'LHandMiddle1',
            'LHandMiddle2', 'LHandMiddle3', 'LHandMiddle4', 'LHandRing1',
            'LHandRing2', 'LHandRing3', 'LHandRing4', 'LHandLittle1',
            'LHandLittle2', 'LHandLittle3', 'LHandLittle4', 'RHandRoot',
            'RHandThumb1', 'RHandThumb2', 'RHandThumb3', 'RHandThumb4',
            'RHandIndex1', 'RHandIndex2', 'RHandIndex3', 'RHandIndex4',
            'RHandMiddle1', 'RHandMiddle2', 'RHandMiddle3', 'RHandMiddle4',
            'RHandRing1', 'RHandRing2', 'RHandRing3', 'RHandRing4',
            'RHandLittle1', 'RHandLittle2', 'RHandLittle3', 'RHandLittle4'
        ]

    def fs(self, modality):
        return 15

    @property
    def h5_key(self):
        return 'pose'
예제 #4
0
파일: text.py 프로젝트: chahuja/mix-stage
class Text(Modality):
    def __init__(self,
                 path2data='../dataset/groot/data',
                 path2outdata='../dataset/groot/data',
                 speaker='all',
                 preprocess_methods=['w2v'],
                 text_aligned=0):
        super(Text, self).__init__(path2data=path2data)
        self.path2data = path2data
        self.df = pd.read_csv(Path(self.path2data) / 'cmu_intervals_df.csv',
                              dtype=object)
        self.df.loc[:, 'delta_time'] = self.df['delta_time'].apply(float)
        self.df.loc[:, 'interval_id'] = self.df['interval_id'].apply(str)

        self.path2outdata = path2outdata
        self.speaker = speaker
        self.preprocess_methods = preprocess_methods

        self.missing = MissingData(self.path2data)

        ## list of word2-vec models
        self.w2v_models = []
        self.text_aligned = text_aligned

    def preprocess(self):
        ## load Glove/Word2Vec
        for pre_meth in self.preprocess_methods:
            if pre_meth == 'w2v':
                self.w2v_models.append(Word2Vec())
            elif pre_meth == 'bert':
                self.w2v_models.append(
                    BertForSequenceEmbedding(hidden_size=512))
            elif pre_meth == 'tokens':
                self.w2v_models.append(BertSentenceBatching())
            elif pre_meth == 'pos':
                self.w2v_models.append(POStagging())
            else:
                raise 'preprocess_method not found'
        print('Embedding models loaded')

        if self.speaker[0] != 'all':
            speakers = self.speaker
        else:
            speakers = self.speakers

        if self.text_aligned:
            self.text_aligned_preprocessing(speakers)
        else:
            self.text_notAligned_preprocessing(speakers)

    def text_aligned_preprocessing(self, speakers):
        for speaker in tqdm(speakers, desc='speakers', leave=False):
            tqdm.write('Speaker: {}'.format(speaker))
            df_speaker = self.get_df_subset('speaker', speaker)
            filename_dict = {}
            interval_id_list = []
            for interval_id in tqdm(df_speaker.interval_id.unique(),
                                    desc='load'):
                path2interval = Path(
                    self.path2data) / 'processed' / speaker / '{}.h5'.format(
                        interval_id)
                try:
                    text = pd.read_hdf(path2interval, 'text/meta', 'r')
                except:
                    warnings.warn(
                        'text/meta not found for {}'.format(interval_id))
                    continue
                filename_dict[interval_id] = text
                interval_id_list.append(interval_id)
            missing_data_list = []
            for interval_id in tqdm(interval_id_list, desc='save'):
                inter = self.save_intervals(interval_id, speaker,
                                            filename_dict, None)
                missing_data_list.append(inter)
            self.missing.save_intervals(set(missing_data_list))

    def text_notAligned_preprocessing(self, speakers):
        for speaker in tqdm(speakers, desc='speakers', leave=False):
            tqdm.write('Speaker: {}'.format(speaker))
            df_speaker = self.get_df_subset('speaker', speaker)
            df_speaker.loc[:, 'video_id'] = df_speaker['video_link'].apply(
                lambda x: x.split('=')[-1])
            df_speaker.loc[:, 'Start'] = pd.to_timedelta(
                df_speaker['start_time'].str.split().str[1]).dt.total_seconds(
                )
            df_speaker.loc[:, 'End'] = pd.to_timedelta(
                df_speaker['end_time'].str.split().str[1]).dt.total_seconds()
            interval_ids = df_speaker['interval_id'].unique()
            ## find path to processed files
            parent = Path(self.path2data) / 'raw' / '{}'.format(speaker)
            filenames = os.listdir(parent)
            filenames = [
                filename for filename in filenames
                if filename.split('_')[-1] == 'transcripts'
            ]
            filenames = [
                '{}/{}.csv'.format(filename,
                                   '_'.join(filename.split('_')[:-1]))
                for filename in filenames
            ]
            is_path = lambda x: os.path.exists(Path(parent) / x)
            # for filename in filenames:
            #   if not is_path(filename):
            #     pdb.set_trace()
            filenames = filter(is_path,
                               filenames)  ## remove paths that don't exist
            filename_dict = {
                Path(filename).stem: filename
                for filename in filenames
            }

            interval_lists = []
            for key in tqdm(filename_dict):
                interval_list = self.get_intervals_from_videos(
                    key, df_speaker, filename_dict, parent, speaker)
                interval_lists += interval_list
            missing_data_list = set(interval_ids) - set(interval_lists)
            self.missing.save_intervals(missing_data_list)

    def get_intervals_from_videos(self, key, df, filename_dict, basepath,
                                  speaker):
        #interval_dict = {}
        ## Read the transcript
        path2text = Path(basepath) / filename_dict[key]
        text = pd.read_csv(path2text)

        ## get all intervals from the video id in a sorted table
        if key[:2] == '_-':
            key = key[2:]
        df_video = df[df['video_id'] == key].sort_values(by='start_time')
        if df_video.empty:  ## non youtube videos
            new_key = '-'.join(key.split('-')[-5:])
            df_video = df[df['video_id'].apply(
                lambda x: new_key in x)].sort_values(by='start_time')
        text.loc[:, 'interval_id'] = text['End'].apply(
            self.find_interval_for_words, args=(df_video, ))

        interval_ids = filter(None, text['interval_id'].unique())
        interval_ids = [idx for idx in interval_ids]
        texts = []
        for interval_id in interval_ids:
            try:
                ## get max_len of the pose data
                interval_path = replace_Nth_parent(
                    basepath, 'processed') / '{}.h5'.format(interval_id)
                data, h5 = self.load(interval_path, 'pose/data')
                max_len = data.shape[0]
                h5.close()
            except:  ## sometimes the interval is missing
                continue

            start_offset = pd.to_timedelta(
                self.df[self.df['interval_id'] == interval_id]
                ['start_time'].str.split().str[1]).dt.total_seconds().iloc[0]

            start_frames, end_frames = [], []
            for i, row in text[text['interval_id'] ==
                               interval_id].reset_index().iterrows():
                start = row['Start']
                if i == 0:
                    start_frames.append(0)
                else:
                    start_frames.append(
                        int(
                            min(int((start - start_offset) * self.fs('text')),
                                max_len)))
                    end_frames.append(start_frames[-1])
            end_frames.append(max_len)
            text.loc[text['interval_id'] == interval_id,
                     'start_frame'] = start_frames
            text.loc[text['interval_id'] == interval_id,
                     'end_frame'] = end_frames
            #interval_dict[interval_id] = text[text['interval_id'] == interval_id].reset_index()
            subtext = text[text['interval_id'] == interval_id].reset_index()
            #texts.append(subtext)
            self.save_intervals(interval_id, speaker, {interval_id: subtext},
                                basepath)
        return interval_ids

    ## Find intervals corresponding to each word
    def find_interval_for_words(self, end_time, df):
        interval_ids = df[(df['End'] >= end_time)
                          & (df['Start'] < end_time)]['interval_id']
        if interval_ids.shape[0] > 1:
            warnings.warn('More than one interval for one word')
        if interval_ids.shape[0] == 0:
            return None
        return str(interval_ids.iloc[0])

    def save_intervals(self, interval_id, speaker, filename_dict, parent):
        if interval_id in filename_dict:
            ## Store Meta
            text = filename_dict[interval_id][[
                'Word', 'start_frame', 'end_frame'
            ]]
            #dt = h5py.special_dtype(vlen=str)
            #text = np.asarray(text, dtype=dt)
            filename = Path(
                self.path2outdata) / 'processed' / speaker / '{}.h5'.format(
                    interval_id)
            key = self.add_key(self.h5_key, ['meta'])

            if not HDF5.isDatasetInFile(filename, key):
                text.to_hdf(filename, key, mode='a')
            #self.append(filename, key, text)

            ## process data for each preprocess_method
            processed_datas = self.process_interval(interval_id, parent,
                                                    filename_dict)

            ## save processed_data
            for preprocess_method, processed_data in zip(
                    self.preprocess_methods, processed_datas):
                filename = Path(
                    self.path2outdata
                ) / 'processed' / speaker / '{}.h5'.format(interval_id)
                key = self.add_key(self.h5_key, [preprocess_method])
                try:
                    self.append(filename, key, processed_data)
                except:
                    warnings.warn('interval_id: {} busy.'.format(interval_id))
                    return interval_id
            return None
        else:
            warnings.warn('interval_id: {} not found.'.format(interval_id))
            return interval_id

    def process_interval(self, interval_id, parent, filename_dict):
        ## get filename
        text = filename_dict[interval_id]
        words_repeated = []
        for i, row in text.reset_index().iterrows():
            words_repeated += [row['Word']] * int(
                (row['end_frame'] - row['start_frame']))

        processed_datas = []
        ## process file
        for preprocess_method, model in zip(self.preprocess_methods,
                                            self.w2v_models):
            if preprocess_method in ['w2v']:
                processed_datas.append(self.preprocess_map[preprocess_method](
                    words_repeated, model))
            elif preprocess_method in ['bert']:
                processed_datas.append(self.preprocess_map[preprocess_method](
                    text, model))
            elif preprocess_method in ['tokens']:
                processed_datas.append(self.preprocess_map[preprocess_method](
                    text, model))
            elif preprocess_method in ['pos']:
                processed_datas.append(self.preprocess_map[preprocess_method](
                    text, model, words_repeated))

        ## return processed output
        return processed_datas

    '''
  PreProcess Methods
  '''

    @property
    def preprocess_map(self):
        return {
            'w2v': self.w2v,
            'bert': self.bert,
            'tokens': self.bert_tokens,
            'pos': self.pos
        }

    def w2v(self, words, model):
        return model(words)[0].squeeze(1)

    def bert(self, text, model):
        text['delta_frames'] = (text['end_frame'] -
                                text['start_frame']).apply(int)
        text_delta_frames = text.delta_frames
        words = text['Word'].values
        words = [word.lower() for word in words]
        sentence = [' '.join(words)]
        outs, pool, words_cap, mask = model(sentence)
        count = 0
        text_cap = pd.DataFrame(columns=text.columns)
        temp_words = []
        temp_word = []
        delta_frames = []
        delta_frames_cap = []
        for word in words_cap[0][1:-1]:
            if '##' == word[:2]:
                temp_word.append(word[2:])
            else:
                temp_word.append(word)
            if ''.join(temp_word) == words[count]:
                temp_words.append((''.join(temp_word), len(temp_word)))
                delta_frames.append(len(temp_word))
                delta_frames_cap += [
                    int(text_delta_frames[count] / delta_frames[-1])
                ] * delta_frames[-1]
                if delta_frames[-1] > 1:
                    delta_frames_cap[-1] = text_delta_frames.iloc[count] - sum(
                        delta_frames_cap[-delta_frames[-1] + 1:])
                temp_word = []
                count += 1

        feats = []
        for i, frames in enumerate(delta_frames_cap):
            feats += [outs[0, i + 1:i + 2]] * frames
        try:
            feats = torch.cat(feats, dim=0)
        except:
            pdb.set_trace()
        if not feats.shape[0] == sum(text_delta_frames):
            pdb.set_trace()
        return feats

    def bert_tokens(self, text, model):
        text['delta_frames'] = (text['end_frame'] -
                                text['start_frame']).apply(int)
        text_delta_frames = text.delta_frames
        words = text['Word'].values
        words = [word.lower() for word in words]
        sentence = [' '.join(words)]
        outs, mask, words_cap = model(sentence)

        words_cap_ = []
        outs_list = []
        for wc, mk, ot in zip(words_cap, mask, outs):
            words_cap_ += wc[1:sum(mk).item() - 1]
            outs_list.append(ot[1:sum(mk).item() - 1])
        words_cap = words_cap_
        outs = torch.cat(outs_list)

        count = 0
        text_cap = pd.DataFrame(columns=text.columns)
        temp_words = []
        temp_word = []
        delta_frames = []
        delta_frames_cap = []
        for word in words_cap:
            if '##' == word[:2]:
                temp_word.append(word[2:])
            else:
                temp_word.append(word)
            if ''.join(temp_word) == words[count]:
                temp_words.append((''.join(temp_word), len(temp_word)))
                delta_frames.append(len(temp_word))
                delta_frames_cap += [
                    int(text_delta_frames[count] / delta_frames[-1])
                ] * delta_frames[-1]
                if delta_frames[-1] > 1:
                    delta_frames_cap[-1] = text_delta_frames.iloc[count] - sum(
                        delta_frames_cap[-delta_frames[-1] + 1:])
                temp_word = []
                count += 1

        feats = []
        for i, frames in enumerate(delta_frames_cap):
            feats += [outs[i:i + 1]] * frames
        try:
            feats = torch.cat(feats, dim=0)
        except:
            pdb.set_trace()
        if not feats.shape[0] == sum(text_delta_frames):
            pdb.set_trace()
        return feats

    def pos(self, text, model, words_repeated):
        return model(text, words_repeated)

    def fs(self, modality):
        return 15

    @property
    def h5_key(self):
        return 'text'
예제 #5
0
파일: audio.py 프로젝트: chahuja/mix-stage
class Audio(Modality):
    def __init__(self,
                 path2data='../dataset/groot/data',
                 path2outdata='../dataset/groot/data',
                 speaker='all',
                 preprocess_methods=['log_mel_512']):
        super(Audio, self).__init__(path2data=path2data)
        self.path2data = path2data
        self.df = pd.read_csv(Path(self.path2data) / 'cmu_intervals_df.csv',
                              dtype=object)
        self.df.loc[:, 'delta_time'] = self.df['delta_time'].apply(float)
        self.df.loc[:, 'interval_id'] = self.df['interval_id'].apply(str)

        self.path2outdata = path2outdata
        self.speaker = speaker
        self.preprocess_methods = preprocess_methods

        self.missing = MissingData(self.path2data)

    def preprocess(self):
        if self.speaker[0] != 'all':
            speakers = self.speaker
        else:
            speakers = self.speakers

        for speaker in tqdm(speakers, desc='speakers', leave=False):
            tqdm.write('Speaker: {}'.format(speaker))
            df_speaker = self.get_df_subset('speaker', speaker)
            interval_ids = df_speaker['interval_id'].unique()

            ## find path to processed files
            parent = Path(
                self.path2data) / 'raw' / '{}_cropped'.format(speaker)
            filenames = os.listdir(parent)
            filenames = [
                filename for filename in filenames
                if filename.split('.')[-1] == 'mp3'
            ]
            filename_dict = {
                filename.split('.')[0].split('_')[-1]: filename
                for filename in filenames
            }
            #self.save_intervals(interval_ids[0], speaker, filename_dict, parent)
            #pdb.set_trace()
            # missing_data_list = []
            # for interval_id in tqdm(interval_ids, desc='intervals'):
            #   missing_data_list.append(self.save_intervals(interval_id, speaker, filename_dict, parent))
            # pdb.set_trace()
            missing_data_list = Parallel(n_jobs=-1)(
                delayed(self.save_intervals)(interval_id, speaker,
                                             filename_dict, parent)
                for interval_id in tqdm(interval_ids, desc='intervals'))
            self.missing.save_intervals(missing_data_list)

    def save_intervals(self, interval_id, speaker, filename_dict, parent):
        if interval_id in filename_dict:
            ## process data for each preprocess_method
            processed_datas = self.process_interval(interval_id, parent,
                                                    filename_dict)

            ## save processed_data
            for preprocess_method, processed_data in zip(
                    self.preprocess_methods, processed_datas):
                if processed_data is None:
                    warnings.warn('{}.mp3 not readable.'.format(interval_id))
                    return interval_id
                filename = Path(
                    self.path2outdata
                ) / 'processed' / speaker / '{}.h5'.format(interval_id)
                key = self.add_key(self.h5_key, [preprocess_method])
                self.append(filename, key, processed_data)
            return None
        else:
            warnings.warn('interval_id: {} not found.'.format(interval_id))
            return interval_id

    def process_interval(self, interval_id, parent, filename_dict):
        ## get filename
        filename = parent / filename_dict[interval_id]

        ## read file
        try:
            y, sr = librosa.load(filename, sr=None, mono=True)
        except:
            return [None] * len(self.preprocess_methods)
        processed_datas = []
        ## process file
        for preprocess_method in self.preprocess_methods:
            processed_datas.append(self.preprocess_map[preprocess_method](y,
                                                                          sr))
        ## return processed output
        return processed_datas

    '''
  PreProcess Methods
  '''

    @property
    def preprocess_map(self):
        return {
            'log_mel_512': self.log_mel_512,
            'log_mel_400': self.log_mel_400,
            'silence': self.silence
        }

    def log_mel_512(self, y, sr, eps=1e-10):
        spec = librosa.feature.melspectrogram(y=y,
                                              sr=sr,
                                              n_fft=2048,
                                              hop_length=512)
        mask = (spec == 0).astype(np.float)
        spec = mask * eps + (1 - mask) * spec
        return np.log(spec).transpose(1, 0)

    def log_mel_400(self, y, sr, eps=1e-6):
        y = librosa.core.resample(y, orig_sr=sr,
                                  target_sr=16000)  ## resampling to 16k Hz
        #pdb.set_trace()
        sr = 16000
        n_fft = 512
        hop_length = 160
        win_length = 400
        S = librosa.core.stft(y=y.reshape((-1)),
                              n_fft=n_fft,
                              hop_length=hop_length,
                              win_length=win_length,
                              center=False)

        S = np.abs(S)
        spec = librosa.feature.melspectrogram(S=S,
                                              sr=sr,
                                              n_fft=n_fft,
                                              hop_length=hop_length,
                                              power=1,
                                              n_mels=64,
                                              fmin=125.0,
                                              fmax=7500.0,
                                              norm=None)
        mask = (spec == 0).astype(np.float)
        spec = mask * eps + (1 - mask) * spec
        return np.log(spec).transpose(1, 0)

    def silence(self, y, sr, eps=1e-6):
        vad = webrtcvad.Vad(3)
        y = librosa.core.resample(y, orig_sr=sr,
                                  target_sr=16000)  ## resampling to 16k Hz
        #pdb.set_trace()
        fs_old = 16000
        fs_new = 15
        ranges = np.arange(0, y.shape[0], fs_old / fs_new)
        starts = ranges[0:-1]
        ends = ranges[1:]

        is_speeches = []
        for start, end in zip(starts, ends):
            Ranges = np.arange(start, end, fs_old / 100)
            is_speech = []
            for s, e, in zip(Ranges[:-1], Ranges[1:]):
                try:
                    is_speech.append(
                        vad.is_speech(y[int(s):int(e)].tobytes(), fs_old))
                except:
                    pdb.set_trace()
            is_speeches.append(
                int(np.array(is_speech, dtype=np.int).mean() <= 0.5))
            is_speeches.append(0)
        return np.array(is_speeches, dtype=np.int)

    @property
    def fs_map(self):
        return {
            'log_mel_512':
            int(45.6 * 1000 /
                512),  #int(44.1*1000/512) #112 #round(22.5*1000/512)
            'log_mel_400': int(16.52 * 1000 / 160),
            'silence': 15
        }

    def fs(self, modality):
        modality = modality.split('/')[-1]
        return self.fs_map[modality]

    @property
    def h5_key(self):
        return 'audio'