예제 #1
0
def remove_redundant(steps, outputs):
    obj_to_step = []
    for i, step in enumerate(steps):
        obj_to_step.extend([i] * len(step.tool.out_types))
    used = [False] * len(steps)
    stack = list(outputs)
    while stack:
        obj = stack.pop()
        step_i = obj_to_step[obj]
        if used[step_i]: continue
        used[step_i] = True
        stack.extend(steps[step_i].local_args)

    new_steps = []
    old_to_new = dict()
    obj_index = 0
    for step, u in zip(steps, used):
        if not u: continue
        new_local_args = tuple(old_to_new[x] for x in step.local_args)
        new_step = ToolStep(step.tool, step.hyper_params, new_local_args,
                            obj_index, step.debug_msg)
        obj_index += len(step.tool.out_types)
        old_to_new.update(zip(step.local_outputs, new_step.local_outputs))
        new_steps.append(new_step)

    return new_steps, old_to_new, used
예제 #2
0
파일: gtool.py 프로젝트: thtrieu/geo_logic
 def run_tool(self, tool, *args, update=True):
     args = tuple(map(self.instantiate_obj, args))
     if isinstance(tool, str):
         arg_types = tuple(self.vis.gi_to_type(x) for x in args)
         tool = self.tools[tool, arg_types]
     step = ToolStep(tool, (), args, len(self.env.gi_to_step_i))
     return self.env.add_step(step, update=update)
예제 #3
0
def steps_var_replace(steps, old_to_new):
    new_steps = []
    for step in steps:
        new_local_args = tuple(old_to_new[x] for x in step.local_args)
        new_step = ToolStep(step.tool, step.hyper_params, new_local_args,
                            step.start_out, step.debug_msg)
        new_steps.append(new_step)

    return new_steps
예제 #4
0
def copy_steps(steps):
    return [
        ToolStep(
            tool=step.tool,
            hyper_params=step.hyper_params,
            local_args=step.local_args,
            start_out=step.start_out,
            debug_msg=step.debug_msg,
        ) for step in steps
    ]
예제 #5
0
파일: gtool.py 프로젝트: thtrieu/geo_logic
 def run_m_tool(self, name, res_obj, *args, update=True):
     args = tuple(map(self.instantiate_obj, args))
     num_args = tuple(self.vis.gi_to_num(x) for x in args)
     arg_types = tuple(type(x) for x in num_args)
     out_type = type(res_obj)
     tool = self.tools.m[name, arg_types, out_type]
     hyper_params = tool.get_hyperpar(res_obj, *num_args)
     step = ToolStep(tool, hyper_params, args, len(self.env.gi_to_step_i))
     #if name == "intersection":
     #    print(arg_types, hyper_params)
     return self.env.add_step(step, update=update)
예제 #6
0
    def parse_line(self, line_info, line):
        try:
            start_var_num = self.var_num
            tokens = line.split()
            debug_msg = "l{}: {}".format(*line_info)
            i = tokens.index('<-')
            outputs = tokens[:i]
            tool_name = tokens[i+1]
            args = iter(tokens[i+2:])
            hyper_params = []
            obj_args = []
            for arg in args:
                for hyper_type in (int, float, Fraction):
                    try:
                        val = hyper_type(arg)
                        hyper_params.append(val)
                        break
                    except ValueError:
                        pass
                else:
                    obj_args.append(arg)

            in_types = [
                type(x)
                for x in hyper_params
            ]
            in_types.extend(
                self.var_type(x)
                for x in obj_args
            )
            in_types = tuple(in_types)
            tool = self.tool_dict.get((tool_name, in_types), None)
            if tool is None:
                tool = self.tool_dict.get((tool_name, None), None)
                if tool is None:
                    raise Exception(
                        "Unknown tool: {} : {}".format(
                            tool_name, ' '.join(x.__name__ for x in in_types))
                    )
            if len(tool.out_types) != len(outputs):
                raise(Exception("Numbers of outputs do not match: {} : {}".format(
                    ' '.join(outputs), ' '.join(x.__name__ for x in tool.out_types)
                )))
            for o,t in zip(outputs, tool.out_types):
                self.add_var(o, t)

            return ToolStep(
                tool, hyper_params, self.var_indices(obj_args),
                start_var_num, debug_msg,
            )
        except Exception:
            print(debug_msg)
            raise
예제 #7
0
def expand_step(steps, expand_predicate):

    old_to_new = []
    obj_index = 0
    new_steps = []
    for step_i, step in enumerate(steps):
        new_local_args = tuple(old_to_new[x] for x in step.local_args)
        if expand_predicate(step_i, step):
            assert isinstance(step.tool, CompositeTool)
            assert not step.tool.implications
            assert not step.tool.proof

            subvars = list(new_local_args)
            for substep in step.tool.assumptions:
                sub_local_args = tuple(subvars[x] for x in substep.local_args)
                new_substep = ToolStep(
                    substep.tool,
                    substep.hyper_params,
                    sub_local_args,
                    obj_index,
                    substep.debug_msg,
                )
                new_steps.append(new_substep)
                out_len = len(substep.tool.out_types)
                subvars.extend(range(obj_index, obj_index + out_len))
                obj_index += out_len

            old_to_new.extend(subvars[x] for x in step.tool.result)
        else:
            new_step = ToolStep(step.tool, step.hyper_params, new_local_args,
                                obj_index, step.debug_msg)
            obj_index += len(step.tool.out_types)
            old_to_new.extend(new_step.local_outputs)
            new_steps.append(new_step)

    return new_steps, old_to_new
예제 #8
0
def merge_duplicities(steps, imported_tools):
    logic = LogicalCore(basic_tools=imported_tools)
    step_env = ToolStepEnv(logic)
    step_env.run_steps(steps, 0, catch_errors=False)
    global_to_local = dict()
    old_to_new = []
    for loc, glob in enumerate(step_env.local_to_global):
        glob = logic.ufd.obj_to_root(glob)
        old_to_new.append(global_to_local.setdefault(glob, loc))

    new_steps = steps_var_replace(steps, old_to_new)
    for step in steps:
        new_local_args = tuple(old_to_new[x] for x in step.local_args)
        new_step = ToolStep(step.tool, step.hyper_params, new_local_args,
                            len(old_to_new), step.debug_msg)
        new_steps.append(new_step)

    return new_steps, old_to_new