Beispiel #1
0
    def _connect_lstm_output_to_graph(self, lstm_node, exit_node, rnn_props):
        exit_consumers = self.g.find_output_consumers(exit_node.output[0])
        gather_node = self._validate_output_exit_consumers(exit_consumers)
        if len(exit_consumers) != 2 or not gather_node:
            log.debug("lstm output exit node has %d consumers",
                      len(exit_consumers))
            raise ValueError("lstm output exit node check failed")

        # gather output for sure has shape [time, batch, hidden]
        gather_output_id = gather_node.output[0]
        log.debug("found output ta gather node %s", gather_output_id)
        # in tf batch major mode, output shape is : [batch, time, hidden]
        # in time major mode, output shape is: [time, batch, hidden]
        # in onnx, output shape is : [time, num_directions, batch, hidden]

        output_id = lstm_node.output[0]
        lstm_output_shape = self.g.get_shape(output_id)
        squeeze_output_shape = [
            lstm_output_shape[0], lstm_output_shape[2], lstm_output_shape[3]
        ]
        squeeze_node = self.g.make_node("Squeeze", [output_id],
                                        attr={"axes": [1]},
                                        shapes=[squeeze_output_shape],
                                        dtypes=[self.g.get_dtype(output_id)])

        if not rnn_props.time_major:
            gather_consumers = self.g.find_output_consumers(gather_output_id)
            gather_trans_consumers = [
                n for n in gather_consumers if check_is_timemajor_transpose(n)
            ]
            if len(gather_trans_consumers) != 1:
                raise ValueError(
                    "batch major should expect a transpose after gather")
            trans = gather_trans_consumers[0]  # trans has rnn scope name

            # we just check the transpose here, but will not re-use it, because
            # it may hold non-const perms. so we re-create a new transpose to replace it
            attr = {"perm": np.array([1, 0, 2], dtype=np.int64)}
            trans_output_shape = [
                squeeze_output_shape[1], squeeze_output_shape[0],
                squeeze_output_shape[2]
            ]
            new_trans = self.g.make_node(
                "Transpose", [squeeze_node.output[0]],
                attr,
                shapes=[trans_output_shape],
                dtypes=[self.g.get_dtype(squeeze_node.output[0])])

            self.g.replace_all_inputs(self.all_nodes, trans.output[0],
                                      new_trans.output[0])
            self.all_nodes.extend([new_trans])

        self.g.replace_all_inputs(self.all_nodes, gather_output_id,
                                  squeeze_node.output[0])
        self.all_nodes.extend([squeeze_node])
    def _convert_timemajor_transpose(self, node):
        if not check_is_timemajor_transpose(node):
            log.debug("not found timemajor transpose")
            return None

        log.debug("found timemajor transpose")

        attr = {"perm": np.array([1, 0, 2], dtype=np.int64)}
        new_trans = make_onnx_node(self.g, "Transpose", [node.input[0]], attr)

        self.g.copy_shape(node.output[0], new_trans.output[0])
        self.g.replace_all_inputs(self.g.get_nodes(), node.output[0], new_trans.output[0])
        return new_trans