コード例 #1
0
ファイル: test_sampler.py プロジェクト: Ares2013/coreml
 def test_binary_balanced_length(self):
     """Test sampler length with mode='balanced' for binary-1"""
     sampler = ClassificationDataSampler(
         self.binary_1_dataset, shuffle=True,
         target_transform=self.binary_target_transform,
         mode='balanced')
     self.assertEqual(len(sampler), 4)
コード例 #2
0
ファイル: test_sampler.py プロジェクト: Ares2013/coreml
 def test_binary_balanced_no_shuffle(self):
     """Test sampling without shuffling on mode='balanced' for binary"""
     sampler = ClassificationDataSampler(
         self.binary_1_dataset, shuffle=False,
         target_transform=self.binary_target_transform,
         mode='balanced')
     indices = list(sampler)
     assert_array_equal(indices, [0, 1, 3, 2])
コード例 #3
0
ファイル: test_sampler.py プロジェクト: Ares2013/coreml
    def test_default_no_shuffle(self):
        """Test sampling without shuffling on mode='default'"""
        sampler = ClassificationDataSampler(
            self.binary_1_dataset, shuffle=False,
            target_transform=self.binary_target_transform, mode='default')
        indices = list(sampler)
        true_indices = [0, 1, 2, 3, 4]

        # without shuffle=True, the ordering should be the default ordering
        self.assertEqual(indices, true_indices)
コード例 #4
0
ファイル: test_sampler.py プロジェクト: Ares2013/coreml
    def test_multi_balanced_no_shuffle(self):
        """Test sampling on mode='balanced' for multiclass no shuffle"""
        sampler = ClassificationDataSampler(
            self.multi_dataset, shuffle=False,
            target_transform=self.multi_target_transform,
            mode='balanced')
        indices = list(sampler)
        true_indices = [2, 1, 0]

        # with shuffle=False, the first occurence of each
        # class should be returned
        self.assertEqual(indices, true_indices)
コード例 #5
0
ファイル: test_sampler.py プロジェクト: Ares2013/coreml
    def test_binary_1_balanced_shuffle(self):
        """Test sampling with shuffling on mode='balanced' for binary-1"""
        sampler = ClassificationDataSampler(
            self.binary_1_dataset, shuffle=True,
            target_transform=self.binary_target_transform,
            mode='balanced')
        indices = list(sampler)
        for i, index in enumerate(indices):
            if not i % 2:
                target_label = 0
            else:
                target_label = 1

            self.assertEqual(sampler.labels[index], target_label)
コード例 #6
0
ファイル: test_sampler.py プロジェクト: Ares2013/coreml
    def test_multi_balanced_with_shuffle(self):
        """Test sampling on mode='balanced' for multiclass with shuffle"""
        sampler = ClassificationDataSampler(
            self.multi_dataset, shuffle=True,
            target_transform=self.multi_target_transform,
            mode='balanced')
        indices = list(sampler)
        for i, index in enumerate(indices):
            if not i % 3:
                target_label = 0
            elif i % 3 == 1:
                target_label = 1
            else:
                target_label = 2

            self.assertEqual(sampler.labels[index], target_label)