def update_core_impl(self, batch): # read scp files # x: original json with loaded features # will be converted to chainer variable later # batch only has one minibatch utterance, which is specified by batch[0] if len(batch[0]) < self.num_gpu: logging.warning('batch size is less than number of gpus. Ignored') return data = self.converter( batch, use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim) # Compute the loss at this time step and accumulate it asr_texts, asr_feats, asr_featlens = get_asr_torch(self.model, data) tts_texts, tts_textlens, tts_feats, tts_labels, tts_featlens, spembs = \ get_tts_data(self.model, data, 'text', use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim) avg_textlen = float(np.mean(tts_textlens.data.cpu().numpy())) if data[0][1]['utt2mode'] == 'p': logging.info("parallel data mode") if ALL_MODE: modes = ['asr', 'tts', 's2s', 't2t'] else: modes = ['asr', 'tts'] random.shuffle(modes) # shuffle loss_data_sum = 0.0 use_mmd = self.args.mmd_weight != 0.0 for mode in modes: if mode == 'asr': if self.args.asr_weight == 0.0: asr_loss_data = 0 # use nan? asr_acc = 0 continue loss, asr_acc = self.model.asr_loss( asr_feats, asr_featlens, asr_texts, do_report=False, report_acc=True) # disable reporter if NO_AVG: pass else: loss = loss / avg_textlen asr_loss_data = loss.item() chainer.reporter.report({'t/asr_loss': asr_loss_data}) chainer.reporter.report({'t/asr_acc': asr_acc}) loss_data_sum += asr_loss_data logging.info("asr_loss_data: %f", asr_loss_data) self.gradient_decent(loss, self.opts[mode]) if mode == 'tts': if self.args.tts_weight == 0.0: tts_loss_data = 0 # use nan? continue loss = self.model.tts_loss(tts_texts, tts_textlens, tts_feats, tts_labels, tts_featlens, spembs=spembs, do_report=False) tts_loss_data = loss.item() chainer.reporter.report({'t/tts_loss': tts_loss_data}) loss_data_sum += tts_loss_data logging.info("tts_loss_data: %f", tts_loss_data) self.gradient_decent(loss, self.opts[mode]) if mode == 's2s': if self.args.s2s_weight == 0.0 and not use_mmd: s2s_loss_data = 0.0 continue result = self.model.ae_speech( data, return_hidden=True, return_inout=self.args.use_inout_mmd) if self.args.use_inout_mmd: loss, speech_in, speech_in_len, speech_out, speech_out_len = result else: loss, hspad, hslen = result s2s_loss_data = loss.item() chainer.reporter.report({'t/s2s_loss': s2s_loss_data}) loss_data_sum += s2s_loss_data logging.info("s2s_loss_data: %f", s2s_loss_data) if self.args.s2s_weight != 0.0: self.gradient_decent(loss, self.opts[mode], freeze_att=FREEZE_ATT, retain_graph=use_mmd) if mode == 't2t': if self.args.t2t_weight == 0.0 and not use_mmd: t2t_loss_data = 0.0 continue result = self.model.ae_text( data, return_hidden=True, return_inout=self.args.use_inout_mmd) if self.args.use_inout_mmd: loss, t2t_acc, text_in, text_in_len, text_out, text_out_len = result else: loss, t2t_acc, htpad, htlen = result if NO_AVG: pass else: loss = loss / avg_textlen t2t_loss_data = loss.item() chainer.reporter.report({'t/t2t_loss': t2t_loss_data}) chainer.reporter.report({'t/t2t_acc': t2t_acc}) loss_data_sum += t2t_loss_data logging.info("t2t_loss_data: %f", t2t_loss_data) if self.args.t2t_weight != 0.0: self.gradient_decent(loss, self.opts[mode], freeze_att=FREEZE_ATT, retain_graph=use_mmd) logging.info("loss_data_sum: %f", loss_data_sum) if use_mmd: if self.args.use_inout_mmd: logging.warning((speech_in.shape, speech_out.shape, speech_in_len, speech_out_len)) speech_loss = packed_mmd(speech_in, speech_in_len, speech_out, speech_out_len) chainer.reporter.report({'t/ps_mmd': speech_loss.item()}) text_loss = packed_mmd(text_in, text_in_len, text_out, text_out_len) chainer.reporter.report({'t/pt_mmd': text_loss.item()}) loss = (speech_loss + text_loss) / 2.0 chainer.reporter.report({'t/p_mmd': loss.item()}) else: if self.args.inter_domain_loss == "kl": loss_fun = packed_gauss_kld else: loss_fun = packed_mmd loss = loss_fun(hspad, hslen, htpad, htlen) chainer.reporter.report({'t/mmd_loss': loss.item()}) self.gradient_decent(loss, self.opts["mmd"], freeze_att=FREEZE_ATT) loss_data = (loss_data_sum + loss.item()) / 5.0 elif ALL_MODE: loss_data = loss_data_sum / 4.0 else: loss_data = loss_data_sum / 2.0 logging.info("loss_data: %f", loss_data) chainer.reporter.report({'t/loss': loss_data}) elif data[0][1]['utt2mode'] == 'a': logging.info("audio only mode") result = self.model.ae_speech(data, return_hidden=True, return_inout=self.args.use_inout_mmd) if self.args.use_inout_mmd: s2s_loss, speech_in, speech_in_len, speech_out, speech_out_len = result speech_loss = packed_mmd(speech_in, speech_in_len, speech_out, speech_out_len) chainer.reporter.report({'t/us_mmd': speech_loss.item()}) loss = (s2s_loss + speech_loss) / 2.0 else: s2s_loss, hspad, hslen = result loss = s2s_loss self.gradient_decent(loss, self.opts['s2s'], freeze_att=FREEZE_ATT) loss_data = s2s_loss.item() logging.info("loss: %f", loss_data) chainer.reporter.report({'t/s2s_loss': loss_data}) elif data[0][1]['utt2mode'] == 't': logging.info("text only mode") # t2t_loss, t2t_acc = self.model.ae_text(data) result = self.model.ae_text(data, return_hidden=True, return_inout=self.args.use_inout_mmd) if self.args.use_inout_mmd: t2t_loss, t2t_acc, text_in, text_in_len, text_out, text_out_len = result text_loss = packed_mmd(text_in, text_in_len, text_out, text_out_len) chainer.reporter.report({'t/ut_mmd': text_loss.item()}) loss = (t2t_loss + text_loss) / 2.0 else: t2t_loss, t2t_acc, htpad, htlen = result loss = t2t_loss if NO_AVG: loss = t2t_loss else: loss = t2t_loss / avg_textlen self.gradient_decent(loss, self.opts['t2t'], freeze_att=FREEZE_ATT) loss_data = t2t_loss.item() logging.info("loss: %f", loss_data) chainer.reporter.report({'t/t2t_loss': loss_data}) chainer.reporter.report({'t/t2t_acc': t2t_acc}) else: logging.error("Error: cannot find correct mode ('p', 'a', 't')") sys.exit() delete_feat(data)
def update_core_impl_autoencode(self, speech_batch, text_batch): # read scp files # x: original json with loaded features # will be converted to chainer variable later # batch only has one minibatch utterance, which is specified by batch[0] s_data = self.converter( speech_batch, use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim) t_data = self.converter( text_batch, use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim) for data in (s_data, t_data): asr_texts, asr_feats, asr_featlens = get_asr_torch( self.model, data) tts_texts, tts_textlens, tts_feats, tts_labels, tts_featlens, spembs = \ get_tts_data(self.model, data, 'text', use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim) avg_textlen = float(np.mean(tts_textlens.data.cpu().numpy())) if data[0][1]['utt2mode'] == 'a': logging.info("audio only mode") result = self.model.ae_speech( data, return_hidden=True, return_inout=self.args.use_inout_mmd) s2s_loss, hspad, hslen = result loss_data = s2s_loss.item() logging.info("loss: %f", loss_data) chainer.reporter.report({'t/s2s_loss': loss_data}) elif data[0][1]['utt2mode'] == 't': logging.info("text only mode") # t2t_loss, t2t_acc = self.model.ae_text(data) result = self.model.ae_text( data, return_hidden=True, return_inout=self.args.use_inout_mmd) t2t_loss, t2t_acc, htpad, htlen = result loss_data = t2t_loss.item() logging.info("loss: %f", loss_data) chainer.reporter.report({'t/t2t_loss': loss_data}) chainer.reporter.report({'t/t2t_acc': t2t_acc}) else: logging.error("Error: cannot find correct mode ('a', 't')") sys.exit() self.opts["s2s"].zero_grad() self.opts["t2t"].zero_grad() if self.args.inter_domain_loss == "kl": domain_loss_fun = packed_gauss_kld else: domain_loss_fun = packed_mmd mmd_loss = domain_loss_fun(hspad, hslen, htpad, htlen) chainer.reporter.report({'t/ae_mmd_loss': mmd_loss.item()}) loss = self.args.s2s_weight * s2s_loss + self.args.t2t_weight * t2t_loss + self.args.mmd_weight * mmd_loss loss.backward() self.gradient_decent(loss, self.opts['s2s'], freeze_att=FREEZE_ATT, backward=False) self.gradient_decent(loss, self.opts['t2t'], freeze_att=FREEZE_ATT, backward=False) delete_feat(s_data) delete_feat(t_data)
def get_tts_torch(model, data): return get_tts_data( model, data, 'text', use_speaker_embedding=model.tts_loss.model.spk_embed_dim)
def evaluate(self): iterator = self._iterators['main'] if self.eval_hook: self.eval_hook(self) if hasattr(iterator, 'reset'): iterator.reset() it = iterator else: it = copy.copy(iterator) summary = reporter_module.DictSummary() if not torch_is_old: torch.set_grad_enabled(False) for batch in it: observation = {} with reporter_module.report_scope(observation): # read scp files # x: original json with loaded features # will be converted to chainer variable later data = self.converter(batch, use_speaker_embedding=self.model. tts_loss.model.spk_embed_dim) self.model.eval() if data[0][1]['utt2mode'] != 'p': logging.error( "Error: evaluation only support a parallel data mode ('p')" ) sys.exit() asr_texts, asr_feats, asr_featlens = get_asr_torch( self.model, data) asr_loss, asr_acc = self.model.asr_loss( asr_feats, asr_featlens, asr_texts, do_report=False, report_acc=True) # disable reporter tts_texts, tts_textlens, tts_feats, tts_labels, tts_featlens, spembs = \ get_tts_data(self.model, data, 'text', use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim) avg_textlen = float(np.mean(tts_textlens.data.cpu().numpy())) tts_loss = self.model.tts_loss(tts_texts, tts_textlens, tts_feats, tts_labels, tts_featlens, spembs=spembs, do_report=False) s2s_loss = self.model.ae_speech(data) t2t_loss, t2t_acc = self.model.ae_text(data) # average loss for all four networks if NO_AVG: loss = (asr_loss + tts_loss + s2s_loss + t2t_loss) / 4.0 else: loss = (asr_loss / avg_textlen + tts_loss + s2s_loss + t2t_loss / avg_textlen) / 4.0 loss_data = loss.item() if torch_is_old else loss.item() if NO_AVG: asr_loss_data = asr_loss.item( ) if torch_is_old else asr_loss.item() else: asr_loss_data = asr_loss.item( ) / avg_textlen if torch_is_old else asr_loss.item( ) / avg_textlen tts_loss_data = tts_loss.item( ) if torch_is_old else tts_loss.item() s2s_loss_data = s2s_loss.item( ) if torch_is_old else s2s_loss.item() if NO_AVG: t2t_loss_data = t2t_loss.item( ) if torch_is_old else t2t_loss.item() else: t2t_loss_data = t2t_loss.item( ) / avg_textlen if torch_is_old else t2t_loss.item( ) / avg_textlen chainer.reporter.report({'d/loss': loss_data}) chainer.reporter.report({'d/asr_loss': asr_loss_data}) chainer.reporter.report({'d/tts_loss': tts_loss_data}) chainer.reporter.report({'d/asr_acc': asr_acc}) chainer.reporter.report({'d/s2s_loss': s2s_loss_data}) chainer.reporter.report({'d/t2t_loss': t2t_loss_data}) chainer.reporter.report({'d/t2t_acc': t2t_acc}) delete_feat(data) summary.add(observation) if not torch_is_old: torch.set_grad_enabled(True) self.model.train() return summary.compute_mean()
def update_core(self): # When we pass one iterator and optimizer to StandardUpdater.__init__, # they are automatically named 'main'. train_iter = self.get_iterator('main') # Get the next batch ( a list of json files) batch = train_iter.__next__() # read scp files # x: original json with loaded features # will be converted to chainer variable later # batch only has one minibatch utterance, which is specified by batch[0] if len(batch[0]) < self.num_gpu: logging.warning('batch size is less than number of gpus. Ignored') return data = self.converter( batch, use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim) # Compute the loss at this time step and accumulate it asr_texts, asr_feats, asr_featlens = get_asr_torch(self.model, data) tts_texts, tts_textlens, tts_feats, tts_labels, tts_featlens, spembs = \ get_tts_data(self.model, data, 'text', use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim) avg_textlen = float(np.mean(tts_textlens.data.cpu().numpy())) if data[0][1]['utt2mode'] == 'p': logging.info("parallel data mode") if ALL_MODE: modes = ['asr', 'tts', 's2s', 't2t'] else: modes = ['asr', 'tts'] random.shuffle(modes) # shuffle loss_data_sum = 0.0 use_mmd = self.args.mmd_weight != 0.0 for mode in modes: if mode == 'asr': if self.args.asr_weight == 0.0: asr_loss_data = 0 # use nan? asr_acc = 0 continue loss, asr_acc = self.model.asr_loss( asr_feats, asr_featlens, asr_texts, do_report=False, report_acc=True) # disable reporter if NO_AVG: pass else: loss = loss / avg_textlen asr_loss_data = loss.item() chainer.reporter.report({'t/asr_loss': asr_loss_data}) chainer.reporter.report({'t/asr_acc': asr_acc}) loss_data_sum += asr_loss_data logging.info("asr_loss_data: %f", asr_loss_data) self.gradient_decent(loss, self.opts[mode]) if mode == 'tts': if self.args.tts_weight == 0.0: tts_loss_data = 0 # use nan? continue loss = self.model.tts_loss(tts_texts, tts_textlens, tts_feats, tts_labels, tts_featlens, spembs=spembs, do_report=False) tts_loss_data = loss.item() chainer.reporter.report({'t/tts_loss': tts_loss_data}) loss_data_sum += tts_loss_data logging.info("tts_loss_data: %f", tts_loss_data) self.gradient_decent(loss, self.opts[mode]) if mode == 's2s': if self.args.s2s_weight == 0.0 and not use_mmd: s2s_loss_data = 0.0 continue loss, hspad, hslen = self.model.ae_speech( data, return_hidden=True) s2s_loss_data = loss.item() chainer.reporter.report({'t/s2s_loss': s2s_loss_data}) loss_data_sum += s2s_loss_data logging.info("s2s_loss_data: %f", s2s_loss_data) if self.args.s2s_weight != 0.0: self.gradient_decent(loss, self.opts[mode], freeze_att=FREEZE_ATT, retain_graph=use_mmd) if mode == 't2t': if self.args.t2t_weight == 0.0 and not use_mmd: t2t_loss_data = 0.0 continue loss, t2t_acc, htpad, htlen = self.model.ae_text( data, return_hidden=True) if NO_AVG: pass else: loss = loss / avg_textlen t2t_loss_data = loss.item() chainer.reporter.report({'t/t2t_loss': t2t_loss_data}) chainer.reporter.report({'t/t2t_acc': t2t_acc}) loss_data_sum += t2t_loss_data logging.info("t2t_loss_data: %f", t2t_loss_data) if self.args.t2t_weight != 0.0: self.gradient_decent(loss, self.opts[mode], freeze_att=FREEZE_ATT, retain_graph=use_mmd) logging.info("loss_data_sum: %f", loss_data_sum) if use_mmd: loss = packed_mmd(hspad, hslen, htpad, htlen) chainer.reporter.report({'t/mmd_loss': loss.item()}) self.gradient_decent(loss, self.opts["mmd"], freeze_att=FREEZE_ATT) loss_data = (loss_data_sum + loss.item()) / 5.0 elif ALL_MODE: loss_data = loss_data_sum / 4.0 else: loss_data = loss_data_sum / 2.0 logging.info("loss_data: %f", loss_data) chainer.reporter.report({'t/loss': loss_data}) elif data[0][1]['utt2mode'] == 'a': logging.info("audio only mode") s2s_loss = self.model.ae_speech(data) loss = s2s_loss self.gradient_decent(loss, self.opts['s2s'], freeze_att=FREEZE_ATT) loss_data = loss.item() logging.info("loss: %f", loss_data) chainer.reporter.report({'t/loss': loss_data}) chainer.reporter.report({'t/s2s_loss': loss_data}) elif data[0][1]['utt2mode'] == 't': logging.info("text only mode") t2t_loss, t2t_acc = self.model.ae_text(data) if NO_AVG: loss = t2t_loss else: loss = t2t_loss / avg_textlen self.gradient_decent(loss, self.opts['t2t'], freeze_att=FREEZE_ATT) loss_data = loss.item() logging.info("loss: %f", loss_data) chainer.reporter.report({'t/loss': loss_data}) chainer.reporter.report({'t/t2t_loss': loss_data}) chainer.reporter.report({'t/t2t_acc': t2t_acc}) else: logging.error("Error: cannot find correct mode ('p', 'a', 't')") sys.exit() delete_feat(data)