def test(self): D = 3 K = 4 hyper_params = pa.HyperParameters(dim=D, k=K, nu=D * torch.ones(K)) d = Dirichlet(hyper_params.alpha) s = d.sample() self.assertTrue(s.size() == (K, ))
def test(self): N = 2 D = 3 K = 4 eta = torch.arange(N * K, dtype=torch.float32).reshape(N, K) dataset = torch.arange(N * D, dtype=torch.float32).reshape(N, D) hyper_params = pa.HyperParameters(dim=D, k=K, nu=D * torch.ones(K)) updater = QpiUpdater(hyper_params) updater.update(dataset, eta) self.assertTrue((K, ) == updater.alpha.size())
def test(self): K = 3 D = 2 N = 4 dataset = torch.arange(N * D, dtype=torch.float32).reshape(N, D) eta = torch.arange(N * K, dtype=torch.float32).reshape(N, K) hyper_params = pa.HyperParameters(dim=D, k=K, nu=D * torch.ones(K)) updater = QmuUpdater(hyper_params) updater.update(dataset, eta) self.assertTrue(updater.beta.size() == (K, )) self.assertTrue(updater.m.size() == (K, D))
def test(self): N = 2 D = 3 K = 4 eta = torch.arange(N * K, dtype=torch.float32).reshape(N, K) beta = torch.arange(K, dtype=torch.float32) m = torch.arange(K * D, dtype=torch.float32).reshape(K, D) dataset = torch.arange(N * D, dtype=torch.float32).reshape(N, D) hyper_params = pa.HyperParameters(dim=D, k=K, nu=D * torch.ones(K)) updater = QlambdaUpdater(hyper_params) updater.update(dataset, eta, beta, m) self.assertTrue((K, D, D) == updater.W.size()) self.assertTrue((K,) == updater.nu.size())
plt.scatter(dataset[:, 0], dataset[:, 1], marker='.', c=colors) plt.scatter(xs.ravel(), ys.ravel(), marker=".", c=pcolors, alpha=0.1) plt.xlim(X_MIN, X_MAX) plt.ylim(Y_MIN, Y_MAX) plt.savefig('./predict.jpg') def make_initial_positions_with_kmeans(dataset, k): p = cl.KMeans(n_clusters=k).fit(dataset) return p.cluster_centers_ if __name__ == "__main__": try: hyper_params = pa.HyperParameters(dim=DIM, k=K, nu=NU) qs_updater = qs.QsUpdater() qp_updater = qp.QpiUpdater(hyper_params) qm_updater = qm.QmuUpdater(hyper_params) ql_updater = ql.QlambdaUpdater(hyper_params) dataset = ds.make_dataset_0(OBS_NUM, DIM, K) std, mean = torch.std_mean(dataset, dim=0) dataset = (dataset - mean) / std display_graph(dataset) cs = make_initial_positions_with_kmeans(dataset, K) # initialize mu qm_updater.m = torch.tensor(cs).float() prev_m = qm_updater.m.clone()