def test_bilayout_index(): bilayout = tvm.bijective_layout("NCHW", "NCHW16c") dst_index = bilayout.forward_index([0, 18, 6, 6]) assert get_const_tuple(dst_index) == (0, 1, 6, 6, 2) src_index = bilayout.backward_index([0, 1, 6, 6, 2]) assert get_const_tuple(src_index) == (0, 18, 6, 6)
def test_bilayout_shape(): bilayout = tvm.bijective_layout("NCHW", "NCHW16c") assert isinstance(bilayout, tvm.tensor.BijectiveLayout) dst_shape = bilayout.forward_shape((1, 32, 7, 7)) assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16) src_shape = bilayout.backward_shape(dst_shape) assert get_const_tuple(src_shape) == (1, 32, 7, 7)
def test_bilayout_shape(): bilayout = tvm.bijective_layout("NCHW", "NCHW16c") assert isinstance(bilayout, tvm.tir.BijectiveLayout) dst_shape = bilayout.forward_shape((1, 32, 7, 7)) assert get_const_tuple(dst_shape) == (1, 2, 7, 7, 16) src_shape = bilayout.backward_shape(dst_shape) assert get_const_tuple(src_shape) == (1, 32, 7, 7)
def test_bilayout_convertible(): # not convertible assert tvm.bijective_layout("NCHW", "ABCD") is None assert tvm.bijective_layout("__undef__", "NCHW") is None assert tvm.bijective_layout("NCHW", "__undef__") is None assert tvm.bijective_layout("__undef__", "__undef__") is None assert tvm.bijective_layout("", "NCHW") is None assert tvm.bijective_layout("NCHW", "") is None assert tvm.bijective_layout("", "") is None # convertible assert tvm.bijective_layout("NCHW", "NCHW16c") is not None