def _test_index_iterator_no_shuffle(): index_list = [1, 3, 5, 10] ii = IndexIterator(index_list, shuffle=False, num=2) indices1 = ii.get_next_indices(3) indices2 = ii.get_next_indices(6) indices3 = ii.__next__() assert isinstance(indices1, numpy.ndarray) assert len(indices1) == 3 assert isinstance(indices2, numpy.ndarray) assert len(indices2) == 6 assert isinstance(indices3, numpy.ndarray) assert len(indices3) == 2 assert indices1[0] == index_list[0] assert indices1[1] == index_list[1] assert indices1[2] == index_list[2] assert indices2[0] == index_list[3] assert indices2[1] == index_list[0] assert indices2[2] == index_list[1] assert indices2[3] == index_list[2] assert indices2[4] == index_list[3] assert indices2[5] == index_list[0] assert indices3[0] == index_list[1] assert indices3[1] == index_list[2]
def _test_index_iterator_with_shuffle(): index_list = [1, 3, 5, 10] ii = IndexIterator(index_list, shuffle=True, num=2) indices1 = ii.get_next_indices(3) indices2 = ii.get_next_indices(6) indices3 = ii.__next__() assert isinstance(indices1, numpy.ndarray) assert len(indices1) == 3 assert isinstance(indices2, numpy.ndarray) assert len(indices2) == 6 assert isinstance(indices3, numpy.ndarray) assert len(indices3) == 2 for indices in [indices1, indices2, indices3]: for index in indices: assert index in index_list
def _test_index_iterator_serialization_no_shuffle(): index_list = [1, 3, 5, 10] ii = IndexIterator(index_list, shuffle=False, num=2) indices1 = ii.get_next_indices(3) # NOQA indices2 = ii.get_next_indices(6) # NOQA indices3 = ii.__next__() # NOQA assert len(ii.current_index_list) == len(index_list) assert numpy.array_equal(ii.current_index_list, numpy.asarray(index_list)) assert ii.current_pos == (3 + 6) % len(index_list) + 2 target = dict() ii.serialize(DummySerializer(target)) ii = IndexIterator(index_list, shuffle=False, num=2) ii.serialize(DummyDeserializer(target)) assert len(ii.current_index_list) == len(index_list) assert numpy.array_equal(ii.current_index_list, numpy.asarray(index_list)) assert ii.current_pos == (3 + 6) % len(index_list) + 2
def _test_index_iterator_serialization_with_shuffle(): index_list = [1, 3, 5, 10] ii = IndexIterator(index_list, shuffle=True, num=2) indices1 = ii.get_next_indices(3) # NOQA indices2 = ii.get_next_indices(6) # NOQA indices3 = ii.__next__() # NOQA assert len(ii.current_index_list) == len(index_list) for index in ii.current_index_list: assert index in index_list assert ii.current_pos == (3 + 6) % len(index_list) + 2 target = dict() ii.serialize(DummySerializer(target)) current_index_list_orig = ii.current_index_list ii = IndexIterator(index_list, shuffle=True, num=2) ii.serialize(DummyDeserializer(target)) assert numpy.array_equal(ii.current_index_list, current_index_list_orig) assert ii.current_pos == (3 + 6) % len(index_list) + 2
def __init__(self, dataset, batch_size, labels, repeat=True, shuffle=True, batch_balancing=False, ignore_labels=None, logger=getLogger(__name__)): assert len(dataset) == len(labels) labels = numpy.asarray(labels) if len(dataset) != labels.size: raise ValueError('dataset length {} and labels size {} must be ' 'same!'.format(len(dataset), labels.size)) labels = numpy.ravel(labels) self.dataset = dataset self.batch_size = batch_size self.labels = labels self.logger = logger if ignore_labels is None: ignore_labels = [] elif isinstance(ignore_labels, int): ignore_labels = [ ignore_labels, ] self.ignore_labels = list(ignore_labels) self._repeat = repeat self._shuffle = shuffle self._batch_balancing = batch_balancing self.labels_iterator_dict = {} max_label_count = -1 include_label_count = 0 for label in numpy.unique(labels): label_index = numpy.argwhere(labels == label).ravel() label_count = len(label_index) ii = IndexIterator(label_index, shuffle=shuffle) self.labels_iterator_dict[label] = ii if label in self.ignore_labels: continue if max_label_count < label_count: max_label_count = label_count include_label_count += 1 self.max_label_count = max_label_count self.N_augmented = max_label_count * include_label_count self.reset()
def _test_index_iterator_serialization_no_shuffle(): index_list = [1, 3, 5, 10] ii = IndexIterator(index_list, shuffle=False, num=2) indices1 = ii.get_next_indices(3) indices2 = ii.get_next_indices(6) indices3 = ii.__next__() assert len(ii.current_index_list) == len(index_list) assert numpy.array_equal(ii.current_index_list, numpy.asarray(index_list)) assert ii.current_pos == (3 + 6) % len(index_list) + 2 target = dict() ii.serialize(DummySerializer(target)) ii = IndexIterator(index_list, shuffle=False, num=2) ii.serialize(DummyDeserializer(target)) assert len(ii.current_index_list) == len(index_list) assert numpy.array_equal(ii.current_index_list, numpy.asarray(index_list)) assert ii.current_pos == (3 + 6) % len(index_list) + 2
def _test_index_iterator_serialization_with_shuffle(): index_list = [1, 3, 5, 10] ii = IndexIterator(index_list, shuffle=True, num=2) indices1 = ii.get_next_indices(3) indices2 = ii.get_next_indices(6) indices3 = ii.__next__() assert len(ii.current_index_list) == len(index_list) for index in ii.current_index_list: assert index in index_list assert ii.current_pos == (3 + 6) % len(index_list) + 2 target = dict() ii.serialize(DummySerializer(target)) current_index_list_orig = ii.current_index_list ii = IndexIterator(index_list, shuffle=True, num=2) ii.serialize(DummyDeserializer(target)) assert numpy.array_equal(ii.current_index_list, current_index_list_orig) assert ii.current_pos == (3 + 6) % len(index_list) + 2