예제 #1
0
    def test_whenIndexToBig_shouldRaise(self):
        datasets = [
            create_random_dataset(NUM_ITEMS, NUM_CLASSES, SHAPE),
            create_random_dataset(NUM_ITEMS, NUM_CLASSES, SHAPE),
        ]
        dataset = ComposedDataset(datasets)

        self.assertRaises(ValueError, lambda: dataset[200])
        self.assertRaises(ValueError, lambda: dataset[201])
        self.assertRaises(ValueError, lambda: dataset[-201])
예제 #2
0
    def test_shouldNotHaveDupplicatedSamples(self):
        dataset = create_random_dataset(NUM_ITEMS, NUM_CLASSES, SHAPE)

        train_dataset, test_dataset = create_few_shot_datasets(dataset, 5)

        samples = unique_samples([train_dataset, test_dataset])
        self.assertEqual(len(samples), len(train_dataset) + len(test_dataset))
예제 #3
0
    def test_givenOneNumSample_trainDatasetShouldHaveOneSamplePerClass(self):
        dataset = create_random_dataset(NUM_ITEMS, NUM_CLASSES, SHAPE)

        train_dataset, test_dataset = create_few_shot_datasets(dataset, 1)

        classes_train = unique_classes(train_dataset)
        self.assertEqual(len(train_dataset), NUM_CLASSES)
        self.assertEqual(len(classes_train), NUM_CLASSES)
예제 #4
0
    def setUp(self):
        samples = create_random_dataset(NUM_ITEMS, NUM_CLASSES,
                                        SHAPE).samples  # type: ignore

        self.dataset_1 = ListDataset(samples[0:20])
        self.dataset_2 = ListDataset(samples[20:40])
        self.dataset_3 = ListDataset(samples[70:100])
        self.composed_dataset = ComposedDataset(
            [self.dataset_1, self.dataset_2, self.dataset_3])
예제 #5
0
    def test_given5NumSamples_trainDatasetShouldHave5SamplesPerClass(self):
        num_samples = 5
        # We add more items to make sure all classes are included
        dataset = create_random_dataset(NUM_ITEMS * num_samples, NUM_CLASSES,
                                        SHAPE)

        train_dataset, test_dataset = create_few_shot_datasets(
            dataset, num_samples)

        classes_train = unique_classes(train_dataset)
        self.assertEqual(len(train_dataset), num_samples * NUM_CLASSES)
        self.assertEqual(len(classes_train), NUM_CLASSES)
 def setUp(self):
     samples = create_random_dataset(NUM_ITEMS, NUM_CLASSES,
                                     SHAPE).samples  # type: ignore
     self.dataset = create_few_shot_datasets(ListDataset(samples), N_WAY)
     self.dataloader_factory = DataLoaderFactory(BATCH_SIZE, SHUFFLE,
                                                 PIN_MEMORY)
 def setUp(self):
     self.dataset = create_random_dataset(100, 3, (3, 12, 12))
     self.transforms = KorniaTransforms(MEAN, STD, (10, 10), (10, 10), 2)
예제 #8
0
    def test_shouldKeepAllSamples(self):
        dataset = create_random_dataset(NUM_ITEMS, NUM_CLASSES, SHAPE)

        train_dataset, test_dataset = create_few_shot_datasets(dataset, 5)

        self.assertEqual(NUM_ITEMS, len(train_dataset) + len(test_dataset))