Exemplo n.º 1
0
def test_type():
    assert json_dumps(
        torch.optim.Adam) == '{"__typename__": "torch.optim.adam.Adam"}'
    assert json_loads(
        '{"__typename__": "torch.optim.adam.Adam"}') == torch.optim.Adam
    assert re.match(r'{"__typename__": "(.*)test_serializer.Foo"}',
                    json_dumps(Foo))
    assert json_dumps(math.floor) == '{"__typename__": "math.floor"}'
    assert json_loads('{"__typename__": "math.floor"}') == math.floor
Exemplo n.º 2
0
def test_dataset():
    dataset = serialize(MNIST, root='data/mnist', train=False, download=True)
    dataloader = serialize(DataLoader, dataset, batch_size=10)

    dumped_ans = {
        "__type__": "torch.utils.data.dataloader.DataLoader",
        "arguments": {
            "batch_size": 10,
            "dataset": {
                "__type__": "torchvision.datasets.mnist.MNIST",
                "arguments": {
                    "root": "data/mnist",
                    "train": False,
                    "download": True
                }
            }
        }
    }
    assert json_dumps(dataloader) == json_dumps(dumped_ans)
    dataloader = json_loads(json_dumps(dumped_ans))
    assert isinstance(dataloader, DataLoader)

    dataset = serialize(MNIST,
                        root='data/mnist',
                        train=False,
                        download=True,
                        transform=serialize(transforms.Compose, [
                            serialize(transforms.ToTensor),
                            serialize(transforms.Normalize, (0.1307, ),
                                      (0.3081, ))
                        ]))
    dataloader = serialize(DataLoader, dataset, batch_size=10)
    x, y = next(iter(json_loads(json_dumps(dataloader))))
    assert x.size() == torch.Size([10, 1, 28, 28])
    assert y.size() == torch.Size([10])

    dataset = serialize(MNIST,
                        root='data/mnist',
                        train=False,
                        download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307, ), (0.3081, ))
                        ]))
    dataloader = serialize(DataLoader, dataset, batch_size=10)
    x, y = next(iter(json_loads(json_dumps(dataloader))))
    assert x.size() == torch.Size([10, 1, 28, 28])
    assert y.size() == torch.Size([10])
Exemplo n.º 3
0
def test_serialize():
    module = serialize(Foo, 3)
    assert json_loads(json_dumps(module)) == module
    module = serialize(Foo, b=2, a=1)
    assert json_loads(json_dumps(module)) == module

    module = serialize(Foo, Foo(1), 5)
    dumped_module = json_dumps(module)
    assert len(
        dumped_module
    ) > 200  # should not be too longer if the serialization is correct

    module = serialize(Foo, serialize(Foo, 1), 5)
    dumped_module = json_dumps(module)
    assert len(
        dumped_module
    ) < 200  # should not be too longer if the serialization is correct
    assert json_loads(dumped_module) == module
Exemplo n.º 4
0
def test_basic_unit():
    module = ImportTest(3, 0.5)
    assert json_loads(json_dumps(module)) == module
Exemplo n.º 5
0
def test_blackbox_module():
    module = ImportTest(3, 0.5)
    assert json_loads(json_dumps(module)) == module