def optimize(self, graph: Graph): flag_changed = False """ before) v0[RGBA] -{ConvertRtoRGBA}- v1[RGBA] after) v0[RGBA] -{ConvertRGBAtoR}- v2[Order=v0.order][R] -{Transpose}- v3[Order=v1.order][R]-{ConvertRtoRGBA}- v1[RGBA] """ matches = traverse.search_sub_structure( graph, [Variable, ConvertRtoRGBA, Variable]) while len(matches) > 0: v0, r2rgba, v1 = matches.pop( ) # type: Variable, ConvertRtoRGBA, Variable if not (ChannelMode.get(v0) == ChannelMode.get(v1) == ChannelModeEnum.RGBA): continue flag_changed = True r2rgba.remove_all() v2 = convert_rgba_to_r(v0) v2.change_order(v0.order) v3 = v2.transpose(v1.order) v1_new = convert_r_to_rgba(v3) v1_new.change_order(v1.order) OptimizeRule.replace_variable(graph, v1_new, v1) """ before) v0[R] -{ConvertRGBAtoR}- v1[R] after) v0[R] -{Transpose}- v1[R] """ matches = traverse.search_sub_structure( graph, [Variable, ConvertRGBAtoR, Variable]) while len(matches) > 0: v0, rgba2r, v1 = matches.pop( ) # type: Variable, ConvertRGBAtoR, Variable if not (ChannelMode.get(v0) == ChannelMode.get(v1) == ChannelModeEnum.R): continue flag_changed = True rgba2r.remove_all() OptimizeRule.replace_variable(graph, v0.transpose(v1.order), v1) return graph, flag_changed
def _replace_output(op: Operator, var_name: str, target: ChannelModeEnum): """ before) -{op}- v after) -{op}- v' -{conversion}- v """ v = op.outputs[var_name] if ChannelMode.get(v) == target: return False v_new = Variable(v.shape, v.order) ChannelMode.set(v_new, target) op.replace_output(v, v_new) if target == ChannelModeEnum.RGBA: convert_rgba_to_r(v_new).change_order(v.order).replace(v) else: convert_r_to_rgba(v_new).change_order(v.order).replace(v) return True
def _replace_input(op: Operator, var_name: str, target: ChannelModeEnum): """ before) v -{op}- after) v -{conversion}- v' -{op}- """ v = op.inputs[var_name] if ChannelMode.get(v) == target: return False if target == ChannelModeEnum.RGBA: v_new = convert_r_to_rgba(v) else: v_new = convert_rgba_to_r(v) op.replace_input(v, v_new) return True