Exemple #1
0
 def __init__(self, src_vocab_size, trg_vocab_size):
     super(Speech2Speech, self).__init__()
     #self.inter_speech=InterSpeech(src_vocab_size,trg_vocab_size)
     #self.transcoder=BaseEncoder(hp.nmt_h_size,hp.tts_h_size,hp.drop,hp.bidirectional,hp.rnn)
     self.asr = ASR(src_vocab)
     self.nmt = NMT(src_vocab, trg_vocab)
     self.taco = Tacotron(trg_vocab_size)
Exemple #2
0
 def __init__(self, src_lang, trg_lang, segment, log_dir, batch_size,
              n_workers, epochs):
     self.name = "NMT"
     self.data_type = "TextText"
     super().__init__(src_lang, trg_lang, segment, log_dir, batch_size,
                      n_workers, epochs)
     self.model = NMT(len(self.dataset.src_vocab[self.segment]),
                      len(self.dataset.trg_vocab[self.segment]))
     self.path = os.path.join(self.path, self.model.get_path(),
                              self.segment)
     if not os.path.exists(self.path + "/CHAMP"):
         os.makedirs(self.path + "/CHAMP")
     self.criterion = nn.CrossEntropyLoss(ignore_index=0)
     self.lr = hp.lr
     self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
     self.train_loss = list()
     self.dev_loss = list()
     self.test_result = 0
     self.decode_mode = 'normal'
     self.load()
Exemple #3
0
class MTTTS(nn.Module):
    def __init__(self,src_vocab_size,trg_vocab_size):
        super(MTTTS, self).__init__()
        self.nmt=NMT(src_vocab_size,trg_vocab_size)
        self.tts=Tacotron(trg_vocab_size)

    def __call__(self,src_txt=None,trg_txt=None,trg_mel=None,teacher_forcing_ratio=1.,mode="normal"):
        nmt_out_txt,nmt_att_score=self.nmt(src_txt,trg_txt,teacher_forcing_ratio)
        if mode=='normal':
            mel_out,stop_targets,tts_att_score=self.tts(nmt_out_txt.argmax(-1),trg_mel)

        elif mode=='osamura':
            mel_out,stop_targets,tts_att_score=self.tts(F.softmax(nmt_out_txt,-1),trg_mel)

        return nmt_out_txt,mel_out,stop_targets,nmt_att_score,tts_att_score

    def get_path(self):
        return "_".join(hp.rnn)+"_nmt_hid"+str(hp.nmt_h_size)+"_depth"+str(len(hp.rnn))+"/"

    def load(self,args):
        nmt_path=os.path.join("./exp","NMT",args.src_lang+"2"+args.trg_lang,self.nmt.get_path(),args.segment)
        self.nmt.load_state_dict(torch.load(os.path.join(nmt_path,"CHAMP","NMT")))
        tts_path=os.path.join("./exp","Tacotron",args.trg_lang,self.tts.get_path(),args.segment)
        self.tts.load_state_dict(torch.load(os.path.join(tts_path,"CHAMP","Tacotron")))
Exemple #4
0
 def __init__(self, src_vocab_size, trg_vocab_size):
     super(InterSpeech, self).__init__()
     self.asr = ASR(src_vocab_size)
     self.nmt = NMT(src_vocab_size, trg_vocab_size)
     self.transcoder = BaseEncoder(hp.asr_h_size, hp.nmt_h_size, hp.drop,
                                   hp.bidirectional, hp.rnn)
Exemple #5
0
class InterSpeech(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size):
        super(InterSpeech, self).__init__()
        self.asr = ASR(src_vocab_size)
        self.nmt = NMT(src_vocab_size, trg_vocab_size)
        self.transcoder = BaseEncoder(hp.asr_h_size, hp.nmt_h_size, hp.drop,
                                      hp.bidirectional, hp.rnn)

    def __call__(self,
                 mel,
                 src_txt=None,
                 trg_txt=None,
                 teacher_forcing_ratio=1.,
                 phase=2):
        #self.flatten_parameters()
        #        with torch.set_grad_enabled(self.training):
        #import pdb; pdb.set_trace()
        with torch.set_grad_enabled(phase in {1, 2} and self.training):
            asr_enc_out = self.asr.encode(mel, skip_step=2)
            asr_out_txt, asr_att_score, context = self.asr.decode(
                src_txt, asr_enc_out, teacher_forcing_ratio)

        with torch.set_grad_enabled(self.training):
            trans_out = self.transcoder(context)

        with torch.set_grad_enabled(phase == 2 and self.training):
            out_txt, att_score, _ = self.nmt.decode(trg_txt, trans_out,
                                                    teacher_forcing_ratio)

        return trans_out, out_txt, att_score

    def get_path(self):
        return "_".join(hp.rnn) + "_asr_hid" + str(
            hp.asr_h_size) + "_nmt_hid" + str(hp.nmt_h_size) + "_depth" + str(
                len(hp.rnn)) + "/"

    def load(self, args):

        asr_path = os.path.join("./exp", "ASR", args.src_lang,
                                self.asr.get_path(), args.segment)
        self.asr.load_state_dict(
            torch.load(os.path.join(asr_path, "CHAMP", "ASR")))
        nmt_path = os.path.join("./exp", "NMT",
                                args.src_lang + "2" + args.trg_lang,
                                self.nmt.get_path(), args.segment)
        self.nmt.load_state_dict(
            torch.load(os.path.join(nmt_path, "CHAMP", "NMT")))

    def encode_wo_dump(self,
                       mel,
                       src_txt=None,
                       trg_txt=None,
                       teacher_forcing_ratio=1.):
        with torch.no_grad():
            asr_enc_out = self.asr.encode(mel, skip_step=2)
            asr_out_txt, asr_att_score, context = self.asr.decode(
                src_txt, asr_enc_out, teacher_forcing_ratio)
            trans_out = self.transcoder(context)
            out_txt, att_score, result = self.nmt.decode(
                trg_txt, trans_out, teacher_forcing_ratio)

        return asr_out_txt, out_txt, result

    def encode(self,
               mel,
               src_txt=None,
               trg_txt=None,
               teacher_forcing_ratio=1.):

        asr_enc_out = self.asr.encode(mel, skip_step=2)
        asr_out_txt, asr_att_score, context = self.asr.decode(
            src_txt, asr_enc_out, teacher_forcing_ratio)
        trans_out = self.transcoder(context)
        out_txt, att_score, _ = self.nmt.decode(trg_txt, trans_out,
                                                teacher_forcing_ratio)

        return trans_out, out_txt, att_score
Exemple #6
0
 def __init__(self,src_vocab_size,trg_vocab_size):
     super(MTTTS, self).__init__()
     self.nmt=NMT(src_vocab_size,trg_vocab_size)
     self.tts=Tacotron(trg_vocab_size)
Exemple #7
0
class train(base.base):
    def __init__(self, src_lang, trg_lang, segment, log_dir, batch_size,
                 n_workers, epochs):
        self.name = "NMT"
        self.data_type = "TextText"
        super().__init__(src_lang, trg_lang, segment, log_dir, batch_size,
                         n_workers, epochs)
        self.model = NMT(len(self.dataset.src_vocab[self.segment]),
                         len(self.dataset.trg_vocab[self.segment]))
        self.path = os.path.join(self.path, self.model.get_path(),
                                 self.segment)
        if not os.path.exists(self.path + "/CHAMP"):
            os.makedirs(self.path + "/CHAMP")
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)
        self.lr = hp.lr
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.train_loss = list()
        self.dev_loss = list()
        self.test_result = 0
        self.decode_mode = 'normal'
        self.load()

    def load(self, ):
        if super().__load__():
            if os.path.exists(self.path + "/train_loss"):
                with open(self.path + "/train_loss") as f:
                    for i in f.readlines():
                        i = i.strip()
                        self.train_loss.append(float(i))

            if os.path.exists(self.path + "/dev_loss"):
                with open(self.path + "/dev_loss") as f:
                    for i in f.readlines():
                        i = i.strip()
                        self.dev_loss.append(float(i))

            #self.epoch=len(self.dev_loss)
            #print("Epoch %d train loss %3f dev BLEU %2f"%(self.epoch,self.train_loss[-1],max(self.dev_loss)))

    def save(self, ):
        #super().__save__()
        with open(self.path + "/train_loss", "w") as f:
            for i in self.train_loss:
                f.write(str(float(i)) + '\n')

        with open(self.path + "/dev_loss", "w") as f:
            for i in self.dev_loss:
                f.write(str(float(i)) + '\n')

        if max(self.dev_loss) == self.dev_loss[-1]:
            torch.save(self.model.state_dict(),
                       os.path.join(self.path, "CHAMP", self.name))
            print("Saving model at ", self.epoch)
            if self.dev_loss[-1] >= hp.stop_learning:
                self.epoch = self.epochs

        show_loss(self.train_loss, self.path + "/train_loss.pdf",
                  self.criterion._get_name())
        show_loss(self.dev_loss, self.path + "/dev_loss.pdf", "BLEUp1")

    def dump(self):
        pass

    def eval(self, mode=None):
        with torch.no_grad():
            out_txt, att_score = self.model(
                self.data['src']['id_' + self.segment].cuda())
        out_txt = out_txt[0].argmax(-1).cpu().numpy()
        try:
            ref = self.data['trg']['raw_' + self.segment]
        except:
            ref = self.dataset.i2w(
                self.data['trg']['id_' + self.segment][0].cpu().numpy(),
                self.segment, self.trg_lang)
        hyp = self.dataset.i2w(out_txt, self.segment, self.trg_lang)
        loss = BLEUp1(hyp, ref)
        self.total_loss += loss
        return loss, hyp, ref, "BLEU+1 %0.1f " % (self.total_loss /
                                                  (self.i + 1))

    def train(self):
        src, trg = self.data['src']['id_' +
                                    self.segment].cuda(), self.data['trg'][
                                        'id_' + self.segment].cuda()
        out_txt, att_score = self.model(src,
                                        trg,
                                        self.teacher_forcing_ratio,
                                        mode=self.decode_mode)
        loss = self.criterion(out_txt.transpose(1, 2), trg)
        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), 1.)
        self.optimizer.step()
        self.model.zero_grad()
        self.optimizer.zero_grad()
        self.total_loss += loss.item()
        return "Epoch %d tr %0.1f CE_loss %0.1f" % (
            self.epoch, self.teacher_forcing_ratio, self.total_loss /
            (self.i + 1))