def run_internal(self): for op in self.g.get_nodes(): if not is_loopcond_op(op): continue log.debug("======================\n handling loop cond node called %s", op.name) context = self.create_context() context.loop_cond = op self._check_in_read_only_mode(context) if self.need_rewrite(context): # cut off connection between cell/cond graphs and useless nodes like Merge, NextIteration. self._cut_off_connection_for_cell(context) context.cell_graph = self._crop_loop_body_sub_graph(context) context.cond_graph = self._crop_loop_condition_sub_graph(context) _result = self.rewrite(context) if _result == REWRITER_RESULT.OK: log.debug("rewrite successfully") elif _result == REWRITER_RESULT.SKIP: log.debug("rewrite skipped for LoopCond called %s", op.name) continue elif _result == REWRITER_RESULT.FAIL: raise ValueError("rewrite failed, so just fast fail it") if self.g.outputs: # clean the graph based on output names. self.g.delete_unused_nodes(self.g.outputs) return self.g.get_nodes()
def run_internal(self): log.debug("enter loop rewriter") for n in self.g.get_nodes(): if is_loopcond_op(n): log.debug("======================") log.debug("found LoopCond op named %s", n.name) context = self.create_context() self._parse_loop_variables(n, context) if self.need_rewrite(context): _result = self.rewrite(context) if _result == REWRITER_RESULT.OK: log.debug("rewrite successfully") elif _result == REWRITER_RESULT.SKIP: log.debug("rewrite skipped for LoopCond called %s", n.name) continue elif _result == REWRITER_RESULT.FAIL: raise ValueError("rewrite failed, so just fast fail it") all_output_name = copy.deepcopy(self.g.outputs) if all_output_name: all_output_name.extend(BodyGraphDict.get_body_graph_output_names()) self.g.delete_unused_nodes(all_output_name) return self.g.get_nodes()