Exemplo n.º 1
0
    def __init__(self, tsv_path, dict_path,
                 unit, batch_size, n_epochs=1e10,
                 is_test=False, min_n_frames=40, max_n_frames=2000,
                 shuffle_bucket=False, sort_by='utt_id',
                 short2long=False, sort_stop_epoch=1000, dynamic_batching=False,
                 corpus='',
                 tsv_path_sub1=False, tsv_path_sub2=False,
                 dict_path_sub1=False, dict_path_sub2=False, nlsyms=False,
                 unit_sub1=False, unit_sub2=False,
                 wp_model=False, wp_model_sub1=False, wp_model_sub2=False,
                 ctc=False, ctc_sub1=False, ctc_sub2=False,
                 subsample_factor=1, subsample_factor_sub1=1, subsample_factor_sub2=1,
                 discourse_aware=False, first_n_utterances=-1):
        """A class for loading dataset.

        Args:
            tsv_path (str): path to the dataset tsv file
            dict_path (str): path to the dictionary
            unit (str): word/wp/char/phone/word_char
            batch_size (int): size of mini-batch
            nlsyms (str): path to the non-linguistic symbols file
            n_epochs (int): total epochs for training.
            is_test (bool):
            min_n_frames (int): exclude utterances shorter than this value
            max_n_frames (int): exclude utterances longer than this value
            shuffle_bucket (bool): gather the similar length of utterances and shuffle them
            sort_by (str): sort all utterances in the ascending order
                input: sort by input length
                output: sort by output length
                shuffle: shuffle all utterances
            short2long (bool): sort utterances in the descending order
            sort_stop_epoch (int): After sort_stop_epoch, training will revert
                back to a random order
            dynamic_batching (bool): change batch size dynamically in training
            ctc (bool):
            subsample_factor (int):
            wp_model (): path to the word-piece model for sentencepiece
            corpus (str): name of corpus
            discourse_aware (bool):
            first_n_utterances (int): evaluate the first N utterances

        """
        super(Dataset, self).__init__()

        self.epoch = 0
        self.iteration = 0
        self.offset = 0

        self.set = os.path.basename(tsv_path).split('.')[0]
        self.is_test = is_test
        self.unit = unit
        self.unit_sub1 = unit_sub1
        self.batch_size = batch_size
        self.n_epochs = n_epochs
        self.shuffle_bucket = shuffle_bucket
        if shuffle_bucket:
            assert sort_by in ['input', 'output']
        self.sort_stop_epoch = sort_stop_epoch
        self.sort_by = sort_by
        assert sort_by in ['input', 'output', 'shuffle', 'utt_id']
        self.dynamic_batching = dynamic_batching
        self.corpus = corpus
        self.discourse_aware = discourse_aware
        if discourse_aware:
            assert not is_test

        self.vocab = count_vocab_size(dict_path)
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        self.idx2token = []
        self.token2idx = []

        # Set index converter
        if unit in ['word', 'word_char']:
            self.idx2token += [Idx2word(dict_path)]
            self.token2idx += [Word2idx(dict_path, word_char_mix=(unit == 'word_char'))]
        elif unit == 'wp':
            self.idx2token += [Idx2wp(dict_path, wp_model)]
            self.token2idx += [Wp2idx(dict_path, wp_model)]
        elif unit in ['char']:
            self.idx2token += [Idx2char(dict_path)]
            self.token2idx += [Char2idx(dict_path, nlsyms=nlsyms)]
        elif 'phone' in unit:
            self.idx2token += [Idx2phone(dict_path)]
            self.token2idx += [Phone2idx(dict_path)]
        else:
            raise ValueError(unit)

        for i in range(1, 3):
            dict_path_sub = locals()['dict_path_sub' + str(i)]
            wp_model_sub = locals()['wp_model_sub' + str(i)]
            unit_sub = locals()['unit_sub' + str(i)]
            if dict_path_sub:
                setattr(self, 'vocab_sub' + str(i), count_vocab_size(dict_path_sub))

                # Set index converter
                if unit_sub:
                    if unit_sub == 'wp':
                        self.idx2token += [Idx2wp(dict_path_sub, wp_model_sub)]
                        self.token2idx += [Wp2idx(dict_path_sub, wp_model_sub)]
                    elif unit_sub == 'char':
                        self.idx2token += [Idx2char(dict_path_sub)]
                        self.token2idx += [Char2idx(dict_path_sub, nlsyms=nlsyms)]
                    elif 'phone' in unit_sub:
                        self.idx2token += [Idx2phone(dict_path_sub)]
                        self.token2idx += [Phone2idx(dict_path_sub)]
                    else:
                        raise ValueError(unit_sub)
            else:
                setattr(self, 'vocab_sub' + str(i), -1)

        # Load dataset tsv file
        df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t')
        df = df.loc[:, ['utt_id', 'speaker', 'feat_path',
                        'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim']]
        for i in range(1, 3):
            if locals()['tsv_path_sub' + str(i)]:
                df_sub = pd.read_csv(locals()['tsv_path_sub' + str(i)], encoding='utf-8', delimiter='\t')
                df_sub = df_sub.loc[:, ['utt_id', 'speaker', 'feat_path',
                                        'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim']]
                setattr(self, 'df_sub' + str(i), df_sub)
            else:
                setattr(self, 'df_sub' + str(i), None)
        self.input_dim = kaldiio.load_mat(df['feat_path'][0]).shape[-1]

        # Remove inappropriate utterances
        if is_test or discourse_aware:
            print('Original utterance num: %d' % len(df))
            n_utts = len(df)
            df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)]
            print('Removed %d empty utterances' % (n_utts - len(df)))
            if first_n_utterances > 0:
                n_utts = len(df)
                df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)]
                df = df.truncate(before=0, after=first_n_utterances - 1)
                print('Select first %d utterances' % len(df))
        else:
            print('Original utterance num: %d' % len(df))
            n_utts = len(df)
            df = df[df.apply(lambda x: min_n_frames <= x[
                'xlen'] <= max_n_frames, axis=1)]
            df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)]
            print('Removed %d utterances (threshold)' % (n_utts - len(df)))

            if ctc and subsample_factor > 1:
                n_utts = len(df)
                df = df[df.apply(lambda x: x['ylen'] <= (x['xlen'] // subsample_factor), axis=1)]
                print('Removed %d utterances (for CTC)' % (n_utts - len(df)))

            for i in range(1, 3):
                df_sub = getattr(self, 'df_sub' + str(i))
                ctc_sub = locals()['ctc_sub' + str(i)]
                subsample_factor_sub = locals()['subsample_factor_sub' + str(i)]
                if df_sub is not None:
                    if ctc_sub and subsample_factor_sub > 1:
                        df_sub = df_sub[df_sub.apply(
                            lambda x: x['ylen'] <= (x['xlen'] // subsample_factor_sub), axis=1)]

                    if len(df) != len(df_sub):
                        n_utts = len(df)
                        df = df.drop(df.index.difference(df_sub.index))
                        print('Removed %d utterances (for CTC, sub%d)' % (n_utts - len(df), i))
                        for j in range(1, i + 1):
                            setattr(self, 'df_sub' + str(j),
                                    getattr(self, 'df_sub' + str(j)).drop(getattr(self, 'df_sub' + str(j)).index.difference(df.index)))

        if corpus == 'swbd':
            # 1. serialize
            # df['session'] = df['speaker'].apply(lambda x: str(x).split('-')[0])
            # 2. not serialize
            df['session'] = df['speaker'].apply(lambda x: str(x))
        else:
            df['session'] = df['speaker'].apply(lambda x: str(x))

        # Sort tsv records
        if discourse_aware:
            # Sort by onset (start time)
            df = df.assign(prev_utt='')
            df = df.assign(line_no=list(range(len(df))))
            if corpus == 'swbd':
                df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[-1].split('-')[0]))
            elif corpus == 'csj':
                df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[1]))
            elif corpus == 'tedlium2':
                df['onset'] = df['utt_id'].apply(lambda x: int(x.split('-')[-2]))
            else:
                raise NotImplementedError(corpus)
            df = df.sort_values(by=['session', 'onset'], ascending=True)

            # Extract previous utterances
            groups = df.groupby('session').groups
            df['prev_utt'] = df.apply(
                lambda x: [df.loc[i, 'line_no']
                           for i in groups[x['session']] if df.loc[i, 'onset'] < x['onset']], axis=1)
            df['n_prev_utt'] = df.apply(lambda x: len(x['prev_utt']), axis=1)
            df['n_utt_in_session'] = df.apply(
                lambda x: len([i for i in groups[x['session']]]), axis=1)
            df = df.sort_values(by=['n_utt_in_session'], ascending=short2long)

            # NOTE: this is used only when LM is trained with seliarize: true
            # if is_test and corpus == 'swbd':
            #     # Sort by onset
            #     df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[-1].split('-')[0]))
            #     df = df.sort_values(by=['session', 'onset'], ascending=True)

        elif not is_test:
            if sort_by == 'input':
                df = df.sort_values(by=['xlen'], ascending=short2long)
            elif sort_by == 'output':
                df = df.sort_values(by=['ylen'], ascending=short2long)
            elif sort_by == 'shuffle':
                df = df.reindex(np.random.permutation(self.df.index))

        # Re-indexing
        if discourse_aware:
            self.df = df
            for i in range(1, 3):
                if getattr(self, 'df_sub' + str(i)) is not None:
                    setattr(self, 'df_sub' + str(i),
                            getattr(self, 'df_sub' + str(i)).reindex(df.index))
        else:
            self.df = df.reset_index()
            for i in range(1, 3):
                if getattr(self, 'df_sub' + str(i)) is not None:
                    setattr(self, 'df_sub' + str(i),
                            getattr(self, 'df_sub' + str(i)).reindex(df.index).reset_index())

        if discourse_aware:
            self.df_indices_buckets = self.discourse_bucketing(batch_size)
        elif shuffle_bucket:
            self.df_indices_buckets = self.shuffle_bucketing(batch_size)
        else:
            self.df_indices = list(self.df.index)
Exemplo n.º 2
0
    def __init__(self,
                 corpus,
                 tsv_path,
                 dict_path,
                 unit,
                 nlsyms,
                 wp_model,
                 is_test,
                 min_n_frames,
                 max_n_frames,
                 sort_by,
                 short2long,
                 tsv_path_sub1,
                 tsv_path_sub2,
                 ctc,
                 ctc_sub1,
                 ctc_sub2,
                 subsample_factor,
                 subsample_factor_sub1,
                 subsample_factor_sub2,
                 dict_path_sub1,
                 dict_path_sub2,
                 unit_sub1,
                 unit_sub2,
                 wp_model_sub1,
                 wp_model_sub2,
                 discourse_aware=False,
                 simulate_longform=False,
                 first_n_utterances=-1,
                 word_alignment_dir=None,
                 ctc_alignment_dir=None):
        """Custom Dataset class.

        Args:
            corpus (str): name of corpus
            tsv_path (str): path to the dataset tsv file
            dict_path (str): path to the dictionary
            unit (str): word/wp/char/phone/word_char
            nlsyms (str): path to the non-linguistic symbols file
            wp_model (): path to the word-piece model for sentencepiece
            is_test (bool):
            min_n_frames (int): exclude utterances shorter than this value
            max_n_frames (int): exclude utterances longer than this value
            sort_by (str): sort all utterances in the ascending order
                input: sort by input length
                output: sort by output length
                shuffle: shuffle all utterances
            short2long (bool): sort utterances in the descending order
            ctc (bool):
            subsample_factor (int):
            discourse_aware (bool): sort in the discourse order
            simulate_longform (bool): simulate long-form uttterance
            first_n_utterances (int): evaluate the first N utterances
            word_alignment_dir (str): path to word alignment directory
            ctc_alignment_dir (str): path to CTC alignment directory

        """
        super(Dataset, self).__init__()

        self.epoch = 0

        # meta deta accessed by dataloader
        self._corpus = corpus
        self._set = os.path.basename(tsv_path).split('.')[0]
        self._vocab = count_vocab_size(dict_path)
        self._unit = unit
        self._unit_sub1 = unit_sub1
        self._unit_sub2 = unit_sub2

        self.is_test = is_test
        self.sort_by = sort_by
        # if shuffle_bucket:
        #     assert sort_by in ['input', 'output']
        if discourse_aware:
            assert not is_test
        if simulate_longform:
            assert is_test
        self.simulate_longform = simulate_longform

        self.subsample_factor = subsample_factor
        self.word_alignment_dir = word_alignment_dir
        self.ctc_alignment_dir = ctc_alignment_dir

        self._idx2token = []
        self._token2idx = []

        # Set index converter
        if unit in ['word', 'word_char']:
            self._idx2token += [Idx2word(dict_path)]
            self._token2idx += [
                Word2idx(dict_path, word_char_mix=(unit == 'word_char'))
            ]
        elif unit == 'wp':
            self._idx2token += [Idx2wp(dict_path, wp_model)]
            self._token2idx += [Wp2idx(dict_path, wp_model)]
        elif unit in ['char']:
            self._idx2token += [Idx2char(dict_path)]
            self._token2idx += [Char2idx(dict_path, nlsyms=nlsyms)]
        elif 'phone' in unit:
            self._idx2token += [Idx2phone(dict_path)]
            self._token2idx += [Phone2idx(dict_path)]
        else:
            raise ValueError(unit)

        for i in range(1, 3):
            dict_path_sub = locals()['dict_path_sub' + str(i)]
            wp_model_sub = locals()['wp_model_sub' + str(i)]
            unit_sub = locals()['unit_sub' + str(i)]
            if dict_path_sub:
                setattr(self, '_vocab_sub' + str(i),
                        count_vocab_size(dict_path_sub))

                # Set index converter
                if unit_sub:
                    if unit_sub == 'wp':
                        self._idx2token += [
                            Idx2wp(dict_path_sub, wp_model_sub)
                        ]
                        self._token2idx += [
                            Wp2idx(dict_path_sub, wp_model_sub)
                        ]
                    elif unit_sub == 'char':
                        self._idx2token += [Idx2char(dict_path_sub)]
                        self._token2idx += [
                            Char2idx(dict_path_sub, nlsyms=nlsyms)
                        ]
                    elif 'phone' in unit_sub:
                        self._idx2token += [Idx2phone(dict_path_sub)]
                        self._token2idx += [Phone2idx(dict_path_sub)]
                    else:
                        raise ValueError(unit_sub)
            else:
                setattr(self, '_vocab_sub' + str(i), -1)

        # Load dataset tsv file
        df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t')
        df = df.loc[:, [
            'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text',
            'token_id', 'ylen', 'ydim'
        ]]
        for i in range(1, 3):
            if locals()['tsv_path_sub' + str(i)]:
                df_sub = pd.read_csv(locals()['tsv_path_sub' + str(i)],
                                     encoding='utf-8',
                                     delimiter='\t')
                df_sub = df_sub.loc[:, [
                    'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text',
                    'token_id', 'ylen', 'ydim'
                ]]
                setattr(self, 'df_sub' + str(i), df_sub)
            else:
                setattr(self, 'df_sub' + str(i), None)
        self._input_dim = kaldiio.load_mat(df['feat_path'][0]).shape[-1]

        # Remove inappropriate utterances
        print('Original utterance num: %d' % len(df))
        n_utts = len(df)
        if is_test or discourse_aware:
            df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)]
            print('Removed %d empty utterances' % (n_utts - len(df)))
            if first_n_utterances > 0:
                n_utts = len(df)
                df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)]
                df = df.truncate(before=0, after=first_n_utterances - 1)
                print('Select first %d utterances' % len(df))
        else:
            df = df[df.apply(
                lambda x: min_n_frames <= x['xlen'] <= max_n_frames, axis=1)]
            df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)]
            print('Removed %d utterances (threshold)' % (n_utts - len(df)))

            if ctc and subsample_factor > 1:
                n_utts = len(df)
                df = df[df.apply(lambda x: x['ylen'] <=
                                 (x['xlen'] // subsample_factor),
                                 axis=1)]
                print('Removed %d utterances (for CTC)' % (n_utts - len(df)))

            for i in range(1, 3):
                df_sub = getattr(self, 'df_sub' + str(i))
                ctc_sub = locals()['ctc_sub' + str(i)]
                subsample_factor_sub = locals()['subsample_factor_sub' +
                                                str(i)]
                if df_sub is not None:
                    if ctc_sub and subsample_factor_sub > 1:
                        df_sub = df_sub[df_sub.apply(
                            lambda x: x['ylen'] <=
                            (x['xlen'] // subsample_factor_sub),
                            axis=1)]

                    if len(df) != len(df_sub):
                        n_utts = len(df)
                        df = df.drop(df.index.difference(df_sub.index))
                        print('Removed %d utterances (for CTC, sub%d)' %
                              (n_utts - len(df), i))
                        for j in range(1, i + 1):
                            setattr(
                                self, 'df_sub' + str(j),
                                getattr(self, 'df_sub' + str(j)).drop(
                                    getattr(self, 'df_sub' +
                                            str(j)).index.difference(
                                                df.index)))

        if corpus == 'swbd':
            # 1. serialize
            # df['session'] = df['speaker'].apply(lambda x: str(x).split('-')[0])
            # 2. not serialize
            df['session'] = df['speaker'].apply(lambda x: str(x))
        else:
            df['session'] = df['speaker'].apply(lambda x: str(x))

        # Sort tsv records
        if discourse_aware:
            # Sort by onset (start time)
            df = df.assign(prev_utt='')
            df = df.assign(line_no=list(range(len(df))))
            if corpus == 'swbd':
                df['onset'] = df['utt_id'].apply(
                    lambda x: int(x.split('_')[-1].split('-')[0]))
            elif corpus == 'csj':
                df['onset'] = df['utt_id'].apply(
                    lambda x: int(x.split('_')[1]))
            elif corpus == 'tedlium2':
                df['onset'] = df['utt_id'].apply(
                    lambda x: int(x.split('-')[-2]))
            else:
                raise NotImplementedError(corpus)
            df = df.sort_values(by=['session', 'onset'], ascending=True)

            # Extract previous utterances
            groups = df.groupby('session').groups
            df['prev_utt'] = df.apply(lambda x: [
                df.loc[i, 'line_no'] for i in groups[x['session']]
                if df.loc[i, 'onset'] < x['onset']
            ],
                                      axis=1)
            df['n_prev_utt'] = df.apply(lambda x: len(x['prev_utt']), axis=1)
            df['n_utt_in_session'] = df.apply(
                lambda x: len([i for i in groups[x['session']]]), axis=1)
            df = df.sort_values(by=['n_utt_in_session'], ascending=short2long)

            # NOTE: this is used only when LM is trained with serialize: true
            # if is_test and corpus == 'swbd':
            #     # Sort by onset
            #     df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[-1].split('-')[0]))
            #     df = df.sort_values(by=['session', 'onset'], ascending=True)

        elif not is_test:
            if sort_by == 'input':
                df = df.sort_values(by=['xlen'], ascending=short2long)
            elif sort_by == 'output':
                df = df.sort_values(by=['ylen'], ascending=short2long)
            elif sort_by == 'shuffle':
                df = df.reindex(np.random.permutation(self.df.index))

        # Fit word alignment to vocabulary
        if word_alignment_dir is not None:
            alignment2boundary = WordAlignmentConverter(dict_path, wp_model)
            n_utts = len(df)
            df['trigger_points'] = df.apply(lambda x: alignment2boundary(
                word_alignment_dir, x['speaker'], x['utt_id'], x['text']),
                                            axis=1)
            # remove utterances which do not have the alignment
            df = df[df.apply(lambda x: x['trigger_points'] is not None,
                             axis=1)]
            print('Removed %d utterances (for word alignment)' %
                  (n_utts - len(df)))
        elif ctc_alignment_dir is not None:
            n_utts = len(df)
            df['trigger_points'] = df.apply(lambda x: load_ctc_alignment(
                ctc_alignment_dir, x['speaker'], x['utt_id']),
                                            axis=1)
            # remove utterances which do not have the alignment
            df = df[df.apply(lambda x: x['trigger_points'] is not None,
                             axis=1)]
            print('Removed %d utterances (for CTC alignment)' %
                  (n_utts - len(df)))

        # Re-indexing
        if discourse_aware:
            self.df = df
            for i in range(1, 3):
                if getattr(self, 'df_sub' + str(i)) is not None:
                    setattr(self, 'df_sub' + str(i),
                            getattr(self, 'df_sub' + str(i)).reindex(df.index))
        else:
            self.df = df.reset_index()
            for i in range(1, 3):
                if getattr(self, 'df_sub' + str(i)) is not None:
                    setattr(
                        self, 'df_sub' + str(i),
                        getattr(self, 'df_sub' + str(i)).reindex(
                            df.index).reset_index())
Exemplo n.º 3
0
    def __init__(self,
                 tsv_path,
                 dict_path,
                 unit,
                 batch_size,
                 nlsyms=False,
                 n_epochs=None,
                 is_test=False,
                 min_n_tokens=1,
                 bptt=2,
                 shuffle=False,
                 backward=False,
                 serialize=False,
                 wp_model=None,
                 corpus=''):
        """A class for loading dataset.

        Args:
            tsv_path (str): path to the dataset tsv file
            dict_path (str): path to the dictionary
            unit (str): word or wp or char or phone or word_char
            batch_size (int): size of mini-batch
            nlsyms (str): path to the non-linguistic symbols file
            n_epochs (int): max epoch. None means infinite loop
            is_test (bool):
            min_n_tokens (int): exclude utterances shorter than this value
            bptt (int): BPTT length
            shuffle (bool): shuffle utterances.
                This is disabled when sort_by_input_length is True.
            backward (bool): flip all text in the corpus
            serialize (bool): serialize text according to contexts in dialogue
            wp_model (): path to the word-piece model for sentencepiece
            corpus (str): name of corpus


        """
        super(Dataset, self).__init__()

        self.set = os.path.basename(tsv_path).split('.')[0]
        self.is_test = is_test
        self.unit = unit
        self.batch_size = batch_size
        self.bptt = bptt
        self.sos = 2
        self.eos = 2
        self.max_epoch = n_epochs
        self.shuffle = shuffle
        self.vocab = self.count_vocab_size(dict_path)
        assert bptt >= 2

        self.idx2token = []
        self.token2idx = []

        # Set index converter
        if unit in ['word', 'word_char']:
            self.idx2token += [Idx2word(dict_path)]
            self.token2idx += [
                Word2idx(dict_path, word_char_mix=(unit == 'word_char'))
            ]
        elif unit == 'wp':
            self.idx2token += [Idx2wp(dict_path, wp_model)]
            self.token2idx += [Wp2idx(dict_path, wp_model)]
        elif unit == 'char':
            self.idx2token += [Idx2char(dict_path)]
            self.token2idx += [Char2idx(dict_path, nlsyms=nlsyms)]
        elif 'phone' in unit:
            self.idx2token += [Idx2phone(dict_path)]
            self.token2idx += [Phone2idx(dict_path)]
        else:
            raise ValueError(unit)

        # Load dataset tsv file
        self.df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t')
        self.df = self.df.loc[:, [
            'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text',
            'token_id', 'ylen', 'ydim'
        ]]

        # Remove inappropriate utterances
        if is_test:
            print('Original utterance num: %d' % len(self.df))
            n_utts = len(self.df)
            self.df = self.df[self.df.apply(lambda x: x['ylen'] > 0, axis=1)]
            print('Removed %d empty utterances' % (n_utts - len(self.df)))
        else:
            print('Original utterance num: %d' % len(self.df))
            n_utts = len(self.df)
            self.df = self.df[self.df.apply(
                lambda x: x['ylen'] >= min_n_tokens, axis=1)]
            print('Removed %d utterances (threshold)' %
                  (n_utts - len(self.df)))

        # Sort tsv records
        if shuffle:
            self.df = self.df.reindex(np.random.permutation(self.df.index))
        elif serialize:
            assert corpus == 'swbd'
            self.df['session'] = self.df['speaker'].apply(
                lambda x: str(x).split('-')[0])
            self.df['onset'] = self.df['utt_id'].apply(
                lambda x: int(x.split('_')[-1].split('-')[0]))
            self.df = self.df.sort_values(by=['session', 'onset'],
                                          ascending=True)
        else:
            self.df = self.df.sort_values(by='utt_id', ascending=True)

        # Concatenate into a single sentence
        concat_ids = []
        indices = list(self.df.index)
        if backward:
            indices = indices[::-1]
        for i in indices:
            assert self.df['token_id'][i] != ''
            concat_ids += [self.eos] + list(
                map(int, self.df['token_id'][i].split()))
        concat_ids += [self.eos]
        # NOTE: <sos> and <eos> have the same index

        # Reshape
        n_utts = len(concat_ids)
        concat_ids = concat_ids[:n_utts // batch_size * batch_size]
        print('Removed %d tokens / %d tokens' %
              (n_utts - len(concat_ids), n_utts))
        self.concat_ids = np.array(concat_ids).reshape((batch_size, -1))
Exemplo n.º 4
0
    def __init__(self, tsv_path, dict_path,
                 unit, batch_size, nlsyms=False, n_epochs=None,
                 is_test=False, min_n_frames=40, max_n_frames=2000,
                 shuffle_bucket=False, sort_by='utt_id',
                 short2long=False, sort_stop_epoch=None, dynamic_batching=False,
                 ctc=False, subsample_factor=1, wp_model=False, corpus='',
                 tsv_path_sub1=False, dict_path_sub1=False, unit_sub1=False,
                 wp_model_sub1=False, ctc_sub1=False, subsample_factor_sub1=1,
                 tsv_path_sub2=False, dict_path_sub2=False, unit_sub2=False,
                 wp_model_sub2=False, ctc_sub2=False, subsample_factor_sub2=1,
                 discourse_aware=False, skip_thought=False):
        """A class for loading dataset.

        Args:
            tsv_path (str): path to the dataset tsv file
            dict_path (str): path to the dictionary
            unit (str): word or wp or char or phone or word_char
            batch_size (int): size of mini-batch
            nlsyms (str): path to the non-linguistic symbols file
            n_epochs (int): max epoch. None means infinite loop.
            is_test (bool):
            min_n_frames (int): exclude utterances shorter than this value
            max_n_frames (int): exclude utterances longer than this value
            shuffle_bucket (bool): gather the similar length of utterances and shuffle them
            sort_by (str): sort all utterances in the ascending order
                input: sort by input length
                output: sort by output length
                shuffle: shuffle all utterances
            short2long (bool): sort utterances in the descending order
            sort_stop_epoch (int): After sort_stop_epoch, training will revert
                back to a random order
            dynamic_batching (bool): change batch size dynamically in training
            ctc (bool):
            subsample_factor (int):
            wp_model (): path to the word-piece model for sentencepiece
            corpus (str): name of corpus
            discourse_aware (bool):
            skip_thought (bool):

        """
        super(Dataset, self).__init__()

        self.epoch = 0
        self.iteration = 0
        self.offset = 0

        self.set = os.path.basename(tsv_path).split('.')[0]
        self.is_test = is_test
        self.unit = unit
        self.unit_sub1 = unit_sub1
        self.batch_size = batch_size
        self.max_epoch = n_epochs
        self.shuffle_bucket = shuffle_bucket
        if shuffle_bucket:
            assert sort_by in ['input', 'output']
        self.sort_stop_epoch = sort_stop_epoch
        self.sort_by = sort_by
        assert sort_by in ['input', 'output', 'shuffle', 'utt_id', 'no_sort']
        self.dynamic_batching = dynamic_batching
        self.corpus = corpus
        self.discourse_aware = discourse_aware
        self.skip_thought = skip_thought

        self.vocab = count_vocab_size(dict_path)
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        self.idx2token = []
        self.token2idx = []

        # Set index converter
        if unit in ['word', 'word_char']:
            self.idx2token += [Idx2word(dict_path)]
            self.token2idx += [Word2idx(dict_path, word_char_mix=(unit == 'word_char'))]
        elif unit == 'wp':
            self.idx2token += [Idx2wp(dict_path, wp_model)]
            self.token2idx += [Wp2idx(dict_path, wp_model)]
        elif unit == 'char':
            self.idx2token += [Idx2char(dict_path)]
            self.token2idx += [Char2idx(dict_path, nlsyms=nlsyms)]
        elif 'phone' in unit:
            self.idx2token += [Idx2phone(dict_path)]
            self.token2idx += [Phone2idx(dict_path)]
        else:
            raise ValueError(unit)

        for i in range(1, 3):
            dict_path_sub = locals()['dict_path_sub' + str(i)]
            wp_model_sub = locals()['wp_model_sub' + str(i)]
            unit_sub = locals()['unit_sub' + str(i)]
            if dict_path_sub:
                setattr(self, 'vocab_sub' + str(i), count_vocab_size(dict_path_sub))

                # Set index converter
                if unit_sub:
                    if unit_sub == 'wp':
                        self.idx2token += [Idx2wp(dict_path_sub, wp_model_sub)]
                        self.token2idx += [Wp2idx(dict_path_sub, wp_model_sub)]
                    elif unit_sub == 'char':
                        self.idx2token += [Idx2char(dict_path_sub)]
                        self.token2idx += [Char2idx(dict_path_sub, nlsyms=nlsyms)]
                    elif 'phone' in unit_sub:
                        self.idx2token += [Idx2phone(dict_path_sub)]
                        self.token2idx += [Phone2idx(dict_path_sub)]
                    else:
                        raise ValueError(unit_sub)
            else:
                setattr(self, 'vocab_sub' + str(i), -1)

        # Load dataset tsv file
        df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t', quoting=csv.QUOTE_NONE, dtype={'utt_id': 'str'})
        df = df.loc[:, ['utt_id', 'speaker', 'feat_path',
                                  'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim']]
        for i in range(1, 3):
            if locals()['tsv_path_sub' + str(i)]:
                df_sub = pd.read_csv(locals()['tsv_path_sub' + str(i)], encoding='utf-8', delimiter='\t')
                df_sub = df_sub.loc[:, ['utt_id', 'speaker', 'feat_path',
                                        'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim']]
                setattr(self, 'df_sub' + str(i), df_sub)
            else:
                setattr(self, 'df_sub' + str(i), None)
        self.input_dim = kaldiio.load_mat(df['feat_path'][0]).shape[-1]

        if corpus == 'swbd':
            df['session'] = df['speaker'].apply(lambda x: str(x).split('-')[0])
        else:
            df['session'] = df['speaker'].apply(lambda x: str(x))

        if discourse_aware or skip_thought:
            max_n_frames = 10000
            min_n_frames = 100

            # Sort by onset
            df = df.assign(prev_utt='')
            if corpus == 'swbd':
                df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[-1].split('-')[0]))
            elif corpus == 'csj':
                df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[1]))
            elif corpus == 'wsj':
                df['onset'] = df['utt_id'].apply(lambda x: x)
            else:
                raise NotImplementedError
            df = df.sort_values(by=['session', 'onset'], ascending=True)

            # Extract previous utterances
            if not skip_thought:
                # df = df.assign(line_no=list(range(len(df))))
                groups = df.groupby('session').groups
                df['n_session_utt'] = df.apply(
                    lambda x: len([i for i in groups[x['session']]]), axis=1)

                # df['prev_utt'] = df.apply(
                #     lambda x: [df.loc[i, 'line_no']
                #                for i in groups[x['session']] if df.loc[i, 'onset'] < x['onset']], axis=1)
                # df['n_prev_utt'] = df.apply(lambda x: len(x['prev_utt']), axis=1)

        elif is_test and corpus == 'swbd':
            # Sort by onset
            df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[-1].split('-')[0]))
            df = df.sort_values(by=['session', 'onset'], ascending=True)

        # Remove inappropriate utterances
        if is_test:
            print('Original utterance num: %d' % len(df))
            n_utts = len(df)
            df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)]
            print('Removed %d empty utterances' % (n_utts - len(df)))
        else:
            print('Original utterance num: %d' % len(df))
            n_utts = len(df)
            df = df[df.apply(lambda x: min_n_frames <= x[
                'xlen'] <= max_n_frames, axis=1)]
            df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)]
            print('Removed %d utterances (threshold)' % (n_utts - len(df)))

            if ctc and subsample_factor > 1:
                n_utts = len(df)
                df = df[df.apply(lambda x: x['ylen'] <= (x['xlen'] // subsample_factor), axis=1)]
                print('Removed %d utterances (for CTC)' % (n_utts - len(df)))

            for i in range(1, 3):
                df_sub = getattr(self, 'df_sub' + str(i))
                ctc_sub = locals()['ctc_sub' + str(i)]
                subsample_factor_sub = locals()['subsample_factor_sub' + str(i)]
                if df_sub is not None:
                    if ctc_sub and subsample_factor_sub > 1:
                        df_sub = df_sub[df_sub.apply(
                            lambda x: x['ylen'] <= (x['xlen'] // subsample_factor_sub), axis=1)]

                    if len(df) != len(df_sub):
                        n_utts = len(df)
                        df = df.drop(df.index.difference(df_sub.index))
                        print('Removed %d utterances (for CTC, sub%d)' % (n_utts - len(df), i))
                        for j in range(1, i + 1):
                            setattr(self, 'df_sub' + str(j),
                                    getattr(self, 'df_sub' + str(j)).drop(getattr(self, 'df_sub' + str(j)).index.difference(df.index)))

            # Re-indexing
            for i in range(1, 3):
                if getattr(self, 'df_sub' + str(i)) is not None:
                    setattr(self, 'df_sub' + str(i), getattr(self, 'df_sub' + str(i)).reset_index())

        # Sort tsv records
        if not is_test:
            if discourse_aware:
                self.utt_offset = 0
                self.n_utt_session_dict = {}
                self.session_offset_dict = {}
                for session_id, ids in sorted(df.groupby('session').groups.items(), key=lambda x: len(x[1])):
                    n_utt = len(ids)
                    # key: n_utt, value: session_id
                    if n_utt not in self.n_utt_session_dict.keys():
                        self.n_utt_session_dict[n_utt] = []
                    self.n_utt_session_dict[n_utt].append(session_id)

                    # key: session_id, value: id for the first utterance in each session
                    self.session_offset_dict[session_id] = ids[0]

                self.n_utt_session_dict_epoch = copy.deepcopy(self.n_utt_session_dict)
                # if discourse_aware == 'state_carry_over':
                #     df = df.sort_values(by=['n_session_utt', 'utt_id'], ascending=short2long)
                # else:
                #     df = df.sort_values(by=['n_prev_utt'], ascending=short2long)
            elif sort_by == 'input':
                df = df.sort_values(by=['xlen'], ascending=short2long)
            elif sort_by == 'output':
                df = df.sort_values(by=['ylen'], ascending=short2long)
            elif sort_by == 'shuffle':
                df = df.reindex(np.random.permutation(self.df.index))
            elif sort_by == 'no_sort':
                pass

        for i in range(1, 3):
            if getattr(self, 'df_sub' + str(i)) is not None:
                setattr(self, 'df_sub' + str(i),
                        getattr(self, 'df_sub' + str(i)).reindex(df.index).reset_index())

        # Re-indexing
        self.df = df.reset_index()
        self.df_indices = list(self.df.index)
Exemplo n.º 5
0
    def __init__(self, csv_path, dict_path,
                 label_type, batch_size, max_epoch=None,
                 is_test=False, max_num_frames=2000, min_num_frames=40,
                 shuffle=False, sort_by_input_length=False,
                 short2long=False, sort_stop_epoch=None,
                 num_enques=None, dynamic_batching=False,
                 use_ctc=False, subsample_factor=1,
                 skip_speech=False,
                 csv_path_sub=None, dict_path_sub=None, label_type_sub=None,
                 use_ctc_sub=False, subsample_factor_sub=1):
        """A class for loading dataset.

        Args:
            csv_path (str):
            dict_path (str):
            label_type (str): word or wordpiece or char or phone
            batch_size (int): the size of mini-batch
            max_epoch (int): the max epoch. None means infinite loop.
            is_test (bool):
            max_num_frames (int): Exclude utteraces longer than this value
            min_num_frames (int): Exclude utteraces shorter than this value
            shuffle (bool): if True, shuffle utterances.
                This is disabled when sort_by_input_length is True.
            sort_by_input_length (bool): if True, sort all utterances in the ascending order
            short2long (bool): if True, sort utteraces in the descending order
            sort_stop_epoch (int): After sort_stop_epoch, training will revert
                back to a random order
            num_enques (int): the number of elements to enqueue
            dynamic_batching (bool): if True, batch size will be chainged
                dynamically in training
            use_ctc (bool):
            subsample_factor (int):
            skip_speech (bool): skip loading speech features

        """
        super(Dataset, self).__init__()

        self.set = os.path.basename(csv_path).split('.')[0]
        self.is_test = is_test
        self.label_type = label_type
        self.label_type_sub = label_type_sub
        self.batch_size = batch_size
        self.max_epoch = max_epoch
        self.shuffle = shuffle
        self.sort_by_input_length = sort_by_input_length
        self.sort_stop_epoch = sort_stop_epoch
        self.num_enques = num_enques
        self.dynamic_batching = dynamic_batching
        self.skip_speech = skip_speech
        self.num_classes = self.count_vocab_size(dict_path)

        # Set index converter
        if label_type in ['word', 'wordpiece']:
            self.idx2word = Idx2word(dict_path)
            self.word2idx = Word2idx(dict_path)
        elif label_type == 'char':
            self.idx2char = Idx2char(dict_path)
            self.char2idx = Char2idx(dict_path)
        elif label_type == 'char_capital_divide':
            self.idx2char = Idx2char(dict_path, capital_divide=True)
            self.char2idx = Char2idx(dict_path, capital_divide=True)
        elif 'phone' in label_type:
            self.idx2phone = Idx2phone(dict_path)
            self.phone2idx = Phone2idx(dict_path)
        else:
            raise ValueError(label_type)

        if dict_path_sub is not None:
            self.num_classes_sub = self.count_vocab_size(dict_path_sub)

            # Set index converter
            if label_type_sub is not None:
                if label_type == 'wordpiece':
                    self.idx2word = Idx2word(dict_path_sub)
                    self.word2idx = Word2idx(dict_path_sub)
                elif label_type_sub == 'char':
                    self.idx2char = Idx2char(dict_path_sub)
                    self.char2idx = Char2idx(dict_path_sub)
                elif label_type_sub == 'char_capital_divide':
                    self.idx2char = Idx2char(dict_path_sub, capital_divide=True)
                    self.char2idx = Char2idx(dict_path_sub, capital_divide=True)
                elif 'phone' in label_type_sub:
                    self.idx2phone = Idx2phone(dict_path_sub)
                    self.phone2idx = Phone2idx(dict_path_sub)
                else:
                    raise ValueError(label_type_sub)
        else:
            self.num_classes_sub = -1

        # Load dataset csv file
        df = pd.read_csv(csv_path, encoding='utf-8')
        df = df.loc[:, ['utt_id', 'feat_path', 'x_len', 'x_dim', 'text', 'token_id', 'y_len', 'y_dim']]
        if csv_path_sub is not None:
            df_sub = pd.read_csv(csv_path_sub, encoding='utf-8')
            df_sub = df_sub.loc[:, ['utt_id', 'feat_path', 'x_len', 'x_dim', 'text', 'token_id', 'y_len', 'y_dim']]
        else:
            df_sub = None

        # Remove inappropriate utteraces
        if not self.is_test:
            logger.info('Original utterance num: %d' % len(df))
            num_utt_org = len(df)

            # Remove by threshold
            df = df[df.apply(lambda x: min_num_frames <= x['x_len'] <= max_num_frames, axis=1)]
            logger.info('Removed %d utterances (threshold)' % (num_utt_org - len(df)))

            # Rempve for CTC loss calculatioon
            if use_ctc and subsample_factor > 1:
                logger.info('Checking utterances for CTC...')
                logger.info('Original utterance num: %d' % len(df))
                num_utt_org = len(df)
                df = df[df.apply(lambda x: x['y_len'] <= x['x_len'] // subsample_factor, axis=1)]
                logger.info('Removed %d utterances (for CTC)' % (num_utt_org - len(df)))

            if df_sub is not None:
                logger.info('Original utterance num (sub): %d' % len(df_sub))
                num_utt_org = len(df_sub)

                # Remove by threshold
                df_sub = df_sub[df_sub.apply(lambda x: min_num_frames <= x['x_len'] <= max_num_frames, axis=1)]
                logger.info('Removed %d utterances (threshold, sub)' % (num_utt_org - len(df_sub)))

                # Rempve for CTC loss calculatioon
                if use_ctc_sub and subsample_factor_sub > 1:
                    logger.info('Checking utterances for CTC...')
                    logger.info('Original utterance num (sub): %d' % len(df_sub))
                    num_utt_org = len(df_sub)
                    df_sub = df_sub[df_sub.apply(lambda x: x['y_len'] <= x['x_len'] // subsample_factor_sub, axis=1)]
                    logger.info('Removed %d utterances (for CTC, sub)' % (num_utt_org - len(df_sub)))

                # Make up the number
                if len(df) != len(df_sub):
                    df = df.drop(df.index.difference(df_sub.index))
                    df_sub = df_sub.drop(df_sub.index.difference(df.index))

        # Sort csv records
        if sort_by_input_length:
            df = df.sort_values(by='x_len', ascending=short2long)
        else:
            if shuffle:
                df = df.reindex(np.random.permutation(df.index))
            else:
                df = df.sort_values(by='utt_id', ascending=True)

        self.df = df
        self.df_sub = df_sub
        self.rest = set(list(df.index))
        self.input_dim = kaldi_io.read_mat(self.df['feat_path'][0]).shape[-1]
Exemplo n.º 6
0
    def __init__(self,
                 tsv_path,
                 dict_path,
                 unit,
                 batch_size,
                 nlsyms=False,
                 n_epochs=None,
                 is_test=False,
                 min_n_frames=40,
                 max_n_frames=2000,
                 shuffle=False,
                 sort_by_input_length=False,
                 short2long=False,
                 sort_stop_epoch=None,
                 n_ques=None,
                 dynamic_batching=False,
                 ctc=False,
                 subsample_factor=1,
                 wp_model=False,
                 corpus='',
                 tsv_path_sub1=False,
                 dict_path_sub1=False,
                 unit_sub1=False,
                 wp_model_sub1=False,
                 ctc_sub1=False,
                 subsample_factor_sub1=1,
                 wp_model_sub2=False,
                 tsv_path_sub2=False,
                 dict_path_sub2=False,
                 unit_sub2=False,
                 ctc_sub2=False,
                 subsample_factor_sub2=1,
                 contextualize=False,
                 skip_thought=False):
        """A class for loading dataset.

        Args:
            tsv_path (str): path to the dataset tsv file
            dict_path (str): path to the dictionary
            unit (str): word or wp or char or phone or word_char
            batch_size (int): size of mini-batch
            nlsyms (str): path to the non-linguistic symbols file
            n_epochs (int): max epoch. None means infinite loop.
            is_test (bool):
            min_n_frames (int): exclude utterances shorter than this value
            max_n_frames (int): exclude utterances longer than this value
            shuffle (bool): shuffle utterances.
                This is disabled when sort_by_input_length is True.
            sort_by_input_length (bool): sort all utterances in the ascending order
            short2long (bool): sort utterances in the descending order
            sort_stop_epoch (int): After sort_stop_epoch, training will revert
                back to a random order
            n_ques (int): number of elements to enqueue
            dynamic_batching (bool): change batch size dynamically in training
            ctc (bool):
            subsample_factor (int):
            wp_model (): path to the word-piece model for sentencepiece
            corpus (str): name of corpus
            contextualize (bool):
            skip_thought (bool):

        """
        super(Dataset, self).__init__()

        self.set = os.path.basename(tsv_path).split('.')[0]
        self.is_test = is_test
        self.unit = unit
        self.unit_sub1 = unit_sub1
        self.batch_size = batch_size
        self.max_epoch = n_epochs
        self.shuffle = shuffle
        self.sort_stop_epoch = sort_stop_epoch
        self.sort_by_input_length = sort_by_input_length
        self.n_ques = n_ques
        self.dynamic_batching = dynamic_batching
        self.corpus = corpus
        self.contextualize = contextualize
        self.skip_thought = skip_thought

        self.vocab = self.count_vocab_size(dict_path)
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        self.idx2token = []
        self.token2idx = []

        # Set index converter
        if unit in ['word', 'word_char']:
            self.idx2token += [Idx2word(dict_path)]
            self.token2idx += [
                Word2idx(dict_path, word_char_mix=(unit == 'word_char'))
            ]
        elif unit == 'wp':
            self.idx2token += [Idx2wp(dict_path, wp_model)]
            self.token2idx += [Wp2idx(dict_path, wp_model)]
        elif unit == 'char':
            self.idx2token += [Idx2char(dict_path)]
            self.token2idx += [Char2idx(dict_path, nlsyms=nlsyms)]
        elif 'phone' in unit:
            self.idx2token += [Idx2phone(dict_path)]
            self.token2idx += [Phone2idx(dict_path)]
        else:
            raise ValueError(unit)

        for i in range(1, 3):
            dict_path_sub = locals()['dict_path_sub' + str(i)]
            wp_model_sub = locals()['wp_model_sub' + str(i)]
            unit_sub = locals()['unit_sub' + str(i)]
            if dict_path_sub:
                setattr(self, 'vocab_sub' + str(i),
                        self.count_vocab_size(dict_path_sub))

                # Set index converter
                if unit_sub:
                    if unit_sub == 'wp':
                        self.idx2token += [Idx2wp(dict_path_sub, wp_model_sub)]
                        self.token2idx += [Wp2idx(dict_path_sub, wp_model_sub)]
                    elif unit_sub == 'char':
                        self.idx2token += [Idx2char(dict_path_sub)]
                        self.token2idx += [
                            Char2idx(dict_path_sub, nlsyms=nlsyms)
                        ]
                    elif 'phone' in unit_sub:
                        self.idx2token += [Idx2phone(dict_path_sub)]
                        self.token2idx += [Phone2idx(dict_path_sub)]
                    else:
                        raise ValueError(unit_sub)
            else:
                setattr(self, 'vocab_sub' + str(i), -1)

        # Load dataset tsv file
        self.df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t')
        self.df = self.df.loc[:, [
            'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text',
            'token_id', 'ylen', 'ydim'
        ]]
        for i in range(1, 3):
            if locals()['tsv_path_sub' + str(i)]:
                df_sub = pd.read_csv(locals()['tsv_path_sub' + str(i)],
                                     encoding='utf-8',
                                     delimiter='\t')
                df_sub = df_sub.loc[:, [
                    'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text',
                    'token_id', 'ylen', 'ydim'
                ]]
                setattr(self, 'df_sub' + str(i), df_sub)
            else:
                setattr(self, 'df_sub' + str(i), None)
        self.input_dim = kaldiio.load_mat(self.df['feat_path'][0]).shape[-1]

        if corpus == 'swbd':
            self.df['session'] = self.df['speaker'].apply(
                lambda x: str(x).split('-')[0])
            # self.df['session'] = self.df['speaker'].apply(lambda x: str(x))
        else:
            self.df['session'] = self.df['speaker'].apply(lambda x: str(x))

        if contextualize or skip_thought:
            max_n_frames = 10000
            min_n_frames = 100

            # Sort by onset
            self.df = self.df.assign(prev_utt='')
            if corpus == 'swbd':
                self.df['onset'] = self.df['utt_id'].apply(
                    lambda x: int(x.split('_')[-1].split('-')[0]))
            elif corpus == 'csj':
                self.df['onset'] = self.df['utt_id'].apply(
                    lambda x: int(x.split('_')[1]))
            elif corpus == 'wsj':
                self.df['onset'] = self.df['utt_id'].apply(lambda x: x)
            else:
                raise NotImplementedError
            self.df = self.df.sort_values(by=['session', 'onset'],
                                          ascending=True)

            # Extract previous utterances
            if not skip_thought and not is_test:
                self.df = self.df.assign(line_no=list(range(len(self.df))))
                groups = self.df.groupby('session').groups  # dict
                self.df['prev_utt'] = self.df.apply(lambda x: [
                    self.df.loc[i, 'line_no'] for i in groups[x['session']]
                    if self.df.loc[i, 'onset'] < x['onset']
                ],
                                                    axis=1)
                self.df['n_prev_utt'] = self.df.apply(
                    lambda x: len(x['prev_utt']), axis=1)

        elif is_test and corpus == 'swbd':
            # Sort by onset
            self.df['onset'] = self.df['utt_id'].apply(
                lambda x: int(x.split('_')[-1].split('-')[0]))
            self.df = self.df.sort_values(by=['session', 'onset'],
                                          ascending=True)

        # Remove inappropriate utterances
        if is_test:
            print('Original utterance num: %d' % len(self.df))
            n_utts = len(self.df)
            self.df = self.df[self.df.apply(lambda x: x['ylen'] > 0, axis=1)]
            print('Removed %d empty utterances' % (n_utts - len(self.df)))
        else:
            print('Original utterance num: %d' % len(self.df))
            n_utts = len(self.df)
            self.df = self.df[self.df.apply(
                lambda x: min_n_frames <= x['xlen'] <= max_n_frames, axis=1)]
            self.df = self.df[self.df.apply(lambda x: x['ylen'] > 0, axis=1)]
            print('Removed %d utterances (threshold)' %
                  (n_utts - len(self.df)))

            if ctc and subsample_factor > 1:
                n_utts = len(self.df)
                self.df = self.df[self.df.apply(
                    lambda x: x['ylen'] <= (x['xlen'] // subsample_factor),
                    axis=1)]
                print('Removed %d utterances (for CTC)' %
                      (n_utts - len(self.df)))

            for i in range(1, 3):
                df_sub = getattr(self, 'df_sub' + str(i))
                ctc_sub = locals()['ctc_sub' + str(i)]
                subsample_factor_sub = locals()['subsample_factor_sub' +
                                                str(i)]
                if df_sub is not None:
                    if ctc_sub and subsample_factor_sub > 1:
                        df_sub = df_sub[df_sub.apply(
                            lambda x: x['ylen'] <=
                            (x['xlen'] // subsample_factor_sub),
                            axis=1)]

                    if len(self.df) != len(df_sub):
                        n_utts = len(self.df)
                        self.df = self.df.drop(
                            self.df.index.difference(df_sub.index))
                        print('Removed %d utterances (for CTC, sub%d)' %
                              (n_utts - len(self.df), i))
                        for j in range(1, i + 1):
                            setattr(
                                self, 'df_sub' + str(j),
                                getattr(self, 'df_sub' + str(j)).drop(
                                    getattr(self, 'df_sub' +
                                            str(j)).index.difference(
                                                self.df.index)))

        # Sort tsv records
        if not is_test:
            if contextualize:
                self.df = self.df.sort_values(by='n_prev_utt',
                                              ascending=short2long)
            elif sort_by_input_length:
                self.df = self.df.sort_values(by='xlen', ascending=short2long)
            elif shuffle:
                self.df = self.df.reindex(np.random.permutation(self.df.index))

        self.rest = set(list(self.df.index))
Exemplo n.º 7
0
    def __init__(self,
                 tsv_path,
                 dict_path,
                 unit,
                 batch_size,
                 n_epochs=None,
                 is_test=False,
                 bptt=2,
                 wp_model=None,
                 corpus='',
                 shuffle=False,
                 serialize=False):
        """A class for loading dataset.

        Args:
            tsv_path (str): path to the dataset tsv file
            dict_path (str): path to the dictionary
            unit (str): word or wp or char or phone or word_char
            batch_size (int): size of mini-batch
            bptt (int): BPTT length
            n_epochs (int): max epoch. None means infinite loop.
            wp_model (): path to the word-piece model for sentencepiece
            corpus (str): name of corpus
            shuffle (bool): shuffle utterances.
                This is disabled when sort_by_input_length is True.
            serialize (bool): serialize text according to contexts in dialogue

        """
        super(Dataset, self).__init__()

        self.set = os.path.basename(tsv_path).split('.')[0]
        self.is_test = is_test
        self.unit = unit
        self.batch_size = batch_size
        self.bptt = bptt
        self.sos = 2
        self.eos = 2
        self.max_epoch = n_epochs
        self.shuffle = shuffle
        self.vocab = self.count_vocab_size(dict_path)
        assert bptt >= 2

        # Set index converter
        if unit in ['word', 'word_char']:
            self.idx2word = Idx2word(dict_path)
            self.word2idx = Word2idx(dict_path,
                                     word_char_mix=(unit == 'word_char'))
        elif unit == 'wp':
            self.idx2wp = Idx2wp(dict_path, wp_model)
            self.wp2idx = Wp2idx(dict_path, wp_model)
        elif unit == 'char':
            self.idx2char = Idx2char(dict_path)
            self.char2idx = Char2idx(dict_path)
        elif 'phone' in unit:
            self.idx2phone = Idx2phone(dict_path)
            self.phone2idx = Phone2idx(dict_path)
        else:
            raise ValueError(unit)

        # Load dataset tsv file
        self.df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t')
        self.df = self.df.loc[:, [
            'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text',
            'token_id', 'ylen', 'ydim'
        ]]
        self.df = self.df[self.df.apply(lambda x: x['ylen'] > 0, axis=1)]

        # Sort tsv records
        if shuffle:
            self.df = self.df.reindex(np.random.permutation(self.df.index))
        elif serialize:
            assert corpus == 'swbd'
            self.df['session'] = self.df['speaker'].apply(
                lambda x: str(x).split('-')[0])
            self.df['onset'] = self.df['utt_id'].apply(
                lambda x: int(x.split('_')[-1].split('-')[0]))
            self.df = self.df.sort_values(by=['session', 'onset'],
                                          ascending=True)
        else:
            self.df = self.df.sort_values(by='utt_id', ascending=True)

        # Concatenate into a single sentence
        concat_ids = []
        for i in list(self.df.index):
            assert self.df['token_id'][i] != ''
            concat_ids += [self.eos] + list(
                map(int, self.df['token_id'][i].split()))
        concat_ids += [self.eos]
        # NOTE: <sos> and <eos> have the same index

        # Reshape
        n_utts = len(concat_ids)
        concat_ids = concat_ids[:n_utts // batch_size * batch_size]
        print('Removed %d tokens / %d tokens' %
              (n_utts - len(concat_ids), n_utts))
        self.concat_ids = np.array(concat_ids).reshape((batch_size, -1))