Beispiel #1
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Setting for logging
    logger = set_logger(args.model_path)

    for i, data_type in enumerate(['dev', 'test']):
        # Load dataset
        dataset = Dataset(data_save_path=args.data_save_path,
                          backend=params['backend'],
                          input_freq=params['input_freq'],
                          use_delta=params['use_delta'],
                          use_double_delta=params['use_double_delta'],
                          data_type=data_type,
                          label_type=params['label_type'],
                          batch_size=args.eval_batch_size,
                          splice=params['splice'],
                          num_stack=params['num_stack'],
                          num_skip=params['num_skip'],
                          sort_utt=False,
                          tool=params['tool'])

        if i == 0:
            params['num_classes'] = dataset.num_classes

            # Load model
            model = load(model_type=params['model_type'],
                         params=params,
                         backend=params['backend'])

            # Restore the saved parameters
            epoch, _, _, _ = model.load_checkpoint(save_path=args.model_path,
                                                   epoch=args.epoch)

            # GPU setting
            model.set_cuda(deterministic=False, benchmark=True)

            logger.info('beam width: %d' % args.beam_width)
            logger.info('epoch: %d' % (epoch - 1))

        per, df = eval_phone(model=model,
                             dataset=dataset,
                             map_file_path='./conf/phones.60-48-39.map',
                             eval_batch_size=args.eval_batch_size,
                             beam_width=args.beam_width,
                             max_decode_len=MAX_DECODE_LEN_PHONE,
                             min_decode_len=MIN_DECODE_LEN_PHONE,
                             length_penalty=args.length_penalty,
                             coverage_penalty=args.coverage_penalty,
                             progressbar=True)
        logger.info('  PER (%s): %.3f %%' % (data_type, (per * 100)))
        logger.info(df)
Beispiel #2
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    test_data = Dataset(data_save_path=args.data_save_path,
                        backend=params['backend'],
                        input_freq=params['input_freq'],
                        use_delta=params['use_delta'],
                        use_double_delta=params['use_double_delta'],
                        data_type='test',
                        label_type=params['label_type'],
                        batch_size=args.eval_batch_size,
                        splice=params['splice'],
                        num_stack=params['num_stack'],
                        num_skip=params['num_skip'],
                        sort_utt=True,
                        reverse=True,
                        tool=params['tool'])

    params['num_classes'] = test_data.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    # Visualize
    decode(model=model,
           dataset=test_data,
           eval_batch_size=args.eval_batch_size,
           beam_width=args.beam_width,
           length_penalty=args.length_penalty,
           save_path=None)
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(data_save_path=args.data_save_path,
                      backend=params['backend'],
                      input_freq=params['input_freq'],
                      use_delta=params['use_delta'],
                      use_double_delta=params['use_double_delta'],
                      data_type='test',
                      label_type=params['label_type'],
                      batch_size=args.eval_batch_size,
                      splice=params['splice'],
                      num_stack=params['num_stack'],
                      num_skip=params['num_skip'],
                      sort_utt=True,
                      reverse=True,
                      tool=params['tool'])

    params['num_classes'] = dataset.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    save_path = mkdir_join(args.model_path, 'ctc_probs')

    ######################################################################

    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    for batch, is_new_epoch in dataset:
        # Get CTC probs
        probs, x_lens, _ = model.posteriors(batch['xs'],
                                            batch['x_lens'],
                                            temperature=1)
        # NOTE: probs: '[B, T, num_classes]'

        # Visualize
        for b in range(len(batch['xs'])):
            plot_ctc_probs(probs[b, :x_lens[b], :],
                           frame_num=x_lens[b],
                           num_stack=dataset.num_stack,
                           spectrogram=batch['xs'][b, :, :40],
                           save_path=join(save_path,
                                          batch['input_names'][b] + '.png'),
                           figsize=(14, 7))

        if is_new_epoch:
            break
Beispiel #4
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dev_data = Dataset(data_save_path=args.data_save_path,
                       backend=params['backend'],
                       input_freq=params['input_freq'],
                       use_delta=params['use_delta'],
                       use_double_delta=params['use_double_delta'],
                       data_type='dev',
                       label_type=params['label_type'],
                       batch_size=args.eval_batch_size,
                       splice=params['splice'],
                       num_stack=params['num_stack'],
                       num_skip=params['num_skip'],
                       sort_utt=False,
                       tool=params['tool'])
    test_data = Dataset(data_save_path=args.data_save_path,
                        backend=params['backend'],
                        input_freq=params['input_freq'],
                        use_delta=params['use_delta'],
                        use_double_delta=params['use_double_delta'],
                        data_type='test',
                        label_type=params['label_type'],
                        batch_size=args.eval_batch_size,
                        splice=params['splice'],
                        num_stack=params['num_stack'],
                        num_skip=params['num_skip'],
                        sort_utt=False,
                        tool=params['tool'])

    params['num_classes'] = test_data.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    print('beam width: %d' % args.beam_width)

    # dev
    per_dev, df_dev = eval_phone(model=model,
                                 dataset=dev_data,
                                 map_file_path='./conf/phones.60-48-39.map',
                                 eval_batch_size=args.eval_batch_size,
                                 beam_width=args.beam_width,
                                 max_decode_len=MAX_DECODE_LEN_PHONE,
                                 length_penalty=args.length_penalty,
                                 progressbar=True)
    print('  PER (dev): %.3f %%' % (per_dev * 100))
    print(df_dev)

    # test
    per_test, df_test = eval_phone(model=model,
                                   dataset=test_data,
                                   map_file_path='./conf/phones.60-48-39.map',
                                   eval_batch_size=args.eval_batch_size,
                                   beam_width=args.beam_width,
                                   max_decode_len=MAX_DECODE_LEN_PHONE,
                                   length_penalty=args.length_penalty,
                                   progressbar=True)
    print('  PER (test): %.3f %%' % (per_test * 100))
    print(df_test)

    with open(join(args.model_path, 'RESULTS'), 'w') as f:
        f.write('beam width: %d\n' % args.beam_width)
        f.write('  PER (dev): %.3f %%' % (per_dev * 100))
        f.write('  PER (test): %.3f %%' % (per_test * 100))
Beispiel #5
0
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(
        data_save_path=args.data_save_path,
        backend=params['backend'],
        input_freq=params['input_freq'],
        use_delta=params['use_delta'],
        use_double_delta=params['use_double_delta'],
        data_type='test',
        label_type=params['label_type'],
        batch_size=args.eval_batch_size, splice=params['splice'],
        num_stack=params['num_stack'], num_skip=params['num_skip'],
        sort_utt=True, reverse=True, tool=params['tool'])

    params['num_classes'] = dataset.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    # sys.stdout = open(join(model.model_dir, 'decode.txt'), 'w')

    ######################################################################

    for batch, is_new_epoch in dataset:
        # Decode
        best_hyps, _, perm_idx = model.decode(
            batch['xs'], batch['x_lens'],
            beam_width=args.beam_width,
            max_decode_len=MAX_DECODE_LEN_PHONE,
            min_decode_len=MIN_DECODE_LEN_PHONE,
            length_penalty=args.length_penalty,
            coverage_penalty=args.coverage_penalty)

        if model.model_type == 'attention' and model.ctc_loss_weight > 0:
            best_hyps_ctc, perm_idx = model.decode_ctc(
                batch['xs'], batch['x_lens'],
                beam_width=args.beam_width)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space(' ')
            else:
                # Convert from list of index to string
                str_ref = dataset.idx2phone(ys[b][: y_lens[b]])

            ##############################
            # Hypothesis
            ##############################
            # Convert from list of index to string
            str_hyp = dataset.idx2phone(best_hyps[b])

            print('----- wav: %s -----' % batch['input_names'][b])
            print('Ref      : %s' % str_ref)
            print('Hyp      : %s' % str_hyp)
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                str_hyp_ctc = dataset.idx2phone(best_hyps_ctc[b])
                print('Hyp (CTC): %s' % str_hyp_ctc)

            # Compute PER
            per, _, _, _ = compute_wer(ref=str_ref.split(' '),
                                       hyp=re.sub(r'(.*) >(.*)', r'\1',
                                                  str_hyp).split(' '),
                                       normalize=True)
            print('PER: %.3f %%' % (per * 100))
            if model.model_type == 'attention' and model.ctc_loss_weight > 0:
                per_ctc, _, _, _ = compute_wer(ref=str_ref.split(' '),
                                               hyp=str_hyp_ctc.split(' '),
                                               normalize=True)
                print('PER (CTC): %.3f %%' % (per_ctc * 100))

        if is_new_epoch:
            break
Beispiel #6
0
def main():

    args = parser.parse_args()

    ##################################################
    # DATSET
    ##################################################
    if args.model_save_path is not None:
        # Load a config file (.yml)
        params = load_config(args.config_path)
    # NOTE: Retrain the saved model from the last checkpoint
    elif args.saved_model_path is not None:
        params = load_config(os.path.join(args.saved_model_path, 'config.yml'))
    else:
        raise ValueError("Set model_save_path or saved_model_path.")

    # Load dataset
    train_data = Dataset(data_save_path=args.data_save_path,
                         backend=params['backend'],
                         input_freq=params['input_freq'],
                         use_delta=params['use_delta'],
                         use_double_delta=params['use_double_delta'],
                         data_type='train',
                         label_type=params['label_type'],
                         batch_size=params['batch_size'],
                         max_epoch=params['num_epoch'],
                         splice=params['splice'],
                         num_stack=params['num_stack'],
                         num_skip=params['num_skip'],
                         sort_utt=True,
                         sort_stop_epoch=params['sort_stop_epoch'],
                         tool=params['tool'],
                         num_enque=None,
                         dynamic_batching=params['dynamic_batching'])
    dev_data = Dataset(data_save_path=args.data_save_path,
                       backend=params['backend'],
                       input_freq=params['input_freq'],
                       use_delta=params['use_delta'],
                       use_double_delta=params['use_double_delta'],
                       data_type='dev',
                       label_type=params['label_type'],
                       batch_size=params['batch_size'],
                       splice=params['splice'],
                       num_stack=params['num_stack'],
                       num_skip=params['num_skip'],
                       shuffle=True,
                       tool=params['tool'])
    test_data = Dataset(data_save_path=args.data_save_path,
                        backend=params['backend'],
                        input_freq=params['input_freq'],
                        use_delta=params['use_delta'],
                        use_double_delta=params['use_double_delta'],
                        data_type='test',
                        label_type=params['label_type'],
                        batch_size=1,
                        splice=params['splice'],
                        num_stack=params['num_stack'],
                        num_skip=params['num_skip'],
                        tool=params['tool'])

    params['num_classes'] = train_data.num_classes

    ##################################################
    # MODEL
    ##################################################
    if args.model_save_path is not None:
        # Model setting
        model = load(model_type=params['model_type'],
                     params=params,
                     backend=params['backend'])

        # Set save path
        save_path = mkdir_join(args.model_save_path, params['backend'],
                               params['model_type'], params['label_type'],
                               model.name)
        model.set_save_path(save_path)

        # Save config file
        save_config(config_path=args.config_path, save_path=model.save_path)

        # Setting for logging
        logger = set_logger(model.save_path)

        if os.path.isdir(params['char_init']):
            # NOTE: Start training from the pre-trained character model
            model.load_checkpoint(save_path=params['char_init'],
                                  epoch=-1,
                                  load_pretrained_model=True)

        # Count total parameters
        for name in sorted(list(model.num_params_dict.keys())):
            num_params = model.num_params_dict[name]
            logger.info("%s %d" % (name, num_params))
        logger.info("Total %.3f M parameters" %
                    (model.total_parameters / 1000000))

        # Define optimizer
        model.set_optimizer(optimizer=params['optimizer'],
                            learning_rate_init=float(params['learning_rate']),
                            weight_decay=float(params['weight_decay']),
                            clip_grad_norm=params['clip_grad_norm'],
                            lr_schedule=False,
                            factor=params['decay_rate'],
                            patience_epoch=params['decay_patient_epoch'])

        epoch, step = 1, 0
        learning_rate = float(params['learning_rate'])
        metric_dev_best = 1

    # NOTE: Retrain the saved model from the last checkpoint
    elif args.saved_model_path is not None:
        # Load model
        model = load(model_type=params['model_type'],
                     params=params,
                     backend=params['backend'])

        # Set save path
        model.save_path = args.saved_model_path

        # Setting for logging
        logger = set_logger(model.save_path)

        # Define optimizer
        model.set_optimizer(
            optimizer=params['optimizer'],
            learning_rate_init=float(params['learning_rate']),  # on-the-fly
            weight_decay=float(params['weight_decay']),
            clip_grad_norm=params['clip_grad_norm'],
            lr_schedule=False,
            factor=params['decay_rate'],
            patience_epoch=params['decay_patient_epoch'])

        # Restore the last saved model
        epoch, step, learning_rate, metric_dev_best = model.load_checkpoint(
            save_path=args.saved_model_path, epoch=-1, restart=True)

    else:
        raise ValueError("Set model_save_path or saved_model_path.")

    train_data.epoch = epoch - 1

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    logger.info('PID: %s' % os.getpid())
    logger.info('USERNAME: %s' % os.uname()[1])

    # Set process name
    setproctitle('timit_' + params['backend'] + '_' + params['model_type'] +
                 '_' + params['label_type'])

    ##################################################
    # TRAINING LOOP
    ##################################################
    # Define learning rate controller
    lr_controller = Controller(
        learning_rate_init=params['learning_rate'],
        backend=params['backend'],
        decay_type=params['decay_type'],
        decay_start_epoch=params['decay_start_epoch'],
        decay_rate=params['decay_rate'],
        decay_patient_epoch=params['decay_patient_epoch'],
        lower_better=True)

    # Setting for tensorboard
    if params['backend'] == 'pytorch':
        tf_writer = SummaryWriter(model.save_path)

    # Train model
    csv_steps, csv_loss_train, csv_loss_dev = [], [], []
    start_time_train = time.time()
    start_time_epoch = time.time()
    start_time_step = time.time()
    not_improved_epoch = 0
    loss_train_mean = 0.
    pbar_epoch = tqdm(total=len(train_data))
    best_model = None
    while True:
        # Compute loss in the training set (including parameter update)
        batch_train, is_new_epoch = train_data.next()
        model, loss_train_val = train_step(model, batch_train,
                                           params['clip_grad_norm'],
                                           params['backend'])
        loss_train_mean += loss_train_val

        pbar_epoch.update(len(batch_train['xs']))

        if (step + 1) % params['print_step'] == 0:

            # Compute loss in the dev set
            batch_dev = dev_data.next()[0]
            loss_dev = model(batch_dev['xs'],
                             batch_dev['ys'],
                             batch_dev['x_lens'],
                             batch_dev['y_lens'],
                             is_eval=True)

            loss_train_mean /= params['print_step']
            csv_steps.append(step)
            csv_loss_train.append(loss_train_mean)
            csv_loss_dev.append(loss_dev)

            # Logging by tensorboard
            if params['backend'] == 'pytorch':
                tf_writer.add_scalar('train/loss', loss_train_mean, step + 1)
                tf_writer.add_scalar('dev/loss', loss_dev, step + 1)
                for name, param in model.named_parameters():
                    name = name.replace('.', '/')
                    tf_writer.add_histogram(name,
                                            param.data.cpu().numpy(), step + 1)
                    tf_writer.add_histogram(name + '/grad',
                                            param.grad.data.cpu().numpy(),
                                            step + 1)

            duration_step = time.time() - start_time_step
            logger.info(
                "...Step:%d(epoch:%.3f) loss:%.3f(%.3f)/lr:%.5f/batch:%d/x_lens:%d (%.3f min)"
                % (step + 1, train_data.epoch_detail, loss_train_mean,
                   loss_dev, learning_rate, train_data.current_batch_size,
                   max(batch_train['x_lens']) * params['num_stack'],
                   duration_step / 60))
            start_time_step = time.time()
            loss_train_mean = 0.
        step += 1

        # Save checkpoint and evaluate model per epoch
        if is_new_epoch:
            duration_epoch = time.time() - start_time_epoch
            logger.info('===== EPOCH:%d (%.3f min) =====' %
                        (epoch, duration_epoch / 60))

            # Save fugure of loss
            plot_loss(csv_loss_train,
                      csv_loss_dev,
                      csv_steps,
                      save_path=model.save_path)

            if epoch < params['eval_start_epoch']:
                # Save the model
                model.save_checkpoint(model.save_path, epoch, step,
                                      learning_rate, metric_dev_best)
            else:
                start_time_eval = time.time()
                # dev
                per_dev_epoch, _ = eval_phone(
                    model=model,
                    dataset=dev_data,
                    map_file_path='./conf/phones.60-48-39.map',
                    eval_batch_size=1,
                    beam_width=1,
                    max_decode_len=MAX_DECODE_LEN_PHONE)
                logger.info('  PER (dev): %.3f %%' % (per_dev_epoch * 100))

                if per_dev_epoch < metric_dev_best:
                    metric_dev_best = per_dev_epoch
                    not_improved_epoch = 0
                    best_model = copy.deepcopy(model)
                    logger.info('||||| Best Score (PER) |||||')

                    # Save the model
                    model.save_checkpoint(model.save_path, epoch, step,
                                          learning_rate, metric_dev_best)

                    # test
                    per_test, _ = eval_phone(
                        model=model,
                        dataset=test_data,
                        map_file_path='./conf/phones.60-48-39.map',
                        eval_batch_size=1,
                        beam_width=1,
                        max_decode_len=MAX_DECODE_LEN_PHONE)
                    logger.info('  PER (test): %.3f %%' % (per_test * 100))
                else:
                    not_improved_epoch += 1

                duration_eval = time.time() - start_time_eval
                logger.info('Evaluation time: %.3f min' % (duration_eval / 60))

                # Early stopping
                if not_improved_epoch == params['not_improved_patient_epoch']:
                    break

                # Update learning rate
                model.optimizer, learning_rate = lr_controller.decay_lr(
                    optimizer=model.optimizer,
                    learning_rate=learning_rate,
                    epoch=epoch,
                    value=per_dev_epoch)

                if epoch == params['convert_to_sgd_epoch']:
                    # Convert to fine-tuning stage
                    model.set_optimizer(
                        'sgd',
                        learning_rate_init=learning_rate,
                        weight_decay=float(params['weight_decay']),
                        clip_grad_norm=params['clip_grad_norm'],
                        lr_schedule=False,
                        factor=params['decay_rate'],
                        patience_epoch=params['decay_patient_epoch'])
                    logger.info('========== Convert to SGD ==========')

                    # Inject Gaussian noise to all parameters
                    if float(params['weight_noise_std']) > 0:
                        model.weight_noise_injection = True

            pbar_epoch = tqdm(total=len(train_data))
            print('========== EPOCH:%d (%.3f min) ==========' %
                  (epoch, duration_epoch / 60))

            if epoch == params['num_epoch']:
                break

            start_time_step = time.time()
            start_time_epoch = time.time()
            epoch += 1

    duration_train = time.time() - start_time_train
    logger.info('Total time: %.3f hour' % (duration_train / 3600))

    if params['backend'] == 'pytorch':
        tf_writer.close()
    pbar_epoch.close()

    # Evaluate the best model by beam search
    per_test_best, _ = eval_phone(model=best_model,
                                  dataset=test_data,
                                  beam_width=10,
                                  max_decode_len=MAX_DECODE_LEN_PHONE,
                                  eval_batch_size=1,
                                  map_file_path='./conf/phones.60-48-39.map')
    logger.info('  PER (test, beam: 10): %.3f %%' % (per_test_best * 100))

    # Training was finished correctly
    with open(os.path.join(model.save_path, 'COMPLETE'), 'w') as f:
        f.write('')
    def check(self,
              label_type,
              data_type='dev',
              backend='pytorch',
              shuffle=False,
              sort_utt=False,
              sort_stop_epoch=None,
              frame_stacking=False,
              splice=1):

        print('========================================')
        print('  backend: %s' % backend)
        print('  label_type: %s' % label_type)
        print('  data_type: %s' % data_type)
        print('  shuffle: %s' % str(shuffle))
        print('  sort_utt: %s' % str(sort_utt))
        print('  sort_stop_epoch: %s' % str(sort_stop_epoch))
        print('  frame_stacking: %s' % str(frame_stacking))
        print('  splice: %d' % splice)
        print('========================================')

        num_stack = 3 if frame_stacking else 1
        num_skip = 3 if frame_stacking else 1
        dataset = Dataset(data_save_path='/n/sd8/inaguma/corpus/timit/kaldi',
                          backend=backend,
                          input_freq=41,
                          use_delta=True,
                          use_double_delta=True,
                          data_type=data_type,
                          label_type=label_type,
                          batch_size=64,
                          max_epoch=2,
                          splice=splice,
                          num_stack=num_stack,
                          num_skip=num_skip,
                          shuffle=shuffle,
                          sort_utt=sort_utt,
                          sort_stop_epoch=sort_stop_epoch,
                          tool='htk',
                          num_enque=None)

        print('=> Loading mini-batch...')
        idx2phone = Idx2phone(dataset.vocab_file_path)

        for batch, is_new_epoch in dataset:
            if data_type == 'train' and backend == 'pytorch':
                for i in range(len(batch['xs'])):
                    if batch['xs'].shape[1] < batch['ys'].shape[1]:
                        raise ValueError(
                            'input length must be longer than label length.')

            if dataset.is_test:
                str_true = batch['ys'][0][0]
            else:
                str_true = idx2phone(batch['ys'][0][:batch['y_lens'][0]])

            print('----- %s (epoch: %.3f, batch: %d) -----' %
                  (batch['input_names'][0], dataset.epoch_detail,
                   len(batch['xs'])))
            print(str_true)
            print('x_lens: %d' % (batch['x_lens'][0] * num_stack))
            if not dataset.is_test:
                print('y_lens: %d' % batch['y_lens'][0])
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(data_save_path=args.data_save_path,
                      backend=params['backend'],
                      input_freq=params['input_freq'],
                      use_delta=params['use_delta'],
                      use_double_delta=params['use_double_delta'],
                      data_type='test',
                      label_type=params['label_type'],
                      batch_size=args.eval_batch_size,
                      splice=params['splice'],
                      num_stack=params['num_stack'],
                      num_skip=params['num_skip'],
                      shuffle=False,
                      tool=params['tool'])

    params['num_classes'] = dataset.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    save_path = mkdir_join(args.model_path, 'att_weights')

    ######################################################################

    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    for batch, is_new_epoch in dataset:
        # Decode
        best_hyps, aw, perm_idx = model.decode(
            batch['xs'],
            batch['x_lens'],
            beam_width=args.beam_width,
            max_decode_len=MAX_DECODE_LEN_PHONE,
            min_decode_len=MIN_DECODE_LEN_PHONE,
            length_penalty=args.length_penalty,
            coverage_penalty=args.coverage_penalty)

        ys = batch['ys'][perm_idx]
        y_lens = batch['y_lens'][perm_idx]

        for b in range(len(batch['xs'])):
            ##############################
            # Reference
            ##############################
            if dataset.is_test:
                str_ref = ys[b][0]
                # NOTE: transcript is seperated by space(' ')
            else:
                # Convert from list of index to string
                str_ref = dataset.idx2phone(ys[b][:y_lens[b]])

            token_list = dataset.idx2phone(best_hyps[b], return_list=True)

            plot_attention_weights(
                aw[b][:len(token_list), :batch['x_lens'][b]],
                # label_list=token_list,
                label_list=[],
                spectrogram=batch['xs'][b, :, :dataset.input_freq],
                str_ref=str_ref,
                save_path=join(save_path, batch['input_names'][b] + '.png'),
                figsize=(20, 8))

        if is_new_epoch:
            break