Пример #1
0
def _test_balanced_serial_iterator_no_batch_balancing():
    x = numpy.arange(8)
    t = numpy.asarray([0, 0, -1, 1, 1, 2, -1, 1])
    iterator = BalancedSerialIterator(NumpyTupleDataset(x, t),
                                      batch_size=9,
                                      labels=t,
                                      ignore_labels=-1,
                                      batch_balancing=False)
    # In this case, we have 3 examples of label=1.
    # When BalancedSerialIterator runs, all label examples are sampled 3 times
    # in one epoch.
    # Therefore, number of data is "augmented" as 9
    # 3 (number of label types) * 3 (number of maximum examples in one label)
    expect_N_augmented = 9
    assert iterator.N_augmented == expect_N_augmented
    # iterator.show_label_stats()  # we can show label stats

    batch = iterator.next()

    assert len(batch) == 9
    labels_batch = numpy.array([example[-1] for example in batch])

    assert numpy.sum(labels_batch == 0) == 3
    assert numpy.sum(labels_batch == 1) == 3
    assert numpy.sum(labels_batch == 2) == 3
Пример #2
0
def _test_balanced_serial_iterator_with_batch_balancing():
    x = numpy.arange(8)
    t = numpy.asarray([0, 0, -1, 1, 1, 2, -1, 1])
    iterator = BalancedSerialIterator(NumpyTupleDataset(x, t),
                                      batch_size=3,
                                      labels=t,
                                      ignore_labels=-1,
                                      batch_balancing=True)
    expect_N_augmented = 9
    assert iterator.N_augmented == expect_N_augmented
    batch1 = iterator.next()
    batch2 = iterator.next()
    batch3 = iterator.next()
    for batch in [batch1, batch2, batch3]:
        assert len(batch) == 3
        labels_batch = numpy.array([example[-1] for example in batch])
        assert numpy.sum(labels_batch == 0) == 1
        assert numpy.sum(labels_batch == 1) == 1
        assert numpy.sum(labels_batch == 2) == 1
Пример #3
0
def _test_balanced_serial_iterator_serialization_with_batch_balancing():
    x = numpy.arange(8)
    t = numpy.asarray([0, 0, -1, 1, 1, 2, -1, 1])
    iterator = BalancedSerialIterator(NumpyTupleDataset(x, t),
                                      batch_size=3,
                                      labels=t,
                                      ignore_labels=-1,
                                      batch_balancing=True)
    batch1 = iterator.next()  # NOQA
    batch2 = iterator.next()  # NOQA
    batch3 = iterator.next()  # NOQA

    assert iterator.current_position == 0
    assert iterator.epoch == 1
    assert iterator.is_new_epoch

    target = dict()
    iterator.serialize(DummySerializer(target))
    current_index_list_orig = dict()
    current_pos_orig = dict()
    for label, index_iterator in iterator.labels_iterator_dict.items():
        ii_label = 'index_iterator_{}'.format(label)
        current_index_list_orig[ii_label] = index_iterator.current_index_list
        current_pos_orig[ii_label] = index_iterator.current_pos

    iterator = BalancedSerialIterator(NumpyTupleDataset(x, t),
                                      batch_size=3,
                                      labels=t,
                                      ignore_labels=-1,
                                      batch_balancing=True)
    iterator.serialize(DummyDeserializer(target))
    assert iterator.current_position == 0
    assert iterator.epoch == 1
    assert iterator.is_new_epoch
    for label, index_iterator in iterator.labels_iterator_dict.items():
        ii_label = 'index_iterator_{}'.format(label)
        assert numpy.array_equal(index_iterator.current_index_list,
                                 current_index_list_orig[ii_label])
        assert index_iterator.current_pos == current_pos_orig[ii_label]