def all_allophone_variations(self, phon, states=None, all_boundary_variations=False):
   """
   :param str phon:
   :param None|list[int] states: which states to yield for this phone
   :param bool all_boundary_variations:
   :return: yields AllophoneState's
   :rtype: list[AllophoneState]
   """
   if states is None:
     states = range(self.num_states_for_phone(phon))
   if all_boundary_variations:
     boundary_variations = [0, 1, 2, 3]
   else:
     boundary_variations = [0]
   for left_ctx in self._iter_possible_ctx(phon, -1):
     for right_ctx in self._iter_possible_ctx(phon, 1):
       for state in states:
         for boundary in boundary_variations:
           a = AllophoneState()
           a.id = phon
           a.context_history = left_ctx
           a.context_future = right_ctx
           a.state = state
           a.boundary = boundary
           if not all_boundary_variations:
             if not left_ctx: a.mark_initial()
             if not right_ctx: a.mark_final()
           yield a
示例#2
0
 def _allos_add_states(self, allos):
     for _a in allos:
         if _a.id == self.si_phone:
             yield _a
         else:  # non-silence
             for state in range(self.allo_num_states):
                 a = AllophoneState()
                 a.id = _a.id
                 a.context_history = _a.context_history
                 a.context_future = _a.context_future
                 a.boundary = _a.boundary
                 a.state = state
                 yield a
 def all_allophone_variations(self, phon, states=None):
     if states is None:
         states = range(self._num_states(phon))
     for left_ctx in self._iter_possible_ctx(phon, -1):
         for right_ctx in self._iter_possible_ctx(phon, 1):
             for state in states:
                 a = AllophoneState()
                 a.id = phon
                 a.context_history = left_ctx
                 a.context_future = right_ctx
                 a.state = state
                 a.boundary = 0
                 if not left_ctx: a.boundary |= 1  # initial
                 if not right_ctx: a.boundary |= 2  # final
                 yield a
 def _allos_add_states(self, allos):
   for _a in allos:
     if _a.id == self.si_phone:
       yield _a
     else:  # non-silence
       for state in range(self.allo_num_states):
         a = AllophoneState()
         a.id = _a.id
         a.context_history = _a.context_history
         a.context_future = _a.context_future
         a.boundary = _a.boundary
         a.state = state
         yield a
示例#5
0
def main():
    arg_parser = ArgumentParser()
    arg_parser.add_argument("--action")
    arg_parser.add_argument("--print_seq", action='store_true')
    arg_parser.add_argument("--print_allos", action='store_true')
    arg_parser.add_argument("--print_targets", action='store_true')
    arg_parser.add_argument("--dataset")
    arg_parser.add_argument("--corpus")
    arg_parser.add_argument("--lexicon", help="filename")
    arg_parser.add_argument("--silence", type=int, help="index")
    arg_parser.add_argument("--context", default=1, type=int)
    arg_parser.add_argument("--hmm_states", default=3, type=int)
    arg_parser.add_argument("--state_tying_type", help="'monophone' or 'full'")
    arg_parser.add_argument("--state_tying_output", help="filename")
    arg_parser.add_argument("--allo_add_all", action="store_true")
    args = arg_parser.parse_args()

    dataset = init_dataset_via_str(
        config_str=args.dataset) if args.dataset else None
    corpus = dict(iter_bliss_orth(
        filename=args.corpus)) if args.corpus else None
    lexicon = Lexicon(filename=args.lexicon) if args.lexicon else None
    silence_label = args.silence

    if args.action == "show_corpus":
        pprint(corpus)
        return

    print("Num phones: %i" % len(lexicon.phonemes), file=log.v1)
    print("Phones: %r" % sorted(lexicon.phonemes.keys()), file=log.v1)

    orth_handler = OrthHandler(lexicon=lexicon,
                               allo_context_len=args.context,
                               allo_num_states=args.hmm_states)
    map_idx_to_allo = defaultdict(set)  # type: dict[int, set[AllophoneState]]
    map_allo_to_idx = {}  # type: dict[AllophoneState, int]
    if args.allo_add_all:
        orth_handler.allo_add_all = True

    print("Num HMM states: %i" % orth_handler.allo_num_states, file=log.v1)
    if args.state_tying_type == "monophone":
        print("Monophone state tying.", file=log.v1)
        num_labels = orth_handler.expected_num_labels_for_monophone_state_tying(
        )
        all_label_idx_are_used = True
    elif args.state_tying_type == "full":
        print("Full state tying.", file=log.v1)
        phone_idxs = {k: i + 1
                      for (i, k) in enumerate(lexicon.phoneme_list)
                      }  # +1 to keep 0 reserved as the term-symbol
        for phon in lexicon.phoneme_list:
            for allo in orth_handler.all_allophone_variations(
                    phon, all_boundary_variations=True):
                allo_idx = allo.index(
                    phone_idxs=phone_idxs,
                    num_states=orth_handler.allo_num_states,
                    context_length=orth_handler.allo_context_len)
                map_idx_to_allo[allo_idx].add(allo)
        num_labels = max(map_idx_to_allo.keys()) + 1
        all_label_idx_are_used = False
    else:
        raise Exception("invalid state tying type %r" % args.state_tying_type)
    print("Num labels: %i" % num_labels, file=log.v1)

    if dataset:
        count = 0
        for segment_name, targets in iter_dataset_targets(dataset):
            count += 1
            if silence_label is None or count == 1:
                likely_silence_label = collections.Counter(
                    targets).most_common(1)[0][0]
                if silence_label is None:
                    silence_label = likely_silence_label
                if silence_label != likely_silence_label:
                    print("warning: silence %i but likely %i" %
                          (silence_label, likely_silence_label),
                          file=log.v2)
                print("Silence label: %i" % silence_label, file=log.v1)
                orth_handler.si_label = silence_label
                # Monophone state tying:
                for allo in orth_handler.all_allophone_variations(
                        orth_handler.si_phone):
                    map_idx_to_allo[silence_label].add(allo)
                    map_allo_to_idx[allo] = silence_label
            assert segment_name in corpus
            orth = corpus[segment_name]
            allo_states = orth_handler.orth_to_allophone_states(orth=orth)
            if args.print_seq:
                print("%r %r" % (segment_name, orth))
            if args.print_allos:
                print("  allophone state seq: %r" % allo_states)
            tgt_seq = [t for t in uniq(targets) if t != silence_label]
            if args.print_targets:
                print("  target seq: %r" % (tgt_seq, ))
            assert len(allo_states) == len(tgt_seq), "check --hmm_states or so"
            for allo, t in zip(allo_states, tgt_seq):
                allo.boundary = 0  # do not differ between boundaries
                allos = map_idx_to_allo[t]
                if allo in map_allo_to_idx:
                    assert allo in allos, "bad mapping"
                else:
                    assert allo not in allos
                    allos.add(allo)
                    map_allo_to_idx[allo] = t
            if len(map_idx_to_allo) >= num_labels:
                assert len(map_idx_to_allo) == num_labels
                assert 0 in map_idx_to_allo
                assert num_labels - 1 in map_idx_to_allo
                print("Finished with uniq mapping after %i sequences." % count,
                      file=log.v1)
                break
            if count % 100 == 0:
                print("Have indices: %i (num labels: %i)" %
                      (len(map_idx_to_allo), num_labels),
                      file=log.v1)

        print("Finished. Have indices: %i (num labels: %i)" %
              (len(map_idx_to_allo), num_labels),
              file=log.v1)
        if len(map_idx_to_allo) < num_labels:
            found = []
            not_found = []
            for p in sorted(lexicon.phonemes.keys()):
                allo = AllophoneState(p, state=0)
                if allo in map_allo_to_idx:
                    found.append(p)
                else:
                    not_found.append(p)
            print("Phonemes found: %r" % found)
            print("Phonemes not found: %r" % not_found)

    if args.state_tying_output:
        assert not os.path.exists(args.state_tying_output)
        if all_label_idx_are_used:
            assert len(map_idx_to_allo) == num_labels
            assert 0 in map_idx_to_allo
            assert num_labels - 1 in map_idx_to_allo
        f = open(args.state_tying_output, "w")
        for i, allos in sorted(map_idx_to_allo.items()):
            for allo in allos:
                f.write("%s %i\n" % (allo.format(), i))
        f.close()
        print("Wrote state tying to %r." % args.state_tying_output,
              file=log.v1)

    print("The end.")
示例#6
0
 def _phones_to_allos(self, phones):
     for p in phones:
         a = AllophoneState()
         a.id = p
         yield a
示例#7
0
 def all_allophone_variations(self,
                              phon,
                              states=None,
                              all_boundary_variations=False):
     """
 :param str phon:
 :param None|list[int] states: which states to yield for this phone
 :param bool all_boundary_variations:
 :return: yields AllophoneState's
 :rtype: list[AllophoneState]
 """
     if states is None:
         states = range(self.num_states_for_phone(phon))
     if all_boundary_variations:
         boundary_variations = [0, 1, 2, 3]
     else:
         boundary_variations = [0]
     for left_ctx in self._iter_possible_ctx(phon, -1):
         for right_ctx in self._iter_possible_ctx(phon, 1):
             for state in states:
                 for boundary in boundary_variations:
                     a = AllophoneState()
                     a.id = phon
                     a.context_history = left_ctx
                     a.context_future = right_ctx
                     a.state = state
                     a.boundary = boundary
                     if not all_boundary_variations:
                         if not left_ctx: a.mark_initial()
                         if not right_ctx: a.mark_final()
                     yield a
def main():
  arg_parser = ArgumentParser()
  arg_parser.add_argument("--action")
  arg_parser.add_argument("--print_seq", action='store_true')
  arg_parser.add_argument("--print_allos", action='store_true')
  arg_parser.add_argument("--print_targets", action='store_true')
  arg_parser.add_argument("--dataset")
  arg_parser.add_argument("--corpus")
  arg_parser.add_argument("--lexicon", help="filename")
  arg_parser.add_argument("--silence", type=int, help="index")
  arg_parser.add_argument("--context", default=1, type=int)
  arg_parser.add_argument("--hmm_states", default=3, type=int)
  arg_parser.add_argument("--state_tying_type", help="'monophone' or 'full'")
  arg_parser.add_argument("--state_tying_output", help="filename")
  arg_parser.add_argument("--allo_add_all", action="store_true")
  args = arg_parser.parse_args()

  dataset = init_dataset(config_str=args.dataset) if args.dataset else None
  corpus = dict(iter_bliss_orth(filename=args.corpus)) if args.corpus else None
  lexicon = Lexicon(filename=args.lexicon) if args.lexicon else None
  silence_label = args.silence

  if args.action == "show_corpus":
    pprint(corpus)
    return

  print("Num phones: %i" % len(lexicon.phonemes), file=log.v1)
  print("Phones: %r" % sorted(lexicon.phonemes.keys()), file=log.v1)

  orth_handler = OrthHandler(
    lexicon=lexicon,
    allo_context_len=args.context,
    allo_num_states=args.hmm_states)
  map_idx_to_allo = defaultdict(set)  # type: dict[int, set[AllophoneState]]
  map_allo_to_idx = {}  # type: dict[AllophoneState, int]
  if args.allo_add_all:
    orth_handler.allo_add_all = True

  print("Num HMM states: %i" % orth_handler.allo_num_states, file=log.v1)
  if args.state_tying_type == "monophone":
    print("Monophone state tying.", file=log.v1)
    num_labels = orth_handler.expected_num_labels_for_monophone_state_tying()
    all_label_idx_are_used = True
  elif args.state_tying_type == "full":
    print("Full state tying.", file=log.v1)
    phone_idxs = {k: i + 1 for (i, k) in enumerate(lexicon.phoneme_list)}  # +1 to keep 0 reserved as the term-symbol
    for phon in lexicon.phoneme_list:
      for allo in orth_handler.all_allophone_variations(phon, all_boundary_variations=True):
        allo_idx = allo.index(
          phone_idxs=phone_idxs,
          num_states=orth_handler.allo_num_states,
          context_length=orth_handler.allo_context_len)
        map_idx_to_allo[allo_idx].add(allo)
    num_labels = max(map_idx_to_allo.keys()) + 1
    all_label_idx_are_used = False
  else:
    raise Exception("invalid state tying type %r" % args.state_tying_type)
  print("Num labels: %i" % num_labels, file=log.v1)

  if dataset:
    count = 0
    for segment_name, targets in iter_dataset_targets(dataset):
      count += 1
      if silence_label is None or count == 1:
        likely_silence_label = collections.Counter(targets).most_common(1)[0][0]
        if silence_label is None:
          silence_label = likely_silence_label
        if silence_label != likely_silence_label:
          print("warning: silence %i but likely %i" % (silence_label, likely_silence_label), file=log.v2)
        print("Silence label: %i" % silence_label, file=log.v1)
        orth_handler.si_label = silence_label
        # Monophone state tying:
        for allo in orth_handler.all_allophone_variations(orth_handler.si_phone):
          map_idx_to_allo[silence_label].add(allo)
          map_allo_to_idx[allo] = silence_label
      assert segment_name in corpus
      orth = corpus[segment_name]
      allo_states = orth_handler.orth_to_allophone_states(orth=orth)
      if args.print_seq:
        print("%r %r" % (segment_name, orth))
      if args.print_allos:
        print("  allophone state seq: %r" % allo_states)
      tgt_seq = [t for t in uniq(targets) if t != silence_label]
      if args.print_targets:
        print("  target seq: %r" % (tgt_seq,))
      assert len(allo_states) == len(tgt_seq), "check --hmm_states or so"
      for allo, t in zip(allo_states, tgt_seq):
        allo.boundary = 0  # do not differ between boundaries
        allos = map_idx_to_allo[t]
        if allo in map_allo_to_idx:
          assert allo in allos, "bad mapping"
        else:
          assert allo not in allos
          allos.add(allo)
          map_allo_to_idx[allo] = t
      if len(map_idx_to_allo) >= num_labels:
        assert len(map_idx_to_allo) == num_labels
        assert 0 in map_idx_to_allo
        assert num_labels - 1 in map_idx_to_allo
        print("Finished with uniq mapping after %i sequences." % count, file=log.v1)
        break
      if count % 100 == 0:
        print("Have indices: %i (num labels: %i)" % (len(map_idx_to_allo), num_labels), file=log.v1)

    print("Finished. Have indices: %i (num labels: %i)" % (len(map_idx_to_allo), num_labels), file=log.v1)
    if len(map_idx_to_allo) < num_labels:
      found = []
      not_found = []
      for p in sorted(lexicon.phonemes.keys()):
        allo = AllophoneState(p, state=0)
        if allo in map_allo_to_idx:
          found.append(p)
        else:
          not_found.append(p)
      print("Phonemes found: %r" % found)
      print("Phonemes not found: %r" % not_found)

  if args.state_tying_output:
    assert not os.path.exists(args.state_tying_output)
    if all_label_idx_are_used:
      assert len(map_idx_to_allo) == num_labels
      assert 0 in map_idx_to_allo
      assert num_labels - 1 in map_idx_to_allo
    f = open(args.state_tying_output, "w")
    for i, allos in sorted(map_idx_to_allo.items()):
      for allo in allos:
        f.write("%s %i\n" % (allo.format(), i))
    f.close()
    print("Wrote state tying to %r." % args.state_tying_output, file=log.v1)

  print("The end.")
 def _phones_to_allos(self, phones):
   for p in phones:
     a = AllophoneState()
     a.id = p
     yield a
def main():
    arg_parser = ArgumentParser()
    arg_parser.add_argument("--action")
    arg_parser.add_argument("--print_seq", action='store_true')
    arg_parser.add_argument("--print_allos", action='store_true')
    arg_parser.add_argument("--print_targets", action='store_true')
    arg_parser.add_argument("--dataset")
    arg_parser.add_argument("--corpus")
    arg_parser.add_argument("--lexicon")
    arg_parser.add_argument("--silence", type=int)
    arg_parser.add_argument("--context", default=1, type=int)
    arg_parser.add_argument("--hmm_states", default=3, type=int)
    arg_parser.add_argument("--state_tying_output")
    arg_parser.add_argument("--allo_add_all", action="store_true")
    args = arg_parser.parse_args()

    dataset = init_dataset_via_str(
        config_str=args.dataset) if args.dataset else None
    corpus = dict(iter_bliss_orth(
        filename=args.corpus)) if args.corpus else None
    lexicon = Lexicon(filename=args.lexicon) if args.lexicon else None
    silence_label = args.silence

    if args.action == "show_corpus":
        pprint(corpus)
        return

    print("Num phones: %i" % len(lexicon.phonemes), file=log.v1)
    print("Phones: %r" % sorted(lexicon.phonemes.keys()), file=log.v1)

    orth_handler = OrthHandler(lexicon=lexicon,
                               allo_context_len=args.context,
                               allo_num_states=args.hmm_states)
    map_idx_to_allo = defaultdict(set)  # type: dict[int, set[AllophoneState]]
    map_allo_to_idx = {}  # type: dict[AllophoneState, int]
    if args.allo_add_all:
        orth_handler.allo_add_all = True

    # NOTE: Assume monophone state tying for now!
    num_labels = orth_handler.expected_num_labels_for_monophone_state_tying()
    print("Num labels: %i" % num_labels, file=log.v1)

    count = 0
    for segment_name, targets in iter_dataset_targets(dataset):
        count += 1
        if silence_label is None or count == 1:
            likely_silence_label = collections.Counter(targets).most_common(
                1)[0][0]
            if silence_label is None:
                silence_label = likely_silence_label
            if silence_label != likely_silence_label:
                print("warning: silence %i but likely %i" %
                      (silence_label, likely_silence_label),
                      file=log.v2)
            print("Silence label: %i" % silence_label, file=log.v1)
            orth_handler.si_label = silence_label
            # Monophone state tying:
            for allo in orth_handler.all_allophone_variations(
                    orth_handler.si_phone):
                map_idx_to_allo[silence_label].add(allo)
                map_allo_to_idx[allo] = silence_label
        assert segment_name in corpus
        orth = corpus[segment_name]
        allo_states = orth_handler.orth_to_allophone_states(orth=orth)
        if args.print_seq:
            print("%r %r" % (segment_name, orth))
        if args.print_allos:
            print("  allophone state seq: %r" % allo_states)
        tgt_seq = [t for t in uniq(targets) if t != silence_label]
        if args.print_targets:
            print("  target seq: %r" % (tgt_seq, ))
        assert len(allo_states) == len(tgt_seq), "check --hmm_states or so"
        for allo, t in zip(allo_states, tgt_seq):
            allo.boundary = 0  # do not differ between boundaries
            allos = map_idx_to_allo[t]
            if allo in map_allo_to_idx:
                assert allo in allos, "bad mapping"
            else:
                assert allo not in allos
                allos.add(allo)
                map_allo_to_idx[allo] = t
        if len(map_idx_to_allo) >= num_labels:
            assert len(map_idx_to_allo) == num_labels
            assert 0 in map_idx_to_allo
            assert num_labels - 1 in map_idx_to_allo
            print("Finished with uniq mapping after %i sequences." % count,
                  file=log.v1)
            break
        if count % 100 == 0:
            print("Have indices: %i (num labels: %i)" %
                  (len(map_idx_to_allo), num_labels),
                  file=log.v1)

    print("Finished. Have indices: %i (num labels: %i)" %
          (len(map_idx_to_allo), num_labels),
          file=log.v1)
    if len(map_idx_to_allo) < num_labels:
        found = []
        not_found = []
        for p in sorted(lexicon.phonemes.keys()):
            allo = AllophoneState(p, state=0)
            if allo in map_allo_to_idx:
                found.append(p)
            else:
                not_found.append(p)
        print("Phonemes found: %r" % found)
        print("Phonemes not found: %r" % not_found)

    if args.state_tying_output:
        assert not os.path.exists(args.state_tying_output)
        assert len(map_idx_to_allo) == num_labels
        assert 0 in map_idx_to_allo
        assert num_labels - 1 in map_idx_to_allo
        f = open(args.state_tying_output, "w")
        for i in range(num_labels):
            phons = sorted(
                set([(allo.id, allo.state) for allo in map_idx_to_allo[i]]))
            assert len(phons) == 1
            phon, state = phons[0]
            for allo in orth_handler.all_allophone_variations(phon,
                                                              states=[state]):
                f.write("%s %i\n" % (allo, i))
        f.close()
        print("Wrote state tying to %r." % args.state_tying_output,
              file=log.v1)