def test_init_deterministic(self, init, dtype_device): dtype, device = dtype_device random_state = 2 x = torch.rand((20, 5), dtype=dtype, device=device) kmeans_layer = KMeans(n_clusters=3, init=init) torch.manual_seed(random_state) cluster_centers_1 = kmeans_layer.initialize(x) torch.manual_seed(random_state) cluster_centers_2 = kmeans_layer.initialize(x) assert torch.allclose(cluster_centers_1, cluster_centers_2)
def test_manual_init(self): n_samples = 10 n_features = 3 n_clusters = 2 kmeans_layer = KMeans(init='manual', n_clusters=n_clusters) x = torch.rand((n_samples, n_features)) manual_init = torch.rand((n_clusters, n_features)) wrong_init_1 = torch.rand((n_clusters + 1, n_features)) wrong_init_2 = torch.rand((n_clusters, n_features + 1)) with pytest.raises(TypeError): kmeans_layer.initialize(x, manual_init=None) with pytest.raises(ValueError): kmeans_layer.initialize(x, manual_init=wrong_init_1) with pytest.raises(ValueError): kmeans_layer.initialize(x, manual_init=wrong_init_2) assert torch.allclose(manual_init, kmeans_layer.initialize(x, manual_init=manual_init))