def test_concat(): ds_pos = (from_dummy_data(with_label=True).named("data", "label").filter( label=allow_unique(2)).reorder(0).transform([lambda x: x + 1])) ds_neg = ds_pos.transform([lambda x: -x]) ds_100x = ds_pos.transform([lambda x: 100 * x]) # two ds_concat = concat(ds_pos, ds_neg) ds_concat_alt = ds_pos.concat(ds_neg) assert len(ds_concat) == len(ds_concat_alt) == len(ds_pos) + len(ds_neg) assert list(ds_concat) == list( ds_concat_alt) == list(ds_pos) + list(ds_neg) # three ds_concat3 = concat(ds_pos, ds_neg, ds_100x) ds_concat3_alt = ds_pos.concat(ds_neg, ds_100x) assert (len(ds_concat3) == len(ds_concat3_alt) == len(ds_pos) + len(ds_neg) + len(ds_100x)) assert (list(ds_concat3) == list(ds_concat3_alt) == list(ds_pos) + list(ds_neg) + list(ds_100x)) # error scenarios with pytest.raises(ValueError): with pytest.warns(UserWarning): concat() with pytest.warns(UserWarning): concat(ds_pos) with pytest.warns(UserWarning): ds_pos.concat() with pytest.warns(UserWarning): ds_pos.concat( from_dummy_numpy_data()) # different shapes result in warning
def test_take(): ds = from_dummy_data().transform(lambda x: 10*x) ds_5 = ds.take(5) assert(list(ds)[:5] == list(ds_5)) with pytest.raises(ValueError): ds.take(10000000)
def test_zip(): ds_pos = from_dummy_data(num_total=10).named("pos") ds_neg = from_dummy_data(num_total=11).transform([lambda x: -x ]).named("neg") ds_np = from_dummy_numpy_data() ds_labelled = from_dummy_data(num_total=10, with_label=True) # syntax 1 zds = zipped(ds_pos, ds_neg) assert len(zds) == min(len(ds_pos), len(ds_neg)) assert zds.shape == (*ds_pos.shape, *ds_neg.shape) # item names survive because there were no clashes assert zds.names == ["pos", "neg"] # syntax 2 zds_alt = ds_pos.zip(ds_neg) assert len(zds_alt) == len(zds) assert zds_alt.shape == zds.shape # with self zds_self = zipped(ds_pos, ds_pos) assert len(zds_self) == len(ds_pos) assert zds_self.shape == (*ds_pos.shape, *ds_pos.shape) # item names are discarded because there are clashes assert zds_self.names == [] # mix labelled and unlabelled data zds_mix_labelling = ds_neg.zip(ds_labelled) assert len(zds_mix_labelling) == min(len(ds_neg), len(ds_labelled)) assert zds_mix_labelling.shape == (*ds_neg.shape, *ds_labelled.shape) # zip three zds_all = zipped(ds_pos, ds_neg, ds_np) assert len(zds) == min(len(ds_pos), len(ds_neg), len(ds_np)) assert zds_all.shape == (*ds_pos.shape, *ds_neg.shape, *ds_np.shape) # error scenarios with pytest.raises(ValueError): with pytest.warns(UserWarning): zipped() with pytest.warns(UserWarning): zipped(ds_pos) with pytest.warns(UserWarning): ds_pos.zip()
def test_unique(): ds = from_dummy_data(with_label=True).named('data', 'label') unique_labels = ds.unique('label') assert(unique_labels == ['a','b']) with pytest.warns(UserWarning): unique_items = ds.unique() assert(unique_items == list(ds))
def test_shuffle(): seed = 42 ds = from_dummy_data() expected_items = [i for i in ds] ds_shuffled = ds.shuffle(seed) found_items = [i for i in ds_shuffled] # same data assert(set(expected_items) == set(found_items)) # different sequence assert(expected_items != found_items)
def test_categorical_template(): ds1 = from_dummy_data(with_label=True).named("data","label") ds2 = ds1.shuffle(42) # when using categorical encoding on multiple datasets that are used together, the encoding may turn out different # this is because the indexes are built up and mapped as they are loaded (the order matters) assert set(ds1.transform(label=categorical())) != set(ds2.transform(label=categorical())) # we can use the categorical template to make matching encodings mapping_fn = categorical_template(ds1, "label") assert set(ds1.transform(label=categorical(mapping_fn))) == set(ds2.transform(label=categorical(mapping_fn))) # this is done implicitely when using the class-member functions assert set(ds1.categorical("label")) == set(ds2.categorical("label"))
def test_one_hot(): ds = from_dummy_data(with_label=True).reorder(0,1,1).named('data', 'label', 'label_duplicate') assert(ds.unique('label') == ['a','b']) # alternative syntaxes ds_oh = ds.one_hot(1, encoding_size=2) ds_oh_alt1 = ds.one_hot("label", encoding_size=2) ds_oh_alt2 = ds.transform(label=one_hot(encoding_size=2)) ds_oh_alt3 = ds.transform([None, one_hot(encoding_size=2)]) ds_oh_auto = ds.one_hot("label") # automatically compute encoding size expected = [np.array([True, False]), np.array([False, True])] for l, l1, l2, l3, la, e in zip( ds_oh.unique('label'), ds_oh_alt1.unique('label'), ds_oh_alt2.unique('label'), ds_oh_alt3.unique('label'), ds_oh_auto.unique('label'), expected ): assert(np.array_equal(l,l1)) assert(np.array_equal(l,l2)) assert(np.array_equal(l,l3)) assert(np.array_equal(l,la)) assert(np.array_equal(l,e)) #type:ignore for x, l, l2 in ds_oh: ind = 0 if l2 == 'a' else 1 assert(np.array_equal(l, expected[ind])) #type:ignore # spiced up ds_oh_userdef = ds.one_hot('label', encoding_size=3, mapping_fn=lambda x: 1 if x == 'a' else 0, dtype='int') for l, e in zip(ds_oh_userdef.unique('label'), [np.array([0,1,0]), np.array([1,0,0])]): assert(np.array_equal(l,e)) #type:ignore # error scenarios with pytest.raises(TypeError): ds.one_hot() # we need some arguments with pytest.raises(IndexError): ds.one_hot(42, encoding_size=2) # wrong key with pytest.raises(IndexError): list(ds.one_hot('label', encoding_size=1)) # encoding size too small -- found at runtime with pytest.raises(KeyError): ds.one_hot("wrong", encoding_size=2) # wrong key
def test_counts(): num_total=11 ds = from_dummy_data(num_total=num_total, with_label=True).named('data', 'label') counts = ds.counts('label') # name based counts_alt = ds.counts(1) # index based expected_counts = [('a', 5), ('b', num_total-5)] assert(counts == counts_alt == expected_counts) with pytest.warns(UserWarning): counts_all = ds.counts() # count all if no args are given assert(set(counts_all) == set([(x, 1) for x in ds]))
def test_repeat(): ds = from_dummy_data() # itemwise ds_item = ds.repeat(3) ds_item_alt = ds.repeat(3, mode='itemwise') assert(set(ds) == set(ds_item_alt)) assert(list(ds_item) == list(ds_item_alt)) # whole ds_whole = ds.repeat(2, mode='whole') assert(set(ds) == set(ds_whole)) assert(list(ds) == list(ds_whole)[:len(ds)] == list(ds_whole)[len(ds):])
def test_split_filter(): num_total=10 ds = from_dummy_data(num_total=num_total, with_label=True).named('data', 'label') # expected items a = [ (x, 'a') for x in list(range(5))] b = [ (x, 'b') for x in list(range(5,num_total))] even_a = [x for x in a if x[0]%2==0] odd_a = [x for x in a if x[0]%2==1] even_b = [x for x in b if x[0]%2==0] odd_b = [x for x in b if x[0]%2==1] # itemwise ds_even, ds_odd = ds.split_filter([lambda x: x%2==0]) assert(list(ds_even) == even_a + even_b) assert(list(ds_odd) == odd_a + odd_b) ds_even_a, ds_not_even_a = ds.split_filter([lambda x: x%2==0, lambda x: x=='a']) assert(list(ds_even_a) == even_a) assert(list(ds_not_even_a) == odd_a + b) # by key ds_b, ds_a = ds.split_filter(label=lambda x:x=='b') assert(list(ds_b) == b) assert(list(ds_a) == a) # bulk ds_odd_b, ds_even_b = ds.split_filter(lambda x: x[0]%2==1 and x[1]=='b') assert(list(ds_odd_b) == odd_b) assert(list(ds_even_b) == a + even_b) # mix ds_even_b, ds_not_even_b = ds.split_filter([lambda x: x%2==0], label=lambda x: x=='b') assert(list(ds_even_b) == even_b) assert(list(ds_not_even_b) == [x for x in list(ds) if not x in even_b ]) # sample_classwise ds_classwise_2, ds_classwise_rest = ds.split_filter(label=allow_unique(2)) assert(list(ds_classwise_2) == list(a[:2] + b[:2])) assert(list(ds_classwise_rest) == list(a[2:] + b[2:])) # error scenarios with pytest.raises(ValueError): ds_same = ds.split_filter() # no args with pytest.raises(ValueError): ds.split_filter([None, None, None]) # too many args with pytest.raises(KeyError): ds.split_filter(badkey=lambda x:True) # key doesn't exist
def test_sample(): seed = 42 ds = from_dummy_data() ds_sampled = ds.sample(5, seed) found_items = [i for i in ds_sampled] # check list uniqueness assert(len(found_items) == len(set(found_items))) # check items expected_items = [ (i,) for i in [10,1,0,4,9]] assert(set(expected_items) == set(found_items)) # check that different seeds yield different results ds_sampled2 = ds.sample(5, seed+1) found_items2 = [i for i in ds_sampled2] assert(set(found_items2) != set(found_items))
def test_categorical(): ds = from_dummy_data(with_label=True).reorder(0,1,1).named('data', 'label', 'label_duplicate') assert(ds.unique('label') == ['a','b']) ds_label = ds.categorical(1) ds_label_alt = ds.categorical('label') # alternative syntaxes ds_label = ds.categorical(1) ds_label_alt1 = ds.categorical("label") ds_label_alt2 = ds.transform(label=categorical()) ds_label_alt3 = ds.transform([None, categorical()]) expected = [0, 1] for l, l1, l2, l3, e in zip( ds_label.unique('label'), ds_label_alt1.unique('label'), ds_label_alt2.unique('label'), ds_label_alt3.unique('label'), expected ): assert(np.array_equal(l,l1)) assert(np.array_equal(l,l2)) assert(np.array_equal(l,l3)) assert(np.array_equal(l,e)) #type:ignore assert(list(ds_label) == [(d, 0 if l == 'a' else 1 ,l2) for d, l, l2 in ds]) ds_label_userdef = ds.categorical('label', lambda x: 1 if x == 'a' else 0) assert(ds_label_userdef.unique('label') == [1, 0]) assert(list(ds_label_userdef) == [(d, 1 if l == 'a' else 0 ,l2) for d, l, l2 in ds]) # error scenarios with pytest.raises(TypeError): ds.categorical() # we need to know what to label with pytest.raises(IndexError): ds.categorical(42) # wrong key with pytest.raises(KeyError): ds.categorical("wrong") # wrong key
def test_split(): seed = 42 ds = from_dummy_data() ds1, ds2, ds3 = ds.split([0.6, 0.3, 0.1], seed=seed) # new sets are distinct assert(set(ds1) != set(ds2)) assert(set(ds1) != set(ds3)) assert(set(ds2) != set(ds3)) # no values are lost assert(set(ds) == set(ds1).union(set(ds2),set(ds3))) # repeat for wildcard ds1w, ds2w, ds3w = ds.split([0.6, -1, 0.1], seed=seed) # using wildcard produces same results assert(set(ds1) == set(ds1w)) assert(set(ds2) == set(ds2w)) assert(set(ds3) == set(ds3w))
def test_cartesian_product(): ds_pos = from_dummy_data().take(2).transform([lambda x: x + 1]) ds_10x = ds_pos.transform([lambda x: 10 * x]) ds_100x = ds_pos.transform([lambda x: 100 * x]) # two ds_prod2 = cartesian_product(ds_pos, ds_10x) ds_prod2_alt = ds_pos.cartesian_product(ds_10x) expected2 = [(1, 10), (2, 10), (1, 20), (2, 20)] assert list(ds_prod2) == list(ds_prod2_alt) == expected2 assert len(ds_prod2) == len(ds_prod2_alt) == len(set(expected2)) ds_prod2.shape # three expected3 = [ (1, 10, 100), (2, 10, 100), (1, 20, 100), (2, 20, 100), (1, 10, 200), (2, 10, 200), (1, 20, 200), (2, 20, 200), ] ds_prod3 = cartesian_product(ds_pos, ds_10x, ds_100x) ds_prod3_alt = ds_pos.cartesian_product(ds_10x, ds_100x) assert list(ds_prod3) == list(ds_prod3_alt) == expected3 assert len(ds_prod3) == len(ds_prod3_alt) == len(set(expected3)) # error scenarios with pytest.raises(ValueError): with pytest.warns(UserWarning): cartesian_product() with pytest.warns(UserWarning): cartesian_product(ds_pos) with pytest.warns(UserWarning): ds_pos.cartesian_product()