def __init__(self, enter_name, enter_input_id, next_iteration_input_id, switch_true_identity_output_id, exit_output_id, is_tensor_array, ta_index_id, g): self.enter_name = enter_name self.enter_input_id = enter_input_id # the output of iteration body graph for this variable # should not be None utils.make_sure(next_iteration_input_id, "next_iteration_input_id should not be None") self.next_iteration_input = TensorValueInfo(next_iteration_input_id, g) # the starting point of iteration body graph, # might be None when this variable value (either initial value or last iteration output value) # is not consumed iteration body graph nodes. self.switch_true_identity_output = TensorValueInfo( switch_true_identity_output_id, g) # the switch_false branch is ended with Exit, which is a boundary for the loop, # might be None when no consumers for the variable output. self.exit_output = TensorValueInfo(exit_output_id, g) # only applicable for tensor array variable self.is_tensor_array = is_tensor_array # todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration. # then we can be sure this is equivalent to scan output behavior. self.ta_index_id = ta_index_id
def _crop_loop_condition_sub_graph(self, context): input_ids = [] output_ids = [context.loop_cond.input[0]] outputs = [TensorValueInfo(o, self.g) for o in output_ids] ops, enter_nodes, merge_nodes = self.find_subgraph(set(input_ids), set(output_ids), self.g, merge_as_end=True) for enter_node in enter_nodes: # connect Enter's output to Enter's input self.g.replace_all_inputs(ops, enter_node.output[0], enter_node.input[0]) dependent_vars = [] for merge_node in merge_nodes: enter_node = [n for n in merge_node.inputs if n.type == "Enter"][0] loop_var = context.loop_properties.all_variables[enter_node.name] # cut off connection between condition graph and Merge node. non_switch_consumers = [n for n in self.g.find_output_consumers(merge_node.output[0]) if n.type != "Switch"] self.g.replace_all_inputs(non_switch_consumers, merge_node.output[0], loop_var.switch_true_identity_output.id) dependent_vars.append(loop_var) # cut off connection between condition graph and LoopCond node. self.g.replace_all_inputs([context.loop_cond], context.loop_cond.output[0], INVALID_INPUT_ID) graph_info = GraphInfo(ops, [], outputs) graph_info.dependent_vars = dependent_vars return graph_info
def __init__(self, data_input_id, index_input_id, consumer_id, g): self.index_input_id = index_input_id self.data_input_id = data_input_id # tensor array is unstacked before being used in loop, consumer_id is the node # (in the iteration body graph) consuming one of the element of tensor array. self.consumer = TensorValueInfo(consumer_id, g)
def _crop_loop_condition_sub_graph(self, context): input_ids = [] output_ids = [context.loop_cond.input[0]] outputs = [TensorValueInfo(o, self.g) for o in output_ids] ops, enter_nodes, merge_nodes = self.find_subgraph(set(input_ids), set(output_ids), self.g, merge_as_end=True) for enter_node in enter_nodes: # connect Enter's output to Enter's input self.g.replace_all_inputs(enter_node.output[0], enter_node.input[0], ops=ops) dependent_vars = [] for merge_node in merge_nodes: enter_node = [n for n in merge_node.inputs if n.type == "Enter"][0] loop_var = context.loop_properties.all_variables[(enter_node.name, merge_node.name)] # cut off connection between condition graph and Merge node. # replace condition graph's inputs to be cell graph's outputs, because we want condition graph # to consumer cell graph outputs. non_switch_consumers = [n for n in self.g.find_output_consumers(merge_node.output[0]) if n.type != "Switch"] self.g.replace_all_inputs(merge_node.output[0], loop_var.next_iteration_input.id, ops=non_switch_consumers) dependent_vars.append(loop_var) # cut off connection between condition graph and LoopCond node. self.g.replace_all_inputs(context.loop_cond.output[0], INVALID_INPUT_ID, ops=[context.loop_cond]) graph_info = GraphInfo(ops, [], outputs) graph_info.dependent_vars = dependent_vars return graph_info