예제 #1
0
    def cb(frame_len, orth):
        if frame_len >= options.max_seq_frame_len:
            return
        orth_syms = parse_orthography(orth)
        if len(orth_syms) >= options.max_seq_orth_len:
            return

        Stats.count += 1
        Stats.total_frame_len += frame_len

        if options.dump_orth_syms:
            print("Orth:", "".join(orth_syms), file=log.v3)
        if options.filter_orth_sym:
            if options.filter_orth_sym in orth_syms:
                print("Found orth:", "".join(orth_syms), file=log.v3)
        if options.filter_orth_syms_seq:
            filter_seq = parse_orthography_into_symbols(
                options.filter_orth_syms_seq)
            if found_sub_seq(filter_seq, orth_syms):
                print("Found orth:", "".join(orth_syms), file=log.v3)
        Stats.orth_syms_set.update(orth_syms)
        Stats.total_orth_len += len(orth_syms)

        # Show some progress if it takes long.
        if time.time() - Stats.process_last_time > 2:
            Stats.process_last_time = time.time()
            if options.collect_time:
                print("Collect process, total frame len so far:",
                      hms(Stats.total_frame_len *
                          (options.frame_time / 1000.0)),
                      file=log.v3)
            else:
                print("Collect process, total orth len so far:",
                      human_size(Stats.total_orth_len),
                      file=log.v3)
예제 #2
0
  def cb(frame_len, orth):
    if frame_len >= options.max_seq_frame_len:
      return
    orth_syms = parse_orthography(orth)
    if len(orth_syms) >= options.max_seq_orth_len:
      return

    Stats.count += 1
    Stats.total_frame_len += frame_len

    if options.dump_orth_syms:
      print >> log.v3, "Orth:", "".join(orth_syms)
    if options.filter_orth_sym:
      if options.filter_orth_sym in orth_syms:
        print >> log.v3, "Found orth:", "".join(orth_syms)
    if options.filter_orth_syms_seq:
      filter_seq = parse_orthography_into_symbols(options.filter_orth_syms_seq)
      if found_sub_seq(filter_seq, orth_syms):
        print >> log.v3, "Found orth:", "".join(orth_syms)
    Stats.orth_syms_set.update(orth_syms)
    Stats.total_orth_len += len(orth_syms)

    # Show some progress if it takes long.
    if time.time() - Stats.process_last_time > 2:
      Stats.process_last_time = time.time()
      if options.collect_time:
        print >> log.v3, "Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0))
      else:
        print >> log.v3, "Collect process, total orth len so far:", human_size(Stats.total_orth_len)
예제 #3
0
  def _callback(self, orth):
    orth_words = parse_orthography(orth, prefix=[], postfix=[], word_based=True)

    self.seq_count += 1

    if self.options.dump_orth:
      print("Orth:", orth_words, file=log.v3)
    self.words.update(orth_words)
    self.total_word_len += len(orth_words)

    # Show some progress if it takes long.
    if time.time() - self.process_last_time > 2:
      self.process_last_time = time.time()
      print("Collect process, total word len so far:", human_size(self.total_word_len), file=log.v3)
예제 #4
0
  def __init__(self, options, iter_corpus):
    """
    :param options: argparse.Namespace
    """

    self.options = options
    self.seq_count = 0
    self.words = set()
    self.total_word_len = 0
    self.process_last_time = time.time()

    iter_corpus(self._callback)

    print("Total word len:", self.total_word_len, "(%s)" % human_size(self.total_word_len), file=log.v3)
    print("Average orth len:", float(self.total_word_len) / self.seq_count, file=log.v3)
    print("Num word symbols:", len(self.words), file=log.v3)
예제 #5
0
    def _callback(self, orth):
        orth_words = parse_orthography(orth,
                                       prefix=[],
                                       postfix=[],
                                       word_based=True)

        self.seq_count += 1

        if self.options.dump_orth:
            print("Orth:", orth_words, file=log.v3)
        self.words.update(orth_words)
        self.total_word_len += len(orth_words)

        # Show some progress if it takes long.
        if time.time() - self.process_last_time > 2:
            self.process_last_time = time.time()
            print("Collect process, total word len so far:",
                  human_size(self.total_word_len),
                  file=log.v3)
예제 #6
0
    def __init__(self, options, iter_corpus):
        """
    :param options: argparse.Namespace
    """

        self.options = options
        self.seq_count = 0
        self.words = set()
        self.total_word_len = 0
        self.process_last_time = time.time()

        iter_corpus(self._callback)

        print("Total word len:",
              self.total_word_len,
              "(%s)" % human_size(self.total_word_len),
              file=log.v3)
        print("Average orth len:",
              float(self.total_word_len) / self.seq_count,
              file=log.v3)
        print("Num word symbols:", len(self.words), file=log.v3)
예제 #7
0
def hdf_dump_from_dataset(dataset, hdf_dataset, parser_args):
    """
  :param Dataset dataset: could be any dataset implemented as child of Dataset
  :type hdf_dataset: h5py._hl.files.File
  :param parser_args: argparse object from main()
  :return:
  """
    print("Work on epoch: %i" % parser_args.epoch, file=log.v3)
    dataset.init_seq_order(parser_args.epoch)

    data_keys = sorted(dataset.get_data_keys())
    print("Data keys:", data_keys, file=log.v3)
    if "orth" in data_keys:
        data_keys.remove("orth")

    # We need to do one run through the dataset to collect some stats like total len.
    print("Collect stats, iterate through all data...", file=log.v3)
    seq_idx = parser_args.start_seq
    seq_idxs = []
    seq_tags = []
    seq_lens = []
    total_seq_len = NumbersDict(0)
    max_tag_len = 0
    dataset_num_seqs = try_run(lambda: dataset.num_seqs,
                               default=None)  # can be unknown
    if parser_args.end_seq != float("inf"):
        if dataset_num_seqs is not None:
            dataset_num_seqs = min(dataset_num_seqs, parser_args.end_seq)
        else:
            dataset_num_seqs = parser_args.end_seq
    if dataset_num_seqs is not None:
        dataset_num_seqs -= parser_args.start_seq
        assert dataset_num_seqs > 0
    while dataset.is_less_than_num_seqs(
            seq_idx) and seq_idx <= parser_args.end_seq:
        seq_idxs += [seq_idx]
        dataset.load_seqs(seq_idx, seq_idx + 1)
        seq_len = dataset.get_seq_length(seq_idx)
        seq_lens += [seq_len]
        tag = dataset.get_tag(seq_idx)
        seq_tags += [tag]
        max_tag_len = max(len(tag), max_tag_len)
        total_seq_len += seq_len
        if dataset_num_seqs is not None:
            progress_bar_with_time(
                float(seq_idx - parser_args.start_seq) / dataset_num_seqs)
        seq_idx += 1
    num_seqs = len(seq_idxs)

    assert num_seqs > 0
    shapes = {}
    for data_key in data_keys:
        assert data_key in total_seq_len.dict
        shape = [total_seq_len[data_key]]
        shape += dataset.get_data_shape(data_key)
        print("Total len of %r is %s, shape %r, dtype %s" %
              (data_key, human_size(
                  shape[0]), shape, dataset.get_data_dtype(data_key)),
              file=log.v3)
        shapes[data_key] = shape

    print("Set seq tags...", file=log.v3)
    hdf_dataset.create_dataset('seqTags',
                               shape=(num_seqs, ),
                               dtype="S%i" % (max_tag_len + 1))
    for i, tag in enumerate(seq_tags):
        hdf_dataset['seqTags'][i] = numpy.array(tag,
                                                dtype="S%i" %
                                                (max_tag_len + 1))
        progress_bar_with_time(float(i) / num_seqs)

    print("Set seq len info...", file=log.v3)
    hdf_dataset.create_dataset(HDFDataset.attr_seqLengths,
                               shape=(num_seqs, 2),
                               dtype="int32")
    for i, seq_len in enumerate(seq_lens):
        data_len = seq_len["data"]
        targets_len = seq_len["classes"]
        for data_key in dataset.get_target_list():
            if data_key == "orth":
                continue
            assert seq_len[
                data_key] == targets_len, "different lengths in multi-target not supported"
        if targets_len is None:
            targets_len = data_len
        hdf_dataset[HDFDataset.attr_seqLengths][i] = [data_len, targets_len]
        progress_bar_with_time(float(i) / num_seqs)

    print("Create arrays in HDF...", file=log.v3)
    hdf_dataset.create_group('targets/data')
    hdf_dataset.create_group('targets/size')
    hdf_dataset.create_group('targets/labels')
    for data_key in data_keys:
        if data_key == "data":
            hdf_dataset.create_dataset('inputs',
                                       shape=shapes[data_key],
                                       dtype=dataset.get_data_dtype(data_key))
        else:
            hdf_dataset['targets/data'].create_dataset(
                data_key,
                shape=shapes[data_key],
                dtype=dataset.get_data_dtype(data_key))
            hdf_dataset['targets/size'].attrs[data_key] = dataset.num_outputs[
                data_key]

        if data_key in dataset.labels:
            labels = dataset.labels[data_key]
            assert len(labels) == dataset.num_outputs[data_key][0]
        else:
            labels = [
                "%s-class-%i" % (data_key, i)
                for i in range(dataset.get_data_dim(data_key))
            ]
        print("Labels for %s:" % data_key, labels[:3], "...", file=log.v5)
        max_label_len = max(map(len, labels))
        if data_key != "data":
            hdf_dataset['targets/labels'].create_dataset(
                data_key, (len(labels), ), dtype="S%i" % (max_label_len + 1))
            for i, label in enumerate(labels):
                hdf_dataset['targets/labels'][data_key][i] = numpy.array(
                    label, dtype="S%i" % (max_label_len + 1))

    # Again iterate through dataset, and set the data
    print("Write data...", file=log.v3)
    dataset.init_seq_order(parser_args.epoch)
    offsets = NumbersDict(0)
    for seq_idx, tag in zip(seq_idxs, seq_tags):
        dataset.load_seqs(seq_idx, seq_idx + 1)
        tag_ = dataset.get_tag(seq_idx)
        assert tag == tag_  # Just a check for sanity. We expect the same order.
        seq_len = dataset.get_seq_length(seq_idx)
        for data_key in data_keys:
            if data_key == "data":
                hdf_data = hdf_dataset['inputs']
            else:
                hdf_data = hdf_dataset['targets/data'][data_key]
            data = dataset.get_data(seq_idx, data_key)
            hdf_data[offsets[data_key]:offsets[data_key] +
                     seq_len[data_key]] = data

        progress_bar_with_time(float(offsets["data"]) / total_seq_len["data"])

        offsets += seq_len

    assert offsets == total_seq_len  # Sanity check.

    # Set some old-format attribs. Not needed for newer CRNN versions.
    hdf_dataset.attrs[HDFDataset.attr_inputPattSize] = dataset.num_inputs
    hdf_dataset.attrs[HDFDataset.attr_numLabels] = dataset.num_outputs.get(
        "classes", (0, 0))[0]

    print("All done.", file=log.v3)
예제 #8
0
파일: hdf_dump.py 프로젝트: atuxhe/returnn
def hdf_dump_from_dataset(dataset, hdf_dataset, parser_args):
  """
  :param Dataset dataset: could be any dataset implemented as child of Dataset
  :type hdf_dataset: h5py._hl.files.File
  :param parser_args: argparse object from main()
  :return:
  """
  print >> log.v3, "Work on epoch: %i" % parser_args.epoch
  dataset.init_seq_order(parser_args.epoch)

  data_keys = sorted(dataset.get_data_keys())
  print >> log.v3, "Data keys:", data_keys

  # We need to do one run through the dataset to collect some stats like total len.
  print >> log.v3, "Collect stats, iterate through all data..."
  seq_idx = parser_args.start_seq
  seq_idxs = []
  seq_tags = []
  seq_lens = []
  total_seq_len = NumbersDict(0)
  max_tag_len = 0
  dataset_num_seqs = try_run(lambda: dataset.num_seqs, default=None)  # can be unknown
  if parser_args.end_seq != float("inf"):
    if dataset_num_seqs is not None:
      dataset_num_seqs = min(dataset_num_seqs, parser_args.end_seq)
    else:
      dataset_num_seqs = parser_args.end_seq
  if dataset_num_seqs is not None:
    dataset_num_seqs -= parser_args.start_seq
    assert dataset_num_seqs > 0
  while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= parser_args.end_seq:
    seq_idxs += [seq_idx]
    dataset.load_seqs(seq_idx, seq_idx + 1)
    seq_len = dataset.get_seq_length(seq_idx)
    seq_lens += [seq_len]
    tag = dataset.get_tag(seq_idx)
    seq_tags += [tag]
    max_tag_len = max(len(tag), max_tag_len)
    total_seq_len += seq_len
    if dataset_num_seqs is not None:
      progress_bar_with_time(float(seq_idx - parser_args.start_seq) / dataset_num_seqs)
    seq_idx += 1
  num_seqs = len(seq_idxs)

  assert num_seqs > 0
  shapes = {}
  for data_key in data_keys:
    assert data_key in total_seq_len.dict
    shape = [total_seq_len[data_key]]
    shape += dataset.get_data_shape(data_key)
    print >> log.v3, "Total len of %r is %s, shape %r, dtype %s" % (
                     data_key, human_size(shape[0]), shape, dataset.get_data_dtype(data_key))
    shapes[data_key] = shape

  print >> log.v3, "Set seq tags..."
  hdf_dataset.create_dataset('seqTags', shape=(num_seqs,), dtype="S%i" % (max_tag_len + 1))
  for i, tag in enumerate(seq_tags):
    hdf_dataset['seqTags'][i] = tag
    progress_bar_with_time(float(i) / num_seqs)

  print >> log.v3, "Set seq len info..."
  hdf_dataset.create_dataset(HDFDataset.attr_seqLengths, shape=(num_seqs, 2), dtype="int32")
  for i, seq_len in enumerate(seq_lens):
    data_len = seq_len["data"]
    targets_len = seq_len["classes"]
    for data_key in dataset.get_target_list():
      assert seq_len[data_key] == targets_len, "different lengths in multi-target not supported"
    if targets_len is None:
      targets_len = data_len
    hdf_dataset[HDFDataset.attr_seqLengths][i] = [data_len, targets_len]
    progress_bar_with_time(float(i) / num_seqs)

  print >> log.v3, "Create arrays in HDF..."
  hdf_dataset.create_group('targets/data')
  hdf_dataset.create_group('targets/size')
  hdf_dataset.create_group('targets/labels')
  for data_key in data_keys:
    if data_key == "data":
      hdf_dataset.create_dataset(
        'inputs', shape=shapes[data_key], dtype=dataset.get_data_dtype(data_key))
    else:
      hdf_dataset['targets/data'].create_dataset(
        data_key, shape=shapes[data_key], dtype=dataset.get_data_dtype(data_key))
      hdf_dataset['targets/size'].attrs[data_key] = dataset.num_outputs[data_key]

    if data_key in dataset.labels:
      labels = dataset.labels[data_key]
      assert len(labels) == dataset.num_outputs[data_key][0]
    else:
      labels = ["%s-class-%i" % (data_key, i) for i in range(dataset.get_data_dim(data_key))]
    print >> log.v5, "Labels for %s:" % data_key, labels[:3], "..."
    max_label_len = max(map(len, labels))
    hdf_dataset['targets/labels'].create_dataset(data_key, (len(labels),), dtype="S%i" % (max_label_len + 1))
    for i, label in enumerate(labels):
      hdf_dataset['targets/labels'][data_key][i] = label

  # Again iterate through dataset, and set the data
  print >> log.v3, "Write data..."
  dataset.init_seq_order(parser_args.epoch)
  offsets = NumbersDict(0)
  for seq_idx, tag in zip(seq_idxs, seq_tags):
    dataset.load_seqs(seq_idx, seq_idx + 1)
    tag_ = dataset.get_tag(seq_idx)
    assert tag == tag_  # Just a check for sanity. We expect the same order.
    seq_len = dataset.get_seq_length(seq_idx)
    for data_key in data_keys:
      if data_key == "data":
        hdf_data = hdf_dataset['inputs']
      else:
        hdf_data = hdf_dataset['targets/data'][data_key]
      data = dataset.get_data(seq_idx, data_key)
      hdf_data[offsets[data_key]:offsets[data_key] + seq_len[data_key]] = data

    progress_bar_with_time(float(offsets["data"]) / total_seq_len["data"])

    offsets += seq_len

  assert offsets == total_seq_len  # Sanity check.

  # Set some old-format attribs. Not needed for newer CRNN versions.
  hdf_dataset.attrs[HDFDataset.attr_inputPattSize] = dataset.num_inputs
  hdf_dataset.attrs[HDFDataset.attr_numLabels] = dataset.num_outputs.get("classes", (0, 0))[0]

  print >> log.v3, "All done."
예제 #9
0
def collect_stats(options, iter_corpus):
    """
  :param options: argparse.Namespace
  """
    orth_symbols_filename = options.output
    if orth_symbols_filename:
        assert not os.path.exists(orth_symbols_filename)

    class Stats:
        count = 0
        process_last_time = time.time()
        total_frame_len = 0
        total_orth_len = 0
        orth_syms_set = set()

    if options.add_numbers:
        Stats.orth_syms_set.update(
            map(chr, list(range(ord("0"),
                                ord("9") + 1))))
    if options.add_lower_alphabet:
        Stats.orth_syms_set.update(
            map(chr, list(range(ord("a"),
                                ord("z") + 1))))
    if options.add_upper_alphabet:
        Stats.orth_syms_set.update(
            map(chr, list(range(ord("A"),
                                ord("Z") + 1))))

    def cb(frame_len, orth):
        if frame_len >= options.max_seq_frame_len:
            return
        orth_syms = parse_orthography(orth)
        if len(orth_syms) >= options.max_seq_orth_len:
            return

        Stats.count += 1
        Stats.total_frame_len += frame_len

        if options.dump_orth_syms:
            print("Orth:", "".join(orth_syms), file=log.v3)
        if options.filter_orth_sym:
            if options.filter_orth_sym in orth_syms:
                print("Found orth:", "".join(orth_syms), file=log.v3)
        if options.filter_orth_syms_seq:
            filter_seq = parse_orthography_into_symbols(
                options.filter_orth_syms_seq)
            if found_sub_seq(filter_seq, orth_syms):
                print("Found orth:", "".join(orth_syms), file=log.v3)
        Stats.orth_syms_set.update(orth_syms)
        Stats.total_orth_len += len(orth_syms)

        # Show some progress if it takes long.
        if time.time() - Stats.process_last_time > 2:
            Stats.process_last_time = time.time()
            if options.collect_time:
                print("Collect process, total frame len so far:",
                      hms(Stats.total_frame_len *
                          (options.frame_time / 1000.0)),
                      file=log.v3)
            else:
                print("Collect process, total orth len so far:",
                      human_size(Stats.total_orth_len),
                      file=log.v3)

    iter_corpus(cb)

    if options.remove_symbols:
        filter_syms = parse_orthography_into_symbols(options.remove_symbols)
        Stats.orth_syms_set -= set(filter_syms)

    if options.collect_time:
        print("Total frame len:",
              Stats.total_frame_len,
              "time:",
              hms(Stats.total_frame_len * (options.frame_time / 1000.0)),
              file=log.v3)
    else:
        print("No time stats (--collect_time False).", file=log.v3)
    print("Total orth len:",
          Stats.total_orth_len,
          "(%s)" % human_size(Stats.total_orth_len),
          end=' ',
          file=log.v3)
    if options.collect_time:
        print("fraction:",
              float(Stats.total_orth_len) / Stats.total_frame_len,
              file=log.v3)
    else:
        print("", file=log.v3)
    print("Average orth len:",
          float(Stats.total_orth_len) / Stats.count,
          file=log.v3)
    print("Num symbols:", len(Stats.orth_syms_set), file=log.v3)

    if orth_symbols_filename:
        orth_syms_file = open(orth_symbols_filename, "wb")
        for orth_sym in sorted(Stats.orth_syms_set):
            orth_syms_file.write("%s\n" % unicode(orth_sym).encode("utf8"))
        orth_syms_file.close()
        print("Wrote orthography symbols to",
              orth_symbols_filename,
              file=log.v3)
    else:
        print("Provide --output to save the symbols.", file=log.v3)
예제 #10
0
def collect_stats(options, iter_corpus):
  """
  :param options: argparse.Namespace
  """
  orth_symbols_filename = options.output
  if orth_symbols_filename:
    assert not os.path.exists(orth_symbols_filename)

  class Stats:
    count = 0
    process_last_time = time.time()
    total_frame_len = 0
    total_orth_len = 0
    orth_syms_set = set()

  if options.add_numbers:
    Stats.orth_syms_set.update(map(chr, range(ord("0"), ord("9") + 1)))
  if options.add_lower_alphabet:
    Stats.orth_syms_set.update(map(chr, range(ord("a"), ord("z") + 1)))
  if options.add_upper_alphabet:
    Stats.orth_syms_set.update(map(chr, range(ord("A"), ord("Z") + 1)))

  def cb(frame_len, orth):
    if frame_len >= options.max_seq_frame_len:
      return
    orth_syms = parse_orthography(orth)
    if len(orth_syms) >= options.max_seq_orth_len:
      return

    Stats.count += 1
    Stats.total_frame_len += frame_len

    if options.dump_orth_syms:
      print >> log.v3, "Orth:", "".join(orth_syms)
    if options.filter_orth_sym:
      if options.filter_orth_sym in orth_syms:
        print >> log.v3, "Found orth:", "".join(orth_syms)
    if options.filter_orth_syms_seq:
      filter_seq = parse_orthography_into_symbols(options.filter_orth_syms_seq)
      if found_sub_seq(filter_seq, orth_syms):
        print >> log.v3, "Found orth:", "".join(orth_syms)
    Stats.orth_syms_set.update(orth_syms)
    Stats.total_orth_len += len(orth_syms)

    # Show some progress if it takes long.
    if time.time() - Stats.process_last_time > 2:
      Stats.process_last_time = time.time()
      if options.collect_time:
        print >> log.v3, "Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0))
      else:
        print >> log.v3, "Collect process, total orth len so far:", human_size(Stats.total_orth_len)

  iter_corpus(cb)

  if options.remove_symbols:
    filter_syms = parse_orthography_into_symbols(options.remove_symbols)
    Stats.orth_syms_set -= set(filter_syms)

  if options.collect_time:
    print >> log.v3, "Total frame len:", Stats.total_frame_len, "time:", hms(Stats.total_frame_len * (options.frame_time / 1000.0))
  else:
    print >> log.v3, "No time stats (--collect_time False)."
  print >> log.v3, "Total orth len:", Stats.total_orth_len, "(%s)" % human_size(Stats.total_orth_len),
  if options.collect_time:
    print >> log.v3, "fraction:", float(Stats.total_orth_len) / Stats.total_frame_len
  else:
    print >> log.v3, ""
  print >> log.v3, "Average orth len:", float(Stats.total_orth_len) / Stats.count
  print >> log.v3, "Num symbols:", len(Stats.orth_syms_set)

  if orth_symbols_filename:
    orth_syms_file = open(orth_symbols_filename, "wb")
    for orth_sym in sorted(Stats.orth_syms_set):
      orth_syms_file.write("%s\n" % unicode(orth_sym).encode("utf8"))
    orth_syms_file.close()
    print >> log.v3, "Wrote orthography symbols to", orth_symbols_filename
  else:
    print >> log.v3, "Provide --output to save the symbols."
예제 #11
0
  def dump_from_dataset(self, dataset, epoch=1, start_seq=0, end_seq=float("inf"), use_progress_bar=True):
    """
    :param Dataset dataset: could be any dataset implemented as child of Dataset
    :param int epoch: for dataset
    :param int start_seq:
    :param int|float end_seq:
    :param bool use_progress_bar:
    """
    from Util import NumbersDict, human_size, progress_bar_with_time, try_run, PY3
    hdf_dataset = self.file

    print("Work on epoch: %i" % epoch, file=log.v3)
    dataset.init_seq_order(epoch)

    data_keys = sorted(dataset.get_data_keys())
    print("Data keys:", data_keys, file=log.v3)
    if "orth" in data_keys:  # special workaround for now, not handled
      data_keys.remove("orth")
    data_target_keys = [key for key in dataset.get_target_list() if key in data_keys]
    data_input_keys = [key for key in data_keys if key not in data_target_keys]
    assert len(data_input_keys) > 0 and len(data_target_keys) > 0
    if len(data_input_keys) > 1:
      if "data" in data_input_keys:
        default_data_input_key = "data"
      else:
        raise Exception("not sure which input data key to use from %r" % (data_input_keys,))
    else:
      default_data_input_key = data_input_keys[0]
    print("Using input data key:", default_data_input_key)
    if len(data_target_keys) > 1:
      if "classes" in data_target_keys:
        default_data_target_key = "classes"
      else:
        raise Exception("not sure which target data key to use from %r" % (data_target_keys,))
    else:
      default_data_target_key = data_target_keys[0]
    print("Using target data key:", default_data_target_key)

    hdf_data_key_map = {key: key for key in data_keys if key != default_data_input_key}
    if "data" in hdf_data_key_map:
      hdf_data_key_map["data"] = "classes"  # Replace "data" which is reserved for input key in HDFDataset.
      assert "classes" not in hdf_data_key_map

    # We need to do one run through the dataset to collect some stats like total len.
    print("Collect stats, iterate through all data...", file=log.v3)
    seq_idx = start_seq
    seq_idxs = []
    seq_tags = []
    seq_lens = []
    total_seq_len = NumbersDict(0)
    max_tag_len = 0
    dataset_num_seqs = try_run(lambda: dataset.num_seqs, default=None)  # can be unknown
    if end_seq != float("inf"):
      if dataset_num_seqs is not None:
        dataset_num_seqs = min(dataset_num_seqs, end_seq)
      else:
        dataset_num_seqs = end_seq
    if dataset_num_seqs is not None:
      dataset_num_seqs -= start_seq
      assert dataset_num_seqs > 0
    while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= end_seq:
      seq_idxs += [seq_idx]
      dataset.load_seqs(seq_idx, seq_idx + 1)
      seq_len = dataset.get_seq_length(seq_idx)
      seq_lens += [seq_len]
      tag = dataset.get_tag(seq_idx)
      seq_tags += [tag]
      max_tag_len = max(len(tag), max_tag_len)
      total_seq_len += seq_len
      if use_progress_bar and dataset_num_seqs is not None:
        progress_bar_with_time(float(seq_idx - start_seq) / dataset_num_seqs)
      seq_idx += 1
    num_seqs = len(seq_idxs)

    assert num_seqs > 0
    shapes = {}
    for data_key in data_keys:
      assert data_key in total_seq_len.dict
      shape = [total_seq_len[data_key]]
      shape += dataset.get_data_shape(data_key)
      print("Total len of %r is %s, shape %r, dtype %s" % (
        data_key, human_size(shape[0]), shape, dataset.get_data_dtype(data_key)), file=log.v3)
      shapes[data_key] = shape

    print("Set seq tags...", file=log.v3)
    hdf_dataset.create_dataset('seqTags', shape=(num_seqs,), dtype="S%i" % (max_tag_len + 1))
    for i, tag in enumerate(seq_tags):
      hdf_dataset['seqTags'][i] = numpy.array(tag, dtype="S%i" % (max_tag_len + 1))
      if use_progress_bar:
        progress_bar_with_time(float(i) / num_seqs)

    print("Set seq len info...", file=log.v3)
    hdf_dataset.create_dataset(attr_seqLengths, shape=(num_seqs, 2), dtype="int32")
    for i, seq_len in enumerate(seq_lens):
      data_len = seq_len[default_data_input_key]
      targets_len = seq_len[default_data_target_key]
      for data_key in data_target_keys:
        assert seq_len[data_key] == targets_len, "different lengths in multi-target not supported"
      if targets_len is None:
        targets_len = data_len
      hdf_dataset[attr_seqLengths][i] = [data_len, targets_len]
      if use_progress_bar:
        progress_bar_with_time(float(i) / num_seqs)

    print("Create arrays in HDF...", file=log.v3)
    hdf_dataset.create_group('targets/data')
    hdf_dataset.create_group('targets/size')
    hdf_dataset.create_group('targets/labels')
    for data_key in data_keys:
      if data_key == default_data_input_key:
        hdf_dataset.create_dataset(
          'inputs', shape=shapes[data_key], dtype=dataset.get_data_dtype(data_key))
      else:
        hdf_dataset['targets/data'].create_dataset(
          hdf_data_key_map[data_key], shape=shapes[data_key], dtype=dataset.get_data_dtype(data_key))
        hdf_dataset['targets/size'].attrs[hdf_data_key_map[data_key]] = dataset.num_outputs[data_key]
      if data_key in dataset.labels:
        labels = dataset.labels[data_key]
        if PY3:
          labels = [label.encode("utf8") for label in labels]
        assert len(labels) == dataset.num_outputs[data_key][0]
      else:
        labels = ["%s-class-%i" % (data_key, i) for i in range(dataset.get_data_dim(data_key))]
      print("Labels for %s:" % data_key, labels[:3], "...", file=log.v5)
      max_label_len = max(map(len, labels))
      if data_key != default_data_input_key:
        hdf_dataset['targets/labels'].create_dataset(hdf_data_key_map[data_key],
                                                     (len(labels),), dtype="S%i" % (max_label_len + 1))
        for i, label in enumerate(labels):
          hdf_dataset['targets/labels'][hdf_data_key_map[data_key]][i] = numpy.array(
            label, dtype="S%i" % (max_label_len + 1))

    # Again iterate through dataset, and set the data
    print("Write data...", file=log.v3)
    dataset.init_seq_order(epoch)
    offsets = NumbersDict(0)
    for seq_idx, tag in zip(seq_idxs, seq_tags):
      dataset.load_seqs(seq_idx, seq_idx + 1)
      tag_ = dataset.get_tag(seq_idx)
      assert tag == tag_  # Just a check for sanity. We expect the same order.
      seq_len = dataset.get_seq_length(seq_idx)
      for data_key in data_keys:
        if data_key == default_data_input_key:
          hdf_data = hdf_dataset['inputs']
        else:
          hdf_data = hdf_dataset['targets/data'][hdf_data_key_map[data_key]]
        data = dataset.get_data(seq_idx, data_key)
        hdf_data[offsets[data_key]:offsets[data_key] + seq_len[data_key]] = data

      if use_progress_bar:
        progress_bar_with_time(float(offsets[default_data_input_key]) / total_seq_len[default_data_input_key])

      offsets += seq_len

    assert offsets == total_seq_len  # Sanity check.

    # Set some old-format attribs. Not needed for newer CRNN versions.
    hdf_dataset.attrs[attr_inputPattSize] = dataset.num_inputs
    hdf_dataset.attrs[attr_numLabels] = dataset.num_outputs.get(default_data_target_key, (0, 0))[0]

    print("All done.", file=log.v3)
예제 #12
0
  def dump_from_dataset(self, dataset, epoch=1, start_seq=0, end_seq=float("inf"), use_progress_bar=True):
    """
    :param Dataset dataset: could be any dataset implemented as child of Dataset
    :param int epoch: for dataset
    :param int start_seq:
    :param int|float end_seq:
    :param bool use_progress_bar:
    """
    from Util import NumbersDict, human_size, progress_bar_with_time, try_run, PY3
    hdf_dataset = self.file

    print("Work on epoch: %i" % epoch, file=log.v3)
    dataset.init_seq_order(epoch)

    data_keys = sorted(dataset.get_data_keys())
    print("Data keys:", data_keys, file=log.v3)
    if "orth" in data_keys:  # special workaround for now, not handled
      data_keys.remove("orth")
    data_target_keys = [key for key in dataset.get_target_list() if key in data_keys]
    data_input_keys = [key for key in data_keys if key not in data_target_keys]
    assert len(data_input_keys) > 0 and len(data_target_keys) > 0
    if len(data_input_keys) > 1:
      if "data" in data_input_keys:
        default_data_input_key = "data"
      else:
        raise Exception("not sure which input data key to use from %r" % (data_input_keys,))
    else:
      default_data_input_key = data_input_keys[0]
    print("Using input data key:", default_data_input_key)
    if len(data_target_keys) > 1:
      if "classes" in data_target_keys:
        default_data_target_key = "classes"
      else:
        raise Exception("not sure which target data key to use from %r" % (data_target_keys,))
    else:
      default_data_target_key = data_target_keys[0]
    print("Using target data key:", default_data_target_key)

    hdf_data_key_map = {key: key for key in data_keys if key != default_data_input_key}
    if "data" in hdf_data_key_map:
      hdf_data_key_map["data"] = "classes"  # Replace "data" which is reserved for input key in HDFDataset.
      assert "classes" not in hdf_data_key_map

    # We need to do one run through the dataset to collect some stats like total len.
    print("Collect stats, iterate through all data...", file=log.v3)
    seq_idx = start_seq
    seq_idxs = []
    seq_tags = []
    seq_lens = []
    total_seq_len = NumbersDict(0)
    max_tag_len = 0
    dataset_num_seqs = try_run(lambda: dataset.num_seqs, default=None)  # can be unknown
    if end_seq != float("inf"):
      if dataset_num_seqs is not None:
        dataset_num_seqs = min(dataset_num_seqs, end_seq)
      else:
        dataset_num_seqs = end_seq
    if dataset_num_seqs is not None:
      dataset_num_seqs -= start_seq
      assert dataset_num_seqs > 0
    while dataset.is_less_than_num_seqs(seq_idx) and seq_idx <= end_seq:
      seq_idxs += [seq_idx]
      dataset.load_seqs(seq_idx, seq_idx + 1)
      seq_len = dataset.get_seq_length(seq_idx)
      seq_lens += [seq_len]
      tag = dataset.get_tag(seq_idx)
      seq_tags += [tag]
      max_tag_len = max(len(tag), max_tag_len)
      total_seq_len += seq_len
      if use_progress_bar and dataset_num_seqs is not None:
        progress_bar_with_time(float(seq_idx - start_seq) / dataset_num_seqs)
      seq_idx += 1
    num_seqs = len(seq_idxs)

    assert num_seqs > 0
    shapes = {}
    for data_key in data_keys:
      assert data_key in total_seq_len.dict
      shape = [total_seq_len[data_key]]
      shape += dataset.get_data_shape(data_key)
      print("Total len of %r is %s, shape %r, dtype %s" % (
        data_key, human_size(shape[0]), shape, dataset.get_data_dtype(data_key)), file=log.v3)
      shapes[data_key] = shape

    print("Set seq tags...", file=log.v3)
    hdf_dataset.create_dataset('seqTags', shape=(num_seqs,), dtype="S%i" % (max_tag_len + 1))
    for i, tag in enumerate(seq_tags):
      hdf_dataset['seqTags'][i] = numpy.array(tag, dtype="S%i" % (max_tag_len + 1))
      if use_progress_bar:
        progress_bar_with_time(float(i) / num_seqs)

    print("Set seq len info...", file=log.v3)
    hdf_dataset.create_dataset(attr_seqLengths, shape=(num_seqs, 2), dtype="int32")
    for i, seq_len in enumerate(seq_lens):
      data_len = seq_len[default_data_input_key]
      targets_len = seq_len[default_data_target_key]
      for data_key in data_target_keys:
        assert seq_len[data_key] == targets_len, "different lengths in multi-target not supported"
      if targets_len is None:
        targets_len = data_len
      hdf_dataset[attr_seqLengths][i] = [data_len, targets_len]
      if use_progress_bar:
        progress_bar_with_time(float(i) / num_seqs)

    print("Create arrays in HDF...", file=log.v3)
    hdf_dataset.create_group('targets/data')
    hdf_dataset.create_group('targets/size')
    hdf_dataset.create_group('targets/labels')
    for data_key in data_keys:
      if data_key == default_data_input_key:
        hdf_dataset.create_dataset(
          'inputs', shape=shapes[data_key], dtype=dataset.get_data_dtype(data_key))
      else:
        hdf_dataset['targets/data'].create_dataset(
          hdf_data_key_map[data_key], shape=shapes[data_key], dtype=dataset.get_data_dtype(data_key))
        hdf_dataset['targets/size'].attrs[hdf_data_key_map[data_key]] = dataset.num_outputs[data_key]
      if data_key in dataset.labels:
        labels = dataset.labels[data_key]
        if PY3:
          labels = [label.encode("utf8") for label in labels]
        assert len(labels) == dataset.num_outputs[data_key][0]
      else:
        labels = ["%s-class-%i" % (data_key, i) for i in range(dataset.get_data_dim(data_key))]
      print("Labels for %s:" % data_key, labels[:3], "...", file=log.v5)
      max_label_len = max(map(len, labels))
      if data_key != default_data_input_key:
        hdf_dataset['targets/labels'].create_dataset(hdf_data_key_map[data_key],
                                                     (len(labels),), dtype="S%i" % (max_label_len + 1))
        for i, label in enumerate(labels):
          hdf_dataset['targets/labels'][hdf_data_key_map[data_key]][i] = numpy.array(
            label, dtype="S%i" % (max_label_len + 1))

    # Again iterate through dataset, and set the data
    print("Write data...", file=log.v3)
    dataset.init_seq_order(epoch)
    offsets = NumbersDict(0)
    for seq_idx, tag in zip(seq_idxs, seq_tags):
      dataset.load_seqs(seq_idx, seq_idx + 1)
      tag_ = dataset.get_tag(seq_idx)
      assert tag == tag_  # Just a check for sanity. We expect the same order.
      seq_len = dataset.get_seq_length(seq_idx)
      for data_key in data_keys:
        if data_key == default_data_input_key:
          hdf_data = hdf_dataset['inputs']
        else:
          hdf_data = hdf_dataset['targets/data'][hdf_data_key_map[data_key]]
        data = dataset.get_data(seq_idx, data_key)
        hdf_data[offsets[data_key]:offsets[data_key] + seq_len[data_key]] = data

      if use_progress_bar:
        progress_bar_with_time(float(offsets[default_data_input_key]) / total_seq_len[default_data_input_key])

      offsets += seq_len

    assert offsets == total_seq_len  # Sanity check.

    # Set some old-format attribs. Not needed for newer CRNN versions.
    hdf_dataset.attrs[attr_inputPattSize] = dataset.num_inputs
    hdf_dataset.attrs[attr_numLabels] = dataset.num_outputs.get(default_data_target_key, (0, 0))[0]

    print("All done.", file=log.v3)