예제 #1
0
def test_mnist_build_a_dichotomy():
    mnist = make_dataset()
    mnist_parity = [
        list(map(lambda x: 2 * x, range(4))),
        list(map(lambda x: 2 * x + 1, range(4)))
    ]
    mnist.build_dichLabels(mnist_parity, 'parity')

    mnist_smallness = [range(0, 4), range(4, 8)]
    mnist.build_dichLabels(mnist_smallness, 'smaller_than_4')
    assert True
예제 #2
0
def test_mnist_hstack_dichotomies():
    mnist = make_dataset()
    # Currently dichotomies will only be binary
    mnist_parity = [
        list(map(lambda x: 2 * x, range(4))),
        list(map(lambda x: 2 * x + 1, range(4)))
    ]
    mnist_smallness = [range(0, 4), range(4, 8)]
    mnist.build_dichLabels(mnist_smallness, 'smaller_than_4')
    mnist.build_dichLabels(mnist_parity, 'parity')
    mnist.hstack_dichs('parity', 'smaller_than_4')
    assert True
예제 #3
0
def test_mnist_product_dichotomies():
    mnist = make_dataset()
    # Currently dichotomies will only be binary
    mnist_parity = [
        list(map(lambda x: 2 * x, range(4))),
        list(map(lambda x: 2 * x + 1, range(4)))
    ]
    mnist_smallness = [range(0, 4), range(4, 8)]
    mnist_prod = [
        set(s1).intersection(set(s2)) for s2 in mnist_smallness
        for s1 in mnist_parity
    ]
    mnist.build_dichLabels(mnist_smallness, 'smaller_than_4')
    mnist.build_dichLabels(mnist_parity, 'parity')
    mnist.hstack_dichs('parity', 'smaller_than_4')
    assert True
예제 #4
0
def build_mnist_ds(filt_lbls=range(8), spl=0.04):
    from keras.datasets import mnist
    from data_tools import ImageDataset

    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    filt_labels = range(8)
    mnist = ImageDataset(x_train,
                         y_train,
                         x_test,
                         y_test,
                         filt_labels=filt_lbls,
                         spl=spl)
    mnist_parity = [
        list(map(lambda x: 2 * x, range(4))),
        list(map(lambda x: 2 * x + 1, range(4)))
    ]
    mnist_smallness = [range(0, 4), range(4, 8)]
    mnist_prod = [
        set(s1).intersection(set(s2)) for s2 in mnist_smallness
        for s1 in mnist_parity
    ]
    mnist.build_dichLabels(mnist_smallness, 'smaller_than_4')
    mnist.build_dichLabels(mnist_parity, 'parity')
    mnist.hstack_dichs('parity', 'smaller_than_4')
    mnist.compstack_dichs('parity', 'smaller_than_4')
    mnist.build_dichLabels(mnist_prod, 'parity_prod_smaller_than_4')

    return mnist