def _connect_lstm_output_to_graph(self, lstm_node, exit_node, rnn_props): exit_consumers = self.g.find_output_consumers(exit_node.output[0]) gather_node = self._validate_output_exit_consumers(exit_consumers) if len(exit_consumers) != 2 or not gather_node: log.debug("lstm output exit node has %d consumers", len(exit_consumers)) raise ValueError("lstm output exit node check failed") # gather output for sure has shape [time, batch, hidden] gather_output_id = gather_node.output[0] log.debug("found output ta gather node %s", gather_output_id) # in tf batch major mode, output shape is : [batch, time, hidden] # in time major mode, output shape is: [time, batch, hidden] # in onnx, output shape is : [time, num_directions, batch, hidden] output_id = lstm_node.output[0] lstm_output_shape = self.g.get_shape(output_id) squeeze_output_shape = [ lstm_output_shape[0], lstm_output_shape[2], lstm_output_shape[3] ] squeeze_node = self.g.make_node("Squeeze", [output_id], attr={"axes": [1]}, shapes=[squeeze_output_shape], dtypes=[self.g.get_dtype(output_id)]) if not rnn_props.time_major: gather_consumers = self.g.find_output_consumers(gather_output_id) gather_trans_consumers = [ n for n in gather_consumers if check_is_timemajor_transpose(n) ] if len(gather_trans_consumers) != 1: raise ValueError( "batch major should expect a transpose after gather") trans = gather_trans_consumers[0] # trans has rnn scope name # we just check the transpose here, but will not re-use it, because # it may hold non-const perms. so we re-create a new transpose to replace it attr = {"perm": np.array([1, 0, 2], dtype=np.int64)} trans_output_shape = [ squeeze_output_shape[1], squeeze_output_shape[0], squeeze_output_shape[2] ] new_trans = self.g.make_node( "Transpose", [squeeze_node.output[0]], attr, shapes=[trans_output_shape], dtypes=[self.g.get_dtype(squeeze_node.output[0])]) self.g.replace_all_inputs(self.all_nodes, trans.output[0], new_trans.output[0]) self.all_nodes.extend([new_trans]) self.g.replace_all_inputs(self.all_nodes, gather_output_id, squeeze_node.output[0]) self.all_nodes.extend([squeeze_node])
def _convert_timemajor_transpose(self, node): if not check_is_timemajor_transpose(node): log.debug("not found timemajor transpose") return None log.debug("found timemajor transpose") attr = {"perm": np.array([1, 0, 2], dtype=np.int64)} new_trans = make_onnx_node(self.g, "Transpose", [node.input[0]], attr) self.g.copy_shape(node.output[0], new_trans.output[0]) self.g.replace_all_inputs(self.g.get_nodes(), node.output[0], new_trans.output[0]) return new_trans