def find_sequence_length_node(self, context):
     # get any state variable
     state_variable = list(context.state_variables.values())[0]
     next_iter_input_node = self.g.get_node_by_output(
         state_variable.next_iteration_input.id)
     if not is_tf_select_op(next_iter_input_node):
         logger.debug("no sequence length node is given")
         return None
     matcher = GraphMatcher(seq_len_pattern)
     match_result = matcher.match_op(next_iter_input_node)
     if not match_result:
         raise RuntimeError("failed to find sequence length.")
     return match_result.get_op("seq_len_node")
Ejemplo n.º 2
0
    def _find_state_variable_with_select(self, context,
                                         next_iteration_input,
                                         switch_true_identity_consumers):
        """
        Find state variables from switch_true_identity_consumers to next_iteration_input.
        Select maybe added after next_iteration_input.
        """
        # find all select not followed by TensorArrayWrite
        select = []
        for c in self.g.find_output_consumers(next_iteration_input):
            if not is_tf_select_op(c):
                continue
            out_ta_writer = [
                o for o in self.g.find_output_consumers(c.output[0]) if is_tf_tensor_array_write_op(o)
            ]
            if out_ta_writer:
                continue
            select.append(c)
        if len(select) == 1:
            next_iteration_input = select[0].output[0]
            switch_true_identity_consumers.append(select[0])

        logger.debug(
            "try to find state variable from [%s, %s]",
            next_iteration_input,
            switch_true_identity_consumers
        )

        def checker(state_variable):
            if state_variable.next_iteration_input.id != next_iteration_input:
                return False
            for consumer in switch_true_identity_consumers:
                if state_variable.switch_true_identity_output.id not in consumer.input:
                    return False
            return True

        state_variables = context.loop_properties.get_variables(checker)
        if len(state_variables) != 1:
            logger.debug("found %d state variables", len(state_variables))
            return None
        return state_variables[0]