def test_subspace_tracking_momentum(self, subspace_momentum): W, H = ornmf(self.X, self.rank, method="MomentumSGD", subspace_momentum=subspace_momentum) compare_norms(W @ H, self.X) with pytest.raises(ValueError, match=f"must be a float between 0 and 1"): _ = ornmf(self.X, self.rank, method="MomentumSGD", subspace_momentum=1.9)
def test_subspace_tracking_learning_rate(self, subspace_learning_rate): W, H = ornmf( self.X, self.rank, method="MomentumSGD", subspace_learning_rate=subspace_learning_rate, ) compare_norms(W @ H, self.X)
def test_no_method(self): with pytest.raises(ValueError, match=f"'method' not recognised"): _ = ornmf(self.X, self.rank, method="uniform")
def test_subspace_tracking(self): W, H = ornmf(self.X, self.rank, method="MomentumSGD") compare_norms(W @ H, self.X)
def test_corrupted_robust(self): W, H = ornmf(self.Y, self.rank, method="RobustPGD") compare_norms(W @ H, self.X)
def test_corrupted_default(self): W, H = ornmf(self.Y, self.rank) compare_norms(W @ H, self.X)
def test_store_error(self): Xhat, Ehat, W, H = ornmf(self.X, self.rank, store_error=True) compare_norms(Xhat, self.X) assert Xhat.shape == self.X.shape assert Ehat.shape == self.E.shape
def test_batch_size(self): W, H = ornmf(self.X, self.rank, batch_size=2) compare_norms(W @ H, self.X) assert W.shape == self.U.shape assert H.shape == self.V.T.shape
def test_default(self, project): W, H = ornmf(self.X, self.rank, project=project) compare_norms(W @ H, self.X) assert W.shape == self.U.shape assert H.shape == self.V.T.shape