Beispiel #1
0
    def _validate_output_exit_consumers(exit_consumers):
        if len(exit_consumers) != 2:
            return None

        gather_node = None
        for n in exit_consumers:
            if is_tensor_array_gather_op(n):
                gather_node = n
            elif is_tensor_array_size_op(n):
                continue
            else:
                return None
        return gather_node
Beispiel #2
0
    def _parse_output_ta(self, context):
        for enter_name, loop_var in context.loop_variables.items():
            if not loop_var.is_tensor_array:
                continue

            output_ta = TensorArrayProp()
            output_ta.data_input_id = loop_var.next_iteration_input_id

            output_ta.index_input_id = loop_var.ta_index_id
            if loop_var.exit_output_id:
                exit_consumers = self.g.find_output_consumers(loop_var.exit_output_id)
                ta_gather_node = [n for n in exit_consumers if is_tensor_array_gather_op(n)][0]
                output_ta.output_id = ta_gather_node.output[0]

            context.output_tas.append(output_ta)
            log.debug("output ta %s - data input (%s) shape: %s, output (%s) shape: %s", enter_name,
                      output_ta.data_input_id, self.g.get_shape(output_ta.data_input_id),
                      output_ta.output_id, self.g.get_shape(output_ta.output_id))
    def _get_loop_var_from_switch(self, switch_node):
        if switch_node.type != 'Switch':
            log.error("not a switch node, skip")
            return None

        # the first input is data
        merge_node = switch_node.inputs[0]
        if merge_node.type != "Merge":
            log.error("switch node does not has Merge as its first input")
            return None

        # find the output_true consumers
        switch_consumers = self.g.find_output_consumers(switch_node.output[1])
        switch_true_consumer_cnt = len(switch_consumers)
        if switch_true_consumer_cnt == 0:
            switch_true_identity_output = None
        elif switch_true_consumer_cnt == 1:
            if switch_consumers[0].type != "Identity":
                raise ValueError("switch has consumer that is not Identity")
            switch_true_identity_output = switch_consumers[0].output[0]
        else:
            raise ValueError("switch_true " + switch_node.name + " has unexpected count of consumers:",
                             [n.name for n in switch_consumers])

        target_node_input_id = None
        enter_node = [n for n in merge_node.inputs if n.type == 'Enter'][0]
        target_node_input_id = enter_node.input[0]
        log.debug("a Switch >> Merge >> Enter is found called %s", enter_node.inputs[0].name)

        next_iteration_node = [n for n in merge_node.inputs if n.type == 'NextIteration'][0]
        last_iteration_output_id = next_iteration_node.input[0]

        # find the output_false consumers to see whether there is consumer for this var
        switch_false_consumers = self.g.find_output_consumers(switch_node.output[0])
        false_consumer_count = len(switch_false_consumers)
        exit_output_id = None
        if false_consumer_count == 1:
            exit_node = switch_false_consumers[0]
            if exit_node.type != "Exit":
                raise ValueError("switch false branch is followed by non-Exit")
            exit_output_id = exit_node.output[0]
        elif false_consumer_count == 0:
            # sometime, the variable output won't be used in the new iteration as input.
            exit_output_id = None
        else:
            raise ValueError("unexpected number of switch false consumers")

        is_ta = False
        ta_index_id = None
        if is_tensor_array_op(self.g.get_node_by_output(target_node_input_id)):
            is_ta = True

            ta_write_node = self.g.get_node_by_output(last_iteration_output_id)
            utils.make_sure(is_tensor_array_write_op(ta_write_node), "ta nextiteration is not following ta write op")
            last_iteration_output_id = ta_write_node.input[2]
            ta_index_id = ta_write_node.input[1]

            # here we parse patterns generated by
            # ta.write(), then ta.stack(), because this is the most frequent usage pattern.
            if exit_output_id:
                exit_consumers = self.g.find_output_consumers(exit_output_id)
                ta_gather_node = [n for n in exit_consumers if is_tensor_array_gather_op(n)][0]

                # update exit output id, treat the gather output as ta's output
                exit_output_id = ta_gather_node.output[0]

        loop_var = LoopVariable(enter_node.name, target_node_input_id, last_iteration_output_id,
                                switch_true_identity_output, exit_output_id, is_ta, ta_index_id, self.g)

        return loop_var