def __init__(self, tsv_path, dict_path, unit, batch_size, nlsyms=False, n_epochs=1e10, is_test=False, min_n_tokens=1, bptt=2, shuffle=False, backward=False, serialize=False, wp_model=None, corpus=''): """A class for loading dataset. Args: tsv_path (str): path to the dataset tsv file dict_path (str): path to the dictionary unit (str): word or wp or char or phone or word_char batch_size (int): size of mini-batch nlsyms (str): path to the non-linguistic symbols file n_epochs (int): total epochs for training is_test (bool): min_n_tokens (int): exclude utterances shorter than this value bptt (int): BPTT length shuffle (bool): shuffle utterances per epoch. backward (bool): flip all text in the corpus serialize (bool): serialize text according to contexts in dialogue wp_model (): path to the word-piece model for sentencepiece corpus (str): name of corpus """ super(Dataset, self).__init__() self.epoch = 0 self.iteration = 0 self.offset = 0 self.set = os.path.basename(tsv_path).split('.')[0] self.is_test = is_test self.unit = unit self.batch_size = batch_size self.bptt = bptt self.sos = 2 self.eos = 2 self.max_epoch = n_epochs self.shuffle = shuffle self.backward = backward self.vocab = count_vocab_size(dict_path) assert bptt >= 2 self.idx2token = [] self.token2idx = [] # Set index converter if unit in ['word', 'word_char']: self.idx2token += [Idx2word(dict_path)] self.token2idx += [ Word2idx(dict_path, word_char_mix=(unit == 'word_char')) ] elif unit == 'wp': self.idx2token += [Idx2wp(dict_path, wp_model)] self.token2idx += [Wp2idx(dict_path, wp_model)] elif unit == 'char': self.idx2token += [Idx2char(dict_path)] self.token2idx += [Char2idx(dict_path, nlsyms=nlsyms)] elif 'phone' in unit: self.idx2token += [Idx2phone(dict_path)] self.token2idx += [Phone2idx(dict_path)] else: raise ValueError(unit) # Load dataset tsv file self.df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t') self.df = self.df.loc[:, [ 'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim' ]] # Remove inappropriate utterances if is_test: print('Original utterance num: %d' % len(self.df)) n_utts = len(self.df) self.df = self.df[self.df.apply(lambda x: x['ylen'] > 0, axis=1)] print('Removed %d empty utterances' % (n_utts - len(self.df))) else: print('Original utterance num: %d' % len(self.df)) n_utts = len(self.df) self.df = self.df[self.df.apply( lambda x: x['ylen'] >= min_n_tokens, axis=1)] print('Removed %d utterances (threshold)' % (n_utts - len(self.df))) # Sort tsv records if shuffle: assert not serialize self.df = self.df.reindex(np.random.permutation(self.df.index)) elif serialize: assert not shuffle assert corpus == 'swbd' self.df['session'] = self.df['speaker'].apply( lambda x: str(x).split('-')[0]) self.df['onset'] = self.df['utt_id'].apply( lambda x: int(x.split('_')[-1].split('-')[0])) self.df = self.df.sort_values(by=['session', 'onset'], ascending=True) else: self.df = self.df.sort_values(by='utt_id', ascending=True) # Concatenate into a single sentence self.concat_ids = self.concat_utterances(self.df)
def __init__(self, corpus, tsv_path, dict_path, unit, nlsyms, wp_model, is_test, min_n_frames, max_n_frames, sort_by, short2long, tsv_path_sub1, tsv_path_sub2, ctc, ctc_sub1, ctc_sub2, subsample_factor, subsample_factor_sub1, subsample_factor_sub2, dict_path_sub1, dict_path_sub2, unit_sub1, unit_sub2, wp_model_sub1, wp_model_sub2, discourse_aware=False, simulate_longform=False, first_n_utterances=-1, word_alignment_dir=None, ctc_alignment_dir=None): """Custom Dataset class. Args: corpus (str): name of corpus tsv_path (str): path to the dataset tsv file dict_path (str): path to the dictionary unit (str): word/wp/char/phone/word_char nlsyms (str): path to the non-linguistic symbols file wp_model (): path to the word-piece model for sentencepiece is_test (bool): min_n_frames (int): exclude utterances shorter than this value max_n_frames (int): exclude utterances longer than this value sort_by (str): sort all utterances in the ascending order input: sort by input length output: sort by output length shuffle: shuffle all utterances short2long (bool): sort utterances in the descending order ctc (bool): subsample_factor (int): discourse_aware (bool): sort in the discourse order simulate_longform (bool): simulate long-form uttterance first_n_utterances (int): evaluate the first N utterances word_alignment_dir (str): path to word alignment directory ctc_alignment_dir (str): path to CTC alignment directory """ super(Dataset, self).__init__() self.epoch = 0 # meta deta accessed by dataloader self._corpus = corpus self._set = os.path.basename(tsv_path).split('.')[0] self._vocab = count_vocab_size(dict_path) self._unit = unit self._unit_sub1 = unit_sub1 self._unit_sub2 = unit_sub2 self.is_test = is_test self.sort_by = sort_by # if shuffle_bucket: # assert sort_by in ['input', 'output'] if discourse_aware: assert not is_test if simulate_longform: assert is_test self.simulate_longform = simulate_longform self.subsample_factor = subsample_factor self.word_alignment_dir = word_alignment_dir self.ctc_alignment_dir = ctc_alignment_dir self._idx2token = [] self._token2idx = [] # Set index converter if unit in ['word', 'word_char']: self._idx2token += [Idx2word(dict_path)] self._token2idx += [ Word2idx(dict_path, word_char_mix=(unit == 'word_char')) ] elif unit == 'wp': self._idx2token += [Idx2wp(dict_path, wp_model)] self._token2idx += [Wp2idx(dict_path, wp_model)] elif unit in ['char']: self._idx2token += [Idx2char(dict_path)] self._token2idx += [Char2idx(dict_path, nlsyms=nlsyms)] elif 'phone' in unit: self._idx2token += [Idx2phone(dict_path)] self._token2idx += [Phone2idx(dict_path)] else: raise ValueError(unit) for i in range(1, 3): dict_path_sub = locals()['dict_path_sub' + str(i)] wp_model_sub = locals()['wp_model_sub' + str(i)] unit_sub = locals()['unit_sub' + str(i)] if dict_path_sub: setattr(self, '_vocab_sub' + str(i), count_vocab_size(dict_path_sub)) # Set index converter if unit_sub: if unit_sub == 'wp': self._idx2token += [ Idx2wp(dict_path_sub, wp_model_sub) ] self._token2idx += [ Wp2idx(dict_path_sub, wp_model_sub) ] elif unit_sub == 'char': self._idx2token += [Idx2char(dict_path_sub)] self._token2idx += [ Char2idx(dict_path_sub, nlsyms=nlsyms) ] elif 'phone' in unit_sub: self._idx2token += [Idx2phone(dict_path_sub)] self._token2idx += [Phone2idx(dict_path_sub)] else: raise ValueError(unit_sub) else: setattr(self, '_vocab_sub' + str(i), -1) # Load dataset tsv file df = pd.read_csv(tsv_path, encoding='utf-8', delimiter='\t') df = df.loc[:, [ 'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim' ]] for i in range(1, 3): if locals()['tsv_path_sub' + str(i)]: df_sub = pd.read_csv(locals()['tsv_path_sub' + str(i)], encoding='utf-8', delimiter='\t') df_sub = df_sub.loc[:, [ 'utt_id', 'speaker', 'feat_path', 'xlen', 'xdim', 'text', 'token_id', 'ylen', 'ydim' ]] setattr(self, 'df_sub' + str(i), df_sub) else: setattr(self, 'df_sub' + str(i), None) self._input_dim = kaldiio.load_mat(df['feat_path'][0]).shape[-1] # Remove inappropriate utterances print('Original utterance num: %d' % len(df)) n_utts = len(df) if is_test or discourse_aware: df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)] print('Removed %d empty utterances' % (n_utts - len(df))) if first_n_utterances > 0: n_utts = len(df) df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)] df = df.truncate(before=0, after=first_n_utterances - 1) print('Select first %d utterances' % len(df)) else: df = df[df.apply( lambda x: min_n_frames <= x['xlen'] <= max_n_frames, axis=1)] df = df[df.apply(lambda x: x['ylen'] > 0, axis=1)] print('Removed %d utterances (threshold)' % (n_utts - len(df))) if ctc and subsample_factor > 1: n_utts = len(df) df = df[df.apply(lambda x: x['ylen'] <= (x['xlen'] // subsample_factor), axis=1)] print('Removed %d utterances (for CTC)' % (n_utts - len(df))) for i in range(1, 3): df_sub = getattr(self, 'df_sub' + str(i)) ctc_sub = locals()['ctc_sub' + str(i)] subsample_factor_sub = locals()['subsample_factor_sub' + str(i)] if df_sub is not None: if ctc_sub and subsample_factor_sub > 1: df_sub = df_sub[df_sub.apply( lambda x: x['ylen'] <= (x['xlen'] // subsample_factor_sub), axis=1)] if len(df) != len(df_sub): n_utts = len(df) df = df.drop(df.index.difference(df_sub.index)) print('Removed %d utterances (for CTC, sub%d)' % (n_utts - len(df), i)) for j in range(1, i + 1): setattr( self, 'df_sub' + str(j), getattr(self, 'df_sub' + str(j)).drop( getattr(self, 'df_sub' + str(j)).index.difference( df.index))) if corpus == 'swbd': # 1. serialize # df['session'] = df['speaker'].apply(lambda x: str(x).split('-')[0]) # 2. not serialize df['session'] = df['speaker'].apply(lambda x: str(x)) else: df['session'] = df['speaker'].apply(lambda x: str(x)) # Sort tsv records if discourse_aware: # Sort by onset (start time) df = df.assign(prev_utt='') df = df.assign(line_no=list(range(len(df)))) if corpus == 'swbd': df['onset'] = df['utt_id'].apply( lambda x: int(x.split('_')[-1].split('-')[0])) elif corpus == 'csj': df['onset'] = df['utt_id'].apply( lambda x: int(x.split('_')[1])) elif corpus == 'tedlium2': df['onset'] = df['utt_id'].apply( lambda x: int(x.split('-')[-2])) else: raise NotImplementedError(corpus) df = df.sort_values(by=['session', 'onset'], ascending=True) # Extract previous utterances groups = df.groupby('session').groups df['prev_utt'] = df.apply(lambda x: [ df.loc[i, 'line_no'] for i in groups[x['session']] if df.loc[i, 'onset'] < x['onset'] ], axis=1) df['n_prev_utt'] = df.apply(lambda x: len(x['prev_utt']), axis=1) df['n_utt_in_session'] = df.apply( lambda x: len([i for i in groups[x['session']]]), axis=1) df = df.sort_values(by=['n_utt_in_session'], ascending=short2long) # NOTE: this is used only when LM is trained with serialize: true # if is_test and corpus == 'swbd': # # Sort by onset # df['onset'] = df['utt_id'].apply(lambda x: int(x.split('_')[-1].split('-')[0])) # df = df.sort_values(by=['session', 'onset'], ascending=True) elif not is_test: if sort_by == 'input': df = df.sort_values(by=['xlen'], ascending=short2long) elif sort_by == 'output': df = df.sort_values(by=['ylen'], ascending=short2long) elif sort_by == 'shuffle': df = df.reindex(np.random.permutation(self.df.index)) # Fit word alignment to vocabulary if word_alignment_dir is not None: alignment2boundary = WordAlignmentConverter(dict_path, wp_model) n_utts = len(df) df['trigger_points'] = df.apply(lambda x: alignment2boundary( word_alignment_dir, x['speaker'], x['utt_id'], x['text']), axis=1) # remove utterances which do not have the alignment df = df[df.apply(lambda x: x['trigger_points'] is not None, axis=1)] print('Removed %d utterances (for word alignment)' % (n_utts - len(df))) elif ctc_alignment_dir is not None: n_utts = len(df) df['trigger_points'] = df.apply(lambda x: load_ctc_alignment( ctc_alignment_dir, x['speaker'], x['utt_id']), axis=1) # remove utterances which do not have the alignment df = df[df.apply(lambda x: x['trigger_points'] is not None, axis=1)] print('Removed %d utterances (for CTC alignment)' % (n_utts - len(df))) # Re-indexing if discourse_aware: self.df = df for i in range(1, 3): if getattr(self, 'df_sub' + str(i)) is not None: setattr(self, 'df_sub' + str(i), getattr(self, 'df_sub' + str(i)).reindex(df.index)) else: self.df = df.reset_index() for i in range(1, 3): if getattr(self, 'df_sub' + str(i)) is not None: setattr( self, 'df_sub' + str(i), getattr(self, 'df_sub' + str(i)).reindex( df.index).reset_index())
def main(args): torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # Load a conf file if args.resume: conf = load_config( os.path.join(os.path.dirname(args.resume), 'conf.yml')) for k, v in conf.items(): if k not in ['resume', 'local_rank']: setattr(args, k, v) # Load dataset train_set = Dataset(corpus=args.corpus, tsv_path=args.train_set, batch_size=args.batch_size, bptt=args.bptt, distributed=args.distributed, min_n_tokens=args.min_n_tokens, shuffle=args.shuffle, backward=args.backward, serialize=args.serialize) dev_set = Dataset(corpus=args.corpus, tsv_path=args.dev_set, batch_size=args.batch_size, bptt=args.bptt, backward=args.backward, serialize=args.serialize) eval_sets = [ Dataset(corpus=args.corpus, tsv_path=s, batch_size=1, bptt=args.bptt, backward=args.backward, serialize=args.serialize) for s in args.eval_sets ] args.vocab = count_vocab_size(args.dict) # Set save path if args.resume: args.save_path = os.path.dirname(args.resume) dir_name = os.path.basename(args.save_path) else: dir_name = set_lm_name(args) args.save_path = mkdir_join( args.model_save_dir, '_'.join(os.path.basename(args.train_set).split('.')[:-1]), dir_name) if args.local_rank > 0: time.sleep(1) args.save_path = set_save_path(args.save_path) # avoid overwriting # Set logger set_logger(os.path.join(args.save_path, 'train.log'), args.stdout, args.local_rank) # Model setting model = build_lm(args, args.save_path) if not args.resume: # Save nlsyms, dictionary, and wp_model if args.nlsyms: shutil.copy(args.nlsyms, os.path.join(args.save_path, 'nlsyms.txt')) shutil.copy(args.dict, os.path.join(args.save_path, 'dict.txt')) if args.unit == 'wp': shutil.copy(args.wp_model, os.path.join(args.save_path, 'wp.model')) for k, v in sorted(args.items(), key=lambda x: x[0]): logger.info('%s: %s' % (k, str(v))) # Count total parameters for n in sorted(list(model.num_params_dict.keys())): n_params = model.num_params_dict[n] logger.info("%s %d" % (n, n_params)) logger.info("Total %.2f M parameters" % (model.total_parameters / 1000000)) logger.info('torch version: %s' % str(torch.__version__)) logger.info(model) # Set optimizer resume_epoch = int(args.resume.split('-')[-1]) if args.resume else 0 optimizer = set_optimizer( model, 'sgd' if resume_epoch > args.convert_to_sgd_epoch else args.optimizer, args.lr, args.weight_decay) # Wrap optimizer by learning rate scheduler is_transformer = args.lm_type in ['transformer', 'transformer_xl'] scheduler = LRScheduler( optimizer, args.lr, decay_type=args.lr_decay_type, decay_start_epoch=args.lr_decay_start_epoch, decay_rate=args.lr_decay_rate, decay_patient_n_epochs=args.lr_decay_patient_n_epochs, early_stop_patient_n_epochs=args.early_stop_patient_n_epochs, warmup_start_lr=args.warmup_start_lr, warmup_n_steps=args.warmup_n_steps, model_size=args.get('transformer_d_model', 0), factor=args.lr_factor, noam=args.optimizer == 'noam', save_checkpoints_topk=10 if is_transformer else 1) if args.resume: # Restore the last saved model load_checkpoint(args.resume, model, scheduler) # Resume between convert_to_sgd_epoch -1 and convert_to_sgd_epoch if resume_epoch == args.convert_to_sgd_epoch: scheduler.convert_to_sgd(model, args.lr, args.weight_decay, decay_type='always', decay_rate=0.5) # GPU setting args.use_apex = args.train_dtype in ["O0", "O1", "O2", "O3"] amp, scaler = None, None if args.n_gpus >= 1: model.cudnn_setting( deterministic=((not is_transformer) and (not args.cudnn_benchmark)) or args.cudnn_deterministic, benchmark=(not is_transformer) and args.cudnn_benchmark) # Mixed precision training setting if args.use_apex: if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): scaler = torch.cuda.amp.GradScaler() else: from apex import amp model, scheduler.optimizer = amp.initialize( model, scheduler.optimizer, opt_level=args.train_dtype) amp.init() if args.resume: load_checkpoint(args.resume, amp=amp) n = torch.cuda.device_count() // args.local_world_size device_ids = list(range(args.local_rank * n, (args.local_rank + 1) * n)) torch.cuda.set_device(device_ids[0]) model.cuda(device_ids[0]) scheduler.cuda(device_ids[0]) if args.distributed: model = DDP(model, device_ids=device_ids) else: model = CustomDataParallel(model, device_ids=list(range(args.n_gpus))) else: model = CPUWrapperLM(model) # Set process name logger.info('PID: %s' % os.getpid()) logger.info('USERNAME: %s' % os.uname()[1]) logger.info('#GPU: %d' % torch.cuda.device_count()) setproctitle(args.job_name if args.job_name else dir_name) # Set reporter reporter = Reporter(args, model, args.local_rank) args.wandb_id = reporter.wandb_id if args.resume: n_steps = scheduler.n_steps * max( 1, args.accum_grad_n_steps // args.local_world_size) reporter.resume(n_steps, resume_epoch) # Save conf file as a yaml file if args.local_rank == 0: save_config(args, os.path.join(args.save_path, 'conf.yml')) # NOTE: save after reporter for wandb ID start_time_train = time.time() for ep in range(resume_epoch, args.n_epochs): train_one_epoch(model, train_set, dev_set, scheduler, reporter, logger, args, amp, scaler) # Save checkpoint and validate model per epoch if reporter.n_epochs + 1 < args.eval_start_epoch: scheduler.epoch() # lr decay reporter.epoch() # plot # Save model if args.local_rank == 0: scheduler.save_checkpoint(model, args.save_path, amp=amp, remove_old=(not is_transformer) and args.remove_old_checkpoints) else: start_time_eval = time.time() # dev model.module.reset_length(args.bptt) ppl_dev, _ = eval_ppl([model.module], dev_set, batch_size=1, bptt=args.bptt) model.module.reset_length(args.bptt) scheduler.epoch(ppl_dev) # lr decay reporter.epoch(ppl_dev, name='perplexity') # plot reporter.add_scalar('dev/perplexity', ppl_dev) logger.info('PPL (%s, ep:%d): %.2f' % (dev_set.set, reporter.n_epochs, ppl_dev)) if scheduler.is_topk or is_transformer: # Save model if args.local_rank == 0: scheduler.save_checkpoint(model, args.save_path, amp=amp, remove_old=(not is_transformer) and args.remove_old_checkpoints) # test ppl_test_avg = 0. for eval_set in eval_sets: model.module.reset_length(args.bptt) ppl_test, _ = eval_ppl([model.module], eval_set, batch_size=1, bptt=args.bptt) model.module.reset_length(args.bptt) logger.info('PPL (%s, ep:%d): %.2f' % (eval_set.set, reporter.n_epochs, ppl_test)) ppl_test_avg += ppl_test if len(eval_sets) > 0: logger.info( 'PPL (avg., ep:%d): %.2f' % (reporter.n_epochs, ppl_test_avg / len(eval_sets))) logger.info('Evaluation time: %.2f min' % ((time.time() - start_time_eval) / 60)) # Early stopping if scheduler.is_early_stop: break # Convert to fine-tuning stage if reporter.n_epochs == args.convert_to_sgd_epoch: scheduler.convert_to_sgd(model, args.lr, args.weight_decay, decay_type='always', decay_rate=0.5) if reporter.n_epochs >= args.n_epochs: break logger.info('Total time: %.2f hour' % ((time.time() - start_time_train) / 3600)) reporter.close() return args.save_path