def test_updating_dataset_increases_its_size(self, dataset): labels = types.TikTensorBatch( [types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0))]) data = types.TikTensorBatch( [types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0))]) dataset.update(data, labels) assert 1 == len(dataset)
def test_access_by_index_to_deleted_element_is_allowed(self, dataset): labels = types.TikTensorBatch( [types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0))]) data = types.TikTensorBatch( [types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0))]) dataset.update(data, labels) dataset.remove((0, 0)) assert dataset[0]
def simple_dataset(self, dataset): labels = types.TikTensorBatch([ types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0)), types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0)) ]) data = types.TikTensorBatch([ types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0)), types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0)) ]) dataset.update(data, labels) return dataset
def test_removing_entries_from_dataset(self, dataset): labels = types.TikTensorBatch([ types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0)), types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0)) ]) data = types.TikTensorBatch([ types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0)), types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0)) ]) dataset.update(data, labels) assert 2 == len(dataset) dataset.remove((0, 0)) assert 1 == len(dataset)
def test_access_by_index(self, dataset): first_label = torch.Tensor(np.arange(9).reshape(3, 3)) first_data = torch.Tensor(np.arange(1, 10).reshape(3, 3)) labels = types.TikTensorBatch([ types.TikTensor(first_label, id_=(0, 0)), types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0)) ]) data = types.TikTensorBatch([ types.TikTensor(first_data, id_=(0, 0)), types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0)) ]) dataset.update(data, labels) ret_data, ret_label = dataset[0] assert torch.equal(first_label, ret_label) assert torch.equal(first_data, ret_data)
def test_updating_removed_entries_recovers_them(self, dataset): labels = types.TikTensorBatch([ types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0)), types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0)) ]) data = types.TikTensorBatch([ types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0)), types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0)) ]) dataset.update(data, labels) assert 2 == len(dataset) dataset.remove((0, 0)) dataset.update(data, labels) assert 2 == len(dataset)