def test_rescale_axis(request: Any, dummy: ep.Tensor) -> None: backend = request.config.option.backend if backend == "numpy": pytest.skip() x_np = np.random.uniform(0.0, 1.0, size=(16, 3, 64, 64)) x_np_ep = ep.astensor(x_np) x_up_np_ep = rescale_images(x_np_ep, (16, 3, 128, 128), 1) x_up_np = x_up_np_ep.numpy() x = ep.from_numpy(dummy, x_np) x_ep = ep.astensor(x) x_up_ep = rescale_images(x_ep, (16, 3, 128, 128), 1) x_up = x_up_ep.numpy() assert np.allclose(x_up_np, x_up, atol=1e-5)
def test_pytorch_numpy_compatibility() -> None: import numpy as np import torch x_np = np.random.uniform(0.0, 1.0, size=(16, 3, 64, 64)) x_torch = torch.from_numpy(x_np) x_np_ep = ep.astensor(x_np) x_torch_ep = ep.astensor(x_torch) x_up_np_ep = rescale_images(x_np_ep, (16, 3, 128, 128), 1) x_up_torch_ep = rescale_images(x_torch_ep, (16, 3, 128, 128), 1) x_up_np = x_up_np_ep.numpy() x_up_torch = x_up_torch_ep.numpy() assert np.allclose(x_up_np, x_up_torch)
def test_jax_numpy_compatibility() -> None: import numpy as np import jax.numpy as jnp x_np = np.random.uniform(0.0, 1.0, size=(16, 3, 64, 64)) x_jax = jnp.array(x_np) x_np_ep = ep.astensor(x_np) x_jax_ep = ep.astensor(x_jax) x_up_np_ep = rescale_images(x_np_ep, (16, 3, 128, 128), 1) x_up_jax_ep = rescale_images(x_jax_ep, (16, 3, 128, 128), 1) x_up_np = x_up_np_ep.numpy() x_up_jax = x_up_jax_ep.numpy() assert np.allclose(x_up_np, x_up_jax)
def test_pytorch_tensorflow_compatibility() -> None: import numpy as np import torch import tensorflow as tf x_np = np.random.uniform(0.0, 1.0, size=(16, 3, 64, 64)) x_torch = torch.from_numpy(x_np) x_tf = tf.convert_to_tensor(x_np) x_tf_ep = ep.astensor(x_tf) x_torch_ep = ep.astensor(x_torch) x_up_tf_ep = rescale_images(x_tf_ep, (16, 3, 128, 128), 1) x_up_torch_ep = rescale_images(x_torch_ep, (16, 3, 128, 128), 1) x_up_tf = x_up_tf_ep.numpy() x_up_torch = x_up_torch_ep.numpy() assert np.allclose(x_up_tf, x_up_torch)