def __init__(self, src_vocab_size, trg_vocab_size): super(Speech2Speech, self).__init__() #self.inter_speech=InterSpeech(src_vocab_size,trg_vocab_size) #self.transcoder=BaseEncoder(hp.nmt_h_size,hp.tts_h_size,hp.drop,hp.bidirectional,hp.rnn) self.asr = ASR(src_vocab) self.nmt = NMT(src_vocab, trg_vocab) self.taco = Tacotron(trg_vocab_size)
def __init__(self, src_lang, trg_lang, segment, log_dir, batch_size, n_workers, epochs): self.name = "NMT" self.data_type = "TextText" super().__init__(src_lang, trg_lang, segment, log_dir, batch_size, n_workers, epochs) self.model = NMT(len(self.dataset.src_vocab[self.segment]), len(self.dataset.trg_vocab[self.segment])) self.path = os.path.join(self.path, self.model.get_path(), self.segment) if not os.path.exists(self.path + "/CHAMP"): os.makedirs(self.path + "/CHAMP") self.criterion = nn.CrossEntropyLoss(ignore_index=0) self.lr = hp.lr self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.train_loss = list() self.dev_loss = list() self.test_result = 0 self.decode_mode = 'normal' self.load()
class MTTTS(nn.Module): def __init__(self,src_vocab_size,trg_vocab_size): super(MTTTS, self).__init__() self.nmt=NMT(src_vocab_size,trg_vocab_size) self.tts=Tacotron(trg_vocab_size) def __call__(self,src_txt=None,trg_txt=None,trg_mel=None,teacher_forcing_ratio=1.,mode="normal"): nmt_out_txt,nmt_att_score=self.nmt(src_txt,trg_txt,teacher_forcing_ratio) if mode=='normal': mel_out,stop_targets,tts_att_score=self.tts(nmt_out_txt.argmax(-1),trg_mel) elif mode=='osamura': mel_out,stop_targets,tts_att_score=self.tts(F.softmax(nmt_out_txt,-1),trg_mel) return nmt_out_txt,mel_out,stop_targets,nmt_att_score,tts_att_score def get_path(self): return "_".join(hp.rnn)+"_nmt_hid"+str(hp.nmt_h_size)+"_depth"+str(len(hp.rnn))+"/" def load(self,args): nmt_path=os.path.join("./exp","NMT",args.src_lang+"2"+args.trg_lang,self.nmt.get_path(),args.segment) self.nmt.load_state_dict(torch.load(os.path.join(nmt_path,"CHAMP","NMT"))) tts_path=os.path.join("./exp","Tacotron",args.trg_lang,self.tts.get_path(),args.segment) self.tts.load_state_dict(torch.load(os.path.join(tts_path,"CHAMP","Tacotron")))
def __init__(self, src_vocab_size, trg_vocab_size): super(InterSpeech, self).__init__() self.asr = ASR(src_vocab_size) self.nmt = NMT(src_vocab_size, trg_vocab_size) self.transcoder = BaseEncoder(hp.asr_h_size, hp.nmt_h_size, hp.drop, hp.bidirectional, hp.rnn)
class InterSpeech(nn.Module): def __init__(self, src_vocab_size, trg_vocab_size): super(InterSpeech, self).__init__() self.asr = ASR(src_vocab_size) self.nmt = NMT(src_vocab_size, trg_vocab_size) self.transcoder = BaseEncoder(hp.asr_h_size, hp.nmt_h_size, hp.drop, hp.bidirectional, hp.rnn) def __call__(self, mel, src_txt=None, trg_txt=None, teacher_forcing_ratio=1., phase=2): #self.flatten_parameters() # with torch.set_grad_enabled(self.training): #import pdb; pdb.set_trace() with torch.set_grad_enabled(phase in {1, 2} and self.training): asr_enc_out = self.asr.encode(mel, skip_step=2) asr_out_txt, asr_att_score, context = self.asr.decode( src_txt, asr_enc_out, teacher_forcing_ratio) with torch.set_grad_enabled(self.training): trans_out = self.transcoder(context) with torch.set_grad_enabled(phase == 2 and self.training): out_txt, att_score, _ = self.nmt.decode(trg_txt, trans_out, teacher_forcing_ratio) return trans_out, out_txt, att_score def get_path(self): return "_".join(hp.rnn) + "_asr_hid" + str( hp.asr_h_size) + "_nmt_hid" + str(hp.nmt_h_size) + "_depth" + str( len(hp.rnn)) + "/" def load(self, args): asr_path = os.path.join("./exp", "ASR", args.src_lang, self.asr.get_path(), args.segment) self.asr.load_state_dict( torch.load(os.path.join(asr_path, "CHAMP", "ASR"))) nmt_path = os.path.join("./exp", "NMT", args.src_lang + "2" + args.trg_lang, self.nmt.get_path(), args.segment) self.nmt.load_state_dict( torch.load(os.path.join(nmt_path, "CHAMP", "NMT"))) def encode_wo_dump(self, mel, src_txt=None, trg_txt=None, teacher_forcing_ratio=1.): with torch.no_grad(): asr_enc_out = self.asr.encode(mel, skip_step=2) asr_out_txt, asr_att_score, context = self.asr.decode( src_txt, asr_enc_out, teacher_forcing_ratio) trans_out = self.transcoder(context) out_txt, att_score, result = self.nmt.decode( trg_txt, trans_out, teacher_forcing_ratio) return asr_out_txt, out_txt, result def encode(self, mel, src_txt=None, trg_txt=None, teacher_forcing_ratio=1.): asr_enc_out = self.asr.encode(mel, skip_step=2) asr_out_txt, asr_att_score, context = self.asr.decode( src_txt, asr_enc_out, teacher_forcing_ratio) trans_out = self.transcoder(context) out_txt, att_score, _ = self.nmt.decode(trg_txt, trans_out, teacher_forcing_ratio) return trans_out, out_txt, att_score
def __init__(self,src_vocab_size,trg_vocab_size): super(MTTTS, self).__init__() self.nmt=NMT(src_vocab_size,trg_vocab_size) self.tts=Tacotron(trg_vocab_size)
class train(base.base): def __init__(self, src_lang, trg_lang, segment, log_dir, batch_size, n_workers, epochs): self.name = "NMT" self.data_type = "TextText" super().__init__(src_lang, trg_lang, segment, log_dir, batch_size, n_workers, epochs) self.model = NMT(len(self.dataset.src_vocab[self.segment]), len(self.dataset.trg_vocab[self.segment])) self.path = os.path.join(self.path, self.model.get_path(), self.segment) if not os.path.exists(self.path + "/CHAMP"): os.makedirs(self.path + "/CHAMP") self.criterion = nn.CrossEntropyLoss(ignore_index=0) self.lr = hp.lr self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.train_loss = list() self.dev_loss = list() self.test_result = 0 self.decode_mode = 'normal' self.load() def load(self, ): if super().__load__(): if os.path.exists(self.path + "/train_loss"): with open(self.path + "/train_loss") as f: for i in f.readlines(): i = i.strip() self.train_loss.append(float(i)) if os.path.exists(self.path + "/dev_loss"): with open(self.path + "/dev_loss") as f: for i in f.readlines(): i = i.strip() self.dev_loss.append(float(i)) #self.epoch=len(self.dev_loss) #print("Epoch %d train loss %3f dev BLEU %2f"%(self.epoch,self.train_loss[-1],max(self.dev_loss))) def save(self, ): #super().__save__() with open(self.path + "/train_loss", "w") as f: for i in self.train_loss: f.write(str(float(i)) + '\n') with open(self.path + "/dev_loss", "w") as f: for i in self.dev_loss: f.write(str(float(i)) + '\n') if max(self.dev_loss) == self.dev_loss[-1]: torch.save(self.model.state_dict(), os.path.join(self.path, "CHAMP", self.name)) print("Saving model at ", self.epoch) if self.dev_loss[-1] >= hp.stop_learning: self.epoch = self.epochs show_loss(self.train_loss, self.path + "/train_loss.pdf", self.criterion._get_name()) show_loss(self.dev_loss, self.path + "/dev_loss.pdf", "BLEUp1") def dump(self): pass def eval(self, mode=None): with torch.no_grad(): out_txt, att_score = self.model( self.data['src']['id_' + self.segment].cuda()) out_txt = out_txt[0].argmax(-1).cpu().numpy() try: ref = self.data['trg']['raw_' + self.segment] except: ref = self.dataset.i2w( self.data['trg']['id_' + self.segment][0].cpu().numpy(), self.segment, self.trg_lang) hyp = self.dataset.i2w(out_txt, self.segment, self.trg_lang) loss = BLEUp1(hyp, ref) self.total_loss += loss return loss, hyp, ref, "BLEU+1 %0.1f " % (self.total_loss / (self.i + 1)) def train(self): src, trg = self.data['src']['id_' + self.segment].cuda(), self.data['trg'][ 'id_' + self.segment].cuda() out_txt, att_score = self.model(src, trg, self.teacher_forcing_ratio, mode=self.decode_mode) loss = self.criterion(out_txt.transpose(1, 2), trg) loss.backward() nn.utils.clip_grad_norm_(self.model.parameters(), 1.) self.optimizer.step() self.model.zero_grad() self.optimizer.zero_grad() self.total_loss += loss.item() return "Epoch %d tr %0.1f CE_loss %0.1f" % ( self.epoch, self.teacher_forcing_ratio, self.total_loss / (self.i + 1))