def _get_loop_var_from_switch(self, switch_node): if switch_node.type != 'Switch': log.error("not a switch node, skip") return None # the first input is data merge_node = switch_node.inputs[0] if merge_node.type != "Merge": log.error("switch node does not has Merge as its first input") return None # find the output_true consumers switch_consumers = self.g.find_output_consumers(switch_node.output[1]) if len(switch_consumers) != 1: raise ValueError("switch has non-1 consumers") if switch_consumers[0].type != "Identity": raise ValueError("switch has consumer that is not Identity") identity_node = switch_consumers[0] target_node_input_id = None enter_node = [n for n in merge_node.inputs if n.type == 'Enter'][0] target_node_input_id = enter_node.input[0] log.debug("a Switch >> Merge >> Enter is found called %s", enter_node.inputs[0].name) next_iteration_node = [ n for n in merge_node.inputs if n.type == 'NextIteration' ][0] last_iteration_output_id = next_iteration_node.input[0] # find the output_false consumers to see whether there is consumer for this var switch_false_consumers = self.g.find_output_consumers( switch_node.output[0]) false_consumer_count = len(switch_false_consumers) exit_output_id = None if false_consumer_count == 1: exit_node = switch_false_consumers[0] if exit_node.type != "Exit": raise ValueError("switch false branch is followed by non-Exit") exit_output_id = exit_node.output[0] elif false_consumer_count == 0: exit_output_id = None else: raise ValueError("unexpected number of switch false consumers") is_ta = False if is_tensor_array_op(self.g.get_node_by_output(target_node_input_id)): is_ta = True loop_var = LoopVariable(enter_node.name, target_node_input_id, last_iteration_output_id, identity_node.output[0], exit_output_id, is_ta) loop_var = self._tune_shape_for_loop_var(loop_var) loop_var = self._tune_shape_for_loop_ta_var(loop_var) return loop_var
def _output_switch_check(self, enter_target_node_input_id, identity_consumers, match): ta_write_nodes = [c for c in identity_consumers if is_tensor_array_write_op(c)] if len(ta_write_nodes) == 1: enter_target_node = self.g.get_node_by_output(enter_target_node_input_id) if is_tensor_array_op(enter_target_node): log.debug("found output switch node") return enter_target_node_input_id log.debug("found enter target node is not ta node") return None log.debug("%d TensorArrayWriteV3 matching found, cannot validate output switch", len(ta_write_nodes)) return None
def _get_loop_var_from_switch(self, switch_node): if switch_node.type != 'Switch': log.error("not a switch node, skip") return None # the first input is data merge_node = switch_node.inputs[0] if merge_node.type != "Merge": log.error("switch node does not has Merge as its first input") return None # find the output_true consumers switch_consumers = self.g.find_output_consumers(switch_node.output[1]) switch_true_consumer_cnt = len(switch_consumers) if switch_true_consumer_cnt == 0: switch_true_identity_output = None elif switch_true_consumer_cnt == 1: if switch_consumers[0].type != "Identity": raise ValueError("switch has consumer that is not Identity") switch_true_identity_output = switch_consumers[0].output[0] else: raise ValueError("switch_true " + switch_node.name + " has unexpected count of consumers:", [n.name for n in switch_consumers]) target_node_input_id = None enter_node = [n for n in merge_node.inputs if n.type == 'Enter'][0] target_node_input_id = enter_node.input[0] log.debug("a Switch >> Merge >> Enter is found called %s", enter_node.inputs[0].name) next_iteration_node = [n for n in merge_node.inputs if n.type == 'NextIteration'][0] last_iteration_output_id = next_iteration_node.input[0] # find the output_false consumers to see whether there is consumer for this var switch_false_consumers = self.g.find_output_consumers(switch_node.output[0]) false_consumer_count = len(switch_false_consumers) exit_output_id = None if false_consumer_count == 1: exit_node = switch_false_consumers[0] if exit_node.type != "Exit": raise ValueError("switch false branch is followed by non-Exit") exit_output_id = exit_node.output[0] elif false_consumer_count == 0: # sometime, the variable output won't be used in the new iteration as input. exit_output_id = None else: raise ValueError("unexpected number of switch false consumers") is_ta = False ta_index_id = None if is_tensor_array_op(self.g.get_node_by_output(target_node_input_id)): is_ta = True ta_write_node = self.g.get_node_by_output(last_iteration_output_id) utils.make_sure(is_tensor_array_write_op(ta_write_node), "ta nextiteration is not following ta write op") last_iteration_output_id = ta_write_node.input[2] ta_index_id = ta_write_node.input[1] # here we parse patterns generated by # ta.write(), then ta.stack(), because this is the most frequent usage pattern. if exit_output_id: exit_consumers = self.g.find_output_consumers(exit_output_id) ta_gather_node = [n for n in exit_consumers if is_tensor_array_gather_op(n)][0] # update exit output id, treat the gather output as ta's output exit_output_id = ta_gather_node.output[0] loop_var = LoopVariable(enter_node.name, target_node_input_id, last_iteration_output_id, switch_true_identity_output, exit_output_id, is_ta, ta_index_id, self.g) return loop_var