예제 #1
0
def plot_probs(model, dataset, eval_batch_size, save_path=None):
    """Plot CTC posteriors.
    Args:
        model: model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int): the batch size when evaluating the model
        save_path (string): path to save figures of CTC posteriors
    """
    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    for batch, is_new_epoch in dataset:

        # Get CTC probs
        probs = model.posteriors(batch['xs'], batch['x_lens'], temperature=1)
        # NOTE: probs: '[B, T, num_classes]'

        # Visualize
        for b in range(len(batch['xs'])):
            plot_ctc_probs(
                probs[b, : batch['x_lens'][b], :],
                frame_num=batch['x_lens'][b],
                num_stack=dataset.num_stack,
                spectrogram=batch['xs'][b, :, :40],
                save_path=join(save_path, batch['input_names'][b] + '.png'),
                figsize=(14, 7))

        if is_new_epoch:
            break
def plot(model, dataset, eval_batch_size=None, save_path=None,
         space_index=None):
    """
    Args:
        model: the model to evaluate
        dataset: An instance of a `Dataset` class
        eval_batch_size (int, optional): the batch size when evaluating the model
        save_path (string): path to save figures of CTC posteriors
        space_index (int, optional):
    """
    # Set batch size in the evaluation
    if eval_batch_size is not None:
        dataset.batch_size = eval_batch_size

    # Clean directory
    if isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    vocab_file_path = '../metrics/vocab_files/' + \
        dataset.label_type + '_' + dataset.data_size + '.txt'
    if dataset.label_type == 'character':
        map_fn = Idx2char(vocab_file_path)
    elif dataset.label_type == 'character_capital_divide':
        map_fn = Idx2char(vocab_file_path, capital_divide=True)
    else:
        map_fn = Idx2word(vocab_file_path)

    for batch, is_new_epoch in dataset:

        # Get CTC probs
        probs = model.posteriors(batch['xs'], batch['x_lens'], temperature=1)
        # NOTE: probs: '[B, T, num_classes]'

        # Decode
        best_hyps _ = model.decode(batch['xs'], batch['x_lens'], beam_width=1)

        # Visualize
        for b in range(len(batch['xs'])):

            # Convert from list of index to string
            str_pred = map_fn(best_hyps[b])

            speaker, book = batch['input_names'][b].split('-')[:2]
            plot_ctc_probs(
                probs[b, :batch['x_lens'][b], :],
                frame_num=batch['x_lens'][b],
                num_stack=dataset.num_stack,
                space_index=space_index,
                str_pred=str_pred,
                save_path=mkdir_join(save_path, speaker, book, batch['input_names'][b] + '.png'))

        if is_new_epoch:
            break
def main():

    args = parser.parse_args()

    # Load a config file (.yml)
    params = load_config(join(args.model_path, 'config.yml'), is_eval=True)

    # Load dataset
    dataset = Dataset(data_save_path=args.data_save_path,
                      backend=params['backend'],
                      input_freq=params['input_freq'],
                      use_delta=params['use_delta'],
                      use_double_delta=params['use_double_delta'],
                      data_type='test',
                      label_type=params['label_type'],
                      batch_size=args.eval_batch_size,
                      splice=params['splice'],
                      num_stack=params['num_stack'],
                      num_skip=params['num_skip'],
                      sort_utt=True,
                      reverse=True,
                      tool=params['tool'])

    params['num_classes'] = dataset.num_classes

    # Load model
    model = load(model_type=params['model_type'],
                 params=params,
                 backend=params['backend'])

    # Restore the saved parameters
    model.load_checkpoint(save_path=args.model_path, epoch=args.epoch)

    # GPU setting
    model.set_cuda(deterministic=False, benchmark=True)

    save_path = mkdir_join(args.model_path, 'ctc_probs')

    ######################################################################

    # Clean directory
    if save_path is not None and isdir(save_path):
        shutil.rmtree(save_path)
        mkdir(save_path)

    for batch, is_new_epoch in dataset:
        # Get CTC probs
        probs, x_lens, _ = model.posteriors(batch['xs'],
                                            batch['x_lens'],
                                            temperature=1)
        # NOTE: probs: '[B, T, num_classes]'

        # Visualize
        for b in range(len(batch['xs'])):
            plot_ctc_probs(probs[b, :x_lens[b], :],
                           frame_num=x_lens[b],
                           num_stack=dataset.num_stack,
                           spectrogram=batch['xs'][b, :, :40],
                           save_path=join(save_path,
                                          batch['input_names'][b] + '.png'),
                           figsize=(14, 7))

        if is_new_epoch:
            break