Ejemplo n.º 1
0
class CER(Metric):
    def __init__(self,
                 name='CER',
                 output_transform=lambda x: x[1:],
                 fmt='.02%',
                 **kwargs):
        alphabet = kwargs.pop('alphabet', None)
        super().__init__(name, output_transform, fmt=fmt, **kwargs)

        if alphabet is None:
            raise ValueError('Arg. `alphabet` is required.')

        self.decoder = GreedyCTCDecoder(alphabet)

    def _update(self, output):
        output, target, output_lengths, target_lengths = output

        transcripts = self.decoder.decode(output, output_lengths)

        references = self.decoder.tensor2str(target, target_lengths)

        super()._update(
            torch.tensor([
                self.decoder.cer(t, r) / float(len(r)) if len(r) else 1
                for t, r in zip(transcripts[0], references)
            ],
                         device=output.device))
Ejemplo n.º 2
0
    def __init__(self,
                 name='CER',
                 output_transform=lambda x: x[1:],
                 fmt='.02%',
                 **kwargs):
        alphabet = kwargs.pop('alphabet', None)
        super().__init__(name, output_transform, fmt=fmt, **kwargs)

        if alphabet is None:
            raise ValueError('Arg. `alphabet` is required.')

        self.decoder = GreedyCTCDecoder(alphabet)
Ejemplo n.º 3
0
def evaluate_from_args(args):
    # Disable some of the more verbose logging statements
    logging.getLogger('asr.common.params').disabled = True
    logging.getLogger('asr.common.registrable').disabled = True

    # Load from archive
    _, weights_file = load_archive(args.serialization_dir, args.overrides,
                                   args.weights_file)

    params = Params.load(os.path.join(args.serialization_dir, CONFIG_NAME),
                         args.overrides)

    prepare_environment(params)

    # Try to use the validation dataset reader if there is one - otherwise fall back
    # to the default dataset_reader used for both training and validation.
    dataset_params = params.pop('val_dataset', params.get('dataset_reader'))

    logger.info("Reading evaluation data from %s", args.input_file)
    dataset_params['manifest_filepath'] = args.input_file
    dataset = datasets.from_params(dataset_params)

    if os.path.exists(os.path.join(args.serialization_dir, "alphabet")):
        alphabet = Alphabet.from_file(
            os.path.join(args.serialization_dir, "alphabet", "tokens"))
    else:
        alphabet = Alphabet.from_params(params.pop("alphabet", {}))

    logits_dir = os.path.join(args.serialization_dir, 'logits')
    os.makedirs(logits_dir, exist_ok=True)

    basename = os.path.splitext(os.path.split(args.input_file)[1])[0]
    print(basename)
    logits_file = os.path.join(logits_dir, basename + '.pth')

    if not os.path.exists(logits_file):
        model = models.from_params(alphabet=alphabet,
                                   params=params.pop('model'))
        model.load_state_dict(
            torch.load(weights_file,
                       map_location=lambda storage, loc: storage)['model'])
        model.eval()

        decoder = GreedyCTCDecoder(alphabet)

        loader_params = params.pop("val_data_loader",
                                   params.get("data_loader"))
        batch_sampler = samplers.BucketingSampler(dataset,
                                                  batch_size=args.batch_size)
        loader = loaders.from_params(loader_params,
                                     dataset=dataset,
                                     batch_sampler=batch_sampler)

        logger.info(f'Logits file `{logits_file}` not found. Generating...')

        with torch.no_grad():
            model.to(args.device)

            logits = []
            total_cer, total_wer, num_tokens, num_chars = 0, 0, 0, 0
            for batch in tqdm.tqdm(loader):
                sample, target, sample_lengths, target_lengths = batch
                sample = sample.to(args.device)
                sample_lengths = sample_lengths.to(args.device)

                output, output_lengths = model(sample, sample_lengths)

                output = output.to('cpu')

                references = decoder.tensor2str(target, target_lengths)

                transcripts = decoder.decode(output)[0]

                logits.extend(
                    (o[:l, ...], r)
                    for o, l, r in zip(output, output_lengths, references))

                del sample, sample_lengths, output

                for reference, transcript in zip(references, transcripts):
                    total_wer += decoder.wer(transcript, reference)
                    total_cer += decoder.cer(transcript, reference)
                    num_tokens += float(len(reference.split()))
                    num_chars += float(len(reference))

            torch.save(logits, logits_file)

            wer = float(total_wer) / num_tokens
            cer = float(total_cer) / num_chars

            print(f'WER: {wer:.02%}\nCER: {cer:.02%}')

        del model

    else:
        logger.info(f'Logits file `{logits_file}` already generated.')
Ejemplo n.º 4
0
def tune_from_args(args):
    # Disable some of the more verbose logging statements
    logging.getLogger('asr.common.params').disabled = True
    logging.getLogger('asr.common.registrable').disabled = True

    # Load from archive
    _, weights_file = load_archive(args.serialization_dir, args.overrides,
                                   args.weights_file)

    params = Params.load(os.path.join(args.serialization_dir, CONFIG_NAME),
                         args.overrides)

    prepare_environment(params)

    # Try to use the validation dataset reader if there is one - otherwise fall back
    # to the default dataset_reader used for both training and validation.
    dataset_params = params.pop('val_dataset', params.get('dataset_reader'))

    logger.info("Reading evaluation data from %s", args.input_file)
    dataset_params['manifest_filepath'] = args.input_file
    dataset = datasets.from_params(dataset_params)

    if os.path.exists(os.path.join(args.serialization_dir, "alphabet")):
        alphabet = Alphabet.from_file(
            os.path.join(args.serialization_dir, "alphabet", "tokens"))
    else:
        alphabet = Alphabet.from_params(params.pop("alphabet", {}))

    logits_dir = os.path.join(args.serialization_dir, 'logits')
    os.makedirs(logits_dir, exist_ok=True)

    basename = os.path.splitext(os.path.split(args.input_file)[1])[0]
    logits_file = os.path.join(logits_dir, basename + '.pth')

    if not os.path.exists(logits_file):
        model = models.from_params(alphabet=alphabet,
                                   params=params.pop('model'))
        model.load_state_dict(
            torch.load(weights_file,
                       map_location=lambda storage, loc: storage)['model'])
        model.eval()

        decoder = GreedyCTCDecoder(alphabet)

        loader_params = params.pop("val_data_loader",
                                   params.get("data_loader"))
        batch_sampler = samplers.BucketingSampler(dataset,
                                                  batch_size=args.batch_size)
        loader = loaders.from_params(loader_params,
                                     dataset=dataset,
                                     batch_sampler=batch_sampler)

        logger.info(f'Logits file `{logits_file}` not found. Generating...')

        with torch.no_grad():
            model.to(args.device)

            logits = []
            for batch in tqdm.tqdm(loader):
                sample, target, sample_lengths, target_lengths = batch
                sample = sample.to(args.device)
                sample_lengths = sample_lengths.to(args.device)

                output, output_lengths = model(sample, sample_lengths)

                output = output.to('cpu')

                references = decoder.tensor2str(target, target_lengths)

                logits.extend((o[:l, ...], r) for o, l, r in zip(
                    output.to('cpu'), output_lengths, references))

                del sample, sample_lengths, output

            torch.save(logits, logits_file)

        del model

    tune_dir = os.path.join(args.serialization_dir, 'tune')
    os.makedirs(tune_dir, exist_ok=True)

    params_grid = list(
        product(
            torch.linspace(args.alpha_from, args.alpha_to, args.alpha_steps),
            torch.linspace(args.beta_from, args.beta_to, args.beta_steps)))

    print(
        'Scheduling {} jobs for alphas=linspace({}, {}, {}), betas=linspace({}, {}, {})'
        .format(len(params_grid), args.alpha_from, args.alpha_to,
                args.alpha_steps, args.beta_from, args.beta_to,
                args.beta_steps))

    # start worker processes
    logger.info(
        f"Using {args.num_workers} processes and {args.lm_workers} for each CTCDecoder."
    )
    extract_start = default_timer()

    p = Pool(args.num_workers, init, [
        logits_file, alphabet, args.lm_path, args.cutoff_top_n,
        args.cutoff_prob, args.beam_width, args.lm_workers
    ])

    scores = []
    best_wer = float('inf')
    with tqdm.tqdm(p.imap(tune_step, params_grid),
                   total=len(params_grid),
                   desc='Grid search') as pbar:
        for params in pbar:
            alpha, beta, wer, cer = params
            scores.append([alpha, beta, wer, cer])

            if wer < best_wer:
                best_wer = wer
                pbar.set_postfix(alpha=alpha, beta=beta, wer=wer, cer=cer)

    logger.info(
        f"Finished {len(params_grid)} processes in {default_timer() - extract_start:.1f}s"
    )

    df = pd.DataFrame(scores, columns=['alpha', 'beta', 'wer', 'cer'])
    df.to_csv(os.path.join(tune_dir, basename + '.csv'), index=False)
Ejemplo n.º 5
0
def test_greedy_decoder():
    """ Code adapted from tensorflow
    """
    max_time_steps = 6

    seq_len_0 = 4
    input_prob_matrix_0 = torch.tensor(
        [
            [1.0, 0.0, 0.0, 0.0],  # t=0
            [0.0, 0.0, 0.4, 0.6],  # t=1
            [0.0, 0.0, 0.4, 0.6],  # t=2
            [0.0, 0.9, 0.1, 0.0],  # t=3
            [0.0, 0.0, 0.0, 0.0],  # t=4 (ignored)
            [0.0, 0.0, 0.0, 0.0]
        ],  # t=5 (ignored)
        dtype=torch.float32)
    input_log_prob_matrix_0 = input_prob_matrix_0.log()

    seq_len_1 = 5
    # dimensions are time x depth
    input_prob_matrix_1 = torch.tensor(
        [
            [0.1, 0.9, 0.0, 0.0],  # t=0
            [0.0, 0.9, 0.1, 0.0],  # t=1
            [0.0, 0.0, 0.1, 0.9],  # t=2
            [0.0, 0.9, 0.1, 0.1],  # t=3
            [0.9, 0.1, 0.0, 0.0],  # t=4
            [0.0, 0.0, 0.0, 0.0]
        ],  # t=5 (ignored)
        dtype=torch.float32)
    input_log_prob_matrix_1 = input_prob_matrix_1.log()

    # len max_time_steps array of batch_size x depth matrices
    inputs = torch.stack([input_log_prob_matrix_0, input_log_prob_matrix_1])
    # batch_size length vector of sequence_lengths
    seq_lens = torch.tensor([seq_len_0, seq_len_1], dtype=torch.int32)

    # batch_size length vector of negative log probabilities
    log_prob_truth = torch.tensor([
        -(torch.tensor([1.0, 0.6, 0.6, 0.9]).log()).sum().item(),
        -(torch.tensor([0.9, 0.9, 0.9, 0.9, 0.9]).log()).sum().item()
    ])

    decode_truth = ['ab', 'bba']
    offsets_truth = [
        torch.tensor([0, 3]),
        torch.tensor([0, 3, 4]),
    ]

    alphabet = Alphabet('abc-', blank_index=3)

    decoder = GreedyCTCDecoder(alphabet)
    out, scores, offsets = decoder.decode(inputs, seq_lens)

    assert out[0] == decode_truth[0]
    assert out[1] == decode_truth[1]

    assert torch.allclose(scores, log_prob_truth)

    assert torch.all(offsets[0] == offsets_truth[0])
    assert torch.all(offsets[1] == offsets_truth[1])