def test_to_backend_with_tf_and_pytorch(): try: import torch import tensorflow as tf except ImportError: print("Pytorch hasn't been imported and tested") return tf.compat.v1.enable_eager_execution() ds = dataset.load("mnist/mnist") tfds = ds.to_tensorflow() ptds = ds.to_pytorch() ptds = torch.utils.data.DataLoader( ptds, batch_size=1, num_workers=1, collate_fn=ds.collate_fn if "collate_fn" in dir(ds) else None, ) for i, (batchtf, batchpt) in enumerate(zip(tfds, ptds)): gt = ds["labels"][i].compute() assert gt == batchtf["labels"].numpy() assert gt == batchpt["labels"].numpy() if i > 10: break
def test_to_backend_with_tf_and_pytorch_multiworker(): import tensorflow as tf import torch tf.compat.v1.enable_eager_execution() ds = dataset.load("mnist/mnist") tfds = ds.to_tensorflow().batch(8) ptds = ds.to_pytorch() ptds = torch.utils.data.DataLoader( ptds, batch_size=8, num_workers=8, collate_fn=ds.collate_fn if "collate_fn" in dir(ds) else None, ) for i, (batchtf, batchpt) in enumerate(zip(tfds, ptds)): assert np.all(batchtf["labels"].numpy() == batchpt["labels"].numpy()) if i > 10: break
def test_tensor_dtag(): t = tensor.from_array(np.array([1, 2], dtype="int32"), dtag="image") ds = dataset.from_tensors({"name": t}) ds.store("./data/new/test") ds = dataset.load("./data/new/test") assert ds["name"].dtag == "image"