コード例 #1
0
    def cb(frame_len, orth):
        if frame_len >= options.max_seq_frame_len:
            return
        orth_syms = parse_orthography(orth)
        if len(orth_syms) >= options.max_seq_orth_len:
            return

        Stats.count += 1
        Stats.total_frame_len += frame_len

        if options.dump_orth_syms:
            print("Orth:", "".join(orth_syms), file=log.v3)
        if options.filter_orth_sym:
            if options.filter_orth_sym in orth_syms:
                print("Found orth:", "".join(orth_syms), file=log.v3)
        if options.filter_orth_syms_seq:
            filter_seq = parse_orthography_into_symbols(
                options.filter_orth_syms_seq)
            if found_sub_seq(filter_seq, orth_syms):
                print("Found orth:", "".join(orth_syms), file=log.v3)
        Stats.orth_syms_set.update(orth_syms)
        Stats.total_orth_len += len(orth_syms)

        # Show some progress if it takes long.
        if time.time() - Stats.process_last_time > 2:
            Stats.process_last_time = time.time()
            if options.collect_time:
                print("Collect process, total frame len so far:",
                      hms(Stats.total_frame_len *
                          (options.frame_time / 1000.0)),
                      file=log.v3)
            else:
                print("Collect process, total orth len so far:",
                      human_size(Stats.total_orth_len),
                      file=log.v3)
コード例 #2
0
  def cb(frame_len, orth):
    if frame_len >= options.max_seq_frame_len:
      return
    orth_syms = parse_orthography(orth)
    if len(orth_syms) >= options.max_seq_orth_len:
      return

    Stats.count += 1
    Stats.total_frame_len += frame_len

    if options.dump_orth_syms:
      print >> log.v3, "Orth:", "".join(orth_syms)
    if options.filter_orth_sym:
      if options.filter_orth_sym in orth_syms:
        print >> log.v3, "Found orth:", "".join(orth_syms)
    if options.filter_orth_syms_seq:
      filter_seq = parse_orthography_into_symbols(options.filter_orth_syms_seq)
      if found_sub_seq(filter_seq, orth_syms):
        print >> log.v3, "Found orth:", "".join(orth_syms)
    Stats.orth_syms_set.update(orth_syms)
    Stats.total_orth_len += len(orth_syms)

    # Show some progress if it takes long.
    if time.time() - Stats.process_last_time > 2:
      Stats.process_last_time = time.time()
      if options.collect_time:
        print >> log.v3, "Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0))
      else:
        print >> log.v3, "Collect process, total orth len so far:", human_size(Stats.total_orth_len)
コード例 #3
0
ファイル: LmDataset.py プロジェクト: atuxhe/returnn
  def __init__(self,
               corpus_file, phone_info=None, orth_symbols_file=None, orth_replace_map_file=None,
               add_random_phone_seqs=0,
               partition_epoch=1,
               log_skipped_seqs=False, **kwargs):
    """
    :param str corpus_file: Bliss XML or line-based txt. optionally can be gzip.
    :param dict | None phone_info: if you want to get phone seqs, dict with lexicon_file etc. see _PhoneSeqGenerator
    :param str | None orth_symbols_file: list of orthography symbols, if you want to get orth symbol seqs
    :param str | None orth_replace_map_file: JSON file with replacement dict for orth symbols
    :param int add_random_phone_seqs: will add random seqs with the same len as the real seq as additional data
    :param bool log_skipped_seqs: log skipped seqs
    """
    super(LmDataset, self).__init__(**kwargs)

    if orth_symbols_file:
      assert not phone_info
      orth_symbols = open(orth_symbols_file).read().splitlines()
      self.orth_symbols_map = {sym: i for (i, sym) in enumerate(orth_symbols)}
      self.orth_symbols = orth_symbols
      self.labels["data"] = orth_symbols
      self.seq_gen = None
    else:
      assert not orth_symbols_file
      assert isinstance(phone_info, dict)
      self.seq_gen = _PhoneSeqGenerator(**phone_info)
      self.orth_symbols = None
      self.labels["data"] = self.seq_gen.get_class_labels()
    if orth_replace_map_file:
      orth_replace_map = load_json(filename=orth_replace_map_file)
      assert isinstance(orth_replace_map, dict)
      self.orth_replace_map = {key: parse_orthography_into_symbols(v)
                               for (key, v) in orth_replace_map.items()}
    else:
      self.orth_replace_map = {}

    if len(self.labels["data"]) <= 256:
      self.dtype = "int8"
    else:
      self.dtype = "int32"
    self.num_outputs = {"data": [len(self.labels["data"]), 1]}
    self.num_inputs = self.num_outputs["data"][0]
    self.seq_order = None
    self.log_skipped_seqs = log_skipped_seqs
    self.partition_epoch = partition_epoch
    self.add_random_phone_seqs = add_random_phone_seqs
    for i in range(add_random_phone_seqs):
      self.num_outputs["random%i" % i] = self.num_outputs["data"]

    if _is_bliss(corpus_file):
      iter_f = _iter_bliss
    else:
      iter_f = _iter_txt
    self.orths = []
    print >> log.v4, "LmDataset, loading file", corpus_file
    iter_f(corpus_file, self.orths.append)
    # It's only estimated because we might filter some out or so.
    self._estimated_num_seqs = len(self.orths) // self.partition_epoch
コード例 #4
0
def collect_stats(options, iter_corpus):
    """
  :param options: argparse.Namespace
  """
    orth_symbols_filename = options.output
    if orth_symbols_filename:
        assert not os.path.exists(orth_symbols_filename)

    class Stats:
        count = 0
        process_last_time = time.time()
        total_frame_len = 0
        total_orth_len = 0
        orth_syms_set = set()

    if options.add_numbers:
        Stats.orth_syms_set.update(
            map(chr, list(range(ord("0"),
                                ord("9") + 1))))
    if options.add_lower_alphabet:
        Stats.orth_syms_set.update(
            map(chr, list(range(ord("a"),
                                ord("z") + 1))))
    if options.add_upper_alphabet:
        Stats.orth_syms_set.update(
            map(chr, list(range(ord("A"),
                                ord("Z") + 1))))

    def cb(frame_len, orth):
        if frame_len >= options.max_seq_frame_len:
            return
        orth_syms = parse_orthography(orth)
        if len(orth_syms) >= options.max_seq_orth_len:
            return

        Stats.count += 1
        Stats.total_frame_len += frame_len

        if options.dump_orth_syms:
            print("Orth:", "".join(orth_syms), file=log.v3)
        if options.filter_orth_sym:
            if options.filter_orth_sym in orth_syms:
                print("Found orth:", "".join(orth_syms), file=log.v3)
        if options.filter_orth_syms_seq:
            filter_seq = parse_orthography_into_symbols(
                options.filter_orth_syms_seq)
            if found_sub_seq(filter_seq, orth_syms):
                print("Found orth:", "".join(orth_syms), file=log.v3)
        Stats.orth_syms_set.update(orth_syms)
        Stats.total_orth_len += len(orth_syms)

        # Show some progress if it takes long.
        if time.time() - Stats.process_last_time > 2:
            Stats.process_last_time = time.time()
            if options.collect_time:
                print("Collect process, total frame len so far:",
                      hms(Stats.total_frame_len *
                          (options.frame_time / 1000.0)),
                      file=log.v3)
            else:
                print("Collect process, total orth len so far:",
                      human_size(Stats.total_orth_len),
                      file=log.v3)

    iter_corpus(cb)

    if options.remove_symbols:
        filter_syms = parse_orthography_into_symbols(options.remove_symbols)
        Stats.orth_syms_set -= set(filter_syms)

    if options.collect_time:
        print("Total frame len:",
              Stats.total_frame_len,
              "time:",
              hms(Stats.total_frame_len * (options.frame_time / 1000.0)),
              file=log.v3)
    else:
        print("No time stats (--collect_time False).", file=log.v3)
    print("Total orth len:",
          Stats.total_orth_len,
          "(%s)" % human_size(Stats.total_orth_len),
          end=' ',
          file=log.v3)
    if options.collect_time:
        print("fraction:",
              float(Stats.total_orth_len) / Stats.total_frame_len,
              file=log.v3)
    else:
        print("", file=log.v3)
    print("Average orth len:",
          float(Stats.total_orth_len) / Stats.count,
          file=log.v3)
    print("Num symbols:", len(Stats.orth_syms_set), file=log.v3)

    if orth_symbols_filename:
        orth_syms_file = open(orth_symbols_filename, "wb")
        for orth_sym in sorted(Stats.orth_syms_set):
            orth_syms_file.write("%s\n" % unicode(orth_sym).encode("utf8"))
        orth_syms_file.close()
        print("Wrote orthography symbols to",
              orth_symbols_filename,
              file=log.v3)
    else:
        print("Provide --output to save the symbols.", file=log.v3)
コード例 #5
0
def collect_stats(options, iter_corpus):
  """
  :param options: argparse.Namespace
  """
  orth_symbols_filename = options.output
  if orth_symbols_filename:
    assert not os.path.exists(orth_symbols_filename)

  class Stats:
    count = 0
    process_last_time = time.time()
    total_frame_len = 0
    total_orth_len = 0
    orth_syms_set = set()

  if options.add_numbers:
    Stats.orth_syms_set.update(map(chr, range(ord("0"), ord("9") + 1)))
  if options.add_lower_alphabet:
    Stats.orth_syms_set.update(map(chr, range(ord("a"), ord("z") + 1)))
  if options.add_upper_alphabet:
    Stats.orth_syms_set.update(map(chr, range(ord("A"), ord("Z") + 1)))

  def cb(frame_len, orth):
    if frame_len >= options.max_seq_frame_len:
      return
    orth_syms = parse_orthography(orth)
    if len(orth_syms) >= options.max_seq_orth_len:
      return

    Stats.count += 1
    Stats.total_frame_len += frame_len

    if options.dump_orth_syms:
      print >> log.v3, "Orth:", "".join(orth_syms)
    if options.filter_orth_sym:
      if options.filter_orth_sym in orth_syms:
        print >> log.v3, "Found orth:", "".join(orth_syms)
    if options.filter_orth_syms_seq:
      filter_seq = parse_orthography_into_symbols(options.filter_orth_syms_seq)
      if found_sub_seq(filter_seq, orth_syms):
        print >> log.v3, "Found orth:", "".join(orth_syms)
    Stats.orth_syms_set.update(orth_syms)
    Stats.total_orth_len += len(orth_syms)

    # Show some progress if it takes long.
    if time.time() - Stats.process_last_time > 2:
      Stats.process_last_time = time.time()
      if options.collect_time:
        print >> log.v3, "Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0))
      else:
        print >> log.v3, "Collect process, total orth len so far:", human_size(Stats.total_orth_len)

  iter_corpus(cb)

  if options.remove_symbols:
    filter_syms = parse_orthography_into_symbols(options.remove_symbols)
    Stats.orth_syms_set -= set(filter_syms)

  if options.collect_time:
    print >> log.v3, "Total frame len:", Stats.total_frame_len, "time:", hms(Stats.total_frame_len * (options.frame_time / 1000.0))
  else:
    print >> log.v3, "No time stats (--collect_time False)."
  print >> log.v3, "Total orth len:", Stats.total_orth_len, "(%s)" % human_size(Stats.total_orth_len),
  if options.collect_time:
    print >> log.v3, "fraction:", float(Stats.total_orth_len) / Stats.total_frame_len
  else:
    print >> log.v3, ""
  print >> log.v3, "Average orth len:", float(Stats.total_orth_len) / Stats.count
  print >> log.v3, "Num symbols:", len(Stats.orth_syms_set)

  if orth_symbols_filename:
    orth_syms_file = open(orth_symbols_filename, "wb")
    for orth_sym in sorted(Stats.orth_syms_set):
      orth_syms_file.write("%s\n" % unicode(orth_sym).encode("utf8"))
    orth_syms_file.close()
    print >> log.v3, "Wrote orthography symbols to", orth_symbols_filename
  else:
    print >> log.v3, "Provide --output to save the symbols."
コード例 #6
0
ファイル: LmDataset.py プロジェクト: sharmaannapurna/returnn
    def __init__(self,
                 corpus_file,
                 orth_symbols_file=None,
                 orth_symbols_map_file=None,
                 orth_replace_map_file=None,
                 word_based=False,
                 seq_end_symbol="[END]",
                 unknown_symbol="[UNKNOWN]",
                 parse_orth_opts=None,
                 phone_info=None,
                 add_random_phone_seqs=0,
                 partition_epoch=1,
                 auto_replace_unknown_symbol=False,
                 log_auto_replace_unknown_symbols=10,
                 log_skipped_seqs=10,
                 error_on_invalid_seq=True,
                 add_delayed_seq_data=False,
                 delayed_seq_data_start_symbol="[START]",
                 **kwargs):
        """
    :param str|()->str corpus_file: Bliss XML or line-based txt. optionally can be gzip.
    :param dict|None phone_info: if you want to get phone seqs, dict with lexicon_file etc. see PhoneSeqGenerator
    :param str|()->str|None orth_symbols_file: list of orthography symbols, if you want to get orth symbol seqs
    :param str|()->str|None orth_symbols_map_file: list of orth symbols, each line: "symbol index"
    :param str|()->str|None orth_replace_map_file: JSON file with replacement dict for orth symbols
    :param bool word_based: whether to parse single words, or otherwise will be char-based
    :param str|None seq_end_symbol: what to add at the end, if given.
      will be set as postfix=[seq_end_symbol] or postfix=[] for parse_orth_opts.
    :param dict[str]|None parse_orth_opts: kwargs for parse_orthography()
    :param int add_random_phone_seqs: will add random seqs with the same len as the real seq as additional data
    :param bool|int log_auto_replace_unknown_symbols: write about auto-replacements with unknown symbol.
      if this is an int, it will only log the first N replacements, and then keep quiet.
    :param bool|int log_skipped_seqs: write about skipped seqs to logging, due to missing lexicon entry or so.
      if this is an int, it will only log the first N entries, and then keep quiet.
    :param bool error_on_invalid_seq: if there is a seq we would have to skip, error
    :param bool add_delayed_seq_data: will add another data-key "delayed" which will have the sequence
      delayed_seq_data_start_symbol + original_sequence[:-1]
    :param str delayed_seq_data_start_symbol: used for add_delayed_seq_data
    :param int partition_epoch: whether to partition the epochs into multiple parts. like epoch_split
    """
        super(LmDataset, self).__init__(**kwargs)

        if callable(corpus_file):
            corpus_file = corpus_file()
        if callable(orth_symbols_file):
            orth_symbols_file = orth_symbols_file()
        if callable(orth_symbols_map_file):
            orth_symbols_map_file = orth_symbols_map_file()
        if callable(orth_replace_map_file):
            orth_replace_map_file = orth_replace_map_file()

        print("LmDataset, loading file", corpus_file, file=log.v4)

        self.word_based = word_based
        self.seq_end_symbol = seq_end_symbol
        self.unknown_symbol = unknown_symbol
        self.parse_orth_opts = parse_orth_opts or {}
        self.parse_orth_opts.setdefault("word_based", self.word_based)
        self.parse_orth_opts.setdefault(
            "postfix",
            [self.seq_end_symbol] if self.seq_end_symbol is not None else [])

        if orth_symbols_file:
            assert not phone_info
            assert not orth_symbols_map_file
            orth_symbols = open(orth_symbols_file).read().splitlines()
            self.orth_symbols_map = {
                sym: i
                for (i, sym) in enumerate(orth_symbols)
            }
            self.orth_symbols = orth_symbols
            self.labels["data"] = orth_symbols
            self.seq_gen = None
        elif orth_symbols_map_file:
            assert not phone_info
            orth_symbols_imap_list = [(int(b), a) for (a, b) in [
                l.split(None, 1)
                for l in open(orth_symbols_map_file).read().splitlines()
            ]]
            orth_symbols_imap_list.sort()
            assert orth_symbols_imap_list[0][0] == 0
            assert orth_symbols_imap_list[-1][0] == len(
                orth_symbols_imap_list) - 1
            self.orth_symbols_map = {
                sym: i
                for (i, sym) in orth_symbols_imap_list
            }
            self.orth_symbols = [sym for (i, sym) in orth_symbols_imap_list]
            self.labels["data"] = self.orth_symbols
            self.seq_gen = None
        else:
            assert not orth_symbols_file
            assert isinstance(phone_info, dict)
            self.seq_gen = PhoneSeqGenerator(**phone_info)
            self.orth_symbols = None
            self.labels["data"] = self.seq_gen.get_class_labels()
        if orth_replace_map_file:
            orth_replace_map = load_json(filename=orth_replace_map_file)
            assert isinstance(orth_replace_map, dict)
            self.orth_replace_map = {
                key: parse_orthography_into_symbols(v,
                                                    word_based=self.word_based)
                for (key, v) in orth_replace_map.items()
            }
            if self.orth_replace_map:
                if len(self.orth_replace_map) <= 5:
                    print("  orth_replace_map: %r" % self.orth_replace_map,
                          file=log.v5)
                else:
                    print("  orth_replace_map: %i entries" %
                          len(self.orth_replace_map),
                          file=log.v5)
        else:
            self.orth_replace_map = {}

        num_labels = len(self.labels["data"])
        use_uint_types = False
        if BackendEngine.is_tensorflow_selected():
            use_uint_types = True
        if num_labels <= 2**7:
            self.dtype = "int8"
        elif num_labels <= 2**8 and use_uint_types:
            self.dtype = "uint8"
        elif num_labels <= 2**31:
            self.dtype = "int32"
        elif num_labels <= 2**32 and use_uint_types:
            self.dtype = "uint32"
        elif num_labels <= 2**61:
            self.dtype = "int64"
        elif num_labels <= 2**62 and use_uint_types:
            self.dtype = "uint64"
        else:
            raise Exception("cannot handle so much labels: %i" % num_labels)
        self.num_outputs = {"data": [len(self.labels["data"]), 1]}
        self.num_inputs = self.num_outputs["data"][0]
        self.seq_order = None
        self.auto_replace_unknown_symbol = auto_replace_unknown_symbol
        self.log_auto_replace_unknown_symbols = log_auto_replace_unknown_symbols
        self.log_skipped_seqs = log_skipped_seqs
        self.error_on_invalid_seq = error_on_invalid_seq
        self.partition_epoch = partition_epoch
        self.add_random_phone_seqs = add_random_phone_seqs
        for i in range(add_random_phone_seqs):
            self.num_outputs["random%i" % i] = self.num_outputs["data"]
        self.add_delayed_seq_data = add_delayed_seq_data
        self.delayed_seq_data_start_symbol = delayed_seq_data_start_symbol
        if add_delayed_seq_data:
            self.num_outputs["delayed"] = self.num_outputs["data"]

        if _is_bliss(corpus_file):
            iter_f = _iter_bliss
        else:
            iter_f = _iter_txt
        self.orths = []
        iter_f(corpus_file, self.orths.append)
        # It's only estimated because we might filter some out or so.
        self._estimated_num_seqs = len(self.orths) // self.partition_epoch
        print("  done, loaded %i sequences" % len(self.orths), file=log.v4)