def cb(frame_len, orth): if frame_len >= options.max_seq_frame_len: return orth_syms = parse_orthography(orth) if len(orth_syms) >= options.max_seq_orth_len: return Stats.count += 1 Stats.total_frame_len += frame_len if options.dump_orth_syms: print("Orth:", "".join(orth_syms), file=log.v3) if options.filter_orth_sym: if options.filter_orth_sym in orth_syms: print("Found orth:", "".join(orth_syms), file=log.v3) if options.filter_orth_syms_seq: filter_seq = parse_orthography_into_symbols( options.filter_orth_syms_seq) if found_sub_seq(filter_seq, orth_syms): print("Found orth:", "".join(orth_syms), file=log.v3) Stats.orth_syms_set.update(orth_syms) Stats.total_orth_len += len(orth_syms) # Show some progress if it takes long. if time.time() - Stats.process_last_time > 2: Stats.process_last_time = time.time() if options.collect_time: print("Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)), file=log.v3) else: print("Collect process, total orth len so far:", human_size(Stats.total_orth_len), file=log.v3)
def cb(frame_len, orth): if frame_len >= options.max_seq_frame_len: return orth_syms = parse_orthography(orth) if len(orth_syms) >= options.max_seq_orth_len: return Stats.count += 1 Stats.total_frame_len += frame_len if options.dump_orth_syms: print >> log.v3, "Orth:", "".join(orth_syms) if options.filter_orth_sym: if options.filter_orth_sym in orth_syms: print >> log.v3, "Found orth:", "".join(orth_syms) if options.filter_orth_syms_seq: filter_seq = parse_orthography_into_symbols(options.filter_orth_syms_seq) if found_sub_seq(filter_seq, orth_syms): print >> log.v3, "Found orth:", "".join(orth_syms) Stats.orth_syms_set.update(orth_syms) Stats.total_orth_len += len(orth_syms) # Show some progress if it takes long. if time.time() - Stats.process_last_time > 2: Stats.process_last_time = time.time() if options.collect_time: print >> log.v3, "Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)) else: print >> log.v3, "Collect process, total orth len so far:", human_size(Stats.total_orth_len)
def _callback(self, orth): orth_words = parse_orthography(orth, prefix=[], postfix=[], word_based=True) self.seq_count += 1 if self.options.dump_orth: print("Orth:", orth_words, file=log.v3) self.words.update(orth_words) self.total_word_len += len(orth_words) # Show some progress if it takes long. if time.time() - self.process_last_time > 2: self.process_last_time = time.time() print("Collect process, total word len so far:", human_size(self.total_word_len), file=log.v3)
def __init__(self, options, iter_corpus): """ :param options: argparse.Namespace """ self.options = options self.seq_count = 0 self.words = set() self.total_word_len = 0 self.process_last_time = time.time() iter_corpus(self._callback) print("Total word len:", self.total_word_len, "(%s)" % human_size(self.total_word_len), file=log.v3) print("Average orth len:", float(self.total_word_len) / self.seq_count, file=log.v3) print("Num word symbols:", len(self.words), file=log.v3)
def hdf_dump_from_dataset(dataset, hdf_dataset, parser_args): """ :param Dataset dataset: could be any dataset implemented as child of Dataset :type hdf_dataset: h5py._hl.files.File :param parser_args: argparse object from main() :return: """ print("Work on epoch: %i" % parser_args.epoch, file=log.v3) dataset.init_seq_order(parser_args.epoch) data_keys = sorted(dataset.get_data_keys()) print("Data keys:", data_keys, file=log.v3) if "orth" in data_keys: data_keys.remove("orth") # We need to do one run through the dataset to collect some stats like total len. print("Collect stats, iterate through all data...", file=log.v3) seq_idx = parser_args.start_seq seq_idxs = [] seq_tags = [] seq_lens = [] total_seq_len = NumbersDict(0) max_tag_len = 0 dataset_num_seqs = try_run(lambda: dataset.num_seqs, default=None) # can be unknown if parser_args.end_seq != float("inf"): if dataset_num_seqs is not None: dataset_num_seqs = min(dataset_num_seqs, parser_args.end_seq) else: dataset_num_seqs = parser_args.end_seq if dataset_num_seqs is not None: dataset_num_seqs -= parser_args.start_seq assert dataset_num_seqs > 0 while dataset.is_less_than_num_seqs( seq_idx) and seq_idx <= parser_args.end_seq: seq_idxs += [seq_idx] dataset.load_seqs(seq_idx, seq_idx + 1) seq_len = dataset.get_seq_length(seq_idx) seq_lens += [seq_len] tag = dataset.get_tag(seq_idx) seq_tags += [tag] max_tag_len = max(len(tag), max_tag_len) total_seq_len += seq_len if dataset_num_seqs is not None: progress_bar_with_time( float(seq_idx - parser_args.start_seq) / dataset_num_seqs) seq_idx += 1 num_seqs = len(seq_idxs) assert num_seqs > 0 shapes = {} for data_key in data_keys: assert data_key in total_seq_len.dict shape = [total_seq_len[data_key]] shape += dataset.get_data_shape(data_key) print("Total len of %r is %s, shape %r, dtype %s" % (data_key, human_size( shape[0]), shape, dataset.get_data_dtype(data_key)), file=log.v3) shapes[data_key] = shape print("Set seq tags...", file=log.v3) hdf_dataset.create_dataset('seqTags', shape=(num_seqs, ), dtype="S%i" % (max_tag_len + 1)) for i, tag in enumerate(seq_tags): hdf_dataset['seqTags'][i] = numpy.array(tag, dtype="S%i" % (max_tag_len + 1)) progress_bar_with_time(float(i) / num_seqs) print("Set seq len info...", file=log.v3) hdf_dataset.create_dataset(HDFDataset.attr_seqLengths, shape=(num_seqs, 2), dtype="int32") for i, seq_len in enumerate(seq_lens): data_len = seq_len["data"] targets_len = seq_len["classes"] for data_key in dataset.get_target_list(): if data_key == "orth": continue assert seq_len[ data_key] == targets_len, "different lengths in multi-target not supported" if targets_len is None: targets_len = data_len hdf_dataset[HDFDataset.attr_seqLengths][i] = [data_len, targets_len] progress_bar_with_time(float(i) / num_seqs) print("Create arrays in HDF...", file=log.v3) hdf_dataset.create_group('targets/data') hdf_dataset.create_group('targets/size') hdf_dataset.create_group('targets/labels') for data_key in data_keys: if data_key == "data": hdf_dataset.create_dataset('inputs', shape=shapes[data_key], dtype=dataset.get_data_dtype(data_key)) else: hdf_dataset['targets/data'].create_dataset( data_key, shape=shapes[data_key], dtype=dataset.get_data_dtype(data_key)) hdf_dataset['targets/size'].attrs[data_key] = dataset.num_outputs[ data_key] if data_key in dataset.labels: labels = dataset.labels[data_key] assert len(labels) == dataset.num_outputs[data_key][0] else: labels = [ "%s-class-%i" % (data_key, i) for i in range(dataset.get_data_dim(data_key)) ] print("Labels for %s:" % data_key, labels[:3], "...", file=log.v5) max_label_len = max(map(len, labels)) if data_key != "data": hdf_dataset['targets/labels'].create_dataset( data_key, (len(labels), ), dtype="S%i" % (max_label_len + 1)) for i, label in enumerate(labels): hdf_dataset['targets/labels'][data_key][i] = numpy.array( label, dtype="S%i" % (max_label_len + 1)) # Again iterate through dataset, and set the data print("Write data...", file=log.v3) dataset.init_seq_order(parser_args.epoch) offsets = NumbersDict(0) for seq_idx, tag in zip(seq_idxs, seq_tags): dataset.load_seqs(seq_idx, seq_idx + 1) tag_ = dataset.get_tag(seq_idx) assert tag == tag_ # Just a check for sanity. We expect the same order. seq_len = dataset.get_seq_length(seq_idx) for data_key in data_keys: if data_key == "data": hdf_data = hdf_dataset['inputs'] else: hdf_data = hdf_dataset['targets/data'][data_key] data = dataset.get_data(seq_idx, data_key) hdf_data[offsets[data_key]:offsets[data_key] + seq_len[data_key]] = data progress_bar_with_time(float(offsets["data"]) / total_seq_len["data"]) offsets += seq_len assert offsets == total_seq_len # Sanity check. # Set some old-format attribs. Not needed for newer CRNN versions. hdf_dataset.attrs[HDFDataset.attr_inputPattSize] = dataset.num_inputs hdf_dataset.attrs[HDFDataset.attr_numLabels] = dataset.num_outputs.get( "classes", (0, 0))[0] print("All done.", file=log.v3)
def hdf_dump_from_dataset(dataset, hdf_dataset, parser_args): """ :param Dataset dataset: could be any dataset implemented as child of Dataset :type hdf_dataset: h5py._hl.files.File :param parser_args: argparse object from main() :return: """ print >> log.v3, "Work on epoch: %i" % parser_args.epoch dataset.init_seq_order(parser_args.epoch) data_keys = sorted(dataset.get_data_keys()) print >> log.v3, "Data keys:", data_keys # We need to do one run through the dataset to collect some stats like total len. print >> log.v3, "Collect stats, iterate through all data..." seq_idx = parser_args.start_seq seq_idxs = [] seq_tags = [] seq_lens = [] total_seq_len = NumbersDict(0) max_tag_len = 0 dataset_num_seqs = try_run(lambda: dataset.num_seqs, default=None) # can be unknown if parser_args.end_seq != float("inf"): if dataset_num_seqs is not None: dataset_num_seqs = min(dataset_num_seqs, parser_args.end_seq) else: dataset_num_seqs = parser_args.end_seq if dataset_num_seqs is not None: dataset_num_seqs -= parser_args.start_seq assert dataset_num_seqs > 0 while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= parser_args.end_seq: seq_idxs += [seq_idx] dataset.load_seqs(seq_idx, seq_idx + 1) seq_len = dataset.get_seq_length(seq_idx) seq_lens += [seq_len] tag = dataset.get_tag(seq_idx) seq_tags += [tag] max_tag_len = max(len(tag), max_tag_len) total_seq_len += seq_len if dataset_num_seqs is not None: progress_bar_with_time(float(seq_idx - parser_args.start_seq) / dataset_num_seqs) seq_idx += 1 num_seqs = len(seq_idxs) assert num_seqs > 0 shapes = {} for data_key in data_keys: assert data_key in total_seq_len.dict shape = [total_seq_len[data_key]] shape += dataset.get_data_shape(data_key) print >> log.v3, "Total len of %r is %s, shape %r, dtype %s" % ( data_key, human_size(shape[0]), shape, dataset.get_data_dtype(data_key)) shapes[data_key] = shape print >> log.v3, "Set seq tags..." hdf_dataset.create_dataset('seqTags', shape=(num_seqs,), dtype="S%i" % (max_tag_len + 1)) for i, tag in enumerate(seq_tags): hdf_dataset['seqTags'][i] = tag progress_bar_with_time(float(i) / num_seqs) print >> log.v3, "Set seq len info..." hdf_dataset.create_dataset(HDFDataset.attr_seqLengths, shape=(num_seqs, 2), dtype="int32") for i, seq_len in enumerate(seq_lens): data_len = seq_len["data"] targets_len = seq_len["classes"] for data_key in dataset.get_target_list(): assert seq_len[data_key] == targets_len, "different lengths in multi-target not supported" if targets_len is None: targets_len = data_len hdf_dataset[HDFDataset.attr_seqLengths][i] = [data_len, targets_len] progress_bar_with_time(float(i) / num_seqs) print >> log.v3, "Create arrays in HDF..." hdf_dataset.create_group('targets/data') hdf_dataset.create_group('targets/size') hdf_dataset.create_group('targets/labels') for data_key in data_keys: if data_key == "data": hdf_dataset.create_dataset( 'inputs', shape=shapes[data_key], dtype=dataset.get_data_dtype(data_key)) else: hdf_dataset['targets/data'].create_dataset( data_key, shape=shapes[data_key], dtype=dataset.get_data_dtype(data_key)) hdf_dataset['targets/size'].attrs[data_key] = dataset.num_outputs[data_key] if data_key in dataset.labels: labels = dataset.labels[data_key] assert len(labels) == dataset.num_outputs[data_key][0] else: labels = ["%s-class-%i" % (data_key, i) for i in range(dataset.get_data_dim(data_key))] print >> log.v5, "Labels for %s:" % data_key, labels[:3], "..." max_label_len = max(map(len, labels)) hdf_dataset['targets/labels'].create_dataset(data_key, (len(labels),), dtype="S%i" % (max_label_len + 1)) for i, label in enumerate(labels): hdf_dataset['targets/labels'][data_key][i] = label # Again iterate through dataset, and set the data print >> log.v3, "Write data..." dataset.init_seq_order(parser_args.epoch) offsets = NumbersDict(0) for seq_idx, tag in zip(seq_idxs, seq_tags): dataset.load_seqs(seq_idx, seq_idx + 1) tag_ = dataset.get_tag(seq_idx) assert tag == tag_ # Just a check for sanity. We expect the same order. seq_len = dataset.get_seq_length(seq_idx) for data_key in data_keys: if data_key == "data": hdf_data = hdf_dataset['inputs'] else: hdf_data = hdf_dataset['targets/data'][data_key] data = dataset.get_data(seq_idx, data_key) hdf_data[offsets[data_key]:offsets[data_key] + seq_len[data_key]] = data progress_bar_with_time(float(offsets["data"]) / total_seq_len["data"]) offsets += seq_len assert offsets == total_seq_len # Sanity check. # Set some old-format attribs. Not needed for newer CRNN versions. hdf_dataset.attrs[HDFDataset.attr_inputPattSize] = dataset.num_inputs hdf_dataset.attrs[HDFDataset.attr_numLabels] = dataset.num_outputs.get("classes", (0, 0))[0] print >> log.v3, "All done."
def collect_stats(options, iter_corpus): """ :param options: argparse.Namespace """ orth_symbols_filename = options.output if orth_symbols_filename: assert not os.path.exists(orth_symbols_filename) class Stats: count = 0 process_last_time = time.time() total_frame_len = 0 total_orth_len = 0 orth_syms_set = set() if options.add_numbers: Stats.orth_syms_set.update( map(chr, list(range(ord("0"), ord("9") + 1)))) if options.add_lower_alphabet: Stats.orth_syms_set.update( map(chr, list(range(ord("a"), ord("z") + 1)))) if options.add_upper_alphabet: Stats.orth_syms_set.update( map(chr, list(range(ord("A"), ord("Z") + 1)))) def cb(frame_len, orth): if frame_len >= options.max_seq_frame_len: return orth_syms = parse_orthography(orth) if len(orth_syms) >= options.max_seq_orth_len: return Stats.count += 1 Stats.total_frame_len += frame_len if options.dump_orth_syms: print("Orth:", "".join(orth_syms), file=log.v3) if options.filter_orth_sym: if options.filter_orth_sym in orth_syms: print("Found orth:", "".join(orth_syms), file=log.v3) if options.filter_orth_syms_seq: filter_seq = parse_orthography_into_symbols( options.filter_orth_syms_seq) if found_sub_seq(filter_seq, orth_syms): print("Found orth:", "".join(orth_syms), file=log.v3) Stats.orth_syms_set.update(orth_syms) Stats.total_orth_len += len(orth_syms) # Show some progress if it takes long. if time.time() - Stats.process_last_time > 2: Stats.process_last_time = time.time() if options.collect_time: print("Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)), file=log.v3) else: print("Collect process, total orth len so far:", human_size(Stats.total_orth_len), file=log.v3) iter_corpus(cb) if options.remove_symbols: filter_syms = parse_orthography_into_symbols(options.remove_symbols) Stats.orth_syms_set -= set(filter_syms) if options.collect_time: print("Total frame len:", Stats.total_frame_len, "time:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)), file=log.v3) else: print("No time stats (--collect_time False).", file=log.v3) print("Total orth len:", Stats.total_orth_len, "(%s)" % human_size(Stats.total_orth_len), end=' ', file=log.v3) if options.collect_time: print("fraction:", float(Stats.total_orth_len) / Stats.total_frame_len, file=log.v3) else: print("", file=log.v3) print("Average orth len:", float(Stats.total_orth_len) / Stats.count, file=log.v3) print("Num symbols:", len(Stats.orth_syms_set), file=log.v3) if orth_symbols_filename: orth_syms_file = open(orth_symbols_filename, "wb") for orth_sym in sorted(Stats.orth_syms_set): orth_syms_file.write("%s\n" % unicode(orth_sym).encode("utf8")) orth_syms_file.close() print("Wrote orthography symbols to", orth_symbols_filename, file=log.v3) else: print("Provide --output to save the symbols.", file=log.v3)
def collect_stats(options, iter_corpus): """ :param options: argparse.Namespace """ orth_symbols_filename = options.output if orth_symbols_filename: assert not os.path.exists(orth_symbols_filename) class Stats: count = 0 process_last_time = time.time() total_frame_len = 0 total_orth_len = 0 orth_syms_set = set() if options.add_numbers: Stats.orth_syms_set.update(map(chr, range(ord("0"), ord("9") + 1))) if options.add_lower_alphabet: Stats.orth_syms_set.update(map(chr, range(ord("a"), ord("z") + 1))) if options.add_upper_alphabet: Stats.orth_syms_set.update(map(chr, range(ord("A"), ord("Z") + 1))) def cb(frame_len, orth): if frame_len >= options.max_seq_frame_len: return orth_syms = parse_orthography(orth) if len(orth_syms) >= options.max_seq_orth_len: return Stats.count += 1 Stats.total_frame_len += frame_len if options.dump_orth_syms: print >> log.v3, "Orth:", "".join(orth_syms) if options.filter_orth_sym: if options.filter_orth_sym in orth_syms: print >> log.v3, "Found orth:", "".join(orth_syms) if options.filter_orth_syms_seq: filter_seq = parse_orthography_into_symbols(options.filter_orth_syms_seq) if found_sub_seq(filter_seq, orth_syms): print >> log.v3, "Found orth:", "".join(orth_syms) Stats.orth_syms_set.update(orth_syms) Stats.total_orth_len += len(orth_syms) # Show some progress if it takes long. if time.time() - Stats.process_last_time > 2: Stats.process_last_time = time.time() if options.collect_time: print >> log.v3, "Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)) else: print >> log.v3, "Collect process, total orth len so far:", human_size(Stats.total_orth_len) iter_corpus(cb) if options.remove_symbols: filter_syms = parse_orthography_into_symbols(options.remove_symbols) Stats.orth_syms_set -= set(filter_syms) if options.collect_time: print >> log.v3, "Total frame len:", Stats.total_frame_len, "time:", hms(Stats.total_frame_len * (options.frame_time / 1000.0)) else: print >> log.v3, "No time stats (--collect_time False)." print >> log.v3, "Total orth len:", Stats.total_orth_len, "(%s)" % human_size(Stats.total_orth_len), if options.collect_time: print >> log.v3, "fraction:", float(Stats.total_orth_len) / Stats.total_frame_len else: print >> log.v3, "" print >> log.v3, "Average orth len:", float(Stats.total_orth_len) / Stats.count print >> log.v3, "Num symbols:", len(Stats.orth_syms_set) if orth_symbols_filename: orth_syms_file = open(orth_symbols_filename, "wb") for orth_sym in sorted(Stats.orth_syms_set): orth_syms_file.write("%s\n" % unicode(orth_sym).encode("utf8")) orth_syms_file.close() print >> log.v3, "Wrote orthography symbols to", orth_symbols_filename else: print >> log.v3, "Provide --output to save the symbols."
def dump_from_dataset(self, dataset, epoch=1, start_seq=0, end_seq=float("inf"), use_progress_bar=True): """ :param Dataset dataset: could be any dataset implemented as child of Dataset :param int epoch: for dataset :param int start_seq: :param int|float end_seq: :param bool use_progress_bar: """ from Util import NumbersDict, human_size, progress_bar_with_time, try_run, PY3 hdf_dataset = self.file print("Work on epoch: %i" % epoch, file=log.v3) dataset.init_seq_order(epoch) data_keys = sorted(dataset.get_data_keys()) print("Data keys:", data_keys, file=log.v3) if "orth" in data_keys: # special workaround for now, not handled data_keys.remove("orth") data_target_keys = [key for key in dataset.get_target_list() if key in data_keys] data_input_keys = [key for key in data_keys if key not in data_target_keys] assert len(data_input_keys) > 0 and len(data_target_keys) > 0 if len(data_input_keys) > 1: if "data" in data_input_keys: default_data_input_key = "data" else: raise Exception("not sure which input data key to use from %r" % (data_input_keys,)) else: default_data_input_key = data_input_keys[0] print("Using input data key:", default_data_input_key) if len(data_target_keys) > 1: if "classes" in data_target_keys: default_data_target_key = "classes" else: raise Exception("not sure which target data key to use from %r" % (data_target_keys,)) else: default_data_target_key = data_target_keys[0] print("Using target data key:", default_data_target_key) hdf_data_key_map = {key: key for key in data_keys if key != default_data_input_key} if "data" in hdf_data_key_map: hdf_data_key_map["data"] = "classes" # Replace "data" which is reserved for input key in HDFDataset. assert "classes" not in hdf_data_key_map # We need to do one run through the dataset to collect some stats like total len. print("Collect stats, iterate through all data...", file=log.v3) seq_idx = start_seq seq_idxs = [] seq_tags = [] seq_lens = [] total_seq_len = NumbersDict(0) max_tag_len = 0 dataset_num_seqs = try_run(lambda: dataset.num_seqs, default=None) # can be unknown if end_seq != float("inf"): if dataset_num_seqs is not None: dataset_num_seqs = min(dataset_num_seqs, end_seq) else: dataset_num_seqs = end_seq if dataset_num_seqs is not None: dataset_num_seqs -= start_seq assert dataset_num_seqs > 0 while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= end_seq: seq_idxs += [seq_idx] dataset.load_seqs(seq_idx, seq_idx + 1) seq_len = dataset.get_seq_length(seq_idx) seq_lens += [seq_len] tag = dataset.get_tag(seq_idx) seq_tags += [tag] max_tag_len = max(len(tag), max_tag_len) total_seq_len += seq_len if use_progress_bar and dataset_num_seqs is not None: progress_bar_with_time(float(seq_idx - start_seq) / dataset_num_seqs) seq_idx += 1 num_seqs = len(seq_idxs) assert num_seqs > 0 shapes = {} for data_key in data_keys: assert data_key in total_seq_len.dict shape = [total_seq_len[data_key]] shape += dataset.get_data_shape(data_key) print("Total len of %r is %s, shape %r, dtype %s" % ( data_key, human_size(shape[0]), shape, dataset.get_data_dtype(data_key)), file=log.v3) shapes[data_key] = shape print("Set seq tags...", file=log.v3) hdf_dataset.create_dataset('seqTags', shape=(num_seqs,), dtype="S%i" % (max_tag_len + 1)) for i, tag in enumerate(seq_tags): hdf_dataset['seqTags'][i] = numpy.array(tag, dtype="S%i" % (max_tag_len + 1)) if use_progress_bar: progress_bar_with_time(float(i) / num_seqs) print("Set seq len info...", file=log.v3) hdf_dataset.create_dataset(attr_seqLengths, shape=(num_seqs, 2), dtype="int32") for i, seq_len in enumerate(seq_lens): data_len = seq_len[default_data_input_key] targets_len = seq_len[default_data_target_key] for data_key in data_target_keys: assert seq_len[data_key] == targets_len, "different lengths in multi-target not supported" if targets_len is None: targets_len = data_len hdf_dataset[attr_seqLengths][i] = [data_len, targets_len] if use_progress_bar: progress_bar_with_time(float(i) / num_seqs) print("Create arrays in HDF...", file=log.v3) hdf_dataset.create_group('targets/data') hdf_dataset.create_group('targets/size') hdf_dataset.create_group('targets/labels') for data_key in data_keys: if data_key == default_data_input_key: hdf_dataset.create_dataset( 'inputs', shape=shapes[data_key], dtype=dataset.get_data_dtype(data_key)) else: hdf_dataset['targets/data'].create_dataset( hdf_data_key_map[data_key], shape=shapes[data_key], dtype=dataset.get_data_dtype(data_key)) hdf_dataset['targets/size'].attrs[hdf_data_key_map[data_key]] = dataset.num_outputs[data_key] if data_key in dataset.labels: labels = dataset.labels[data_key] if PY3: labels = [label.encode("utf8") for label in labels] assert len(labels) == dataset.num_outputs[data_key][0] else: labels = ["%s-class-%i" % (data_key, i) for i in range(dataset.get_data_dim(data_key))] print("Labels for %s:" % data_key, labels[:3], "...", file=log.v5) max_label_len = max(map(len, labels)) if data_key != default_data_input_key: hdf_dataset['targets/labels'].create_dataset(hdf_data_key_map[data_key], (len(labels),), dtype="S%i" % (max_label_len + 1)) for i, label in enumerate(labels): hdf_dataset['targets/labels'][hdf_data_key_map[data_key]][i] = numpy.array( label, dtype="S%i" % (max_label_len + 1)) # Again iterate through dataset, and set the data print("Write data...", file=log.v3) dataset.init_seq_order(epoch) offsets = NumbersDict(0) for seq_idx, tag in zip(seq_idxs, seq_tags): dataset.load_seqs(seq_idx, seq_idx + 1) tag_ = dataset.get_tag(seq_idx) assert tag == tag_ # Just a check for sanity. We expect the same order. seq_len = dataset.get_seq_length(seq_idx) for data_key in data_keys: if data_key == default_data_input_key: hdf_data = hdf_dataset['inputs'] else: hdf_data = hdf_dataset['targets/data'][hdf_data_key_map[data_key]] data = dataset.get_data(seq_idx, data_key) hdf_data[offsets[data_key]:offsets[data_key] + seq_len[data_key]] = data if use_progress_bar: progress_bar_with_time(float(offsets[default_data_input_key]) / total_seq_len[default_data_input_key]) offsets += seq_len assert offsets == total_seq_len # Sanity check. # Set some old-format attribs. Not needed for newer CRNN versions. hdf_dataset.attrs[attr_inputPattSize] = dataset.num_inputs hdf_dataset.attrs[attr_numLabels] = dataset.num_outputs.get(default_data_target_key, (0, 0))[0] print("All done.", file=log.v3)