def _cut_off_connection_for_cell(self, context): nodes_to_remove = [] all_vars = [context.time_var] all_vars += [val for _, val in context.other_loop_vars.items()] for val in all_vars: # remove the node to cut off a starting node of the cell (e.g. loop body). nodes_to_remove.append( self.g.get_node_by_name(val.switch_true_identity_output_id)) # connect NextIteration to an invalid node, to cut off a ending node of the cell. next_iter_nodes = [ n for n in self.g.get_nodes() if n.type == "NextIteration" ] self.g.replace_all_inputs(next_iter_nodes, val.next_iteration_input_id, INVLAID_INPUT_ID) for input_ta in context.input_tas: # remove the node to cut off connection between scan_input and the cell. nodes_to_remove.append(self.g.get_node_by_name(input_ta.output_id)) for output_ta in context.output_tas: # remove the node to cut off connection between scan_output and the cell. ta_write_nodes = [ n for n in self.g.get_nodes() if is_tensor_array_write_op(n) ] self.g.replace_all_inputs(ta_write_nodes, output_ta.data_input_id, INVLAID_INPUT_ID) return nodes_to_remove
def _tune_shape_for_loop_ta_var(self, loop_var): if loop_var.is_tensor_array: ta_write_node = self.g.get_node_by_output( loop_var.next_iteration_input_id) if not is_tensor_array_write_op(ta_write_node): raise ValueError( "ta var nextiteration is not following ta write op") loop_var.next_iteration_input_id = ta_write_node.input[2] loop_var.ta_index_id = ta_write_node.input[1] ta_output_shape = None next_iteration_shape = self.g.get_shape( loop_var.next_iteration_input_id) if next_iteration_shape is None: enter_node = ta_write_node.inputs[0] ta_node_output = enter_node.input[0] ta_element_shape = self.g.get_shape(ta_node_output) ta_output_shape = ta_element_shape log.debug( "loop var [%s, %s] output shapes are inferred from TA element shape", loop_var.enter_name, loop_var.enter_input_id) else: log.debug( "loop var [%s, %s] output shapes are inferred from cell output %s", loop_var.enter_name, loop_var.enter_input_id, loop_var.next_iteration_input_id) ta_output_shape = next_iteration_shape self.g.set_shape(loop_var.next_iteration_input_id, ta_output_shape) self.g.set_shape(loop_var.switch_true_identity_output_id, ta_output_shape) self.g.set_shape(loop_var.exit_output_id, ta_output_shape) return loop_var
def _cut_off_connection_for_cell(self, context): for val in context.loop_properties.all_variables.values(): if val.switch_true_identity_output.id: # remove the node to cut off a starting node of the cell (e.g. loop body). n = self.g.get_node_by_output( val.switch_true_identity_output.id) self.g.remove_node(n.name) if val.is_tensor_array: # connect NextIteration to an invalid node, to cut off an ending node of the cell. ta_write_nodes = [ n for n in self.g.get_nodes() if is_tensor_array_write_op(n) ] self.g.replace_all_inputs(ta_write_nodes, val.next_iteration_input.id, INVALID_INPUT_ID) else: # connect NextIteration to an invalid node, to cut off an ending node of the cell. next_iter_nodes = [ n for n in self.g.get_nodes() if n.type == "NextIteration" ] self.g.replace_all_inputs(next_iter_nodes, val.next_iteration_input.id, INVALID_INPUT_ID) for scan_input in context.loop_properties.scan_inputs: # remove the node to cut off connection between scan_input and the cell. self.g.remove_node(self.g.get_node_by_output(scan_input.id).name)
def _tune_shape_for_loop_ta_var(self, loop_var): if loop_var.is_tensor_array: ta_write_node = self.g.get_node_by_name( loop_var.next_iteration_input_id) if not is_tensor_array_write_op(ta_write_node): raise ValueError( "ta var nextiteration is not following ta write op") loop_var.next_iteration_input_id = ta_write_node.input[2] loop_var.ta_index_id = ta_write_node.input[1] log.debug( "loop var [%s, %s] output shapes are inferred from TA element shape", loop_var.enter_name, loop_var.enter_input_id) enter_node = ta_write_node.inputs[0] output_ta_node = enter_node.inputs[0] log.debug(self.g.get_shape(output_ta_node.output[0])) self.g.copy_shape(output_ta_node.output[0], loop_var.next_iteration_input_id) self.g.copy_shape(output_ta_node.output[0], loop_var.switch_true_identity_output_id) self.g.copy_shape(output_ta_node.output[0], loop_var.exit_output_id) return loop_var
def _output_switch_check(self, enter_target_node_input_id, identity_consumers, match): ta_write_nodes = [c for c in identity_consumers if is_tensor_array_write_op(c)] if len(ta_write_nodes) == 1: enter_target_node = self.g.get_node_by_output(enter_target_node_input_id) if is_tensor_array_op(enter_target_node): log.debug("found output switch node") return enter_target_node_input_id log.debug("found enter target node is not ta node") return None log.debug("%d TensorArrayWriteV3 matching found, cannot validate output switch", len(ta_write_nodes)) return None
def _find_state_variable_with_select(self, context, next_iteration_input, switch_true_identity_consumers): """ Find state variables from switch_true_identity_consumers to next_iteration_input. Select maybe added after next_iteration_input. """ # find all select not followed by TensorArrayWrite select = [] for c in self.g.find_output_consumers(next_iteration_input): if not is_select_op(c): continue out_ta_writer = [ o for o in self.g.find_output_consumers(c.output[0]) if is_tensor_array_write_op(o) ] if out_ta_writer: continue select.append(c) if len(select) == 1: next_iteration_input = select[0].output[0] switch_true_identity_consumers.append(select[0]) log.debug( "try to find state variable from [%s, %s]", next_iteration_input, switch_true_identity_consumers ) def checker(state_variable): if state_variable.next_iteration_input.id != next_iteration_input: return False for consumer in switch_true_identity_consumers: if state_variable.switch_true_identity_output.id not in consumer.input: return False return True state_variables = context.loop_properties.get_variables(checker) if len(state_variables) != 1: log.debug("found %d state variables", len(state_variables)) return None return state_variables[0]
def _get_loop_var_from_switch(self, switch_node): if switch_node.type != 'Switch': log.error("not a switch node, skip") return None # the first input is data merge_node = switch_node.inputs[0] if merge_node.type != "Merge": log.error("switch node does not has Merge as its first input") return None # find the output_true consumers switch_consumers = self.g.find_output_consumers(switch_node.output[1]) switch_true_consumer_cnt = len(switch_consumers) if switch_true_consumer_cnt == 0: switch_true_identity_output = None elif switch_true_consumer_cnt == 1: if switch_consumers[0].type != "Identity": raise ValueError("switch has consumer that is not Identity") switch_true_identity_output = switch_consumers[0].output[0] else: raise ValueError("switch_true " + switch_node.name + " has unexpected count of consumers:", [n.name for n in switch_consumers]) target_node_input_id = None enter_node = [n for n in merge_node.inputs if n.type == 'Enter'][0] target_node_input_id = enter_node.input[0] log.debug("a Switch >> Merge >> Enter is found called %s", enter_node.inputs[0].name) next_iteration_node = [n for n in merge_node.inputs if n.type == 'NextIteration'][0] last_iteration_output_id = next_iteration_node.input[0] # find the output_false consumers to see whether there is consumer for this var switch_false_consumers = self.g.find_output_consumers(switch_node.output[0]) false_consumer_count = len(switch_false_consumers) exit_output_id = None if false_consumer_count == 1: exit_node = switch_false_consumers[0] if exit_node.type != "Exit": raise ValueError("switch false branch is followed by non-Exit") exit_output_id = exit_node.output[0] elif false_consumer_count == 0: # sometime, the variable output won't be used in the new iteration as input. exit_output_id = None else: raise ValueError("unexpected number of switch false consumers") is_ta = False ta_index_id = None if is_tensor_array_op(self.g.get_node_by_output(target_node_input_id)): is_ta = True ta_write_node = self.g.get_node_by_output(last_iteration_output_id) utils.make_sure(is_tensor_array_write_op(ta_write_node), "ta nextiteration is not following ta write op") last_iteration_output_id = ta_write_node.input[2] ta_index_id = ta_write_node.input[1] # here we parse patterns generated by # ta.write(), then ta.stack(), because this is the most frequent usage pattern. if exit_output_id: exit_consumers = self.g.find_output_consumers(exit_output_id) ta_gather_node = [n for n in exit_consumers if is_tensor_array_gather_op(n)][0] # update exit output id, treat the gather output as ta's output exit_output_id = ta_gather_node.output[0] loop_var = LoopVariable(enter_node.name, target_node_input_id, last_iteration_output_id, switch_true_identity_output, exit_output_id, is_ta, ta_index_id, self.g) return loop_var