示例#1
0
  def __init__(self,
               corpus_file, phone_info=None, orth_symbols_file=None, orth_replace_map_file=None,
               add_random_phone_seqs=0,
               partition_epoch=1,
               log_skipped_seqs=False, **kwargs):
    """
    :param str corpus_file: Bliss XML or line-based txt. optionally can be gzip.
    :param dict | None phone_info: if you want to get phone seqs, dict with lexicon_file etc. see _PhoneSeqGenerator
    :param str | None orth_symbols_file: list of orthography symbols, if you want to get orth symbol seqs
    :param str | None orth_replace_map_file: JSON file with replacement dict for orth symbols
    :param int add_random_phone_seqs: will add random seqs with the same len as the real seq as additional data
    :param bool log_skipped_seqs: log skipped seqs
    """
    super(LmDataset, self).__init__(**kwargs)

    if orth_symbols_file:
      assert not phone_info
      orth_symbols = open(orth_symbols_file).read().splitlines()
      self.orth_symbols_map = {sym: i for (i, sym) in enumerate(orth_symbols)}
      self.orth_symbols = orth_symbols
      self.labels["data"] = orth_symbols
      self.seq_gen = None
    else:
      assert not orth_symbols_file
      assert isinstance(phone_info, dict)
      self.seq_gen = _PhoneSeqGenerator(**phone_info)
      self.orth_symbols = None
      self.labels["data"] = self.seq_gen.get_class_labels()
    if orth_replace_map_file:
      orth_replace_map = load_json(filename=orth_replace_map_file)
      assert isinstance(orth_replace_map, dict)
      self.orth_replace_map = {key: parse_orthography_into_symbols(v)
                               for (key, v) in orth_replace_map.items()}
    else:
      self.orth_replace_map = {}

    if len(self.labels["data"]) <= 256:
      self.dtype = "int8"
    else:
      self.dtype = "int32"
    self.num_outputs = {"data": [len(self.labels["data"]), 1]}
    self.num_inputs = self.num_outputs["data"][0]
    self.seq_order = None
    self.log_skipped_seqs = log_skipped_seqs
    self.partition_epoch = partition_epoch
    self.add_random_phone_seqs = add_random_phone_seqs
    for i in range(add_random_phone_seqs):
      self.num_outputs["random%i" % i] = self.num_outputs["data"]

    if _is_bliss(corpus_file):
      iter_f = _iter_bliss
    else:
      iter_f = _iter_txt
    self.orths = []
    print >> log.v4, "LmDataset, loading file", corpus_file
    iter_f(corpus_file, self.orths.append)
    # It's only estimated because we might filter some out or so.
    self._estimated_num_seqs = len(self.orths) // self.partition_epoch
示例#2
0
  def __init__(self,
               seq_list_file, seq_lens_file,
               datasets,
               data_map, data_dims,
               data_dtypes=None,
               window=1, **kwargs):
    """
    :param str seq_list_file: filename. line-separated
    :param str seq_lens_file: filename. json. dict[str,dict[str,int]], seq-tag -> data-key -> len
    :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
    :param dict[str,(str,str)] data_map: self-data-key -> (dataset-key, dataset-data-key).
      Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
    :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
    :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified
    """
    assert window == 1  # not implemented
    super(MetaDataset, self).__init__(**kwargs)
    assert self.shuffle_frames_of_nseqs == 0  # not implemented. anyway only for non-recurrent nets

    self.seq_list_original = open(seq_list_file).read().splitlines()
    self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original)}
    self._num_seqs = len(self.seq_list_original)

    self.data_map = data_map
    self.dataset_keys = set([m[0] for m in self.data_map.values()]); ":type: set[str]"
    self.data_keys = set(self.data_map.keys()); ":type: set[str]"
    assert "data" in self.data_keys
    self.target_list = sorted(self.data_keys - ["data"])

    data_dims = convert_data_dims(data_dims)
    self.data_dims = data_dims
    assert "data" in data_dims
    for key in self.target_list:
      assert key in data_dims
    self.num_inputs = data_dims["data"][0]
    self.num_outputs = data_dims

    self.data_dtypes = {data_key: _select_dtype(data_key, data_dims, data_dtypes) for data_key in self.data_keys}

    if seq_lens_file:
      seq_lens = load_json(filename=seq_lens_file)
      assert isinstance(seq_lens, dict)
      # dict[str,NumbersDict], seq-tag -> data-key -> len
      self._seq_lens = {tag: NumbersDict(l) for (tag, l) in seq_lens.items()}
    else:
      self._seq_lens = None

    if self._seq_lens:
      self._num_timesteps = sum([self._seq_lens[s] for s in self.seq_list_original])
    else:
      self._num_timesteps = None

    # Will only init the needed datasets.
    self.datasets = {key: init_dataset(datasets[key]) for key in self.dataset_keys}
示例#3
0
  def __init__(self,
               seq_list_file, seq_lens_file,
               datasets,
               data_map, data_dims,
               data_dtypes=None,
               window=1, **kwargs):
    """
    :param str seq_list_file: filename. line-separated
    :param str seq_lens_file: filename. json. dict[str,dict[str,int]], seq-tag -> data-key -> len
    :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
    :param dict[str,(str,str)] data_map: self-data-key -> (dataset-key, dataset-data-key).
      Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
    :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
    :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified
    """
    assert window == 1  # not implemented
    super(MetaDataset, self).__init__(**kwargs)
    assert self.shuffle_frames_of_nseqs == 0  # not implemented. anyway only for non-recurrent nets

    self.seq_list_original = open(seq_list_file).read().splitlines()
    self.tag_idx = {tag: idx for (idx, tag) in enumerate(self.seq_list_original)}
    self._num_seqs = len(self.seq_list_original)

    self.data_map = data_map
    self.dataset_keys = set([m[0] for m in self.data_map.values()]); ":type: set[str]"
    self.data_keys = set(self.data_map.keys()); ":type: set[str]"
    assert "data" in self.data_keys
    self.target_list = sorted(self.data_keys - ["data"])

    data_dims = convert_data_dims(data_dims)
    self.data_dims = data_dims
    assert "data" in data_dims
    for key in self.target_list:
      assert key in data_dims
    self.num_inputs = data_dims["data"][0]
    self.num_outputs = data_dims

    self.data_dtypes = {data_key: _select_dtype(data_key, data_dims, data_dtypes) for data_key in self.data_keys}

    if seq_lens_file:
      seq_lens = load_json(filename=seq_lens_file)
      assert isinstance(seq_lens, dict)
      # dict[str,NumbersDict], seq-tag -> data-key -> len
      self._seq_lens = {tag: NumbersDict(l) for (tag, l) in seq_lens.items()}
    else:
      self._seq_lens = None

    if self._seq_lens:
      self._num_timesteps = sum([self._seq_lens[s] for s in self.seq_list_original])
    else:
      self._num_timesteps = None

    # Will only init the needed datasets.
    self.datasets = {key: init_dataset(datasets[key]) for key in self.dataset_keys}
示例#4
0
    def load_file(self, f):
        """
    Reads the configuration parameters from a file and adds them to the inner set of parameters.

    :param string|io.TextIOBase|io.StringIO f:
    """
        if isinstance(f, str):
            import os
            assert os.path.isfile(f), "config file not found: %r" % f
            self.files.append(f)
            filename = f
            content = open(filename).read()
        else:
            # assume stream-like
            filename = "<config string>"
            content = f.read()
        content = content.strip()
        if content.startswith("#!"):  # assume Python
            from Util import custom_exec
            # Operate inplace on ourselves.
            # Also, we want that it's available as the globals() dict, so that defined functions behave well
            # (they would loose the local context otherwise).
            user_ns = self.typed_dict
            # Always overwrite:
            user_ns.update({
                "config": self,
                "__file__": filename,
                "__name__": "__crnn_config__"
            })
            custom_exec(content, filename, user_ns, user_ns)
            return
        if content.startswith("{"):  # assume JSON
            from Util import load_json
            json_content = load_json(content=content)
            assert isinstance(json_content, dict)
            self.update(json_content)
            return
        # old line-based format
        for line in content.splitlines():
            if "#" in line:  # Strip away comment.
                line = line[:line.index("#")]
            line = line.strip()
            if not line:
                continue
            line = line.split(None, 1)
            assert len(line) == 2, "unable to parse config line: %r" % line
            self.add_line(key=line[0], value=line[1])
示例#5
0
文件: Config.py 项目: rwth-i6/returnn
  def load_file(self, f):
    """
    Reads the configuration parameters from a file and adds them to the inner set of parameters.

    :param string|io.TextIOBase|io.StringIO f:
    """
    if isinstance(f, str):
      import os
      assert os.path.isfile(f), "config file not found: %r" % f
      self.files.append(f)
      filename = f
      content = open(filename).read()
    else:
      # assume stream-like
      filename = "<config string>"
      content = f.read()
    content = content.strip()
    if content.startswith("#!"):  # assume Python
      from Util import custom_exec
      # Operate inplace on ourselves.
      # Also, we want that it's available as the globals() dict, so that defined functions behave well
      # (they would loose the local context otherwise).
      user_ns = self.typed_dict
      # Always overwrite:
      user_ns.update({"config": self, "__file__": filename, "__name__": "__crnn_config__"})
      custom_exec(content, filename, user_ns, user_ns)
      return
    if content.startswith("{"):  # assume JSON
      from Util import load_json
      json_content = load_json(content=content)
      assert isinstance(json_content, dict)
      self.update(json_content)
      return
    # old line-based format
    for line in content.splitlines():
      if "#" in line:  # Strip away comment.
        line = line[:line.index("#")]
      line = line.strip()
      if not line:
        continue
      line = line.split(None, 1)
      assert len(line) == 2, "unable to parse config line: %r" % line
      self.add_line(key=line[0], value=line[1])
示例#6
0
    def __init__(self,
                 datasets,
                 data_map,
                 data_dims,
                 seq_list_file,
                 seq_lens_file=None,
                 data_dtypes=None,
                 window=1,
                 **kwargs):
        """
    :param dict[str,dict[str]] datasets: dataset-key -> dataset-kwargs. including keyword 'class' and maybe 'files'
    :param dict[str,(str,str)] data_map: self-data-key -> (dataset-key, dataset-data-key).
      Should contain 'data' as key. Also defines the target-list, which is all except 'data'.
    :param dict[str,(int,int)] data_dims: self-data-key -> data-dimension, len(shape) (1 ==> sparse repr).
    :param str seq_list_file: filename. pickle. dict[str,list[str]], dataset-key -> list of sequence tags. If tag is the same for all datasets a line-separated plain text file can be used.
    :param str seq_lens_file: filename. json. dict[str,dict[str,int]], seq-tag -> data-key -> len. Use if getting sequence length from loading data is too costly.
    :param dict[str,str] data_dtypes: self-data-key -> dtype. automatic if not specified
    """
        assert window == 1  # not implemented
        super(MetaDataset, self).__init__(**kwargs)
        assert self.shuffle_frames_of_nseqs == 0  # not implemented. anyway only for non-recurrent nets

        self.data_map = data_map
        self.dataset_keys = set([m[0] for m in self.data_map.values()])
        ":type: set[str]"
        self.data_keys = set(self.data_map.keys())
        ":type: set[str]"
        assert "data" in self.data_keys
        self.target_list = sorted(self.data_keys - {"data"})
        self.default_dataset_key = self.data_map["data"][0]

        if seq_list_file.endswith(".pkl"):
            import pickle
            seq_list = pickle.load(open(seq_list_file, 'rb'))
        else:
            seq_list = open(seq_list_file).read().splitlines()
        assert isinstance(seq_list, (list, dict))
        if isinstance(seq_list, list):
            seq_list = {key: seq_list for key in self.dataset_keys}
        self.seq_list_original = seq_list  # type: dict[str,list[str]]  # dataset key -> seq list
        self._num_seqs = len(self.seq_list_original[self.default_dataset_key])
        for key in self.dataset_keys:
            assert len(self.seq_list_original[key]) == self._num_seqs
        self.tag_idx = {
            tag: idx
            for (idx, tag) in enumerate(self.seq_list_original[
                self.default_dataset_key])
        }

        data_dims = convert_data_dims(data_dims)
        self.data_dims = data_dims
        assert "data" in data_dims
        for key in self.target_list:
            assert key in data_dims
        self.num_inputs = data_dims["data"][0]
        self.num_outputs = data_dims

        self.data_dtypes = {
            data_key: _select_dtype(data_key, data_dims, data_dtypes)
            for data_key in self.data_keys
        }

        if seq_lens_file:
            seq_lens = load_json(filename=seq_lens_file)
            assert isinstance(seq_lens, dict)
            # dict[str,NumbersDict], seq-tag -> data-key -> len
            self._seq_lens = {
                tag: NumbersDict(l)
                for (tag, l) in seq_lens.items()
            }
        else:
            self._seq_lens = None

        if self._seq_lens:
            self._num_timesteps = sum([
                self._seq_lens[s]
                for s in self.seq_list_original[self.default_dataset_key]
            ])
        else:
            self._num_timesteps = None

        # Will only init the needed datasets.
        self.datasets = {
            key:
            init_dataset(datasets[key],
                         extra_kwargs={"name": "%s_%s" % (self.name, key)})
            for key in self.dataset_keys
        }
        for data_key in self.data_keys:
            dataset_key, dataset_data_key = self.data_map[data_key]
            dataset = self.datasets[dataset_key]
            if dataset_data_key in dataset.labels:
                self.labels[data_key] = dataset.labels[dataset_data_key]
示例#7
0
    def __init__(self,
                 corpus_file,
                 orth_symbols_file=None,
                 orth_symbols_map_file=None,
                 orth_replace_map_file=None,
                 word_based=False,
                 seq_end_symbol="[END]",
                 unknown_symbol="[UNKNOWN]",
                 parse_orth_opts=None,
                 phone_info=None,
                 add_random_phone_seqs=0,
                 partition_epoch=1,
                 auto_replace_unknown_symbol=False,
                 log_auto_replace_unknown_symbols=10,
                 log_skipped_seqs=10,
                 error_on_invalid_seq=True,
                 add_delayed_seq_data=False,
                 delayed_seq_data_start_symbol="[START]",
                 **kwargs):
        """
    :param str|()->str corpus_file: Bliss XML or line-based txt. optionally can be gzip.
    :param dict|None phone_info: if you want to get phone seqs, dict with lexicon_file etc. see PhoneSeqGenerator
    :param str|()->str|None orth_symbols_file: list of orthography symbols, if you want to get orth symbol seqs
    :param str|()->str|None orth_symbols_map_file: list of orth symbols, each line: "symbol index"
    :param str|()->str|None orth_replace_map_file: JSON file with replacement dict for orth symbols
    :param bool word_based: whether to parse single words, or otherwise will be char-based
    :param str|None seq_end_symbol: what to add at the end, if given.
      will be set as postfix=[seq_end_symbol] or postfix=[] for parse_orth_opts.
    :param dict[str]|None parse_orth_opts: kwargs for parse_orthography()
    :param int add_random_phone_seqs: will add random seqs with the same len as the real seq as additional data
    :param bool|int log_auto_replace_unknown_symbols: write about auto-replacements with unknown symbol.
      if this is an int, it will only log the first N replacements, and then keep quiet.
    :param bool|int log_skipped_seqs: write about skipped seqs to logging, due to missing lexicon entry or so.
      if this is an int, it will only log the first N entries, and then keep quiet.
    :param bool error_on_invalid_seq: if there is a seq we would have to skip, error
    :param bool add_delayed_seq_data: will add another data-key "delayed" which will have the sequence
      delayed_seq_data_start_symbol + original_sequence[:-1]
    :param str delayed_seq_data_start_symbol: used for add_delayed_seq_data
    :param int partition_epoch: whether to partition the epochs into multiple parts. like epoch_split
    """
        super(LmDataset, self).__init__(**kwargs)

        if callable(corpus_file):
            corpus_file = corpus_file()
        if callable(orth_symbols_file):
            orth_symbols_file = orth_symbols_file()
        if callable(orth_symbols_map_file):
            orth_symbols_map_file = orth_symbols_map_file()
        if callable(orth_replace_map_file):
            orth_replace_map_file = orth_replace_map_file()

        print("LmDataset, loading file", corpus_file, file=log.v4)

        self.word_based = word_based
        self.seq_end_symbol = seq_end_symbol
        self.unknown_symbol = unknown_symbol
        self.parse_orth_opts = parse_orth_opts or {}
        self.parse_orth_opts.setdefault("word_based", self.word_based)
        self.parse_orth_opts.setdefault(
            "postfix",
            [self.seq_end_symbol] if self.seq_end_symbol is not None else [])

        if orth_symbols_file:
            assert not phone_info
            assert not orth_symbols_map_file
            orth_symbols = open(orth_symbols_file).read().splitlines()
            self.orth_symbols_map = {
                sym: i
                for (i, sym) in enumerate(orth_symbols)
            }
            self.orth_symbols = orth_symbols
            self.labels["data"] = orth_symbols
            self.seq_gen = None
        elif orth_symbols_map_file:
            assert not phone_info
            orth_symbols_imap_list = [(int(b), a) for (a, b) in [
                l.split(None, 1)
                for l in open(orth_symbols_map_file).read().splitlines()
            ]]
            orth_symbols_imap_list.sort()
            assert orth_symbols_imap_list[0][0] == 0
            assert orth_symbols_imap_list[-1][0] == len(
                orth_symbols_imap_list) - 1
            self.orth_symbols_map = {
                sym: i
                for (i, sym) in orth_symbols_imap_list
            }
            self.orth_symbols = [sym for (i, sym) in orth_symbols_imap_list]
            self.labels["data"] = self.orth_symbols
            self.seq_gen = None
        else:
            assert not orth_symbols_file
            assert isinstance(phone_info, dict)
            self.seq_gen = PhoneSeqGenerator(**phone_info)
            self.orth_symbols = None
            self.labels["data"] = self.seq_gen.get_class_labels()
        if orth_replace_map_file:
            orth_replace_map = load_json(filename=orth_replace_map_file)
            assert isinstance(orth_replace_map, dict)
            self.orth_replace_map = {
                key: parse_orthography_into_symbols(v,
                                                    word_based=self.word_based)
                for (key, v) in orth_replace_map.items()
            }
            if self.orth_replace_map:
                if len(self.orth_replace_map) <= 5:
                    print("  orth_replace_map: %r" % self.orth_replace_map,
                          file=log.v5)
                else:
                    print("  orth_replace_map: %i entries" %
                          len(self.orth_replace_map),
                          file=log.v5)
        else:
            self.orth_replace_map = {}

        num_labels = len(self.labels["data"])
        use_uint_types = False
        if BackendEngine.is_tensorflow_selected():
            use_uint_types = True
        if num_labels <= 2**7:
            self.dtype = "int8"
        elif num_labels <= 2**8 and use_uint_types:
            self.dtype = "uint8"
        elif num_labels <= 2**31:
            self.dtype = "int32"
        elif num_labels <= 2**32 and use_uint_types:
            self.dtype = "uint32"
        elif num_labels <= 2**61:
            self.dtype = "int64"
        elif num_labels <= 2**62 and use_uint_types:
            self.dtype = "uint64"
        else:
            raise Exception("cannot handle so much labels: %i" % num_labels)
        self.num_outputs = {"data": [len(self.labels["data"]), 1]}
        self.num_inputs = self.num_outputs["data"][0]
        self.seq_order = None
        self.auto_replace_unknown_symbol = auto_replace_unknown_symbol
        self.log_auto_replace_unknown_symbols = log_auto_replace_unknown_symbols
        self.log_skipped_seqs = log_skipped_seqs
        self.error_on_invalid_seq = error_on_invalid_seq
        self.partition_epoch = partition_epoch
        self.add_random_phone_seqs = add_random_phone_seqs
        for i in range(add_random_phone_seqs):
            self.num_outputs["random%i" % i] = self.num_outputs["data"]
        self.add_delayed_seq_data = add_delayed_seq_data
        self.delayed_seq_data_start_symbol = delayed_seq_data_start_symbol
        if add_delayed_seq_data:
            self.num_outputs["delayed"] = self.num_outputs["data"]

        if _is_bliss(corpus_file):
            iter_f = _iter_bliss
        else:
            iter_f = _iter_txt
        self.orths = []
        iter_f(corpus_file, self.orths.append)
        # It's only estimated because we might filter some out or so.
        self._estimated_num_seqs = len(self.orths) // self.partition_epoch
        print("  done, loaded %i sequences" % len(self.orths), file=log.v4)
示例#8
0
 def __init__(self, program_path, patterns_path):
     self.program_json = load_json(program_path)
     self.patterns_json = load_json(patterns_path)
     self.flow_graph = Node('root')