Пример #1
0
    def cb(frame_len, orth):
        if frame_len >= options.max_seq_frame_len:
            return
        orth_syms = parse_orthography(orth)
        if len(orth_syms) >= options.max_seq_orth_len:
            return

        Stats.count += 1
        Stats.total_frame_len += frame_len

        if options.dump_orth_syms:
            print("Orth:", "".join(orth_syms), file=log.v3)
        if options.filter_orth_sym:
            if options.filter_orth_sym in orth_syms:
                print("Found orth:", "".join(orth_syms), file=log.v3)
        if options.filter_orth_syms_seq:
            filter_seq = parse_orthography_into_symbols(
                options.filter_orth_syms_seq)
            if found_sub_seq(filter_seq, orth_syms):
                print("Found orth:", "".join(orth_syms), file=log.v3)
        Stats.orth_syms_set.update(orth_syms)
        Stats.total_orth_len += len(orth_syms)

        # Show some progress if it takes long.
        if time.time() - Stats.process_last_time > 2:
            Stats.process_last_time = time.time()
            if options.collect_time:
                print("Collect process, total frame len so far:",
                      hms(Stats.total_frame_len *
                          (options.frame_time / 1000.0)),
                      file=log.v3)
            else:
                print("Collect process, total orth len so far:",
                      human_size(Stats.total_orth_len),
                      file=log.v3)
Пример #2
0
  def cb(frame_len, orth):
    if frame_len >= options.max_seq_frame_len:
      return
    orth_syms = parse_orthography(orth)
    if len(orth_syms) >= options.max_seq_orth_len:
      return

    Stats.count += 1
    Stats.total_frame_len += frame_len

    if options.dump_orth_syms:
      print >> log.v3, "Orth:", "".join(orth_syms)
    if options.filter_orth_sym:
      if options.filter_orth_sym in orth_syms:
        print >> log.v3, "Found orth:", "".join(orth_syms)
    if options.filter_orth_syms_seq:
      filter_seq = parse_orthography_into_symbols(options.filter_orth_syms_seq)
      if found_sub_seq(filter_seq, orth_syms):
        print >> log.v3, "Found orth:", "".join(orth_syms)
    Stats.orth_syms_set.update(orth_syms)
    Stats.total_orth_len += len(orth_syms)

    # Show some progress if it takes long.
    if time.time() - Stats.process_last_time > 2:
      Stats.process_last_time = time.time()
      if options.collect_time:
        print >> log.v3, "Collect process, total frame len so far:", hms(Stats.total_frame_len * (options.frame_time / 1000.0))
      else:
        print >> log.v3, "Collect process, total orth len so far:", human_size(Stats.total_orth_len)
Пример #3
0
  def _collect_single_seq(self, seq_idx):
    """
    :type seq_idx: int
    :rtype: DatasetSeq | None
    :returns DatasetSeq or None if seq_idx >= num_seqs.
    """
    while True:
      if self.next_orth_idx >= len(self.orths_epoch):
        assert self.next_seq_idx <= seq_idx, "We expect that we iterate through all seqs."
        return None
      assert self.next_seq_idx == seq_idx, "We expect that we iterate through all seqs."
      orth = self.orths_epoch[self.seq_order[self.next_orth_idx]]
      self.next_orth_idx += 1
      if orth == "</s>": continue  # special sentence end symbol. empty seq, ignore.

      if self.seq_gen:
        try:
          phones = self.seq_gen.generate_seq(orth)
        except KeyError as e:
          if self.log_skipped_seqs:
            print >> log.v4, "LmDataset: skipping sequence %r because of missing lexicon entry: %s" % (
                             orth, e)
          self.num_skipped += 1
          continue
        data = self.seq_gen.seq_to_class_idxs(phones, dtype=self.dtype)

      elif self.orth_symbols:
        orth_syms = parse_orthography(orth)
        orth_syms = sum([self.orth_replace_map.get(s, [s]) for s in orth_syms], [])
        i = 0
        while i < len(orth_syms) - 1:
          if orth_syms[i:i+2] == [" ", " "]:
            orth_syms[i:i+2] = [" "]  # collapse two spaces
          else:
            i += 1
        try:
          data = numpy.array(map(self.orth_symbols_map.__getitem__, orth_syms), dtype=self.dtype)
        except KeyError as e:
          if self.log_skipped_seqs:
            print >> log.v4, "LmDataset: skipping sequence %r because of missing orth symbol: %s" % (
                             "".join(orth_syms), e)
          self.num_skipped += 1
          continue

      else:
        assert False

      targets = {}
      for i in range(self.add_random_phone_seqs):
        assert self.seq_gen  # not implemented atm for orths
        phones = self.seq_gen.generate_garbage_seq(target_len=data.shape[0])
        targets["random%i" % i] = self.seq_gen.seq_to_class_idxs(phones, dtype=self.dtype)
      self._num_timesteps_accumulated += data.shape[0]
      self.next_seq_idx = seq_idx + 1
      return DatasetSeq(seq_idx=seq_idx, features=data, targets=targets)
Пример #4
0
  def _callback(self, orth):
    orth_words = parse_orthography(orth, prefix=[], postfix=[], word_based=True)

    self.seq_count += 1

    if self.options.dump_orth:
      print("Orth:", orth_words, file=log.v3)
    self.words.update(orth_words)
    self.total_word_len += len(orth_words)

    # Show some progress if it takes long.
    if time.time() - self.process_last_time > 2:
      self.process_last_time = time.time()
      print("Collect process, total word len so far:", human_size(self.total_word_len), file=log.v3)
Пример #5
0
    def _callback(self, orth):
        orth_words = parse_orthography(orth,
                                       prefix=[],
                                       postfix=[],
                                       word_based=True)

        self.seq_count += 1

        if self.options.dump_orth:
            print("Orth:", orth_words, file=log.v3)
        self.words.update(orth_words)
        self.total_word_len += len(orth_words)

        # Show some progress if it takes long.
        if time.time() - self.process_last_time > 2:
            self.process_last_time = time.time()
            print("Collect process, total word len so far:",
                  human_size(self.total_word_len),
                  file=log.v3)
Пример #6
0
    def _collect_single_seq(self, seq_idx):
        """
    :type seq_idx: int
    :rtype: DatasetSeq | None
    :returns DatasetSeq or None if seq_idx >= num_seqs.
    """
        while True:
            if self.next_orth_idx >= len(self.orths_epoch):
                assert self.next_seq_idx <= seq_idx, "We expect that we iterate through all seqs."
                if self.num_skipped > 0:
                    print("LmDataset: reached end, skipped %i sequences" %
                          self.num_skipped)
                return None
            assert self.next_seq_idx == seq_idx, "We expect that we iterate through all seqs."
            orth = self.orths_epoch[self.seq_order[self.next_orth_idx]]
            self.next_orth_idx += 1
            if orth == "</s>":
                continue  # special sentence end symbol. empty seq, ignore.

            if self.seq_gen:
                try:
                    phones = self.seq_gen.generate_seq(orth)
                except KeyError as e:
                    if self.log_skipped_seqs:
                        print(
                            "LmDataset: skipping sequence %r because of missing lexicon entry: %s"
                            % (orth, e),
                            file=log.v4)
                        self._reduce_log_skipped_seqs()
                    if self.error_on_invalid_seq:
                        raise Exception(
                            "LmDataset: invalid seq %r, missing lexicon entry %r"
                            % (orth, e))
                    self.num_skipped += 1
                    continue  # try another seq
                data = self.seq_gen.seq_to_class_idxs(phones, dtype=self.dtype)

            elif self.orth_symbols:
                orth_syms = parse_orthography(orth, **self.parse_orth_opts)
                while True:
                    orth_syms = sum(
                        [self.orth_replace_map.get(s, [s]) for s in orth_syms],
                        [])
                    i = 0
                    while i < len(orth_syms) - 1:
                        if orth_syms[i:i + 2] == [" ", " "]:
                            orth_syms[i:i + 2] = [" "]  # collapse two spaces
                        else:
                            i += 1
                    if self.auto_replace_unknown_symbol:
                        try:
                            map(self.orth_symbols_map.__getitem__, orth_syms)
                        except KeyError as e:
                            orth_sym = e.message
                            if self.log_auto_replace_unknown_symbols:
                                print(
                                    "LmDataset: unknown orth symbol %r, adding to orth_replace_map as %r"
                                    % (orth_sym, self.unknown_symbol),
                                    file=log.v3)
                                self._reduce_log_auto_replace_unknown_symbols()
                            self.orth_replace_map[orth_sym] = [
                                self.unknown_symbol
                            ] if self.unknown_symbol is not None else []
                            continue  # try this seq again with updated orth_replace_map
                    break
                self.num_unknown += orth_syms.count(self.unknown_symbol)
                if self.word_based:
                    orth_debug_str = repr(orth_syms)
                else:
                    orth_debug_str = repr("".join(orth_syms))
                try:
                    data = numpy.array(map(self.orth_symbols_map.__getitem__,
                                           orth_syms),
                                       dtype=self.dtype)
                except KeyError as e:
                    if self.log_skipped_seqs:
                        print(
                            "LmDataset: skipping sequence %s because of missing orth symbol: %s"
                            % (orth_debug_str, e),
                            file=log.v4)
                        self._reduce_log_skipped_seqs()
                    if self.error_on_invalid_seq:
                        raise Exception(
                            "LmDataset: invalid seq %s, missing orth symbol %s"
                            % (orth_debug_str, e))
                    self.num_skipped += 1
                    continue  # try another seq

            else:
                assert False

            targets = {}
            for i in range(self.add_random_phone_seqs):
                assert self.seq_gen  # not implemented atm for orths
                phones = self.seq_gen.generate_garbage_seq(
                    target_len=data.shape[0])
                targets["random%i" % i] = self.seq_gen.seq_to_class_idxs(
                    phones, dtype=self.dtype)
            if self.add_delayed_seq_data:
                targets["delayed"] = numpy.concatenate(([
                    self.orth_symbols_map[self.delayed_seq_data_start_symbol]
                ], data[:-1])).astype(self.dtype)
                assert targets["delayed"].shape == data.shape
            self.next_seq_idx = seq_idx + 1
            return DatasetSeq(seq_idx=seq_idx, features=data, targets=targets)