def _deserialize_sum(self, node, node_map):
     child_ids = node.sum.children
     # Resolve references to child nodes by ID.
     children = [node_map.get(id) for id in child_ids]
     # Check all childs have been resolved.
     assert None not in children, "Child node ID could not be resolved"
     sum = Sum(children = children, weights=node.sum.weights)
     sum.id = node.id
     return sum
def create_sum(data=None,
               node_id=0,
               parent_id=0,
               pos=0,
               context=None,
               scope=None,
               split_rows=None,
               split_on_sum=True,
               **kwargs):
    assert split_rows is not None, "No split_rows lambda"
    assert scope is not None, "No scope"

    result = []

    data_slices = split_rows(data, context, scope)

    if len(data_slices) == 1:
        result.append((
            SplittingOperations.GET_NEXT_OP,
            {
                "data": data,
                "parent_id": parent_id,
                "pos": pos,
                "no_clusters": True,
                "scope": scope,
            },
        ))
        return result

    node = Sum()
    node.scope.extend(scope)
    node.id = node_id
    # assert parent.scope == node.scope

    for data_slice, scope_slice, proportion in data_slices:
        assert isinstance(scope_slice, list), "slice must be a list"

        child_data = data
        if split_on_sum:
            child_data = data_slice

        node.children.append(None)
        node.weights.append(proportion)
        result.append((
            SplittingOperations.GET_NEXT_OP,
            {
                "data": child_data,
                "parent_id": node.id,
                "pos": len(node.children) - 1,
                "scope": scope,
            },
        ))

    return node, result
Exemplo n.º 3
0
def learn_structure(dataset,
                    context,
                    op_lambdas=_op_lambdas,
                    prune=True,
                    validate=True,
                    compress=True,
                    num_worker_threads=30,
                    parallelized_ops=set([
                        SplittingOperations.GET_NEXT_OP,
                        SplittingOperations.CREATE_PRODUCT_NODE,
                        SplittingOperations.NAIVE_FACTORIZATION,
                        SplittingOperations.CREATE_CONDITIONAL_NODE,
                        SplittingOperations.REMOVE_UNINFORMATIVE_FEATURES
                    ]),
                    **kwargs):
    assert dataset is not None
    assert context is not None
    assert op_lambdas is not None

    # non_consecutive but monotonic counter

    id_counter = IdCounter()

    # root = Product()
    root = Sum()
    root.children.append(None)
    root.id = id_counter.increment()

    nodes = {root.id: root}

    tasks = Queue()
    tasks.put((
        SplittingOperations.GET_NEXT_OP,
        {
            "data": dataset,
            "parent_id": root.id,
            "parent_type": type(root),
            "pos": 0,
            'node_id': id_counter.increment(),
            "is_first": True
        },
    ))

    def op_lambda_eval(params):
        next_op, context, all_params = params
        func = op_lambdas.get(next_op, None)
        assert func is not None, "No lambda function associated with operation: %s" % (
            next_op)
        if func == create_leaf_node:
            result = func(context=context, **all_params)
        else:
            result = func(context=context, **all_params)
        return (all_params['parent_id'], all_params['pos']), result

    def handle_op_lambdas_result(result):
        (parent_id, pos), (node, subtasks) = result
        if node is not None:
            nodes[parent_id].children[pos] = node
            nodes[node.id] = node

        if subtasks is not None:
            assert isinstance(subtasks, list)
            for e in subtasks:
                assert isinstance(e, tuple)
                tasks.put(e)

    parallelizable_tasks = []

    while True:

        while not tasks.empty():
            task = tasks.get()
            assert task is not None

            next_op, op_params = task

            all_params = ChainMap(op_params, kwargs)
            all_params['node_id'] = id_counter.increment()
            all_params['parent_type'] = type(nodes[all_params['parent_id']])

            if True or parallelized_ops is not None and next_op in parallelized_ops:

                op_result = op_lambda_eval((next_op, context, all_params))
                print("next_op", next_op, 'rows', all_params['data'].shape[0],
                      'scope', all_params['scope'])

                try:
                    for r in op_result[1][1]:
                        op, par = r
                        newp = dict(par)
                        x = newp['data']
                        del newp['data']
                        print('res op', op, 'data', x.shape, 'par', newp)
                except:
                    pass

                handle_op_lambdas_result(op_result)
            else:
                parallelizable_tasks.append((next_op, context, all_params))

        with Pool(num_worker_threads) as pool:
            results = pool.imap(op_lambda_eval, parallelizable_tasks)
            for r in tqdm.tqdm(results, total=len(parallelizable_tasks)):
                handle_op_lambdas_result(r)
            parallelizable_tasks.clear()

        if tasks.empty():
            break

    node = root.children[0]
    assign_ids(node)
    #
    # if compress:
    #     node = Compress(node)
    # if prune:
    #     node = Prune(node)
    # # if validate:
    #     valid, err = is_valid(node)
    #     assert valid, "invalid spn: " + err

    return node