def test_hybrid_discriminative_rbm_fit(): train = torchvision.datasets.KMNIST( root="./data", train=True, download=True, transform=torchvision.transforms.ToTensor(), ) new_hybrid_discriminative_rbm = discriminative_rbm.HybridDiscriminativeRBM( n_visible=784, n_hidden=128, n_classes=10, learning_rate=0.1, alpha=0.01, momentum=0, decay=0, use_gpu=False, ) loss, acc = new_hybrid_discriminative_rbm.fit(train, batch_size=128, epochs=1) assert loss >= 0 assert acc >= 0
def test_hybrid_discriminative_rbm_class_sampling(): new_hybrid_discriminative_rbm = discriminative_rbm.HybridDiscriminativeRBM( n_classes=10) h = torch.ones(1, 128) probs, states = new_hybrid_discriminative_rbm.class_sampling(h) assert probs.size(1) == 10 assert states.size(1) == 10
def test_hybrid_discriminative_rbm_hidden_sampling(): new_hybrid_discriminative_rbm = discriminative_rbm.HybridDiscriminativeRBM( n_classes=10) v = torch.ones(1, 128) y = torch.ones(128, 10) probs, states = new_hybrid_discriminative_rbm.hidden_sampling(v, y) assert probs.size(1) == 128 assert states.size(1) == 128
def test_hybrid_discriminative_rbm_alpha_setter(): new_hybrid_discriminative_rbm = discriminative_rbm.HybridDiscriminativeRBM( ) try: new_hybrid_discriminative_rbm.alpha = 'a' except: new_hybrid_discriminative_rbm.alpha = 0.01 assert new_hybrid_discriminative_rbm.alpha == 0.01 try: new_hybrid_discriminative_rbm.alpha = -1 except: new_hybrid_discriminative_rbm.alpha = 0.01 assert new_hybrid_discriminative_rbm.alpha == 0.01
def test_hybrid_discriminative_rbm_alpha(): new_hybrid_discriminative_rbm = discriminative_rbm.HybridDiscriminativeRBM( ) assert new_hybrid_discriminative_rbm.alpha == 0.01