Ejemplo n.º 1
0
def search(config, params, load_path, part, decode_only, report, decoded_save,
           nll_only, seed):
    import matplotlib
    matplotlib.use("Agg")
    from matplotlib import pyplot
    from lvsr.notebook import show_alignment

    data = Data(**config['data'])
    search_conf = config['monitoring']['search']

    logger.info("Recognizer initialization started")
    recognizer = create_model(config, data, load_path)
    recognizer.init_beam_search(search_conf['beam_size'])
    logger.info("Recognizer is initialized")

    has_uttids = 'uttids' in data.info_dataset.provides_sources
    add_sources = ('uttids', ) if has_uttids else ()
    dataset = data.get_dataset(part, add_sources)
    stream = data.get_stream(part,
                             batches=False,
                             shuffle=part == 'train',
                             add_sources=add_sources,
                             num_examples=500 if part == 'train' else None,
                             seed=seed)
    it = stream.get_epoch_iterator(as_dict=True)
    if decode_only is not None:
        decode_only = eval(decode_only)

    weights = tensor.matrix('weights')
    weight_statistics = theano.function([weights], [
        weights_std(weights.dimshuffle(0, 'x', 1)),
        monotonicity_penalty(weights.dimshuffle(0, 'x', 1))
    ])

    print_to = sys.stdout
    if report:
        alignments_path = os.path.join(report, "alignments")
        if not os.path.exists(report):
            os.mkdir(report)
            os.mkdir(alignments_path)
        print_to = open(os.path.join(report, "report.txt"), 'w')

    decoded_file = None
    if decoded_save:
        decoded_file = open(decoded_save, 'w')

    num_examples = .0
    total_nll = .0
    total_errors = .0
    total_length = .0
    total_wer_errors = .0
    total_word_length = 0.

    if config.get('vocabulary'):
        with open(os.path.expandvars(config['vocabulary'])) as f:
            vocabulary = dict(line.split() for line in f.readlines())

        def to_words(chars):
            words = chars.split()
            words = [
                vocabulary[word] if word in vocabulary else vocabulary['<UNK>']
                for word in words
            ]
            return words

    for number, example in enumerate(it):
        if decode_only and number not in decode_only:
            continue
        uttids = example.pop('uttids', None)
        raw_groundtruth = example.pop('labels')
        required_inputs = dict_subset(example, recognizer.inputs.keys())

        print("Utterance {} ({})".format(number, uttids), file=print_to)

        groundtruth = dataset.decode(raw_groundtruth)
        groundtruth_text = dataset.pretty_print(raw_groundtruth, example)
        costs_groundtruth, weights_groundtruth = recognizer.analyze(
            inputs=required_inputs,
            groundtruth=raw_groundtruth,
            prediction=raw_groundtruth)[:2]
        weight_std_groundtruth, mono_penalty_groundtruth = weight_statistics(
            weights_groundtruth)
        total_nll += costs_groundtruth.sum()
        num_examples += 1
        print("Groundtruth:", groundtruth_text, file=print_to)
        print("Groundtruth cost:", costs_groundtruth.sum(), file=print_to)
        print("Groundtruth weight std:", weight_std_groundtruth, file=print_to)
        print("Groundtruth monotonicity penalty:",
              mono_penalty_groundtruth,
              file=print_to)
        print("Average groundtruth cost: {}".format(total_nll / num_examples),
              file=print_to)
        if nll_only:
            print_to.flush()
            continue

        before = time.time()
        try:
            search_kwargs = dict(
                char_discount=search_conf.get('char_discount'),
                round_to_inf=search_conf.get('round_to_inf'),
                stop_on=search_conf.get('stop_on'),
                validate_solution_function=getattr(data.info_dataset,
                                                   'validate_solution', None))
            search_kwargs = {k: v for k, v in search_kwargs.items() if v}
            outputs, search_costs = recognizer.beam_search(
                required_inputs, **search_kwargs)
        except CandidateNotFoundError:
            logger.error('Candidate not found!')
            outputs = [[]]
            search_costs = [[numpy.NaN]]

        took = time.time() - before
        recognized = dataset.decode(outputs[0])
        recognized_text = dataset.pretty_print(outputs[0], example)
        if recognized:
            # Theano scan doesn't work with 0 length sequences
            costs_recognized, weights_recognized = recognizer.analyze(
                inputs=required_inputs,
                groundtruth=raw_groundtruth,
                prediction=outputs[0])[:2]
            weight_std_recognized, mono_penalty_recognized = weight_statistics(
                weights_recognized)
            error = min(1, wer(groundtruth, recognized))
        else:
            error = 1
        total_errors += len(groundtruth) * error
        total_length += len(groundtruth)

        if config.get('vocabulary'):
            wer_error = min(
                1, wer(to_words(groundtruth_text), to_words(recognized_text)))
            total_wer_errors += len(groundtruth) * wer_error
            total_word_length += len(groundtruth)

        if report and recognized:
            show_alignment(weights_groundtruth, groundtruth, bos_symbol=True)
            pyplot.savefig(
                os.path.join(alignments_path,
                             "{}.groundtruth.png".format(number)))
            show_alignment(weights_recognized, recognized, bos_symbol=True)
            pyplot.savefig(
                os.path.join(alignments_path,
                             "{}.recognized.png".format(number)))

        if decoded_file is not None:
            print("{} {}".format(uttids, ' '.join(recognized)),
                  file=decoded_file)

        print("Decoding took:", took, file=print_to)
        print("Beam search cost:", search_costs[0], file=print_to)
        print("Recognized:", recognized_text, file=print_to)
        if recognized:
            print("Recognized cost:", costs_recognized.sum(), file=print_to)
            print("Recognized weight std:",
                  weight_std_recognized,
                  file=print_to)
            print("Recognized monotonicity penalty:",
                  mono_penalty_recognized,
                  file=print_to)
        print("CER:", error, file=print_to)
        print("Average CER:", total_errors / total_length, file=print_to)
        if config.get('vocabulary'):
            print("WER:", wer_error, file=print_to)
            print("Average WER:",
                  total_wer_errors / total_word_length,
                  file=print_to)
        print_to.flush()
Ejemplo n.º 2
0
def search(config, params, load_path, part, decode_only, report, decoded_save,
           nll_only, seed):
    import matplotlib
    matplotlib.use("Agg")
    from matplotlib import pyplot
    from lvsr.notebook import show_alignment

    data = Data(**config['data'])
    search_conf = config['monitoring']['search']

    logger.info("Recognizer initialization started")
    recognizer = SpeechRecognizer(data.recordings_source,
                                  data.labels_source,
                                  data.eos_label,
                                  data.num_features,
                                  data.num_labels,
                                  character_map=data.character_map,
                                  name='recognizer',
                                  **config["net"])
    recognizer.load_params(load_path)
    recognizer.init_beam_search(search_conf['beam_size'])
    logger.info("Recognizer is initialized")

    stream = data.get_stream(part,
                             batches=False,
                             shuffle=part == 'train',
                             add_sources=(data.uttid_source, ),
                             num_examples=500 if part == 'train' else None,
                             seed=seed)
    it = stream.get_epoch_iterator()
    if decode_only is not None:
        decode_only = eval(decode_only)

    weights = tensor.matrix('weights')
    weight_statistics = theano.function([weights], [
        weights_std(weights.dimshuffle(0, 'x', 1)),
        monotonicity_penalty(weights.dimshuffle(0, 'x', 1))
    ])

    print_to = sys.stdout
    if report:
        alignments_path = os.path.join(report, "alignments")
        if not os.path.exists(report):
            os.mkdir(report)
            os.mkdir(alignments_path)
        print_to = open(os.path.join(report, "report.txt"), 'w')

    decoded_file = None
    if decoded_save:
        decoded_file = open(decoded_save, 'w')

    num_examples = .0
    total_nll = .0
    total_errors = .0
    total_length = .0
    total_wer_errors = .0
    total_word_length = 0.

    if config.get('vocabulary'):
        with open(os.path.expandvars(config['vocabulary'])) as f:
            vocabulary = dict(line.split() for line in f.readlines())

        def to_words(chars):
            words = chars.split()
            words = [
                vocabulary[word] if word in vocabulary else vocabulary['<UNK>']
                for word in words
            ]
            return words

    for number, example in enumerate(it):
        if decode_only and number not in decode_only:
            continue
        print("Utterance {} ({})".format(number, example[2]), file=print_to)
        groundtruth = data.decode(example[1])
        groundtruth_text = data.pretty_print(example[1])
        costs_groundtruth, weights_groundtruth = (recognizer.analyze(
            example[0], example[1], example[1])[:2])
        weight_std_groundtruth, mono_penalty_groundtruth = weight_statistics(
            weights_groundtruth)
        total_nll += costs_groundtruth.sum()
        num_examples += 1
        print("Groundtruth:", groundtruth_text, file=print_to)
        print("Groundtruth cost:", costs_groundtruth.sum(), file=print_to)
        print("Groundtruth weight std:", weight_std_groundtruth, file=print_to)
        print("Groundtruth monotonicity penalty:",
              mono_penalty_groundtruth,
              file=print_to)
        print("Average groundtruth cost: {}".format(total_nll / num_examples),
              file=print_to)
        if nll_only:
            print_to.flush()
            continue

        before = time.time()
        outputs, search_costs = recognizer.beam_search(
            example[0],
            char_discount=search_conf['char_discount'],
            round_to_inf=search_conf['round_to_inf'],
            stop_on=search_conf['stop_on'])
        took = time.time() - before
        recognized = data.decode(outputs[0])
        recognized_text = data.pretty_print(outputs[0])
        if recognized:
            # Theano scan doesn't work with 0 length sequences
            costs_recognized, weights_recognized = (recognizer.analyze(
                example[0], example[1], outputs[0])[:2])
            weight_std_recognized, mono_penalty_recognized = weight_statistics(
                weights_recognized)
            error = min(1, wer(groundtruth, recognized))
        else:
            error = 1
        total_errors += len(groundtruth) * error
        total_length += len(groundtruth)

        if config.get('vocabulary'):
            wer_error = min(
                1, wer(to_words(groundtruth_text), to_words(recognized_text)))
            total_wer_errors += len(groundtruth) * wer_error
            total_word_length += len(groundtruth)

        if report and recognized:
            show_alignment(weights_groundtruth, groundtruth, bos_symbol=True)
            pyplot.savefig(
                os.path.join(alignments_path,
                             "{}.groundtruth.png".format(number)))
            show_alignment(weights_recognized, recognized, bos_symbol=True)
            pyplot.savefig(
                os.path.join(alignments_path,
                             "{}.recognized.png".format(number)))

        if decoded_file is not None:
            print("{} {}".format(example[2], ' '.join(recognized)),
                  file=decoded_file)

        print("Decoding took:", took, file=print_to)
        print("Beam search cost:", search_costs[0], file=print_to)
        print("Recognized:", recognized_text, file=print_to)
        if recognized:
            print("Recognized cost:", costs_recognized.sum(), file=print_to)
            print("Recognized weight std:",
                  weight_std_recognized,
                  file=print_to)
            print("Recognized monotonicity penalty:",
                  mono_penalty_recognized,
                  file=print_to)
        print("CER:", error, file=print_to)
        print("Average CER:", total_errors / total_length, file=print_to)
        if config.get('vocabulary'):
            print("WER:", wer_error, file=print_to)
            print("Average WER:",
                  total_wer_errors / total_word_length,
                  file=print_to)
        print_to.flush()
Ejemplo n.º 3
0
def search(config, params, load_path, beam_size, part, decode_only, report,
           decoded_save, nll_only, char_discount):
    from matplotlib import pyplot
    from lvsr.notebook import show_alignment

    data = Data(**config['data'])

    recognizer = SpeechRecognizer(
        data.recordings_source, data.labels_source,
        data.eos_label, data.num_features, data.num_labels,
        character_map=data.character_map,
        name='recognizer', **config["net"])
    recognizer.load_params(load_path)
    recognizer.init_beam_search(beam_size)

    dataset = data.get_dataset(part, add_sources=(data.uttid_source,))
    stream = data.get_stream(part, batches=False, shuffle=False,
                                add_sources=(data.uttid_source,))
    it = stream.get_epoch_iterator()
    if decode_only is not None:
        decode_only = eval(decode_only)

    weights = tensor.matrix('weights')
    weight_statistics = theano.function(
        [weights],
        [weights_std(weights.dimshuffle(0, 'x', 1)),
            monotonicity_penalty(weights.dimshuffle(0, 'x', 1))])

    print_to = sys.stdout
    if report:
        alignments_path = os.path.join(report, "alignments")
        if not os.path.exists(report):
            os.mkdir(report)
            os.mkdir(alignments_path)
        print_to = open(os.path.join(report, "report.txt"), 'w')

    decoded_file = None
    if decoded_save:
        decoded_file = open(decoded_save, 'w')

    num_examples = .0
    total_nll = .0
    total_errors = .0
    total_length = .0
    total_wer_errors = .0
    total_word_length = 0.
    with open(os.path.expandvars(config['vocabulary'])) as f:
        vocabulary = dict(line.split() for line in f.readlines())

    def to_words(chars):
        words = chars.split()
        words = [vocabulary[word] if word in vocabulary
                    else vocabulary['<UNK>'] for word in words]
        return words

    for number, data in enumerate(it):
        if decode_only and number not in decode_only:
            continue
        print("Utterance {} ({})".format(number, data[2]), file=print_to)
        groundtruth = dataset.decode(data[1])
        groundtruth_text = dataset.pretty_print(data[1])
        costs_groundtruth, weights_groundtruth = (
            recognizer.analyze(data[0], data[1])[:2])
        weight_std_groundtruth, mono_penalty_groundtruth = weight_statistics(
            weights_groundtruth)
        total_nll += costs_groundtruth.sum()
        num_examples += 1
        print("Groundtruth:", groundtruth_text, file=print_to)
        print("Groundtruth cost:", costs_groundtruth.sum(), file=print_to)
        print("Groundtruth weight std:", weight_std_groundtruth, file=print_to)
        print("Groundtruth monotonicity penalty:", mono_penalty_groundtruth, file=print_to)
        print("Average groundtruth cost: {}".format(total_nll / num_examples),
                file=print_to)
        if nll_only:
            continue

        before = time.time()
        outputs, search_costs = recognizer.beam_search(
            data[0], char_discount=char_discount)
        took = time.time() - before
        recognized = dataset.decode(outputs[0])
        recognized_text = dataset.pretty_print(outputs[0])
        costs_recognized, weights_recognized = (
            recognizer.analyze(data[0], outputs[0])[:2])
        weight_std_recognized, mono_penalty_recognized = weight_statistics(
            weights_recognized)
        error = min(1, wer(groundtruth, recognized))
        total_errors += len(groundtruth) * error
        total_length += len(groundtruth)

        wer_error = min(1, wer(to_words(groundtruth_text),
                                to_words(recognized_text)))
        total_wer_errors += len(groundtruth) * wer_error
        total_word_length += len(groundtruth)

        if report and recognized:
            show_alignment(weights_groundtruth, groundtruth, bos_symbol=True)
            pyplot.savefig(os.path.join(
                alignments_path, "{}.groundtruth.png".format(number)))
            show_alignment(weights_recognized, recognized, bos_symbol=True)
            pyplot.savefig(os.path.join(
                alignments_path, "{}.recognized.png".format(number)))

        if decoded_file is not None:
            print("{} {}".format(data[2], ' '.join(recognized)), file=decoded_file)

        print("Decoding took:", took, file=print_to)
        print("Beam search cost:", search_costs[0], file=print_to)
        print("Recognized:", recognized_text, file=print_to)
        print("Recognized cost:", costs_recognized.sum(), file=print_to)
        print("Recognized weight std:", weight_std_recognized, file=print_to)
        print("Recognized monotonicity penalty:", mono_penalty_recognized, file=print_to)
        print("CER:", error, file=print_to)
        print("Average CER:", total_errors / total_length, file=print_to)
        print("WER:", wer_error, file=print_to)
        print("Average WER:", total_wer_errors / total_word_length, file=print_to)
Ejemplo n.º 4
0
def search(config, params, load_path, part, decode_only, report,
           decoded_save, nll_only, seed):
    import matplotlib
    matplotlib.use("Agg")
    from matplotlib import pyplot
    from lvsr.notebook import show_alignment

    data = Data(**config['data'])
    search_conf = config['monitoring']['search']

    logger.info("Recognizer initialization started")
    recognizer = create_model(config, data, load_path)
    recognizer.init_beam_search(search_conf['beam_size'])
    logger.info("Recognizer is initialized")

    has_uttids = 'uttids' in data.info_dataset.provides_sources
    add_sources = ('uttids',) if has_uttids else ()
    dataset = data.get_dataset(part, add_sources)
    stream = data.get_stream(part, batches=False,
                             shuffle=part == 'train',
                             add_sources=add_sources,
                             num_examples=500 if part == 'train' else None,
                             seed=seed)
    it = stream.get_epoch_iterator(as_dict=True)
    if decode_only is not None:
        decode_only = eval(decode_only)

    weights = tensor.matrix('weights')
    weight_statistics = theano.function(
        [weights],
        [weights_std(weights.dimshuffle(0, 'x', 1)),
            monotonicity_penalty(weights.dimshuffle(0, 'x', 1))])

    print_to = sys.stdout
    if report:
        alignments_path = os.path.join(report, "alignments")
        if not os.path.exists(report):
            os.mkdir(report)
            os.mkdir(alignments_path)
        print_to = open(os.path.join(report, "report.txt"), 'w')

    decoded_file = None
    if decoded_save:
        decoded_file = open(decoded_save, 'w')

    num_examples = .0
    total_nll = .0
    total_errors = .0
    total_length = .0
    total_wer_errors = .0
    total_word_length = 0.

    if config.get('vocabulary'):
        with open(os.path.expandvars(config['vocabulary'])) as f:
            vocabulary = dict(line.split() for line in f.readlines())

        def to_words(chars):
            words = chars.split()
            words = [vocabulary[word] if word in vocabulary
                     else vocabulary['<UNK>'] for word in words]
            return words

    for number, example in enumerate(it):
        if decode_only and number not in decode_only:
            continue
        uttids = example.pop('uttids', None)
        raw_groundtruth = example.pop('labels')
        required_inputs = dict_subset(example, recognizer.inputs.keys())

        print("Utterance {} ({})".format(number, uttids), file=print_to)

        groundtruth = dataset.decode(raw_groundtruth)
        groundtruth_text = dataset.pretty_print(raw_groundtruth, example)
        costs_groundtruth, weights_groundtruth = recognizer.analyze(
            inputs=required_inputs,
            groundtruth=raw_groundtruth,
            prediction=raw_groundtruth)[:2]
        weight_std_groundtruth, mono_penalty_groundtruth = weight_statistics(
            weights_groundtruth)
        total_nll += costs_groundtruth.sum()
        num_examples += 1
        print("Groundtruth:", groundtruth_text, file=print_to)
        print("Groundtruth cost:", costs_groundtruth.sum(), file=print_to)
        print("Groundtruth weight std:", weight_std_groundtruth, file=print_to)
        print("Groundtruth monotonicity penalty:", mono_penalty_groundtruth,
              file=print_to)
        print("Average groundtruth cost: {}".format(total_nll / num_examples),
              file=print_to)
        if nll_only:
            print_to.flush()
            continue

        before = time.time()
        try:
            search_kwargs = dict(
                char_discount=search_conf.get('char_discount'),
                round_to_inf=search_conf.get('round_to_inf'),
                stop_on=search_conf.get('stop_on'),
                validate_solution_function=getattr(
                    data.info_dataset, 'validate_solution', None))
            search_kwargs = {k: v for k, v in search_kwargs.items() if v}
            outputs, search_costs = recognizer.beam_search(
                required_inputs, **search_kwargs)
        except CandidateNotFoundError:
            logger.error('Candidate not found!')
            outputs = [[]]
            search_costs = [[numpy.NaN]]

        took = time.time() - before
        recognized = dataset.decode(outputs[0])
        recognized_text = dataset.pretty_print(outputs[0], example)
        if recognized:
            # Theano scan doesn't work with 0 length sequences
            costs_recognized, weights_recognized = recognizer.analyze(
                inputs=required_inputs,
                groundtruth=raw_groundtruth,
                prediction=outputs[0])[:2]
            weight_std_recognized, mono_penalty_recognized = weight_statistics(
                weights_recognized)
            error = min(1, wer(groundtruth, recognized))
        else:
            error = 1
        total_errors += len(groundtruth) * error
        total_length += len(groundtruth)

        if config.get('vocabulary'):
            wer_error = min(1, wer(to_words(groundtruth_text),
                                   to_words(recognized_text)))
            total_wer_errors += len(groundtruth) * wer_error
            total_word_length += len(groundtruth)

        if report and recognized:
            show_alignment(weights_groundtruth, groundtruth, bos_symbol=True)
            pyplot.savefig(os.path.join(
                alignments_path, "{}.groundtruth.png".format(number)))
            show_alignment(weights_recognized, recognized, bos_symbol=True)
            pyplot.savefig(os.path.join(
                alignments_path, "{}.recognized.png".format(number)))

        if decoded_file is not None:
            print("{} {}".format(uttids, ' '.join(recognized)),
                  file=decoded_file)

        print("Decoding took:", took, file=print_to)
        print("Beam search cost:", search_costs[0], file=print_to)
        print("Recognized:", recognized_text, file=print_to)
        if recognized:
            print("Recognized cost:", costs_recognized.sum(), file=print_to)
            print("Recognized weight std:", weight_std_recognized,
                  file=print_to)
            print("Recognized monotonicity penalty:", mono_penalty_recognized,
                  file=print_to)
        print("CER:", error, file=print_to)
        print("Average CER:", total_errors / total_length, file=print_to)
        if config.get('vocabulary'):
            print("WER:", wer_error, file=print_to)
            print("Average WER:", total_wer_errors / total_word_length, file=print_to)
        print_to.flush()