Example #1
0
    def provide_few_shot_dataset_splits(
        self, dataset_splits: DatasetSplits
    ) -> FewShotDatasetSplits:
        valid = create_few_shot_datasets(
            dataset_splits.valid, self.config.dataset.num_samples
        )
        test = create_few_shot_datasets(
            dataset_splits.test, self.config.dataset.num_samples
        )

        return FewShotDatasetSplits(dataset_splits.train, valid, test)
Example #2
0
    def test_shouldNotHaveDupplicatedSamples(self):
        dataset = create_random_dataset(NUM_ITEMS, NUM_CLASSES, SHAPE)

        train_dataset, test_dataset = create_few_shot_datasets(dataset, 5)

        samples = unique_samples([train_dataset, test_dataset])
        self.assertEqual(len(samples), len(train_dataset) + len(test_dataset))
Example #3
0
    def test_givenOneNumSample_trainDatasetShouldHaveOneSamplePerClass(self):
        dataset = create_random_dataset(NUM_ITEMS, NUM_CLASSES, SHAPE)

        train_dataset, test_dataset = create_few_shot_datasets(dataset, 1)

        classes_train = unique_classes(train_dataset)
        self.assertEqual(len(train_dataset), NUM_CLASSES)
        self.assertEqual(len(classes_train), NUM_CLASSES)
Example #4
0
    def test_given5NumSamples_trainDatasetShouldHave5SamplesPerClass(self):
        num_samples = 5
        # We add more items to make sure all classes are included
        dataset = create_random_dataset(NUM_ITEMS * num_samples, NUM_CLASSES,
                                        SHAPE)

        train_dataset, test_dataset = create_few_shot_datasets(
            dataset, num_samples)

        classes_train = unique_classes(train_dataset)
        self.assertEqual(len(train_dataset), num_samples * NUM_CLASSES)
        self.assertEqual(len(classes_train), NUM_CLASSES)
 def setUp(self):
     samples = create_random_dataset(NUM_ITEMS, NUM_CLASSES,
                                     SHAPE).samples  # type: ignore
     self.dataset = create_few_shot_datasets(ListDataset(samples), N_WAY)
     self.dataloader_factory = DataLoaderFactory(BATCH_SIZE, SHUFFLE,
                                                 PIN_MEMORY)
Example #6
0
    def test_shouldKeepAllSamples(self):
        dataset = create_random_dataset(NUM_ITEMS, NUM_CLASSES, SHAPE)

        train_dataset, test_dataset = create_few_shot_datasets(dataset, 5)

        self.assertEqual(NUM_ITEMS, len(train_dataset) + len(test_dataset))