def __init__(self, dsl_dict, edge_cost_dict, input_type, output_type, input_size, output_size, max_num_units, min_num_units, max_num_children, max_depth, penalty, ite_beta=1.0): self.dsl_dict = dsl_dict self.edge_cost_dict = edge_cost_dict self.input_type = input_type self.output_type = output_type self.input_size = input_size self.output_size = output_size self.max_num_units = max_num_units self.min_num_units = min_num_units self.max_num_children = max_num_children self.max_depth = max_depth self.penalty = penalty self.ite_beta = ite_beta start = dsl.StartFunction(input_type=input_type, output_type=output_type, input_size=input_size, output_size=output_size, num_units=max_num_units) self.root_node = ProgramNode(start, 0, None, 0, 0, 0)
def enumerative_synthesis(self, graph, enumeration_depth, typesig=None, input_size=None, output_size=None): #construct the num_selected lists max_depth_copy = graph.max_depth graph.max_depth = enumeration_depth all_programs = [] enumerated = {} input_size = self.input_size if input_size is None else input_size output_size = self.output_size if output_size is None else output_size if typesig is None: root = copy.deepcopy(graph.root_node) else: new_start = dsl.StartFunction(input_type=typesig[0], output_type=typesig[1], input_size=input_size, output_size=output_size, num_units=graph.max_num_units) root = ProgramNode(new_start, 0, None, 0, 0, 0) def enumerate_helper(currentnode): printedprog = print_program(currentnode.program, ignore_constants=True) assert not enumerated.get(printedprog) enumerated[printedprog] = True if graph.is_fully_symbolic(currentnode.program): all_programs.append({ "program" : copy.deepcopy(currentnode.program), "struct_cost" : currentnode.cost, "depth" : currentnode.depth }) elif currentnode.depth < enumeration_depth: all_children = graph.get_all_children(currentnode, in_enumeration=True) for childnode in all_children: if not enumerated.get(print_program(childnode.program, ignore_constants=True)): enumerate_helper(childnode) enumerate_helper(root) graph.max_depth = max_depth_copy return all_programs