Exemplo n.º 1
0
def test_hybrid_discriminative_rbm_fit():
    train = torchvision.datasets.KMNIST(
        root="./data",
        train=True,
        download=True,
        transform=torchvision.transforms.ToTensor(),
    )

    new_hybrid_discriminative_rbm = discriminative_rbm.HybridDiscriminativeRBM(
        n_visible=784,
        n_hidden=128,
        n_classes=10,
        learning_rate=0.1,
        alpha=0.01,
        momentum=0,
        decay=0,
        use_gpu=False,
    )

    loss, acc = new_hybrid_discriminative_rbm.fit(train,
                                                  batch_size=128,
                                                  epochs=1)

    assert loss >= 0
    assert acc >= 0
def test_hybrid_discriminative_rbm_class_sampling():
    new_hybrid_discriminative_rbm = discriminative_rbm.HybridDiscriminativeRBM(
        n_classes=10)

    h = torch.ones(1, 128)

    probs, states = new_hybrid_discriminative_rbm.class_sampling(h)

    assert probs.size(1) == 10
    assert states.size(1) == 10
def test_hybrid_discriminative_rbm_hidden_sampling():
    new_hybrid_discriminative_rbm = discriminative_rbm.HybridDiscriminativeRBM(
        n_classes=10)

    v = torch.ones(1, 128)
    y = torch.ones(128, 10)

    probs, states = new_hybrid_discriminative_rbm.hidden_sampling(v, y)

    assert probs.size(1) == 128
    assert states.size(1) == 128
def test_hybrid_discriminative_rbm_alpha_setter():
    new_hybrid_discriminative_rbm = discriminative_rbm.HybridDiscriminativeRBM(
    )

    try:
        new_hybrid_discriminative_rbm.alpha = 'a'
    except:
        new_hybrid_discriminative_rbm.alpha = 0.01

    assert new_hybrid_discriminative_rbm.alpha == 0.01

    try:
        new_hybrid_discriminative_rbm.alpha = -1
    except:
        new_hybrid_discriminative_rbm.alpha = 0.01

    assert new_hybrid_discriminative_rbm.alpha == 0.01
def test_hybrid_discriminative_rbm_alpha():
    new_hybrid_discriminative_rbm = discriminative_rbm.HybridDiscriminativeRBM(
    )

    assert new_hybrid_discriminative_rbm.alpha == 0.01