コード例 #1
0
ファイル: test_masks.py プロジェクト: cirKITers/masKIT
    def test_percentage_perturbation(self):
        size = 3
        mp = DropoutMask((size, ))

        for i in [0.01, 0.1, 0.5, 0.9]:
            mp.perturb(amount=i)
            assert pnp.sum(mp.mask) == round(i * mp.mask.size)
            mp.clear()
コード例 #2
0
ファイル: test_masks.py プロジェクト: cirKITers/masKIT
    def test_wrong_percentage_perturbation(self):
        size = 3
        mp = DropoutMask((size, ))

        for i in [1.1, 1.5, 3.1]:
            mp.perturb(amount=i)
            assert pnp.sum(mp.mask) == round(i)
            mp.clear()
コード例 #3
0
ファイル: test_masks.py プロジェクト: cirKITers/masKIT
    def test_perturbation_remove_add(self):
        size = 3
        mp = DropoutMask((size, ))

        for amount in [random.randrange(size), 0, size, size + 1]:
            mp.perturb(amount=amount, mode=PerturbationMode.RESET)
            assert pnp.sum(mp.mask) == 0
            mp.perturb(amount=amount, mode=PerturbationMode.SET)
            assert pnp.sum(mp.mask) == min(amount, size)
            mp.clear()
コード例 #4
0
ファイル: test_masks.py プロジェクト: cirKITers/masKIT
 def test_setting(self):
     size = 3
     mp = DropoutMask((size, ))
     assert mp
     assert len(mp.mask) == mp.mask.size
     assert pnp.sum(mp.mask) == 0
     mp[1] = True
     assert mp[1] == True  # noqa: E712
     with pytest.raises(IndexError):
         mp[size] = True
     assert pnp.sum(mp.mask) == 1
     mp.clear()
     assert pnp.sum(mp.mask) == 0
     mp[:] = True
     result = mp[:]
     assert len(result) == size
     assert pnp.all(result)
     assert pnp.sum(mp.mask) == size
     mp.clear()
     with pytest.raises(IndexError):
         mp[1, 2] = True