def _validate_output_exit_consumers(exit_consumers): if len(exit_consumers) != 2: return None gather_node = None for n in exit_consumers: if is_tensor_array_gather_op(n): gather_node = n elif is_tensor_array_size_op(n): continue else: return None return gather_node
def _parse_output_ta(self, context): for enter_name, loop_var in context.loop_variables.items(): if not loop_var.is_tensor_array: continue output_ta = TensorArrayProp() output_ta.data_input_id = loop_var.next_iteration_input_id output_ta.index_input_id = loop_var.ta_index_id if loop_var.exit_output_id: exit_consumers = self.g.find_output_consumers(loop_var.exit_output_id) ta_gather_node = [n for n in exit_consumers if is_tensor_array_gather_op(n)][0] output_ta.output_id = ta_gather_node.output[0] context.output_tas.append(output_ta) log.debug("output ta %s - data input (%s) shape: %s, output (%s) shape: %s", enter_name, output_ta.data_input_id, self.g.get_shape(output_ta.data_input_id), output_ta.output_id, self.g.get_shape(output_ta.output_id))
def _get_loop_var_from_switch(self, switch_node): if switch_node.type != 'Switch': log.error("not a switch node, skip") return None # the first input is data merge_node = switch_node.inputs[0] if merge_node.type != "Merge": log.error("switch node does not has Merge as its first input") return None # find the output_true consumers switch_consumers = self.g.find_output_consumers(switch_node.output[1]) switch_true_consumer_cnt = len(switch_consumers) if switch_true_consumer_cnt == 0: switch_true_identity_output = None elif switch_true_consumer_cnt == 1: if switch_consumers[0].type != "Identity": raise ValueError("switch has consumer that is not Identity") switch_true_identity_output = switch_consumers[0].output[0] else: raise ValueError("switch_true " + switch_node.name + " has unexpected count of consumers:", [n.name for n in switch_consumers]) target_node_input_id = None enter_node = [n for n in merge_node.inputs if n.type == 'Enter'][0] target_node_input_id = enter_node.input[0] log.debug("a Switch >> Merge >> Enter is found called %s", enter_node.inputs[0].name) next_iteration_node = [n for n in merge_node.inputs if n.type == 'NextIteration'][0] last_iteration_output_id = next_iteration_node.input[0] # find the output_false consumers to see whether there is consumer for this var switch_false_consumers = self.g.find_output_consumers(switch_node.output[0]) false_consumer_count = len(switch_false_consumers) exit_output_id = None if false_consumer_count == 1: exit_node = switch_false_consumers[0] if exit_node.type != "Exit": raise ValueError("switch false branch is followed by non-Exit") exit_output_id = exit_node.output[0] elif false_consumer_count == 0: # sometime, the variable output won't be used in the new iteration as input. exit_output_id = None else: raise ValueError("unexpected number of switch false consumers") is_ta = False ta_index_id = None if is_tensor_array_op(self.g.get_node_by_output(target_node_input_id)): is_ta = True ta_write_node = self.g.get_node_by_output(last_iteration_output_id) utils.make_sure(is_tensor_array_write_op(ta_write_node), "ta nextiteration is not following ta write op") last_iteration_output_id = ta_write_node.input[2] ta_index_id = ta_write_node.input[1] # here we parse patterns generated by # ta.write(), then ta.stack(), because this is the most frequent usage pattern. if exit_output_id: exit_consumers = self.g.find_output_consumers(exit_output_id) ta_gather_node = [n for n in exit_consumers if is_tensor_array_gather_op(n)][0] # update exit output id, treat the gather output as ta's output exit_output_id = ta_gather_node.output[0] loop_var = LoopVariable(enter_node.name, target_node_input_id, last_iteration_output_id, switch_true_identity_output, exit_output_id, is_ta, ta_index_id, self.g) return loop_var