def test_k_means_initialization_grid_crf(): # with only 1 state per label, nothing happends X, Y = toy.generate_big_checker(n_samples=10) crf = LatentGridCRF(n_labels=2, n_states_per_label=1, inference_method='lp') H = crf.init_latent(X, Y) assert_array_equal(Y, H)
def test_with_crosses_bad_init(): # use less perfect initialization X, Y = toy.generate_crosses(n_samples=10, noise=5, n_crosses=1, total_size=8) n_labels = 2 crf = LatentGridCRF(n_labels=n_labels, n_states_per_label=2, inference_method='lp') clf = LatentSSVM(problem=crf, max_iter=50, C=10. ** 3, verbose=2, check_constraints=True, n_jobs=-1, break_on_bad=True) H_init = crf.init_latent(X, Y) mask = np.random.uniform(size=H_init.shape) > .7 H_init[mask] = 2 * (H_init[mask] / 2) clf.fit(X, Y, H_init=H_init) Y_pred = clf.predict(X) assert_array_equal(np.array(Y_pred), Y)