def test_concat():
    ds_pos = (from_dummy_data(with_label=True).named("data", "label").filter(
        label=allow_unique(2)).reorder(0).transform([lambda x: x + 1]))
    ds_neg = ds_pos.transform([lambda x: -x])
    ds_100x = ds_pos.transform([lambda x: 100 * x])

    # two
    ds_concat = concat(ds_pos, ds_neg)
    ds_concat_alt = ds_pos.concat(ds_neg)
    assert len(ds_concat) == len(ds_concat_alt) == len(ds_pos) + len(ds_neg)
    assert list(ds_concat) == list(
        ds_concat_alt) == list(ds_pos) + list(ds_neg)

    # three
    ds_concat3 = concat(ds_pos, ds_neg, ds_100x)
    ds_concat3_alt = ds_pos.concat(ds_neg, ds_100x)
    assert (len(ds_concat3) == len(ds_concat3_alt) ==
            len(ds_pos) + len(ds_neg) + len(ds_100x))
    assert (list(ds_concat3) == list(ds_concat3_alt) ==
            list(ds_pos) + list(ds_neg) + list(ds_100x))

    # error scenarios
    with pytest.raises(ValueError):
        with pytest.warns(UserWarning):
            concat()

    with pytest.warns(UserWarning):
        concat(ds_pos)

    with pytest.warns(UserWarning):
        ds_pos.concat()

    with pytest.warns(UserWarning):
        ds_pos.concat(
            from_dummy_numpy_data())  # different shapes result in warning
def test_take():
    ds = from_dummy_data().transform(lambda x: 10*x)

    ds_5 = ds.take(5)
    assert(list(ds)[:5] == list(ds_5))

    with pytest.raises(ValueError):
        ds.take(10000000)
def test_zip():
    ds_pos = from_dummy_data(num_total=10).named("pos")
    ds_neg = from_dummy_data(num_total=11).transform([lambda x: -x
                                                      ]).named("neg")
    ds_np = from_dummy_numpy_data()
    ds_labelled = from_dummy_data(num_total=10, with_label=True)

    # syntax 1
    zds = zipped(ds_pos, ds_neg)
    assert len(zds) == min(len(ds_pos), len(ds_neg))
    assert zds.shape == (*ds_pos.shape, *ds_neg.shape)
    # item names survive because there were no clashes
    assert zds.names == ["pos", "neg"]

    # syntax 2
    zds_alt = ds_pos.zip(ds_neg)
    assert len(zds_alt) == len(zds)
    assert zds_alt.shape == zds.shape

    # with self
    zds_self = zipped(ds_pos, ds_pos)
    assert len(zds_self) == len(ds_pos)
    assert zds_self.shape == (*ds_pos.shape, *ds_pos.shape)
    # item names are discarded because there are clashes
    assert zds_self.names == []

    # mix labelled and unlabelled data
    zds_mix_labelling = ds_neg.zip(ds_labelled)
    assert len(zds_mix_labelling) == min(len(ds_neg), len(ds_labelled))
    assert zds_mix_labelling.shape == (*ds_neg.shape, *ds_labelled.shape)

    # zip three
    zds_all = zipped(ds_pos, ds_neg, ds_np)
    assert len(zds) == min(len(ds_pos), len(ds_neg), len(ds_np))
    assert zds_all.shape == (*ds_pos.shape, *ds_neg.shape, *ds_np.shape)

    # error scenarios
    with pytest.raises(ValueError):
        with pytest.warns(UserWarning):
            zipped()

    with pytest.warns(UserWarning):
        zipped(ds_pos)

    with pytest.warns(UserWarning):
        ds_pos.zip()
def test_unique():
    ds = from_dummy_data(with_label=True).named('data', 'label')

    unique_labels = ds.unique('label')
    assert(unique_labels == ['a','b'])

    with pytest.warns(UserWarning):
        unique_items = ds.unique()
        assert(unique_items == list(ds))
def test_shuffle():
    seed = 42
    ds = from_dummy_data()
    expected_items = [i for i in ds]
    ds_shuffled = ds.shuffle(seed)
    found_items = [i for i in ds_shuffled]

    # same data
    assert(set(expected_items) == set(found_items))

    # different sequence
    assert(expected_items != found_items)
def test_categorical_template():
    ds1 = from_dummy_data(with_label=True).named("data","label")
    ds2 = ds1.shuffle(42)

    # when using categorical encoding on multiple datasets that are used together, the encoding may turn out different
    # this is because the indexes are built up and mapped as they are loaded (the order matters)
    assert set(ds1.transform(label=categorical())) != set(ds2.transform(label=categorical()))

    # we can use the categorical template to make matching encodings
    mapping_fn = categorical_template(ds1, "label")
    assert set(ds1.transform(label=categorical(mapping_fn))) == set(ds2.transform(label=categorical(mapping_fn)))

    # this is done implicitely when using the class-member functions
    assert set(ds1.categorical("label")) == set(ds2.categorical("label"))
def test_one_hot():
    ds = from_dummy_data(with_label=True).reorder(0,1,1).named('data', 'label', 'label_duplicate')
    assert(ds.unique('label') == ['a','b'])

    # alternative syntaxes
    ds_oh = ds.one_hot(1, encoding_size=2)
    ds_oh_alt1 = ds.one_hot("label", encoding_size=2)
    ds_oh_alt2 = ds.transform(label=one_hot(encoding_size=2))
    ds_oh_alt3 = ds.transform([None, one_hot(encoding_size=2)])

    ds_oh_auto = ds.one_hot("label") # automatically compute encoding size

    expected = [np.array([True, False]), np.array([False, True])]

    for l, l1, l2, l3, la, e in zip(
        ds_oh.unique('label'), 
        ds_oh_alt1.unique('label'), 
        ds_oh_alt2.unique('label'), 
        ds_oh_alt3.unique('label'), 
        ds_oh_auto.unique('label'),
        expected
    ):
        assert(np.array_equal(l,l1))
        assert(np.array_equal(l,l2))
        assert(np.array_equal(l,l3))
        assert(np.array_equal(l,la))
        assert(np.array_equal(l,e)) #type:ignore

    for x, l, l2 in ds_oh:
        ind = 0 if l2 == 'a' else 1
        assert(np.array_equal(l, expected[ind])) #type:ignore

    # spiced up
    ds_oh_userdef = ds.one_hot('label', encoding_size=3, mapping_fn=lambda x: 1 if x == 'a' else 0, dtype='int')

    for l, e in zip(ds_oh_userdef.unique('label'), [np.array([0,1,0]), np.array([1,0,0])]):
        assert(np.array_equal(l,e)) #type:ignore

    # error scenarios
    with pytest.raises(TypeError):
        ds.one_hot() # we need some arguments

    with pytest.raises(IndexError):
        ds.one_hot(42, encoding_size=2) # wrong key

    with pytest.raises(IndexError):
        list(ds.one_hot('label', encoding_size=1)) # encoding size too small -- found at runtime

    with pytest.raises(KeyError):
        ds.one_hot("wrong", encoding_size=2) # wrong key
def test_counts():
    num_total=11
    ds = from_dummy_data(num_total=num_total, with_label=True).named('data', 'label')

    counts = ds.counts('label') # name based
    counts_alt = ds.counts(1) # index based

    expected_counts = [('a', 5), ('b', num_total-5)]
    assert(counts == counts_alt == expected_counts)

    with pytest.warns(UserWarning):
        counts_all = ds.counts()
        # count all if no args are given
        assert(set(counts_all) == set([(x, 1) for x in ds]))
def test_repeat():
    ds = from_dummy_data()

    # itemwise
    ds_item = ds.repeat(3)
    ds_item_alt = ds.repeat(3, mode='itemwise')

    assert(set(ds) == set(ds_item_alt))
    assert(list(ds_item) == list(ds_item_alt))

    # whole
    ds_whole = ds.repeat(2, mode='whole')
    assert(set(ds) == set(ds_whole))
    assert(list(ds) == list(ds_whole)[:len(ds)] == list(ds_whole)[len(ds):])
def test_split_filter():
    num_total=10
    ds = from_dummy_data(num_total=num_total, with_label=True).named('data', 'label')

    # expected items
    a = [ (x, 'a') for x in list(range(5))]
    b = [ (x, 'b') for x in list(range(5,num_total))]
    even_a = [x for x in a if x[0]%2==0]
    odd_a  = [x for x in a if x[0]%2==1]
    even_b = [x for x in b if x[0]%2==0]
    odd_b  = [x for x in b if x[0]%2==1]

    # itemwise
    ds_even, ds_odd = ds.split_filter([lambda x: x%2==0])
    assert(list(ds_even) == even_a + even_b) 
    assert(list(ds_odd) == odd_a + odd_b) 

    ds_even_a, ds_not_even_a = ds.split_filter([lambda x: x%2==0, lambda x: x=='a'])
    assert(list(ds_even_a) == even_a) 
    assert(list(ds_not_even_a) == odd_a + b) 

    # by key
    ds_b, ds_a = ds.split_filter(label=lambda x:x=='b')
    assert(list(ds_b) == b)
    assert(list(ds_a) == a)

    # bulk
    ds_odd_b, ds_even_b = ds.split_filter(lambda x: x[0]%2==1 and x[1]=='b')
    assert(list(ds_odd_b) == odd_b)
    assert(list(ds_even_b) == a + even_b)

    # mix
    ds_even_b, ds_not_even_b = ds.split_filter([lambda x: x%2==0], label=lambda x: x=='b')
    assert(list(ds_even_b) == even_b)
    assert(list(ds_not_even_b) == [x for x in list(ds) if not x in even_b ])

    # sample_classwise
    ds_classwise_2, ds_classwise_rest = ds.split_filter(label=allow_unique(2))
    assert(list(ds_classwise_2) == list(a[:2] + b[:2]))
    assert(list(ds_classwise_rest) == list(a[2:] + b[2:]))

    # error scenarios
    with pytest.raises(ValueError):
        ds_same = ds.split_filter() # no args

    with pytest.raises(ValueError):
        ds.split_filter([None, None, None]) # too many args

    with pytest.raises(KeyError):
        ds.split_filter(badkey=lambda x:True) # key doesn't exist
def test_sample():
    seed = 42
    ds = from_dummy_data()
    ds_sampled = ds.sample(5, seed)
    found_items = [i for i in ds_sampled]

    # check list uniqueness
    assert(len(found_items) == len(set(found_items)))

    # check items
    expected_items = [ (i,) for i in [10,1,0,4,9]]
    assert(set(expected_items) == set(found_items))

    # check that different seeds yield different results
    ds_sampled2 = ds.sample(5, seed+1)
    found_items2 = [i for i in ds_sampled2]
    assert(set(found_items2) != set(found_items))
def test_categorical():
    ds = from_dummy_data(with_label=True).reorder(0,1,1).named('data', 'label', 'label_duplicate')

    assert(ds.unique('label') == ['a','b'])

    ds_label = ds.categorical(1)
    ds_label_alt = ds.categorical('label')

    # alternative syntaxes
    ds_label = ds.categorical(1)
    ds_label_alt1 = ds.categorical("label")
    ds_label_alt2 = ds.transform(label=categorical())
    ds_label_alt3 = ds.transform([None, categorical()])

    expected = [0, 1]

    for l, l1, l2, l3, e in zip(
        ds_label.unique('label'), 
        ds_label_alt1.unique('label'), 
        ds_label_alt2.unique('label'), 
        ds_label_alt3.unique('label'), 
        expected
    ):
        assert(np.array_equal(l,l1))
        assert(np.array_equal(l,l2))
        assert(np.array_equal(l,l3))
        assert(np.array_equal(l,e)) #type:ignore

    assert(list(ds_label) == [(d, 0 if l == 'a' else 1 ,l2) for d, l, l2 in ds])

    ds_label_userdef = ds.categorical('label', lambda x: 1 if x == 'a' else 0)

    assert(ds_label_userdef.unique('label') == [1, 0])
    assert(list(ds_label_userdef) == [(d, 1 if l == 'a' else 0 ,l2) for d, l, l2 in ds])

    # error scenarios
    with pytest.raises(TypeError):
        ds.categorical() # we need to know what to label

    with pytest.raises(IndexError):
        ds.categorical(42) # wrong key

    with pytest.raises(KeyError):
        ds.categorical("wrong") # wrong key
def test_split():
    seed = 42
    ds = from_dummy_data()
    ds1, ds2, ds3 = ds.split([0.6, 0.3, 0.1], seed=seed)

    # new sets are distinct
    assert(set(ds1) != set(ds2))
    assert(set(ds1) != set(ds3))
    assert(set(ds2) != set(ds3))

    # no values are lost
    assert(set(ds) == set(ds1).union(set(ds2),set(ds3)))

    # repeat for wildcard
    ds1w, ds2w, ds3w = ds.split([0.6, -1, 0.1], seed=seed)

    # using wildcard produces same results
    assert(set(ds1) == set(ds1w))
    assert(set(ds2) == set(ds2w))
    assert(set(ds3) == set(ds3w))
def test_cartesian_product():
    ds_pos = from_dummy_data().take(2).transform([lambda x: x + 1])
    ds_10x = ds_pos.transform([lambda x: 10 * x])
    ds_100x = ds_pos.transform([lambda x: 100 * x])

    # two
    ds_prod2 = cartesian_product(ds_pos, ds_10x)

    ds_prod2_alt = ds_pos.cartesian_product(ds_10x)

    expected2 = [(1, 10), (2, 10), (1, 20), (2, 20)]
    assert list(ds_prod2) == list(ds_prod2_alt) == expected2
    assert len(ds_prod2) == len(ds_prod2_alt) == len(set(expected2))
    ds_prod2.shape

    # three
    expected3 = [
        (1, 10, 100),
        (2, 10, 100),
        (1, 20, 100),
        (2, 20, 100),
        (1, 10, 200),
        (2, 10, 200),
        (1, 20, 200),
        (2, 20, 200),
    ]
    ds_prod3 = cartesian_product(ds_pos, ds_10x, ds_100x)
    ds_prod3_alt = ds_pos.cartesian_product(ds_10x, ds_100x)
    assert list(ds_prod3) == list(ds_prod3_alt) == expected3
    assert len(ds_prod3) == len(ds_prod3_alt) == len(set(expected3))

    # error scenarios
    with pytest.raises(ValueError):
        with pytest.warns(UserWarning):
            cartesian_product()

    with pytest.warns(UserWarning):
        cartesian_product(ds_pos)

    with pytest.warns(UserWarning):
        ds_pos.cartesian_product()