Exemplo n.º 1
0
def main():

    # Load configuration
    args, dir_name = parse_args_eval(sys.argv[1:])

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    set_logger(os.path.join(args.recog_dir, 'plot.log'),
               stdout=args.recog_stdout)

    for i, s in enumerate(args.recog_sets):
        # Load dataloader
        dataloader = build_dataloader(
            args=args,
            tsv_path=s,
            batch_size=1,
            is_test=True,
            first_n_utterances=args.recog_first_n_utt,
            longform_max_n_frames=args.recog_longform_max_n_frames)

        if i == 0:
            # Load ASR model
            model = Speech2Text(args, dir_name)
            epoch = int(float(args.recog_model[0].split('-')[-1]) * 10) / 10
            if args.recog_n_average > 1:
                # Model averaging for Transformer
                model = average_checkpoints(model,
                                            args.recog_model[0],
                                            n_average=args.recog_n_average)
            else:
                load_checkpoint(args.recog_model[0], model)

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('epoch: %d' % epoch)
            logger.info('batch size: %d' % args.recog_batch_size)

            # GPU setting
            if args.recog_n_gpus >= 1:
                model.cudnn_setting(deterministic=True, benchmark=False)
                model.cuda()

        save_path = mkdir_join(args.recog_dir, 'ctc_probs')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        for batch in dataloader:
            nbest_hyps_id, _ = model.decode(batch['xs'], args,
                                            dataloader.idx2token[0])
            best_hyps_id = [h[0] for h in nbest_hyps_id]

            # Get CTC probs
            ctc_probs, topk_ids, xlens = model.get_ctc_probs(batch['xs'],
                                                             temperature=1,
                                                             topk=min(
                                                                 100,
                                                                 model.vocab))
            # NOTE: ctc_probs: '[B, T, topk]'

            for b in range(len(batch['xs'])):
                tokens = dataloader.idx2token[0](best_hyps_id[b],
                                                 return_list=True)
                spk = batch['speakers'][b]

                plot_ctc_probs(
                    ctc_probs[b, :xlens[b]],
                    topk_ids[b],
                    factor=args.subsample_factor,
                    spectrogram=batch['xs'][b][:, :dataloader.input_dim],
                    save_path=mkdir_join(save_path, spk,
                                         batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8))

                hyp = ' '.join(tokens)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % batch['text'][b].lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)
Exemplo n.º 2
0
def main():

    # Load configuration
    args, recog_params, dir_name = parse_args_eval(sys.argv[1:])

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    set_logger(os.path.join(args.recog_dir, 'plot.log'),
               stdout=args.recog_stdout)

    for i, s in enumerate(args.recog_sets):
        # Load dataset
        dataset = Dataset(
            corpus=args.corpus,
            tsv_path=s,
            dict_path=os.path.join(dir_name, 'dict.txt'),
            dict_path_sub1=os.path.join(dir_name, 'dict_sub1.txt') if
            os.path.isfile(os.path.join(dir_name, 'dict_sub1.txt')) else False,
            nlsyms=args.nlsyms,
            wp_model=os.path.join(dir_name, 'wp.model'),
            unit=args.unit,
            unit_sub1=args.unit_sub1,
            batch_size=args.recog_batch_size,
            is_test=True)

        if i == 0:
            # Load the ASR model
            model = Speech2Text(args, dir_name)
            epoch = int(args.recog_model[0].split('-')[-1])
            if args.recog_n_average > 1:
                # Model averaging for Transformer
                model = average_checkpoints(model,
                                            args.recog_model[0],
                                            n_average=args.recog_n_average)
            else:
                load_checkpoint(args.recog_model[0], model)

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('epoch: %d' % epoch)
            logger.info('batch size: %d' % args.recog_batch_size)

            # GPU setting
            if args.recog_n_gpus >= 1:
                model.cudnn_setting(deterministic=True, benchmark=False)
                model.cuda()

        save_path = mkdir_join(args.recog_dir, 'ctc_probs')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        while True:
            batch, is_new_epoch = dataset.next(
                recog_params['recog_batch_size'])
            best_hyps_id, _ = model.decode(batch['xs'], recog_params)

            # Get CTC probs
            ctc_probs, topk_ids, xlens = model.get_ctc_probs(batch['xs'],
                                                             temperature=1,
                                                             topk=min(
                                                                 100,
                                                                 model.vocab))
            # NOTE: ctc_probs: '[B, T, topk]'

            for b in range(len(batch['xs'])):
                tokens = dataset.idx2token[0](best_hyps_id[b],
                                              return_list=True)
                spk = batch['speakers'][b]

                plot_ctc_probs(
                    ctc_probs[b, :xlens[b]],
                    topk_ids[b],
                    subsample_factor=args.subsample_factor,
                    spectrogram=batch['xs'][b][:, :dataset.input_dim],
                    save_path=mkdir_join(save_path, spk,
                                         batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8))

                hyp = ' '.join(tokens)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % batch['text'][b].lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)

            if is_new_epoch:
                break
Exemplo n.º 3
0
def main():

    # Load configuration
    args, dir_name = parse_args_eval(sys.argv[1:])

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'align.log')):
        os.remove(os.path.join(args.recog_dir, 'align.log'))
    set_logger(os.path.join(args.recog_dir, 'align.log'),
               stdout=args.recog_stdout)

    for i, s in enumerate(args.recog_sets):
        # Align all utterances
        args.min_n_frames = 0
        args.max_n_frames = 1e5

        # Load dataloader
        dataloader = build_dataloader(args=args,
                                      tsv_path=s,
                                      batch_size=args.recog_batch_size)

        if i == 0:
            # Load ASR model
            model = Speech2Text(args, dir_name)
            epoch = int(args.recog_model[0].split('-')[-1])
            if args.recog_n_average > 1:
                # Model averaging for Transformer
                model = average_checkpoints(model,
                                            args.recog_model[0],
                                            n_average=args.recog_n_average)
            else:
                load_checkpoint(args.recog_model[0], model)

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('epoch: %d' % epoch)
            logger.info('batch size: %d' % args.recog_batch_size)

            # GPU setting
            if args.recog_n_gpus >= 1:
                model.cudnn_setting(deterministic=True, benchmark=False)
                model.cuda()

        save_path = mkdir_join(args.recog_dir, 'ctc_forced_alignments')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        pbar = tqdm(total=len(dataloader))
        while True:
            batch, is_new_epoch = dataloader.next()
            trigger_points = model.ctc_forced_align(batch['xs'],
                                                    batch['ys'])  # `[B, L]`

            for b in range(len(batch['xs'])):
                save_path_spk = mkdir_join(save_path, batch['speakers'][b])
                save_path_utt = mkdir_join(save_path_spk,
                                           batch['utt_ids'][b] + '.txt')

                tokens = dataloader.idx2token[0](batch['ys'][b],
                                                 return_list=True)
                with codecs.open(save_path_utt, 'w', encoding="utf-8") as f:
                    for i, tok in enumerate(tokens):
                        f.write('%s %d\n' % (tok, trigger_points[b, i]))
                    f.write('%s %d\n' %
                            ('<eos>', trigger_points[b, len(tokens)]))

            pbar.update(len(batch['xs']))

            if is_new_epoch:
                break

        pbar.close()
Exemplo n.º 4
0
def main():

    # Load configuration
    args, dir_name = parse_args_eval(sys.argv[1:])

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'decode.log')):
        os.remove(os.path.join(args.recog_dir, 'decode.log'))
    set_logger(os.path.join(args.recog_dir, 'decode.log'),
               stdout=args.recog_stdout)

    wer_avg, cer_avg, per_avg = 0, 0, 0
    ppl_avg, loss_avg = 0, 0
    acc_avg = 0
    bleu_avg = 0
    for i, s in enumerate(args.recog_sets):
        # Load dataloader
        dataloader = build_dataloader(
            args=args,
            tsv_path=s,
            batch_size=1,
            is_test=True,
            first_n_utterances=args.recog_first_n_utt,
            longform_max_n_frames=args.recog_longform_max_n_frames)

        if i == 0:
            # Load ASR model
            model = Speech2Text(args, dir_name)
            epoch = int(float(args.recog_model[0].split('-')[-1]) * 10) / 10
            if args.recog_n_average > 1:
                # Model averaging for Transformer
                # topk_list = load_checkpoint(args.recog_model[0], model)
                model = average_checkpoints(
                    model,
                    args.recog_model[0],
                    # topk_list=topk_list,
                    n_average=args.recog_n_average)
            else:
                load_checkpoint(args.recog_model[0], model)

            # Ensemble (different models)
            ensemble_models = [model]
            if len(args.recog_model) > 1:
                for recog_model_e in args.recog_model[1:]:
                    conf_e = load_config(
                        os.path.join(os.path.dirname(recog_model_e),
                                     'conf.yml'))
                    args_e = copy.deepcopy(args)
                    for k, v in conf_e.items():
                        if 'recog' not in k:
                            setattr(args_e, k, v)
                    model_e = Speech2Text(args_e)
                    load_checkpoint(recog_model_e, model_e)
                    if args.recog_n_gpus >= 1:
                        model_e.cuda()
                    ensemble_models += [model_e]

            # Load LM for shallow fusion
            if not args.lm_fusion:
                # first path
                if args.recog_lm is not None and args.recog_lm_weight > 0:
                    conf_lm = load_config(
                        os.path.join(os.path.dirname(args.recog_lm),
                                     'conf.yml'))
                    args_lm = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm, k, v)
                    args_lm.recog_mem_len = args.recog_mem_len
                    lm = build_lm(args_lm,
                                  wordlm=args.recog_wordlm,
                                  lm_dict_path=os.path.join(
                                      os.path.dirname(args.recog_lm),
                                      'dict.txt'),
                                  asr_dict_path=os.path.join(
                                      dir_name, 'dict.txt'))
                    load_checkpoint(args.recog_lm, lm)
                    if args_lm.backward:
                        model.lm_bwd = lm
                    else:
                        model.lm_fwd = lm

                # second path (forward)
                if args.recog_lm_second is not None and args.recog_lm_second_weight > 0:
                    conf_lm_second = load_config(
                        os.path.join(os.path.dirname(args.recog_lm_second),
                                     'conf.yml'))
                    args_lm_second = argparse.Namespace()
                    for k, v in conf_lm_second.items():
                        setattr(args_lm_second, k, v)
                    args_lm_second.recog_mem_len = args.recog_mem_len
                    lm_second = build_lm(args_lm_second)
                    load_checkpoint(args.recog_lm_second, lm_second)
                    model.lm_second = lm_second

                # second path (backward)
                if args.recog_lm_bwd is not None and args.recog_lm_bwd_weight > 0:
                    conf_lm = load_config(
                        os.path.join(os.path.dirname(args.recog_lm_bwd),
                                     'conf.yml'))
                    args_lm_bwd = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm_bwd, k, v)
                    args_lm_bwd.recog_mem_len = args.recog_mem_len
                    lm_bwd = build_lm(args_lm_bwd)
                    load_checkpoint(args.recog_lm_bwd, lm_bwd)
                    model.lm_bwd = lm_bwd

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('recog metric: %s' % args.recog_metric)
            logger.info('recog oracle: %s' % args.recog_oracle)
            logger.info('epoch: %d' % epoch)
            logger.info('batch size: %d' % args.recog_batch_size)
            logger.info('beam width: %d' % args.recog_beam_width)
            logger.info('min length ratio: %.3f' % args.recog_min_len_ratio)
            logger.info('max length ratio: %.3f' % args.recog_max_len_ratio)
            logger.info('length penalty: %.3f' % args.recog_length_penalty)
            logger.info('length norm: %s' % args.recog_length_norm)
            logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty)
            logger.info('coverage threshold: %.3f' %
                        args.recog_coverage_threshold)
            logger.info('CTC weight: %.3f' % args.recog_ctc_weight)
            logger.info('fist LM path: %s' % args.recog_lm)
            logger.info('second LM path: %s' % args.recog_lm_second)
            logger.info('backward LM path: %s' % args.recog_lm_bwd)
            logger.info('LM weight (first-pass): %.3f' % args.recog_lm_weight)
            logger.info('LM weight (second-pass): %.3f' %
                        args.recog_lm_second_weight)
            logger.info('LM weight (backward): %.3f' %
                        args.recog_lm_bwd_weight)
            logger.info('GNMT: %s' % args.recog_gnmt_decoding)
            logger.info('forward-backward attention: %s' %
                        args.recog_fwd_bwd_attention)
            logger.info('resolving UNK: %s' % args.recog_resolving_unk)
            logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('ASR decoder state carry over: %s' %
                        (args.recog_asr_state_carry_over))
            logger.info('LM state carry over: %s' %
                        (args.recog_lm_state_carry_over))
            logger.info('model average (Transformer): %d' %
                        (args.recog_n_average))

            # GPU setting
            if args.recog_n_gpus >= 1:
                model.cudnn_setting(deterministic=True, benchmark=False)
                model.cuda()

        start_time = time.time()

        if args.recog_metric == 'edit_distance':
            if args.recog_unit in ['word', 'word_char']:
                wer, cer, _ = eval_word(ensemble_models,
                                        dataloader,
                                        args,
                                        epoch=epoch - 1,
                                        recog_dir=args.recog_dir,
                                        progressbar=True,
                                        fine_grained=True,
                                        oracle=True)
                wer_avg += wer
                cer_avg += cer
            elif args.recog_unit == 'wp':
                wer, cer = eval_wordpiece(ensemble_models,
                                          dataloader,
                                          args,
                                          epoch=epoch - 1,
                                          recog_dir=args.recog_dir,
                                          streaming=args.recog_streaming,
                                          progressbar=True,
                                          fine_grained=True,
                                          oracle=True)
                wer_avg += wer
                cer_avg += cer
            elif 'char' in args.recog_unit:
                wer, cer = eval_char(ensemble_models,
                                     dataloader,
                                     args,
                                     epoch=epoch - 1,
                                     recog_dir=args.recog_dir,
                                     progressbar=True,
                                     task_idx=0,
                                     fine_grained=True,
                                     oracle=True)
                #  task_idx=1 if args.recog_unit and 'char' in args.recog_unit else 0)
                wer_avg += wer
                cer_avg += cer
            elif 'phone' in args.recog_unit:
                per = eval_phone(ensemble_models,
                                 dataloader,
                                 args,
                                 epoch=epoch - 1,
                                 recog_dir=args.recog_dir,
                                 progressbar=True,
                                 fine_grained=True,
                                 oracle=True)
                per_avg += per
            else:
                raise ValueError(args.recog_unit)
        elif args.recog_metric in ['ppl', 'loss']:
            ppl, loss = eval_ppl(ensemble_models, dataloader, progressbar=True)
            ppl_avg += ppl
            loss_avg += loss
        elif args.recog_metric == 'accuracy':
            acc_avg += eval_accuracy(ensemble_models,
                                     dataloader,
                                     progressbar=True)
        elif args.recog_metric == 'bleu':
            bleu = eval_wordpiece_bleu(ensemble_models,
                                       dataloader,
                                       args,
                                       epoch=epoch - 1,
                                       recog_dir=args.recog_dir,
                                       streaming=args.recog_streaming,
                                       progressbar=True,
                                       fine_grained=True,
                                       oracle=True)
            bleu_avg += bleu
        else:
            raise NotImplementedError(args.recog_metric)
        elapsed_time = time.time() - start_time
        logger.info('Elapsed time: %.3f [sec]' % elapsed_time)
        logger.info('RTF: %.3f' % (elapsed_time /
                                   (dataloader.n_frames * 0.01)))

    if args.recog_metric == 'edit_distance':
        if 'phone' in args.recog_unit:
            logger.info('PER (avg.): %.2f %%\n' %
                        (per_avg / len(args.recog_sets)))
        else:
            logger.info('WER / CER (avg.): %.2f / %.2f %%\n' %
                        (wer_avg / len(args.recog_sets),
                         cer_avg / len(args.recog_sets)))
    elif args.recog_metric in ['ppl', 'loss']:
        logger.info('PPL (avg.): %.2f\n' % (ppl_avg / len(args.recog_sets)))
        print('PPL (avg.): %.3f' % (ppl_avg / len(args.recog_sets)))
        logger.info('Loss (avg.): %.2f\n' % (loss_avg / len(args.recog_sets)))
        print('Loss (avg.): %.3f' % (loss_avg / len(args.recog_sets)))
    elif args.recog_metric == 'accuracy':
        logger.info('Accuracy (avg.): %.2f\n' %
                    (acc_avg / len(args.recog_sets)))
        print('Accuracy (avg.): %.3f' % (acc_avg / len(args.recog_sets)))
    elif args.recog_metric == 'bleu':
        logger.info('BLEU (avg.): %.2f\n' % (bleu / len(args.recog_sets)))
        print('BLEU (avg.): %.3f' % (bleu / len(args.recog_sets)))
Exemplo n.º 5
0
def main():

    # Load configuration
    args, dir_name = parse_args_eval(sys.argv[1:])

    # Setting for logging
    if os.path.isfile(os.path.join(args.recog_dir, 'plot.log')):
        os.remove(os.path.join(args.recog_dir, 'plot.log'))
    set_logger(os.path.join(args.recog_dir, 'plot.log'), stdout=args.recog_stdout)

    for i, s in enumerate(args.recog_sets):
        # Load dataloader
        dataloader = build_dataloader(args=args,
                                      tsv_path=s,
                                      batch_size=1,
                                      is_test=True,
                                      first_n_utterances=args.recog_first_n_utt,
                                      longform_max_n_frames=args.recog_longform_max_n_frames)

        if i == 0:
            # Load ASR model
            model = Speech2Text(args, dir_name)
            epoch = int(float(args.recog_model[0].split('-')[-1]) * 10) / 10
            if args.recog_n_average > 1:
                # Model averaging for Transformer
                model = average_checkpoints(model, args.recog_model[0],
                                            n_average=args.recog_n_average)
            else:
                load_checkpoint(args.recog_model[0], model)

            # Ensemble (different models)
            ensemble_models = [model]
            if len(args.recog_model) > 1:
                for recog_model_e in args.recog_model[1:]:
                    conf_e = load_config(os.path.join(os.path.dirname(recog_model_e), 'conf.yml'))
                    args_e = copy.deepcopy(args)
                    for k, v in conf_e.items():
                        if 'recog' not in k:
                            setattr(args_e, k, v)
                    model_e = Speech2Text(args_e)
                    load_checkpoint(recog_model_e, model_e)
                    if args.recog_n_gpus >= 1:
                        model_e.cuda()
                    ensemble_models += [model_e]

            # Load LM for shallow fusion
            if not args.lm_fusion:
                # first path
                if args.recog_lm is not None and args.recog_lm_weight > 0:
                    conf_lm = load_config(os.path.join(os.path.dirname(args.recog_lm), 'conf.yml'))
                    args_lm = argparse.Namespace()
                    for k, v in conf_lm.items():
                        setattr(args_lm, k, v)
                    lm = build_lm(args_lm)
                    load_checkpoint(args.recog_lm, lm)
                    if args_lm.backward:
                        model.lm_bwd = lm
                    else:
                        model.lm_fwd = lm
                # NOTE: only support for first path

            if not args.recog_unit:
                args.recog_unit = args.unit

            logger.info('recog unit: %s' % args.recog_unit)
            logger.info('recog oracle: %s' % args.recog_oracle)
            logger.info('epoch: %d' % epoch)
            logger.info('batch size: %d' % args.recog_batch_size)
            logger.info('beam width: %d' % args.recog_beam_width)
            logger.info('min length ratio: %.3f' % args.recog_min_len_ratio)
            logger.info('max length ratio: %.3f' % args.recog_max_len_ratio)
            logger.info('length penalty: %.3f' % args.recog_length_penalty)
            logger.info('length norm: %s' % args.recog_length_norm)
            logger.info('coverage penalty: %.3f' % args.recog_coverage_penalty)
            logger.info('coverage threshold: %.3f' % args.recog_coverage_threshold)
            logger.info('CTC weight: %.3f' % args.recog_ctc_weight)
            logger.info('fist LM path: %s' % args.recog_lm)
            logger.info('LM weight: %.3f' % args.recog_lm_weight)
            logger.info('GNMT: %s' % args.recog_gnmt_decoding)
            logger.info('forward-backward attention: %s' % args.recog_fwd_bwd_attention)
            logger.info('resolving UNK: %s' % args.recog_resolving_unk)
            logger.info('ensemble: %d' % (len(ensemble_models)))
            logger.info('ASR decoder state carry over: %s' % (args.recog_asr_state_carry_over))
            logger.info('LM state carry over: %s' % (args.recog_lm_state_carry_over))
            logger.info('model average (Transformer): %d' % (args.recog_n_average))

            # GPU setting
            if args.recog_n_gpus >= 1:
                model.cudnn_setting(deterministic=True, benchmark=False)
                model.cuda()

        save_path = mkdir_join(args.recog_dir, 'att_weights')

        # Clean directory
        if save_path is not None and os.path.isdir(save_path):
            shutil.rmtree(save_path)
            os.mkdir(save_path)

        for batch in dataloader:
            nbest_hyps_id, aws = model.decode(
                batch['xs'], args, dataloader.idx2token[0],
                exclude_eos=False,
                refs_id=batch['ys'],
                ensemble_models=ensemble_models[1:] if len(ensemble_models) > 1 else [],
                speakers=batch['sessions'] if dataloader.corpus == 'swbd' else batch['speakers'])
            best_hyps_id = [h[0] for h in nbest_hyps_id]

            # Get CTC probs
            ctc_probs, topk_ids = None, None
            if args.ctc_weight > 0:
                ctc_probs, topk_ids, xlens = model.get_ctc_probs(
                    batch['xs'], task='ys', temperature=1, topk=min(100, model.vocab))
                # NOTE: ctc_probs: '[B, T, topk]'
            ctc_probs_sub1, topk_ids_sub1 = None, None
            if args.ctc_weight_sub1 > 0:
                ctc_probs_sub1, topk_ids_sub1, xlens_sub1 = model.get_ctc_probs(
                    batch['xs'], task='ys_sub1', temperature=1, topk=min(100, model.vocab_sub1))

            if model.bwd_weight > 0.5:
                # Reverse the order
                best_hyps_id = [hyp[::-1] for hyp in best_hyps_id]
                aws = [[aw[0][:, ::-1]] for aw in aws]

            for b in range(len(batch['xs'])):
                tokens = dataloader.idx2token[0](best_hyps_id[b], return_list=True)
                spk = batch['speakers'][b]

                plot_attention_weights(
                    aws[b][0][:, :len(tokens)], tokens,
                    spectrogram=batch['xs'][b][:, :dataloader.input_dim] if args.input_type == 'speech' else None,
                    factor=args.subsample_factor,
                    ref=batch['text'][b].lower(),
                    save_path=mkdir_join(save_path, spk, batch['utt_ids'][b] + '.png'),
                    figsize=(20, 8),
                    ctc_probs=ctc_probs[b, :xlens[b]] if ctc_probs is not None else None,
                    ctc_topk_ids=topk_ids[b] if topk_ids is not None else None,
                    ctc_probs_sub1=ctc_probs_sub1[b, :xlens_sub1[b]] if ctc_probs_sub1 is not None else None,
                    ctc_topk_ids_sub1=topk_ids_sub1[b] if topk_ids_sub1 is not None else None)

                if model.bwd_weight > 0.5:
                    hyp = ' '.join(tokens[::-1])
                else:
                    hyp = ' '.join(tokens)
                logger.info('utt-id: %s' % batch['utt_ids'][b])
                logger.info('Ref: %s' % batch['text'][b].lower())
                logger.info('Hyp: %s' % hyp)
                logger.info('-' * 50)