Ejemplo n.º 1
0
def decode_cnn_ctc(args=None):
    '''Decode CNN w/ CTC using kaldi data tables'''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(sys.argv[0])
    options = _decode_cnn_ctc_parse_args(args, logger)
    logger.log(9, 'Parsed options')
    id2label_map = dict()
    with open(options.label_to_id_map_path) as file_obj:
        for line in file_obj:
            label, idee = line.strip().split()
            idee = int(idee)
            if idee < 0:
                logger.error('All label ids must be nonnegative')
                return 1
            id2label_map[idee] = label
    if (len(id2label_map) + 1) != options.num_labels:
        logger.error('Expected {} labels in id to label map, got {}'.format(
            options.num_labels - 1, len(id2label_map)))
        return 1
    for idee in range(options.num_labels - 1):
        if idee not in id2label_map.keys():
            logger.error('label to id map missing id: {}'.format(idee))
            return 1
    logger.log(9, 'Loaded label to id map')
    model_config = ModelConfig(**vars(options))
    decode_config = DecodeConfig(**vars(options))
    decode_data = DecodeData(
        options.data_rspecifier,
        delta_order=model_config.delta_order,
        cmvn_rxfilename=model_config.cmvn_rxfilename,
        batch_size=decode_config.batch_size,
    )
    total_batches = len(decode_data)
    labels_out = io_open(options.output_wspecifier, 'tv', mode='w')
    logger.log(9, 'Set up eval data and opened label output file')
    with redirect_stdout_to_stderr():
        logger.log(9, 'Creating model')
        from pydrobert.mol.model import ConvCTC
        model = ConvCTC(model_config)
        logger.log(9, 'Beginning decoding')
        batches_decoded = 0
        logger.log(9, '000/{:03d} batches decoded'.format(total_batches))
        for label_batch in model.decode_generator(decode_config, decode_data):
            if decode_data.batch_size:
                for key, label_ids in label_batch:
                    labels_out.write(
                        key, tuple(id2label_map[idee] for idee in label_ids))
            else:
                labels_out.write(label_batch[0], tuple(label_batch[1]))
            batches_decoded += 1
            if batches_decoded % max(1, total_batches // 10) == 0:
                logger.log(
                    9, '{:03d}/{:03d} batches decoded'.format(
                        batches_decoded, total_batches))
    logger.info('Done decoding')
Ejemplo n.º 2
0
def registered_regular_logger():
    logger_name = ''.join(chr(x + 97) for x in np.random.choice(26, 100))
    ret_logger = logging.getLogger(logger_name)
    s_stream = StringIO()
    ret_logger.addHandler(logging.StreamHandler(s_stream))
    register_logger_for_kaldi(logger_name)
    yield ret_logger
    deregister_logger_for_kaldi(logger_name)
    ret_logger.removeHandler(s_stream)
Ejemplo n.º 3
0
def registered_regular_logger():
    logger_name = ''.join(chr(x + 97) for x in np.random.choice(26, 100))
    ret_logger = logging.getLogger(logger_name)
    s_stream = StringIO()
    ret_logger.addHandler(logging.StreamHandler(s_stream))
    register_logger_for_kaldi(logger_name)
    yield ret_logger
    deregister_logger_for_kaldi(logger_name)
    ret_logger.removeHandler(s_stream)
Ejemplo n.º 4
0
def test_do_not_callback_unregistered(kaldi_logger):
    kaldi_logger.setLevel(logging.WARNING)
    verbose_log(-1, 'should see this')
    deregister_logger_for_kaldi(kaldi_logger.name)
    verbose_log(-1, 'should not see this')
    register_logger_for_kaldi('bingobangobongo')
    verbose_log(-1, 'still nothing')
    register_logger_for_kaldi(kaldi_logger.name)
    verbose_log(-1, 'but see this')
    s_stream = kaldi_logger.handlers[-1].stream
    assert 'should see this\nbut see this\n' == s_stream.getvalue()
Ejemplo n.º 5
0
def test_do_not_callback_unregistered(kaldi_logger):
    kaldi_logger.setLevel(logging.WARNING)
    verbose_log(-1, 'should see this')
    deregister_logger_for_kaldi(kaldi_logger.name)
    verbose_log(-1, 'should not see this')
    register_logger_for_kaldi('bingobangobongo')
    verbose_log(-1, 'still nothing')
    register_logger_for_kaldi(kaldi_logger.name)
    verbose_log(-1, 'but see this')
    s_stream = kaldi_logger.handlers[-1].stream
    assert 'should see this\nbut see this\n' == s_stream.getvalue()
Ejemplo n.º 6
0
def kaldi_logger():
    logger_name = ''.join(chr(x + 97) for x in np.random.choice(26, 100))
    old_class = logging.getLoggerClass()
    logging.setLoggerClass(KaldiLogger)
    ret_logger = logging.getLogger(logger_name)
    logging.setLoggerClass(old_class)
    s_stream = StringIO()
    ret_logger.addHandler(logging.StreamHandler(s_stream))
    register_logger_for_kaldi(logger_name)
    yield ret_logger
    deregister_logger_for_kaldi(logger_name)
    for handler in ret_logger.handlers:
        ret_logger.removeHandler(handler)
Ejemplo n.º 7
0
def kaldi_logger():
    logger_name = ''.join(chr(x + 97) for x in np.random.choice(26, 100))
    old_class = logging.getLoggerClass()
    logging.setLoggerClass(KaldiLogger)
    ret_logger = logging.getLogger(logger_name)
    logging.setLoggerClass(old_class)
    s_stream = StringIO()
    ret_logger.addHandler(logging.StreamHandler(s_stream))
    register_logger_for_kaldi(logger_name)
    yield ret_logger
    deregister_logger_for_kaldi(logger_name)
    for handler in ret_logger.handlers:
        ret_logger.removeHandler(handler)
Ejemplo n.º 8
0
def write_pickle_to_table(args=None):
    '''Write pickle file(s) contents to a table

    The inverse is write-table-to-pickle
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(logger)
    try:
        options = _write_pickle_to_table_parse_args(args, logger)
    except SystemExit as ex:
        return ex.code
    if options.key_in is None:
        return _write_pickle_to_table_value_only(options, logger)
    else:
        return _write_pickle_to_table_key_value(options, logger)
Ejemplo n.º 9
0
def alt_compute_cmvn_stats(args=None):
    '''Python-based code for CMVN statistics computation

    Used for debugging
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(sys.argv[0])
    options = _alt_compute_cmvn_stats_parse_args(args, logger)
    cmvn = CMVNCalculator()
    feat_table = io_open(options.feats_in, 'bm')
    num_utts = 0
    for feats in feat_table:
        cmvn.accumulate(feats)
        num_utts += 1
    logger.info('Accumulated stats for {} utterances'.format(num_utts))
    cmvn.save(options.cmvn_stats_out)
    logger.info('Wrote stats to {}'.format(options.cmvn_stats_out))
Ejemplo n.º 10
0
def alt_apply_cmvn(args=None):
    '''Python-based code for CMVN application

    Used for debugging
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(sys.argv[0])
    options = _alt_apply_cmvn_parse_args(args, logger)
    cmvn = CMVNCalculator(options.cmvn_stats_in)
    feats_in = io_open(options.feats_in, 'bm')
    feats_out = io_open(options.feats_out, 'bm', mode='w')
    num_utts = 0
    for utt_id, feats in feats_in.items():
        feats = cmvn.apply(feats, in_place=True)
        feats_out.write(utt_id, feats)
        num_utts += 1
    logger.info('Applied CMVN to {} utterances'.format(num_utts))
Ejemplo n.º 11
0
def alt_add_deltas(args=None):
    '''Python-based code for adding deltas

    Used for debugging
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(sys.argv[0])
    options = _alt_add_deltas_parse_args(args, logger)
    feats_in = io_open(options.feats_in, 'bm')
    feats_out = io_open(options.feats_out, 'bm', mode='w')
    num_utts = 0
    for utt_id, feats in feats_in.items():
        feats = calculate_deltas(feats, options.delta_order)
        feats_out.write(utt_id, feats)
        num_utts += 1
    logger.info('Added {} deltas to {} utterances'.format(
        options.delta_order, num_utts))
Ejemplo n.º 12
0
def compute_loss(args=None):
    '''Compute the average loss over the data'''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(sys.argv[0])
    options = _compute_loss_parse_args(args, logger)
    logger.log(9, 'Parsed options')
    label2id_map = dict()
    with open(options.label_to_id_map_path) as file_obj:
        for line in file_obj:
            label, idee = line.strip().split()
            idee = int(idee)
            if idee < 0:
                logger.error('All label ids must be nonnegative')
                return 1
            label2id_map[label] = idee
    if (len(label2id_map) + 1) != options.num_labels:
        logger.error('Expected {} labels in label_to_id_map, got {}'.format(
            options.num_labels - 1, len(label2id_map)))
        return 1
    for idee in range(options.num_labels - 1):
        if idee not in label2id_map.values():
            raise ValueError('label to id map missing id: {}'.format(idee))
    model_config = ModelConfig(**vars(options))
    decode_config = DecodeConfig(**vars(options))
    eval_data = ValidationData(
        options.data_rspecifier,
        options.labels_rspecifier,
        label2id_map,
        batch_size=decode_config.batch_size,
        delta_order=model_config.delta_order,
        cmvn_rxfilename=model_config.cmvn_rxfilename,
    )
    with redirect_stdout_to_stderr():
        logger.log(9, 'Creating model')
        from pydrobert.mol.model import ConvCTC
        model = ConvCTC(model_config)
        logger.log(9, 'Calculating loss')
        loss = model.evaluate_generator(decode_config, eval_data)
        logger.log(9, 'Calculated loss')
    print('Model loss:', loss)
Ejemplo n.º 13
0
def find_best_model_from_log(args=None):
    '''Find the best model from training log and return its path on stdout

    In a given training stage, the 'best' model is the one that minimizes or
    maximizes the target quantity (dependent on 'mode' and 'monitored').
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(sys.argv[0])
    options = _find_best_model_from_log_parse_args(args, logger)
    logger.log(9, 'Parsed options')
    csv_file = open(options.csv_path, mode='r')
    csv_reader = DictReader(csv_file, delimiter=options.csv_delimiter)
    fields = csv_reader.fieldnames
    if 'model_path' not in fields:
        logger.error('"model_path" must be a field in the csv')
        return 1
    if options.monitored is None:
        if 'val_acc' in fields:
            options.monitored = 'val_acc'
        elif 'val_loss' in fields:
            options.monitored = 'val_loss'
        elif 'acc' in fields:
            options.monitored = 'acc'
        elif 'loss' in fields:
            options.monitored = 'loss'
        else:
            logger.error(
                'Unable to find suitable value to monitor. Set --monitored '
                'manually')
            return 1
        logger.log(9,
                   'Using "{}" as monitored value'.format(options.monitored))
    elif options.monitored not in fields:
        logger.error('Monitored value "{}" not in CSV'.format(
            options.monitored))
        return 1
    if options.mode == 'min' or (options.mode == 'auto'
                                 and 'acc' not in options.monitored):
        monitor_op = np.less
        best = np.Inf
    else:  # max
        monitor_op = np.greater
        best = -np.Inf
    if options.training_stage and 'training_stage' not in fields:
        logger.error(
            '--training-stage specified but training_stage not in csv')
        return 1
    logger.log(9, 'Looking through CSV')
    best_path = None
    for row in csv_reader:
        if options.training_stage and (row['training_stage'] !=
                                       options.training_stage):
            continue
        current = float(row[options.monitored])
        if monitor_op(current, best):
            best = current
            best_path = row['model_path']
    logger.log(9, 'Looked through CSV')
    if best_path is None:
        logger.error('Could not find any model')
        return 1
    else:
        print(best_path)
        return 0
Ejemplo n.º 14
0
def compute_error_rate(args=None):
    '''Compute error rates between reference and hypothesis token vectors

    Two common error rates in speech are the word (WER) and phone (PER),
    though the computation is the same. Given a reference and hypothesis
    sequence, the error rate is

    >>> error_rate = (substitutions + insertions + deletions) / (
    ...     ref_tokens * 100)

    Where the number of substitutions (e.g. ``A B C -> A D C``),
    deletions (e.g. ``A B C -> A C``), and insertions (e.g.
    ``A B C -> A D B C``) are determined by Levenshtein distance.
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(sys.argv[0])
    options = _compute_error_rate_parse_args(args, logger)
    global_edit = 0
    global_token_count = 0
    global_sents = 0
    global_processed = 0
    inserts = dict()
    deletes = dict()
    subs = dict()
    totals = dict()

    def _err_on_utt_id(utt_id, missing_rxspecifier):
        msg = "Utterance '{}' absent in '{}'".format(
            utt_id, missing_rxspecifier)
        if options.strict:
            logger.error(msg)
            return 1
        else:
            logger.warning(msg)
            return 0
    return_tables = options.print_tables or not options.include_inserts_in_cost
    with kaldi_open(options.ref_rspecifier, 'tv') as ref_table, \
            kaldi_open(options.hyp_rspecifier, 'tv') as hyp_table:
        while not ref_table.done() and not hyp_table.done():
            global_sents += 1
            if ref_table.key() > hyp_table.key():
                if _err_on_utt_id(hyp_table.key(), options.ref_rspecifier):
                    return 1
                hyp_table.move()
            elif hyp_table.key() > ref_table.key():
                if _err_on_utt_id(ref_table.key(), options.hyp_rspecifier):
                    return 1
                ref_table.move()
            else:
                logger.debug('Processing {}: ref [{}] hyp [{}]'.format(
                    ref_table.key(),
                    ' '.join(ref_table.value()),
                    ' '.join(hyp_table.value())))
                global_token_count += len(ref_table.value())
                res = kaldi_eval_util.edit_distance(
                    ref_table.value(), hyp_table.value(),
                    return_tables=return_tables,
                    insertion_cost=options.insertion_cost,
                    deletion_cost=options.deletion_cost,
                    substitution_cost=options.substitution_cost,
                )
                if return_tables:
                    global_edit += res[0]
                    for global_dict, utt_dict in zip(
                            (inserts, deletes, subs, totals), res[1:]):
                        for token in ref_table.value() + hyp_table.value():
                            global_dict.setdefault(token, 0)
                        for token, count in utt_dict.items():
                            global_dict[token] += count
                else:
                    global_edit += res
            global_processed += 1
            ref_table.move()
            hyp_table.move()
        while not ref_table.done():
            if _err_on_utt_id(ref_table.key(), options.hyp_rspecifier):
                return 1
            global_sents += 1
            ref_table.move()
        while not hyp_table.done():
            if _err_on_utt_id(hyp_table.key(), options.ref_rspecifier):
                return 1
            global_sents += 1
            hyp_table.move()
    if options.out_path is None:
        out_file = sys.stdout
    else:
        out_file = open(options.out_path, 'w')
    print(
        "Processed {}/{}.".format(global_processed, global_sents),
        file=out_file, end=' '
    )
    if not options.include_inserts_in_cost:
        global_edit -= sum(inserts.values())
    if options.report_accuracy:
        print(
            'Accuracy: {:.2f}%'.format(
                (1 - global_edit / global_token_count) * 100),
            file=out_file,
        )
    else:
        print(
            'Error rate: {:.2f}%'.format(
                global_edit / global_token_count * 100),
            file=out_file,
        )
    if options.print_tables:
        print(
            "Total insertions: {}, deletions: {}, substitutions: {}".format(
                sum(inserts.values()), sum(deletes.values()),
                sum(subs.values())),
            file=out_file,
        )
        print("", file=out_file)
        tokens = list(set(inserts) | set(deletes) | set(subs))
        tokens.sort()
        token_len = max(max(len(token) for token in tokens), 5)
        max_count = max(
            chain(inserts.values(), deletes.values(), subs.values()))
        max_count_len = int(log10(max_count) + 1)
        divider_str = '+' + ('-' * (token_len + 1))
        divider_str += ('+' + ('-' * (max_count_len + 9))) * 4
        divider_str += '+'
        format_str = '|{{:<{}}}|'.format(token_len + 1)
        format_str += 4 * '{{:>{}}}({{:05.2f}}%)|'.format(max_count_len + 1)
        print(
            '|{2:<{0}}|{3:>{1}}(%)|{4:>{1}}(%)|{5:>{1}}(%)|{6:>{1}}(%)|'
            ''.format(
                token_len + 1, max_count_len + 6, 'token', 'inserts',
                'deletes', 'subs', 'errs',
            ),
            file=out_file,
        )
        print(divider_str, file=out_file)
        print(divider_str, file=out_file)
        for token in tokens:
            i, d, s = inserts[token], deletes[token], subs[token]
            t = totals[token]
            print(
                format_str.format(
                    token,
                    i, i / t * 100,
                    d, d / t * 100,
                    s, s / t * 100,
                    i + d + s, (i + d + s) / t * 100,
                ),
                file=out_file
            )
            print(divider_str, file=out_file)
    return 0
Ejemplo n.º 15
0
def write_table_to_pickle(args=None):
    '''Write a kaldi table to pickle file(s)

    The inverse is write-pickle-to-table
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(logger)
    try:
        options = _write_table_to_pickle_parse_args(args, logger)
    except SystemExit as ex:
        return ex.code
    out_type = options.out_type
    if out_type is None:
        if options.in_type.is_floating_point:
            if options.in_type.is_double:
                out_type = np.float64
            else:
                out_type = np.float32
        else:
            out_type = np.str
    from six.moves import cPickle as pickle
    try:
        reader = kaldi_open(options.rspecifier, options.in_type, 'r')
        if options.value_out.endswith('.gz'):
            import gzip
            value_out = gzip.open(options.value_out, 'wb')
        else:
            value_out = open(options.value_out, 'wb')
        if options.key_out:
            if options.key_out.endswith('.gz'):
                import gzip
                key_out = gzip.open(options.key_out, 'wt')
            else:
                key_out = open(options.key_out, 'w')
        else:
            key_out = None
    except IOError as error:
        logger.error(error.message, exc_info=True)
        return 1
    num_entries = 0
    try:
        for key, value in reader.items():
            num_entries += 1
            if not np.issubdtype(out_type, np.dtype(str).type):
                value = value.astype(out_type)
            if key_out:
                pickle.dump(value, value_out)
                pickle.dump(key, key_out)
            else:
                pickle.dump((key, value), value_out)
            if num_entries % 10 == 0:
                logger.info('Processed {} entries'.format(num_entries))
            logger.debug('Processed key {}'.format(key))
    except (IOError, ValueError) as error:
        logger.error(error.message, exc_info=True)
        return 1
    finally:
        value_out.close()
        if key_out:
            key_out.close()
    if num_entries == 0:
        logger.warning("No entries were written (table was empty)")
    else:
        logger.info("Wrote {} entries".format(num_entries))
    return 0
Ejemplo n.º 16
0
def write_torch_dir_to_table(args=None):
    '''Write a data directory containing PyTorch data files to a Kaldi table

    Reads from a folder in the format:

    ::
        folder/
          <file_prefix><key_1><file_suffix>
          <file_prefix><key_2><file_suffix>
          ...

    Where each file contains a PyTorch tensor. The contents of the file
    ``<file_prefix><key_1><file_suffix>`` will be written as a value in
    a Kaldi table with key ``<key_1>``
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(logger)
    try:
        options = _write_torch_dir_to_table_parse_args(args, logger)
    except SystemExit as ex:
        return ex.code
    if not os.path.isdir(options.dir):
        print("'{}' is not a directory".format(options.dir), file=sys.stderr)
        return 1
    import torch
    is_bool = False
    if options.out_type in {
            enums.KaldiDataType.BaseMatrix, enums.KaldiDataType.BaseVector,
            enums.KaldiDataType.WaveMatrix, enums.KaldiDataType.Base,
            enums.KaldiDataType.BasePairVector
    }:
        if options.out_type.is_double:
            torch_type = torch.double
        else:
            torch_type = torch.float
    elif options.out_type in {
            enums.KaldiDataType.FloatMatrix, enums.KaldiDataType.FloatVector
    }:
        torch_type = torch.float
    elif options.out_type in {
            enums.KaldiDataType.DoubleMatrix, enums.KaldiDataType.Double
    }:
        torch_type = torch.double
    elif options.out_type in {
            enums.KaldiDataType.Int32, enums.KaldiDataType.Int32Vector,
            enums.KaldiDataType.Int32VectorVector
    }:
        torch_type = torch.int
    elif options.out_type == enums.KaldiDataType.Boolean:
        torch_type = torch.uint8
        is_bool = True
    else:
        print('Do not know how to convert {} from torch type'.format(
            options.out_type),
              file=sys.stderr)
        return 1
    neg_fsl = -len(options.file_suffix)
    if not neg_fsl:
        neg_fsl = None
    fpl = len(options.file_prefix)
    utt_ids = sorted(
        os.path.basename(x)[fpl:neg_fsl] for x in os.listdir(options.dir) if
        x.startswith(options.file_prefix) and x.endswith(options.file_suffix))
    with kaldi_open(options.wspecifier, options.out_type, mode='w') as table:
        for utt_id in utt_ids:
            val = torch.load(
                os.path.join(
                    options.dir,
                    options.file_prefix + utt_id + options.file_suffix))
            val = val.cpu().type(torch_type).numpy()
            if is_bool:
                val = bool(val)  # make sure val is a scalar!
            table.write(utt_id, val)
    return 0
Ejemplo n.º 17
0
def compute_error_rate(args=None):
    '''Compute error rates between reference and hypothesis token vectors

    Two common error rates in speech are the word (WER) and phone (PER),
    though the computation is the same. Given a reference and hypothesis
    sequence, the error rate is

    >>> error_rate = (substitutions + insertions + deletions) / (
    ...     ref_tokens * 100)

    Where the number of substitutions (e.g. ``A B C -> A D C``),
    deletions (e.g. ``A B C -> A C``), and insertions (e.g.
    ``A B C -> A D B C``) are determined by Levenshtein distance.
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(sys.argv[0])
    options = _compute_error_rate_parse_args(args, logger)
    global_edit = 0
    global_token_count = 0
    global_sents = 0
    global_processed = 0
    inserts = dict()
    deletes = dict()
    subs = dict()
    totals = dict()

    def _err_on_utt_id(utt_id, missing_rxspecifier):
        msg = "Utterance '{}' absent in '{}'".format(utt_id,
                                                     missing_rxspecifier)
        if options.strict:
            logger.error(msg)
            return 1
        else:
            logger.warning(msg)
            return 0

    return_tables = options.print_tables or not options.include_inserts_in_cost
    with kaldi_open(options.ref_rspecifier, 'tv') as ref_table, \
            kaldi_open(options.hyp_rspecifier, 'tv') as hyp_table:
        while not ref_table.done() and not hyp_table.done():
            global_sents += 1
            if ref_table.key() > hyp_table.key():
                if _err_on_utt_id(hyp_table.key(), options.ref_rspecifier):
                    return 1
                hyp_table.move()
            elif hyp_table.key() > ref_table.key():
                if _err_on_utt_id(ref_table.key(), options.hyp_rspecifier):
                    return 1
                ref_table.move()
            else:
                logger.debug('Processing {}: ref [{}] hyp [{}]'.format(
                    ref_table.key(), ' '.join(ref_table.value()),
                    ' '.join(hyp_table.value())))
                global_token_count += len(ref_table.value())
                res = kaldi_eval_util.edit_distance(
                    ref_table.value(),
                    hyp_table.value(),
                    return_tables=return_tables,
                    insertion_cost=options.insertion_cost,
                    deletion_cost=options.deletion_cost,
                    substitution_cost=options.substitution_cost,
                )
                if return_tables:
                    global_edit += res[0]
                    for global_dict, utt_dict in zip(
                        (inserts, deletes, subs, totals), res[1:]):
                        for token in ref_table.value() + hyp_table.value():
                            global_dict.setdefault(token, 0)
                        for token, count in utt_dict.items():
                            global_dict[token] += count
                else:
                    global_edit += res
            global_processed += 1
            ref_table.move()
            hyp_table.move()
        while not ref_table.done():
            if _err_on_utt_id(ref_table.key(), options.hyp_rspecifier):
                return 1
            global_sents += 1
            ref_table.move()
        while not hyp_table.done():
            if _err_on_utt_id(hyp_table.key(), options.ref_rspecifier):
                return 1
            global_sents += 1
            hyp_table.move()
    if options.out_path is None:
        out_file = sys.stdout
    else:
        out_file = open(options.out_path, 'w')
    print("Processed {}/{}.".format(global_processed, global_sents),
          file=out_file,
          end=' ')
    if not options.include_inserts_in_cost:
        global_edit -= sum(inserts.values())
    if options.report_accuracy:
        print(
            'Accuracy: {:.2f}%'.format(
                (1 - global_edit / global_token_count) * 100),
            file=out_file,
        )
    else:
        print(
            'Error rate: {:.2f}%'.format(global_edit / global_token_count *
                                         100),
            file=out_file,
        )
    if options.print_tables:
        print(
            "Total insertions: {}, deletions: {}, substitutions: {}".format(
                sum(inserts.values()), sum(deletes.values()),
                sum(subs.values())),
            file=out_file,
        )
        print("", file=out_file)
        tokens = list(set(inserts) | set(deletes) | set(subs))
        tokens.sort()
        token_len = max(max(len(token) for token in tokens), 5)
        max_count = max(
            chain(inserts.values(), deletes.values(), subs.values()))
        max_count_len = int(log10(max_count) + 1)
        divider_str = '+' + ('-' * (token_len + 1))
        divider_str += ('+' + ('-' * (max_count_len + 9))) * 4
        divider_str += '+'
        format_str = '|{{:<{}}}|'.format(token_len + 1)
        format_str += 4 * '{{:>{}}}({{:05.2f}}%)|'.format(max_count_len + 1)
        print(
            '|{2:<{0}}|{3:>{1}}(%)|{4:>{1}}(%)|{5:>{1}}(%)|{6:>{1}}(%)|'
            ''.format(
                token_len + 1,
                max_count_len + 6,
                'token',
                'inserts',
                'deletes',
                'subs',
                'errs',
            ),
            file=out_file,
        )
        print(divider_str, file=out_file)
        print(divider_str, file=out_file)
        for token in tokens:
            i, d, s = inserts[token], deletes[token], subs[token]
            t = totals[token]
            print(format_str.format(
                token,
                i,
                i / t * 100,
                d,
                d / t * 100,
                s,
                s / t * 100,
                i + d + s,
                (i + d + s) / t * 100,
            ),
                  file=out_file)
            print(divider_str, file=out_file)
    return 0
Ejemplo n.º 18
0
def train_cnn_ctc(args=None):
    '''Train CNN w/ CTC decoding using kaldi data tables'''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(sys.argv[0])
    options = _train_cnn_ctc_parse_args(args, logger)
    logger.log(9, 'Parsed options')
    label2id_map = dict()
    with open(options.label_to_id_map_path) as file_obj:
        for line in file_obj:
            label, idee = line.strip().split()
            idee = int(idee)
            if idee < 0:
                logger.error('All label ids must be nonnegative')
                return 1
            label2id_map[label] = idee
    if (len(label2id_map) + 1) != options.num_labels:
        logger.error('Expected {} labels in label_to_id_map, got {}'.format(
            options.num_labels - 1, len(label2id_map)))
        return 1
    for idee in range(options.num_labels - 1):
        if idee not in label2id_map.values():
            raise ValueError('label to id map missing id: {}'.format(idee))
    logger.log(9, 'Loaded label to id map')
    model_config = ModelConfig(**vars(options))
    train_config = TrainConfig(**vars(options))
    train_data = TrainData(
        (options.data_rspecifier, 'bm', {
            'cache': train_config.cache
        }),
        (options.labels_rspecifier, 'tv', {
            'cache': train_config.cache
        }),
        label2id_map,
        batch_size=train_config.batch_size,
        delta_order=model_config.delta_order,
        cmvn_rxfilename=model_config.cmvn_rxfilename,
        rng=train_config.train_seed,
    )
    if options.val_data_rspecifier or options.val_labels_rspecifier:
        if None in (options.val_data_rspecifier,
                    options.val_labels_rspecifier):
            logger.error(
                "Both 'val_data_rspecifier' and 'val_labels_rspecifier' must "
                "be specified, or neither")
            return 1
        val_data = ValidationData(
            options.val_data_rspecifier,
            options.val_labels_rspecifier,
            label2id_map,
            batch_size=train_config.batch_size,
            delta_order=model_config.delta_order,
            cmvn_rxfilename=model_config.cmvn_rxfilename,
        )
    else:
        val_data = None
    logger.log(9, 'Set up training/validation data generators')
    with redirect_stdout_to_stderr():
        logger.log(9, 'Creating model')
        from pydrobert.mol.model import ConvCTC
        model = ConvCTC(model_config)
        logger.log(9, 'Beginning training')
        model.fit_generator(train_config, train_data, val_data=val_data)
        logger.log(9, 'Finished training')
Ejemplo n.º 19
0
def write_table_to_torch_dir(args=None):
    '''Write a Kaldi table to a series of PyTorch data files in a directory

    Writes to a folder in the format:

    ::
        folder/
          <file_prefix><key_1><file_suffix>
          <file_prefix><key_2><file_suffix>
          ...

    The contents of the file ``<file_prefix><key_1><file_suffix>`` will be
    a PyTorch tensor corresponding to the entry in the table for ``<key_1>``
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(logger)
    try:
        options = _write_table_to_torch_dir_parse_args(args, logger)
    except SystemExit as ex:
        return ex.code
    out_type = options.out_type
    if out_type is None:
        if options.in_type in {
                enums.KaldiDataType.BaseMatrix, enums.KaldiDataType.BaseVector,
                enums.KaldiDataType.WaveMatrix, enums.KaldiDataType.Base,
                enums.KaldiDataType.BasePairVector
        }:
            if options.in_type.is_double:
                out_type = 'double'
            else:
                out_type = 'float'
        elif options.in_type in {
                enums.KaldiDataType.FloatMatrix,
                enums.KaldiDataType.FloatVector
        }:
            out_type = 'float'
        elif options.in_type in {
                enums.KaldiDataType.DoubleMatrix, enums.KaldiDataType.Double
        }:
            out_type = 'double'
        elif options.in_type in {
                enums.KaldiDataType.Int32, enums.KaldiDataType.Int32Vector,
                enums.KaldiDataType.Int32VectorVector
        }:
            out_type = 'int'
        elif options.in_type == enums.KaldiDataType.Boolean:
            out_type = 'byte'
        else:
            print('Do not know how to convert {} to torch type'.format(
                options.in_type),
                  file=sys.stderr)
            return 1
    import torch
    if out_type == 'float':
        out_type = torch.float
    elif out_type == 'double':
        out_type = torch.double
    elif out_type == 'half':
        out_type = torch.half
    elif out_type == 'byte':
        out_type = torch.uint8
    elif out_type == 'char':
        out_type = torch.int8
    elif out_type == 'short':
        out_type = torch.short
    elif out_type == 'int':
        out_type = torch.int
    elif out_type == 'long':
        out_type = torch.long
    try:
        os.makedirs(options.dir)
    except FileExistsError:
        pass
    with kaldi_open(options.rspecifier, options.in_type) as table:
        for key, value in table.items():
            value = torch.tensor(value).type(out_type)
            torch.save(
                value,
                os.path.join(options.dir,
                             options.file_prefix + key + options.file_suffix))
    return 0
Ejemplo n.º 20
0
def normalize_feat_lens(args=None):
    '''Ensure features match some reference lengths

    Incoming features are either clipped or padded to match
    reference lengths (stored as an int32 table), if they are within
    tolerance.
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(sys.argv[0])
    options = _normalize_feat_lens_parse_args(args, logger)
    if options.pad_mode == 'zero':
        options.pad_mode = 'constant'
    feats_in = kaldi_open(options.feats_in_rspecifier, options.type, mode='r')
    len_in = kaldi_open(options.len_in_rspecifier, 'i', mode='r+')
    feats_out = kaldi_open(options.feats_out_wspecifier,
                           options.type,
                           mode='w')
    total_utts = 0
    processed_utts = 0
    for utt_id, feats in feats_in.items():
        total_utts += 1
        if utt_id not in len_in:
            msg = "Utterance '{}' absent in '{}'".format(
                utt_id, options.len_in_rspecifier)
            if options.strict:
                logger.error(msg)
                return 1
            else:
                logger.warning(msg)
                continue
        exp_feat_len = len_in[utt_id]
        act_feat_len = len(feats)
        logger.debug('{} exp len: {} act len: {}'.format(
            utt_id, exp_feat_len, act_feat_len))
        if act_feat_len < exp_feat_len:
            if act_feat_len < exp_feat_len - options.tolerance:
                msg = '{} has feature length {}, which is below the '
                msg += 'tolerance ({}) of the expected length {}'
                msg = msg.format(utt_id, act_feat_len, options.tolerance,
                                 exp_feat_len)
                if options.strict:
                    logger.error(msg)
                    return 1
                else:
                    logger.warning(msg)
                    continue
            # for matrices or vectors, this cast shouldn't be necessary.
            # If the user tries some special type like token vectors,
            # however, this *might* work as intended
            feats = np.array(feats, copy=False)
            pad_list = [(0, 0)] * len(feats.shape)
            if options.side == 'right':
                pad_list[0] = (0, exp_feat_len - act_feat_len)
            elif options.side == 'left':
                pad_list[0] = (exp_feat_len - act_feat_len, 0)
            else:
                pad_list[0] = ((exp_feat_len - act_feat_len) // 2,
                               (exp_feat_len - act_feat_len + 1) // 2)
            feats = np.pad(feats, pad_list, options.pad_mode)
        elif act_feat_len > exp_feat_len:
            if act_feat_len > exp_feat_len + options.tolerance:
                msg = '{} has feature length {}, which is above the '
                msg += 'tolerance ({}) of the expected length {}'
                msg = msg.format(utt_id, act_feat_len, options.tolerance,
                                 exp_feat_len)
                if options.strict:
                    logger.error(msg)
                    return 1
                else:
                    logger.warning(msg)
                    continue
            if options.side == 'right':
                feats = feats[:exp_feat_len - act_feat_len]
            elif options.side == 'left':
                feats = feats[exp_feat_len - act_feat_len:]
            else:
                feats = feats[(exp_feat_len - act_feat_len) //
                              2:(exp_feat_len - act_feat_len + 1) // 2]
        feats_out.write(utt_id, feats)
        processed_utts += 1
    logger.info('Processed {}/{} utterances'.format(processed_utts,
                                                    total_utts))
    return 0