コード例 #1
0
ファイル: multi_label.py プロジェクト: wattlebird/pystruct
dataset = "scene"
# dataset = "yeast"

if dataset == "yeast":
    yeast = fetch_mldata("yeast")

    X = yeast.data
    X = np.hstack([X, np.ones((X.shape[0], 1))])
    y = yeast.target.toarray().astype(np.int).T

    X_train, X_test = X[:1500], X[1500:]
    y_train, y_test = y[:1500], y[1500:]

else:
    scene = load_scene()
    X_train, X_test = scene['X_train'], scene['X_test']
    y_train, y_test = scene['y_train'], scene['y_test']

n_labels = y_train.shape[1]
full = np.vstack([x for x in itertools.combinations(range(n_labels), 2)])
tree = chow_liu_tree(y_train)

full_model = MultiLabelClf(edges=full, inference_method='qpbo')
independent_model = MultiLabelClf(inference_method='unary')
tree_model = MultiLabelClf(edges=tree, inference_method="max-product")

full_ssvm = OneSlackSSVM(full_model, inference_cache=50, C=.1, tol=0.01)

tree_ssvm = OneSlackSSVM(tree_model, inference_cache=50, C=.1, tol=0.01)
コード例 #2
0
ファイル: multi_label_tree.py プロジェクト: DATAQC/pystruct
dataset = "scene"
# dataset = "yeast"

if dataset == "yeast":
    yeast = fetch_mldata("yeast")

    X = yeast.data
    X = np.hstack([X, np.ones((X.shape[0], 1))])
    y = yeast.target.toarray().astype(np.int).T

    X_train, X_test = X[:1500], X[1500:]
    y_train, y_test = y[:1500], y[1500:]

else:
    scene = load_scene()
    X_train, X_test = scene['X_train'], scene['X_test']
    y_train, y_test = scene['y_train'], scene['y_test']

n_labels = y_train.shape[1]
full = np.vstack([x for x in itertools.combinations(range(n_labels), 2)])
tree = chow_liu_tree(y_train)

#tree_model = MultiLabelClf(edges=tree, inference_method=('ogm', {'alg': 'dyn'}))
tree_model = MultiLabelClf(edges=tree, inference_method='max-product')

tree_ssvm = OneSlackSSVM(tree_model, inference_cache=50, C=.1, tol=0.01)

print("fitting tree model...")
tree_ssvm.fit(X_train, y_train)
コード例 #3
0
ファイル: test_datasets.py プロジェクト: SachithS/UnitGener
def test_dataset_loading():
    # test that we can read the datasets.
    load_scene()
    load_letters()
    load_snakes()