def make_first_level(all_features, complete_x, loss, x_size, y_test, errors, loss_type, top_k, alpha, w): first_level = [] counter = 0 all_nodes = {} # First level slices are enumerated in a "classic way" (getting data and not analyzing bounds for feature in all_features: new_node = Node(complete_x, loss, x_size, y_test, errors) new_node.parents = [(feature, counter)] new_node.attributes.append((feature, counter)) new_node.name = new_node.make_name() new_id = len(all_nodes) new_node.key = new_node.make_key(new_id) all_nodes[new_node.key] = new_node new_node.process_slice(loss_type) new_node.score = opt_fun(new_node.loss, new_node.size, loss, x_size, w) new_node.c_upper = new_node.score first_level.append(new_node) new_node.print_debug(top_k, 0) # constraints for 1st level nodes to be problematic candidates if new_node.check_constraint(top_k, x_size, alpha): # this method updates top k slices if needed top_k.add_new_top_slice(new_node) counter = counter + 1 return first_level, all_nodes
def make_first_level(all_features, complete_x, loss, x_size, y_test, errors, loss_type, w, alpha, top_k): all_nodes = {} counter = 0 first_level = [] for feature in all_features: new_node = Node(complete_x, loss, x_size, y_test, errors) new_node.parents = [(feature, counter)] new_node.attributes.append((feature, counter)) new_node.name = new_node.make_name() new_id = len(all_nodes) new_node.key = new_node.make_key(new_id) all_nodes[new_node.key] = new_node new_node.process_slice(loss_type) # for first level nodes all bounds are strict as concrete metrics new_node.s_upper = new_node.size new_node.s_lower = 0 new_node.e_upper = new_node.loss new_node.e_max_upper = new_node.e_max new_node.score = opt_fun(new_node.loss, new_node.size, loss, x_size, w) new_node.c_upper = new_node.score first_level.append(new_node) new_node.print_debug(top_k, 0) # constraints for 1st level nodes to be problematic candidates if new_node.score > 1 and new_node.size >= x_size / alpha: # this method updates top k slices if needed top_k.add_new_top_slice(new_node) counter = counter + 1 return first_level, all_nodes