def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    test_data = Dataset(data_save_path=args.data_save_path,
                        backend=params['backend'],
                        input_freq=params['input_freq'],
                        use_delta=params['use_delta'],
                        use_double_delta=params['use_double_delta'],
                        data_type='test_eval92',
                        data_size=params['data_size'],
                        label_type=params['label_type'],
                        label_type_sub=params['label_type_sub'],
                        batch_size=args.eval_batch_size,
                        splice=params['splice'],
                        num_stack=params['num_stack'],
                        num_skip=params['num_skip'],
                        sort_utt=False,
                        reverse=False,
                        tool=params['tool'])

    params['num_classes'] = test_data.num_classes
    params['num_classes_sub'] = test_data.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    a2c_oracle = False

    # Visualize
    plot(model=model,
         dataset=test_data,
         beam_width=args.beam_width,
         beam_width_sub=args.beam_width_sub,
         eval_batch_size=args.eval_batch_size,
         a2c_oracle=a2c_oracle,
         save_path=mkdir_join(args.model_path, 'att_weights'))
Beispiel #2
0
def main():

    args = parser.parse_args()

    ##################################################
    # DATSET
    ##################################################
    if args.model_save_path is not None:
        # Load a config file (.yml)
        params = load_config(args.config_path)
    # NOTE: Retrain the saved model from the last checkpoint
    elif args.saved_model_path is not None:
        params = load_config(os.path.join(args.saved_model_path, 'config.yml'))
    else:
        raise ValueError("Set model_save_path or saved_model_path.")

    # Load dataset
    train_data = Dataset(data_save_path=args.data_save_path,
                         backend=params['backend'],
                         input_freq=params['input_freq'],
                         use_delta=params['use_delta'],
                         use_double_delta=params['use_double_delta'],
                         data_type=params['data_size'],
                         data_size=params['data_size'],
                         label_type=params['label_type'],
                         label_type_sub=params['label_type_sub'],
                         batch_size=params['batch_size'],
                         max_epoch=params['num_epoch'],
                         splice=params['splice'],
                         num_stack=params['num_stack'],
                         num_skip=params['num_skip'],
                         min_frame_num=params['min_frame_num'],
                         sort_utt=True,
                         sort_stop_epoch=params['sort_stop_epoch'],
                         tool=params['tool'],
                         num_enque=None,
                         dynamic_batching=params['dynamic_batching'])
    dev93_data = Dataset(data_save_path=args.data_save_path,
                         backend=params['backend'],
                         input_freq=params['input_freq'],
                         use_delta=params['use_delta'],
                         use_double_delta=params['use_double_delta'],
                         data_type='test_dev93',
                         data_size=params['data_size'],
                         label_type=params['label_type'],
                         label_type_sub=params['label_type_sub'],
                         batch_size=params['batch_size'],
                         splice=params['splice'],
                         num_stack=params['num_stack'],
                         num_skip=params['num_skip'],
                         shuffle=True,
                         tool=params['tool'])
    eval92_data = Dataset(data_save_path=args.data_save_path,
                          backend=params['backend'],
                          input_freq=params['input_freq'],
                          use_delta=params['use_delta'],
                          use_double_delta=params['use_double_delta'],
                          data_type='test_eval92',
                          data_size=params['data_size'],
                          label_type=params['label_type'],
                          label_type_sub=params['label_type_sub'],
                          batch_size=params['batch_size'],
                          splice=params['splice'],
                          num_stack=params['num_stack'],
                          num_skip=params['num_skip'],
                          tool=params['tool'])

    params['num_classes'] = train_data.num_classes
    params['num_classes_sub'] = train_data.num_classes_sub

    ##################################################
    # MODEL
    ##################################################
    # Model setting
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    if args.model_save_path is not None:

        # Set save path
        save_path = mkdir_join(
            args.model_save_path, params['backend'], params['model_type'],
            params['label_type'] + '_' + params['label_type_sub'],
            params['data_size'], model.name)
        model.set_save_path(save_path)

        # Save config file
        save_config(config_path=args.config_path, save_path=model.save_path)

        # Setting for logging
        logger = set_logger(model.save_path)

        if os.path.isdir(params['char_init']):
            # NOTE: Start training from the pre-trained character model
            model.load_checkpoint(save_path=params['char_init'],
                                  epoch=-1,
                                  load_pretrained_model=True)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            logger.info("%s %d" % (name, num_params))
        logger.info("Total %.3f M parameters" %
                    (model.total_parameters / 1000000))

        # Define optimizer
        model.set_optimizer(optimizer=params['optimizer'],
                            learning_rate_init=float(params['learning_rate']),
                            weight_decay=float(params['weight_decay']),
                            clip_grad_norm=params['clip_grad_norm'],
                            lr_schedule=False,
                            factor=params['decay_rate'],
                            patience_epoch=params['decay_patient_epoch'])

        epoch, step = 1, 0
        learning_rate = float(params['learning_rate'])
        metric_dev_best = 1

    # NOTE: Retrain the saved model from the last checkpoint
    elif args.saved_model_path is not None:

        # Set save path
        model.save_path = args.saved_model_path

        # Setting for logging
        logger = set_logger(model.save_path, restart=True)

        # Define optimizer
        model.set_optimizer(
            optimizer=params['optimizer'],
            learning_rate_init=float(params['learning_rate']),  # on-the-fly
            weight_decay=float(params['weight_decay']),
            clip_grad_norm=params['clip_grad_norm'],
            lr_schedule=False,
            factor=params['decay_rate'],
            patience_epoch=params['decay_patient_epoch'])

        # Restore the last saved model
        epoch, step, learning_rate, metric_dev_best = model.load_checkpoint(
            save_path=args.saved_model_path, epoch=-1, restart=True)

    else:
        raise ValueError("Set model_save_path or saved_model_path.")

    train_data.epoch = epoch - 1

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    setproctitle('wsj_' + params['backend'] + '_' + params['model_type'] +
                 '_' + params['label_type'] + '_' + params['label_type_sub'] +
                 '_' + params['data_size'])

    ##################################################
    # TRAINING LOOP
    ##################################################
    # Define learning rate controller
    lr_controller = Controller(
        learning_rate_init=learning_rate,
        backend=params['backend'],
        decay_start_epoch=params['decay_start_epoch'],
        decay_rate=params['decay_rate'],
        decay_patient_epoch=params['decay_patient_epoch'],
        lower_better=True)

    # Setting for tensorboard
    if params['backend'] == 'pytorch':
        tf_writer = SummaryWriter(model.save_path)

    # Train model
    csv_steps, csv_loss_train, csv_loss_dev = [], [], []
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0
    best_model = model
    loss_train_mean, loss_main_train_mean, loss_sub_train_mean = 0., 0., 0.
    pbar_epoch = tqdm(total=len(train_data))
    while True:
        # Compute loss in the training set (including parameter update)
        batch_train, is_new_epoch = train_data.next()
        model, loss_train_val, loss_main_train_val, loss_sub_train_val = train_hierarchical_step(
            model,
            batch_train,
            params['clip_grad_norm'],
            backend=params['backend'])
        loss_train_mean += loss_train_val
        loss_main_train_mean += loss_main_train_val
        loss_sub_train_mean += loss_sub_train_val

        pbar_epoch.update(len(batch_train['xs']))

        if (step + 1) % params['print_step'] == 0:

            # Compute loss in the dev set
            batch_dev = dev93_data.next()[0]
            loss_dev, loss_main_dev, loss_sub_dev = model(
                batch_dev['xs'],
                batch_dev['ys'],
                batch_dev['x_lens'],
                batch_dev['y_lens'],
                batch_dev['ys_sub'],
                batch_dev['y_lens_sub'],
                is_eval=True)

            loss_train_mean /= params['print_step']
            loss_main_train_mean /= params['print_step']
            loss_sub_train_mean /= params['print_step']
            csv_steps.append(step)
            csv_loss_train.append(loss_train_mean)
            csv_loss_dev.append(loss_dev)

            # Logging by tensorboard
            if params['backend'] == 'pytorch':
                tf_writer.add_scalar('train/loss', loss_train_mean, step + 1)
                tf_writer.add_scalar('train/loss_main', loss_main_train_mean,
                                     step + 1)
                tf_writer.add_scalar('train/loss_sub', loss_sub_train_mean,
                                     step + 1)
                tf_writer.add_scalar('dev/loss', loss_dev, step + 1)
                tf_writer.add_scalar('dev/loss_main', loss_main_dev, step + 1)
                tf_writer.add_scalar('dev/loss_sub', loss_sub_dev, step + 1)
                # for name, param in model.named_parameters():
                #     name = name.replace('.', '/')
                #     tf_writer.add_histogram(
                #         name, param.data.cpu().numpy(), step + 1)
                #     if param.grad is not None:
                #         tf_writer.add_histogram(
                #             name + '/grad', param.grad.data.cpu().numpy(), step + 1)
                #     # TODO: fix this

            duration_step = time.time() - start_time_step
            logger.info(
                "...Step:%d(epoch:%.3f) loss:%.3f/%.3f/%.3f(%.3f/%.3f/%.3f)/lr:%.5f/batch:%d/x_lens:%d (%.3f min)"
                % (step + 1, train_data.epoch_detail, loss_train_mean,
                   loss_main_train_mean, loss_sub_train_mean, loss_dev,
                   loss_main_dev, loss_sub_dev, learning_rate,
                   train_data.current_batch_size, max(batch_train['x_lens']) *
                   params['num_stack'], duration_step / 60))
            start_time_step = time.time()
            loss_train_mean, loss_main_train_mean, loss_sub_train_mean = 0., 0., 0.
        step += 1

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('===== EPOCH:%d (%.3f min) =====' %
                        (epoch, duration_epoch / 60))

            # Save fugure of loss
            plot_loss(csv_loss_train,
                      csv_loss_dev,
                      csv_steps,
                      save_path=model.save_path)

            if epoch < params['eval_start_epoch']:
                # Save the model
                model.save_checkpoint(model.save_path, epoch, step,
                                      learning_rate, metric_dev_best)
            else:
                start_time_eval = time.time()
                # dev
                if model.main_loss_weight > 0:
                    metric_dev, _ = eval_word(
                        models=[model],
                        dataset=dev93_data,
                        beam_width=1,
                        max_decode_len=MAX_DECODE_LEN_WORD,
                        eval_batch_size=1)
                    logger.info('  WER (dev93, main): %.3f %%' %
                                (metric_dev * 100))
                else:
                    wer_dev_sub, metric_dev, _ = eval_char(
                        models=[model],
                        dataset=dev93_data,
                        beam_width=1,
                        max_decode_len=MAX_DECODE_LEN_CHAR,
                        eval_batch_size=1)
                    logger.info('  WER / CER (dev93, sub): %.3f / %.3f %%' %
                                ((wer_dev_sub * 100), (metric_dev * 100)))

                if metric_dev < metric_dev_best:
                    metric_dev_best = metric_dev
                    not_improved_epoch = 0
                    best_model = copy.deepcopy(model)
                    logger.info('||||| Best Score |||||')

                    # Save the model
                    model.save_checkpoint(model.save_path, epoch, step,
                                          learning_rate, metric_dev_best)

                    # test
                    if model.main_loss_weight > 0:
                        wer_eval92, _ = eval_word(
                            models=[model],
                            dataset=eval92_data,
                            beam_width=1,
                            max_decode_len=MAX_DECODE_LEN_WORD,
                            eval_batch_size=1)
                        logger.info('  WER (eval92, main): %.3f %%' %
                                    (wer_eval92 * 100))
                    else:
                        wer_eval92_sub, cer_eval92_sub, _ = eval_char(
                            models=[model],
                            dataset=eval92_data,
                            beam_width=1,
                            max_decode_len=MAX_DECODE_LEN_CHAR,
                            eval_batch_size=1)
                        logger.info(
                            ' WER / CER (eval92, sub): %.3f / %.3f %%' %
                            ((wer_eval92_sub * 100), (cer_eval92_sub * 100)))
                else:
                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.3f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == params['not_improved_patient_epoch']:
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=epoch,
                    value=metric_dev)

                if epoch == params['convert_to_sgd_epoch']:
                    # Convert to fine-tuning stage
                    model.set_optimizer(
                        'sgd',
                        learning_rate_init=learning_rate,
                        weight_decay=float(params['weight_decay']),
                        clip_grad_norm=params['clip_grad_norm'],
                        lr_schedule=False,
                        factor=params['decay_rate'],
                        patience_epoch=params['decay_patient_epoch'])
                    logger.info('========== Convert to SGD ==========')

                    # Inject Gaussian noise to all parameters
                    if float(params['weight_noise_std']) > 0:
                        model.weight_noise_injection = True

            pbar_epoch = tqdm(total=len(train_data))
            print('========== EPOCH:%d (%.3f min) ==========' %
                  (epoch, duration_epoch / 60))

            if epoch == params['num_epoch']:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.3f hour' % (duration_train / 3600))

    if params['backend'] == 'pytorch':
        tf_writer.close()

    # TODO: evaluate the best model by beam search here

    # Training was finished correctly
    with open(os.path.join(model.save_path, 'COMPLETE'), 'w') as f:
        f.write('')
    def check(self,
              label_type,
              label_type_sub,
              data_type='test_dev93',
              data_size='train_si284',
              backend='pytorch',
              shuffle=False,
              sort_utt=True,
              sort_stop_epoch=None,
              frame_stacking=False,
              splice=1,
              num_gpus=1):

        print('========================================')
        print('  backend: %s' % backend)
        print('  label_type: %s' % label_type)
        print('  label_type_sub: %s' % label_type_sub)
        print('  data_type: %s' % data_type)
        print('  data_size: %s' % data_size)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('  num_gpus: %d' % num_gpus)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(data_save_path='/n/sd8/inaguma/corpus/wsj/kaldi',
                          backend=backend,
                          input_freq=81,
                          use_delta=True,
                          use_double_delta=True,
                          data_type=data_type,
                          data_size=data_size,
                          label_type=label_type,
                          label_type_sub=label_type_sub,
                          batch_size=64,
                          max_epoch=1,
                          splice=splice,
                          num_stack=num_stack,
                          num_skip=num_skip,
                          min_frame_num=40,
                          shuffle=shuffle,
                          sort_utt=sort_utt,
                          reverse=True,
                          sort_stop_epoch=sort_stop_epoch,
                          num_gpus=num_gpus,
                          tool='htk',
                          num_enque=None)

        print('=> Loading mini-batch...')

        for batch, is_new_epoch in dataset:
            if data_type == 'train' and backend == 'pytorch':
                for i in range(len(batch['xs'])):
                    if batch['xs'].shape[1] < batch['ys'].shape[1]:
                        raise ValueError(
                            'input length must be longer than label length.')

            if dataset.is_test:
                str_ref = batch['ys'][0][0]
                str_ref_sub = batch['ys_sub'][0][0]
            else:
                str_ref = dataset.idx2word(batch['ys'][0][:batch['y_lens'][0]])
                str_ref_sub = dataset.idx2char(
                    batch['ys_sub'][0][:batch['y_lens_sub'][0]])

            print('----- %s (epoch: %.3f, batch: %d) -----' %
                  (batch['input_names'][0], dataset.epoch_detail,
                   len(batch['xs'])))
            print('=' * 20)
            print(str_ref)
            print('-' * 10)
            print(str_ref_sub)
            print('x_lens: %d' % (batch['x_lens'][0] * num_stack))
            if not dataset.is_test:
                print('y_lens (word): %d' % batch['y_lens'][0])
                print('y_lens_sub (char): %d' % batch['y_lens_sub'][0])

            if dataset.epoch_detail >= 1:
                break
Beispiel #4
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(data_save_path=args.data_save_path,
                      backend=params['backend'],
                      input_freq=params['input_freq'],
                      use_delta=params['use_delta'],
                      use_double_delta=params['use_double_delta'],
                      data_type='test_eval92',
                      data_size=params['data_size'],
                      label_type=params['label_type'],
                      label_type_sub=params['label_type_sub'],
                      batch_size=args.eval_batch_size,
                      splice=params['splice'],
                      num_stack=params['num_stack'],
                      num_skip=params['num_skip'],
                      sort_utt=False,
                      reverse=False,
                      tool=params['tool'])

    params['num_classes'] = dataset.num_classes
    params['num_classes_sub'] = dataset.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    save_path = mkdir_join(args.model_path, 'att_weights')

    ######################################################################

    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    word2char = Word2char(dataset.vocab_file_path, dataset.vocab_file_path_sub)

    for batch, is_new_epoch in dataset:
        # Decode
        if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
            best_hyps, aw, best_hyps_sub, aw_sub, _ = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD,
                min_decode_len=MIN_DECODE_LEN_WORD,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty,
                joint_decoding=args.joint_decoding,
                space_index=dataset.char2idx('_')[0],
                oov_index=dataset.word2idx('OOV')[0],
                word2char=word2char,
                idx2word=dataset.idx2word,
                idx2char=dataset.idx2char,
                score_sub_weight=args.score_sub_weight)
        else:
            best_hyps, aw, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD,
                min_decode_len=MIN_DECODE_LEN_WORD,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty)
            best_hyps_sub, aw_sub, _ = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_CHAR,
                min_decode_len=MIN_DECODE_LEN_CHAR,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty,
                task_index=1)

        for b in range(len(batch['xs'])):
            word_list = dataset.idx2word(best_hyps[b], return_list=True)
            char_list = dataset.idx2char(best_hyps_sub[b], return_list=True)

            # Visualize
            plot_hierarchical_attention_weights(
                aw[b][:len(word_list), :batch['x_lens'][b]],
                aw_sub[b][:len(char_list), :batch['x_lens'][b]],
                label_list=word_list,
                label_list_sub=char_list,
                spectrogram=batch['xs'][b, :, :dataset.input_freq],
                save_path=mkdir_join(save_path,
                                     batch['input_names'][b] + '.png'),
                figsize=(40, 8))

        if is_new_epoch:
            break
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    test_data = Dataset(data_save_path=args.data_save_path,
                        backend=params['backend'],
                        input_freq=params['input_freq'],
                        use_delta=params['use_delta'],
                        use_double_delta=params['use_double_delta'],
                        data_type='test_eval92',
                        data_size=params['data_size'],
                        label_type=params['label_type'],
                        label_type_sub=params['label_type_sub'],
                        batch_size=args.eval_batch_size,
                        splice=params['splice'],
                        num_stack=params['num_stack'],
                        num_skip=params['num_skip'],
                        sort_utt=False,
                        tool=params['tool'])

    params['num_classes'] = test_data.num_classes
    params['num_classes_sub'] = test_data.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    a2c_oracle = False
    resolving_unk = True

    print('beam width (main): %d' % args.beam_width)
    print('beam width (sub) : %d' % args.beam_width_sub)
    print('a2c oracle: %s' % str(a2c_oracle))
    print('resolving_unk: %s' % str(resolving_unk))

    wer_eval92, df_eval92 = eval_word(models=[model],
                                      dataset=test_data,
                                      beam_width=args.beam_width,
                                      beam_width_sub=args.beam_width_sub,
                                      max_decode_len=MAX_DECODE_LEN_WORD,
                                      max_decode_len_sub=MAX_DECODE_LEN_CHAR,
                                      eval_batch_size=args.eval_batch_size,
                                      progressbar=True,
                                      resolving_unk=resolving_unk,
                                      a2c_oracle=a2c_oracle)
    print('  WER (eval92, main): %.3f %%' % (wer_eval92 * 100))
    print(df_eval92)
    wer_eval92_sub, cer_eval92_sub, df_eval92_sub = eval_char(
        models=[model],
        dataset=test_data,
        beam_width=args.beam_width_sub,
        max_decode_len=MAX_DECODE_LEN_CHAR,
        eval_batch_size=args.eval_batch_size,
        progressbar=True)
    print(' WER / CER (eval92, sub): %.3f / %.3f %%' %
          ((wer_eval92_sub * 100), (cer_eval92_sub * 100)))
    print(df_eval92_sub)
Beispiel #6
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Setting for logging
    logger = set_logger(args.model_path)

    # for i, data_type in enumerate(['test_dev93', 'test_eval92']):
    for i, data_type in enumerate(['test_eval92']):
        # Load dataset
        dataset = Dataset(data_save_path=args.data_save_path,
                          backend=params['backend'],
                          input_freq=params['input_freq'],
                          use_delta=params['use_delta'],
                          use_double_delta=params['use_double_delta'],
                          data_type=data_type,
                          data_size=params['data_size'],
                          label_type=params['label_type'],
                          label_type_sub=params['label_type_sub'],
                          batch_size=args.eval_batch_size,
                          splice=params['splice'],
                          num_stack=params['num_stack'],
                          num_skip=params['num_skip'],
                          sort_utt=False,
                          tool=params['tool'])

        if i == 0:
            params['num_classes'] = dataset.num_classes
            params['num_classes_sub'] = dataset.num_classes_sub

            # Load model
            model = load(model_type=params['model_type'],
                         params=params,
                         backend=params['backend'])

            # Restore the saved parameters
            epoch, _, _, _ = model.load_checkpoint(save_path=args.model_path,
                                                   epoch=args.epoch)

            # GPU setting
            model.set_cuda(deterministic=False, benchmark=True)

            logger.info('beam width (main): %d' % args.beam_width)
            logger.info('beam width (sub) : %d' % args.beam_width_sub)
            logger.info('epoch: %d' % (epoch - 1))
            logger.info('a2c oracle: %s' % str(args.a2c_oracle))
            logger.info('resolving_unk: %s' % str(args.resolving_unk))
            logger.info('joint_decoding: %s' % str(args.joint_decoding))
            logger.info('score_sub_weight : %f' % args.score_sub_weight)

        wer, df = eval_word(models=[model],
                            dataset=dataset,
                            eval_batch_size=args.eval_batch_size,
                            beam_width=args.beam_width,
                            max_decode_len=MAX_DECODE_LEN_WORD,
                            min_decode_len=MIN_DECODE_LEN_WORD,
                            beam_width_sub=args.beam_width_sub,
                            max_decode_len_sub=MAX_DECODE_LEN_CHAR,
                            min_decode_len_sub=MIN_DECODE_LEN_CHAR,
                            length_penalty=args.length_penalty,
                            coverage_penalty=args.coverage_penalty,
                            progressbar=True,
                            resolving_unk=args.resolving_unk,
                            a2c_oracle=args.a2c_oracle,
                            joint_decoding=args.joint_decoding,
                            score_sub_weight=args.score_sub_weight)
        logger.info('  WER (%s, main): %.3f %%' % (dataset.data_type,
                                                   (wer * 100)))
        logger.info(df)
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        # data_type='test_dev93',
        data_type='test_eval92',
        data_size=params['data_size'],
        label_type=params['label_type'],
        label_type_sub=params['label_type_sub'],
        batch_size=args.eval_batch_size,
        splice=params['splice'],
        num_stack=params['num_stack'],
        num_skip=params['num_skip'],
        sort_utt=False,
        reverse=False,
        tool=params['tool'])

    params['num_classes'] = dataset.num_classes
    params['num_classes_sub'] = dataset.num_classes_sub

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    # sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    ######################################################################

    word2char = Word2char(dataset.vocab_file_path, dataset.vocab_file_path_sub)

    for batch, is_new_epoch in dataset:
        # Decode
        if model.model_type == 'nested_attention':
            best_hyps, aw, best_hyps_sub, aw_sub, _, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD,
                min_decode_len=MIN_DECODE_LEN_WORD,
                beam_width_sub=args.beam_width_sub,
                max_decode_len_sub=MAX_DECODE_LEN_CHAR,
                min_decode_len_sub=MIN_DECODE_LEN_CHAR,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty)
        else:
            best_hyps, aw, perm_idx = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD,
                min_decode_len=MIN_DECODE_LEN_WORD,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty)
            best_hyps_sub, aw_sub, _ = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width_sub,
                max_decode_len=MAX_DECODE_LEN_CHAR,
                min_decode_len=MIN_DECODE_LEN_CHAR,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty,
                task_index=1)

        if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
            best_hyps_joint, aw_joint, best_hyps_sub_joint, aw_sub_joint, _ = model.decode(
                batch['xs'],
                batch['x_lens'],
                beam_width=args.beam_width,
                max_decode_len=MAX_DECODE_LEN_WORD,
                min_decode_len=MIN_DECODE_LEN_WORD,
                length_penalty=args.length_penalty,
                coverage_penalty=args.coverage_penalty,
                joint_decoding=args.joint_decoding,
                space_index=dataset.char2idx('_')[0],
                oov_index=dataset.word2idx('OOV')[0],
                word2char=word2char,
                idx2word=dataset.idx2word,
                idx2char=dataset.idx2char,
                score_sub_weight=args.score_sub_weight)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]
        ys_sub = batch['ys_sub'][perm_idx]
        y_lens_sub = batch['y_lens_sub'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                str_ref_sub = ys_sub[b][0]
                # NOTE: transcript is seperated by space('_')
            else:
                # Convert from list of index to string
                str_ref = dataset.idx2word(ys[b][:y_lens[b]])
                str_ref_sub = dataset.idx2char(ys_sub[b][:y_lens_sub[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = dataset.idx2word(best_hyps[b])
            str_hyp_sub = dataset.idx2char(best_hyps_sub[b])
            if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                str_hyp_joint = dataset.idx2word(best_hyps_joint[b])
                str_hyp_sub_joint = dataset.idx2char(best_hyps_sub_joint[b])

            ##############################
            # Resolving UNK
            ##############################
            if 'OOV' in str_hyp and args.resolving_unk:
                str_hyp_no_unk = resolve_unk(str_hyp, best_hyps_sub[b], aw[b],
                                             aw_sub[b], dataset.idx2char)
            if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                if 'OOV' in str_hyp_joint and args.resolving_unk:
                    str_hyp_no_unk_joint = resolve_unk(str_hyp_joint,
                                                       best_hyps_sub_joint[b],
                                                       aw_joint[b],
                                                       aw_sub_joint[b],
                                                       dataset.idx2char)

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref         : %s' % str_ref.replace('_', ' '))
            print('Hyp (main)  : %s' % str_hyp.replace('_', ' '))
            print('Hyp (sub)   : %s' % str_hyp_sub.replace('_', ' '))
            if 'OOV' in str_hyp and args.resolving_unk:
                print('Hyp (no UNK): %s' % str_hyp_no_unk.replace('_', ' '))
            if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                print('===== joint decoding =====')
                print('Hyp (main)  : %s' % str_hyp_joint.replace('_', ' '))
                print('Hyp (sub)   : %s' % str_hyp_sub_joint.replace('_', ' '))
                if 'OOV' in str_hyp_joint and args.resolving_unk:
                    print('Hyp (no UNK): %s' %
                          str_hyp_no_unk_joint.replace('_', ' '))

            try:
                wer, _, _, _ = compute_wer(ref=str_ref.split('_'),
                                           hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                                      str_hyp).split('_'),
                                           normalize=True)
                print('WER (main)  : %.3f %%' % (wer * 100))
                wer_sub, _, _, _ = compute_wer(ref=str_ref_sub.split('_'),
                                               hyp=re.sub(
                                                   r'(.*)>(.*)', r'\1',
                                                   str_hyp_sub).split('_'),
                                               normalize=True)
                print('WER (sub)   : %.3f %%' % (wer_sub * 100))
                if 'OOV' in str_hyp and args.resolving_unk:
                    wer_no_unk, _, _, _ = compute_wer(
                        ref=str_ref.split('_'),
                        hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                   str_hyp_no_unk.replace('*', '')).split('_'),
                        normalize=True)
                    print('WER (no UNK): %.3f %%' % (wer_no_unk * 100))

                if model.model_type == 'hierarchical_attention' and args.joint_decoding is not None:
                    print('===== joint decoding =====')
                    wer_joint, _, _, _ = compute_wer(
                        ref=str_ref.split('_'),
                        hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                   str_hyp_joint).split('_'),
                        normalize=True)
                    print('WER (main)  : %.3f %%' % (wer_joint * 100))
                    if 'OOV' in str_hyp_joint and args.resolving_unk:
                        wer_no_unk_joint, _, _, _ = compute_wer(
                            ref=str_ref.split('_'),
                            hyp=re.sub(r'(.*)_>(.*)', r'\1',
                                       str_hyp_no_unk_joint.replace(
                                           '*', '')).split('_'),
                            normalize=True)
                        print('WER (no UNK): %.3f %%' %
                              (wer_no_unk_joint * 100))
            except:
                print('--- skipped ---')

        if is_new_epoch:
            break