Example #1
0
    def __init__(self,
                 tsv_path,
                 dict_path,
                 unit,
                 batch_size,
                 nlsyms=False,
                 n_epochs=1e10,
                 is_test=False,
                 min_n_tokens=1,
                 bptt=2,
                 shuffle=False,
                 backward=False,
                 serialize=False,
                 wp_model=None,
                 corpus=''):
        """A class for loading dataset.

        Args:
            tsv_path (str): path to the dataset tsv file
            dict_path (str): path to the dictionary
            unit (str): word or wp or char or phone or word_char
            batch_size (int): size of mini-batch
            nlsyms (str): path to the non-linguistic symbols file
            n_epochs (int): total epochs for training
            is_test (bool):
            min_n_tokens (int): exclude utterances shorter than this value
            bptt (int): BPTT length
            shuffle (bool): shuffle utterances per epoch.
            backward (bool): flip all text in the corpus
            serialize (bool): serialize text according to contexts in dialogue
            wp_model (): path to the word-piece model for sentencepiece
            corpus (str): name of corpus

        """
        super(Dataset, self).__init__()

        self.epoch = 0
        self.iteration = 0
        self.offset = 0

        self.set = os.path.basename(tsv_path).split('.')[0]
        self.is_test = is_test
        self.unit = unit
        self.batch_size = batch_size
        self.bptt = bptt
        self.sos = 2
        self.eos = 2
        self.max_epoch = n_epochs
        self.shuffle = shuffle
        self.backward = backward
        self.vocab = count_vocab_size(dict_path)
        assert bptt >= 2

        self.idx2token = []
        self.token2idx = []

        # Set index converter
        if unit in ['word', 'word_char']:
            self.idx2token += [Idx2word(dict_path)]
            self.token2idx += [
                Word2idx(dict_path, word_char_mix=(unit == 'word_char'))
            ]
        elif unit == 'wp':
            self.idx2token += [Idx2wp(dict_path, wp_model)]
            self.token2idx += [Wp2idx(dict_path, wp_model)]
        elif unit == 'char':
            self.idx2token += [Idx2char(dict_path)]
            self.token2idx += [Char2idx(dict_path, nlsyms=nlsyms)]
        elif 'phone' in unit:
            self.idx2token += [Idx2phone(dict_path)]
            self.token2idx += [Phone2idx(dict_path)]
        else:
            raise ValueError(unit)

        # Load dataset tsv file
        self.df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t')
        self.df = self.df.loc[:, [
            'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text',
            'token_id', 'ylen', 'ydim'
        ]]

        # Remove inappropriate utterances
        if is_test:
            print('Original utterance num: %d' % len(self.df))
            n_utts = len(self.df)
            self.df = self.df[self.df.apply(lambda x: x['ylen'] > 0, axis=1)]
            print('Removed %d empty utterances' % (n_utts - len(self.df)))
        else:
            print('Original utterance num: %d' % len(self.df))
            n_utts = len(self.df)
            self.df = self.df[self.df.apply(
                lambda x: x['ylen'] >= min_n_tokens, axis=1)]
            print('Removed %d utterances (threshold)' %
                  (n_utts - len(self.df)))

        # Sort tsv records
        if shuffle:
            assert not serialize
            self.df = self.df.reindex(np.random.permutation(self.df.index))
        elif serialize:
            assert not shuffle
            assert corpus == 'swbd'
            self.df['session'] = self.df['speaker'].apply(
                lambda x: str(x).split('-')[0])
            self.df['onset'] = self.df['utt_id'].apply(
                lambda x: int(x.split('_')[-1].split('-')[0]))
            self.df = self.df.sort_values(by=['session', 'onset'],
                                          ascending=True)
        else:
            self.df = self.df.sort_values(by='utt_id', ascending=True)

        # Concatenate into a single sentence
        self.concat_ids = self.concat_utterances(self.df)
Example #2
0
    def __init__(self,
                 corpus,
                 tsv_path,
                 dict_path,
                 unit,
                 nlsyms,
                 wp_model,
                 is_test,
                 min_n_frames,
                 max_n_frames,
                 sort_by,
                 short2long,
                 tsv_path_sub1,
                 tsv_path_sub2,
                 ctc,
                 ctc_sub1,
                 ctc_sub2,
                 subsample_factor,
                 subsample_factor_sub1,
                 subsample_factor_sub2,
                 dict_path_sub1,
                 dict_path_sub2,
                 unit_sub1,
                 unit_sub2,
                 wp_model_sub1,
                 wp_model_sub2,
                 discourse_aware=False,
                 simulate_longform=False,
                 first_n_utterances=-1,
                 word_alignment_dir=None,
                 ctc_alignment_dir=None):
        """Custom Dataset class.

        Args:
            corpus (str): name of corpus
            tsv_path (str): path to the dataset tsv file
            dict_path (str): path to the dictionary
            unit (str): word/wp/char/phone/word_char
            nlsyms (str): path to the non-linguistic symbols file
            wp_model (): path to the word-piece model for sentencepiece
            is_test (bool):
            min_n_frames (int): exclude utterances shorter than this value
            max_n_frames (int): exclude utterances longer than this value
            sort_by (str): sort all utterances in the ascending order
                input: sort by input length
                output: sort by output length
                shuffle: shuffle all utterances
            short2long (bool): sort utterances in the descending order
            ctc (bool):
            subsample_factor (int):
            discourse_aware (bool): sort in the discourse order
            simulate_longform (bool): simulate long-form uttterance
            first_n_utterances (int): evaluate the first N utterances
            word_alignment_dir (str): path to word alignment directory
            ctc_alignment_dir (str): path to CTC alignment directory

        """
        super(Dataset, self).__init__()

        self.epoch = 0

        # meta deta accessed by dataloader
        self._corpus = corpus
        self._set = os.path.basename(tsv_path).split('.')[0]
        self._vocab = count_vocab_size(dict_path)
        self._unit = unit
        self._unit_sub1 = unit_sub1
        self._unit_sub2 = unit_sub2

        self.is_test = is_test
        self.sort_by = sort_by
        # if shuffle_bucket:
        #     assert sort_by in ['input', 'output']
        if discourse_aware:
            assert not is_test
        if simulate_longform:
            assert is_test
        self.simulate_longform = simulate_longform

        self.subsample_factor = subsample_factor
        self.word_alignment_dir = word_alignment_dir
        self.ctc_alignment_dir = ctc_alignment_dir

        self._idx2token = []
        self._token2idx = []

        # Set index converter
        if unit in ['word', 'word_char']:
            self._idx2token += [Idx2word(dict_path)]
            self._token2idx += [
                Word2idx(dict_path, word_char_mix=(unit == 'word_char'))
            ]
        elif unit == 'wp':
            self._idx2token += [Idx2wp(dict_path, wp_model)]
            self._token2idx += [Wp2idx(dict_path, wp_model)]
        elif unit in ['char']:
            self._idx2token += [Idx2char(dict_path)]
            self._token2idx += [Char2idx(dict_path, nlsyms=nlsyms)]
        elif 'phone' in unit:
            self._idx2token += [Idx2phone(dict_path)]
            self._token2idx += [Phone2idx(dict_path)]
        else:
            raise ValueError(unit)

        for i in range(1, 3):
            dict_path_sub = locals()['dict_path_sub' + str(i)]
            wp_model_sub = locals()['wp_model_sub' + str(i)]
            unit_sub = locals()['unit_sub' + str(i)]
            if dict_path_sub:
                setattr(self, '_vocab_sub' + str(i),
                        count_vocab_size(dict_path_sub))

                # Set index converter
                if unit_sub:
                    if unit_sub == 'wp':
                        self._idx2token += [
                            Idx2wp(dict_path_sub, wp_model_sub)
                        ]
                        self._token2idx += [
                            Wp2idx(dict_path_sub, wp_model_sub)
                        ]
                    elif unit_sub == 'char':
                        self._idx2token += [Idx2char(dict_path_sub)]
                        self._token2idx += [
                            Char2idx(dict_path_sub, nlsyms=nlsyms)
                        ]
                    elif 'phone' in unit_sub:
                        self._idx2token += [Idx2phone(dict_path_sub)]
                        self._token2idx += [Phone2idx(dict_path_sub)]
                    else:
                        raise ValueError(unit_sub)
            else:
                setattr(self, '_vocab_sub' + str(i), -1)

        # Load dataset tsv file
        df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t')
        df = df.loc[:, [
            'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text',
            'token_id', 'ylen', 'ydim'
        ]]
        for i in range(1, 3):
            if locals()['tsv_path_sub' + str(i)]:
                df_sub = pd.read_csv(locals()['tsv_path_sub' + str(i)],
                                     encoding='utf-8',
                                     delimiter='\t')
                df_sub = df_sub.loc[:, [
                    'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text',
                    'token_id', 'ylen', 'ydim'
                ]]
                setattr(self, 'df_sub' + str(i), df_sub)
            else:
                setattr(self, 'df_sub' + str(i), None)
        self._input_dim = kaldiio.load_mat(df['feat_path'][0]).shape[-1]

        # Remove inappropriate utterances
        print('Original utterance num: %d' % len(df))
        n_utts = len(df)
        if is_test or discourse_aware:
            df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)]
            print('Removed %d empty utterances' % (n_utts - len(df)))
            if first_n_utterances > 0:
                n_utts = len(df)
                df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)]
                df = df.truncate(before=0, after=first_n_utterances - 1)
                print('Select first %d utterances' % len(df))
        else:
            df = df[df.apply(
                lambda x: min_n_frames <= x['xlen'] <= max_n_frames, axis=1)]
            df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)]
            print('Removed %d utterances (threshold)' % (n_utts - len(df)))

            if ctc and subsample_factor > 1:
                n_utts = len(df)
                df = df[df.apply(lambda x: x['ylen'] <=
                                 (x['xlen'] // subsample_factor),
                                 axis=1)]
                print('Removed %d utterances (for CTC)' % (n_utts - len(df)))

            for i in range(1, 3):
                df_sub = getattr(self, 'df_sub' + str(i))
                ctc_sub = locals()['ctc_sub' + str(i)]
                subsample_factor_sub = locals()['subsample_factor_sub' +
                                                str(i)]
                if df_sub is not None:
                    if ctc_sub and subsample_factor_sub > 1:
                        df_sub = df_sub[df_sub.apply(
                            lambda x: x['ylen'] <=
                            (x['xlen'] // subsample_factor_sub),
                            axis=1)]

                    if len(df) != len(df_sub):
                        n_utts = len(df)
                        df = df.drop(df.index.difference(df_sub.index))
                        print('Removed %d utterances (for CTC, sub%d)' %
                              (n_utts - len(df), i))
                        for j in range(1, i + 1):
                            setattr(
                                self, 'df_sub' + str(j),
                                getattr(self, 'df_sub' + str(j)).drop(
                                    getattr(self, 'df_sub' +
                                            str(j)).index.difference(
                                                df.index)))

        if corpus == 'swbd':
            # 1. serialize
            # df['session'] = df['speaker'].apply(lambda x: str(x).split('-')[0])
            # 2. not serialize
            df['session'] = df['speaker'].apply(lambda x: str(x))
        else:
            df['session'] = df['speaker'].apply(lambda x: str(x))

        # Sort tsv records
        if discourse_aware:
            # Sort by onset (start time)
            df = df.assign(prev_utt='')
            df = df.assign(line_no=list(range(len(df))))
            if corpus == 'swbd':
                df['onset'] = df['utt_id'].apply(
                    lambda x: int(x.split('_')[-1].split('-')[0]))
            elif corpus == 'csj':
                df['onset'] = df['utt_id'].apply(
                    lambda x: int(x.split('_')[1]))
            elif corpus == 'tedlium2':
                df['onset'] = df['utt_id'].apply(
                    lambda x: int(x.split('-')[-2]))
            else:
                raise NotImplementedError(corpus)
            df = df.sort_values(by=['session', 'onset'], ascending=True)

            # Extract previous utterances
            groups = df.groupby('session').groups
            df['prev_utt'] = df.apply(lambda x: [
                df.loc[i, 'line_no'] for i in groups[x['session']]
                if df.loc[i, 'onset'] < x['onset']
            ],
                                      axis=1)
            df['n_prev_utt'] = df.apply(lambda x: len(x['prev_utt']), axis=1)
            df['n_utt_in_session'] = df.apply(
                lambda x: len([i for i in groups[x['session']]]), axis=1)
            df = df.sort_values(by=['n_utt_in_session'], ascending=short2long)

            # NOTE: this is used only when LM is trained with serialize: true
            # if is_test and corpus == 'swbd':
            #     # Sort by onset
            #     df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[-1].split('-')[0]))
            #     df = df.sort_values(by=['session', 'onset'], ascending=True)

        elif not is_test:
            if sort_by == 'input':
                df = df.sort_values(by=['xlen'], ascending=short2long)
            elif sort_by == 'output':
                df = df.sort_values(by=['ylen'], ascending=short2long)
            elif sort_by == 'shuffle':
                df = df.reindex(np.random.permutation(self.df.index))

        # Fit word alignment to vocabulary
        if word_alignment_dir is not None:
            alignment2boundary = WordAlignmentConverter(dict_path, wp_model)
            n_utts = len(df)
            df['trigger_points'] = df.apply(lambda x: alignment2boundary(
                word_alignment_dir, x['speaker'], x['utt_id'], x['text']),
                                            axis=1)
            # remove utterances which do not have the alignment
            df = df[df.apply(lambda x: x['trigger_points'] is not None,
                             axis=1)]
            print('Removed %d utterances (for word alignment)' %
                  (n_utts - len(df)))
        elif ctc_alignment_dir is not None:
            n_utts = len(df)
            df['trigger_points'] = df.apply(lambda x: load_ctc_alignment(
                ctc_alignment_dir, x['speaker'], x['utt_id']),
                                            axis=1)
            # remove utterances which do not have the alignment
            df = df[df.apply(lambda x: x['trigger_points'] is not None,
                             axis=1)]
            print('Removed %d utterances (for CTC alignment)' %
                  (n_utts - len(df)))

        # Re-indexing
        if discourse_aware:
            self.df = df
            for i in range(1, 3):
                if getattr(self, 'df_sub' + str(i)) is not None:
                    setattr(self, 'df_sub' + str(i),
                            getattr(self, 'df_sub' + str(i)).reindex(df.index))
        else:
            self.df = df.reset_index()
            for i in range(1, 3):
                if getattr(self, 'df_sub' + str(i)) is not None:
                    setattr(
                        self, 'df_sub' + str(i),
                        getattr(self, 'df_sub' + str(i)).reindex(
                            df.index).reset_index())
Example #3
0
def main(args):

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Load a conf file
    if args.resume:
        conf = load_config(
            os.path.join(os.path.dirname(args.resume), 'conf.yml'))
        for k, v in conf.items():
            if k not in ['resume', 'local_rank']:
                setattr(args, k, v)

    # Load dataset
    train_set = Dataset(corpus=args.corpus,
                        tsv_path=args.train_set,
                        batch_size=args.batch_size,
                        bptt=args.bptt,
                        distributed=args.distributed,
                        min_n_tokens=args.min_n_tokens,
                        shuffle=args.shuffle,
                        backward=args.backward,
                        serialize=args.serialize)
    dev_set = Dataset(corpus=args.corpus,
                      tsv_path=args.dev_set,
                      batch_size=args.batch_size,
                      bptt=args.bptt,
                      backward=args.backward,
                      serialize=args.serialize)
    eval_sets = [
        Dataset(corpus=args.corpus,
                tsv_path=s,
                batch_size=1,
                bptt=args.bptt,
                backward=args.backward,
                serialize=args.serialize) for s in args.eval_sets
    ]
    args.vocab = count_vocab_size(args.dict)

    # Set save path
    if args.resume:
        args.save_path = os.path.dirname(args.resume)
        dir_name = os.path.basename(args.save_path)
    else:
        dir_name = set_lm_name(args)
        args.save_path = mkdir_join(
            args.model_save_dir,
            '_'.join(os.path.basename(args.train_set).split('.')[:-1]),
            dir_name)
        if args.local_rank > 0:
            time.sleep(1)
        args.save_path = set_save_path(args.save_path)  # avoid overwriting

    # Set logger
    set_logger(os.path.join(args.save_path, 'train.log'), args.stdout,
               args.local_rank)

    # Model setting
    model = build_lm(args, args.save_path)

    if not args.resume:
        # Save nlsyms, dictionary, and wp_model
        if args.nlsyms:
            shutil.copy(args.nlsyms, os.path.join(args.save_path,
                                                  'nlsyms.txt'))
        shutil.copy(args.dict, os.path.join(args.save_path, 'dict.txt'))
        if args.unit == 'wp':
            shutil.copy(args.wp_model, os.path.join(args.save_path,
                                                    'wp.model'))

        for k, v in sorted(args.items(), key=lambda x: x[0]):
            logger.info('%s: %s' % (k, str(v)))

        # Count total parameters
        for n in sorted(list(model.num_params_dict.keys())):
            n_params = model.num_params_dict[n]
            logger.info("%s %d" % (n, n_params))
        logger.info("Total %.2f M parameters" %
                    (model.total_parameters / 1000000))
        logger.info('torch version: %s' % str(torch.__version__))
        logger.info(model)

    # Set optimizer
    resume_epoch = int(args.resume.split('-')[-1]) if args.resume else 0
    optimizer = set_optimizer(
        model,
        'sgd' if resume_epoch > args.convert_to_sgd_epoch else args.optimizer,
        args.lr, args.weight_decay)

    # Wrap optimizer by learning rate scheduler
    is_transformer = args.lm_type in ['transformer', 'transformer_xl']
    scheduler = LRScheduler(
        optimizer,
        args.lr,
        decay_type=args.lr_decay_type,
        decay_start_epoch=args.lr_decay_start_epoch,
        decay_rate=args.lr_decay_rate,
        decay_patient_n_epochs=args.lr_decay_patient_n_epochs,
        early_stop_patient_n_epochs=args.early_stop_patient_n_epochs,
        warmup_start_lr=args.warmup_start_lr,
        warmup_n_steps=args.warmup_n_steps,
        model_size=args.get('transformer_d_model', 0),
        factor=args.lr_factor,
        noam=args.optimizer == 'noam',
        save_checkpoints_topk=10 if is_transformer else 1)

    if args.resume:
        # Restore the last saved model
        load_checkpoint(args.resume, model, scheduler)

        # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch
        if resume_epoch == args.convert_to_sgd_epoch:
            scheduler.convert_to_sgd(model,
                                     args.lr,
                                     args.weight_decay,
                                     decay_type='always',
                                     decay_rate=0.5)

    # GPU setting
    args.use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"]
    amp, scaler = None, None
    if args.n_gpus >= 1:
        model.cudnn_setting(
            deterministic=((not is_transformer) and (not args.cudnn_benchmark))
            or args.cudnn_deterministic,
            benchmark=(not is_transformer) and args.cudnn_benchmark)

        # Mixed precision training setting
        if args.use_apex:
            if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
                scaler = torch.cuda.amp.GradScaler()
            else:
                from apex import amp
                model, scheduler.optimizer = amp.initialize(
                    model, scheduler.optimizer, opt_level=args.train_dtype)
                amp.init()
                if args.resume:
                    load_checkpoint(args.resume, amp=amp)

        n = torch.cuda.device_count() // args.local_world_size
        device_ids = list(range(args.local_rank * n,
                                (args.local_rank + 1) * n))

        torch.cuda.set_device(device_ids[0])
        model.cuda(device_ids[0])
        scheduler.cuda(device_ids[0])
        if args.distributed:
            model = DDP(model, device_ids=device_ids)
        else:
            model = CustomDataParallel(model,
                                       device_ids=list(range(args.n_gpus)))
    else:
        model = CPUWrapperLM(model)

    # Set process name
    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])
    logger.info('#GPU: %d' % torch.cuda.device_count())
    setproctitle(args.job_name if args.job_name else dir_name)

    # Set reporter
    reporter = Reporter(args, model, args.local_rank)
    args.wandb_id = reporter.wandb_id
    if args.resume:
        n_steps = scheduler.n_steps * max(
            1, args.accum_grad_n_steps // args.local_world_size)
        reporter.resume(n_steps, resume_epoch)

    # Save conf file as a yaml file
    if args.local_rank == 0:
        save_config(args, os.path.join(args.save_path, 'conf.yml'))
        # NOTE: save after reporter for wandb ID

    start_time_train = time.time()
    for ep in range(resume_epoch, args.n_epochs):
        train_one_epoch(model, train_set, dev_set, scheduler, reporter, logger,
                        args, amp, scaler)

        # Save checkpoint and validate model per epoch
        if reporter.n_epochs + 1 < args.eval_start_epoch:
            scheduler.epoch()  # lr decay
            reporter.epoch()  # plot

            # Save model
            if args.local_rank == 0:
                scheduler.save_checkpoint(model,
                                          args.save_path,
                                          amp=amp,
                                          remove_old=(not is_transformer)
                                          and args.remove_old_checkpoints)
        else:
            start_time_eval = time.time()
            # dev
            model.module.reset_length(args.bptt)
            ppl_dev, _ = eval_ppl([model.module],
                                  dev_set,
                                  batch_size=1,
                                  bptt=args.bptt)
            model.module.reset_length(args.bptt)
            scheduler.epoch(ppl_dev)  # lr decay
            reporter.epoch(ppl_dev, name='perplexity')  # plot
            reporter.add_scalar('dev/perplexity', ppl_dev)
            logger.info('PPL (%s, ep:%d): %.2f' %
                        (dev_set.set, reporter.n_epochs, ppl_dev))

            if scheduler.is_topk or is_transformer:
                # Save model
                if args.local_rank == 0:
                    scheduler.save_checkpoint(model,
                                              args.save_path,
                                              amp=amp,
                                              remove_old=(not is_transformer)
                                              and args.remove_old_checkpoints)

                # test
                ppl_test_avg = 0.
                for eval_set in eval_sets:
                    model.module.reset_length(args.bptt)
                    ppl_test, _ = eval_ppl([model.module],
                                           eval_set,
                                           batch_size=1,
                                           bptt=args.bptt)
                    model.module.reset_length(args.bptt)
                    logger.info('PPL (%s, ep:%d): %.2f' %
                                (eval_set.set, reporter.n_epochs, ppl_test))
                    ppl_test_avg += ppl_test
                if len(eval_sets) > 0:
                    logger.info(
                        'PPL (avg., ep:%d): %.2f' %
                        (reporter.n_epochs, ppl_test_avg / len(eval_sets)))

            logger.info('Evaluation time: %.2f min' %
                        ((time.time() - start_time_eval) / 60))

            # Early stopping
            if scheduler.is_early_stop:
                break

            # Convert to fine-tuning stage
            if reporter.n_epochs == args.convert_to_sgd_epoch:
                scheduler.convert_to_sgd(model,
                                         args.lr,
                                         args.weight_decay,
                                         decay_type='always',
                                         decay_rate=0.5)

        if reporter.n_epochs >= args.n_epochs:
            break

    logger.info('Total time: %.2f hour' %
                ((time.time() - start_time_train) / 3600))
    reporter.close()

    return args.save_path