Beispiel #1
0
def conv2d(expr, type_map):
    """Rewrite a conv2d op"""
    attrs = {**expr.attrs}
    attrs.pop("out_dtype")
    x, weight = expr.args
    x_t = type_map[x]
    w_t = type_map[weight]
    conv_scale = fold_constant(x_t.scale * w_t.scale)
    conv_zp = get_zeros(conv_scale)
    out = relay.qnn.op.conv2d(
        x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs
    )
    out_layout = attrs["out_layout"] if attrs["out_layout"] != "" else attrs["data_layout"]
    out_axis = bijective_layout(out_layout, "NCHW").backward_index(list(range(4)))[1]
    return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype, out_axis.value)]
Beispiel #2
0
def get_shape(src_shape, src_layout, dst_layout):
    """Given a source shape, a source layout and a destination layout, infer
    the destination shape.

    Parameter
    ---------
    src_shape : tuple of int or IntImm
        Source shape

    src_layout : str or Layout
        Source layout

    dst_layout : str or Layout
        Destination layout

    Returns
    -------
    dst_shape : tuple of int
        Destination shape
    """
    if src_layout == dst_layout:
        return get_const_tuple(src_shape)

    if isinstance(src_layout, str):
        src_layout = layout(src_layout)
    if isinstance(dst_layout, str):
        dst_layout = layout(dst_layout)

    assert len(src_layout) == len(
        dst_layout), "Incompatible layout %s vs %s" % (
            src_layout,
            dst_layout,
        )

    layout_mapping = bijective_layout(src_layout, dst_layout)
    dst_indices = layout_mapping.forward_index(
        tvm.runtime.convert(list(range(len(src_layout)))))

    return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices]))