예제 #1
0
    def test_updating_dataset_increases_its_size(self, dataset):
        labels = types.TikTensorBatch(
            [types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0))])
        data = types.TikTensorBatch(
            [types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0))])

        dataset.update(data, labels)
        assert 1 == len(dataset)
예제 #2
0
    def test_access_by_index_to_deleted_element_is_allowed(self, dataset):
        labels = types.TikTensorBatch(
            [types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0))])
        data = types.TikTensorBatch(
            [types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0))])
        dataset.update(data, labels)

        dataset.remove((0, 0))

        assert dataset[0]
예제 #3
0
 def simple_dataset(self, dataset):
     labels = types.TikTensorBatch([
         types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0)),
         types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0))
     ])
     data = types.TikTensorBatch([
         types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0)),
         types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0))
     ])
     dataset.update(data, labels)
     return dataset
예제 #4
0
    def test_removing_entries_from_dataset(self, dataset):
        labels = types.TikTensorBatch([
            types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0)),
            types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0))
        ])
        data = types.TikTensorBatch([
            types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0)),
            types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0))
        ])
        dataset.update(data, labels)

        assert 2 == len(dataset)

        dataset.remove((0, 0))

        assert 1 == len(dataset)
예제 #5
0
    def test_access_by_index(self, dataset):
        first_label = torch.Tensor(np.arange(9).reshape(3, 3))
        first_data = torch.Tensor(np.arange(1, 10).reshape(3, 3))

        labels = types.TikTensorBatch([
            types.TikTensor(first_label, id_=(0, 0)),
            types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0))
        ])
        data = types.TikTensorBatch([
            types.TikTensor(first_data, id_=(0, 0)),
            types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0))
        ])
        dataset.update(data, labels)
        ret_data, ret_label = dataset[0]

        assert torch.equal(first_label, ret_label)
        assert torch.equal(first_data, ret_data)
예제 #6
0
    def test_updating_removed_entries_recovers_them(self, dataset):
        labels = types.TikTensorBatch([
            types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0)),
            types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0))
        ])
        data = types.TikTensorBatch([
            types.TikTensor(np.ones(shape=(3, 3)), id_=(0, 0)),
            types.TikTensor(np.ones(shape=(3, 3)), id_=(1, 0))
        ])
        dataset.update(data, labels)

        assert 2 == len(dataset)

        dataset.remove((0, 0))
        dataset.update(data, labels)

        assert 2 == len(dataset)