예제 #1
0
    def __init__(self, enter_name, enter_input_id, next_iteration_input_id,
                 switch_true_identity_output_id, exit_output_id,
                 is_tensor_array, ta_index_id, g):
        self.enter_name = enter_name
        self.enter_input_id = enter_input_id

        # the output of iteration body graph for this variable
        # should not be None
        utils.make_sure(next_iteration_input_id,
                        "next_iteration_input_id should not be None")
        self.next_iteration_input = TensorValueInfo(next_iteration_input_id, g)

        # the starting point of iteration body graph,
        # might be None when this variable value (either initial value or last iteration output value)
        # is not consumed iteration body graph nodes.
        self.switch_true_identity_output = TensorValueInfo(
            switch_true_identity_output_id, g)

        # the switch_false branch is ended with Exit, which is a boundary for the loop,
        # might be None when no consumers for the variable output.
        self.exit_output = TensorValueInfo(exit_output_id, g)

        # only applicable for tensor array variable
        self.is_tensor_array = is_tensor_array
        # todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration.
        # then we can be sure this is equivalent to scan output behavior.
        self.ta_index_id = ta_index_id
    def _crop_loop_condition_sub_graph(self, context):
        input_ids = []
        output_ids = [context.loop_cond.input[0]]
        outputs = [TensorValueInfo(o, self.g) for o in output_ids]
        ops, enter_nodes, merge_nodes = self.find_subgraph(set(input_ids), set(output_ids), self.g, merge_as_end=True)

        for enter_node in enter_nodes:
            # connect Enter's output to Enter's input
            self.g.replace_all_inputs(ops, enter_node.output[0], enter_node.input[0])

        dependent_vars = []
        for merge_node in merge_nodes:
            enter_node = [n for n in merge_node.inputs if n.type == "Enter"][0]
            loop_var = context.loop_properties.all_variables[enter_node.name]

            # cut off connection between condition graph and Merge node.
            non_switch_consumers = [n for n in self.g.find_output_consumers(merge_node.output[0]) if n.type != "Switch"]
            self.g.replace_all_inputs(non_switch_consumers, merge_node.output[0],
                                      loop_var.switch_true_identity_output.id)
            dependent_vars.append(loop_var)

        # cut off connection between condition graph and LoopCond node.
        self.g.replace_all_inputs([context.loop_cond], context.loop_cond.output[0], INVALID_INPUT_ID)

        graph_info = GraphInfo(ops, [], outputs)
        graph_info.dependent_vars = dependent_vars
        return graph_info
    def __init__(self, data_input_id, index_input_id, consumer_id, g):
        self.index_input_id = index_input_id
        self.data_input_id = data_input_id

        # tensor array is unstacked before being used in loop, consumer_id is the node
        # (in the iteration body graph) consuming one of the element of tensor array.
        self.consumer = TensorValueInfo(consumer_id, g)
예제 #4
0
    def _crop_loop_condition_sub_graph(self, context):
        input_ids = []
        output_ids = [context.loop_cond.input[0]]
        outputs = [TensorValueInfo(o, self.g) for o in output_ids]
        ops, enter_nodes, merge_nodes = self.find_subgraph(set(input_ids), set(output_ids), self.g, merge_as_end=True)

        for enter_node in enter_nodes:
            # connect Enter's output to Enter's input
            self.g.replace_all_inputs(enter_node.output[0], enter_node.input[0], ops=ops)

        dependent_vars = []
        for merge_node in merge_nodes:
            enter_node = [n for n in merge_node.inputs if n.type == "Enter"][0]
            loop_var = context.loop_properties.all_variables[(enter_node.name, merge_node.name)]

            # cut off connection between condition graph and Merge node.
            # replace condition graph's inputs to be cell graph's outputs, because we want condition graph
            # to consumer cell graph outputs.
            non_switch_consumers = [n for n in self.g.find_output_consumers(merge_node.output[0]) if n.type != "Switch"]
            self.g.replace_all_inputs(merge_node.output[0], loop_var.next_iteration_input.id,
                                      ops=non_switch_consumers)
            dependent_vars.append(loop_var)

        # cut off connection between condition graph and LoopCond node.
        self.g.replace_all_inputs(context.loop_cond.output[0], INVALID_INPUT_ID, ops=[context.loop_cond])

        graph_info = GraphInfo(ops, [], outputs)
        graph_info.dependent_vars = dependent_vars
        return graph_info