示例#1
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Setting for logging
    logger = set_logger(args.model_path)

    for i, data_type in enumerate(['dev', 'test']):
        # Load dataset
        dataset = Dataset(data_save_path=args.data_save_path,
                          backend=params['backend'],
                          input_freq=params['input_freq'],
                          use_delta=params['use_delta'],
                          use_double_delta=params['use_double_delta'],
                          data_type=data_type,
                          label_type=params['label_type'],
                          batch_size=args.eval_batch_size,
                          splice=params['splice'],
                          num_stack=params['num_stack'],
                          num_skip=params['num_skip'],
                          sort_utt=False,
                          tool=params['tool'])

        if i == 0:
            params['num_classes'] = dataset.num_classes

            # Load model
            model = load(model_type=params['model_type'],
                         params=params,
                         backend=params['backend'])

            # Restore the saved parameters
            epoch, _, _, _ = model.load_checkpoint(save_path=args.model_path,
                                                   epoch=args.epoch)

            # GPU setting
            model.set_cuda(deterministic=False, benchmark=True)

            logger.info('beam width: %d' % args.beam_width)
            logger.info('epoch: %d' % (epoch - 1))

        per, df = eval_phone(model=model,
                             dataset=dataset,
                             map_file_path='./conf/phones.60-48-39.map',
                             eval_batch_size=args.eval_batch_size,
                             beam_width=args.beam_width,
                             max_decode_len=MAX_DECODE_LEN_PHONE,
                             min_decode_len=MIN_DECODE_LEN_PHONE,
                             length_penalty=args.length_penalty,
                             coverage_penalty=args.coverage_penalty,
                             progressbar=True)
        logger.info('  PER (%s): %.3f %%' % (data_type, (per * 100)))
        logger.info(df)
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    test_data = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        data_type='eval1',
        # data_type='eval2',
        # data_type='eval3',
        data_size=params['data_size'],
        label_type=params['label_type'],
        label_type_sub=params['label_type_sub'],
        batch_size=args.eval_batch_size,
        splice=params['splice'],
        num_stack=params['num_stack'],
        num_skip=params['num_skip'],
        sort_utt=False,
        reverse=False,
        tool=params['tool'])

    params['num_classes'] = test_data.num_classes
    params['num_classes_sub'] = test_data.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    a2c_oracle = False

    # Visualize
    plot(model=model,
         dataset=test_data,
         eval_batch_size=args.eval_batch_size,
         beam_width=args.beam_width,
         beam_width_sub=args.beam_width_sub,
         length_penalty=args.length_penalty,
         a2c_oracle=a2c_oracle,
         save_path=mkdir_join(args.model_path, 'att_weights'))
示例#3
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    vocab_file_path = '../metrics/vocab_files/' + \
        params['label_type'] + '_' + params['data_size'] + '.txt'
    vocab_file_path_sub = '../metrics/vocab_files/' + \
        params['label_type_sub'] + '_' + params['data_size'] + '.txt'
    test_data = Dataset(
        backend=params['backend'],
        input_channel=params['input_channel'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        data_type='test_clean',
        # data_type='test_other',
        data_size=params['data_size'],
        label_type=params['label_type'],
        label_type_sub=params['label_type_sub'],
        vocab_file_path=vocab_file_path,
        vocab_file_path_sub=vocab_file_path_sub,
        batch_size=args.eval_batch_size,
        splice=params['splice'],
        num_stack=params['num_stack'],
        num_skip=params['num_skip'],
        sort_utt=True,
        reverse=True,
        save_format=params['save_format'])
    params['num_classes'] = test_data.num_classes
    params['num_classes_sub'] = test_data.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    # Visualize
    decode(model=model,
           dataset=test_data,
           beam_width=args.beam_width,
           max_decode_len=args.max_decode_len,
           max_decode_len_sub=args.max_decode_len_sub,
           eval_batch_size=args.eval_batch_size,
           save_path=None)
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    test_data = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        # data_type='eval2000_swbd',
        data_type='eval2000_ch',
        data_size=params['data_size'],
        label_type=params['label_type'], label_type_sub=params['label_type_sub'],
        batch_size=args.eval_batch_size, splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        # sort_utt=True, reverse=True,
        sort_utt=False, reverse=False, tool=params['tool'])

    params['num_classes'] = test_data.num_classes
    params['num_classes_sub'] = test_data.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    # Visualize
    decode(model=model,
           dataset=test_data,
           beam_width=args.beam_width,
           beam_width_sub=args.beam_width_sub,
           eval_batch_size=args.eval_batch_size,
           save_path=None
           # save_path=args.model_path
           resolving_unk=False)
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    vocab_file_path = '../metrics/vocab_files/' + \
        params['label_type'] + '_' + params['data_size'] + '.txt'
    test_data = Dataset(
        backend=params['backend'],
        input_channel=params['input_channel'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        data_type='test_clean',
        # data_type='test_other',
        data_size=params['data_size'],
        label_type=params['label_type'], vocab_file_path=vocab_file_path,
        batch_size=args.eval_batch_size, splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        sort_utt=True, reverse=False, save_format=params['save_format'])
    params['num_classes'] = test_data.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    space_index = 27 if params['label_type'] == 'character' else None
    # NOTE: index 0 is reserved for blank in warpctc_pytorch

    # Visualize
    plot(model=model,
         dataset=test_data,
         eval_batch_size=args.eval_batch_size,
         save_path=mkdir_join(args.model_path, 'ctc_probs'),
         space_index=space_index)
示例#6
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    test_data = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        data_type='test', label_type=params['label_type'],
        batch_size=args.eval_batch_size, splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        sort_utt=True, reverse=True, tool=params['tool'])

    params['num_classes'] = test_data.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    # Visualize
    plot_probs(model=model,
               dataset=test_data,
               eval_batch_size=args.eval_batch_size,
               save_path=mkdir_join(args.model_path, 'ctc_probs'))
示例#7
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dev_data = Dataset(data_save_path=args.data_save_path,
                       backend=params['backend'],
                       input_freq=params['input_freq'],
                       use_delta=params['use_delta'],
                       use_double_delta=params['use_double_delta'],
                       data_type='dev',
                       label_type=params['label_type'],
                       batch_size=args.eval_batch_size,
                       splice=params['splice'],
                       num_stack=params['num_stack'],
                       num_skip=params['num_skip'],
                       sort_utt=False,
                       tool=params['tool'])
    test_data = Dataset(data_save_path=args.data_save_path,
                        backend=params['backend'],
                        input_freq=params['input_freq'],
                        use_delta=params['use_delta'],
                        use_double_delta=params['use_double_delta'],
                        data_type='test',
                        label_type=params['label_type'],
                        batch_size=args.eval_batch_size,
                        splice=params['splice'],
                        num_stack=params['num_stack'],
                        num_skip=params['num_skip'],
                        sort_utt=False,
                        tool=params['tool'])

    params['num_classes'] = test_data.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    print('beam width: %d' % args.beam_width)

    # dev
    per_dev, df_dev = eval_phone(model=model,
                                 dataset=dev_data,
                                 map_file_path='./conf/phones.60-48-39.map',
                                 eval_batch_size=args.eval_batch_size,
                                 beam_width=args.beam_width,
                                 max_decode_len=MAX_DECODE_LEN_PHONE,
                                 length_penalty=args.length_penalty,
                                 progressbar=True)
    print('  PER (dev): %.3f %%' % (per_dev * 100))
    print(df_dev)

    # test
    per_test, df_test = eval_phone(model=model,
                                   dataset=test_data,
                                   map_file_path='./conf/phones.60-48-39.map',
                                   eval_batch_size=args.eval_batch_size,
                                   beam_width=args.beam_width,
                                   max_decode_len=MAX_DECODE_LEN_PHONE,
                                   length_penalty=args.length_penalty,
                                   progressbar=True)
    print('  PER (test): %.3f %%' % (per_test * 100))
    print(df_test)

    with open(join(args.model_path, 'RESULTS'), 'w') as f:
        f.write('beam width: %d\n' % args.beam_width)
        f.write('  PER (dev): %.3f %%' % (per_dev * 100))
        f.write('  PER (test): %.3f %%' % (per_test * 100))
示例#8
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Setting for logging
    logger = set_logger(args.model_path)

    wer_mean, wer_sub_mean, cer_sub_mean = 0, 0, 0
    for i, data_type in enumerate(['eval1', 'eval2', 'eval3']):
        # Load dataset
        dataset = Dataset(data_save_path=args.data_save_path,
                          backend=params['backend'],
                          input_freq=params['input_freq'],
                          use_delta=params['use_delta'],
                          use_double_delta=params['use_double_delta'],
                          data_type=data_type,
                          data_size=params['data_size'],
                          label_type=params['label_type'],
                          label_type_sub=params['label_type_sub'],
                          batch_size=args.eval_batch_size,
                          splice=params['splice'],
                          num_stack=params['num_stack'],
                          num_skip=params['num_skip'],
                          shuffle=False,
                          tool=params['tool'])

        if i == 0:
            params['num_classes'] = dataset.num_classes
            params['num_classes_sub'] = dataset.num_classes_sub

            # Load model
            model = load(model_type=params['model_type'],
                         params=params,
                         backend=params['backend'])

            # Restore the saved parameters
            epoch, _, _, _ = model.load_checkpoint(save_path=args.model_path,
                                                   epoch=args.epoch)

            # GPU setting
            model.set_cuda(deterministic=False, benchmark=True)

            logger.info('beam width (main): %d\n' % args.beam_width)
            logger.info('beam width (sub) : %d\n' % args.beam_width_sub)
            logger.info('epoch: %d' % (epoch - 1))
            logger.info('a2c oracle: %s\n' % str(args.a2c_oracle))
            logger.info('resolving_unk: %s\n' % str(args.resolving_unk))
            logger.info('joint_decoding: %s\n' % str(args.joint_decoding))
            logger.info('score_sub_weight : %f' % args.score_sub_weight)

        wer, df = eval_word(models=[model],
                            dataset=dataset,
                            eval_batch_size=args.eval_batch_size,
                            beam_width=args.beam_width,
                            max_decode_len=MAX_DECODE_LEN_WORD,
                            min_decode_len=MIN_DECODE_LEN_WORD,
                            beam_width_sub=args.beam_width_sub,
                            max_decode_len_sub=MAX_DECODE_LEN_CHAR,
                            min_decode_len_sub=MIN_DECODE_LEN_CHAR,
                            length_penalty=args.length_penalty,
                            coverage_penalty=args.coverage_penalty,
                            progressbar=True,
                            resolving_unk=args.resolving_unk,
                            a2c_oracle=args.a2c_oracle,
                            joint_decoding=args.joint_decoding,
                            score_sub_weight=args.score_sub_weight)
        wer_mean += wer
        logger.info('  WER (%s, main): %.3f %%' % (data_type, (wer * 100)))
        logger.info(df)

        wer_sub, cer_sub, df_sub = eval_char(
            models=[model],
            dataset=dataset,
            eval_batch_size=args.eval_batch_size,
            beam_width=args.beam_width_sub,
            max_decode_len=MAX_DECODE_LEN_CHAR,
            min_decode_len=MIN_DECODE_LEN_CHAR,
            length_penalty=args.length_penalty,
            coverage_penalty=args.coverage_penalty,
            progressbar=True)
        wer_sub_mean += wer_sub
        cer_sub_mean += cer_sub
        logger.info(' WER / CER (%s, sub): %.3f / %.3f %%' % (data_type,
                                                              (wer_sub * 100),
                                                              (cer_sub * 100)))
        logger.info(df_sub)

    logger.info('  WER (mean, main): %.3f %%' % (wer_mean * 100 / 3))
    logger.info('  WER / CER (mean, sub): %.3f / %.3f %%' %
                ((wer_sub_mean * 100 / 3), (cer_sub_mean * 100 / 3)))
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        data_type='eval1',
        # data_type='eval2',
        # data_type='eval3',
        data_size=params['data_size'],
        label_type=params['label_type'],
        label_type_sub=params['label_type_sub'],
        batch_size=args.eval_batch_size,
        splice=params['splice'],
        num_stack=params['num_stack'],
        num_skip=params['num_skip'],
        sort_utt=False,
        reverse=False,
        tool=params['tool'])

    params['num_classes'] = dataset.num_classes
    params['num_classes_sub'] = dataset.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    # sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    ######################################################################

    word2char = Word2char(dataset.vocab_file_path, dataset.vocab_file_path_sub)

    for batch, is_new_epoch in dataset:
        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, aw, best_hyps_sub, aw_sub, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                beam_width_sub=args.beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_WORD,
                max_decode_len_sub=MAX_DECODE_LEN_CHAR,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty)
        else:
            best_hyps, aw, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD,
                min_decode_len=MIN_DECODE_LEN_WORD,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty)
            best_hyps_sub, aw_sub, _ = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_CHAR,
                min_decode_len=MIN_DECODE_LEN_CHAR,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty,
                task_index=1)

        if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
            best_hyps_joint, aw_joint, best_hyps_sub_joint, aw_sub_joint, _ = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD,
                min_decode_len=MIN_DECODE_LEN_WORD,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty,
                joint_decoding=args.joint_decoding,
                space_index=dataset.char2idx('_')[0],
                oov_index=dataset.word2idx('OOV')[0],
                word2char=word2char,
                idx2word=dataset.idx2word,
                idx2char=dataset.idx2char,
                score_sub_weight=args.score_sub_weight)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]
        ys_sub = batch['ys_sub'][perm_idx]
        y_lens_sub = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                str_ref_sub = ys_sub[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = dataset.idx2word(ys[b][:y_lens[b]])
                str_ref_sub = dataset.idx2char(ys_sub[b][:y_lens_sub[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = dataset.idx2word(best_hyps[b])
            str_hyp_sub = dataset.idx2char(best_hyps_sub[b])
            if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                str_hyp_joint = dataset.idx2word(best_hyps_joint[b])
                str_hyp_sub_joint = dataset.idx2char(best_hyps_sub_joint[b])

            ##############################
            # Resolving UNK
            ##############################
            if 'OOV' in str_hyp and args.resolving_unk:
                str_hyp_no_unk = resolve_unk(str_hyp, best_hyps_sub[b], aw[b],
                                             aw_sub[b], dataset.idx2char)
            if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                if 'OOV' in str_hyp_joint and args.resolving_unk:
                    str_hyp_no_unk_joint = resolve_unk(str_hyp_joint,
                                                       best_hyps_sub_joint[b],
                                                       aw_joint[b],
                                                       aw_sub_joint[b],
                                                       dataset.idx2char)

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref         : %s' % str_ref.replace('_', ' '))
            print('Hyp (main)  : %s' % str_hyp.replace('_', ' '))
            print('Hyp (sub)   : %s' % str_hyp_sub.replace('_', ' '))
            if 'OOV' in str_hyp and args.resolving_unk:
                print('Hyp (no UNK): %s' % str_hyp_no_unk.replace('_', ' '))
            if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                print('===== joint decoding =====')
                print('Hyp (main)  : %s' % str_hyp_joint.replace('_', ' '))
                print('Hyp (sub)   : %s' % str_hyp_sub_joint.replace('_', ' '))
                if 'OOV' in str_hyp_joint and args.resolving_unk:
                    print('Hyp (no UNK): %s' %
                          str_hyp_no_unk_joint.replace('_', ' '))

            try:
                wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                           hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                                      str_hyp).split('_'),
                                           normalize=True)
                print('WER (main)  : %.3f %%' % (wer * 100))
                if dataset.label_type_sub == 'character_wb':
                    wer_sub, _, _, _ = compute_wer(ref=str_ref_sub.split('_'),
                                                   hyp=re.sub(
                                                       r'(.*)>(.*)', r'\1',
                                                       str_hyp_sub).split('_'),
                                                   normalize=True)
                    print('WER (sub)   : %.3f %%' % (wer_sub * 100))
                else:
                    cer, _, _, _ = compute_wer(
                        ref=list(str_ref_sub.replace('_', '')),
                        hyp=list(
                            re.sub(r'(.*)>(.*)', r'\1',
                                   str_hyp_sub).replace('_', '')),
                        normalize=True)
                    print('CER (sub)   : %.3f %%' % (cer * 100))
                if 'OOV' in str_hyp and args.resolving_unk:
                    wer_no_unk, _, _, _ = compute_wer(
                        ref=str_ref.split('_'),
                        hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                   str_hyp_no_unk.replace('*', '')).split('_'),
                        normalize=True)
                    print('WER (no UNK): %.3f %%' % (wer_no_unk * 100))

                if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                    print('===== joint decoding =====')
                    wer_joint, _, _, _ = compute_wer(
                        ref=str_ref.split('_'),
                        hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                   str_hyp_joint).split('_'),
                        normalize=True)
                    print('WER (main)  : %.3f %%' % (wer_joint * 100))
                    if 'OOV' in str_hyp_joint and args.resolving_unk:
                        wer_no_unk_joint, _, _, _ = compute_wer(
                            ref=str_ref.split('_'),
                            hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                       str_hyp_no_unk_joint.replace(
                                           '*', '')).split('_'),
                            normalize=True)
                        print('WER (no UNK): %.3f %%' %
                              (wer_no_unk_joint * 100))

            except:
                print('--- skipped ---')
            print('\n')

        if is_new_epoch:
            break
示例#10
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    vocab_file_path = '../metrics/vocab_files/' + \
        params['label_type'] + '_' + params['data_size'] + '.txt'
    test_clean_data = Dataset(backend=params['backend'],
                              input_channel=params['input_channel'],
                              use_delta=params['use_delta'],
                              use_double_delta=params['use_double_delta'],
                              data_type='test_clean',
                              data_size=params['data_size'],
                              label_type=params['label_type'],
                              vocab_file_path=vocab_file_path,
                              batch_size=args.eval_batch_size,
                              splice=params['splice'],
                              num_stack=params['num_stack'],
                              num_skip=params['num_skip'],
                              sort_utt=False,
                              save_format=params['save_format'])
    test_other_data = Dataset(backend=params['backend'],
                              input_channel=params['input_channel'],
                              use_delta=params['use_delta'],
                              use_double_delta=params['use_double_delta'],
                              data_type='test_other',
                              data_size=params['data_size'],
                              label_type=params['label_type'],
                              vocab_file_path=vocab_file_path,
                              batch_size=args.eval_batch_size,
                              splice=params['splice'],
                              num_stack=params['num_stack'],
                              num_skip=params['num_skip'],
                              sort_utt=False,
                              save_format=params['save_format'])
    params['num_classes'] = test_clean_data.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    if 'word' in params['label_type']:
        wer_test_clean = do_eval_wer(model=model,
                                     dataset=test_clean_data,
                                     beam_width=args.beam_width,
                                     max_decode_len=args.max_decode_len,
                                     eval_batch_size=args.eval_batch_size,
                                     progressbar=True)
        print('  WER (clean): %f %%' % (wer_test_clean * 100))
        wer_test_other = do_eval_wer(model=model,
                                     dataset=test_other_data,
                                     beam_width=args.beam_width,
                                     max_decode_len=args.max_decode_len,
                                     eval_batch_size=args.eval_batch_size,
                                     progressbar=True)
        print('  WER (other): %f %%' % (wer_test_other * 100))
        print('  WER (mean): %f %%' %
              ((wer_test_clean + wer_test_other) * 100 / 2))
    else:
        cer_test_clean, wer_test_clean = do_eval_cer(
            model=model,
            dataset=test_clean_data,
            beam_width=args.beam_width,
            max_decode_len=args.max_decode_len,
            eval_batch_size=args.eval_batch_size,
            progressbar=True)
        print('  CER (clean): %f %%' % (cer_test_clean * 100))
        print('  WER (clean): %f %%' % (wer_test_clean * 100))
        cer_test_other, wer_test_other = do_eval_cer(
            model=model,
            dataset=test_other_data,
            beam_width=args.beam_width,
            max_decode_len=args.max_decode_len,
            eval_batch_size=args.eval_batch_size,
            progressbar=True)
        print('  CER (other): %f %%' % (cer_test_other * 100))
        print('  WER (other): %f %%' % (wer_test_other * 100))
        print('  CER (mean): %f %%' %
              ((cer_test_clean + cer_test_other) * 100 / 2))
        print('  WER (mean): %f %%' %
              ((wer_test_clean + wer_test_other) * 100 / 2))
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(data_save_path=args.data_save_path,
                      backend=params['backend'],
                      input_freq=params['input_freq'],
                      use_delta=params['use_delta'],
                      use_double_delta=params['use_double_delta'],
                      data_type='test',
                      label_type=params['label_type'],
                      batch_size=args.eval_batch_size,
                      splice=params['splice'],
                      num_stack=params['num_stack'],
                      num_skip=params['num_skip'],
                      sort_utt=True,
                      reverse=True,
                      tool=params['tool'])

    params['num_classes'] = dataset.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    save_path = mkdir_join(args.model_path, 'ctc_probs')

    ######################################################################

    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    for batch, is_new_epoch in dataset:
        # Get CTC probs
        probs, x_lens, _ = model.posteriors(batch['xs'],
                                            batch['x_lens'],
                                            temperature=1)
        # NOTE: probs: '[B, T, num_classes]'

        # Visualize
        for b in range(len(batch['xs'])):
            plot_ctc_probs(probs[b, :x_lens[b], :],
                           frame_num=x_lens[b],
                           num_stack=dataset.num_stack,
                           spectrogram=batch['xs'][b, :, :40],
                           save_path=join(save_path,
                                          batch['input_names'][b] + '.png'),
                           figsize=(14, 7))

        if is_new_epoch:
            break
示例#12
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    test_data = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        # data_type='test_dev93',
        data_type='test_eval92',
        data_size=params['data_size'],
        label_type=params['label_type'],
        batch_size=args.eval_batch_size,
        splice=params['splice'],
        num_stack=params['num_stack'],
        num_skip=params['num_skip'],
        sort_utt=False,
        tool=params['tool'])

    params['num_classes'] = test_data.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    print('beam width: %d' % args.beam_width)

    if params['label_type'] == 'word':
        wer_eval92, df_eval92 = eval_word(models=[model],
                                          dataset=test_data,
                                          beam_width=args.beam_width,
                                          max_decode_len=MAX_DECODE_LEN_WORD,
                                          eval_batch_size=args.eval_batch_size,
                                          length_penalty=args.length_penalty,
                                          progressbar=True)
        print('  WER (eval92): %.3f %%' % (wer_eval92 * 100))
        print(df_eval92)
    else:
        wer_eval92, cer_eval92, df_eval92 = eval_char(
            models=[model],
            dataset=test_data,
            beam_width=args.beam_width,
            max_decode_len=MAX_DECODE_LEN_CHAR,
            eval_batch_size=args.eval_batch_size,
            length_penalty=args.length_penalty,
            progressbar=True)
        print('  WER / CER (eval92): %.3f / %.3f %%' % ((wer_eval92 * 100),
                                                        (cer_eval92 * 100)))
        print(df_eval92)
示例#13
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    eval2000_swbd_data = Dataset(data_save_path=args.data_save_path,
                                 backend=params['backend'],
                                 input_freq=params['input_freq'],
                                 use_delta=params['use_delta'],
                                 use_double_delta=params['use_double_delta'],
                                 data_type='eval2000_swbd',
                                 data_size=params['data_size'],
                                 label_type=params['label_type'],
                                 label_type_sub=params['label_type_sub'],
                                 batch_size=args.eval_batch_size,
                                 splice=params['splice'],
                                 num_stack=params['num_stack'],
                                 num_skip=params['num_skip'],
                                 sort_utt=False,
                                 tool=params['tool'])
    eval2000_ch_data = Dataset(data_save_path=args.data_save_path,
                               backend=params['backend'],
                               input_freq=params['input_freq'],
                               use_delta=params['use_delta'],
                               use_double_delta=params['use_double_delta'],
                               data_type='eval2000_ch',
                               data_size=params['data_size'],
                               label_type=params['label_type'],
                               label_type_sub=params['label_type_sub'],
                               batch_size=args.eval_batch_size,
                               splice=params['splice'],
                               num_stack=params['num_stack'],
                               num_skip=params['num_skip'],
                               sort_utt=False,
                               tool=params['tool'])

    params['num_classes'] = eval2000_swbd_data.num_classes
    params['num_classes_sub'] = eval2000_swbd_data.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    a2c_oracle = False
    resolving_unk = False

    ##############################
    # Switchboard
    ##############################
    wer_eval2000_swbd, df_wer_eval2000_swbd = eval_word(
        models=[model],
        dataset=eval2000_swbd_data,
        beam_width=args.beam_width,
        beam_width_sub=args.beam_width_sub,
        max_decode_len=MAX_DECODE_LEN_WORD,
        max_decode_len_sub=MAX_DECODE_LEN_CHAR,
        eval_batch_size=args.eval_batch_size,
        progressbar=True,
        resolving_unk=resolving_unk,
        a2c_oracle=a2c_oracle)
    print('  WER (SWB, main): %.3f %%' % (wer_eval2000_swbd * 100))
    print(df_wer_eval2000_swbd)
    wer_eval2000_swbd_sub, cer_eval2000_swbd_sub, _ = eval_char(
        models=[model],
        dataset=eval2000_swbd_data,
        beam_width=args.beam_width_sub,
        max_decode_len=MAX_DECODE_LEN_CHAR,
        eval_batch_size=args.eval_batch_size,
        progressbar=True)
    print(' WER / CER (SWB, sub): %.3f / %.3f %%' %
          ((wer_eval2000_swbd_sub * 100), (cer_eval2000_swbd_sub * 100)))

    ##############################
    # Callhome
    ##############################
    wer_eval2000_ch, df_wer_eval2000_ch = eval_word(
        models=[model],
        dataset=eval2000_ch_data,
        beam_width=args.beam_width,
        beam_width_sub=args.beam_width_sub,
        max_decode_len=MAX_DECODE_LEN_WORD,
        max_decode_len_sub=MAX_DECODE_LEN_CHAR,
        eval_batch_size=args.eval_batch_size,
        progressbar=True,
        resolving_unk=resolving_unk,
        a2c_oracle=a2c_oracle)
    print('  WER (CHE, main): %.3f %%' % (wer_eval2000_ch * 100))
    print(df_wer_eval2000_ch)
    wer_eval2000_ch_sub, cer_eval2000_ch_sub, _ = eval_char(
        models=[model],
        dataset=eval2000_ch_data,
        beam_width=args.beam_width_sub,
        max_decode_len=args.max_decode_len_sub,
        eval_batch_size=args.eval_batch_size,
        progressbar=True)
    print('  WER / CER (CHE, sub): %.3f / %.3f %%' %
          ((wer_eval2000_ch_sub * 100), (cer_eval2000_ch_sub * 100)))

    print('  WER (mean, main): %.3f %%' %
          ((wer_eval2000_swbd + wer_eval2000_ch) * 100 / 2))
    print('  WER / CER (mean, sub): %.3f / %.3f %%' %
          (((wer_eval2000_swbd_sub + wer_eval2000_ch_sub) * 100 / 2),
           ((cer_eval2000_swbd_sub + cer_eval2000_ch_sub) * 100 / 2)))
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        data_type='eval1',
        # data_type='eval2',
        # data_type='eval3',
        data_size=params['data_size'],
        label_type=params['label_type'],
        batch_size=args.eval_batch_size, splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        sort_utt=False, reverse=False, tool=params['tool'])

    params['num_classes'] = dataset.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    save_path = mkdir_join(args.model_path, 'att_weights')

    ######################################################################

    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    if dataset.label_type == 'word':
        map_fn = dataset.idx2word
        max_decode_len = MAX_DECODE_LEN_WORD
        min_decode_len = MIN_DECODE_LEN_WORD
    else:
        map_fn = dataset.idx2char
        max_decode_len = MAX_DECODE_LEN_CHAR
        min_decode_len = MIN_DECODE_LEN_CHAR

    for batch, is_new_epoch in dataset:
        # Decode
        best_hyps, aw, perm_idx = model.decode(
            batch['xs'], batch['x_lens'],
            beam_width=args.beam_width,
            max_decode_len=max_decode_len,
            min_decode_len=min_decode_len,
            length_penalty=args.length_penalty,
            coverage_penalty=args.coverage_penalty)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = map_fn(ys[b][:y_lens[b]])

            token_list = map_fn(best_hyps[b], return_list=True)

            speaker = batch['input_names'][b].split('_')[0]
            plot_attention_weights(
                aw[b][:len(token_list), :batch['x_lens'][b]],
                label_list=token_list,
                spectrogram=batch['xs'][b, :, :dataset.input_freq],
                str_ref=str_ref,
                save_path=mkdir_join(save_path, speaker,
                                     batch['input_names'][b] + '.png'),
                figsize=(20, 8))

        if is_new_epoch:
            break
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    test_data = Dataset(data_save_path=args.data_save_path,
                        backend=params['backend'],
                        input_freq=params['input_freq'],
                        use_delta=params['use_delta'],
                        use_double_delta=params['use_double_delta'],
                        data_type='test_eval92',
                        data_size=params['data_size'],
                        label_type=params['label_type'],
                        label_type_sub=params['label_type_sub'],
                        batch_size=args.eval_batch_size,
                        splice=params['splice'],
                        num_stack=params['num_stack'],
                        num_skip=params['num_skip'],
                        sort_utt=False,
                        tool=params['tool'])

    params['num_classes'] = test_data.num_classes
    params['num_classes_sub'] = test_data.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    a2c_oracle = False
    resolving_unk = True

    print('beam width (main): %d' % args.beam_width)
    print('beam width (sub) : %d' % args.beam_width_sub)
    print('a2c oracle: %s' % str(a2c_oracle))
    print('resolving_unk: %s' % str(resolving_unk))

    wer_eval92, df_eval92 = eval_word(models=[model],
                                      dataset=test_data,
                                      beam_width=args.beam_width,
                                      beam_width_sub=args.beam_width_sub,
                                      max_decode_len=MAX_DECODE_LEN_WORD,
                                      max_decode_len_sub=MAX_DECODE_LEN_CHAR,
                                      eval_batch_size=args.eval_batch_size,
                                      progressbar=True,
                                      resolving_unk=resolving_unk,
                                      a2c_oracle=a2c_oracle)
    print('  WER (eval92, main): %.3f %%' % (wer_eval92 * 100))
    print(df_eval92)
    wer_eval92_sub, cer_eval92_sub, df_eval92_sub = eval_char(
        models=[model],
        dataset=test_data,
        beam_width=args.beam_width_sub,
        max_decode_len=MAX_DECODE_LEN_CHAR,
        eval_batch_size=args.eval_batch_size,
        progressbar=True)
    print(' WER / CER (eval92, sub): %.3f / %.3f %%' %
          ((wer_eval92_sub * 100), (cer_eval92_sub * 100)))
    print(df_eval92_sub)
示例#16
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    wer_mean, cer_mean = 0, 0
    with open(join(args.model_path, 'RESULTS'), 'w') as f:
        for i, data_type in enumerate(['eval1', 'eval2', 'eval3']):
            # Load dataset
            eval_data = Dataset(
                data_save_path=args.data_save_path,
                backend=params['backend'],
                input_freq=params['input_freq'],
                use_delta=params['use_delta'],
                use_double_delta=params['use_double_delta'],
                data_type=data_type, data_size=params['data_size'],
                label_type=params['label_type'],
                batch_size=args.eval_batch_size, splice=params['splice'],
                num_stack=params['num_stack'], num_skip=params['num_skip'],
                shuffle=False, tool=params['tool'])

            if i == 0:
                params['num_classes'] = eval_data.num_classes

                # Load model
                model = load(model_type=params['model_type'],
                             params=params,
                             backend=params['backend'])

                # Restore the saved parameters
                model.load_checkpoint(
                    save_path=args.model_path, epoch=args.epoch)

                # GPU setting
                model.set_cuda(deterministic=False, benchmark=True)

            print('beam width: %d' % args.beam_width)
            f.write('beam width: %d\n' % args.beam_width)

            if params['label_type'] == 'word':
                wer, df = eval_word(
                    models=[model],
                    dataset=eval_data,
                    eval_batch_size=args.eval_batch_size,
                    beam_width=args.beam_width,
                    max_decode_len=MAX_DECODE_LEN_WORD,
                    length_penalty=args.length_penalty,
                    progressbar=True)
                wer_mean += wer
                print('  WER (%s): %.3f %%' % (data_type, (wer * 100)))
                f.write('  WER (%s): %.3f %%' % (data_type, (wer * 100)))
                print(df)
            else:
                wer, cer, df = eval_char(
                    models=[model],
                    dataset=eval_data,
                    eval_batch_size=args.eval_batch_size,
                    beam_width=args.beam_width,
                    max_decode_len=MAX_DECODE_LEN_CHAR,
                    length_penalty=args.length_penalty,
                    progressbar=True)
                wer_mean += wer
                cer_mean += cer
                print(' WER / CER (%s, sub): %.3f / %.3f %%' %
                      (data_type, (wer * 100), (cer * 100)))
                f.write(' WER / CER (%s, sub): %.3f / %.3f %%' %
                        (data_type, (wer * 100), (cer * 100)))
                print(df)

        if params['label_type'] == 'word':
            print('  WER (mean): %.3f %%' % (wer * 100 / 3))
            f.write('  WER (mean): %.3f %%' % (wer * 100 / 3))
        else:
            print('  WER / CER (mean): %.3f / %.3f %%' %
                  ((wer * 100 / 3), (cer * 100 / 3)))
            f.write('  WER / CER (mean): %.3f / %.3f %%' %
                    ((wer * 100 / 3), (cer * 100 / 3)))
示例#17
0
def main():

    model_paths = [
        path for path in glob(
            join(
                '/n/sd8/inaguma/result/pytorch/librispeech/ctc/character/100h',
                '*'))
    ]

    if len(model_paths) == 0:
        raise ValueError('There are no model path.')

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(model_paths[0], 'config.yml'), is_eval=True)

    # Load dataset
    vocab_file_path = '../metrics/vocab_files/' + \
        params['label_type'] + '_' + params['data_size'] + '.txt'
    test_clean_data = Dataset(backend=params['backend'],
                              input_channel=params['input_channel'],
                              use_delta=params['use_delta'],
                              use_double_delta=params['use_double_delta'],
                              data_type='test_clean',
                              data_size=params['data_size'],
                              label_type=params['label_type'],
                              vocab_file_path=vocab_file_path,
                              batch_size=args.eval_batch_size,
                              splice=params['splice'],
                              num_stack=params['num_stack'],
                              num_skip=params['num_skip'],
                              sort_utt=False,
                              save_format=params['save_format'])
    test_other_data = Dataset(backend=params['backend'],
                              input_channel=params['input_channel'],
                              use_delta=params['use_delta'],
                              use_double_delta=params['use_double_delta'],
                              data_type='test_other',
                              data_size=params['data_size'],
                              label_type=params['label_type'],
                              vocab_file_path=vocab_file_path,
                              batch_size=args.eval_batch_size,
                              splice=params['splice'],
                              num_stack=params['num_stack'],
                              num_skip=params['num_skip'],
                              sort_utt=False,
                              save_format=params['save_format'])

    models = []
    for model_path in model_paths:
        if isfile(join(model_path, 'complete.txt')):
            # Load a config file (.yml)
            params = load_config(join(model_path, 'config.yml'), is_eval=True)

            params['num_classes'] = test_clean_data.num_classes

            # Load model
            model = load(model_type=params['model_type'],
                         params=params,
                         backend=params['backend'])

            # Restore the saved model
            model.load_checkpoint(save_path=args.model_path, epoch=-1)

            # GPU setting
            model.set_cuda(deterministic=False, benchmark=True)

            models.append(model)

    print('=' * 30)
    print('  frame stack %d' % int(params['num_stack']))
    print('  beam width: %d' % args.beam_width)
    print('  ensemble: %d' % len(models))
    print('  temperature (training): %d' % params['logits_temperature'])
    print('  temperature (inference): %d' % args.temperature)
    print('=' * 30)

    if 'char' in params['label_type']:
        cer_test_clean, wer_test_clean = do_eval_cer(
            models=models,
            dataset=test_clean_data,
            beam_width=args.beam_width,
            max_decode_len=args.max_decode_len,
            eval_batch_size=args.eval_batch_size,
            temperature=args.temperature,
            progressbar=True)
        print('  CER (clean): %f %%' % (cer_test_clean * 100))
        print('  WER (clean): %f %%' % (wer_test_clean * 100))
        cer_test_other, wer_test_other = do_eval_cer(
            models=models,
            dataset=test_other_data,
            beam_width=args.beam_width,
            max_decode_len=args.max_decode_len,
            eval_batch_size=args.eval_batch_size,
            progressbar=True)
        print('  CER (other): %f %%' % (cer_test_other * 100))
        print('  WER (other): %f %%' % (wer_test_other * 100))
        print('  CER (mean): %f %%' %
              ((cer_test_clean + cer_test_other) * 100 / 2))
        print('  WER (mean): %f %%' %
              ((wer_test_clean + wer_test_other) * 100 / 2))
    else:
        raise NotImplementedError
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        data_type='eval1',
        # data_type='eval2',
        # data_type='eval3',
        data_size=params['data_size'],
        label_type=params['label_type'],
        label_type_sub=params['label_type_sub'],
        batch_size=args.eval_batch_size,
        splice=params['splice'],
        num_stack=params['num_stack'],
        num_skip=params['num_skip'],
        sort_utt=False,
        reverse=False,
        tool=params['tool'])

    params['num_classes'] = dataset.num_classes
    params['num_classes_sub'] = dataset.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    save_path = mkdir_join(args.model_path, 'att_weights')

    ######################################################################

    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    for batch, is_new_epoch in dataset:
        # Decode
        best_hyps, aw, perm_idx = model.decode(
            batch['xs'],
            batch['x_lens'],
            beam_width=args.beam_width,
            max_decode_len=MAX_DECODE_LEN_WORD,
            min_decode_len=MIN_DECODE_LEN_WORD,
            length_penalty=args.length_penalty,
            coverage_penalty=args.coverage_penalty)
        best_hyps_sub, aw_sub, _ = model.decode(
            batch['xs'],
            batch['x_lens'],
            beam_width=args.beam_width_sub,
            max_decode_len=MAX_DECODE_LEN_CHAR,
            min_decode_len=MIN_DECODE_LEN_CHAR,
            length_penalty=args.length_penalty,
            coverage_penalty=args.coverage_penalty,
            task_index=1)

        for b in range(len(batch['xs'])):

            word_list = dataset.idx2word(best_hyps[b], return_list=True)
            char_list = dataset.idx2char(best_hyps_sub[b], return_list=True)

            speaker = batch['input_names'][b].split('_')[0]

            plot_hierarchical_attention_weights(
                aw[b][:len(word_list), :batch['x_lens'][b]],
                aw_sub[b][:len(char_list), :batch['x_lens'][b]],
                label_list=word_list,
                label_list_sub=char_list,
                spectrogram=batch['xs'][b, :, :dataset.input_freq],
                save_path=mkdir_join(save_path, speaker,
                                     batch['input_names'][b] + '.png'),
                figsize=(40, 8))

        if is_new_epoch:
            break
示例#19
0
def main():

    args = parser.parse_args()

    ##################################################
    # DATSET
    ##################################################
    if args.model_save_path is not None:
        # Load a config file (.yml)
        params = load_config(args.config_path)
    # NOTE: Retrain the saved model from the last checkpoint
    elif args.saved_model_path is not None:
        params = load_config(os.path.join(args.saved_model_path, 'config.yml'))
    else:
        raise ValueError("Set model_save_path or saved_model_path.")

    # Load dataset
    train_data = Dataset(data_save_path=args.data_save_path,
                         backend=params['backend'],
                         input_channel=params['input_channel'],
                         use_delta=params['use_delta'],
                         use_double_delta=params['use_double_delta'],
                         data_type='train',
                         data_size=params['data_size'],
                         label_type=params['label_type'],
                         batch_size=params['batch_size'],
                         max_epoch=params['num_epoch'],
                         splice=params['splice'],
                         num_stack=params['num_stack'],
                         num_skip=params['num_skip'],
                         sort_utt=True,
                         sort_stop_epoch=params['sort_stop_epoch'],
                         tool=params['tool'],
                         num_enque=None,
                         dynamic_batching=params['dynamic_batching'])
    dev_clean_data = Dataset(data_save_path=args.data_save_path,
                             backend=params['backend'],
                             input_channel=params['input_channel'],
                             use_delta=params['use_delta'],
                             use_double_delta=params['use_double_delta'],
                             data_type='dev_clean',
                             data_size=params['data_size'],
                             label_type=params['label_type'],
                             batch_size=params['batch_size'],
                             splice=params['splice'],
                             num_stack=params['num_stack'],
                             num_skip=params['num_skip'],
                             shuffle=True,
                             tool=params['tool'])
    dev_other_data = Dataset(data_save_path=args.data_save_path,
                             backend=params['backend'],
                             input_channel=params['input_channel'],
                             use_delta=params['use_delta'],
                             use_double_delta=params['use_double_delta'],
                             data_type='dev_other',
                             data_size=params['data_size'],
                             label_type=params['label_type'],
                             batch_size=params['batch_size'],
                             splice=params['splice'],
                             num_stack=params['num_stack'],
                             num_skip=params['num_skip'],
                             shuffle=True,
                             tool=params['tool'])
    test_clean_data = Dataset(data_save_path=args.data_save_path,
                              backend=params['backend'],
                              input_channel=params['input_channel'],
                              use_delta=params['use_delta'],
                              use_double_delta=params['use_double_delta'],
                              data_type='test_clean',
                              data_size=params['data_size'],
                              label_type=params['label_type'],
                              batch_size=params['batch_size'],
                              splice=params['splice'],
                              num_stack=params['num_stack'],
                              num_skip=params['num_skip'],
                              tool=params['tool'])
    test_other_data = Dataset(data_save_path=args.data_save_path,
                              backend=params['backend'],
                              input_channel=params['input_channel'],
                              use_delta=params['use_delta'],
                              use_double_delta=params['use_double_delta'],
                              data_type='test_other',
                              data_size=params['data_size'],
                              label_type=params['label_type'],
                              batch_size=params['batch_size'],
                              splice=params['splice'],
                              num_stack=params['num_stack'],
                              num_skip=params['num_skip'],
                              tool=params['tool'])

    params['num_classes'] = train_data.num_classes

    ##################################################
    # MODEL
    ##################################################
    # Model setting
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    if args.model_save_path is not None:

        # Set save path
        save_path = mkdir_join(args.model_save_path, params['backend'],
                               params['model_type'], params['label_type'],
                               params['data_size'], model.name)
        model.set_save_path(save_path)

        # Save config file
        save_config(config_path=args.config_path, save_path=model.save_path)

        # Setting for logging
        logger = set_logger(model.save_path)

        if os.path.isdir(params['char_init']):
            # NOTE: Start training from the pre-trained character model
            model.load_checkpoint(save_path=params['char_init'],
                                  epoch=-1,
                                  load_pretrained_model=True)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            logger.info("%s %d" % (name, num_params))
        logger.info("Total %.3f M parameters" %
                    (model.total_parameters / 1000000))

        # Define optimizer
        model.set_optimizer(optimizer=params['optimizer'],
                            learning_rate_init=float(params['learning_rate']),
                            weight_decay=float(params['weight_decay']),
                            clip_grad_norm=params['clip_grad_norm'],
                            lr_schedule=False,
                            factor=params['decay_rate'],
                            patience_epoch=params['decay_patient_epoch'])

        epoch, step = 1, 0
        learning_rate = float(params['learning_rate'])
        metric_dev_best = 1

    # NOTE: Retrain the saved model from the last checkpoint
    elif args.saved_model_path is not None:

        # Set save path
        model.save_path = args.saved_model_path

        # Setting for logging
        logger = set_logger(model.save_path, restart=True)

        # Define optimizer
        model.set_optimizer(
            optimizer=params['optimizer'],
            learning_rate_init=float(params['learning_rate']),  # on-the-fly
            weight_decay=float(params['weight_decay']),
            clip_grad_norm=params['clip_grad_norm'],
            lr_schedule=False,
            factor=params['decay_rate'],
            patience_epoch=params['decay_patient_epoch'])

        # Restore the last saved model
        epoch, step, learning_rate, metric_dev_best = model.load_checkpoint(
            save_path=args.saved_model_path, epoch=-1, restart=True)

    else:
        raise ValueError("Set model_save_path or saved_model_path.")

    train_data.epoch = epoch - 1

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    setproctitle('libri_' + params['backend'] + '_' + params['model_type'] +
                 '_' + params['label_type'] + '_' + params['data_size'])

    ##################################################
    # TRAINING LOOP
    ##################################################
    # Define learning rate controller
    lr_controller = Controller(
        learning_rate_init=learning_rate,
        backend=params['backend'],
        decay_start_epoch=params['decay_start_epoch'],
        decay_rate=params['decay_rate'],
        decay_patient_epoch=params['decay_patient_epoch'],
        lower_better=True)

    # Setting for tensorboard
    if params['backend'] == 'pytorch':
        tf_writer = SummaryWriter(model.save_path)

    # Train model
    csv_steps, csv_loss_train, csv_loss_dev = [], [], []
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0
    best_model = model
    loss_train_mean = 0.
    pbar_epoch = tqdm(total=len(train_data))
    while True:
        # Compute loss in the training set (including parameter update)
        batch_train, is_new_epoch = train_data.next()
        model, loss_train_val = train_step(model,
                                           batch_train,
                                           params['clip_grad_norm'],
                                           backend=params['backend'])
        loss_train_mean += loss_train_val

        pbar_epoch.update(len(batch_train['xs']))

        if (step + 1) % params['print_step'] == 0:

            # Compute loss in the dev set
            batch_dev = dev_clean_data.next()[0]
            loss_dev = model(batch_dev['xs'],
                             batch_dev['ys'],
                             batch_dev['x_lens'],
                             batch_dev['y_lens'],
                             is_eval=True)

            loss_train_mean /= params['print_step']
            csv_steps.append(step)
            csv_loss_train.append(loss_train_mean)
            csv_loss_dev.append(loss_dev)

            # Logging by tensorboard
            if params['backend'] == 'pytorch':
                tf_writer.add_scalar('train/loss', loss_train_mean, step + 1)
                tf_writer.add_scalar('dev/loss', loss_dev, step + 1)
                for name, param in model.named_parameters():
                    name = name.replace('.', '/')
                    tf_writer.add_histogram(name,
                                            param.data.cpu().numpy(), step + 1)
                    tf_writer.add_histogram(name + '/grad',
                                            param.grad.data.cpu().numpy(),
                                            step + 1)

            duration_step = time.time() - start_time_step
            logger.info(
                "...Step:%d(epoch:%.3f) loss:%.3f(%.3f)/lr:%.5f/batch:%d/x_lens:%d (%.3f min)"
                % (step + 1, train_data.epoch_detail, loss_train_mean,
                   loss_dev, learning_rate, train_data.current_batch_size,
                   max(batch_train['x_lens']) * params['num_stack'],
                   duration_step / 60))
            start_time_step = time.time()
            loss_train_mean = 0.
        step += 1

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('===== EPOCH:%d (%.3f min) =====' %
                        (epoch, duration_epoch / 60))

            # Save fugure of loss
            plot_loss(csv_loss_train,
                      csv_loss_dev,
                      csv_steps,
                      save_path=model.save_path)

            if epoch < params['eval_start_epoch']:
                # Save the model
                model.save_checkpoint(model.save_path, epoch, step,
                                      learning_rate, metric_dev_best)
            else:
                start_time_eval = time.time()
                # dev
                if 'word' in params['label_type']:
                    metric_dev_epoch, _ = do_eval_wer(
                        models=[model],
                        dataset=dev_clean_data,
                        beam_width=1,
                        max_decode_len=MAX_DECODE_LEN_WORD,
                        eval_batch_size=1)
                    logger.info('  WER (dev-clean): %.3f %%' %
                                (metric_dev_epoch * 100))
                else:
                    metric_dev_epoch, wer_dev_clean_epoch, _ = do_eval_cer(
                        models=[model],
                        dataset=dev_clean_data,
                        beam_width=1,
                        max_decode_len=MAX_DECODE_LEN_CHAR,
                        eval_batch_size=1)
                    logger.info('  CER / WER (dev-clean): %.3f %% / %.3f %%' %
                                ((metric_dev_epoch * 100),
                                 (wer_dev_clean_epoch * 100)))

                if metric_dev_epoch < metric_dev_best:
                    metric_dev_best = metric_dev_epoch
                    not_improved_epoch = 0
                    best_model = copy.deepcopy(model)
                    logger.info('||||| Best Score |||||')

                    # Save the model
                    model.save_checkpoint(model.save_path, epoch, step,
                                          learning_rate, metric_dev_best)

                    # dev-other & test
                    if 'word' in params['label_type']:
                        metric_dev_other_epoch, _ = do_eval_wer(
                            models=[model],
                            dataset=dev_other_data,
                            beam_width=1,
                            max_decode_len=MAX_DECODE_LEN_WORD,
                            eval_batch_size=1)
                        logger.info('  WER (dev-other): %.3f %%' %
                                    (metric_dev_other_epoch * 100))

                        wer_test_clean, _ = do_eval_wer(
                            models=[model],
                            dataset=test_clean_data,
                            beam_width=1,
                            max_decode_len=MAX_DECODE_LEN_WORD,
                            eval_batch_size=1)
                        logger.info('  WER (test-clean): %.3f %%' %
                                    (wer_test_clean * 100))

                        wer_test_other, _ = do_eval_wer(
                            models=[model],
                            dataset=test_other_data,
                            beam_width=1,
                            max_decode_len=MAX_DECODE_LEN_WORD,
                            eval_batch_size=1)
                        logger.info('  WER (test-other): %.3f %%' %
                                    (wer_test_other * 100))

                        logger.info(
                            '  WER (test-mean): %.3f %%' %
                            ((wer_test_clean + wer_test_other) * 100 / 2))
                    else:
                        metric_dev_other_epoch, wer_dev_other_epoch, _ = do_eval_cer(
                            models=[model],
                            dataset=dev_other_data,
                            beam_width=1,
                            max_decode_len=MAX_DECODE_LEN_CHAR,
                            eval_batch_size=1)
                        logger.info(
                            '  CER / WER (dev-other): %.3f %% / %.3f %%' %
                            ((metric_dev_other_epoch * 100),
                             (wer_dev_other_epoch * 100)))

                        cer_test_clean, wer_test_clean, _ = do_eval_cer(
                            models=[model],
                            dataset=test_clean_data,
                            beam_width=1,
                            max_decode_len=MAX_DECODE_LEN_CHAR,
                            eval_batch_size=1)
                        logger.info(
                            '  CER / WER (test-clean): %.3f %% / %.3f %%' %
                            ((cer_test_clean * 100), (wer_test_clean * 100)))

                        cer_test_other, wer_test_other, _ = do_eval_cer(
                            models=[model],
                            dataset=test_other_data,
                            beam_width=1,
                            max_decode_len=MAX_DECODE_LEN_CHAR,
                            eval_batch_size=1)
                        logger.info(
                            '  CER / WER (test-other): %.3f %% / %.3f %%' %
                            ((cer_test_other * 100), (wer_test_other * 100)))

                        logger.info(
                            '  CER / WER (test-mean): %.3f %% / %.3f %%' %
                            (((cer_test_clean + cer_test_other) * 100 / 2),
                             ((wer_test_clean + wer_test_other) * 100 / 2)))

                else:
                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.3f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == params['not_improved_patient_epoch']:
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=epoch,
                    value=metric_dev_epoch)

                if epoch == params['convert_to_sgd_epoch']:
                    # Convert to fine-tuning stage
                    model.set_optimizer(
                        'sgd',
                        learning_rate_init=learning_rate,
                        weight_decay=float(params['weight_decay']),
                        clip_grad_norm=params['clip_grad_norm'],
                        lr_schedule=False,
                        factor=params['decay_rate'],
                        patience_epoch=params['decay_patient_epoch'])
                    logger.info('========== Convert to SGD ==========')

                    # Inject Gaussian noise to all parameters
                    if float(params['weight_noise_std']) > 0:
                        model.weight_noise_injection = True

            pbar_epoch = tqdm(total=len(train_data))
            print('========== EPOCH:%d (%.3f min) ==========' %
                  (epoch, duration_epoch / 60))

            if epoch == params['num_epoch']:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    # TODO: evaluate the best model by beam search here

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.3f hour' % (duration_train / 3600))

    if params['backend'] == 'pytorch':
        tf_writer.close()

    # Training was finished correctly
    with open(os.path.join(model.save_path, 'COMPLETE'), 'w') as f:
        f.write('')
示例#20
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        data_type='test',
        label_type=params['label_type'],
        batch_size=args.eval_batch_size, splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        sort_utt=True, reverse=True, tool=params['tool'])

    params['num_classes'] = dataset.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    # sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    ######################################################################

    for batch, is_new_epoch in dataset:
        # Decode
        best_hyps, _, perm_idx = model.decode(
            batch['xs'], batch['x_lens'],
            beam_width=args.beam_width,
            max_decode_len=MAX_DECODE_LEN_PHONE,
            min_decode_len=MIN_DECODE_LEN_PHONE,
            length_penalty=args.length_penalty,
            coverage_penalty=args.coverage_penalty)

        if model.model_type == 'attention' and model.ctc_loss_weight > 0:
            best_hyps_ctc, perm_idx = model.decode_ctc(
                batch['xs'], batch['x_lens'],
                beam_width=args.beam_width)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space(' ')
            else:
                # Convert from list of index to string
                str_ref = dataset.idx2phone(ys[b][: y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = dataset.idx2phone(best_hyps[b])

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref      : %s' % str_ref)
            print('Hyp      : %s' % str_hyp)
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                str_hyp_ctc = dataset.idx2phone(best_hyps_ctc[b])
                print('Hyp (CTC): %s' % str_hyp_ctc)

            # Compute PER
            per, _, _, _ = compute_wer(ref=str_ref.split(' '),
                                       hyp=re.sub(r'(.*) >(.*)', r'\1',
                                                  str_hyp).split(' '),
                                       normalize=True)
            print('PER: %.3f %%' % (per * 100))
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                per_ctc, _, _, _ = compute_wer(ref=str_ref.split(' '),
                                               hyp=str_hyp_ctc.split(' '),
                                               normalize=True)
                print('PER (CTC): %.3f %%' % (per_ctc * 100))

        if is_new_epoch:
            break
示例#21
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    logger = set_logger(args.model_path)

    for i, data_type in enumerate(
        ['dev_clean', 'dev_other', 'test_clean', 'test_other']):
        # Load dataset
        dataset = Dataset(data_save_path=args.data_save_path,
                          backend=params['backend'],
                          input_freq=params['input_freq'],
                          use_delta=params['use_delta'],
                          use_double_delta=params['use_double_delta'],
                          data_type=data_type,
                          data_size=params['data_size'],
                          label_type=params['label_type'],
                          batch_size=args.eval_batch_size,
                          splice=params['splice'],
                          num_stack=params['num_stack'],
                          num_skip=params['num_skip'],
                          sort_utt=False,
                          tool=params['tool'])

        if i == 0:
            params['num_classes'] = dataset.num_classes

            # Load model
            model = load(model_type=params['model_type'],
                         params=params,
                         backend=params['backend'])

            # Restore the saved parameters
            epoch, _, _, _ = model.load_checkpoint(save_path=args.model_path,
                                                   epoch=args.epoch)

            # GPU setting
            model.set_cuda(deterministic=False, benchmark=True)

            logger.info('beam width: %d' % args.beam_width)
            logger.info('epoch: %d' % (epoch - 1))

        if params['label_type'] == 'word':
            wer, df = eval_word(models=[model],
                                dataset=dataset,
                                eval_batch_size=args.eval_batch_size,
                                beam_width=args.beam_width,
                                max_decode_len=MAX_DECODE_LEN_WORD,
                                min_decode_len=MIN_DECODE_LEN_WORD,
                                length_penalty=args.length_penalty,
                                coverage_penalty=args.coverage_penalty,
                                progressbar=True)
            logger.info('  WER (%s): %.3f %%' % (dataset.label_type,
                                                 (wer * 100)))
            logger.info(df)
        else:
            wer, cer, df = eval_char(models=[model],
                                     dataset=dataset,
                                     eval_batch_size=args.eval_batch_size,
                                     beam_width=args.beam_width,
                                     max_decode_len=MAX_DECODE_LEN_CHAR,
                                     min_decode_len=MIN_DECODE_LEN_CHAR,
                                     length_penalty=args.length_penalty,
                                     coverage_penalty=args.coverage_penalty,
                                     progressbar=True)
            logger.info('  WER / CER (%s): %.3f / %.3f %%' %
                        (dataset.label_type, (wer * 100), (cer * 100)))
            logger.info(df)
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        data_type='eval1',
        # data_type='eval2',
        # data_type='eval3',
        data_size=params['data_size'],
        label_type=params['label_type'],
        label_type_sub=params['label_type_sub'],
        batch_size=args.eval_batch_size,
        splice=params['splice'],
        num_stack=params['num_stack'],
        num_skip=params['num_skip'],
        sort_utt=False,
        reverse=False,
        tool=params['tool'])

    params['num_classes'] = dataset.num_classes
    params['num_classes_sub'] = dataset.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    a2c_oracle = False

    save_path = mkdir_join(args.model_path, 'att_weights')

    ######################################################################

    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    for batch, is_new_epoch in dataset:
        batch_size = len(batch['xs'])

        if a2c_oracle:
            if dataset.is_test:
                max_label_num = 0
                for b in range(batch_size):
                    if max_label_num < len(list(batch['ys_sub'][b][0])):
                        max_label_num = len(list(batch['ys_sub'][b][0]))

                ys_sub = np.zeros((batch_size, max_label_num), dtype=np.int32)
                ys_sub -= 1  # pad with -1
                y_lens_sub = np.zeros((batch_size, ), dtype=np.int32)
                for b in range(batch_size):
                    indices = dataset.char2idx(batch['ys_sub'][b][0])
                    ys_sub[b, :len(indices)] = indices
                    y_lens_sub[b] = len(indices)
                    # NOTE: transcript is seperated by space('_')
            else:
                ys_sub = batch['ys_sub']
                y_lens_sub = batch['y_lens_sub']
        else:
            ys_sub = None
            y_lens_sub = None

        best_hyps, aw, best_hyps_sub, aw_sub, aw_dec, _ = model.decode(
            batch['xs'],
            batch['x_lens'],
            beam_width=args.beam_width,
            max_decode_len=MAX_DECODE_LEN_WORD,
            min_decode_len=MIN_DECODE_LEN_WORD,
            beam_width_sub=args.beam_width_sub,
            max_decode_len_sub=MAX_DECODE_LEN_CHAR,
            min_decode_len_sub=MIN_DECODE_LEN_CHAR,
            length_penalty=args.length_penalty,
            coverage_penalty=args.coverage_penalty,
            teacher_forcing=a2c_oracle,
            ys_sub=ys_sub,
            y_lens_sub=y_lens_sub)

        for b in range(len(batch['xs'])):
            word_list = dataset.idx2word(best_hyps[b], return_list=True)
            if dataset.label_type_sub == 'word':
                char_list = dataset.idx2word(best_hyps_sub[b],
                                             return_list=True)
            else:
                char_list = dataset.idx2char(best_hyps_sub[b],
                                             return_list=True)

            speaker = batch['input_names'][b].split('_')[0]

            # word to acoustic & character to acoustic
            plot_hierarchical_attention_weights(
                aw[b][:len(word_list), :batch['x_lens'][b]],
                aw_sub[b][:len(char_list), :batch['x_lens'][b]],
                label_list=word_list,
                label_list_sub=char_list,
                spectrogram=batch['xs'][b, :, :dataset.input_freq],
                save_path=mkdir_join(save_path, speaker,
                                     batch['input_names'][b] + '.png'),
                figsize=(40, 8))

            # word to characater
            plot_nested_attention_weights(
                aw_dec[b][:len(word_list), :len(char_list)],
                label_list=word_list,
                label_list_sub=char_list,
                save_path=mkdir_join(
                    save_path, speaker,
                    batch['input_names'][b] + '_word2char.png'),
                figsize=(40, 8))

            # with open(join(save_path, speaker, batch['input_names'][b] + '.txt'), 'w') as f:
            #     f.write(batch['ys'][b][0])

        if is_new_epoch:
            break