コード例 #1
0
    def test_init_deterministic(self, init, dtype_device):
        dtype, device = dtype_device

        random_state = 2
        x = torch.rand((20, 5), dtype=dtype, device=device)

        kmeans_layer = KMeans(n_clusters=3, init=init)

        torch.manual_seed(random_state)
        cluster_centers_1 = kmeans_layer.initialize(x)
        torch.manual_seed(random_state)
        cluster_centers_2 = kmeans_layer.initialize(x)

        assert torch.allclose(cluster_centers_1, cluster_centers_2)
コード例 #2
0
    def test_manual_init(self):
        n_samples = 10
        n_features = 3
        n_clusters = 2

        kmeans_layer = KMeans(init='manual', n_clusters=n_clusters)

        x = torch.rand((n_samples, n_features))
        manual_init = torch.rand((n_clusters, n_features))
        wrong_init_1 = torch.rand((n_clusters + 1, n_features))
        wrong_init_2 = torch.rand((n_clusters, n_features + 1))

        with pytest.raises(TypeError):
            kmeans_layer.initialize(x, manual_init=None)

        with pytest.raises(ValueError):
            kmeans_layer.initialize(x, manual_init=wrong_init_1)

        with pytest.raises(ValueError):
            kmeans_layer.initialize(x, manual_init=wrong_init_2)

        assert torch.allclose(manual_init, kmeans_layer.initialize(x, manual_init=manual_init))