def __init__(self, config, paras, id2accent): super(FOMetaASRInterface, self).__init__(config, paras, id2accent) assert paras.meta_k is not None self.meta_k = paras.meta_k if paras.meta_batch_size is None: logger.log("Meta batch_size not set...", prefix='info') self.meta_batch_size = self.num_pretrain self.asr_model = None self.asr_opt = None self.max_step = paras.max_step if paras.max_step > 0 else config['solver']['total_steps'] self.dashboard.set_status('pretraining') self._train = partial(self.run_batch, train=True) self._eval = partial(self.run_batch, train=False) self.meta_batch_size = paras.meta_batch_size self._updates = None self._counter = 0 d_model = config['asr_model']['d_model'] opt_k = config['asr_model']['meta']['optimizer_opt']['k'] warmup_steps = config['asr_model']['meta']['optimizer_opt']['warmup_steps'] self.inner_lr = d_model ** (-0.5) * opt_k * ((warmup_steps) ** (-0.5)) logger.notice("Meta Interface Information") logger.log(f"Sampling strategy: {self.sample_strategy}", prefix='info') logger.log(f"Meta batch size : {self.meta_batch_size}", prefix='info') logger.log(f"# inner-loop step: {self.meta_k}", prefix='info') # Max. of lr in noam logger.log( "Inner loop lr : {}".format(round(self.inner_lr, 5)),prefix='info')
def load_data(self): logger.notice(f"Loading data from {self.data_dir}") if self.model_name == 'blstm': self.id2ch = [BLANK_SYMBOL] elif self.model_name == 'transformer': self.id2ch = [SOS_SYMBOL] else: raise NotImplementedError with open(self.config['solver']['spm_mapping']) as fin: for line in fin.readlines(): self.id2ch.append(line.rstrip().split(' ')[0]) self.id2ch.append(EOS_SYMBOL) logger.log(f"Train units: {self.id2ch}") setattr(self, 'eval_set', get_loader( self.data_dir.joinpath('test'), # self.data_dir.joinpath('dev'), batch_size = self.batch_size, half_batch_ilen = 512 if self.batch_size > 1 else None, is_memmap = self.is_memmap, is_bucket = False, shuffle = False, num_workers = 1, ))
def set_model(self): logger.notice(f"Load trained ASR model from {self.model_path}") if self.model_name == 'blstm': from src.model.blstm.mono_blstm import MonoBLSTM as ASRModel elif self.model_name == 'las': from src.model.seq2seq.mono_las import MonoLAS as ASRModel elif self.model_name == 'transformer': from src.model.transformer_pytorch.mono_transformer_torch import MyTransformer as ASRModel else: raise NotImplementedError self.asr_model = ASRModel(self.id2ch, self.config['asr_model']) if self.use_gpu: self.asr_model.load_state_dict(torch.load(self.model_path)) self.asr_model = self.asr_model.cuda() else: self.asr_model.load_state_dict(torch.load(self.model_path, map_location=torch.device('cpu'))) self.asr_model.eval() logger.log(f'ASR model device {self.asr_model.device}', prefix='debug') self.sos_id = self.asr_model.sos_id self.eos_id = self.asr_model.eos_id self.blank_id = None if self.model_name != 'blstm' else self.asr_model.blank_id #FIXME: will merge it later if self.decode_mode != 'greedy': if self.model_name == 'blstm': raise NotImplementedError elif self.model_name == 'transformer': raise NotImplementedError else: raise NotImplementedError
def load_data(self): logger.notice( f"Loading data from {self.data_dir} with {self.paras.njobs} threads" ) #TODO: combine the following with Metric self.id2ch = self.id2units setattr( self, 'train_set', get_loader( self.data_dir.joinpath('train'), batch_size=self.config['solver']['batch_size'], min_ilen=self.config['solver']['min_ilen'], max_ilen=self.config['solver']['max_ilen'], half_batch_ilen=self.config['solver']['half_batch_ilen'], # bucket_reverse=True, bucket_reverse=False, is_memmap=self.is_memmap, is_bucket=self.is_bucket, num_workers=self.paras.njobs, # shuffle=False, #debug )) setattr( self, 'dev_set', get_loader( self.data_dir.joinpath('dev'), batch_size=self.config['solver']['dev_batch_size'], is_memmap=self.is_memmap, is_bucket=False, shuffle=False, num_workers=self.paras.njobs, ))
def get_loader(data_dir, batch_size, is_memmap, is_bucket, num_workers=0, min_ilen=None, max_ilen=None, half_batch_ilen=None, bucket_reverse=False, shuffle=True, read_file=False, drop_last=False, pin_memory=True): assert not read_file, "Load from Kaldi ark haven't been implemented yet" dset = ESPDataset(data_dir, is_memmap) # if data is already loaded in memory if not is_memmap: num_workers = 0 logger.notice(f"Loading data from {data_dir} with {num_workers} threads") if is_bucket: my_sampler = BucketSampler(dset.ilens, min_ilen = min_ilen, max_ilen = max_ilen, half_batch_ilen = half_batch_ilen, batch_size=batch_size, bucket_size=BUCKET_SIZE, bucket_reverse=bucket_reverse, drop_last = drop_last) loader = DataLoader(dset, batch_size=1, num_workers=num_workers, collate_fn=collate_fn, batch_sampler=my_sampler, drop_last=drop_last, pin_memory=pin_memory) else: loader = DataLoader(dset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, pin_memory=pin_memory) return loader
def get_trainer(cls, config, paras, id2accent): logger.notice("LAS Trainer Inint...") class LASTrainer(cls): def __init__(self, config, paras, id2accent): super(LASTrainer, self).__init__(config, paras, id2accent) def set_model(self): self.asr_model = MonoLAS(self.id2ch, self.config['asr_model']).cuda() self.seq_loss = nn.CrossEntropyLoss(ignore_index=IGNORE_ID, reduction='none') self.sos_id = self.asr_model.sos_id self.eos_id = self.asr_model.eos_id assert self.config['asr_model']['optimizer'][ 'type'] == 'adadelta', "Use AdaDelta for pure seq2seq" self.asr_opt = getattr(torch.optim,\ self.config['asr_model']['optimizer']['type']) self.asr_opt = self.asr_opt(self.asr_model.parameters(),\ **self.config['asr_model']['optimizer_opt']) # self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # self.asr_opt, mode='min', # factor=0.2, patience=3, # verbose=True) super().load_model() def exec(self): self.train() def run_batch(self, cur_b, x, ilens, ys, olens, train): sos = ys[0].new([self.sos_id]) eos = ys[0].new([self.eos_id]) ys_out = [torch.cat] if train: info = {'loss': loss.item(), 'acc': acc} # if self.global_step % 5 == 0: if self.global_step % 500 == 0: self.probe_model(pred, ys) self.asr_opt.zero_grad() loss.backward() else: wer = self.metric_observer.batch_cal_wer( pred.detach(), ys, ['att'])['att'] info = {'wer': wer, 'loss': loss.item(), 'acc': acc} return info def probe_model(self, pred, ys_out): self.metric_observer.cal_att_wer(torch.argmax(pred[0], dim=-1), ys[0], show=True) return LASTrainer(config, paras, id2accent)
def load_model(self): logger.log("MultiASR model for pretraining initialization") if self.paras.resume: logger.notice(f"Resume pretraining from {self.global_step}") self.asr_model.load_state_dict(torch.load(self.resume_model_path)) self.dashboard.set_step(self.global_step) else: logger.notice(f"Start pretraining from {self.global_step}")
def save_best_model(self, tpe='wer', only_stat=False): assert self.asr_model is not None if not only_stat: model_save_path = self.log_dir.joinpath(f'model.{tpe}.best') logger.notice('Current best {}: {:3f}, save model to {}'.format( tpe.upper(), getattr(self, f'best_{tpe}'), model_save_path)) torch.save(self.asr_model.state_dict(), model_save_path) with open(self.log_dir.joinpath(f'best_{tpe}'), 'w') as fout: print('{} {}'.format(self.global_step, \ getattr(self,f'best_{tpe}')), file=fout)
def train(self): try: while self.global_step < self.max_step: tbar = get_bar(total=self.eval_ival, \ desc=f"Step {self.global_step}", leave=True) for _ in range(self.eval_ival): #TODO: we can add sampling method to compare Meta and Multi fair idx, (x, ilens, ys, olens) = self.data_container.get_item()[0] batch_size = len(ys) info = self._train(idx, x, ilens, ys, olens, accent_idx=idx) self.train_info.add(info, batch_size) grad_norm = nn.utils.clip_grad_norm_( self.asr_model.parameters(), GRAD_CLIP) if math.isnan(grad_norm): logger.warning( f"grad norm NaN @ step {self.global_step}") else: self.asr_opt.step() if isinstance(self.asr_opt, TransformerOptimizer): self.log_msg(self.asr_opt.lr) else: self.log_msg() self.check_evaluate() self.global_step += 1 self.dashboard.step() del x, ilens, ys, olens tbar.update(1) if self.global_step % self.save_ival == 0: self.save_per_steps() self.dashboard.check() tbar.close() except KeyboardInterrupt: logger.warning("Pretraining stopped") self.save_per_steps() self.dashboard.set_status('pretrained(SIGINT)') else: logger.notice("Pretraining completed") self.dashboard.set_status('pretrained')
def train(self): self.evaluate() try: if self.save_verbose: self.save_init() while self.ep < self.max_epoch: tbar = get_bar(total=len(self.train_set), \ desc=f"Epoch {self.ep}", leave=True) for cur_b, (x, ilens, ys, olens) in enumerate(self.train_set): batch_size = len(ys) info = self._train(cur_b, x, ilens, ys, olens) self.train_info.add(info, batch_size) grad_norm = nn.utils.clip_grad_norm_( self.asr_model.parameters(), GRAD_CLIP) if math.isnan(grad_norm): logger.warning( f"grad norm NaN @ step {self.global_step}") else: self.asr_opt.step() if isinstance(self.asr_opt, TransformerOptimizer): self.log_msg(self.asr_opt.lr) else: self.log_msg() self.check_evaluate() self.global_step += 1 self.dashboard.step() del x, ilens, ys, olens tbar.update(1) self.ep += 1 self.save_per_epoch() self.dashboard.check() tbar.close() if self.eval_every_epoch: self.evaluate() except KeyboardInterrupt: logger.warning("Training stopped") self.evaluate() self.dashboard.set_status('trained(SIGINT)') else: logger.notice("Training completed") self.dashboard.set_status('trained')
def train(self): try: task_ids = list(range(self.num_pretrain)) while self.global_step < self.max_step: tbar = get_bar(total=self.eval_ival, \ desc=f"Step {self.global_step}", leave=True) for _ in range(self.eval_ival): shuffle(task_ids) #FIXME: Here split to inner-train and inner-test (should observe whether the performance drops) for accent_id in task_ids[:self.meta_batch_size]: # inner-loop learn tr_batches = self.data_container.get_item(accent_id, self.meta_k) self.run_task(tr_batches) # inner-loop test val_batch = self.data_container.get_item(accent_id)[0] batch_size = len(val_batch[1][2]) info = self._train(val_batch[0],*val_batch[1], accent_idx = val_batch[0]) grad_norm = nn.utils.clip_grad_norm_( self.asr_model.parameters(), GRAD_CLIP) if math.isnan(grad_norm): logger.warning(f"grad norm NaN @ step {self.global_step} on {self.accents[accent_id]}, ignore...") self._partial_meta_update() del val_batch self.train_info.add(info, batch_size) self._final_meta_update() self.log_msg(self.meta_opt.lr) self.check_evaluate() self.global_step += 1 self.dashboard.step() tbar.update(1) if self.global_step % self.save_ival == 0: self.save_per_steps() self.dashboard.check() tbar.close() except KeyboardInterrupt: logger.warning("Pretraining stopped") self.save_per_steps() self.dashboard.set_status('pretrained(SIGINT)') else: logger.notice("Pretraining completed") self.dashboard.set_status('pretrained')
def load_model(self): logger.log("ASR model initialization") if self.paras.resume: logger.notice( f"Resume training from epoch {self.ep} (best cer: {self.best_cer}, best wer: {self.best_wer})" ) self.asr_model.load_state_dict(torch.load(self.resume_model_path)) if isinstance(self.asr_opt, TransformerOptimizer): with open(self.optimizer_path, 'rb') as fin: self.asr_opt = pickle.load(fin) else: self.asr_opt.load_state_dict(torch.load(self.optimizer_path)) self.dashboard.set_step(self.global_step) elif self.paras.pretrain: model_dict = self.asr_model.state_dict() logger.notice( f"Load pretraining {','.join(self.pretrain_module)} from {self.pretrain_model_path}" ) pretrain_dict = self.filter_model( torch.load(self.pretrain_model_path)) model_dict.update(pretrain_dict) self.asr_model.load_state_dict(model_dict) if 'freeze_module' in self.config['solver']: logger.warning( "Part of model will be frozen during fine-tuning") self.freeze_module(self.config['solver']['freeze_module']) logger.notice("Done!") else: # simple monolingual training from step 0 logger.notice("Training from scratch")
def get_trainer(cls, config, paras, id2accent): logger.notice("Transformer Trainer Init...") class TransformerTrainer(cls): def __init__(self, config, paras, id2accent): super(TransformerTrainer, self).__init__(config, paras, id2accent) def set_model(self): self.asr_model = Transformer(self.id2ch, self.config['asr_model']).cuda() self.asr_opt = optim.RAdam(self.asr_model.parameters(), betas=(0.9, 0.98), eps=1e-9) # self.asr_opt = TransformerOptimizer( # torch.optim.Adam(self.asr_model.parameters(), betas=(0.9, 0.98), eps=1e-09), # optim.RAdam(self.asr_model.parameters()) # self.config['asr_model']['optimizer_opt']['k'], # self.config['asr_model']['encoder']['d_model'], # self.config['asr_model']['optimizer_opt']['warmup_steps'] # ) self.label_smoothing = self.config['solver']['label_smoothing'] self.sos_id = self.asr_model.sos_id self.eos_id = self.asr_model.eos_id super().load_model() def exec(self): self.train() def run_batch(self, cur_b, x, ilens, ys, olens, train): pred, gold = self.asr_model(x, ilens, ys, olens) loss, acc = cal_performance(pred, gold, self.label_smoothing) if train: info = { 'loss': loss.item(), 'acc': acc} if self.global_step % 500 == 0: self.probe_model(pred[0], gold[0]) self.asr_opt.zero_grad() loss.backward() else: wer = self.metric_observer.batch_cal_wer(pred.detach(), gold) info = { 'wer': wer, 'loss':loss.item(), 'acc': acc} return info def probe_model(self, pred, ys_out): self.metric_observer.cal_wer(torch.argmax(pred, dim=-1), ys_out, show=True) return TransformerTrainer(config, paras, id2accent)
def set_model(self): self.asr_model = MyTransformer(self.id2ch, self.config['asr_model']).cuda() self.label_smooth_rate = self.config['solver']['label_smoothing'] if self.label_smooth_rate > 0.0: logger.log( f"Use label smoothing rate {self.label_smooth_rate}", prefix='info') # self.asr_opt = optim.RAdam(self.asr_model.parameters(), betas=(0.9, 0.98), eps=1e-9) # if self.config['asr_model']['optimizer_cls'] == 'noam': if 'inner_optimizer_cls' not in self.config[ 'asr_model']: # multi or mono if self.config['asr_model']['optimizer_cls'] == 'noam': logger.notice( "Use noam optimizer, it is recommended to be used in mono-lingual training" ) self.asr_opt = TransformerOptimizer( torch.optim.Adam(self.asr_model.parameters(), betas=(0.9, 0.98), eps=1e-09, weight_decay=1e-3), self.config['asr_model']['optimizer_opt']['k'], self.config['asr_model']['d_model'], self.config['asr_model']['optimizer_opt'] ['warmup_steps']) elif self.config['asr_model']['optimizer_cls'] == 'RAdam': logger.notice( f"Use third-library {self.config['asr_model']['optimizer_cls']} optimizer" ) self.asr_opt = getattr(extra_optim,\ self.config['asr_model']['optimizer_cls']) self.asr_opt = self.asr_opt(self.asr_model.parameters(), \ **self.config['asr_model']['optimizer_opt']) else: logger.notice( f"Use {self.config['asr_model']['optimizer_cls']} optimizer, it is recommended to be used in fine-tuning" ) self.asr_opt = getattr(torch.optim,\ self.config['asr_model']['optimizer_cls']) self.asr_opt = self.asr_opt(self.asr_model.parameters(), \ **self.config['asr_model']['optimizer_opt']) else: logger.notice( "During meta-training, model optimizer will reset after running each task" ) self.sos_id = self.asr_model.sos_id self.eos_id = self.asr_model.eos_id super().load_model()
def exec(self): if self.decode_mode != 'greedy': logger.notice(f"Start decoding with beam search (with beam size: {self.config['solver']['beam_decode']['beam_size']})") raise NotImplementedError(f"{self.decode_mode} haven't supported yet") self._decode = self.beam_decode else: logger.notice("Start greedy decoding") if self.batch_size > 1: dev = 'gpu' if self.use_gpu else 'cpu' logger.log(f"Number of utterance batches to decode: {len(self.eval_set)}, decoding with {self.batch_size} batch_size using {dev}") self._decode = self.batch_greedy_decode self.njobs = 1 else: logger.log(f"Number of utterances to decode: {len(self.eval_set)}, decoding with {self.njobs} threads using cpu") self._decode = self.greedy_decode if self.njobs > 1: try: _ = Parallel(n_jobs=self.njobs)(delayed(self._decode)(i, x, ilen, y, olen) for i, (x, ilen, y, olen) in enumerate(self.eval_set)) #NOTE: cannot log comet here, since it cannot serialize except KeyboardInterrupt: logger.warning("Decoding stopped") else: logger.notice("Decoding done") # self.comet_exp.log_other('status','decoded') else: tbar = get_bar(total=len(self.eval_set), leave=True) for cur_b, (xs, ilens, ys, olens) in enumerate(self.eval_set): self.batch_greedy_decode(xs, ilens, ys, olens) tbar.update(1)
def load_model(self): logger.log("First Order GBML ASRmodel for pretraining initialization") if self.paras.resume: logger.notice(f"Resume pretraining from {self.global_step}") self.asr_model.load_state_dict(torch.load(self.resume_model_path)) self.dashboard.set_step(self.global_step) else: logger.notice(f"Start pretraining from {self.global_step}") self._original = clone_state_dict(self.asr_model.state_dict(keep_vars=True)) params = [p for p in self._original.values() if getattr(p, 'requires_grad', False)] if self.config['asr_model']['meta_opt_cls'] == 'noam': self.meta_opt = TransformerOptimizer( torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-09), self.config['asr_model']['meta']['optimizer_opt']['k'], self.config['asr_model']['d_model'], self.config['asr_model']['meta']['optimizer_opt']['warmup_steps'] ) else: raise NotImplementedError(f"Should use noam optimizer in outer loop transformer learning, but got {self.asr_model['meta_opt_cls']}")
def get_loader(data_dir, batch_size, is_memmap, is_bucket, num_workers=0, split_rate=1.0, split_seed=531, min_ilen=None, max_ilen=None, half_batch_ilen=None, bucket_reverse=False, shuffle=True, read_file=False, drop_last=False, pin_memory=True): assert not read_file, "Load from Kaldi ark haven't been implemented yet" dset = CommonVoiceDataset(data_dir, is_memmap) if split_rate < 1.0: logger.notice(f"Only use {split_rate * 100}% data for training (Split seed: {split_seed})") tot_sz = len(dset) num_tr = int(tot_sz * split_rate) num_drop = tot_sz - num_tr dset, _ = random_split(dset, [num_tr, num_drop], generator=torch.Generator().manual_seed(split_seed)) # if data is already loaded in memory if not is_memmap: num_workers = 0 logger.notice(f"Loading data from {data_dir} with {num_workers} threads") if is_bucket and split_rate == 1.0: my_sampler = BucketSampler(dset.ilens, min_ilen = min_ilen, max_ilen = max_ilen, half_batch_ilen = half_batch_ilen, batch_size=batch_size, bucket_size=BUCKET_SIZE, bucket_reverse=bucket_reverse, drop_last = drop_last) loader = DataLoader(dset, batch_size=1, num_workers=num_workers, collate_fn=collate_fn, batch_sampler=my_sampler, drop_last=drop_last, pin_memory=pin_memory) else: logger.notice("No bucket sampling, the efficincy will be a little bit worse") loader = DataLoader(dset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=shuffle, drop_last=drop_last, pin_memory=pin_memory) return loader
def get_trainer(cls, config, paras, id2accent): logger.notice("Transformer Trainer Init...") class TransformerTrainer(cls): def __init__(self, config, paras, id2accent): super(TransformerTrainer, self).__init__(config, paras, id2accent) def set_model(self): self.asr_model = MyTransformer(self.id2ch, self.config['asr_model']).cuda() self.label_smooth_rate = self.config['solver']['label_smoothing'] if self.label_smooth_rate > 0.0: logger.log( f"Use label smoothing rate {self.label_smooth_rate}", prefix='info') # self.asr_opt = optim.RAdam(self.asr_model.parameters(), betas=(0.9, 0.98), eps=1e-9) # if self.config['asr_model']['optimizer_cls'] == 'noam': if 'inner_optimizer_cls' not in self.config[ 'asr_model']: # multi or mono if self.config['asr_model']['optimizer_cls'] == 'noam': logger.notice( "Use noam optimizer, it is recommended to be used in mono-lingual training" ) self.asr_opt = TransformerOptimizer( torch.optim.Adam(self.asr_model.parameters(), betas=(0.9, 0.98), eps=1e-09), self.config['asr_model']['optimizer_opt']['k'], self.config['asr_model']['d_model'], self.config['asr_model']['optimizer_opt'] ['warmup_steps']) elif self.config['asr_model']['optimizer_cls'] == 'RAdam': logger.notice( f"Use third-library {self.config['asr_model']['optimizer_cls']} optimizer" ) self.asr_opt = getattr(extra_optim,\ self.config['asr_model']['optimizer_cls']) self.asr_opt = self.asr_opt(self.asr_model.parameters(), \ **self.config['asr_model']['optimizer_opt']) else: logger.notice( f"Use {self.config['asr_model']['optimizer_cls']} optimizer, it is recommended to be used in fine-tuning" ) self.asr_opt = getattr(torch.optim,\ self.config['asr_model']['optimizer_cls']) self.asr_opt = self.asr_opt(self.asr_model.parameters(), \ **self.config['asr_model']['optimizer_opt']) else: logger.notice( "During meta-training, model optimizer will reset after running each task" ) self.sos_id = self.asr_model.sos_id self.eos_id = self.asr_model.eos_id super().load_model() def exec(self): self.train() def run_batch(self, cur_b, x, ilens, ys, olens, train, accent_idx=None): # if accent_idx is not None -> monolingual training batch_size = len(ys) pred, gold = self.asr_model(x, ilens, ys, olens) pred_cat = pred.view(-1, pred.size(2)) gold_cat = gold.contiguous().view(-1) non_pad_mask = gold_cat.ne(IGNORE_ID) n_total = non_pad_mask.sum().item() if self.label_smooth_rate > 0.0: eps = self.label_smooth_rate n_class = pred_cat.size(1) gold_for_scatter = gold_cat.ne(IGNORE_ID).long() * gold_cat one_hot = torch.zeros_like(pred_cat).scatter( 1, gold_for_scatter.view(-1, 1), 1) one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / n_class log_prb = F.log_softmax(pred_cat, dim=-1) loss = -(one_hot * log_prb).sum(dim=1) loss = loss.masked_select(non_pad_mask).sum() / n_total else: loss = F.cross_entropy(pred_cat, gold_cat, ignore_index=IGNORE_ID) pred_cat = pred_cat.detach().max(1)[1] n_correct = pred_cat.eq(gold_cat).masked_select( non_pad_mask).sum().item() if train: info = {'loss': loss.item(), 'acc': float(n_correct) / n_total} # if self.global_step % 5 == 0: if self.global_step % 500 == 0: self.probe_model(pred.detach(), gold, accent_idx) self.asr_opt.zero_grad() loss.backward() else: cer = self.metric_observer.batch_cal_er( pred.detach(), gold, ['att'], ['cer'])['att_cer'] wer = self.metric_observer.batch_cal_er( pred.detach(), gold, ['att'], ['wer'])['att_wer'] info = { 'cer': cer, 'wer': wer, 'loss': loss.item(), 'acc': float(n_correct) / n_total } return info def probe_model(self, pred, ys_out, accent_idx): if accent_idx is not None: logger.log(f"Probe on {self.accents[accent_idx]}", prefix='debug') self.metric_observer.cal_att_cer(torch.argmax(pred[0], dim=-1), ys_out[0], show=True, show_decode=True) self.metric_observer.cal_att_wer(torch.argmax(pred[0], dim=-1), ys_out[0], show=True) return TransformerTrainer(config, paras, id2accent)
ret = list(map(int, s.split(' '))) ret = [id2units[i] for i in ret] return ret def remove_accent(s): return unidecode.unidecode(s) def filter(s): # return re.sub(r'([,.!:-?"\'])\1+', r'\1', s) return s.translate({ord(c): None for c in IGNORE_CH_LIST}) ### Cal CER #################################################################### logger.notice("CER calculating...") cer = 0.0 with open(Path(decode_dir, 'best-hyp'), 'r') as hyp_ref_in: cnt = 0 for line in hyp_ref_in.readlines(): cnt += 1 ref, hyp = line.rstrip().split('\t') ref = filter(remove_accent(spm.DecodePieces(to_list(ref)))) ref.upper() hyp = filter(remove_accent(spm.DecodePieces(to_list(hyp)))) hyp.upper() cer += (editdistance.eval(ref, hyp) / len(ref) * 100) cer = cer / cnt logger.log(f"CER: {cer}", prefix='test') comet_exp.log_other(f"cer({paras.decode_mode})", round(cer, 2)) with open(Path(decode_dir, 'cer'), 'w') as fout:
def __init__(self, config, paras, log_dir, train_type, resume=False): self.log_dir = log_dir self.expkey_f = Path(self.log_dir, 'exp_key') self.global_step = 1 if resume: assert self.expkey_f.exists( ), f"Cannot find comet exp key in {self.log_dir}" with open(Path(self.log_dir, 'exp_key'), 'r') as f: exp_key = f.read().strip() self.exp = ExistingExperiment( previous_experiment=exp_key, project_name=COMET_PROJECT_NAME, workspace=COMET_WORKSPACE, auto_output_logging=None, auto_metric_logging=None, display_summary_level=0, ) else: self.exp = Experiment( project_name=COMET_PROJECT_NAME, workspace=COMET_WORKSPACE, auto_output_logging=None, auto_metric_logging=None, display_summary_level=0, ) #TODO: is there exists better way to do this? with open(self.expkey_f, 'w') as f: print(self.exp.get_key(), file=f) self.exp.log_other('seed', paras.seed) self.log_config(config) if train_type == 'evaluation': if paras.pretrain: self.exp.set_name( f"{paras.pretrain_suffix}-{paras.eval_suffix}") self.exp.add_tags([ paras.pretrain_suffix, config['solver']['setting'], paras.accent, paras.algo, paras.eval_suffix ]) if paras.pretrain_model_path: self.exp.log_other("pretrain-model-path", paras.pretrain_model_path) else: self.exp.log_other("pretrain-runs", paras.pretrain_runs) self.exp.log_other("pretrain-setting", paras.pretrain_setting) self.exp.log_other("pretrain-tgt-accent", paras.pretrain_tgt_accent) else: self.exp.set_name(paras.eval_suffix) self.exp.add_tags( ["mono", config['solver']['setting'], paras.accent]) else: self.exp.set_name(paras.pretrain_suffix) self.exp.log_others({ f"accent{i}": k for i, k in enumerate(paras.pretrain_accents) }) self.exp.log_other('accent', paras.tgt_accent) self.exp.add_tags([ paras.algo, config['solver']['setting'], paras.tgt_accent ]) #TODO: Need to add pretrain setting ##slurm-related hostname = os.uname()[1] if len(hostname.split('.')) == 2 and hostname.split( '.')[1] == 'speech': logger.notice(f"Running on Battleship {hostname}") self.exp.log_other('jobid', int(os.getenv('SLURM_JOBID'))) else: logger.notice(f"Running on {hostname}")
def __init__(self, config, paras, id2accent): self.config = config self.paras = paras self.train_type = 'evaluation' self.is_memmap = paras.is_memmap self.model_name = paras.model_name self.njobs = paras.njobs if paras.algo == 'no' and paras.pretrain_suffix is None: paras.pretrain_suffix = paras.eval_suffix ### Set path cur_path = Path.cwd() self.data_dir = Path(config['solver']['data_root'], id2accent[paras.accent]) self.log_dir = Path(cur_path, LOG_DIR, self.train_type, config['solver']['setting'], paras.algo, \ paras.pretrain_suffix, paras.eval_suffix, \ id2accent[paras.accent], str(paras.runs)) self.model_path = Path(self.log_dir, paras.test_model) assert self.model_path.exists(), f"{self.model_path.as_posix()} not exists..." self.decode_dir = Path(self.log_dir, paras.decode_suffix) ### Decode self.decode_mode = paras.decode_mode self.beam_decode_param = config['solver']['beam_decode'] self.batch_size = paras.decode_batch_size self.use_gpu = paras.cuda if paras.decode_mode == 'lm_beam': assert paras.lm_model_path is not None, "In LM Beam decode mode, lm_model_path should be specified" # assert self.model_name == 'blstm', "LM Beam decode is only supported in blstm model" self.lm_model_path = paras.lm_model_path else : self.lm_model_path = None # if paras.decode_mode == 'greedy': # self._decode = self.greedy_decode # elif paras.decode_mode == 'beam' or paras.decode_mode == 'lm_beam': # self._decode = self.beam_decode # else : # raise NotImplementedError ##################################################################### ### Resume Mechanism if not paras.resume: if self.decode_dir.exists(): assert paras.overwrite, \ f"Path exists ({self.decode_dir}). Use --overwrite or change decode suffix" # time.sleep(10) logger.warning('Overwrite existing directory') rmtree(self.decode_dir) self.decode_dir.mkdir(parents=True) self.prev_decode_step = -1 else: with open(Path(self.decode_dir,'best-hyp'),'r') as f: for i, l in enumerate(f): pass self.prev_decode_step = i+1 logger.notice(f"Decode resume from {self.prev_decode_step}") ### Comet with open(Path(self.log_dir,'exp_key'),'r') as f: exp_key = f.read().strip() comet_exp = ExistingExperiment(previous_experiment=exp_key, project_name=COMET_PROJECT_NAME, workspace=COMET_WORKSPACE, auto_output_logging=None, auto_metric_logging=None, display_summary_level=0, ) comet_exp.log_other('status','decode')
def get_trainer(cls, config, paras, id2accent): logger.notice("BLSTM Trainer Init...") class BLSTMTrainer(cls): def __init__(self, config, paras, id2accent): super(BLSTMTrainer, self).__init__(config, paras, id2accent) def set_model(self): self.asr_model = MonoBLSTM(self.id2ch, self.config['asr_model']).cuda() self.ctc_loss = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True) self.sos_id = self.asr_model.sos_id self.eos_id = self.asr_model.eos_id self.asr_opt = getattr(torch.optim, \ self.config['asr_model']['optimizer']['type']) self.asr_opt = self.asr_opt(self.asr_model.parameters(), \ **self.config['asr_model']['optimizer_opt']) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.asr_opt, mode='min', factor=0.2, patience=3, verbose=True) super().load_model() self.freeze_encoder(paras.freeze_layer) def freeze_encoder(self, module): if module is not None: if module == 'VGG': for p in self.asr_model.encoder.vgg.parameters(): p.requires_grad = False elif module == 'VGG_BLSTM': for p in self.asr_model.encoder.parameters(): p.requires_grad = False else: raise ValueError( f"Unknown freeze layer {module} (VGG, VGG_BLSTM)") logger.log(f"Freeze {' '.join(module.split('_'))} layer", prefix='info') def exec(self): self.train() def run_batch(self, cur_b, x, ilens, ys, olens, train): sos = ys[0].new([self.sos_id]) eos = ys[0].new([self.eos_id]) ys_out = [torch.cat([sos, y, eos], dim=0) for y in ys] olens += 2 # pad <sos> and <eos> y_true = torch.cat(ys_out) pred, enc_lens = self.asr_model(x, ilens) olens = to_device(self.asr_model, olens) pred = F.log_softmax(pred, dim=-1) # (T, o_dim) loss = self.ctc_loss( pred.transpose(0, 1).contiguous(), y_true.cuda().to(dtype=torch.long), enc_lens.cpu().to(dtype=torch.long), olens.cpu().to(dtype=torch.long)) if train: info = {'loss': loss.item()} # if self.global_step % 5 == 0: if self.global_step % 500 == 0: self.probe_model(pred, ys) self.asr_opt.zero_grad() loss.backward() else: cer = self.metric_observer.batch_cal_er( pred.detach(), ys, ['ctc'], ['cer'])['ctc_cer'] wer = self.metric_observer.batch_cal_er( pred.detach(), ys, ['ctc'], ['wer'])['ctc_wer'] info = {'cer': cer, 'wer': wer, 'loss': loss.item()} return info def probe_model(self, pred, ys): self.metric_observer.cal_ctc_cer(torch.argmax(pred[0], dim=-1), ys[0], show=True, show_decode=True) self.metric_observer.cal_ctc_wer(torch.argmax(pred[0], dim=-1), ys[0], show=True) return BLSTMTrainer(config, paras, id2accent)