예제 #1
0
def _write_pickle_to_table_empty(wspecifier, logger):
    '''Special case when pickle file(s) was/were empty'''
    # doesn't matter what type we choose; we're not writing anything
    try:
        kaldi_open(wspecifier, 'bm', 'w')
    except IOError as error:
        logger.error(error.message, exc_info=True)
        return 1
    logger.warning('No entries were written (pickle file(s) was/were empty)')
    return 0
예제 #2
0
def test_write_table_to_pickle(values, temp_file_1_name, temp_file_2_name):
    if len(values):
        kaldi_dtype = infer_kaldi_data_type(values[0]).value
    else:
        kaldi_dtype = 'bm'
    with kaldi_open('ark:' + temp_file_1_name, kaldi_dtype, 'w') as writer:
        for num, value in enumerate(values):
            writer.write(str(num), value)
    ret_code = command_line.write_table_to_pickle(
        ['ark:' + temp_file_1_name, temp_file_2_name, '-i', kaldi_dtype])
    assert ret_code == 0
    num_entries = 0
    pickle_file = open(temp_file_2_name, 'rb')
    num_entries = 0
    try:
        while True:
            key, value = pickle.load(pickle_file)
            num_entries = int(key) + 1
            try:
                values[num_entries - 1].dtype
                assert np.allclose(value, values[num_entries - 1])
            except AttributeError:
                assert value == values[num_entries - 1]
    except EOFError:
        pass
    assert num_entries == len(values)
예제 #3
0
def _write_pickle_to_table_value_only(options, logger):
    '''write_pickle_to_table when only value_in has been specified'''
    from six.moves import cPickle as pickle
    try:
        if options.value_in.endswith('.gz'):
            import gzip
            value_in = gzip.open(options.value_in, 'rb')
        else:
            value_in = open(options.value_in, 'rb')
    except IOError as error:
        logger.error(error.message, exc_info=True)
        return 1
    try:
        key, value = pickle.load(value_in)
    except pickle.UnpicklingError as error:
        logger.error(error.message, exc_info=True)
        return 1
    except EOFError:
        value_in.close()
        return _write_pickle_to_table_empty(options.wspecifier, logger)
    out_type = options.out_type
    try:
        writer = kaldi_open(options.wspecifier, out_type, 'w')
    except IOError as error:
        value_in.close()
        logger.error(error.message, exc_info=True)
        return 1
    num_entries = 0
    try:
        while True:
            if out_type.is_floating_point:
                if out_type.is_double:
                    try:
                        value = value.astype(np.float64, copy=False)
                    except AttributeError:
                        pass
                else:
                    try:
                        value = value.astype(np.float32, copy=False)
                    except AttributeError:
                        pass
            writer.write(key, value)
            num_entries += 1
            if num_entries % 10 == 0:
                logger.info('Processed {} entries'.format(num_entries))
            logger.debug('Processed key {}'.format(key))
            key, value = pickle.load(value_in)
    except EOFError:
        pass
    except (IOError, ValueError, TypeError, pickle.UnpicklingError) as error:
        if hasattr(error, 'message'):
            logger.error(error.message, exc_info=True)
        else:
            logger.error('error', exc_info=True)
        return 1
    finally:
        value_in.close()
    logger.info("Wrote {} entries".format(num_entries))
    return 0
예제 #4
0
 def __init__(self, rxfilename=None, along_axis=0):
     self.along_axis = along_axis
     if KaldiDataType.BaseMatrix.is_double:
         self.stats_dtype = np.float64
     else:
         self.stats_dtype = np.float32
     if rxfilename:
         with kaldi_open(rxfilename) as stats_file:
             self.stats = stats_file.read('bm')
     else:
         self.stats = None
예제 #5
0
def test_normalize_feat_lens(temp_file_1_name, temp_file_2_name,
                             temp_file_3_name):
    feats_a = np.random.random((10, 4))
    feats_b = np.random.random((5, 4))
    feats_c = np.random.random((4, 4))
    with kaldi_open('ark:' + temp_file_1_name, 'dm', 'w') as feats_in_writer:
        feats_in_writer.write('A', feats_a)
        feats_in_writer.write('B', feats_b)
        feats_in_writer.write('C', feats_c)
    with kaldi_open('ark:' + temp_file_2_name, 'i', 'w') as len_in_writer:
        len_in_writer.write('A', 9)
        len_in_writer.write('B', 7)
        len_in_writer.write('C', 4)
    ret_code = command_line.normalize_feat_lens([
        'ark:' + temp_file_1_name,
        'ark:' + temp_file_2_name,
        'ark:' + temp_file_3_name,
        '--type=dm',
        '--pad-mode=zero',
    ])
    assert ret_code == 0
    with kaldi_open('ark:' + temp_file_3_name, 'dm') as feats_out_reader:
        out_a = next(feats_out_reader)
        out_b = next(feats_out_reader)
        out_c = next(feats_out_reader)
        assert out_a.shape == (9, 4)
        assert np.allclose(out_a, feats_a[:9])
        assert out_b.shape == (7, 4)
        assert np.allclose(out_b[:5], feats_b)
        assert np.allclose(out_b[5:], 0)
        assert out_c.shape == (4, 4)
        assert np.allclose(out_c, feats_c)
    ret_code = command_line.normalize_feat_lens([
        'ark:' + temp_file_1_name,
        'ark:' + temp_file_2_name,
        'ark:' + temp_file_3_name,
        '--type=dm',
        '--tolerance=1',
        '--strict=true',
    ])
    assert ret_code == 1
예제 #6
0
def test_write_table_to_torch_dir(temp_dir):
    import torch
    out_dir = os.path.join(temp_dir, 'test_write_table_to_torch_dir')
    os.makedirs(out_dir)
    rwspecifier = 'ark:' + os.path.join(out_dir, 'table.ark')
    a = torch.rand(10, 4)
    b = torch.rand(5, 2)
    c = torch.rand(5, 100)
    with kaldi_open(rwspecifier, 'bm', mode='w') as table:
        table.write('a', a.numpy())
        table.write('b', b.numpy())
        table.write('c', c.numpy())
    assert not command_line.write_table_to_torch_dir([rwspecifier, out_dir])
    assert torch.allclose(c, torch.load(os.path.join(out_dir, 'c.pt')))
    assert torch.allclose(b, torch.load(os.path.join(out_dir, 'b.pt')))
    assert torch.allclose(a, torch.load(os.path.join(out_dir, 'a.pt')))
예제 #7
0
def test_write_torch_dir_to_table(temp_dir):
    import torch
    in_dir = os.path.join(temp_dir, 'test_write_torch_dir_to_table')
    rwspecifier = 'ark:' + os.path.join(in_dir, 'table.ark')
    os.makedirs(in_dir)
    a = torch.rand(5, 4)
    b = torch.rand(4, 3)
    c = torch.rand(3, 2)
    torch.save(a, os.path.join(in_dir, 'a.pt'))
    torch.save(b, os.path.join(in_dir, 'b.pt'))
    torch.save(c, os.path.join(in_dir, 'c.pt'))
    assert not command_line.write_torch_dir_to_table([in_dir, rwspecifier])
    with kaldi_open(rwspecifier, 'bm') as table:
        keys, vals = zip(*table.items())
        keys = tuple(keys)
        vals = tuple(vals)
    assert keys == ('a', 'b', 'c')
    assert len(vals) == 3
    for dval, tval in zip((a, b, c), vals):
        assert torch.allclose(dval, torch.from_numpy(tval))
예제 #8
0
def test_write_pickle_to_table(values, temp_file_1_name, temp_file_2_name):
    if len(values):
        kaldi_dtype = infer_kaldi_data_type(values[0]).value
    else:
        kaldi_dtype = 'bm'
    with open(temp_file_1_name, 'wb') as pickle_file:
        for num, value in enumerate(values):
            pickle.dump((str(num), value), pickle_file)
    ret_code = command_line.write_pickle_to_table(
        [temp_file_1_name, 'ark:' + temp_file_2_name, '-o', kaldi_dtype])
    assert ret_code == 0
    kaldi_reader = kaldi_open('ark:' + temp_file_2_name, kaldi_dtype, 'r')
    num_entries = 0
    for key, value in kaldi_reader.items():
        num_entries = int(key) + 1
        try:
            values[num_entries - 1].dtype
            assert np.allclose(value, values[num_entries - 1])
        except AttributeError:
            assert value == values[num_entries - 1]
    assert num_entries == len(values)
예제 #9
0
def compute_error_rate(args=None):
    '''Compute error rates between reference and hypothesis token vectors

    Two common error rates in speech are the word (WER) and phone (PER),
    though the computation is the same. Given a reference and hypothesis
    sequence, the error rate is

    >>> error_rate = (substitutions + insertions + deletions) / (
    ...     ref_tokens * 100)

    Where the number of substitutions (e.g. ``A B C -> A D C``),
    deletions (e.g. ``A B C -> A C``), and insertions (e.g.
    ``A B C -> A D B C``) are determined by Levenshtein distance.
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(sys.argv[0])
    options = _compute_error_rate_parse_args(args, logger)
    global_edit = 0
    global_token_count = 0
    global_sents = 0
    global_processed = 0
    inserts = dict()
    deletes = dict()
    subs = dict()
    totals = dict()

    def _err_on_utt_id(utt_id, missing_rxspecifier):
        msg = "Utterance '{}' absent in '{}'".format(
            utt_id, missing_rxspecifier)
        if options.strict:
            logger.error(msg)
            return 1
        else:
            logger.warning(msg)
            return 0
    return_tables = options.print_tables or not options.include_inserts_in_cost
    with kaldi_open(options.ref_rspecifier, 'tv') as ref_table, \
            kaldi_open(options.hyp_rspecifier, 'tv') as hyp_table:
        while not ref_table.done() and not hyp_table.done():
            global_sents += 1
            if ref_table.key() > hyp_table.key():
                if _err_on_utt_id(hyp_table.key(), options.ref_rspecifier):
                    return 1
                hyp_table.move()
            elif hyp_table.key() > ref_table.key():
                if _err_on_utt_id(ref_table.key(), options.hyp_rspecifier):
                    return 1
                ref_table.move()
            else:
                logger.debug('Processing {}: ref [{}] hyp [{}]'.format(
                    ref_table.key(),
                    ' '.join(ref_table.value()),
                    ' '.join(hyp_table.value())))
                global_token_count += len(ref_table.value())
                res = kaldi_eval_util.edit_distance(
                    ref_table.value(), hyp_table.value(),
                    return_tables=return_tables,
                    insertion_cost=options.insertion_cost,
                    deletion_cost=options.deletion_cost,
                    substitution_cost=options.substitution_cost,
                )
                if return_tables:
                    global_edit += res[0]
                    for global_dict, utt_dict in zip(
                            (inserts, deletes, subs, totals), res[1:]):
                        for token in ref_table.value() + hyp_table.value():
                            global_dict.setdefault(token, 0)
                        for token, count in utt_dict.items():
                            global_dict[token] += count
                else:
                    global_edit += res
            global_processed += 1
            ref_table.move()
            hyp_table.move()
        while not ref_table.done():
            if _err_on_utt_id(ref_table.key(), options.hyp_rspecifier):
                return 1
            global_sents += 1
            ref_table.move()
        while not hyp_table.done():
            if _err_on_utt_id(hyp_table.key(), options.ref_rspecifier):
                return 1
            global_sents += 1
            hyp_table.move()
    if options.out_path is None:
        out_file = sys.stdout
    else:
        out_file = open(options.out_path, 'w')
    print(
        "Processed {}/{}.".format(global_processed, global_sents),
        file=out_file, end=' '
    )
    if not options.include_inserts_in_cost:
        global_edit -= sum(inserts.values())
    if options.report_accuracy:
        print(
            'Accuracy: {:.2f}%'.format(
                (1 - global_edit / global_token_count) * 100),
            file=out_file,
        )
    else:
        print(
            'Error rate: {:.2f}%'.format(
                global_edit / global_token_count * 100),
            file=out_file,
        )
    if options.print_tables:
        print(
            "Total insertions: {}, deletions: {}, substitutions: {}".format(
                sum(inserts.values()), sum(deletes.values()),
                sum(subs.values())),
            file=out_file,
        )
        print("", file=out_file)
        tokens = list(set(inserts) | set(deletes) | set(subs))
        tokens.sort()
        token_len = max(max(len(token) for token in tokens), 5)
        max_count = max(
            chain(inserts.values(), deletes.values(), subs.values()))
        max_count_len = int(log10(max_count) + 1)
        divider_str = '+' + ('-' * (token_len + 1))
        divider_str += ('+' + ('-' * (max_count_len + 9))) * 4
        divider_str += '+'
        format_str = '|{{:<{}}}|'.format(token_len + 1)
        format_str += 4 * '{{:>{}}}({{:05.2f}}%)|'.format(max_count_len + 1)
        print(
            '|{2:<{0}}|{3:>{1}}(%)|{4:>{1}}(%)|{5:>{1}}(%)|{6:>{1}}(%)|'
            ''.format(
                token_len + 1, max_count_len + 6, 'token', 'inserts',
                'deletes', 'subs', 'errs',
            ),
            file=out_file,
        )
        print(divider_str, file=out_file)
        print(divider_str, file=out_file)
        for token in tokens:
            i, d, s = inserts[token], deletes[token], subs[token]
            t = totals[token]
            print(
                format_str.format(
                    token,
                    i, i / t * 100,
                    d, d / t * 100,
                    s, s / t * 100,
                    i + d + s, (i + d + s) / t * 100,
                ),
                file=out_file
            )
            print(divider_str, file=out_file)
    return 0
예제 #10
0
def write_table_to_pickle(args=None):
    '''Write a kaldi table to pickle file(s)

    The inverse is write-pickle-to-table
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(logger)
    try:
        options = _write_table_to_pickle_parse_args(args, logger)
    except SystemExit as ex:
        return ex.code
    out_type = options.out_type
    if out_type is None:
        if options.in_type.is_floating_point:
            if options.in_type.is_double:
                out_type = np.float64
            else:
                out_type = np.float32
        else:
            out_type = np.str
    from six.moves import cPickle as pickle
    try:
        reader = kaldi_open(options.rspecifier, options.in_type, 'r')
        if options.value_out.endswith('.gz'):
            import gzip
            value_out = gzip.open(options.value_out, 'wb')
        else:
            value_out = open(options.value_out, 'wb')
        if options.key_out:
            if options.key_out.endswith('.gz'):
                import gzip
                key_out = gzip.open(options.key_out, 'wt')
            else:
                key_out = open(options.key_out, 'w')
        else:
            key_out = None
    except IOError as error:
        logger.error(error.message, exc_info=True)
        return 1
    num_entries = 0
    try:
        for key, value in reader.items():
            num_entries += 1
            if not np.issubdtype(out_type, np.dtype(str).type):
                value = value.astype(out_type)
            if key_out:
                pickle.dump(value, value_out)
                pickle.dump(key, key_out)
            else:
                pickle.dump((key, value), value_out)
            if num_entries % 10 == 0:
                logger.info('Processed {} entries'.format(num_entries))
            logger.debug('Processed key {}'.format(key))
    except (IOError, ValueError) as error:
        logger.error(error.message, exc_info=True)
        return 1
    finally:
        value_out.close()
        if key_out:
            key_out.close()
    if num_entries == 0:
        logger.warning("No entries were written (table was empty)")
    else:
        logger.info("Wrote {} entries".format(num_entries))
    return 0
예제 #11
0
def write_torch_dir_to_table(args=None):
    '''Write a data directory containing PyTorch data files to a Kaldi table

    Reads from a folder in the format:

    ::
        folder/
          <file_prefix><key_1><file_suffix>
          <file_prefix><key_2><file_suffix>
          ...

    Where each file contains a PyTorch tensor. The contents of the file
    ``<file_prefix><key_1><file_suffix>`` will be written as a value in
    a Kaldi table with key ``<key_1>``
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(logger)
    try:
        options = _write_torch_dir_to_table_parse_args(args, logger)
    except SystemExit as ex:
        return ex.code
    if not os.path.isdir(options.dir):
        print("'{}' is not a directory".format(options.dir), file=sys.stderr)
        return 1
    import torch
    is_bool = False
    if options.out_type in {
            enums.KaldiDataType.BaseMatrix, enums.KaldiDataType.BaseVector,
            enums.KaldiDataType.WaveMatrix, enums.KaldiDataType.Base,
            enums.KaldiDataType.BasePairVector
    }:
        if options.out_type.is_double:
            torch_type = torch.double
        else:
            torch_type = torch.float
    elif options.out_type in {
            enums.KaldiDataType.FloatMatrix, enums.KaldiDataType.FloatVector
    }:
        torch_type = torch.float
    elif options.out_type in {
            enums.KaldiDataType.DoubleMatrix, enums.KaldiDataType.Double
    }:
        torch_type = torch.double
    elif options.out_type in {
            enums.KaldiDataType.Int32, enums.KaldiDataType.Int32Vector,
            enums.KaldiDataType.Int32VectorVector
    }:
        torch_type = torch.int
    elif options.out_type == enums.KaldiDataType.Boolean:
        torch_type = torch.uint8
        is_bool = True
    else:
        print('Do not know how to convert {} from torch type'.format(
            options.out_type),
              file=sys.stderr)
        return 1
    neg_fsl = -len(options.file_suffix)
    if not neg_fsl:
        neg_fsl = None
    fpl = len(options.file_prefix)
    utt_ids = sorted(
        os.path.basename(x)[fpl:neg_fsl] for x in os.listdir(options.dir) if
        x.startswith(options.file_prefix) and x.endswith(options.file_suffix))
    with kaldi_open(options.wspecifier, options.out_type, mode='w') as table:
        for utt_id in utt_ids:
            val = torch.load(
                os.path.join(
                    options.dir,
                    options.file_prefix + utt_id + options.file_suffix))
            val = val.cpu().type(torch_type).numpy()
            if is_bool:
                val = bool(val)  # make sure val is a scalar!
            table.write(utt_id, val)
    return 0
예제 #12
0
def write_table_to_torch_dir(args=None):
    '''Write a Kaldi table to a series of PyTorch data files in a directory

    Writes to a folder in the format:

    ::
        folder/
          <file_prefix><key_1><file_suffix>
          <file_prefix><key_2><file_suffix>
          ...

    The contents of the file ``<file_prefix><key_1><file_suffix>`` will be
    a PyTorch tensor corresponding to the entry in the table for ``<key_1>``
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(logger)
    try:
        options = _write_table_to_torch_dir_parse_args(args, logger)
    except SystemExit as ex:
        return ex.code
    out_type = options.out_type
    if out_type is None:
        if options.in_type in {
                enums.KaldiDataType.BaseMatrix, enums.KaldiDataType.BaseVector,
                enums.KaldiDataType.WaveMatrix, enums.KaldiDataType.Base,
                enums.KaldiDataType.BasePairVector
        }:
            if options.in_type.is_double:
                out_type = 'double'
            else:
                out_type = 'float'
        elif options.in_type in {
                enums.KaldiDataType.FloatMatrix,
                enums.KaldiDataType.FloatVector
        }:
            out_type = 'float'
        elif options.in_type in {
                enums.KaldiDataType.DoubleMatrix, enums.KaldiDataType.Double
        }:
            out_type = 'double'
        elif options.in_type in {
                enums.KaldiDataType.Int32, enums.KaldiDataType.Int32Vector,
                enums.KaldiDataType.Int32VectorVector
        }:
            out_type = 'int'
        elif options.in_type == enums.KaldiDataType.Boolean:
            out_type = 'byte'
        else:
            print('Do not know how to convert {} to torch type'.format(
                options.in_type),
                  file=sys.stderr)
            return 1
    import torch
    if out_type == 'float':
        out_type = torch.float
    elif out_type == 'double':
        out_type = torch.double
    elif out_type == 'half':
        out_type = torch.half
    elif out_type == 'byte':
        out_type = torch.uint8
    elif out_type == 'char':
        out_type = torch.int8
    elif out_type == 'short':
        out_type = torch.short
    elif out_type == 'int':
        out_type = torch.int
    elif out_type == 'long':
        out_type = torch.long
    try:
        os.makedirs(options.dir)
    except FileExistsError:
        pass
    with kaldi_open(options.rspecifier, options.in_type) as table:
        for key, value in table.items():
            value = torch.tensor(value).type(out_type)
            torch.save(
                value,
                os.path.join(options.dir,
                             options.file_prefix + key + options.file_suffix))
    return 0
예제 #13
0
def _write_pickle_to_table_key_value(options, logger):
    try:
        logger.info('Opening {}'.format(options.value_in))
        if options.value_in.endswith('.gz'):
            import gzip
            value_in = gzip.open(options.value_in, 'rb')
        else:
            value_in = open(options.value_in, 'rb')
        logger.info('Opening {}'.format(options.key_in))
        if options.key_in.endswith('.gz'):
            import gzip
            key_in = gzip.open(options.key_in, 'rt')
        else:
            key_in = open(options.key_in, 'r')
    except IOError as error:
        logger.error(error.message, exc_info=True)
        return 1
    try:
        value = pickle.load(value_in)
    except pickle.UnpicklingError as error:
        value_in.close()
        key_in.close()
        logger.error(error.message, exc_info=True)
        return 1
    except EOFError:
        value_in.close()
        try:
            pickle.load(key_in)
            logger.error('Number of keys (1) and values (0) do not match')
            return 1
        except pickle.UnpicklingError as error:
            key_in.close()
            logger.error(error.message, exc_info=True)
            return 1
        key_in.close()
        return _write_pickle_to_table_empty(options.wspecifier, logger)
    try:
        key = pickle.load(key_in)
    except EOFError:
        value_in.close()
        key_in.close()
        logger.error('Number of keys (0) and values (1) do not match')
        return 1
    except pickle.UnpicklingError as error:
        value_in.close()
        key_in.close()
        logger.error(error.message, exc_info=True)
        return 1
    out_type = options.out_type
    try:
        logging.info('Opening {}'.format(options.wspecifier))
        writer = kaldi_open(options.wspecifier, out_type, 'w')
    except IOError as error:
        value_in.close()
        key_in.close()
        logger.error(error.message, exc_info=True)
        return 1
    num_entries = 0
    try:
        while True:
            if out_type.is_floating_point:
                if out_type.is_double:
                    try:
                        value = value.astype(np.float64, copy=False)
                    except AttributeError:
                        pass  # will happen implicitly
                else:
                    try:
                        value = value.astype(np.float32, copy=False)
                    except AttributeError:
                        pass  # will happen implicitly
            writer.write(key, value)
            num_entries += 1
            if num_entries % 10 == 0:
                logger.info('Processed {} entries'.format(num_entries))
            logger.debug('Processed key {}'.format(key))
            key = pickle.load(key_in)
            value = pickle.load(value_in)
    except EOFError:
        pass
    except (IOError, ValueError, TypeError, pickle.UnpicklingError) as error:
        logger.error(error.message, exc_info=True)
        return 1
    try:
        pickle.load(value_in)
        value_in.close()
        key_in.close()
        logger.error('Number of keys ({}) and values ({}) do not match'.format(
            num_entries, num_entries + 1))
        return 1
    except EOFError:
        pass
    except (IOError, pickle.UnpicklingError) as error:
        value_in.close()
        key_in.close()
        logger.error(error.message, exc_info=True)
        return 1
    try:
        pickle.load(key_in)
        value_in.close()
        key_in.close()
        logger.error('Number of keys ({}) and values ({}) do not match'.format(
            num_entries + 1, num_entries))
        return 1
    except EOFError:
        pass
    except (IOError, pickle.UnpicklingError) as error:
        logger.error(error.message, exc_info=True)
        return 1
    finally:
        value_in.close()
        key_in.close()
    logger.info("Wrote {} entries".format(num_entries))
    return 0
예제 #14
0
def compute_error_rate(args=None):
    '''Compute error rates between reference and hypothesis token vectors

    Two common error rates in speech are the word (WER) and phone (PER),
    though the computation is the same. Given a reference and hypothesis
    sequence, the error rate is

    >>> error_rate = (substitutions + insertions + deletions) / (
    ...     ref_tokens * 100)

    Where the number of substitutions (e.g. ``A B C -> A D C``),
    deletions (e.g. ``A B C -> A C``), and insertions (e.g.
    ``A B C -> A D B C``) are determined by Levenshtein distance.
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(sys.argv[0])
    options = _compute_error_rate_parse_args(args, logger)
    global_edit = 0
    global_token_count = 0
    global_sents = 0
    global_processed = 0
    inserts = dict()
    deletes = dict()
    subs = dict()
    totals = dict()

    def _err_on_utt_id(utt_id, missing_rxspecifier):
        msg = "Utterance '{}' absent in '{}'".format(utt_id,
                                                     missing_rxspecifier)
        if options.strict:
            logger.error(msg)
            return 1
        else:
            logger.warning(msg)
            return 0

    return_tables = options.print_tables or not options.include_inserts_in_cost
    with kaldi_open(options.ref_rspecifier, 'tv') as ref_table, \
            kaldi_open(options.hyp_rspecifier, 'tv') as hyp_table:
        while not ref_table.done() and not hyp_table.done():
            global_sents += 1
            if ref_table.key() > hyp_table.key():
                if _err_on_utt_id(hyp_table.key(), options.ref_rspecifier):
                    return 1
                hyp_table.move()
            elif hyp_table.key() > ref_table.key():
                if _err_on_utt_id(ref_table.key(), options.hyp_rspecifier):
                    return 1
                ref_table.move()
            else:
                logger.debug('Processing {}: ref [{}] hyp [{}]'.format(
                    ref_table.key(), ' '.join(ref_table.value()),
                    ' '.join(hyp_table.value())))
                global_token_count += len(ref_table.value())
                res = kaldi_eval_util.edit_distance(
                    ref_table.value(),
                    hyp_table.value(),
                    return_tables=return_tables,
                    insertion_cost=options.insertion_cost,
                    deletion_cost=options.deletion_cost,
                    substitution_cost=options.substitution_cost,
                )
                if return_tables:
                    global_edit += res[0]
                    for global_dict, utt_dict in zip(
                        (inserts, deletes, subs, totals), res[1:]):
                        for token in ref_table.value() + hyp_table.value():
                            global_dict.setdefault(token, 0)
                        for token, count in utt_dict.items():
                            global_dict[token] += count
                else:
                    global_edit += res
            global_processed += 1
            ref_table.move()
            hyp_table.move()
        while not ref_table.done():
            if _err_on_utt_id(ref_table.key(), options.hyp_rspecifier):
                return 1
            global_sents += 1
            ref_table.move()
        while not hyp_table.done():
            if _err_on_utt_id(hyp_table.key(), options.ref_rspecifier):
                return 1
            global_sents += 1
            hyp_table.move()
    if options.out_path is None:
        out_file = sys.stdout
    else:
        out_file = open(options.out_path, 'w')
    print("Processed {}/{}.".format(global_processed, global_sents),
          file=out_file,
          end=' ')
    if not options.include_inserts_in_cost:
        global_edit -= sum(inserts.values())
    if options.report_accuracy:
        print(
            'Accuracy: {:.2f}%'.format(
                (1 - global_edit / global_token_count) * 100),
            file=out_file,
        )
    else:
        print(
            'Error rate: {:.2f}%'.format(global_edit / global_token_count *
                                         100),
            file=out_file,
        )
    if options.print_tables:
        print(
            "Total insertions: {}, deletions: {}, substitutions: {}".format(
                sum(inserts.values()), sum(deletes.values()),
                sum(subs.values())),
            file=out_file,
        )
        print("", file=out_file)
        tokens = list(set(inserts) | set(deletes) | set(subs))
        tokens.sort()
        token_len = max(max(len(token) for token in tokens), 5)
        max_count = max(
            chain(inserts.values(), deletes.values(), subs.values()))
        max_count_len = int(log10(max_count) + 1)
        divider_str = '+' + ('-' * (token_len + 1))
        divider_str += ('+' + ('-' * (max_count_len + 9))) * 4
        divider_str += '+'
        format_str = '|{{:<{}}}|'.format(token_len + 1)
        format_str += 4 * '{{:>{}}}({{:05.2f}}%)|'.format(max_count_len + 1)
        print(
            '|{2:<{0}}|{3:>{1}}(%)|{4:>{1}}(%)|{5:>{1}}(%)|{6:>{1}}(%)|'
            ''.format(
                token_len + 1,
                max_count_len + 6,
                'token',
                'inserts',
                'deletes',
                'subs',
                'errs',
            ),
            file=out_file,
        )
        print(divider_str, file=out_file)
        print(divider_str, file=out_file)
        for token in tokens:
            i, d, s = inserts[token], deletes[token], subs[token]
            t = totals[token]
            print(format_str.format(
                token,
                i,
                i / t * 100,
                d,
                d / t * 100,
                s,
                s / t * 100,
                i + d + s,
                (i + d + s) / t * 100,
            ),
                  file=out_file)
            print(divider_str, file=out_file)
    return 0
예제 #15
0
 def save(self, wxfilename):
     '''Save statistics to extended file name'''
     with kaldi_open(wxfilename, mode='w') as stats_file:
         stats_file.write(self.stats, 'bm')
예제 #16
0
def normalize_feat_lens(args=None):
    '''Ensure features match some reference lengths

    Incoming features are either clipped or padded to match
    reference lengths (stored as an int32 table), if they are within
    tolerance.
    '''
    logger = logging.getLogger(sys.argv[0])
    if not logger.handlers:
        logger.addHandler(logging.StreamHandler())
    register_logger_for_kaldi(sys.argv[0])
    options = _normalize_feat_lens_parse_args(args, logger)
    if options.pad_mode == 'zero':
        options.pad_mode = 'constant'
    feats_in = kaldi_open(options.feats_in_rspecifier, options.type, mode='r')
    len_in = kaldi_open(options.len_in_rspecifier, 'i', mode='r+')
    feats_out = kaldi_open(options.feats_out_wspecifier,
                           options.type,
                           mode='w')
    total_utts = 0
    processed_utts = 0
    for utt_id, feats in feats_in.items():
        total_utts += 1
        if utt_id not in len_in:
            msg = "Utterance '{}' absent in '{}'".format(
                utt_id, options.len_in_rspecifier)
            if options.strict:
                logger.error(msg)
                return 1
            else:
                logger.warning(msg)
                continue
        exp_feat_len = len_in[utt_id]
        act_feat_len = len(feats)
        logger.debug('{} exp len: {} act len: {}'.format(
            utt_id, exp_feat_len, act_feat_len))
        if act_feat_len < exp_feat_len:
            if act_feat_len < exp_feat_len - options.tolerance:
                msg = '{} has feature length {}, which is below the '
                msg += 'tolerance ({}) of the expected length {}'
                msg = msg.format(utt_id, act_feat_len, options.tolerance,
                                 exp_feat_len)
                if options.strict:
                    logger.error(msg)
                    return 1
                else:
                    logger.warning(msg)
                    continue
            # for matrices or vectors, this cast shouldn't be necessary.
            # If the user tries some special type like token vectors,
            # however, this *might* work as intended
            feats = np.array(feats, copy=False)
            pad_list = [(0, 0)] * len(feats.shape)
            if options.side == 'right':
                pad_list[0] = (0, exp_feat_len - act_feat_len)
            elif options.side == 'left':
                pad_list[0] = (exp_feat_len - act_feat_len, 0)
            else:
                pad_list[0] = ((exp_feat_len - act_feat_len) // 2,
                               (exp_feat_len - act_feat_len + 1) // 2)
            feats = np.pad(feats, pad_list, options.pad_mode)
        elif act_feat_len > exp_feat_len:
            if act_feat_len > exp_feat_len + options.tolerance:
                msg = '{} has feature length {}, which is above the '
                msg += 'tolerance ({}) of the expected length {}'
                msg = msg.format(utt_id, act_feat_len, options.tolerance,
                                 exp_feat_len)
                if options.strict:
                    logger.error(msg)
                    return 1
                else:
                    logger.warning(msg)
                    continue
            if options.side == 'right':
                feats = feats[:exp_feat_len - act_feat_len]
            elif options.side == 'left':
                feats = feats[exp_feat_len - act_feat_len:]
            else:
                feats = feats[(exp_feat_len - act_feat_len) //
                              2:(exp_feat_len - act_feat_len + 1) // 2]
        feats_out.write(utt_id, feats)
        processed_utts += 1
    logger.info('Processed {}/{} utterances'.format(processed_utts,
                                                    total_utts))
    return 0