Beispiel #1
0
def add_typed_leaves_text_support():
    add_parametric_text_support()
    add_node_to_str(TypeMixture, type_mixture_to_str)
    add_node_to_str(TypeMixtureUnconstrained, type_mixture_uncon_to_str)

    add_str_to_spn("typemixture", type_mixture_tree_to_spn,
                   """typemixture: "TypeMixture(" [PARAMNAME]"," [DECIMAL "*" node ("+" DECIMAL "*" node)*] ")" """,
                   TypeMixture)

    add_str_to_spn("typemixtureunconstrained", type_mixture_uncon_tree_to_spn,
                   """typemixtureunconstrained: "TypeMixtureUnconstrained(" [PARAMNAME]"," [DECIMAL "*" node ("+" DECIMAL "*" node)*] ")" """,
                   TypeMixtureUnconstrained)
Beispiel #2
0
from spn.structure.leaves.histogram.Text import add_histogram_text_support
from spn.structure.leaves.parametric.Symbolic import add_parametric_symbolic_support
from spn.structure.leaves.parametric.Text import add_parametric_text_support
from spn.structure.leaves.piecewise.Text import add_piecewise_text_support
from spn.structure.leaves.cltree.Text import add_cltree_text_support

add_parametric_text_support()
add_piecewise_text_support()
add_histogram_text_support()
add_cltree_text_support()
add_parametric_symbolic_support()
def train_spn(window_size=3,
              min_instances_slice=10000,
              features=None,
              number_of_classes=3):
    if features is None:
        features = [20, 120]

    add_parametric_inference_support()
    add_parametric_text_support()

    data = get_data_in_window(window_size=window_size,
                              features=features,
                              three_classes=number_of_classes == 3)

    sss = sk.model_selection.StratifiedShuffleSplit(test_size=0.2,
                                                    train_size=0.8,
                                                    random_state=42)
    for train_index, test_index in sss.split(
            data[:, 0:window_size * window_size * len(features)],
            data[:, (window_size * window_size * len(features)) +
                 (int(window_size * window_size / 2))]):
        X_train, X_test = data[train_index], data[test_index]

    context_list = list()
    parametric_list = list()
    number_of_features = len(features)
    for _ in range(number_of_features * window_size * window_size):
        context_list.append(MetaType.REAL)
        parametric_list.append(Gaussian)

    for _ in range(window_size * window_size):
        context_list.append(MetaType.DISCRETE)
        parametric_list.append(Categorical)

    ds_context = Context(meta_types=context_list)
    ds_context.add_domains(data)
    ds_context.parametric_types = parametric_list

    spn = load_spn(window_size, features, min_instances_slice,
                   number_of_classes)
    if spn is None:
        spn = Sum()
        for class_pixel in tqdm(range(-window_size * window_size, 0)):
            for label, count in zip(
                    *np.unique(data[:, class_pixel], return_counts=True)):
                train_data = X_train[X_train[:, class_pixel] == label, :]
                branch = learn_parametric(
                    train_data,
                    ds_context,
                    min_instances_slice=min_instances_slice)
                spn.children.append(branch)
                spn.weights.append(train_data.shape[0])

        spn.scope.extend(branch.scope)
        spn.weights = (np.array(spn.weights) / sum(spn.weights)).tolist()

        assign_ids(spn)
        save_spn(spn, window_size, features, min_instances_slice,
                 number_of_classes)

    res = np.ndarray((X_test.shape[0], number_of_classes))

    for i in tqdm(range(number_of_classes)):
        tmp = X_test.copy()
        tmp[:, -int((window_size**2) / 2)] = i
        res[:, i] = log_likelihood(spn, tmp)[:, 0]

    predicted_classes = np.argmax(res, axis=1).reshape((X_test.shape[0], 1))

    correct_predicted = 0
    for x, y in zip(X_test[:, -5], predicted_classes):
        if x == y[0]:
            correct_predicted += 1
    accuracy = correct_predicted / X_test.shape[0]
    return spn, accuracy
Beispiel #4
0
 def setUp(self):
     add_parametric_text_support()