def conv2d(expr, type_map): """Rewrite a conv2d op""" attrs = {**expr.attrs} attrs.pop("out_dtype") x, weight = expr.args x_t = type_map[x] w_t = type_map[weight] conv_scale = fold_constant(x_t.scale * w_t.scale) conv_zp = get_zeros(conv_scale) out = relay.qnn.op.conv2d( x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs ) out_layout = attrs["out_layout"] if attrs["out_layout"] != "" else attrs["data_layout"] out_axis = bijective_layout(out_layout, "NCHW").backward_index(list(range(4)))[1] return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype, out_axis.value)]
def get_shape(src_shape, src_layout, dst_layout): """Given a source shape, a source layout and a destination layout, infer the destination shape. Parameter --------- src_shape : tuple of int or IntImm Source shape src_layout : str or Layout Source layout dst_layout : str or Layout Destination layout Returns ------- dst_shape : tuple of int Destination shape """ if src_layout == dst_layout: return get_const_tuple(src_shape) if isinstance(src_layout, str): src_layout = layout(src_layout) if isinstance(dst_layout, str): dst_layout = layout(dst_layout) assert len(src_layout) == len( dst_layout), "Incompatible layout %s vs %s" % ( src_layout, dst_layout, ) layout_mapping = bijective_layout(src_layout, dst_layout) dst_indices = layout_mapping.forward_index( tvm.runtime.convert(list(range(len(src_layout))))) return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices]))