Beispiel #1
0
def test_invalid_model(tmpdir):
    p = tmpdir.join("model.pth.tar").strpath

    net = FactorizedPrior(32, 64)
    torch.save(net.state_dict(), p)

    with pytest.raises(SystemExit):
        run_update_model(p, "--architecture", "foobar")
Beispiel #2
0
def test_load(tmpdir):
    p = tmpdir.join('model.pth.tar').strpath

    net = FactorizedPrior(32, 64)

    for k in ['network', 'state_dict']:
        torch.save({k: net.state_dict()}, p)
        stdout, stderr = run_update_model(p, '--architecture',
                                          'factorized-prior', '--dir', tmpdir)
        assert len(stdout) == 0
        assert len(stderr) == 0
Beispiel #3
0
def test_load(tmpdir):
    p = tmpdir.join("model.pth.tar").strpath

    net = FactorizedPrior(32, 64)

    for k in ["network", "state_dict"]:
        torch.save({k: net.state_dict()}, p)
        stdout, stderr = run_update_model(p, "--architecture",
                                          "factorized-prior", "--dir", tmpdir)
        assert len(stdout) == 0
        assert len(stderr) == 0
Beispiel #4
0
def test_valid_name(tmpdir):
    p = tmpdir.join('model.pth.tar').strpath

    net = FactorizedPrior(32, 64)
    torch.save(net.state_dict(), p)

    stdout, stderr = run_update_model(p, '--architecture', 'factorized-prior',
                                      '--dir', tmpdir, '--name', 'yolo')
    assert len(stdout) == 0
    assert len(stderr) == 0

    files = sorted(list(Path(tmpdir).glob('*.pth.tar')))
    assert len(files) == 2

    assert files[0].name == 'model.pth.tar'
    assert files[1].name[:5] == 'yolo-'
Beispiel #5
0
def test_valid_name(tmpdir):
    p = tmpdir.join("model.pth.tar").strpath

    net = FactorizedPrior(32, 64)
    torch.save(net.state_dict(), p)

    stdout, stderr = run_update_model(p, "--architecture", "factorized-prior",
                                      "--dir", tmpdir, "--name", "yolo")
    assert len(stdout) == 0
    assert len(stderr) == 0

    files = sorted(Path(tmpdir).glob("*.pth.tar"))
    assert len(files) == 2

    assert files[0].name == "model.pth.tar"
    assert files[1].name[:5] == "yolo-"
Beispiel #6
0
def test_valid_no_update(tmpdir):
    p = tmpdir.join('model.pth.tar').strpath

    net = FactorizedPrior(32, 64)
    torch.save(net.state_dict(), p)

    stdout, stderr = run_update_model(p, '--architecture', 'factorized-prior',
                                      '--dir', tmpdir, '--no-update')
    assert len(stdout) == 0
    assert len(stderr) == 0

    files = list(Path(tmpdir).glob('*.pth.tar'))
    assert len(files) == 1

    cdf_len = net.state_dict()['entropy_bottleneck._cdf_length']
    new_cdf_len = torch.load(files[0])['entropy_bottleneck._cdf_length']
    assert cdf_len.size(0) == new_cdf_len.size(0)
Beispiel #7
0
def test_valid(tmpdir):
    p = tmpdir.join("model.pth.tar").strpath

    net = FactorizedPrior(32, 64)
    torch.save(net.state_dict(), p)

    stdout, stderr = run_update_model(p, "--architecture", "factorized-prior",
                                      "--dir", tmpdir)
    assert len(stdout) == 0
    assert len(stderr) == 0

    files = list(Path(tmpdir).glob("*.pth.tar"))
    assert len(files) == 1

    cdf_len = net.state_dict()["entropy_bottleneck._cdf_length"]
    new_cdf_len = torch.load(files[0])["entropy_bottleneck._cdf_length"]
    assert cdf_len.size(0) != new_cdf_len.size(0)
Beispiel #8
0
    def test_factorized_prior(self):
        model = FactorizedPrior(128, 192)
        x = torch.rand(1, 3, 64, 64)
        out = model(x)

        assert "x_hat" in out
        assert "likelihoods" in out
        assert "y" in out["likelihoods"]

        assert out["x_hat"].shape == x.shape

        y_likelihoods_shape = out["likelihoods"]["y"].shape
        assert y_likelihoods_shape[0] == x.shape[0]
        assert y_likelihoods_shape[1] == 192
        assert y_likelihoods_shape[2] == x.shape[2] / 2 ** 4
        assert y_likelihoods_shape[3] == x.shape[3] / 2 ** 4
Beispiel #9
0
    def test_factorized_prior(self):
        model = FactorizedPrior(128, 192)
        x = torch.rand(1, 3, 64, 64)
        out = model(x)

        assert 'x_hat' in out
        assert 'likelihoods' in out
        assert 'y' in out['likelihoods']

        assert out['x_hat'].shape == x.shape

        y_likelihoods_shape = out['likelihoods']['y'].shape
        assert y_likelihoods_shape[0] == x.shape[0]
        assert y_likelihoods_shape[1] == 192
        assert y_likelihoods_shape[2] == x.shape[2] / 2**4
        assert y_likelihoods_shape[3] == x.shape[3] / 2**4