def generate_seq(self, seq_idx):
     data = self.data[seq_idx]
     return DatasetSeq(
         seq_idx=seq_idx,
         features=data["data"],
         targets={target: data[target]
                  for target in self.target_list})
Esempio n. 2
0
    def _collect_single_seq(self, seq_idx):
        """
    :param int seq_idx: sequence id
    :rtype: DatasetSeq
    """
        if seq_idx >= len(self.seq_order):
            return None
        real_seq_index = self.seq_order[seq_idx]
        seq_name = self.all_seq_names[real_seq_index]

        curr_triplet = self.curr_epoch_triplets[seq_idx]
        targets = {}
        for id, sample in enumerate(curr_triplet):
            real_sample_seq_idx = sample
            sample_seq_name = self.all_seq_names[real_sample_seq_idx]
            sample_seq_file_index = self.file_indices[real_sample_seq_idx]
            norm_sample_seq_name = self._normalize_seq_name(sample_seq_name)
            for name, parsers in self.all_parsers.items():
                targets['%s_%d' %
                        (name, id)] = parsers[sample_seq_file_index].get_data(
                            norm_sample_seq_name)

        features = targets['%s_%d' % (self.input_stream_name, 0)]
        return DatasetSeq(seq_idx=seq_idx,
                          seq_tag=seq_name,
                          features=features,
                          targets=targets)
Esempio n. 3
0
 def _collect_single_seq(self, seq_idx):
   """
   :param int seq_idx:
   :rtype: DatasetSeq | None
   :returns DatasetSeq or None if seq_idx >= num_seqs.
   """
   assert self.cur_seq_list is not None, "call init_seq_order"
   if seq_idx >= len(self.cur_seq_list):
     return None
   seq_tag = self.cur_seq_list[seq_idx]
   sub_seq_tags = seq_tag.split(self.seq_tag_delim)
   sub_seq_idxs = self.cur_sub_seq_idxs[seq_idx]
   assert len(sub_seq_tags) == len(sub_seq_idxs)
   features = {key: [] for key in self.get_data_keys()}
   if seq_idx == 0:  # some extra check, but enough to do for first seq only
     sub_dataset_keys = self.sub_dataset.get_data_keys()
     for key in self.remove_in_between_postfix:
       assert key in sub_dataset_keys, "%s: remove_in_between_postfix key %r not in sub dataset data-keys %r" % (
         self, key, sub_dataset_keys)
   for sub_seq_idx, sub_seq_tag in zip(sub_seq_idxs, sub_seq_tags):
     self.sub_dataset.load_seqs(sub_seq_idx, sub_seq_idx + 1)
     sub_dataset_tag = self.sub_dataset.get_tag(sub_seq_idx)
     assert sub_dataset_tag == sub_seq_tag, "%s: expected tag %r for sub seq idx %i but got %r, part of seq %i %r" % (
       self, sub_seq_tag, sub_seq_idx, sub_dataset_tag, seq_idx, seq_tag)
     for key in self.get_data_keys():
       data = self.sub_dataset.get_data(sub_seq_idx, key)
       if key in self.remove_in_between_postfix and sub_seq_idx != sub_seq_idxs[-1]:
         assert data.ndim == 1 and data[-1] == self.remove_in_between_postfix[key]
         data = data[:-1]
       features[key].append(data)
   features = {key: numpy.concatenate(values, axis=0) for (key, values) in features.items()}
   return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features)
Esempio n. 4
0
 def generate_seq(self, seq_idx):
   seq_len = self.get_random_seq_len()
   input_seq = self.generate_input_seq(seq_len)
   output_seq = self.make_output_seq(input_seq)
   features = class_idx_seq_to_1_of_k(input_seq, num_classes=len(self._input_classes))
   targets = numpy.array(output_seq)
   return DatasetSeq(seq_idx=seq_idx, features=features, targets=targets)
Esempio n. 5
0
 def get_dataset_seq_for_name(self, name, seq_idx=-1):
   """
   :param str name:
   :param int seq_idx:
   :rtype: DatasetSeq
   """
   data = {key: d.read(name) for (key, d) in self.data.items()}  # type: typing.Dict[str,numpy.ndarray]
   return DatasetSeq(seq_idx=seq_idx, seq_tag=name, features=data["data"], targets=data)
Esempio n. 6
0
 def generate_seq(self, seq_idx):
     seq_len = self.seq_len
     i1 = seq_idx
     i2 = i1 + seq_len * self.num_inputs
     features = np.array(range(i1, i2)).reshape((seq_len, self.num_inputs))
     i1, i2 = i2, i2 + seq_len
     targets = np.array(range(i1, i2))
     return DatasetSeq(seq_idx=seq_idx, features=features, targets=targets)
Esempio n. 7
0
 def get_dataset_seq_for_name(self, name, seq_idx=-1):
     data = {key: d.read(name)
             for (key, d) in self.data.items()
             }  # type: dict[str,numpy.ndarray]
     return DatasetSeq(seq_idx=seq_idx,
                       seq_tag=name,
                       features=data["data"],
                       targets=data)
Esempio n. 8
0
 def _collect_single_seq(self, seq_idx):
   dataset_idx = self._get_dataset_for_seq_idx(seq_idx)
   dataset = self.datasets[dataset_idx]
   dataset_seq_idx = seq_idx + self.dataset_seq_idx_offsets[dataset_idx]
   seq_tag = dataset.get_tag(dataset_seq_idx)
   features = dataset.get_input_data(dataset_seq_idx)
   targets = {k: dataset.get_targets(k, dataset_seq_idx) for k in dataset.get_target_list()}
   return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features, targets=targets)
Esempio n. 9
0
 def generate_seq(self, seq_idx):
   """
   :type seq_idx: int
   :rtype: DatasetSeq
   """
   seq_len = self.get_random_seq_len()
   seq = [self.random.randint(0, self.nsymbols) for i in range(seq_len)]
   seq_np = numpy.array(seq, dtype="int8")
   return DatasetSeq(seq_idx=seq_idx, features=seq_np, targets={"classes": seq_np})
Esempio n. 10
0
 def _collect_single_seq(self, seq_idx):
   """
   :type seq_idx: int
   :rtype: DatasetSeq
   """
   seq_tag = self.seq_list_ordered[seq_idx]
   features = self._get_data(seq_idx, "data")
   targets = {target: self._get_data(seq_idx, target) for target in self.target_list}
   return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features, targets=targets)
Esempio n. 11
0
 def generate_seq(self, seq_idx):
   seq_len = self.seq_len
   i1 = seq_idx
   i2 = i1 + seq_len * self.num_inputs
   features = numpy.array([((i % self.input_max_value) + self.input_shift) * self.input_scale
                           for i in range(i1, i2)]).reshape((seq_len, self.num_inputs))
   i1, i2 = i2, i2 + seq_len
   targets = numpy.array([i % self.num_outputs["classes"][0]
                          for i in range(i1, i2)])
   return DatasetSeq(seq_idx=seq_idx, features=features, targets=targets)
Esempio n. 12
0
 def _collect_single_seq(self, seq_idx):
     seq_name = self.get_tag(seq_idx)
     #print >> log.v5, "ClusteringDataset: _collect_single_seq: seq_name", seq_name
     data = {
         key: self.dataset.get_data(seq_idx=seq_idx, key=key)
         for key in self.dataset.get_data_keys()
     }
     data["cluster_idx"] = numpy.array([self.cluster_map[seq_name]],
                                       dtype=self.cluster_idx_dtype)
     return DatasetSeq(seq_idx=seq_idx, features=data["data"], targets=data)
Esempio n. 13
0
 def _collect_single_seq(self, seq_idx):
     if seq_idx >= self._num_seqs:
         return None
     line_nr = self._get_line_nr(seq_idx)
     features = self._get_data(key="data", line_nr=line_nr)
     targets = self._get_data(key="classes", line_nr=line_nr)
     assert features is not None and targets is not None
     return DatasetSeq(seq_idx=seq_idx,
                       seq_tag="line-%i" % line_nr,
                       features=features,
                       targets=targets)
Esempio n. 14
0
    def _collect_single_seq(self, seq_idx):
        """Returns the sequence specified by the index seq_idx.
    Normalization is applied to the input features if mean and variance
    have been specified during dataset creating (see the constructor).

    :type seq_idx: int
    :rtype: DatasetSeq | None
    :returns: None if seq_idx >= num_seqs or the corresponding sequence.
    """
        if self._seq_index_list is None:
            self.init_seq_order()

        if seq_idx >= len(self._seq_index_list):
            return None

        # map the seq_idx to the shuffled sequence indices
        shuf_seq_idx = self._seq_index_list[seq_idx]
        partition_offset = int(
            np.sum([
                self._get_partition_size(i1)
                for i1 in range(self._current_partition)
            ]))
        shuf_seq_idx += partition_offset

        seqMapping = self._seqMap[shuf_seq_idx]
        fileIdx = seqMapping[0]
        datasetName = seqMapping[1]
        fileHandler = self._fileHandlers[fileIdx]
        inputFeatures = fileHandler['inputs'][datasetName][...]
        targets = None
        if 'outputs' in fileHandler:
            targets = fileHandler['outputs'][datasetName][...]

        # optional normalization
        if self._normData is not None:
            assert isinstance(self._normData, NormalizationData)
            if self._flag_normalizeInputs:
                inputFeatures = StereoHdfDataset._normalizeVector(
                    inputFeatures, self._normData.inputMean,
                    self._normData.inputVariance)
            if self._flag_normalizeTargets:
                targets = StereoHdfDataset._normalizeVector(
                    targets, self._normData.outputMean,
                    self._normData.outputVariance)

        # enforce float32 to enable Theano optimizations
        inputFeatures = inputFeatures.astype(np.float32)
        if (targets is not None) and targets.shape[1] > 1:
            targets = targets.astype(np.float32)
        elif targets.shape[1] == 1:
            targets = np.reshape(targets.astype(np.int32),
                                 (targets.shape[0], ))

        return DatasetSeq(seq_idx, inputFeatures, targets)
Esempio n. 15
0
 def producer_add_data(self, data, seq_tag=None):
   """
   :param numpy.ndarray data:
   :param str|None seq_tag:
   """
   with self.condition:
     if seq_tag is None:
       seq_tag = "seq-%i" % self.producer_seq_idx
     seq = DatasetSeq(features=data, seq_idx=self.producer_seq_idx, seq_tag=seq_tag)
     self.producer_seq_idx += 1
     self.producer_data.append(seq)
     self.condition.notify()
Esempio n. 16
0
 def _add_data(self, data, original_tag):
   """
   :type data: dict[str,numpy.ndarray]
   :type original_tag: str
   """
   features = data["data"]
   if not self.added_data:
     seq_idx = 0
     assert self.expected_load_seq_start == 0
   else:
     seq_idx = self.added_data[-1].seq_idx + 1
   tag = "%s.%i" % (original_tag, seq_idx)
   seq = DatasetSeq(seq_idx=seq_idx, features=features, targets=data, seq_tag=tag)
   self._num_timesteps_accumulated += seq.num_frames
   self.added_data += [seq]
Esempio n. 17
0
    def _collect_single_seq_from_buffer(self, wavFileId, seq_idx):
        """
    returns the sequence specified by the index seq_idx

    :type wavFileId: int
    :type seq_idx: int
    :rtype: DatasetSeq | None
    :returns DatasetSeq or None if seq_idx >= num_seqs.
    """
        inputFeatures = self._getInputFeatures(wavFileId)
        outputFeatures = self._getOutputFeatures(wavFileId)
        inputFeatures = inputFeatures.astype(np.float32)
        if outputFeatures is not None:
            outputFeatures = targets.astype(np.float32)
        return DatasetSeq(seq_idx, inputFeatures, outputFeatures)
Esempio n. 18
0
  def _collect_single_seq(self, seq_idx):
    """
    :type seq_idx: int
    :rtype: DatasetSeq
    """
    if not self.is_less_than_num_seqs(seq_idx):
      return None
    dataset_idx, dataset_seq_idx = self.dataset_seq_idxs[seq_idx]
    dataset_key = self.dataset_idxs[dataset_idx]
    dataset = self.datasets[dataset_key]

    seq_tag = dataset.get_tag(dataset_seq_idx)
    features = self._get_data(dataset_key, dataset_seq_idx, "data")
    targets = {target: self._get_data(dataset_key, dataset_seq_idx, target) for target in self.target_list}
    return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features, targets=targets)
Esempio n. 19
0
    def _collect_single_seq(self, seq_idx):
        """Returns the sequence specified by the index seq_idx.
    Normalization is applied to the input features if mean and variance
    have been specified during dataset creating (see the constructor).

    :type seq_idx: int
    :rtype: DatasetSeq | None
    :returns: None if seq_idx >= num_seqs or the corresponding sequence.
    """
        if seq_idx >= self.num_seqs:
            return None

        seqMapping = self._seqMap[seq_idx]
        fileIdx = seqMapping[0]
        datasetName = seqMapping[1]
        fileHandler = self._fileHandlers[fileIdx]
        inputFeatures = fileHandler['inputs'][datasetName][...]
        targets = None
        if 'outputs' in fileHandler:
            targets = fileHandler['outputs'][datasetName][...]

        # optional normalization
        if self._normData is not None:
            assert isinstance(self._normData, NormalizationData)
            # inputs
            if self._flag_normalizeInputs:
                inputFeatures = StereoHdfDataset._normalizeVector(
                    inputFeatures,
                    self._normData.inputMean,
                    self._normData.inputVariance,
                )
            # outputs
            if self._flag_normalizeTargets:
                targets = StereoHdfDataset._normalizeVector(
                    targets,
                    self._normData.outputMean,
                    self._normData.outputVariance,
                )

        # enforce float32 to enable Theano optimizations
        inputFeatures = inputFeatures.astype(np.float32)
        if targets is not None:
            targets = targets.astype(np.float32)

        return DatasetSeq(seq_idx, inputFeatures, targets)
Esempio n. 20
0
  def _collect_single_seq(self, seq_idx):
    """
    :type seq_idx: int
    :rtype: DatasetSeq
    """
    if seq_idx >= len(self.seq_order):
      return None

    real_seq_index = self.seq_order[seq_idx]
    file_index = self.file_indices[real_seq_index]
    seq_name = self.all_seq_names[real_seq_index]
    norm_seq_name = self._normalize_seq_name(seq_name)
    targets = {name: parsers[file_index].get_data(norm_seq_name) for name, parsers in self.all_parsers.items()}
    features = targets[self.input_stream_name]
    return DatasetSeq(seq_idx=seq_idx,
                      seq_tag=seq_name,
                      features=features,
                      targets=targets)
Esempio n. 21
0
    def _collect_single_seq(self, seq_idx):
        """this method implements stacking the features

    :type seq_idx: int
    :param seq_idx: index of a sequence
    :rtype: DatasetSeq
    :return: DatasetSeq
    """
        if seq_idx >= self.num_seqs:
            return None
        originalSeq = super(DatasetWithTimeContext,
                            self)._collect_single_seq(seq_idx)
        inputFeatures = originalSeq.get_data('data')
        frames, bins = inputFeatures.shape
        leftContext = deque()
        rightContext = deque()
        inFeatWithContext = []
        for i in range(self._tau):
            leftContext.append(np.zeros(bins))
            if i + 1 < frames:
                rightContext.append(inputFeatures[i + 1, ...])
            else:
                rightContext.append(np.zeros(bins))
        for t in range(frames):
            f = inputFeatures[t, ...]
            newFeature = np.concatenate([
                np.concatenate(leftContext, axis=0), f,
                np.concatenate(rightContext, axis=0)
            ],
                                        axis=0)
            inFeatWithContext.append(newFeature)
            leftContext.popleft()
            leftContext.append(f)
            rightContext.popleft()
            if t + 1 + self._tau < frames:
                rightContext.append(inputFeatures[t + 1 + self._tau, ...])
            else:
                rightContext.append(np.zeros(bins))
        inputFeatures = np.array(inFeatWithContext)
        targets = None
        if 'classes' in originalSeq.get_data_keys():
            targets = originalSeq.get_data('classes')
        return DatasetSeq(seq_idx, inputFeatures, targets)
Esempio n. 22
0
 def _collect_single_seq(self, seq_idx):
   """
   :param int seq_idx:
   :rtype: DatasetSeq | None
   :returns DatasetSeq or None if seq_idx >= num_seqs.
   """
   assert self.cur_seq_list is not None, "call init_seq_order"
   if seq_idx >= len(self.cur_seq_list):
     return None
   seq_tag = self.cur_seq_list[seq_idx]
   sub_seq_idxs = self.cur_sub_seq_idxs[seq_idx]
   features = {key: [] for key in self.get_data_keys()}
   for sub_seq_idx in sub_seq_idxs:
     self.sub_dataset.load_seqs(sub_seq_idx, sub_seq_idx + 1)
     for key in self.get_data_keys():
       data = self.sub_dataset.get_data(seq_idx, key)
       features[key].append(data)
   features = {key: numpy.concatenate(values, axis=0) for (key, values) in features.items()}
   return DatasetSeq(seq_idx=seq_idx, seq_tag=seq_tag, features=features)
Esempio n. 23
0
 def _add_cache_seq(self, seq_idx, features, targets):
     last_seq_idx = self._get_cache_last_seq_idx()
     assert seq_idx == last_seq_idx + 1
     self.cached_seqs += [DatasetSeq(seq_idx, features, targets)]
Esempio n. 24
0
    def addNewData(self, features, targets=None, segmentName=None):
        """
    Adds a new seq.
    This is called via the Sprint main thread.
    :param numpy.ndarray features: format (input-feature,time) (via Sprint)
    :param dict[str,numpy.ndarray|str] targets: format (time) (idx of output-feature)
    :returns the sorted seq index
    :rtype: int
    """

        # is in format (feature,time)
        assert self.num_inputs == features.shape[0]
        T = features.shape[1]
        # must be in format: (time,feature)
        features = features.transpose()
        assert features.shape == (T, self.num_inputs)

        if targets is None:
            targets = {}
        if not isinstance(targets, dict):
            targets = {"classes": targets}
        if "classes" in targets:
            # 'classes' is always the alignment
            assert targets["classes"].shape == (
                T,
            ), "Number of targets %s does not equal to number of features %s" % (
                targets["classes"].shape, (T, ))  # is in format (time,)

        # Maybe convert some targets.
        if self.target_maps:
            for key, tmap in self.target_maps.items():
                assert key in targets
                v = tmap[targets[key]]
                v = numpy.asarray(v)
                if v.ndim == 0:
                    v = numpy.zeros(
                        (T, ), dtype=v.dtype) + v  # add time dimension
                targets[key] = v

        # Maybe remove some targets.
        for key in self._target_black_list:
            if key in targets:
                del targets[key]

        # Check if all targets are valid.
        for key, v in sorted(list(targets.items())):
            if isinstance(v, numpy.ndarray):
                continue  # ok
            if isinstance(v, unicode):
                v = v.encode("utf8")
            if isinstance(v, (str, bytes)):
                v = map(ord, v)
                v = numpy.array(v, dtype="uint8")
                targets[key] = v
                continue
            print >> log.v3, "SprintDataset, we will ignore the target %r because it is not a numpy array: %r" % (
                key, v)
            self._target_black_list += [key]
            del targets[key]

        with self.lock:
            # This gets called in the Sprint main thread.
            # If this is used together with SprintInterface.getSegmentList(), we are always in a state where
            # we just yielded a segment name, thus we are always in a Sprint epoch and thus ready for data.
            assert self.ready_for_data
            assert not self.reached_final_seq
            assert not self.sprintFinalized

            seq_idx = self.next_seq_to_be_added

            if self.predefined_seq_list_order:
                # Note: Only in ExternSprintDataset, we can reliably set the seq order for now.
                assert self.predefined_seq_list_order[
                    seq_idx] == segmentName, "seq-order not as expected"

            self.next_seq_to_be_added += 1
            self._num_timesteps += T
            self.cond.notify_all()

            if seq_idx > self.requested_load_seq_end - 1 + self.SprintCachedSeqsMax:
                print >> log.v5, "SprintDataset addNewData: seq=%i, len=%i. Cache filled, waiting to get loaded..." % (
                    seq_idx, T)
                while seq_idx > self.requested_load_seq_end - 1 + self.SprintCachedSeqsMin:
                    assert not self.reached_final_seq
                    assert seq_idx + 1 == self.next_seq_to_be_added
                    self.cond.wait()

            self.added_data += [
                DatasetSeq(seq_idx, features, targets, seq_tag=segmentName)
            ]
            self.cond.notify_all()
            return seq_idx
Esempio n. 25
0
    def _collect_single_seq(self, seq_idx):
        """
    :type seq_idx: int
    :rtype: DatasetSeq | None
    :returns DatasetSeq or None if seq_idx >= num_seqs.
    """
        while True:
            if self.next_orth_idx >= len(self.orths_epoch):
                assert self.next_seq_idx <= seq_idx, "We expect that we iterate through all seqs."
                if self.num_skipped > 0:
                    print("LmDataset: reached end, skipped %i sequences" %
                          self.num_skipped)
                return None
            assert self.next_seq_idx == seq_idx, "We expect that we iterate through all seqs."
            orth = self.orths_epoch[self.seq_order[self.next_orth_idx]]
            self.next_orth_idx += 1
            if orth == "</s>":
                continue  # special sentence end symbol. empty seq, ignore.

            if self.seq_gen:
                try:
                    phones = self.seq_gen.generate_seq(orth)
                except KeyError as e:
                    if self.log_skipped_seqs:
                        print(
                            "LmDataset: skipping sequence %r because of missing lexicon entry: %s"
                            % (orth, e),
                            file=log.v4)
                        self._reduce_log_skipped_seqs()
                    if self.error_on_invalid_seq:
                        raise Exception(
                            "LmDataset: invalid seq %r, missing lexicon entry %r"
                            % (orth, e))
                    self.num_skipped += 1
                    continue  # try another seq
                data = self.seq_gen.seq_to_class_idxs(phones, dtype=self.dtype)

            elif self.orth_symbols:
                orth_syms = parse_orthography(orth, **self.parse_orth_opts)
                while True:
                    orth_syms = sum(
                        [self.orth_replace_map.get(s, [s]) for s in orth_syms],
                        [])
                    i = 0
                    while i < len(orth_syms) - 1:
                        if orth_syms[i:i + 2] == [" ", " "]:
                            orth_syms[i:i + 2] = [" "]  # collapse two spaces
                        else:
                            i += 1
                    if self.auto_replace_unknown_symbol:
                        try:
                            map(self.orth_symbols_map.__getitem__, orth_syms)
                        except KeyError as e:
                            orth_sym = e.message
                            if self.log_auto_replace_unknown_symbols:
                                print(
                                    "LmDataset: unknown orth symbol %r, adding to orth_replace_map as %r"
                                    % (orth_sym, self.unknown_symbol),
                                    file=log.v3)
                                self._reduce_log_auto_replace_unknown_symbols()
                            self.orth_replace_map[orth_sym] = [
                                self.unknown_symbol
                            ] if self.unknown_symbol is not None else []
                            continue  # try this seq again with updated orth_replace_map
                    break
                self.num_unknown += orth_syms.count(self.unknown_symbol)
                if self.word_based:
                    orth_debug_str = repr(orth_syms)
                else:
                    orth_debug_str = repr("".join(orth_syms))
                try:
                    data = numpy.array(map(self.orth_symbols_map.__getitem__,
                                           orth_syms),
                                       dtype=self.dtype)
                except KeyError as e:
                    if self.log_skipped_seqs:
                        print(
                            "LmDataset: skipping sequence %s because of missing orth symbol: %s"
                            % (orth_debug_str, e),
                            file=log.v4)
                        self._reduce_log_skipped_seqs()
                    if self.error_on_invalid_seq:
                        raise Exception(
                            "LmDataset: invalid seq %s, missing orth symbol %s"
                            % (orth_debug_str, e))
                    self.num_skipped += 1
                    continue  # try another seq

            else:
                assert False

            targets = {}
            for i in range(self.add_random_phone_seqs):
                assert self.seq_gen  # not implemented atm for orths
                phones = self.seq_gen.generate_garbage_seq(
                    target_len=data.shape[0])
                targets["random%i" % i] = self.seq_gen.seq_to_class_idxs(
                    phones, dtype=self.dtype)
            if self.add_delayed_seq_data:
                targets["delayed"] = numpy.concatenate(([
                    self.orth_symbols_map[self.delayed_seq_data_start_symbol]
                ], data[:-1])).astype(self.dtype)
                assert targets["delayed"].shape == data.shape
            self.next_seq_idx = seq_idx + 1
            return DatasetSeq(seq_idx=seq_idx, features=data, targets=targets)
Esempio n. 26
0
  def add_new_data(self, features, targets=None, segment_name=None):
    """
    Adds a new seq.
    This is called via the Sprint main thread.

    :param numpy.ndarray features: format (input-feature,time) (via Sprint)
    :param dict[str,numpy.ndarray|str] targets: format (time) (idx of output-feature)
    :param str|None segment_name:
    :returns the sorted seq index
    :rtype: int
    """

    # is in format (feature,time)
    assert self.num_inputs == features.shape[0]
    num_frames = features.shape[1]
    # must be in format: (time,feature)
    features = features.transpose()
    assert features.shape == (num_frames, self.num_inputs)
    if self.input_stddev != 1:
      features /= self.input_stddev
    if self.window > 1:
      features = self.sliding_window(features)
      assert features.shape == (num_frames, self.num_inputs * self.window)

    if targets is None:
      targets = {}
    if not isinstance(targets, dict):
      targets = {"classes": targets}
    if "classes" in targets:
      # 'classes' is always the alignment
      assert targets["classes"].shape == (num_frames,), (  # is in format (time,)
        "Number of targets %s does not equal to number of features %s" % (targets["classes"].shape, (num_frames,)))
    if "orth" in targets:
      targets["orth"] = targets["orth"].decode("utf8").strip()
    if "orth" in targets and self.orth_post_process:
      targets["orth"] = self.orth_post_process(targets["orth"])
    if self.bpe:
      assert "orth" in targets
      orth = targets["orth"]
      assert isinstance(orth, (str, unicode))
      assert "bpe" not in targets
      targets["bpe"] = numpy.array(self.bpe.get_seq(orth), dtype="int32")
    if self.orth_vocab:
      assert "orth" in targets
      orth = targets["orth"]
      assert isinstance(orth, (str, unicode))
      assert "orth_classes" not in targets
      targets["orth_classes"] = numpy.array(self.orth_vocab.get_seq(orth), dtype="int32")

    # Maybe convert some targets.
    if self.target_maps:
      for key, target_map in self.target_maps.items():
        assert key in targets
        v = target_map[targets[key]]
        v = numpy.asarray(v)
        if v.ndim == 0:
          v = numpy.zeros((num_frames,), dtype=v.dtype) + v  # add time dimension
        targets[key] = v

    # Maybe remove some targets.
    for key in self._target_black_list:
      if key in targets:
        del targets[key]

    # Check if all targets are valid.
    for key, v in sorted(list(targets.items())):
      if isinstance(v, numpy.ndarray):
        continue  # ok
      if isinstance(v, unicode):
        v = v.encode("utf8")
      if isinstance(v, (str, bytes)):
        if PY3:
          assert isinstance(v, bytes)
          v = list(v)
        else:
          v = list(map(ord, v))
        v = numpy.array(v, dtype="uint8")
        targets[key] = v
        if self.str_add_final_zero:
          v = numpy.append(v, numpy.array([0], dtype=v.dtype))
          assert key + "0" not in targets
          targets[key + "0"] = v
        continue
      print("%s, we will ignore the target %r because it is not a numpy array: %r" % (self, key, v), file=log.v3)
      self._target_black_list += [key]
      del targets[key]

    with self.lock:
      # This gets called in the Sprint main thread.
      # If this is used together with SprintInterface.getSegmentList(), we are always in a state where
      # we just yielded a segment name, thus we are always in a Sprint epoch and thus ready for data.
      assert self.ready_for_data
      assert not self.reached_final_seq
      assert not self.sprintFinalized

      seq_idx = self.next_seq_to_be_added

      if self.predefined_seq_list_order:
        # Note: Only in ExternSprintDataset, we can reliably set the seq order for now.
        assert seq_idx < len(self.predefined_seq_list_order), "seq_idx %i, expected predef num seqs %i" % (
          seq_idx, len(self.predefined_seq_list_order))
        expected_seq_name = self.predefined_seq_list_order[seq_idx]
        if expected_seq_name != segment_name:
          if segment_name in self.predefined_seq_list_order:
            raise Exception("seq_idx %i expected to be tag %r but got tag %r; tag %r is at idx %i" % (
              seq_idx, expected_seq_name, segment_name, segment_name,
              self.predefined_seq_list_order.index(segment_name)))
          raise Exception("seq_idx %i expected to be tag %r but got tag %r; tag %r not found" % (
            seq_idx, expected_seq_name, segment_name, segment_name))

      self.next_seq_to_be_added += 1
      self._num_timesteps += num_frames
      self.cond.notify_all()

      if seq_idx > self.requested_load_seq_end - 1 + self.SprintCachedSeqsMax:
        print("%s add_new_data: seq=%i, len=%i. Cache filled, waiting to get loaded..." % (
          self, seq_idx, num_frames), file=log.v5)
        while seq_idx > self.requested_load_seq_end - 1 + self.SprintCachedSeqsMin:
          assert not self.reached_final_seq
          assert seq_idx + 1 == self.next_seq_to_be_added
          self.cond.wait()

      self.added_data += [DatasetSeq(seq_idx, features, targets, seq_tag=segment_name)]
      self.cond.notify_all()
      return seq_idx