def make_ethosu_unary_elementwise( ifm, ofm_channels, operator_type, activation="NONE", ifm_layout="NHWC", ofm_layout="NHWC", rounding_mode="TFL", ): ethosu_unary_elementwise = ethosu_ops.ethosu_unary_elementwise( ifm=ifm, lut=relay.const([], dtype="int8"), operator_type=operator_type, ifm_scale=1, ifm_zero_point=0, ofm_scale=1, ofm_zero_point=0, ofm_channels=ofm_channels, activation=activation, clip_min=10 if activation == "CLIP" else 0, clip_max=100 if activation == "CLIP" else 0, rounding_mode=rounding_mode, ifm_layout=ifm_layout, ofm_layout=ofm_layout, ) return ethosu_unary_elementwise
def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map) -> tvm.relay.Expr: params = self.params_class(post.op.body) params.ifm.tensor = post.args[0] if str(params.ofm.layout) != "NHWC": raise UnsupportedLayout(str(params.ofm.layout)) activation_map = {"clip": "CLIP"} if params.activation: activation = activation_map[params.activation.op.name] clip_min = int(params.activation.attrs.a_min) clip_max = int(params.activation.attrs.a_max) else: activation = "NONE" clip_min = 0 clip_max = 0 # We don't yet support activation functions that use LUT. lut = relay.const([], dtype="int8") unary_input_shape = params.ifm.shape # If the input tensor is not 4D, enter reshapes before and after the unary operator if len(params.ifm.shape) == 4: unary_input = params.ifm.tensor else: pad_size = 4 - len(unary_input_shape) unary_input_shape = ([1] * pad_size) + unary_input_shape unary_input = relay.op.reshape(params.ifm.tensor, newshape=unary_input_shape) ethosu_unary_elementwise = ethosu_ops.ethosu_unary_elementwise( ifm=unary_input, lut=lut, operator_type=params.operator_type, ifm_scale=float(params.ifm.q_params.scale_f32), ifm_zero_point=int(params.ifm.q_params.zero_point), ofm_scale=float(params.ofm.q_params.scale_f32), ofm_zero_point=int(params.ofm.q_params.zero_point), ofm_channels=unary_input_shape[3], activation=activation, clip_min=clip_min, clip_max=clip_max, ifm_layout=str(params.ifm.layout), ofm_layout=str(params.ofm.layout), ) if len(params.ifm.shape) == 4: op = ethosu_unary_elementwise else: op = relay.op.reshape(ethosu_unary_elementwise, newshape=params.ifm.shape) return op