def parse_options(self): options = self._flatbuf_op.BuiltinOptions() o = ResizeBilinearOptions() o.Init(options.Bytes, options.Pos) self.align_corners = o.AlignCorners() self._flatbuf_options_obj = options self._supported_options = ['align_corners']
def _convert_resize(self, method, op): """Generic method to Convert TFLite RESIZE operators""" try: from tflite.BuiltinOptions import BuiltinOptions from tflite.Operator import Operator from tflite.ResizeBilinearOptions import ResizeBilinearOptions # ResizeNearestNeighborOptions was added in tflite v1.13 tflite_ver = 1120 if 'ResizeNearestNeighborOptions' in dir(BuiltinOptions): from tflite.ResizeNearestNeighborOptions import ResizeNearestNeighborOptions tflite_ver = 1130 except ImportError: raise ImportError("The tflite package must be installed") assert isinstance(op, Operator) input_tensors = self.get_input_tensors(op) assert len(input_tensors) == 2, "input tensors length should be 2" # images, 4-D Tensor with shape NHWC. input_tensor = input_tensors[0] in_expr = self.get_expr(input_tensor.tensor_idx) # size - 1-D int32 Tensor of 2 elements: new_height, new_width target_size = tuple(self.get_tensor_value(input_tensors[1])) # Options - align_corners (bool) resize_options = None align_corners = False if method == "BILINEAR": assert op.BuiltinOptionsType( ) == BuiltinOptions.ResizeBilinearOptions resize_options = ResizeBilinearOptions() elif tflite_ver >= 1130: assert op.BuiltinOptionsType( ) == BuiltinOptions.ResizeNearestNeighborOptions resize_options = ResizeNearestNeighborOptions() if resize_options is not None: op_options = op.BuiltinOptions() resize_options.Init(op_options.Bytes, op_options.Pos) align_corners = resize_options.AlignCorners() # Use layout NHWC out = _op.image.resize(in_expr, target_size, "NHWC", method, align_corners) return out
class ResizeBilinear(Layer): def __init__(self, op, op_type, tflite_interpreter): Layer.__init__(self, op, op_type, tflite_interpreter) self.tflite_resize_bilinear_parser = ResizeBilinearOptions() self.tflite_resize_bilinear_parser.Init(self.op.BuiltinOptions().Bytes, self.op.BuiltinOptions().Pos) def generate(self): node_output_detail = self.tflite_interpreter._get_tensor_details( self.op.Outputs(0)) node_input_detail = self.tflite_interpreter._get_tensor_details( self.op.Inputs(0)) # create scale constant node tensor_input_detail = self.tflite_interpreter._get_tensor_details( self.op.Inputs(1)) source_width, source_height = node_input_detail['shape'].tolist()[1:3] target_width, targwt_height = self.tflite_interpreter.get_tensor( tensor_input_detail['index']).tolist() source_size = np.array([1.0, 1.0, source_height, source_width], dtype=np.float) target_siz = np.array([1.0, 1.0, targwt_height, target_width], dtype=np.float) scale_val = target_siz / source_size scale_constant_node = tflite_utils.create_constant_node( self.node_name + '_scales', scale_val.shape, scale_val) constant_info = onnx.helper.make_tensor_value_info( name=scale_constant_node.name, elem_type=TensorProto.FLOAT, shape=scale_val.shape) self.node_list.append(scale_constant_node) self.value_infos.append(constant_info) # create roi constant node roi_constant_node = tflite_utils.create_constant_node( self.node_name + 'resize_roi', [], np.array([-1], dtype=np.float32)) self.node_list.append(roi_constant_node) previous_onnx_node_names = self.input_nodes_name.copy() previous_onnx_node_names.extend( [roi_constant_node.name, scale_constant_node.name]) resize_nearest_neighbor_node = onnx.helper.make_node( op_type='Resize', inputs=previous_onnx_node_names, outputs=[self.node_name], name=self.node_name, mode='linear', coordinate_transformation_mode='align_corners' if self.tflite_resize_bilinear_parser.AlignCorners() == True else 'half_pixel') resize_nearest_neighbor_info = onnx.helper.make_tensor_value_info( name=self.node_name, elem_type=TensorProto.FLOAT, shape=tflite_utils.tflite2onnx_shape_map( node_output_detail['shape'].tolist())) # update tables self.node_list.append(resize_nearest_neighbor_node) self.value_infos.append(resize_nearest_neighbor_info) return self.node_list, self.value_infos, self.weight_node_list