def update_core_impl(self, batch):
        # read scp files
        # x: original json with loaded features
        #    will be converted to chainer variable later
        # batch only has one minibatch utterance, which is specified by batch[0]
        if len(batch[0]) < self.num_gpu:
            logging.warning('batch size is less than number of gpus. Ignored')
            return
        data = self.converter(
            batch,
            use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim)

        # Compute the loss at this time step and accumulate it
        asr_texts, asr_feats, asr_featlens = get_asr_torch(self.model, data)
        tts_texts, tts_textlens, tts_feats, tts_labels, tts_featlens, spembs = \
            get_tts_data(self.model, data, 'text', use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim)

        avg_textlen = float(np.mean(tts_textlens.data.cpu().numpy()))
        if data[0][1]['utt2mode'] == 'p':
            logging.info("parallel data mode")
            if ALL_MODE:
                modes = ['asr', 'tts', 's2s', 't2t']
            else:
                modes = ['asr', 'tts']
            random.shuffle(modes)
            # shuffle
            loss_data_sum = 0.0
            use_mmd = self.args.mmd_weight != 0.0
            for mode in modes:
                if mode == 'asr':
                    if self.args.asr_weight == 0.0:
                        asr_loss_data = 0  # use nan?
                        asr_acc = 0
                        continue
                    loss, asr_acc = self.model.asr_loss(
                        asr_feats,
                        asr_featlens,
                        asr_texts,
                        do_report=False,
                        report_acc=True)  # disable reporter
                    if NO_AVG:
                        pass
                    else:
                        loss = loss / avg_textlen
                    asr_loss_data = loss.item()
                    chainer.reporter.report({'t/asr_loss': asr_loss_data})
                    chainer.reporter.report({'t/asr_acc': asr_acc})
                    loss_data_sum += asr_loss_data
                    logging.info("asr_loss_data: %f", asr_loss_data)
                    self.gradient_decent(loss, self.opts[mode])
                if mode == 'tts':
                    if self.args.tts_weight == 0.0:
                        tts_loss_data = 0  # use nan?
                        continue

                    loss = self.model.tts_loss(tts_texts,
                                               tts_textlens,
                                               tts_feats,
                                               tts_labels,
                                               tts_featlens,
                                               spembs=spembs,
                                               do_report=False)
                    tts_loss_data = loss.item()
                    chainer.reporter.report({'t/tts_loss': tts_loss_data})
                    loss_data_sum += tts_loss_data
                    logging.info("tts_loss_data: %f", tts_loss_data)
                    self.gradient_decent(loss, self.opts[mode])
                if mode == 's2s':
                    if self.args.s2s_weight == 0.0 and not use_mmd:
                        s2s_loss_data = 0.0
                        continue
                    result = self.model.ae_speech(
                        data,
                        return_hidden=True,
                        return_inout=self.args.use_inout_mmd)
                    if self.args.use_inout_mmd:
                        loss, speech_in, speech_in_len, speech_out, speech_out_len = result
                    else:
                        loss, hspad, hslen = result
                    s2s_loss_data = loss.item()
                    chainer.reporter.report({'t/s2s_loss': s2s_loss_data})
                    loss_data_sum += s2s_loss_data
                    logging.info("s2s_loss_data: %f", s2s_loss_data)
                    if self.args.s2s_weight != 0.0:
                        self.gradient_decent(loss,
                                             self.opts[mode],
                                             freeze_att=FREEZE_ATT,
                                             retain_graph=use_mmd)
                if mode == 't2t':
                    if self.args.t2t_weight == 0.0 and not use_mmd:
                        t2t_loss_data = 0.0
                        continue
                    result = self.model.ae_text(
                        data,
                        return_hidden=True,
                        return_inout=self.args.use_inout_mmd)
                    if self.args.use_inout_mmd:
                        loss, t2t_acc, text_in, text_in_len, text_out, text_out_len = result
                    else:
                        loss, t2t_acc, htpad, htlen = result
                    if NO_AVG:
                        pass
                    else:
                        loss = loss / avg_textlen
                    t2t_loss_data = loss.item()
                    chainer.reporter.report({'t/t2t_loss': t2t_loss_data})
                    chainer.reporter.report({'t/t2t_acc': t2t_acc})
                    loss_data_sum += t2t_loss_data
                    logging.info("t2t_loss_data: %f", t2t_loss_data)
                    if self.args.t2t_weight != 0.0:
                        self.gradient_decent(loss,
                                             self.opts[mode],
                                             freeze_att=FREEZE_ATT,
                                             retain_graph=use_mmd)
                logging.info("loss_data_sum: %f", loss_data_sum)

            if use_mmd:
                if self.args.use_inout_mmd:
                    logging.warning((speech_in.shape, speech_out.shape,
                                     speech_in_len, speech_out_len))
                    speech_loss = packed_mmd(speech_in, speech_in_len,
                                             speech_out, speech_out_len)
                    chainer.reporter.report({'t/ps_mmd': speech_loss.item()})
                    text_loss = packed_mmd(text_in, text_in_len, text_out,
                                           text_out_len)
                    chainer.reporter.report({'t/pt_mmd': text_loss.item()})
                    loss = (speech_loss + text_loss) / 2.0
                    chainer.reporter.report({'t/p_mmd': loss.item()})
                else:
                    if self.args.inter_domain_loss == "kl":
                        loss_fun = packed_gauss_kld
                    else:
                        loss_fun = packed_mmd
                    loss = loss_fun(hspad, hslen, htpad, htlen)
                    chainer.reporter.report({'t/mmd_loss': loss.item()})
                self.gradient_decent(loss,
                                     self.opts["mmd"],
                                     freeze_att=FREEZE_ATT)
                loss_data = (loss_data_sum + loss.item()) / 5.0
            elif ALL_MODE:
                loss_data = loss_data_sum / 4.0
            else:
                loss_data = loss_data_sum / 2.0

            logging.info("loss_data: %f", loss_data)
            chainer.reporter.report({'t/loss': loss_data})
        elif data[0][1]['utt2mode'] == 'a':
            logging.info("audio only mode")
            result = self.model.ae_speech(data,
                                          return_hidden=True,
                                          return_inout=self.args.use_inout_mmd)
            if self.args.use_inout_mmd:
                s2s_loss, speech_in, speech_in_len, speech_out, speech_out_len = result
                speech_loss = packed_mmd(speech_in, speech_in_len, speech_out,
                                         speech_out_len)
                chainer.reporter.report({'t/us_mmd': speech_loss.item()})
                loss = (s2s_loss + speech_loss) / 2.0
            else:
                s2s_loss, hspad, hslen = result
                loss = s2s_loss
            self.gradient_decent(loss, self.opts['s2s'], freeze_att=FREEZE_ATT)

            loss_data = s2s_loss.item()
            logging.info("loss: %f", loss_data)
            chainer.reporter.report({'t/s2s_loss': loss_data})
        elif data[0][1]['utt2mode'] == 't':
            logging.info("text only mode")
            # t2t_loss, t2t_acc = self.model.ae_text(data)
            result = self.model.ae_text(data,
                                        return_hidden=True,
                                        return_inout=self.args.use_inout_mmd)
            if self.args.use_inout_mmd:
                t2t_loss, t2t_acc, text_in, text_in_len, text_out, text_out_len = result
                text_loss = packed_mmd(text_in, text_in_len, text_out,
                                       text_out_len)
                chainer.reporter.report({'t/ut_mmd': text_loss.item()})
                loss = (t2t_loss + text_loss) / 2.0
            else:
                t2t_loss, t2t_acc, htpad, htlen = result
                loss = t2t_loss

            if NO_AVG:
                loss = t2t_loss
            else:
                loss = t2t_loss / avg_textlen
            self.gradient_decent(loss, self.opts['t2t'], freeze_att=FREEZE_ATT)

            loss_data = t2t_loss.item()
            logging.info("loss: %f", loss_data)
            chainer.reporter.report({'t/t2t_loss': loss_data})
            chainer.reporter.report({'t/t2t_acc': t2t_acc})
        else:
            logging.error("Error: cannot find correct mode ('p', 'a', 't')")
            sys.exit()

        delete_feat(data)
    def update_core_impl_autoencode(self, speech_batch, text_batch):
        # read scp files
        # x: original json with loaded features
        #    will be converted to chainer variable later
        # batch only has one minibatch utterance, which is specified by batch[0]
        s_data = self.converter(
            speech_batch,
            use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim)
        t_data = self.converter(
            text_batch,
            use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim)
        for data in (s_data, t_data):
            asr_texts, asr_feats, asr_featlens = get_asr_torch(
                self.model, data)
            tts_texts, tts_textlens, tts_feats, tts_labels, tts_featlens, spembs = \
                get_tts_data(self.model, data, 'text', use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim)
            avg_textlen = float(np.mean(tts_textlens.data.cpu().numpy()))
            if data[0][1]['utt2mode'] == 'a':
                logging.info("audio only mode")
                result = self.model.ae_speech(
                    data,
                    return_hidden=True,
                    return_inout=self.args.use_inout_mmd)
                s2s_loss, hspad, hslen = result
                loss_data = s2s_loss.item()
                logging.info("loss: %f", loss_data)
                chainer.reporter.report({'t/s2s_loss': loss_data})
            elif data[0][1]['utt2mode'] == 't':
                logging.info("text only mode")
                # t2t_loss, t2t_acc = self.model.ae_text(data)
                result = self.model.ae_text(
                    data,
                    return_hidden=True,
                    return_inout=self.args.use_inout_mmd)
                t2t_loss, t2t_acc, htpad, htlen = result
                loss_data = t2t_loss.item()
                logging.info("loss: %f", loss_data)
                chainer.reporter.report({'t/t2t_loss': loss_data})
                chainer.reporter.report({'t/t2t_acc': t2t_acc})
            else:
                logging.error("Error: cannot find correct mode ('a', 't')")
                sys.exit()

        self.opts["s2s"].zero_grad()
        self.opts["t2t"].zero_grad()
        if self.args.inter_domain_loss == "kl":
            domain_loss_fun = packed_gauss_kld
        else:
            domain_loss_fun = packed_mmd
        mmd_loss = domain_loss_fun(hspad, hslen, htpad, htlen)
        chainer.reporter.report({'t/ae_mmd_loss': mmd_loss.item()})
        loss = self.args.s2s_weight * s2s_loss + self.args.t2t_weight * t2t_loss + self.args.mmd_weight * mmd_loss
        loss.backward()
        self.gradient_decent(loss,
                             self.opts['s2s'],
                             freeze_att=FREEZE_ATT,
                             backward=False)
        self.gradient_decent(loss,
                             self.opts['t2t'],
                             freeze_att=FREEZE_ATT,
                             backward=False)
        delete_feat(s_data)
        delete_feat(t_data)
def get_tts_torch(model, data):
    return get_tts_data(
        model,
        data,
        'text',
        use_speaker_embedding=model.tts_loss.model.spk_embed_dim)
    def evaluate(self):
        iterator = self._iterators['main']

        if self.eval_hook:
            self.eval_hook(self)

        if hasattr(iterator, 'reset'):
            iterator.reset()
            it = iterator
        else:
            it = copy.copy(iterator)

        summary = reporter_module.DictSummary()

        if not torch_is_old:
            torch.set_grad_enabled(False)
        for batch in it:
            observation = {}
            with reporter_module.report_scope(observation):
                # read scp files
                # x: original json with loaded features
                #    will be converted to chainer variable later
                data = self.converter(batch,
                                      use_speaker_embedding=self.model.
                                      tts_loss.model.spk_embed_dim)
                self.model.eval()
                if data[0][1]['utt2mode'] != 'p':
                    logging.error(
                        "Error: evaluation only support a parallel data mode ('p')"
                    )
                    sys.exit()

                asr_texts, asr_feats, asr_featlens = get_asr_torch(
                    self.model, data)
                asr_loss, asr_acc = self.model.asr_loss(
                    asr_feats,
                    asr_featlens,
                    asr_texts,
                    do_report=False,
                    report_acc=True)  # disable reporter

                tts_texts, tts_textlens, tts_feats, tts_labels, tts_featlens, spembs = \
                    get_tts_data(self.model, data, 'text', use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim)
                avg_textlen = float(np.mean(tts_textlens.data.cpu().numpy()))
                tts_loss = self.model.tts_loss(tts_texts,
                                               tts_textlens,
                                               tts_feats,
                                               tts_labels,
                                               tts_featlens,
                                               spembs=spembs,
                                               do_report=False)
                s2s_loss = self.model.ae_speech(data)
                t2t_loss, t2t_acc = self.model.ae_text(data)

                # average loss for all four networks
                if NO_AVG:
                    loss = (asr_loss + tts_loss + s2s_loss + t2t_loss) / 4.0
                else:
                    loss = (asr_loss / avg_textlen + tts_loss + s2s_loss +
                            t2t_loss / avg_textlen) / 4.0
                loss_data = loss.item() if torch_is_old else loss.item()
                if NO_AVG:
                    asr_loss_data = asr_loss.item(
                    ) if torch_is_old else asr_loss.item()
                else:
                    asr_loss_data = asr_loss.item(
                    ) / avg_textlen if torch_is_old else asr_loss.item(
                    ) / avg_textlen
                tts_loss_data = tts_loss.item(
                ) if torch_is_old else tts_loss.item()
                s2s_loss_data = s2s_loss.item(
                ) if torch_is_old else s2s_loss.item()
                if NO_AVG:
                    t2t_loss_data = t2t_loss.item(
                    ) if torch_is_old else t2t_loss.item()
                else:
                    t2t_loss_data = t2t_loss.item(
                    ) / avg_textlen if torch_is_old else t2t_loss.item(
                    ) / avg_textlen

                chainer.reporter.report({'d/loss': loss_data})
                chainer.reporter.report({'d/asr_loss': asr_loss_data})
                chainer.reporter.report({'d/tts_loss': tts_loss_data})
                chainer.reporter.report({'d/asr_acc': asr_acc})
                chainer.reporter.report({'d/s2s_loss': s2s_loss_data})
                chainer.reporter.report({'d/t2t_loss': t2t_loss_data})
                chainer.reporter.report({'d/t2t_acc': t2t_acc})

                delete_feat(data)

            summary.add(observation)

        if not torch_is_old:
            torch.set_grad_enabled(True)
        self.model.train()

        return summary.compute_mean()
    def update_core(self):
        # When we pass one iterator and optimizer to StandardUpdater.__init__,
        # they are automatically named 'main'.
        train_iter = self.get_iterator('main')

        # Get the next batch ( a list of json files)
        batch = train_iter.__next__()

        # read scp files
        # x: original json with loaded features
        #    will be converted to chainer variable later
        # batch only has one minibatch utterance, which is specified by batch[0]
        if len(batch[0]) < self.num_gpu:
            logging.warning('batch size is less than number of gpus. Ignored')
            return
        data = self.converter(
            batch,
            use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim)

        # Compute the loss at this time step and accumulate it
        asr_texts, asr_feats, asr_featlens = get_asr_torch(self.model, data)
        tts_texts, tts_textlens, tts_feats, tts_labels, tts_featlens, spembs = \
            get_tts_data(self.model, data, 'text', use_speaker_embedding=self.model.tts_loss.model.spk_embed_dim)

        avg_textlen = float(np.mean(tts_textlens.data.cpu().numpy()))
        if data[0][1]['utt2mode'] == 'p':
            logging.info("parallel data mode")
            if ALL_MODE:
                modes = ['asr', 'tts', 's2s', 't2t']
            else:
                modes = ['asr', 'tts']
            random.shuffle(modes)
            # shuffle
            loss_data_sum = 0.0
            use_mmd = self.args.mmd_weight != 0.0
            for mode in modes:
                if mode == 'asr':
                    if self.args.asr_weight == 0.0:
                        asr_loss_data = 0  # use nan?
                        asr_acc = 0
                        continue
                    loss, asr_acc = self.model.asr_loss(
                        asr_feats,
                        asr_featlens,
                        asr_texts,
                        do_report=False,
                        report_acc=True)  # disable reporter
                    if NO_AVG:
                        pass
                    else:
                        loss = loss / avg_textlen
                    asr_loss_data = loss.item()
                    chainer.reporter.report({'t/asr_loss': asr_loss_data})
                    chainer.reporter.report({'t/asr_acc': asr_acc})
                    loss_data_sum += asr_loss_data
                    logging.info("asr_loss_data: %f", asr_loss_data)
                    self.gradient_decent(loss, self.opts[mode])
                if mode == 'tts':
                    if self.args.tts_weight == 0.0:
                        tts_loss_data = 0  # use nan?
                        continue

                    loss = self.model.tts_loss(tts_texts,
                                               tts_textlens,
                                               tts_feats,
                                               tts_labels,
                                               tts_featlens,
                                               spembs=spembs,
                                               do_report=False)
                    tts_loss_data = loss.item()
                    chainer.reporter.report({'t/tts_loss': tts_loss_data})
                    loss_data_sum += tts_loss_data
                    logging.info("tts_loss_data: %f", tts_loss_data)
                    self.gradient_decent(loss, self.opts[mode])
                if mode == 's2s':
                    if self.args.s2s_weight == 0.0 and not use_mmd:
                        s2s_loss_data = 0.0
                        continue
                    loss, hspad, hslen = self.model.ae_speech(
                        data, return_hidden=True)
                    s2s_loss_data = loss.item()
                    chainer.reporter.report({'t/s2s_loss': s2s_loss_data})
                    loss_data_sum += s2s_loss_data
                    logging.info("s2s_loss_data: %f", s2s_loss_data)
                    if self.args.s2s_weight != 0.0:
                        self.gradient_decent(loss,
                                             self.opts[mode],
                                             freeze_att=FREEZE_ATT,
                                             retain_graph=use_mmd)
                if mode == 't2t':
                    if self.args.t2t_weight == 0.0 and not use_mmd:
                        t2t_loss_data = 0.0
                        continue
                    loss, t2t_acc, htpad, htlen = self.model.ae_text(
                        data, return_hidden=True)
                    if NO_AVG:
                        pass
                    else:
                        loss = loss / avg_textlen
                    t2t_loss_data = loss.item()
                    chainer.reporter.report({'t/t2t_loss': t2t_loss_data})
                    chainer.reporter.report({'t/t2t_acc': t2t_acc})
                    loss_data_sum += t2t_loss_data
                    logging.info("t2t_loss_data: %f", t2t_loss_data)
                    if self.args.t2t_weight != 0.0:
                        self.gradient_decent(loss,
                                             self.opts[mode],
                                             freeze_att=FREEZE_ATT,
                                             retain_graph=use_mmd)
                logging.info("loss_data_sum: %f", loss_data_sum)

            if use_mmd:
                loss = packed_mmd(hspad, hslen, htpad, htlen)
                chainer.reporter.report({'t/mmd_loss': loss.item()})
                self.gradient_decent(loss,
                                     self.opts["mmd"],
                                     freeze_att=FREEZE_ATT)
                loss_data = (loss_data_sum + loss.item()) / 5.0
            elif ALL_MODE:
                loss_data = loss_data_sum / 4.0
            else:
                loss_data = loss_data_sum / 2.0

            logging.info("loss_data: %f", loss_data)
            chainer.reporter.report({'t/loss': loss_data})
        elif data[0][1]['utt2mode'] == 'a':
            logging.info("audio only mode")
            s2s_loss = self.model.ae_speech(data)
            loss = s2s_loss
            self.gradient_decent(loss, self.opts['s2s'], freeze_att=FREEZE_ATT)

            loss_data = loss.item()
            logging.info("loss: %f", loss_data)
            chainer.reporter.report({'t/loss': loss_data})
            chainer.reporter.report({'t/s2s_loss': loss_data})
        elif data[0][1]['utt2mode'] == 't':
            logging.info("text only mode")
            t2t_loss, t2t_acc = self.model.ae_text(data)
            if NO_AVG:
                loss = t2t_loss
            else:
                loss = t2t_loss / avg_textlen
            self.gradient_decent(loss, self.opts['t2t'], freeze_att=FREEZE_ATT)

            loss_data = loss.item()
            logging.info("loss: %f", loss_data)
            chainer.reporter.report({'t/loss': loss_data})
            chainer.reporter.report({'t/t2t_loss': loss_data})
            chainer.reporter.report({'t/t2t_acc': t2t_acc})
        else:
            logging.error("Error: cannot find correct mode ('p', 'a', 't')")
            sys.exit()

        delete_feat(data)