Exemplo n.º 1
0
def test_to_backend_with_tf_and_pytorch():
    try:
        import torch
        import tensorflow as tf
    except ImportError:
        print("Pytorch hasn't been imported and tested")
        return

    tf.compat.v1.enable_eager_execution()
    ds = dataset.load("mnist/mnist")

    tfds = ds.to_tensorflow()
    ptds = ds.to_pytorch()
    ptds = torch.utils.data.DataLoader(
        ptds,
        batch_size=1,
        num_workers=1,
        collate_fn=ds.collate_fn if "collate_fn" in dir(ds) else None,
    )

    for i, (batchtf, batchpt) in enumerate(zip(tfds, ptds)):
        gt = ds["labels"][i].compute()
        assert gt == batchtf["labels"].numpy()
        assert gt == batchpt["labels"].numpy()
        if i > 10:
            break
Exemplo n.º 2
0
def test_to_backend_with_tf_and_pytorch_multiworker():
    import tensorflow as tf
    import torch

    tf.compat.v1.enable_eager_execution()
    ds = dataset.load("mnist/mnist")

    tfds = ds.to_tensorflow().batch(8)
    ptds = ds.to_pytorch()
    ptds = torch.utils.data.DataLoader(
        ptds,
        batch_size=8,
        num_workers=8,
        collate_fn=ds.collate_fn if "collate_fn" in dir(ds) else None,
    )
    for i, (batchtf, batchpt) in enumerate(zip(tfds, ptds)):
        assert np.all(batchtf["labels"].numpy() == batchpt["labels"].numpy())
        if i > 10:
            break
Exemplo n.º 3
0
def test_tensor_dtag():
    t = tensor.from_array(np.array([1, 2], dtype="int32"), dtag="image")
    ds = dataset.from_tensors({"name": t})
    ds.store("./data/new/test")
    ds = dataset.load("./data/new/test")
    assert ds["name"].dtag == "image"