Example #1
0
def init(config_filename, cmd_line_opts, dataset_config_str):
  """
  :param str config_filename: global config for CRNN
  :param list[str] cmd_line_opts: options for initConfig method
  :param str dataset_config_str: dataset via init_dataset_via_str()
  """
  rnn.initBetterExchook()
  rnn.initThreadJoinHack()
  if config_filename:
    rnn.initConfig(config_filename, cmd_line_opts)
    rnn.initLog()
  else:
    log.initialize(verbosity=[5])
  print >> log.v3, "CRNN dump-dataset starting up."
  rnn.initFaulthandler()
  rnn.initConfigJsonNetwork()
  if config_filename:
    rnn.initData()
    rnn.printTaskProperties()
    assert isinstance(rnn.train_data, Dataset)
    return rnn.train_data
  else:
    assert dataset_config_str
    dataset = init_dataset_via_str(dataset_config_str)
    print >> log.v3, "Source dataset:", dataset.len_info()
    return dataset
Example #2
0
def init(config_filename, cmd_line_opts, dataset_config_str):
    """
  :param str config_filename: global config for CRNN
  :param list[str] cmd_line_opts: options for initConfig method
  :param str dataset_config_str: dataset via init_dataset_via_str()
  """
    rnn.initBetterExchook()
    rnn.initThreadJoinHack()
    if config_filename:
        rnn.initConfig(config_filename, cmd_line_opts)
        rnn.initLog()
    else:
        log.initialize(verbosity=[5])
    print("Returnn hdf_dump starting up.", file=log.v3)
    rnn.initFaulthandler()
    if config_filename:
        rnn.initData()
        rnn.printTaskProperties()
        assert isinstance(rnn.train_data, Dataset)
        return rnn.train_data
    else:
        assert dataset_config_str
        dataset = init_dataset_via_str(dataset_config_str)
        print("Source dataset:", dataset.len_info(), file=log.v3)
        return dataset
Example #3
0
def load_data(config, cache_byte_size, files_config_key, **kwargs):
  """
  :param Config config:
  :param int cache_byte_size:
  :param str files_config_key: such as "train" or "dev"
  :param kwargs: passed on to init_dataset() or init_dataset_via_str()
  :rtype: (Dataset,int)
  :returns the dataset, and the cache byte size left over if we cache the whole dataset.
  """
  if not config.bool_or_other(files_config_key, None):
    return None, 0
  kwargs = kwargs.copy()
  kwargs.setdefault("name", files_config_key)
  if config.is_typed(files_config_key) and isinstance(config.typed_value(files_config_key), dict):
    config_opts = config.typed_value(files_config_key)
    assert isinstance(config_opts, dict)
    kwargs.update(config_opts)
    if 'cache_byte_size' not in config_opts:
      if kwargs.get('class', None) == 'HDFDataset':
        kwargs["cache_byte_size"] = cache_byte_size
    Dataset.kwargs_update_from_config(config, kwargs)
    data = init_dataset(kwargs)
  else:
    config_str = config.value(files_config_key, "")
    data = init_dataset_via_str(config_str, config=config, cache_byte_size=cache_byte_size, **kwargs)
  cache_leftover = 0
  if isinstance(data, HDFDataset):
    cache_leftover = data.definite_cache_leftover
  return data, cache_leftover
Example #4
0
 def load_data(self, configT, cache_byte_size, files_config_key, **kwargs):
   """
   :param Config config:
   :param int cache_byte_size:
   :param str files_config_key: such as "train" or "dev"
   :param kwargs: passed on to init_dataset() or init_dataset_via_str()
   :rtype: (Dataset,int)
   :returns the dataset, and the cache byte size left over if we cache the whole dataset.
   """
   if not configT.has(files_config_key):
     return None, 0
   kwargs = kwargs.copy()
   kwargs.setdefault("name", files_config_key)
   #error somewhere here
   if configT.is_typed(files_config_key) and isinstance(configT.typed_value(files_config_key), dict):
     config_opts = configT.typed_value(files_config_key)
     assert isinstance(config_opts, dict)
     kwargs.update(config_opts)
     if 'cache_byte_size' not in config_opts:
       if kwargs.get('class', None) == 'HDFDataset':
         kwargs["cache_byte_size"] = cache_byte_size
     Dataset.kwargs_update_from_config(configT, kwargs)
     data = init_dataset(kwargs)
   else:
     config_str = configT.value(files_config_key, "")
     data = init_dataset_via_str(config_str, config=configT, cache_byte_size=cache_byte_size, **kwargs)
   cache_leftover = 0
   if isinstance(data, HDFDataset):
     cache_leftover = data.definite_cache_leftover
   return data, cache_leftover
Example #5
0
def load_data(config, cache_byte_size, files_config_key, **kwargs):
  """
  :type config: Config
  :type cache_byte_size: int
  :type chunking: str
  :type seq_ordering: str
  :rtype: (Dataset,int)
  :returns the dataset, and the cache byte size left over if we cache the whole dataset.
  """
  if not config.has(files_config_key):
    return None, 0
  if config.is_typed(files_config_key) and isinstance(config.typed_value(files_config_key), dict):
    new_kwargs = config.typed_value(files_config_key)
    assert isinstance(new_kwargs, dict)
    kwargs.update(new_kwargs)
    if 'cache_byte_size' not in new_kwargs:
      if kwargs.get('class', None) == 'HDFDataset':
        kwargs["cache_byte_size"] = cache_byte_size
    Dataset.kwargs_update_from_config(config, kwargs)
    data = init_dataset(kwargs)
  else:
    config_str = config.value(files_config_key, "")
    data = init_dataset_via_str(config_str, config=config, cache_byte_size=cache_byte_size, **kwargs)
  cache_leftover = 0
  if isinstance(data, HDFDataset):
    cache_leftover = data.definite_cache_leftover
  return data, cache_leftover
Example #6
0
def main():
    arg_parser = ArgumentParser()
    arg_parser.add_argument("--action")
    arg_parser.add_argument("--print_seq", action='store_true')
    arg_parser.add_argument("--print_allos", action='store_true')
    arg_parser.add_argument("--print_targets", action='store_true')
    arg_parser.add_argument("--dataset")
    arg_parser.add_argument("--corpus")
    arg_parser.add_argument("--lexicon", help="filename")
    arg_parser.add_argument("--silence", type=int, help="index")
    arg_parser.add_argument("--context", default=1, type=int)
    arg_parser.add_argument("--hmm_states", default=3, type=int)
    arg_parser.add_argument("--state_tying_type", help="'monophone' or 'full'")
    arg_parser.add_argument("--state_tying_output", help="filename")
    arg_parser.add_argument("--allo_add_all", action="store_true")
    args = arg_parser.parse_args()

    dataset = init_dataset_via_str(
        config_str=args.dataset) if args.dataset else None
    corpus = dict(iter_bliss_orth(
        filename=args.corpus)) if args.corpus else None
    lexicon = Lexicon(filename=args.lexicon) if args.lexicon else None
    silence_label = args.silence

    if args.action == "show_corpus":
        pprint(corpus)
        return

    print("Num phones: %i" % len(lexicon.phonemes), file=log.v1)
    print("Phones: %r" % sorted(lexicon.phonemes.keys()), file=log.v1)

    orth_handler = OrthHandler(lexicon=lexicon,
                               allo_context_len=args.context,
                               allo_num_states=args.hmm_states)
    map_idx_to_allo = defaultdict(set)  # type: dict[int, set[AllophoneState]]
    map_allo_to_idx = {}  # type: dict[AllophoneState, int]
    if args.allo_add_all:
        orth_handler.allo_add_all = True

    print("Num HMM states: %i" % orth_handler.allo_num_states, file=log.v1)
    if args.state_tying_type == "monophone":
        print("Monophone state tying.", file=log.v1)
        num_labels = orth_handler.expected_num_labels_for_monophone_state_tying(
        )
        all_label_idx_are_used = True
    elif args.state_tying_type == "full":
        print("Full state tying.", file=log.v1)
        phone_idxs = {k: i + 1
                      for (i, k) in enumerate(lexicon.phoneme_list)
                      }  # +1 to keep 0 reserved as the term-symbol
        for phon in lexicon.phoneme_list:
            for allo in orth_handler.all_allophone_variations(
                    phon, all_boundary_variations=True):
                allo_idx = allo.index(
                    phone_idxs=phone_idxs,
                    num_states=orth_handler.allo_num_states,
                    context_length=orth_handler.allo_context_len)
                map_idx_to_allo[allo_idx].add(allo)
        num_labels = max(map_idx_to_allo.keys()) + 1
        all_label_idx_are_used = False
    else:
        raise Exception("invalid state tying type %r" % args.state_tying_type)
    print("Num labels: %i" % num_labels, file=log.v1)

    if dataset:
        count = 0
        for segment_name, targets in iter_dataset_targets(dataset):
            count += 1
            if silence_label is None or count == 1:
                likely_silence_label = collections.Counter(
                    targets).most_common(1)[0][0]
                if silence_label is None:
                    silence_label = likely_silence_label
                if silence_label != likely_silence_label:
                    print("warning: silence %i but likely %i" %
                          (silence_label, likely_silence_label),
                          file=log.v2)
                print("Silence label: %i" % silence_label, file=log.v1)
                orth_handler.si_label = silence_label
                # Monophone state tying:
                for allo in orth_handler.all_allophone_variations(
                        orth_handler.si_phone):
                    map_idx_to_allo[silence_label].add(allo)
                    map_allo_to_idx[allo] = silence_label
            assert segment_name in corpus
            orth = corpus[segment_name]
            allo_states = orth_handler.orth_to_allophone_states(orth=orth)
            if args.print_seq:
                print("%r %r" % (segment_name, orth))
            if args.print_allos:
                print("  allophone state seq: %r" % allo_states)
            tgt_seq = [t for t in uniq(targets) if t != silence_label]
            if args.print_targets:
                print("  target seq: %r" % (tgt_seq, ))
            assert len(allo_states) == len(tgt_seq), "check --hmm_states or so"
            for allo, t in zip(allo_states, tgt_seq):
                allo.boundary = 0  # do not differ between boundaries
                allos = map_idx_to_allo[t]
                if allo in map_allo_to_idx:
                    assert allo in allos, "bad mapping"
                else:
                    assert allo not in allos
                    allos.add(allo)
                    map_allo_to_idx[allo] = t
            if len(map_idx_to_allo) >= num_labels:
                assert len(map_idx_to_allo) == num_labels
                assert 0 in map_idx_to_allo
                assert num_labels - 1 in map_idx_to_allo
                print("Finished with uniq mapping after %i sequences." % count,
                      file=log.v1)
                break
            if count % 100 == 0:
                print("Have indices: %i (num labels: %i)" %
                      (len(map_idx_to_allo), num_labels),
                      file=log.v1)

        print("Finished. Have indices: %i (num labels: %i)" %
              (len(map_idx_to_allo), num_labels),
              file=log.v1)
        if len(map_idx_to_allo) < num_labels:
            found = []
            not_found = []
            for p in sorted(lexicon.phonemes.keys()):
                allo = AllophoneState(p, state=0)
                if allo in map_allo_to_idx:
                    found.append(p)
                else:
                    not_found.append(p)
            print("Phonemes found: %r" % found)
            print("Phonemes not found: %r" % not_found)

    if args.state_tying_output:
        assert not os.path.exists(args.state_tying_output)
        if all_label_idx_are_used:
            assert len(map_idx_to_allo) == num_labels
            assert 0 in map_idx_to_allo
            assert num_labels - 1 in map_idx_to_allo
        f = open(args.state_tying_output, "w")
        for i, allos in sorted(map_idx_to_allo.items()):
            for allo in allos:
                f.write("%s %i\n" % (allo.format(), i))
        f.close()
        print("Wrote state tying to %r." % args.state_tying_output,
              file=log.v1)

    print("The end.")
def main():
    arg_parser = ArgumentParser()
    arg_parser.add_argument("--action")
    arg_parser.add_argument("--print_seq", action='store_true')
    arg_parser.add_argument("--print_allos", action='store_true')
    arg_parser.add_argument("--print_targets", action='store_true')
    arg_parser.add_argument("--dataset")
    arg_parser.add_argument("--corpus")
    arg_parser.add_argument("--lexicon")
    arg_parser.add_argument("--silence", type=int)
    arg_parser.add_argument("--context", default=1, type=int)
    arg_parser.add_argument("--hmm_states", default=3, type=int)
    arg_parser.add_argument("--state_tying_output")
    arg_parser.add_argument("--allo_add_all", action="store_true")
    args = arg_parser.parse_args()

    dataset = init_dataset_via_str(
        config_str=args.dataset) if args.dataset else None
    corpus = dict(iter_bliss_orth(
        filename=args.corpus)) if args.corpus else None
    lexicon = Lexicon(filename=args.lexicon) if args.lexicon else None
    silence_label = args.silence

    if args.action == "show_corpus":
        pprint(corpus)
        return

    print("Num phones: %i" % len(lexicon.phonemes), file=log.v1)
    print("Phones: %r" % sorted(lexicon.phonemes.keys()), file=log.v1)

    orth_handler = OrthHandler(lexicon=lexicon,
                               allo_context_len=args.context,
                               allo_num_states=args.hmm_states)
    map_idx_to_allo = defaultdict(set)  # type: dict[int, set[AllophoneState]]
    map_allo_to_idx = {}  # type: dict[AllophoneState, int]
    if args.allo_add_all:
        orth_handler.allo_add_all = True

    # NOTE: Assume monophone state tying for now!
    num_labels = orth_handler.expected_num_labels_for_monophone_state_tying()
    print("Num labels: %i" % num_labels, file=log.v1)

    count = 0
    for segment_name, targets in iter_dataset_targets(dataset):
        count += 1
        if silence_label is None or count == 1:
            likely_silence_label = collections.Counter(targets).most_common(
                1)[0][0]
            if silence_label is None:
                silence_label = likely_silence_label
            if silence_label != likely_silence_label:
                print("warning: silence %i but likely %i" %
                      (silence_label, likely_silence_label),
                      file=log.v2)
            print("Silence label: %i" % silence_label, file=log.v1)
            orth_handler.si_label = silence_label
            # Monophone state tying:
            for allo in orth_handler.all_allophone_variations(
                    orth_handler.si_phone):
                map_idx_to_allo[silence_label].add(allo)
                map_allo_to_idx[allo] = silence_label
        assert segment_name in corpus
        orth = corpus[segment_name]
        allo_states = orth_handler.orth_to_allophone_states(orth=orth)
        if args.print_seq:
            print("%r %r" % (segment_name, orth))
        if args.print_allos:
            print("  allophone state seq: %r" % allo_states)
        tgt_seq = [t for t in uniq(targets) if t != silence_label]
        if args.print_targets:
            print("  target seq: %r" % (tgt_seq, ))
        assert len(allo_states) == len(tgt_seq), "check --hmm_states or so"
        for allo, t in zip(allo_states, tgt_seq):
            allo.boundary = 0  # do not differ between boundaries
            allos = map_idx_to_allo[t]
            if allo in map_allo_to_idx:
                assert allo in allos, "bad mapping"
            else:
                assert allo not in allos
                allos.add(allo)
                map_allo_to_idx[allo] = t
        if len(map_idx_to_allo) >= num_labels:
            assert len(map_idx_to_allo) == num_labels
            assert 0 in map_idx_to_allo
            assert num_labels - 1 in map_idx_to_allo
            print("Finished with uniq mapping after %i sequences." % count,
                  file=log.v1)
            break
        if count % 100 == 0:
            print("Have indices: %i (num labels: %i)" %
                  (len(map_idx_to_allo), num_labels),
                  file=log.v1)

    print("Finished. Have indices: %i (num labels: %i)" %
          (len(map_idx_to_allo), num_labels),
          file=log.v1)
    if len(map_idx_to_allo) < num_labels:
        found = []
        not_found = []
        for p in sorted(lexicon.phonemes.keys()):
            allo = AllophoneState(p, state=0)
            if allo in map_allo_to_idx:
                found.append(p)
            else:
                not_found.append(p)
        print("Phonemes found: %r" % found)
        print("Phonemes not found: %r" % not_found)

    if args.state_tying_output:
        assert not os.path.exists(args.state_tying_output)
        assert len(map_idx_to_allo) == num_labels
        assert 0 in map_idx_to_allo
        assert num_labels - 1 in map_idx_to_allo
        f = open(args.state_tying_output, "w")
        for i in range(num_labels):
            phons = sorted(
                set([(allo.id, allo.state) for allo in map_idx_to_allo[i]]))
            assert len(phons) == 1
            phon, state = phons[0]
            for allo in orth_handler.all_allophone_variations(phon,
                                                              states=[state]):
                f.write("%s %i\n" % (allo, i))
        f.close()
        print("Wrote state tying to %r." % args.state_tying_output,
              file=log.v1)