예제 #1
0
def test_ExampleConcatenatedDataset_step():
    D1 = DebugDataset(size=10)
    D2 = DebugDataset(size=10)

    E = ExampleConcatenatedDataset(D1, D2)
    E.set_example_pars(step=2)
    assert len(E) == 10
    d = E[2]

    assert d == {"val": [2], "index_": 2, "other": [2]}

    assert len(E.labels["label1"]) == 10
    assert np.all(E.labels["label1"] == [[i, i] for i in range(10)])

    D3 = DebugDataset(size=20)

    with pytest.raises(AssertionError):
        ExampleConcatenatedDataset(D1, D3)

    D4 = DebugDataset(size=10)
    D5 = DebugDataset(size=10)

    E = ExampleConcatenatedDataset(D1, D2, D4, D5)
    E.set_example_pars(step=2)
    assert len(E) == 10
    d = E[2]

    assert d == {"val": [2, 2], "index_": 2, "other": [2, 2]}

    assert len(E.labels["label1"]) == 10
    assert np.all(E.labels["label1"] == [[i, i] for i in range(10)])
예제 #2
0
def test_ExampleConcatenatedDataset_slicing():
    D1 = DebugDataset(size=10)
    D2 = DebugDataset(size=10, offset=1)
    D3 = DebugDataset(size=10, offset=2)
    D4 = DebugDataset(size=10, offset=3)

    E = ExampleConcatenatedDataset(D1, D2, D3, D4)
    E.set_example_pars(start=1, step=2)
    assert len(E) == 10
    d = E[2]

    assert d == {"val": [2 + 1, 2 + 3], "index_": 2, "other": [2 + 1, 2 + 3]}

    assert len(E.labels["label1"]) == 10
    assert np.all(E.labels["label1"] == [[i + 1, i + 3] for i in range(10)])

    E.set_example_pars(start=0, stop=-1, step=2)
    assert len(E) == 10
    d = E[2]

    assert d == {"val": [2, 2 + 2], "index_": 2, "other": [2, 2 + 2]}

    assert len(E.labels["label1"]) == 10
    assert np.all(E.labels["label1"] == [[i, i + 2] for i in range(10)])

    E.set_example_pars(start=1, stop=-1, step=2)
    assert len(E) == 10
    d = E[2]

    assert d == {"val": [2 + 1], "index_": 2, "other": [2 + 1]}

    assert len(E.labels["label1"]) == 10
    assert np.all(E.labels["label1"] == [[i + 1] for i in range(10)])
예제 #3
0
파일: sequence.py 프로젝트: mritv/edflow
class SequenceDataset(DatasetMixin):
    """Wraps around a dataset and returns sequences of examples.
    Given the length of those sequences the number of available examples
    is reduced by this length times the step taken. Additionally each
    example must have a frame id :attr:`fid_key` specified in the labels, by
    which it can be filtered.
    This is to ensure that each frame is taken from the same video.

    This class assumes that examples come sequentially with :attr:`fid_key` and
    that frame id ``0`` exists.

    The SequenceDataset also exposes the Attribute ``self.base_indices``,
    which holds at each index ``i`` the indices of the elements contained in
    the example from the sequentialized dataset.
    """
    def __init__(self,
                 dataset,
                 length,
                 step=1,
                 fid_key="fid",
                 strategy="raise"):
        """
        Parameters
        ----------
        dataset : DatasetMixin
            Dataset from which single frame examples are taken.
        length : int
            Length of the returned sequences in frames.
        step : int
            Step between returned frames. Must be `>= 1`.
        fid_key : str
            Key in labels, at which the frame indices can be found.
        strategy : str
            How to handle bad sequences, i.e. sequences starting with a 
            :attr:`fid_key` > 0.
            - ``raise``: Raise a ``ValueError``
            - ``remove``: remove the sequence
            - ``reset``: remove the sequence

        This dataset will have `len(dataset) - length * step` examples.
        """

        self.step = step
        self.length = length

        frame_ids = np.array(dataset.labels[fid_key])
        if frame_ids.ndim != 1 or len(frame_ids) != len(dataset):
            raise ValueError(
                "Frame ids must be supplied as a sequence of "
                "scalars with the same length as the dataset! Here we "
                "have np.shape(dataset.labels[{}]) = {}`.".format(
                    fid_key, np.shape(frame_ids)))
        if frame_ids.dtype != np.int:
            raise TypeError(
                "Frame ids must be supplied as ints, but are {}".format(
                    frame_ids.dtype))
        if frame_ids.dtype != np.int:
            raise TypeError(
                "Frame ids must be supplied as ints, but are {}".format(
                    frame_ids.dtype))

        # Gradient
        diffs = frame_ids[1:] - frame_ids[:-1]
        # All indices where the fid is not monotonically growing
        idxs = np.array([0] + list(np.where(diffs != 1)[0] + 1))
        # Values at these indices
        start_fids = frame_ids[idxs]

        # Bad starts
        badboys = start_fids != 0
        if np.any(badboys):
            n = sum(badboys)
            i_s = "" if n == 1 else "s"
            areis = "is" if n == 1 else "are"
            id_s = "ex" if n == 1 else "ices"
            if strategy == "raise":
                raise ValueError(
                    "Frame id sequences must always start with 0. "
                    "There {} {} sequence{} starting with the follwing id{}: "
                    "{} at ind{} {} in the dataset.".format(
                        areis, n, i_s, i_s, start_fids[badboys], id_s,
                        idxs[badboys]))

            elif strategy == "remove":
                idxs_stop = np.array(list(idxs[1:]) + [None])
                starts = idxs[badboys]
                stops = idxs_stop[badboys]

                bad_seq_mask = np.ones(len(dataset), dtype=bool)
                for bad_start_idx, bad_stop_idx in zip(starts, stops):
                    bad_seq_mask[bad_start_idx:bad_stop_idx] = False

                good_seq_idxs = np.arange(len(dataset))[bad_seq_mask]
                dataset = SubDataset(dataset, good_seq_idxs)

                frame_ids = dataset.labels[fid_key]

            elif strategy == "reset":
                frame_ids = np.copy(frame_ids)  # Don't try to override

                idxs_stop = np.array(list(idxs[1:]) + [None])
                starts = idxs[badboys]
                stops = idxs_stop[badboys]
                vals = start_fids[badboys]

                for val, bad_sa_idx, bad_so_idx in zip(vals, starts, stops):
                    frame_ids[bad_sa_idx:bad_so_idx] = (
                        frame_ids[bad_sa_idx:bad_so_idx] - val)

                dataset.labels[fid_key] = frame_ids

                frame_ids = dataset.labels[fid_key]
            else:
                raise ValueError("Strategy of SequenceDataset must be one of "
                                 "`raise`, `remove` or `reset` but is "
                                 "{}".format(strategy))

        top_indeces = np.where(np.array(frame_ids) >= (length * step - 1))[0]

        all_subdatasets = []
        base_indices = []
        for i in range(length * step):
            indeces = top_indeces - i
            base_indices += [indeces]
            subdset = SubDataset(dataset, indeces)
            all_subdatasets += [subdset]

        all_subdatasets = all_subdatasets[::-1]

        self.data = ExampleConcatenatedDataset(*all_subdatasets)
        self.data.set_example_pars(step=self.step)
        self.base_indices = np.array(base_indices).transpose(1, 0)[:, ::-1]