コード例 #1
0
def _test_index_iterator_no_shuffle():
    index_list = [1, 3, 5, 10]
    ii = IndexIterator(index_list, shuffle=False, num=2)

    indices1 = ii.get_next_indices(3)
    indices2 = ii.get_next_indices(6)
    indices3 = ii.__next__()

    assert isinstance(indices1, numpy.ndarray)
    assert len(indices1) == 3
    assert isinstance(indices2, numpy.ndarray)
    assert len(indices2) == 6
    assert isinstance(indices3, numpy.ndarray)
    assert len(indices3) == 2
    assert indices1[0] == index_list[0]
    assert indices1[1] == index_list[1]
    assert indices1[2] == index_list[2]
    assert indices2[0] == index_list[3]
    assert indices2[1] == index_list[0]
    assert indices2[2] == index_list[1]
    assert indices2[3] == index_list[2]
    assert indices2[4] == index_list[3]
    assert indices2[5] == index_list[0]
    assert indices3[0] == index_list[1]
    assert indices3[1] == index_list[2]
コード例 #2
0
def _test_index_iterator_no_shuffle():
    index_list = [1, 3, 5, 10]
    ii = IndexIterator(index_list, shuffle=False, num=2)

    indices1 = ii.get_next_indices(3)
    indices2 = ii.get_next_indices(6)
    indices3 = ii.__next__()

    assert isinstance(indices1, numpy.ndarray)
    assert len(indices1) == 3
    assert isinstance(indices2, numpy.ndarray)
    assert len(indices2) == 6
    assert isinstance(indices3, numpy.ndarray)
    assert len(indices3) == 2
    assert indices1[0] == index_list[0]
    assert indices1[1] == index_list[1]
    assert indices1[2] == index_list[2]
    assert indices2[0] == index_list[3]
    assert indices2[1] == index_list[0]
    assert indices2[2] == index_list[1]
    assert indices2[3] == index_list[2]
    assert indices2[4] == index_list[3]
    assert indices2[5] == index_list[0]
    assert indices3[0] == index_list[1]
    assert indices3[1] == index_list[2]
コード例 #3
0
def _test_index_iterator_with_shuffle():
    index_list = [1, 3, 5, 10]
    ii = IndexIterator(index_list, shuffle=True, num=2)

    indices1 = ii.get_next_indices(3)
    indices2 = ii.get_next_indices(6)
    indices3 = ii.__next__()

    assert isinstance(indices1, numpy.ndarray)
    assert len(indices1) == 3
    assert isinstance(indices2, numpy.ndarray)
    assert len(indices2) == 6
    assert isinstance(indices3, numpy.ndarray)
    assert len(indices3) == 2
    for indices in [indices1, indices2, indices3]:
        for index in indices:
            assert index in index_list
コード例 #4
0
def _test_index_iterator_with_shuffle():
    index_list = [1, 3, 5, 10]
    ii = IndexIterator(index_list, shuffle=True, num=2)

    indices1 = ii.get_next_indices(3)
    indices2 = ii.get_next_indices(6)
    indices3 = ii.__next__()

    assert isinstance(indices1, numpy.ndarray)
    assert len(indices1) == 3
    assert isinstance(indices2, numpy.ndarray)
    assert len(indices2) == 6
    assert isinstance(indices3, numpy.ndarray)
    assert len(indices3) == 2
    for indices in [indices1, indices2, indices3]:
        for index in indices:
            assert index in index_list
コード例 #5
0
def _test_index_iterator_serialization_no_shuffle():
    index_list = [1, 3, 5, 10]
    ii = IndexIterator(index_list, shuffle=False, num=2)

    indices1 = ii.get_next_indices(3)  # NOQA
    indices2 = ii.get_next_indices(6)  # NOQA
    indices3 = ii.__next__()  # NOQA

    assert len(ii.current_index_list) == len(index_list)
    assert numpy.array_equal(ii.current_index_list, numpy.asarray(index_list))
    assert ii.current_pos == (3 + 6) % len(index_list) + 2

    target = dict()
    ii.serialize(DummySerializer(target))

    ii = IndexIterator(index_list, shuffle=False, num=2)
    ii.serialize(DummyDeserializer(target))
    assert len(ii.current_index_list) == len(index_list)
    assert numpy.array_equal(ii.current_index_list, numpy.asarray(index_list))
    assert ii.current_pos == (3 + 6) % len(index_list) + 2
コード例 #6
0
def _test_index_iterator_serialization_with_shuffle():
    index_list = [1, 3, 5, 10]
    ii = IndexIterator(index_list, shuffle=True, num=2)

    indices1 = ii.get_next_indices(3)  # NOQA
    indices2 = ii.get_next_indices(6)  # NOQA
    indices3 = ii.__next__()  # NOQA

    assert len(ii.current_index_list) == len(index_list)
    for index in ii.current_index_list:
        assert index in index_list
    assert ii.current_pos == (3 + 6) % len(index_list) + 2

    target = dict()
    ii.serialize(DummySerializer(target))
    current_index_list_orig = ii.current_index_list

    ii = IndexIterator(index_list, shuffle=True, num=2)
    ii.serialize(DummyDeserializer(target))
    assert numpy.array_equal(ii.current_index_list, current_index_list_orig)
    assert ii.current_pos == (3 + 6) % len(index_list) + 2
コード例 #7
0
    def __init__(self,
                 dataset,
                 batch_size,
                 labels,
                 repeat=True,
                 shuffle=True,
                 batch_balancing=False,
                 ignore_labels=None,
                 logger=getLogger(__name__)):
        assert len(dataset) == len(labels)
        labels = numpy.asarray(labels)
        if len(dataset) != labels.size:
            raise ValueError('dataset length {} and labels size {} must be '
                             'same!'.format(len(dataset), labels.size))
        labels = numpy.ravel(labels)
        self.dataset = dataset
        self.batch_size = batch_size
        self.labels = labels
        self.logger = logger

        if ignore_labels is None:
            ignore_labels = []
        elif isinstance(ignore_labels, int):
            ignore_labels = [
                ignore_labels,
            ]
        self.ignore_labels = list(ignore_labels)
        self._repeat = repeat
        self._shuffle = shuffle
        self._batch_balancing = batch_balancing

        self.labels_iterator_dict = {}

        max_label_count = -1
        include_label_count = 0
        for label in numpy.unique(labels):
            label_index = numpy.argwhere(labels == label).ravel()
            label_count = len(label_index)
            ii = IndexIterator(label_index, shuffle=shuffle)
            self.labels_iterator_dict[label] = ii
            if label in self.ignore_labels:
                continue
            if max_label_count < label_count:
                max_label_count = label_count
            include_label_count += 1

        self.max_label_count = max_label_count
        self.N_augmented = max_label_count * include_label_count
        self.reset()
コード例 #8
0
def _test_index_iterator_serialization_no_shuffle():
    index_list = [1, 3, 5, 10]
    ii = IndexIterator(index_list, shuffle=False, num=2)

    indices1 = ii.get_next_indices(3)
    indices2 = ii.get_next_indices(6)
    indices3 = ii.__next__()

    assert len(ii.current_index_list) == len(index_list)
    assert numpy.array_equal(ii.current_index_list, numpy.asarray(index_list))
    assert ii.current_pos == (3 + 6) % len(index_list) + 2

    target = dict()
    ii.serialize(DummySerializer(target))

    ii = IndexIterator(index_list, shuffle=False, num=2)
    ii.serialize(DummyDeserializer(target))
    assert len(ii.current_index_list) == len(index_list)
    assert numpy.array_equal(ii.current_index_list, numpy.asarray(index_list))
    assert ii.current_pos == (3 + 6) % len(index_list) + 2
コード例 #9
0
def _test_index_iterator_serialization_with_shuffle():
    index_list = [1, 3, 5, 10]
    ii = IndexIterator(index_list, shuffle=True, num=2)

    indices1 = ii.get_next_indices(3)
    indices2 = ii.get_next_indices(6)
    indices3 = ii.__next__()

    assert len(ii.current_index_list) == len(index_list)
    for index in ii.current_index_list:
        assert index in index_list
    assert ii.current_pos == (3 + 6) % len(index_list) + 2

    target = dict()
    ii.serialize(DummySerializer(target))
    current_index_list_orig = ii.current_index_list

    ii = IndexIterator(index_list, shuffle=True, num=2)
    ii.serialize(DummyDeserializer(target))
    assert numpy.array_equal(ii.current_index_list, current_index_list_orig)
    assert ii.current_pos == (3 + 6) % len(index_list) + 2