def _compare_with_sklearn(self, cov_type):
        # sklearn version.
        iterations = 40
        np.random.seed(5)
        sklearn_assignments = np.asarray([0, 0, 1, 0, 0, 0, 1, 0, 0, 1])
        sklearn_means = np.asarray([[144.83417719, 254.20130341],
                                    [274.38754816, 353.16074346]])
        sklearn_covs = np.asarray([[[395.0081194, -4.50389512],
                                    [-4.50389512, 408.27543989]],
                                   [[385.17484203, -31.27834935],
                                    [-31.27834935, 391.74249925]]])

        # skflow version.
        gmm = gmm_lib.GMM(self.num_centers,
                          initial_clusters=self.initial_means,
                          covariance_type=cov_type,
                          config=run_config.RunConfig(tf_random_seed=2))
        gmm.fit(input_fn=self.input_fn(), steps=iterations)
        points = self.points[:10, :]
        skflow_assignments = []
        for item in gmm.predict_assignments(
                input_fn=self.input_fn(points=points, batch_size=10)):
            skflow_assignments.append(item)
        self.assertAllClose(sklearn_assignments,
                            np.ravel(skflow_assignments).astype(int))
        self.assertAllClose(sklearn_means, gmm.clusters())
        if cov_type == 'full':
            self.assertAllClose(sklearn_covs, gmm.covariances(), rtol=0.01)
        else:
            for d in [0, 1]:
                self.assertAllClose(np.diag(sklearn_covs[d]),
                                    gmm.covariances()[d, :],
                                    rtol=0.01)
 def test_weights(self):
     """Tests the shape of the weights."""
     gmm = gmm_lib.GMM(self.num_centers,
                       initial_clusters=self.initial_means,
                       random_seed=4,
                       config=run_config.RunConfig(tf_random_seed=2))
     gmm.fit(input_fn=self.input_fn(), steps=0)
     weights = gmm.weights()
     self.assertAllEqual(list(weights.shape), [self.num_centers])
 def test_fit(self):
     gmm = gmm_lib.GMM(self.num_centers,
                       initial_clusters='random',
                       random_seed=4,
                       config=run_config.RunConfig(tf_random_seed=2))
     gmm.fit(input_fn=self.input_fn(), steps=1)
     score1 = gmm.score(input_fn=self.input_fn(batch_size=self.num_points),
                        steps=1)
     gmm.fit(input_fn=self.input_fn(), steps=10)
     score2 = gmm.score(input_fn=self.input_fn(batch_size=self.num_points),
                        steps=1)
     self.assertLess(score1, score2)
    def test_infer(self):
        gmm = gmm_lib.GMM(self.num_centers,
                          initial_clusters=self.initial_means,
                          random_seed=4,
                          config=run_config.RunConfig(tf_random_seed=2))
        gmm.fit(input_fn=self.input_fn(), steps=60)
        clusters = gmm.clusters()

        # Make a small test set
        num_points = 40
        points, true_assignments = self.make_random_points(
            clusters, num_points)

        assignments = []
        for item in gmm.predict_assignments(
                input_fn=self.input_fn(points=points, batch_size=num_points)):
            assignments.append(item)
        assignments = np.ravel(assignments)
        self.assertAllEqual(true_assignments, assignments)
    def test_random_input_large(self):
        # sklearn version.
        iterations = 5  # that should be enough to know whether this diverges
        np.random.seed(5)
        num_classes = 20
        x = np.array([[np.random.random() for _ in range(100)]
                      for _ in range(num_classes)],
                     dtype=np.float32)

        # skflow version.
        gmm = gmm_lib.GMM(num_classes,
                          covariance_type='full',
                          config=run_config.RunConfig(tf_random_seed=2))

        def get_input_fn(x):
            def input_fn():
                return constant_op.constant(x.astype(np.float32)), None

            return input_fn

        gmm.fit(input_fn=get_input_fn(x), steps=iterations)
        self.assertFalse(np.isnan(gmm.clusters()).any())
 def test_queues(self):
     gmm = gmm_lib.GMM(2, covariance_type='diag')
     gmm.fit(input_fn=self.input_fn(), steps=1)