def test_binary_balanced_length(self): """Test sampler length with mode='balanced' for binary-1""" sampler = ClassificationDataSampler( self.binary_1_dataset, shuffle=True, target_transform=self.binary_target_transform, mode='balanced') self.assertEqual(len(sampler), 4)
def test_binary_balanced_no_shuffle(self): """Test sampling without shuffling on mode='balanced' for binary""" sampler = ClassificationDataSampler( self.binary_1_dataset, shuffle=False, target_transform=self.binary_target_transform, mode='balanced') indices = list(sampler) assert_array_equal(indices, [0, 1, 3, 2])
def test_default_no_shuffle(self): """Test sampling without shuffling on mode='default'""" sampler = ClassificationDataSampler( self.binary_1_dataset, shuffle=False, target_transform=self.binary_target_transform, mode='default') indices = list(sampler) true_indices = [0, 1, 2, 3, 4] # without shuffle=True, the ordering should be the default ordering self.assertEqual(indices, true_indices)
def test_multi_balanced_no_shuffle(self): """Test sampling on mode='balanced' for multiclass no shuffle""" sampler = ClassificationDataSampler( self.multi_dataset, shuffle=False, target_transform=self.multi_target_transform, mode='balanced') indices = list(sampler) true_indices = [2, 1, 0] # with shuffle=False, the first occurence of each # class should be returned self.assertEqual(indices, true_indices)
def test_binary_1_balanced_shuffle(self): """Test sampling with shuffling on mode='balanced' for binary-1""" sampler = ClassificationDataSampler( self.binary_1_dataset, shuffle=True, target_transform=self.binary_target_transform, mode='balanced') indices = list(sampler) for i, index in enumerate(indices): if not i % 2: target_label = 0 else: target_label = 1 self.assertEqual(sampler.labels[index], target_label)
def test_multi_balanced_with_shuffle(self): """Test sampling on mode='balanced' for multiclass with shuffle""" sampler = ClassificationDataSampler( self.multi_dataset, shuffle=True, target_transform=self.multi_target_transform, mode='balanced') indices = list(sampler) for i, index in enumerate(indices): if not i % 3: target_label = 0 elif i % 3 == 1: target_label = 1 else: target_label = 2 self.assertEqual(sampler.labels[index], target_label)