def test_func_type_alpha_equal(): t1 = relay.TensorType((1, 2), "float32") t2 = relay.TensorType((1, 2, 3), "float32") tp1 = relay.TypeVar("v1", relay.Kind.Type) tp2 = relay.TypeVar("v2", relay.Kind.Type) tp3 = relay.TypeVar("v3", relay.Kind.Shape) tp4 = relay.TypeVar("v3", relay.Kind.Shape) broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast") identity = tvm.get_env_func("tvm.relay.type_relation.Identity") tr1 = relay.TypeRelation(broadcast, tvm.convert([tp1, tp3]), 1, None) tr2 = relay.TypeRelation(broadcast, tvm.convert([tp2, tp4]), 1, None) tr3 = relay.TypeRelation(identity, tvm.convert([tp1, tp3]), 1, None) ft = relay.FuncType(tvm.convert([t1, t2]), tp1, tvm.convert([tp1, tp3]), tvm.convert([tr1])) translate_vars = relay.FuncType(tvm.convert([t1, t2]), tp1, tvm.convert([tp2, tp4]), tvm.convert([tr2])) assert ft == translate_vars different_args = relay.FuncType(tvm.convert([t1]), tp1, tvm.convert([tp1, tp3]), tvm.convert([tr1])) assert ft != different_args different_order = relay.FuncType(tvm.convert([t2, t1]), tp1, tvm.convert([tp1, tp3]), tvm.convert([tr1])) assert ft != different_order no_rel = relay.FuncType(tvm.convert([t1, t2]), tp1, tvm.convert([tp1, tp3]), tvm.convert([])) assert ft != no_rel more_vars = relay.FuncType(tvm.convert([t1, t2]), tp2, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr1])) assert ft != more_vars all_the_vars = relay.FuncType(tvm.convert([t1, t2]), tp1, tvm.convert([tp1, tp2, tp3, tp4]), tvm.convert([tr1, tr2])) assert ft != all_the_vars different_rel = relay.FuncType(tvm.convert([t1, t2]), tp1, tvm.convert([tp1, tp3]), tvm.convert([tr3])) assert ft != different_rel more_rels = relay.FuncType(tvm.convert([t1, t2]), tp1, tvm.convert([tp1, tp3]), tvm.convert([tr1, tr3])) assert ft != more_rels
def test_type_relation_sequal(): t1 = relay.TensorType((1, 2), "float32") t2 = relay.TensorType((1, 2, 3), "float32") t3 = relay.TensorType((1, 2, 3, 4), "float32") # functions are compared only by pointer equality so # we need to be sure to use the same pointers broadcast = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") identity = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity") attr1 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3, 4)) attr1_same = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3, 4)) attr2 = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3, 4, 4)) tr = relay.TypeRelation(broadcast, tvm.runtime.convert([t1, t2]), 1, attr1) same = relay.TypeRelation(broadcast, tvm.runtime.convert([t1, t2]), 1, attr1) diff_func = relay.TypeRelation(identity, tvm.runtime.convert([t1, t2]), 1, attr1) diff_order = relay.TypeRelation(broadcast, tvm.runtime.convert([t2, t1]), 1, attr1) diff_args = relay.TypeRelation(broadcast, tvm.runtime.convert([t2, t3]), 1, attr1) diff_attr = relay.TypeRelation(broadcast, tvm.runtime.convert([t1, t2]), 1, attr2) same_attr = relay.TypeRelation(broadcast, tvm.runtime.convert([t1, t2]), 1, attr1_same) bigger = relay.TypeRelation(identity, tvm.runtime.convert([t1, t3, t2]), 2, attr1) diff_num_inputs = relay.TypeRelation(identity, tvm.runtime.convert([t1, t3, t2]), 1, attr2) # func, number of args, input count, and order should be the same assert tr == same assert tr != diff_func assert tr != diff_order assert tr != diff_args assert tr != diff_attr assert tr == same_attr assert tr != bigger assert bigger != diff_num_inputs
def test_invalid_relation_kind(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) args = tvm.convert([tp1, tp2, tp3]) tr = relay.TypeRelation(None, args, 2, None) assert not check_kind(tr)
def test_invalid_relation_kind(): tp1 = relay.TypeVar("tp1", relay.TypeKind.ShapeVar) tp2 = relay.TypeVar("tp2", relay.TypeKind.BaseType) tp3 = relay.TypeVar("tp3", relay.TypeKind.Constraint) args = tvm.runtime.convert([tp1, tp2, tp3]) func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") tr = relay.TypeRelation(func, args, 2, None) check_kind(tr)
def test_invalid_relation_kind(): tp1 = relay.TypeVar('tp1', relay.Kind.Shape) tp2 = relay.TypeVar('tp2', relay.Kind.BaseType) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) args = tvm.convert([tp1, tp2, tp3]) func = tvm.get_env_func("tvm.relay.type_relation.Broadcast") tr = relay.TypeRelation(func, args, 2, None) check_kind(tr)
def test_relation_kind(): # only have type kinds for arguments tp = relay.TypeVar('tp', relay.Kind.Type) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') tf = relay.FuncType(tvm.convert([]), tt, tvm.convert([]), tvm.convert([])) args = tvm.convert([tf, tt, tp]) tr = relay.TypeRelation(None, args, 2, None) assert check_kind(tr)
def test_relation_kind(): # only have type kinds for arguments tp = relay.TypeVar("tp", relay.TypeKind.Type) tt = relay.TensorType(tvm.runtime.convert([1, 2, 3]), "float32") tf = relay.FuncType(tvm.runtime.convert([]), tt, tvm.runtime.convert([]), tvm.runtime.convert([])) args = tvm.runtime.convert([tf, tt, tp]) tr = relay.TypeRelation(None, args, 2, None) assert check_kind(tr) == relay.TypeKind.Constraint
def test_func_with_invalid_relation(): tp1 = relay.TypeVar('tp1', relay.Kind.Type) tp2 = relay.TypeVar('tp2', relay.Kind.Shape) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) tr = relay.TypeRelation(None, tvm.convert([tp2, tp3]), 1, None) tf = relay.FuncType(tvm.convert([tp1]), tp1, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr])) assert not check_kind(tf)
def test_func_with_invalid_relation(): tp1 = relay.TypeVar('tp1', relay.TypeKind.Type) tp2 = relay.TypeVar('tp2', relay.TypeKind.ShapeVar) tp3 = relay.TypeVar('tp3', relay.TypeKind.Constraint) func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Identity") tr = relay.TypeRelation(func, tvm.runtime.convert([tp2, tp3]), 1, None) tf = relay.FuncType(tvm.runtime.convert([tp1]), tp1, tvm.runtime.convert([tp1, tp2, tp3]), tvm.runtime.convert([tr])) check_kind(tf)
def test_func_with_invalid_relation(): tp1 = relay.TypeVar('tp1', relay.Kind.Type) tp2 = relay.TypeVar('tp2', relay.Kind.Shape) tp3 = relay.TypeVar('tp3', relay.Kind.ShapeVar) func = tvm.get_env_func("tvm.relay.type_relation.Identity") tr = relay.TypeRelation(func, tvm.convert([tp2, tp3]), 1, None) tf = relay.FuncType(tvm.convert([tp1]), tp1, tvm.convert([tp1, tp2, tp3]), tvm.convert([tr])) check_kind(tf)
def test_type_relation(): tp = relay.TypeVar('tp', relay.Kind.Type) tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([])) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') args = tvm.convert([tf, tt, tp]) num_inputs = 2 func = tvm.get_env_func("tvm.relay.type_relation.Broadcast") attrs = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) tr = relay.TypeRelation(func, args, num_inputs, attrs) assert tr.args == args assert tr.num_inputs == num_inputs
def test_type_relation(): tp = relay.TypeParam('tp', relay.Kind.Type) tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([])) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') args = tvm.convert([tf, tt, tp]) num_inputs = 2 func = None attrs = None tr = relay.TypeRelation(func, args, num_inputs, attrs) assert tr.args == args assert tr.num_inputs == num_inputs
def test_type_relation(): tp = relay.TypeVar('tp', relay.TypeKind.Type) tf = relay.FuncType(tvm.convert([]), None, tvm.convert([]), tvm.convert([])) tt = relay.TensorType(tvm.convert([1, 2, 3]), 'float32') args = tvm.convert([tp, tf, tt]) num_inputs = 2 func = tvm.ir.EnvFunc.get("tvm.relay.type_relation.Broadcast") attrs = tvm.ir.make_node("attrs.TestAttrs", name="attr", padding=(3,4)) tr = relay.TypeRelation(func, args, num_inputs, attrs) assert tr.args == args assert tr.num_inputs == num_inputs str(tr) check_json_roundtrip(tr)
def test_func_kind(): # only contain type kinds tp1 = relay.TypeVar('tp1', relay.Kind.Type) tp2 = relay.TypeVar('tp2', relay.Kind.Type) shape = tvm.convert([1, 2, 3]) dtype = 'float32' tensor_type = relay.TensorType(shape, dtype) tr = relay.TypeRelation(None, tvm.convert([tensor_type, tp1]), 1, None) type_params = tvm.convert([tp1, tp2]) type_constraints = tvm.convert([tr]) arg_types = tvm.convert([tp1, tensor_type]) ret_type = relay.TupleType(tvm.convert([tp2, tensor_type])) tf = relay.FuncType(arg_types, ret_type, type_params, type_constraints) assert check_kind(tf)