def test_rescale_inf(): # positive infinite s = Box(np.array([0.0, 0.0]), np.array([1.0, np.inf]), dtype=np.float32) trafo = rescale(s, np.array([1.0, 1.0]), np.array([3.0, np.inf])) assert trafo.target == Box(np.array([1.0, 1.0]), np.array([3.0, np.inf]), dtype=np.float32) check_convert(trafo, [3.0, 1.0], [1.0, 0.0]) # negative infinite s = Box(np.array([-1.0, -np.inf]), np.array([0.0, 0.0]), dtype=np.float32) trafo = rescale(s, np.array([1.0, -np.inf]), np.array([3.0, 1.0])) assert trafo.target == Box(np.array([1.0, -np.inf]), np.array([3.0, 1.0]), dtype=np.float32) check_convert(trafo, [1.0, 1.0], [-1.0, 0.0]) # two sided s = Box(np.array([-1.0, -np.inf]), np.array([0.0, np.inf]), dtype=np.float32) trafo = rescale(s, np.array([1.0, -np.inf]), np.array([3.0, np.inf])) assert trafo.target == Box(np.array([1.0, -np.inf]), np.array([3.0, np.inf]), dtype=np.float32) check_convert(trafo, [1.0, 12.0], [-1.0, 12.0])
def test_rescale_checks(): # check that invalid target range causes error with pytest.raises(ValueError): rescale(Box(np.array([0.0]), np.array([1.0])), np.inf, np.inf) # cannot linearly transform infinite to finite range with pytest.raises(ValueError): s = Box(np.array([-np.inf, 0.0]), np.array([np.inf, np.inf])) trafo = rescale(s, np.array([1.0, 1.0]), np.array([3.0, np.inf]))
def test_rescale_box(): s = Box(np.array([0.0, 1.0]), np.array([1.0, 2.0])) trafo = rescale(s, np.array([1.0, 0.0]), np.array([2.0, 1.0])) assert trafo.target == Box(np.array([1.0, 0.0]), np.array([2.0, 1.0])) check_convert(trafo, [1.0, 0.0], [0.0, 1.0]) check_convert(trafo, [2.0, 1.0], [1.0, 2.0]) # scalar rescale s = Box(np.array([0.0, 1.0]), np.array([1.0, 2.0])) trafo = rescale(s, 0.0, 1.0) assert trafo.target == Box(np.array([0.0, 0.0]), np.array([1.0, 1.0]))
def test_rescale_tuple(): with pytest.raises(NotImplementedError): rescale(Tuple([Box(0, 1, (1, 1))]), 0.0, 1.0)
def test_rescale_discrete(space): with pytest.raises(TypeError): rescale(space, 0.0, 1.0)
def test_rescale_discrete(space): # cannot rescale discrete spaces with pytest.raises(TypeError): rescale(space, 0.0, 1.0)