def convert_reshape(self, op): """Convert TFLite reshape""" try: from tflite.BuiltinOptions import BuiltinOptions from tflite.Operator import Operator from tflite.ReshapeOptions import ReshapeOptions except ImportError: raise ImportError("The tflite package must be installed") assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" input_tensor = input_tensors[0] input_tensor_idx = input_tensor.tensor_idx assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions op_options = op.BuiltinOptions() reshape_options = ReshapeOptions() reshape_options.Init(op_options.Bytes, op_options.Pos) target_shape = reshape_options.NewShapeAsNumpy() in_expr = self.get_expr(input_tensor_idx) out = _op.reshape(in_expr, newshape=tuple(target_shape)) return out
def convert_reshape(self, op): """Convert TFLite reshape""" try: from tflite.BuiltinOptions import BuiltinOptions from tflite.Operator import Operator from tflite.ReshapeOptions import ReshapeOptions except ImportError: raise ImportError("The tflite package must be installed") assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" input_tensor = input_tensors[0] input_tensor_idx = input_tensor.tensor_idx assert op.BuiltinOptionsType() == BuiltinOptions.ReshapeOptions op_options = op.BuiltinOptions() reshape_options = ReshapeOptions() reshape_options.Init(op_options.Bytes, op_options.Pos) target_shape = reshape_options.NewShapeAsNumpy() input_shape_length = len(input_tensor.tensor.ShapeAsNumpy()) in_expr = self.get_expr(input_tensor_idx) if input_shape_length in (1, 2): # The rule is channel first (after N but before H, W). # length of 1 means N*H*W*C, do nothing. # length of 2 means N*H*W, C, do nothing. pass elif input_shape_length == 3: # convert N C H*W to N H*W C in_expr = _op.transpose(in_expr, axes=(0, 2, 1)) elif input_shape_length == 4: # convert input to N H W C, then reshape to target shape, # finally convert back if necessary in_expr = _op.transpose(in_expr, axes=(0, 2, 3, 1)) else: msg = 'Input shape length {} for operator Reshape is not valid.' raise tvm.error.OpAttributeInvalid(msg.format(input_shape_length)) out = _op.reshape(in_expr, newshape=tuple(target_shape)) # The rule is channel first. # 1: N*H*W*C # 2: N*H*W, C # 3: N H W C, reshape to N H*W C, transpose to N C H*W # 4: N H W C, transpose to N C H W # add more if we need target shapes in future if len(target_shape) == 1 or len(target_shape) == 2: pass elif len(target_shape) == 3: out = _op.transpose(out, axes=(0, 2, 1)) elif len(target_shape) == 4: out = _op.transpose(out, axes=(0, 3, 1, 2)) else: raise tvm.error.OpAttributeInvalid( 'Length of target shape must be between 1 and 5 for operator Reshape.' ) return out
class Reshape(Layer): def __init__(self, op, op_type, tflite_interpreter): Layer.__init__(self, op, op_type, tflite_interpreter) if op.BuiltinOptions() is None: self.tflite_reshape_parser = None else: self.tflite_reshape_parser = ReshapeOptions() self.tflite_reshape_parser.Init(op.BuiltinOptions().Bytes, op.BuiltinOptions().Pos) def generate(self): node_output_detail = self.tflite_interpreter._get_tensor_details( self.op.Outputs(0)) node_input_detail = self.tflite_interpreter._get_tensor_details( self.op.Inputs(0)) out_dim = node_output_detail['shape'] in_dim = node_input_detail['shape'] dims = list(range(len(in_dim))) dims = dims[:1] + dims[2:] + dims[1:2] # add transpose transpose_before_node_name = 'transpose_node_before_reshape_' + self.node_name transpose_before_node = onnx.helper.make_node( 'Transpose', inputs=self.input_nodes_name, outputs=[transpose_before_node_name], perm=dims, name=transpose_before_node_name) # update tables self.node_list.append(transpose_before_node) reshape_node_name = self.node_name shape_tensor_name = 'shape_tensor_' + self.node_name shape_node_name = 'shape_const_' + self.node_name # no attribute 'new_shape', should be op 'squeeze' if self.tflite_reshape_parser is None: new_shape = np.array(out_dim, dtype='int64') elif not self.tflite_reshape_parser.NewShapeIsNone(): new_shape = np.array(self.tflite_reshape_parser.NewShapeAsNumpy(), dtype='int64') shape_tensor = onnx.helper.make_tensor(shape_tensor_name, TensorProto.INT64, new_shape.shape, new_shape) shape_node = helper.make_node("Constant", [], [shape_node_name], name=shape_node_name, value=shape_tensor) reshape_node = onnx.helper.make_node( 'Reshape', inputs=[transpose_before_node_name, shape_node_name], outputs=[reshape_node_name], name=reshape_node_name) # update tables self.node_list.append(shape_node) self.node_list.append(reshape_node) # no attribute 'new_shape', if self.tflite_reshape_parser is None: pass elif self.tflite_reshape_parser.NewShapeIsNone: dims = list(range(len(out_dim))) dims = dims[:1] + dims[-1:] + dims[1:-1] # add transpose transpose_after_node_name = 'transpose_node_after_reshape_' + self.node_name transpose_after_node = onnx.helper.make_node( 'Transpose', inputs=[reshape_node_name], outputs=[transpose_after_node_name], perm=dims, name=transpose_after_node_name) # update tables self.node_list.append(transpose_after_node) # change output node's input_name for o_n in self.output_nodes: for idx, o_n_i_n in enumerate(o_n.input_nodes_name): if o_n_i_n == self.node_name: o_n.input_nodes_name[idx] = transpose_after_node_name return self.node_list, self.value_infos, self.weight_node_list