def decode_cnn_ctc(args=None): '''Decode CNN w/ CTC using kaldi data tables''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(sys.argv[0]) options = _decode_cnn_ctc_parse_args(args, logger) logger.log(9, 'Parsed options') id2label_map = dict() with open(options.label_to_id_map_path) as file_obj: for line in file_obj: label, idee = line.strip().split() idee = int(idee) if idee < 0: logger.error('All label ids must be nonnegative') return 1 id2label_map[idee] = label if (len(id2label_map) + 1) != options.num_labels: logger.error('Expected {} labels in id to label map, got {}'.format( options.num_labels - 1, len(id2label_map))) return 1 for idee in range(options.num_labels - 1): if idee not in id2label_map.keys(): logger.error('label to id map missing id: {}'.format(idee)) return 1 logger.log(9, 'Loaded label to id map') model_config = ModelConfig(**vars(options)) decode_config = DecodeConfig(**vars(options)) decode_data = DecodeData( options.data_rspecifier, delta_order=model_config.delta_order, cmvn_rxfilename=model_config.cmvn_rxfilename, batch_size=decode_config.batch_size, ) total_batches = len(decode_data) labels_out = io_open(options.output_wspecifier, 'tv', mode='w') logger.log(9, 'Set up eval data and opened label output file') with redirect_stdout_to_stderr(): logger.log(9, 'Creating model') from pydrobert.mol.model import ConvCTC model = ConvCTC(model_config) logger.log(9, 'Beginning decoding') batches_decoded = 0 logger.log(9, '000/{:03d} batches decoded'.format(total_batches)) for label_batch in model.decode_generator(decode_config, decode_data): if decode_data.batch_size: for key, label_ids in label_batch: labels_out.write( key, tuple(id2label_map[idee] for idee in label_ids)) else: labels_out.write(label_batch[0], tuple(label_batch[1])) batches_decoded += 1 if batches_decoded % max(1, total_batches // 10) == 0: logger.log( 9, '{:03d}/{:03d} batches decoded'.format( batches_decoded, total_batches)) logger.info('Done decoding')
def registered_regular_logger(): logger_name = ''.join(chr(x + 97) for x in np.random.choice(26, 100)) ret_logger = logging.getLogger(logger_name) s_stream = StringIO() ret_logger.addHandler(logging.StreamHandler(s_stream)) register_logger_for_kaldi(logger_name) yield ret_logger deregister_logger_for_kaldi(logger_name) ret_logger.removeHandler(s_stream)
def test_do_not_callback_unregistered(kaldi_logger): kaldi_logger.setLevel(logging.WARNING) verbose_log(-1, 'should see this') deregister_logger_for_kaldi(kaldi_logger.name) verbose_log(-1, 'should not see this') register_logger_for_kaldi('bingobangobongo') verbose_log(-1, 'still nothing') register_logger_for_kaldi(kaldi_logger.name) verbose_log(-1, 'but see this') s_stream = kaldi_logger.handlers[-1].stream assert 'should see this\nbut see this\n' == s_stream.getvalue()
def kaldi_logger(): logger_name = ''.join(chr(x + 97) for x in np.random.choice(26, 100)) old_class = logging.getLoggerClass() logging.setLoggerClass(KaldiLogger) ret_logger = logging.getLogger(logger_name) logging.setLoggerClass(old_class) s_stream = StringIO() ret_logger.addHandler(logging.StreamHandler(s_stream)) register_logger_for_kaldi(logger_name) yield ret_logger deregister_logger_for_kaldi(logger_name) for handler in ret_logger.handlers: ret_logger.removeHandler(handler)
def write_pickle_to_table(args=None): '''Write pickle file(s) contents to a table The inverse is write-table-to-pickle ''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(logger) try: options = _write_pickle_to_table_parse_args(args, logger) except SystemExit as ex: return ex.code if options.key_in is None: return _write_pickle_to_table_value_only(options, logger) else: return _write_pickle_to_table_key_value(options, logger)
def alt_compute_cmvn_stats(args=None): '''Python-based code for CMVN statistics computation Used for debugging ''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(sys.argv[0]) options = _alt_compute_cmvn_stats_parse_args(args, logger) cmvn = CMVNCalculator() feat_table = io_open(options.feats_in, 'bm') num_utts = 0 for feats in feat_table: cmvn.accumulate(feats) num_utts += 1 logger.info('Accumulated stats for {} utterances'.format(num_utts)) cmvn.save(options.cmvn_stats_out) logger.info('Wrote stats to {}'.format(options.cmvn_stats_out))
def alt_apply_cmvn(args=None): '''Python-based code for CMVN application Used for debugging ''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(sys.argv[0]) options = _alt_apply_cmvn_parse_args(args, logger) cmvn = CMVNCalculator(options.cmvn_stats_in) feats_in = io_open(options.feats_in, 'bm') feats_out = io_open(options.feats_out, 'bm', mode='w') num_utts = 0 for utt_id, feats in feats_in.items(): feats = cmvn.apply(feats, in_place=True) feats_out.write(utt_id, feats) num_utts += 1 logger.info('Applied CMVN to {} utterances'.format(num_utts))
def alt_add_deltas(args=None): '''Python-based code for adding deltas Used for debugging ''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(sys.argv[0]) options = _alt_add_deltas_parse_args(args, logger) feats_in = io_open(options.feats_in, 'bm') feats_out = io_open(options.feats_out, 'bm', mode='w') num_utts = 0 for utt_id, feats in feats_in.items(): feats = calculate_deltas(feats, options.delta_order) feats_out.write(utt_id, feats) num_utts += 1 logger.info('Added {} deltas to {} utterances'.format( options.delta_order, num_utts))
def compute_loss(args=None): '''Compute the average loss over the data''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(sys.argv[0]) options = _compute_loss_parse_args(args, logger) logger.log(9, 'Parsed options') label2id_map = dict() with open(options.label_to_id_map_path) as file_obj: for line in file_obj: label, idee = line.strip().split() idee = int(idee) if idee < 0: logger.error('All label ids must be nonnegative') return 1 label2id_map[label] = idee if (len(label2id_map) + 1) != options.num_labels: logger.error('Expected {} labels in label_to_id_map, got {}'.format( options.num_labels - 1, len(label2id_map))) return 1 for idee in range(options.num_labels - 1): if idee not in label2id_map.values(): raise ValueError('label to id map missing id: {}'.format(idee)) model_config = ModelConfig(**vars(options)) decode_config = DecodeConfig(**vars(options)) eval_data = ValidationData( options.data_rspecifier, options.labels_rspecifier, label2id_map, batch_size=decode_config.batch_size, delta_order=model_config.delta_order, cmvn_rxfilename=model_config.cmvn_rxfilename, ) with redirect_stdout_to_stderr(): logger.log(9, 'Creating model') from pydrobert.mol.model import ConvCTC model = ConvCTC(model_config) logger.log(9, 'Calculating loss') loss = model.evaluate_generator(decode_config, eval_data) logger.log(9, 'Calculated loss') print('Model loss:', loss)
def find_best_model_from_log(args=None): '''Find the best model from training log and return its path on stdout In a given training stage, the 'best' model is the one that minimizes or maximizes the target quantity (dependent on 'mode' and 'monitored'). ''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(sys.argv[0]) options = _find_best_model_from_log_parse_args(args, logger) logger.log(9, 'Parsed options') csv_file = open(options.csv_path, mode='r') csv_reader = DictReader(csv_file, delimiter=options.csv_delimiter) fields = csv_reader.fieldnames if 'model_path' not in fields: logger.error('"model_path" must be a field in the csv') return 1 if options.monitored is None: if 'val_acc' in fields: options.monitored = 'val_acc' elif 'val_loss' in fields: options.monitored = 'val_loss' elif 'acc' in fields: options.monitored = 'acc' elif 'loss' in fields: options.monitored = 'loss' else: logger.error( 'Unable to find suitable value to monitor. Set --monitored ' 'manually') return 1 logger.log(9, 'Using "{}" as monitored value'.format(options.monitored)) elif options.monitored not in fields: logger.error('Monitored value "{}" not in CSV'.format( options.monitored)) return 1 if options.mode == 'min' or (options.mode == 'auto' and 'acc' not in options.monitored): monitor_op = np.less best = np.Inf else: # max monitor_op = np.greater best = -np.Inf if options.training_stage and 'training_stage' not in fields: logger.error( '--training-stage specified but training_stage not in csv') return 1 logger.log(9, 'Looking through CSV') best_path = None for row in csv_reader: if options.training_stage and (row['training_stage'] != options.training_stage): continue current = float(row[options.monitored]) if monitor_op(current, best): best = current best_path = row['model_path'] logger.log(9, 'Looked through CSV') if best_path is None: logger.error('Could not find any model') return 1 else: print(best_path) return 0
def compute_error_rate(args=None): '''Compute error rates between reference and hypothesis token vectors Two common error rates in speech are the word (WER) and phone (PER), though the computation is the same. Given a reference and hypothesis sequence, the error rate is >>> error_rate = (substitutions + insertions + deletions) / ( ... ref_tokens * 100) Where the number of substitutions (e.g. ``A B C -> A D C``), deletions (e.g. ``A B C -> A C``), and insertions (e.g. ``A B C -> A D B C``) are determined by Levenshtein distance. ''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(sys.argv[0]) options = _compute_error_rate_parse_args(args, logger) global_edit = 0 global_token_count = 0 global_sents = 0 global_processed = 0 inserts = dict() deletes = dict() subs = dict() totals = dict() def _err_on_utt_id(utt_id, missing_rxspecifier): msg = "Utterance '{}' absent in '{}'".format( utt_id, missing_rxspecifier) if options.strict: logger.error(msg) return 1 else: logger.warning(msg) return 0 return_tables = options.print_tables or not options.include_inserts_in_cost with kaldi_open(options.ref_rspecifier, 'tv') as ref_table, \ kaldi_open(options.hyp_rspecifier, 'tv') as hyp_table: while not ref_table.done() and not hyp_table.done(): global_sents += 1 if ref_table.key() > hyp_table.key(): if _err_on_utt_id(hyp_table.key(), options.ref_rspecifier): return 1 hyp_table.move() elif hyp_table.key() > ref_table.key(): if _err_on_utt_id(ref_table.key(), options.hyp_rspecifier): return 1 ref_table.move() else: logger.debug('Processing {}: ref [{}] hyp [{}]'.format( ref_table.key(), ' '.join(ref_table.value()), ' '.join(hyp_table.value()))) global_token_count += len(ref_table.value()) res = kaldi_eval_util.edit_distance( ref_table.value(), hyp_table.value(), return_tables=return_tables, insertion_cost=options.insertion_cost, deletion_cost=options.deletion_cost, substitution_cost=options.substitution_cost, ) if return_tables: global_edit += res[0] for global_dict, utt_dict in zip( (inserts, deletes, subs, totals), res[1:]): for token in ref_table.value() + hyp_table.value(): global_dict.setdefault(token, 0) for token, count in utt_dict.items(): global_dict[token] += count else: global_edit += res global_processed += 1 ref_table.move() hyp_table.move() while not ref_table.done(): if _err_on_utt_id(ref_table.key(), options.hyp_rspecifier): return 1 global_sents += 1 ref_table.move() while not hyp_table.done(): if _err_on_utt_id(hyp_table.key(), options.ref_rspecifier): return 1 global_sents += 1 hyp_table.move() if options.out_path is None: out_file = sys.stdout else: out_file = open(options.out_path, 'w') print( "Processed {}/{}.".format(global_processed, global_sents), file=out_file, end=' ' ) if not options.include_inserts_in_cost: global_edit -= sum(inserts.values()) if options.report_accuracy: print( 'Accuracy: {:.2f}%'.format( (1 - global_edit / global_token_count) * 100), file=out_file, ) else: print( 'Error rate: {:.2f}%'.format( global_edit / global_token_count * 100), file=out_file, ) if options.print_tables: print( "Total insertions: {}, deletions: {}, substitutions: {}".format( sum(inserts.values()), sum(deletes.values()), sum(subs.values())), file=out_file, ) print("", file=out_file) tokens = list(set(inserts) | set(deletes) | set(subs)) tokens.sort() token_len = max(max(len(token) for token in tokens), 5) max_count = max( chain(inserts.values(), deletes.values(), subs.values())) max_count_len = int(log10(max_count) + 1) divider_str = '+' + ('-' * (token_len + 1)) divider_str += ('+' + ('-' * (max_count_len + 9))) * 4 divider_str += '+' format_str = '|{{:<{}}}|'.format(token_len + 1) format_str += 4 * '{{:>{}}}({{:05.2f}}%)|'.format(max_count_len + 1) print( '|{2:<{0}}|{3:>{1}}(%)|{4:>{1}}(%)|{5:>{1}}(%)|{6:>{1}}(%)|' ''.format( token_len + 1, max_count_len + 6, 'token', 'inserts', 'deletes', 'subs', 'errs', ), file=out_file, ) print(divider_str, file=out_file) print(divider_str, file=out_file) for token in tokens: i, d, s = inserts[token], deletes[token], subs[token] t = totals[token] print( format_str.format( token, i, i / t * 100, d, d / t * 100, s, s / t * 100, i + d + s, (i + d + s) / t * 100, ), file=out_file ) print(divider_str, file=out_file) return 0
def write_table_to_pickle(args=None): '''Write a kaldi table to pickle file(s) The inverse is write-pickle-to-table ''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(logger) try: options = _write_table_to_pickle_parse_args(args, logger) except SystemExit as ex: return ex.code out_type = options.out_type if out_type is None: if options.in_type.is_floating_point: if options.in_type.is_double: out_type = np.float64 else: out_type = np.float32 else: out_type = np.str from six.moves import cPickle as pickle try: reader = kaldi_open(options.rspecifier, options.in_type, 'r') if options.value_out.endswith('.gz'): import gzip value_out = gzip.open(options.value_out, 'wb') else: value_out = open(options.value_out, 'wb') if options.key_out: if options.key_out.endswith('.gz'): import gzip key_out = gzip.open(options.key_out, 'wt') else: key_out = open(options.key_out, 'w') else: key_out = None except IOError as error: logger.error(error.message, exc_info=True) return 1 num_entries = 0 try: for key, value in reader.items(): num_entries += 1 if not np.issubdtype(out_type, np.dtype(str).type): value = value.astype(out_type) if key_out: pickle.dump(value, value_out) pickle.dump(key, key_out) else: pickle.dump((key, value), value_out) if num_entries % 10 == 0: logger.info('Processed {} entries'.format(num_entries)) logger.debug('Processed key {}'.format(key)) except (IOError, ValueError) as error: logger.error(error.message, exc_info=True) return 1 finally: value_out.close() if key_out: key_out.close() if num_entries == 0: logger.warning("No entries were written (table was empty)") else: logger.info("Wrote {} entries".format(num_entries)) return 0
def write_torch_dir_to_table(args=None): '''Write a data directory containing PyTorch data files to a Kaldi table Reads from a folder in the format: :: folder/ <file_prefix><key_1><file_suffix> <file_prefix><key_2><file_suffix> ... Where each file contains a PyTorch tensor. The contents of the file ``<file_prefix><key_1><file_suffix>`` will be written as a value in a Kaldi table with key ``<key_1>`` ''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(logger) try: options = _write_torch_dir_to_table_parse_args(args, logger) except SystemExit as ex: return ex.code if not os.path.isdir(options.dir): print("'{}' is not a directory".format(options.dir), file=sys.stderr) return 1 import torch is_bool = False if options.out_type in { enums.KaldiDataType.BaseMatrix, enums.KaldiDataType.BaseVector, enums.KaldiDataType.WaveMatrix, enums.KaldiDataType.Base, enums.KaldiDataType.BasePairVector }: if options.out_type.is_double: torch_type = torch.double else: torch_type = torch.float elif options.out_type in { enums.KaldiDataType.FloatMatrix, enums.KaldiDataType.FloatVector }: torch_type = torch.float elif options.out_type in { enums.KaldiDataType.DoubleMatrix, enums.KaldiDataType.Double }: torch_type = torch.double elif options.out_type in { enums.KaldiDataType.Int32, enums.KaldiDataType.Int32Vector, enums.KaldiDataType.Int32VectorVector }: torch_type = torch.int elif options.out_type == enums.KaldiDataType.Boolean: torch_type = torch.uint8 is_bool = True else: print('Do not know how to convert {} from torch type'.format( options.out_type), file=sys.stderr) return 1 neg_fsl = -len(options.file_suffix) if not neg_fsl: neg_fsl = None fpl = len(options.file_prefix) utt_ids = sorted( os.path.basename(x)[fpl:neg_fsl] for x in os.listdir(options.dir) if x.startswith(options.file_prefix) and x.endswith(options.file_suffix)) with kaldi_open(options.wspecifier, options.out_type, mode='w') as table: for utt_id in utt_ids: val = torch.load( os.path.join( options.dir, options.file_prefix + utt_id + options.file_suffix)) val = val.cpu().type(torch_type).numpy() if is_bool: val = bool(val) # make sure val is a scalar! table.write(utt_id, val) return 0
def compute_error_rate(args=None): '''Compute error rates between reference and hypothesis token vectors Two common error rates in speech are the word (WER) and phone (PER), though the computation is the same. Given a reference and hypothesis sequence, the error rate is >>> error_rate = (substitutions + insertions + deletions) / ( ... ref_tokens * 100) Where the number of substitutions (e.g. ``A B C -> A D C``), deletions (e.g. ``A B C -> A C``), and insertions (e.g. ``A B C -> A D B C``) are determined by Levenshtein distance. ''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(sys.argv[0]) options = _compute_error_rate_parse_args(args, logger) global_edit = 0 global_token_count = 0 global_sents = 0 global_processed = 0 inserts = dict() deletes = dict() subs = dict() totals = dict() def _err_on_utt_id(utt_id, missing_rxspecifier): msg = "Utterance '{}' absent in '{}'".format(utt_id, missing_rxspecifier) if options.strict: logger.error(msg) return 1 else: logger.warning(msg) return 0 return_tables = options.print_tables or not options.include_inserts_in_cost with kaldi_open(options.ref_rspecifier, 'tv') as ref_table, \ kaldi_open(options.hyp_rspecifier, 'tv') as hyp_table: while not ref_table.done() and not hyp_table.done(): global_sents += 1 if ref_table.key() > hyp_table.key(): if _err_on_utt_id(hyp_table.key(), options.ref_rspecifier): return 1 hyp_table.move() elif hyp_table.key() > ref_table.key(): if _err_on_utt_id(ref_table.key(), options.hyp_rspecifier): return 1 ref_table.move() else: logger.debug('Processing {}: ref [{}] hyp [{}]'.format( ref_table.key(), ' '.join(ref_table.value()), ' '.join(hyp_table.value()))) global_token_count += len(ref_table.value()) res = kaldi_eval_util.edit_distance( ref_table.value(), hyp_table.value(), return_tables=return_tables, insertion_cost=options.insertion_cost, deletion_cost=options.deletion_cost, substitution_cost=options.substitution_cost, ) if return_tables: global_edit += res[0] for global_dict, utt_dict in zip( (inserts, deletes, subs, totals), res[1:]): for token in ref_table.value() + hyp_table.value(): global_dict.setdefault(token, 0) for token, count in utt_dict.items(): global_dict[token] += count else: global_edit += res global_processed += 1 ref_table.move() hyp_table.move() while not ref_table.done(): if _err_on_utt_id(ref_table.key(), options.hyp_rspecifier): return 1 global_sents += 1 ref_table.move() while not hyp_table.done(): if _err_on_utt_id(hyp_table.key(), options.ref_rspecifier): return 1 global_sents += 1 hyp_table.move() if options.out_path is None: out_file = sys.stdout else: out_file = open(options.out_path, 'w') print("Processed {}/{}.".format(global_processed, global_sents), file=out_file, end=' ') if not options.include_inserts_in_cost: global_edit -= sum(inserts.values()) if options.report_accuracy: print( 'Accuracy: {:.2f}%'.format( (1 - global_edit / global_token_count) * 100), file=out_file, ) else: print( 'Error rate: {:.2f}%'.format(global_edit / global_token_count * 100), file=out_file, ) if options.print_tables: print( "Total insertions: {}, deletions: {}, substitutions: {}".format( sum(inserts.values()), sum(deletes.values()), sum(subs.values())), file=out_file, ) print("", file=out_file) tokens = list(set(inserts) | set(deletes) | set(subs)) tokens.sort() token_len = max(max(len(token) for token in tokens), 5) max_count = max( chain(inserts.values(), deletes.values(), subs.values())) max_count_len = int(log10(max_count) + 1) divider_str = '+' + ('-' * (token_len + 1)) divider_str += ('+' + ('-' * (max_count_len + 9))) * 4 divider_str += '+' format_str = '|{{:<{}}}|'.format(token_len + 1) format_str += 4 * '{{:>{}}}({{:05.2f}}%)|'.format(max_count_len + 1) print( '|{2:<{0}}|{3:>{1}}(%)|{4:>{1}}(%)|{5:>{1}}(%)|{6:>{1}}(%)|' ''.format( token_len + 1, max_count_len + 6, 'token', 'inserts', 'deletes', 'subs', 'errs', ), file=out_file, ) print(divider_str, file=out_file) print(divider_str, file=out_file) for token in tokens: i, d, s = inserts[token], deletes[token], subs[token] t = totals[token] print(format_str.format( token, i, i / t * 100, d, d / t * 100, s, s / t * 100, i + d + s, (i + d + s) / t * 100, ), file=out_file) print(divider_str, file=out_file) return 0
def train_cnn_ctc(args=None): '''Train CNN w/ CTC decoding using kaldi data tables''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(sys.argv[0]) options = _train_cnn_ctc_parse_args(args, logger) logger.log(9, 'Parsed options') label2id_map = dict() with open(options.label_to_id_map_path) as file_obj: for line in file_obj: label, idee = line.strip().split() idee = int(idee) if idee < 0: logger.error('All label ids must be nonnegative') return 1 label2id_map[label] = idee if (len(label2id_map) + 1) != options.num_labels: logger.error('Expected {} labels in label_to_id_map, got {}'.format( options.num_labels - 1, len(label2id_map))) return 1 for idee in range(options.num_labels - 1): if idee not in label2id_map.values(): raise ValueError('label to id map missing id: {}'.format(idee)) logger.log(9, 'Loaded label to id map') model_config = ModelConfig(**vars(options)) train_config = TrainConfig(**vars(options)) train_data = TrainData( (options.data_rspecifier, 'bm', { 'cache': train_config.cache }), (options.labels_rspecifier, 'tv', { 'cache': train_config.cache }), label2id_map, batch_size=train_config.batch_size, delta_order=model_config.delta_order, cmvn_rxfilename=model_config.cmvn_rxfilename, rng=train_config.train_seed, ) if options.val_data_rspecifier or options.val_labels_rspecifier: if None in (options.val_data_rspecifier, options.val_labels_rspecifier): logger.error( "Both 'val_data_rspecifier' and 'val_labels_rspecifier' must " "be specified, or neither") return 1 val_data = ValidationData( options.val_data_rspecifier, options.val_labels_rspecifier, label2id_map, batch_size=train_config.batch_size, delta_order=model_config.delta_order, cmvn_rxfilename=model_config.cmvn_rxfilename, ) else: val_data = None logger.log(9, 'Set up training/validation data generators') with redirect_stdout_to_stderr(): logger.log(9, 'Creating model') from pydrobert.mol.model import ConvCTC model = ConvCTC(model_config) logger.log(9, 'Beginning training') model.fit_generator(train_config, train_data, val_data=val_data) logger.log(9, 'Finished training')
def write_table_to_torch_dir(args=None): '''Write a Kaldi table to a series of PyTorch data files in a directory Writes to a folder in the format: :: folder/ <file_prefix><key_1><file_suffix> <file_prefix><key_2><file_suffix> ... The contents of the file ``<file_prefix><key_1><file_suffix>`` will be a PyTorch tensor corresponding to the entry in the table for ``<key_1>`` ''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(logger) try: options = _write_table_to_torch_dir_parse_args(args, logger) except SystemExit as ex: return ex.code out_type = options.out_type if out_type is None: if options.in_type in { enums.KaldiDataType.BaseMatrix, enums.KaldiDataType.BaseVector, enums.KaldiDataType.WaveMatrix, enums.KaldiDataType.Base, enums.KaldiDataType.BasePairVector }: if options.in_type.is_double: out_type = 'double' else: out_type = 'float' elif options.in_type in { enums.KaldiDataType.FloatMatrix, enums.KaldiDataType.FloatVector }: out_type = 'float' elif options.in_type in { enums.KaldiDataType.DoubleMatrix, enums.KaldiDataType.Double }: out_type = 'double' elif options.in_type in { enums.KaldiDataType.Int32, enums.KaldiDataType.Int32Vector, enums.KaldiDataType.Int32VectorVector }: out_type = 'int' elif options.in_type == enums.KaldiDataType.Boolean: out_type = 'byte' else: print('Do not know how to convert {} to torch type'.format( options.in_type), file=sys.stderr) return 1 import torch if out_type == 'float': out_type = torch.float elif out_type == 'double': out_type = torch.double elif out_type == 'half': out_type = torch.half elif out_type == 'byte': out_type = torch.uint8 elif out_type == 'char': out_type = torch.int8 elif out_type == 'short': out_type = torch.short elif out_type == 'int': out_type = torch.int elif out_type == 'long': out_type = torch.long try: os.makedirs(options.dir) except FileExistsError: pass with kaldi_open(options.rspecifier, options.in_type) as table: for key, value in table.items(): value = torch.tensor(value).type(out_type) torch.save( value, os.path.join(options.dir, options.file_prefix + key + options.file_suffix)) return 0
def normalize_feat_lens(args=None): '''Ensure features match some reference lengths Incoming features are either clipped or padded to match reference lengths (stored as an int32 table), if they are within tolerance. ''' logger = logging.getLogger(sys.argv[0]) if not logger.handlers: logger.addHandler(logging.StreamHandler()) register_logger_for_kaldi(sys.argv[0]) options = _normalize_feat_lens_parse_args(args, logger) if options.pad_mode == 'zero': options.pad_mode = 'constant' feats_in = kaldi_open(options.feats_in_rspecifier, options.type, mode='r') len_in = kaldi_open(options.len_in_rspecifier, 'i', mode='r+') feats_out = kaldi_open(options.feats_out_wspecifier, options.type, mode='w') total_utts = 0 processed_utts = 0 for utt_id, feats in feats_in.items(): total_utts += 1 if utt_id not in len_in: msg = "Utterance '{}' absent in '{}'".format( utt_id, options.len_in_rspecifier) if options.strict: logger.error(msg) return 1 else: logger.warning(msg) continue exp_feat_len = len_in[utt_id] act_feat_len = len(feats) logger.debug('{} exp len: {} act len: {}'.format( utt_id, exp_feat_len, act_feat_len)) if act_feat_len < exp_feat_len: if act_feat_len < exp_feat_len - options.tolerance: msg = '{} has feature length {}, which is below the ' msg += 'tolerance ({}) of the expected length {}' msg = msg.format(utt_id, act_feat_len, options.tolerance, exp_feat_len) if options.strict: logger.error(msg) return 1 else: logger.warning(msg) continue # for matrices or vectors, this cast shouldn't be necessary. # If the user tries some special type like token vectors, # however, this *might* work as intended feats = np.array(feats, copy=False) pad_list = [(0, 0)] * len(feats.shape) if options.side == 'right': pad_list[0] = (0, exp_feat_len - act_feat_len) elif options.side == 'left': pad_list[0] = (exp_feat_len - act_feat_len, 0) else: pad_list[0] = ((exp_feat_len - act_feat_len) // 2, (exp_feat_len - act_feat_len + 1) // 2) feats = np.pad(feats, pad_list, options.pad_mode) elif act_feat_len > exp_feat_len: if act_feat_len > exp_feat_len + options.tolerance: msg = '{} has feature length {}, which is above the ' msg += 'tolerance ({}) of the expected length {}' msg = msg.format(utt_id, act_feat_len, options.tolerance, exp_feat_len) if options.strict: logger.error(msg) return 1 else: logger.warning(msg) continue if options.side == 'right': feats = feats[:exp_feat_len - act_feat_len] elif options.side == 'left': feats = feats[exp_feat_len - act_feat_len:] else: feats = feats[(exp_feat_len - act_feat_len) // 2:(exp_feat_len - act_feat_len + 1) // 2] feats_out.write(utt_id, feats) processed_utts += 1 logger.info('Processed {}/{} utterances'.format(processed_utts, total_utts)) return 0