コード例 #1
0
    def _get_loop_var_from_switch(self, switch_node):
        if switch_node.type != 'Switch':
            log.error("not a switch node, skip")
            return None

        # the first input is data
        merge_node = switch_node.inputs[0]
        if merge_node.type != "Merge":
            log.error("switch node does not has Merge as its first input")
            return None

        # find the output_true consumers
        switch_consumers = self.g.find_output_consumers(switch_node.output[1])
        if len(switch_consumers) != 1:
            raise ValueError("switch has non-1 consumers")

        if switch_consumers[0].type != "Identity":
            raise ValueError("switch has consumer that is not Identity")
        identity_node = switch_consumers[0]

        target_node_input_id = None
        enter_node = [n for n in merge_node.inputs if n.type == 'Enter'][0]
        target_node_input_id = enter_node.input[0]
        log.debug("a Switch >> Merge >> Enter is found called %s",
                  enter_node.inputs[0].name)

        next_iteration_node = [
            n for n in merge_node.inputs if n.type == 'NextIteration'
        ][0]
        last_iteration_output_id = next_iteration_node.input[0]

        # find the output_false consumers to see whether there is consumer for this var
        switch_false_consumers = self.g.find_output_consumers(
            switch_node.output[0])
        false_consumer_count = len(switch_false_consumers)
        exit_output_id = None
        if false_consumer_count == 1:
            exit_node = switch_false_consumers[0]
            if exit_node.type != "Exit":
                raise ValueError("switch false branch is followed by non-Exit")
            exit_output_id = exit_node.output[0]
        elif false_consumer_count == 0:
            exit_output_id = None
        else:
            raise ValueError("unexpected number of switch false consumers")

        is_ta = False
        if is_tensor_array_op(self.g.get_node_by_output(target_node_input_id)):
            is_ta = True

        loop_var = LoopVariable(enter_node.name, target_node_input_id,
                                last_iteration_output_id,
                                identity_node.output[0], exit_output_id, is_ta)
        loop_var = self._tune_shape_for_loop_var(loop_var)
        loop_var = self._tune_shape_for_loop_ta_var(loop_var)
        return loop_var
コード例 #2
0
 def _output_switch_check(self, enter_target_node_input_id, identity_consumers, match):
     ta_write_nodes = [c for c in identity_consumers if is_tensor_array_write_op(c)]
     if len(ta_write_nodes) == 1:
         enter_target_node = self.g.get_node_by_output(enter_target_node_input_id)
         if is_tensor_array_op(enter_target_node):
             log.debug("found output switch node")
             return enter_target_node_input_id
         log.debug("found enter target node is not ta node")
         return None
     log.debug("%d TensorArrayWriteV3 matching found, cannot validate output switch", len(ta_write_nodes))
     return None
コード例 #3
0
    def _get_loop_var_from_switch(self, switch_node):
        if switch_node.type != 'Switch':
            log.error("not a switch node, skip")
            return None

        # the first input is data
        merge_node = switch_node.inputs[0]
        if merge_node.type != "Merge":
            log.error("switch node does not has Merge as its first input")
            return None

        # find the output_true consumers
        switch_consumers = self.g.find_output_consumers(switch_node.output[1])
        switch_true_consumer_cnt = len(switch_consumers)
        if switch_true_consumer_cnt == 0:
            switch_true_identity_output = None
        elif switch_true_consumer_cnt == 1:
            if switch_consumers[0].type != "Identity":
                raise ValueError("switch has consumer that is not Identity")
            switch_true_identity_output = switch_consumers[0].output[0]
        else:
            raise ValueError("switch_true " + switch_node.name + " has unexpected count of consumers:",
                             [n.name for n in switch_consumers])

        target_node_input_id = None
        enter_node = [n for n in merge_node.inputs if n.type == 'Enter'][0]
        target_node_input_id = enter_node.input[0]
        log.debug("a Switch >> Merge >> Enter is found called %s", enter_node.inputs[0].name)

        next_iteration_node = [n for n in merge_node.inputs if n.type == 'NextIteration'][0]
        last_iteration_output_id = next_iteration_node.input[0]

        # find the output_false consumers to see whether there is consumer for this var
        switch_false_consumers = self.g.find_output_consumers(switch_node.output[0])
        false_consumer_count = len(switch_false_consumers)
        exit_output_id = None
        if false_consumer_count == 1:
            exit_node = switch_false_consumers[0]
            if exit_node.type != "Exit":
                raise ValueError("switch false branch is followed by non-Exit")
            exit_output_id = exit_node.output[0]
        elif false_consumer_count == 0:
            # sometime, the variable output won't be used in the new iteration as input.
            exit_output_id = None
        else:
            raise ValueError("unexpected number of switch false consumers")

        is_ta = False
        ta_index_id = None
        if is_tensor_array_op(self.g.get_node_by_output(target_node_input_id)):
            is_ta = True

            ta_write_node = self.g.get_node_by_output(last_iteration_output_id)
            utils.make_sure(is_tensor_array_write_op(ta_write_node), "ta nextiteration is not following ta write op")
            last_iteration_output_id = ta_write_node.input[2]
            ta_index_id = ta_write_node.input[1]

            # here we parse patterns generated by
            # ta.write(), then ta.stack(), because this is the most frequent usage pattern.
            if exit_output_id:
                exit_consumers = self.g.find_output_consumers(exit_output_id)
                ta_gather_node = [n for n in exit_consumers if is_tensor_array_gather_op(n)][0]

                # update exit output id, treat the gather output as ta's output
                exit_output_id = ta_gather_node.output[0]

        loop_var = LoopVariable(enter_node.name, target_node_input_id, last_iteration_output_id,
                                switch_true_identity_output, exit_output_id, is_ta, ta_index_id, self.g)

        return loop_var