Example #1
0
def sum_condition(node, children, input_vals=None, scope=None):
    if not scope.intersection(node.scope):
        return Copy(node), 0
    new_node = Sum()
    new_node.scope = list(set(node.scope) - scope)
    new_weights = []
    probs = []
    for i, c in enumerate(children):
        if c[0]:
            new_node.children.append(c[0])
            new_weights.append(node.weights[i] * np.exp(c[1]))
        else:
            probs.append(node.weights[i] * np.exp(c[1]))
    new_node.weights = [w / sum(new_weights) for w in new_weights]
    assert np.all(np.logical_not(np.isnan(
        new_node.weights))), 'Found nan weights'
    if not new_node.scope:
        return None, np.log(sum(probs))
    return new_node, np.log(sum(new_weights))
Example #2
0
def get_flat_spn(spn, target_id):
    
    from spn.structure.Base import Sum, Product, Leaf, assign_ids
    from spn.algorithms.TransformStructure import Prune
    from spn.algorithms.Validity import is_valid
    from copy import deepcopy
    
    
    flat_spn = Sum()
    flat_spn.scope=spn.scope
    
    def create_flat_spn_recursive(node, distribution_mix, prob=1.0, independent_nodes=[]):
        
        if isinstance(node, Sum):
            for i, c in enumerate(node.children):
                forwarded_weight = node.weights[i] * prob
                create_flat_spn_recursive(c, distribution_mix, forwarded_weight, independent_nodes.copy())
        
        elif isinstance(node, Product):
            
            stop = False
            next_node = None
            
            for c in node.children:
                if target_id in c.scope:
                    if len(c.scope) == 1:
                        stop = True
                        independent_nodes.append(deepcopy(c))
                    else:
                        next_node = c
                else:
                    for feature_id in c.scope:
                        weighted_nodes = get_nodes_with_weight(c, feature_id)
                        t_node = type(weighted_nodes[0][1])
                        mixed_node = distribution_mix[t_node](weighted_nodes)
                        independent_nodes.append(mixed_node)
            
            if stop:
                flat_spn.weights.append(prob)
                prod = Product(children=independent_nodes)
                prod.scope = spn.scope
                flat_spn.children.append(prod)
                
            else:
                create_flat_spn_recursive(next_node, distribution_mix, prob, independent_nodes)
                
        else:
            raise Exception("Can only iterate over Sum and Product nodes")
        
        
    from simple_spn.internal.MixDistributions import mix_categorical
    
    distribution_mix = {Categorical : mix_categorical}
    
    
    create_flat_spn_recursive(spn, distribution_mix)
    assign_ids(flat_spn)
    flat_spn = Prune(flat_spn)
    valid, err = is_valid(flat_spn)
    assert valid, err

    return flat_spn