예제 #1
0
def test_tf_dataset_load_prepare_fashion_mnist_incompatible_shape():
    """ Load Fashion MNIST dataset (quick), attempt to prepare for input with less dimensions """
    test_data = DataSet()
    load_test_data('fashion_mnist', test_data, 128)

    with pytest.raises(ModelError):
        test_data.x_format = keras_prepare_input(np.float, np.array([None,
                                                                     32]),
                                                 test_data.x)
예제 #2
0
def test_tf_dataset_load_prepare_fashion_mnist():
    """ Load Fashion MNIST dataset (quick), prepare for model with 32x32 input """
    test_data = DataSet()
    load_test_data('fashion_mnist', test_data, 128)

    assert test_data.mode is DataSet.MODE_FILESET
    assert np.array_equal(test_data.x.shape, [128, 28, 28, 1])
    assert test_data.x.dtype is np.dtype('|u1'), "uint8"
    assert np.array_equal(test_data.y.shape, [128])
    assert test_data.y.dtype is np.dtype('<i8'), "int64"

    test_data.x_format = keras_prepare_input(np.float,
                                             np.array([None, 32, 32, 1]),
                                             test_data.x)

    assert np.array_equal(test_data.x_format.shape, [128, 32, 32, 1])
    assert test_data.x_format.dtype is np.dtype('<f8'), "float32"