def __init__(self, tsv_path, dict_path, unit, batch_size, n_epochs=1e10, is_test=False, min_n_frames=40, max_n_frames=2000, shuffle_bucket=False, sort_by='utt_id', short2long=False, sort_stop_epoch=1000, dynamic_batching=False, corpus='', tsv_path_sub1=False, tsv_path_sub2=False, dict_path_sub1=False, dict_path_sub2=False, nlsyms=False, unit_sub1=False, unit_sub2=False, wp_model=False, wp_model_sub1=False, wp_model_sub2=False, ctc=False, ctc_sub1=False, ctc_sub2=False, subsample_factor=1, subsample_factor_sub1=1, subsample_factor_sub2=1, discourse_aware=False, first_n_utterances=-1): """A class for loading dataset. Args: tsv_path (str): path to the dataset tsv file dict_path (str): path to the dictionary unit (str): word/wp/char/phone/word_char batch_size (int): size of mini-batch nlsyms (str): path to the non-linguistic symbols file n_epochs (int): total epochs for training. is_test (bool): min_n_frames (int): exclude utterances shorter than this value max_n_frames (int): exclude utterances longer than this value shuffle_bucket (bool): gather the similar length of utterances and shuffle them sort_by (str): sort all utterances in the ascending order input: sort by input length output: sort by output length shuffle: shuffle all utterances short2long (bool): sort utterances in the descending order sort_stop_epoch (int): After sort_stop_epoch, training will revert back to a random order dynamic_batching (bool): change batch size dynamically in training ctc (bool): subsample_factor (int): wp_model (): path to the word-piece model for sentencepiece corpus (str): name of corpus discourse_aware (bool): first_n_utterances (int): evaluate the first N utterances """ super(Dataset, self).__init__() self.epoch = 0 self.iteration = 0 self.offset = 0 self.set = os.path.basename(tsv_path).split('.')[0] self.is_test = is_test self.unit = unit self.unit_sub1 = unit_sub1 self.batch_size = batch_size self.n_epochs = n_epochs self.shuffle_bucket = shuffle_bucket if shuffle_bucket: assert sort_by in ['input', 'output'] self.sort_stop_epoch = sort_stop_epoch self.sort_by = sort_by assert sort_by in ['input', 'output', 'shuffle', 'utt_id'] self.dynamic_batching = dynamic_batching self.corpus = corpus self.discourse_aware = discourse_aware if discourse_aware: assert not is_test self.vocab = count_vocab_size(dict_path) self.eos = 2 self.pad = 3 # NOTE: reserved in advance self.idx2token = [] self.token2idx = [] # Set index converter if unit in ['word', 'word_char']: self.idx2token += [Idx2word(dict_path)] self.token2idx += [Word2idx(dict_path, word_char_mix=(unit == 'word_char'))] elif unit == 'wp': self.idx2token += [Idx2wp(dict_path, wp_model)] self.token2idx += [Wp2idx(dict_path, wp_model)] elif unit in ['char']: self.idx2token += [Idx2char(dict_path)] self.token2idx += [Char2idx(dict_path, nlsyms=nlsyms)] elif 'phone' in unit: self.idx2token += [Idx2phone(dict_path)] self.token2idx += [Phone2idx(dict_path)] else: raise ValueError(unit) for i in range(1, 3): dict_path_sub = locals()['dict_path_sub' + str(i)] wp_model_sub = locals()['wp_model_sub' + str(i)] unit_sub = locals()['unit_sub' + str(i)] if dict_path_sub: setattr(self, 'vocab_sub' + str(i), count_vocab_size(dict_path_sub)) # Set index converter if unit_sub: if unit_sub == 'wp': self.idx2token += [Idx2wp(dict_path_sub, wp_model_sub)] self.token2idx += [Wp2idx(dict_path_sub, wp_model_sub)] elif unit_sub == 'char': self.idx2token += [Idx2char(dict_path_sub)] self.token2idx += [Char2idx(dict_path_sub, nlsyms=nlsyms)] elif 'phone' in unit_sub: self.idx2token += [Idx2phone(dict_path_sub)] self.token2idx += [Phone2idx(dict_path_sub)] else: raise ValueError(unit_sub) else: setattr(self, 'vocab_sub' + str(i), -1) # Load dataset tsv file df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t') df = df.loc[:, ['utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim']] for i in range(1, 3): if locals()['tsv_path_sub' + str(i)]: df_sub = pd.read_csv(locals()['tsv_path_sub' + str(i)], encoding='utf-8', delimiter='\t') df_sub = df_sub.loc[:, ['utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim']] setattr(self, 'df_sub' + str(i), df_sub) else: setattr(self, 'df_sub' + str(i), None) self.input_dim = kaldiio.load_mat(df['feat_path'][0]).shape[-1] # Remove inappropriate utterances if is_test or discourse_aware: print('Original utterance num: %d' % len(df)) n_utts = len(df) df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)] print('Removed %d empty utterances' % (n_utts - len(df))) if first_n_utterances > 0: n_utts = len(df) df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)] df = df.truncate(before=0, after=first_n_utterances - 1) print('Select first %d utterances' % len(df)) else: print('Original utterance num: %d' % len(df)) n_utts = len(df) df = df[df.apply(lambda x: min_n_frames <= x[ 'xlen'] <= max_n_frames, axis=1)] df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)] print('Removed %d utterances (threshold)' % (n_utts - len(df))) if ctc and subsample_factor > 1: n_utts = len(df) df = df[df.apply(lambda x: x['ylen'] <= (x['xlen'] // subsample_factor), axis=1)] print('Removed %d utterances (for CTC)' % (n_utts - len(df))) for i in range(1, 3): df_sub = getattr(self, 'df_sub' + str(i)) ctc_sub = locals()['ctc_sub' + str(i)] subsample_factor_sub = locals()['subsample_factor_sub' + str(i)] if df_sub is not None: if ctc_sub and subsample_factor_sub > 1: df_sub = df_sub[df_sub.apply( lambda x: x['ylen'] <= (x['xlen'] // subsample_factor_sub), axis=1)] if len(df) != len(df_sub): n_utts = len(df) df = df.drop(df.index.difference(df_sub.index)) print('Removed %d utterances (for CTC, sub%d)' % (n_utts - len(df), i)) for j in range(1, i + 1): setattr(self, 'df_sub' + str(j), getattr(self, 'df_sub' + str(j)).drop(getattr(self, 'df_sub' + str(j)).index.difference(df.index))) if corpus == 'swbd': # 1. serialize # df['session'] = df['speaker'].apply(lambda x: str(x).split('-')[0]) # 2. not serialize df['session'] = df['speaker'].apply(lambda x: str(x)) else: df['session'] = df['speaker'].apply(lambda x: str(x)) # Sort tsv records if discourse_aware: # Sort by onset (start time) df = df.assign(prev_utt='') df = df.assign(line_no=list(range(len(df)))) if corpus == 'swbd': df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[-1].split('-')[0])) elif corpus == 'csj': df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[1])) elif corpus == 'tedlium2': df['onset'] = df['utt_id'].apply(lambda x: int(x.split('-')[-2])) else: raise NotImplementedError(corpus) df = df.sort_values(by=['session', 'onset'], ascending=True) # Extract previous utterances groups = df.groupby('session').groups df['prev_utt'] = df.apply( lambda x: [df.loc[i, 'line_no'] for i in groups[x['session']] if df.loc[i, 'onset'] < x['onset']], axis=1) df['n_prev_utt'] = df.apply(lambda x: len(x['prev_utt']), axis=1) df['n_utt_in_session'] = df.apply( lambda x: len([i for i in groups[x['session']]]), axis=1) df = df.sort_values(by=['n_utt_in_session'], ascending=short2long) # NOTE: this is used only when LM is trained with seliarize: true # if is_test and corpus == 'swbd': # # Sort by onset # df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[-1].split('-')[0])) # df = df.sort_values(by=['session', 'onset'], ascending=True) elif not is_test: if sort_by == 'input': df = df.sort_values(by=['xlen'], ascending=short2long) elif sort_by == 'output': df = df.sort_values(by=['ylen'], ascending=short2long) elif sort_by == 'shuffle': df = df.reindex(np.random.permutation(self.df.index)) # Re-indexing if discourse_aware: self.df = df for i in range(1, 3): if getattr(self, 'df_sub' + str(i)) is not None: setattr(self, 'df_sub' + str(i), getattr(self, 'df_sub' + str(i)).reindex(df.index)) else: self.df = df.reset_index() for i in range(1, 3): if getattr(self, 'df_sub' + str(i)) is not None: setattr(self, 'df_sub' + str(i), getattr(self, 'df_sub' + str(i)).reindex(df.index).reset_index()) if discourse_aware: self.df_indices_buckets = self.discourse_bucketing(batch_size) elif shuffle_bucket: self.df_indices_buckets = self.shuffle_bucketing(batch_size) else: self.df_indices = list(self.df.index)
def __init__(self, corpus, tsv_path, dict_path, unit, nlsyms, wp_model, is_test, min_n_frames, max_n_frames, sort_by, short2long, tsv_path_sub1, tsv_path_sub2, ctc, ctc_sub1, ctc_sub2, subsample_factor, subsample_factor_sub1, subsample_factor_sub2, dict_path_sub1, dict_path_sub2, unit_sub1, unit_sub2, wp_model_sub1, wp_model_sub2, discourse_aware=False, simulate_longform=False, first_n_utterances=-1, word_alignment_dir=None, ctc_alignment_dir=None): """Custom Dataset class. Args: corpus (str): name of corpus tsv_path (str): path to the dataset tsv file dict_path (str): path to the dictionary unit (str): word/wp/char/phone/word_char nlsyms (str): path to the non-linguistic symbols file wp_model (): path to the word-piece model for sentencepiece is_test (bool): min_n_frames (int): exclude utterances shorter than this value max_n_frames (int): exclude utterances longer than this value sort_by (str): sort all utterances in the ascending order input: sort by input length output: sort by output length shuffle: shuffle all utterances short2long (bool): sort utterances in the descending order ctc (bool): subsample_factor (int): discourse_aware (bool): sort in the discourse order simulate_longform (bool): simulate long-form uttterance first_n_utterances (int): evaluate the first N utterances word_alignment_dir (str): path to word alignment directory ctc_alignment_dir (str): path to CTC alignment directory """ super(Dataset, self).__init__() self.epoch = 0 # meta deta accessed by dataloader self._corpus = corpus self._set = os.path.basename(tsv_path).split('.')[0] self._vocab = count_vocab_size(dict_path) self._unit = unit self._unit_sub1 = unit_sub1 self._unit_sub2 = unit_sub2 self.is_test = is_test self.sort_by = sort_by # if shuffle_bucket: # assert sort_by in ['input', 'output'] if discourse_aware: assert not is_test if simulate_longform: assert is_test self.simulate_longform = simulate_longform self.subsample_factor = subsample_factor self.word_alignment_dir = word_alignment_dir self.ctc_alignment_dir = ctc_alignment_dir self._idx2token = [] self._token2idx = [] # Set index converter if unit in ['word', 'word_char']: self._idx2token += [Idx2word(dict_path)] self._token2idx += [ Word2idx(dict_path, word_char_mix=(unit == 'word_char')) ] elif unit == 'wp': self._idx2token += [Idx2wp(dict_path, wp_model)] self._token2idx += [Wp2idx(dict_path, wp_model)] elif unit in ['char']: self._idx2token += [Idx2char(dict_path)] self._token2idx += [Char2idx(dict_path, nlsyms=nlsyms)] elif 'phone' in unit: self._idx2token += [Idx2phone(dict_path)] self._token2idx += [Phone2idx(dict_path)] else: raise ValueError(unit) for i in range(1, 3): dict_path_sub = locals()['dict_path_sub' + str(i)] wp_model_sub = locals()['wp_model_sub' + str(i)] unit_sub = locals()['unit_sub' + str(i)] if dict_path_sub: setattr(self, '_vocab_sub' + str(i), count_vocab_size(dict_path_sub)) # Set index converter if unit_sub: if unit_sub == 'wp': self._idx2token += [ Idx2wp(dict_path_sub, wp_model_sub) ] self._token2idx += [ Wp2idx(dict_path_sub, wp_model_sub) ] elif unit_sub == 'char': self._idx2token += [Idx2char(dict_path_sub)] self._token2idx += [ Char2idx(dict_path_sub, nlsyms=nlsyms) ] elif 'phone' in unit_sub: self._idx2token += [Idx2phone(dict_path_sub)] self._token2idx += [Phone2idx(dict_path_sub)] else: raise ValueError(unit_sub) else: setattr(self, '_vocab_sub' + str(i), -1) # Load dataset tsv file df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t') df = df.loc[:, [ 'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim' ]] for i in range(1, 3): if locals()['tsv_path_sub' + str(i)]: df_sub = pd.read_csv(locals()['tsv_path_sub' + str(i)], encoding='utf-8', delimiter='\t') df_sub = df_sub.loc[:, [ 'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim' ]] setattr(self, 'df_sub' + str(i), df_sub) else: setattr(self, 'df_sub' + str(i), None) self._input_dim = kaldiio.load_mat(df['feat_path'][0]).shape[-1] # Remove inappropriate utterances print('Original utterance num: %d' % len(df)) n_utts = len(df) if is_test or discourse_aware: df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)] print('Removed %d empty utterances' % (n_utts - len(df))) if first_n_utterances > 0: n_utts = len(df) df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)] df = df.truncate(before=0, after=first_n_utterances - 1) print('Select first %d utterances' % len(df)) else: df = df[df.apply( lambda x: min_n_frames <= x['xlen'] <= max_n_frames, axis=1)] df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)] print('Removed %d utterances (threshold)' % (n_utts - len(df))) if ctc and subsample_factor > 1: n_utts = len(df) df = df[df.apply(lambda x: x['ylen'] <= (x['xlen'] // subsample_factor), axis=1)] print('Removed %d utterances (for CTC)' % (n_utts - len(df))) for i in range(1, 3): df_sub = getattr(self, 'df_sub' + str(i)) ctc_sub = locals()['ctc_sub' + str(i)] subsample_factor_sub = locals()['subsample_factor_sub' + str(i)] if df_sub is not None: if ctc_sub and subsample_factor_sub > 1: df_sub = df_sub[df_sub.apply( lambda x: x['ylen'] <= (x['xlen'] // subsample_factor_sub), axis=1)] if len(df) != len(df_sub): n_utts = len(df) df = df.drop(df.index.difference(df_sub.index)) print('Removed %d utterances (for CTC, sub%d)' % (n_utts - len(df), i)) for j in range(1, i + 1): setattr( self, 'df_sub' + str(j), getattr(self, 'df_sub' + str(j)).drop( getattr(self, 'df_sub' + str(j)).index.difference( df.index))) if corpus == 'swbd': # 1. serialize # df['session'] = df['speaker'].apply(lambda x: str(x).split('-')[0]) # 2. not serialize df['session'] = df['speaker'].apply(lambda x: str(x)) else: df['session'] = df['speaker'].apply(lambda x: str(x)) # Sort tsv records if discourse_aware: # Sort by onset (start time) df = df.assign(prev_utt='') df = df.assign(line_no=list(range(len(df)))) if corpus == 'swbd': df['onset'] = df['utt_id'].apply( lambda x: int(x.split('_')[-1].split('-')[0])) elif corpus == 'csj': df['onset'] = df['utt_id'].apply( lambda x: int(x.split('_')[1])) elif corpus == 'tedlium2': df['onset'] = df['utt_id'].apply( lambda x: int(x.split('-')[-2])) else: raise NotImplementedError(corpus) df = df.sort_values(by=['session', 'onset'], ascending=True) # Extract previous utterances groups = df.groupby('session').groups df['prev_utt'] = df.apply(lambda x: [ df.loc[i, 'line_no'] for i in groups[x['session']] if df.loc[i, 'onset'] < x['onset'] ], axis=1) df['n_prev_utt'] = df.apply(lambda x: len(x['prev_utt']), axis=1) df['n_utt_in_session'] = df.apply( lambda x: len([i for i in groups[x['session']]]), axis=1) df = df.sort_values(by=['n_utt_in_session'], ascending=short2long) # NOTE: this is used only when LM is trained with serialize: true # if is_test and corpus == 'swbd': # # Sort by onset # df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[-1].split('-')[0])) # df = df.sort_values(by=['session', 'onset'], ascending=True) elif not is_test: if sort_by == 'input': df = df.sort_values(by=['xlen'], ascending=short2long) elif sort_by == 'output': df = df.sort_values(by=['ylen'], ascending=short2long) elif sort_by == 'shuffle': df = df.reindex(np.random.permutation(self.df.index)) # Fit word alignment to vocabulary if word_alignment_dir is not None: alignment2boundary = WordAlignmentConverter(dict_path, wp_model) n_utts = len(df) df['trigger_points'] = df.apply(lambda x: alignment2boundary( word_alignment_dir, x['speaker'], x['utt_id'], x['text']), axis=1) # remove utterances which do not have the alignment df = df[df.apply(lambda x: x['trigger_points'] is not None, axis=1)] print('Removed %d utterances (for word alignment)' % (n_utts - len(df))) elif ctc_alignment_dir is not None: n_utts = len(df) df['trigger_points'] = df.apply(lambda x: load_ctc_alignment( ctc_alignment_dir, x['speaker'], x['utt_id']), axis=1) # remove utterances which do not have the alignment df = df[df.apply(lambda x: x['trigger_points'] is not None, axis=1)] print('Removed %d utterances (for CTC alignment)' % (n_utts - len(df))) # Re-indexing if discourse_aware: self.df = df for i in range(1, 3): if getattr(self, 'df_sub' + str(i)) is not None: setattr(self, 'df_sub' + str(i), getattr(self, 'df_sub' + str(i)).reindex(df.index)) else: self.df = df.reset_index() for i in range(1, 3): if getattr(self, 'df_sub' + str(i)) is not None: setattr( self, 'df_sub' + str(i), getattr(self, 'df_sub' + str(i)).reindex( df.index).reset_index())
def __init__(self, tsv_path, dict_path, unit, batch_size, nlsyms=False, n_epochs=None, is_test=False, min_n_frames=40, max_n_frames=2000, shuffle_bucket=False, sort_by='utt_id', short2long=False, sort_stop_epoch=None, dynamic_batching=False, ctc=False, subsample_factor=1, wp_model=False, corpus='', tsv_path_sub1=False, dict_path_sub1=False, unit_sub1=False, wp_model_sub1=False, ctc_sub1=False, subsample_factor_sub1=1, tsv_path_sub2=False, dict_path_sub2=False, unit_sub2=False, wp_model_sub2=False, ctc_sub2=False, subsample_factor_sub2=1, discourse_aware=False, skip_thought=False): """A class for loading dataset. Args: tsv_path (str): path to the dataset tsv file dict_path (str): path to the dictionary unit (str): word or wp or char or phone or word_char batch_size (int): size of mini-batch nlsyms (str): path to the non-linguistic symbols file n_epochs (int): max epoch. None means infinite loop. is_test (bool): min_n_frames (int): exclude utterances shorter than this value max_n_frames (int): exclude utterances longer than this value shuffle_bucket (bool): gather the similar length of utterances and shuffle them sort_by (str): sort all utterances in the ascending order input: sort by input length output: sort by output length shuffle: shuffle all utterances short2long (bool): sort utterances in the descending order sort_stop_epoch (int): After sort_stop_epoch, training will revert back to a random order dynamic_batching (bool): change batch size dynamically in training ctc (bool): subsample_factor (int): wp_model (): path to the word-piece model for sentencepiece corpus (str): name of corpus discourse_aware (bool): skip_thought (bool): """ super(Dataset, self).__init__() self.epoch = 0 self.iteration = 0 self.offset = 0 self.set = os.path.basename(tsv_path).split('.')[0] self.is_test = is_test self.unit = unit self.unit_sub1 = unit_sub1 self.batch_size = batch_size self.max_epoch = n_epochs self.shuffle_bucket = shuffle_bucket if shuffle_bucket: assert sort_by in ['input', 'output'] self.sort_stop_epoch = sort_stop_epoch self.sort_by = sort_by assert sort_by in ['input', 'output', 'shuffle', 'utt_id', 'no_sort'] self.dynamic_batching = dynamic_batching self.corpus = corpus self.discourse_aware = discourse_aware self.skip_thought = skip_thought self.vocab = count_vocab_size(dict_path) self.eos = 2 self.pad = 3 # NOTE: reserved in advance self.idx2token = [] self.token2idx = [] # Set index converter if unit in ['word', 'word_char']: self.idx2token += [Idx2word(dict_path)] self.token2idx += [Word2idx(dict_path, word_char_mix=(unit == 'word_char'))] elif unit == 'wp': self.idx2token += [Idx2wp(dict_path, wp_model)] self.token2idx += [Wp2idx(dict_path, wp_model)] elif unit == 'char': self.idx2token += [Idx2char(dict_path)] self.token2idx += [Char2idx(dict_path, nlsyms=nlsyms)] elif 'phone' in unit: self.idx2token += [Idx2phone(dict_path)] self.token2idx += [Phone2idx(dict_path)] else: raise ValueError(unit) for i in range(1, 3): dict_path_sub = locals()['dict_path_sub' + str(i)] wp_model_sub = locals()['wp_model_sub' + str(i)] unit_sub = locals()['unit_sub' + str(i)] if dict_path_sub: setattr(self, 'vocab_sub' + str(i), count_vocab_size(dict_path_sub)) # Set index converter if unit_sub: if unit_sub == 'wp': self.idx2token += [Idx2wp(dict_path_sub, wp_model_sub)] self.token2idx += [Wp2idx(dict_path_sub, wp_model_sub)] elif unit_sub == 'char': self.idx2token += [Idx2char(dict_path_sub)] self.token2idx += [Char2idx(dict_path_sub, nlsyms=nlsyms)] elif 'phone' in unit_sub: self.idx2token += [Idx2phone(dict_path_sub)] self.token2idx += [Phone2idx(dict_path_sub)] else: raise ValueError(unit_sub) else: setattr(self, 'vocab_sub' + str(i), -1) # Load dataset tsv file df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t', quoting=csv.QUOTE_NONE, dtype={'utt_id': 'str'}) df = df.loc[:, ['utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim']] for i in range(1, 3): if locals()['tsv_path_sub' + str(i)]: df_sub = pd.read_csv(locals()['tsv_path_sub' + str(i)], encoding='utf-8', delimiter='\t') df_sub = df_sub.loc[:, ['utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim']] setattr(self, 'df_sub' + str(i), df_sub) else: setattr(self, 'df_sub' + str(i), None) self.input_dim = kaldiio.load_mat(df['feat_path'][0]).shape[-1] if corpus == 'swbd': df['session'] = df['speaker'].apply(lambda x: str(x).split('-')[0]) else: df['session'] = df['speaker'].apply(lambda x: str(x)) if discourse_aware or skip_thought: max_n_frames = 10000 min_n_frames = 100 # Sort by onset df = df.assign(prev_utt='') if corpus == 'swbd': df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[-1].split('-')[0])) elif corpus == 'csj': df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[1])) elif corpus == 'wsj': df['onset'] = df['utt_id'].apply(lambda x: x) else: raise NotImplementedError df = df.sort_values(by=['session', 'onset'], ascending=True) # Extract previous utterances if not skip_thought: # df = df.assign(line_no=list(range(len(df)))) groups = df.groupby('session').groups df['n_session_utt'] = df.apply( lambda x: len([i for i in groups[x['session']]]), axis=1) # df['prev_utt'] = df.apply( # lambda x: [df.loc[i, 'line_no'] # for i in groups[x['session']] if df.loc[i, 'onset'] < x['onset']], axis=1) # df['n_prev_utt'] = df.apply(lambda x: len(x['prev_utt']), axis=1) elif is_test and corpus == 'swbd': # Sort by onset df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[-1].split('-')[0])) df = df.sort_values(by=['session', 'onset'], ascending=True) # Remove inappropriate utterances if is_test: print('Original utterance num: %d' % len(df)) n_utts = len(df) df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)] print('Removed %d empty utterances' % (n_utts - len(df))) else: print('Original utterance num: %d' % len(df)) n_utts = len(df) df = df[df.apply(lambda x: min_n_frames <= x[ 'xlen'] <= max_n_frames, axis=1)] df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)] print('Removed %d utterances (threshold)' % (n_utts - len(df))) if ctc and subsample_factor > 1: n_utts = len(df) df = df[df.apply(lambda x: x['ylen'] <= (x['xlen'] // subsample_factor), axis=1)] print('Removed %d utterances (for CTC)' % (n_utts - len(df))) for i in range(1, 3): df_sub = getattr(self, 'df_sub' + str(i)) ctc_sub = locals()['ctc_sub' + str(i)] subsample_factor_sub = locals()['subsample_factor_sub' + str(i)] if df_sub is not None: if ctc_sub and subsample_factor_sub > 1: df_sub = df_sub[df_sub.apply( lambda x: x['ylen'] <= (x['xlen'] // subsample_factor_sub), axis=1)] if len(df) != len(df_sub): n_utts = len(df) df = df.drop(df.index.difference(df_sub.index)) print('Removed %d utterances (for CTC, sub%d)' % (n_utts - len(df), i)) for j in range(1, i + 1): setattr(self, 'df_sub' + str(j), getattr(self, 'df_sub' + str(j)).drop(getattr(self, 'df_sub' + str(j)).index.difference(df.index))) # Re-indexing for i in range(1, 3): if getattr(self, 'df_sub' + str(i)) is not None: setattr(self, 'df_sub' + str(i), getattr(self, 'df_sub' + str(i)).reset_index()) # Sort tsv records if not is_test: if discourse_aware: self.utt_offset = 0 self.n_utt_session_dict = {} self.session_offset_dict = {} for session_id, ids in sorted(df.groupby('session').groups.items(), key=lambda x: len(x[1])): n_utt = len(ids) # key: n_utt, value: session_id if n_utt not in self.n_utt_session_dict.keys(): self.n_utt_session_dict[n_utt] = [] self.n_utt_session_dict[n_utt].append(session_id) # key: session_id, value: id for the first utterance in each session self.session_offset_dict[session_id] = ids[0] self.n_utt_session_dict_epoch = copy.deepcopy(self.n_utt_session_dict) # if discourse_aware == 'state_carry_over': # df = df.sort_values(by=['n_session_utt', 'utt_id'], ascending=short2long) # else: # df = df.sort_values(by=['n_prev_utt'], ascending=short2long) elif sort_by == 'input': df = df.sort_values(by=['xlen'], ascending=short2long) elif sort_by == 'output': df = df.sort_values(by=['ylen'], ascending=short2long) elif sort_by == 'shuffle': df = df.reindex(np.random.permutation(self.df.index)) elif sort_by == 'no_sort': pass for i in range(1, 3): if getattr(self, 'df_sub' + str(i)) is not None: setattr(self, 'df_sub' + str(i), getattr(self, 'df_sub' + str(i)).reindex(df.index).reset_index()) # Re-indexing self.df = df.reset_index() self.df_indices = list(self.df.index)
def __init__(self, tsv_path, dict_path, unit, batch_size, nlsyms=False, n_epochs=None, is_test=False, min_n_tokens=1, bptt=2, shuffle=False, backward=False, serialize=False, wp_model=None, corpus=''): """A class for loading dataset. Args: tsv_path (str): path to the dataset tsv file dict_path (str): path to the dictionary unit (str): word or wp or char or phone or word_char batch_size (int): size of mini-batch nlsyms (str): path to the non-linguistic symbols file n_epochs (int): max epoch. None means infinite loop is_test (bool): min_n_tokens (int): exclude utterances shorter than this value bptt (int): BPTT length shuffle (bool): shuffle utterances. This is disabled when sort_by_input_length is True. backward (bool): flip all text in the corpus serialize (bool): serialize text according to contexts in dialogue wp_model (): path to the word-piece model for sentencepiece corpus (str): name of corpus """ super(Dataset, self).__init__() self.set = os.path.basename(tsv_path).split('.')[0] self.is_test = is_test self.unit = unit self.batch_size = batch_size self.bptt = bptt self.sos = 2 self.eos = 2 self.max_epoch = n_epochs self.shuffle = shuffle self.vocab = self.count_vocab_size(dict_path) assert bptt >= 2 self.idx2token = [] self.token2idx = [] # Set index converter if unit in ['word', 'word_char']: self.idx2token += [Idx2word(dict_path)] self.token2idx += [ Word2idx(dict_path, word_char_mix=(unit == 'word_char')) ] elif unit == 'wp': self.idx2token += [Idx2wp(dict_path, wp_model)] self.token2idx += [Wp2idx(dict_path, wp_model)] elif unit == 'char': self.idx2token += [Idx2char(dict_path)] self.token2idx += [Char2idx(dict_path, nlsyms=nlsyms)] elif 'phone' in unit: self.idx2token += [Idx2phone(dict_path)] self.token2idx += [Phone2idx(dict_path)] else: raise ValueError(unit) # Load dataset tsv file self.df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t') self.df = self.df.loc[:, [ 'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim' ]] # Remove inappropriate utterances if is_test: print('Original utterance num: %d' % len(self.df)) n_utts = len(self.df) self.df = self.df[self.df.apply(lambda x: x['ylen'] > 0, axis=1)] print('Removed %d empty utterances' % (n_utts - len(self.df))) else: print('Original utterance num: %d' % len(self.df)) n_utts = len(self.df) self.df = self.df[self.df.apply( lambda x: x['ylen'] >= min_n_tokens, axis=1)] print('Removed %d utterances (threshold)' % (n_utts - len(self.df))) # Sort tsv records if shuffle: self.df = self.df.reindex(np.random.permutation(self.df.index)) elif serialize: assert corpus == 'swbd' self.df['session'] = self.df['speaker'].apply( lambda x: str(x).split('-')[0]) self.df['onset'] = self.df['utt_id'].apply( lambda x: int(x.split('_')[-1].split('-')[0])) self.df = self.df.sort_values(by=['session', 'onset'], ascending=True) else: self.df = self.df.sort_values(by='utt_id', ascending=True) # Concatenate into a single sentence concat_ids = [] indices = list(self.df.index) if backward: indices = indices[::-1] for i in indices: assert self.df['token_id'][i] != '' concat_ids += [self.eos] + list( map(int, self.df['token_id'][i].split())) concat_ids += [self.eos] # NOTE: <sos> and <eos> have the same index # Reshape n_utts = len(concat_ids) concat_ids = concat_ids[:n_utts // batch_size * batch_size] print('Removed %d tokens / %d tokens' % (n_utts - len(concat_ids), n_utts)) self.concat_ids = np.array(concat_ids).reshape((batch_size, -1))
def __init__(self, csv_path, dict_path, label_type, batch_size, max_epoch=None, is_test=False, max_num_frames=2000, min_num_frames=40, shuffle=False, sort_by_input_length=False, short2long=False, sort_stop_epoch=None, num_enques=None, dynamic_batching=False, use_ctc=False, subsample_factor=1, skip_speech=False, csv_path_sub=None, dict_path_sub=None, label_type_sub=None, use_ctc_sub=False, subsample_factor_sub=1): """A class for loading dataset. Args: csv_path (str): dict_path (str): label_type (str): word or wordpiece or char or phone batch_size (int): the size of mini-batch max_epoch (int): the max epoch. None means infinite loop. is_test (bool): max_num_frames (int): Exclude utteraces longer than this value min_num_frames (int): Exclude utteraces shorter than this value shuffle (bool): if True, shuffle utterances. This is disabled when sort_by_input_length is True. sort_by_input_length (bool): if True, sort all utterances in the ascending order short2long (bool): if True, sort utteraces in the descending order sort_stop_epoch (int): After sort_stop_epoch, training will revert back to a random order num_enques (int): the number of elements to enqueue dynamic_batching (bool): if True, batch size will be chainged dynamically in training use_ctc (bool): subsample_factor (int): skip_speech (bool): skip loading speech features """ super(Dataset, self).__init__() self.set = os.path.basename(csv_path).split('.')[0] self.is_test = is_test self.label_type = label_type self.label_type_sub = label_type_sub self.batch_size = batch_size self.max_epoch = max_epoch self.shuffle = shuffle self.sort_by_input_length = sort_by_input_length self.sort_stop_epoch = sort_stop_epoch self.num_enques = num_enques self.dynamic_batching = dynamic_batching self.skip_speech = skip_speech self.num_classes = self.count_vocab_size(dict_path) # Set index converter if label_type in ['word', 'wordpiece']: self.idx2word = Idx2word(dict_path) self.word2idx = Word2idx(dict_path) elif label_type == 'char': self.idx2char = Idx2char(dict_path) self.char2idx = Char2idx(dict_path) elif label_type == 'char_capital_divide': self.idx2char = Idx2char(dict_path, capital_divide=True) self.char2idx = Char2idx(dict_path, capital_divide=True) elif 'phone' in label_type: self.idx2phone = Idx2phone(dict_path) self.phone2idx = Phone2idx(dict_path) else: raise ValueError(label_type) if dict_path_sub is not None: self.num_classes_sub = self.count_vocab_size(dict_path_sub) # Set index converter if label_type_sub is not None: if label_type == 'wordpiece': self.idx2word = Idx2word(dict_path_sub) self.word2idx = Word2idx(dict_path_sub) elif label_type_sub == 'char': self.idx2char = Idx2char(dict_path_sub) self.char2idx = Char2idx(dict_path_sub) elif label_type_sub == 'char_capital_divide': self.idx2char = Idx2char(dict_path_sub, capital_divide=True) self.char2idx = Char2idx(dict_path_sub, capital_divide=True) elif 'phone' in label_type_sub: self.idx2phone = Idx2phone(dict_path_sub) self.phone2idx = Phone2idx(dict_path_sub) else: raise ValueError(label_type_sub) else: self.num_classes_sub = -1 # Load dataset csv file df = pd.read_csv(csv_path, encoding='utf-8') df = df.loc[:, ['utt_id', 'feat_path', 'x_len', 'x_dim', 'text', 'token_id', 'y_len', 'y_dim']] if csv_path_sub is not None: df_sub = pd.read_csv(csv_path_sub, encoding='utf-8') df_sub = df_sub.loc[:, ['utt_id', 'feat_path', 'x_len', 'x_dim', 'text', 'token_id', 'y_len', 'y_dim']] else: df_sub = None # Remove inappropriate utteraces if not self.is_test: logger.info('Original utterance num: %d' % len(df)) num_utt_org = len(df) # Remove by threshold df = df[df.apply(lambda x: min_num_frames <= x['x_len'] <= max_num_frames, axis=1)] logger.info('Removed %d utterances (threshold)' % (num_utt_org - len(df))) # Rempve for CTC loss calculatioon if use_ctc and subsample_factor > 1: logger.info('Checking utterances for CTC...') logger.info('Original utterance num: %d' % len(df)) num_utt_org = len(df) df = df[df.apply(lambda x: x['y_len'] <= x['x_len'] // subsample_factor, axis=1)] logger.info('Removed %d utterances (for CTC)' % (num_utt_org - len(df))) if df_sub is not None: logger.info('Original utterance num (sub): %d' % len(df_sub)) num_utt_org = len(df_sub) # Remove by threshold df_sub = df_sub[df_sub.apply(lambda x: min_num_frames <= x['x_len'] <= max_num_frames, axis=1)] logger.info('Removed %d utterances (threshold, sub)' % (num_utt_org - len(df_sub))) # Rempve for CTC loss calculatioon if use_ctc_sub and subsample_factor_sub > 1: logger.info('Checking utterances for CTC...') logger.info('Original utterance num (sub): %d' % len(df_sub)) num_utt_org = len(df_sub) df_sub = df_sub[df_sub.apply(lambda x: x['y_len'] <= x['x_len'] // subsample_factor_sub, axis=1)] logger.info('Removed %d utterances (for CTC, sub)' % (num_utt_org - len(df_sub))) # Make up the number if len(df) != len(df_sub): df = df.drop(df.index.difference(df_sub.index)) df_sub = df_sub.drop(df_sub.index.difference(df.index)) # Sort csv records if sort_by_input_length: df = df.sort_values(by='x_len', ascending=short2long) else: if shuffle: df = df.reindex(np.random.permutation(df.index)) else: df = df.sort_values(by='utt_id', ascending=True) self.df = df self.df_sub = df_sub self.rest = set(list(df.index)) self.input_dim = kaldi_io.read_mat(self.df['feat_path'][0]).shape[-1]
def __init__(self, csv_path, dict_path, label_type, batch_size, bptt, eos, nepochs=None, is_test=False, shuffle=False, wp_model=None): """A class for loading dataset. Args: csv_path (str): dict_path (str): label_type (str): word or wp or char or phone batch_size (int): the size of mini-batch bptt (int): eos (int): nepochs (int): the max epoch. None means infinite loop. is_test (bool): shuffle (bool): if True, shuffle utterances. This is disabled when sort_by_input_length is True. wp_model (): """ super(Dataset, self).__init__() self.set = os.path.basename(csv_path).split('.')[0] self.is_test = is_test self.label_type = label_type self.batch_size = batch_size self.bptt = bptt self.eos = eos self.max_epoch = nepochs self.shuffle = shuffle self.vocab = self.count_vocab_size(dict_path) # Set index converter if label_type == 'word': self.idx2word = Idx2word(dict_path) self.word2idx = Word2idx(dict_path) elif label_type == 'wp': self.idx2wp = Idx2wp(dict_path, wp_model) self.wp2idx = Wp2idx(dict_path, wp_model) elif label_type == 'char': self.idx2char = Idx2char(dict_path) self.char2idx = Char2idx(dict_path) else: raise ValueError(label_type) # Load dataset csv file df = pd.read_csv(csv_path, encoding='utf-8') df = df.loc[:, [ 'utt_id', 'feat_path', 'x_len', 'x_dim', 'text', 'token_id', 'y_len', 'y_dim' ]] # Sort csv records if shuffle: self.df = df.reindex(np.random.permutation(df.index)) else: self.df = df.sort_values(by='utt_id', ascending=True) # Concatenate into a single sentence concat_ids = [eos] for i in list(self.df.index): assert self.df['token_id'][i] != '' concat_ids += list(map(int, self.df['token_id'][i].split())) + [eos] # Truncate concat_ids = concat_ids[:len(concat_ids) // batch_size * batch_size] self.concat_ids = np.array(concat_ids).reshape((batch_size, -1))
def __init__(self, tsv_path, dict_path, unit, batch_size, nlsyms=False, n_epochs=None, is_test=False, min_n_frames=40, max_n_frames=2000, shuffle=False, sort_by_input_length=False, short2long=False, sort_stop_epoch=None, n_ques=None, dynamic_batching=False, ctc=False, subsample_factor=1, wp_model=False, corpus='', tsv_path_sub1=False, dict_path_sub1=False, unit_sub1=False, wp_model_sub1=False, ctc_sub1=False, subsample_factor_sub1=1, wp_model_sub2=False, tsv_path_sub2=False, dict_path_sub2=False, unit_sub2=False, ctc_sub2=False, subsample_factor_sub2=1, contextualize=False, skip_thought=False): """A class for loading dataset. Args: tsv_path (str): path to the dataset tsv file dict_path (str): path to the dictionary unit (str): word or wp or char or phone or word_char batch_size (int): size of mini-batch nlsyms (str): path to the non-linguistic symbols file n_epochs (int): max epoch. None means infinite loop. is_test (bool): min_n_frames (int): exclude utterances shorter than this value max_n_frames (int): exclude utterances longer than this value shuffle (bool): shuffle utterances. This is disabled when sort_by_input_length is True. sort_by_input_length (bool): sort all utterances in the ascending order short2long (bool): sort utterances in the descending order sort_stop_epoch (int): After sort_stop_epoch, training will revert back to a random order n_ques (int): number of elements to enqueue dynamic_batching (bool): change batch size dynamically in training ctc (bool): subsample_factor (int): wp_model (): path to the word-piece model for sentencepiece corpus (str): name of corpus contextualize (bool): skip_thought (bool): """ super(Dataset, self).__init__() self.set = os.path.basename(tsv_path).split('.')[0] self.is_test = is_test self.unit = unit self.unit_sub1 = unit_sub1 self.batch_size = batch_size self.max_epoch = n_epochs self.shuffle = shuffle self.sort_stop_epoch = sort_stop_epoch self.sort_by_input_length = sort_by_input_length self.n_ques = n_ques self.dynamic_batching = dynamic_batching self.corpus = corpus self.contextualize = contextualize self.skip_thought = skip_thought self.vocab = self.count_vocab_size(dict_path) self.eos = 2 self.pad = 3 # NOTE: reserved in advance self.idx2token = [] self.token2idx = [] # Set index converter if unit in ['word', 'word_char']: self.idx2token += [Idx2word(dict_path)] self.token2idx += [ Word2idx(dict_path, word_char_mix=(unit == 'word_char')) ] elif unit == 'wp': self.idx2token += [Idx2wp(dict_path, wp_model)] self.token2idx += [Wp2idx(dict_path, wp_model)] elif unit == 'char': self.idx2token += [Idx2char(dict_path)] self.token2idx += [Char2idx(dict_path, nlsyms=nlsyms)] elif 'phone' in unit: self.idx2token += [Idx2phone(dict_path)] self.token2idx += [Phone2idx(dict_path)] else: raise ValueError(unit) for i in range(1, 3): dict_path_sub = locals()['dict_path_sub' + str(i)] wp_model_sub = locals()['wp_model_sub' + str(i)] unit_sub = locals()['unit_sub' + str(i)] if dict_path_sub: setattr(self, 'vocab_sub' + str(i), self.count_vocab_size(dict_path_sub)) # Set index converter if unit_sub: if unit_sub == 'wp': self.idx2token += [Idx2wp(dict_path_sub, wp_model_sub)] self.token2idx += [Wp2idx(dict_path_sub, wp_model_sub)] elif unit_sub == 'char': self.idx2token += [Idx2char(dict_path_sub)] self.token2idx += [ Char2idx(dict_path_sub, nlsyms=nlsyms) ] elif 'phone' in unit_sub: self.idx2token += [Idx2phone(dict_path_sub)] self.token2idx += [Phone2idx(dict_path_sub)] else: raise ValueError(unit_sub) else: setattr(self, 'vocab_sub' + str(i), -1) # Load dataset tsv file self.df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t') self.df = self.df.loc[:, [ 'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim' ]] for i in range(1, 3): if locals()['tsv_path_sub' + str(i)]: df_sub = pd.read_csv(locals()['tsv_path_sub' + str(i)], encoding='utf-8', delimiter='\t') df_sub = df_sub.loc[:, [ 'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim' ]] setattr(self, 'df_sub' + str(i), df_sub) else: setattr(self, 'df_sub' + str(i), None) self.input_dim = kaldiio.load_mat(self.df['feat_path'][0]).shape[-1] if corpus == 'swbd': self.df['session'] = self.df['speaker'].apply( lambda x: str(x).split('-')[0]) # self.df['session'] = self.df['speaker'].apply(lambda x: str(x)) else: self.df['session'] = self.df['speaker'].apply(lambda x: str(x)) if contextualize or skip_thought: max_n_frames = 10000 min_n_frames = 100 # Sort by onset self.df = self.df.assign(prev_utt='') if corpus == 'swbd': self.df['onset'] = self.df['utt_id'].apply( lambda x: int(x.split('_')[-1].split('-')[0])) elif corpus == 'csj': self.df['onset'] = self.df['utt_id'].apply( lambda x: int(x.split('_')[1])) elif corpus == 'wsj': self.df['onset'] = self.df['utt_id'].apply(lambda x: x) else: raise NotImplementedError self.df = self.df.sort_values(by=['session', 'onset'], ascending=True) # Extract previous utterances if not skip_thought and not is_test: self.df = self.df.assign(line_no=list(range(len(self.df)))) groups = self.df.groupby('session').groups # dict self.df['prev_utt'] = self.df.apply(lambda x: [ self.df.loc[i, 'line_no'] for i in groups[x['session']] if self.df.loc[i, 'onset'] < x['onset'] ], axis=1) self.df['n_prev_utt'] = self.df.apply( lambda x: len(x['prev_utt']), axis=1) elif is_test and corpus == 'swbd': # Sort by onset self.df['onset'] = self.df['utt_id'].apply( lambda x: int(x.split('_')[-1].split('-')[0])) self.df = self.df.sort_values(by=['session', 'onset'], ascending=True) # Remove inappropriate utterances if is_test: print('Original utterance num: %d' % len(self.df)) n_utts = len(self.df) self.df = self.df[self.df.apply(lambda x: x['ylen'] > 0, axis=1)] print('Removed %d empty utterances' % (n_utts - len(self.df))) else: print('Original utterance num: %d' % len(self.df)) n_utts = len(self.df) self.df = self.df[self.df.apply( lambda x: min_n_frames <= x['xlen'] <= max_n_frames, axis=1)] self.df = self.df[self.df.apply(lambda x: x['ylen'] > 0, axis=1)] print('Removed %d utterances (threshold)' % (n_utts - len(self.df))) if ctc and subsample_factor > 1: n_utts = len(self.df) self.df = self.df[self.df.apply( lambda x: x['ylen'] <= (x['xlen'] // subsample_factor), axis=1)] print('Removed %d utterances (for CTC)' % (n_utts - len(self.df))) for i in range(1, 3): df_sub = getattr(self, 'df_sub' + str(i)) ctc_sub = locals()['ctc_sub' + str(i)] subsample_factor_sub = locals()['subsample_factor_sub' + str(i)] if df_sub is not None: if ctc_sub and subsample_factor_sub > 1: df_sub = df_sub[df_sub.apply( lambda x: x['ylen'] <= (x['xlen'] // subsample_factor_sub), axis=1)] if len(self.df) != len(df_sub): n_utts = len(self.df) self.df = self.df.drop( self.df.index.difference(df_sub.index)) print('Removed %d utterances (for CTC, sub%d)' % (n_utts - len(self.df), i)) for j in range(1, i + 1): setattr( self, 'df_sub' + str(j), getattr(self, 'df_sub' + str(j)).drop( getattr(self, 'df_sub' + str(j)).index.difference( self.df.index))) # Sort tsv records if not is_test: if contextualize: self.df = self.df.sort_values(by='n_prev_utt', ascending=short2long) elif sort_by_input_length: self.df = self.df.sort_values(by='xlen', ascending=short2long) elif shuffle: self.df = self.df.reindex(np.random.permutation(self.df.index)) self.rest = set(list(self.df.index))
def __init__(self, tsv_path, dict_path, unit, batch_size, n_epochs=None, is_test=False, bptt=2, wp_model=None, corpus='', shuffle=False, serialize=False): """A class for loading dataset. Args: tsv_path (str): path to the dataset tsv file dict_path (str): path to the dictionary unit (str): word or wp or char or phone or word_char batch_size (int): size of mini-batch bptt (int): BPTT length n_epochs (int): max epoch. None means infinite loop. wp_model (): path to the word-piece model for sentencepiece corpus (str): name of corpus shuffle (bool): shuffle utterances. This is disabled when sort_by_input_length is True. serialize (bool): serialize text according to contexts in dialogue """ super(Dataset, self).__init__() self.set = os.path.basename(tsv_path).split('.')[0] self.is_test = is_test self.unit = unit self.batch_size = batch_size self.bptt = bptt self.sos = 2 self.eos = 2 self.max_epoch = n_epochs self.shuffle = shuffle self.vocab = self.count_vocab_size(dict_path) assert bptt >= 2 # Set index converter if unit in ['word', 'word_char']: self.idx2word = Idx2word(dict_path) self.word2idx = Word2idx(dict_path, word_char_mix=(unit == 'word_char')) elif unit == 'wp': self.idx2wp = Idx2wp(dict_path, wp_model) self.wp2idx = Wp2idx(dict_path, wp_model) elif unit == 'char': self.idx2char = Idx2char(dict_path) self.char2idx = Char2idx(dict_path) elif 'phone' in unit: self.idx2phone = Idx2phone(dict_path) self.phone2idx = Phone2idx(dict_path) else: raise ValueError(unit) # Load dataset tsv file self.df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t') self.df = self.df.loc[:, [ 'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim' ]] self.df = self.df[self.df.apply(lambda x: x['ylen'] > 0, axis=1)] # Sort tsv records if shuffle: self.df = self.df.reindex(np.random.permutation(self.df.index)) elif serialize: assert corpus == 'swbd' self.df['session'] = self.df['speaker'].apply( lambda x: str(x).split('-')[0]) self.df['onset'] = self.df['utt_id'].apply( lambda x: int(x.split('_')[-1].split('-')[0])) self.df = self.df.sort_values(by=['session', 'onset'], ascending=True) else: self.df = self.df.sort_values(by='utt_id', ascending=True) # Concatenate into a single sentence concat_ids = [] for i in list(self.df.index): assert self.df['token_id'][i] != '' concat_ids += [self.eos] + list( map(int, self.df['token_id'][i].split())) concat_ids += [self.eos] # NOTE: <sos> and <eos> have the same index # Reshape n_utts = len(concat_ids) concat_ids = concat_ids[:n_utts // batch_size * batch_size] print('Removed %d tokens / %d tokens' % (n_utts - len(concat_ids), n_utts)) self.concat_ids = np.array(concat_ids).reshape((batch_size, -1))
def __init__(self, csv_path, dict_path, label_type, batch_size, bptt, eos, max_epoch=None, is_test=False, shuffle=False): """A class for loading dataset. Args: csv_path (str): dict_path (str): label_type (str): word or wordpiece or char or phone batch_size (int): the size of mini-batch bptt (int): eos (int): max_epoch (int): the max epoch. None means infinite loop. is_test (bool): shuffle (bool): if True, shuffle utterances. This is disabled when sort_by_input_length is True. """ super(Dataset, self).__init__() self.set = os.path.basename(csv_path).split('.')[0] self.is_test = is_test self.label_type = label_type self.batch_size = batch_size self.bptt = bptt self.max_epoch = max_epoch self.num_classes = self.count_vocab_size(dict_path) # Set index converter if label_type in ['word', 'wordpiece']: self.idx2word = Idx2word(dict_path) self.word2idx = Word2idx(dict_path) elif label_type == 'char': self.idx2char = Idx2char(dict_path) self.char2idx = Char2idx(dict_path) elif label_type == 'char_capital_divide': self.idx2char = Idx2char(dict_path, capital_divide=True) self.char2idx = Char2idx(dict_path, capital_divide=True) else: raise ValueError(label_type) # Load dataset csv file df = pd.read_csv(csv_path, encoding='utf-8') df = df.loc[:, [ 'utt_id', 'feat_path', 'x_len', 'x_dim', 'text', 'token_id', 'y_len', 'y_dim' ]] # Sort csv records if shuffle: df = df.reindex(np.random.permutation(df.index)) else: df = df.sort_values(by='utt_id', ascending=True) # Concatenate into a single sentence self.concat_ids = [] for i in list(df.index): assert df['token_id'][i] != '' self.concat_ids += list(map(int, df['token_id'][i].split())) + [eos]