Example #1
0
    def __call__(self, example):
        # Shortcut if chunking is disabled
        if self.chunk_size == -1:
            return example

        to_chunk = {k: example.pop(k) for k in self.chunk_keys}
        to_chunk_lengths = [c.shape[self.axis] for c in to_chunk.values()]
        assert to_chunk_lengths[1:] == to_chunk_lengths[:-1], (
            f'The shapes along the chunk dimension of all entries to chunk '
            f'must be equal! {to_chunk_lengths}')
        to_chunk_length = to_chunk_lengths[0]

        if to_chunk_length < self.min_length:
            raise lazy_dataset.FilterException()
        elif to_chunk_length < self.chunk_size:
            # Pad
            from paderbox.array.padding import pad_axis
            pad_width = [
                ((self.chunk_size - to_chunk_length) // 2),
                ((self.chunk_size - to_chunk_length) // 2) + 1,
                # +1 for odd case, in even case does not hurt
            ]
            to_chunk = {
                k: pad_axis(v, pad_width, axis=self.axis)
                for k, v in to_chunk.items()
            }
            start = 0
        elif to_chunk_length >= self.chunk_size:
            start = np.random.randint(0, to_chunk_length - self.chunk_size + 1)
        else:
            raise RuntimeError(to_chunk_length, self.min_length,
                               self.chunk_size)

        chunk = deepcopy(example)
        chunk.update({
            k: _getitem_on_axis(v,
                                slice(start, start + self.chunk_size),
                                axis=self.axis).copy()
            for k, v in to_chunk.items()
        })
        chunk.update(num_samples=self.chunk_size)

        return chunk
Example #2
0
    def __call__(self, example):
        # Shortcut if chunking is disabled
        if self.chunk_size == -1:
            return [example]

        to_chunk = {k: example.pop(k) for k in self.chunk_keys}
        to_chunk_lengths = [c.shape[self.axis] for c in to_chunk.values()]
        assert to_chunk_lengths[1:] == to_chunk_lengths[:-1], (
            'The shapes along the chunk dimension of all entries to chunk must '
            'be equal!\n'
            f'chunk_keys: {self.chunk_keys}'
            f'to_chunk_lengths: {to_chunk_lengths}')
        to_chunk_length = to_chunk_lengths[0]

        # Discard examples that are shorter than `chunk_size`
        if to_chunk_length < self.chunk_size:
            raise lazy_dataset.FilterException()

        # Cut overlapping chunks
        chunks = []

        shift = self.chunk_size // 2
        for chunk_beginning in range(
                0,
                to_chunk_length - self.chunk_size + 1,  # only full sizes
                shift,
        ):
            chunk_end = chunk_beginning + self.chunk_size
            chunk = deepcopy(example)
            chunk.update({
                k: _getitem_on_axis(v,
                                    slice(chunk_beginning, chunk_end),
                                    axis=self.axis)
                for k, v in to_chunk.items()
            })
            chunk.update(num_samples=self.chunk_size)
            chunks.append(chunk)

        return chunks
Example #3
0
    def __call__(self, example: dict, rng=np.random) -> List[dict]:
        """

        Args:
            example: dictionary with string keys
            rng: random number generator, maybe set using
                paderbox.utils.random_utils.str_to_random_state

        Returns:
        """

        example = flatten(example, sep=self.flatten_separator)

        to_segment_keys = self.get_to_segment_keys(example)
        axis = self.get_axis_list(to_segment_keys)

        to_segment = {
            key: example.pop(key) for key in to_segment_keys
        }

        if all([isinstance(key, str) for key in self.copy_keys]):
            to_copy = {key: example.pop(key) for key in self.copy_keys}
        elif self.copy_keys[0] is True:
            assert len(self.copy_keys) == 1, self.copy_keys
            to_copy = example
        elif self.copy_keys[0] is False:
            assert len(self.copy_keys) == 1, self.copy_keys
            to_copy = dict()
        else:
            raise TypeError('Unknown type for copy keys', self.copy_keys)

        if any([not isinstance(value, (np.ndarray, torch.Tensor))
                for value in to_segment.values()]):
            raise ValueError(
                'This segmenter only works on numpy arrays',
                'However, the following keys point to other types:',
                '\n'.join([f'{key} points to a {type(to_segment[key])}'
                           for key in to_segment_keys])
            )

        to_segment_lengths = [
            v.shape[axis[i]] for i, v in enumerate(to_segment.values())]

        assert to_segment_lengths[1:] == to_segment_lengths[:-1], (
            'The shapes along the segment dimension of all entries to segment'
            ' must be equal!\n'
            f'segment keys: {to_segment_keys}'
            f'to_segment_lengths: {to_segment_lengths}'
        )
        assert len(to_segment) > 0, ('Did not find any signals to segment',
                                     self.include, self.exclude, to_segment)
        to_segment_length = to_segment_lengths[0]

        # Discard examples that are shorter than `length`
        if not self.mode == 'max' and to_segment_length < self.length:
            import lazy_dataset
            raise lazy_dataset.FilterException()

        # Shortcut if segmentation is disabled
        if self.length == -1:
            to_copy.update(to_segment)
            to_copy.update(segment_start=0, segment_stop=to_segment_length)
            return [deflatten(to_copy)]

        boundaries, segmented = self.segment(to_segment, to_segment_length,
                                             axis=axis, rng=rng)

        segmented_examples = list()

        for idx, (start, stop) in enumerate(boundaries):
            example_copy = copy(to_copy)
            example_copy.update({key: value[idx]
                                 for key, value in segmented.items()})
            example_copy.update(segment_start=start, segment_stop=stop)
            segmented_examples.append(deflatten(example_copy))
        return segmented_examples
Example #4
0
 def map_function(example):
     if example['value'] == 3:
         raise lazy_dataset.FilterException('Got 3')
     return example