def test_ExampleConcatenatedDataset_step(): D1 = DebugDataset(size=10) D2 = DebugDataset(size=10) E = ExampleConcatenatedDataset(D1, D2) E.set_example_pars(step=2) assert len(E) == 10 d = E[2] assert d == {"val": [2], "index_": 2, "other": [2]} assert len(E.labels["label1"]) == 10 assert np.all(E.labels["label1"] == [[i, i] for i in range(10)]) D3 = DebugDataset(size=20) with pytest.raises(AssertionError): ExampleConcatenatedDataset(D1, D3) D4 = DebugDataset(size=10) D5 = DebugDataset(size=10) E = ExampleConcatenatedDataset(D1, D2, D4, D5) E.set_example_pars(step=2) assert len(E) == 10 d = E[2] assert d == {"val": [2, 2], "index_": 2, "other": [2, 2]} assert len(E.labels["label1"]) == 10 assert np.all(E.labels["label1"] == [[i, i] for i in range(10)])
def test_ExampleConcatenatedDataset_slicing(): D1 = DebugDataset(size=10) D2 = DebugDataset(size=10, offset=1) D3 = DebugDataset(size=10, offset=2) D4 = DebugDataset(size=10, offset=3) E = ExampleConcatenatedDataset(D1, D2, D3, D4) E.set_example_pars(start=1, step=2) assert len(E) == 10 d = E[2] assert d == {"val": [2 + 1, 2 + 3], "index_": 2, "other": [2 + 1, 2 + 3]} assert len(E.labels["label1"]) == 10 assert np.all(E.labels["label1"] == [[i + 1, i + 3] for i in range(10)]) E.set_example_pars(start=0, stop=-1, step=2) assert len(E) == 10 d = E[2] assert d == {"val": [2, 2 + 2], "index_": 2, "other": [2, 2 + 2]} assert len(E.labels["label1"]) == 10 assert np.all(E.labels["label1"] == [[i, i + 2] for i in range(10)]) E.set_example_pars(start=1, stop=-1, step=2) assert len(E) == 10 d = E[2] assert d == {"val": [2 + 1], "index_": 2, "other": [2 + 1]} assert len(E.labels["label1"]) == 10 assert np.all(E.labels["label1"] == [[i + 1] for i in range(10)])
class SequenceDataset(DatasetMixin): """Wraps around a dataset and returns sequences of examples. Given the length of those sequences the number of available examples is reduced by this length times the step taken. Additionally each example must have a frame id :attr:`fid_key` specified in the labels, by which it can be filtered. This is to ensure that each frame is taken from the same video. This class assumes that examples come sequentially with :attr:`fid_key` and that frame id ``0`` exists. The SequenceDataset also exposes the Attribute ``self.base_indices``, which holds at each index ``i`` the indices of the elements contained in the example from the sequentialized dataset. """ def __init__(self, dataset, length, step=1, fid_key="fid", strategy="raise"): """ Parameters ---------- dataset : DatasetMixin Dataset from which single frame examples are taken. length : int Length of the returned sequences in frames. step : int Step between returned frames. Must be `>= 1`. fid_key : str Key in labels, at which the frame indices can be found. strategy : str How to handle bad sequences, i.e. sequences starting with a :attr:`fid_key` > 0. - ``raise``: Raise a ``ValueError`` - ``remove``: remove the sequence - ``reset``: remove the sequence This dataset will have `len(dataset) - length * step` examples. """ self.step = step self.length = length frame_ids = np.array(dataset.labels[fid_key]) if frame_ids.ndim != 1 or len(frame_ids) != len(dataset): raise ValueError( "Frame ids must be supplied as a sequence of " "scalars with the same length as the dataset! Here we " "have np.shape(dataset.labels[{}]) = {}`.".format( fid_key, np.shape(frame_ids))) if frame_ids.dtype != np.int: raise TypeError( "Frame ids must be supplied as ints, but are {}".format( frame_ids.dtype)) if frame_ids.dtype != np.int: raise TypeError( "Frame ids must be supplied as ints, but are {}".format( frame_ids.dtype)) # Gradient diffs = frame_ids[1:] - frame_ids[:-1] # All indices where the fid is not monotonically growing idxs = np.array([0] + list(np.where(diffs != 1)[0] + 1)) # Values at these indices start_fids = frame_ids[idxs] # Bad starts badboys = start_fids != 0 if np.any(badboys): n = sum(badboys) i_s = "" if n == 1 else "s" areis = "is" if n == 1 else "are" id_s = "ex" if n == 1 else "ices" if strategy == "raise": raise ValueError( "Frame id sequences must always start with 0. " "There {} {} sequence{} starting with the follwing id{}: " "{} at ind{} {} in the dataset.".format( areis, n, i_s, i_s, start_fids[badboys], id_s, idxs[badboys])) elif strategy == "remove": idxs_stop = np.array(list(idxs[1:]) + [None]) starts = idxs[badboys] stops = idxs_stop[badboys] bad_seq_mask = np.ones(len(dataset), dtype=bool) for bad_start_idx, bad_stop_idx in zip(starts, stops): bad_seq_mask[bad_start_idx:bad_stop_idx] = False good_seq_idxs = np.arange(len(dataset))[bad_seq_mask] dataset = SubDataset(dataset, good_seq_idxs) frame_ids = dataset.labels[fid_key] elif strategy == "reset": frame_ids = np.copy(frame_ids) # Don't try to override idxs_stop = np.array(list(idxs[1:]) + [None]) starts = idxs[badboys] stops = idxs_stop[badboys] vals = start_fids[badboys] for val, bad_sa_idx, bad_so_idx in zip(vals, starts, stops): frame_ids[bad_sa_idx:bad_so_idx] = ( frame_ids[bad_sa_idx:bad_so_idx] - val) dataset.labels[fid_key] = frame_ids frame_ids = dataset.labels[fid_key] else: raise ValueError("Strategy of SequenceDataset must be one of " "`raise`, `remove` or `reset` but is " "{}".format(strategy)) top_indeces = np.where(np.array(frame_ids) >= (length * step - 1))[0] all_subdatasets = [] base_indices = [] for i in range(length * step): indeces = top_indeces - i base_indices += [indeces] subdset = SubDataset(dataset, indeces) all_subdatasets += [subdset] all_subdatasets = all_subdatasets[::-1] self.data = ExampleConcatenatedDataset(*all_subdatasets) self.data.set_example_pars(step=self.step) self.base_indices = np.array(base_indices).transpose(1, 0)[:, ::-1]