def _cut_off_connection_for_cell(self, context):
        nodes_to_remove = []
        all_vars = [context.time_var]
        all_vars += [val for _, val in context.other_loop_vars.items()]
        for val in all_vars:
            # remove the node to cut off a starting node of the cell (e.g. loop body).
            nodes_to_remove.append(
                self.g.get_node_by_name(val.switch_true_identity_output_id))

            # connect NextIteration to an invalid node, to cut off a ending node of the cell.
            next_iter_nodes = [
                n for n in self.g.get_nodes() if n.type == "NextIteration"
            ]
            self.g.replace_all_inputs(next_iter_nodes,
                                      val.next_iteration_input_id,
                                      INVLAID_INPUT_ID)

        for input_ta in context.input_tas:
            # remove the node to cut off connection between scan_input and the cell.
            nodes_to_remove.append(self.g.get_node_by_name(input_ta.output_id))

        for output_ta in context.output_tas:
            # remove the node to cut off connection between scan_output and the cell.
            ta_write_nodes = [
                n for n in self.g.get_nodes() if is_tensor_array_write_op(n)
            ]
            self.g.replace_all_inputs(ta_write_nodes, output_ta.data_input_id,
                                      INVLAID_INPUT_ID)

        return nodes_to_remove
Example #2
0
    def _tune_shape_for_loop_ta_var(self, loop_var):
        if loop_var.is_tensor_array:
            ta_write_node = self.g.get_node_by_output(
                loop_var.next_iteration_input_id)
            if not is_tensor_array_write_op(ta_write_node):
                raise ValueError(
                    "ta var nextiteration is not following ta write op")

            loop_var.next_iteration_input_id = ta_write_node.input[2]
            loop_var.ta_index_id = ta_write_node.input[1]

            ta_output_shape = None
            next_iteration_shape = self.g.get_shape(
                loop_var.next_iteration_input_id)
            if next_iteration_shape is None:
                enter_node = ta_write_node.inputs[0]
                ta_node_output = enter_node.input[0]
                ta_element_shape = self.g.get_shape(ta_node_output)
                ta_output_shape = ta_element_shape
                log.debug(
                    "loop var [%s, %s] output shapes are inferred from TA element shape",
                    loop_var.enter_name, loop_var.enter_input_id)
            else:
                log.debug(
                    "loop var [%s, %s] output shapes are inferred from cell output %s",
                    loop_var.enter_name, loop_var.enter_input_id,
                    loop_var.next_iteration_input_id)
                ta_output_shape = next_iteration_shape

            self.g.set_shape(loop_var.next_iteration_input_id, ta_output_shape)
            self.g.set_shape(loop_var.switch_true_identity_output_id,
                             ta_output_shape)
            self.g.set_shape(loop_var.exit_output_id, ta_output_shape)

        return loop_var
Example #3
0
    def _cut_off_connection_for_cell(self, context):
        for val in context.loop_properties.all_variables.values():
            if val.switch_true_identity_output.id:
                # remove the node to cut off a starting node of the cell (e.g. loop body).
                n = self.g.get_node_by_output(
                    val.switch_true_identity_output.id)
                self.g.remove_node(n.name)

            if val.is_tensor_array:
                # connect NextIteration to an invalid node, to cut off an ending node of the cell.
                ta_write_nodes = [
                    n for n in self.g.get_nodes()
                    if is_tensor_array_write_op(n)
                ]
                self.g.replace_all_inputs(ta_write_nodes,
                                          val.next_iteration_input.id,
                                          INVALID_INPUT_ID)
            else:
                # connect NextIteration to an invalid node, to cut off an ending node of the cell.
                next_iter_nodes = [
                    n for n in self.g.get_nodes() if n.type == "NextIteration"
                ]
                self.g.replace_all_inputs(next_iter_nodes,
                                          val.next_iteration_input.id,
                                          INVALID_INPUT_ID)

        for scan_input in context.loop_properties.scan_inputs:
            # remove the node to cut off connection between scan_input and the cell.
            self.g.remove_node(self.g.get_node_by_output(scan_input.id).name)
Example #4
0
    def _tune_shape_for_loop_ta_var(self, loop_var):
        if loop_var.is_tensor_array:
            ta_write_node = self.g.get_node_by_name(
                loop_var.next_iteration_input_id)
            if not is_tensor_array_write_op(ta_write_node):
                raise ValueError(
                    "ta var nextiteration is not following ta write op")

            loop_var.next_iteration_input_id = ta_write_node.input[2]
            loop_var.ta_index_id = ta_write_node.input[1]

            log.debug(
                "loop var [%s, %s] output shapes are inferred from TA element shape",
                loop_var.enter_name, loop_var.enter_input_id)
            enter_node = ta_write_node.inputs[0]
            output_ta_node = enter_node.inputs[0]
            log.debug(self.g.get_shape(output_ta_node.output[0]))
            self.g.copy_shape(output_ta_node.output[0],
                              loop_var.next_iteration_input_id)
            self.g.copy_shape(output_ta_node.output[0],
                              loop_var.switch_true_identity_output_id)
            self.g.copy_shape(output_ta_node.output[0],
                              loop_var.exit_output_id)

        return loop_var
Example #5
0
 def _output_switch_check(self, enter_target_node_input_id, identity_consumers, match):
     ta_write_nodes = [c for c in identity_consumers if is_tensor_array_write_op(c)]
     if len(ta_write_nodes) == 1:
         enter_target_node = self.g.get_node_by_output(enter_target_node_input_id)
         if is_tensor_array_op(enter_target_node):
             log.debug("found output switch node")
             return enter_target_node_input_id
         log.debug("found enter target node is not ta node")
         return None
     log.debug("%d TensorArrayWriteV3 matching found, cannot validate output switch", len(ta_write_nodes))
     return None
    def _find_state_variable_with_select(self, context,
                                         next_iteration_input,
                                         switch_true_identity_consumers):
        """
        Find state variables from switch_true_identity_consumers to next_iteration_input.
        Select maybe added after next_iteration_input.
        """
        # find all select not followed by TensorArrayWrite
        select = []
        for c in self.g.find_output_consumers(next_iteration_input):
            if not is_select_op(c):
                continue
            out_ta_writer = [
                o for o in self.g.find_output_consumers(c.output[0]) if is_tensor_array_write_op(o)
            ]
            if out_ta_writer:
                continue
            select.append(c)
        if len(select) == 1:
            next_iteration_input = select[0].output[0]
            switch_true_identity_consumers.append(select[0])

        log.debug(
            "try to find state variable from [%s, %s]",
            next_iteration_input,
            switch_true_identity_consumers
        )

        def checker(state_variable):
            if state_variable.next_iteration_input.id != next_iteration_input:
                return False
            for consumer in switch_true_identity_consumers:
                if state_variable.switch_true_identity_output.id not in consumer.input:
                    return False
            return True

        state_variables = context.loop_properties.get_variables(checker)
        if len(state_variables) != 1:
            log.debug("found %d state variables", len(state_variables))
            return None
        return state_variables[0]
    def _get_loop_var_from_switch(self, switch_node):
        if switch_node.type != 'Switch':
            log.error("not a switch node, skip")
            return None

        # the first input is data
        merge_node = switch_node.inputs[0]
        if merge_node.type != "Merge":
            log.error("switch node does not has Merge as its first input")
            return None

        # find the output_true consumers
        switch_consumers = self.g.find_output_consumers(switch_node.output[1])
        switch_true_consumer_cnt = len(switch_consumers)
        if switch_true_consumer_cnt == 0:
            switch_true_identity_output = None
        elif switch_true_consumer_cnt == 1:
            if switch_consumers[0].type != "Identity":
                raise ValueError("switch has consumer that is not Identity")
            switch_true_identity_output = switch_consumers[0].output[0]
        else:
            raise ValueError("switch_true " + switch_node.name + " has unexpected count of consumers:",
                             [n.name for n in switch_consumers])

        target_node_input_id = None
        enter_node = [n for n in merge_node.inputs if n.type == 'Enter'][0]
        target_node_input_id = enter_node.input[0]
        log.debug("a Switch >> Merge >> Enter is found called %s", enter_node.inputs[0].name)

        next_iteration_node = [n for n in merge_node.inputs if n.type == 'NextIteration'][0]
        last_iteration_output_id = next_iteration_node.input[0]

        # find the output_false consumers to see whether there is consumer for this var
        switch_false_consumers = self.g.find_output_consumers(switch_node.output[0])
        false_consumer_count = len(switch_false_consumers)
        exit_output_id = None
        if false_consumer_count == 1:
            exit_node = switch_false_consumers[0]
            if exit_node.type != "Exit":
                raise ValueError("switch false branch is followed by non-Exit")
            exit_output_id = exit_node.output[0]
        elif false_consumer_count == 0:
            # sometime, the variable output won't be used in the new iteration as input.
            exit_output_id = None
        else:
            raise ValueError("unexpected number of switch false consumers")

        is_ta = False
        ta_index_id = None
        if is_tensor_array_op(self.g.get_node_by_output(target_node_input_id)):
            is_ta = True

            ta_write_node = self.g.get_node_by_output(last_iteration_output_id)
            utils.make_sure(is_tensor_array_write_op(ta_write_node), "ta nextiteration is not following ta write op")
            last_iteration_output_id = ta_write_node.input[2]
            ta_index_id = ta_write_node.input[1]

            # here we parse patterns generated by
            # ta.write(), then ta.stack(), because this is the most frequent usage pattern.
            if exit_output_id:
                exit_consumers = self.g.find_output_consumers(exit_output_id)
                ta_gather_node = [n for n in exit_consumers if is_tensor_array_gather_op(n)][0]

                # update exit output id, treat the gather output as ta's output
                exit_output_id = ta_gather_node.output[0]

        loop_var = LoopVariable(enter_node.name, target_node_input_id, last_iteration_output_id,
                                switch_true_identity_output, exit_output_id, is_ta, ta_index_id, self.g)

        return loop_var