def iter_asr(set_name, set_train_mode, set_rr):
            assert set_name in ['train', 'dev', 'test']
            rr = set_rr
            rr = sort_reverse(rr, feat_len[set_name])
            rr_key = feat_iterator[set_name].get_key_by_index(rr)
            curr_feat_list = feat_iterator[set_name].get_feat_by_key(rr_key)
            curr_text_list = text_iterator[set_name].get_text_by_key(rr_key)

            curr_feat_mat, curr_feat_len, curr_text_mat, curr_text_len = batch_speech_text(
                opts['gpu'], curr_feat_list, curr_text_list)

            _loss, _acc = fn_batch_asr(model_asr,
                                       curr_feat_mat,
                                       curr_feat_len,
                                       curr_text_mat,
                                       curr_text_len,
                                       train_step=set_train_mode,
                                       coeff_loss=opts['coeff_pair'])
            _loss /= opts['coeff_pair']
            assert_nan(_loss)
            _count = len(rr)
            m_asr_loss[set_name] += _loss * _count
            m_asr_acc[set_name] += _acc * _count
            m_asr_count[set_name] += _count
            if tf_writer is not None:
                auto_writer_info_asr(set_name, _loss, _acc)
        def iter_tts(set_name, set_train_mode, set_rr):
            assert set_name in ['train', 'dev', 'test']
            rr = set_rr
            rr = sort_reverse(rr, text_len[set_name])
            rr_key = text_iterator[set_name].get_key_by_index(rr)
            curr_feat_list = feat_iterator[set_name].get_feat_by_key(rr_key)
            curr_text_list = text_iterator[set_name].get_text_by_key(rr_key)
            if model_tts.TYPE == TacotronType.MULTI_SPEAKER:
                curr_spkvec_list = feat_spkvec_iterator.get_feat_by_key(rr_key)
                curr_aux_info = {'speaker_vector': curr_spkvec_list}
            else:
                curr_aux_info = None
            curr_feat_mat, curr_feat_len, curr_text_mat, curr_text_len = batch_speech_text(
                opts['gpu'],
                curr_feat_list,
                curr_text_list,
                feat_sil=feat_sil,
                group=opts['tts_group'],
                start_sil=1,
                end_sil=opts['tts_pad_sil'])
            _loss, _loss_feat, _loss_bce_fend, _loss_spk_emb, _acc_fend = fn_batch_tts(
                model_tts,
                curr_text_mat,
                curr_text_len,
                curr_feat_mat,
                curr_feat_len,
                aux_info=curr_aux_info,
                train_step=set_train_mode,
                coeff_loss=opts['coeff_pair'])
            _loss /= opts['coeff_pair']
            assert_nan(_loss)
            _count = len(rr)
            m_tts_loss[set_name] += _loss * _count
            m_tts_loss_feat[set_name] += _loss_feat * _count
            m_tts_loss_bce[set_name] += _loss_bce_fend * _count
            m_tts_loss_spk_emb[set_name] += _loss_spk_emb * _count
            m_tts_acc[set_name] += _acc_fend * _count
            m_tts_count[set_name] += _count

            if tf_writer is not None:
                auto_writer_info_tts(set_name, _loss, _loss_feat,
                                     _loss_bce_fend, _loss_spk_emb, _acc_fend)
Beispiel #3
0
                    curr_label_list,
                    feat_sil=feat_sil,
                    group=opts['group'],
                    start_sil=1,
                    end_sil=opts['pad_sil'])
                # print(2, timeit.default_timer() - tic); tic = timeit.default_timer()
                _tmp_loss, _tmp_loss_feat, _tmp_loss_bernend, _tmp_acc_bernend = fn_batch(
                    text_mat,
                    text_len,
                    feat_mat,
                    feat_len,
                    aux_info=aux_info,
                    train_step=set_train_mode)
                # print(3, timeit.default_timer() - tic); tic = timeit.default_timer()
                _tmp_count = len(rr)
                assert_nan(_tmp_loss)
                mloss[set_name] += _tmp_loss * _tmp_count
                mloss_feat[set_name] += _tmp_loss_feat * _tmp_count
                mloss_bernend[set_name] += _tmp_loss_bernend * _tmp_count
                macc_bernend[set_name] += _tmp_acc_bernend * _tmp_count
                mcount[set_name] += _tmp_count
            pass

        info_header = [
            'set', 'loss', 'loss feat', 'loss bern end', 'acc bern end'
        ]
        info_table = []
        logger.info("Epoch %d -- lrate %f -- time %.2fs" %
                    (ee, opt.param_groups[0]['lr'], time.time() - start_time))
        for set_name in mloss.keys():
            mloss[set_name] /= mcount[set_name]
                                                  feat_sil=feat_sil,
                                                  group=1,
                                                  start_sil=1,
                                                  end_sil=1)

                curr_key_list = feat_iterator[set_name].get_key_by_index(rr)
                curr_spk_list = [map_key2spk[x] for x in curr_key_list]

                _tmp_loss_ce, _tmp_acc = fn_batch_ce(
                    feat_mat,
                    feat_len,
                    speaker_list=curr_spk_list,
                    train_step=set_train_mode)
                _tmp_count = len(rr)
                # TODO : include margin loss later #
                assert_nan(_tmp_loss_ce)
                mloss_ce[set_name] += _tmp_loss_ce * _tmp_count
                mloss[set_name] += _tmp_loss_ce * _tmp_count
                macc_spk[set_name] += _tmp_acc * _tmp_count
                mcount[set_name] += _tmp_count
            pass

        info_header = ['set', 'loss', 'loss ce', 'loss margin', 'acc spk']
        info_table = []
        logger.info("Epoch %d -- lrate %f -- time %.2fs" %
                    (ee, opts['lrate'], time.time() - start_time))
        for set_name in mloss.keys():
            mloss[set_name] /= mcount[set_name]
            mloss_ce[set_name] /= mcount[set_name]
            mloss_margin[set_name] /= mcount[set_name]
            macc_spk[set_name] /= mcount[set_name]
        def iter_cycle_asr2tts(set_name, set_train_mode, set_rr):
            rr = set_rr
            rr = sort_reverse(rr, feat_len[set_name])
            rr_key = feat_iterator[set_name].get_key_by_index(rr)
            curr_feat_list = feat_iterator[set_name].get_feat_by_key(rr_key)
            curr_feat_mat, curr_feat_len = batch_speech(
                opts['gpu'],
                curr_feat_list,
                feat_sil=feat_sil,
                group=opts['tts_group'],
                start_sil=1,
                end_sil=opts['tts_pad_sil'])
            # modified feature for ASR #
            curr_feat_mat_for_asr = curr_feat_mat[:, 1:-opts[
                'tts_pad_sil']].contiguous().view(len(set_rr), -1, NDIM_FEAT)
            curr_feat_len_for_asr = [len(x) for x in curr_feat_list]

            if opts['asr_gen_search']['type'] == 'greedy':
                curr_pred_text_list, curr_pred_text_len, curr_pred_att_mat = generator_text.greedy_search(
                    model_asr,
                    curr_feat_mat_for_asr,
                    curr_feat_len_for_asr,
                    map_text2idx=map_text2idx,
                    max_target=opts['asr_gen_cutoff'])
            elif opts['asr_gen_search']['type'] == 'beam':
                curr_pred_text_list, curr_pred_text_len = [], []
                for ii in range(0, len(rr), opts['asr_gen_search']['chunk']):
                    _start_ii = ii
                    _end_ii = min(ii + opts['asr_gen_search']['chunk'],
                                  len(rr))
                    curr_pred_text_list_ii, curr_pred_text_len_ii, _ = generator_text.beam_search(
                        model_asr,
                        curr_feat_mat_for_asr[_start_ii:_end_ii],
                        curr_feat_len_for_asr[_start_ii:_end_ii],
                        map_text2idx=map_text2idx,
                        max_target=opts['asr_gen_cutoff'],
                        kbeam=opts['asr_gen_search']['kbeam'])
                    curr_pred_text_list.extend(curr_pred_text_list_ii)
                    curr_pred_text_len.extend(curr_pred_text_len_ii)
            if model_tts.TYPE == TacotronType.MULTI_SPEAKER:
                curr_spkvec_list = feat_spkvec_iterator.get_feat_by_key(rr_key)

            # TODO: filter bad text #
            curr_pred_quality = generator_text.eval_gen_text_quality(
                None, curr_pred_text_len, None)
            curr_pred_valid_idx = [
                x for x, y in enumerate(curr_pred_quality) if y == 1
            ]

            m_asr_gen_info['total'] += len(rr)
            m_asr_gen_info['valid'] += len(curr_pred_valid_idx)
            if len(curr_pred_valid_idx) == 0:
                return None

            curr_pred_text_list = batch_select(curr_pred_text_list,
                                               curr_pred_valid_idx)
            if model_tts.TYPE == TacotronType.MULTI_SPEAKER:
                curr_spkvec_list = batch_select(curr_spkvec_list,
                                                curr_pred_valid_idx)
            curr_pred_text_len = batch_select(curr_pred_text_len,
                                              curr_pred_valid_idx)

            curr_feat_mat = batch_select(curr_feat_mat, curr_pred_valid_idx)
            curr_feat_len = batch_select(curr_feat_len, curr_pred_valid_idx)

            # zip & sort dec #
            curr_pred_text_list = batch_sorter(curr_pred_text_list,
                                               curr_pred_text_len)
            if model_tts.TYPE == TacotronType.MULTI_SPEAKER:
                curr_spkvec_list = batch_sorter(curr_spkvec_list,
                                                curr_pred_text_len)
            curr_feat_mat = batch_sorter(curr_feat_mat, curr_pred_text_len)
            curr_feat_len = batch_sorter(curr_feat_len, curr_pred_text_len)
            curr_pred_text_len = batch_sorter(
                curr_pred_text_len,
                curr_pred_text_len)  # sort key must be on the last step

            curr_pred_text_mat, curr_pred_text_len = batch_text(
                opts['gpu'],
                curr_pred_text_list,
            )

            if model_tts.TYPE == TacotronType.MULTI_SPEAKER:
                curr_aux_info = {'speaker_vector': curr_spkvec_list}
            else:
                curr_aux_info = None
            _loss, _loss_feat, _loss_bce_fend, _loss_spk_emb, _acc_fend = fn_batch_tts(
                model_tts,
                curr_pred_text_mat,
                curr_pred_text_len,
                curr_feat_mat,
                curr_feat_len,
                aux_info=curr_aux_info,
                train_step=set_train_mode,
                coeff_loss=opts['coeff_unpair'])
            _loss /= opts['coeff_unpair']

            assert_nan(_loss)
            _count = len(curr_pred_text_list)
            m_tts_loss[set_name] += _loss * _count
            m_tts_loss_feat[set_name] += _loss_feat * _count
            m_tts_loss_bce[set_name] += _loss_bce_fend * _count
            m_tts_loss_spk_emb[set_name] += _loss_spk_emb * _count
            m_tts_acc[set_name] += _acc_fend * _count
            m_tts_count[set_name] += _count

            if tf_writer is not None:
                auto_writer_info_tts(set_name, _loss, _loss_feat,
                                     _loss_bce_fend, _loss_spk_emb, _acc_fend)
        def iter_cycle_tts2asr(set_name, set_train_mode, set_rr):
            rr = set_rr
            rr = sort_reverse(rr, text_len[set_name])
            rr_key = text_iterator[set_name].get_key_by_index(rr)
            curr_text_list = text_iterator[set_name].get_text_by_key(rr_key)
            if model_tts.TYPE == TacotronType.MULTI_SPEAKER:
                if opts['tts_spk_sample'] is None:
                    curr_spkvec_list = feat_spkvec_iterator.get_feat_by_key(
                        rr_key)
                elif opts['tts_spk_sample'] == 'uniform':
                    _sample_rr_key = random.sample(feat_iterator[set_name].key,
                                                   k=len(set_rr))
                    curr_spkvec_list = feat_spkvec_iterator.get_feat_by_key(
                        _sample_rr_key)
                else:
                    raise NotImplementedError()
                curr_aux_info = {'speaker_vector': curr_spkvec_list}
            else:
                curr_aux_info = None
            curr_text_mat, curr_text_len = batch_text(opts['gpu'],
                                                      curr_text_list)

            curr_pred_feat_list, curr_pred_feat_len, curr_pred_att_mat = generator_speech.decode_greedy_pred(
                model_tts,
                curr_text_mat,
                curr_text_len,
                group=opts['tts_group'],
                feat_sil=feat_sil,
                aux_info=curr_aux_info,
                max_target=opts['tts_gen_cutoff'] // opts['tts_group'])

            # filter bad speech #
            curr_pred_quality = generator_text.eval_gen_text_quality(
                None, curr_pred_feat_len, None)
            curr_pred_valid_idx = [
                x for x, y in enumerate(curr_pred_quality) if y == 1
            ]

            m_tts_gen_info['total'] += len(rr)
            m_tts_gen_info['valid'] += len(curr_pred_valid_idx)
            if len(curr_pred_valid_idx) == 0:
                return None

            curr_pred_feat_list = batch_select(curr_pred_feat_list,
                                               curr_pred_valid_idx)
            curr_pred_feat_len = batch_select(curr_pred_feat_len,
                                              curr_pred_valid_idx)

            curr_text_mat = batch_select(curr_text_mat, curr_pred_valid_idx)
            curr_text_len = batch_select(curr_text_len, curr_pred_valid_idx)

            # zip & sort dec #
            curr_pred_feat_list = batch_sorter(curr_pred_feat_list,
                                               curr_pred_feat_len)
            curr_text_mat = batch_sorter(curr_text_mat, curr_pred_feat_len)
            curr_text_len = batch_sorter(curr_text_len, curr_pred_feat_len)
            curr_pred_feat_len = batch_sorter(
                curr_pred_feat_len,
                curr_pred_feat_len)  # sort key must be on the last step

            curr_pred_feat_mat, curr_pred_feat_len = batch_speech(
                opts['gpu'], curr_pred_feat_list)
            # if sorted(curr_pred_feat_len, reverse=True) != curr_pred_feat_len :
            # import ipdb; ipdb.set_trace()
            _loss, _acc = fn_batch_asr(model_asr,
                                       curr_pred_feat_mat,
                                       curr_pred_feat_len,
                                       curr_text_mat,
                                       curr_text_len,
                                       train_step=set_train_mode,
                                       coeff_loss=opts['coeff_unpair'])
            _loss /= opts['coeff_unpair']
            assert_nan(_loss)
            _count = len(rr)
            m_asr_loss[set_name] += _loss * _count
            m_asr_acc[set_name] += _acc * _count
            m_asr_count[set_name] += _count
            if tf_writer is not None:
                auto_writer_info_asr(set_name, _loss, _acc)