def test_percentage_perturbation(self): size = 3 mp = DropoutMask((size, )) for i in [0.01, 0.1, 0.5, 0.9]: mp.perturb(amount=i) assert pnp.sum(mp.mask) == round(i * mp.mask.size) mp.clear()
def test_wrong_percentage_perturbation(self): size = 3 mp = DropoutMask((size, )) for i in [1.1, 1.5, 3.1]: mp.perturb(amount=i) assert pnp.sum(mp.mask) == round(i) mp.clear()
def test_perturbation_remove_add(self): size = 3 mp = DropoutMask((size, )) for amount in [random.randrange(size), 0, size, size + 1]: mp.perturb(amount=amount, mode=PerturbationMode.RESET) assert pnp.sum(mp.mask) == 0 mp.perturb(amount=amount, mode=PerturbationMode.SET) assert pnp.sum(mp.mask) == min(amount, size) mp.clear()
def test_setting(self): size = 3 mp = DropoutMask((size, )) assert mp assert len(mp.mask) == mp.mask.size assert pnp.sum(mp.mask) == 0 mp[1] = True assert mp[1] == True # noqa: E712 with pytest.raises(IndexError): mp[size] = True assert pnp.sum(mp.mask) == 1 mp.clear() assert pnp.sum(mp.mask) == 0 mp[:] = True result = mp[:] assert len(result) == size assert pnp.all(result) assert pnp.sum(mp.mask) == size mp.clear() with pytest.raises(IndexError): mp[1, 2] = True