def test_to_tensorflow():
    # prep data
    ds = from_dummy_numpy_data().named("data", "label").one_hot("label").shuffle(42)
    tf_ds = ds.to_tensorflow().batch(2)

    # prep model
    import tensorflow as tf #type:ignore     
    tf.random.set_seed(42)

    model = tf.keras.Sequential([
        tf.keras.layers.Input(ds.shape[0]),
        tf.keras.layers.Dense(10, activation='relu'),
        tf.keras.layers.Dense(2, activation='softmax'),
    ])

    model.compile(
        optimizer='adam',
        loss=tf.keras.losses.CategoricalCrossentropy(),
        metrics=['accuracy']
    )

    # model should be able to fit the data
    model.fit(tf_ds, epochs=10)
    preds = model.predict(tf_ds)
    pred_labels = np.argmax(preds, axis=1)

    expected_labels = np.array([v[0] for v in ds.reorder('label').categorical(0)])
    assert sum(pred_labels == expected_labels) > len(ds)//2  #type:ignore
def test_concat():
    ds_pos = (from_dummy_data(with_label=True).named("data", "label").filter(
        label=allow_unique(2)).reorder(0).transform([lambda x: x + 1]))
    ds_neg = ds_pos.transform([lambda x: -x])
    ds_100x = ds_pos.transform([lambda x: 100 * x])

    # two
    ds_concat = concat(ds_pos, ds_neg)
    ds_concat_alt = ds_pos.concat(ds_neg)
    assert len(ds_concat) == len(ds_concat_alt) == len(ds_pos) + len(ds_neg)
    assert list(ds_concat) == list(
        ds_concat_alt) == list(ds_pos) + list(ds_neg)

    # three
    ds_concat3 = concat(ds_pos, ds_neg, ds_100x)
    ds_concat3_alt = ds_pos.concat(ds_neg, ds_100x)
    assert (len(ds_concat3) == len(ds_concat3_alt) ==
            len(ds_pos) + len(ds_neg) + len(ds_100x))
    assert (list(ds_concat3) == list(ds_concat3_alt) ==
            list(ds_pos) + list(ds_neg) + list(ds_100x))

    # error scenarios
    with pytest.raises(ValueError):
        with pytest.warns(UserWarning):
            concat()

    with pytest.warns(UserWarning):
        concat(ds_pos)

    with pytest.warns(UserWarning):
        ds_pos.concat()

    with pytest.warns(UserWarning):
        ds_pos.concat(
            from_dummy_numpy_data())  # different shapes result in warning
def test_transform():
    ds = from_dummy_numpy_data().named("data","label") 

    # simple
    ds_itemwise = ds.transform([lambda x: x/255.0])
    ds_keywise = ds.transform(data=lambda x: x/255.0)
    ds_build = ds.transform(lambda x: (x[0]/255.0, x[1])) 
    
    for (d, l), (di, li), (dk, lk), (db, lb) in zip(list(ds), list(ds_itemwise), list(ds_keywise), list(ds_build)):
        assert(np.array_equal(d/255.0, di))
        assert(np.array_equal(di, dk))
        assert(np.array_equal(di, db))
        assert(l == li == lk == lb)

    # complex
    ds_complex = ds.transform(
        data=[reshape(DUMMY_NUMPY_DATA_SHAPE_2D), image_resize((10,10))],
        label=one_hot(encoding_size=2)
    )

    assert(ds_complex.shape == ((10,10),(2,)))

    # error scenarios
    with pytest.warns(UserWarning):
        # no args
        ds.transform()
    
    with pytest.raises(ValueError): 
        # too many transforms given
        ds.transform([reshape(DUMMY_NUMPY_DATA_SHAPE_2D), None, None])
def test_item_naming():
    ds = from_dummy_numpy_data()
    items = [x for x in ds]
    assert(ds.names == [])

    item_names = ['mydata', 'mylabel']

    # named transform syntax doesn't work without item_names
    with pytest.raises(Exception):
        ds.transform(moddata=reshape(DUMMY_NUMPY_DATA_SHAPE_2D))

    # passed one by one as arguments
    ds.named(*item_names)
    assert(ds.names == item_names)

    # passed in a list, overide previous
    item_names2 = ['moddata', 'modlabel']
    ds.named(item_names2) #type: ignore
    assert(ds.names == item_names2)

    # test named transform syntax
    ds_trans = ds.transform(moddata=reshape(DUMMY_NUMPY_DATA_SHAPE_2D))
    items_trans = [x for x in ds_trans]
    for (old_data, _), (new_data, _) in zip(items, items_trans):
        assert(set(old_data) == set(new_data.flatten()))
        assert(old_data.shape != new_data.shape)

    # invalid name doesn't work
    with pytest.raises(Exception):
        ds.transform(badname=reshape(DUMMY_NUMPY_DATA_SHAPE_2D))
def test_shape():
    def get_data(i):
        return i,i

    # no shape yet
    ds = loaders.Loader(get_data)
    assert(ds.shape == _DEFAULT_SHAPE)

    # shape given
    ds.append(1)
    assert(ds.shape == (_DEFAULT_SHAPE, _DEFAULT_SHAPE))

    # numpy data
    ds_np = from_dummy_numpy_data().reshape(DUMMY_NUMPY_DATA_SHAPE_2D)
    assert( ds_np.shape == (DUMMY_NUMPY_DATA_SHAPE_2D,_DEFAULT_SHAPE) )

    # changed to new size
    IMG_SIZE = (6,6)
    ds_img = ds_np.image_resize(IMG_SIZE)
    assert( ds_img.shape == (IMG_SIZE,_DEFAULT_SHAPE) )

    # image with three channels
    DUMMY_NUMPY_DATA_SHAPE_3D
    ds_np3 = ds_np.reshape(DUMMY_NUMPY_DATA_SHAPE_3D)
    assert( ds_np3.shape == (DUMMY_NUMPY_DATA_SHAPE_3D,_DEFAULT_SHAPE) )

    ds_img3 = ds_np3.image_resize(IMG_SIZE)
    assert( ds_img3.shape == ((*IMG_SIZE,3),_DEFAULT_SHAPE) )
def test_reorder():
    ds = from_dummy_numpy_data()
    ds.named("mydata", "mylabel")

    ## error scenarios
    with pytest.warns(UserWarning):
        # no order given
        ds_ignored = ds.reorder()
        assert(ds == ds_ignored)

    with pytest.raises(ValueError):
        # indexes out of range
        ds_re = ds.reorder(3,4)

    with pytest.raises(KeyError):
        # a keys doesn't exist
        ds_re = ds.reorder("badkey", "mydata")

    ## working scenarios

    # using indexes
    ds_re = ds.reorder(1,0)
    for (ldata, llbl), (rlbl, rdata) in zip(list(ds), list(ds_re)):
        assert(np.array_equal(ldata, rdata))
        assert(llbl == rlbl)

    # same results using keys
    ds_re_key = ds.reorder("mylabel","mydata")
    for (llbl, ldata), (rlbl, rdata) in zip(list(ds_re_key), list(ds_re)):
        assert(np.array_equal(ldata, rdata))
        assert(llbl == rlbl)

    # same result using a mix
    ds_re_mix = ds.reorder(1,"mydata")
    for (llbl, ldata), (rlbl, rdata) in zip(list(ds_re_mix), list(ds_re)):
        assert(np.array_equal(ldata, rdata))
        assert(llbl == rlbl)

    # we can even place the same element multiple times
    ds_re_creative = ds.reorder(0,1,1,0)
    for (ldata, llbl), (rdata1, rlbl1, rlbl2, rdata2 ) in zip(list(ds), list(ds_re_creative)):
        assert(np.array_equal(ldata, rdata1))
        assert(np.array_equal(ldata, rdata2))
        assert(llbl == rlbl1 == rlbl2)

    # shape updates accordingly
    assert(ds_re_creative.shape == (DUMMY_NUMPY_DATA_SHAPE_1D, _DEFAULT_SHAPE, _DEFAULT_SHAPE, DUMMY_NUMPY_DATA_SHAPE_1D))

    # error scenarios
    with pytest.warns(UserWarning):
        ds.named('one','two').reorder(0,1,1) # key needs to be unique, but wouldn't be
def test_numpy_image_numpy_conversion():
    ds_1d = from_dummy_numpy_data()
    items_1d = [x for x in ds_1d]

    # Warns because no elements where converted
    with pytest.warns(None) as record:
        ds2 = ds_1d.image() # skipped all because they could't be converted
        ds3 = ds_1d.image(False, False)
    assert(len(record) == 2) # warns on both

    # The two previous statements didn't create any changes
    items2 = [x for x in ds2]
    items3 = [x for x in ds3]
    for (one, _), (two, _), (three, _) in zip(items_1d, items2, items3):
        assert(np.array_equal(one, two))
        assert(np.array_equal(two, three))

    # Force conversion of first arg - doesn't work due to shape incompatibility
    with pytest.raises(Exception):
        # Tries to convert first argument
        ds_1d.image(True)

    ds_2d = ds_1d.reshape(DUMMY_NUMPY_DATA_SHAPE_2D)
    items_2d = [x for x in ds_2d]

    # Succesful conversion should happen here
    with pytest.warns(None) as record:
        ds_img = ds_2d.image()
    assert(len(record) == 0)

    items_img = [x for x in ds_img]
    for (one, lbl1), (two, lbl2) in zip(items_2d, items_img):
        assert(type(one) == np.ndarray)
        assert(type(two) == Image.Image)
        assert(lbl1 == lbl2)

    # test the backward-conversion
    ds_np = ds_img.numpy()
    items_np = [x for x in ds_np]
    for (one, lbl1), (two, lbl2) in zip(items_2d, items_np):
        assert(type(one) == type(two))
        assert(np.array_equal(one, two))
        assert(lbl1 == lbl2)

    # well get a warning if it doens't convert any
    with pytest.warns(UserWarning):
        ds_img.numpy(False)
def test_zip():
    ds_pos = from_dummy_data(num_total=10).named("pos")
    ds_neg = from_dummy_data(num_total=11).transform([lambda x: -x
                                                      ]).named("neg")
    ds_np = from_dummy_numpy_data()
    ds_labelled = from_dummy_data(num_total=10, with_label=True)

    # syntax 1
    zds = zipped(ds_pos, ds_neg)
    assert len(zds) == min(len(ds_pos), len(ds_neg))
    assert zds.shape == (*ds_pos.shape, *ds_neg.shape)
    # item names survive because there were no clashes
    assert zds.names == ["pos", "neg"]

    # syntax 2
    zds_alt = ds_pos.zip(ds_neg)
    assert len(zds_alt) == len(zds)
    assert zds_alt.shape == zds.shape

    # with self
    zds_self = zipped(ds_pos, ds_pos)
    assert len(zds_self) == len(ds_pos)
    assert zds_self.shape == (*ds_pos.shape, *ds_pos.shape)
    # item names are discarded because there are clashes
    assert zds_self.names == []

    # mix labelled and unlabelled data
    zds_mix_labelling = ds_neg.zip(ds_labelled)
    assert len(zds_mix_labelling) == min(len(ds_neg), len(ds_labelled))
    assert zds_mix_labelling.shape == (*ds_neg.shape, *ds_labelled.shape)

    # zip three
    zds_all = zipped(ds_pos, ds_neg, ds_np)
    assert len(zds) == min(len(ds_pos), len(ds_neg), len(ds_np))
    assert zds_all.shape == (*ds_pos.shape, *ds_neg.shape, *ds_np.shape)

    # error scenarios
    with pytest.raises(ValueError):
        with pytest.warns(UserWarning):
            zipped()

    with pytest.warns(UserWarning):
        zipped(ds_pos)

    with pytest.warns(UserWarning):
        ds_pos.zip()
def test_to_pytorch():
    # prep data
    ds = from_dummy_numpy_data().named("data", "label").one_hot("label")
    pt_ds = ds.to_pytorch()

    import torch
    from torch.utils.data import DataLoader
    loader = DataLoader(pt_ds, batch_size=2, shuffle=False)
    
    elem = next(iter(loader))

    # data equals
    assert(torch.all(torch.eq(elem[0][0], torch.Tensor(ds[0][0])))) #type:ignore
    assert(torch.all(torch.eq(elem[0][1], torch.Tensor(ds[1][0])))) #type:ignore

    # labels equal
    assert(torch.all(torch.eq(elem[1][0], torch.Tensor(ds[0][1])))) #type:ignore
    assert(torch.all(torch.eq(elem[1][1], torch.Tensor(ds[1][1])))) #type:ignore
def test_image_resize():
    ds = from_dummy_numpy_data().reshape(DUMMY_NUMPY_DATA_SHAPE_2D)
    for tpl in ds:
        data = tpl[0]
        assert(data.shape == DUMMY_NUMPY_DATA_SHAPE_2D)

    NEW_SIZE = (5,5)

    # works directly on numpy arrays (ints)
    ds_resized = ds.image_resize(NEW_SIZE)
    for tpl in ds_resized:
        data = tpl[0]
        assert(data.size == NEW_SIZE)
        assert(data.mode == 'L') # grayscale int

    # also if they are floats
    ds_resized_float = ds.transform([custom(np.float32)]).image_resize(NEW_SIZE)
    for tpl in ds_resized_float:
        data = tpl[0]
        assert(data.size == NEW_SIZE)
        assert(data.mode == 'F') # grayscale float

    # works directly on strings
    ds_str = loaders.from_folder_data(get_test_dataset_path(DATASET_PATHS.FOLDER_DATA))
    ds_resized_from_str = ds_str.image_resize(NEW_SIZE)
    for tpl in ds_resized_from_str:
        data = tpl[0]
        assert(data.size == NEW_SIZE)

    # works on other images (scaling down)
    ds_resized_again = ds_resized.image_resize(DUMMY_NUMPY_DATA_SHAPE_2D)
    for tpl in ds_resized_again:
        data = tpl[0]
        assert(data.size == DUMMY_NUMPY_DATA_SHAPE_2D)

    # Test error scenarios
    with pytest.raises(ValueError):
        ds.image_resize() # No args

    with pytest.raises(ValueError):
        ds.image_resize(NEW_SIZE, NEW_SIZE, NEW_SIZE) # Too many args

    with pytest.raises(AssertionError):
        ds.image_resize((4,4,4)) # Invalid size
def test_reshape():
    ds = from_dummy_numpy_data().named('data','label')
    items = list(ds)

    s = ds.shape
    assert(ds.shape == (DUMMY_NUMPY_DATA_SHAPE_1D, _DEFAULT_SHAPE) )
    assert(ds[0][0].shape == DUMMY_NUMPY_DATA_SHAPE_1D)

    # reshape adding extra dim
    ds_r = ds.reshape(DUMMY_NUMPY_DATA_SHAPE_2D)
    ds_r_alt = ds.reshape(data=DUMMY_NUMPY_DATA_SHAPE_2D)
    items_r = list(ds_r)
    items_r_alt = list(ds_r_alt)

    assert(ds_r.shape == ( DUMMY_NUMPY_DATA_SHAPE_2D, _DEFAULT_SHAPE) )
    assert(ds_r[0][0].shape == DUMMY_NUMPY_DATA_SHAPE_2D)

    for (old_data, l), (new_data, ln), (new_data_alt, lna) in zip(items, items_r, items_r_alt):
        assert(set(old_data) == set(new_data.flatten()) == set(new_data_alt.flatten()))
        assert(old_data.shape != new_data.shape == new_data_alt.shape)
        assert(l == ln == lna)

    # use wildcard
    ds_wild = ds.reshape((-1,DUMMY_NUMPY_DATA_SHAPE_2D[1]))
    items_wild = list(ds_wild)
    for (old_data, _), (new_data, _) in zip(items_r, items_wild):
        assert(np.array_equal(old_data, new_data))

    # reshape back, alternative syntax
    ds_back = ds_r.reshape(DUMMY_NUMPY_DATA_SHAPE_1D, None)
    items_back = [x for x in ds_back]

    for (old_data, _), (new_data, _) in zip(items, items_back):
        assert(np.array_equal(old_data, new_data))

    # yet another syntax
    ds_trans = ds.transform([reshape(DUMMY_NUMPY_DATA_SHAPE_2D)])
    items_trans = [x for x in ds_trans]
    for (old_data, _), (new_data, _) in zip(items_r, items_trans):
        assert(np.array_equal(old_data, new_data))

    # doing nothing also works
    ds_dummy = ds.reshape(None, None)
    items_dummy = [x for x in ds_dummy]
    for (old_data, _), (new_data, _) in zip(items, items_dummy):
        assert(np.array_equal(old_data, new_data))

    # TODO test reshape on string data
    ds_str = loaders.from_folder_data(get_test_dataset_path(DATASET_PATHS.FOLDER_DATA))

    with pytest.raises(ValueError):
        # string has no shape 
        ds_str.reshape((1,2))

    with pytest.raises(ValueError):
        # No input
        ds.reshape() 

    with pytest.raises(TypeError):
        # bad input
        ds.reshape('whazzagh') 
    
    with pytest.raises(ValueError):
        # Too many inputs
        ds.reshape(None, None, None) 

    with pytest.raises(ValueError):
        # Dimensions don't match
        ds.reshape((13,13))