def _deserialize_sum(self, node, node_map): child_ids = node.sum.children # Resolve references to child nodes by ID. children = [node_map.get(id) for id in child_ids] # Check all childs have been resolved. assert None not in children, "Child node ID could not be resolved" sum = Sum(children = children, weights=node.sum.weights) sum.id = node.id return sum
def create_sum(data=None, node_id=0, parent_id=0, pos=0, context=None, scope=None, split_rows=None, split_on_sum=True, **kwargs): assert split_rows is not None, "No split_rows lambda" assert scope is not None, "No scope" result = [] data_slices = split_rows(data, context, scope) if len(data_slices) == 1: result.append(( SplittingOperations.GET_NEXT_OP, { "data": data, "parent_id": parent_id, "pos": pos, "no_clusters": True, "scope": scope, }, )) return result node = Sum() node.scope.extend(scope) node.id = node_id # assert parent.scope == node.scope for data_slice, scope_slice, proportion in data_slices: assert isinstance(scope_slice, list), "slice must be a list" child_data = data if split_on_sum: child_data = data_slice node.children.append(None) node.weights.append(proportion) result.append(( SplittingOperations.GET_NEXT_OP, { "data": child_data, "parent_id": node.id, "pos": len(node.children) - 1, "scope": scope, }, )) return node, result
def learn_structure(dataset, context, op_lambdas=_op_lambdas, prune=True, validate=True, compress=True, num_worker_threads=30, parallelized_ops=set([ SplittingOperations.GET_NEXT_OP, SplittingOperations.CREATE_PRODUCT_NODE, SplittingOperations.NAIVE_FACTORIZATION, SplittingOperations.CREATE_CONDITIONAL_NODE, SplittingOperations.REMOVE_UNINFORMATIVE_FEATURES ]), **kwargs): assert dataset is not None assert context is not None assert op_lambdas is not None # non_consecutive but monotonic counter id_counter = IdCounter() # root = Product() root = Sum() root.children.append(None) root.id = id_counter.increment() nodes = {root.id: root} tasks = Queue() tasks.put(( SplittingOperations.GET_NEXT_OP, { "data": dataset, "parent_id": root.id, "parent_type": type(root), "pos": 0, 'node_id': id_counter.increment(), "is_first": True }, )) def op_lambda_eval(params): next_op, context, all_params = params func = op_lambdas.get(next_op, None) assert func is not None, "No lambda function associated with operation: %s" % ( next_op) if func == create_leaf_node: result = func(context=context, **all_params) else: result = func(context=context, **all_params) return (all_params['parent_id'], all_params['pos']), result def handle_op_lambdas_result(result): (parent_id, pos), (node, subtasks) = result if node is not None: nodes[parent_id].children[pos] = node nodes[node.id] = node if subtasks is not None: assert isinstance(subtasks, list) for e in subtasks: assert isinstance(e, tuple) tasks.put(e) parallelizable_tasks = [] while True: while not tasks.empty(): task = tasks.get() assert task is not None next_op, op_params = task all_params = ChainMap(op_params, kwargs) all_params['node_id'] = id_counter.increment() all_params['parent_type'] = type(nodes[all_params['parent_id']]) if True or parallelized_ops is not None and next_op in parallelized_ops: op_result = op_lambda_eval((next_op, context, all_params)) print("next_op", next_op, 'rows', all_params['data'].shape[0], 'scope', all_params['scope']) try: for r in op_result[1][1]: op, par = r newp = dict(par) x = newp['data'] del newp['data'] print('res op', op, 'data', x.shape, 'par', newp) except: pass handle_op_lambdas_result(op_result) else: parallelizable_tasks.append((next_op, context, all_params)) with Pool(num_worker_threads) as pool: results = pool.imap(op_lambda_eval, parallelizable_tasks) for r in tqdm.tqdm(results, total=len(parallelizable_tasks)): handle_op_lambdas_result(r) parallelizable_tasks.clear() if tasks.empty(): break node = root.children[0] assign_ids(node) # # if compress: # node = Compress(node) # if prune: # node = Prune(node) # # if validate: # valid, err = is_valid(node) # assert valid, "invalid spn: " + err return node